Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions dev/diffs/4.0.1.diff
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,21 @@ index aa3d02dc2fb..c4f878d9908 100644
-- Test cases with unicode_rtrim.
WITH t(c1) AS (SELECT replace(listagg(DISTINCT col1 COLLATE unicode_rtrim) COLLATE utf8_binary, ' ', '') FROM (VALUES ('xbc '), ('xbc '), ('a'), ('xbc'))) SELECT len(c1), regexp_count(c1, 'a'), regexp_count(c1, 'xbc') FROM t;
WITH t(c1) AS (SELECT listagg(col1) WITHIN GROUP (ORDER BY col1 COLLATE unicode_rtrim) FROM (VALUES ('abc '), ('abc\n'), ('abc'), ('x'))) SELECT replace(replace(c1, ' ', ''), '\n', '$') FROM t;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql
index 0000000..0000000 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql
@@ -6,6 +6,10 @@
-- https://github.com/postgres/postgres/blob/REL_12_BETA2/src/test/regress/sql/aggregates.sql#L352-L605

-- Test aggregate operator with codegen on and off.
+
+-- Floating-point precision difference between DataFusion and JVM for FILTER aggregates
+--SET spark.comet.enabled = false
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The float precision difference issue is:

[info] - postgreSQL/aggregates_part3.sql *** FAILED *** (381 milliseconds)
[info]   postgreSQL/aggregates_part3.sql
[info]   Expected "2828.9682539682[954]", but got "2828.9682539682[517]" Result did not match for query #2
[info]   select sum(1/ten) filter (where ten > 0) from tenk1 (SQLQueryTestSuite.scala:683)

+
--CONFIG_DIM1 spark.sql.codegen.wholeStage=true
--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY
--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=NO_CODEGEN
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql
index 3a409eea348..26e9aaf215c 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql
Expand Down
18 changes: 16 additions & 2 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -973,14 +973,28 @@ impl PhysicalPlanner {
.map(|expr| self.create_agg_expr(expr, Arc::clone(&schema)))
.collect();

let num_agg = agg.agg_exprs.len();
let aggr_expr = agg_exprs?.into_iter().map(Arc::new).collect();

// Build per-aggregate filter expressions from the FILTER (WHERE ...) clause.
// Filters are only present in Partial mode; Final/PartialMerge always get None.
let filter_exprs: Result<Vec<Option<Arc<dyn PhysicalExpr>>>, ExecutionError> = agg
.agg_exprs
.iter()
.map(|expr| {
if let Some(f) = expr.filter.as_ref() {
self.create_expr(f, Arc::clone(&schema)).map(Some)
} else {
Ok(None)
}
})
.collect();

let aggregate: Arc<dyn ExecutionPlan> = Arc::new(
datafusion::physical_plan::aggregates::AggregateExec::try_new(
mode,
group_by,
aggr_expr,
vec![None; num_agg], // no filter expressions
filter_exprs?,
Arc::clone(&child.native_plan),
Arc::clone(&schema),
)?,
Expand Down
4 changes: 4 additions & 0 deletions native/proto/src/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ message AggExpr {
BloomFilterAgg bloomFilterAgg = 16;
}

// Optional filter expression for SQL FILTER (WHERE ...) clause.
// Only set in Partial aggregation mode; absent in Final/PartialMerge.
optional Expr filter = 89;

// Optional QueryContext for error reporting (contains SQL text and position)
optional QueryContext query_context = 90;

Expand Down
9 changes: 7 additions & 2 deletions native/spark-expr/src/agg_funcs/avg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ where
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
_opt_filter: Option<&arrow::array::BooleanArray>,
opt_filter: Option<&arrow::array::BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
assert_eq!(values.len(), 1, "single argument to update_batch");
Expand All @@ -257,7 +257,7 @@ where
self.sums.resize(total_num_groups, T::default_value());

let iter = group_indices.iter().zip(data.iter());
if values.null_count() == 0 {
if opt_filter.is_none() && values.null_count() == 0 {
for (&group_index, &value) in iter {
let sum = &mut self.sums[group_index];
// No overflow checking - Infinity is a valid result
Expand All @@ -266,6 +266,11 @@ where
}
} else {
for (idx, (&group_index, &value)) in iter.enumerate() {
if let Some(f) = opt_filter {
if !f.is_valid(idx) || !f.value(idx) {
continue;
}
}
if values.is_null(idx) {
continue;
}
Expand Down
9 changes: 7 additions & 2 deletions native/spark-expr/src/agg_funcs/avg_decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ impl GroupsAccumulator for AvgDecimalGroupsAccumulator {
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
_opt_filter: Option<&arrow::array::BooleanArray>,
opt_filter: Option<&arrow::array::BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
assert_eq!(values.len(), 1, "single argument to update_batch");
Expand All @@ -517,12 +517,17 @@ impl GroupsAccumulator for AvgDecimalGroupsAccumulator {
ensure_bit_capacity(&mut self.is_not_null, total_num_groups);

let iter = group_indices.iter().zip(data.iter());
if values.null_count() == 0 {
if opt_filter.is_none() && values.null_count() == 0 {
for (&group_index, &value) in iter {
self.update_single(group_index, value)?;
}
} else {
for (idx, (&group_index, &value)) in iter.enumerate() {
if let Some(f) = opt_filter {
if !f.is_valid(idx) || !f.value(idx) {
continue;
}
}
if values.is_null(idx) {
continue;
}
Expand Down
65 changes: 62 additions & 3 deletions native/spark-expr/src/agg_funcs/sum_decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -446,20 +446,24 @@ impl GroupsAccumulator for SumDecimalGroupsAccumulator {
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> DFResult<()> {
assert!(opt_filter.is_none(), "opt_filter is not supported yet");
assert_eq!(values.len(), 1);
let values = values[0].as_primitive::<Decimal128Type>();
let data = values.values();

self.resize_helper(total_num_groups);

let iter = group_indices.iter().zip(data.iter());
if values.null_count() == 0 {
if opt_filter.is_none() && values.null_count() == 0 {
for (&group_index, &value) in iter {
self.update_single(group_index, value)?;
}
} else {
for (idx, (&group_index, &value)) in iter.enumerate() {
if let Some(f) = opt_filter {
if !f.is_valid(idx) || !f.value(idx) {
continue;
}
}
if values.is_null(idx) {
continue;
}
Expand Down Expand Up @@ -540,7 +544,10 @@ impl GroupsAccumulator for SumDecimalGroupsAccumulator {
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> DFResult<()> {
assert!(opt_filter.is_none(), "opt_filter is not supported yet");
debug_assert!(
opt_filter.is_none(),
"opt_filter is not supported in merge_batch"
);

self.resize_helper(total_num_groups);

Expand Down Expand Up @@ -712,4 +719,56 @@ mod tests {
let schema = Schema::new(fields);
RecordBatch::try_new(Arc::new(schema), columns).unwrap()
}

#[test]
fn test_update_batch_with_filter() {
use arrow::array::Decimal128Array;
use datafusion::logical_expr::{EmitTo, GroupsAccumulator};

let data_type = DataType::Decimal128(10, 2);
let mut acc = SumDecimalGroupsAccumulator::new(
data_type.clone(),
10,
EvalMode::Legacy,
None,
crate::create_query_context_map(),
);

// values: [100, 200, 300, 400], filter: [T, F, T, F] => sum = 100+300 = 400
let values: ArrayRef = Arc::new(
Decimal128Array::from(vec![100i128, 200, 300, 400]).with_data_type(data_type.clone()),
);
let filter = BooleanArray::from(vec![true, false, true, false]);
acc.update_batch(&[values], &[0, 0, 0, 0], Some(&filter), 1)
.unwrap();

let result = acc.evaluate(EmitTo::All).unwrap();
let result = result.as_any().downcast_ref::<Decimal128Array>().unwrap();
assert_eq!(result.value(0), 400);
}

#[test]
fn test_update_batch_filter_null_treated_as_exclude() {
use arrow::array::Decimal128Array;
use datafusion::logical_expr::{EmitTo, GroupsAccumulator};

let data_type = DataType::Decimal128(10, 2);
let mut acc = SumDecimalGroupsAccumulator::new(
data_type.clone(),
10,
EvalMode::Legacy,
None,
crate::create_query_context_map(),
);

let values: ArrayRef =
Arc::new(Decimal128Array::from(vec![10i128, 20, 30]).with_data_type(data_type.clone()));
let filter = BooleanArray::from(vec![Some(true), None, Some(true)]);
acc.update_batch(&[values], &[0, 0, 0], Some(&filter), 1)
.unwrap();

let result = acc.evaluate(EmitTo::All).unwrap();
let result = result.as_any().downcast_ref::<Decimal128Array>().unwrap();
assert_eq!(result.value(0), 40); // 10 + 30 = 40
}
}
Loading
Loading