Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions native-engine/datafusion-ext-functions/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ pub fn create_auron_ext_function(
"Spark_Month" => Arc::new(spark_dates::spark_month),
"Spark_Day" => Arc::new(spark_dates::spark_day),
"Spark_DayOfWeek" => Arc::new(spark_dates::spark_dayofweek),
"Spark_WeekOfYear" => Arc::new(spark_dates::spark_weekofyear),
"Spark_Quarter" => Arc::new(spark_dates::spark_quarter),
"Spark_Hour" => Arc::new(spark_dates::spark_hour),
"Spark_Minute" => Arc::new(spark_dates::spark_minute),
Expand Down
136 changes: 134 additions & 2 deletions native-engine/datafusion-ext-functions/src/spark_dates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ use arrow::{
compute::{DatePart, date_part},
datatypes::{DataType, TimeUnit},
};
use chrono::{TimeZone, Utc, prelude::*};
use chrono::{Duration, TimeZone, Utc, prelude::*};
use chrono_tz::Tz;
use datafusion::{
common::{Result, ScalarValue},
common::{DataFusionError, Result, ScalarValue},
physical_plan::ColumnarValue,
};
use datafusion_ext_commons::arrow::cast::cast;
Expand Down Expand Up @@ -72,6 +72,81 @@ pub fn spark_dayofweek(args: &[ColumnarValue]) -> Result<ColumnarValue> {
Ok(ColumnarValue::Array(Arc::new(dayofweek)))
}

/// `spark_weekofyear(date/timestamp/compatible-string[, timezone])`
///
/// Matches Spark's `weekofyear()` semantics:
/// ISO week numbering, with Monday as the first day of the week,
/// and week 1 defined as the first week with more than 3 days.
///
/// For `Timestamp` inputs, this function interprets epoch milliseconds in the
/// provided timezone (if any) before deriving the calendar date and ISO week.
/// If no timezone is provided, `UTC` is used by default. If an invalid
/// timezone string is provided, the function returns an execution error.
/// For `Date` and compatible string inputs, the behavior is unchanged: the
/// value is cast to `Date32` and the ISO week is computed from the resulting
/// date.
pub fn spark_weekofyear(args: &[ColumnarValue]) -> Result<ColumnarValue> {
// First argument as an Arrow array (date/timestamp/string, etc.)
let array = args[0].clone().into_array(1)?;

// Determine timezone (for timestamp inputs). Default to UTC to match
// existing behavior when no timezone is provided.
let default_tz = chrono_tz::UTC;
let tz: Tz = if args.len() > 1 {
match &args[1] {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => {
s.parse::<Tz>().map_err(|_| {
DataFusionError::Execution(format!("spark_weekofyear invalid timezone: {s}"))
})?
}
_ => default_tz,
}
} else {
default_tz
};

match array.data_type() {
// Timestamp inputs: localize epoch milliseconds before computing ISO week
DataType::Timestamp(TimeUnit::Millisecond, _) => {
let ts_arr = array
.as_any()
.downcast_ref::<TimestampMillisecondArray>()
.expect("internal cast to TimestampMillisecondArray must succeed");

let weekofyear = Int32Array::from_iter(ts_arr.iter().map(|opt_ms| {
opt_ms.and_then(|ms| {
tz.timestamp_millis_opt(ms)
.single()
.map(|dt| dt.date_naive().iso_week().week() as i32)
})
}));

Ok(ColumnarValue::Array(Arc::new(weekofyear)))
}
// Non-timestamp inputs: preserve existing Date32-based behavior
_ => {
let input = cast(&array, &DataType::Date32)?;
let input = input
.as_any()
.downcast_ref::<Date32Array>()
.expect("internal cast to Date32 must succeed");

let epoch =
NaiveDate::from_ymd_opt(1970, 1, 1).expect("1970-01-01 must be a valid date");
let weekofyear = Int32Array::from_iter(input.iter().map(|opt_days| {
opt_days.and_then(|days| {
epoch
.checked_add_signed(Duration::days(days as i64))
.map(|date| date.iso_week().week() as i32)
})
}));

Ok(ColumnarValue::Array(Arc::new(weekofyear)))
}
}
}

/// `spark_quarter(date/timestamp/compatible-string)`
///
/// Simulates Spark's `quarter()` function.
Expand Down Expand Up @@ -307,6 +382,63 @@ mod tests {
Ok(())
}

#[test]
fn test_spark_weekofyear() -> Result<()> {
let input = Arc::new(Date32Array::from(vec![
Some(0),
Some(4017),
Some(16801),
Some(17167),
Some(14455),
None,
]));
let args = vec![ColumnarValue::Array(input)];
let expected_ret: ArrayRef = Arc::new(Int32Array::from(vec![
Some(1),
Some(1),
Some(53),
Some(52),
Some(31),
None,
]));
assert_eq!(&spark_weekofyear(&args)?.into_array(1)?, &expected_ret);
Ok(())
}

#[test]
fn test_spark_weekofyear_with_timezone() -> Result<()> {
// In America/New_York:
// 2021-01-04 04:30:00 UTC -> 2021-01-03 23:30:00 local -> ISO week 53
// 2021-01-04 05:30:00 UTC -> 2021-01-04 00:30:00 local -> ISO week 1
let input = Arc::new(TimestampMillisecondArray::from(vec![
Some(utc_ms(2021, 1, 4, 4, 30, 0)),
Some(utc_ms(2021, 1, 4, 5, 30, 0)),
None,
]));
let args = vec![
ColumnarValue::Array(input),
ColumnarValue::Scalar(ScalarValue::Utf8(Some("America/New_York".to_string()))),
];
let expected_ret: ArrayRef = Arc::new(Int32Array::from(vec![Some(53), Some(1), None]));
assert_eq!(&spark_weekofyear(&args)?.into_array(3)?, &expected_ret);
Ok(())
}

#[test]
fn test_spark_weekofyear_invalid_timezone() {
let input = Arc::new(TimestampMillisecondArray::from(vec![Some(utc_ms(
2021, 1, 4, 5, 30, 0,
))]));
let args = vec![
ColumnarValue::Array(input),
ColumnarValue::Scalar(ScalarValue::Utf8(Some("Mars/Olympus".to_string()))),
];

let err =
spark_weekofyear(&args).expect_err("spark_weekofyear should fail for invalid timezone");
assert!(err.to_string().contains("invalid timezone"));
}

#[test]
fn test_spark_quarter_basic() -> Result<()> {
// Date32 days relative to 1970-01-01:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,35 @@ class AuronFunctionSuite extends AuronQueryTest with BaseAuronSQLSuite {
}
}

test("weekofyear function") {
withSQLConf("spark.sql.session.timeZone" -> "America/Los_Angeles") {
withTable("t1") {
sql(
"create table t1(c1 date, c2 date, c3 date, c4 date, c5 timestamp, c6 string) using parquet")
sql("""insert into t1 values (
| date'2009-07-30',
| date'1980-12-31',
| date'2016-01-01',
| date'2017-01-01',
| timestamp'2016-01-03 23:30:00',
| '2016-01-01'
|)""".stripMargin)

val query =
"""select
| weekofyear(c1),
| weekofyear(c2),
| weekofyear(c3),
| weekofyear(c4),
| weekofyear(c5),
| weekofyear(c6)
|from t1
|""".stripMargin
checkSparkAnswerAndOperator(query)
}
}
}

test("round function with varying scales for intPi") {
withTable("t2") {
sql("CREATE TABLE t2 (c1 INT) USING parquet")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -941,6 +941,9 @@ object NativeConverters extends Logging {
case DayOfMonth(child) => buildExtScalarFunction("Spark_Day", child :: Nil, IntegerType)
case DayOfWeek(child) =>
buildExtScalarFunction("Spark_DayOfWeek", child :: Nil, IntegerType)
case WeekOfYear(child) =>
buildTimePartExt("Spark_WeekOfYear", child, isPruningExpr, fallback)

case Quarter(child) => buildExtScalarFunction("Spark_Quarter", child :: Nil, IntegerType)

case e: Levenshtein =>
Expand Down
Loading