Compare commits

..

3 commits

Author SHA1 Message Date
d3aae47d2c docs: still WIP 2026-01-13 21:41:03 +01:00
8f7d11226f refactor: apply clippy rules 2026-01-13 21:20:50 +01:00
d205d722aa style: apply cargo fmt 2026-01-13 20:47:19 +01:00
25 changed files with 521 additions and 510 deletions

View file

@ -1,8 +1,12 @@
# [WIP] sqlxgentools
Tools to generate SQL migrations and Rust SQLx repositories code from models structs to use with a SQLite database.
Little tool to generate SQLite migrations files and Rust SQLx repositories code, all from models structs.
Will be used in [minauthator](https://forge.lefuturiste.fr/mbess/minauthator).
Still very much work in progress, but it can be already used in your next Rust app if you don't mind some limitations like the lack of incremental migrations and little quirks here and there.
## Getting started
- [Quick start tutorial](./docs/tutorials/quick_start.md)
## Project context
@ -18,10 +22,19 @@ Will be used in [minauthator](https://forge.lefuturiste.fr/mbess/minauthator).
- Provide a full a ORM interface
## Included crates
This project is split into 3 published crates.
- [`sqlxgentools_cli`](https://crates.io/crates/sqlxgentools_cli), used to parse, generate migrations and repositories.
- [`sqlxgentools_attrs`](https://crates.io/crates/sqlxgentools_attrs), provides proc macros.
- [`sqlxgentools_misc`](https://crates.io/crates/sqlxgentools_misc), provides data types and traits (optional).
## Features
- [x] generate migrations
- [x] from scratch
- [ ] incremental migration
- [ ] up migration
- [ ] down migration
- [x] generate repositories
@ -29,27 +42,10 @@ Will be used in [minauthator](https://forge.lefuturiste.fr/mbess/minauthator).
- [x] get_by_id
- [x] insert
- [x] insert_many
- [ ] generate custom by
- [x] co-exist with custom repository
- [x] custom get_by, get_many_by
- [x] get_many_of (from one-to-many relations)
## Usage
## Contributions
### Generate initial CREATE TABLE sqlite migration
Questions, remarks and contributions is very much welcomed.
cargo run --bin sqlx-generator -- ./path/to/project generate-create-migrations > migrations/all.sql
sqlx-generator \
-m path/to/models \
gen-repositories \
-o path/to/repositories
sqlx-generator \
-m path/to/models \
gen-migrations \
-o path/to/migrations/all.sql
### Generate repositories code
not implemented yet
cargo run --bin sqlx-generator -- ./path/to/project generate-repositories

View file

View file

0
docs/references/cli.md Normal file
View file

View file

View file

@ -6,3 +6,34 @@ Steps:
- Generate migrations
- Generate repositories
- Use repositories in your code
### CLI installation
The [sqlxgentools_cli crate](https://crates.io/crates/sqlxgentools_cli) provides the CLI,
it can be installed globally on your machine (or at least your user).
cargo install sqlxgentools_cli
### Project installation
Install the `sqlxgentools_attrs` crate
### Declare your models
TODO
### Generate migrations
Change directory into your project root.
sqlx-generator -m path/to/models_module gen-migrations -o path/to/migrations/all.sql
### Generate repositories
Change directory into your project root.
sqlx-generator -m path/to/models_module gen-repositories -o path/to/repositories_module
### Use the repositories
TODO

View file

@ -0,0 +1 @@
TODO

View file

@ -1,11 +1,11 @@
use anyhow::Context;
use std::str::FromStr;
use std::path::PathBuf;
use anyhow::Result;
use std::str::FromStr;
use fully_pub::fully_pub;
use sqlx::{
Pool, Sqlite, sqlite::{SqliteConnectOptions, SqlitePoolOptions},
sqlite::{SqliteConnectOptions, SqlitePoolOptions},
Pool, Sqlite,
};
/// database storage interface
@ -13,17 +13,8 @@ use sqlx::{
#[derive(Clone, Debug)]
struct Database(Pool<Sqlite>);
/// Initialize database
pub async fn provide_database(sqlite_db_path: &str) -> Result<Database> {
let path = PathBuf::from(sqlite_db_path);
let is_db_initialization = !path.exists();
// // database does not exists, trying to create it
// if path
// .parent()
// .filter(|pp| pp.exists())
// Err(anyhow!("Could not find parent directory of the db location.")));
let conn_str = format!("sqlite://{sqlite_db_path}");
let pool = SqlitePoolOptions::new()
@ -31,11 +22,6 @@ pub async fn provide_database(sqlite_db_path: &str) -> Result<Database> {
.connect_with(SqliteConnectOptions::from_str(&conn_str)?.create_if_missing(true))
.await
.context("could not connect to database_url")?;
// if is_db_initialization {
// initialize_db(Database(pool.clone())).await?;
// }
Ok(Database(pool))
}

View file

@ -1,3 +1,3 @@
pub mod repositories;
pub mod db;
pub mod models;
pub mod repositories;

View file

@ -1,21 +1,24 @@
use anyhow::{Context, Result};
use anyhow::Result;
use chrono::Utc;
use sqlx::types::Json;
use sqlxgentools_misc::ForeignRef;
use crate::{db::provide_database, models::user::{User, UserToken}, repositories::user_token_repository::UserTokenRepository};
use crate::{
db::provide_database,
models::user::{User, UserToken},
repositories::user_token_repository::UserTokenRepository,
};
pub mod db;
pub mod models;
pub mod repositories;
pub mod db;
#[tokio::main]
async fn main() -> Result<()> {
println!("Sandbox");
let users = vec![
let users = [
User {
id: "idu1".into(),
handle: "john.doe".into(),
@ -24,7 +27,7 @@ async fn main() -> Result<()> {
last_login_at: None,
status: models::user::UserStatus::Invited,
groups: Json(vec![]),
avatar_bytes: None
avatar_bytes: None,
},
User {
id: "idu2".into(),
@ -34,7 +37,7 @@ async fn main() -> Result<()> {
last_login_at: None,
status: models::user::UserStatus::Invited,
groups: Json(vec![]),
avatar_bytes: None
avatar_bytes: None,
},
User {
id: "idu3".into(),
@ -44,52 +47,53 @@ async fn main() -> Result<()> {
last_login_at: None,
status: models::user::UserStatus::Invited,
groups: Json(vec![]),
avatar_bytes: None
}
avatar_bytes: None,
},
];
let user_token = UserToken {
let _user_token = UserToken {
id: "idtoken1".into(),
secret: "4LP5A3F3XBV5NM8VXRGZG3QDXO9PNAC0".into(),
last_use_time: None,
creation_time: Utc::now(),
expiration_time: Utc::now(),
user_id: ForeignRef::new(&users.get(0).unwrap())
user_id: ForeignRef::new(users.first().unwrap()),
};
let db = provide_database("tmp/db.db").await?;
let user_token_repo = UserTokenRepository::new(db);
user_token_repo.insert_many(&vec![
UserToken {
id: "idtoken2".into(),
secret: "4LP5A3F3XBV5NM8VXRGZG3QDXO9PNAC0".into(),
last_use_time: None,
creation_time: Utc::now(),
expiration_time: Utc::now(),
user_id: ForeignRef::new(&users.get(0).unwrap())
},
UserToken {
id: "idtoken3".into(),
secret: "CBHR6G41KSEMR1AI".into(),
last_use_time: None,
creation_time: Utc::now(),
expiration_time: Utc::now(),
user_id: ForeignRef::new(&users.get(1).unwrap())
},
UserToken {
id: "idtoken4".into(),
secret: "CBHR6G41KSEMR1AI".into(),
last_use_time: None,
creation_time: Utc::now(),
expiration_time: Utc::now(),
user_id: ForeignRef::new(&users.get(1).unwrap())
}
]).await?;
let user_tokens = user_token_repo.get_many_user_tokens_by_usersss(
vec!["idu2".into()]
).await?;
user_token_repo
.insert_many(&vec![
UserToken {
id: "idtoken2".into(),
secret: "4LP5A3F3XBV5NM8VXRGZG3QDXO9PNAC0".into(),
last_use_time: None,
creation_time: Utc::now(),
expiration_time: Utc::now(),
user_id: ForeignRef::new(users.first().unwrap()),
},
UserToken {
id: "idtoken3".into(),
secret: "CBHR6G41KSEMR1AI".into(),
last_use_time: None,
creation_time: Utc::now(),
expiration_time: Utc::now(),
user_id: ForeignRef::new(users.get(1).unwrap()),
},
UserToken {
id: "idtoken4".into(),
secret: "CBHR6G41KSEMR1AI".into(),
last_use_time: None,
creation_time: Utc::now(),
expiration_time: Utc::now(),
user_id: ForeignRef::new(users.get(1).unwrap()),
},
])
.await?;
let user_tokens = user_token_repo
.get_many_user_tokens_by_usersss(vec!["idu2".into()])
.await?;
dbg!(&user_tokens);
Ok(())
}

View file

@ -1,8 +1,8 @@
use chrono::{DateTime, Utc};
use sqlx::types::Json;
use fully_pub::fully_pub;
use sqlx::types::Json;
use sqlxgentools_attrs::{SqlGeneratorDerive, SqlGeneratorModelWithId, sql_generator_model};
use sqlxgentools_attrs::{sql_generator_model, SqlGeneratorDerive, SqlGeneratorModelWithId};
use sqlxgentools_misc::{DatabaseLine, ForeignRef};
#[derive(sqlx::Type, Clone, Debug, PartialEq)]
@ -11,37 +11,36 @@ enum UserStatus {
Disabled,
Invited,
Active,
Archived
Archived,
}
#[derive(SqlGeneratorDerive, SqlGeneratorModelWithId, sqlx::FromRow, Debug, Clone)]
#[sql_generator_model(table_name="usersss")]
#[sql_generator_model(table_name = "usersss")]
#[fully_pub]
struct User {
#[sql_generator_field(is_primary=true)]
#[sql_generator_field(is_primary = true)]
id: String,
#[sql_generator_field(is_unique=true)]
#[sql_generator_field(is_unique = true)]
handle: String,
full_name: Option<String>,
prefered_color: Option<i64>,
last_login_at: Option<DateTime<Utc>>,
status: UserStatus,
groups: Json<Vec<String>>,
avatar_bytes: Option<Vec<u8>>
avatar_bytes: Option<Vec<u8>>,
}
#[derive(SqlGeneratorDerive, SqlGeneratorModelWithId, sqlx::FromRow, Debug, Clone)]
#[sql_generator_model(table_name="user_tokens")]
#[sql_generator_model(table_name = "user_tokens")]
#[fully_pub]
struct UserToken {
#[sql_generator_field(is_primary=true)]
#[sql_generator_field(is_primary = true)]
id: String,
secret: String,
last_use_time: Option<DateTime<Utc>>,
creation_time: DateTime<Utc>,
expiration_time: DateTime<Utc>,
#[sql_generator_field(reverse_relation_name="user_tokens")] // to generate get_user_tokens_of_user(&user_id)
user_id: ForeignRef<User>
#[sql_generator_field(reverse_relation_name = "user_tokens")]
// to generate get_user_tokens_of_user(&user_id)
user_id: ForeignRef<User>,
}

View file

@ -1,5 +1,5 @@
use crate::models::user::User;
use crate::db::Database;
use crate::models::user::User;
pub struct UserRepository {
db: Database,
}
@ -8,7 +8,9 @@ impl UserRepository {
UserRepository { db }
}
pub async fn get_all(&self) -> Result<Vec<User>, sqlx::Error> {
sqlx::query_as::<_, User>("SELECT * FROM usersss").fetch_all(&self.db.0).await
sqlx::query_as::<_, User>("SELECT * FROM usersss")
.fetch_all(&self.db.0)
.await
}
pub async fn get_by_id(&self, item_id: &str) -> Result<User, sqlx::Error> {
sqlx::query_as::<_, User>("SELECT * FROM usersss WHERE id = $1")
@ -16,10 +18,7 @@ impl UserRepository {
.fetch_one(&self.db.0)
.await
}
pub async fn get_many_by_id(
&self,
items_ids: &[&str],
) -> Result<Vec<User>, sqlx::Error> {
pub async fn get_many_by_id(&self, items_ids: &[&str]) -> Result<Vec<User>, sqlx::Error> {
if items_ids.is_empty() {
return Ok(vec![]);
}
@ -27,9 +26,7 @@ impl UserRepository {
.map(|i| format!("${}", i))
.collect::<Vec<String>>()
.join(",");
let query_sql = format!(
"SELECT * FROM usersss WHERE id IN ({})", placeholder_params
);
let query_sql = format!("SELECT * FROM usersss WHERE id IN ({})", placeholder_params);
let mut query = sqlx::query_as::<_, User>(&query_sql);
for id in items_ids {
query = query.bind(id);
@ -43,8 +40,8 @@ impl UserRepository {
.bind(&entity.id)
.bind(&entity.handle)
.bind(&entity.full_name)
.bind(&entity.prefered_color)
.bind(&entity.last_login_at)
.bind(entity.prefered_color)
.bind(entity.last_login_at)
.bind(&entity.status)
.bind(&entity.groups)
.bind(&entity.avatar_bytes)
@ -59,8 +56,11 @@ impl UserRepository {
.map(|c| c.to_vec())
.map(|x| {
format!(
"({})", x.iter().map(| i | format!("${}", i)).collect:: < Vec <
String >> ().join(", ")
"({})",
x.iter()
.map(|i| format!("${}", i))
.collect::<Vec<String>>()
.join(", ")
)
})
.collect::<Vec<String>>()
@ -75,8 +75,8 @@ impl UserRepository {
.bind(&entity.id)
.bind(&entity.handle)
.bind(&entity.full_name)
.bind(&entity.prefered_color)
.bind(&entity.last_login_at)
.bind(entity.prefered_color)
.bind(entity.last_login_at)
.bind(&entity.status)
.bind(&entity.groups)
.bind(&entity.avatar_bytes);
@ -84,11 +84,7 @@ impl UserRepository {
query.execute(&self.db.0).await?;
Ok(())
}
pub async fn update_by_id(
&self,
item_id: &str,
entity: &User,
) -> Result<(), sqlx::Error> {
pub async fn update_by_id(&self, item_id: &str, entity: &User) -> Result<(), sqlx::Error> {
sqlx::query(
"UPDATE usersss SET id = $2, handle = $3, full_name = $4, prefered_color = $5, last_login_at = $6, status = $7, groups = $8, avatar_bytes = $9 WHERE id = $1",
)
@ -96,8 +92,8 @@ impl UserRepository {
.bind(&entity.id)
.bind(&entity.handle)
.bind(&entity.full_name)
.bind(&entity.prefered_color)
.bind(&entity.last_login_at)
.bind(entity.prefered_color)
.bind(entity.last_login_at)
.bind(&entity.status)
.bind(&entity.groups)
.bind(&entity.avatar_bytes)

View file

@ -1,5 +1,5 @@
use crate::models::user::UserToken;
use crate::db::Database;
use crate::models::user::UserToken;
pub struct UserTokenRepository {
db: Database,
}
@ -18,10 +18,7 @@ impl UserTokenRepository {
.fetch_one(&self.db.0)
.await
}
pub async fn get_many_by_id(
&self,
items_ids: &[&str],
) -> Result<Vec<UserToken>, sqlx::Error> {
pub async fn get_many_by_id(&self, items_ids: &[&str]) -> Result<Vec<UserToken>, sqlx::Error> {
if items_ids.is_empty() {
return Ok(vec![]);
}
@ -30,7 +27,8 @@ impl UserTokenRepository {
.collect::<Vec<String>>()
.join(",");
let query_sql = format!(
"SELECT * FROM user_tokens WHERE id IN ({})", placeholder_params
"SELECT * FROM user_tokens WHERE id IN ({})",
placeholder_params
);
let mut query = sqlx::query_as::<_, UserToken>(&query_sql);
for id in items_ids {
@ -44,26 +42,26 @@ impl UserTokenRepository {
)
.bind(&entity.id)
.bind(&entity.secret)
.bind(&entity.last_use_time)
.bind(&entity.creation_time)
.bind(&entity.expiration_time)
.bind(entity.last_use_time)
.bind(entity.creation_time)
.bind(entity.expiration_time)
.bind(&entity.user_id.target_id)
.execute(&self.db.0)
.await?;
Ok(())
}
pub async fn insert_many(
&self,
entities: &Vec<UserToken>,
) -> Result<(), sqlx::Error> {
pub async fn insert_many(&self, entities: &Vec<UserToken>) -> Result<(), sqlx::Error> {
let values_templates: String = (1..(6usize * entities.len() + 1))
.collect::<Vec<usize>>()
.chunks(6usize)
.map(|c| c.to_vec())
.map(|x| {
format!(
"({})", x.iter().map(| i | format!("${}", i)).collect:: < Vec <
String >> ().join(", ")
"({})",
x.iter()
.map(|i| format!("${}", i))
.collect::<Vec<String>>()
.join(", ")
)
})
.collect::<Vec<String>>()
@ -77,28 +75,24 @@ impl UserTokenRepository {
query = query
.bind(&entity.id)
.bind(&entity.secret)
.bind(&entity.last_use_time)
.bind(&entity.creation_time)
.bind(&entity.expiration_time)
.bind(entity.last_use_time)
.bind(entity.creation_time)
.bind(entity.expiration_time)
.bind(&entity.user_id.target_id);
}
query.execute(&self.db.0).await?;
Ok(())
}
pub async fn update_by_id(
&self,
item_id: &str,
entity: &UserToken,
) -> Result<(), sqlx::Error> {
pub async fn update_by_id(&self, item_id: &str, entity: &UserToken) -> Result<(), sqlx::Error> {
sqlx::query(
"UPDATE user_tokens SET id = $2, secret = $3, last_use_time = $4, creation_time = $5, expiration_time = $6, user_id = $7 WHERE id = $1",
)
.bind(item_id)
.bind(&entity.id)
.bind(&entity.secret)
.bind(&entity.last_use_time)
.bind(&entity.creation_time)
.bind(&entity.expiration_time)
.bind(entity.last_use_time)
.bind(entity.creation_time)
.bind(entity.expiration_time)
.bind(&entity.user_id.target_id)
.execute(&self.db.0)
.await?;
@ -132,7 +126,8 @@ impl UserTokenRepository {
.collect::<Vec<String>>()
.join(",");
let query_tmpl = format!(
"SELECT * FROM user_tokens WHERE user_id IN ({})", placeholder_params
"SELECT * FROM user_tokens WHERE user_id IN ({})",
placeholder_params
);
let mut query = sqlx::query_as::<_, UserToken>(&query_tmpl);
for id in items_ids {

View file

@ -5,7 +5,10 @@
use std::assert_matches::assert_matches;
use chrono::Utc;
use sandbox::{models::user::{User, UserStatus}, repositories::user_repository::UserRepository};
use sandbox::{
models::user::{User, UserStatus},
repositories::user_repository::UserRepository,
};
use sqlx::{types::Json, Pool, Sqlite};
#[sqlx::test(fixtures("../src/migrations/all.sql"))]
@ -22,43 +25,40 @@ async fn test_user_repository_create_read_update_delete(pool: Pool<Sqlite>) -> s
last_login_at: Some(Utc::now()),
status: UserStatus::Invited,
groups: Json(vec!["artists".into()]),
avatar_bytes: vec![0x00]
avatar_bytes: vec![0x00],
};
assert_matches!(user_repo.insert(&new_user).await, Ok(()));
assert_matches!(
user_repo.insert(&new_user).await,
Ok(())
);
assert_matches!(
user_repo.get_by_id("ffffffff-0000-4000-0000-0000000000c9".into()).await,
user_repo
.get_by_id("ffffffff-0000-4000-0000-0000000000c9".into())
.await,
Ok(User { .. })
);
assert_matches!(
user_repo.get_by_id("ffffffff-0000-4040-0000-000000000000".into()).await,
user_repo
.get_by_id("ffffffff-0000-4040-0000-000000000000".into())
.await,
Err(sqlx::Error::RowNotFound)
);
// Insert Many
let bunch_of_users: Vec<User> = (0..10).map(|pid| User {
id: format!("ffffffff-0000-4000-0010-{:0>8}", pid),
handle: format!("user num {}", pid),
full_name: None,
prefered_color: None,
last_login_at: None,
status: UserStatus::Invited,
groups: Json(vec![]),
avatar_bytes: vec![]
}).collect();
assert_matches!(
user_repo.insert_many(&bunch_of_users).await,
Ok(())
);
let bunch_of_users: Vec<User> = (0..10)
.map(|pid| User {
id: format!("ffffffff-0000-4000-0010-{:0>8}", pid),
handle: format!("user num {}", pid),
full_name: None,
prefered_color: None,
last_login_at: None,
status: UserStatus::Invited,
groups: Json(vec![]),
avatar_bytes: vec![],
})
.collect();
assert_matches!(user_repo.insert_many(&bunch_of_users).await, Ok(()));
// Read many all
let read_all_res = user_repo.get_all().await;
assert_matches!(
read_all_res,
Ok(..)
);
assert_matches!(read_all_res, Ok(..));
let all_users = read_all_res.unwrap();
assert_eq!(all_users.len(), 11);
@ -69,16 +69,18 @@ async fn test_user_repository_create_read_update_delete(pool: Pool<Sqlite>) -> s
user_repo.update_by_id(&new_user.id, &updated_user).await,
Ok(())
);
let user_from_db = user_repo.get_by_id("ffffffff-0000-4000-0000-0000000000c9".into()).await.unwrap();
let user_from_db = user_repo
.get_by_id("ffffffff-0000-4000-0000-0000000000c9".into())
.await
.unwrap();
assert_eq!(user_from_db.status, UserStatus::Disabled);
// Delete
assert_matches!(user_repo.delete_by_id(&new_user.id).await, Ok(()));
assert_matches!(
user_repo.delete_by_id(&new_user.id).await,
Ok(())
);
assert_matches!(
user_repo.get_by_id("ffffffff-0000-4000-0000-0000000000c9".into()).await,
user_repo
.get_by_id("ffffffff-0000-4000-0000-0000000000c9".into())
.await,
Err(sqlx::Error::RowNotFound)
);

View file

@ -1,6 +1,6 @@
use proc_macro::TokenStream;
use quote::quote;
use syn::{DeriveInput, Fields, parse_macro_input};
use syn::{parse_macro_input, DeriveInput, Fields};
#[proc_macro_attribute]
pub fn sql_generator_model(_attr: TokenStream, item: TokenStream) -> TokenStream {
@ -21,7 +21,7 @@ pub fn derive_sql_generator_model_with_id(input: TokenStream) -> TokenStream {
if let syn::Data::Struct(data) = input.data {
if let Fields::Named(fields) = data.fields {
for field in fields.named {
if field.ident.as_ref().map_or(false, |ident| ident == "id") {
if field.ident.as_ref().is_some_and(|ident| ident == "id") {
let expanded = quote! {
impl DatabaseLine for #name {
fn id(&self) -> String {
@ -38,4 +38,3 @@ pub fn derive_sql_generator_model_with_id(input: TokenStream) -> TokenStream {
// If `id` field is not found, return an error
panic!("Expected struct with a named field `id` of type String")
}

View file

@ -1,8 +1,7 @@
use anyhow::{Result, anyhow};
use anyhow::{anyhow, Result};
use crate::models::{Field, Model};
// Implementations
impl Field {
/// return sqlite type
@ -21,7 +20,7 @@ impl Field {
"DateTime" => Some("DATETIME".into()),
"Json" => Some("TEXT".into()),
"Vec<u8>" => Some("BLOB".into()),
_ => Some("TEXT".into())
_ => Some("TEXT".into()),
}
}
}
@ -35,8 +34,10 @@ pub fn generate_create_table_sql(models: &[Model]) -> Result<String> {
let mut fields_sql: Vec<String> = vec![];
for field in model.fields.iter() {
let mut additions: String = "".into();
let sql_type = field.sql_type()
.ok_or(anyhow!(format!("Could not find SQL type for field {}", field.name)))?;
let sql_type = field.sql_type().ok_or(anyhow!(format!(
"Could not find SQL type for field {}",
field.name
)))?;
if !field.is_nullable {
additions.push_str(" NOT NULL");
}
@ -46,20 +47,15 @@ pub fn generate_create_table_sql(models: &[Model]) -> Result<String> {
if field.is_primary {
additions.push_str(" PRIMARY KEY");
}
fields_sql.push(
format!("\t{: <#18}\t{}{}", field.name, sql_type, additions)
);
fields_sql.push(format!("\t{: <#18}\t{}{}", field.name, sql_type, additions));
}
sql_code.push_str(
&format!(
"CREATE TABLE {} (\n{}\n);\n",
model.table_name,
fields_sql.join(",\n")
)
);
sql_code.push_str(&format!(
"CREATE TABLE {} (\n{}\n);\n",
model.table_name,
fields_sql.join(",\n")
));
}
Ok(sql_code)
}

View file

@ -8,13 +8,12 @@ pub mod repositories;
#[fully_pub]
enum SourceNode {
File(String),
Directory(Vec<SourceNodeContainer>)
Directory(Vec<SourceNodeContainer>),
}
#[derive(Serialize, Debug)]
#[fully_pub]
struct SourceNodeContainer {
name: String,
inner: SourceNode
inner: SourceNode,
}

View file

@ -1,12 +1,14 @@
use anyhow::Result;
use proc_macro2::{TokenStream, Ident};
use heck::ToSnakeCase;
use proc_macro2::{Ident, TokenStream};
use quote::{format_ident, quote};
use syn::File;
use heck::ToSnakeCase;
use crate::{generators::repositories::relations::gen_get_many_of_related_entity_method, models::{Field, FieldForeignMode, Model}};
use crate::generators::{SourceNode, SourceNodeContainer};
use crate::{
generators::repositories::relations::gen_get_many_of_related_entity_method,
models::{Field, FieldForeignMode, Model},
};
fn gen_get_all_method(model: &Model) -> TokenStream {
let resource_ident = format_ident!("{}", &model.name);
@ -23,7 +25,10 @@ fn gen_get_all_method(model: &Model) -> TokenStream {
fn gen_get_by_field_method(model: &Model, query_field: &Field) -> TokenStream {
let resource_ident = format_ident!("{}", &model.name);
let select_query = format!("SELECT * FROM {} WHERE {} = $1", model.table_name, query_field.name);
let select_query = format!(
"SELECT * FROM {} WHERE {} = $1",
model.table_name, query_field.name
);
let func_name_ident = format_ident!("get_by_{}", query_field.name);
@ -40,7 +45,10 @@ fn gen_get_by_field_method(model: &Model, query_field: &Field) -> TokenStream {
fn gen_get_many_by_field_method(model: &Model, query_field: &Field) -> TokenStream {
let resource_ident = format_ident!("{}", &model.name);
let select_query_tmpl = format!("SELECT * FROM {} WHERE {} IN ({{}})", model.table_name, query_field.name);
let select_query_tmpl = format!(
"SELECT * FROM {} WHERE {} IN ({{}})",
model.table_name, query_field.name
);
let func_name_ident = format_ident!("get_many_by_{}", query_field.name);
@ -66,21 +74,21 @@ fn gen_get_many_by_field_method(model: &Model, query_field: &Field) -> TokenStre
}
fn get_mutation_fields(model: &Model) -> (Vec<&Field>, Vec<&Field>) {
let normal_field_names: Vec<&Field> = model.fields.iter()
.filter(|f| match f.foreign_mode { FieldForeignMode::NotRef => true, FieldForeignMode::ForeignRef(_) => false })
let normal_field_names: Vec<&Field> = model
.fields
.iter()
.filter(|f| match f.foreign_mode {
FieldForeignMode::NotRef => true,
FieldForeignMode::ForeignRef(_) => false,
})
.collect();
let foreign_keys_field_names: Vec<&Field> = model.fields.iter()
.filter(|f| match f.foreign_mode { FieldForeignMode::NotRef => false, FieldForeignMode::ForeignRef(_) => true })
.collect();
(normal_field_names, foreign_keys_field_names)
}
fn get_mutation_fields_ident(model: &Model) -> (Vec<&Field>, Vec<&Field>) {
let normal_field_names: Vec<&Field> = model.fields.iter()
.filter(|f| match f.foreign_mode { FieldForeignMode::NotRef => true, FieldForeignMode::ForeignRef(_) => false })
.collect();
let foreign_keys_field_names: Vec<&Field> = model.fields.iter()
.filter(|f| match f.foreign_mode { FieldForeignMode::NotRef => false, FieldForeignMode::ForeignRef(_) => true })
let foreign_keys_field_names: Vec<&Field> = model
.fields
.iter()
.filter(|f| match f.foreign_mode {
FieldForeignMode::NotRef => false,
FieldForeignMode::ForeignRef(_) => true,
})
.collect();
(normal_field_names, foreign_keys_field_names)
}
@ -88,26 +96,31 @@ fn get_mutation_fields_ident(model: &Model) -> (Vec<&Field>, Vec<&Field>) {
fn gen_insert_method(model: &Model) -> TokenStream {
let resource_ident = format_ident!("{}", &model.name);
let value_templates = (1..(model.fields.len()+1))
let value_templates = (1..(model.fields.len() + 1))
.map(|i| format!("${}", i))
.collect::<Vec<String>>()
.join(", ");
let (normal_fields, foreign_keys_fields) = get_mutation_fields(model);
let (normal_field_idents, foreign_keys_field_idents) = (
normal_fields.iter().map(|f| format_ident!("{}", &f.name)).collect::<Vec<Ident>>(),
foreign_keys_fields.iter().map(|f| format_ident!("{}", &f.name)).collect::<Vec<Ident>>()
normal_fields
.iter()
.map(|f| format_ident!("{}", &f.name))
.collect::<Vec<Ident>>(),
foreign_keys_fields
.iter()
.map(|f| format_ident!("{}", &f.name))
.collect::<Vec<Ident>>(),
);
let sql_columns = [normal_fields, foreign_keys_fields].concat()
let sql_columns = [normal_fields, foreign_keys_fields]
.concat()
.iter()
.map(|f| f.name.clone())
.collect::<Vec<String>>()
.join(", ");
let insert_query = format!(
"INSERT INTO {} ({}) VALUES ({})",
model.table_name,
sql_columns,
value_templates
model.table_name, sql_columns, value_templates
);
// foreign keys must be inserted first, we sort the columns so that foreign keys are first
@ -126,19 +139,26 @@ fn gen_insert_method(model: &Model) -> TokenStream {
fn gen_insert_many_method(model: &Model) -> TokenStream {
let resource_ident = format_ident!("{}", &model.name);
let sql_columns = model.fields.iter()
let sql_columns = model
.fields
.iter()
.map(|f| f.name.clone())
.collect::<Vec<String>>()
.join(", ");
let base_insert_query = format!(
"INSERT INTO {} ({}) VALUES {{}} ON CONFLICT DO NOTHING",
model.table_name,
sql_columns
model.table_name, sql_columns
);
let (normal_fields, foreign_keys_fields) = get_mutation_fields(model);
let (normal_field_idents, foreign_keys_field_idents) = (
normal_fields.iter().map(|f| format_ident!("{}", &f.name)).collect::<Vec<Ident>>(),
foreign_keys_fields.iter().map(|f| format_ident!("{}", &f.name)).collect::<Vec<Ident>>()
normal_fields
.iter()
.map(|f| format_ident!("{}", &f.name))
.collect::<Vec<Ident>>(),
foreign_keys_fields
.iter()
.map(|f| format_ident!("{}", &f.name))
.collect::<Vec<Ident>>(),
);
let fields_count = model.fields.len();
@ -174,32 +194,39 @@ fn gen_insert_many_method(model: &Model) -> TokenStream {
}
}
fn gen_update_by_id_method(model: &Model) -> TokenStream {
let resource_ident = format_ident!("{}", &model.name);
let primary_key = &model.fields.iter()
let primary_key = &model
.fields
.iter()
.find(|f| f.is_primary)
.expect("A model must have at least one primary key")
.name;
let (normal_fields, foreign_keys_fields) = get_mutation_fields(model);
let (normal_field_idents, foreign_keys_field_idents) = (
normal_fields.iter().map(|f| format_ident!("{}", &f.name)).collect::<Vec<Ident>>(),
foreign_keys_fields.iter().map(|f| format_ident!("{}", &f.name)).collect::<Vec<Ident>>()
normal_fields
.iter()
.map(|f| format_ident!("{}", &f.name))
.collect::<Vec<Ident>>(),
foreign_keys_fields
.iter()
.map(|f| format_ident!("{}", &f.name))
.collect::<Vec<Ident>>(),
);
let sql_columns = [normal_fields, foreign_keys_fields].concat()
let sql_columns = [normal_fields, foreign_keys_fields]
.concat()
.iter()
.map(|f| f.name.clone())
.collect::<Vec<String>>();
let set_statements = sql_columns.iter()
let set_statements = sql_columns
.iter()
.enumerate()
.map(|(i, column_name)| format!("{} = ${}", column_name, i+2))
.map(|(i, column_name)| format!("{} = ${}", column_name, i + 2))
.collect::<Vec<String>>()
.join(", ");
let update_query = format!(
"UPDATE {} SET {} WHERE {} = $1",
model.table_name,
set_statements,
primary_key
model.table_name, set_statements, primary_key
);
let func_name_ident = format_ident!("update_by_{}", primary_key);
@ -218,7 +245,9 @@ fn gen_update_by_id_method(model: &Model) -> TokenStream {
}
fn gen_delete_by_id_method(model: &Model) -> TokenStream {
let primary_key = &model.fields.iter()
let primary_key = &model
.fields
.iter()
.find(|f| f.is_primary)
.expect("A model must have at least one primary key")
.name;
@ -226,8 +255,7 @@ fn gen_delete_by_id_method(model: &Model) -> TokenStream {
let func_name_ident = format_ident!("delete_by_{}", primary_key);
let query = format!(
"DELETE FROM {} WHERE {} = $1",
model.table_name,
primary_key
model.table_name, primary_key
);
quote! {
@ -243,7 +271,9 @@ fn gen_delete_by_id_method(model: &Model) -> TokenStream {
}
fn gen_delete_many_by_id_method(model: &Model) -> TokenStream {
let primary_key = &model.fields.iter()
let primary_key = &model
.fields
.iter()
.find(|f| f.is_primary)
.expect("A model must have at least one primary key")
.name;
@ -251,8 +281,7 @@ fn gen_delete_many_by_id_method(model: &Model) -> TokenStream {
let func_name_ident = format_ident!("delete_many_by_{}", primary_key);
let delete_query_tmpl = format!(
"DELETE FROM {} WHERE {} IN ({{}})",
model.table_name,
primary_key
model.table_name, primary_key
);
quote! {
@ -278,8 +307,10 @@ fn gen_delete_many_by_id_method(model: &Model) -> TokenStream {
}
}
pub fn generate_repository_file(all_models: &[Model], model: &Model) -> Result<SourceNodeContainer> {
pub fn generate_repository_file(
_all_models: &[Model],
model: &Model,
) -> Result<SourceNodeContainer> {
let resource_name = model.name.clone();
let resource_module_ident = format_ident!("{}", &model.module_path.first().unwrap());
@ -290,15 +321,19 @@ pub fn generate_repository_file(all_models: &[Model], model: &Model) -> Result<S
let get_all_method_code = gen_get_all_method(model);
let get_by_id_method_code = gen_get_by_field_method(
model,
model.fields.iter()
.find(|f| f.is_primary == true)
.expect("Expected at least one primary key on the model.")
model
.fields
.iter()
.find(|f| f.is_primary)
.expect("Expected at least one primary key on the model."),
);
let get_many_by_id_method_code = gen_get_many_by_field_method(
model,
model.fields.iter()
.find(|f| f.is_primary == true)
.expect("Expected at least one primary key on the model.")
model
.fields
.iter()
.find(|f| f.is_primary)
.expect("Expected at least one primary key on the model."),
);
let insert_method_code = gen_insert_method(model);
let insert_many_method_code = gen_insert_many_method(model);
@ -306,40 +341,37 @@ pub fn generate_repository_file(all_models: &[Model], model: &Model) -> Result<S
let delete_by_id_method_code = gen_delete_by_id_method(model);
let delete_many_by_id_method_code = gen_delete_many_by_id_method(model);
let query_by_field_methods: Vec<TokenStream> = model
.fields
.iter()
.filter(|f| f.is_query_entrypoint)
.map(|field| gen_get_by_field_method(model, field))
.collect();
let query_many_by_field_methods: Vec<TokenStream> = model
.fields
.iter()
.filter(|f| f.is_query_entrypoint)
.map(|field| gen_get_many_by_field_method(model, field))
.collect();
let query_by_field_methods: Vec<TokenStream> =
model.fields.iter()
.filter(|f| f.is_query_entrypoint)
.map(|field|
gen_get_by_field_method(
model,
&field
)
)
.collect();
let query_many_by_field_methods: Vec<TokenStream> =
model.fields.iter()
.filter(|f| f.is_query_entrypoint)
.map(|field|
gen_get_many_by_field_method(
model,
&field
)
)
.collect();
let fields_with_foreign_refs: Vec<&Field> = model.fields.iter().filter(|f|
match f.foreign_mode { FieldForeignMode::ForeignRef(_) => true, FieldForeignMode::NotRef => false }
).collect();
let related_entity_methods_codes: Vec<TokenStream> = fields_with_foreign_refs.iter().map(|field|
gen_get_many_of_related_entity_method(model, &field)
).collect();
let fields_with_foreign_refs: Vec<&Field> = model
.fields
.iter()
.filter(|f| match f.foreign_mode {
FieldForeignMode::ForeignRef(_) => true,
FieldForeignMode::NotRef => false,
})
.collect();
let related_entity_methods_codes: Vec<TokenStream> = fields_with_foreign_refs
.iter()
.map(|field| gen_get_many_of_related_entity_method(model, field))
.collect();
// TODO: add import line
let base_repository_code: TokenStream = quote! {
use crate::models::#resource_module_ident::#resource_ident;
use crate::db::Database;
pub struct #repository_ident {
db: Database
}
@ -370,7 +402,7 @@ pub fn generate_repository_file(all_models: &[Model], model: &Model) -> Result<S
#(#query_by_field_methods)*
#(#query_many_by_field_methods)*
#(#related_entity_methods_codes)*
}
};
@ -380,6 +412,6 @@ pub fn generate_repository_file(all_models: &[Model], model: &Model) -> Result<S
Ok(SourceNodeContainer {
name: format!("{}_repository.rs", model.name.to_snake_case()),
inner: SourceNode::File(pretty)
inner: SourceNode::File(pretty),
})
}

View file

@ -13,7 +13,7 @@ pub fn generate_repositories_source_files(models: &[Model]) -> Result<SourceNode
nodes.push(base::generate_repository_file(models, model)?);
// nodes.push(relations::generate_repository_file(model)?);
}
let mut mod_index_code: String = String::new();
for node in &nodes {
let module_name = node.name.replace(".rs", "");
@ -21,11 +21,10 @@ pub fn generate_repositories_source_files(models: &[Model]) -> Result<SourceNode
}
nodes.push(SourceNodeContainer {
name: "mod.rs".into(),
inner: SourceNode::File(mod_index_code.to_string())
inner: SourceNode::File(mod_index_code.to_string()),
});
Ok(SourceNodeContainer {
name: "".into(),
inner: SourceNode::Directory(nodes)
inner: SourceNode::Directory(nodes),
})
}

View file

@ -5,17 +5,23 @@ use crate::models::{Field, FieldForeignMode, Model};
/// method that can be used to retreive a list of entities of type X that are the children of a parent type Y
/// ex: get all comments of a post
pub fn gen_get_many_of_related_entity_method(model: &Model, foreign_key_field: &Field) -> TokenStream {
pub fn gen_get_many_of_related_entity_method(
model: &Model,
foreign_key_field: &Field,
) -> TokenStream {
let resource_ident = format_ident!("{}", &model.name);
let foreign_ref_params = match &foreign_key_field.foreign_mode {
FieldForeignMode::ForeignRef(params) => params,
FieldForeignMode::NotRef => {
panic!("Expected foreign key");
}
}
};
let select_query = format!("SELECT * FROM {} WHERE {} = $1", model.table_name, foreign_key_field.name);
let select_query = format!(
"SELECT * FROM {} WHERE {} = $1",
model.table_name, foreign_key_field.name
);
let func_name_ident = format_ident!("get_many_of_{}", foreign_ref_params.target_resource_name);
@ -28,4 +34,3 @@ pub fn gen_get_many_of_related_entity_method(model: &Model, foreign_key_field: &
}
}
}

View file

@ -1,22 +1,19 @@
use std::{ffi::OsStr, path::Path};
use attribute_derive::FromAttr;
use std::{ffi::OsStr, path::Path};
use anyhow::{anyhow, Result};
use argh::FromArgs;
use anyhow::{Result, anyhow};
use crate::generators::{SourceNode, SourceNodeContainer};
// use gen_migrations::generate_create_table_sql;
// use gen_repositories::{generate_repositories_source_files, SourceNodeContainer};
pub mod generators;
pub mod models;
pub mod parse_models;
pub mod generators;
#[derive(FromAttr, PartialEq, Debug, Default)]
#[attribute(ident = sql_generator_model)]
pub struct SqlGeneratorModelAttr {
table_name: Option<String>
table_name: Option<String>,
}
#[derive(FromAttr, PartialEq, Debug, Default)]
@ -25,20 +22,19 @@ pub struct SqlGeneratorFieldAttr {
is_primary: Option<bool>,
is_unique: Option<bool>,
reverse_relation_name: Option<String>,
/// to indicate that this field will be used to obtains entities
/// our framework will generate methods for all fields that is an entrypoint
is_query_entrypoint: Option<bool>
is_query_entrypoint: Option<bool>,
}
#[derive(FromArgs, PartialEq, Debug)]
/// Generate SQL CREATE TABLE migrations
#[argh(subcommand, name = "gen-migrations")]
struct GenerateMigration {
/// path of file where to write all in one generated SQL migration
#[argh(option, short = 'o')]
output: Option<String>
output: Option<String>,
}
#[derive(FromArgs, PartialEq, Debug)]
@ -47,7 +43,7 @@ struct GenerateMigration {
struct GenerateRepositories {
/// path of the directory that contains repositories
#[argh(option, short = 'o')]
output: Option<String>
output: Option<String>,
}
#[derive(FromArgs, PartialEq, Debug)]
@ -67,18 +63,18 @@ struct GeneratorArgs {
/// path of the directory containing models
#[argh(option, short = 'm')]
models_path: Option<String>,
#[argh(subcommand)]
nested: GeneratorArgsSubCommands
nested: GeneratorArgsSubCommands,
}
fn write_source_code(base_path: &Path, snc: SourceNodeContainer) -> Result<()> {
let path = base_path.join(snc.name);
match snc.inner {
SourceNode::File(code) => {
println!("writing file {:?}", path);
println!("Writing file {:?}.", path);
std::fs::write(path, code)?;
},
}
SourceNode::Directory(dir) => {
for node in dir {
write_source_code(&path, node)?;
@ -92,11 +88,14 @@ pub fn main() -> Result<()> {
let args: GeneratorArgs = argh::from_env();
let project_root = &args.project_root.unwrap_or(".".to_string());
let project_root_path = Path::new(&project_root);
eprintln!("Using project root at: {:?}", &project_root_path.canonicalize()?);
eprintln!(
"Using project root at: {:?}",
&project_root_path.canonicalize()?
);
if !project_root_path.exists() {
return Err(anyhow!("Could not resolve project root path."));
}
// check Cargo.toml
let main_manifest_location = "Cargo.toml";
let main_manifest_path = project_root_path.join(main_manifest_location);
@ -117,13 +116,20 @@ pub fn main() -> Result<()> {
if !models_mod_path.exists() {
return Err(anyhow!("Could not resolve models modules."));
}
if models_mod_path.file_name().map(|x| x == OsStr::new("mod.rs")).unwrap_or(false) {
if models_mod_path
.file_name()
.map(|x| x == OsStr::new("mod.rs"))
.unwrap_or(false)
{
models_mod_path.pop();
}
eprintln!("Found models in project, parsing models");
let models = parse_models::parse_models_from_module(&models_mod_path)?;
dbg!(&models);
eprintln!(
"Found and parsed a grand total of {} sqlxgentools compatible models.",
models.len()
);
match args.nested {
GeneratorArgsSubCommands::GenerateRepositories(opts) => {
eprintln!("Generating repositories…");
@ -134,16 +140,15 @@ pub fn main() -> Result<()> {
return Err(anyhow!("Could not resolve repositories modules."));
}
let snc = generators::repositories::generate_repositories_source_files(&models)?;
dbg!(&snc);
write_source_code(&repositories_mod_path, snc)?;
},
}
GeneratorArgsSubCommands::GenerateMigration(opts) => {
eprintln!("Generating migrations…");
let sql_code = generators::migrations::generate_create_table_sql(&models)?;
if let Some(out_location) = opts.output {
let output_path = Path::new(&out_location);
let write_res = std::fs::write(output_path, sql_code);
eprintln!("{:?}", write_res);
let _write_res = std::fs::write(output_path, sql_code);
// TODO: check if write result is an error and return error message.
} else {
println!("{}", sql_code);
}

View file

@ -7,7 +7,7 @@ struct Model {
module_path: Vec<String>,
name: String,
table_name: String,
fields: Vec<Field>
fields: Vec<Field>,
}
impl Model {
@ -29,7 +29,6 @@ impl Model {
// }
}
#[derive(Debug, Clone)]
#[fully_pub]
struct ForeignRefParams {
@ -41,12 +40,11 @@ struct ForeignRefParams {
// target_resource_name_plural: String
}
#[derive(Debug, Clone)]
#[fully_pub]
enum FieldForeignMode {
ForeignRef(ForeignRefParams),
NotRef
NotRef,
}
#[derive(Debug, Clone)]
@ -58,6 +56,5 @@ struct Field {
is_unique: bool,
is_primary: bool,
is_query_entrypoint: bool,
foreign_mode: FieldForeignMode
foreign_mode: FieldForeignMode,
}

View file

@ -1,11 +1,14 @@
use std::{fs, path::Path};
use attribute_derive::FromAttr;
use std::{fs, path::Path};
use anyhow::{Result, anyhow};
use anyhow::{anyhow, Result};
use convert_case::{Case, Casing};
use syn::{GenericArgument, PathArguments, Type};
use syn::{GenericArgument, Type};
use crate::{SqlGeneratorFieldAttr, SqlGeneratorModelAttr, models::{Field, FieldForeignMode, ForeignRefParams, Model}};
use crate::{
models::{Field, FieldForeignMode, ForeignRefParams, Model},
SqlGeneratorFieldAttr, SqlGeneratorModelAttr,
};
fn extract_generic_type(base_segments: Vec<String>, ty: &Type) -> Option<&Type> {
// If it is not `TypePath`, it is not possible to be `Option<T>`, return `None`
@ -52,40 +55,19 @@ fn extract_generic_type(base_segments: Vec<String>, ty: &Type) -> Option<&Type>
fn get_type_first_ident(inp: &Type) -> Option<String> {
match inp {
Type::Path(field_type_path) => {
Some(field_type_path.path.segments.get(0).unwrap().ident.to_string())
},
_ => {
None
}
Type::Path(field_type_path) => Some(
field_type_path
.path
.segments
.get(0)
.unwrap()
.ident
.to_string(),
),
_ => None,
}
}
fn get_first_generic_arg_type_ident(inp: &Type) -> Option<String> {
if let Type::Path(field_type_path) = inp {
if let PathArguments::AngleBracketed(args) = &field_type_path.path.segments.get(0).unwrap().arguments {
if args.args.is_empty() {
None
} else {
if let GenericArgument::Type(arg_type) = args.args.get(0).unwrap() {
if let Type::Path(arg_type_path) = arg_type {
Some(arg_type_path.path.segments.get(0).unwrap().ident.to_string())
} else {
None
}
} else {
None
}
}
} else {
None
}
} else {
None
}
}
fn parse_model_attribute(item: &syn::ItemStruct) -> Result<Option<SqlGeneratorModelAttr>> {
for attr in item.attrs.iter() {
let attr_ident = match attr.path().get_ident() {
@ -101,9 +83,12 @@ fn parse_model_attribute(item: &syn::ItemStruct) -> Result<Option<SqlGeneratorMo
match SqlGeneratorModelAttr::from_attribute(attr) {
Ok(v) => {
return Ok(Some(v));
},
}
Err(err) => {
return Err(anyhow!("Failed to parse sql_generator_model attribute macro: {}", err));
return Err(anyhow!(
"Failed to parse sql_generator_model attribute macro: {}",
err
));
}
};
}
@ -125,9 +110,13 @@ fn parse_field_attribute(field: &syn::Field) -> Result<Option<SqlGeneratorFieldA
match SqlGeneratorFieldAttr::from_attribute(attr) {
Ok(v) => {
return Ok(Some(v));
},
}
Err(err) => {
return Err(anyhow!("Failed to parse sql_generator_field attribute macro on field {:?}, {}", field, err));
return Err(anyhow!(
"Failed to parse sql_generator_field attribute macro on field {:?}, {}",
field,
err
));
}
};
}
@ -136,10 +125,7 @@ fn parse_field_attribute(field: &syn::Field) -> Result<Option<SqlGeneratorFieldA
/// Take struct name as source, apply snake case and pluralize with a s
fn generate_table_name_from_struct_name(struct_name: &str) -> String {
format!(
"{}s",
struct_name.to_case(Case::Snake)
)
format!("{}s", struct_name.to_case(Case::Snake))
}
/// Scan for models struct in a rust file and return a struct representing the model
@ -150,127 +136,130 @@ pub fn parse_models(source_code_path: &Path) -> Result<Vec<Model>> {
let mut models: Vec<Model> = vec![];
for item in parsed_file.items {
match item {
syn::Item::Struct(itemval) => {
let model_name = itemval.ident.to_string();
let model_attrs = match parse_model_attribute(&itemval)? {
Some(v) => v,
None => {
// we require model struct to have the `sql_generator_model` attribute
continue;
}
if let syn::Item::Struct(itemval) = item {
let model_name = itemval.ident.to_string();
let model_attrs = match parse_model_attribute(&itemval)? {
Some(v) => v,
None => {
// we require model struct to have the `sql_generator_model` attribute
continue;
}
};
let mut fields: Vec<Field> = vec![];
for field in itemval.fields.iter() {
let field_name = field.ident.clone().unwrap().to_string();
let field_type = field.ty.clone();
let mut output_field = Field {
name: field_name,
rust_type: "Unknown".into(),
is_nullable: false,
is_primary: false,
is_unique: false,
is_query_entrypoint: false,
foreign_mode: FieldForeignMode::NotRef,
};
let mut fields: Vec<Field> = vec![];
for field in itemval.fields.iter() {
let field_name = field.ident.clone().unwrap().to_string();
let field_type = field.ty.clone();
println!("field {} {:?}", field_name, field_type);
let mut output_field = Field {
name: field_name,
rust_type: "Unknown".into(),
is_nullable: false,
is_primary: false,
is_unique: false,
is_query_entrypoint: false,
foreign_mode: FieldForeignMode::NotRef
let first_type: String = match get_type_first_ident(&field_type) {
Some(v) => v,
None => {
return Err(anyhow!("Could not extract ident from Option inner type"));
}
};
let mut final_type = first_type.clone();
if first_type == "Option" {
output_field.is_nullable = true;
let inner_type = match extract_generic_type(
vec![
"Option".into(),
"std:option:Option".into(),
"core:option:Option".into(),
],
&field_type,
) {
Some(v) => v,
None => {
return Err(anyhow!("Could not extract type from Option"));
}
};
let first_type: String = match get_type_first_ident(&field_type) {
final_type = match get_type_first_ident(inner_type) {
Some(v) => v,
None => {
return Err(anyhow!("Could not extract ident from Option inner type"));
}
};
let mut final_type = first_type.clone();
if first_type == "Option" {
output_field.is_nullable = true;
let inner_type = match extract_generic_type(
vec!["Option".into(), "std:option:Option".into(), "core:option:Option".into()],
&field_type
) {
Some(v) => v,
None => {
return Err(anyhow!("Could not extract type from Option"));
}
};
final_type = match get_type_first_ident(inner_type) {
Some(v) => v,
None => {
return Err(anyhow!("Could not extract ident from Option inner type"));
}
}
}
if first_type == "Vec" {
let inner_type = match extract_generic_type(
vec!["Vec".into()],
&field_type
) {
Some(v) => v,
None => {
return Err(anyhow!("Could not extract type from Vec"));
}
};
final_type = match get_type_first_ident(inner_type) {
Some(v) => format!("Vec<{}>", v),
None => {
return Err(anyhow!("Could not extract ident from Vec inner type"));
}
}
}
output_field.rust_type = final_type;
let field_attrs_opt = parse_field_attribute(field)?;
if first_type == "ForeignRef" {
let attrs = match &field_attrs_opt {
Some(attrs) => attrs,
None => {
return Err(anyhow!("Found a ForeignRef type but did not found attributes."))
}
};
let rrn = match &attrs.reverse_relation_name {
Some(rrn) => rrn.clone(),
None => {
return Err(anyhow!("Found a ForeignRef type but did not found reverse_relation_name attribute."))
}
};
let extract_res = extract_generic_type(vec!["ForeignRef".into()], &field_type)
.and_then(|t| get_type_first_ident(t));
let target_type_name = match extract_res {
Some(v) => v,
None => {
return Err(anyhow!("Could not extract inner type from ForeignRef."));
}
};
output_field.foreign_mode = FieldForeignMode::ForeignRef(
ForeignRefParams {
reverse_relation_name: rrn,
target_resource_name: target_type_name.to_case(Case::Snake)
}
);
}
// parse attribute
if let Some(field_attr) = field_attrs_opt {
output_field.is_primary = field_attr.is_primary.unwrap_or_default();
output_field.is_unique = field_attr.is_unique.unwrap_or_default();
output_field.is_query_entrypoint = field_attr.is_query_entrypoint.unwrap_or_default();
}
fields.push(output_field);
}
models.push(Model {
module_path: vec![source_code_path.file_stem().unwrap().to_str().unwrap().to_string()],
name: model_name.clone(),
table_name: model_attrs.table_name
.unwrap_or(generate_table_name_from_struct_name(&model_name)),
fields
})
},
_ => {}
if first_type == "Vec" {
let inner_type = match extract_generic_type(vec!["Vec".into()], &field_type) {
Some(v) => v,
None => {
return Err(anyhow!("Could not extract type from Vec"));
}
};
final_type = match get_type_first_ident(inner_type) {
Some(v) => format!("Vec<{}>", v),
None => {
return Err(anyhow!("Could not extract ident from Vec inner type"));
}
}
}
output_field.rust_type = final_type;
let field_attrs_opt = parse_field_attribute(field)?;
if first_type == "ForeignRef" {
let attrs = match &field_attrs_opt {
Some(attrs) => attrs,
None => {
return Err(anyhow!(
"Found a ForeignRef type but did not found attributes."
))
}
};
let rrn = match &attrs.reverse_relation_name {
Some(rrn) => rrn.clone(),
None => {
return Err(anyhow!("Found a ForeignRef type but did not found reverse_relation_name attribute."))
}
};
let extract_res = extract_generic_type(vec!["ForeignRef".into()], &field_type)
.and_then(get_type_first_ident);
let target_type_name = match extract_res {
Some(v) => v,
None => {
return Err(anyhow!("Could not extract inner type from ForeignRef."));
}
};
output_field.foreign_mode = FieldForeignMode::ForeignRef(ForeignRefParams {
reverse_relation_name: rrn,
target_resource_name: target_type_name.to_case(Case::Snake),
});
}
// parse attribute
if let Some(field_attr) = field_attrs_opt {
output_field.is_primary = field_attr.is_primary.unwrap_or_default();
output_field.is_unique = field_attr.is_unique.unwrap_or_default();
output_field.is_query_entrypoint =
field_attr.is_query_entrypoint.unwrap_or_default();
}
fields.push(output_field);
}
models.push(Model {
module_path: vec![source_code_path
.file_stem()
.unwrap()
.to_str()
.unwrap()
.to_string()],
name: model_name.clone(),
table_name: model_attrs
.table_name
.unwrap_or(generate_table_name_from_struct_name(&model_name)),
fields,
})
}
}
Ok(models)
@ -281,7 +270,7 @@ fn parse_models_from_module_inner(module_path: &Path) -> Result<Vec<Model>> {
let mut models: Vec<Model> = vec![];
if module_path.is_file() {
println!("Parsing models from path {:?}.", module_path);
println!("Looking for models to parse from path {:?}.", module_path);
models.extend(parse_models(module_path)?);
return Ok(models);
}
@ -296,23 +285,6 @@ fn parse_models_from_module_inner(module_path: &Path) -> Result<Vec<Model>> {
Ok(models)
}
// fn complete_models(original_models: Vec<Model>) -> Result<Vec<Model>> {
// let mut new_models: Vec<Model> = vec![];
// for model in original_models {
// for original_field in model.fields {
// let mut field = original_field
// match original_field.foreign_mode {
// FieldForeignMode::NotRef => {},
// FieldForeignMode::ForeignRef(ref_params) => {
// }
// }
// }
// }
// Ok(new_models)
// }
/// Scan for models struct in a rust file and return a struct representing the model
pub fn parse_models_from_module(module_path: &Path) -> Result<Vec<Model>> {
let models = parse_models_from_module_inner(module_path)?;

View file

@ -1,6 +1,6 @@
[package]
name = "sqlxgentools_misc"
description = "Various misc class to use in applications that use sqlxgentools"
description = "Various data types and traits to use in a sqlxgentools-enabled codebase."
publish = true
edition.workspace = true
authors.workspace = true

View file

@ -12,23 +12,20 @@ use sqlx_core::error::BoxDynError;
use sqlx_core::types::Type;
use sqlx_sqlite::{Sqlite, SqliteArgumentValue};
#[fully_pub]
trait DatabaseLine {
fn id(&self) -> String;
}
/// Wrapper to mark a model field as foreign
/// You can use a generic argument inside ForeignRef to point to the target model
#[derive(Clone, Debug)]
#[fully_pub]
struct ForeignRef<T: Sized + DatabaseLine> {
pub target_type: PhantomData<T>,
pub target_id: String
pub target_id: String,
}
// Implement serde Serialize for ForeignRef
impl<T: Sized + DatabaseLine> Serialize for ForeignRef<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
@ -40,22 +37,20 @@ impl<T: Sized + DatabaseLine> Serialize for ForeignRef<T> {
}
}
impl<T: Sized + DatabaseLine> ForeignRef<T> {
pub fn new(entity: &T) -> ForeignRef<T> {
pub fn new(entity: &T) -> ForeignRef<T> {
ForeignRef {
target_type: PhantomData,
target_id: entity.id()
target_id: entity.id(),
}
}
}
impl<'r, DB: Database, T: Sized + DatabaseLine> Decode<'r, DB> for ForeignRef<T>
where
// we want to delegate some of the work to string decoding so let's make sure strings
// are supported by the database
&'r str: Decode<'r, DB>
&'r str: Decode<'r, DB>,
{
fn decode(
value: <DB as Database>::ValueRef<'r>,
@ -66,7 +61,7 @@ where
Ok(ForeignRef::<T> {
target_type: PhantomData,
target_id: ref_val
target_id: ref_val,
})
}
}
@ -84,9 +79,11 @@ impl<T: DatabaseLine + Sized> Type<Sqlite> for ForeignRef<T> {
}
impl<T: DatabaseLine + Sized> Encode<'_, Sqlite> for ForeignRef<T> {
fn encode_by_ref(&self, args: &mut Vec<SqliteArgumentValue<'_>>) -> Result<IsNull, BoxDynError> {
fn encode_by_ref(
&self,
args: &mut Vec<SqliteArgumentValue<'_>>,
) -> Result<IsNull, BoxDynError> {
args.push(SqliteArgumentValue::Text(self.target_id.clone().into()));
Ok(IsNull::No)
}
}