diff --git a/crates/lance-graph/src/datafusion_planner/analysis.rs b/crates/lance-graph/src/datafusion_planner/analysis.rs index 113072c..f078e77 100644 --- a/crates/lance-graph/src/datafusion_planner/analysis.rs +++ b/crates/lance-graph/src/datafusion_planner/analysis.rs @@ -38,13 +38,18 @@ pub struct RelationshipInstance { pub struct PlanningContext<'a> { pub analysis: &'a QueryAnalysis, pub(crate) relationship_instance_idx: HashMap, + pub parameters: &'a HashMap, } impl<'a> PlanningContext<'a> { - pub fn new(analysis: &'a QueryAnalysis) -> Self { + pub fn new( + analysis: &'a QueryAnalysis, + parameters: &'a HashMap, + ) -> Self { Self { analysis, relationship_instance_idx: HashMap::new(), + parameters, } } @@ -383,7 +388,8 @@ mod tests { required_datasets: HashSet::new(), }; - let mut ctx = PlanningContext::new(&analysis); + let empty_params = HashMap::new(); + let mut ctx = PlanningContext::new(&analysis, &empty_params); // First call should return first instance let inst1 = ctx.next_relationship_instance("KNOWS").unwrap(); diff --git a/crates/lance-graph/src/datafusion_planner/builder/aggregate_ops.rs b/crates/lance-graph/src/datafusion_planner/builder/aggregate_ops.rs index 6feb2c2..0f9cbec 100644 --- a/crates/lance-graph/src/datafusion_planner/builder/aggregate_ops.rs +++ b/crates/lance-graph/src/datafusion_planner/builder/aggregate_ops.rs @@ -3,6 +3,7 @@ //! Aggregation operations: Projection with aggregates and grouping +use crate::datafusion_planner::analysis::PlanningContext; use crate::datafusion_planner::DataFusionPlanner; use crate::error::Result; use crate::logical_plan::*; @@ -11,6 +12,7 @@ use datafusion::logical_expr::{col, LogicalPlan, LogicalPlanBuilder}; impl DataFusionPlanner { pub(crate) fn build_project_with_aggregates( &self, + ctx: &mut PlanningContext, input_plan: LogicalPlan, projections: &[ProjectionItem], ) -> Result { @@ -21,7 +23,7 @@ impl DataFusionPlanner { let mut agg_aliases = Vec::new(); for p in projections { - let expr = super::super::expression::to_df_value_expr(&p.expression); + let expr = super::super::expression::to_df_value_expr(&p.expression, ctx.parameters); if super::super::expression::contains_aggregate(&p.expression) { // Aggregate expressions get aliased @@ -44,7 +46,8 @@ impl DataFusionPlanner { for p in projections { if !super::super::expression::contains_aggregate(&p.expression) { // Re-create the expression and apply alias - let expr = super::super::expression::to_df_value_expr(&p.expression); + let expr = + super::super::expression::to_df_value_expr(&p.expression, ctx.parameters); let aliased = if let Some(alias) = &p.alias { expr.alias(alias) } else { diff --git a/crates/lance-graph/src/datafusion_planner/builder/basic_ops.rs b/crates/lance-graph/src/datafusion_planner/builder/basic_ops.rs index 8f1c893..10b3c87 100644 --- a/crates/lance-graph/src/datafusion_planner/builder/basic_ops.rs +++ b/crates/lance-graph/src/datafusion_planner/builder/basic_ops.rs @@ -17,7 +17,7 @@ impl DataFusionPlanner { predicate: &crate::ast::BooleanExpression, ) -> Result { let input_plan = self.build_operator(ctx, input)?; - let expr = super::super::expression::to_df_boolean_expr(predicate); + let expr = super::super::expression::to_df_boolean_expr(predicate, ctx.parameters); LogicalPlanBuilder::from(input_plan) .filter(expr) .map_err(|e| self.plan_error("Failed to build filter", e))? @@ -39,21 +39,23 @@ impl DataFusionPlanner { .any(|p| super::super::expression::contains_aggregate(&p.expression)); if has_aggregates { - self.build_project_with_aggregates(input_plan, projections) + self.build_project_with_aggregates(ctx, input_plan, projections) } else { - self.build_simple_project(input_plan, projections) + self.build_simple_project(ctx, input_plan, projections) } } pub(crate) fn build_simple_project( &self, + ctx: &mut PlanningContext, input_plan: LogicalPlan, projections: &[ProjectionItem], ) -> Result { let exprs: Vec = projections .iter() .map(|p| { - let expr = super::super::expression::to_df_value_expr(&p.expression); + let expr = + super::super::expression::to_df_value_expr(&p.expression, ctx.parameters); // Apply alias if provided, otherwise use Cypher dot notation // Normalize alias to lowercase for case-insensitive behavior if let Some(alias) = &p.alias { @@ -98,7 +100,8 @@ impl DataFusionPlanner { let sort_exprs: Vec = sort_items .iter() .map(|item| { - let expr = super::super::expression::to_df_value_expr(&item.expression); + let expr = + super::super::expression::to_df_value_expr(&item.expression, ctx.parameters); let asc = matches!(item.direction, crate::ast::SortDirection::Ascending); SortExpr { expr, @@ -160,7 +163,7 @@ impl DataFusionPlanner { }; // Convert expression to DataFusion Expr - let df_expr = super::super::expression::to_df_value_expr(expression); + let df_expr = super::super::expression::to_df_value_expr(expression, ctx.parameters); // We project the list expression first (aliased as the target alias temporarily) // DataFusion unnest takes a column name. diff --git a/crates/lance-graph/src/datafusion_planner/builder/expand_ops.rs b/crates/lance-graph/src/datafusion_planner/builder/expand_ops.rs index 85231ce..5522fe5 100644 --- a/crates/lance-graph/src/datafusion_planner/builder/expand_ops.rs +++ b/crates/lance-graph/src/datafusion_planner/builder/expand_ops.rs @@ -59,7 +59,7 @@ impl DataFusionPlanner { // Build relationship scan with qualified columns and property filters let rel_scan = - self.build_relationship_scan(&rel_instance, rel_source, relationship_properties)?; + self.build_relationship_scan(ctx, &rel_instance, rel_source, relationship_properties)?; // Join source node with relationship let source_params = SourceJoinParams { @@ -297,6 +297,7 @@ impl DataFusionPlanner { // Build target node scan and join let target_scan = self.build_qualified_target_scan( + ctx, catalog, &target_label, target_variable, diff --git a/crates/lance-graph/src/datafusion_planner/builder/join_builder.rs b/crates/lance-graph/src/datafusion_planner/builder/join_builder.rs index 019d771..caddf87 100644 --- a/crates/lance-graph/src/datafusion_planner/builder/join_builder.rs +++ b/crates/lance-graph/src/datafusion_planner/builder/join_builder.rs @@ -515,7 +515,8 @@ mod tests { // Analyze both patterns to build the context let left_analysis = analysis::analyze(&expand_left).unwrap(); - let left_ctx = analysis::PlanningContext::new(&left_analysis); + let empty_params = std::collections::HashMap::new(); + let left_ctx = analysis::PlanningContext::new(&left_analysis, &empty_params); // Test the key inference logic directly let (left_keys, right_keys) = diff --git a/crates/lance-graph/src/datafusion_planner/config_helpers.rs b/crates/lance-graph/src/datafusion_planner/config_helpers.rs index c8172b7..075c5c4 100644 --- a/crates/lance-graph/src/datafusion_planner/config_helpers.rs +++ b/crates/lance-graph/src/datafusion_planner/config_helpers.rs @@ -140,7 +140,8 @@ mod tests { analysis .var_to_label .insert("b".to_string(), "Person".to_string()); - let ctx = PlanningContext::new(&analysis); + let empty_params = std::collections::HashMap::new(); + let ctx = PlanningContext::new(&analysis, &empty_params); let (label, node_map) = planner .get_target_node_mapping(&ctx, "b") @@ -156,7 +157,8 @@ mod tests { analysis .var_to_label .insert("a".to_string(), "Person".to_string()); - let ctx = PlanningContext::new(&analysis); + let empty_params = std::collections::HashMap::new(); + let ctx = PlanningContext::new(&analysis, &empty_params); let (label, node_map) = planner .get_target_node_mapping(&ctx, "_temp_a_1") @@ -169,7 +171,8 @@ mod tests { fn test_get_target_node_mapping_invalid_temp_variable() { let planner = planner_with_basic_config(); let analysis = QueryAnalysis::default(); - let ctx = PlanningContext::new(&analysis); + let empty_params = std::collections::HashMap::new(); + let ctx = PlanningContext::new(&analysis, &empty_params); let err = planner .get_target_node_mapping(&ctx, "_temp_invalid") @@ -185,7 +188,8 @@ mod tests { analysis .var_to_label .insert("a".to_string(), "Person".to_string()); - let ctx = PlanningContext::new(&analysis); + let empty_params = std::collections::HashMap::new(); + let ctx = PlanningContext::new(&analysis, &empty_params); let err = planner.get_target_node_mapping(&ctx, "c").unwrap_err(); let msg = format!("{}", err); @@ -203,7 +207,8 @@ mod tests { analysis .var_to_label .insert("b".to_string(), "Organization".to_string()); - let ctx = PlanningContext::new(&analysis); + let empty_params = std::collections::HashMap::new(); + let ctx = PlanningContext::new(&analysis, &empty_params); let err = planner.get_target_node_mapping(&ctx, "b").unwrap_err(); let msg = format!("{}", err); diff --git a/crates/lance-graph/src/datafusion_planner/expression.rs b/crates/lance-graph/src/datafusion_planner/expression.rs index 2f4dd7e..c4a6b58 100644 --- a/crates/lance-graph/src/datafusion_planner/expression.rs +++ b/crates/lance-graph/src/datafusion_planner/expression.rs @@ -18,12 +18,38 @@ use datafusion_functions_aggregate::count::count_distinct; use datafusion_functions_aggregate::min_max::max; use datafusion_functions_aggregate::min_max::min; use datafusion_functions_aggregate::sum::sum; +use std::collections::HashMap; + +/// Helper function to convert serde_json::Value to DataFusion ScalarValue +fn json_to_scalar(value: &serde_json::Value) -> datafusion::scalar::ScalarValue { + use datafusion::scalar::ScalarValue; + match value { + serde_json::Value::Null => ScalarValue::Null, + serde_json::Value::Bool(b) => ScalarValue::Boolean(Some(*b)), + serde_json::Value::Number(n) => { + if let Some(i) = n.as_i64() { + ScalarValue::Int64(Some(i)) + } else if let Some(f) = n.as_f64() { + ScalarValue::Float64(Some(f)) + } else { + ScalarValue::Null + } + } + serde_json::Value::String(s) => ScalarValue::Utf8(Some(s.clone())), + serde_json::Value::Array(_) | serde_json::Value::Object(_) => ScalarValue::Null, // Complex types not supported yet + } +} /// Helper function to create LIKE expressions with consistent settings -fn create_like_expr(expression: &ValueExpression, pattern: &str, case_insensitive: bool) -> Expr { +fn create_like_expr( + expression: &ValueExpression, + pattern: &str, + case_insensitive: bool, + parameters: &HashMap, +) -> Expr { Expr::Like(datafusion::logical_expr::Like { negated: false, - expr: Box::new(to_df_value_expr(expression)), + expr: Box::new(to_df_value_expr(expression, parameters)), pattern: Box::new(lit(pattern.to_string())), escape_char: None, case_insensitive, @@ -31,7 +57,10 @@ fn create_like_expr(expression: &ValueExpression, pattern: &str, case_insensitiv } /// Convert BooleanExpression to DataFusion Expr -pub(crate) fn to_df_boolean_expr(expr: &BooleanExpression) -> Expr { +pub(crate) fn to_df_boolean_expr( + expr: &BooleanExpression, + parameters: &HashMap, +) -> Expr { use crate::ast::{BooleanExpression as BE, ComparisonOperator as CO}; match expr { BE::Comparison { @@ -39,8 +68,8 @@ pub(crate) fn to_df_boolean_expr(expr: &BooleanExpression) -> Expr { operator, right, } => { - let l = to_df_value_expr(left); - let r = to_df_value_expr(right); + let l = to_df_value_expr(left, parameters); + let r = to_df_value_expr(right, parameters); let op = match operator { CO::Equal => Operator::Eq, CO::NotEqual => Operator::NotEq, @@ -57,57 +86,66 @@ pub(crate) fn to_df_boolean_expr(expr: &BooleanExpression) -> Expr { } BE::In { expression, list } => { use datafusion::logical_expr::expr::InList as DFInList; - let expr = to_df_value_expr(expression); - let list_exprs = list.iter().map(to_df_value_expr).collect::>(); + let expr = to_df_value_expr(expression, parameters); + let list_exprs = list + .iter() + .map(|e| to_df_value_expr(e, parameters)) + .collect::>(); Expr::InList(DFInList::new(Box::new(expr), list_exprs, false)) } BE::And(l, r) => Expr::BinaryExpr(BinaryExpr { - left: Box::new(to_df_boolean_expr(l)), + left: Box::new(to_df_boolean_expr(l, parameters)), op: Operator::And, - right: Box::new(to_df_boolean_expr(r)), + right: Box::new(to_df_boolean_expr(r, parameters)), }), BE::Or(l, r) => Expr::BinaryExpr(BinaryExpr { - left: Box::new(to_df_boolean_expr(l)), + left: Box::new(to_df_boolean_expr(l, parameters)), op: Operator::Or, - right: Box::new(to_df_boolean_expr(r)), + right: Box::new(to_df_boolean_expr(r, parameters)), }), - BE::Not(inner) => Expr::Not(Box::new(to_df_boolean_expr(inner))), + BE::Not(inner) => Expr::Not(Box::new(to_df_boolean_expr(inner, parameters))), BE::Exists(prop) => Expr::IsNotNull(Box::new(to_df_value_expr( &ValueExpression::Property(prop.clone()), + parameters, ))), - BE::IsNull(expression) => Expr::IsNull(Box::new(to_df_value_expr(expression))), - BE::IsNotNull(expression) => Expr::IsNotNull(Box::new(to_df_value_expr(expression))), + BE::IsNull(expression) => Expr::IsNull(Box::new(to_df_value_expr(expression, parameters))), + BE::IsNotNull(expression) => { + Expr::IsNotNull(Box::new(to_df_value_expr(expression, parameters))) + } BE::Like { expression, pattern, - } => create_like_expr(expression, pattern, false), + } => create_like_expr(expression, pattern, false, parameters), BE::ILike { expression, pattern, - } => create_like_expr(expression, pattern, true), + } => create_like_expr(expression, pattern, true, parameters), BE::Contains { expression, substring, } => { // CONTAINS is equivalent to LIKE '%substring%' let pattern = format!("%{}%", substring); - create_like_expr(expression, &pattern, false) + create_like_expr(expression, &pattern, false, parameters) } BE::StartsWith { expression, prefix } => { // STARTS WITH is equivalent to LIKE 'prefix%' let pattern = format!("{}%", prefix); - create_like_expr(expression, &pattern, false) + create_like_expr(expression, &pattern, false, parameters) } BE::EndsWith { expression, suffix } => { // ENDS WITH is equivalent to LIKE '%suffix' let pattern = format!("%{}", suffix); - create_like_expr(expression, &pattern, false) + create_like_expr(expression, &pattern, false, parameters) } } } /// Convert ValueExpression to DataFusion Expr -pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr { +pub(crate) fn to_df_value_expr( + expr: &ValueExpression, + parameters: &HashMap, +) -> Expr { use crate::ast::{PropertyValue as PV, ValueExpression as VE}; match expr { VE::Property(prop) => { @@ -122,7 +160,14 @@ pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr { VE::Literal(PV::Null) => { datafusion::logical_expr::Expr::Literal(datafusion::scalar::ScalarValue::Null, None) } - VE::Literal(PV::Parameter(_)) => lit(0), + VE::Literal(PV::Parameter(name)) => { + // Handle parameter in literal (if that ever happens, though usually it's separate) + if let Some(value) = parameters.get(name) { + Expr::Literal(json_to_scalar(value), None) + } else { + col(format!("${}", name)) + } + } VE::Literal(PV::Property(prop)) => { // Create qualified column name: variable__property (lowercase for case-insensitivity) col(qualify_column(&prop.variable, &prop.property)) @@ -131,7 +176,7 @@ pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr { match name.to_lowercase().as_str() { "tolower" | "lower" => { if args.len() == 1 { - let arg_expr = to_df_value_expr(&args[0]); + let arg_expr = to_df_value_expr(&args[0], parameters); lower().call(vec![arg_expr]) } else { // Invalid argument count - return NULL @@ -140,7 +185,7 @@ pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr { } "toupper" | "upper" => { if args.len() == 1 { - let arg_expr = to_df_value_expr(&args[0]); + let arg_expr = to_df_value_expr(&args[0], parameters); upper().call(vec![arg_expr]) } else { // Invalid argument count - return NULL @@ -182,7 +227,7 @@ pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr { } } else { // COUNT(p.property) - count non-null values of that property - to_df_value_expr(&args[0]) + to_df_value_expr(&args[0], parameters) }; // Use DataFusion's count or count_distinct @@ -198,7 +243,7 @@ pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr { } "sum" => { if args.len() == 1 { - let arg_expr = to_df_value_expr(&args[0]); + let arg_expr = to_df_value_expr(&args[0], parameters); sum(arg_expr) } else { lit(0) @@ -206,7 +251,7 @@ pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr { } "avg" => { if args.len() == 1 { - let arg_expr = to_df_value_expr(&args[0]); + let arg_expr = to_df_value_expr(&args[0], parameters); avg(arg_expr) } else { lit(0) @@ -214,7 +259,7 @@ pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr { } "min" => { if args.len() == 1 { - let arg_expr = to_df_value_expr(&args[0]); + let arg_expr = to_df_value_expr(&args[0], parameters); min(arg_expr) } else { lit(0) @@ -222,7 +267,7 @@ pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr { } "max" => { if args.len() == 1 { - let arg_expr = to_df_value_expr(&args[0]); + let arg_expr = to_df_value_expr(&args[0], parameters); max(arg_expr) } else { lit(0) @@ -230,7 +275,7 @@ pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr { } "collect" => { if args.len() == 1 { - let arg_expr = to_df_value_expr(&args[0]); + let arg_expr = to_df_value_expr(&args[0], parameters); array_agg(arg_expr) } else { lit(0) @@ -253,8 +298,8 @@ pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr { right, } => { use crate::ast::ArithmeticOperator as AO; - let l = to_df_value_expr(left); - let r = to_df_value_expr(right); + let l = to_df_value_expr(left, parameters); + let r = to_df_value_expr(right, parameters); let op = match operator { AO::Add => Operator::Plus, AO::Subtract => Operator::Minus, @@ -275,8 +320,8 @@ pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr { } => { // Create UDF for vector distance computation let udf = udf::create_vector_distance_udf(metric); - let left_expr = to_df_value_expr(left); - let right_expr = to_df_value_expr(right); + let left_expr = to_df_value_expr(left, parameters); + let right_expr = to_df_value_expr(right, parameters); Expr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf( udf, vec![left_expr, right_expr], @@ -289,8 +334,8 @@ pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr { } => { // Create UDF for vector similarity computation let udf = udf::create_vector_similarity_udf(metric); - let left_expr = to_df_value_expr(left); - let right_expr = to_df_value_expr(right); + let left_expr = to_df_value_expr(left, parameters); + let right_expr = to_df_value_expr(right, parameters); Expr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf( udf, vec![left_expr, right_expr], @@ -316,18 +361,11 @@ pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr { lit(scalar) } VE::Parameter(name) => { - // TODO: Implement proper parameter resolution - // Parameters ($param) should be resolved to literal values from the query's - // parameter map (CypherQuery::parameters()) before or during planning. - // - // Current limitation: This creates a column reference as a placeholder, - // which will fail at execution if the column doesn't exist. - // - // Proper fix requires one of: - // 1. Resolve parameters during semantic analysis (substitute before planning) - // 2. Pass parameter map to to_df_value_expr and resolve here - // 3. Use DataFusion's parameter binding mechanism - col(format!("${}", name)) + if let Some(value) = parameters.get(name) { + Expr::Literal(json_to_scalar(value), None) + } else { + col(format!("${}", name)) + } } } } @@ -420,7 +458,7 @@ mod tests { right: ValueExpression::Literal(PropertyValue::Integer(30)), }; - let df_expr = to_df_boolean_expr(&expr); + let df_expr = to_df_boolean_expr(&expr, &std::collections::HashMap::new()); let s = format!("{:?}", df_expr); assert!(s.contains("p__age"), "Should contain qualified column"); assert!( @@ -451,7 +489,7 @@ mod tests { right: ValueExpression::Literal(PropertyValue::Integer(30)), }; - let df_expr = to_df_boolean_expr(&expr); + let df_expr = to_df_boolean_expr(&expr, &std::collections::HashMap::new()); // Should successfully translate without panicking assert!(format!("{:?}", df_expr).contains("p__age")); } @@ -479,7 +517,7 @@ mod tests { }), ); - let df_expr = to_df_boolean_expr(&expr); + let df_expr = to_df_boolean_expr(&expr, &std::collections::HashMap::new()); let s = format!("{:?}", df_expr); assert!(s.contains("And"), "Should contain AND operator"); assert!(s.contains("p__age"), "Should contain column reference"); @@ -507,7 +545,7 @@ mod tests { }), ); - let df_expr = to_df_boolean_expr(&expr); + let df_expr = to_df_boolean_expr(&expr, &std::collections::HashMap::new()); let s = format!("{:?}", df_expr); assert!(s.contains("Or"), "Should contain OR operator"); } @@ -524,7 +562,7 @@ mod tests { right: ValueExpression::Literal(PropertyValue::Boolean(true)), })); - let df_expr = to_df_boolean_expr(&expr); + let df_expr = to_df_boolean_expr(&expr, &std::collections::HashMap::new()); let s = format!("{:?}", df_expr); assert!(s.contains("Not"), "Should contain NOT operator"); } @@ -536,7 +574,7 @@ mod tests { property: "email".into(), }); - let df_expr = to_df_boolean_expr(&expr); + let df_expr = to_df_boolean_expr(&expr, &std::collections::HashMap::new()); let s = format!("{:?}", df_expr); assert!( s.contains("IsNotNull") || s.contains("p__email"), @@ -557,7 +595,8 @@ mod tests { ], }; - if let Expr::InList(in_list) = to_df_boolean_expr(&expr) { + if let Expr::InList(in_list) = to_df_boolean_expr(&expr, &std::collections::HashMap::new()) + { assert!(!in_list.negated); assert_eq!(in_list.list.len(), 2); match *in_list.expr { @@ -581,7 +620,8 @@ mod tests { pattern: "A%".into(), }; - if let Expr::Like(like_expr) = to_df_boolean_expr(&expr) { + if let Expr::Like(like_expr) = to_df_boolean_expr(&expr, &std::collections::HashMap::new()) + { assert!(!like_expr.negated, "Should not be negated"); assert!(!like_expr.case_insensitive, "Should be case sensitive"); assert_eq!(like_expr.escape_char, None, "Should have no escape char"); @@ -611,7 +651,8 @@ mod tests { pattern: "alice%".into(), }; - if let Expr::Like(like_expr) = to_df_boolean_expr(&expr) { + if let Expr::Like(like_expr) = to_df_boolean_expr(&expr, &std::collections::HashMap::new()) + { assert!(!like_expr.negated, "Should not be negated"); assert!( like_expr.case_insensitive, @@ -652,7 +693,8 @@ mod tests { pattern: "Test%".into(), }; - if let Expr::Like(like) = to_df_boolean_expr(&like_expr) { + if let Expr::Like(like) = to_df_boolean_expr(&like_expr, &std::collections::HashMap::new()) + { assert!( !like.case_insensitive, "LIKE should be case-sensitive (case_insensitive = false)" @@ -670,7 +712,9 @@ mod tests { pattern: "Test%".into(), }; - if let Expr::Like(ilike) = to_df_boolean_expr(&ilike_expr) { + if let Expr::Like(ilike) = + to_df_boolean_expr(&ilike_expr, &std::collections::HashMap::new()) + { assert!( ilike.case_insensitive, "ILIKE should be case-insensitive (case_insensitive = true)" @@ -690,7 +734,7 @@ mod tests { pattern: "%@example.com".into(), }; - let df_expr = to_df_boolean_expr(&expr); + let df_expr = to_df_boolean_expr(&expr, &std::collections::HashMap::new()); let s = format!("{:?}", df_expr); assert!( s.contains("Like") || s.contains("like"), @@ -709,7 +753,8 @@ mod tests { substring: "ali".into(), }; - if let Expr::Like(like_expr) = to_df_boolean_expr(&expr) { + if let Expr::Like(like_expr) = to_df_boolean_expr(&expr, &std::collections::HashMap::new()) + { assert!(!like_expr.negated, "Should not be negated"); assert!(!like_expr.case_insensitive, "Should be case sensitive"); assert_eq!(like_expr.escape_char, None, "Should have no escape char"); @@ -745,7 +790,8 @@ mod tests { prefix: "admin".into(), }; - if let Expr::Like(like_expr) = to_df_boolean_expr(&expr) { + if let Expr::Like(like_expr) = to_df_boolean_expr(&expr, &std::collections::HashMap::new()) + { assert!(!like_expr.negated, "Should not be negated"); assert!(!like_expr.case_insensitive, "Should be case sensitive"); @@ -784,7 +830,8 @@ mod tests { suffix: "@example.com".into(), }; - if let Expr::Like(like_expr) = to_df_boolean_expr(&expr) { + if let Expr::Like(like_expr) = to_df_boolean_expr(&expr, &std::collections::HashMap::new()) + { assert!(!like_expr.negated, "Should not be negated"); assert!(!like_expr.case_insensitive, "Should be case sensitive"); @@ -824,7 +871,8 @@ mod tests { substring: "Test".into(), }; - if let Expr::Like(like_expr) = to_df_boolean_expr(&expr) { + if let Expr::Like(like_expr) = to_df_boolean_expr(&expr, &std::collections::HashMap::new()) + { assert!( !like_expr.case_insensitive, "CONTAINS should be case-sensitive by default" @@ -842,7 +890,8 @@ mod tests { substring: "test".into(), }; - if let Expr::Like(like_expr) = to_df_boolean_expr(&expr) { + if let Expr::Like(like_expr) = to_df_boolean_expr(&expr, &std::collections::HashMap::new()) + { match *like_expr.expr { Expr::Column(ref col_expr) => { assert_eq!( @@ -869,7 +918,7 @@ mod tests { property: "name".into(), }); - let df_expr = to_df_value_expr(&expr); + let df_expr = to_df_value_expr(&expr, &std::collections::HashMap::new()); let s = format!("{:?}", df_expr); assert_eq!( s, @@ -880,7 +929,7 @@ mod tests { #[test] fn test_value_expr_literal_integer() { let expr = ValueExpression::Literal(PropertyValue::Integer(42)); - let df_expr = to_df_value_expr(&expr); + let df_expr = to_df_value_expr(&expr, &std::collections::HashMap::new()); let s = format!("{:?}", df_expr); assert!(s.contains("42") || s.contains("Int64(42)")); } @@ -888,7 +937,7 @@ mod tests { #[test] fn test_value_expr_literal_float() { let expr = ValueExpression::Literal(PropertyValue::Float(std::f64::consts::PI)); - let df_expr = to_df_value_expr(&expr); + let df_expr = to_df_value_expr(&expr, &std::collections::HashMap::new()); let s = format!("{:?}", df_expr); assert!(s.contains("3.14") || s.contains("Float64")); } @@ -896,7 +945,7 @@ mod tests { #[test] fn test_value_expr_literal_string() { let expr = ValueExpression::Literal(PropertyValue::String("hello".into())); - let df_expr = to_df_value_expr(&expr); + let df_expr = to_df_value_expr(&expr, &std::collections::HashMap::new()); let s = format!("{:?}", df_expr); assert!(s.contains("hello") || s.contains("Utf8")); } @@ -904,7 +953,7 @@ mod tests { #[test] fn test_value_expr_literal_boolean() { let expr = ValueExpression::Literal(PropertyValue::Boolean(true)); - let df_expr = to_df_value_expr(&expr); + let df_expr = to_df_value_expr(&expr, &std::collections::HashMap::new()); let s = format!("{:?}", df_expr); assert!(s.contains("true") || s.contains("Boolean")); } @@ -912,7 +961,7 @@ mod tests { #[test] fn test_value_expr_literal_null() { let expr = ValueExpression::Literal(PropertyValue::Null); - let df_expr = to_df_value_expr(&expr); + let df_expr = to_df_value_expr(&expr, &std::collections::HashMap::new()); let s = format!("{:?}", df_expr); // Null literals are translated to Literal with Null value assert!(s.contains("Literal"), "Should be a Literal expression"); @@ -930,7 +979,7 @@ mod tests { right: Box::new(ValueExpression::Literal(PropertyValue::Integer(5))), }; - let df_expr = to_df_value_expr(&expr); + let df_expr = to_df_value_expr(&expr, &std::collections::HashMap::new()); let s = format!("{:?}", df_expr); // Arithmetic expressions should now return a BinaryExpr with Plus operator assert!(s.contains("BinaryExpr"), "Should be a BinaryExpr"); @@ -959,7 +1008,7 @@ mod tests { right: Box::new(ValueExpression::Literal(PropertyValue::Integer(2))), }; - let df_expr = to_df_value_expr(&expr); + let df_expr = to_df_value_expr(&expr, &std::collections::HashMap::new()); let s = format!("{:?}", df_expr); // Should translate to BinaryExpr with the correct operator assert!( @@ -983,7 +1032,7 @@ mod tests { distinct: false, }; - let df_expr = to_df_value_expr(&expr); + let df_expr = to_df_value_expr(&expr, &std::collections::HashMap::new()); let s = format!("{:?}", df_expr); assert!( s.contains("count") || s.contains("Count"), @@ -1002,7 +1051,7 @@ mod tests { distinct: false, }; - let df_expr = to_df_value_expr(&expr); + let df_expr = to_df_value_expr(&expr, &std::collections::HashMap::new()); let s = format!("{:?}", df_expr); assert!( s.contains("count") || s.contains("Count"), @@ -1022,7 +1071,7 @@ mod tests { distinct: false, }; - let df_expr = to_df_value_expr(&expr); + let df_expr = to_df_value_expr(&expr, &std::collections::HashMap::new()); let s = format!("{:?}", df_expr); assert!( s.contains("sum") || s.contains("Sum"), @@ -1042,7 +1091,7 @@ mod tests { distinct: false, }; - let df_expr = to_df_value_expr(&expr); + let df_expr = to_df_value_expr(&expr, &std::collections::HashMap::new()); let s = format!("{:?}", df_expr); assert!( s.contains("avg") || s.contains("Avg"), @@ -1062,7 +1111,7 @@ mod tests { distinct: false, }; - let df_expr = to_df_value_expr(&expr); + let df_expr = to_df_value_expr(&expr, &std::collections::HashMap::new()); let s = format!("{:?}", df_expr); assert!( s.contains("min") || s.contains("Min"), @@ -1082,7 +1131,7 @@ mod tests { distinct: false, }; - let df_expr = to_df_value_expr(&expr); + let df_expr = to_df_value_expr(&expr, &std::collections::HashMap::new()); let s = format!("{:?}", df_expr); assert!( s.contains("max") || s.contains("Max"), @@ -1101,7 +1150,7 @@ mod tests { })], }; - let df_expr = to_df_value_expr(&expr); + let df_expr = to_df_value_expr(&expr, &std::collections::HashMap::new()); let s = format!("{:?}", df_expr); // Should be a ScalarFunction with lower assert!( @@ -1122,7 +1171,7 @@ mod tests { })], }; - let df_expr = to_df_value_expr(&expr); + let df_expr = to_df_value_expr(&expr, &std::collections::HashMap::new()); let s = format!("{:?}", df_expr); // Should be a ScalarFunction with upper assert!( @@ -1144,7 +1193,7 @@ mod tests { })], }; - let df_expr = to_df_value_expr(&expr); + let df_expr = to_df_value_expr(&expr, &std::collections::HashMap::new()); let s = format!("{:?}", df_expr); assert!( s.contains("lower") || s.contains("Lower"), @@ -1164,7 +1213,7 @@ mod tests { })], }; - let df_expr = to_df_value_expr(&expr); + let df_expr = to_df_value_expr(&expr, &std::collections::HashMap::new()); let s = format!("{:?}", df_expr); assert!( s.contains("upper") || s.contains("Upper"), @@ -1190,7 +1239,7 @@ mod tests { substring: "offer".into(), }; - let df_expr = to_df_boolean_expr(&contains_expr); + let df_expr = to_df_boolean_expr(&contains_expr, &HashMap::new()); let s = format!("{:?}", df_expr); // Should be a Like expression with lower() on the column, not lit(0) diff --git a/crates/lance-graph/src/datafusion_planner/join_ops.rs b/crates/lance-graph/src/datafusion_planner/join_ops.rs index a1e83ba..c2dc375 100644 --- a/crates/lance-graph/src/datafusion_planner/join_ops.rs +++ b/crates/lance-graph/src/datafusion_planner/join_ops.rs @@ -106,6 +106,7 @@ impl DataFusionPlanner { for (k, v) in params.target_properties.iter() { let lit_expr = super::expression::to_df_value_expr( &crate::ast::ValueExpression::Literal(v.clone()), + ctx.parameters, ); let filter_expr = Expr::BinaryExpr(BinaryExpr { left: Box::new(col(k.to_lowercase())), diff --git a/crates/lance-graph/src/datafusion_planner/mod.rs b/crates/lance-graph/src/datafusion_planner/mod.rs index a5387f9..28af511 100644 --- a/crates/lance-graph/src/datafusion_planner/mod.rs +++ b/crates/lance-graph/src/datafusion_planner/mod.rs @@ -34,6 +34,7 @@ use crate::error::Result; use crate::logical_plan::LogicalOperator; use crate::source_catalog::GraphSourceCatalog; use datafusion::logical_expr::LogicalPlan; +use std::collections::HashMap; use std::sync::Arc; /// Planner abstraction for graph-to-physical planning @@ -45,6 +46,7 @@ pub trait GraphPhysicalPlanner { pub struct DataFusionPlanner { pub(crate) config: GraphConfig, pub(crate) catalog: Option>, + pub(crate) parameters: HashMap, } impl DataFusionPlanner { @@ -52,6 +54,7 @@ impl DataFusionPlanner { Self { config, catalog: None, + parameters: HashMap::new(), } } @@ -59,9 +62,15 @@ impl DataFusionPlanner { Self { config, catalog: Some(catalog), + parameters: HashMap::new(), } } + pub fn with_parameters(mut self, params: HashMap) -> Self { + self.parameters = params; + self + } + /// Helper to convert DataFusion builder errors into GraphError::PlanError with context pub(crate) fn plan_error( &self, @@ -81,7 +90,7 @@ impl GraphPhysicalPlanner for DataFusionPlanner { let analysis = analysis::analyze(logical_plan)?; // Phase 2: Build execution plan with context - let mut ctx = PlanningContext::new(&analysis); + let mut ctx = PlanningContext::new(&analysis, &self.parameters); self.build_operator(&mut ctx, logical_plan) } } diff --git a/crates/lance-graph/src/datafusion_planner/scan_ops.rs b/crates/lance-graph/src/datafusion_planner/scan_ops.rs index 912f07d..7bead6d 100644 --- a/crates/lance-graph/src/datafusion_planner/scan_ops.rs +++ b/crates/lance-graph/src/datafusion_planner/scan_ops.rs @@ -21,7 +21,7 @@ impl DataFusionPlanner { /// Build a qualified node scan with property filters and column aliasing pub(crate) fn build_scan( &self, - _ctx: &PlanningContext, + ctx: &PlanningContext, variable: &str, label: &str, properties: &HashMap, @@ -46,6 +46,7 @@ impl DataFusionPlanner { .map(|(k, v)| { let lit_expr = super::expression::to_df_value_expr( &crate::ast::ValueExpression::Literal(v.clone()), + ctx.parameters, ); Expr::BinaryExpr(BinaryExpr { left: Box::new(col(k)), @@ -122,6 +123,7 @@ impl DataFusionPlanner { /// Build a qualified relationship scan with property filters pub(crate) fn build_relationship_scan( &self, + ctx: &PlanningContext, rel_instance: &RelationshipInstance, rel_source: Arc, relationship_properties: &HashMap, @@ -140,6 +142,7 @@ impl DataFusionPlanner { for (k, v) in relationship_properties.iter() { let lit_expr = super::expression::to_df_value_expr( &crate::ast::ValueExpression::Literal(v.clone()), + ctx.parameters, ); let filter_expr = Expr::BinaryExpr(BinaryExpr { left: Box::new(col(k)), @@ -259,6 +262,7 @@ impl DataFusionPlanner { /// Build a qualified target node scan with property filters pub(crate) fn build_qualified_target_scan( &self, + ctx: &PlanningContext, catalog: &Arc, target_label: &str, target_variable: &str, @@ -285,6 +289,7 @@ impl DataFusionPlanner { for (k, v) in target_properties.iter() { let lit_expr = super::expression::to_df_value_expr( &crate::ast::ValueExpression::Literal(v.clone()), + ctx.parameters, ); let filter_expr = Expr::BinaryExpr(BinaryExpr { left: Box::new(col(k)), diff --git a/crates/lance-graph/src/parser.rs b/crates/lance-graph/src/parser.rs index 8887ce2..d0e3a93 100644 --- a/crates/lance-graph/src/parser.rs +++ b/crates/lance-graph/src/parser.rs @@ -589,9 +589,8 @@ fn parse_vector_similarity(input: &str) -> IResult<&str, ValueExpression> { // Parse parameter reference: $name fn parse_parameter(input: &str) -> IResult<&str, ValueExpression> { - let (input, _) = char('$')(input)?; - let (input, name) = identifier(input)?; - Ok((input, ValueExpression::Parameter(name.to_string()))) + let (input, name) = parameter(input)?; + Ok((input, ValueExpression::Parameter(name))) } // Parse a function call: function_name(args) @@ -973,9 +972,18 @@ fn boolean_literal(input: &str) -> IResult<&str, bool> { // Parse a parameter reference fn parameter(input: &str) -> IResult<&str, String> { - let (input, _) = char('$')(input)?; - let (input, name) = identifier(input)?; - Ok((input, name.to_string())) + alt(( + // $param + map(preceded(char('$'), identifier), |s| s.to_string()), + // @param + map(preceded(char('@'), identifier), |s| s.to_string()), + // :param + map(preceded(char(':'), identifier), |s| s.to_string()), + // {param} + map(delimited(char('{'), identifier, char('}')), |s| { + s.to_string() + }), + ))(input) } // Parse comma with optional whitespace @@ -1699,6 +1707,92 @@ mod tests { } } + #[test] + fn test_parse_multiple_parameters() { + let query = "MATCH (p:Person) WHERE p.age > $min_age AND p.age < $max_age RETURN p"; + let result = parse_cypher_query(query); + assert!( + result.is_ok(), + "Multiple parameters should parse successfully" + ); + + let ast = result.unwrap(); + let where_clause = ast.where_clause.expect("Expected WHERE clause"); + + match where_clause.expression { + BooleanExpression::And(left, right) => { + // Check left: p.age > $min_age + match *left { + BooleanExpression::Comparison { + right: val_right, .. + } => match val_right { + ValueExpression::Parameter(name) => { + assert_eq!(name, "min_age"); + } + _ => panic!("Expected Parameter min_age"), + }, + _ => panic!("Expected comparison on left"), + } + + // Check right: p.age < $max_age + match *right { + BooleanExpression::Comparison { + right: val_right, .. + } => match val_right { + ValueExpression::Parameter(name) => { + assert_eq!(name, "max_age"); + } + _ => panic!("Expected Parameter max_age"), + }, + _ => panic!("Expected comparison on right"), + } + } + _ => panic!("Expected AND expression"), + } + } + + #[test] + fn test_parse_parameter_formats() { + // Test @param + let query = "MATCH (p:Person) WHERE p.age > @min_age RETURN p"; + let result = parse_cypher_query(query); + assert!(result.is_ok(), "@param should parse successfully"); + let where_clause = result.unwrap().where_clause.expect("Expected WHERE clause"); + match where_clause.expression { + BooleanExpression::Comparison { right, .. } => match right { + ValueExpression::Parameter(name) => assert_eq!(name, "min_age"), + _ => panic!("Expected Parameter for @param"), + }, + _ => panic!("Expected comparison"), + } + + // Test :param + let query = "MATCH (p:Person) WHERE p.age > :min_age RETURN p"; + let result = parse_cypher_query(query); + assert!(result.is_ok(), ":param should parse successfully"); + let where_clause = result.unwrap().where_clause.expect("Expected WHERE clause"); + match where_clause.expression { + BooleanExpression::Comparison { right, .. } => match right { + ValueExpression::Parameter(name) => assert_eq!(name, "min_age"), + _ => panic!("Expected Parameter for :param"), + }, + _ => panic!("Expected comparison"), + } + + // Test {param} + let query = "MATCH (p:Person) WHERE p.age > {min_age} RETURN p"; + let result = parse_cypher_query(query); + assert!(result.is_ok(), "{{param}} should parse successfully"); + let where_clause = result.unwrap().where_clause.expect("Expected WHERE clause"); + match where_clause.expression { + BooleanExpression::Comparison { right, .. } => match right { + ValueExpression::Parameter(name) => assert_eq!(name, "min_age"), + _ => panic!("Expected Parameter for {{param}}"), + }, + _ => panic!("Expected comparison"), + } + } + #[test] fn test_vector_distance_metrics() { for metric in &["cosine", "l2", "dot"] { diff --git a/crates/lance-graph/src/query.rs b/crates/lance-graph/src/query.rs index 8f5f017..7efb945 100644 --- a/crates/lance-graph/src/query.rs +++ b/crates/lance-graph/src/query.rs @@ -782,7 +782,8 @@ impl CypherQuery { let logical_plan = logical_planner.plan(&self.ast)?; // Phase 3: DataFusion Logical Plan - let df_planner = DataFusionPlanner::with_catalog(config.clone(), catalog); + let df_planner = DataFusionPlanner::with_catalog(config.clone(), catalog) + .with_parameters(self.parameters.clone()); let df_logical_plan = df_planner.plan(&logical_plan)?; Ok((logical_plan, df_logical_plan)) @@ -1489,12 +1490,16 @@ mod tests { fn test_query_with_parameters() { let mut params = HashMap::new(); params.insert("minAge".to_string(), serde_json::Value::Number(30.into())); + params.insert("maxAge".to_string(), serde_json::Value::Number(50.into())); - let query = CypherQuery::new("MATCH (n:Person) WHERE n.age > $minAge RETURN n.name") - .unwrap() - .with_parameters(params); + let query = CypherQuery::new( + "MATCH (n:Person) WHERE n.age > $minAge AND n.age < $maxAge RETURN n.name", + ) + .unwrap() + .with_parameters(params); assert!(query.parameters().contains_key("minAge")); + assert!(query.parameters().contains_key("maxAge")); } #[test] diff --git a/crates/lance-graph/tests/test_datafusion_pipeline.rs b/crates/lance-graph/tests/test_datafusion_pipeline.rs index 243db94..dde77b4 100644 --- a/crates/lance-graph/tests/test_datafusion_pipeline.rs +++ b/crates/lance-graph/tests/test_datafusion_pipeline.rs @@ -4951,3 +4951,52 @@ async fn test_unwind_then_match() { assert_eq!(rows, expected); } + +#[tokio::test] +async fn test_datafusion_parameter_filtering_age() { + let config = create_graph_config(); + let person_batch = create_person_dataset(); + + let mut params = HashMap::new(); + // Filter for people older than 30 (Bob:35, David:40) + params.insert("min_age".to_string(), serde_json::json!(30)); + + let query = CypherQuery::new("MATCH (p:Person) WHERE p.age > $min_age RETURN p.name, p.age") + .unwrap() + .with_config(config) + .with_parameters(params); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + + let result = query + .execute(datasets, Some(ExecutionStrategy::DataFusion)) + .await + .unwrap(); + + // Should return 2 people (Bob:35, David:40) + assert_eq!(result.num_rows(), 2); + assert_eq!(result.num_columns(), 2); + + let names = result + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let ages = result + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + let mut results = Vec::new(); + for i in 0..result.num_rows() { + results.push((names.value(i).to_string(), ages.value(i))); + } + + results.sort(); + assert_eq!( + results, + vec![("Bob".to_string(), 35), ("David".to_string(), 40)] + ); +} diff --git a/python/python/tests/test_cypher_engine.py b/python/python/tests/test_cypher_engine.py index 84372f6..89fcc72 100644 --- a/python/python/tests/test_cypher_engine.py +++ b/python/python/tests/test_cypher_engine.py @@ -166,3 +166,54 @@ def test_cypher_engine_config_access(graph_env): assert "person" in engine_config.node_labels() # case-insensitive assert "company" in engine_config.node_labels() + + +def test_cypher_parameter_syntax(graph_env): + """Test various Cypher parameter syntaxes ($ @ : {}).""" + config, datasets = graph_env + + # 1. Test $param + query_dollar = CypherQuery( + "MATCH (p:Person) WHERE p.age > $age RETURN p.name" + ).with_config(config) + result = query_dollar.with_parameter("age", 30).execute(datasets) + data = result.to_pydict() + assert set(data["p.name"]) == {"Bob", "David"} + + # 2. Test @param + query_at = CypherQuery( + "MATCH (p:Person) WHERE p.age > @age RETURN p.name" + ).with_config(config) + result = query_at.with_parameter("age", 30).execute(datasets) + data = result.to_pydict() + assert set(data["p.name"]) == {"Bob", "David"} + + # 3. Test :param + query_colon = CypherQuery( + "MATCH (p:Person) WHERE p.age > :age RETURN p.name" + ).with_config(config) + result = query_colon.with_parameter("age", 30).execute(datasets) + data = result.to_pydict() + assert set(data["p.name"]) == {"Bob", "David"} + + # 4. Test {param} + query_curly = CypherQuery( + "MATCH (p:Person) WHERE p.age > {age} RETURN p.name" + ).with_config(config) + result = query_curly.with_parameter("age", 30).execute(datasets) + data = result.to_pydict() + assert set(data["p.name"]) == {"Bob", "David"} + + # 5. Test multiple parameters + query_multi = CypherQuery( + "MATCH (p:Person) WHERE p.age > $min_age AND p.age < $max_age RETURN p.name" + ).with_config(config) + result = ( + query_multi.with_parameter("min_age", 25) + .with_parameter("max_age", 35) + .execute(datasets) + ) + data = result.to_pydict() + # Should get Alice (28), Carol (29), Bob (34) + # David is 42 (excluded) + assert set(data["p.name"]) == {"Alice", "Carol", "Bob"}