style: apply cargo fmt

This commit is contained in:
Matthieu Bessat 2026-01-13 20:47:19 +01:00
parent 534ed83419
commit d205d722aa
17 changed files with 390 additions and 315 deletions

View file

@ -1,11 +1,12 @@
use anyhow::Context; use anyhow::Context;
use std::str::FromStr;
use std::path::PathBuf;
use anyhow::Result; use anyhow::Result;
use std::path::PathBuf;
use std::str::FromStr;
use fully_pub::fully_pub; use fully_pub::fully_pub;
use sqlx::{ use sqlx::{
Pool, Sqlite, sqlite::{SqliteConnectOptions, SqlitePoolOptions}, sqlite::{SqliteConnectOptions, SqlitePoolOptions},
Pool, Sqlite,
}; };
/// database storage interface /// database storage interface
@ -13,7 +14,6 @@ use sqlx::{
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
struct Database(Pool<Sqlite>); struct Database(Pool<Sqlite>);
/// Initialize database /// Initialize database
pub async fn provide_database(sqlite_db_path: &str) -> Result<Database> { pub async fn provide_database(sqlite_db_path: &str) -> Result<Database> {
let path = PathBuf::from(sqlite_db_path); let path = PathBuf::from(sqlite_db_path);
@ -37,5 +37,3 @@ pub async fn provide_database(sqlite_db_path: &str) -> Result<Database> {
Ok(Database(pool)) Ok(Database(pool))
} }

View file

@ -1,3 +1,3 @@
pub mod repositories;
pub mod db; pub mod db;
pub mod models; pub mod models;
pub mod repositories;

View file

@ -4,12 +4,15 @@ use chrono::Utc;
use sqlx::types::Json; use sqlx::types::Json;
use sqlxgentools_misc::ForeignRef; use sqlxgentools_misc::ForeignRef;
use crate::{db::provide_database, models::user::{User, UserToken}, repositories::user_token_repository::UserTokenRepository}; use crate::{
db::provide_database,
models::user::{User, UserToken},
repositories::user_token_repository::UserTokenRepository,
};
pub mod db;
pub mod models; pub mod models;
pub mod repositories; pub mod repositories;
pub mod db;
#[tokio::main] #[tokio::main]
async fn main() -> Result<()> { async fn main() -> Result<()> {
@ -24,7 +27,7 @@ async fn main() -> Result<()> {
last_login_at: None, last_login_at: None,
status: models::user::UserStatus::Invited, status: models::user::UserStatus::Invited,
groups: Json(vec![]), groups: Json(vec![]),
avatar_bytes: None avatar_bytes: None,
}, },
User { User {
id: "idu2".into(), id: "idu2".into(),
@ -34,7 +37,7 @@ async fn main() -> Result<()> {
last_login_at: None, last_login_at: None,
status: models::user::UserStatus::Invited, status: models::user::UserStatus::Invited,
groups: Json(vec![]), groups: Json(vec![]),
avatar_bytes: None avatar_bytes: None,
}, },
User { User {
id: "idu3".into(), id: "idu3".into(),
@ -44,8 +47,8 @@ async fn main() -> Result<()> {
last_login_at: None, last_login_at: None,
status: models::user::UserStatus::Invited, status: models::user::UserStatus::Invited,
groups: Json(vec![]), groups: Json(vec![]),
avatar_bytes: None avatar_bytes: None,
} },
]; ];
let user_token = UserToken { let user_token = UserToken {
id: "idtoken1".into(), id: "idtoken1".into(),
@ -53,43 +56,44 @@ async fn main() -> Result<()> {
last_use_time: None, last_use_time: None,
creation_time: Utc::now(), creation_time: Utc::now(),
expiration_time: Utc::now(), expiration_time: Utc::now(),
user_id: ForeignRef::new(&users.get(0).unwrap()) user_id: ForeignRef::new(&users.get(0).unwrap()),
}; };
let db = provide_database("tmp/db.db").await?; let db = provide_database("tmp/db.db").await?;
let user_token_repo = UserTokenRepository::new(db); let user_token_repo = UserTokenRepository::new(db);
user_token_repo.insert_many(&vec![ user_token_repo
UserToken { .insert_many(&vec![
id: "idtoken2".into(), UserToken {
secret: "4LP5A3F3XBV5NM8VXRGZG3QDXO9PNAC0".into(), id: "idtoken2".into(),
last_use_time: None, secret: "4LP5A3F3XBV5NM8VXRGZG3QDXO9PNAC0".into(),
creation_time: Utc::now(), last_use_time: None,
expiration_time: Utc::now(), creation_time: Utc::now(),
user_id: ForeignRef::new(&users.get(0).unwrap()) expiration_time: Utc::now(),
}, user_id: ForeignRef::new(&users.get(0).unwrap()),
UserToken { },
id: "idtoken3".into(), UserToken {
secret: "CBHR6G41KSEMR1AI".into(), id: "idtoken3".into(),
last_use_time: None, secret: "CBHR6G41KSEMR1AI".into(),
creation_time: Utc::now(), last_use_time: None,
expiration_time: Utc::now(), creation_time: Utc::now(),
user_id: ForeignRef::new(&users.get(1).unwrap()) expiration_time: Utc::now(),
}, user_id: ForeignRef::new(&users.get(1).unwrap()),
UserToken { },
id: "idtoken4".into(), UserToken {
secret: "CBHR6G41KSEMR1AI".into(), id: "idtoken4".into(),
last_use_time: None, secret: "CBHR6G41KSEMR1AI".into(),
creation_time: Utc::now(), last_use_time: None,
expiration_time: Utc::now(), creation_time: Utc::now(),
user_id: ForeignRef::new(&users.get(1).unwrap()) expiration_time: Utc::now(),
} user_id: ForeignRef::new(&users.get(1).unwrap()),
]).await?; },
let user_tokens = user_token_repo.get_many_user_tokens_by_usersss( ])
vec!["idu2".into()] .await?;
).await?; let user_tokens = user_token_repo
.get_many_user_tokens_by_usersss(vec!["idu2".into()])
.await?;
dbg!(&user_tokens); dbg!(&user_tokens);
Ok(()) Ok(())
} }

View file

@ -1,8 +1,8 @@
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use sqlx::types::Json;
use fully_pub::fully_pub; use fully_pub::fully_pub;
use sqlx::types::Json;
use sqlxgentools_attrs::{SqlGeneratorDerive, SqlGeneratorModelWithId, sql_generator_model}; use sqlxgentools_attrs::{sql_generator_model, SqlGeneratorDerive, SqlGeneratorModelWithId};
use sqlxgentools_misc::{DatabaseLine, ForeignRef}; use sqlxgentools_misc::{DatabaseLine, ForeignRef};
#[derive(sqlx::Type, Clone, Debug, PartialEq)] #[derive(sqlx::Type, Clone, Debug, PartialEq)]
@ -11,37 +11,36 @@ enum UserStatus {
Disabled, Disabled,
Invited, Invited,
Active, Active,
Archived Archived,
} }
#[derive(SqlGeneratorDerive, SqlGeneratorModelWithId, sqlx::FromRow, Debug, Clone)] #[derive(SqlGeneratorDerive, SqlGeneratorModelWithId, sqlx::FromRow, Debug, Clone)]
#[sql_generator_model(table_name="usersss")] #[sql_generator_model(table_name = "usersss")]
#[fully_pub] #[fully_pub]
struct User { struct User {
#[sql_generator_field(is_primary=true)] #[sql_generator_field(is_primary = true)]
id: String, id: String,
#[sql_generator_field(is_unique=true)] #[sql_generator_field(is_unique = true)]
handle: String, handle: String,
full_name: Option<String>, full_name: Option<String>,
prefered_color: Option<i64>, prefered_color: Option<i64>,
last_login_at: Option<DateTime<Utc>>, last_login_at: Option<DateTime<Utc>>,
status: UserStatus, status: UserStatus,
groups: Json<Vec<String>>, groups: Json<Vec<String>>,
avatar_bytes: Option<Vec<u8>> avatar_bytes: Option<Vec<u8>>,
} }
#[derive(SqlGeneratorDerive, SqlGeneratorModelWithId, sqlx::FromRow, Debug, Clone)] #[derive(SqlGeneratorDerive, SqlGeneratorModelWithId, sqlx::FromRow, Debug, Clone)]
#[sql_generator_model(table_name="user_tokens")] #[sql_generator_model(table_name = "user_tokens")]
#[fully_pub] #[fully_pub]
struct UserToken { struct UserToken {
#[sql_generator_field(is_primary=true)] #[sql_generator_field(is_primary = true)]
id: String, id: String,
secret: String, secret: String,
last_use_time: Option<DateTime<Utc>>, last_use_time: Option<DateTime<Utc>>,
creation_time: DateTime<Utc>, creation_time: DateTime<Utc>,
expiration_time: DateTime<Utc>, expiration_time: DateTime<Utc>,
#[sql_generator_field(reverse_relation_name="user_tokens")] // to generate get_user_tokens_of_user(&user_id) #[sql_generator_field(reverse_relation_name = "user_tokens")]
user_id: ForeignRef<User> // to generate get_user_tokens_of_user(&user_id)
user_id: ForeignRef<User>,
} }

View file

@ -1,5 +1,5 @@
use crate::models::user::User;
use crate::db::Database; use crate::db::Database;
use crate::models::user::User;
pub struct UserRepository { pub struct UserRepository {
db: Database, db: Database,
} }
@ -8,7 +8,9 @@ impl UserRepository {
UserRepository { db } UserRepository { db }
} }
pub async fn get_all(&self) -> Result<Vec<User>, sqlx::Error> { pub async fn get_all(&self) -> Result<Vec<User>, sqlx::Error> {
sqlx::query_as::<_, User>("SELECT * FROM usersss").fetch_all(&self.db.0).await sqlx::query_as::<_, User>("SELECT * FROM usersss")
.fetch_all(&self.db.0)
.await
} }
pub async fn get_by_id(&self, item_id: &str) -> Result<User, sqlx::Error> { pub async fn get_by_id(&self, item_id: &str) -> Result<User, sqlx::Error> {
sqlx::query_as::<_, User>("SELECT * FROM usersss WHERE id = $1") sqlx::query_as::<_, User>("SELECT * FROM usersss WHERE id = $1")
@ -16,10 +18,7 @@ impl UserRepository {
.fetch_one(&self.db.0) .fetch_one(&self.db.0)
.await .await
} }
pub async fn get_many_by_id( pub async fn get_many_by_id(&self, items_ids: &[&str]) -> Result<Vec<User>, sqlx::Error> {
&self,
items_ids: &[&str],
) -> Result<Vec<User>, sqlx::Error> {
if items_ids.is_empty() { if items_ids.is_empty() {
return Ok(vec![]); return Ok(vec![]);
} }
@ -27,9 +26,7 @@ impl UserRepository {
.map(|i| format!("${}", i)) .map(|i| format!("${}", i))
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join(","); .join(",");
let query_sql = format!( let query_sql = format!("SELECT * FROM usersss WHERE id IN ({})", placeholder_params);
"SELECT * FROM usersss WHERE id IN ({})", placeholder_params
);
let mut query = sqlx::query_as::<_, User>(&query_sql); let mut query = sqlx::query_as::<_, User>(&query_sql);
for id in items_ids { for id in items_ids {
query = query.bind(id); query = query.bind(id);
@ -59,8 +56,11 @@ impl UserRepository {
.map(|c| c.to_vec()) .map(|c| c.to_vec())
.map(|x| { .map(|x| {
format!( format!(
"({})", x.iter().map(| i | format!("${}", i)).collect:: < Vec < "({})",
String >> ().join(", ") x.iter()
.map(|i| format!("${}", i))
.collect::<Vec<String>>()
.join(", ")
) )
}) })
.collect::<Vec<String>>() .collect::<Vec<String>>()
@ -84,11 +84,7 @@ impl UserRepository {
query.execute(&self.db.0).await?; query.execute(&self.db.0).await?;
Ok(()) Ok(())
} }
pub async fn update_by_id( pub async fn update_by_id(&self, item_id: &str, entity: &User) -> Result<(), sqlx::Error> {
&self,
item_id: &str,
entity: &User,
) -> Result<(), sqlx::Error> {
sqlx::query( sqlx::query(
"UPDATE usersss SET id = $2, handle = $3, full_name = $4, prefered_color = $5, last_login_at = $6, status = $7, groups = $8, avatar_bytes = $9 WHERE id = $1", "UPDATE usersss SET id = $2, handle = $3, full_name = $4, prefered_color = $5, last_login_at = $6, status = $7, groups = $8, avatar_bytes = $9 WHERE id = $1",
) )

View file

@ -1,5 +1,5 @@
use crate::models::user::UserToken;
use crate::db::Database; use crate::db::Database;
use crate::models::user::UserToken;
pub struct UserTokenRepository { pub struct UserTokenRepository {
db: Database, db: Database,
} }
@ -18,10 +18,7 @@ impl UserTokenRepository {
.fetch_one(&self.db.0) .fetch_one(&self.db.0)
.await .await
} }
pub async fn get_many_by_id( pub async fn get_many_by_id(&self, items_ids: &[&str]) -> Result<Vec<UserToken>, sqlx::Error> {
&self,
items_ids: &[&str],
) -> Result<Vec<UserToken>, sqlx::Error> {
if items_ids.is_empty() { if items_ids.is_empty() {
return Ok(vec![]); return Ok(vec![]);
} }
@ -30,7 +27,8 @@ impl UserTokenRepository {
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join(","); .join(",");
let query_sql = format!( let query_sql = format!(
"SELECT * FROM user_tokens WHERE id IN ({})", placeholder_params "SELECT * FROM user_tokens WHERE id IN ({})",
placeholder_params
); );
let mut query = sqlx::query_as::<_, UserToken>(&query_sql); let mut query = sqlx::query_as::<_, UserToken>(&query_sql);
for id in items_ids { for id in items_ids {
@ -52,18 +50,18 @@ impl UserTokenRepository {
.await?; .await?;
Ok(()) Ok(())
} }
pub async fn insert_many( pub async fn insert_many(&self, entities: &Vec<UserToken>) -> Result<(), sqlx::Error> {
&self,
entities: &Vec<UserToken>,
) -> Result<(), sqlx::Error> {
let values_templates: String = (1..(6usize * entities.len() + 1)) let values_templates: String = (1..(6usize * entities.len() + 1))
.collect::<Vec<usize>>() .collect::<Vec<usize>>()
.chunks(6usize) .chunks(6usize)
.map(|c| c.to_vec()) .map(|c| c.to_vec())
.map(|x| { .map(|x| {
format!( format!(
"({})", x.iter().map(| i | format!("${}", i)).collect:: < Vec < "({})",
String >> ().join(", ") x.iter()
.map(|i| format!("${}", i))
.collect::<Vec<String>>()
.join(", ")
) )
}) })
.collect::<Vec<String>>() .collect::<Vec<String>>()
@ -85,11 +83,7 @@ impl UserTokenRepository {
query.execute(&self.db.0).await?; query.execute(&self.db.0).await?;
Ok(()) Ok(())
} }
pub async fn update_by_id( pub async fn update_by_id(&self, item_id: &str, entity: &UserToken) -> Result<(), sqlx::Error> {
&self,
item_id: &str,
entity: &UserToken,
) -> Result<(), sqlx::Error> {
sqlx::query( sqlx::query(
"UPDATE user_tokens SET id = $2, secret = $3, last_use_time = $4, creation_time = $5, expiration_time = $6, user_id = $7 WHERE id = $1", "UPDATE user_tokens SET id = $2, secret = $3, last_use_time = $4, creation_time = $5, expiration_time = $6, user_id = $7 WHERE id = $1",
) )
@ -132,7 +126,8 @@ impl UserTokenRepository {
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join(","); .join(",");
let query_tmpl = format!( let query_tmpl = format!(
"SELECT * FROM user_tokens WHERE user_id IN ({})", placeholder_params "SELECT * FROM user_tokens WHERE user_id IN ({})",
placeholder_params
); );
let mut query = sqlx::query_as::<_, UserToken>(&query_tmpl); let mut query = sqlx::query_as::<_, UserToken>(&query_tmpl);
for id in items_ids { for id in items_ids {

View file

@ -5,7 +5,10 @@
use std::assert_matches::assert_matches; use std::assert_matches::assert_matches;
use chrono::Utc; use chrono::Utc;
use sandbox::{models::user::{User, UserStatus}, repositories::user_repository::UserRepository}; use sandbox::{
models::user::{User, UserStatus},
repositories::user_repository::UserRepository,
};
use sqlx::{types::Json, Pool, Sqlite}; use sqlx::{types::Json, Pool, Sqlite};
#[sqlx::test(fixtures("../src/migrations/all.sql"))] #[sqlx::test(fixtures("../src/migrations/all.sql"))]
@ -22,43 +25,40 @@ async fn test_user_repository_create_read_update_delete(pool: Pool<Sqlite>) -> s
last_login_at: Some(Utc::now()), last_login_at: Some(Utc::now()),
status: UserStatus::Invited, status: UserStatus::Invited,
groups: Json(vec!["artists".into()]), groups: Json(vec!["artists".into()]),
avatar_bytes: vec![0x00] avatar_bytes: vec![0x00],
}; };
assert_matches!(user_repo.insert(&new_user).await, Ok(()));
assert_matches!( assert_matches!(
user_repo.insert(&new_user).await, user_repo
Ok(()) .get_by_id("ffffffff-0000-4000-0000-0000000000c9".into())
); .await,
assert_matches!(
user_repo.get_by_id("ffffffff-0000-4000-0000-0000000000c9".into()).await,
Ok(User { .. }) Ok(User { .. })
); );
assert_matches!( assert_matches!(
user_repo.get_by_id("ffffffff-0000-4040-0000-000000000000".into()).await, user_repo
.get_by_id("ffffffff-0000-4040-0000-000000000000".into())
.await,
Err(sqlx::Error::RowNotFound) Err(sqlx::Error::RowNotFound)
); );
// Insert Many // Insert Many
let bunch_of_users: Vec<User> = (0..10).map(|pid| User { let bunch_of_users: Vec<User> = (0..10)
id: format!("ffffffff-0000-4000-0010-{:0>8}", pid), .map(|pid| User {
handle: format!("user num {}", pid), id: format!("ffffffff-0000-4000-0010-{:0>8}", pid),
full_name: None, handle: format!("user num {}", pid),
prefered_color: None, full_name: None,
last_login_at: None, prefered_color: None,
status: UserStatus::Invited, last_login_at: None,
groups: Json(vec![]), status: UserStatus::Invited,
avatar_bytes: vec![] groups: Json(vec![]),
}).collect(); avatar_bytes: vec![],
assert_matches!( })
user_repo.insert_many(&bunch_of_users).await, .collect();
Ok(()) assert_matches!(user_repo.insert_many(&bunch_of_users).await, Ok(()));
);
// Read many all // Read many all
let read_all_res = user_repo.get_all().await; let read_all_res = user_repo.get_all().await;
assert_matches!( assert_matches!(read_all_res, Ok(..));
read_all_res,
Ok(..)
);
let all_users = read_all_res.unwrap(); let all_users = read_all_res.unwrap();
assert_eq!(all_users.len(), 11); assert_eq!(all_users.len(), 11);
@ -69,16 +69,18 @@ async fn test_user_repository_create_read_update_delete(pool: Pool<Sqlite>) -> s
user_repo.update_by_id(&new_user.id, &updated_user).await, user_repo.update_by_id(&new_user.id, &updated_user).await,
Ok(()) Ok(())
); );
let user_from_db = user_repo.get_by_id("ffffffff-0000-4000-0000-0000000000c9".into()).await.unwrap(); let user_from_db = user_repo
.get_by_id("ffffffff-0000-4000-0000-0000000000c9".into())
.await
.unwrap();
assert_eq!(user_from_db.status, UserStatus::Disabled); assert_eq!(user_from_db.status, UserStatus::Disabled);
// Delete // Delete
assert_matches!(user_repo.delete_by_id(&new_user.id).await, Ok(()));
assert_matches!( assert_matches!(
user_repo.delete_by_id(&new_user.id).await, user_repo
Ok(()) .get_by_id("ffffffff-0000-4000-0000-0000000000c9".into())
); .await,
assert_matches!(
user_repo.get_by_id("ffffffff-0000-4000-0000-0000000000c9".into()).await,
Err(sqlx::Error::RowNotFound) Err(sqlx::Error::RowNotFound)
); );

View file

@ -1,6 +1,6 @@
use proc_macro::TokenStream; use proc_macro::TokenStream;
use quote::quote; use quote::quote;
use syn::{DeriveInput, Fields, parse_macro_input}; use syn::{parse_macro_input, DeriveInput, Fields};
#[proc_macro_attribute] #[proc_macro_attribute]
pub fn sql_generator_model(_attr: TokenStream, item: TokenStream) -> TokenStream { pub fn sql_generator_model(_attr: TokenStream, item: TokenStream) -> TokenStream {
@ -38,4 +38,3 @@ pub fn derive_sql_generator_model_with_id(input: TokenStream) -> TokenStream {
// If `id` field is not found, return an error // If `id` field is not found, return an error
panic!("Expected struct with a named field `id` of type String") panic!("Expected struct with a named field `id` of type String")
} }

View file

@ -1,8 +1,7 @@
use anyhow::{Result, anyhow}; use anyhow::{anyhow, Result};
use crate::models::{Field, Model}; use crate::models::{Field, Model};
// Implementations // Implementations
impl Field { impl Field {
/// return sqlite type /// return sqlite type
@ -21,7 +20,7 @@ impl Field {
"DateTime" => Some("DATETIME".into()), "DateTime" => Some("DATETIME".into()),
"Json" => Some("TEXT".into()), "Json" => Some("TEXT".into()),
"Vec<u8>" => Some("BLOB".into()), "Vec<u8>" => Some("BLOB".into()),
_ => Some("TEXT".into()) _ => Some("TEXT".into()),
} }
} }
} }
@ -35,8 +34,10 @@ pub fn generate_create_table_sql(models: &[Model]) -> Result<String> {
let mut fields_sql: Vec<String> = vec![]; let mut fields_sql: Vec<String> = vec![];
for field in model.fields.iter() { for field in model.fields.iter() {
let mut additions: String = "".into(); let mut additions: String = "".into();
let sql_type = field.sql_type() let sql_type = field.sql_type().ok_or(anyhow!(format!(
.ok_or(anyhow!(format!("Could not find SQL type for field {}", field.name)))?; "Could not find SQL type for field {}",
field.name
)))?;
if !field.is_nullable { if !field.is_nullable {
additions.push_str(" NOT NULL"); additions.push_str(" NOT NULL");
} }
@ -46,20 +47,15 @@ pub fn generate_create_table_sql(models: &[Model]) -> Result<String> {
if field.is_primary { if field.is_primary {
additions.push_str(" PRIMARY KEY"); additions.push_str(" PRIMARY KEY");
} }
fields_sql.push( fields_sql.push(format!("\t{: <#18}\t{}{}", field.name, sql_type, additions));
format!("\t{: <#18}\t{}{}", field.name, sql_type, additions)
);
} }
sql_code.push_str( sql_code.push_str(&format!(
&format!( "CREATE TABLE {} (\n{}\n);\n",
"CREATE TABLE {} (\n{}\n);\n", model.table_name,
model.table_name, fields_sql.join(",\n")
fields_sql.join(",\n") ));
)
);
} }
Ok(sql_code) Ok(sql_code)
} }

View file

@ -8,13 +8,12 @@ pub mod repositories;
#[fully_pub] #[fully_pub]
enum SourceNode { enum SourceNode {
File(String), File(String),
Directory(Vec<SourceNodeContainer>) Directory(Vec<SourceNodeContainer>),
} }
#[derive(Serialize, Debug)] #[derive(Serialize, Debug)]
#[fully_pub] #[fully_pub]
struct SourceNodeContainer { struct SourceNodeContainer {
name: String, name: String,
inner: SourceNode inner: SourceNode,
} }

View file

@ -1,12 +1,14 @@
use anyhow::Result; use anyhow::Result;
use proc_macro2::{TokenStream, Ident}; use heck::ToSnakeCase;
use proc_macro2::{Ident, TokenStream};
use quote::{format_ident, quote}; use quote::{format_ident, quote};
use syn::File; use syn::File;
use heck::ToSnakeCase;
use crate::{generators::repositories::relations::gen_get_many_of_related_entity_method, models::{Field, FieldForeignMode, Model}};
use crate::generators::{SourceNode, SourceNodeContainer}; use crate::generators::{SourceNode, SourceNodeContainer};
use crate::{
generators::repositories::relations::gen_get_many_of_related_entity_method,
models::{Field, FieldForeignMode, Model},
};
fn gen_get_all_method(model: &Model) -> TokenStream { fn gen_get_all_method(model: &Model) -> TokenStream {
let resource_ident = format_ident!("{}", &model.name); let resource_ident = format_ident!("{}", &model.name);
@ -23,7 +25,10 @@ fn gen_get_all_method(model: &Model) -> TokenStream {
fn gen_get_by_field_method(model: &Model, query_field: &Field) -> TokenStream { fn gen_get_by_field_method(model: &Model, query_field: &Field) -> TokenStream {
let resource_ident = format_ident!("{}", &model.name); let resource_ident = format_ident!("{}", &model.name);
let select_query = format!("SELECT * FROM {} WHERE {} = $1", model.table_name, query_field.name); let select_query = format!(
"SELECT * FROM {} WHERE {} = $1",
model.table_name, query_field.name
);
let func_name_ident = format_ident!("get_by_{}", query_field.name); let func_name_ident = format_ident!("get_by_{}", query_field.name);
@ -40,7 +45,10 @@ fn gen_get_by_field_method(model: &Model, query_field: &Field) -> TokenStream {
fn gen_get_many_by_field_method(model: &Model, query_field: &Field) -> TokenStream { fn gen_get_many_by_field_method(model: &Model, query_field: &Field) -> TokenStream {
let resource_ident = format_ident!("{}", &model.name); let resource_ident = format_ident!("{}", &model.name);
let select_query_tmpl = format!("SELECT * FROM {} WHERE {} IN ({{}})", model.table_name, query_field.name); let select_query_tmpl = format!(
"SELECT * FROM {} WHERE {} IN ({{}})",
model.table_name, query_field.name
);
let func_name_ident = format_ident!("get_many_by_{}", query_field.name); let func_name_ident = format_ident!("get_many_by_{}", query_field.name);
@ -66,21 +74,41 @@ fn gen_get_many_by_field_method(model: &Model, query_field: &Field) -> TokenStre
} }
fn get_mutation_fields(model: &Model) -> (Vec<&Field>, Vec<&Field>) { fn get_mutation_fields(model: &Model) -> (Vec<&Field>, Vec<&Field>) {
let normal_field_names: Vec<&Field> = model.fields.iter() let normal_field_names: Vec<&Field> = model
.filter(|f| match f.foreign_mode { FieldForeignMode::NotRef => true, FieldForeignMode::ForeignRef(_) => false }) .fields
.iter()
.filter(|f| match f.foreign_mode {
FieldForeignMode::NotRef => true,
FieldForeignMode::ForeignRef(_) => false,
})
.collect(); .collect();
let foreign_keys_field_names: Vec<&Field> = model.fields.iter() let foreign_keys_field_names: Vec<&Field> = model
.filter(|f| match f.foreign_mode { FieldForeignMode::NotRef => false, FieldForeignMode::ForeignRef(_) => true }) .fields
.iter()
.filter(|f| match f.foreign_mode {
FieldForeignMode::NotRef => false,
FieldForeignMode::ForeignRef(_) => true,
})
.collect(); .collect();
(normal_field_names, foreign_keys_field_names) (normal_field_names, foreign_keys_field_names)
} }
fn get_mutation_fields_ident(model: &Model) -> (Vec<&Field>, Vec<&Field>) { fn get_mutation_fields_ident(model: &Model) -> (Vec<&Field>, Vec<&Field>) {
let normal_field_names: Vec<&Field> = model.fields.iter() let normal_field_names: Vec<&Field> = model
.filter(|f| match f.foreign_mode { FieldForeignMode::NotRef => true, FieldForeignMode::ForeignRef(_) => false }) .fields
.iter()
.filter(|f| match f.foreign_mode {
FieldForeignMode::NotRef => true,
FieldForeignMode::ForeignRef(_) => false,
})
.collect(); .collect();
let foreign_keys_field_names: Vec<&Field> = model.fields.iter() let foreign_keys_field_names: Vec<&Field> = model
.filter(|f| match f.foreign_mode { FieldForeignMode::NotRef => false, FieldForeignMode::ForeignRef(_) => true }) .fields
.iter()
.filter(|f| match f.foreign_mode {
FieldForeignMode::NotRef => false,
FieldForeignMode::ForeignRef(_) => true,
})
.collect(); .collect();
(normal_field_names, foreign_keys_field_names) (normal_field_names, foreign_keys_field_names)
} }
@ -88,26 +116,31 @@ fn get_mutation_fields_ident(model: &Model) -> (Vec<&Field>, Vec<&Field>) {
fn gen_insert_method(model: &Model) -> TokenStream { fn gen_insert_method(model: &Model) -> TokenStream {
let resource_ident = format_ident!("{}", &model.name); let resource_ident = format_ident!("{}", &model.name);
let value_templates = (1..(model.fields.len()+1)) let value_templates = (1..(model.fields.len() + 1))
.map(|i| format!("${}", i)) .map(|i| format!("${}", i))
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join(", "); .join(", ");
let (normal_fields, foreign_keys_fields) = get_mutation_fields(model); let (normal_fields, foreign_keys_fields) = get_mutation_fields(model);
let (normal_field_idents, foreign_keys_field_idents) = ( let (normal_field_idents, foreign_keys_field_idents) = (
normal_fields.iter().map(|f| format_ident!("{}", &f.name)).collect::<Vec<Ident>>(), normal_fields
foreign_keys_fields.iter().map(|f| format_ident!("{}", &f.name)).collect::<Vec<Ident>>() .iter()
.map(|f| format_ident!("{}", &f.name))
.collect::<Vec<Ident>>(),
foreign_keys_fields
.iter()
.map(|f| format_ident!("{}", &f.name))
.collect::<Vec<Ident>>(),
); );
let sql_columns = [normal_fields, foreign_keys_fields].concat() let sql_columns = [normal_fields, foreign_keys_fields]
.concat()
.iter() .iter()
.map(|f| f.name.clone()) .map(|f| f.name.clone())
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join(", "); .join(", ");
let insert_query = format!( let insert_query = format!(
"INSERT INTO {} ({}) VALUES ({})", "INSERT INTO {} ({}) VALUES ({})",
model.table_name, model.table_name, sql_columns, value_templates
sql_columns,
value_templates
); );
// foreign keys must be inserted first, we sort the columns so that foreign keys are first // foreign keys must be inserted first, we sort the columns so that foreign keys are first
@ -126,19 +159,26 @@ fn gen_insert_method(model: &Model) -> TokenStream {
fn gen_insert_many_method(model: &Model) -> TokenStream { fn gen_insert_many_method(model: &Model) -> TokenStream {
let resource_ident = format_ident!("{}", &model.name); let resource_ident = format_ident!("{}", &model.name);
let sql_columns = model.fields.iter() let sql_columns = model
.fields
.iter()
.map(|f| f.name.clone()) .map(|f| f.name.clone())
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join(", "); .join(", ");
let base_insert_query = format!( let base_insert_query = format!(
"INSERT INTO {} ({}) VALUES {{}} ON CONFLICT DO NOTHING", "INSERT INTO {} ({}) VALUES {{}} ON CONFLICT DO NOTHING",
model.table_name, model.table_name, sql_columns
sql_columns
); );
let (normal_fields, foreign_keys_fields) = get_mutation_fields(model); let (normal_fields, foreign_keys_fields) = get_mutation_fields(model);
let (normal_field_idents, foreign_keys_field_idents) = ( let (normal_field_idents, foreign_keys_field_idents) = (
normal_fields.iter().map(|f| format_ident!("{}", &f.name)).collect::<Vec<Ident>>(), normal_fields
foreign_keys_fields.iter().map(|f| format_ident!("{}", &f.name)).collect::<Vec<Ident>>() .iter()
.map(|f| format_ident!("{}", &f.name))
.collect::<Vec<Ident>>(),
foreign_keys_fields
.iter()
.map(|f| format_ident!("{}", &f.name))
.collect::<Vec<Ident>>(),
); );
let fields_count = model.fields.len(); let fields_count = model.fields.len();
@ -174,32 +214,39 @@ fn gen_insert_many_method(model: &Model) -> TokenStream {
} }
} }
fn gen_update_by_id_method(model: &Model) -> TokenStream { fn gen_update_by_id_method(model: &Model) -> TokenStream {
let resource_ident = format_ident!("{}", &model.name); let resource_ident = format_ident!("{}", &model.name);
let primary_key = &model.fields.iter() let primary_key = &model
.fields
.iter()
.find(|f| f.is_primary) .find(|f| f.is_primary)
.expect("A model must have at least one primary key") .expect("A model must have at least one primary key")
.name; .name;
let (normal_fields, foreign_keys_fields) = get_mutation_fields(model); let (normal_fields, foreign_keys_fields) = get_mutation_fields(model);
let (normal_field_idents, foreign_keys_field_idents) = ( let (normal_field_idents, foreign_keys_field_idents) = (
normal_fields.iter().map(|f| format_ident!("{}", &f.name)).collect::<Vec<Ident>>(), normal_fields
foreign_keys_fields.iter().map(|f| format_ident!("{}", &f.name)).collect::<Vec<Ident>>() .iter()
.map(|f| format_ident!("{}", &f.name))
.collect::<Vec<Ident>>(),
foreign_keys_fields
.iter()
.map(|f| format_ident!("{}", &f.name))
.collect::<Vec<Ident>>(),
); );
let sql_columns = [normal_fields, foreign_keys_fields].concat() let sql_columns = [normal_fields, foreign_keys_fields]
.concat()
.iter() .iter()
.map(|f| f.name.clone()) .map(|f| f.name.clone())
.collect::<Vec<String>>(); .collect::<Vec<String>>();
let set_statements = sql_columns.iter() let set_statements = sql_columns
.iter()
.enumerate() .enumerate()
.map(|(i, column_name)| format!("{} = ${}", column_name, i+2)) .map(|(i, column_name)| format!("{} = ${}", column_name, i + 2))
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join(", "); .join(", ");
let update_query = format!( let update_query = format!(
"UPDATE {} SET {} WHERE {} = $1", "UPDATE {} SET {} WHERE {} = $1",
model.table_name, model.table_name, set_statements, primary_key
set_statements,
primary_key
); );
let func_name_ident = format_ident!("update_by_{}", primary_key); let func_name_ident = format_ident!("update_by_{}", primary_key);
@ -218,7 +265,9 @@ fn gen_update_by_id_method(model: &Model) -> TokenStream {
} }
fn gen_delete_by_id_method(model: &Model) -> TokenStream { fn gen_delete_by_id_method(model: &Model) -> TokenStream {
let primary_key = &model.fields.iter() let primary_key = &model
.fields
.iter()
.find(|f| f.is_primary) .find(|f| f.is_primary)
.expect("A model must have at least one primary key") .expect("A model must have at least one primary key")
.name; .name;
@ -226,8 +275,7 @@ fn gen_delete_by_id_method(model: &Model) -> TokenStream {
let func_name_ident = format_ident!("delete_by_{}", primary_key); let func_name_ident = format_ident!("delete_by_{}", primary_key);
let query = format!( let query = format!(
"DELETE FROM {} WHERE {} = $1", "DELETE FROM {} WHERE {} = $1",
model.table_name, model.table_name, primary_key
primary_key
); );
quote! { quote! {
@ -243,7 +291,9 @@ fn gen_delete_by_id_method(model: &Model) -> TokenStream {
} }
fn gen_delete_many_by_id_method(model: &Model) -> TokenStream { fn gen_delete_many_by_id_method(model: &Model) -> TokenStream {
let primary_key = &model.fields.iter() let primary_key = &model
.fields
.iter()
.find(|f| f.is_primary) .find(|f| f.is_primary)
.expect("A model must have at least one primary key") .expect("A model must have at least one primary key")
.name; .name;
@ -251,8 +301,7 @@ fn gen_delete_many_by_id_method(model: &Model) -> TokenStream {
let func_name_ident = format_ident!("delete_many_by_{}", primary_key); let func_name_ident = format_ident!("delete_many_by_{}", primary_key);
let delete_query_tmpl = format!( let delete_query_tmpl = format!(
"DELETE FROM {} WHERE {} IN ({{}})", "DELETE FROM {} WHERE {} IN ({{}})",
model.table_name, model.table_name, primary_key
primary_key
); );
quote! { quote! {
@ -278,8 +327,10 @@ fn gen_delete_many_by_id_method(model: &Model) -> TokenStream {
} }
} }
pub fn generate_repository_file(
pub fn generate_repository_file(all_models: &[Model], model: &Model) -> Result<SourceNodeContainer> { all_models: &[Model],
model: &Model,
) -> Result<SourceNodeContainer> {
let resource_name = model.name.clone(); let resource_name = model.name.clone();
let resource_module_ident = format_ident!("{}", &model.module_path.first().unwrap()); let resource_module_ident = format_ident!("{}", &model.module_path.first().unwrap());
@ -290,15 +341,19 @@ pub fn generate_repository_file(all_models: &[Model], model: &Model) -> Result<S
let get_all_method_code = gen_get_all_method(model); let get_all_method_code = gen_get_all_method(model);
let get_by_id_method_code = gen_get_by_field_method( let get_by_id_method_code = gen_get_by_field_method(
model, model,
model.fields.iter() model
.fields
.iter()
.find(|f| f.is_primary == true) .find(|f| f.is_primary == true)
.expect("Expected at least one primary key on the model.") .expect("Expected at least one primary key on the model."),
); );
let get_many_by_id_method_code = gen_get_many_by_field_method( let get_many_by_id_method_code = gen_get_many_by_field_method(
model, model,
model.fields.iter() model
.fields
.iter()
.find(|f| f.is_primary == true) .find(|f| f.is_primary == true)
.expect("Expected at least one primary key on the model.") .expect("Expected at least one primary key on the model."),
); );
let insert_method_code = gen_insert_method(model); let insert_method_code = gen_insert_method(model);
let insert_many_method_code = gen_insert_many_method(model); let insert_many_method_code = gen_insert_many_method(model);
@ -306,34 +361,31 @@ pub fn generate_repository_file(all_models: &[Model], model: &Model) -> Result<S
let delete_by_id_method_code = gen_delete_by_id_method(model); let delete_by_id_method_code = gen_delete_by_id_method(model);
let delete_many_by_id_method_code = gen_delete_many_by_id_method(model); let delete_many_by_id_method_code = gen_delete_many_by_id_method(model);
let query_by_field_methods: Vec<TokenStream> = model
.fields
.iter()
.filter(|f| f.is_query_entrypoint)
.map(|field| gen_get_by_field_method(model, &field))
.collect();
let query_many_by_field_methods: Vec<TokenStream> = model
.fields
.iter()
.filter(|f| f.is_query_entrypoint)
.map(|field| gen_get_many_by_field_method(model, &field))
.collect();
let query_by_field_methods: Vec<TokenStream> = let fields_with_foreign_refs: Vec<&Field> = model
model.fields.iter() .fields
.filter(|f| f.is_query_entrypoint) .iter()
.map(|field| .filter(|f| match f.foreign_mode {
gen_get_by_field_method( FieldForeignMode::ForeignRef(_) => true,
model, FieldForeignMode::NotRef => false,
&field })
) .collect();
) let related_entity_methods_codes: Vec<TokenStream> = fields_with_foreign_refs
.collect(); .iter()
let query_many_by_field_methods: Vec<TokenStream> = .map(|field| gen_get_many_of_related_entity_method(model, &field))
model.fields.iter() .collect();
.filter(|f| f.is_query_entrypoint)
.map(|field|
gen_get_many_by_field_method(
model,
&field
)
)
.collect();
let fields_with_foreign_refs: Vec<&Field> = model.fields.iter().filter(|f|
match f.foreign_mode { FieldForeignMode::ForeignRef(_) => true, FieldForeignMode::NotRef => false }
).collect();
let related_entity_methods_codes: Vec<TokenStream> = fields_with_foreign_refs.iter().map(|field|
gen_get_many_of_related_entity_method(model, &field)
).collect();
// TODO: add import line // TODO: add import line
let base_repository_code: TokenStream = quote! { let base_repository_code: TokenStream = quote! {
@ -380,6 +432,6 @@ pub fn generate_repository_file(all_models: &[Model], model: &Model) -> Result<S
Ok(SourceNodeContainer { Ok(SourceNodeContainer {
name: format!("{}_repository.rs", model.name.to_snake_case()), name: format!("{}_repository.rs", model.name.to_snake_case()),
inner: SourceNode::File(pretty) inner: SourceNode::File(pretty),
}) })
} }

View file

@ -21,11 +21,10 @@ pub fn generate_repositories_source_files(models: &[Model]) -> Result<SourceNode
} }
nodes.push(SourceNodeContainer { nodes.push(SourceNodeContainer {
name: "mod.rs".into(), name: "mod.rs".into(),
inner: SourceNode::File(mod_index_code.to_string()) inner: SourceNode::File(mod_index_code.to_string()),
}); });
Ok(SourceNodeContainer { Ok(SourceNodeContainer {
name: "".into(), name: "".into(),
inner: SourceNode::Directory(nodes) inner: SourceNode::Directory(nodes),
}) })
} }

View file

@ -5,7 +5,10 @@ use crate::models::{Field, FieldForeignMode, Model};
/// method that can be used to retreive a list of entities of type X that are the children of a parent type Y /// method that can be used to retreive a list of entities of type X that are the children of a parent type Y
/// ex: get all comments of a post /// ex: get all comments of a post
pub fn gen_get_many_of_related_entity_method(model: &Model, foreign_key_field: &Field) -> TokenStream { pub fn gen_get_many_of_related_entity_method(
model: &Model,
foreign_key_field: &Field,
) -> TokenStream {
let resource_ident = format_ident!("{}", &model.name); let resource_ident = format_ident!("{}", &model.name);
let foreign_ref_params = match &foreign_key_field.foreign_mode { let foreign_ref_params = match &foreign_key_field.foreign_mode {
@ -15,7 +18,10 @@ pub fn gen_get_many_of_related_entity_method(model: &Model, foreign_key_field: &
} }
}; };
let select_query = format!("SELECT * FROM {} WHERE {} = $1", model.table_name, foreign_key_field.name); let select_query = format!(
"SELECT * FROM {} WHERE {} = $1",
model.table_name, foreign_key_field.name
);
let func_name_ident = format_ident!("get_many_of_{}", foreign_ref_params.target_resource_name); let func_name_ident = format_ident!("get_many_of_{}", foreign_ref_params.target_resource_name);
@ -28,4 +34,3 @@ pub fn gen_get_many_of_related_entity_method(model: &Model, foreign_key_field: &
} }
} }
} }

View file

@ -1,22 +1,19 @@
use std::{ffi::OsStr, path::Path};
use attribute_derive::FromAttr; use attribute_derive::FromAttr;
use std::{ffi::OsStr, path::Path};
use anyhow::{anyhow, Result};
use argh::FromArgs; use argh::FromArgs;
use anyhow::{Result, anyhow};
use crate::generators::{SourceNode, SourceNodeContainer}; use crate::generators::{SourceNode, SourceNodeContainer};
// use gen_migrations::generate_create_table_sql; pub mod generators;
// use gen_repositories::{generate_repositories_source_files, SourceNodeContainer};
pub mod models; pub mod models;
pub mod parse_models; pub mod parse_models;
pub mod generators;
#[derive(FromAttr, PartialEq, Debug, Default)] #[derive(FromAttr, PartialEq, Debug, Default)]
#[attribute(ident = sql_generator_model)] #[attribute(ident = sql_generator_model)]
pub struct SqlGeneratorModelAttr { pub struct SqlGeneratorModelAttr {
table_name: Option<String> table_name: Option<String>,
} }
#[derive(FromAttr, PartialEq, Debug, Default)] #[derive(FromAttr, PartialEq, Debug, Default)]
@ -28,17 +25,16 @@ pub struct SqlGeneratorFieldAttr {
/// to indicate that this field will be used to obtains entities /// to indicate that this field will be used to obtains entities
/// our framework will generate methods for all fields that is an entrypoint /// our framework will generate methods for all fields that is an entrypoint
is_query_entrypoint: Option<bool> is_query_entrypoint: Option<bool>,
} }
#[derive(FromArgs, PartialEq, Debug)] #[derive(FromArgs, PartialEq, Debug)]
/// Generate SQL CREATE TABLE migrations /// Generate SQL CREATE TABLE migrations
#[argh(subcommand, name = "gen-migrations")] #[argh(subcommand, name = "gen-migrations")]
struct GenerateMigration { struct GenerateMigration {
/// path of file where to write all in one generated SQL migration /// path of file where to write all in one generated SQL migration
#[argh(option, short = 'o')] #[argh(option, short = 'o')]
output: Option<String> output: Option<String>,
} }
#[derive(FromArgs, PartialEq, Debug)] #[derive(FromArgs, PartialEq, Debug)]
@ -47,7 +43,7 @@ struct GenerateMigration {
struct GenerateRepositories { struct GenerateRepositories {
/// path of the directory that contains repositories /// path of the directory that contains repositories
#[argh(option, short = 'o')] #[argh(option, short = 'o')]
output: Option<String> output: Option<String>,
} }
#[derive(FromArgs, PartialEq, Debug)] #[derive(FromArgs, PartialEq, Debug)]
@ -69,7 +65,7 @@ struct GeneratorArgs {
models_path: Option<String>, models_path: Option<String>,
#[argh(subcommand)] #[argh(subcommand)]
nested: GeneratorArgsSubCommands nested: GeneratorArgsSubCommands,
} }
fn write_source_code(base_path: &Path, snc: SourceNodeContainer) -> Result<()> { fn write_source_code(base_path: &Path, snc: SourceNodeContainer) -> Result<()> {
@ -78,7 +74,7 @@ fn write_source_code(base_path: &Path, snc: SourceNodeContainer) -> Result<()> {
SourceNode::File(code) => { SourceNode::File(code) => {
println!("writing file {:?}", path); println!("writing file {:?}", path);
std::fs::write(path, code)?; std::fs::write(path, code)?;
}, }
SourceNode::Directory(dir) => { SourceNode::Directory(dir) => {
for node in dir { for node in dir {
write_source_code(&path, node)?; write_source_code(&path, node)?;
@ -92,7 +88,10 @@ pub fn main() -> Result<()> {
let args: GeneratorArgs = argh::from_env(); let args: GeneratorArgs = argh::from_env();
let project_root = &args.project_root.unwrap_or(".".to_string()); let project_root = &args.project_root.unwrap_or(".".to_string());
let project_root_path = Path::new(&project_root); let project_root_path = Path::new(&project_root);
eprintln!("Using project root at: {:?}", &project_root_path.canonicalize()?); eprintln!(
"Using project root at: {:?}",
&project_root_path.canonicalize()?
);
if !project_root_path.exists() { if !project_root_path.exists() {
return Err(anyhow!("Could not resolve project root path.")); return Err(anyhow!("Could not resolve project root path."));
} }
@ -117,7 +116,11 @@ pub fn main() -> Result<()> {
if !models_mod_path.exists() { if !models_mod_path.exists() {
return Err(anyhow!("Could not resolve models modules.")); return Err(anyhow!("Could not resolve models modules."));
} }
if models_mod_path.file_name().map(|x| x == OsStr::new("mod.rs")).unwrap_or(false) { if models_mod_path
.file_name()
.map(|x| x == OsStr::new("mod.rs"))
.unwrap_or(false)
{
models_mod_path.pop(); models_mod_path.pop();
} }
eprintln!("Found models in project, parsing models"); eprintln!("Found models in project, parsing models");
@ -136,7 +139,7 @@ pub fn main() -> Result<()> {
let snc = generators::repositories::generate_repositories_source_files(&models)?; let snc = generators::repositories::generate_repositories_source_files(&models)?;
dbg!(&snc); dbg!(&snc);
write_source_code(&repositories_mod_path, snc)?; write_source_code(&repositories_mod_path, snc)?;
}, }
GeneratorArgsSubCommands::GenerateMigration(opts) => { GeneratorArgsSubCommands::GenerateMigration(opts) => {
eprintln!("Generating migrations…"); eprintln!("Generating migrations…");
let sql_code = generators::migrations::generate_create_table_sql(&models)?; let sql_code = generators::migrations::generate_create_table_sql(&models)?;

View file

@ -7,7 +7,7 @@ struct Model {
module_path: Vec<String>, module_path: Vec<String>,
name: String, name: String,
table_name: String, table_name: String,
fields: Vec<Field> fields: Vec<Field>,
} }
impl Model { impl Model {
@ -29,7 +29,6 @@ impl Model {
// } // }
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
#[fully_pub] #[fully_pub]
struct ForeignRefParams { struct ForeignRefParams {
@ -41,12 +40,11 @@ struct ForeignRefParams {
// target_resource_name_plural: String // target_resource_name_plural: String
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
#[fully_pub] #[fully_pub]
enum FieldForeignMode { enum FieldForeignMode {
ForeignRef(ForeignRefParams), ForeignRef(ForeignRefParams),
NotRef NotRef,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -58,6 +56,5 @@ struct Field {
is_unique: bool, is_unique: bool,
is_primary: bool, is_primary: bool,
is_query_entrypoint: bool, is_query_entrypoint: bool,
foreign_mode: FieldForeignMode foreign_mode: FieldForeignMode,
} }

View file

@ -1,11 +1,14 @@
use std::{fs, path::Path};
use attribute_derive::FromAttr; use attribute_derive::FromAttr;
use std::{fs, path::Path};
use anyhow::{Result, anyhow}; use anyhow::{anyhow, Result};
use convert_case::{Case, Casing}; use convert_case::{Case, Casing};
use syn::{GenericArgument, PathArguments, Type}; use syn::{GenericArgument, PathArguments, Type};
use crate::{SqlGeneratorFieldAttr, SqlGeneratorModelAttr, models::{Field, FieldForeignMode, ForeignRefParams, Model}}; use crate::{
models::{Field, FieldForeignMode, ForeignRefParams, Model},
SqlGeneratorFieldAttr, SqlGeneratorModelAttr,
};
fn extract_generic_type(base_segments: Vec<String>, ty: &Type) -> Option<&Type> { fn extract_generic_type(base_segments: Vec<String>, ty: &Type) -> Option<&Type> {
// If it is not `TypePath`, it is not possible to be `Option<T>`, return `None` // If it is not `TypePath`, it is not possible to be `Option<T>`, return `None`
@ -52,24 +55,38 @@ fn extract_generic_type(base_segments: Vec<String>, ty: &Type) -> Option<&Type>
fn get_type_first_ident(inp: &Type) -> Option<String> { fn get_type_first_ident(inp: &Type) -> Option<String> {
match inp { match inp {
Type::Path(field_type_path) => { Type::Path(field_type_path) => Some(
Some(field_type_path.path.segments.get(0).unwrap().ident.to_string()) field_type_path
}, .path
_ => { .segments
None .get(0)
} .unwrap()
.ident
.to_string(),
),
_ => None,
} }
} }
fn get_first_generic_arg_type_ident(inp: &Type) -> Option<String> { fn get_first_generic_arg_type_ident(inp: &Type) -> Option<String> {
if let Type::Path(field_type_path) = inp { if let Type::Path(field_type_path) = inp {
if let PathArguments::AngleBracketed(args) = &field_type_path.path.segments.get(0).unwrap().arguments { if let PathArguments::AngleBracketed(args) =
&field_type_path.path.segments.get(0).unwrap().arguments
{
if args.args.is_empty() { if args.args.is_empty() {
None None
} else { } else {
if let GenericArgument::Type(arg_type) = args.args.get(0).unwrap() { if let GenericArgument::Type(arg_type) = args.args.get(0).unwrap() {
if let Type::Path(arg_type_path) = arg_type { if let Type::Path(arg_type_path) = arg_type {
Some(arg_type_path.path.segments.get(0).unwrap().ident.to_string()) Some(
arg_type_path
.path
.segments
.get(0)
.unwrap()
.ident
.to_string(),
)
} else { } else {
None None
} }
@ -85,7 +102,6 @@ fn get_first_generic_arg_type_ident(inp: &Type) -> Option<String> {
} }
} }
fn parse_model_attribute(item: &syn::ItemStruct) -> Result<Option<SqlGeneratorModelAttr>> { fn parse_model_attribute(item: &syn::ItemStruct) -> Result<Option<SqlGeneratorModelAttr>> {
for attr in item.attrs.iter() { for attr in item.attrs.iter() {
let attr_ident = match attr.path().get_ident() { let attr_ident = match attr.path().get_ident() {
@ -101,9 +117,12 @@ fn parse_model_attribute(item: &syn::ItemStruct) -> Result<Option<SqlGeneratorMo
match SqlGeneratorModelAttr::from_attribute(attr) { match SqlGeneratorModelAttr::from_attribute(attr) {
Ok(v) => { Ok(v) => {
return Ok(Some(v)); return Ok(Some(v));
}, }
Err(err) => { Err(err) => {
return Err(anyhow!("Failed to parse sql_generator_model attribute macro: {}", err)); return Err(anyhow!(
"Failed to parse sql_generator_model attribute macro: {}",
err
));
} }
}; };
} }
@ -125,9 +144,13 @@ fn parse_field_attribute(field: &syn::Field) -> Result<Option<SqlGeneratorFieldA
match SqlGeneratorFieldAttr::from_attribute(attr) { match SqlGeneratorFieldAttr::from_attribute(attr) {
Ok(v) => { Ok(v) => {
return Ok(Some(v)); return Ok(Some(v));
}, }
Err(err) => { Err(err) => {
return Err(anyhow!("Failed to parse sql_generator_field attribute macro on field {:?}, {}", field, err)); return Err(anyhow!(
"Failed to parse sql_generator_field attribute macro on field {:?}, {}",
field,
err
));
} }
}; };
} }
@ -136,10 +159,7 @@ fn parse_field_attribute(field: &syn::Field) -> Result<Option<SqlGeneratorFieldA
/// Take struct name as source, apply snake case and pluralize with a s /// Take struct name as source, apply snake case and pluralize with a s
fn generate_table_name_from_struct_name(struct_name: &str) -> String { fn generate_table_name_from_struct_name(struct_name: &str) -> String {
format!( format!("{}s", struct_name.to_case(Case::Snake))
"{}s",
struct_name.to_case(Case::Snake)
)
} }
/// Scan for models struct in a rust file and return a struct representing the model /// Scan for models struct in a rust file and return a struct representing the model
@ -174,7 +194,7 @@ pub fn parse_models(source_code_path: &Path) -> Result<Vec<Model>> {
is_primary: false, is_primary: false,
is_unique: false, is_unique: false,
is_query_entrypoint: false, is_query_entrypoint: false,
foreign_mode: FieldForeignMode::NotRef foreign_mode: FieldForeignMode::NotRef,
}; };
let first_type: String = match get_type_first_ident(&field_type) { let first_type: String = match get_type_first_ident(&field_type) {
@ -187,8 +207,12 @@ pub fn parse_models(source_code_path: &Path) -> Result<Vec<Model>> {
if first_type == "Option" { if first_type == "Option" {
output_field.is_nullable = true; output_field.is_nullable = true;
let inner_type = match extract_generic_type( let inner_type = match extract_generic_type(
vec!["Option".into(), "std:option:Option".into(), "core:option:Option".into()], vec![
&field_type "Option".into(),
"std:option:Option".into(),
"core:option:Option".into(),
],
&field_type,
) { ) {
Some(v) => v, Some(v) => v,
None => { None => {
@ -198,15 +222,15 @@ pub fn parse_models(source_code_path: &Path) -> Result<Vec<Model>> {
final_type = match get_type_first_ident(inner_type) { final_type = match get_type_first_ident(inner_type) {
Some(v) => v, Some(v) => v,
None => { None => {
return Err(anyhow!("Could not extract ident from Option inner type")); return Err(anyhow!(
"Could not extract ident from Option inner type"
));
} }
} }
} }
if first_type == "Vec" { if first_type == "Vec" {
let inner_type = match extract_generic_type( let inner_type = match extract_generic_type(vec!["Vec".into()], &field_type)
vec!["Vec".into()], {
&field_type
) {
Some(v) => v, Some(v) => v,
None => { None => {
return Err(anyhow!("Could not extract type from Vec")); return Err(anyhow!("Could not extract type from Vec"));
@ -221,13 +245,14 @@ pub fn parse_models(source_code_path: &Path) -> Result<Vec<Model>> {
} }
output_field.rust_type = final_type; output_field.rust_type = final_type;
let field_attrs_opt = parse_field_attribute(field)?; let field_attrs_opt = parse_field_attribute(field)?;
if first_type == "ForeignRef" { if first_type == "ForeignRef" {
let attrs = match &field_attrs_opt { let attrs = match &field_attrs_opt {
Some(attrs) => attrs, Some(attrs) => attrs,
None => { None => {
return Err(anyhow!("Found a ForeignRef type but did not found attributes.")) return Err(anyhow!(
"Found a ForeignRef type but did not found attributes."
))
} }
}; };
let rrn = match &attrs.reverse_relation_name { let rrn = match &attrs.reverse_relation_name {
@ -237,39 +262,48 @@ pub fn parse_models(source_code_path: &Path) -> Result<Vec<Model>> {
} }
}; };
let extract_res = extract_generic_type(vec!["ForeignRef".into()], &field_type) let extract_res =
.and_then(|t| get_type_first_ident(t)); extract_generic_type(vec!["ForeignRef".into()], &field_type)
.and_then(|t| get_type_first_ident(t));
let target_type_name = match extract_res { let target_type_name = match extract_res {
Some(v) => v, Some(v) => v,
None => { None => {
return Err(anyhow!("Could not extract inner type from ForeignRef.")); return Err(anyhow!(
"Could not extract inner type from ForeignRef."
));
} }
}; };
output_field.foreign_mode = FieldForeignMode::ForeignRef( output_field.foreign_mode =
ForeignRefParams { FieldForeignMode::ForeignRef(ForeignRefParams {
reverse_relation_name: rrn, reverse_relation_name: rrn,
target_resource_name: target_type_name.to_case(Case::Snake) target_resource_name: target_type_name.to_case(Case::Snake),
} });
);
} }
// parse attribute // parse attribute
if let Some(field_attr) = field_attrs_opt { if let Some(field_attr) = field_attrs_opt {
output_field.is_primary = field_attr.is_primary.unwrap_or_default(); output_field.is_primary = field_attr.is_primary.unwrap_or_default();
output_field.is_unique = field_attr.is_unique.unwrap_or_default(); output_field.is_unique = field_attr.is_unique.unwrap_or_default();
output_field.is_query_entrypoint = field_attr.is_query_entrypoint.unwrap_or_default(); output_field.is_query_entrypoint =
field_attr.is_query_entrypoint.unwrap_or_default();
} }
fields.push(output_field); fields.push(output_field);
} }
models.push(Model { models.push(Model {
module_path: vec![source_code_path.file_stem().unwrap().to_str().unwrap().to_string()], module_path: vec![source_code_path
.file_stem()
.unwrap()
.to_str()
.unwrap()
.to_string()],
name: model_name.clone(), name: model_name.clone(),
table_name: model_attrs.table_name table_name: model_attrs
.table_name
.unwrap_or(generate_table_name_from_struct_name(&model_name)), .unwrap_or(generate_table_name_from_struct_name(&model_name)),
fields fields,
}) })
}, }
_ => {} _ => {}
} }
} }

View file

@ -12,23 +12,20 @@ use sqlx_core::error::BoxDynError;
use sqlx_core::types::Type; use sqlx_core::types::Type;
use sqlx_sqlite::{Sqlite, SqliteArgumentValue}; use sqlx_sqlite::{Sqlite, SqliteArgumentValue};
#[fully_pub] #[fully_pub]
trait DatabaseLine { trait DatabaseLine {
fn id(&self) -> String; fn id(&self) -> String;
} }
/// Wrapper to mark a model field as foreign /// Wrapper to mark a model field as foreign
/// You can use a generic argument inside ForeignRef to point to the target model /// You can use a generic argument inside ForeignRef to point to the target model
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
#[fully_pub] #[fully_pub]
struct ForeignRef<T: Sized + DatabaseLine> { struct ForeignRef<T: Sized + DatabaseLine> {
pub target_type: PhantomData<T>, pub target_type: PhantomData<T>,
pub target_id: String pub target_id: String,
} }
// Implement serde Serialize for ForeignRef // Implement serde Serialize for ForeignRef
impl<T: Sized + DatabaseLine> Serialize for ForeignRef<T> { impl<T: Sized + DatabaseLine> Serialize for ForeignRef<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
@ -40,22 +37,20 @@ impl<T: Sized + DatabaseLine> Serialize for ForeignRef<T> {
} }
} }
impl<T: Sized + DatabaseLine> ForeignRef<T> { impl<T: Sized + DatabaseLine> ForeignRef<T> {
pub fn new(entity: &T) -> ForeignRef<T> { pub fn new(entity: &T) -> ForeignRef<T> {
ForeignRef { ForeignRef {
target_type: PhantomData, target_type: PhantomData,
target_id: entity.id() target_id: entity.id(),
} }
} }
} }
impl<'r, DB: Database, T: Sized + DatabaseLine> Decode<'r, DB> for ForeignRef<T> impl<'r, DB: Database, T: Sized + DatabaseLine> Decode<'r, DB> for ForeignRef<T>
where where
// we want to delegate some of the work to string decoding so let's make sure strings // we want to delegate some of the work to string decoding so let's make sure strings
// are supported by the database // are supported by the database
&'r str: Decode<'r, DB> &'r str: Decode<'r, DB>,
{ {
fn decode( fn decode(
value: <DB as Database>::ValueRef<'r>, value: <DB as Database>::ValueRef<'r>,
@ -66,7 +61,7 @@ where
Ok(ForeignRef::<T> { Ok(ForeignRef::<T> {
target_type: PhantomData, target_type: PhantomData,
target_id: ref_val target_id: ref_val,
}) })
} }
} }
@ -84,9 +79,11 @@ impl<T: DatabaseLine + Sized> Type<Sqlite> for ForeignRef<T> {
} }
impl<T: DatabaseLine + Sized> Encode<'_, Sqlite> for ForeignRef<T> { impl<T: DatabaseLine + Sized> Encode<'_, Sqlite> for ForeignRef<T> {
fn encode_by_ref(&self, args: &mut Vec<SqliteArgumentValue<'_>>) -> Result<IsNull, BoxDynError> { fn encode_by_ref(
&self,
args: &mut Vec<SqliteArgumentValue<'_>>,
) -> Result<IsNull, BoxDynError> {
args.push(SqliteArgumentValue::Text(self.target_id.clone().into())); args.push(SqliteArgumentValue::Text(self.target_id.clone().into()));
Ok(IsNull::No) Ok(IsNull::No)
} }
} }