diff --git a/native/core/src/execution/expressions/temporal.rs b/native/core/src/execution/expressions/temporal.rs index ae57cd3b2b..10483e9d1b 100644 --- a/native/core/src/execution/expressions/temporal.rs +++ b/native/core/src/execution/expressions/temporal.rs @@ -25,7 +25,8 @@ use datafusion::logical_expr::ScalarUDF; use datafusion::physical_expr::{PhysicalExpr, ScalarFunctionExpr}; use datafusion_comet_proto::spark_expression::Expr; use datafusion_comet_spark_expr::{ - SparkHour, SparkMinute, SparkSecond, SparkUnixTimestamp, TimestampTruncExpr, + SparkHour, SparkHoursTransform, SparkMinute, SparkSecond, SparkUnixTimestamp, + TimestampTruncExpr, }; use crate::execution::{ @@ -160,3 +161,31 @@ impl ExpressionBuilder for TruncTimestampBuilder { Ok(Arc::new(TimestampTruncExpr::new(child, format, timezone))) } } + +pub struct HoursTransformBuilder; + +impl ExpressionBuilder for HoursTransformBuilder { + fn build( + &self, + spark_expr: &Expr, + input_schema: SchemaRef, + planner: &PhysicalPlanner, + ) -> Result, ExecutionError> { + let expr = extract_expr!(spark_expr, HoursTransform); + let child = planner.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?; + let timezone = expr.timezone.clone(); + let args = vec![child]; + let comet_hours_transform = + Arc::new(ScalarUDF::new_from_impl(SparkHoursTransform::new(timezone))); + let field_ref = Arc::new(Field::new("hours_transform", DataType::Int32, true)); + let expr: ScalarFunctionExpr = ScalarFunctionExpr::new( + "hours_transform", + comet_hours_transform, + args, + field_ref, + Arc::new(ConfigOptions::default()), + ); + + Ok(Arc::new(expr)) + } +} diff --git a/native/core/src/execution/planner/expression_registry.rs b/native/core/src/execution/planner/expression_registry.rs index bf3904d9c1..919a72a21a 100644 --- a/native/core/src/execution/planner/expression_registry.rs +++ b/native/core/src/execution/planner/expression_registry.rs @@ -110,6 +110,7 @@ pub enum ExpressionType { Second, TruncTimestamp, UnixTimestamp, + HoursTransform, } /// Registry for expression builders @@ -310,6 +311,10 @@ impl ExpressionRegistry { ExpressionType::TruncTimestamp, Box::new(TruncTimestampBuilder), ); + self.builders.insert( + ExpressionType::HoursTransform, + Box::new(HoursTransformBuilder), + ); } /// Extract expression type from Spark protobuf expression @@ -382,6 +387,7 @@ impl ExpressionRegistry { Some(ExprStruct::Second(_)) => Ok(ExpressionType::Second), Some(ExprStruct::TruncTimestamp(_)) => Ok(ExpressionType::TruncTimestamp), Some(ExprStruct::UnixTimestamp(_)) => Ok(ExpressionType::UnixTimestamp), + Some(ExprStruct::HoursTransform(_)) => Ok(ExpressionType::HoursTransform), Some(other) => Err(ExecutionError::GeneralError(format!( "Unsupported expression type: {:?}", diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 32cbc0ce13..29e3926715 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -88,6 +88,7 @@ message Expr { UnixTimestamp unix_timestamp = 65; FromJson from_json = 66; ToCsv to_csv = 67; + HoursTransform hours_transform = 68; } // Optional QueryContext for error reporting (contains SQL text and position) @@ -349,6 +350,11 @@ message Hour { string timezone = 2; } +message HoursTransform { + Expr child = 1; + string timezone = 2; +} + message Minute { Expr child = 1; string timezone = 2; diff --git a/native/spark-expr/src/datetime_funcs/hours.rs b/native/spark-expr/src/datetime_funcs/hours.rs new file mode 100644 index 0000000000..5abf6c45f1 --- /dev/null +++ b/native/spark-expr/src/datetime_funcs/hours.rs @@ -0,0 +1,299 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Spark-compatible `hours` V2 partition transform. +//! +//! Computes the number of hours since the Unix epoch (1970-01-01 00:00:00 UTC). +//! +//! - For `Timestamp(Microsecond, Some(tz))`: applies timezone offset before computing. +//! - For `Timestamp(Microsecond, None)` (NTZ): uses raw microseconds directly. + +use arrow::array::cast::as_primitive_array; +use arrow::array::types::TimestampMicrosecondType; +use arrow::array::{Array, Int32Array}; +use arrow::datatypes::{DataType, TimeUnit::Microsecond}; +use arrow::temporal_conversions::as_datetime; +use chrono::{Offset, TimeZone}; +use datafusion::common::{internal_datafusion_err, DataFusionError}; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use std::{any::Any, fmt::Debug, sync::Arc}; + +use crate::timezone::Tz; + +const MICROS_PER_HOUR: i64 = 3_600_000_000; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkHoursTransform { + signature: Signature, + timezone: String, +} + +impl SparkHoursTransform { + pub fn new(timezone: String) -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + timezone, + } + } +} + +impl ScalarUDFImpl for SparkHoursTransform { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "hours_transform" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> datafusion::common::Result { + Ok(DataType::Int32) + } + + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion::common::Result { + let args: [ColumnarValue; 1] = args.args.try_into().map_err(|_| { + internal_datafusion_err!("hours_transform expects exactly one argument") + })?; + + match args { + [ColumnarValue::Array(array)] => { + let ts_array = as_primitive_array::(&array); + let result: Int32Array = match array.data_type() { + DataType::Timestamp(Microsecond, Some(_)) => { + let tz: Tz = self.timezone.parse().map_err(|e| { + DataFusionError::Execution(format!( + "Failed to parse timezone '{}': {}", + self.timezone, e + )) + })?; + arrow::compute::kernels::arity::try_unary(ts_array, |micros| { + let dt = as_datetime::(micros).ok_or_else( + || { + DataFusionError::Execution(format!( + "Cannot convert {micros} to datetime" + )) + }, + )?; + let offset_secs = + tz.offset_from_utc_datetime(&dt).fix().local_minus_utc() as i64; + let local_micros = micros + offset_secs * 1_000_000; + Ok(local_micros.div_euclid(MICROS_PER_HOUR) as i32) + })? + } + DataType::Timestamp(Microsecond, None) => { + arrow::compute::kernels::arity::unary(ts_array, |micros| { + micros.div_euclid(MICROS_PER_HOUR) as i32 + }) + } + other => { + return Err(DataFusionError::Execution(format!( + "hours_transform does not support input type: {:?}", + other + ))); + } + }; + Ok(ColumnarValue::Array(Arc::new(result))) + } + _ => Err(DataFusionError::Execution( + "hours_transform(scalar) should be folded on Spark JVM side.".to_string(), + )), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::TimestampMicrosecondArray; + use arrow::datatypes::Field; + use datafusion::config::ConfigOptions; + use std::sync::Arc; + + #[test] + fn test_hours_transform_utc() { + let udf = SparkHoursTransform::new("UTC".to_string()); + // 2023-10-01 14:30:00 UTC = 1696171800 seconds = 1696171800000000 micros + // Expected hours since epoch = 1696171800000000 / 3600000000 = 471158 + let ts = TimestampMicrosecondArray::from(vec![Some(1_696_171_800_000_000i64)]) + .with_timezone("UTC"); + let return_field = Arc::new(Field::new("hours_transform", DataType::Int32, true)); + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(Arc::new(ts))], + number_rows: 1, + return_field, + config_options: Arc::new(ConfigOptions::default()), + arg_fields: vec![], + }; + let result = udf.invoke_with_args(args).unwrap(); + match result { + ColumnarValue::Array(arr) => { + let int_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(int_arr.value(0), 471158); + } + _ => panic!("Expected array"), + } + } + + #[test] + fn test_hours_transform_ntz() { + let udf = SparkHoursTransform::new("UTC".to_string()); + // Same timestamp but NTZ (no timezone on array) + let ts = TimestampMicrosecondArray::from(vec![Some(1_696_171_800_000_000i64)]); + let return_field = Arc::new(Field::new("hours_transform", DataType::Int32, true)); + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(Arc::new(ts))], + number_rows: 1, + return_field, + config_options: Arc::new(ConfigOptions::default()), + arg_fields: vec![], + }; + let result = udf.invoke_with_args(args).unwrap(); + match result { + ColumnarValue::Array(arr) => { + let int_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(int_arr.value(0), 471158); + } + _ => panic!("Expected array"), + } + } + + #[test] + fn test_hours_transform_negative_epoch() { + let udf = SparkHoursTransform::new("UTC".to_string()); + // 1969-12-31 23:30:00 UTC = -1800 seconds = -1800000000 micros + // Expected: div_euclid(-1800000000, 3600000000) = -1 + let ts = + TimestampMicrosecondArray::from(vec![Some(-1_800_000_000i64)]).with_timezone("UTC"); + let return_field = Arc::new(Field::new("hours_transform", DataType::Int32, true)); + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(Arc::new(ts))], + number_rows: 1, + return_field, + config_options: Arc::new(ConfigOptions::default()), + arg_fields: vec![], + }; + let result = udf.invoke_with_args(args).unwrap(); + match result { + ColumnarValue::Array(arr) => { + let int_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(int_arr.value(0), -1); + } + _ => panic!("Expected array"), + } + } + + #[test] + fn test_hours_transform_null() { + let udf = SparkHoursTransform::new("UTC".to_string()); + let ts = TimestampMicrosecondArray::from(vec![None as Option]).with_timezone("UTC"); + let return_field = Arc::new(Field::new("hours_transform", DataType::Int32, true)); + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(Arc::new(ts))], + number_rows: 1, + return_field, + config_options: Arc::new(ConfigOptions::default()), + arg_fields: vec![], + }; + let result = udf.invoke_with_args(args).unwrap(); + match result { + ColumnarValue::Array(arr) => { + let int_arr = arr.as_any().downcast_ref::().unwrap(); + assert!(int_arr.is_null(0)); + } + _ => panic!("Expected array"), + } + } + + #[test] + fn test_hours_transform_epoch_zero() { + let udf = SparkHoursTransform::new("UTC".to_string()); + let ts = TimestampMicrosecondArray::from(vec![Some(0i64)]).with_timezone("UTC"); + let return_field = Arc::new(Field::new("hours_transform", DataType::Int32, true)); + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(Arc::new(ts))], + number_rows: 1, + return_field, + config_options: Arc::new(ConfigOptions::default()), + arg_fields: vec![], + }; + let result = udf.invoke_with_args(args).unwrap(); + match result { + ColumnarValue::Array(arr) => { + let int_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(int_arr.value(0), 0); + } + _ => panic!("Expected array"), + } + } + + #[test] + fn test_hours_transform_non_utc_timezone() { + // Asia/Tokyo is UTC+9. For a UTC timestamp of 1970-01-01 00:00:00 UTC (micros=0), + // local time = 1970-01-01 09:00:00 JST, so local hours since epoch = 9. + let udf = SparkHoursTransform::new("Asia/Tokyo".to_string()); + let ts = TimestampMicrosecondArray::from(vec![Some(0i64)]).with_timezone("Asia/Tokyo"); + let return_field = Arc::new(Field::new("hours_transform", DataType::Int32, true)); + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(Arc::new(ts))], + number_rows: 1, + return_field, + config_options: Arc::new(ConfigOptions::default()), + arg_fields: vec![], + }; + let result = udf.invoke_with_args(args).unwrap(); + match result { + ColumnarValue::Array(arr) => { + let int_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(int_arr.value(0), 9); + } + _ => panic!("Expected array"), + } + } + + #[test] + fn test_hours_transform_ntz_ignores_timezone() { + // NTZ with micros=0 should always return 0, regardless of the timezone + // string stored in the UDF (proving the NTZ path ignores timezone). + let udf = SparkHoursTransform::new("Asia/Tokyo".to_string()); + let ts = TimestampMicrosecondArray::from(vec![Some(0i64)]); // No timezone on array + let return_field = Arc::new(Field::new("hours_transform", DataType::Int32, true)); + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(Arc::new(ts))], + number_rows: 1, + return_field, + config_options: Arc::new(ConfigOptions::default()), + arg_fields: vec![], + }; + let result = udf.invoke_with_args(args).unwrap(); + match result { + ColumnarValue::Array(arr) => { + let int_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(int_arr.value(0), 0); // NOT 9, because NTZ ignores timezone + } + _ => panic!("Expected array"), + } + } +} diff --git a/native/spark-expr/src/datetime_funcs/mod.rs b/native/spark-expr/src/datetime_funcs/mod.rs index 5bafc1d287..d8ed8abb98 100644 --- a/native/spark-expr/src/datetime_funcs/mod.rs +++ b/native/spark-expr/src/datetime_funcs/mod.rs @@ -18,6 +18,7 @@ mod date_diff; mod date_trunc; mod extract_date_part; +mod hours; mod make_date; mod timestamp_trunc; mod unix_timestamp; @@ -27,6 +28,7 @@ pub use date_trunc::SparkDateTrunc; pub use extract_date_part::SparkHour; pub use extract_date_part::SparkMinute; pub use extract_date_part::SparkSecond; +pub use hours::SparkHoursTransform; pub use make_date::SparkMakeDate; pub use timestamp_trunc::TimestampTruncExpr; pub use unix_timestamp::SparkUnixTimestamp; diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index a7711d642d..74fcae8f4c 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -71,8 +71,8 @@ pub use comet_scalar_funcs::{ }; pub use csv_funcs::*; pub use datetime_funcs::{ - SparkDateDiff, SparkDateTrunc, SparkHour, SparkMakeDate, SparkMinute, SparkSecond, - SparkUnixTimestamp, TimestampTruncExpr, + SparkDateDiff, SparkDateTrunc, SparkHour, SparkHoursTransform, SparkMakeDate, SparkMinute, + SparkSecond, SparkUnixTimestamp, TimestampTruncExpr, }; pub use error::{decimal_overflow_error, SparkError, SparkErrorWithContext, SparkResult}; pub use hash_funcs::*; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 02a76f69f0..4800ea75ca 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -197,6 +197,7 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[DateDiff] -> CometDateDiff, classOf[DateFormatClass] -> CometDateFormat, classOf[Days] -> CometDays, + classOf[Hours] -> CometHours, classOf[DateSub] -> CometDateSub, classOf[UnixDate] -> CometUnixDate, classOf[FromUnixTime] -> CometFromUnixTime, diff --git a/spark/src/main/scala/org/apache/comet/serde/datetime.scala b/spark/src/main/scala/org/apache/comet/serde/datetime.scala index 8f3894c1ac..1be8d7ec0e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/datetime.scala +++ b/spark/src/main/scala/org/apache/comet/serde/datetime.scala @@ -21,7 +21,7 @@ package org.apache.comet.serde import java.util.Locale -import org.apache.spark.sql.catalyst.expressions.{Attribute, DateAdd, DateDiff, DateFormatClass, DateSub, DayOfMonth, DayOfWeek, DayOfYear, Days, GetDateField, Hour, LastDay, Literal, MakeDate, Minute, Month, NextDay, Quarter, Second, TruncDate, TruncTimestamp, UnixDate, UnixTimestamp, WeekDay, WeekOfYear, Year} +import org.apache.spark.sql.catalyst.expressions.{Attribute, DateAdd, DateDiff, DateFormatClass, DateSub, DayOfMonth, DayOfWeek, DayOfYear, Days, GetDateField, Hour, Hours, LastDay, Literal, MakeDate, Minute, Month, NextDay, Quarter, Second, TruncDate, TruncTimestamp, UnixDate, UnixTimestamp, WeekDay, WeekOfYear, Year} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DateType, IntegerType, StringType, TimestampType} import org.apache.spark.unsafe.types.UTF8String @@ -589,6 +589,39 @@ object CometDateFormat extends CometExpressionSerde[DateFormatClass] { } } +/** + * Converts a timestamp to the number of hours since Unix epoch (1970-01-01 00:00:00 UTC). This is + * a V2 partition transform expression. + * + * For TimestampType: uses timezone-aware conversion to determine the local hour boundary. For + * TimestampNTZType: directly divides the raw microsecond value (wall-clock time). + */ +object CometHours extends CometExpressionSerde[Hours] { + override def convert( + expr: Hours, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val childExpr = exprToProtoInternal(expr.child, inputs, binding) + + if (childExpr.isDefined) { + val builder = ExprOuterClass.HoursTransform.newBuilder() + builder.setChild(childExpr.get) + + val timeZone = SQLConf.get.sessionLocalTimeZone + builder.setTimezone(timeZone) + + Some( + ExprOuterClass.Expr + .newBuilder() + .setHoursTransform(builder) + .build()) + } else { + withInfo(expr, expr.child) + None + } + } +} + /** * Converts a timestamp or date to the number of days since Unix epoch (1970-01-01). This is a V2 * partition transform expression.