From 39bba731fa523a54956a9622dd4d2a44ffb33f7b Mon Sep 17 00:00:00 2001 From: pm Date: Wed, 27 Nov 2024 12:46:49 +0100 Subject: [PATCH 01/14] support unary expressions like NOT --- static_sqlite_macros/src/schema.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/static_sqlite_macros/src/schema.rs b/static_sqlite_macros/src/schema.rs index f72b173..f5ab9ce 100644 --- a/static_sqlite_macros/src/schema.rs +++ b/static_sqlite_macros/src/schema.rs @@ -473,7 +473,8 @@ fn expr_columns<'a>(expr: &'a Expr) -> Vec> { cols.extend(columns); } cols - } + }, + Expr::UnaryOp { expr, .. } => expr_columns(expr), expr => todo!("expr_columns rest of the ops {expr}"), } } From 0c27271a482c9eee16931ed32cc5811a1dcffe99 Mon Sep 17 00:00:00 2001 From: pm Date: Fri, 21 Mar 2025 20:14:32 +0100 Subject: [PATCH 02/14] add support for bindings that are not table-columns --- static_sqlite_macros/src/lib.rs | 127 ++++++++++++++++++++++---------- tests/integration_test.rs | 85 +++++++++++++++++++++ 2 files changed, 174 insertions(+), 38 deletions(-) diff --git a/static_sqlite_macros/src/lib.rs b/static_sqlite_macros/src/lib.rs index 350bcd4..bc7d738 100644 --- a/static_sqlite_macros/src/lib.rs +++ b/static_sqlite_macros/src/lib.rs @@ -424,49 +424,65 @@ fn fn_tokens(db: &Sqlite, schema: &Schema, exprs: &[&SqlExpr]) -> Result quote! { Vec }, - "INTEGER" => quote! { i64 }, - "REAL" | "DOUBLE" => quote! { f64 }, - "TEXT" => quote! { impl ToString }, - _ => unimplemented!("Sqlite fn arg not supported"), - }; - let field_name = Ident::new(&field.column_name, expr.ident.span()); - let not_null = field.not_null; - let pk = field.pk; - match (pk, not_null) { - (0, 0) => quote! { #field_name: Option<#field_type> }, - _ => quote! { #field_name: #field_type }, + .map(|fieldname| { + // if fieldname is in the form of __ or ____ then extract the name and type + let parts_in_fieldname = fieldname.split("__").collect::>(); + if parts_in_fieldname.len() == 2 || parts_in_fieldname.len() == 3 { + let column_type = parts_in_fieldname[1]; + let field_name = Ident::new(&fieldname.to_string().to_lowercase(), expr.ident.span()); + let field_type = create_fn_argument_type(fieldname, column_type); + let nullable = parts_in_fieldname.len() == 3 && parts_in_fieldname[2] == "nullable"; + match nullable { + true => quote! { #field_name: Option<#field_type> }, + false => quote! { #field_name: #field_type }, + } + // otherwise is has to be a column name + } else if let Some(field) = input_schema_rows + .iter() + .find(|row| &row.column_name == fieldname) + { + let field_type = create_fn_argument_type( fieldname, field.column_type.as_str()); + let field_name = Ident::new(&field.column_name, expr.ident.span()); + let not_null = field.not_null; + let pk = field.pk; + match (pk, not_null) { + (0, 0) => quote! { #field_name: Option<#field_type> }, + _ => quote! { #field_name: #field_type }, + } + } else { + unimplemented!( + "field {:?} not found in schema and has no __ suffix", + fieldname + ); } }) .collect::>(); - let params = input_schema_rows + let params = inputs .iter() - .map(|field| { - let not_null = field.not_null; - let name = Ident::new(&field.column_name, expr.ident.span()); - match field.column_type.as_str() { - "BLOB" => { - quote! { #name.into() } - } - "INTEGER" => quote! { #name.into() }, - "REAL" | "DOUBLE" => quote! { #name.into() }, - "TEXT" => match not_null { - 1 => quote! { - #name.to_string().into() - }, - 0 => quote! { - match #name { - Some(val) => val.to_string().into(), - None => static_sqlite::Value::Null - } - }, - _ => unreachable!(), - }, - _ => unimplemented!("Sqlite param not supported"), + .map(|fieldname| { + // if fieldname is in the form of __ or ____ then extract the name and type + let parts_in_fieldname = fieldname.split("__").collect::>(); + if parts_in_fieldname.len() == 2 || parts_in_fieldname.len() == 3 { + let field_name = Ident::new(&fieldname.to_string().to_lowercase(), expr.ident.span()); + let not_null = if parts_in_fieldname.len() == 3 && parts_in_fieldname[2] == "nullable" { 0 } else { 1 }; + let type_hint = parts_in_fieldname[1]; + return create_binding_value(type_hint, not_null, field_name); + // otherwise is has to be a column name + } else if let Some(field) = input_schema_rows + .iter() + .find(|row| &row.column_name == fieldname) + { + let not_null = field.not_null; + let name = Ident::new(&field.column_name, expr.ident.span()); + create_binding_value(field.column_type.as_str(), not_null, name) + } else { + unimplemented!( + "field {:?} not found in schema and has no __ suffix", + fieldname + ); } }) .collect::>(); @@ -489,6 +505,7 @@ fn fn_tokens(db: &Sqlite, schema: &Schema, exprs: &[&SqlExpr]) -> Result Result> { let rows: Vec<#pascal_case> = static_sqlite::query(db, #sql, vec![#(#params,)*]).await?; Ok(rows) @@ -498,6 +515,40 @@ fn fn_tokens(db: &Sqlite, schema: &Schema, exprs: &[&SqlExpr]) -> Result TokenStream { + match column_type { + "BLOB" => quote! { Vec }, + "INTEGER" => quote! { i64 }, + "REAL" | "DOUBLE" => quote! { f64 }, + "TEXT" => quote! { impl ToString }, + _ => unimplemented!("type {:?} not supported for fn arg {:?}", column_type, fieldname), + } +} + +fn create_binding_value(field_type: &str, not_null: i64, name: Ident) -> TokenStream { + match field_type { + "BLOB" => { + quote! { #name.into() } + } + "INTEGER" => quote! { #name.into() }, + "REAL" | "DOUBLE" => quote! { #name.into() }, + "TEXT" => match not_null { + 1 => quote! { + + #name.to_string().into() + }, + 0 => quote! { + match #name { + Some(val) => val.to_string().into(), + None => static_sqlite::Value::Null + } + }, + _ => unreachable!(), + }, + _ => unimplemented!("Sqlite param not supported"), + } +} + fn join_table_names(expr: &&SqlExpr) -> Vec { let mut output = vec![]; visit_relations(&expr.statements, |rel| { diff --git a/tests/integration_test.rs b/tests/integration_test.rs index 791e729..f974b07 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -226,6 +226,91 @@ async fn crud_works() -> Result<()> { Ok(()) } +#[tokio::test] +async fn parameters_that_are_not_in_the_schema_work() -> Result<()> { + sql! { + let migrate = r#" + create table User ( + id integer primary key, + name text unique not null + ); + + create table Post ( + id integer primary key, + user_id integer not null references User(id), + name text unique not null + ); + "#; + + let insert_user = r#" + insert into User (name) values (:name) returning * + "#; + + let insert_post = r#" + insert into Post (user_id, name) values (:user_id, :name) returning * + "#; + let select_posts = r#" + select * from Post where id = :id AND id = :id__INTEGER AND name = :id__INTEGER AND name = :name AND :ff__TEXT="sdd" + "#; + } + + let db = static_sqlite::open(":memory:").await?; + let _ = migrate(&db).await?; + let user1 = insert_user(&db, "user1").await?.first_row()?; + insert_post(&db, user1.id, "user 1 - post1").await?.first_row()?; + insert_post(&db, user1.id, "user 1 - post2").await?.first_row()?; + let user2 = insert_user(&db, "user2").await?.first_row()?; + insert_post(&db, user2.id, "user 2 - post1").await?.first_row()?; + insert_post(&db, user2.id, "user 2 - post2").await?.first_row()?; + + let posts = select_posts(&db, 1, 2, "Hello", "sdd").await?; + println!("{:?}", posts); + + + Ok(()) +} + + + +#[tokio::test] +async fn duplicate_column_names_in_one_query_work() -> Result<()> { + sql! { + let migrate = r#" + create table User ( + id integer primary key, + name text not null + ); + + create table Post ( + id integer primary key, + user_id integer not null references User(id), + name text not null + ); + "#; + + let insert_user = r#"insert into User (name) values (:name) returning *"#; + let insert_post = r#"insert into Post (user_id, name) values (:user_id, :name) returning *"#; + let select_posts_by_user_id = r#"select p.id, p.name, u.name as user_name from Post p, User u where p.user_id = u.id AND u.id = :id"#; + let select_posts_all = r#"select p.id, p.name, u.name as user_name from Post p, User u where p.user_id = u.id"#; + } + + let db = static_sqlite::open(":memory:").await?; + let _ = migrate(&db).await?; + let user1 = insert_user(&db, "user1").await?.first_row()?; + insert_post(&db, user1.id, "user 1 - post1").await?.first_row()?; + insert_post(&db, user1.id, "user 1 - post2").await?.first_row()?; + let user2 = insert_user(&db, "user2").await?.first_row()?; + insert_post(&db, user2.id, "user 2 - post1").await?; + insert_post(&db, user2.id, "user 2 - post2").await?; + + let posts = select_posts_by_user_id(&db, 2).await?; + println!("{:?}", posts); + let posts = select_posts_all(&db).await?; + println!("{:?}", posts); + + Ok(()) +} + #[test] fn ui() { let t = trybuild::TestCases::new(); From 0d2834ab39134d37239d8f956fa66e5936044b50 Mon Sep 17 00:00:00 2001 From: pm Date: Fri, 21 Mar 2025 20:14:32 +0100 Subject: [PATCH 03/14] add support for type hints in bindings and output-columns --- README.md | 77 ++++++++++++ static_sqlite_core/src/ffi.rs | 3 +- static_sqlite_macros/src/lib.rs | 199 ++++++++++++++++++++------------ tests/integration_test.rs | 105 ++++++++++++++--- 4 files changed, 289 insertions(+), 95 deletions(-) diff --git a/README.md b/README.md index 51369dc..598edfe 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,83 @@ async fn main() -> Result<()> { cargo add --git https://github.com/swlkr/static_sqlite ``` + +# Example with aliased columns and type-hints + +Sometimes the type of either a bound parameter or a returned column can not be inferred by +sqlite / static_sqlite (see [sqlite3 docs](https://www.sqlite.org/c3ref/column_decltype.html)) + +In this case you can use type-hints to help the static_sqlite to use the correct type. + +To use type-hints your parameter or column name needs to follow the following format: + +``` +__ +``` + +or + +``` +____ +``` + +If not explicitly specified, the parameter or column is assumed to be NOT NULL. + +sql! { + let migrate = r#" + create table User ( + id integer primary key, + name text unique not null + ); + create table Friendship ( + id integer primary key, + user_id integer not null references User(id), + friend_id integer not null references User(id) + ); + "#; + + let insert_user = r#" + insert into User (name) + values (:name) + returning * + "#; + let create_friendship = r#" + insert into Friendship (user_id, friend_id) + values (:user_id, :friend_id) + returning * + "#; + let get_friendship = r#" + select + u1.name as friend1_name__TEXT, + u2.name as friend2_name__TEXT + from Friendship, User as u1, User as u2 + where Friendship.user_id = u1.id + and Friendship.friend_id = u2.id + and Friendship.id = :friendship_id__INTEGER + "#; +} + + +#[tokio::main] +async fn main() -> Result<()> { + let db = static_sqlite::open(":memory:").await?; + let _ = migrate(&db).await?; + insert_user(&db, "swlkr").await?; + insert_user(&db, "toolbar23").await?; + create_friendship(&db, 1, 2).await?; + + let friends = get_friendship(&db, 1).await?; + + assert_eq!(friends.len(), 1); + assert_eq!(friends.first().unwrap().friend1_name, "swlkr"); + assert_eq!(friends.first().unwrap().friend2_name, "toolbar23"); + + Ok(()) +} +``` + + + # Treesitter ``` diff --git a/static_sqlite_core/src/ffi.rs b/static_sqlite_core/src/ffi.rs index e17cff3..dc04e5f 100644 --- a/static_sqlite_core/src/ffi.rs +++ b/static_sqlite_core/src/ffi.rs @@ -287,8 +287,7 @@ impl Sqlite { let count = sqlite3_column_count(stmt); for i in 0..count { let name_ptr = sqlite3_column_name(stmt, i); - - if !name_ptr.is_null() { + if !name_ptr.is_null() { let name = CStr::from_ptr(name_ptr).to_string_lossy().into_owned(); columns.push(name); } diff --git a/static_sqlite_macros/src/lib.rs b/static_sqlite_macros/src/lib.rs index bc7d738..e24d8d9 100644 --- a/static_sqlite_macros/src/lib.rs +++ b/static_sqlite_macros/src/lib.rs @@ -411,6 +411,7 @@ fn fn_tokens(db: &Sqlite, schema: &Schema, exprs: &[&SqlExpr]) -> Result Result {} }; } - let input_schema_rows: Vec<&&SchemaRow> = inputs - .iter() - .filter_map(|col_name| schema_rows.iter().find(|row| &row.column_name == col_name)) - .collect(); let fn_args = inputs .iter() - .map(|fieldname| { - // if fieldname is in the form of __ or ____ then extract the name and type - let parts_in_fieldname = fieldname.split("__").collect::>(); - if parts_in_fieldname.len() == 2 || parts_in_fieldname.len() == 3 { - let column_type = parts_in_fieldname[1]; - let field_name = Ident::new(&fieldname.to_string().to_lowercase(), expr.ident.span()); - let field_type = create_fn_argument_type(fieldname, column_type); - let nullable = parts_in_fieldname.len() == 3 && parts_in_fieldname[2] == "nullable"; - match nullable { - true => quote! { #field_name: Option<#field_type> }, - false => quote! { #field_name: #field_type }, - } - // otherwise is has to be a column name - } else if let Some(field) = input_schema_rows - .iter() - .find(|row| &row.column_name == fieldname) - { - let field_type = create_fn_argument_type( fieldname, field.column_type.as_str()); - let field_name = Ident::new(&field.column_name, expr.ident.span()); - let not_null = field.not_null; - let pk = field.pk; - match (pk, not_null) { - (0, 0) => quote! { #field_name: Option<#field_type> }, - _ => quote! { #field_name: #field_type }, + .map(|aliases_column_name| { + match parse_type_hinted_column_name(aliases_column_name, &schema_rows) { + TypedToken::FromTypeHint(type_hint) => { + let field_name = Ident::new(&type_hint.alias, expr.ident.span()); + let field_type = create_fn_argument_type(&type_hint.alias, &type_hint.column_type); + match type_hint.not_null { + 0 => quote! { #field_name: Option<#field_type> }, + _ => quote! { #field_name: #field_type }, + } + }, + TypedToken::FromSchemaRow(schema_row) => { + let field_name = Ident::new(&schema_row.column_name, expr.ident.span()); + let field_type = create_fn_argument_type(aliases_column_name, &schema_row.column_type); + match (schema_row.pk, schema_row.not_null) { + (0, 0) => quote! { #field_name: Option<#field_type> }, + _ => quote! { #field_name: #field_type }, + } } - } else { - unimplemented!( - "field {:?} not found in schema and has no __ suffix", - fieldname - ); } }) .collect::>(); + let params = inputs .iter() - .map(|fieldname| { - // if fieldname is in the form of __ or ____ then extract the name and type - let parts_in_fieldname = fieldname.split("__").collect::>(); - if parts_in_fieldname.len() == 2 || parts_in_fieldname.len() == 3 { - let field_name = Ident::new(&fieldname.to_string().to_lowercase(), expr.ident.span()); - let not_null = if parts_in_fieldname.len() == 3 && parts_in_fieldname[2] == "nullable" { 0 } else { 1 }; - let type_hint = parts_in_fieldname[1]; - return create_binding_value(type_hint, not_null, field_name); - // otherwise is has to be a column name - } else if let Some(field) = input_schema_rows - .iter() - .find(|row| &row.column_name == fieldname) - { - let not_null = field.not_null; - let name = Ident::new(&field.column_name, expr.ident.span()); - create_binding_value(field.column_type.as_str(), not_null, name) - } else { - unimplemented!( - "field {:?} not found in schema and has no __ suffix", - fieldname - ); + .map(|aliases_column_name| { + match parse_type_hinted_column_name(aliases_column_name, &schema_rows) { + TypedToken::FromTypeHint(type_hint) => { + let field_name = Ident::new(&type_hint.alias, expr.ident.span()); + create_binding_value(&type_hint.column_type, type_hint.not_null, field_name) + }, + TypedToken::FromSchemaRow(schema_row) => { + let field_name = Ident::new(&schema_row.column_name, expr.ident.span()); + create_binding_value(&schema_row.column_type, schema_row.not_null, field_name) + } } + }) .collect::>(); + let ident = &expr.ident; let outputs = output_column_names(db, expr)?; let pascal_case = snake_to_pascal_case(&ident); - let cols: Vec = outputs - .iter() - .filter_map(|col_name| { - schema_rows - .iter() - .find(|row| &row.column_name == col_name) - .cloned() - .cloned() - }) - .collect(); - let struct_tokens = struct_tokens(expr.ident.span(), &pascal_case, &cols); + + let output_typed = outputs.iter().map(|output| parse_type_hinted_column_name(output, &schema_rows)).collect::>(); + + let struct_tokens = struct_tokens(expr.ident.span(), &pascal_case, &output_typed); + + let sql = &expr.sql; output.push(quote! { #struct_tokens #[doc = #sql] #[allow(non_snake_case)] + #[allow(non_snake_case)] pub async fn #ident(db: &static_sqlite::Sqlite, #(#fn_args),*) -> Result> { let rows: Vec<#pascal_case> = static_sqlite::query(db, #sql, vec![#(#params,)*]).await?; Ok(rows) @@ -549,6 +522,54 @@ fn create_binding_value(field_type: &str, not_null: i64, name: Ident) -> TokenSt } } +#[derive(Debug, Clone)] +struct TypeHintedToken { + name: String, + alias: String, + column_type: String, + not_null: i64, +} + +#[derive(Debug, Clone)] +enum TypedToken { FromTypeHint(TypeHintedToken), FromSchemaRow(SchemaRow) } + +/* + * Parses a type hint and returns a TypedColumnOrParameter + * + * If the alias is in the form of __ or ____ then it is a type hint + * Otherwise it is a column name + * + */ +fn parse_type_hinted_column_name(alias: &str, schema_rows: &Vec<&SchemaRow>) -> TypedToken { + let parts = alias.split("__").collect::>(); + let result = match parts.len() { + 1 => TypedToken::FromSchemaRow( + match schema_rows.iter().find(|row| &row.column_name == alias) { + Some(row) => (**row).clone(), + None => panic!("Column {:?} referenced in binding or column not found in schema, maybe you forgot to add the type hint?", alias), + } + ), + 2 => TypedToken::FromTypeHint(TypeHintedToken { + alias: alias.to_string(), + name: parts[0].to_string(), + column_type: parts[1].to_string(), + not_null: 1, + }), + 3 => TypedToken::FromTypeHint(TypeHintedToken { + alias: alias.to_string(), + name: parts[0].to_string(), + column_type: parts[1].to_string(), + not_null: match parts[2].to_lowercase().as_str() { + "nullable" => 0, + "not_null" => 1, + _ => panic!("Invalid type hint: {:?}, last part must be nullable or not_null", alias), + }, + }), + _ => panic!("Invalid type hint: {:?}", alias), + }; + result +} + fn join_table_names(expr: &&SqlExpr) -> Vec { let mut output = vec![]; visit_relations(&expr.statements, |rel| { @@ -584,16 +605,32 @@ fn structs_tokens(span: Span, schema: &Schema) -> Vec { .iter() .map(|(table, cols)| { let ident = proc_macro2::Ident::new(&table, span); - struct_tokens(span, &ident, cols) + let typed_tokens: Vec = cols.iter() + .map(|col| TypedToken::FromSchemaRow(col.clone())) + .collect(); + struct_tokens(span, &ident, &typed_tokens) }) .collect() } -fn struct_tokens(span: Span, ident: &Ident, cols: &Vec) -> TokenStream { - let struct_fields = cols.iter().map(|row| { - let field_type = field_type(row); - let name = Ident::new(&row.column_name, span); - let optional = match (row.not_null, row.pk) { + +fn struct_tokens(span: Span, ident: &Ident, output_typed: &[TypedToken]) -> TokenStream { + let struct_fields = output_typed.iter().map(|row| { + let field_type = match row { + TypedToken::FromTypeHint(type_hint) => field_type_from_datatype_name(&type_hint.column_type), + TypedToken::FromSchemaRow(schema_row) => field_type(schema_row), + }; + let name = match row { + TypedToken::FromTypeHint(type_hint) => Ident::new(&type_hint.name, span), + TypedToken::FromSchemaRow(schema_row) => Ident::new(&schema_row.column_name, span), + }; + let optional = match ( match row { + TypedToken::FromTypeHint(type_hint) => type_hint.not_null, + TypedToken::FromSchemaRow(schema_row) => schema_row.not_null, + }, match row { + TypedToken::FromTypeHint(_) => 0, + TypedToken::FromSchemaRow(schema_row) => schema_row.pk, + }) { (0, 0) => true, (0, 1) | (1, 0) | (1, 1) => false, _ => unreachable!(), @@ -604,9 +641,15 @@ fn struct_tokens(span: Span, ident: &Ident, cols: &Vec) -> TokenStrea false => quote! { pub #name: #field_type }, } }); - let match_stmt = cols.iter().map(|field| { - let name = Ident::new(&field.column_name, span); - let lit_str = LitStr::new(&field.column_name, span); + let match_stmt = output_typed.iter().map(|row| { + let name = Ident::new(match row { + TypedToken::FromTypeHint(type_hint) => &type_hint.name, + TypedToken::FromSchemaRow(schema_row) => &schema_row.column_name, + }, span); + let lit_str = LitStr::new(match row { + TypedToken::FromTypeHint(type_hint) => &type_hint.alias, + TypedToken::FromSchemaRow(schema_row) => &schema_row.column_name, + }, span); quote! { #lit_str => row.#name = value.try_into()? @@ -634,8 +677,14 @@ fn struct_tokens(span: Span, ident: &Ident, cols: &Vec) -> TokenStrea tokens } + fn field_type(row: &SchemaRow) -> TokenStream { - match row.column_type.as_str() { + field_type_from_datatype_name(&row.column_type) +} + + +fn field_type_from_datatype_name(datatype_name: &str) -> TokenStream { + match datatype_name { "BLOB" => quote! { Vec }, "INTEGER" => quote! { i64 }, "REAL" | "DOUBLE" => quote! { f64 }, diff --git a/tests/integration_test.rs b/tests/integration_test.rs index f974b07..def1b84 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -273,44 +273,113 @@ async fn parameters_that_are_not_in_the_schema_work() -> Result<()> { #[tokio::test] -async fn duplicate_column_names_in_one_query_work() -> Result<()> { +async fn example_friendshipworks() -> Result<()> { + use static_sqlite::{sql, Result, self}; + sql! { let migrate = r#" create table User ( id integer primary key, - name text not null + name text unique not null ); - create table Post ( + create table Friendship ( id integer primary key, user_id integer not null references User(id), - name text not null + friend_id integer not null references User(id) ); "#; - let insert_user = r#"insert into User (name) values (:name) returning *"#; - let insert_post = r#"insert into Post (user_id, name) values (:user_id, :name) returning *"#; - let select_posts_by_user_id = r#"select p.id, p.name, u.name as user_name from Post p, User u where p.user_id = u.id AND u.id = :id"#; - let select_posts_all = r#"select p.id, p.name, u.name as user_name from Post p, User u where p.user_id = u.id"#; + let insert_user = r#" + insert into User (name) + values (:name) + returning * + "#; + let create_friendship = r#" + insert into Friendship (user_id, friend_id) + values (:user_id, :friend_id) + returning * + "#; + let get_friendship = r#" + SELECT + u1.name as friend1_name__TEXT, + u2.name as friend2_name__TEXT + FROM Friendship, User as u1, User as u2 + WHERE Friendship.user_id = u1.id + AND Friendship.friend_id = u2.id + AND Friendship.id = :friendship_id__INTEGER + "#; } let db = static_sqlite::open(":memory:").await?; let _ = migrate(&db).await?; - let user1 = insert_user(&db, "user1").await?.first_row()?; - insert_post(&db, user1.id, "user 1 - post1").await?.first_row()?; - insert_post(&db, user1.id, "user 1 - post2").await?.first_row()?; - let user2 = insert_user(&db, "user2").await?.first_row()?; - insert_post(&db, user2.id, "user 2 - post1").await?; - insert_post(&db, user2.id, "user 2 - post2").await?; + insert_user(&db, "swlkr").await?; + insert_user(&db, "toolbar23").await?; + create_friendship(&db, 1, 2).await?; - let posts = select_posts_by_user_id(&db, 2).await?; - println!("{:?}", posts); - let posts = select_posts_all(&db).await?; - println!("{:?}", posts); + let friends = get_friendship(&db, 1).await?; + + assert_eq!(friends.len(), 1); + assert_eq!(friends.first().unwrap().friend1_name, "swlkr"); + assert_eq!(friends.first().unwrap().friend2_name, "toolbar23"); Ok(()) + +} + + + +#[tokio::test] +async fn duplicate_column_names_in_one_quer2y_work() -> Result<()> { +sql! { + let migrate = r#" + CREATE TABLE IF NOT EXISTS Identifiers ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + entity_type INTEGER NOT NULL, + identifier_type TEXT NOT NULL, + identifier_value TEXT NOT NULL, + UNIQUE(entity_type, identifier_type, identifier_value) + ); + + CREATE TABLE IF NOT EXISTS MappingChanges ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + from_identifier INTEGER NOT NULL REFERENCES Identifiers(id), + to_identifier_previous INTEGER NOT NULL REFERENCES Identifiers(id), + to_identifier_new INTEGER NOT NULL REFERENCES Identifiers(id), + timestamp INTEGER NOT NULL + ); + "#; + + let get_changes = r#" + SELECT + mc.id, + mc.timestamp, + f.entity_type, + f.identifier_type, + f.identifier_value, + op.identifier_type as old_type__TEXT__NULLABLE, + op.identifier_value as old_value__TEXT__NULLABLE, + n.identifier_type as new_type__TEXT__NULLABLE, + n.identifier_value as new_value__TEXT__NULLABLE + FROM MappingChanges mc, Identifiers f, Identifiers op, Identifiers n + WHERE mc.from_identifier = f.id + AND mc.to_identifier_previous = op.id + AND mc.to_identifier_new = n.id + AND mc.timestamp > :since__INTEGER + ORDER BY mc.timestamp ASC + "#; } +let db = static_sqlite::open(":memory:").await?; +let _ = migrate(&db).await?; +let changes = get_changes(&db, 1).await?; +println!("{:?}", changes); + +Ok(()) +} + + + #[test] fn ui() { let t = trybuild::TestCases::new(); From 214f98a0574e242adba9e9be20b8dc378f589442 Mon Sep 17 00:00:00 2001 From: pm Date: Sat, 22 Mar 2025 15:43:57 +0100 Subject: [PATCH 04/14] cleanup --- static_sqlite_core/src/ffi.rs | 3 ++- static_sqlite_macros/src/lib.rs | 2 -- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/static_sqlite_core/src/ffi.rs b/static_sqlite_core/src/ffi.rs index dc04e5f..e17cff3 100644 --- a/static_sqlite_core/src/ffi.rs +++ b/static_sqlite_core/src/ffi.rs @@ -287,7 +287,8 @@ impl Sqlite { let count = sqlite3_column_count(stmt); for i in 0..count { let name_ptr = sqlite3_column_name(stmt, i); - if !name_ptr.is_null() { + + if !name_ptr.is_null() { let name = CStr::from_ptr(name_ptr).to_string_lossy().into_owned(); columns.push(name); } diff --git a/static_sqlite_macros/src/lib.rs b/static_sqlite_macros/src/lib.rs index e24d8d9..e02b715 100644 --- a/static_sqlite_macros/src/lib.rs +++ b/static_sqlite_macros/src/lib.rs @@ -411,7 +411,6 @@ fn fn_tokens(db: &Sqlite, schema: &Schema, exprs: &[&SqlExpr]) -> Result Result Result> { let rows: Vec<#pascal_case> = static_sqlite::query(db, #sql, vec![#(#params,)*]).await?; Ok(rows) From ff00577a53f2bed13499743bd422a85a1d0cd640 Mon Sep 17 00:00:00 2001 From: pm Date: Sat, 22 Mar 2025 17:58:37 +0100 Subject: [PATCH 05/14] fix readme --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 598edfe..13c9242 100644 --- a/README.md +++ b/README.md @@ -70,6 +70,7 @@ or If not explicitly specified, the parameter or column is assumed to be NOT NULL. +```rust sql! { let migrate = r#" create table User ( From a99d4b2b75bc2ccba576eae9c93615ef477e3d70 Mon Sep 17 00:00:00 2001 From: pm Date: Sun, 30 Mar 2025 22:12:38 +0200 Subject: [PATCH 06/14] add support for Stream> instead of Result> as output for queries. --- Cargo.lock | 135 ++++++++++++++++++++ Cargo.toml | 1 + README.md | 41 ++++++ src/lib.rs | 2 +- static_sqlite_async/Cargo.toml | 2 + static_sqlite_async/src/lib.rs | 34 ++++- static_sqlite_core/Cargo.toml | 1 - static_sqlite_core/src/ffi.rs | 219 +++++++++++++++++++++++--------- static_sqlite_macros/src/lib.rs | 8 +- tests/integration_test.rs | 91 +++++++++---- 10 files changed, 443 insertions(+), 91 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 82f750f..9e656b5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -26,6 +26,34 @@ dependencies = [ "memchr", ] +[[package]] +name = "async-stream" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "autocfg" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" + [[package]] name = "backtrace" version = "0.3.74" @@ -120,6 +148,95 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" + +[[package]] +name = "futures-executor" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" + +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "futures-sink" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" + +[[package]] +name = "futures-task" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" + +[[package]] +name = "futures-util" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + [[package]] name = "gimli" version = "0.31.0" @@ -231,6 +348,12 @@ version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + [[package]] name = "prettyplease" version = "0.2.25" @@ -353,6 +476,15 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "slab" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" +dependencies = [ + "autocfg", +] + [[package]] name = "sqlparser" version = "0.52.0" @@ -378,6 +510,7 @@ dependencies = [ name = "static_sqlite" version = "0.1.0" dependencies = [ + "futures", "static_sqlite_async", "static_sqlite_core", "static_sqlite_macros", @@ -389,7 +522,9 @@ dependencies = [ name = "static_sqlite_async" version = "0.1.0" dependencies = [ + "async-stream", "crossbeam-channel", + "futures", "static_sqlite_core", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index 1e72bd0..92214a0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ resolver = "2" static_sqlite_macros = { path = "static_sqlite_macros", version = "0.1.0" } static_sqlite_core = { path = "static_sqlite_core", version = "0.1.0" } static_sqlite_async = { path = "static_sqlite_async", version = "0.1.0" } +futures = { version = "0.3" } [dev-dependencies] tokio = { version = "1", features = ["rt", "sync", "macros"] } diff --git a/README.md b/README.md index 13c9242..e72ddbf 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,47 @@ async fn main() -> Result<()> { cargo add --git https://github.com/swlkr/static_sqlite ``` +# Example for Streams + +If you don't want to read the whole result set into memory, you can get the result +as a futures::Stream over items of the derived type. The fn with the postfix _stream is automatically +created. + +``` + sql! { + let migrate = r#" + create table Row ( + txt text + ) + "#; + + let insert_row = r#" + insert into Row (txt) values (:txt) returning * + "#; + + let select_rows = r#" + select * from Row + "#; + } + + let db = static_sqlite::open(":memory:").await?; + migrate(&db).await?; + + insert_row(&db, Some("test1")).await?.first_row()?; + insert_row(&db, Some("test2")).await?.first_row()?; + insert_row(&db, Some("test3")).await?.first_row()?; + insert_row(&db, Some("test4")).await?.first_row()?; + + let f = select_rows_stream(&db).await?; + + pin_mut!(f); + + assert_eq!(f.next().await.unwrap().unwrap().txt, Some("test1".into())); + assert_eq!(f.next().await.unwrap().unwrap().txt, Some("test2".into())); + assert_eq!(f.next().await.unwrap().unwrap().txt, Some("test3".into())); + assert_eq!(f.next().await.unwrap().unwrap().txt, Some("test4".into())); + +``` # Example with aliased columns and type-hints diff --git a/src/lib.rs b/src/lib.rs index 1b2404a..50535af 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,6 @@ extern crate self as static_sqlite; pub use static_sqlite_async::{ - execute, execute_all, open, query, rows, Error, FromRow, Result, Savepoint, Sqlite, Value, + execute, execute_all, iter, open, query, rows, Error, FromRow, Result, Savepoint, Sqlite, Value, }; pub use static_sqlite_core::FirstRow; pub use static_sqlite_macros::sql; diff --git a/static_sqlite_async/Cargo.toml b/static_sqlite_async/Cargo.toml index 8b22739..04d61bb 100644 --- a/static_sqlite_async/Cargo.toml +++ b/static_sqlite_async/Cargo.toml @@ -7,3 +7,5 @@ edition = "2021" static_sqlite_core = { path = "../static_sqlite_core", version = "0.1.0" } tokio = { version = "1", features = ["sync"] } crossbeam-channel = { version = "0.5" } +futures = { version = "0.3" } +async-stream = "0.3" diff --git a/static_sqlite_async/src/lib.rs b/static_sqlite_async/src/lib.rs index f9597b0..35cd9cf 100644 --- a/static_sqlite_async/src/lib.rs +++ b/static_sqlite_async/src/lib.rs @@ -1,8 +1,9 @@ // Inspired by the incredible tokio-rusqlite crate // https://github.com/programatik29/tokio-rusqlite/blob/master/src/lib.rs -use static_sqlite_core as core; use crossbeam_channel::Sender; +pub use futures::Stream; +use static_sqlite_core as core; use tokio::sync::oneshot; pub use static_sqlite_core::*; @@ -33,9 +34,7 @@ impl Sqlite { return Ok(()); } - result - .unwrap() - .map_err(|e| Error::Sqlite(e.to_string())) + result.unwrap().map_err(|e| Error::Sqlite(e.to_string())) } pub async fn call(&self, function: F) -> Result @@ -126,6 +125,33 @@ pub async fn query( conn.call(move |conn| conn.query(sql, ¶ms)).await } +pub async fn iter( + conn: &Sqlite, + sql: &'static str, + params: Vec, +) -> Result>> { + let (sender, receiver) = std::sync::mpsc::channel(); + + conn.sender + .send(Message::Execute(Box::new(move |conn| { + let value = conn.iter(sql, ¶ms).unwrap(); + + for item in value { + let res = sender.send(item); + if res.is_err() { + break; + } + } + }))) + .map_err(|_| Error::ConnectionClosed)?; + + Ok(async_stream::stream! { + for item in receiver { + yield item; + } + }) +} + pub async fn rows( conn: Sqlite, sql: &'static str, diff --git a/static_sqlite_core/Cargo.toml b/static_sqlite_core/Cargo.toml index c8eff2c..be2367c 100644 --- a/static_sqlite_core/Cargo.toml +++ b/static_sqlite_core/Cargo.toml @@ -6,4 +6,3 @@ edition = "2021" [dependencies] thiserror = "1" static_sqlite_ffi = { path = "../static_sqlite_ffi" } - diff --git a/static_sqlite_core/src/ffi.rs b/static_sqlite_core/src/ffi.rs index e17cff3..bd7c50c 100644 --- a/static_sqlite_core/src/ffi.rs +++ b/static_sqlite_core/src/ffi.rs @@ -9,6 +9,7 @@ use static_sqlite_ffi::{ use std::{ ffi::{c_char, c_int, CStr, CString, NulError}, + marker::PhantomData, num::TryFromIntError, ops::Deref, str::Utf8Error, @@ -16,6 +17,7 @@ use std::{ const SQLITE_ROW: i32 = static_sqlite_ffi::SQLITE_ROW as i32; const SQLITE_DONE: i32 = static_sqlite_ffi::SQLITE_DONE as i32; +const SQLITE_NULL: i32 = static_sqlite_ffi::SQLITE_NULL as i32; #[derive(thiserror::Error, Debug)] pub enum Error { @@ -160,25 +162,7 @@ impl Sqlite { .to_string_lossy() .into_owned(); - let value = match sqlite3_column_type(stmt, i) { - 1 => Value::Integer(sqlite3_column_int64(stmt, i)), - 2 => Value::Real(sqlite3_column_double(stmt, i)), - 3 => { - let text = - CStr::from_ptr(sqlite3_column_text(stmt, i) as *const c_char) - .to_string_lossy() - .into_owned(); - Value::Text(text) - } - 4 => { - let len = sqlite3_column_bytes(stmt, i) as usize; - let ptr = sqlite3_column_text(stmt, i); - let slice = std::slice::from_raw_parts(ptr, len); - Value::Blob(slice.to_vec()) - } - _ => Value::Null, - }; - + let value = Self::get_column_value(stmt, i)?; values.push((name, value)); } @@ -186,23 +170,84 @@ impl Sqlite { rows.push(row); } - if sqlite3_finalize(stmt) != 0 { - let error = CStr::from_ptr(sqlite3_errmsg(self.db)) - .to_string_lossy() - .into_owned(); - if error.starts_with("UNIQUE constraint failed: ") { - return Err(Error::UniqueConstraint( - error.replace("UNIQUE constraint failed: ", ""), + Self::finalize_statement(self.db, stmt)?; + + Ok(rows) + } + } + + unsafe fn get_column_value(stmt: *mut sqlite3_stmt, i: c_int) -> Result { + match sqlite3_column_type(stmt, i) { + x if x == static_sqlite_ffi::SQLITE_INTEGER as i32 => { + Ok(Value::Integer(sqlite3_column_int64(stmt, i))) + } + x if x == static_sqlite_ffi::SQLITE_FLOAT as i32 => { + Ok(Value::Real(sqlite3_column_double(stmt, i))) + } + x if x == static_sqlite_ffi::SQLITE_TEXT as i32 => { + let text_ptr = sqlite3_column_text(stmt, i) as *const c_char; + if text_ptr.is_null() { + Ok(Value::Text(String::new())) + } else { + let text = CStr::from_ptr(text_ptr).to_str()?.to_owned(); + Ok(Value::Text(text)) + } + } + x if x == static_sqlite_ffi::SQLITE_BLOB as i32 => { + let len = sqlite3_column_bytes(stmt, i); + if len < 0 { + return Err(Error::Sqlite( + "SQLite returned negative length for BLOB column".into(), )); + } + let len = len as usize; + let ptr = static_sqlite_ffi::sqlite3_column_blob(stmt, i); + if ptr.is_null() { + if len == 0 { + Ok(Value::Blob(vec![])) + } else { + Err(Error::Sqlite("SQLite returned null pointer for non-empty BLOB column (likely out of memory)".into())) + } } else { - return Err(Error::Sqlite(error)); + let slice = std::slice::from_raw_parts(ptr as *const u8, len); + Ok(Value::Blob(slice.to_vec())) } } + x if x == static_sqlite_ffi::SQLITE_NULL as i32 => Ok(Value::Null), + _ => Err(Error::Sqlite(format!( + "Unexpected column type {}", + sqlite3_column_type(stmt, i) + ))), + } + } - Ok(rows) + unsafe fn finalize_statement(db: *mut sqlite3, stmt: *mut sqlite3_stmt) -> Result<()> { + let rc = sqlite3_finalize(stmt); + if rc != static_sqlite_ffi::SQLITE_OK as i32 { + let error = CStr::from_ptr(sqlite3_errmsg(db)) + .to_string_lossy() + .into_owned(); + if error.starts_with("UNIQUE constraint failed: ") { + Err(Error::UniqueConstraint( + error.replace("UNIQUE constraint failed: ", ""), + )) + } else { + Err(Error::Sqlite(error)) + } + } else { + Ok(()) } } + pub fn iter<'a, T: FromRow + 'a>( + &'a self, + sql: &str, + params: &[Value], + ) -> Result> + 'a> { + let stmt = self.prepare(sql, params)?; + Ok(SqliteIterator::new(self, stmt)) + } + pub fn rows(&self, sql: &str, params: &[Value]) -> Result>> { unsafe { let stmt = self.prepare(sql, params)?; @@ -216,43 +261,14 @@ impl Sqlite { .to_string_lossy() .into_owned(); - let value = match sqlite3_column_type(stmt, i) { - 1 => Value::Integer(sqlite3_column_int64(stmt, i)), - 2 => Value::Real(sqlite3_column_double(stmt, i)), - 3 => { - let text = - CStr::from_ptr(sqlite3_column_text(stmt, i) as *const c_char) - .to_string_lossy() - .into_owned(); - Value::Text(text) - } - 4 => { - let len = sqlite3_column_bytes(stmt, i) as usize; - let ptr = sqlite3_column_text(stmt, i); - let slice = std::slice::from_raw_parts(ptr, len); - Value::Blob(slice.to_vec()) - } - _ => Value::Null, - }; - + let value = Self::get_column_value(stmt, i)?; values.push((name, value)); } rows.push(values); } - if sqlite3_finalize(stmt) != 0 { - let error = CStr::from_ptr(sqlite3_errmsg(self.db)) - .to_string_lossy() - .into_owned(); - if error.starts_with("UNIQUE constraint failed: ") { - return Err(Error::UniqueConstraint( - error.replace("UNIQUE constraint failed: ", ""), - )); - } else { - return Err(Error::Sqlite(error)); - } - } + Self::finalize_statement(self.db, stmt)?; Ok(rows) } @@ -581,3 +597,88 @@ impl From<()> for Value { Value::Null } } + +#[derive(Debug)] +pub struct SqliteIterator<'a, T: FromRow> { + db: &'a Sqlite, + stmt: *mut sqlite3_stmt, + finished: bool, + _marker: PhantomData, +} + +impl<'a, T: FromRow> SqliteIterator<'a, T> { + fn new(db: &'a Sqlite, stmt: *mut sqlite3_stmt) -> Self { + SqliteIterator { + db, + stmt, + finished: false, + _marker: PhantomData, + } + } +} + +impl<'a, T: FromRow> Iterator for SqliteIterator<'a, T> { + type Item = Result; + + fn next(&mut self) -> Option { + if self.finished { + return None; + } + + unsafe { + match sqlite3_step(self.stmt) { + SQLITE_ROW => { + let column_count = sqlite3_column_count(self.stmt); + let mut values: Vec<(String, Value)> = vec![]; + + for i in 0..column_count { + let name_ptr = sqlite3_column_name(self.stmt, i); + let name = if name_ptr.is_null() { + format!("column_{}", i) + } else { + match CStr::from_ptr(name_ptr).to_str() { + Ok(s) => s.to_owned(), + Err(e) => return Some(Err(e.into())), + } + }; + + match Sqlite::get_column_value(self.stmt, i) { + Ok(value) => values.push((name, value)), + Err(e) => { + self.finished = true; + return Some(Err(e)); + } + } + } + + match T::from_row(values) { + Ok(row) => Some(Ok(row)), + Err(e) => { + self.finished = true; + Some(Err(e)) + } + } + } + SQLITE_DONE => { + self.finished = true; + None + } + _ => { + self.finished = true; + let error = CStr::from_ptr(sqlite3_errmsg(self.db.db)) + .to_string_lossy() + .into_owned(); + Some(Err(Error::Sqlite(error))) + } + } + } + } +} + +impl<'a, T: FromRow> Drop for SqliteIterator<'a, T> { + fn drop(&mut self) { + unsafe { + let _ = Sqlite::finalize_statement(self.db.db, self.stmt); + } + } +} diff --git a/static_sqlite_macros/src/lib.rs b/static_sqlite_macros/src/lib.rs index e02b715..ce052a0 100644 --- a/static_sqlite_macros/src/lib.rs +++ b/static_sqlite_macros/src/lib.rs @@ -463,6 +463,7 @@ fn fn_tokens(db: &Sqlite, schema: &Schema, exprs: &[&SqlExpr]) -> Result>(); let ident = &expr.ident; + let ident_stream = Ident::new(&format!("{}_stream", ident), expr.ident.span()); let outputs = output_column_names(db, expr)?; let pascal_case = snake_to_pascal_case(&ident); @@ -470,7 +471,6 @@ fn fn_tokens(db: &Sqlite, schema: &Schema, exprs: &[&SqlExpr]) -> Result Result = static_sqlite::query(db, #sql, vec![#(#params,)*]).await?; Ok(rows) } + + #[doc = #sql] + #[allow(non_snake_case)] + pub async fn #ident_stream(db: &static_sqlite::Sqlite, #(#fn_args),*) -> Result>> { + static_sqlite::iter(db, #sql, vec![#(#params,)*]).await + } }) } Ok(output) diff --git a/tests/integration_test.rs b/tests/integration_test.rs index def1b84..fb5ee7a 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -1,5 +1,6 @@ +use futures::pin_mut; +use futures::StreamExt; use static_sqlite::{sql, FirstRow, Result, Sqlite}; - #[tokio::test] async fn option_type_works() -> Result<()> { sql! { @@ -24,6 +25,44 @@ async fn option_type_works() -> Result<()> { Ok(()) } +#[tokio::test] +async fn stream_works() -> Result<()> { + sql! { + let migrate = r#" + create table Row ( + txt text + ) + "#; + + let insert_row = r#" + insert into Row (txt) values (:txt) returning * + "#; + + let select_rows = r#" + select * from Row + "#; + } + + let db = static_sqlite::open(":memory:").await?; + migrate(&db).await?; + + insert_row(&db, Some("test1")).await?.first_row()?; + insert_row(&db, Some("test2")).await?.first_row()?; + insert_row(&db, Some("test3")).await?.first_row()?; + insert_row(&db, Some("test4")).await?.first_row()?; + + let f = select_rows_stream(&db).await?; + + pin_mut!(f); + + assert_eq!(f.next().await.unwrap().unwrap().txt, Some("test1".into())); + assert_eq!(f.next().await.unwrap().unwrap().txt, Some("test2".into())); + assert_eq!(f.next().await.unwrap().unwrap().txt, Some("test3".into())); + assert_eq!(f.next().await.unwrap().unwrap().txt, Some("test4".into())); + + Ok(()) +} + #[tokio::test] async fn it_works() -> Result<()> { sql! { @@ -118,7 +157,8 @@ async fn it_works() -> Result<()> { Some(2.0), Some(vec![0xFE, 0xED]), ) - .await?.first_row()?; + .await? + .first_row()?; assert_eq!( row, @@ -162,6 +202,7 @@ async fn readme_works() -> Result<()> { values (:name) returning * "#; + } let db = static_sqlite::open(":memory:").await?; @@ -257,24 +298,29 @@ async fn parameters_that_are_not_in_the_schema_work() -> Result<()> { let db = static_sqlite::open(":memory:").await?; let _ = migrate(&db).await?; let user1 = insert_user(&db, "user1").await?.first_row()?; - insert_post(&db, user1.id, "user 1 - post1").await?.first_row()?; - insert_post(&db, user1.id, "user 1 - post2").await?.first_row()?; + insert_post(&db, user1.id, "user 1 - post1") + .await? + .first_row()?; + insert_post(&db, user1.id, "user 1 - post2") + .await? + .first_row()?; let user2 = insert_user(&db, "user2").await?.first_row()?; - insert_post(&db, user2.id, "user 2 - post1").await?.first_row()?; - insert_post(&db, user2.id, "user 2 - post2").await?.first_row()?; + insert_post(&db, user2.id, "user 2 - post1") + .await? + .first_row()?; + insert_post(&db, user2.id, "user 2 - post2") + .await? + .first_row()?; let posts = select_posts(&db, 1, 2, "Hello", "sdd").await?; println!("{:?}", posts); - Ok(()) } - - #[tokio::test] async fn example_friendshipworks() -> Result<()> { - use static_sqlite::{sql, Result, self}; + use static_sqlite::{self, sql, Result}; sql! { let migrate = r#" @@ -324,15 +370,12 @@ async fn example_friendshipworks() -> Result<()> { assert_eq!(friends.first().unwrap().friend2_name, "toolbar23"); Ok(()) - } - - #[tokio::test] -async fn duplicate_column_names_in_one_quer2y_work() -> Result<()> { -sql! { - let migrate = r#" +async fn duplicate_column_names_in_one_query_work() -> Result<()> { + sql! { + let migrate = r#" CREATE TABLE IF NOT EXISTS Identifiers ( id INTEGER PRIMARY KEY AUTOINCREMENT, entity_type INTEGER NOT NULL, @@ -350,7 +393,7 @@ sql! { ); "#; - let get_changes = r#" + let get_changes = r#" SELECT mc.id, mc.timestamp, @@ -368,18 +411,16 @@ sql! { AND mc.timestamp > :since__INTEGER ORDER BY mc.timestamp ASC "#; -} + } -let db = static_sqlite::open(":memory:").await?; -let _ = migrate(&db).await?; -let changes = get_changes(&db, 1).await?; -println!("{:?}", changes); + let db = static_sqlite::open(":memory:").await?; + let _ = migrate(&db).await?; + let changes = get_changes(&db, 1).await?; + println!("{:?}", changes); -Ok(()) + Ok(()) } - - #[test] fn ui() { let t = trybuild::TestCases::new(); From 3fe5498d29981da94761e97e22ac77ecdc14a6b3 Mon Sep 17 00:00:00 2001 From: pm Date: Sun, 30 Mar 2025 22:25:07 +0200 Subject: [PATCH 07/14] clean imports, fix Result-mixup --- Cargo.lock | 7 +++ Cargo.toml | 2 +- src/lib.rs | 3 +- static_sqlite_async/src/lib.rs | 4 +- static_sqlite_core/src/ffi.rs | 1 - static_sqlite_macros/src/lib.rs | 90 +++++++++++++++++++++------------ tests/integration_test.rs | 4 +- 7 files changed, 71 insertions(+), 40 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9e656b5..5effb29 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -26,6 +26,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "anyhow" +version = "1.0.97" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcfed56ad506cb2c684a14971b8861fdc3baaaae314b9e5f9bb532cbe3ba7a4f" + [[package]] name = "async-stream" version = "0.3.6" @@ -510,6 +516,7 @@ dependencies = [ name = "static_sqlite" version = "0.1.0" dependencies = [ + "anyhow", "futures", "static_sqlite_async", "static_sqlite_core", diff --git a/Cargo.toml b/Cargo.toml index 92214a0..ee11665 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ static_sqlite_macros = { path = "static_sqlite_macros", version = "0.1.0" } static_sqlite_core = { path = "static_sqlite_core", version = "0.1.0" } static_sqlite_async = { path = "static_sqlite_async", version = "0.1.0" } futures = { version = "0.3" } - +anyhow = "1.0.97" [dev-dependencies] tokio = { version = "1", features = ["rt", "sync", "macros"] } trybuild = "1.0" diff --git a/src/lib.rs b/src/lib.rs index 50535af..a7182bf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ extern crate self as static_sqlite; pub use static_sqlite_async::{ - execute, execute_all, iter, open, query, rows, Error, FromRow, Result, Savepoint, Sqlite, Value, + execute, execute_all, open, query, rows, stream, Error, FromRow, Result, Savepoint, Sqlite, + Value, }; pub use static_sqlite_core::FirstRow; pub use static_sqlite_macros::sql; diff --git a/static_sqlite_async/src/lib.rs b/static_sqlite_async/src/lib.rs index 35cd9cf..4a887ef 100644 --- a/static_sqlite_async/src/lib.rs +++ b/static_sqlite_async/src/lib.rs @@ -125,11 +125,11 @@ pub async fn query( conn.call(move |conn| conn.query(sql, ¶ms)).await } -pub async fn iter( +pub async fn stream( conn: &Sqlite, sql: &'static str, params: Vec, -) -> Result>> { +) -> Result>> { let (sender, receiver) = std::sync::mpsc::channel(); conn.sender diff --git a/static_sqlite_core/src/ffi.rs b/static_sqlite_core/src/ffi.rs index bd7c50c..8443479 100644 --- a/static_sqlite_core/src/ffi.rs +++ b/static_sqlite_core/src/ffi.rs @@ -17,7 +17,6 @@ use std::{ const SQLITE_ROW: i32 = static_sqlite_ffi::SQLITE_ROW as i32; const SQLITE_DONE: i32 = static_sqlite_ffi::SQLITE_DONE as i32; -const SQLITE_NULL: i32 = static_sqlite_ffi::SQLITE_NULL as i32; #[derive(thiserror::Error, Debug)] pub enum Error { diff --git a/static_sqlite_macros/src/lib.rs b/static_sqlite_macros/src/lib.rs index ce052a0..5ed8096 100644 --- a/static_sqlite_macros/src/lib.rs +++ b/static_sqlite_macros/src/lib.rs @@ -380,7 +380,7 @@ fn migrate_fn(expr: &SqlExpr) -> TokenStream { let SqlExpr { ident, sql, .. } = expr; quote! { - pub async fn #ident(sqlite: &static_sqlite::Sqlite) -> Result<()> { + pub async fn #ident(sqlite: &static_sqlite::Sqlite) -> static_sqlite::Result<()> { let sql = #sql.to_string(); let _ = static_sqlite::execute_all(&sqlite, "create table if not exists __migrations__ (sql text primary key not null);".into()).await?; for stmt in sql.split(";").filter(|s| !s.trim().is_empty()) { @@ -427,15 +427,17 @@ fn fn_tokens(db: &Sqlite, schema: &Schema, exprs: &[&SqlExpr]) -> Result { let field_name = Ident::new(&type_hint.alias, expr.ident.span()); - let field_type = create_fn_argument_type(&type_hint.alias, &type_hint.column_type); + let field_type = + create_fn_argument_type(&type_hint.alias, &type_hint.column_type); match type_hint.not_null { 0 => quote! { #field_name: Option<#field_type> }, _ => quote! { #field_name: #field_type }, } - }, + } TypedToken::FromSchemaRow(schema_row) => { let field_name = Ident::new(&schema_row.column_name, expr.ident.span()); - let field_type = create_fn_argument_type(aliases_column_name, &schema_row.column_type); + let field_type = + create_fn_argument_type(aliases_column_name, &schema_row.column_type); match (schema_row.pk, schema_row.not_null) { (0, 0) => quote! { #field_name: Option<#field_type> }, _ => quote! { #field_name: #field_type }, @@ -452,13 +454,16 @@ fn fn_tokens(db: &Sqlite, schema: &Schema, exprs: &[&SqlExpr]) -> Result { let field_name = Ident::new(&type_hint.alias, expr.ident.span()); create_binding_value(&type_hint.column_type, type_hint.not_null, field_name) - }, + } TypedToken::FromSchemaRow(schema_row) => { let field_name = Ident::new(&schema_row.column_name, expr.ident.span()); - create_binding_value(&schema_row.column_type, schema_row.not_null, field_name) + create_binding_value( + &schema_row.column_type, + schema_row.not_null, + field_name, + ) } } - }) .collect::>(); @@ -467,7 +472,10 @@ fn fn_tokens(db: &Sqlite, schema: &Schema, exprs: &[&SqlExpr]) -> Result>(); + let output_typed = outputs + .iter() + .map(|output| parse_type_hinted_column_name(output, &schema_rows)) + .collect::>(); let struct_tokens = struct_tokens(expr.ident.span(), &pascal_case, &output_typed); @@ -477,15 +485,15 @@ fn fn_tokens(db: &Sqlite, schema: &Schema, exprs: &[&SqlExpr]) -> Result Result> { + pub async fn #ident(db: &static_sqlite::Sqlite, #(#fn_args),*) -> static_sqlite::Result> { let rows: Vec<#pascal_case> = static_sqlite::query(db, #sql, vec![#(#params,)*]).await?; Ok(rows) } #[doc = #sql] #[allow(non_snake_case)] - pub async fn #ident_stream(db: &static_sqlite::Sqlite, #(#fn_args),*) -> Result>> { - static_sqlite::iter(db, #sql, vec![#(#params,)*]).await + pub async fn #ident_stream(db: &static_sqlite::Sqlite, #(#fn_args),*) -> static_sqlite::Result>> { + static_sqlite::stream(db, #sql, vec![#(#params,)*]).await } }) } @@ -498,7 +506,11 @@ fn create_fn_argument_type(fieldname: &String, column_type: &str) -> TokenStream "INTEGER" => quote! { i64 }, "REAL" | "DOUBLE" => quote! { f64 }, "TEXT" => quote! { impl ToString }, - _ => unimplemented!("type {:?} not supported for fn arg {:?}", column_type, fieldname), + _ => unimplemented!( + "type {:?} not supported for fn arg {:?}", + column_type, + fieldname + ), } } @@ -535,7 +547,10 @@ struct TypeHintedToken { } #[derive(Debug, Clone)] -enum TypedToken { FromTypeHint(TypeHintedToken), FromSchemaRow(SchemaRow) } +enum TypedToken { + FromTypeHint(TypeHintedToken), + FromSchemaRow(SchemaRow), +} /* * Parses a type hint and returns a TypedColumnOrParameter @@ -609,7 +624,8 @@ fn structs_tokens(span: Span, schema: &Schema) -> Vec { .iter() .map(|(table, cols)| { let ident = proc_macro2::Ident::new(&table, span); - let typed_tokens: Vec = cols.iter() + let typed_tokens: Vec = cols + .iter() .map(|col| TypedToken::FromSchemaRow(col.clone())) .collect(); struct_tokens(span, &ident, &typed_tokens) @@ -617,24 +633,28 @@ fn structs_tokens(span: Span, schema: &Schema) -> Vec { .collect() } - fn struct_tokens(span: Span, ident: &Ident, output_typed: &[TypedToken]) -> TokenStream { let struct_fields = output_typed.iter().map(|row| { let field_type = match row { - TypedToken::FromTypeHint(type_hint) => field_type_from_datatype_name(&type_hint.column_type), + TypedToken::FromTypeHint(type_hint) => { + field_type_from_datatype_name(&type_hint.column_type) + } TypedToken::FromSchemaRow(schema_row) => field_type(schema_row), }; let name = match row { TypedToken::FromTypeHint(type_hint) => Ident::new(&type_hint.name, span), TypedToken::FromSchemaRow(schema_row) => Ident::new(&schema_row.column_name, span), }; - let optional = match ( match row { - TypedToken::FromTypeHint(type_hint) => type_hint.not_null, - TypedToken::FromSchemaRow(schema_row) => schema_row.not_null, - }, match row { - TypedToken::FromTypeHint(_) => 0, - TypedToken::FromSchemaRow(schema_row) => schema_row.pk, - }) { + let optional = match ( + match row { + TypedToken::FromTypeHint(type_hint) => type_hint.not_null, + TypedToken::FromSchemaRow(schema_row) => schema_row.not_null, + }, + match row { + TypedToken::FromTypeHint(_) => 0, + TypedToken::FromSchemaRow(schema_row) => schema_row.pk, + }, + ) { (0, 0) => true, (0, 1) | (1, 0) | (1, 1) => false, _ => unreachable!(), @@ -646,14 +666,20 @@ fn struct_tokens(span: Span, ident: &Ident, output_typed: &[TypedToken]) -> Toke } }); let match_stmt = output_typed.iter().map(|row| { - let name = Ident::new(match row { - TypedToken::FromTypeHint(type_hint) => &type_hint.name, - TypedToken::FromSchemaRow(schema_row) => &schema_row.column_name, - }, span); - let lit_str = LitStr::new(match row { - TypedToken::FromTypeHint(type_hint) => &type_hint.alias, - TypedToken::FromSchemaRow(schema_row) => &schema_row.column_name, - }, span); + let name = Ident::new( + match row { + TypedToken::FromTypeHint(type_hint) => &type_hint.name, + TypedToken::FromSchemaRow(schema_row) => &schema_row.column_name, + }, + span, + ); + let lit_str = LitStr::new( + match row { + TypedToken::FromTypeHint(type_hint) => &type_hint.alias, + TypedToken::FromSchemaRow(schema_row) => &schema_row.column_name, + }, + span, + ); quote! { #lit_str => row.#name = value.try_into()? @@ -681,12 +707,10 @@ fn struct_tokens(span: Span, ident: &Ident, output_typed: &[TypedToken]) -> Toke tokens } - fn field_type(row: &SchemaRow) -> TokenStream { field_type_from_datatype_name(&row.column_type) } - fn field_type_from_datatype_name(datatype_name: &str) -> TokenStream { match datatype_name { "BLOB" => quote! { Vec }, diff --git a/tests/integration_test.rs b/tests/integration_test.rs index fb5ee7a..65d08cc 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -13,7 +13,7 @@ async fn option_type_works() -> Result<()> { let insert_row = r#" insert into Row (txt) values (:txt) returning * "#; - } + }; let db = static_sqlite::open(":memory:").await?; let _k = migrate(&db).await?; @@ -320,7 +320,7 @@ async fn parameters_that_are_not_in_the_schema_work() -> Result<()> { #[tokio::test] async fn example_friendshipworks() -> Result<()> { - use static_sqlite::{self, sql, Result}; + use static_sqlite::{self, sql}; sql! { let migrate = r#" From 067da4a795aed29f8df6a03c73db8d4baaff65d2 Mon Sep 17 00:00:00 2001 From: pm Date: Sun, 30 Mar 2025 23:38:01 +0200 Subject: [PATCH 08/14] add query_first --- README.md | 36 +++++++++++++++++++++++++++++ src/lib.rs | 4 ++-- static_sqlite_async/src/lib.rs | 8 +++++++ static_sqlite_core/src/ffi.rs | 17 ++++++++++++++ static_sqlite_core/src/lib.rs | 8 +++++++ static_sqlite_macros/src/lib.rs | 8 +++++++ tests/integration_test.rs | 40 +++++++++++++++++++++++++++++++++ 7 files changed, 119 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index e72ddbf..825eb05 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,42 @@ async fn main() -> Result<()> { cargo add --git https://github.com/swlkr/static_sqlite ``` + +# Example for First + +With Sqlite you often do small queries that just return on row. For this a fn with the postfix _first +is automatically created. + +``` + sql! { + let migrate = r#" + create table Row ( + id integer primary key autoincrement, + txt text NOT NULL + ) + "#; + + let insert_row = r#" + insert into Row (txt) values (:txt) returning * + "#; + + let select_row = r#" + select * from Row where id = :id + "#; + } + + let db = static_sqlite::open(":memory:").await?; + migrate(&db).await?; + + insert_row(&db, "test1").await?.first_row()?; + insert_row(&db, "test2").await?.first_row()?; + + match select_row_first(&db, 1).await? { + Some(row) => assert_eq!(row.txt, "test1"), + None => panic!("Row 1 not found"), + } +``` + # Example for Streams If you don't want to read the whole result set into memory, you can get the result diff --git a/src/lib.rs b/src/lib.rs index a7182bf..76ddb83 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,7 @@ extern crate self as static_sqlite; pub use static_sqlite_async::{ - execute, execute_all, open, query, rows, stream, Error, FromRow, Result, Savepoint, Sqlite, - Value, + execute, execute_all, open, query, query_first, rows, stream, Error, FromRow, Result, + Savepoint, Sqlite, Value, }; pub use static_sqlite_core::FirstRow; pub use static_sqlite_macros::sql; diff --git a/static_sqlite_async/src/lib.rs b/static_sqlite_async/src/lib.rs index 4a887ef..d6e69c7 100644 --- a/static_sqlite_async/src/lib.rs +++ b/static_sqlite_async/src/lib.rs @@ -125,6 +125,14 @@ pub async fn query( conn.call(move |conn| conn.query(sql, ¶ms)).await } +pub async fn query_first( + conn: &Sqlite, + sql: &'static str, + params: Vec, +) -> Result> { + conn.call(move |conn| conn.query_first(sql, ¶ms)).await +} + pub async fn stream( conn: &Sqlite, sql: &'static str, diff --git a/static_sqlite_core/src/ffi.rs b/static_sqlite_core/src/ffi.rs index 8443479..dd7ea34 100644 --- a/static_sqlite_core/src/ffi.rs +++ b/static_sqlite_core/src/ffi.rs @@ -34,6 +34,8 @@ pub enum Error { ConnectionClosed, #[error("sqlite row not found")] RowNotFound, + #[error("sqlite returned too many rows in result")] + TooManyRowsInResult, #[error(transparent)] Utf8Error(#[from] Utf8Error), } @@ -175,6 +177,21 @@ impl Sqlite { } } + pub fn query_first( + &self, + sql: &'static str, + params: &[Value], + ) -> Result> { + match self.query(sql, params) { + Ok(rows) => Ok(match rows.len() { + 0 => None, + 1 => Some(rows.into_iter().nth(0).unwrap()), + _ => return Err(Error::RowNotFound), + }), + Err(e) => Err(e), + } + } + unsafe fn get_column_value(stmt: *mut sqlite3_stmt, i: c_int) -> Result { match sqlite3_column_type(stmt, i) { x if x == static_sqlite_ffi::SQLITE_INTEGER as i32 => { diff --git a/static_sqlite_core/src/lib.rs b/static_sqlite_core/src/lib.rs index b5ce9e7..850008d 100644 --- a/static_sqlite_core/src/lib.rs +++ b/static_sqlite_core/src/lib.rs @@ -21,6 +21,14 @@ pub fn query( conn.query(sql, params) } +pub fn query_first( + conn: &Sqlite, + sql: &'static str, + params: &[Value], +) -> Result> { + conn.query_first(sql, params) +} + pub fn rows(conn: &Sqlite, sql: &str, params: &[Value]) -> Result>> { conn.rows(sql, params) } diff --git a/static_sqlite_macros/src/lib.rs b/static_sqlite_macros/src/lib.rs index 5ed8096..b3d4d19 100644 --- a/static_sqlite_macros/src/lib.rs +++ b/static_sqlite_macros/src/lib.rs @@ -469,6 +469,7 @@ fn fn_tokens(db: &Sqlite, schema: &Schema, exprs: &[&SqlExpr]) -> Result Result static_sqlite::Result>> { static_sqlite::stream(db, #sql, vec![#(#params,)*]).await } + + + #[doc = #sql] + #[allow(non_snake_case)] + pub async fn #ident_first(db: &static_sqlite::Sqlite, #(#fn_args),*) -> static_sqlite::Result> { + static_sqlite::query_first(db, #sql, vec![#(#params,)*]).await + } }) } Ok(output) diff --git a/tests/integration_test.rs b/tests/integration_test.rs index 65d08cc..0284317 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -63,6 +63,46 @@ async fn stream_works() -> Result<()> { Ok(()) } +#[tokio::test] +async fn query_first_works() -> Result<()> { + sql! { + let migrate = r#" + create table Row ( + id integer primary key autoincrement, + txt text NOT NULL + ) + "#; + + let insert_row = r#" + insert into Row (txt) values (:txt) returning * + "#; + + let select_row = r#" + select * from Row where id = :id + "#; + } + + let db = static_sqlite::open(":memory:").await?; + migrate(&db).await?; + + insert_row(&db, "test1").await?.first_row()?; + insert_row(&db, "test2").await?.first_row()?; + insert_row(&db, "test3").await?.first_row()?; + insert_row(&db, "test4").await?.first_row()?; + + match select_row_first(&db, 1).await? { + Some(row) => assert_eq!(row.txt, "test1"), + None => panic!("Row 1 not found"), + } + + match select_row_first(&db, 2).await? { + Some(row) => assert_eq!(row.txt, "test2"), + None => panic!("Row 2 not found"), + } + + Ok(()) +} + #[tokio::test] async fn it_works() -> Result<()> { sql! { From 01973b9823ce217fa34a4203a6c677c2301ec577 Mon Sep 17 00:00:00 2001 From: pm Date: Mon, 31 Mar 2025 08:49:03 +0200 Subject: [PATCH 09/14] only generate one version for each statement, postfix _stream / _first used to select the variants --- README.md | 20 ++++++------ static_sqlite_macros/src/lib.rs | 57 +++++++++++++++++++++------------ tests/integration_test.rs | 16 +++++---- 3 files changed, 57 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index 825eb05..fb58c6d 100644 --- a/README.md +++ b/README.md @@ -51,10 +51,11 @@ cargo add --git https://github.com/swlkr/static_sqlite # Example for First -With Sqlite you often do small queries that just return on row. For this a fn with the postfix _first -is automatically created. +If the name of your statement ends with "_first", the created fn return an Option with the first value instead of a Vec. -``` +I the query returns more than one rows, it throws an error. + +```rust sql! { let migrate = r#" create table Row ( @@ -86,12 +87,12 @@ is automatically created. # Example for Streams -If you don't want to read the whole result set into memory, you can get the result -as a futures::Stream over items of the derived type. The fn with the postfix _stream is automatically -created. +If the name of your statement ends with "_stream", the created fn return an async Stream instead of a Vec. -``` - sql! { +This way you can iterate over large result sets. + +```rust +sql! { let migrate = r#" create table Row ( txt text @@ -102,7 +103,7 @@ created. insert into Row (txt) values (:txt) returning * "#; - let select_rows = r#" + let select_rows_stream = r#" select * from Row "#; } @@ -123,6 +124,7 @@ created. assert_eq!(f.next().await.unwrap().unwrap().txt, Some("test2".into())); assert_eq!(f.next().await.unwrap().unwrap().txt, Some("test3".into())); assert_eq!(f.next().await.unwrap().unwrap().txt, Some("test4".into())); +} ``` diff --git a/static_sqlite_macros/src/lib.rs b/static_sqlite_macros/src/lib.rs index b3d4d19..b06f076 100644 --- a/static_sqlite_macros/src/lib.rs +++ b/static_sqlite_macros/src/lib.rs @@ -395,6 +395,12 @@ fn migrate_fn(expr: &SqlExpr) -> TokenStream { } } +enum FunctionType { + QueryVec, + QueryOption, + Stream, +} + fn fn_tokens(db: &Sqlite, schema: &Schema, exprs: &[&SqlExpr]) -> Result> { let mut output = vec![]; for expr in exprs { @@ -467,9 +473,15 @@ fn fn_tokens(db: &Sqlite, schema: &Schema, exprs: &[&SqlExpr]) -> Result>(); + let ident_type = if expr.ident.to_string().ends_with("_stream") { + FunctionType::Stream + } else if expr.ident.to_string().ends_with("_first") { + FunctionType::QueryOption + } else { + FunctionType::QueryVec + }; + let ident = &expr.ident; - let ident_stream = Ident::new(&format!("{}_stream", ident), expr.ident.span()); - let ident_first = Ident::new(&format!("{}_first", ident), expr.ident.span()); let outputs = output_column_names(db, expr)?; let pascal_case = snake_to_pascal_case(&ident); @@ -481,28 +493,33 @@ fn fn_tokens(db: &Sqlite, schema: &Schema, exprs: &[&SqlExpr]) -> Result static_sqlite::Result> { - let rows: Vec<#pascal_case> = static_sqlite::query(db, #sql, vec![#(#params,)*]).await?; - Ok(rows) - } + let fn_tokens = match ident_type { + FunctionType::QueryVec => quote! { + pub async fn #ident(db: &static_sqlite::Sqlite, #(#fn_args),*) -> static_sqlite::Result> { + let rows: Vec<#pascal_case> = static_sqlite::query(db, #sql, vec![#(#params,)*]).await?; + Ok(rows) + } + }, + FunctionType::QueryOption => quote! { + pub async fn #ident(db: &static_sqlite::Sqlite, #(#fn_args),*) -> static_sqlite::Result> { + static_sqlite::query_first(db, #sql, vec![#(#params,)*]).await + } - #[doc = #sql] - #[allow(non_snake_case)] - pub async fn #ident_stream(db: &static_sqlite::Sqlite, #(#fn_args),*) -> static_sqlite::Result>> { - static_sqlite::stream(db, #sql, vec![#(#params,)*]).await - } + }, + FunctionType::Stream => quote! { + pub async fn #ident(db: &static_sqlite::Sqlite, #(#fn_args),*) -> static_sqlite::Result>> { + static_sqlite::stream(db, #sql, vec![#(#params,)*]).await + } + }, + }; + + output.push(quote! { + #struct_tokens + + #fn_tokens - #[doc = #sql] - #[allow(non_snake_case)] - pub async fn #ident_first(db: &static_sqlite::Sqlite, #(#fn_args),*) -> static_sqlite::Result> { - static_sqlite::query_first(db, #sql, vec![#(#params,)*]).await - } }) } Ok(output) diff --git a/tests/integration_test.rs b/tests/integration_test.rs index 0284317..5d2bd2a 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -38,7 +38,7 @@ async fn stream_works() -> Result<()> { insert into Row (txt) values (:txt) returning * "#; - let select_rows = r#" + let select_rows_stream = r#" select * from Row "#; } @@ -73,11 +73,11 @@ async fn query_first_works() -> Result<()> { ) "#; - let insert_row = r#" + let insert_row_first = r#" insert into Row (txt) values (:txt) returning * "#; - let select_row = r#" + let select_row_first = r#" select * from Row where id = :id "#; } @@ -85,10 +85,10 @@ async fn query_first_works() -> Result<()> { let db = static_sqlite::open(":memory:").await?; migrate(&db).await?; - insert_row(&db, "test1").await?.first_row()?; - insert_row(&db, "test2").await?.first_row()?; - insert_row(&db, "test3").await?.first_row()?; - insert_row(&db, "test4").await?.first_row()?; + assert_eq!(insert_row_first(&db, "test1").await?.unwrap().txt, "test1"); + assert_eq!(insert_row_first(&db, "test2").await?.unwrap().txt, "test2"); + assert_eq!(insert_row_first(&db, "test3").await?.unwrap().txt, "test3"); + assert_eq!(insert_row_first(&db, "test4").await?.unwrap().txt, "test4"); match select_row_first(&db, 1).await? { Some(row) => assert_eq!(row.txt, "test1"), @@ -395,6 +395,8 @@ async fn example_friendshipworks() -> Result<()> { AND Friendship.friend_id = u2.id AND Friendship.id = :friendship_id__INTEGER "#; + + } let db = static_sqlite::open(":memory:").await?; From ed2e506bc182a35a8f4839aa64668c6dd9a4e13a Mon Sep 17 00:00:00 2001 From: pm Date: Mon, 31 Mar 2025 08:50:12 +0200 Subject: [PATCH 10/14] cleanup --- static_sqlite_macros/src/lib.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/static_sqlite_macros/src/lib.rs b/static_sqlite_macros/src/lib.rs index b06f076..18ab829 100644 --- a/static_sqlite_macros/src/lib.rs +++ b/static_sqlite_macros/src/lib.rs @@ -473,7 +473,7 @@ fn fn_tokens(db: &Sqlite, schema: &Schema, exprs: &[&SqlExpr]) -> Result>(); - let ident_type = if expr.ident.to_string().ends_with("_stream") { + let fn_type = if expr.ident.to_string().ends_with("_stream") { FunctionType::Stream } else if expr.ident.to_string().ends_with("_first") { FunctionType::QueryOption @@ -494,7 +494,7 @@ fn fn_tokens(db: &Sqlite, schema: &Schema, exprs: &[&SqlExpr]) -> Result quote! { pub async fn #ident(db: &static_sqlite::Sqlite, #(#fn_args),*) -> static_sqlite::Result> { let rows: Vec<#pascal_case> = static_sqlite::query(db, #sql, vec![#(#params,)*]).await?; @@ -505,8 +505,6 @@ fn fn_tokens(db: &Sqlite, schema: &Schema, exprs: &[&SqlExpr]) -> Result static_sqlite::Result> { static_sqlite::query_first(db, #sql, vec![#(#params,)*]).await } - - }, FunctionType::Stream => quote! { pub async fn #ident(db: &static_sqlite::Sqlite, #(#fn_args),*) -> static_sqlite::Result>> { From 4fc4e1ed07a85d1d18477f5514e66c899954f51a Mon Sep 17 00:00:00 2001 From: pm Date: Mon, 31 Mar 2025 10:28:25 +0200 Subject: [PATCH 11/14] add transaction support --- static_sqlite_async/src/lib.rs | 12 ++++++++ static_sqlite_core/src/ffi.rs | 21 ++++++++++++++ tests/integration_test.rs | 53 +++++++++++++++++++++++++++++++++- 3 files changed, 85 insertions(+), 1 deletion(-) diff --git a/static_sqlite_async/src/lib.rs b/static_sqlite_async/src/lib.rs index d6e69c7..2fba3fa 100644 --- a/static_sqlite_async/src/lib.rs +++ b/static_sqlite_async/src/lib.rs @@ -53,6 +53,18 @@ impl Sqlite { receiver.await.map_err(|_| Error::ConnectionClosed)? } + + pub async fn begin_transaction(&self) -> Result<()> { + self.call(move |conn| conn.begin_transaction()).await + } + + pub async fn commit_transaction(&self) -> Result<()> { + self.call(move |conn| conn.commit_transaction()).await + } + + pub async fn rollback_transaction(&self) -> Result<()> { + self.call(move |conn| conn.rollback_transaction()).await + } } pub async fn open(path: impl ToString) -> Result { diff --git a/static_sqlite_core/src/ffi.rs b/static_sqlite_core/src/ffi.rs index dd7ea34..9e55565 100644 --- a/static_sqlite_core/src/ffi.rs +++ b/static_sqlite_core/src/ffi.rs @@ -150,6 +150,27 @@ impl Sqlite { self.execute(sql, vec![]) } + pub fn begin_transaction(&self) -> Result<()> { + match self.execute("BEGIN TRANSACTION", vec![]) { + Ok(_) => Ok(()), + Err(e) => Err(e), + } + } + + pub fn commit_transaction(&self) -> Result<()> { + match self.execute("COMMIT TRANSACTION", vec![]) { + Ok(_) => Ok(()), + Err(e) => Err(e), + } + } + + pub fn rollback_transaction(&self) -> Result<()> { + match self.execute("ROLLBACK TRANSACTION", vec![]) { + Ok(_) => Ok(()), + Err(e) => Err(e), + } + } + pub fn query(&self, sql: &'static str, params: &[Value]) -> Result> { unsafe { let stmt = self.prepare(sql, params)?; diff --git a/tests/integration_test.rs b/tests/integration_test.rs index 5d2bd2a..d9fd026 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -255,6 +255,57 @@ async fn readme_works() -> Result<()> { Ok(()) } +#[tokio::test] +async fn transaction_works() -> Result<()> { + sql! { + let migrate = r#" + create table Item ( + id integer primary key + ); + "#; + + let insert_item = r#" + insert into Item (id) + values (:id) + returning * + "#; + + let get_item_first = r#" + select id from Item where id = :id + "#; + + } + let db = static_sqlite::open(":memory:").await?; + let _ = migrate(&db).await?; + + // being; insert; commmit + db.begin_transaction().await?; + insert_item(&db, 1).await?; + let item1 = get_item_first(&db, 1).await?; + assert_eq!(item1.is_some(), true); + db.commit_transaction().await?; + + // begin; insert; rollback + db.begin_transaction().await?; + + insert_item(&db, 2).await?; + let item2_in_transaction = get_item_first(&db, 2).await?; + assert_eq!(item2_in_transaction.is_some(), true); + + db.rollback_transaction().await?; + + let item2_after_rollback = get_item_first(&db, 2).await?; + assert_eq!(item2_after_rollback.is_some(), false); + + // rollback without begin + match db.rollback_transaction().await { + Ok(_) => panic!("should fail because no transaction is in progress"), + Err(_) => (), + } + + Ok(()) +} + #[tokio::test] async fn crud_works() -> Result<()> { sql! { @@ -282,8 +333,8 @@ async fn crud_works() -> Result<()> { let all_users = r#" select id, name from User "#; - } + } let db = static_sqlite::open(":memory:").await?; let _ = migrate(&db).await?; let user = insert_user(&db, "swlkr").await?.first_row()?; From 7dc415d51a7ee9dbefd28e5f4d4fe165f33fbe47 Mon Sep 17 00:00:00 2001 From: pm Date: Mon, 31 Mar 2025 10:37:19 +0200 Subject: [PATCH 12/14] readme updated for transaction --- README.md | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/README.md b/README.md index fb58c6d..021247f 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,25 @@ cargo add --git https://github.com/swlkr/static_sqlite ``` +# Example for Transactions + +Use the methods begin_transaction, commit_transaction and rollback_transaction to manage Sqlite transactions. + + +```rust + + // migration and sql-fn definition goes here + + let db = static_sqlite::open(":memory:").await?; + + migrate(&db).await?; + + db.begin_transaction()?; + insert_row(&db, "test1").await?.first_row()?; + insert_row(&db, "test2").await?.first_row()?; + db.commit_transaction()?; +``` + # Example for First If the name of your statement ends with "_first", the created fn return an Option with the first value instead of a Vec. From f66d93c05c984e67a0967f0ec7e776a33a62e416 Mon Sep 17 00:00:00 2001 From: pm Date: Wed, 2 Apr 2025 08:56:57 +0200 Subject: [PATCH 13/14] fix error: TooManyRowsInResult instead of RowNotFound for multiple rows. --- static_sqlite_core/src/ffi.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/static_sqlite_core/src/ffi.rs b/static_sqlite_core/src/ffi.rs index dd7ea34..1ef4ed3 100644 --- a/static_sqlite_core/src/ffi.rs +++ b/static_sqlite_core/src/ffi.rs @@ -186,7 +186,7 @@ impl Sqlite { Ok(rows) => Ok(match rows.len() { 0 => None, 1 => Some(rows.into_iter().nth(0).unwrap()), - _ => return Err(Error::RowNotFound), + _ => return Err(Error::TooManyRowsInResult), }), Err(e) => Err(e), } From 0bc54450c0a881debf51954ac10a0570bf50ddb9 Mon Sep 17 00:00:00 2001 From: pm Date: Wed, 2 Apr 2025 11:21:27 +0200 Subject: [PATCH 14/14] must allow snake-case here to allow the type-hints without warning, another reason to handle this on another level --- static_sqlite_macros/src/lib.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/static_sqlite_macros/src/lib.rs b/static_sqlite_macros/src/lib.rs index 18ab829..2c1599d 100644 --- a/static_sqlite_macros/src/lib.rs +++ b/static_sqlite_macros/src/lib.rs @@ -496,17 +496,20 @@ fn fn_tokens(db: &Sqlite, schema: &Schema, exprs: &[&SqlExpr]) -> Result quote! { + #[allow(non_snake_case)] pub async fn #ident(db: &static_sqlite::Sqlite, #(#fn_args),*) -> static_sqlite::Result> { let rows: Vec<#pascal_case> = static_sqlite::query(db, #sql, vec![#(#params,)*]).await?; Ok(rows) } }, FunctionType::QueryOption => quote! { + #[allow(non_snake_case)] pub async fn #ident(db: &static_sqlite::Sqlite, #(#fn_args),*) -> static_sqlite::Result> { static_sqlite::query_first(db, #sql, vec![#(#params,)*]).await } }, FunctionType::Stream => quote! { + #[allow(non_snake_case)] pub async fn #ident(db: &static_sqlite::Sqlite, #(#fn_args),*) -> static_sqlite::Result>> { static_sqlite::stream(db, #sql, vec![#(#params,)*]).await }