This commit is contained in:
Matthieu Bessat 2025-01-27 09:08:04 +01:00
parent c15e69a6c4
commit 09791951d9
13 changed files with 250 additions and 128 deletions

View file

@ -0,0 +1,22 @@
[package]
name = "generator_cli"
edition = "2021"
[[bin]]
name = "sqlx-generator"
path = "src/main.rs"
[dependencies]
anyhow = "1.0.95"
argh = "0.1.13"
attribute-derive = "0.10.3"
convert_case = "0.6.0"
fully_pub = "0.1.4"
heck = "0.5.0"
prettyplease = "0.2.25"
proc-macro2 = "1.0.92"
quote = "1.0.38"
serde = "1.0.216"
serde_json = "1.0.134"
structmeta = "0.3.0"
syn = { version = "2.0.92", features = ["extra-traits", "full", "parsing"] }

View file

@ -0,0 +1,61 @@
use anyhow::{Result, anyhow};
use crate::models::{Field, Model};
// Implementations
impl Field {
/// return sqlite type
fn sql_type(&self) -> Option<String> {
// for now, we just match against the rust type string representation
match self.rust_type.as_str() {
"u64" => Some("INTEGER".into()),
"u32" => Some("INTEGER".into()),
"i32" => Some("INTEGER".into()),
"i64" => Some("INTEGER".into()),
"f64" => Some("REAL".into()),
"f32" => Some("REAL".into()),
"String" => Some("TEXT".into()),
"DateTime" => Some("DATETIME".into()),
"Json" => Some("TEXT".into()),
"Vec<u8>" => Some("BLOB".into()),
_ => Some("TEXT".into())
}
}
}
/// Generate CREATE TABLE statement from parsed model
pub fn generate_create_table_sql(models: &Vec<Model>) -> Result<String> {
let mut sql_code: String = "".into();
for model in models.iter() {
let mut fields_sql: Vec<String> = vec![];
for field in model.fields.iter() {
let mut additions: String = "".into();
let sql_type = field.sql_type()
.ok_or(anyhow!(format!("Could not find SQL type for field {}", field.name)))?;
if !field.is_nullable {
additions.push_str(" NOT NULL");
}
if field.is_unique {
additions.push_str(" UNIQUE");
}
if field.is_primary {
additions.push_str(" PRIMARY KEY");
}
fields_sql.push(
format!("\t{: <#18}\t{}{}", field.name, sql_type, additions)
);
}
sql_code.push_str(
&format!(
"CREATE TABLE {} (\n{}\n);\n",
model.table_name,
fields_sql.join(",\n")
)
);
}
Ok(sql_code)
}

View file

@ -0,0 +1,166 @@
use anyhow::Result;
use fully_pub::fully_pub;
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use serde::Serialize;
use syn::File;
use heck::ToSnakeCase;
use crate::models::Model;
fn gen_get_all_method(model: &Model) -> TokenStream {
let resource_ident = format_ident!("{}", &model.name);
let error_msg = format!("Failed to fetch resource {:?}", model.name.clone());
let select_query = format!("SELECT * FROM {}", model.table_name);
quote! {
pub async fn get_all(&self) -> Result<Vec<#resource_ident>> {
sqlx::query_as::<_, #resource_ident>(#select_query)
.fetch_all(&self.db.0)
.await
.context(#error_msg)
}
}
}
fn gen_get_by_id_method(model: &Model) -> TokenStream {
let resource_ident = format_ident!("{}", &model.name);
let error_msg = format!("Failed to fetch resource {:?}", model.name.clone());
let primary_key = &model.fields.iter()
.find(|f| f.is_primary)
.expect("A model must have at least one primary key")
.name;
let select_query = format!("SELECT * FROM {} WHERE {} = $1", model.table_name, primary_key);
let func_name_ident = format_ident!("get_by_{}", primary_key);
quote! {
pub async fn #func_name_ident(&self, item_id: &str) -> Result<#resource_ident> {
sqlx::query_as::<_, #resource_ident>(#select_query)
.bind(item_id)
.fetch_one(&self.db.0)
.await
.context(#error_msg)
}
}
}
fn gen_insert_method(model: &Model) -> TokenStream {
let resource_ident = format_ident!("{}", &model.name);
let error_msg = format!("Failed to insert resource {:?}", model.name.clone());
let sql_columns = model.fields.iter()
.map(|f| f.name.clone())
.collect::<Vec<String>>()
.join(", ");
let value_templates = (1..(model.fields.len()+1))
.map(|i| format!("${}", i))
.collect::<Vec<String>>()
.join(", ");
let insert_query = format!(
"INSERT INTO {} ({}) VALUES ({})",
model.table_name,
sql_columns,
value_templates
);
let field_names: Vec<proc_macro2::Ident> = model.fields.iter()
.map(|f| format_ident!("{}", &f.name))
.collect();
quote! {
pub async fn insert(&self, entity: &#resource_ident) -> Result<()> {
sqlx::query(#insert_query)
#( .bind( &entity.#field_names ) )*
.execute(&self.db.0)
.await
.context(#error_msg)?;
Ok(())
}
}
}
fn generate_repository_file(model: &Model) -> Result<SourceNodeContainer> {
let resource_name = model.name.clone();
let resource_module_ident = format_ident!("{}", &model.module_path.get(0).unwrap());
let resource_ident = format_ident!("{}", &resource_name);
let repository_ident = format_ident!("{}Repository", resource_ident);
let get_all_method_code = gen_get_all_method(&model);
let get_by_id_method_code = gen_get_by_id_method(&model);
let insert_method_code = gen_insert_method(&model);
// TODO: add import line
let base_repository_code: TokenStream = quote! {
use crate::models::#resource_module_ident::#resource_ident;
use crate::services::database::Database;
use anyhow::{Result, Context};
pub struct #repository_ident {
db: Database
}
impl #repository_ident {
pub fn new(db: Database) -> Self {
#repository_ident {
db
}
}
#get_all_method_code
#get_by_id_method_code
#insert_method_code
}
};
// convert TokenStream into rust code as string
let parse_res: syn::Result<File> = syn::parse2(base_repository_code);
let pretty = prettyplease::unparse(&parse_res?);
Ok(SourceNodeContainer {
name: format!("{}_repository.rs", model.name.to_snake_case()),
inner: SourceNode::File(pretty)
})
}
#[derive(Serialize, Debug)]
#[fully_pub]
enum SourceNode {
File(String),
Directory(Vec<SourceNodeContainer>)
}
#[derive(Serialize, Debug)]
#[fully_pub]
struct SourceNodeContainer {
name: String,
inner: SourceNode
}
/// Generate base repositories for all models
pub fn generate_repositories_source_files(models: &Vec<Model>) -> Result<SourceNodeContainer> {
let mut nodes: Vec<SourceNodeContainer> = vec![];
for model in models.iter() {
let snc = generate_repository_file(model)?;
nodes.push(snc)
}
let mut mod_index_code: String = String::new();
for node in &nodes {
let module_name = node.name.replace(".rs", "");
mod_index_code.push_str(&format!("pub mod {module_name};\n"));
}
nodes.push(SourceNodeContainer {
name: "mod.rs".into(),
inner: SourceNode::File(mod_index_code.to_string())
});
Ok(SourceNodeContainer {
name: "".into(),
inner: SourceNode::Directory(nodes)
})
}

View file

@ -0,0 +1,140 @@
use std::{ffi::OsStr, path::Path};
use attribute_derive::FromAttr;
use argh::FromArgs;
use anyhow::{Result, anyhow};
use gen_migrations::generate_create_table_sql;
use gen_repositories::{generate_repositories_source_files, SourceNodeContainer};
pub mod models;
pub mod parse_models;
pub mod gen_migrations;
pub mod gen_repositories;
#[derive(FromAttr, PartialEq, Debug, Default)]
#[attribute(ident = sql_generator_model)]
pub struct SqlGeneratorModelAttr {
table_name: Option<String>
}
#[derive(FromAttr, PartialEq, Debug, Default)]
#[attribute(ident = sql_generator_field)]
pub struct SqlGeneratorFieldAttr {
is_primary: Option<bool>,
is_unique: Option<bool>
}
#[derive(FromArgs, PartialEq, Debug)]
/// Generate SQL CREATE TABLE migrations
#[argh(subcommand, name = "gen-migrations")]
struct GenerateMigration {
/// path of file where to write all in one generated SQL migration
#[argh(option, short = 'o')]
output: Option<String>
}
#[derive(FromArgs, PartialEq, Debug)]
/// Generate Rust SQLx repositories code
#[argh(subcommand, name = "gen-repositories")]
struct GenerateRepositories {
}
#[derive(FromArgs, PartialEq, Debug)]
#[argh(subcommand)]
enum GeneratorArgsSubCommands {
GenerateMigration(GenerateMigration),
GenerateRepositories(GenerateRepositories),
}
#[derive(FromArgs)]
/// SQLX Generator args
struct GeneratorArgs {
/// whether or not to debug
#[argh(switch, short = 'd')]
debug: bool,
/// path where to find Cargo.toml
#[argh(option)]
project_root: Option<String>,
#[argh(subcommand)]
nested: GeneratorArgsSubCommands
}
fn write_source_code(base_path: &Path, snc: SourceNodeContainer) -> Result<()> {
let path = base_path.join(snc.name);
match snc.inner {
gen_repositories::SourceNode::File(code) => {
println!("writing file {:?}", path);
std::fs::write(path, code)?;
},
gen_repositories::SourceNode::Directory(dir) => {
for node in dir {
write_source_code(&path, node)?;
}
}
}
Ok(())
}
pub fn main() -> Result<()> {
let args: GeneratorArgs = argh::from_env();
let project_root = &args.project_root.unwrap_or(".".to_string());
let project_root_path = Path::new(&project_root);
eprintln!("Using project root at: {:?}", &project_root_path.canonicalize()?);
if !project_root_path.exists() {
return Err(anyhow!("Could not resolve project root path."));
}
// check Cargo.toml
let main_manifest_location = "Cargo.toml";
let main_manifest_path = project_root_path.join(main_manifest_location);
if !main_manifest_path.exists() {
return Err(anyhow!("Could not find Cargo.toml in project root."));
}
// search for a models modules
let models_mod_location = "src/models.rs";
let mut models_mod_path = project_root_path.join(models_mod_location);
if !models_mod_path.exists() {
let models_mod_location = "src/models/mod.rs";
models_mod_path = project_root_path.join(models_mod_location);
}
if !models_mod_path.exists() {
return Err(anyhow!("Could not resolve models modules."));
}
if models_mod_path.file_name().map(|x| x == OsStr::new("mod.rs")).unwrap_or(false) {
models_mod_path.pop();
}
eprintln!("Found models in project, parsing models");
let models = parse_models::parse_models_from_module(&models_mod_path)?;
dbg!(&models);
match args.nested {
GeneratorArgsSubCommands::GenerateRepositories(opts) => {
eprintln!("Generating repositories…");
// search for a repository module
let repositories_mod_location = "src/repositories";
let repositories_mod_path = project_root_path.join(repositories_mod_location);
if !repositories_mod_path.exists() {
return Err(anyhow!("Could not resolve repositories modules."));
}
let snc = generate_repositories_source_files(&models)?;
dbg!(&snc);
write_source_code(&repositories_mod_path, snc)?;
},
GeneratorArgsSubCommands::GenerateMigration(opts) => {
eprintln!("Generating migrations…");
let sql_code = generate_create_table_sql(&models)?;
if let Some(out_location) = opts.output {
let output_path = Path::new(&out_location);
let write_res = std::fs::write(output_path, sql_code);
eprintln!("{:?}", write_res);
} else {
println!("{}", sql_code);
}
}
}
Ok(())
}

View file

@ -0,0 +1,21 @@
// BASE MODELS
use fully_pub::fully_pub;
#[derive(Debug)]
#[fully_pub]
struct Model {
module_path: Vec<String>,
name: String,
table_name: String,
fields: Vec<Field>
}
#[derive(Debug)]
#[fully_pub]
struct Field {
name: String,
rust_type: String,
is_nullable: bool,
is_unique: bool,
is_primary: bool
}

View file

@ -0,0 +1,236 @@
use std::{fs, path::Path};
use attribute_derive::FromAttr;
use anyhow::{Result, anyhow};
use convert_case::{Case, Casing};
use syn::Type;
use crate::{models::{Field, Model}, SqlGeneratorFieldAttr, SqlGeneratorModelAttr};
fn extract_generic_type(base_segments: Vec<String>, ty: &syn::Type) -> Option<&syn::Type> {
// If it is not `TypePath`, it is not possible to be `Option<T>`, return `None`
if let syn::Type::Path(syn::TypePath { qself: None, path }) = ty {
// We have limited the 5 ways to write `Option`, and we can see that after `Option`,
// there will be no `PathSegment` of the same level
// Therefore, we only need to take out the highest level `PathSegment` and splice it into a string
// for comparison with the analysis result
let segments_str = &path
.segments
.iter()
.map(|segment| segment.ident.to_string())
.collect::<Vec<_>>()
.join(":");
// Concatenate `PathSegment` into a string, compare and take out the `PathSegment` where `Option` is located
let option_segment = base_segments
.iter()
.find(|s| segments_str == *s)
.and_then(|_| path.segments.last());
let inner_type = option_segment
// Take out the generic parameters of the `PathSegment` where `Option` is located
// If it is not generic, it is not possible to be `Option<T>`, return `None`
// But this situation may not occur
.and_then(|path_seg| match &path_seg.arguments {
syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments {
args,
..
}) => args.first(),
_ => None,
})
// Take out the type information in the generic parameter
// If it is not a type, it is not possible to be `Option<T>`, return `None`
// But this situation may not occur
.and_then(|generic_arg| match generic_arg {
syn::GenericArgument::Type(ty) => Some(ty),
_ => None,
});
// Return `T` in `Option<T>`
return inner_type;
}
None
}
fn get_type_first_ident(inp: &Type) -> Option<String> {
match inp {
syn::Type::Path(field_type_path) => {
Some(field_type_path.path.segments.get(0).unwrap().ident.to_string())
},
_ => {
None
}
}
}
fn parse_model_attribute(item: &syn::ItemStruct) -> Result<Option<SqlGeneratorModelAttr>> {
for attr in item.attrs.iter() {
let attr_ident = match attr.path().get_ident() {
Some(v) => v,
None => {
continue;
}
};
if attr_ident.to_string() != "sql_generator_model" {
continue;
}
match SqlGeneratorModelAttr::from_attribute(attr) {
Ok(v) => {
return Ok(Some(v));
},
Err(err) => {
return Err(anyhow!("Failed to parse sql_generator_model attribute macro: {}", err));
}
};
}
Ok(None)
}
fn parse_field_attribute(field: &syn::Field) -> Result<Option<SqlGeneratorFieldAttr>> {
for attr in field.attrs.iter() {
let attr_ident = match attr.path().get_ident() {
Some(v) => v,
None => {
continue;
}
};
if attr_ident.to_string() != "sql_generator_field" {
continue;
}
match SqlGeneratorFieldAttr::from_attribute(attr) {
Ok(v) => {
return Ok(Some(v));
},
Err(err) => {
return Err(anyhow!("Failed to parse sql_generator_field attribute macro: {}", err));
}
};
}
Ok(None)
}
/// Take struct name as source, apply snake case and pluralize with a s
fn generate_table_name_from_struct_name(struct_name: &str) -> String {
return format!(
"{}s",
struct_name.to_case(Case::Snake)
);
}
/// Scan for models struct in a rust file and return a struct representing the model
pub fn parse_models(source_code_path: &Path) -> Result<Vec<Model>> {
let models_code = fs::read_to_string(source_code_path)?;
let parsed_file = syn::parse_file(&models_code)?;
let mut models: Vec<Model> = vec![];
for item in parsed_file.items {
match item {
syn::Item::Struct(itemval) => {
let model_name = itemval.ident.to_string();
let model_attrs = match parse_model_attribute(&itemval)? {
Some(v) => v,
None => {
// we require model struct to have the `sql_generator_model` attribute
continue;
}
};
let mut fields: Vec<Field> = vec![];
for field in itemval.fields.iter() {
let field_name = field.ident.clone().unwrap().to_string();
let field_type = field.ty.clone();
// println!("field {}", field_name);
let mut output_field = Field {
name: field_name,
rust_type: "Unknown".into(),
is_nullable: false,
is_primary: false,
is_unique: false
};
let first_type: String = match get_type_first_ident(&field_type) {
Some(v) => v,
None => {
return Err(anyhow!("Could not extract ident from Option inner type"));
}
};
let mut final_type = first_type.clone();
if first_type == "Option" {
output_field.is_nullable = true;
let inner_type = match extract_generic_type(
vec!["Option".into(), "std:option:Option".into(), "core:option:Option".into()],
&field_type
) {
Some(v) => v,
None => {
return Err(anyhow!("Could not extract type from Option"));
}
};
final_type = match get_type_first_ident(inner_type) {
Some(v) => v,
None => {
return Err(anyhow!("Could not extract ident from Option inner type"));
}
}
}
if first_type == "Vec" {
let inner_type = match extract_generic_type(
vec!["Vec".into()],
&field_type
) {
Some(v) => v,
None => {
return Err(anyhow!("Could not extract type from Vec"));
}
};
final_type = match get_type_first_ident(inner_type) {
Some(v) => format!("Vec<{}>", v),
None => {
return Err(anyhow!("Could not extract ident from Vec inner type"));
}
}
}
output_field.rust_type = final_type;
// parse attribute
if let Some(field_attr) = parse_field_attribute(field)? {
output_field.is_primary = field_attr.is_primary.unwrap_or_default();
output_field.is_unique = field_attr.is_unique.unwrap_or_default();
}
fields.push(output_field);
}
models.push(Model {
module_path: vec![source_code_path.file_stem().unwrap().to_str().unwrap().to_string()],
name: model_name.clone(),
table_name: model_attrs.table_name
.unwrap_or(generate_table_name_from_struct_name(&model_name)),
fields
})
},
_ => {}
}
}
Ok(models)
}
/// Scan for models struct in a rust file and return a struct representing the model
pub fn parse_models_from_module(module_path: &Path) -> Result<Vec<Model>> {
let mut models: Vec<Model> = vec![];
if module_path.is_file() {
models.extend(parse_models(&module_path)?);
return Ok(models);
}
let entries = fs::read_dir(module_path)
.map_err(|err| anyhow!("Could not scan models directory. {:?}", err))?;
for dir_entry_res in entries {
let file_path = dir_entry_res?.path();
models.extend(parse_models_from_module(&file_path)?)
}
Ok(models)
}