From 534ed8341962b0c93251aeb2d111e61c6b71ae67 Mon Sep 17 00:00:00 2001 From: Matthieu Bessat Date: Tue, 13 Jan 2026 20:45:37 +0100 Subject: [PATCH] feat(repositories): add methods get_many_of_related_entity and delete_many_by_id --- TODO.md | 2 +- lib/sqlxgentools_attrs/src/lib.rs | 2 +- .../src/generators/repositories/base.rs | 196 +++++++++++++----- .../src/generators/repositories/relations.rs | 48 +---- lib/sqlxgentools_cli/src/main.rs | 6 +- lib/sqlxgentools_cli/src/models.rs | 6 +- lib/sqlxgentools_cli/src/parse_models.rs | 7 +- 7 files changed, 162 insertions(+), 105 deletions(-) diff --git a/TODO.md b/TODO.md index c00fc17..5339cc5 100644 --- a/TODO.md +++ b/TODO.md @@ -14,7 +14,7 @@ - insert - update - delete_by_id - - custom queries +- [ ] delete_many - [ ] Config file for project - configure models path diff --git a/lib/sqlxgentools_attrs/src/lib.rs b/lib/sqlxgentools_attrs/src/lib.rs index 5dd1506..f88bf77 100644 --- a/lib/sqlxgentools_attrs/src/lib.rs +++ b/lib/sqlxgentools_attrs/src/lib.rs @@ -12,7 +12,7 @@ pub fn derive_sql_generator_model(_input: TokenStream) -> TokenStream { TokenStream::new() } -#[proc_macro_derive(SqlGeneratorModelWithId, attributes(sql_generator_field))] +#[proc_macro_derive(SqlGeneratorModelWithId)] pub fn derive_sql_generator_model_with_id(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); let name = input.ident; diff --git a/lib/sqlxgentools_cli/src/generators/repositories/base.rs b/lib/sqlxgentools_cli/src/generators/repositories/base.rs index 51f61cb..d2f3609 100644 --- a/lib/sqlxgentools_cli/src/generators/repositories/base.rs +++ b/lib/sqlxgentools_cli/src/generators/repositories/base.rs @@ -1,10 +1,10 @@ use anyhow::Result; -use proc_macro2::TokenStream; +use proc_macro2::{TokenStream, Ident}; use quote::{format_ident, quote}; use syn::File; use heck::ToSnakeCase; -use crate::{generators::repositories::relations::{gen_get_many_by_related_entities_method, gen_get_many_by_related_entity_method}, models::{Field, FieldForeignMode, Model}}; +use crate::{generators::repositories::relations::gen_get_many_of_related_entity_method, models::{Field, FieldForeignMode, Model}}; use crate::generators::{SourceNode, SourceNodeContainer}; @@ -21,49 +21,42 @@ fn gen_get_all_method(model: &Model) -> TokenStream { } } -fn gen_get_by_id_method(model: &Model) -> TokenStream { +fn gen_get_by_field_method(model: &Model, query_field: &Field) -> TokenStream { let resource_ident = format_ident!("{}", &model.name); - let primary_key = &model.fields.iter() - .find(|f| f.is_primary) - .expect("A model must have at least one primary key") - .name; - let select_query = format!("SELECT * FROM {} WHERE {} = $1", model.table_name, primary_key); + let select_query = format!("SELECT * FROM {} WHERE {} = $1", model.table_name, query_field.name); - let func_name_ident = format_ident!("get_by_{}", primary_key); + let func_name_ident = format_ident!("get_by_{}", query_field.name); quote! { - pub async fn #func_name_ident(&self, item_id: &str) -> Result<#resource_ident, sqlx::Error> { + // FIXME: Value is not necesssarly a string, it can be an int or a bool + pub async fn #func_name_ident(&self, value: &str) -> Result<#resource_ident, sqlx::Error> { sqlx::query_as::<_, #resource_ident>(#select_query) - .bind(item_id) + .bind(value) .fetch_one(&self.db.0) .await } } } -fn gen_get_many_by_id_method(model: &Model) -> TokenStream { +fn gen_get_many_by_field_method(model: &Model, query_field: &Field) -> TokenStream { let resource_ident = format_ident!("{}", &model.name); - let primary_key = &model.fields.iter() - .find(|f| f.is_primary) - .expect("A model must have at least one primary key") - .name; - let select_query_tmpl = format!("SELECT * FROM {} WHERE {} IN ({{}})", model.table_name, primary_key); + let select_query_tmpl = format!("SELECT * FROM {} WHERE {} IN ({{}})", model.table_name, query_field.name); - let func_name_ident = format_ident!("get_many_by_{}", primary_key); + let func_name_ident = format_ident!("get_many_by_{}", query_field.name); quote! { - pub async fn #func_name_ident(&self, items_ids: &[&str]) -> Result, sqlx::Error> { - if items_ids.is_empty() { + pub async fn #func_name_ident(&self, values: &[&str]) -> Result, sqlx::Error> { + if values.is_empty() { return Ok(vec![]) } - let placeholder_params: String = (1..=(items_ids.len())) + let placeholder_params: String = (1..=(values.len())) .map(|i| format!("${}", i)) .collect::>() .join(","); let query_sql = format!(#select_query_tmpl, placeholder_params); let mut query = sqlx::query_as::<_, #resource_ident>(&query_sql); - for id in items_ids { - query = query.bind(id) + for value in values { + query = query.bind(value) } query .fetch_all(&self.db.0) @@ -72,41 +65,57 @@ fn gen_get_many_by_id_method(model: &Model) -> TokenStream { } } -fn get_mutation_fields(model: &Model) -> (Vec, Vec) { - let normal_field_names: Vec = model.fields.iter() +fn get_mutation_fields(model: &Model) -> (Vec<&Field>, Vec<&Field>) { + let normal_field_names: Vec<&Field> = model.fields.iter() .filter(|f| match f.foreign_mode { FieldForeignMode::NotRef => true, FieldForeignMode::ForeignRef(_) => false }) - .map(|f| format_ident!("{}", &f.name)) .collect(); - let foreign_keys_field_names: Vec = model.fields.iter() + let foreign_keys_field_names: Vec<&Field> = model.fields.iter() + .filter(|f| match f.foreign_mode { FieldForeignMode::NotRef => false, FieldForeignMode::ForeignRef(_) => true }) + .collect(); + (normal_field_names, foreign_keys_field_names) +} + +fn get_mutation_fields_ident(model: &Model) -> (Vec<&Field>, Vec<&Field>) { + let normal_field_names: Vec<&Field> = model.fields.iter() + .filter(|f| match f.foreign_mode { FieldForeignMode::NotRef => true, FieldForeignMode::ForeignRef(_) => false }) + .collect(); + let foreign_keys_field_names: Vec<&Field> = model.fields.iter() .filter(|f| match f.foreign_mode { FieldForeignMode::NotRef => false, FieldForeignMode::ForeignRef(_) => true }) - .map(|f| format_ident!("{}", &f.name)) .collect(); (normal_field_names, foreign_keys_field_names) } fn gen_insert_method(model: &Model) -> TokenStream { let resource_ident = format_ident!("{}", &model.name); - let sql_columns = model.fields.iter() - .map(|f| f.name.clone()) - .collect::>() - .join(", "); + let value_templates = (1..(model.fields.len()+1)) .map(|i| format!("${}", i)) .collect::>() .join(", "); + let (normal_fields, foreign_keys_fields) = get_mutation_fields(model); + let (normal_field_idents, foreign_keys_field_idents) = ( + normal_fields.iter().map(|f| format_ident!("{}", &f.name)).collect::>(), + foreign_keys_fields.iter().map(|f| format_ident!("{}", &f.name)).collect::>() + ); + + let sql_columns = [normal_fields, foreign_keys_fields].concat() + .iter() + .map(|f| f.name.clone()) + .collect::>() + .join(", "); let insert_query = format!( "INSERT INTO {} ({}) VALUES ({})", model.table_name, sql_columns, value_templates ); - let (normal_field_names, foreign_keys_field_names) = get_mutation_fields(model); + // foreign keys must be inserted first, we sort the columns so that foreign keys are first quote! { pub async fn insert(&self, entity: &#resource_ident) -> Result<(), sqlx::Error> { sqlx::query(#insert_query) - #( .bind( &entity.#normal_field_names ) )* - #( .bind( &entity.#foreign_keys_field_names.target_id) )* + #( .bind( &entity.#normal_field_idents ) )* + #( .bind( &entity.#foreign_keys_field_idents.target_id) )* .execute(&self.db.0) .await?; @@ -126,7 +135,11 @@ fn gen_insert_many_method(model: &Model) -> TokenStream { model.table_name, sql_columns ); - let (normal_field_names, foreign_keys_field_names) = get_mutation_fields(model); + let (normal_fields, foreign_keys_fields) = get_mutation_fields(model); + let (normal_field_idents, foreign_keys_field_idents) = ( + normal_fields.iter().map(|f| format_ident!("{}", &f.name)).collect::>(), + foreign_keys_fields.iter().map(|f| format_ident!("{}", &f.name)).collect::>() + ); let fields_count = model.fields.len(); quote! { @@ -149,8 +162,8 @@ fn gen_insert_many_method(model: &Model) -> TokenStream { let mut query = sqlx::query(&query_sql); for entity in entities { query = query - #( .bind( &entity.#normal_field_names ) )* - #( .bind( &entity.#foreign_keys_field_names.target_id) )*; + #( .bind( &entity.#normal_field_idents ) )* + #( .bind( &entity.#foreign_keys_field_idents.target_id) )*; } query .execute(&self.db.0) @@ -168,26 +181,34 @@ fn gen_update_by_id_method(model: &Model) -> TokenStream { .find(|f| f.is_primary) .expect("A model must have at least one primary key") .name; - let set_statements = model.fields.iter().enumerate() - .map(|(i, field)| format!("{} = ${}", field.name, i+2)) + let (normal_fields, foreign_keys_fields) = get_mutation_fields(model); + let (normal_field_idents, foreign_keys_field_idents) = ( + normal_fields.iter().map(|f| format_ident!("{}", &f.name)).collect::>(), + foreign_keys_fields.iter().map(|f| format_ident!("{}", &f.name)).collect::>() + ); + let sql_columns = [normal_fields, foreign_keys_fields].concat() + .iter() + .map(|f| f.name.clone()) + .collect::>(); + let set_statements = sql_columns.iter() + .enumerate() + .map(|(i, column_name)| format!("{} = ${}", column_name, i+2)) .collect::>() .join(", "); - - let func_name_ident = format_ident!("update_by_{}", primary_key); let update_query = format!( "UPDATE {} SET {} WHERE {} = $1", model.table_name, set_statements, primary_key ); - let (normal_field_names, foreign_keys_field_names) = get_mutation_fields(model); + let func_name_ident = format_ident!("update_by_{}", primary_key); quote! { pub async fn #func_name_ident(&self, item_id: &str, entity: &#resource_ident) -> Result<(), sqlx::Error> { sqlx::query(#update_query) .bind(item_id) - #( .bind( &entity.#normal_field_names ) )* - #( .bind( &entity.#foreign_keys_field_names.target_id) )* + #( .bind( &entity.#normal_field_idents ) )* + #( .bind( &entity.#foreign_keys_field_idents.target_id) )* .execute(&self.db.0) .await?; @@ -221,6 +242,42 @@ fn gen_delete_by_id_method(model: &Model) -> TokenStream { } } +fn gen_delete_many_by_id_method(model: &Model) -> TokenStream { + let primary_key = &model.fields.iter() + .find(|f| f.is_primary) + .expect("A model must have at least one primary key") + .name; + + let func_name_ident = format_ident!("delete_many_by_{}", primary_key); + let delete_query_tmpl = format!( + "DELETE FROM {} WHERE {} IN ({{}})", + model.table_name, + primary_key + ); + + quote! { + pub async fn #func_name_ident(&self, ids: &[&str]) -> Result<(), sqlx::Error> { + if ids.is_empty() { + return Ok(()) + } + let placeholder_params: String = (1..=(ids.len())) + .map(|i| format!("${}", i)) + .collect::>() + .join(","); + let query_sql = format!(#delete_query_tmpl, placeholder_params); + let mut query = sqlx::query(&query_sql); + for item_id in ids { + query = query.bind(item_id) + } + query + .execute(&self.db.0) + .await?; + + Ok(()) + } + } +} + pub fn generate_repository_file(all_models: &[Model], model: &Model) -> Result { let resource_name = model.name.clone(); @@ -231,22 +288,51 @@ pub fn generate_repository_file(all_models: &[Model], model: &Model) -> Result = + model.fields.iter() + .filter(|f| f.is_query_entrypoint) + .map(|field| + gen_get_by_field_method( + model, + &field + ) + ) + .collect(); + let query_many_by_field_methods: Vec = + model.fields.iter() + .filter(|f| f.is_query_entrypoint) + .map(|field| + gen_get_many_by_field_method( + model, + &field + ) + ) + .collect(); + let fields_with_foreign_refs: Vec<&Field> = model.fields.iter().filter(|f| match f.foreign_mode { FieldForeignMode::ForeignRef(_) => true, FieldForeignMode::NotRef => false } ).collect(); let related_entity_methods_codes: Vec = fields_with_foreign_refs.iter().map(|field| - gen_get_many_by_related_entity_method(model, &field) - ).collect(); - let related_entities_methods_codes: Vec = fields_with_foreign_refs.iter().map(|field| - gen_get_many_by_related_entities_method(all_models, model, &field) + gen_get_many_of_related_entity_method(model, &field) ).collect(); // TODO: add import line @@ -278,10 +364,14 @@ pub fn generate_repository_file(all_models: &[Model], model: &Model) -> Result TokenStream { +/// method that can be used to retreive a list of entities of type X that are the children of a parent type Y +/// ex: get all comments of a post +pub fn gen_get_many_of_related_entity_method(model: &Model, foreign_key_field: &Field) -> TokenStream { let resource_ident = format_ident!("{}", &model.name); let foreign_ref_params = match &foreign_key_field.foreign_mode { @@ -15,7 +17,7 @@ pub fn gen_get_many_by_related_entity_method(model: &Model, foreign_key_field: & let select_query = format!("SELECT * FROM {} WHERE {} = $1", model.table_name, foreign_key_field.name); - let func_name_ident = format_ident!("get_many_{}_by_{}", foreign_ref_params.reverse_relation_name, foreign_ref_params.target_resource_name); + let func_name_ident = format_ident!("get_many_of_{}", foreign_ref_params.target_resource_name); quote! { pub async fn #func_name_ident(&self, item_id: &str) -> Result, sqlx::Error> { @@ -27,45 +29,3 @@ pub fn gen_get_many_by_related_entity_method(model: &Model, foreign_key_field: & } } -pub fn gen_get_many_by_related_entities_method(all_models: &[Model], model: &Model, foreign_key_field: &Field) -> TokenStream { - let resource_ident = format_ident!("{}", &model.name); - - let foreign_ref_params = match &foreign_key_field.foreign_mode { - FieldForeignMode::ForeignRef(params) => params, - FieldForeignMode::NotRef => { - panic!("Expected foreign key"); - } - }; - - let select_query = format!("SELECT * FROM {} WHERE {} IN ({{}})", model.table_name, foreign_key_field.name); - - let target_resource = all_models.iter() - .find(|m| m.name.to_lowercase() == foreign_ref_params.target_resource_name.to_lowercase()) - .expect("Could not find foreign ref target type associated resource"); - - let func_name_ident = format_ident!("get_many_{}_by_{}", foreign_ref_params.reverse_relation_name, target_resource.table_name); - - quote! { - pub async fn #func_name_ident(&self, items_ids: Vec) -> Result, sqlx::Error> { - if items_ids.is_empty() { - return Ok(vec![]) - } - let placeholder_params: String = (1..=(items_ids.len())) - .map(|i| format!("${i}")) - .collect::>() - .join(","); - let query_tmpl = format!( - #select_query, - placeholder_params - ); - let mut query = sqlx::query_as::<_, #resource_ident>(&query_tmpl); - for id in items_ids { - query = query.bind(id) - } - - query - .fetch_all(&self.db.0) - .await - } - } -} diff --git a/lib/sqlxgentools_cli/src/main.rs b/lib/sqlxgentools_cli/src/main.rs index 17cfc13..d6c1423 100644 --- a/lib/sqlxgentools_cli/src/main.rs +++ b/lib/sqlxgentools_cli/src/main.rs @@ -24,7 +24,11 @@ pub struct SqlGeneratorModelAttr { pub struct SqlGeneratorFieldAttr { is_primary: Option, is_unique: Option, - reverse_relation_name: Option + reverse_relation_name: Option, + + /// to indicate that this field will be used to obtains entities + /// our framework will generate methods for all fields that is an entrypoint + is_query_entrypoint: Option } diff --git a/lib/sqlxgentools_cli/src/models.rs b/lib/sqlxgentools_cli/src/models.rs index e16af70..8be8dac 100644 --- a/lib/sqlxgentools_cli/src/models.rs +++ b/lib/sqlxgentools_cli/src/models.rs @@ -57,7 +57,7 @@ struct Field { is_nullable: bool, is_unique: bool, is_primary: bool, - foreign_mode: FieldForeignMode, - // is_foreign_ref: bool, - // reverse_relation_name: Option + is_query_entrypoint: bool, + foreign_mode: FieldForeignMode } + diff --git a/lib/sqlxgentools_cli/src/parse_models.rs b/lib/sqlxgentools_cli/src/parse_models.rs index 0f8b5d8..bea7a43 100644 --- a/lib/sqlxgentools_cli/src/parse_models.rs +++ b/lib/sqlxgentools_cli/src/parse_models.rs @@ -127,7 +127,7 @@ fn parse_field_attribute(field: &syn::Field) -> Result { - return Err(anyhow!("Failed to parse sql_generator_field attribute macro: {}", err)); + return Err(anyhow!("Failed to parse sql_generator_field attribute macro on field {:?}, {}", field, err)); } }; } @@ -173,6 +173,7 @@ pub fn parse_models(source_code_path: &Path) -> Result> { is_nullable: false, is_primary: false, is_unique: false, + is_query_entrypoint: false, foreign_mode: FieldForeignMode::NotRef }; @@ -247,7 +248,7 @@ pub fn parse_models(source_code_path: &Path) -> Result> { output_field.foreign_mode = FieldForeignMode::ForeignRef( ForeignRefParams { reverse_relation_name: rrn, - target_resource_name: target_type_name.to_lowercase() + target_resource_name: target_type_name.to_case(Case::Snake) } ); } @@ -256,6 +257,7 @@ pub fn parse_models(source_code_path: &Path) -> Result> { if let Some(field_attr) = field_attrs_opt { output_field.is_primary = field_attr.is_primary.unwrap_or_default(); output_field.is_unique = field_attr.is_unique.unwrap_or_default(); + output_field.is_query_entrypoint = field_attr.is_query_entrypoint.unwrap_or_default(); } fields.push(output_field); @@ -279,6 +281,7 @@ fn parse_models_from_module_inner(module_path: &Path) -> Result> { let mut models: Vec = vec![]; if module_path.is_file() { + println!("Parsing models from path {:?}.", module_path); models.extend(parse_models(module_path)?); return Ok(models); }