Skip to content

Commit b0855a5

Browse files
committed
[AURON #2130] Implement native function of weekofyear
Signed-off-by: weimingdiit <weimingdiit@gmail.com>
1 parent 54bd43c commit b0855a5

2 files changed

Lines changed: 44 additions & 6 deletions

File tree

native-engine/datafusion-ext-functions/src/spark_dates.rs

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use arrow::{
2323
use chrono::{Duration, TimeZone, Utc, prelude::*};
2424
use chrono_tz::Tz;
2525
use datafusion::{
26-
common::{Result, ScalarValue},
26+
common::{DataFusionError, Result, ScalarValue},
2727
physical_plan::ColumnarValue,
2828
};
2929
use datafusion_ext_commons::arrow::cast::cast;
@@ -80,9 +80,11 @@ pub fn spark_dayofweek(args: &[ColumnarValue]) -> Result<ColumnarValue> {
8080
///
8181
/// For `Timestamp` inputs, this function interprets epoch milliseconds in the
8282
/// provided timezone (if any) before deriving the calendar date and ISO week.
83-
/// If no timezone is provided, `UTC` is used by default. For `Date` and
84-
/// compatible string inputs, the behavior is unchanged: the value is cast to
85-
/// `Date32` and the ISO week is computed from the resulting date.
83+
/// If no timezone is provided, `UTC` is used by default. If an invalid
84+
/// timezone string is provided, the function returns an execution error.
85+
/// For `Date` and compatible string inputs, the behavior is unchanged: the
86+
/// value is cast to `Date32` and the ISO week is computed from the resulting
87+
/// date.
8688
pub fn spark_weekofyear(args: &[ColumnarValue]) -> Result<ColumnarValue> {
8789
// First argument as an Arrow array (date/timestamp/string, etc.)
8890
let array = args[0].clone().into_array(1)?;
@@ -94,7 +96,9 @@ pub fn spark_weekofyear(args: &[ColumnarValue]) -> Result<ColumnarValue> {
9496
match &args[1] {
9597
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)))
9698
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => {
97-
s.parse::<Tz>().unwrap_or(default_tz)
99+
s.parse::<Tz>().map_err(|_| {
100+
DataFusionError::Execution(format!("spark_weekofyear invalid timezone: {s}"))
101+
})?
98102
}
99103
_ => default_tz,
100104
}
@@ -401,6 +405,40 @@ mod tests {
401405
Ok(())
402406
}
403407

408+
#[test]
409+
fn test_spark_weekofyear_with_timezone() -> Result<()> {
410+
// In America/New_York:
411+
// 2021-01-04 04:30:00 UTC -> 2021-01-03 23:30:00 local -> ISO week 53
412+
// 2021-01-04 05:30:00 UTC -> 2021-01-04 00:30:00 local -> ISO week 1
413+
let input = Arc::new(TimestampMillisecondArray::from(vec![
414+
Some(utc_ms(2021, 1, 4, 4, 30, 0)),
415+
Some(utc_ms(2021, 1, 4, 5, 30, 0)),
416+
None,
417+
]));
418+
let args = vec![
419+
ColumnarValue::Array(input),
420+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("America/New_York".to_string()))),
421+
];
422+
let expected_ret: ArrayRef = Arc::new(Int32Array::from(vec![Some(53), Some(1), None]));
423+
assert_eq!(&spark_weekofyear(&args)?.into_array(3)?, &expected_ret);
424+
Ok(())
425+
}
426+
427+
#[test]
428+
fn test_spark_weekofyear_invalid_timezone() {
429+
let input = Arc::new(TimestampMillisecondArray::from(vec![Some(utc_ms(
430+
2021, 1, 4, 5, 30, 0,
431+
))]));
432+
let args = vec![
433+
ColumnarValue::Array(input),
434+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("Mars/Olympus".to_string()))),
435+
];
436+
437+
let err =
438+
spark_weekofyear(&args).expect_err("spark_weekofyear should fail for invalid timezone");
439+
assert!(err.to_string().contains("invalid timezone"));
440+
}
441+
404442
#[test]
405443
fn test_spark_quarter_basic() -> Result<()> {
406444
// Date32 days relative to 1970-01-01:

spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -941,7 +941,7 @@ object NativeConverters extends Logging {
941941
case DayOfWeek(child) =>
942942
buildExtScalarFunction("Spark_DayOfWeek", child :: Nil, IntegerType)
943943
case WeekOfYear(child) =>
944-
buildExtScalarFunction("Spark_WeekOfYear", child :: Nil, IntegerType)
944+
buildTimePartExt("Spark_WeekOfYear", child, isPruningExpr, fallback)
945945

946946
case Quarter(child) => buildExtScalarFunction("Spark_Quarter", child :: Nil, IntegerType)
947947

0 commit comments

Comments
 (0)