From e3ce64222633c6a7815946a0da51a12ce113ce36 Mon Sep 17 00:00:00 2001 From: Matthieu Bessat Date: Mon, 5 Jan 2026 14:07:27 +0100 Subject: [PATCH] WIP --- lib/sqlxgentools_attrs/src/lib.rs | 2 +- .../src/generators/repositories/base.rs | 148 ++++++++++++------ lib/sqlxgentools_cli/src/main.rs | 6 +- lib/sqlxgentools_cli/src/models.rs | 6 +- lib/sqlxgentools_cli/src/parse_models.rs | 5 +- 5 files changed, 115 insertions(+), 52 deletions(-) diff --git a/lib/sqlxgentools_attrs/src/lib.rs b/lib/sqlxgentools_attrs/src/lib.rs index 5dd1506..f88bf77 100644 --- a/lib/sqlxgentools_attrs/src/lib.rs +++ b/lib/sqlxgentools_attrs/src/lib.rs @@ -12,7 +12,7 @@ pub fn derive_sql_generator_model(_input: TokenStream) -> TokenStream { TokenStream::new() } -#[proc_macro_derive(SqlGeneratorModelWithId, attributes(sql_generator_field))] +#[proc_macro_derive(SqlGeneratorModelWithId)] pub fn derive_sql_generator_model_with_id(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); let name = input.ident; diff --git a/lib/sqlxgentools_cli/src/generators/repositories/base.rs b/lib/sqlxgentools_cli/src/generators/repositories/base.rs index 51f61cb..340229d 100644 --- a/lib/sqlxgentools_cli/src/generators/repositories/base.rs +++ b/lib/sqlxgentools_cli/src/generators/repositories/base.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use proc_macro2::TokenStream; +use proc_macro2::{TokenStream, Ident}; use quote::{format_ident, quote}; use syn::File; use heck::ToSnakeCase; @@ -21,49 +21,42 @@ fn gen_get_all_method(model: &Model) -> TokenStream { } } -fn gen_get_by_id_method(model: &Model) -> TokenStream { +fn gen_get_by_field_method(model: &Model, query_field: &Field) -> TokenStream { let resource_ident = format_ident!("{}", &model.name); - let primary_key = &model.fields.iter() - .find(|f| f.is_primary) - .expect("A model must have at least one primary key") - .name; - let select_query = format!("SELECT * FROM {} WHERE {} = $1", model.table_name, primary_key); + let select_query = format!("SELECT * FROM {} WHERE {} = $1", model.table_name, query_field.name); - let func_name_ident = format_ident!("get_by_{}", primary_key); + let func_name_ident = format_ident!("get_by_{}", query_field.name); quote! { - pub async fn #func_name_ident(&self, item_id: &str) -> Result<#resource_ident, sqlx::Error> { + // FIXME: Value is not necesssarly a string, it can be an int or a bool + pub async fn #func_name_ident(&self, value: &str) -> Result<#resource_ident, sqlx::Error> { sqlx::query_as::<_, #resource_ident>(#select_query) - .bind(item_id) + .bind(value) .fetch_one(&self.db.0) .await } } } -fn gen_get_many_by_id_method(model: &Model) -> TokenStream { +fn gen_get_many_by_field_method(model: &Model, query_field: &Field) -> TokenStream { let resource_ident = format_ident!("{}", &model.name); - let primary_key = &model.fields.iter() - .find(|f| f.is_primary) - .expect("A model must have at least one primary key") - .name; - let select_query_tmpl = format!("SELECT * FROM {} WHERE {} IN ({{}})", model.table_name, primary_key); + let select_query_tmpl = format!("SELECT * FROM {} WHERE {} IN ({{}})", model.table_name, query_field.name); - let func_name_ident = format_ident!("get_many_by_{}", primary_key); + let func_name_ident = format_ident!("get_many_by_{}", query_field.name); quote! { - pub async fn #func_name_ident(&self, items_ids: &[&str]) -> Result, sqlx::Error> { - if items_ids.is_empty() { + pub async fn #func_name_ident(&self, values: &[&str]) -> Result, sqlx::Error> { + if values.is_empty() { return Ok(vec![]) } - let placeholder_params: String = (1..=(items_ids.len())) + let placeholder_params: String = (1..=(values.len())) .map(|i| format!("${}", i)) .collect::>() .join(","); let query_sql = format!(#select_query_tmpl, placeholder_params); let mut query = sqlx::query_as::<_, #resource_ident>(&query_sql); - for id in items_ids { - query = query.bind(id) + for value in values { + query = query.bind(value) } query .fetch_all(&self.db.0) @@ -72,41 +65,57 @@ fn gen_get_many_by_id_method(model: &Model) -> TokenStream { } } -fn get_mutation_fields(model: &Model) -> (Vec, Vec) { - let normal_field_names: Vec = model.fields.iter() +fn get_mutation_fields(model: &Model) -> (Vec<&Field>, Vec<&Field>) { + let normal_field_names: Vec<&Field> = model.fields.iter() .filter(|f| match f.foreign_mode { FieldForeignMode::NotRef => true, FieldForeignMode::ForeignRef(_) => false }) - .map(|f| format_ident!("{}", &f.name)) .collect(); - let foreign_keys_field_names: Vec = model.fields.iter() + let foreign_keys_field_names: Vec<&Field> = model.fields.iter() + .filter(|f| match f.foreign_mode { FieldForeignMode::NotRef => false, FieldForeignMode::ForeignRef(_) => true }) + .collect(); + (normal_field_names, foreign_keys_field_names) +} + +fn get_mutation_fields_ident(model: &Model) -> (Vec<&Field>, Vec<&Field>) { + let normal_field_names: Vec<&Field> = model.fields.iter() + .filter(|f| match f.foreign_mode { FieldForeignMode::NotRef => true, FieldForeignMode::ForeignRef(_) => false }) + .collect(); + let foreign_keys_field_names: Vec<&Field> = model.fields.iter() .filter(|f| match f.foreign_mode { FieldForeignMode::NotRef => false, FieldForeignMode::ForeignRef(_) => true }) - .map(|f| format_ident!("{}", &f.name)) .collect(); (normal_field_names, foreign_keys_field_names) } fn gen_insert_method(model: &Model) -> TokenStream { let resource_ident = format_ident!("{}", &model.name); - let sql_columns = model.fields.iter() - .map(|f| f.name.clone()) - .collect::>() - .join(", "); + let value_templates = (1..(model.fields.len()+1)) .map(|i| format!("${}", i)) .collect::>() .join(", "); + let (normal_fields, foreign_keys_fields) = get_mutation_fields(model); + let (normal_field_idents, foreign_keys_field_idents) = ( + normal_fields.iter().map(|f| format_ident!("{}", &f.name)).collect::>(), + foreign_keys_fields.iter().map(|f| format_ident!("{}", &f.name)).collect::>() + ); + + let sql_columns = [normal_fields, foreign_keys_fields].concat() + .iter() + .map(|f| f.name.clone()) + .collect::>() + .join(", "); let insert_query = format!( "INSERT INTO {} ({}) VALUES ({})", model.table_name, sql_columns, value_templates ); - let (normal_field_names, foreign_keys_field_names) = get_mutation_fields(model); + // foreign keys must be inserted first, we sort the columns so that foreign keys are first quote! { pub async fn insert(&self, entity: &#resource_ident) -> Result<(), sqlx::Error> { sqlx::query(#insert_query) - #( .bind( &entity.#normal_field_names ) )* - #( .bind( &entity.#foreign_keys_field_names.target_id) )* + #( .bind( &entity.#normal_field_idents ) )* + #( .bind( &entity.#foreign_keys_field_idents.target_id) )* .execute(&self.db.0) .await?; @@ -126,7 +135,11 @@ fn gen_insert_many_method(model: &Model) -> TokenStream { model.table_name, sql_columns ); - let (normal_field_names, foreign_keys_field_names) = get_mutation_fields(model); + let (normal_fields, foreign_keys_fields) = get_mutation_fields(model); + let (normal_field_idents, foreign_keys_field_idents) = ( + normal_fields.iter().map(|f| format_ident!("{}", &f.name)).collect::>(), + foreign_keys_fields.iter().map(|f| format_ident!("{}", &f.name)).collect::>() + ); let fields_count = model.fields.len(); quote! { @@ -149,8 +162,8 @@ fn gen_insert_many_method(model: &Model) -> TokenStream { let mut query = sqlx::query(&query_sql); for entity in entities { query = query - #( .bind( &entity.#normal_field_names ) )* - #( .bind( &entity.#foreign_keys_field_names.target_id) )*; + #( .bind( &entity.#normal_field_idents ) )* + #( .bind( &entity.#foreign_keys_field_idents.target_id) )*; } query .execute(&self.db.0) @@ -168,26 +181,34 @@ fn gen_update_by_id_method(model: &Model) -> TokenStream { .find(|f| f.is_primary) .expect("A model must have at least one primary key") .name; - let set_statements = model.fields.iter().enumerate() - .map(|(i, field)| format!("{} = ${}", field.name, i+2)) + let (normal_fields, foreign_keys_fields) = get_mutation_fields(model); + let (normal_field_idents, foreign_keys_field_idents) = ( + normal_fields.iter().map(|f| format_ident!("{}", &f.name)).collect::>(), + foreign_keys_fields.iter().map(|f| format_ident!("{}", &f.name)).collect::>() + ); + let sql_columns = [normal_fields, foreign_keys_fields].concat() + .iter() + .map(|f| f.name.clone()) + .collect::>(); + let set_statements = sql_columns.iter() + .enumerate() + .map(|(i, column_name)| format!("{} = ${}", column_name, i+2)) .collect::>() .join(", "); - - let func_name_ident = format_ident!("update_by_{}", primary_key); let update_query = format!( "UPDATE {} SET {} WHERE {} = $1", model.table_name, set_statements, primary_key ); - let (normal_field_names, foreign_keys_field_names) = get_mutation_fields(model); + let func_name_ident = format_ident!("update_by_{}", primary_key); quote! { pub async fn #func_name_ident(&self, item_id: &str, entity: &#resource_ident) -> Result<(), sqlx::Error> { sqlx::query(#update_query) .bind(item_id) - #( .bind( &entity.#normal_field_names ) )* - #( .bind( &entity.#foreign_keys_field_names.target_id) )* + #( .bind( &entity.#normal_field_idents ) )* + #( .bind( &entity.#foreign_keys_field_idents.target_id) )* .execute(&self.db.0) .await?; @@ -231,14 +252,45 @@ pub fn generate_repository_file(all_models: &[Model], model: &Model) -> Result = + model.fields.iter() + .filter(|f| f.is_query_entrypoint) + .map(|field| + gen_get_by_field_method( + model, + &field + ) + ) + .collect(); + let query_many_by_field_methods: Vec = + model.fields.iter() + .filter(|f| f.is_query_entrypoint) + .map(|field| + gen_get_many_by_field_method( + model, + &field + ) + ) + .collect(); + let fields_with_foreign_refs: Vec<&Field> = model.fields.iter().filter(|f| match f.foreign_mode { FieldForeignMode::ForeignRef(_) => true, FieldForeignMode::NotRef => false } ).collect(); @@ -278,6 +330,10 @@ pub fn generate_repository_file(all_models: &[Model], model: &Model) -> Result, is_unique: Option, - reverse_relation_name: Option + reverse_relation_name: Option, + + /// to indicate that this field will be used to obtains entities + /// our framework will generate methods for all fields that is an entrypoint + is_query_entrypoint: Option } diff --git a/lib/sqlxgentools_cli/src/models.rs b/lib/sqlxgentools_cli/src/models.rs index e16af70..8be8dac 100644 --- a/lib/sqlxgentools_cli/src/models.rs +++ b/lib/sqlxgentools_cli/src/models.rs @@ -57,7 +57,7 @@ struct Field { is_nullable: bool, is_unique: bool, is_primary: bool, - foreign_mode: FieldForeignMode, - // is_foreign_ref: bool, - // reverse_relation_name: Option + is_query_entrypoint: bool, + foreign_mode: FieldForeignMode } + diff --git a/lib/sqlxgentools_cli/src/parse_models.rs b/lib/sqlxgentools_cli/src/parse_models.rs index 0f8b5d8..396bbb7 100644 --- a/lib/sqlxgentools_cli/src/parse_models.rs +++ b/lib/sqlxgentools_cli/src/parse_models.rs @@ -127,7 +127,7 @@ fn parse_field_attribute(field: &syn::Field) -> Result { - return Err(anyhow!("Failed to parse sql_generator_field attribute macro: {}", err)); + return Err(anyhow!("Failed to parse sql_generator_field attribute macro on field {:?}, {}", field, err)); } }; } @@ -173,6 +173,7 @@ pub fn parse_models(source_code_path: &Path) -> Result> { is_nullable: false, is_primary: false, is_unique: false, + is_query_entrypoint: false, foreign_mode: FieldForeignMode::NotRef }; @@ -256,6 +257,7 @@ pub fn parse_models(source_code_path: &Path) -> Result> { if let Some(field_attr) = field_attrs_opt { output_field.is_primary = field_attr.is_primary.unwrap_or_default(); output_field.is_unique = field_attr.is_unique.unwrap_or_default(); + output_field.is_query_entrypoint = field_attr.is_query_entrypoint.unwrap_or_default(); } fields.push(output_field); @@ -279,6 +281,7 @@ fn parse_models_from_module_inner(module_path: &Path) -> Result> { let mut models: Vec = vec![]; if module_path.is_file() { + println!("Parsing models from path {:?}.", module_path); models.extend(parse_models(module_path)?); return Ok(models); }