Compare commits

..

No commits in common. "d3aae47d2c0ba4d6c8b6b6088904a8c812cf485f" and "534ed8341962b0c93251aeb2d111e61c6b71ae67" have entirely different histories.

25 changed files with 504 additions and 515 deletions

View file

@ -1,12 +1,8 @@
# [WIP] sqlxgentools # [WIP] sqlxgentools
Little tool to generate SQLite migrations files and Rust SQLx repositories code, all from models structs. Tools to generate SQL migrations and Rust SQLx repositories code from models structs to use with a SQLite database.
Still very much work in progress, but it can be already used in your next Rust app if you don't mind some limitations like the lack of incremental migrations and little quirks here and there. Will be used in [minauthator](https://forge.lefuturiste.fr/mbess/minauthator).
## Getting started
- [Quick start tutorial](./docs/tutorials/quick_start.md)
## Project context ## Project context
@ -22,19 +18,10 @@ Still very much work in progress, but it can be already used in your next Rust a
- Provide a full a ORM interface - Provide a full a ORM interface
## Included crates
This project is split into 3 published crates.
- [`sqlxgentools_cli`](https://crates.io/crates/sqlxgentools_cli), used to parse, generate migrations and repositories.
- [`sqlxgentools_attrs`](https://crates.io/crates/sqlxgentools_attrs), provides proc macros.
- [`sqlxgentools_misc`](https://crates.io/crates/sqlxgentools_misc), provides data types and traits (optional).
## Features ## Features
- [x] generate migrations - [x] generate migrations
- [x] from scratch - [x] from scratch
- [ ] incremental migration
- [ ] up migration - [ ] up migration
- [ ] down migration - [ ] down migration
- [x] generate repositories - [x] generate repositories
@ -42,10 +29,27 @@ This project is split into 3 published crates.
- [x] get_by_id - [x] get_by_id
- [x] insert - [x] insert
- [x] insert_many - [x] insert_many
- [x] custom get_by, get_many_by - [ ] generate custom by
- [x] get_many_of (from one-to-many relations) - [x] co-exist with custom repository
## Contributions ## Usage
Questions, remarks and contributions is very much welcomed. ### Generate initial CREATE TABLE sqlite migration
cargo run --bin sqlx-generator -- ./path/to/project generate-create-migrations > migrations/all.sql
sqlx-generator \
-m path/to/models \
gen-repositories \
-o path/to/repositories
sqlx-generator \
-m path/to/models \
gen-migrations \
-o path/to/migrations/all.sql
### Generate repositories code
not implemented yet
cargo run --bin sqlx-generator -- ./path/to/project generate-repositories

View file

View file

@ -6,34 +6,3 @@ Steps:
- Generate migrations - Generate migrations
- Generate repositories - Generate repositories
- Use repositories in your code - Use repositories in your code
### CLI installation
The [sqlxgentools_cli crate](https://crates.io/crates/sqlxgentools_cli) provides the CLI,
it can be installed globally on your machine (or at least your user).
cargo install sqlxgentools_cli
### Project installation
Install the `sqlxgentools_attrs` crate
### Declare your models
TODO
### Generate migrations
Change directory into your project root.
sqlx-generator -m path/to/models_module gen-migrations -o path/to/migrations/all.sql
### Generate repositories
Change directory into your project root.
sqlx-generator -m path/to/models_module gen-repositories -o path/to/repositories_module
### Use the repositories
TODO

View file

@ -1 +0,0 @@
TODO

View file

@ -1,11 +1,11 @@
use anyhow::Context; use anyhow::Context;
use anyhow::Result;
use std::str::FromStr; use std::str::FromStr;
use std::path::PathBuf;
use anyhow::Result;
use fully_pub::fully_pub; use fully_pub::fully_pub;
use sqlx::{ use sqlx::{
sqlite::{SqliteConnectOptions, SqlitePoolOptions}, Pool, Sqlite, sqlite::{SqliteConnectOptions, SqlitePoolOptions},
Pool, Sqlite,
}; };
/// database storage interface /// database storage interface
@ -13,8 +13,17 @@ use sqlx::{
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
struct Database(Pool<Sqlite>); struct Database(Pool<Sqlite>);
/// Initialize database /// Initialize database
pub async fn provide_database(sqlite_db_path: &str) -> Result<Database> { pub async fn provide_database(sqlite_db_path: &str) -> Result<Database> {
let path = PathBuf::from(sqlite_db_path);
let is_db_initialization = !path.exists();
// // database does not exists, trying to create it
// if path
// .parent()
// .filter(|pp| pp.exists())
// Err(anyhow!("Could not find parent directory of the db location.")));
let conn_str = format!("sqlite://{sqlite_db_path}"); let conn_str = format!("sqlite://{sqlite_db_path}");
let pool = SqlitePoolOptions::new() let pool = SqlitePoolOptions::new()
@ -22,6 +31,11 @@ pub async fn provide_database(sqlite_db_path: &str) -> Result<Database> {
.connect_with(SqliteConnectOptions::from_str(&conn_str)?.create_if_missing(true)) .connect_with(SqliteConnectOptions::from_str(&conn_str)?.create_if_missing(true))
.await .await
.context("could not connect to database_url")?; .context("could not connect to database_url")?;
// if is_db_initialization {
// initialize_db(Database(pool.clone())).await?;
// }
Ok(Database(pool)) Ok(Database(pool))
} }

View file

@ -1,3 +1,3 @@
pub mod repositories;
pub mod db; pub mod db;
pub mod models; pub mod models;
pub mod repositories;

View file

@ -1,24 +1,21 @@
use anyhow::Result; use anyhow::{Context, Result};
use chrono::Utc; use chrono::Utc;
use sqlx::types::Json; use sqlx::types::Json;
use sqlxgentools_misc::ForeignRef; use sqlxgentools_misc::ForeignRef;
use crate::{ use crate::{db::provide_database, models::user::{User, UserToken}, repositories::user_token_repository::UserTokenRepository};
db::provide_database,
models::user::{User, UserToken},
repositories::user_token_repository::UserTokenRepository,
};
pub mod db;
pub mod models; pub mod models;
pub mod repositories; pub mod repositories;
pub mod db;
#[tokio::main] #[tokio::main]
async fn main() -> Result<()> { async fn main() -> Result<()> {
println!("Sandbox"); println!("Sandbox");
let users = [ let users = vec![
User { User {
id: "idu1".into(), id: "idu1".into(),
handle: "john.doe".into(), handle: "john.doe".into(),
@ -27,7 +24,7 @@ async fn main() -> Result<()> {
last_login_at: None, last_login_at: None,
status: models::user::UserStatus::Invited, status: models::user::UserStatus::Invited,
groups: Json(vec![]), groups: Json(vec![]),
avatar_bytes: None, avatar_bytes: None
}, },
User { User {
id: "idu2".into(), id: "idu2".into(),
@ -37,7 +34,7 @@ async fn main() -> Result<()> {
last_login_at: None, last_login_at: None,
status: models::user::UserStatus::Invited, status: models::user::UserStatus::Invited,
groups: Json(vec![]), groups: Json(vec![]),
avatar_bytes: None, avatar_bytes: None
}, },
User { User {
id: "idu3".into(), id: "idu3".into(),
@ -47,30 +44,29 @@ async fn main() -> Result<()> {
last_login_at: None, last_login_at: None,
status: models::user::UserStatus::Invited, status: models::user::UserStatus::Invited,
groups: Json(vec![]), groups: Json(vec![]),
avatar_bytes: None, avatar_bytes: None
}, }
]; ];
let _user_token = UserToken { let user_token = UserToken {
id: "idtoken1".into(), id: "idtoken1".into(),
secret: "4LP5A3F3XBV5NM8VXRGZG3QDXO9PNAC0".into(), secret: "4LP5A3F3XBV5NM8VXRGZG3QDXO9PNAC0".into(),
last_use_time: None, last_use_time: None,
creation_time: Utc::now(), creation_time: Utc::now(),
expiration_time: Utc::now(), expiration_time: Utc::now(),
user_id: ForeignRef::new(users.first().unwrap()), user_id: ForeignRef::new(&users.get(0).unwrap())
}; };
let db = provide_database("tmp/db.db").await?; let db = provide_database("tmp/db.db").await?;
let user_token_repo = UserTokenRepository::new(db); let user_token_repo = UserTokenRepository::new(db);
user_token_repo user_token_repo.insert_many(&vec![
.insert_many(&vec![
UserToken { UserToken {
id: "idtoken2".into(), id: "idtoken2".into(),
secret: "4LP5A3F3XBV5NM8VXRGZG3QDXO9PNAC0".into(), secret: "4LP5A3F3XBV5NM8VXRGZG3QDXO9PNAC0".into(),
last_use_time: None, last_use_time: None,
creation_time: Utc::now(), creation_time: Utc::now(),
expiration_time: Utc::now(), expiration_time: Utc::now(),
user_id: ForeignRef::new(users.first().unwrap()), user_id: ForeignRef::new(&users.get(0).unwrap())
}, },
UserToken { UserToken {
id: "idtoken3".into(), id: "idtoken3".into(),
@ -78,7 +74,7 @@ async fn main() -> Result<()> {
last_use_time: None, last_use_time: None,
creation_time: Utc::now(), creation_time: Utc::now(),
expiration_time: Utc::now(), expiration_time: Utc::now(),
user_id: ForeignRef::new(users.get(1).unwrap()), user_id: ForeignRef::new(&users.get(1).unwrap())
}, },
UserToken { UserToken {
id: "idtoken4".into(), id: "idtoken4".into(),
@ -86,14 +82,14 @@ async fn main() -> Result<()> {
last_use_time: None, last_use_time: None,
creation_time: Utc::now(), creation_time: Utc::now(),
expiration_time: Utc::now(), expiration_time: Utc::now(),
user_id: ForeignRef::new(users.get(1).unwrap()), user_id: ForeignRef::new(&users.get(1).unwrap())
}, }
]) ]).await?;
.await?; let user_tokens = user_token_repo.get_many_user_tokens_by_usersss(
let user_tokens = user_token_repo vec!["idu2".into()]
.get_many_user_tokens_by_usersss(vec!["idu2".into()]) ).await?;
.await?;
dbg!(&user_tokens); dbg!(&user_tokens);
Ok(()) Ok(())
} }

View file

@ -1,8 +1,8 @@
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use fully_pub::fully_pub;
use sqlx::types::Json; use sqlx::types::Json;
use fully_pub::fully_pub;
use sqlxgentools_attrs::{sql_generator_model, SqlGeneratorDerive, SqlGeneratorModelWithId}; use sqlxgentools_attrs::{SqlGeneratorDerive, SqlGeneratorModelWithId, sql_generator_model};
use sqlxgentools_misc::{DatabaseLine, ForeignRef}; use sqlxgentools_misc::{DatabaseLine, ForeignRef};
#[derive(sqlx::Type, Clone, Debug, PartialEq)] #[derive(sqlx::Type, Clone, Debug, PartialEq)]
@ -11,36 +11,37 @@ enum UserStatus {
Disabled, Disabled,
Invited, Invited,
Active, Active,
Archived, Archived
} }
#[derive(SqlGeneratorDerive, SqlGeneratorModelWithId, sqlx::FromRow, Debug, Clone)] #[derive(SqlGeneratorDerive, SqlGeneratorModelWithId, sqlx::FromRow, Debug, Clone)]
#[sql_generator_model(table_name = "usersss")] #[sql_generator_model(table_name="usersss")]
#[fully_pub] #[fully_pub]
struct User { struct User {
#[sql_generator_field(is_primary = true)] #[sql_generator_field(is_primary=true)]
id: String, id: String,
#[sql_generator_field(is_unique = true)] #[sql_generator_field(is_unique=true)]
handle: String, handle: String,
full_name: Option<String>, full_name: Option<String>,
prefered_color: Option<i64>, prefered_color: Option<i64>,
last_login_at: Option<DateTime<Utc>>, last_login_at: Option<DateTime<Utc>>,
status: UserStatus, status: UserStatus,
groups: Json<Vec<String>>, groups: Json<Vec<String>>,
avatar_bytes: Option<Vec<u8>>, avatar_bytes: Option<Vec<u8>>
} }
#[derive(SqlGeneratorDerive, SqlGeneratorModelWithId, sqlx::FromRow, Debug, Clone)] #[derive(SqlGeneratorDerive, SqlGeneratorModelWithId, sqlx::FromRow, Debug, Clone)]
#[sql_generator_model(table_name = "user_tokens")] #[sql_generator_model(table_name="user_tokens")]
#[fully_pub] #[fully_pub]
struct UserToken { struct UserToken {
#[sql_generator_field(is_primary = true)] #[sql_generator_field(is_primary=true)]
id: String, id: String,
secret: String, secret: String,
last_use_time: Option<DateTime<Utc>>, last_use_time: Option<DateTime<Utc>>,
creation_time: DateTime<Utc>, creation_time: DateTime<Utc>,
expiration_time: DateTime<Utc>, expiration_time: DateTime<Utc>,
#[sql_generator_field(reverse_relation_name = "user_tokens")] #[sql_generator_field(reverse_relation_name="user_tokens")] // to generate get_user_tokens_of_user(&user_id)
// to generate get_user_tokens_of_user(&user_id) user_id: ForeignRef<User>
user_id: ForeignRef<User>,
} }

View file

@ -1,5 +1,5 @@
use crate::db::Database;
use crate::models::user::User; use crate::models::user::User;
use crate::db::Database;
pub struct UserRepository { pub struct UserRepository {
db: Database, db: Database,
} }
@ -8,9 +8,7 @@ impl UserRepository {
UserRepository { db } UserRepository { db }
} }
pub async fn get_all(&self) -> Result<Vec<User>, sqlx::Error> { pub async fn get_all(&self) -> Result<Vec<User>, sqlx::Error> {
sqlx::query_as::<_, User>("SELECT * FROM usersss") sqlx::query_as::<_, User>("SELECT * FROM usersss").fetch_all(&self.db.0).await
.fetch_all(&self.db.0)
.await
} }
pub async fn get_by_id(&self, item_id: &str) -> Result<User, sqlx::Error> { pub async fn get_by_id(&self, item_id: &str) -> Result<User, sqlx::Error> {
sqlx::query_as::<_, User>("SELECT * FROM usersss WHERE id = $1") sqlx::query_as::<_, User>("SELECT * FROM usersss WHERE id = $1")
@ -18,7 +16,10 @@ impl UserRepository {
.fetch_one(&self.db.0) .fetch_one(&self.db.0)
.await .await
} }
pub async fn get_many_by_id(&self, items_ids: &[&str]) -> Result<Vec<User>, sqlx::Error> { pub async fn get_many_by_id(
&self,
items_ids: &[&str],
) -> Result<Vec<User>, sqlx::Error> {
if items_ids.is_empty() { if items_ids.is_empty() {
return Ok(vec![]); return Ok(vec![]);
} }
@ -26,7 +27,9 @@ impl UserRepository {
.map(|i| format!("${}", i)) .map(|i| format!("${}", i))
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join(","); .join(",");
let query_sql = format!("SELECT * FROM usersss WHERE id IN ({})", placeholder_params); let query_sql = format!(
"SELECT * FROM usersss WHERE id IN ({})", placeholder_params
);
let mut query = sqlx::query_as::<_, User>(&query_sql); let mut query = sqlx::query_as::<_, User>(&query_sql);
for id in items_ids { for id in items_ids {
query = query.bind(id); query = query.bind(id);
@ -40,8 +43,8 @@ impl UserRepository {
.bind(&entity.id) .bind(&entity.id)
.bind(&entity.handle) .bind(&entity.handle)
.bind(&entity.full_name) .bind(&entity.full_name)
.bind(entity.prefered_color) .bind(&entity.prefered_color)
.bind(entity.last_login_at) .bind(&entity.last_login_at)
.bind(&entity.status) .bind(&entity.status)
.bind(&entity.groups) .bind(&entity.groups)
.bind(&entity.avatar_bytes) .bind(&entity.avatar_bytes)
@ -56,11 +59,8 @@ impl UserRepository {
.map(|c| c.to_vec()) .map(|c| c.to_vec())
.map(|x| { .map(|x| {
format!( format!(
"({})", "({})", x.iter().map(| i | format!("${}", i)).collect:: < Vec <
x.iter() String >> ().join(", ")
.map(|i| format!("${}", i))
.collect::<Vec<String>>()
.join(", ")
) )
}) })
.collect::<Vec<String>>() .collect::<Vec<String>>()
@ -75,8 +75,8 @@ impl UserRepository {
.bind(&entity.id) .bind(&entity.id)
.bind(&entity.handle) .bind(&entity.handle)
.bind(&entity.full_name) .bind(&entity.full_name)
.bind(entity.prefered_color) .bind(&entity.prefered_color)
.bind(entity.last_login_at) .bind(&entity.last_login_at)
.bind(&entity.status) .bind(&entity.status)
.bind(&entity.groups) .bind(&entity.groups)
.bind(&entity.avatar_bytes); .bind(&entity.avatar_bytes);
@ -84,7 +84,11 @@ impl UserRepository {
query.execute(&self.db.0).await?; query.execute(&self.db.0).await?;
Ok(()) Ok(())
} }
pub async fn update_by_id(&self, item_id: &str, entity: &User) -> Result<(), sqlx::Error> { pub async fn update_by_id(
&self,
item_id: &str,
entity: &User,
) -> Result<(), sqlx::Error> {
sqlx::query( sqlx::query(
"UPDATE usersss SET id = $2, handle = $3, full_name = $4, prefered_color = $5, last_login_at = $6, status = $7, groups = $8, avatar_bytes = $9 WHERE id = $1", "UPDATE usersss SET id = $2, handle = $3, full_name = $4, prefered_color = $5, last_login_at = $6, status = $7, groups = $8, avatar_bytes = $9 WHERE id = $1",
) )
@ -92,8 +96,8 @@ impl UserRepository {
.bind(&entity.id) .bind(&entity.id)
.bind(&entity.handle) .bind(&entity.handle)
.bind(&entity.full_name) .bind(&entity.full_name)
.bind(entity.prefered_color) .bind(&entity.prefered_color)
.bind(entity.last_login_at) .bind(&entity.last_login_at)
.bind(&entity.status) .bind(&entity.status)
.bind(&entity.groups) .bind(&entity.groups)
.bind(&entity.avatar_bytes) .bind(&entity.avatar_bytes)

View file

@ -1,5 +1,5 @@
use crate::db::Database;
use crate::models::user::UserToken; use crate::models::user::UserToken;
use crate::db::Database;
pub struct UserTokenRepository { pub struct UserTokenRepository {
db: Database, db: Database,
} }
@ -18,7 +18,10 @@ impl UserTokenRepository {
.fetch_one(&self.db.0) .fetch_one(&self.db.0)
.await .await
} }
pub async fn get_many_by_id(&self, items_ids: &[&str]) -> Result<Vec<UserToken>, sqlx::Error> { pub async fn get_many_by_id(
&self,
items_ids: &[&str],
) -> Result<Vec<UserToken>, sqlx::Error> {
if items_ids.is_empty() { if items_ids.is_empty() {
return Ok(vec![]); return Ok(vec![]);
} }
@ -27,8 +30,7 @@ impl UserTokenRepository {
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join(","); .join(",");
let query_sql = format!( let query_sql = format!(
"SELECT * FROM user_tokens WHERE id IN ({})", "SELECT * FROM user_tokens WHERE id IN ({})", placeholder_params
placeholder_params
); );
let mut query = sqlx::query_as::<_, UserToken>(&query_sql); let mut query = sqlx::query_as::<_, UserToken>(&query_sql);
for id in items_ids { for id in items_ids {
@ -42,26 +44,26 @@ impl UserTokenRepository {
) )
.bind(&entity.id) .bind(&entity.id)
.bind(&entity.secret) .bind(&entity.secret)
.bind(entity.last_use_time) .bind(&entity.last_use_time)
.bind(entity.creation_time) .bind(&entity.creation_time)
.bind(entity.expiration_time) .bind(&entity.expiration_time)
.bind(&entity.user_id.target_id) .bind(&entity.user_id.target_id)
.execute(&self.db.0) .execute(&self.db.0)
.await?; .await?;
Ok(()) Ok(())
} }
pub async fn insert_many(&self, entities: &Vec<UserToken>) -> Result<(), sqlx::Error> { pub async fn insert_many(
&self,
entities: &Vec<UserToken>,
) -> Result<(), sqlx::Error> {
let values_templates: String = (1..(6usize * entities.len() + 1)) let values_templates: String = (1..(6usize * entities.len() + 1))
.collect::<Vec<usize>>() .collect::<Vec<usize>>()
.chunks(6usize) .chunks(6usize)
.map(|c| c.to_vec()) .map(|c| c.to_vec())
.map(|x| { .map(|x| {
format!( format!(
"({})", "({})", x.iter().map(| i | format!("${}", i)).collect:: < Vec <
x.iter() String >> ().join(", ")
.map(|i| format!("${}", i))
.collect::<Vec<String>>()
.join(", ")
) )
}) })
.collect::<Vec<String>>() .collect::<Vec<String>>()
@ -75,24 +77,28 @@ impl UserTokenRepository {
query = query query = query
.bind(&entity.id) .bind(&entity.id)
.bind(&entity.secret) .bind(&entity.secret)
.bind(entity.last_use_time) .bind(&entity.last_use_time)
.bind(entity.creation_time) .bind(&entity.creation_time)
.bind(entity.expiration_time) .bind(&entity.expiration_time)
.bind(&entity.user_id.target_id); .bind(&entity.user_id.target_id);
} }
query.execute(&self.db.0).await?; query.execute(&self.db.0).await?;
Ok(()) Ok(())
} }
pub async fn update_by_id(&self, item_id: &str, entity: &UserToken) -> Result<(), sqlx::Error> { pub async fn update_by_id(
&self,
item_id: &str,
entity: &UserToken,
) -> Result<(), sqlx::Error> {
sqlx::query( sqlx::query(
"UPDATE user_tokens SET id = $2, secret = $3, last_use_time = $4, creation_time = $5, expiration_time = $6, user_id = $7 WHERE id = $1", "UPDATE user_tokens SET id = $2, secret = $3, last_use_time = $4, creation_time = $5, expiration_time = $6, user_id = $7 WHERE id = $1",
) )
.bind(item_id) .bind(item_id)
.bind(&entity.id) .bind(&entity.id)
.bind(&entity.secret) .bind(&entity.secret)
.bind(entity.last_use_time) .bind(&entity.last_use_time)
.bind(entity.creation_time) .bind(&entity.creation_time)
.bind(entity.expiration_time) .bind(&entity.expiration_time)
.bind(&entity.user_id.target_id) .bind(&entity.user_id.target_id)
.execute(&self.db.0) .execute(&self.db.0)
.await?; .await?;
@ -126,8 +132,7 @@ impl UserTokenRepository {
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join(","); .join(",");
let query_tmpl = format!( let query_tmpl = format!(
"SELECT * FROM user_tokens WHERE user_id IN ({})", "SELECT * FROM user_tokens WHERE user_id IN ({})", placeholder_params
placeholder_params
); );
let mut query = sqlx::query_as::<_, UserToken>(&query_tmpl); let mut query = sqlx::query_as::<_, UserToken>(&query_tmpl);
for id in items_ids { for id in items_ids {

View file

@ -5,10 +5,7 @@
use std::assert_matches::assert_matches; use std::assert_matches::assert_matches;
use chrono::Utc; use chrono::Utc;
use sandbox::{ use sandbox::{models::user::{User, UserStatus}, repositories::user_repository::UserRepository};
models::user::{User, UserStatus},
repositories::user_repository::UserRepository,
};
use sqlx::{types::Json, Pool, Sqlite}; use sqlx::{types::Json, Pool, Sqlite};
#[sqlx::test(fixtures("../src/migrations/all.sql"))] #[sqlx::test(fixtures("../src/migrations/all.sql"))]
@ -25,25 +22,23 @@ async fn test_user_repository_create_read_update_delete(pool: Pool<Sqlite>) -> s
last_login_at: Some(Utc::now()), last_login_at: Some(Utc::now()),
status: UserStatus::Invited, status: UserStatus::Invited,
groups: Json(vec!["artists".into()]), groups: Json(vec!["artists".into()]),
avatar_bytes: vec![0x00], avatar_bytes: vec![0x00]
}; };
assert_matches!(user_repo.insert(&new_user).await, Ok(()));
assert_matches!( assert_matches!(
user_repo user_repo.insert(&new_user).await,
.get_by_id("ffffffff-0000-4000-0000-0000000000c9".into()) Ok(())
.await, );
assert_matches!(
user_repo.get_by_id("ffffffff-0000-4000-0000-0000000000c9".into()).await,
Ok(User { .. }) Ok(User { .. })
); );
assert_matches!( assert_matches!(
user_repo user_repo.get_by_id("ffffffff-0000-4040-0000-000000000000".into()).await,
.get_by_id("ffffffff-0000-4040-0000-000000000000".into())
.await,
Err(sqlx::Error::RowNotFound) Err(sqlx::Error::RowNotFound)
); );
// Insert Many // Insert Many
let bunch_of_users: Vec<User> = (0..10) let bunch_of_users: Vec<User> = (0..10).map(|pid| User {
.map(|pid| User {
id: format!("ffffffff-0000-4000-0010-{:0>8}", pid), id: format!("ffffffff-0000-4000-0010-{:0>8}", pid),
handle: format!("user num {}", pid), handle: format!("user num {}", pid),
full_name: None, full_name: None,
@ -51,14 +46,19 @@ async fn test_user_repository_create_read_update_delete(pool: Pool<Sqlite>) -> s
last_login_at: None, last_login_at: None,
status: UserStatus::Invited, status: UserStatus::Invited,
groups: Json(vec![]), groups: Json(vec![]),
avatar_bytes: vec![], avatar_bytes: vec![]
}) }).collect();
.collect(); assert_matches!(
assert_matches!(user_repo.insert_many(&bunch_of_users).await, Ok(())); user_repo.insert_many(&bunch_of_users).await,
Ok(())
);
// Read many all // Read many all
let read_all_res = user_repo.get_all().await; let read_all_res = user_repo.get_all().await;
assert_matches!(read_all_res, Ok(..)); assert_matches!(
read_all_res,
Ok(..)
);
let all_users = read_all_res.unwrap(); let all_users = read_all_res.unwrap();
assert_eq!(all_users.len(), 11); assert_eq!(all_users.len(), 11);
@ -69,18 +69,16 @@ async fn test_user_repository_create_read_update_delete(pool: Pool<Sqlite>) -> s
user_repo.update_by_id(&new_user.id, &updated_user).await, user_repo.update_by_id(&new_user.id, &updated_user).await,
Ok(()) Ok(())
); );
let user_from_db = user_repo let user_from_db = user_repo.get_by_id("ffffffff-0000-4000-0000-0000000000c9".into()).await.unwrap();
.get_by_id("ffffffff-0000-4000-0000-0000000000c9".into())
.await
.unwrap();
assert_eq!(user_from_db.status, UserStatus::Disabled); assert_eq!(user_from_db.status, UserStatus::Disabled);
// Delete // Delete
assert_matches!(user_repo.delete_by_id(&new_user.id).await, Ok(()));
assert_matches!( assert_matches!(
user_repo user_repo.delete_by_id(&new_user.id).await,
.get_by_id("ffffffff-0000-4000-0000-0000000000c9".into()) Ok(())
.await, );
assert_matches!(
user_repo.get_by_id("ffffffff-0000-4000-0000-0000000000c9".into()).await,
Err(sqlx::Error::RowNotFound) Err(sqlx::Error::RowNotFound)
); );

View file

@ -1,6 +1,6 @@
use proc_macro::TokenStream; use proc_macro::TokenStream;
use quote::quote; use quote::quote;
use syn::{parse_macro_input, DeriveInput, Fields}; use syn::{DeriveInput, Fields, parse_macro_input};
#[proc_macro_attribute] #[proc_macro_attribute]
pub fn sql_generator_model(_attr: TokenStream, item: TokenStream) -> TokenStream { pub fn sql_generator_model(_attr: TokenStream, item: TokenStream) -> TokenStream {
@ -21,7 +21,7 @@ pub fn derive_sql_generator_model_with_id(input: TokenStream) -> TokenStream {
if let syn::Data::Struct(data) = input.data { if let syn::Data::Struct(data) = input.data {
if let Fields::Named(fields) = data.fields { if let Fields::Named(fields) = data.fields {
for field in fields.named { for field in fields.named {
if field.ident.as_ref().is_some_and(|ident| ident == "id") { if field.ident.as_ref().map_or(false, |ident| ident == "id") {
let expanded = quote! { let expanded = quote! {
impl DatabaseLine for #name { impl DatabaseLine for #name {
fn id(&self) -> String { fn id(&self) -> String {
@ -38,3 +38,4 @@ pub fn derive_sql_generator_model_with_id(input: TokenStream) -> TokenStream {
// If `id` field is not found, return an error // If `id` field is not found, return an error
panic!("Expected struct with a named field `id` of type String") panic!("Expected struct with a named field `id` of type String")
} }

View file

@ -1,7 +1,8 @@
use anyhow::{anyhow, Result}; use anyhow::{Result, anyhow};
use crate::models::{Field, Model}; use crate::models::{Field, Model};
// Implementations // Implementations
impl Field { impl Field {
/// return sqlite type /// return sqlite type
@ -20,7 +21,7 @@ impl Field {
"DateTime" => Some("DATETIME".into()), "DateTime" => Some("DATETIME".into()),
"Json" => Some("TEXT".into()), "Json" => Some("TEXT".into()),
"Vec<u8>" => Some("BLOB".into()), "Vec<u8>" => Some("BLOB".into()),
_ => Some("TEXT".into()), _ => Some("TEXT".into())
} }
} }
} }
@ -34,10 +35,8 @@ pub fn generate_create_table_sql(models: &[Model]) -> Result<String> {
let mut fields_sql: Vec<String> = vec![]; let mut fields_sql: Vec<String> = vec![];
for field in model.fields.iter() { for field in model.fields.iter() {
let mut additions: String = "".into(); let mut additions: String = "".into();
let sql_type = field.sql_type().ok_or(anyhow!(format!( let sql_type = field.sql_type()
"Could not find SQL type for field {}", .ok_or(anyhow!(format!("Could not find SQL type for field {}", field.name)))?;
field.name
)))?;
if !field.is_nullable { if !field.is_nullable {
additions.push_str(" NOT NULL"); additions.push_str(" NOT NULL");
} }
@ -47,15 +46,20 @@ pub fn generate_create_table_sql(models: &[Model]) -> Result<String> {
if field.is_primary { if field.is_primary {
additions.push_str(" PRIMARY KEY"); additions.push_str(" PRIMARY KEY");
} }
fields_sql.push(format!("\t{: <#18}\t{}{}", field.name, sql_type, additions)); fields_sql.push(
format!("\t{: <#18}\t{}{}", field.name, sql_type, additions)
);
} }
sql_code.push_str(&format!( sql_code.push_str(
&format!(
"CREATE TABLE {} (\n{}\n);\n", "CREATE TABLE {} (\n{}\n);\n",
model.table_name, model.table_name,
fields_sql.join(",\n") fields_sql.join(",\n")
)); )
);
} }
Ok(sql_code) Ok(sql_code)
} }

View file

@ -8,12 +8,13 @@ pub mod repositories;
#[fully_pub] #[fully_pub]
enum SourceNode { enum SourceNode {
File(String), File(String),
Directory(Vec<SourceNodeContainer>), Directory(Vec<SourceNodeContainer>)
} }
#[derive(Serialize, Debug)] #[derive(Serialize, Debug)]
#[fully_pub] #[fully_pub]
struct SourceNodeContainer { struct SourceNodeContainer {
name: String, name: String,
inner: SourceNode, inner: SourceNode
} }

View file

@ -1,14 +1,12 @@
use anyhow::Result; use anyhow::Result;
use heck::ToSnakeCase; use proc_macro2::{TokenStream, Ident};
use proc_macro2::{Ident, TokenStream};
use quote::{format_ident, quote}; use quote::{format_ident, quote};
use syn::File; use syn::File;
use heck::ToSnakeCase;
use crate::{generators::repositories::relations::gen_get_many_of_related_entity_method, models::{Field, FieldForeignMode, Model}};
use crate::generators::{SourceNode, SourceNodeContainer}; use crate::generators::{SourceNode, SourceNodeContainer};
use crate::{
generators::repositories::relations::gen_get_many_of_related_entity_method,
models::{Field, FieldForeignMode, Model},
};
fn gen_get_all_method(model: &Model) -> TokenStream { fn gen_get_all_method(model: &Model) -> TokenStream {
let resource_ident = format_ident!("{}", &model.name); let resource_ident = format_ident!("{}", &model.name);
@ -25,10 +23,7 @@ fn gen_get_all_method(model: &Model) -> TokenStream {
fn gen_get_by_field_method(model: &Model, query_field: &Field) -> TokenStream { fn gen_get_by_field_method(model: &Model, query_field: &Field) -> TokenStream {
let resource_ident = format_ident!("{}", &model.name); let resource_ident = format_ident!("{}", &model.name);
let select_query = format!( let select_query = format!("SELECT * FROM {} WHERE {} = $1", model.table_name, query_field.name);
"SELECT * FROM {} WHERE {} = $1",
model.table_name, query_field.name
);
let func_name_ident = format_ident!("get_by_{}", query_field.name); let func_name_ident = format_ident!("get_by_{}", query_field.name);
@ -45,10 +40,7 @@ fn gen_get_by_field_method(model: &Model, query_field: &Field) -> TokenStream {
fn gen_get_many_by_field_method(model: &Model, query_field: &Field) -> TokenStream { fn gen_get_many_by_field_method(model: &Model, query_field: &Field) -> TokenStream {
let resource_ident = format_ident!("{}", &model.name); let resource_ident = format_ident!("{}", &model.name);
let select_query_tmpl = format!( let select_query_tmpl = format!("SELECT * FROM {} WHERE {} IN ({{}})", model.table_name, query_field.name);
"SELECT * FROM {} WHERE {} IN ({{}})",
model.table_name, query_field.name
);
let func_name_ident = format_ident!("get_many_by_{}", query_field.name); let func_name_ident = format_ident!("get_many_by_{}", query_field.name);
@ -74,21 +66,21 @@ fn gen_get_many_by_field_method(model: &Model, query_field: &Field) -> TokenStre
} }
fn get_mutation_fields(model: &Model) -> (Vec<&Field>, Vec<&Field>) { fn get_mutation_fields(model: &Model) -> (Vec<&Field>, Vec<&Field>) {
let normal_field_names: Vec<&Field> = model let normal_field_names: Vec<&Field> = model.fields.iter()
.fields .filter(|f| match f.foreign_mode { FieldForeignMode::NotRef => true, FieldForeignMode::ForeignRef(_) => false })
.iter()
.filter(|f| match f.foreign_mode {
FieldForeignMode::NotRef => true,
FieldForeignMode::ForeignRef(_) => false,
})
.collect(); .collect();
let foreign_keys_field_names: Vec<&Field> = model let foreign_keys_field_names: Vec<&Field> = model.fields.iter()
.fields .filter(|f| match f.foreign_mode { FieldForeignMode::NotRef => false, FieldForeignMode::ForeignRef(_) => true })
.iter() .collect();
.filter(|f| match f.foreign_mode { (normal_field_names, foreign_keys_field_names)
FieldForeignMode::NotRef => false, }
FieldForeignMode::ForeignRef(_) => true,
}) fn get_mutation_fields_ident(model: &Model) -> (Vec<&Field>, Vec<&Field>) {
let normal_field_names: Vec<&Field> = model.fields.iter()
.filter(|f| match f.foreign_mode { FieldForeignMode::NotRef => true, FieldForeignMode::ForeignRef(_) => false })
.collect();
let foreign_keys_field_names: Vec<&Field> = model.fields.iter()
.filter(|f| match f.foreign_mode { FieldForeignMode::NotRef => false, FieldForeignMode::ForeignRef(_) => true })
.collect(); .collect();
(normal_field_names, foreign_keys_field_names) (normal_field_names, foreign_keys_field_names)
} }
@ -96,31 +88,26 @@ fn get_mutation_fields(model: &Model) -> (Vec<&Field>, Vec<&Field>) {
fn gen_insert_method(model: &Model) -> TokenStream { fn gen_insert_method(model: &Model) -> TokenStream {
let resource_ident = format_ident!("{}", &model.name); let resource_ident = format_ident!("{}", &model.name);
let value_templates = (1..(model.fields.len() + 1)) let value_templates = (1..(model.fields.len()+1))
.map(|i| format!("${}", i)) .map(|i| format!("${}", i))
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join(", "); .join(", ");
let (normal_fields, foreign_keys_fields) = get_mutation_fields(model); let (normal_fields, foreign_keys_fields) = get_mutation_fields(model);
let (normal_field_idents, foreign_keys_field_idents) = ( let (normal_field_idents, foreign_keys_field_idents) = (
normal_fields normal_fields.iter().map(|f| format_ident!("{}", &f.name)).collect::<Vec<Ident>>(),
.iter() foreign_keys_fields.iter().map(|f| format_ident!("{}", &f.name)).collect::<Vec<Ident>>()
.map(|f| format_ident!("{}", &f.name))
.collect::<Vec<Ident>>(),
foreign_keys_fields
.iter()
.map(|f| format_ident!("{}", &f.name))
.collect::<Vec<Ident>>(),
); );
let sql_columns = [normal_fields, foreign_keys_fields] let sql_columns = [normal_fields, foreign_keys_fields].concat()
.concat()
.iter() .iter()
.map(|f| f.name.clone()) .map(|f| f.name.clone())
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join(", "); .join(", ");
let insert_query = format!( let insert_query = format!(
"INSERT INTO {} ({}) VALUES ({})", "INSERT INTO {} ({}) VALUES ({})",
model.table_name, sql_columns, value_templates model.table_name,
sql_columns,
value_templates
); );
// foreign keys must be inserted first, we sort the columns so that foreign keys are first // foreign keys must be inserted first, we sort the columns so that foreign keys are first
@ -139,26 +126,19 @@ fn gen_insert_method(model: &Model) -> TokenStream {
fn gen_insert_many_method(model: &Model) -> TokenStream { fn gen_insert_many_method(model: &Model) -> TokenStream {
let resource_ident = format_ident!("{}", &model.name); let resource_ident = format_ident!("{}", &model.name);
let sql_columns = model let sql_columns = model.fields.iter()
.fields
.iter()
.map(|f| f.name.clone()) .map(|f| f.name.clone())
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join(", "); .join(", ");
let base_insert_query = format!( let base_insert_query = format!(
"INSERT INTO {} ({}) VALUES {{}} ON CONFLICT DO NOTHING", "INSERT INTO {} ({}) VALUES {{}} ON CONFLICT DO NOTHING",
model.table_name, sql_columns model.table_name,
sql_columns
); );
let (normal_fields, foreign_keys_fields) = get_mutation_fields(model); let (normal_fields, foreign_keys_fields) = get_mutation_fields(model);
let (normal_field_idents, foreign_keys_field_idents) = ( let (normal_field_idents, foreign_keys_field_idents) = (
normal_fields normal_fields.iter().map(|f| format_ident!("{}", &f.name)).collect::<Vec<Ident>>(),
.iter() foreign_keys_fields.iter().map(|f| format_ident!("{}", &f.name)).collect::<Vec<Ident>>()
.map(|f| format_ident!("{}", &f.name))
.collect::<Vec<Ident>>(),
foreign_keys_fields
.iter()
.map(|f| format_ident!("{}", &f.name))
.collect::<Vec<Ident>>(),
); );
let fields_count = model.fields.len(); let fields_count = model.fields.len();
@ -194,39 +174,32 @@ fn gen_insert_many_method(model: &Model) -> TokenStream {
} }
} }
fn gen_update_by_id_method(model: &Model) -> TokenStream { fn gen_update_by_id_method(model: &Model) -> TokenStream {
let resource_ident = format_ident!("{}", &model.name); let resource_ident = format_ident!("{}", &model.name);
let primary_key = &model let primary_key = &model.fields.iter()
.fields
.iter()
.find(|f| f.is_primary) .find(|f| f.is_primary)
.expect("A model must have at least one primary key") .expect("A model must have at least one primary key")
.name; .name;
let (normal_fields, foreign_keys_fields) = get_mutation_fields(model); let (normal_fields, foreign_keys_fields) = get_mutation_fields(model);
let (normal_field_idents, foreign_keys_field_idents) = ( let (normal_field_idents, foreign_keys_field_idents) = (
normal_fields normal_fields.iter().map(|f| format_ident!("{}", &f.name)).collect::<Vec<Ident>>(),
.iter() foreign_keys_fields.iter().map(|f| format_ident!("{}", &f.name)).collect::<Vec<Ident>>()
.map(|f| format_ident!("{}", &f.name))
.collect::<Vec<Ident>>(),
foreign_keys_fields
.iter()
.map(|f| format_ident!("{}", &f.name))
.collect::<Vec<Ident>>(),
); );
let sql_columns = [normal_fields, foreign_keys_fields] let sql_columns = [normal_fields, foreign_keys_fields].concat()
.concat()
.iter() .iter()
.map(|f| f.name.clone()) .map(|f| f.name.clone())
.collect::<Vec<String>>(); .collect::<Vec<String>>();
let set_statements = sql_columns let set_statements = sql_columns.iter()
.iter()
.enumerate() .enumerate()
.map(|(i, column_name)| format!("{} = ${}", column_name, i + 2)) .map(|(i, column_name)| format!("{} = ${}", column_name, i+2))
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join(", "); .join(", ");
let update_query = format!( let update_query = format!(
"UPDATE {} SET {} WHERE {} = $1", "UPDATE {} SET {} WHERE {} = $1",
model.table_name, set_statements, primary_key model.table_name,
set_statements,
primary_key
); );
let func_name_ident = format_ident!("update_by_{}", primary_key); let func_name_ident = format_ident!("update_by_{}", primary_key);
@ -245,9 +218,7 @@ fn gen_update_by_id_method(model: &Model) -> TokenStream {
} }
fn gen_delete_by_id_method(model: &Model) -> TokenStream { fn gen_delete_by_id_method(model: &Model) -> TokenStream {
let primary_key = &model let primary_key = &model.fields.iter()
.fields
.iter()
.find(|f| f.is_primary) .find(|f| f.is_primary)
.expect("A model must have at least one primary key") .expect("A model must have at least one primary key")
.name; .name;
@ -255,7 +226,8 @@ fn gen_delete_by_id_method(model: &Model) -> TokenStream {
let func_name_ident = format_ident!("delete_by_{}", primary_key); let func_name_ident = format_ident!("delete_by_{}", primary_key);
let query = format!( let query = format!(
"DELETE FROM {} WHERE {} = $1", "DELETE FROM {} WHERE {} = $1",
model.table_name, primary_key model.table_name,
primary_key
); );
quote! { quote! {
@ -271,9 +243,7 @@ fn gen_delete_by_id_method(model: &Model) -> TokenStream {
} }
fn gen_delete_many_by_id_method(model: &Model) -> TokenStream { fn gen_delete_many_by_id_method(model: &Model) -> TokenStream {
let primary_key = &model let primary_key = &model.fields.iter()
.fields
.iter()
.find(|f| f.is_primary) .find(|f| f.is_primary)
.expect("A model must have at least one primary key") .expect("A model must have at least one primary key")
.name; .name;
@ -281,7 +251,8 @@ fn gen_delete_many_by_id_method(model: &Model) -> TokenStream {
let func_name_ident = format_ident!("delete_many_by_{}", primary_key); let func_name_ident = format_ident!("delete_many_by_{}", primary_key);
let delete_query_tmpl = format!( let delete_query_tmpl = format!(
"DELETE FROM {} WHERE {} IN ({{}})", "DELETE FROM {} WHERE {} IN ({{}})",
model.table_name, primary_key model.table_name,
primary_key
); );
quote! { quote! {
@ -307,10 +278,8 @@ fn gen_delete_many_by_id_method(model: &Model) -> TokenStream {
} }
} }
pub fn generate_repository_file(
_all_models: &[Model], pub fn generate_repository_file(all_models: &[Model], model: &Model) -> Result<SourceNodeContainer> {
model: &Model,
) -> Result<SourceNodeContainer> {
let resource_name = model.name.clone(); let resource_name = model.name.clone();
let resource_module_ident = format_ident!("{}", &model.module_path.first().unwrap()); let resource_module_ident = format_ident!("{}", &model.module_path.first().unwrap());
@ -321,19 +290,15 @@ pub fn generate_repository_file(
let get_all_method_code = gen_get_all_method(model); let get_all_method_code = gen_get_all_method(model);
let get_by_id_method_code = gen_get_by_field_method( let get_by_id_method_code = gen_get_by_field_method(
model, model,
model model.fields.iter()
.fields .find(|f| f.is_primary == true)
.iter() .expect("Expected at least one primary key on the model.")
.find(|f| f.is_primary)
.expect("Expected at least one primary key on the model."),
); );
let get_many_by_id_method_code = gen_get_many_by_field_method( let get_many_by_id_method_code = gen_get_many_by_field_method(
model, model,
model model.fields.iter()
.fields .find(|f| f.is_primary == true)
.iter() .expect("Expected at least one primary key on the model.")
.find(|f| f.is_primary)
.expect("Expected at least one primary key on the model."),
); );
let insert_method_code = gen_insert_method(model); let insert_method_code = gen_insert_method(model);
let insert_many_method_code = gen_insert_many_method(model); let insert_many_method_code = gen_insert_many_method(model);
@ -341,31 +306,34 @@ pub fn generate_repository_file(
let delete_by_id_method_code = gen_delete_by_id_method(model); let delete_by_id_method_code = gen_delete_by_id_method(model);
let delete_many_by_id_method_code = gen_delete_many_by_id_method(model); let delete_many_by_id_method_code = gen_delete_many_by_id_method(model);
let query_by_field_methods: Vec<TokenStream> = model
.fields let query_by_field_methods: Vec<TokenStream> =
.iter() model.fields.iter()
.filter(|f| f.is_query_entrypoint) .filter(|f| f.is_query_entrypoint)
.map(|field| gen_get_by_field_method(model, field)) .map(|field|
gen_get_by_field_method(
model,
&field
)
)
.collect(); .collect();
let query_many_by_field_methods: Vec<TokenStream> = model let query_many_by_field_methods: Vec<TokenStream> =
.fields model.fields.iter()
.iter()
.filter(|f| f.is_query_entrypoint) .filter(|f| f.is_query_entrypoint)
.map(|field| gen_get_many_by_field_method(model, field)) .map(|field|
gen_get_many_by_field_method(
model,
&field
)
)
.collect(); .collect();
let fields_with_foreign_refs: Vec<&Field> = model let fields_with_foreign_refs: Vec<&Field> = model.fields.iter().filter(|f|
.fields match f.foreign_mode { FieldForeignMode::ForeignRef(_) => true, FieldForeignMode::NotRef => false }
.iter() ).collect();
.filter(|f| match f.foreign_mode { let related_entity_methods_codes: Vec<TokenStream> = fields_with_foreign_refs.iter().map(|field|
FieldForeignMode::ForeignRef(_) => true, gen_get_many_of_related_entity_method(model, &field)
FieldForeignMode::NotRef => false, ).collect();
})
.collect();
let related_entity_methods_codes: Vec<TokenStream> = fields_with_foreign_refs
.iter()
.map(|field| gen_get_many_of_related_entity_method(model, field))
.collect();
// TODO: add import line // TODO: add import line
let base_repository_code: TokenStream = quote! { let base_repository_code: TokenStream = quote! {
@ -412,6 +380,6 @@ pub fn generate_repository_file(
Ok(SourceNodeContainer { Ok(SourceNodeContainer {
name: format!("{}_repository.rs", model.name.to_snake_case()), name: format!("{}_repository.rs", model.name.to_snake_case()),
inner: SourceNode::File(pretty), inner: SourceNode::File(pretty)
}) })
} }

View file

@ -21,10 +21,11 @@ pub fn generate_repositories_source_files(models: &[Model]) -> Result<SourceNode
} }
nodes.push(SourceNodeContainer { nodes.push(SourceNodeContainer {
name: "mod.rs".into(), name: "mod.rs".into(),
inner: SourceNode::File(mod_index_code.to_string()), inner: SourceNode::File(mod_index_code.to_string())
}); });
Ok(SourceNodeContainer { Ok(SourceNodeContainer {
name: "".into(), name: "".into(),
inner: SourceNode::Directory(nodes), inner: SourceNode::Directory(nodes)
}) })
} }

View file

@ -5,10 +5,7 @@ use crate::models::{Field, FieldForeignMode, Model};
/// method that can be used to retreive a list of entities of type X that are the children of a parent type Y /// method that can be used to retreive a list of entities of type X that are the children of a parent type Y
/// ex: get all comments of a post /// ex: get all comments of a post
pub fn gen_get_many_of_related_entity_method( pub fn gen_get_many_of_related_entity_method(model: &Model, foreign_key_field: &Field) -> TokenStream {
model: &Model,
foreign_key_field: &Field,
) -> TokenStream {
let resource_ident = format_ident!("{}", &model.name); let resource_ident = format_ident!("{}", &model.name);
let foreign_ref_params = match &foreign_key_field.foreign_mode { let foreign_ref_params = match &foreign_key_field.foreign_mode {
@ -18,10 +15,7 @@ pub fn gen_get_many_of_related_entity_method(
} }
}; };
let select_query = format!( let select_query = format!("SELECT * FROM {} WHERE {} = $1", model.table_name, foreign_key_field.name);
"SELECT * FROM {} WHERE {} = $1",
model.table_name, foreign_key_field.name
);
let func_name_ident = format_ident!("get_many_of_{}", foreign_ref_params.target_resource_name); let func_name_ident = format_ident!("get_many_of_{}", foreign_ref_params.target_resource_name);
@ -34,3 +28,4 @@ pub fn gen_get_many_of_related_entity_method(
} }
} }
} }

View file

@ -1,19 +1,22 @@
use attribute_derive::FromAttr;
use std::{ffi::OsStr, path::Path}; use std::{ffi::OsStr, path::Path};
use attribute_derive::FromAttr;
use anyhow::{anyhow, Result};
use argh::FromArgs; use argh::FromArgs;
use anyhow::{Result, anyhow};
use crate::generators::{SourceNode, SourceNodeContainer}; use crate::generators::{SourceNode, SourceNodeContainer};
pub mod generators; // use gen_migrations::generate_create_table_sql;
// use gen_repositories::{generate_repositories_source_files, SourceNodeContainer};
pub mod models; pub mod models;
pub mod parse_models; pub mod parse_models;
pub mod generators;
#[derive(FromAttr, PartialEq, Debug, Default)] #[derive(FromAttr, PartialEq, Debug, Default)]
#[attribute(ident = sql_generator_model)] #[attribute(ident = sql_generator_model)]
pub struct SqlGeneratorModelAttr { pub struct SqlGeneratorModelAttr {
table_name: Option<String>, table_name: Option<String>
} }
#[derive(FromAttr, PartialEq, Debug, Default)] #[derive(FromAttr, PartialEq, Debug, Default)]
@ -25,16 +28,17 @@ pub struct SqlGeneratorFieldAttr {
/// to indicate that this field will be used to obtains entities /// to indicate that this field will be used to obtains entities
/// our framework will generate methods for all fields that is an entrypoint /// our framework will generate methods for all fields that is an entrypoint
is_query_entrypoint: Option<bool>, is_query_entrypoint: Option<bool>
} }
#[derive(FromArgs, PartialEq, Debug)] #[derive(FromArgs, PartialEq, Debug)]
/// Generate SQL CREATE TABLE migrations /// Generate SQL CREATE TABLE migrations
#[argh(subcommand, name = "gen-migrations")] #[argh(subcommand, name = "gen-migrations")]
struct GenerateMigration { struct GenerateMigration {
/// path of file where to write all in one generated SQL migration /// path of file where to write all in one generated SQL migration
#[argh(option, short = 'o')] #[argh(option, short = 'o')]
output: Option<String>, output: Option<String>
} }
#[derive(FromArgs, PartialEq, Debug)] #[derive(FromArgs, PartialEq, Debug)]
@ -43,7 +47,7 @@ struct GenerateMigration {
struct GenerateRepositories { struct GenerateRepositories {
/// path of the directory that contains repositories /// path of the directory that contains repositories
#[argh(option, short = 'o')] #[argh(option, short = 'o')]
output: Option<String>, output: Option<String>
} }
#[derive(FromArgs, PartialEq, Debug)] #[derive(FromArgs, PartialEq, Debug)]
@ -65,16 +69,16 @@ struct GeneratorArgs {
models_path: Option<String>, models_path: Option<String>,
#[argh(subcommand)] #[argh(subcommand)]
nested: GeneratorArgsSubCommands, nested: GeneratorArgsSubCommands
} }
fn write_source_code(base_path: &Path, snc: SourceNodeContainer) -> Result<()> { fn write_source_code(base_path: &Path, snc: SourceNodeContainer) -> Result<()> {
let path = base_path.join(snc.name); let path = base_path.join(snc.name);
match snc.inner { match snc.inner {
SourceNode::File(code) => { SourceNode::File(code) => {
println!("Writing file {:?}.", path); println!("writing file {:?}", path);
std::fs::write(path, code)?; std::fs::write(path, code)?;
} },
SourceNode::Directory(dir) => { SourceNode::Directory(dir) => {
for node in dir { for node in dir {
write_source_code(&path, node)?; write_source_code(&path, node)?;
@ -88,10 +92,7 @@ pub fn main() -> Result<()> {
let args: GeneratorArgs = argh::from_env(); let args: GeneratorArgs = argh::from_env();
let project_root = &args.project_root.unwrap_or(".".to_string()); let project_root = &args.project_root.unwrap_or(".".to_string());
let project_root_path = Path::new(&project_root); let project_root_path = Path::new(&project_root);
eprintln!( eprintln!("Using project root at: {:?}", &project_root_path.canonicalize()?);
"Using project root at: {:?}",
&project_root_path.canonicalize()?
);
if !project_root_path.exists() { if !project_root_path.exists() {
return Err(anyhow!("Could not resolve project root path.")); return Err(anyhow!("Could not resolve project root path."));
} }
@ -116,19 +117,12 @@ pub fn main() -> Result<()> {
if !models_mod_path.exists() { if !models_mod_path.exists() {
return Err(anyhow!("Could not resolve models modules.")); return Err(anyhow!("Could not resolve models modules."));
} }
if models_mod_path if models_mod_path.file_name().map(|x| x == OsStr::new("mod.rs")).unwrap_or(false) {
.file_name()
.map(|x| x == OsStr::new("mod.rs"))
.unwrap_or(false)
{
models_mod_path.pop(); models_mod_path.pop();
} }
eprintln!("Found models in project, parsing models"); eprintln!("Found models in project, parsing models");
let models = parse_models::parse_models_from_module(&models_mod_path)?; let models = parse_models::parse_models_from_module(&models_mod_path)?;
eprintln!( dbg!(&models);
"Found and parsed a grand total of {} sqlxgentools compatible models.",
models.len()
);
match args.nested { match args.nested {
GeneratorArgsSubCommands::GenerateRepositories(opts) => { GeneratorArgsSubCommands::GenerateRepositories(opts) => {
@ -140,15 +134,16 @@ pub fn main() -> Result<()> {
return Err(anyhow!("Could not resolve repositories modules.")); return Err(anyhow!("Could not resolve repositories modules."));
} }
let snc = generators::repositories::generate_repositories_source_files(&models)?; let snc = generators::repositories::generate_repositories_source_files(&models)?;
dbg!(&snc);
write_source_code(&repositories_mod_path, snc)?; write_source_code(&repositories_mod_path, snc)?;
} },
GeneratorArgsSubCommands::GenerateMigration(opts) => { GeneratorArgsSubCommands::GenerateMigration(opts) => {
eprintln!("Generating migrations…"); eprintln!("Generating migrations…");
let sql_code = generators::migrations::generate_create_table_sql(&models)?; let sql_code = generators::migrations::generate_create_table_sql(&models)?;
if let Some(out_location) = opts.output { if let Some(out_location) = opts.output {
let output_path = Path::new(&out_location); let output_path = Path::new(&out_location);
let _write_res = std::fs::write(output_path, sql_code); let write_res = std::fs::write(output_path, sql_code);
// TODO: check if write result is an error and return error message. eprintln!("{:?}", write_res);
} else { } else {
println!("{}", sql_code); println!("{}", sql_code);
} }

View file

@ -7,7 +7,7 @@ struct Model {
module_path: Vec<String>, module_path: Vec<String>,
name: String, name: String,
table_name: String, table_name: String,
fields: Vec<Field>, fields: Vec<Field>
} }
impl Model { impl Model {
@ -29,6 +29,7 @@ impl Model {
// } // }
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
#[fully_pub] #[fully_pub]
struct ForeignRefParams { struct ForeignRefParams {
@ -40,11 +41,12 @@ struct ForeignRefParams {
// target_resource_name_plural: String // target_resource_name_plural: String
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
#[fully_pub] #[fully_pub]
enum FieldForeignMode { enum FieldForeignMode {
ForeignRef(ForeignRefParams), ForeignRef(ForeignRefParams),
NotRef, NotRef
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -56,5 +58,6 @@ struct Field {
is_unique: bool, is_unique: bool,
is_primary: bool, is_primary: bool,
is_query_entrypoint: bool, is_query_entrypoint: bool,
foreign_mode: FieldForeignMode, foreign_mode: FieldForeignMode
} }

View file

@ -1,14 +1,11 @@
use attribute_derive::FromAttr;
use std::{fs, path::Path}; use std::{fs, path::Path};
use attribute_derive::FromAttr;
use anyhow::{anyhow, Result}; use anyhow::{Result, anyhow};
use convert_case::{Case, Casing}; use convert_case::{Case, Casing};
use syn::{GenericArgument, Type}; use syn::{GenericArgument, PathArguments, Type};
use crate::{ use crate::{SqlGeneratorFieldAttr, SqlGeneratorModelAttr, models::{Field, FieldForeignMode, ForeignRefParams, Model}};
models::{Field, FieldForeignMode, ForeignRefParams, Model},
SqlGeneratorFieldAttr, SqlGeneratorModelAttr,
};
fn extract_generic_type(base_segments: Vec<String>, ty: &Type) -> Option<&Type> { fn extract_generic_type(base_segments: Vec<String>, ty: &Type) -> Option<&Type> {
// If it is not `TypePath`, it is not possible to be `Option<T>`, return `None` // If it is not `TypePath`, it is not possible to be `Option<T>`, return `None`
@ -55,19 +52,40 @@ fn extract_generic_type(base_segments: Vec<String>, ty: &Type) -> Option<&Type>
fn get_type_first_ident(inp: &Type) -> Option<String> { fn get_type_first_ident(inp: &Type) -> Option<String> {
match inp { match inp {
Type::Path(field_type_path) => Some( Type::Path(field_type_path) => {
field_type_path Some(field_type_path.path.segments.get(0).unwrap().ident.to_string())
.path },
.segments _ => {
.get(0) None
.unwrap() }
.ident
.to_string(),
),
_ => None,
} }
} }
fn get_first_generic_arg_type_ident(inp: &Type) -> Option<String> {
if let Type::Path(field_type_path) = inp {
if let PathArguments::AngleBracketed(args) = &field_type_path.path.segments.get(0).unwrap().arguments {
if args.args.is_empty() {
None
} else {
if let GenericArgument::Type(arg_type) = args.args.get(0).unwrap() {
if let Type::Path(arg_type_path) = arg_type {
Some(arg_type_path.path.segments.get(0).unwrap().ident.to_string())
} else {
None
}
} else {
None
}
}
} else {
None
}
} else {
None
}
}
fn parse_model_attribute(item: &syn::ItemStruct) -> Result<Option<SqlGeneratorModelAttr>> { fn parse_model_attribute(item: &syn::ItemStruct) -> Result<Option<SqlGeneratorModelAttr>> {
for attr in item.attrs.iter() { for attr in item.attrs.iter() {
let attr_ident = match attr.path().get_ident() { let attr_ident = match attr.path().get_ident() {
@ -83,12 +101,9 @@ fn parse_model_attribute(item: &syn::ItemStruct) -> Result<Option<SqlGeneratorMo
match SqlGeneratorModelAttr::from_attribute(attr) { match SqlGeneratorModelAttr::from_attribute(attr) {
Ok(v) => { Ok(v) => {
return Ok(Some(v)); return Ok(Some(v));
} },
Err(err) => { Err(err) => {
return Err(anyhow!( return Err(anyhow!("Failed to parse sql_generator_model attribute macro: {}", err));
"Failed to parse sql_generator_model attribute macro: {}",
err
));
} }
}; };
} }
@ -110,13 +125,9 @@ fn parse_field_attribute(field: &syn::Field) -> Result<Option<SqlGeneratorFieldA
match SqlGeneratorFieldAttr::from_attribute(attr) { match SqlGeneratorFieldAttr::from_attribute(attr) {
Ok(v) => { Ok(v) => {
return Ok(Some(v)); return Ok(Some(v));
} },
Err(err) => { Err(err) => {
return Err(anyhow!( return Err(anyhow!("Failed to parse sql_generator_field attribute macro on field {:?}, {}", field, err));
"Failed to parse sql_generator_field attribute macro on field {:?}, {}",
field,
err
));
} }
}; };
} }
@ -125,7 +136,10 @@ fn parse_field_attribute(field: &syn::Field) -> Result<Option<SqlGeneratorFieldA
/// Take struct name as source, apply snake case and pluralize with a s /// Take struct name as source, apply snake case and pluralize with a s
fn generate_table_name_from_struct_name(struct_name: &str) -> String { fn generate_table_name_from_struct_name(struct_name: &str) -> String {
format!("{}s", struct_name.to_case(Case::Snake)) format!(
"{}s",
struct_name.to_case(Case::Snake)
)
} }
/// Scan for models struct in a rust file and return a struct representing the model /// Scan for models struct in a rust file and return a struct representing the model
@ -136,7 +150,8 @@ pub fn parse_models(source_code_path: &Path) -> Result<Vec<Model>> {
let mut models: Vec<Model> = vec![]; let mut models: Vec<Model> = vec![];
for item in parsed_file.items { for item in parsed_file.items {
if let syn::Item::Struct(itemval) = item { match item {
syn::Item::Struct(itemval) => {
let model_name = itemval.ident.to_string(); let model_name = itemval.ident.to_string();
let model_attrs = match parse_model_attribute(&itemval)? { let model_attrs = match parse_model_attribute(&itemval)? {
Some(v) => v, Some(v) => v,
@ -150,6 +165,7 @@ pub fn parse_models(source_code_path: &Path) -> Result<Vec<Model>> {
for field in itemval.fields.iter() { for field in itemval.fields.iter() {
let field_name = field.ident.clone().unwrap().to_string(); let field_name = field.ident.clone().unwrap().to_string();
let field_type = field.ty.clone(); let field_type = field.ty.clone();
println!("field {} {:?}", field_name, field_type);
let mut output_field = Field { let mut output_field = Field {
name: field_name, name: field_name,
@ -158,7 +174,7 @@ pub fn parse_models(source_code_path: &Path) -> Result<Vec<Model>> {
is_primary: false, is_primary: false,
is_unique: false, is_unique: false,
is_query_entrypoint: false, is_query_entrypoint: false,
foreign_mode: FieldForeignMode::NotRef, foreign_mode: FieldForeignMode::NotRef
}; };
let first_type: String = match get_type_first_ident(&field_type) { let first_type: String = match get_type_first_ident(&field_type) {
@ -171,12 +187,8 @@ pub fn parse_models(source_code_path: &Path) -> Result<Vec<Model>> {
if first_type == "Option" { if first_type == "Option" {
output_field.is_nullable = true; output_field.is_nullable = true;
let inner_type = match extract_generic_type( let inner_type = match extract_generic_type(
vec![ vec!["Option".into(), "std:option:Option".into(), "core:option:Option".into()],
"Option".into(), &field_type
"std:option:Option".into(),
"core:option:Option".into(),
],
&field_type,
) { ) {
Some(v) => v, Some(v) => v,
None => { None => {
@ -191,7 +203,10 @@ pub fn parse_models(source_code_path: &Path) -> Result<Vec<Model>> {
} }
} }
if first_type == "Vec" { if first_type == "Vec" {
let inner_type = match extract_generic_type(vec!["Vec".into()], &field_type) { let inner_type = match extract_generic_type(
vec!["Vec".into()],
&field_type
) {
Some(v) => v, Some(v) => v,
None => { None => {
return Err(anyhow!("Could not extract type from Vec")); return Err(anyhow!("Could not extract type from Vec"));
@ -206,14 +221,13 @@ pub fn parse_models(source_code_path: &Path) -> Result<Vec<Model>> {
} }
output_field.rust_type = final_type; output_field.rust_type = final_type;
let field_attrs_opt = parse_field_attribute(field)?; let field_attrs_opt = parse_field_attribute(field)?;
if first_type == "ForeignRef" { if first_type == "ForeignRef" {
let attrs = match &field_attrs_opt { let attrs = match &field_attrs_opt {
Some(attrs) => attrs, Some(attrs) => attrs,
None => { None => {
return Err(anyhow!( return Err(anyhow!("Found a ForeignRef type but did not found attributes."))
"Found a ForeignRef type but did not found attributes."
))
} }
}; };
let rrn = match &attrs.reverse_relation_name { let rrn = match &attrs.reverse_relation_name {
@ -224,42 +238,39 @@ pub fn parse_models(source_code_path: &Path) -> Result<Vec<Model>> {
}; };
let extract_res = extract_generic_type(vec!["ForeignRef".into()], &field_type) let extract_res = extract_generic_type(vec!["ForeignRef".into()], &field_type)
.and_then(get_type_first_ident); .and_then(|t| get_type_first_ident(t));
let target_type_name = match extract_res { let target_type_name = match extract_res {
Some(v) => v, Some(v) => v,
None => { None => {
return Err(anyhow!("Could not extract inner type from ForeignRef.")); return Err(anyhow!("Could not extract inner type from ForeignRef."));
} }
}; };
output_field.foreign_mode = FieldForeignMode::ForeignRef(ForeignRefParams { output_field.foreign_mode = FieldForeignMode::ForeignRef(
ForeignRefParams {
reverse_relation_name: rrn, reverse_relation_name: rrn,
target_resource_name: target_type_name.to_case(Case::Snake), target_resource_name: target_type_name.to_case(Case::Snake)
}); }
);
} }
// parse attribute // parse attribute
if let Some(field_attr) = field_attrs_opt { if let Some(field_attr) = field_attrs_opt {
output_field.is_primary = field_attr.is_primary.unwrap_or_default(); output_field.is_primary = field_attr.is_primary.unwrap_or_default();
output_field.is_unique = field_attr.is_unique.unwrap_or_default(); output_field.is_unique = field_attr.is_unique.unwrap_or_default();
output_field.is_query_entrypoint = output_field.is_query_entrypoint = field_attr.is_query_entrypoint.unwrap_or_default();
field_attr.is_query_entrypoint.unwrap_or_default();
} }
fields.push(output_field); fields.push(output_field);
} }
models.push(Model { models.push(Model {
module_path: vec![source_code_path module_path: vec![source_code_path.file_stem().unwrap().to_str().unwrap().to_string()],
.file_stem()
.unwrap()
.to_str()
.unwrap()
.to_string()],
name: model_name.clone(), name: model_name.clone(),
table_name: model_attrs table_name: model_attrs.table_name
.table_name
.unwrap_or(generate_table_name_from_struct_name(&model_name)), .unwrap_or(generate_table_name_from_struct_name(&model_name)),
fields, fields
}) })
},
_ => {}
} }
} }
Ok(models) Ok(models)
@ -270,7 +281,7 @@ fn parse_models_from_module_inner(module_path: &Path) -> Result<Vec<Model>> {
let mut models: Vec<Model> = vec![]; let mut models: Vec<Model> = vec![];
if module_path.is_file() { if module_path.is_file() {
println!("Looking for models to parse from path {:?}.", module_path); println!("Parsing models from path {:?}.", module_path);
models.extend(parse_models(module_path)?); models.extend(parse_models(module_path)?);
return Ok(models); return Ok(models);
} }
@ -285,6 +296,23 @@ fn parse_models_from_module_inner(module_path: &Path) -> Result<Vec<Model>> {
Ok(models) Ok(models)
} }
// fn complete_models(original_models: Vec<Model>) -> Result<Vec<Model>> {
// let mut new_models: Vec<Model> = vec![];
// for model in original_models {
// for original_field in model.fields {
// let mut field = original_field
// match original_field.foreign_mode {
// FieldForeignMode::NotRef => {},
// FieldForeignMode::ForeignRef(ref_params) => {
// }
// }
// }
// }
// Ok(new_models)
// }
/// Scan for models struct in a rust file and return a struct representing the model /// Scan for models struct in a rust file and return a struct representing the model
pub fn parse_models_from_module(module_path: &Path) -> Result<Vec<Model>> { pub fn parse_models_from_module(module_path: &Path) -> Result<Vec<Model>> {
let models = parse_models_from_module_inner(module_path)?; let models = parse_models_from_module_inner(module_path)?;

View file

@ -1,6 +1,6 @@
[package] [package]
name = "sqlxgentools_misc" name = "sqlxgentools_misc"
description = "Various data types and traits to use in a sqlxgentools-enabled codebase." description = "Various misc class to use in applications that use sqlxgentools"
publish = true publish = true
edition.workspace = true edition.workspace = true
authors.workspace = true authors.workspace = true

View file

@ -12,20 +12,23 @@ use sqlx_core::error::BoxDynError;
use sqlx_core::types::Type; use sqlx_core::types::Type;
use sqlx_sqlite::{Sqlite, SqliteArgumentValue}; use sqlx_sqlite::{Sqlite, SqliteArgumentValue};
#[fully_pub] #[fully_pub]
trait DatabaseLine { trait DatabaseLine {
fn id(&self) -> String; fn id(&self) -> String;
} }
/// Wrapper to mark a model field as foreign /// Wrapper to mark a model field as foreign
/// You can use a generic argument inside ForeignRef to point to the target model /// You can use a generic argument inside ForeignRef to point to the target model
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
#[fully_pub] #[fully_pub]
struct ForeignRef<T: Sized + DatabaseLine> { struct ForeignRef<T: Sized + DatabaseLine> {
pub target_type: PhantomData<T>, pub target_type: PhantomData<T>,
pub target_id: String, pub target_id: String
} }
// Implement serde Serialize for ForeignRef // Implement serde Serialize for ForeignRef
impl<T: Sized + DatabaseLine> Serialize for ForeignRef<T> { impl<T: Sized + DatabaseLine> Serialize for ForeignRef<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
@ -37,20 +40,22 @@ impl<T: Sized + DatabaseLine> Serialize for ForeignRef<T> {
} }
} }
impl<T: Sized + DatabaseLine> ForeignRef<T> { impl<T: Sized + DatabaseLine> ForeignRef<T> {
pub fn new(entity: &T) -> ForeignRef<T> { pub fn new(entity: &T) -> ForeignRef<T> {
ForeignRef { ForeignRef {
target_type: PhantomData, target_type: PhantomData,
target_id: entity.id(), target_id: entity.id()
} }
} }
} }
impl<'r, DB: Database, T: Sized + DatabaseLine> Decode<'r, DB> for ForeignRef<T> impl<'r, DB: Database, T: Sized + DatabaseLine> Decode<'r, DB> for ForeignRef<T>
where where
// we want to delegate some of the work to string decoding so let's make sure strings // we want to delegate some of the work to string decoding so let's make sure strings
// are supported by the database // are supported by the database
&'r str: Decode<'r, DB>, &'r str: Decode<'r, DB>
{ {
fn decode( fn decode(
value: <DB as Database>::ValueRef<'r>, value: <DB as Database>::ValueRef<'r>,
@ -61,7 +66,7 @@ where
Ok(ForeignRef::<T> { Ok(ForeignRef::<T> {
target_type: PhantomData, target_type: PhantomData,
target_id: ref_val, target_id: ref_val
}) })
} }
} }
@ -79,11 +84,9 @@ impl<T: DatabaseLine + Sized> Type<Sqlite> for ForeignRef<T> {
} }
impl<T: DatabaseLine + Sized> Encode<'_, Sqlite> for ForeignRef<T> { impl<T: DatabaseLine + Sized> Encode<'_, Sqlite> for ForeignRef<T> {
fn encode_by_ref( fn encode_by_ref(&self, args: &mut Vec<SqliteArgumentValue<'_>>) -> Result<IsNull, BoxDynError> {
&self,
args: &mut Vec<SqliteArgumentValue<'_>>,
) -> Result<IsNull, BoxDynError> {
args.push(SqliteArgumentValue::Text(self.target_id.clone().into())); args.push(SqliteArgumentValue::Text(self.target_id.clone().into()));
Ok(IsNull::No) Ok(IsNull::No)
} }
} }