diff --git a/native-engine/datafusion-ext-functions/src/spark_dates.rs b/native-engine/datafusion-ext-functions/src/spark_dates.rs index 6693d8d8e..3bc830d87 100644 --- a/native-engine/datafusion-ext-functions/src/spark_dates.rs +++ b/native-engine/datafusion-ext-functions/src/spark_dates.rs @@ -22,7 +22,7 @@ use arrow::{ compute::{DatePart, date_part}, datatypes::{DataType, TimeUnit}, }; -use chrono::{Duration, LocalResult, NaiveDate, TimeZone, Utc, prelude::*}; +use chrono::{Duration, LocalResult, NaiveDate, Offset, TimeZone, Utc, prelude::*}; use chrono_tz::Tz; use datafusion::{ common::{DataFusionError, Result, ScalarValue}, @@ -30,69 +30,12 @@ use datafusion::{ }; use datafusion_ext_commons::arrow::cast::cast; -// ---- date parts on Date32 via Arrow's date_part -// ----------------------------------------------- - -pub fn spark_year(args: &[ColumnarValue]) -> Result { - let input = cast(&args[0].clone().into_array(1)?, &DataType::Date32)?; - Ok(ColumnarValue::Array(date_part(&input, DatePart::Year)?)) -} - -pub fn spark_month(args: &[ColumnarValue]) -> Result { - let input = cast(&args[0].clone().into_array(1)?, &DataType::Date32)?; - Ok(ColumnarValue::Array(date_part(&input, DatePart::Month)?)) -} - -pub fn spark_day(args: &[ColumnarValue]) -> Result { - let input = cast(&args[0].clone().into_array(1)?, &DataType::Date32)?; - Ok(ColumnarValue::Array(date_part(&input, DatePart::Day)?)) -} - -/// `spark_dayofweek(date/timestamp/compatible-string)` -/// -/// Matches Spark's `dayofweek()` semantics: -/// Sunday = 1, Monday = 2, ..., Saturday = 7. -pub fn spark_dayofweek(args: &[ColumnarValue]) -> Result { - let input = cast(&args[0].clone().into_array(1)?, &DataType::Date32)?; - let input = input - .as_any() - .downcast_ref::() - .expect("internal cast to Date32 must succeed"); - - // Date32 is days since 1970-01-01. 1970-01-01 is a Thursday. - // If we number weekdays so that Sunday = 0, ..., Saturday = 6, - // then 1970-01-01 corresponds to 4. For an offset `days`, - // weekday_index = (days + 4) mod 7 gives 0 = Sunday, ..., 6 = Saturday. - // Spark wants Sunday = 1, ..., Saturday = 7, so we add 1. - let dayofweek = Int32Array::from_iter(input.iter().map(|opt_days| { - opt_days.map(|days| { - let weekday_index = (days as i64 + 4).rem_euclid(7); - weekday_index as i32 + 1 - }) - })); - - Ok(ColumnarValue::Array(Arc::new(dayofweek))) -} - -/// `spark_weekofyear(date/timestamp/compatible-string[, timezone])` -/// -/// Matches Spark's `weekofyear()` semantics: -/// ISO week numbering, with Monday as the first day of the week, -/// and week 1 defined as the first week with more than 3 days. -/// -/// For `Timestamp` inputs, this function interprets epoch milliseconds in the -/// provided timezone (if any) before deriving the calendar date and ISO week. -/// If no timezone is provided, `UTC` is used by default. If an invalid -/// timezone string is provided, the function returns an execution error. -/// For `Date` and compatible string inputs, the behavior is unchanged: the -/// value is cast to `Date32` and the ISO week is computed from the resulting -/// date. +/// Spark `weekofyear()`: ISO week number (Monday-based, week 1 has >3 days). +/// For timestamps, localizes to the given timezone before computing the week. +/// Defaults to UTC when no timezone is provided. pub fn spark_weekofyear(args: &[ColumnarValue]) -> Result { - // First argument as an Arrow array (date/timestamp/string, etc.) let array = args[0].clone().into_array(1)?; - // Determine timezone (for timestamp inputs). Default to UTC to match - // existing behavior when no timezone is provided. let default_tz = chrono_tz::UTC; let tz: Tz = if args.len() > 1 { match &args[1] { @@ -109,7 +52,6 @@ pub fn spark_weekofyear(args: &[ColumnarValue]) -> Result { }; match array.data_type() { - // Timestamp inputs: localize epoch milliseconds before computing ISO week DataType::Timestamp(TimeUnit::Millisecond, _) => { let ts_arr = array .as_any() @@ -126,7 +68,6 @@ pub fn spark_weekofyear(args: &[ColumnarValue]) -> Result { Ok(ColumnarValue::Array(Arc::new(weekofyear))) } - // Non-timestamp inputs: preserve existing Date32-based behavior _ => { let input = cast(&array, &DataType::Date32)?; let input = input @@ -149,37 +90,6 @@ pub fn spark_weekofyear(args: &[ColumnarValue]) -> Result { } } -/// `spark_quarter(date/timestamp/compatible-string)` -/// -/// Simulates Spark's `quarter()` function. -/// Converts the input to `Date32`, extracts the month (1–12), -/// and computes the quarter as `((month - 1) / 3) + 1`. -/// Null values are propagated transparently. -pub fn spark_quarter(args: &[ColumnarValue]) -> Result { - // Cast input to Date32 for compatibility with date_part() - let input = cast(&args[0].clone().into_array(1)?, &DataType::Date32)?; - - // Extract month (1–12) using Arrow's date_part - let month_arr: ArrayRef = date_part(&input, DatePart::Month)?; - let month_arr = month_arr - .as_any() - .downcast_ref::() - .expect("date_part(Month) must return Int32Array"); - - // Compute quarter: ((month - 1) / 3) + 1, preserving NULLs - let quarter = Int32Array::from_iter( - month_arr - .iter() - .map(|opt_m| opt_m.map(|m| ((m - 1) / 3 + 1))), - ); - - Ok(ColumnarValue::Array(Arc::new(quarter))) -} - -// ---- timezone handling (custom, Spark-like) -// --------------------------------------------------- - -/// Parse optional timezone (2nd argument) into `Option`. fn parse_tz(args: &[ColumnarValue]) -> Option { parse_tz_value(args.get(1)) } @@ -209,8 +119,6 @@ fn start_of_local_day_ms(local_date: NaiveDate, tz_opt: Option) -> Option { - // Align with Java's LocalDate.atStartOfDay(zone): choose the first valid - // local time on that date if midnight itself falls in a gap. for minute in 1..=(24 * 60) { let candidate = local_midnight + chrono::Duration::minutes(minute); match tz.from_local_datetime(&candidate) { @@ -289,22 +197,119 @@ fn months_between_value( }) } -/// Return the UTC offset in **seconds** for `epoch_ms` at the given `tz` -/// (DST-aware). +/// DST-aware UTC offset in seconds for a given instant and timezone. fn offset_seconds_at(tz: Tz, epoch_ms: i64) -> i32 { - // Convert epoch_ms to UTC DateTime, then ask the tz for local offset. let dt_utc = Utc.timestamp_millis_opt(epoch_ms).single(); match dt_utc { Some(dt) => tz .offset_from_utc_datetime(&dt.naive_utc()) .fix() .local_minus_utc(), - None => 0, // Gracefully return 0 on invalid inputs to avoid panic. + None => 0, } } +/// Convert timestamp millis to local-timezone Date32 (days since epoch). +fn ts_ms_to_local_date32(ts: &TimestampMillisecondArray, tz: Tz) -> Date32Array { + const MS_PER_SEC: i64 = 1000; + const MS_PER_DAY: i64 = 86_400_000; + + Date32Array::from_iter(ts.iter().map(|opt_ms| { + opt_ms.map(|epoch_ms| { + let local_ms = epoch_ms + offset_seconds_at(tz, epoch_ms) as i64 * MS_PER_SEC; + let local_days = if local_ms >= 0 { + local_ms / MS_PER_DAY + } else { + (local_ms - MS_PER_DAY + 1) / MS_PER_DAY + }; + local_days as i32 + }) + })) +} + +/// Resolve input to a Date32Array, applying timezone adjustment for timestamps. +fn resolve_local_date32(args: &[ColumnarValue]) -> Result { + match parse_tz(args) { + Some(tz) => { + let arr = cast( + &args[0].clone().into_array(1)?, + &DataType::Timestamp(TimeUnit::Millisecond, None), + )?; + let ts = arr + .as_any() + .downcast_ref::() + .expect("cast to Timestamp(Millisecond, None) must succeed"); + Ok(ts_ms_to_local_date32(ts, tz)) + } + None => { + let arr = cast(&args[0].clone().into_array(1)?, &DataType::Date32)?; + Ok(arr + .as_any() + .downcast_ref::() + .expect("cast to Date32 must succeed") + .clone()) + } + } +} + +pub fn spark_year(args: &[ColumnarValue]) -> Result { + let local = resolve_local_date32(args)?; + Ok(ColumnarValue::Array(date_part( + &(Arc::new(local) as ArrayRef), + DatePart::Year, + )?)) +} + +pub fn spark_month(args: &[ColumnarValue]) -> Result { + let local = resolve_local_date32(args)?; + Ok(ColumnarValue::Array(date_part( + &(Arc::new(local) as ArrayRef), + DatePart::Month, + )?)) +} + +pub fn spark_day(args: &[ColumnarValue]) -> Result { + let local = resolve_local_date32(args)?; + Ok(ColumnarValue::Array(date_part( + &(Arc::new(local) as ArrayRef), + DatePart::Day, + )?)) +} + +/// Spark `dayofweek()`: Sunday = 1, Monday = 2, ..., Saturday = 7. +pub fn spark_dayofweek(args: &[ColumnarValue]) -> Result { + let input = resolve_local_date32(args)?; + + // Date32 days since epoch; epoch (1970-01-01) is Thursday (index 4). + // (days + 4) mod 7 → 0=Sun..6=Sat; Spark wants 1=Sun..7=Sat. + let dayofweek = Int32Array::from_iter(input.iter().map(|opt_days| { + opt_days.map(|days| { + let weekday_index = (days as i64 + 4).rem_euclid(7); + weekday_index as i32 + 1 + }) + })); + + Ok(ColumnarValue::Array(Arc::new(dayofweek))) +} + +pub fn spark_quarter(args: &[ColumnarValue]) -> Result { + let local = resolve_local_date32(args)?; + let month_arr: ArrayRef = date_part(&(Arc::new(local) as ArrayRef), DatePart::Month)?; + let month_arr = month_arr + .as_any() + .downcast_ref::() + .expect("date_part(Month) must return Int32Array"); + let quarter = Int32Array::from_iter( + month_arr + .iter() + .map(|opt_m| opt_m.map(|m| ((m - 1) / 3 + 1))), + ); + + Ok(ColumnarValue::Array(Arc::new(quarter))) +} + /// Extract hour/minute/second from a `TimestampMillisecondArray` with optional -/// timezone. `which`: "hour" | "minute" | "second" +/// timezone. fn extract_hms_with_tz( ts: &TimestampMillisecondArray, tz_opt: Option, @@ -317,15 +322,13 @@ fn extract_hms_with_tz( Int32Array::from_iter(ts.iter().map(|opt_ms| { opt_ms.map(|epoch_ms| { - // Localize by applying tz offset in seconds (if provided). let local_ms = if let Some(tz) = tz_opt { let off_sec = offset_seconds_at(tz, epoch_ms) as i64; epoch_ms + off_sec * MS_PER_SEC } else { - epoch_ms // Treat as UTC when tz is None. + epoch_ms }; - // Milliseconds within the day with positive modulo. let mut day_ms = local_ms % MS_PER_DAY; if day_ms < 0 { day_ms += MS_PER_DAY; @@ -341,12 +344,6 @@ fn extract_hms_with_tz( })) } -// ---- Spark-like hour/minute/second built on custom TZ logic -// ----------------------------------- - -/// Extract the HOUR component. We first cast any input to -/// `Timestamp(Millisecond, None)` (to get the physical milliseconds) and then -/// apply our own timezone/DST logic. pub fn spark_hour(args: &[ColumnarValue]) -> Result { let arr_ts_ms_none = cast( &args[0].clone().into_array(1)?, @@ -364,7 +361,6 @@ pub fn spark_hour(args: &[ColumnarValue]) -> Result { )))) } -/// Extract the MINUTE component (same approach as `spark_hour`). pub fn spark_minute(args: &[ColumnarValue]) -> Result { let arr_ts_ms_none = cast( &args[0].clone().into_array(1)?, @@ -382,7 +378,6 @@ pub fn spark_minute(args: &[ColumnarValue]) -> Result { )))) } -/// Extract the SECOND component (same approach as `spark_hour`). pub fn spark_second(args: &[ColumnarValue]) -> Result { let arr_ts_ms_none = cast( &args[0].clone().into_array(1)?, @@ -1077,4 +1072,106 @@ mod tests { assert_eq!(&out, &expected); Ok(()) } + + #[test] + fn test_year_month_day_with_tz_new_york() -> Result<()> { + // 04:30 UTC on Jan 4 is 23:30 on Jan 3 in New York + let epoch = utc_ms(2021, 1, 4, 4, 30, 0); + let ts = Arc::new(TimestampMillisecondArray::from(vec![Some(epoch)])); + let tz = ColumnarValue::Scalar(ScalarValue::Utf8(Some("America/New_York".to_string()))); + + let out_year = + spark_year(&[ColumnarValue::Array(ts.clone()), tz.clone()])?.into_array(1)?; + let out_month = + spark_month(&[ColumnarValue::Array(ts.clone()), tz.clone()])?.into_array(1)?; + let out_day = spark_day(&[ColumnarValue::Array(ts.clone()), tz.clone()])?.into_array(1)?; + + let expected_year: ArrayRef = Arc::new(Int32Array::from(vec![Some(2021)])); + let expected_month: ArrayRef = Arc::new(Int32Array::from(vec![Some(1)])); + let expected_day: ArrayRef = Arc::new(Int32Array::from(vec![Some(3)])); + + assert_eq!(&out_year, &expected_year); + assert_eq!(&out_month, &expected_month); + assert_eq!(&out_day, &expected_day); + Ok(()) + } + + #[test] + fn test_dayofweek_with_tz_new_york() -> Result<()> { + // Same instant as above; Jan 3 2021 is a Sunday (=1 in Spark) + let epoch = utc_ms(2021, 1, 4, 4, 30, 0); + let ts = Arc::new(TimestampMillisecondArray::from(vec![Some(epoch)])); + let tz = ColumnarValue::Scalar(ScalarValue::Utf8(Some("America/New_York".to_string()))); + + let out = spark_dayofweek(&[ColumnarValue::Array(ts), tz])?.into_array(1)?; + let expected: ArrayRef = Arc::new(Int32Array::from(vec![Some(1)])); + assert_eq!(&out, &expected); + Ok(()) + } + + #[test] + fn test_quarter_with_tz_boundary() -> Result<()> { + // 03:00 UTC on Apr 1 is 23:00 on Mar 31 in New York → Q1, not Q2 + let epoch = utc_ms(2021, 4, 1, 3, 0, 0); + let ts = Arc::new(TimestampMillisecondArray::from(vec![Some(epoch)])); + let tz = ColumnarValue::Scalar(ScalarValue::Utf8(Some("America/New_York".to_string()))); + + let out = spark_quarter(&[ColumnarValue::Array(ts), tz])?.into_array(1)?; + let expected: ArrayRef = Arc::new(Int32Array::from(vec![Some(1)])); + assert_eq!(&out, &expected); + Ok(()) + } + + #[test] + fn test_date_parts_with_shanghai() -> Result<()> { + // 17:00 UTC on Dec 31 is 01:00 on Jan 1 in Shanghai → crosses year boundary + let epoch = utc_ms(2021, 12, 31, 17, 0, 0); + let ts = Arc::new(TimestampMillisecondArray::from(vec![Some(epoch)])); + let tz = ColumnarValue::Scalar(ScalarValue::Utf8(Some("Asia/Shanghai".to_string()))); + + let out_year = + spark_year(&[ColumnarValue::Array(ts.clone()), tz.clone()])?.into_array(1)?; + let out_month = + spark_month(&[ColumnarValue::Array(ts.clone()), tz.clone()])?.into_array(1)?; + let out_day = spark_day(&[ColumnarValue::Array(ts.clone()), tz.clone()])?.into_array(1)?; + let out_quarter = + spark_quarter(&[ColumnarValue::Array(ts.clone()), tz.clone()])?.into_array(1)?; + let out_dow = spark_dayofweek(&[ColumnarValue::Array(ts), tz])?.into_array(1)?; + + let expected_year: ArrayRef = Arc::new(Int32Array::from(vec![Some(2022)])); + let expected_month: ArrayRef = Arc::new(Int32Array::from(vec![Some(1)])); + let expected_day: ArrayRef = Arc::new(Int32Array::from(vec![Some(1)])); + let expected_quarter: ArrayRef = Arc::new(Int32Array::from(vec![Some(1)])); + let expected_dow: ArrayRef = Arc::new(Int32Array::from(vec![Some(7)])); // Saturday + + assert_eq!(&out_year, &expected_year); + assert_eq!(&out_month, &expected_month); + assert_eq!(&out_day, &expected_day); + assert_eq!(&out_quarter, &expected_quarter); + assert_eq!(&out_dow, &expected_dow); + Ok(()) + } + + #[test] + fn test_date_parts_null_tz_unchanged() -> Result<()> { + let input = Arc::new(Date32Array::from(vec![Some(0), Some(100), None])); + let null_tz = ColumnarValue::Scalar(ScalarValue::Utf8(None)); + + let out_year = + spark_year(&[ColumnarValue::Array(input.clone()), null_tz.clone()])?.into_array(1)?; + let out_no_tz = spark_year(&[ColumnarValue::Array(input.clone())])?.into_array(1)?; + assert_eq!(&out_year, &out_no_tz); + + let out_month = + spark_month(&[ColumnarValue::Array(input.clone()), null_tz.clone()])?.into_array(1)?; + let out_no_tz = spark_month(&[ColumnarValue::Array(input.clone())])?.into_array(1)?; + assert_eq!(&out_month, &out_no_tz); + + let out_day = + spark_day(&[ColumnarValue::Array(input.clone()), null_tz.clone()])?.into_array(1)?; + let out_no_tz = spark_day(&[ColumnarValue::Array(input)])?.into_array(1)?; + assert_eq!(&out_day, &out_no_tz); + + Ok(()) + } } diff --git a/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronFunctionSuite.scala b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronFunctionSuite.scala index df0708815..921b6ebd2 100644 --- a/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronFunctionSuite.scala +++ b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronFunctionSuite.scala @@ -144,6 +144,30 @@ class AuronFunctionSuite extends AuronQueryTest with BaseAuronSQLSuite { } } + test("date-part functions with non-UTC timezone") { + withTable("t1") { + sql("create table t1(c1 timestamp) using parquet") + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") { + sql("insert into t1 values(timestamp'2021-01-04 04:30:00')") + } + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/New_York") { + checkSparkAnswerAndOperator( + "select year(c1), month(c1), dayofmonth(c1), dayofweek(c1), quarter(c1) from t1") + } + } + } + + test("date-part functions with date input unchanged across timezones") { + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "Asia/Shanghai") { + withTable("t1") { + sql("create table t1(c1 date) using parquet") + sql("insert into t1 values(date'2021-01-04')") + checkSparkAnswerAndOperator( + "select year(c1), month(c1), dayofmonth(c1), dayofweek(c1), quarter(c1) from t1") + } + } + } + test("stddev_samp function with UDAF fallback") { withSQLConf("spark.auron.udafFallback.enable" -> "true") { withTable("t1") { diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala index 7db4374bd..822381c3f 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala @@ -936,15 +936,18 @@ object NativeConverters extends Logging { case e: Nvl => buildScalarFunction(pb.ScalarFunction.Nvl, e.children, e.dataType) - case Year(child) => buildExtScalarFunction("Spark_Year", child :: Nil, IntegerType) - case Month(child) => buildExtScalarFunction("Spark_Month", child :: Nil, IntegerType) - case DayOfMonth(child) => buildExtScalarFunction("Spark_Day", child :: Nil, IntegerType) + case Year(child) => + buildTimePartExt("Spark_Year", child, isPruningExpr, fallback) + case Month(child) => + buildTimePartExt("Spark_Month", child, isPruningExpr, fallback) + case DayOfMonth(child) => + buildTimePartExt("Spark_Day", child, isPruningExpr, fallback) case DayOfWeek(child) => - buildExtScalarFunction("Spark_DayOfWeek", child :: Nil, IntegerType) + buildTimePartExt("Spark_DayOfWeek", child, isPruningExpr, fallback) case WeekOfYear(child) => buildTimePartExt("Spark_WeekOfYear", child, isPruningExpr, fallback) - - case Quarter(child) => buildExtScalarFunction("Spark_Quarter", child :: Nil, IntegerType) + case Quarter(child) => + buildTimePartExt("Spark_Quarter", child, isPruningExpr, fallback) case e: Levenshtein => buildScalarFunction(pb.ScalarFunction.Levenshtein, e.children, e.dataType)