This commit is contained in:
Matthieu Bessat 2026-01-05 14:07:27 +01:00
parent 5f45671b74
commit 7268180d4d
5 changed files with 64 additions and 29 deletions

View file

@ -12,7 +12,7 @@ pub fn derive_sql_generator_model(_input: TokenStream) -> TokenStream {
TokenStream::new() TokenStream::new()
} }
#[proc_macro_derive(SqlGeneratorModelWithId, attributes(sql_generator_field))] #[proc_macro_derive(SqlGeneratorModelWithId)]
pub fn derive_sql_generator_model_with_id(input: TokenStream) -> TokenStream { pub fn derive_sql_generator_model_with_id(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput); let input = parse_macro_input!(input as DeriveInput);
let name = input.ident; let name = input.ident;

View file

@ -21,49 +21,42 @@ fn gen_get_all_method(model: &Model) -> TokenStream {
} }
} }
fn gen_get_by_id_method(model: &Model) -> TokenStream { fn gen_get_by_field_method(model: &Model, query_field: &Field) -> TokenStream {
let resource_ident = format_ident!("{}", &model.name); let resource_ident = format_ident!("{}", &model.name);
let primary_key = &model.fields.iter() let select_query = format!("SELECT * FROM {} WHERE {} = $1", model.table_name, query_field.name);
.find(|f| f.is_primary)
.expect("A model must have at least one primary key")
.name;
let select_query = format!("SELECT * FROM {} WHERE {} = $1", model.table_name, primary_key);
let func_name_ident = format_ident!("get_by_{}", primary_key); let func_name_ident = format_ident!("get_by_{}", query_field.name);
quote! { quote! {
pub async fn #func_name_ident(&self, item_id: &str) -> Result<#resource_ident, sqlx::Error> { // FIXME: Value is not necesssarly a string, it can be an int or a bool
pub async fn #func_name_ident(&self, value: &str) -> Result<#resource_ident, sqlx::Error> {
sqlx::query_as::<_, #resource_ident>(#select_query) sqlx::query_as::<_, #resource_ident>(#select_query)
.bind(item_id) .bind(value)
.fetch_one(&self.db.0) .fetch_one(&self.db.0)
.await .await
} }
} }
} }
fn gen_get_many_by_id_method(model: &Model) -> TokenStream { fn gen_get_many_by_field_method(model: &Model, query_field: &Field) -> TokenStream {
let resource_ident = format_ident!("{}", &model.name); let resource_ident = format_ident!("{}", &model.name);
let primary_key = &model.fields.iter() let select_query_tmpl = format!("SELECT * FROM {} WHERE {} IN ({{}})", model.table_name, query_field.name);
.find(|f| f.is_primary)
.expect("A model must have at least one primary key")
.name;
let select_query_tmpl = format!("SELECT * FROM {} WHERE {} IN ({{}})", model.table_name, primary_key);
let func_name_ident = format_ident!("get_many_by_{}", primary_key); let func_name_ident = format_ident!("get_many_by_{}", query_field.name);
quote! { quote! {
pub async fn #func_name_ident(&self, items_ids: &[&str]) -> Result<Vec<#resource_ident>, sqlx::Error> { pub async fn #func_name_ident(&self, values: &[&str]) -> Result<Vec<#resource_ident>, sqlx::Error> {
if items_ids.is_empty() { if values.is_empty() {
return Ok(vec![]) return Ok(vec![])
} }
let placeholder_params: String = (1..=(items_ids.len())) let placeholder_params: String = (1..=(values.len()))
.map(|i| format!("${}", i)) .map(|i| format!("${}", i))
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join(","); .join(",");
let query_sql = format!(#select_query_tmpl, placeholder_params); let query_sql = format!(#select_query_tmpl, placeholder_params);
let mut query = sqlx::query_as::<_, #resource_ident>(&query_sql); let mut query = sqlx::query_as::<_, #resource_ident>(&query_sql);
for id in items_ids { for value in values {
query = query.bind(id) query = query.bind(value)
} }
query query
.fetch_all(&self.db.0) .fetch_all(&self.db.0)
@ -231,14 +224,45 @@ pub fn generate_repository_file(all_models: &[Model], model: &Model) -> Result<S
let repository_ident = format_ident!("{}Repository", resource_ident); let repository_ident = format_ident!("{}Repository", resource_ident);
let get_all_method_code = gen_get_all_method(model); let get_all_method_code = gen_get_all_method(model);
let get_by_id_method_code = gen_get_by_id_method(model); let get_by_id_method_code = gen_get_by_field_method(
let get_many_by_id_method_code = gen_get_many_by_id_method(model); model,
model.fields.iter()
.find(|f| f.is_primary == true)
.expect("Expected at least one primary key on the model.")
);
let get_many_by_id_method_code = gen_get_many_by_field_method(
model,
model.fields.iter()
.find(|f| f.is_primary == true)
.expect("Expected at least one primary key on the model.")
);
let insert_method_code = gen_insert_method(model); let insert_method_code = gen_insert_method(model);
let insert_many_method_code = gen_insert_many_method(model); let insert_many_method_code = gen_insert_many_method(model);
let update_by_id_method_code = gen_update_by_id_method(model); let update_by_id_method_code = gen_update_by_id_method(model);
let delete_by_id_method_code = gen_delete_by_id_method(model); let delete_by_id_method_code = gen_delete_by_id_method(model);
let query_by_field_methods: Vec<TokenStream> =
model.fields.iter()
.filter(|f| f.is_query_entrypoint)
.map(|field|
gen_get_by_field_method(
model,
&field
)
)
.collect();
let query_many_by_field_methods: Vec<TokenStream> =
model.fields.iter()
.filter(|f| f.is_query_entrypoint)
.map(|field|
gen_get_many_by_field_method(
model,
&field
)
)
.collect();
let fields_with_foreign_refs: Vec<&Field> = model.fields.iter().filter(|f| let fields_with_foreign_refs: Vec<&Field> = model.fields.iter().filter(|f|
match f.foreign_mode { FieldForeignMode::ForeignRef(_) => true, FieldForeignMode::NotRef => false } match f.foreign_mode { FieldForeignMode::ForeignRef(_) => true, FieldForeignMode::NotRef => false }
).collect(); ).collect();
@ -279,6 +303,10 @@ pub fn generate_repository_file(all_models: &[Model], model: &Model) -> Result<S
#delete_by_id_method_code #delete_by_id_method_code
#(#query_by_field_methods)*
#(#query_many_by_field_methods)*
#(#related_entity_methods_codes)* #(#related_entity_methods_codes)*
#(#related_entities_methods_codes)* #(#related_entities_methods_codes)*

View file

@ -24,7 +24,11 @@ pub struct SqlGeneratorModelAttr {
pub struct SqlGeneratorFieldAttr { pub struct SqlGeneratorFieldAttr {
is_primary: Option<bool>, is_primary: Option<bool>,
is_unique: Option<bool>, is_unique: Option<bool>,
reverse_relation_name: Option<String> reverse_relation_name: Option<String>,
/// to indicate that this field will be used to obtains entities
/// our framework will generate methods for all fields that is an entrypoint
is_query_entrypoint: Option<bool>
} }

View file

@ -57,7 +57,7 @@ struct Field {
is_nullable: bool, is_nullable: bool,
is_unique: bool, is_unique: bool,
is_primary: bool, is_primary: bool,
foreign_mode: FieldForeignMode, is_query_entrypoint: bool,
// is_foreign_ref: bool, foreign_mode: FieldForeignMode
// reverse_relation_name: Option<String>
} }

View file

@ -127,7 +127,7 @@ fn parse_field_attribute(field: &syn::Field) -> Result<Option<SqlGeneratorFieldA
return Ok(Some(v)); return Ok(Some(v));
}, },
Err(err) => { Err(err) => {
return Err(anyhow!("Failed to parse sql_generator_field attribute macro: {}", err)); return Err(anyhow!("Failed to parse sql_generator_field attribute macro on field {:?}, {}", field, err));
} }
}; };
} }
@ -173,6 +173,7 @@ pub fn parse_models(source_code_path: &Path) -> Result<Vec<Model>> {
is_nullable: false, is_nullable: false,
is_primary: false, is_primary: false,
is_unique: false, is_unique: false,
is_query_entrypoint: false,
foreign_mode: FieldForeignMode::NotRef foreign_mode: FieldForeignMode::NotRef
}; };
@ -256,6 +257,7 @@ pub fn parse_models(source_code_path: &Path) -> Result<Vec<Model>> {
if let Some(field_attr) = field_attrs_opt { if let Some(field_attr) = field_attrs_opt {
output_field.is_primary = field_attr.is_primary.unwrap_or_default(); output_field.is_primary = field_attr.is_primary.unwrap_or_default();
output_field.is_unique = field_attr.is_unique.unwrap_or_default(); output_field.is_unique = field_attr.is_unique.unwrap_or_default();
output_field.is_query_entrypoint = field_attr.is_query_entrypoint.unwrap_or_default();
} }
fields.push(output_field); fields.push(output_field);
@ -279,6 +281,7 @@ fn parse_models_from_module_inner(module_path: &Path) -> Result<Vec<Model>> {
let mut models: Vec<Model> = vec![]; let mut models: Vec<Model> = vec![];
if module_path.is_file() { if module_path.is_file() {
println!("Parsing models from path {:?}.", module_path);
models.extend(parse_models(module_path)?); models.extend(parse_models(module_path)?);
return Ok(models); return Ok(models);
} }