diff --git a/native-engine/auron-planner/proto/auron.proto b/native-engine/auron-planner/proto/auron.proto index b0618b971..9a5e50408 100644 --- a/native-engine/auron-planner/proto/auron.proto +++ b/native-engine/auron-planner/proto/auron.proto @@ -129,6 +129,7 @@ enum WindowFunction { ROW_NUMBER = 0; RANK = 1; DENSE_RANK = 2; + LEAD = 3; } enum AggFunction { diff --git a/native-engine/auron-planner/src/planner.rs b/native-engine/auron-planner/src/planner.rs index 84a625734..7e0b838df 100644 --- a/native-engine/auron-planner/src/planner.rs +++ b/native-engine/auron-planner/src/planner.rs @@ -636,6 +636,7 @@ impl PhysicalPlanner { protobuf::WindowFunction::DenseRank => { WindowFunction::RankLike(WindowRankType::DenseRank) } + protobuf::WindowFunction::Lead => WindowFunction::Lead, }, protobuf::WindowFunctionType::Agg => match w.agg_func() { protobuf::AggFunction::Min => WindowFunction::Agg(AggFunction::Min), diff --git a/native-engine/datafusion-ext-plans/src/window/mod.rs b/native-engine/datafusion-ext-plans/src/window/mod.rs index a9e9da29d..ead885bb9 100644 --- a/native-engine/datafusion-ext-plans/src/window/mod.rs +++ b/native-engine/datafusion-ext-plans/src/window/mod.rs @@ -23,8 +23,8 @@ use crate::{ agg::{AggFunction, agg::create_agg}, window::{ processors::{ - agg_processor::AggProcessor, rank_processor::RankProcessor, - row_number_processor::RowNumberProcessor, + agg_processor::AggProcessor, lead_processor::LeadProcessor, + rank_processor::RankProcessor, row_number_processor::RowNumberProcessor, }, window_context::WindowContext, }, @@ -36,6 +36,7 @@ pub mod window_context; #[derive(Debug, Clone, Copy)] pub enum WindowFunction { RankLike(WindowRankType), + Lead, Agg(AggFunction), } @@ -87,6 +88,7 @@ impl WindowExpr { WindowFunction::RankLike(WindowRankType::DenseRank) => { Ok(Box::new(RankProcessor::new(true))) } + WindowFunction::Lead => Ok(Box::new(LeadProcessor::new(self.children.clone()))), WindowFunction::Agg(agg_func) => { let agg = create_agg( agg_func.clone(), @@ -98,4 +100,8 @@ impl WindowExpr { } } } + + pub fn requires_full_partition(&self) -> bool { + matches!(self.func, WindowFunction::Lead) + } } diff --git a/native-engine/datafusion-ext-plans/src/window/processors/lead_processor.rs b/native-engine/datafusion-ext-plans/src/window/processors/lead_processor.rs new file mode 100644 index 000000000..c8b73face --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/window/processors/lead_processor.rs @@ -0,0 +1,109 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use arrow::{array::ArrayRef, datatypes::DataType, record_batch::RecordBatch}; +use datafusion::{ + common::{DataFusionError, Result, ScalarValue}, + physical_expr::PhysicalExprRef, +}; +use datafusion_ext_commons::arrow::cast::cast; + +use crate::window::{WindowFunctionProcessor, window_context::WindowContext}; + +pub struct LeadProcessor { + children: Vec, +} + +impl LeadProcessor { + pub fn new(children: Vec) -> Self { + Self { children } + } +} + +impl WindowFunctionProcessor for LeadProcessor { + fn process_batch(&mut self, context: &WindowContext, batch: &RecordBatch) -> Result { + assert_eq!( + self.children.len(), + 3, + "lead expects input/offset/default children", + ); + + let input_values = self.children[0] + .evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows()))?; + + let offset_values = self.children[1] + .evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows()))?; + let offset_values = if offset_values.data_type() == &DataType::Int32 { + offset_values + } else { + cast(&offset_values, &DataType::Int32)? + }; + let offset = match ScalarValue::try_from_array(&offset_values, 0)? { + ScalarValue::Int32(Some(offset)) => offset as i64, + other => { + return Err(DataFusionError::Execution(format!( + "lead offset must be a non-null foldable integer, got {other:?}", + ))); + } + }; + + let default_values = self.children[2] + .evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows()))?; + let default_values = if default_values.data_type() == input_values.data_type() { + default_values + } else { + cast(&default_values, input_values.data_type())? + }; + + let mut partition_starts = vec![0usize; batch.num_rows()]; + let mut partition_ends = vec![batch.num_rows(); batch.num_rows()]; + if context.has_partition() && batch.num_rows() > 0 { + let partition_rows = context.get_partition_rows(batch)?; + let mut partition_start = 0usize; + for row_idx in 1..=batch.num_rows() { + let is_boundary = row_idx == batch.num_rows() + || partition_rows.row(row_idx).as_ref() + != partition_rows.row(partition_start).as_ref(); + if is_boundary { + for idx in partition_start..row_idx { + partition_starts[idx] = partition_start; + partition_ends[idx] = row_idx; + } + partition_start = row_idx; + } + } + } + + let mut output = Vec::with_capacity(batch.num_rows()); + for row_idx in 0..batch.num_rows() { + let target_idx = row_idx as i64 + offset; + let partition_start = partition_starts[row_idx] as i64; + let partition_end = partition_ends[row_idx] as i64; + let value = if target_idx >= partition_start && target_idx < partition_end { + ScalarValue::try_from_array(&input_values, target_idx as usize)? + } else { + ScalarValue::try_from_array(&default_values, row_idx)? + }; + output.push(value); + } + + ScalarValue::iter_to_array(output) + } +} diff --git a/native-engine/datafusion-ext-plans/src/window/processors/mod.rs b/native-engine/datafusion-ext-plans/src/window/processors/mod.rs index 7d4a72b55..c92bc5d19 100644 --- a/native-engine/datafusion-ext-plans/src/window/processors/mod.rs +++ b/native-engine/datafusion-ext-plans/src/window/processors/mod.rs @@ -14,5 +14,6 @@ // limitations under the License. pub mod agg_processor; +pub mod lead_processor; pub mod rank_processor; pub mod row_number_processor; diff --git a/native-engine/datafusion-ext-plans/src/window/window_context.rs b/native-engine/datafusion-ext-plans/src/window/window_context.rs index a76eb1253..1c5f68f12 100644 --- a/native-engine/datafusion-ext-plans/src/window/window_context.rs +++ b/native-engine/datafusion-ext-plans/src/window/window_context.rs @@ -167,4 +167,10 @@ impl WindowContext { .collect::>>()?, )?) } + + pub fn requires_full_partition(&self) -> bool { + self.window_exprs + .iter() + .any(|expr| expr.requires_full_partition()) + } } diff --git a/native-engine/datafusion-ext-plans/src/window_exec.rs b/native-engine/datafusion-ext-plans/src/window_exec.rs index 5bb698eec..36e63cc9f 100644 --- a/native-engine/datafusion-ext-plans/src/window_exec.rs +++ b/native-engine/datafusion-ext-plans/src/window_exec.rs @@ -17,6 +17,7 @@ use std::{any::Any, fmt::Formatter, sync::Arc}; use arrow::{ array::{Array, ArrayRef, Int32Array}, + compute::concat_batches, datatypes::SchemaRef, record_batch::{RecordBatch, RecordBatchOptions}, }; @@ -37,7 +38,7 @@ use once_cell::sync::OnceCell; use crate::{ common::execution_context::ExecutionContext, - window::{WindowExpr, window_context::WindowContext}, + window::{WindowExpr, WindowFunctionProcessor, window_context::WindowContext}, }; #[derive(Debug)] @@ -217,45 +218,28 @@ fn execute_window( .map(|expr: &WindowExpr| expr.create_processor(&window_ctx)) .collect::>>()?; - while let Some(mut batch) = input.next().await.transpose()? { - let _timer = exec_ctx.baseline_metrics().elapsed_compute().timer(); - let mut window_cols: Vec = processors - .iter_mut() - .map(|processor| processor.process_batch(&window_ctx, &batch)) - .collect::>()?; - - if let Some(group_limit) = window_ctx.group_limit { - assert_eq!(window_cols.len(), 1); - let limited = arrow::compute::kernels::cmp::lt_eq( - &window_cols[0], - &Int32Array::new_scalar(group_limit as i32), - )?; - window_cols[0] = arrow::compute::filter(&window_cols[0], &limited)?; - batch = arrow::compute::filter_record_batch(&batch, &limited)?; + if window_ctx.requires_full_partition() { + let mut staging_batches = vec![]; + while let Some(batch) = input.next().await.transpose()? { + staging_batches.push(batch); + } + + if !staging_batches.is_empty() { + let batch = concat_batches(&window_ctx.input_schema, &staging_batches)?; + let output_batch = + process_window_batch(batch, &window_ctx, processors.as_mut_slice())?; + exec_ctx + .baseline_metrics() + .record_output(output_batch.num_rows()); + sender.send(output_batch).await; } + return Ok(()); + } - let outputs: Vec = batch - .columns() - .iter() - .cloned() - .chain(if window_ctx.output_window_cols { - window_cols - } else { - vec![] - }) - .zip(window_ctx.output_schema.fields()) - .map(|(array, field)| { - if array.data_type() != field.data_type() { - return cast(&array, field.data_type()); - } - Ok(array.clone()) - }) - .collect::>()?; - let output_batch = RecordBatch::try_new_with_options( - window_ctx.output_schema.clone(), - outputs, - &RecordBatchOptions::new().with_row_count(Some(batch.num_rows())), - )?; + while let Some(batch) = input.next().await.transpose()? { + let _timer = exec_ctx.baseline_metrics().elapsed_compute().timer(); + let output_batch = + process_window_batch(batch, &window_ctx, processors.as_mut_slice())?; exec_ctx .baseline_metrics() .record_output(output_batch.num_rows()); @@ -265,6 +249,50 @@ fn execute_window( })) } +fn process_window_batch( + mut batch: RecordBatch, + window_ctx: &WindowContext, + processors: &mut [Box], +) -> Result { + let mut window_cols: Vec = processors + .iter_mut() + .map(|processor| processor.process_batch(window_ctx, &batch)) + .collect::>()?; + + if let Some(group_limit) = window_ctx.group_limit { + assert_eq!(window_cols.len(), 1); + let limited = arrow::compute::kernels::cmp::lt_eq( + &window_cols[0], + &Int32Array::new_scalar(group_limit as i32), + )?; + window_cols[0] = arrow::compute::filter(&window_cols[0], &limited)?; + batch = arrow::compute::filter_record_batch(&batch, &limited)?; + } + + let outputs: Vec = batch + .columns() + .iter() + .cloned() + .chain(if window_ctx.output_window_cols { + window_cols + } else { + vec![] + }) + .zip(window_ctx.output_schema.fields()) + .map(|(array, field)| { + if array.data_type() != field.data_type() { + return cast(&array, field.data_type()); + } + Ok(array.clone()) + }) + .collect::>()?; + Ok(RecordBatch::try_new_with_options( + window_ctx.output_schema.clone(), + outputs, + &RecordBatchOptions::new().with_row_count(Some(batch.num_rows())), + )?) +} + #[cfg(test)] mod test { use std::sync::Arc; @@ -273,9 +301,13 @@ mod test { use datafusion::{ assert_batches_eq, common::Result, - physical_expr::{PhysicalSortExpr, expressions::Column}, + physical_expr::{ + PhysicalSortExpr, + expressions::{Column, Literal}, + }, physical_plan::{ExecutionPlan, test::TestMemoryExec}, prelude::SessionContext, + scalar::ScalarValue, }; use crate::{ @@ -491,4 +523,65 @@ mod test { assert_batches_eq!(expected, &batches); Ok(()) } + + #[tokio::test] + async fn test_window_lead_across_batches() -> Result<(), Box> { + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + + let batch1 = build_table_i32( + ("a1", &vec![1, 1]), + ("b1", &vec![10, 20]), + ("c1", &vec![0, 0]), + )?; + let batch2 = build_table_i32( + ("a1", &vec![1, 2]), + ("b1", &vec![30, 40]), + ("c1", &vec![0, 0]), + )?; + let schema = batch1.schema(); + let input = Arc::new(TestMemoryExec::try_new( + &[vec![batch1, batch2]], + schema, + None, + )?); + + let window_exprs = vec![WindowExpr::new( + WindowFunction::Lead, + vec![ + Arc::new(Column::new("b1", 1)), + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + Arc::new(Literal::new(ScalarValue::Int32(Some(-1)))), + ], + Arc::new(Field::new("b1_lead", DataType::Int32, false)), + DataType::Int32, + )]; + + let window = Arc::new(WindowExec::try_new( + input, + window_exprs, + vec![Arc::new(Column::new("a1", 0))], + vec![PhysicalSortExpr { + expr: Arc::new(Column::new("b1", 1)), + options: Default::default(), + }], + None, + true, + )?); + + let stream = window.execute(0, task_ctx)?; + let batches = datafusion::physical_plan::common::collect(stream).await?; + let expected = vec![ + "+----+----+----+---------+", + "| a1 | b1 | c1 | b1_lead |", + "+----+----+----+---------+", + "| 1 | 10 | 0 | 20 |", + "| 1 | 20 | 0 | 30 |", + "| 1 | 30 | 0 | -1 |", + "| 2 | 40 | 0 | -1 |", + "+----+----+----+---------+", + ]; + assert_batches_eq!(expected, &batches); + Ok(()) + } } diff --git a/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronWindowSuite.scala b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronWindowSuite.scala new file mode 100644 index 000000000..5f1e78fb8 --- /dev/null +++ b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronWindowSuite.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.auron + +import org.apache.spark.sql.AuronQueryTest +import org.apache.spark.sql.execution.auron.plan.NativeWindowBase + +import org.apache.auron.util.AuronTestUtils + +class AuronWindowSuite extends AuronQueryTest with BaseAuronSQLSuite with AuronSQLTestHelper { + + test("lead window function") { + withSQLConf("spark.auron.enable.window" -> "true") { + withTable("t1") { + sql("create table t1(id int, grp int, v string) using parquet") + sql("insert into t1 values (1, 1, 'a'), (2, 1, null), (3, 1, 'c'), (4, 2, 'x')") + + checkSparkAnswerAndOperator("""select + | id, + | grp, + | v, + | lead(v) over (partition by grp order by id) as next_v, + | lead(v, 2, 'fallback') over (partition by grp order by id) as next2_v + |from t1 + |""".stripMargin) + } + } + } + + test("lead window function with ignore nulls falls back") { + if (AuronTestUtils.isSparkV32OrGreater) { + withSQLConf("spark.auron.enable.window" -> "true") { + withTable("t1") { + sql("create table t1(id int, grp int, v string) using parquet") + sql("insert into t1 values (1, 1, 'a'), (2, 1, null), (3, 1, 'c'), (4, 2, 'x')") + + val df = checkSparkAnswer("""select + | id, + | grp, + | lead(v, 1, 'fallback') ignore nulls + | over (partition by grp order by id) as next_non_null_v + |from t1 + |""".stripMargin) + val plan = stripAQEPlan(df.queryExecution.executedPlan) + assert(plan.collectFirst { case _: NativeWindowBase => true }.isEmpty) + } + } + } + } +} diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowBase.scala index fad61ff09..4b1fc4c03 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowBase.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.Ascending import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.DenseRank import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.Lead import org.apache.spark.sql.catalyst.expressions.NamedExpression import org.apache.spark.sql.catalyst.expressions.NullsFirst import org.apache.spark.sql.catalyst.expressions.Rank @@ -89,6 +90,11 @@ abstract class NativeWindowBase( override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec) + private def leadIgnoreNulls(expr: Lead): Boolean = + expr.getClass.getMethods + .find(method => method.getName == "ignoreNulls" && method.getParameterCount == 0) + .exists(method => method.invoke(expr).asInstanceOf[Boolean]) + private def nativeWindowExprs = windowExpression.map { named => val field = NativeConverters.convertField(Util.getSchema(named :: Nil).fields(0)) val windowExprBuilder = pb.WindowExprNode.newBuilder().setField(field) @@ -118,6 +124,17 @@ abstract class NativeWindowBase( windowExprBuilder.setFuncType(pb.WindowFunctionType.Window) windowExprBuilder.setWindowFunc(pb.WindowFunction.DENSE_RANK) + case e: Lead => + assert( + spec.frameSpecification == e.frame, + s"window frame not supported: ${spec.frameSpecification}") + assert(!leadIgnoreNulls(e), "window function not supported: lead with IGNORE NULLS") + windowExprBuilder.setFuncType(pb.WindowFunctionType.Window) + windowExprBuilder.setWindowFunc(pb.WindowFunction.LEAD) + windowExprBuilder.addChildren(NativeConverters.convertExpr(e.input)) + windowExprBuilder.addChildren(NativeConverters.convertExpr(e.offset)) + windowExprBuilder.addChildren(NativeConverters.convertExpr(e.default)) + case e: Sum => assert( spec.frameSpecification == RowNumber().frame, // only supports RowFrame(Unbounde, CurrentRow)