diff --git a/.gitignore b/.gitignore index ea8c4bf..4a65cde 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,5 @@ +# IDEs +.idea + +# Rust /target diff --git a/Cargo.lock b/Cargo.lock index 883719a..4b2d22c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -111,7 +111,7 @@ version = "3.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c11d40217d16aee8508cc8e5fde8b4ff24639758608e5374e731b53f85749fb9" dependencies = [ - "heck", + "heck 0.4.0", "proc-macro-error", "proc-macro2", "quote", @@ -288,6 +288,15 @@ version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e" +[[package]] +name = "heck" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d621efb26863f0e9924c6ac577e8275e5e6b77455db64ffa6c65c904e9e132c" +dependencies = [ + "unicode-segmentation", +] + [[package]] name = "heck" version = "0.4.0" @@ -338,10 +347,11 @@ dependencies = [ "anyhow", "chrono", "clap", - "heck", + "heck 0.4.0", "indexmap", "postgres", "postgres-types", + "sea-query", "tempfile", "thiserror", "time 0.3.9", @@ -662,6 +672,42 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +[[package]] +name = "sea-query" +version = "0.26.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6a1de9d0334895f7eb51bd1603c3c9f6737413f905a134001f1e198f43bfd70" +dependencies = [ + "chrono", + "postgres-types", + "sea-query-derive", + "sea-query-driver", +] + +[[package]] +name = "sea-query-derive" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34cdc022b4f606353fe5dc85b09713a04e433323b70163e81513b141c6ae6eb5" +dependencies = [ + "heck 0.3.3", + "proc-macro2", + "quote", + "syn", + "thiserror", +] + +[[package]] +name = "sea-query-driver" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d7f0cae2e7ebb2affc378c40bc343c8197181d601d6755c3e66f1bd18cac253" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "sha2" version = "0.10.2" @@ -918,6 +964,12 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "unicode-segmentation" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e8820f5d777f6224dc4be3632222971ac30164d4a258d595640799554ebfd99" + [[package]] name = "version_check" version = "0.9.4" diff --git a/Cargo.toml b/Cargo.toml index 1ee2c96..8f183b5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ heck = "0.4.0" indexmap = { version = "1.8.2" } postgres = { version = "0.19.3", optional = true } postgres-types = { version = "0.2.3", features = ["derive"] } +sea-query = { version = "0.26.3", default-features = false, features = ["backend-postgres", "derive", "with-chrono"], optional = true } thiserror = "1.0.31" time = { version = "0.3.9", features = ["parsing"] } @@ -23,4 +24,6 @@ postgres = { version = "0.19.3", features = ["with-chrono-0_4", ] } tempfile = "^3.3.0" [features] -default = ["postgres", ] +default = ["postgres", "sql"] +# Enables SQL query builder functionality. +sql = ["dep:sea-query"] diff --git a/README.md b/README.md index c963ecd..5395271 100644 --- a/README.md +++ b/README.md @@ -1,21 +1,13 @@ # instant-models -Run tests: +## Generate Rust code with the CLI ```shell -cargo test -- --nocapture +cargo run --bin cli --features="postgres clap sql" -- -t "accounts" > accounts.rs ``` -## Generate Rust code with the cli - - ```shell -cargo run --bin cli --features="postgres clap" -- -t "accounts" > accounts.rs -``` - - -```shell -$ cargo run --bin cli --features="postgres clap" -- --help +$ cargo run --bin cli --features="postgres clap sql" -- --help instant-models 0.1.0 Generate Rust code from postgres table @@ -32,3 +24,79 @@ OPTIONS: -t, --table-name Name of the table to generate -V, --version Print version information ``` + +## Query Builder + +When the `sql` feature is enabled, generated tables include code to compose simple SQL queries. + +E.g. Consider the following generated table structure. + +```rust +pub struct Accounts { + pub user_id: i32, + pub username: String, + pub password: String, + pub email: String, + pub created_on: chrono::naive::NaiveDateTime, + pub last_login: Option, +} +``` + +We can construct a query to retrieve the username and email address of all users: +```rust +// SELECT "username", "email" FROM "accounts" +let select: String = Accounts::query() + .select(|a| [a.username, a.email]) + .to_string(); +``` + +Conditionals can be specified using `filter` and combined using bitwise operators: +```rust +// SELECT "username", "email" FROM "accounts" +// WHERE "last_login" IS NOT NULL AND ("user_id" = 1 OR "username" != "admin") +let select: String = Accounts::query() + .select(|a| [a.username, a.email]) + .filter(|a| Sql::is_not_null(a.last_login) & (Sql::eq(a.user_id, 1) | Sql::ne(a.username, "admin"))) + .to_string(); +``` + +### Fetching Queries + +With the `postgres` feature enabled, the query can be excuted directly using the [postgres](https://crates.io/crates/postgres) crate: + +```rust +use postgres::{Config, NoTls, Row}; + +// Connect to a Postgres database. +let client = &mut Config::new() + .user("postgres") + .password("postgres") + .host("127.0.0.1") + .port(5432) + .dbname("postgres") + .connect(NoTls) + .unwrap(); + +// Fetch all rows. +let rows: Vec = Accounts::query() + .select(|a| [a.username, a.email]) + .fetch(client, &[]) + .unwrap(); +``` + +## Development + + +### Tests + +Start a local, ephemeral Postgres instance: + +```shell +docker run -it -p 127.0.0.1:5432:5432 --rm -e POSTGRES_PASSWORD=postgres postgres +``` + +Run all tests: + +```shell +cargo test -- --nocapture +``` diff --git a/src/bin/cli.rs b/src/bin/cli.rs index cdf5ff5..6cd203e 100644 --- a/src/bin/cli.rs +++ b/src/bin/cli.rs @@ -45,4 +45,6 @@ fn main() { println!("{}", struct_bldr.build_type()); println!("\n{}", struct_bldr.build_new_type()); println!("\n{}", struct_bldr.build_type_methods()); + #[cfg(feature = "sql")] + println!("\n{}", struct_bldr.build_field_identifiers()); } diff --git a/src/lib.rs b/src/lib.rs index 5d6d554..ecdb6cb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,8 +1,11 @@ -mod struct_builder; +pub use column::*; +#[cfg(feature = "sql")] +pub use sql::*; pub use struct_builder::*; +pub use types::*; mod column; -pub use column::*; - +#[cfg(feature = "sql")] +mod sql; +mod struct_builder; mod types; -pub use types::*; diff --git a/src/sql/field.rs b/src/sql/field.rs new file mode 100644 index 0000000..c6a4c59 --- /dev/null +++ b/src/sql/field.rs @@ -0,0 +1,101 @@ +use std::borrow::Cow; +use std::marker::PhantomData; + +use crate::Table; + +/// SQL column definition with the Rust type. +// TODO: add table type back-reference? +pub struct Field { + pub name: &'static str, + // TODO: replace sea_query. + pub iden: Table::IdenType, + pub typ: PhantomData, + pub table: PhantomData, +} + +impl Field { + pub const fn new(name: &'static str, iden: Table::IdenType) -> Self { + Self { + name, + iden, + typ: PhantomData::, + table: PhantomData::
, + } + } + + pub fn table() -> Table::IdenType { + Table::table() + } +} + +// TODO: replace sea_query. +impl sea_query::IntoIden for Field { + fn into_iden(self) -> sea_query::DynIden { + self.iden.into_iden() + } +} + +/// Helper trait for converting tuples of fields into an iterator. +pub trait FieldList { + type IntoIter: Iterator; + + fn into_iter(self) -> Self::IntoIter; +} + +macro_rules! impl_field_list { + ( $( $name:ident.$idx:tt )+ ) => { + impl<$($name),+> FieldList for ($($name,)+) + where $($name: sea_query::IntoIden,)+ + { + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + vec![$(self.$idx.into_iden(),)+].into_iter() + } + } + }; +} + +// If you need to select more than 12 fields in a single query, open an issue. +impl_field_list!(A.0); +impl_field_list!(A.0 B.1); +impl_field_list!(A.0 B.1 C.2); +impl_field_list!(A.0 B.1 C.2 D.3); +impl_field_list!(A.0 B.1 C.2 D.3 E.4); +impl_field_list!(A.0 B.1 C.2 D.3 E.4 F.5); +impl_field_list!(A.0 B.1 C.2 D.3 E.4 F.5 G.6); +impl_field_list!(A.0 B.1 C.2 D.3 E.4 F.5 G.6 H.7); +impl_field_list!(A.0 B.1 C.2 D.3 E.4 F.5 G.6 H.7 I.8); +impl_field_list!(A.0 B.1 C.2 D.3 E.4 F.5 G.6 H.7 I.8 J.9); +impl_field_list!(A.0 B.1 C.2 D.3 E.4 F.5 G.6 H.7 I.8 J.9 K.10); +impl_field_list!(A.0 B.1 C.2 D.3 E.4 F.5 G.6 H.7 I.8 J.9 K.10 L.11); + +/// Marker trait to indicate which types and fields can be compared. +pub trait Compatible {} + +impl Compatible> for Field {} + +impl Compatible> for Field, T2> {} + +impl Compatible, T1>> for Field {} + +impl Compatible> for Type {} + +impl Compatible> for Option {} + +impl Compatible, T>> for Type {} + +macro_rules! impl_compatible { + ( $t:ty | $( $s:ty ),+ ) => ($( + impl Compatible> for $s {} + impl Compatible> for Option<$s> {} + impl Compatible, T>> for $s {} + impl Compatible, T>> for Option<$s> {} + + impl Compatible> for Field<$s, T2> {} + impl Compatible> for Field, T2> {} + impl Compatible, T1>> for Field<$s, T2> {} + )*) +} + +impl_compatible!(String | &str, Cow<'static, str>); diff --git a/src/sql/mod.rs b/src/sql/mod.rs new file mode 100644 index 0000000..6a125f3 --- /dev/null +++ b/src/sql/mod.rs @@ -0,0 +1,7 @@ +mod field; +mod query; +mod table; + +pub use field::*; +pub use query::*; +pub use table::*; diff --git a/src/sql/query.rs b/src/sql/query.rs new file mode 100644 index 0000000..7c039f3 --- /dev/null +++ b/src/sql/query.rs @@ -0,0 +1,320 @@ +use crate::{Combine, Compatible, Field, FieldList, Sources, Table}; +use sea_query::{BinOper, IntoIden}; +use std::fmt::{Display, Formatter}; +use std::marker::PhantomData; + +#[derive(Default)] +pub struct SqlQuery { + sources: PhantomData, + // TODO: replace sea_query. + query: sea_query::SelectStatement, +} + +// TODO: replace sea_query. +impl SqlQuery { + pub fn new() -> SqlQuery { + let mut query = sea_query::SelectStatement::new(); + for table in T::tables() { + query.from(table); + } + Self { + query, + sources: PhantomData::, + } + } + + pub fn select(mut self, columns: F) -> Self + where + F: FnOnce(T::SOURCES) -> I, + I: FieldList, + { + self.query.columns(columns(T::sources()).into_iter()); + self + } + + pub fn filter(mut self, conditions: F) -> Self + where + F: FnOnce(T::SOURCES) -> Sql, + { + self.query.cond_where(conditions(T::sources()).cond); + self + } + + pub fn limit(mut self, limit: u64) -> Self { + self.query.limit(limit); + self + } + + pub fn from(mut self) -> SqlQuery + where + T: Combine, + { + use sea_query::IntoTableRef; + self.query.from(O::table().into_table_ref()); + SqlQuery { + sources: PhantomData::, + query: self.query, + } + } + + pub fn join(mut self, on: F) -> SqlQuery + where + // TODO: restrict join to only tables with foreign keys. + T: Combine, + T::COMBINED: Sources, + F: FnOnce(::SOURCES) -> Sql, + { + // TODO: join on foreign keys automatically, or add them to a list and handle them later. + let on_condition: Sql = on(::sources()); + self.query + .join(sea_query::JoinType::Join, O::table(), on_condition.cond); + SqlQuery { + sources: PhantomData::, + query: self.query, + } + } +} + +impl Display for SqlQuery { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + self.query.to_string(sea_query::PostgresQueryBuilder) + ) + } +} + +#[cfg(feature = "postgres")] +impl SqlQuery { + /// Executes a query, returning the resulting rows. + pub fn fetch( + self, + client: &mut postgres::Client, + params: &[&(dyn postgres_types::ToSql + Sync)], + ) -> Result, postgres::Error> { + client.query(&self.to_string(), params) + } +} + +/// SQL condition for e.g. WHERE, ON, HAVING clauses. +/// +/// Can be composed using bitwise operators `&` for AND, `|` for OR. +pub struct Sql { + // TODO: replace sea_query. + cond: sea_query::Cond, +} + +impl Sql { + pub fn eq(left: Left, right: Right) -> Self + where + Left: Into, + Right: Compatible + IntoValueOrFieldRef, + { + let left_col = left.into().into_tbl_expr(); + let condition: sea_query::SimpleExpr = match right.into_value_or_field_ref() { + ValueOrFieldRef::Value(value) => left_col.eq(value), + ValueOrFieldRef::FieldRef(right_col) => { + left_col.equals(right_col.table, right_col.column) + } + }; + Self { + cond: sea_query::Cond::all().add(condition), + } + } + + pub fn ne(left: Left, right: Right) -> Self + where + Left: Into, + Right: Compatible + IntoValueOrFieldRef, + { + let left_col = left.into().into_tbl_expr(); + let condition: sea_query::SimpleExpr = match right.into_value_or_field_ref() { + ValueOrFieldRef::Value(value) => left_col.ne(value), + ValueOrFieldRef::FieldRef(right_col) => { + left_col.binary(BinOper::NotEqual, right_col.into_tbl_expr()) + } + }; + Self { + cond: sea_query::Cond::all().add(condition), + } + } + + pub fn gt(left: Left, right: Right) -> Self + where + Left: Into, + Right: Compatible + IntoValueOrFieldRef, + { + let left_col = left.into().into_tbl_expr(); + let condition: sea_query::SimpleExpr = match right.into_value_or_field_ref() { + ValueOrFieldRef::Value(value) => left_col.gt(value), + ValueOrFieldRef::FieldRef(right_col) => { + left_col.greater_than(right_col.into_tbl_expr()) + } + }; + Self { + cond: sea_query::Cond::all().add(condition), + } + } + + pub fn gte(left: Left, right: Right) -> Self + where + Left: Into, + Right: Compatible + IntoValueOrFieldRef, + { + let left_col = left.into().into_tbl_expr(); + let condition: sea_query::SimpleExpr = match right.into_value_or_field_ref() { + ValueOrFieldRef::Value(value) => left_col.gte(value), + ValueOrFieldRef::FieldRef(right_col) => { + left_col.greater_or_equal(right_col.into_tbl_expr()) + } + }; + Self { + cond: sea_query::Cond::all().add(condition), + } + } + + pub fn lt(left: Left, right: Right) -> Self + where + Left: Into, + Right: Compatible + IntoValueOrFieldRef, + { + let left_col = left.into().into_tbl_expr(); + let condition: sea_query::SimpleExpr = match right.into_value_or_field_ref() { + ValueOrFieldRef::Value(value) => left_col.lt(value), + ValueOrFieldRef::FieldRef(right_col) => left_col.less_than(right_col.into_tbl_expr()), + }; + Self { + cond: sea_query::Cond::all().add(condition), + } + } + + pub fn lte(left: Left, right: Right) -> Self + where + Left: Into, + Right: Compatible + IntoValueOrFieldRef, + { + let left_col = left.into().into_tbl_expr(); + let condition: sea_query::SimpleExpr = match right.into_value_or_field_ref() { + ValueOrFieldRef::Value(value) => left_col.lte(value), + ValueOrFieldRef::FieldRef(right_col) => { + left_col.less_or_equal(right_col.into_tbl_expr()) + } + }; + Self { + cond: sea_query::Cond::all().add(condition), + } + } + + pub fn is(left: Left, right: Right) -> Self + where + Left: Into, + Right: Compatible + Into, + { + let left_col = left.into().into_tbl_expr(); + Self { + cond: sea_query::Cond::all().add(left_col.is(right.into())), + } + } + + pub fn is_not(left: Left, right: Right) -> Self + where + Left: Into, + Right: Compatible + Into, + { + let left_col = left.into().into_tbl_expr(); + Self { + cond: sea_query::Cond::all().add(left_col.is_not(right.into())), + } + } + + pub fn is_null(col: T) -> Self + where + T: Into, + { + let column = col.into().into_tbl_expr(); + Self { + cond: sea_query::Cond::all().add(column.is_null()), + } + } + + pub fn is_not_null(col: T) -> Self + where + T: Into, + { + let column = col.into().into_tbl_expr(); + Self { + cond: sea_query::Cond::all().add(column.is_not_null()), + } + } + + // TODO: port rest of conditions. +} + +impl std::ops::BitAnd for Sql { + type Output = Sql; + + fn bitand(self, rhs: Self) -> Self::Output { + Sql { + cond: sea_query::Cond::all().add(self.cond).add(rhs.cond), + } + } +} + +impl std::ops::BitOr for Sql { + type Output = Sql; + + fn bitor(self, rhs: Self) -> Self::Output { + Sql { + cond: sea_query::Cond::any().add(self.cond).add(rhs.cond), + } + } +} + +/// Field reference with explicit table and column identifiers. +pub struct FieldRef { + table: sea_query::DynIden, + column: sea_query::DynIden, +} + +impl FieldRef { + pub fn into_tbl_expr(self) -> sea_query::Expr { + sea_query::Expr::tbl(self.table, self.column) + } +} + +impl From> for FieldRef { + fn from(field: Field) -> FieldRef { + FieldRef { + table: Table::table().into_iden(), + column: field.iden.into_iden(), + } + } +} + +pub enum ValueOrFieldRef { + Value(sea_query::Value), + FieldRef(FieldRef), +} + +pub trait IntoValueOrFieldRef { + fn into_value_or_field_ref(self) -> ValueOrFieldRef; +} + +impl IntoValueOrFieldRef for Field { + fn into_value_or_field_ref(self) -> ValueOrFieldRef { + ValueOrFieldRef::FieldRef(FieldRef { + table: Table::table().into_iden(), + column: self.iden.into_iden(), + }) + } +} + +impl IntoValueOrFieldRef for V +where + V: Into, +{ + fn into_value_or_field_ref(self) -> ValueOrFieldRef { + ValueOrFieldRef::Value(self.into()) + } +} diff --git a/src/sql/table.rs b/src/sql/table.rs new file mode 100644 index 0000000..91fdd9f --- /dev/null +++ b/src/sql/table.rs @@ -0,0 +1,100 @@ +use crate::SqlQuery; + +/// A SQL table with field identifiers. +pub trait Table { + type IdenType: sea_query::Iden; + type FieldsType; + const FIELDS: Self::FieldsType; + + fn query() -> SqlQuery + where + Self: 'static, + { + SqlQuery::new() + } + + /// Returns the table identifier. + // TODO: replace sea_query. + fn table() -> Self::IdenType; + + /// Returns a struct with the SQL column identifiers. + fn fields() -> Self::FieldsType { + Self::FIELDS + } +} + +/// Represents one or more SQL [Tables](Table). +pub trait Sources { + type SOURCES; + + /// Table field definitions. Returns a tuple if more than one table is referenced. + fn sources() -> Self::SOURCES; + + /// List of all table identifiers. + fn tables() -> Vec; +} + +impl Sources for T { + type SOURCES = T::FieldsType; + + fn sources() -> Self::SOURCES { + T::fields() + } + + fn tables() -> Vec { + use sea_query::IntoTableRef; + vec![T::table().into_table_ref()] + } +} + +/// Helper trait for combining tuples of SQL tables. +/// +/// E.g. +/// - A + B => (A,B). +/// - (A,B) + C => (A,B,C). +pub trait Combine { + type COMBINED; +} + +impl Combine for A { + type COMBINED = (A, B); +} + +/// Implement [Sources] for tuples of [Tables](Table). E.g. (A,B), (A,B,C), etc. +/// +/// Pass `true` as the last parameter if the tuple can be combined, omit it if not (maximum size). +macro_rules! impl_sources_tuple { + ( $( $name:ident )+ ) => { + impl<$($name: Table + 'static),+> Sources for ($($name,)+) + { + type SOURCES = ($($name::FieldsType,)+); + + fn sources() -> Self::SOURCES { + ($($name::fields(),)+) + } + + fn tables() -> Vec { + use sea_query::IntoTableRef; + vec![$($name::table().into_table_ref(),)+] + } + } + }; + ( $( $name:ident )+, $joinable:expr ) => { + impl_sources_tuple!($($name)+); + + impl Combine for ($($name,)+) { + type COMBINED = ($($name,)+ Z); + } + }; +} + +// If you need to join more than 10 tables in a single query, open an issue. +impl_sources_tuple! { A B, true } +impl_sources_tuple! { A B C, true } +impl_sources_tuple! { A B C D, true } +impl_sources_tuple! { A B C D E, true } +impl_sources_tuple! { A B C D E F, true } +impl_sources_tuple! { A B C D E F G, true } +impl_sources_tuple! { A B C D E F G H, true } +impl_sources_tuple! { A B C D E F G H I, true } +impl_sources_tuple! { A B C D E F G H I J } diff --git a/src/struct_builder.rs b/src/struct_builder.rs index 0648b07..5cfad5c 100644 --- a/src/struct_builder.rs +++ b/src/struct_builder.rs @@ -1,9 +1,11 @@ -use crate::{Column, Constraint, NewValue, Type}; -use heck::{AsSnakeCase, AsUpperCamelCase}; -use indexmap::IndexMap; use std::borrow::Cow; use std::str::FromStr; +use heck::{AsSnakeCase, AsUpperCamelCase}; +use indexmap::IndexMap; + +use crate::{Column, Constraint, NewValue, Type}; + #[derive(Debug, PartialEq)] pub struct StructBuilder { pub name: Cow<'static, str>, @@ -29,6 +31,43 @@ impl StructBuilder { } } + #[cfg(feature = "postgres")] + pub fn new_from_conn( + client: &mut postgres::Client, + table_name: &str, + ) -> Result { + let mut struct_bldr = Self::new(table_name.to_string().into()); + let mut col_index: IndexMap = IndexMap::new(); + for row in client.query("SELECT column_name, is_nullable, data_type FROM information_schema.columns WHERE table_name = $1;", &[&table_name])? { + let column_name: &str = row.get(0); + let is_nullable: &str = row.get(1); + let data_type: &str = row.get(2); + let col = Column::new(column_name.to_string().into(), Type::from_str(data_type)?).set_null(is_nullable == "YES"); + col_index.insert(column_name.to_string(), col); + } + + for row in client.query("SELECT a.column_name, a.constraint_name, b.constraint_type FROM information_schema.constraint_column_usage AS a JOIN information_schema.table_constraints AS b ON a.constraint_name = b.constraint_name WHERE a.table_name = $1", &[&table_name])? { + let column_name: &str = row.get(0); + let constraint_name: &str = row.get(1); + let constraint_type: &str = row.get(2); + if let Some(col) = col_index.get_mut(&column_name.to_string()) { + match constraint_type { + "UNIQUE" => { col.unique = true; } + "PRIMARY KEY" => { col.primary_key = true; } + other => panic!("unknown constraint type: {}", other), + } + } else { + panic!("got constraint for unknown column: column_name {column_name}, constraint_name {constraint_name} constraint_type {constraint_type}"); + } + } + + for (_, col) in col_index.into_iter() { + struct_bldr.add_column(col); + } + + Ok(struct_bldr) + } + pub fn add_column(&mut self, val: Column) -> &mut Self { self.columns.insert(val.name.clone(), val); self @@ -48,7 +87,7 @@ impl StructBuilder { " pub {},", NewValue { val: col, - lifetime: Some("a") + lifetime: Some("a"), } )); acc.push('\n'); @@ -57,8 +96,7 @@ impl StructBuilder { format!( r#"pub struct {}New<'a> {{ -{}}} - "#, +{}}}"#, AsUpperCamelCase(&self.name), columns ) @@ -110,8 +148,7 @@ impl StructBuilder { }} Ok(()) }} -}} - "#, +}}"#, AsUpperCamelCase(&self.name), ) } @@ -164,41 +201,107 @@ impl StructBuilder { } */ - #[cfg(feature = "postgres")] - pub fn new_from_conn( - client: &mut postgres::Client, - table_name: &str, - ) -> Result { - let mut struct_bldr = Self::new(table_name.to_string().into()); - let mut col_index: IndexMap = IndexMap::new(); - for row in client.query("SELECT column_name, is_nullable, data_type FROM information_schema.columns WHERE table_name = $1;", &[&table_name])? { - let column_name: &str = row.get(0); - let is_nullable: &str = row.get(1); - let data_type: &str = row.get(2); - let col = Column::new(column_name.to_string().into(), Type::from_str(data_type)?).set_null(is_nullable == "YES"); - col_index.insert(column_name.to_string(), col); - } + /// Generates a helper enum and struct to allow accessing field identifiers when building + /// SQL queries. + #[cfg(feature = "sql")] + pub fn build_field_identifiers(&self) -> String { + let mut output: String = String::new(); - for row in client.query("SELECT a.column_name, a.constraint_name, b.constraint_type FROM information_schema.constraint_column_usage AS a JOIN information_schema.table_constraints AS b ON a.constraint_name = b.constraint_name WHERE a.table_name = $1", &[&table_name])? { - let column_name: &str = row.get(0); - let constraint_name: &str = row.get(1); - let constraint_type: &str = row.get(2); - if let Some(col) = col_index.get_mut(&column_name.to_string()) { - match constraint_type { - "UNIQUE" => {col.unique = true;}, - "PRIMARY KEY" => {col.primary_key = true;}, - other => panic!("unknown constraint type: {}", other), + // Generate enum with sea_query field identifiers: `Iden`. + // TODO: use a proc-macro to derive this instead? + // TODO: replace sea_query. + let struct_name = format!("{}", AsUpperCamelCase(&self.name)); + let enum_name = format!("{}Iden", struct_name); + let column_names = self.columns.values().fold(String::new(), |mut acc, col| { + acc.push_str(&format!(" {},\n", AsUpperCamelCase(&col.name))); + acc + }); + let match_iden_columns = self.columns.values().fold(String::new(), |mut acc, col| { + acc.push_str(&format!( + " Self::{} => \"{}\",\n", + AsUpperCamelCase(&col.name), + &col.name + )); + acc + }); + output.push_str(&format!( + r#" +#[derive(Copy, Clone)] +pub enum {} {{ + Table, +{}}} + +impl sea_query::Iden for {} {{ + fn unquoted(&self, s: &mut dyn std::fmt::Write) {{ + write!( + s, + "{{}}", + match self {{ + Self::Table => "{}", +{} }}).expect("{} failed to write"); + }} +}} +"#, + enum_name, column_names, enum_name, &self.name, match_iden_columns, enum_name, + )); + + // Generate fields struct: `Fields`. + // TODO: derive with proc-macro instead? + let fields_name = format!("{}Fields", AsUpperCamelCase(&self.name)); + let fields = self.columns.values().fold(String::new(), |mut acc, col| { + if col.null { + acc.push_str(&format!( + " pub {}: ::instant_models::Field, {}>,\n", + AsSnakeCase(&col.name), + col.r#type, + struct_name + )); + } else { + acc.push_str(&format!( + " pub {}: ::instant_models::Field<{}, {}>,\n", + AsSnakeCase(&col.name), + col.r#type, + struct_name + )); } - } else { - panic!("got constraint for unknown column: column_name {column_name}, constraint_name {constraint_name} constraint_type {constraint_type}"); - } - } + acc + }); + output.push_str(&format!( + r#" +pub struct {} {{ +{}}} +"#, + fields_name, fields, + )); - for (_, col) in col_index.into_iter() { - struct_bldr.add_column(col); - } + // Implement Table for the struct. + // TODO: derive with proc-macro instead? + let fields_instance = self.columns.values().fold(String::new(), |mut acc, col| { + acc.push_str(&format!( + " {}: ::instant_models::Field::new(\"{}\", {}::{}),\n", + AsSnakeCase(&col.name), + col.name, + enum_name, + AsUpperCamelCase(&col.name) + )); + acc + }); + output.push_str(&format!( + r#" +impl instant_models::Table for {} {{ + type IdenType = {}; + type FieldsType = {}; + const FIELDS: Self::FieldsType = {} {{ +{} }}; - Ok(struct_bldr) + fn table() -> Self::IdenType {{ + {}::Table + }} +}}"#, + struct_name, enum_name, fields_name, fields_name, fields_instance, enum_name, + )); + + output } } @@ -211,10 +314,91 @@ impl std::fmt::Display for StructBuilder { write!( fmt, r#"pub struct {} {{ -{}}} - "#, +{}}}"#, AsUpperCamelCase(&self.name), columns ) } } + +#[cfg(test)] +mod tests { + use super::*; + + fn mock_struct_builder() -> StructBuilder { + let mut columns = IndexMap::new(); + columns.insert( + "user_id".into(), + Column::new("user_id".into(), Type::from_str("integer").unwrap()), + ); + columns.insert( + "username".into(), + Column::new("username".into(), Type::from_str("text").unwrap()), + ); + columns.insert( + "email".into(), + Column::new("email".into(), Type::from_str("text").unwrap()).set_null(true), + ); + + let constraints = vec![Constraint::PrimaryKey { + name: "pk_user_id".into(), + columns: vec!["user_id".into()], + }]; + + StructBuilder { + name: "accounts".into(), + columns, + constraints, + } + } + + #[test] + fn test_build_field_identifiers() { + let builder: StructBuilder = mock_struct_builder(); + assert_eq!( + builder.build_field_identifiers(), + r##" +#[derive(Copy, Clone)] +pub enum AccountsIden { + Table, + UserId, + Username, + Email, +} + +impl sea_query::Iden for AccountsIden { + fn unquoted(&self, s: &mut dyn std::fmt::Write) { + write!( + s, + "{}", + match self { + Self::Table => "accounts", + Self::UserId => "user_id", + Self::Username => "username", + Self::Email => "email", + }).expect("AccountsIden failed to write"); + } +} + +pub struct AccountsFields { + pub user_id: ::instant_models::Field, + pub username: ::instant_models::Field, + pub email: ::instant_models::Field, Accounts>, +} + +impl instant_models::Table for Accounts { + type IdenType = AccountsIden; + type FieldsType = AccountsFields; + const FIELDS: Self::FieldsType = AccountsFields { + user_id: ::instant_models::Field::new("user_id", AccountsIden::UserId), + username: ::instant_models::Field::new("username", AccountsIden::Username), + email: ::instant_models::Field::new("email", AccountsIden::Email), + }; + + fn table() -> Self::IdenType { + AccountsIden::Table + } +}"## + ); + } +} diff --git a/tests/accounts.rs b/tests/accounts.rs index 3a8ab55..dc50f3b 100644 --- a/tests/accounts.rs +++ b/tests/accounts.rs @@ -1,25 +1,24 @@ -#![allow(dead_code)] -use postgres::{Config, NoTls}; // +use instant_models::{Sql, Table}; +use postgres::{Config, NoTls}; // Example generated with // `cargo run --bin cli --features="postgres clap" -- -t "accounts" > accounts.rs` -// pub struct Accounts { pub user_id: i32, + pub created_on: chrono::naive::NaiveDateTime, + pub last_login: Option, pub username: String, pub password: String, pub email: String, - pub created_on: chrono::naive::NaiveDateTime, - pub last_login: Option, } pub struct AccountsNew<'a> { + pub created_on: chrono::naive::NaiveDateTime, + pub last_login: Option, pub username: &'a str, pub password: &'a str, pub email: &'a str, - pub created_on: chrono::naive::NaiveDateTime, - pub last_login: Option, } impl Accounts { @@ -27,16 +26,16 @@ impl Accounts { client: &mut postgres::Client, slice: &[AccountsNew<'_>], ) -> Result<(), postgres::Error> { - let statement = client.prepare("INSERT INTO accounts(username, password, email, created_on, last_login) VALUES($1, $2, $3, $4, $5);")?; + let statement = client.prepare("INSERT INTO accounts(created_on, last_login, username, password, email) VALUES($1, $2, $3, $4, $5);")?; for entry in slice { client.execute( &statement, &[ + &entry.created_on, + &entry.last_login, &entry.username, &entry.password, &entry.email, - &entry.created_on, - &entry.last_login, ], )?; } @@ -44,8 +43,64 @@ impl Accounts { } } +#[derive(Copy, Clone)] +pub enum AccountsIden { + Table, + UserId, + CreatedOn, + LastLogin, + Username, + Password, + Email, +} + +impl sea_query::Iden for AccountsIden { + fn unquoted(&self, s: &mut dyn std::fmt::Write) { + write!( + s, + "{}", + match self { + Self::Table => "accounts", + Self::UserId => "user_id", + Self::CreatedOn => "created_on", + Self::LastLogin => "last_login", + Self::Username => "username", + Self::Password => "password", + Self::Email => "email", + } + ) + .expect("AccountsIden failed to write"); + } +} + +pub struct AccountsFields { + pub user_id: ::instant_models::Field, + pub created_on: ::instant_models::Field, + pub last_login: ::instant_models::Field, Accounts>, + pub username: ::instant_models::Field, + pub password: ::instant_models::Field, + pub email: ::instant_models::Field, +} + +impl instant_models::Table for Accounts { + type IdenType = AccountsIden; + type FieldsType = AccountsFields; + const FIELDS: Self::FieldsType = AccountsFields { + user_id: ::instant_models::Field::new("user_id", AccountsIden::UserId), + created_on: ::instant_models::Field::new("created_on", AccountsIden::CreatedOn), + last_login: ::instant_models::Field::new("last_login", AccountsIden::LastLogin), + username: ::instant_models::Field::new("username", AccountsIden::Username), + password: ::instant_models::Field::new("password", AccountsIden::Password), + email: ::instant_models::Field::new("email", AccountsIden::Email), + }; + + fn table() -> AccountsIden { + AccountsIden::Table + } +} + #[test] -fn test_accounts() { +fn test_accounts_insert() { let client = &mut Config::new() .user("postgres") .password("postgres") @@ -104,10 +159,13 @@ fn test_accounts() { Accounts::insert_slice(client, &[new_val_1, new_val_2, new_val_3, new_val_4]).unwrap(); - for row in client - .query("SELECT user_id, username FROM accounts;", &[]) - .unwrap() - { + let select_query = Accounts::query().select(|a| (a.user_id, a.username)); + assert_eq!( + select_query.to_string(), + r#"SELECT "user_id", "username" FROM "accounts""# + ); + + for row in select_query.fetch(client, &[]).unwrap() { let id: i32 = row.get(0); let name: &str = row.get(1); println!("found person: {} {}", id, name); @@ -116,3 +174,61 @@ fn test_accounts() { // clean up what we did client.batch_execute(r#"DELETE FROM accounts;"#).unwrap(); } + +#[test] +fn test_accounts_query() { + // SELECT single column. + let select = Accounts::query().select(|a| (a.user_id,)).to_string(); + assert_eq!(select, r#"SELECT "user_id" FROM "accounts""#); + + // SELECT multiple columns. + let select_multiple = Accounts::query() + .select(|a| (a.user_id, a.username, a.email)) + .to_string(); + assert_eq!( + select_multiple, + r#"SELECT "user_id", "username", "email" FROM "accounts""# + ); + + // SELECT WHERE single condition. + let select_where = Accounts::query() + .select(|a| (a.user_id,)) + .filter(|a| Sql::is_null(a.last_login)) + .to_string(); + assert_eq!( + select_where, + r#"SELECT "user_id" FROM "accounts" WHERE "accounts"."last_login" IS NULL"# + ); + + // SELECT WHERE AND. + let select_where_and = Accounts::query() + .select(|a| (a.user_id,)) + .filter(|a| Sql::is_null(a.last_login) & Sql::is_not_null(a.created_on)) + .to_string(); + assert_eq!( + select_where_and, + r#"SELECT "user_id" FROM "accounts" WHERE "accounts"."last_login" IS NULL AND "accounts"."created_on" IS NOT NULL"# + ); + + // SELECT WHERE OR. + let select_where_or = Accounts::query() + .select(|a| (a.user_id,)) + .filter(|a| Sql::is_null(a.last_login) | Sql::is_not_null(a.created_on)) + .to_string(); + assert_eq!( + select_where_or, + r#"SELECT "user_id" FROM "accounts" WHERE "accounts"."last_login" IS NULL OR "accounts"."created_on" IS NOT NULL"# + ); + + // SELECT WHERE AND OR. + let select_where_and_or = Accounts::query() + .select(|a| (a.user_id,)) + .filter(|a| { + Sql::is_null(a.last_login) & (Sql::is_not_null(a.created_on) | Sql::eq(a.user_id, 1)) + }) + .to_string(); + assert_eq!( + select_where_and_or, + r#"SELECT "user_id" FROM "accounts" WHERE "accounts"."last_login" IS NULL AND ("accounts"."created_on" IS NOT NULL OR "accounts"."user_id" = 1)"# + ); +} diff --git a/tests/basic.rs b/tests/basic.rs index 5df01dd..5ed7059 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -28,6 +28,8 @@ fn create_cargo_project(builder: StructBuilder) -> Result<(), anyhow::Error> { file.write_all(builder.build_type().as_bytes())?; file.write_all(builder.build_new_type().as_bytes())?; file.write_all(builder.build_type_methods().as_bytes())?; + #[cfg(feature = "sql")] + file.write_all(builder.build_field_identifiers().as_bytes())?; drop(file); let mut manifest_file = OpenOptions::new() .write(true) diff --git a/tests/join.rs b/tests/join.rs new file mode 100644 index 0000000..d409074 --- /dev/null +++ b/tests/join.rs @@ -0,0 +1,259 @@ +use instant_models::{Field, Sql, Table}; +use std::fmt::Write; + +// Example manually constructed. + +pub struct Accounts { + pub user_id: i32, + pub created_on: chrono::naive::NaiveDateTime, + pub last_login: Option, + pub username: String, + pub password: String, + pub email: String, +} + +pub struct AccountsNew<'a> { + pub created_on: chrono::naive::NaiveDateTime, + pub last_login: Option, + pub username: &'a str, + pub password: &'a str, + pub email: &'a str, +} + +impl Accounts { + pub fn insert_slice( + client: &mut postgres::Client, + slice: &[AccountsNew<'_>], + ) -> Result<(), postgres::Error> { + let statement = client.prepare("INSERT INTO accounts(created_on, last_login, username, password, email) VALUES($1, $2, $3, $4, $5);")?; + for entry in slice { + client.execute( + &statement, + &[ + &entry.created_on, + &entry.last_login, + &entry.username, + &entry.password, + &entry.email, + ], + )?; + } + Ok(()) + } +} + +#[derive(Copy, Clone)] +pub enum AccountsIden { + Table, + UserId, + CreatedOn, + LastLogin, + Username, + Password, + Email, +} + +impl sea_query::Iden for AccountsIden { + fn unquoted(&self, s: &mut dyn std::fmt::Write) { + write!( + s, + "{}", + match self { + Self::Table => "accounts", + Self::UserId => "user_id", + Self::CreatedOn => "created_on", + Self::LastLogin => "last_login", + Self::Username => "username", + Self::Password => "password", + Self::Email => "email", + } + ) + .expect("AccountsIden failed to write"); + } +} + +pub struct AccountsFields { + pub user_id: ::instant_models::Field, + pub created_on: ::instant_models::Field, + pub last_login: ::instant_models::Field, Accounts>, + pub username: ::instant_models::Field, + pub password: ::instant_models::Field, + pub email: ::instant_models::Field, +} + +impl instant_models::Table for Accounts { + type IdenType = AccountsIden; + type FieldsType = AccountsFields; + const FIELDS: Self::FieldsType = AccountsFields { + user_id: ::instant_models::Field::new("user_id", AccountsIden::UserId), + created_on: ::instant_models::Field::new("created_on", AccountsIden::CreatedOn), + last_login: ::instant_models::Field::new("last_login", AccountsIden::LastLogin), + username: ::instant_models::Field::new("username", AccountsIden::Username), + password: ::instant_models::Field::new("password", AccountsIden::Password), + email: ::instant_models::Field::new("email", AccountsIden::Email), + }; + + fn table() -> AccountsIden { + AccountsIden::Table + } +} + +pub struct Access { + pub user: i32, + pub domain: i32, + pub role: String, +} + +#[derive(Copy, Clone)] +pub enum AccessIden { + Table, + User, + Domain, + Role, +} + +impl sea_query::Iden for AccessIden { + fn unquoted(&self, s: &mut dyn Write) { + write!( + s, + "{}", + match self { + Self::Table => "access", + Self::User => "user", + Self::Domain => "domain", + Self::Role => "role", + } + ) + .expect("AccessIden failed to write"); + } +} + +pub struct AccessFields { + pub user: Field, + pub domain: Field, + pub role: Field, +} + +impl Table for Access { + type IdenType = AccessIden; + type FieldsType = AccessFields; + const FIELDS: Self::FieldsType = AccessFields { + user: Field::new("user", AccessIden::User), + domain: Field::new("domain", AccessIden::Domain), + role: Field::new("role", AccessIden::Role), + }; + + fn table() -> Self::IdenType { + AccessIden::Table + } +} + +pub struct Examples { + pub id: i32, + pub example: String, + pub active: bool, +} + +#[derive(Copy, Clone)] +pub enum ExamplesIden { + Table, + Id, + Example, + Active, +} + +pub struct ExamplesFields { + pub id: Field, + pub example: Field, + pub active: Field, +} + +impl sea_query::Iden for ExamplesIden { + fn unquoted(&self, s: &mut dyn Write) { + write!( + s, + "{}", + match self { + Self::Table => "examples", + Self::Id => "id", + Self::Example => "example", + Self::Active => "active", + } + ) + .expect("ExampleIden failed to write"); + } +} + +impl Table for Examples { + type IdenType = ExamplesIden; + type FieldsType = ExamplesFields; + const FIELDS: Self::FieldsType = ExamplesFields { + id: Field::new("id", ExamplesIden::Id), + example: Field::new("example", ExamplesIden::Example), + active: Field::new("active", ExamplesIden::Active), + }; + + fn table() -> ExamplesIden { + ExamplesIden::Table + } +} + +#[test] +fn test_query_from_chain() { + let expected = r#"SELECT "user_id", "username", "password", "email" +FROM "accounts", "access", "examples" +WHERE "accounts"."username" = 'foo' +AND ("accounts"."last_login" IS NOT NULL OR "accounts"."created_on" <> '1970-01-01 00:00:00') +AND ("accounts"."user_id" = "access"."user" AND "access"."role" = 'DomainAdmin') +AND ("accounts"."user_id" = "examples"."id" AND "examples"."active" IS TRUE) +LIMIT 1"#; + + let user = "foo"; + let role = "DomainAdmin"; + let timestamp = chrono::NaiveDateTime::from_timestamp(0, 0); + + let query = Accounts::query() + .filter(|a| { + Sql::eq(a.username, user) + & (Sql::is_not_null(a.last_login) | Sql::ne(a.created_on, timestamp)) + }) + .from::() + .filter(|(a, acl)| Sql::eq(a.user_id, acl.user) & Sql::eq(acl.role, role)) + .from::() + .filter(|(a, .., ex)| Sql::eq(a.user_id, ex.id) & Sql::is(ex.active, true)) + .select(|(a, ..)| (a.user_id, a.username, a.password, a.email)) + .limit(1) + .to_string(); + + assert_eq!(query, expected.replace('\n', " ")); +} + +#[test] +fn test_query_join_chain() { + let expected = r#"SELECT "user_id", "username", "password", "email" +FROM "accounts" +JOIN "access" ON "accounts"."user_id" = "access"."user" +JOIN "examples" ON "accounts"."user_id" = "examples"."id" +WHERE "accounts"."username" = 'foo' +AND ("accounts"."last_login" IS NOT NULL OR "accounts"."created_on" <> '1970-01-01 00:00:00') +AND ("examples"."active" IS TRUE AND "access"."role" = 'DomainAdmin') +LIMIT 1"#; + + let user = "foo"; + let role = "DomainAdmin"; + let timestamp = chrono::NaiveDateTime::from_timestamp(0, 0); + + let query = Accounts::query() + .filter(|a| { + Sql::eq(a.username, user) + & (Sql::is_not_null(a.last_login) | Sql::ne(a.created_on, timestamp)) + }) + .join::(|(a, acl)| Sql::eq(a.user_id, acl.user)) + .join::(|(a, .., ex)| Sql::eq(a.user_id, ex.id)) + .filter(|(.., acl, ex)| Sql::is(ex.active, true) & Sql::eq(acl.role, role)) + .select(|(a, ..)| (a.user_id, a.username, a.password, a.email)) + .limit(1) + .to_string(); + + assert_eq!(query, expected.replace('\n', " ")); +}