feat: one-to-many relation helper

Allow one to specify that a field of a model is a foreign key.
It will generate a bunch of helper methods to query related entities
from one entity.
This commit is contained in:
Matthieu Bessat 2025-11-11 17:10:47 +01:00
parent 32ef1f7b33
commit 5f45671b74
25 changed files with 764 additions and 140 deletions

View file

@ -0,0 +1,295 @@
use anyhow::Result;
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::File;
use heck::ToSnakeCase;
use crate::{generators::repositories::relations::{gen_get_many_by_related_entities_method, gen_get_many_by_related_entity_method}, models::{Field, FieldForeignMode, Model}};
use crate::generators::{SourceNode, SourceNodeContainer};
fn gen_get_all_method(model: &Model) -> TokenStream {
let resource_ident = format_ident!("{}", &model.name);
let select_query = format!("SELECT * FROM {}", model.table_name);
quote! {
pub async fn get_all(&self) -> Result<Vec<#resource_ident>, sqlx::Error> {
sqlx::query_as::<_, #resource_ident>(#select_query)
.fetch_all(&self.db.0)
.await
}
}
}
fn gen_get_by_id_method(model: &Model) -> TokenStream {
let resource_ident = format_ident!("{}", &model.name);
let primary_key = &model.fields.iter()
.find(|f| f.is_primary)
.expect("A model must have at least one primary key")
.name;
let select_query = format!("SELECT * FROM {} WHERE {} = $1", model.table_name, primary_key);
let func_name_ident = format_ident!("get_by_{}", primary_key);
quote! {
pub async fn #func_name_ident(&self, item_id: &str) -> Result<#resource_ident, sqlx::Error> {
sqlx::query_as::<_, #resource_ident>(#select_query)
.bind(item_id)
.fetch_one(&self.db.0)
.await
}
}
}
fn gen_get_many_by_id_method(model: &Model) -> TokenStream {
let resource_ident = format_ident!("{}", &model.name);
let primary_key = &model.fields.iter()
.find(|f| f.is_primary)
.expect("A model must have at least one primary key")
.name;
let select_query_tmpl = format!("SELECT * FROM {} WHERE {} IN ({{}})", model.table_name, primary_key);
let func_name_ident = format_ident!("get_many_by_{}", primary_key);
quote! {
pub async fn #func_name_ident(&self, items_ids: &[&str]) -> Result<Vec<#resource_ident>, sqlx::Error> {
if items_ids.is_empty() {
return Ok(vec![])
}
let placeholder_params: String = (1..=(items_ids.len()))
.map(|i| format!("${}", i))
.collect::<Vec<String>>()
.join(",");
let query_sql = format!(#select_query_tmpl, placeholder_params);
let mut query = sqlx::query_as::<_, #resource_ident>(&query_sql);
for id in items_ids {
query = query.bind(id)
}
query
.fetch_all(&self.db.0)
.await
}
}
}
fn get_mutation_fields(model: &Model) -> (Vec<proc_macro2::Ident>, Vec<proc_macro2::Ident>) {
let normal_field_names: Vec<proc_macro2::Ident> = model.fields.iter()
.filter(|f| match f.foreign_mode { FieldForeignMode::NotRef => true, FieldForeignMode::ForeignRef(_) => false })
.map(|f| format_ident!("{}", &f.name))
.collect();
let foreign_keys_field_names: Vec<proc_macro2::Ident> = model.fields.iter()
.filter(|f| match f.foreign_mode { FieldForeignMode::NotRef => false, FieldForeignMode::ForeignRef(_) => true })
.map(|f| format_ident!("{}", &f.name))
.collect();
(normal_field_names, foreign_keys_field_names)
}
fn gen_insert_method(model: &Model) -> TokenStream {
let resource_ident = format_ident!("{}", &model.name);
let sql_columns = model.fields.iter()
.map(|f| f.name.clone())
.collect::<Vec<String>>()
.join(", ");
let value_templates = (1..(model.fields.len()+1))
.map(|i| format!("${}", i))
.collect::<Vec<String>>()
.join(", ");
let insert_query = format!(
"INSERT INTO {} ({}) VALUES ({})",
model.table_name,
sql_columns,
value_templates
);
let (normal_field_names, foreign_keys_field_names) = get_mutation_fields(model);
quote! {
pub async fn insert(&self, entity: &#resource_ident) -> Result<(), sqlx::Error> {
sqlx::query(#insert_query)
#( .bind( &entity.#normal_field_names ) )*
#( .bind( &entity.#foreign_keys_field_names.target_id) )*
.execute(&self.db.0)
.await?;
Ok(())
}
}
}
fn gen_insert_many_method(model: &Model) -> TokenStream {
let resource_ident = format_ident!("{}", &model.name);
let sql_columns = model.fields.iter()
.map(|f| f.name.clone())
.collect::<Vec<String>>()
.join(", ");
let base_insert_query = format!(
"INSERT INTO {} ({}) VALUES {{}} ON CONFLICT DO NOTHING",
model.table_name,
sql_columns
);
let (normal_field_names, foreign_keys_field_names) = get_mutation_fields(model);
let fields_count = model.fields.len();
quote! {
pub async fn insert_many(&self, entities: &Vec<#resource_ident>) -> Result<(), sqlx::Error> {
let values_templates: String = (1..(#fields_count*entities.len()+1))
.collect::<Vec<usize>>()
.chunks(#fields_count)
.map(|c| c.to_vec())
.map(|x| format!(
"({})",
x.iter()
.map(|i| format!("${}", i))
.collect::<Vec<String>>()
.join(", ")
))
.collect::<Vec<String>>()
.join(", ");
let query_sql = format!(#base_insert_query, values_templates);
let mut query = sqlx::query(&query_sql);
for entity in entities {
query = query
#( .bind( &entity.#normal_field_names ) )*
#( .bind( &entity.#foreign_keys_field_names.target_id) )*;
}
query
.execute(&self.db.0)
.await?;
Ok(())
}
}
}
fn gen_update_by_id_method(model: &Model) -> TokenStream {
let resource_ident = format_ident!("{}", &model.name);
let primary_key = &model.fields.iter()
.find(|f| f.is_primary)
.expect("A model must have at least one primary key")
.name;
let set_statements = model.fields.iter().enumerate()
.map(|(i, field)| format!("{} = ${}", field.name, i+2))
.collect::<Vec<String>>()
.join(", ");
let func_name_ident = format_ident!("update_by_{}", primary_key);
let update_query = format!(
"UPDATE {} SET {} WHERE {} = $1",
model.table_name,
set_statements,
primary_key
);
let (normal_field_names, foreign_keys_field_names) = get_mutation_fields(model);
quote! {
pub async fn #func_name_ident(&self, item_id: &str, entity: &#resource_ident) -> Result<(), sqlx::Error> {
sqlx::query(#update_query)
.bind(item_id)
#( .bind( &entity.#normal_field_names ) )*
#( .bind( &entity.#foreign_keys_field_names.target_id) )*
.execute(&self.db.0)
.await?;
Ok(())
}
}
}
fn gen_delete_by_id_method(model: &Model) -> TokenStream {
let primary_key = &model.fields.iter()
.find(|f| f.is_primary)
.expect("A model must have at least one primary key")
.name;
let func_name_ident = format_ident!("delete_by_{}", primary_key);
let query = format!(
"DELETE FROM {} WHERE {} = $1",
model.table_name,
primary_key
);
quote! {
pub async fn #func_name_ident(&self, item_id: &str) -> Result<(), sqlx::Error> {
sqlx::query(#query)
.bind(item_id)
.execute(&self.db.0)
.await?;
Ok(())
}
}
}
pub fn generate_repository_file(all_models: &[Model], model: &Model) -> Result<SourceNodeContainer> {
let resource_name = model.name.clone();
let resource_module_ident = format_ident!("{}", &model.module_path.first().unwrap());
let resource_ident = format_ident!("{}", &resource_name);
let repository_ident = format_ident!("{}Repository", resource_ident);
let get_all_method_code = gen_get_all_method(model);
let get_by_id_method_code = gen_get_by_id_method(model);
let get_many_by_id_method_code = gen_get_many_by_id_method(model);
let insert_method_code = gen_insert_method(model);
let insert_many_method_code = gen_insert_many_method(model);
let update_by_id_method_code = gen_update_by_id_method(model);
let delete_by_id_method_code = gen_delete_by_id_method(model);
let fields_with_foreign_refs: Vec<&Field> = model.fields.iter().filter(|f|
match f.foreign_mode { FieldForeignMode::ForeignRef(_) => true, FieldForeignMode::NotRef => false }
).collect();
let related_entity_methods_codes: Vec<TokenStream> = fields_with_foreign_refs.iter().map(|field|
gen_get_many_by_related_entity_method(model, &field)
).collect();
let related_entities_methods_codes: Vec<TokenStream> = fields_with_foreign_refs.iter().map(|field|
gen_get_many_by_related_entities_method(all_models, model, &field)
).collect();
// TODO: add import line
let base_repository_code: TokenStream = quote! {
use crate::models::#resource_module_ident::#resource_ident;
use crate::db::Database;
pub struct #repository_ident {
db: Database
}
impl #repository_ident {
pub fn new(db: Database) -> Self {
#repository_ident {
db
}
}
#get_all_method_code
#get_by_id_method_code
#get_many_by_id_method_code
#insert_method_code
#insert_many_method_code
#update_by_id_method_code
#delete_by_id_method_code
#(#related_entity_methods_codes)*
#(#related_entities_methods_codes)*
}
};
// convert TokenStream into rust code as string
let parse_res: syn::Result<File> = syn::parse2(base_repository_code);
let pretty = prettyplease::unparse(&parse_res?);
Ok(SourceNodeContainer {
name: format!("{}_repository.rs", model.name.to_snake_case()),
inner: SourceNode::File(pretty)
})
}

View file

@ -0,0 +1,31 @@
pub mod base;
pub mod relations;
use anyhow::Result;
use crate::generators::{SourceNode, SourceNodeContainer};
use crate::models::Model;
/// Generate base repositories for all models
pub fn generate_repositories_source_files(models: &[Model]) -> Result<SourceNodeContainer> {
let mut nodes: Vec<SourceNodeContainer> = vec![];
for model in models.iter() {
nodes.push(base::generate_repository_file(models, model)?);
// nodes.push(relations::generate_repository_file(model)?);
}
let mut mod_index_code: String = String::new();
for node in &nodes {
let module_name = node.name.replace(".rs", "");
mod_index_code.push_str(&format!("pub mod {module_name};\n"));
}
nodes.push(SourceNodeContainer {
name: "mod.rs".into(),
inner: SourceNode::File(mod_index_code.to_string())
});
Ok(SourceNodeContainer {
name: "".into(),
inner: SourceNode::Directory(nodes)
})
}

View file

@ -0,0 +1,71 @@
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use crate::models::{Field, FieldForeignMode, Model};
pub fn gen_get_many_by_related_entity_method(model: &Model, foreign_key_field: &Field) -> TokenStream {
let resource_ident = format_ident!("{}", &model.name);
let foreign_ref_params = match &foreign_key_field.foreign_mode {
FieldForeignMode::ForeignRef(params) => params,
FieldForeignMode::NotRef => {
panic!("Expected foreign key");
}
};
let select_query = format!("SELECT * FROM {} WHERE {} = $1", model.table_name, foreign_key_field.name);
let func_name_ident = format_ident!("get_many_{}_by_{}", foreign_ref_params.reverse_relation_name, foreign_ref_params.target_resource_name);
quote! {
pub async fn #func_name_ident(&self, item_id: &str) -> Result<Vec<#resource_ident>, sqlx::Error> {
sqlx::query_as::<_, #resource_ident>(#select_query)
.bind(item_id)
.fetch_all(&self.db.0)
.await
}
}
}
pub fn gen_get_many_by_related_entities_method(all_models: &[Model], model: &Model, foreign_key_field: &Field) -> TokenStream {
let resource_ident = format_ident!("{}", &model.name);
let foreign_ref_params = match &foreign_key_field.foreign_mode {
FieldForeignMode::ForeignRef(params) => params,
FieldForeignMode::NotRef => {
panic!("Expected foreign key");
}
};
let select_query = format!("SELECT * FROM {} WHERE {} IN ({{}})", model.table_name, foreign_key_field.name);
let target_resource = all_models.iter()
.find(|m| m.name.to_lowercase() == foreign_ref_params.target_resource_name.to_lowercase())
.expect("Could not find foreign ref target type associated resource");
let func_name_ident = format_ident!("get_many_{}_by_{}", foreign_ref_params.reverse_relation_name, target_resource.table_name);
quote! {
pub async fn #func_name_ident(&self, items_ids: Vec<String>) -> Result<Vec<#resource_ident>, sqlx::Error> {
if items_ids.is_empty() {
return Ok(vec![])
}
let placeholder_params: String = (1..=(items_ids.len()))
.map(|i| format!("${i}"))
.collect::<Vec<String>>()
.join(",");
let query_tmpl = format!(
#select_query,
placeholder_params
);
let mut query = sqlx::query_as::<_, #resource_ident>(&query_tmpl);
for id in items_ids {
query = query.bind(id)
}
query
.fetch_all(&self.db.0)
.await
}
}
}