Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions datafusion/expr-common/src/casts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,35 @@ fn is_lossy_temporal_cast(from_type: &DataType, to_type: &DataType) -> bool {
|| (is_date_type(to_type) && from_type.is_temporal())
}

/// Returns true when casting a timestamp from `from_type` to `to_type` loses

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this specific to timestamps? It seems like any narrowing cast would have the same problem for example cast(x as int) = 5 can't be unwrapped to x = 5.0 for floats

@discord9 discord9 Jun 11, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this specific to timestamps? It seems like any narrowing cast would have the same problem for example cast(x as int) = 5 can't be unwrapped to x = 5.0 for floats

was just trying to make this pr small since the last allow list one is too large, considering adding a longer block list

/// timestamp precision.
///
/// This is used by comparison cast unwrapping to avoid rewrites such as
/// `CAST(ts_ns AS timestamp(ms)) = lit_ms` -> `ts_ns = lit_ns`. The original
/// predicate can match any nanosecond value in the same millisecond, while the
/// rewritten predicate only matches the exact millisecond boundary.
pub fn is_timestamp_precision_narrowing_cast(
from_type: &DataType,
to_type: &DataType,
) -> bool {
let (DataType::Timestamp(from_unit, _), DataType::Timestamp(to_unit, _)) =
(from_type, to_type)
else {
return false;
};

timestamp_unit_scale(from_unit) > timestamp_unit_scale(to_unit)
}

fn timestamp_unit_scale(unit: &TimeUnit) -> i128 {
match unit {
TimeUnit::Second => 1,
TimeUnit::Millisecond => MILLISECONDS as i128,
TimeUnit::Microsecond => MICROSECONDS as i128,
TimeUnit::Nanosecond => NANOSECONDS as i128,
}
}

/// Returns true if unwrap_cast_in_comparison supports this numeric type
fn is_supported_numeric_type(data_type: &DataType) -> bool {
matches!(
Expand Down Expand Up @@ -784,6 +813,23 @@ mod tests {
);
}

#[test]
fn test_timestamp_precision_narrowing_cast() {
let ts_ns = DataType::Timestamp(TimeUnit::Nanosecond, None);
let ts_us = DataType::Timestamp(TimeUnit::Microsecond, None);
let ts_ms = DataType::Timestamp(TimeUnit::Millisecond, None);
let ts_s = DataType::Timestamp(TimeUnit::Second, None);

assert!(is_timestamp_precision_narrowing_cast(&ts_ns, &ts_ms));
assert!(is_timestamp_precision_narrowing_cast(&ts_us, &ts_s));
assert!(!is_timestamp_precision_narrowing_cast(&ts_ms, &ts_ns));
assert!(!is_timestamp_precision_narrowing_cast(&ts_ms, &ts_ms));
assert!(!is_timestamp_precision_narrowing_cast(
&DataType::Int64,
&ts_ms
));
}

#[test]
fn test_try_cast_to_type_unsupported() {
// int64 to list
Expand Down
56 changes: 51 additions & 5 deletions datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ use datafusion_common::{Result, ScalarValue};
use datafusion_common::{internal_err, tree_node::Transformed};
use datafusion_expr::{BinaryExpr, lit};
use datafusion_expr::{Cast, Expr, Operator, TryCast, simplify::SimplifyContext};
use datafusion_expr_common::casts::{is_supported_type, try_cast_literal_to_type};
use datafusion_expr_common::casts::{
is_supported_type, is_timestamp_precision_narrowing_cast, try_cast_literal_to_type,
};

pub(super) fn unwrap_cast_in_comparison_for_binary(
info: &SimplifyContext,
Expand Down Expand Up @@ -113,10 +115,14 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary(
match (expr, literal) {
(
Expr::TryCast(TryCast {
expr: left_expr, ..
expr: left_expr,
field,
..
})
| Expr::Cast(Cast {
expr: left_expr, ..
expr: left_expr,
field,
..
}),
Expr::Literal(lit_val, _),
) => {
Expand All @@ -128,6 +134,10 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary(
return false;
};

if is_timestamp_precision_narrowing_cast(&expr_type, field.data_type()) {
return false;
}

if cast_literal_to_type_with_op(lit_val, &expr_type, op).is_some() {
return true;
}
Expand All @@ -146,10 +156,14 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist(
list: &[Expr],
) -> bool {
let (Expr::TryCast(TryCast {
expr: left_expr, ..
expr: left_expr,
field,
..
})
| Expr::Cast(Cast {
expr: left_expr, ..
expr: left_expr,
field,
..
})) = expr
else {
return false;
Expand All @@ -163,6 +177,10 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist(
return false;
}

if is_timestamp_precision_narrowing_cast(&expr_type, field.data_type()) {
return false;
}

for right in list {
let Ok(right_type) = info.get_data_type(right) else {
return false;
Expand Down Expand Up @@ -586,6 +604,25 @@ mod tests {
assert_eq!(optimize_test(expr_lt, &schema), expected);
}

#[test]
fn test_not_unwrap_cast_timestamp_precision_narrowing() {
let schema = expr_test_schema();
let expr_input = cast(col("ts_nano_none"), timestamp_millis_none_type())
.eq(lit_timestamp_millis_none(1));

assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input);
}

#[test]
fn test_unwrap_cast_timestamp_precision_widening() {
let schema = expr_test_schema();
let expr_input = cast(col("ts_millis_none"), timestamp_nano_none_type())
.eq(lit_timestamp_nano_none(1_000_000));
let expected = col("ts_millis_none").eq(lit_timestamp_millis_none(1));

assert_eq!(optimize_test(expr_input, &schema), expected);
}

fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr {
let simplifier = ExprSimplifier::new(
SimplifyContext::builder()
Expand All @@ -607,6 +644,7 @@ mod tests {
Field::new("c5", DataType::Float32, false),
Field::new("c6", DataType::UInt32, false),
Field::new("ts_nano_none", timestamp_nano_none_type(), false),
Field::new("ts_millis_none", timestamp_millis_none_type(), false),
Field::new("ts_nano_utf", timestamp_nano_utc_type(), false),
Field::new("str1", DataType::Utf8, false),
Field::new("largestr", DataType::LargeUtf8, false),
Expand Down Expand Up @@ -643,6 +681,10 @@ mod tests {
lit(ScalarValue::TimestampNanosecond(Some(ts), None))
}

fn lit_timestamp_millis_none(ts: i64) -> Expr {
lit(ScalarValue::TimestampMillisecond(Some(ts), None))
}

fn lit_timestamp_nano_utc(ts: i64) -> Expr {
let utc = Some("+0:00".into());
lit(ScalarValue::TimestampNanosecond(Some(ts), utc))
Expand All @@ -652,6 +694,10 @@ mod tests {
DataType::Timestamp(TimeUnit::Nanosecond, None)
}

fn timestamp_millis_none_type() -> DataType {
DataType::Timestamp(TimeUnit::Millisecond, None)
}

// this is the type that now() returns
fn timestamp_nano_utc_type() -> DataType {
let utc = Some("+0:00".into());
Expand Down
2 changes: 1 addition & 1 deletion datafusion/optimizer/tests/optimizer_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,7 @@ fn extension_node_does_not_block_projection_pruning() -> Result<()> {
Projection: t.a, CAST(t.ts AS Timestamp(ms, "UTC")) AS ts
Filter: __common_expr_3 > TimestampMillisecond(1000, Some("UTC")) AND __common_expr_3 < TimestampMillisecond(2000, Some("UTC"))
Projection: CAST(t.ts AS Timestamp(ms, "UTC")) AS __common_expr_3, t.a, t.ts
TableScan: t projection=[a, ts], partial_filters=[t.ts > TimestampNanosecond(1000000000, None), t.ts < TimestampNanosecond(2000000000, None), CAST(t.ts AS Timestamp(ms, "UTC")) > TimestampMillisecond(1000, Some("UTC")), CAST(t.ts AS Timestamp(ms, "UTC")) < TimestampMillisecond(2000, Some("UTC"))]
TableScan: t projection=[a, ts], partial_filters=[CAST(t.ts AS Timestamp(ms, "UTC")) > TimestampMillisecond(1000, Some("UTC")), CAST(t.ts AS Timestamp(ms, "UTC")) < TimestampMillisecond(2000, Some("UTC"))]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think tje inequality still holds for all possible ns values of t.ts, so this seems like a significantly worse plan as now the column must be cast rather than comparing directly to a constant 🤔

Specifically I think this predicate

CAST(t.ts AS Timestamp(ms, "UTC")) > TimestampMillisecond(1000, Some("UTC"))

Evaluates to true for the exact same values of t.ts (but is more efficient to implement) than this predicate:

t.ts > TimestampNanosecond(1000000000, None)

I do see that this doesn't hold for equality

CAST(t.ts AS Timestamp(ms, "UTC")) = TimestampMillisecond(1000, Some("UTC"))

is NOT the same as

t.ts = TimestampNanosecond(1000000000, None)

A counter example being 1000000001

@discord9 discord9 Jun 11, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually the 1_000_000_001ns basically only allow >= and < to unwrap cast(FOR POSITIVE timestamp), remove them from block list now

Op Same-op rewrite safe? Example value Why results match / differ
= ❌ No 1_000_000_001ns Before: value truncates to 1000ms, so 1000 = 1000 is true. After: 1_000_000_001 = 1_000_000_000 is false.
!= ❌ No 1_000_000_001ns Before: both are in the 1000ms bucket, so 1000 != 1000 is false. After: exact ns values differ, so true.
> ❌ No 1_000_000_001ns Before: value still truncates to 1000ms, so 1000 > 1000 is false. After: 1_000_000_001 > 1_000_000_000 is true.
>= ✅ Yes for this positive bucket boundary N/A It matches the lower edge of the bucket: values below 1_000_000_000ns cast below 1000ms; values at/above it cast to 1000ms or higher.
< ✅ Yes for this positive bucket boundary N/A It is the inverse of the lower-bound check: values before the bucket cast below 1000ms; values inside or after the bucket do not.
<= ❌ No 1_000_000_001ns Before: value truncates to 1000ms, so 1000 <= 1000 is true. After: 1_000_000_001 <= 1_000_000_000 is false.
IS DISTINCT FROM ❌ No 1_000_000_001ns Same as != for non-null values: before false because both are 1000ms; after true because exact ns values differ.
IS NOT DISTINCT FROM ❌ No 1_000_000_001ns Same as = for non-null values: before true because both are 1000ms; after false because exact ns values differ.

"#,
);

Expand Down
70 changes: 66 additions & 4 deletions datafusion/physical-expr/src/simplifier/unwrap_cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ use std::sync::Arc;
use arrow::datatypes::{DataType, Schema};
use datafusion_common::{Result, ScalarValue, tree_node::Transformed};
use datafusion_expr::Operator;
use datafusion_expr_common::casts::try_cast_literal_to_type;
use datafusion_expr_common::casts::{
is_timestamp_precision_narrowing_cast, try_cast_literal_to_type,
};

use crate::PhysicalExpr;
use crate::expressions::{BinaryExpr, CastExpr, Literal, TryCastExpr, lit};
Expand All @@ -60,13 +62,14 @@ fn try_unwrap_cast_binary(
schema: &Schema,
) -> Result<Option<Arc<dyn PhysicalExpr>>> {
// Case 1: cast(left_expr) op literal
if let (Some((inner_expr, _cast_type)), Some(literal)) = (
if let (Some((inner_expr, cast_type)), Some(literal)) = (
extract_cast_info(binary.left()),
binary.right().downcast_ref::<Literal>(),
) && binary.op().supports_propagation()
&& let Some(unwrapped) = try_unwrap_cast_comparison(
Arc::clone(inner_expr),
literal.value(),
cast_type,
*binary.op(),
schema,
)?
Expand All @@ -75,7 +78,7 @@ fn try_unwrap_cast_binary(
}

// Case 2: literal op cast(right_expr)
if let (Some(literal), Some((inner_expr, _cast_type))) = (
if let (Some(literal), Some((inner_expr, cast_type))) = (
binary.left().downcast_ref::<Literal>(),
extract_cast_info(binary.right()),
) {
Expand All @@ -85,6 +88,7 @@ fn try_unwrap_cast_binary(
&& let Some(unwrapped) = try_unwrap_cast_comparison(
Arc::clone(inner_expr),
literal.value(),
cast_type,
swapped_op,
schema,
)?
Expand Down Expand Up @@ -118,12 +122,17 @@ fn extract_cast_info(
fn try_unwrap_cast_comparison(
inner_expr: Arc<dyn PhysicalExpr>,
literal_value: &ScalarValue,
cast_type: &DataType,
op: Operator,
schema: &Schema,
) -> Result<Option<Arc<dyn PhysicalExpr>>> {
// Get the data type of the inner expression
let inner_type = inner_expr.data_type(schema)?;

if is_timestamp_precision_narrowing_cast(&inner_type, cast_type) {
return Ok(None);
}

// Try to cast the literal to the inner expression's type
if let Some(casted_literal) = try_cast_literal_to_type(literal_value, &inner_type) {
let literal_expr = lit(casted_literal);
Expand All @@ -138,7 +147,7 @@ fn try_unwrap_cast_comparison(
mod tests {
use super::*;
use crate::expressions::col;
use arrow::datatypes::Field;
use arrow::datatypes::{Field, TimeUnit};
use datafusion_common::tree_node::TreeNode;

/// Check if an expression is a cast expression
Expand Down Expand Up @@ -548,6 +557,59 @@ mod tests {
assert!(!result.transformed);
}

#[test]
fn test_not_unwrap_timestamp_precision_narrowing() {
let schema = Schema::new(vec![Field::new(
"ts",
DataType::Timestamp(TimeUnit::Nanosecond, None),
false,
)]);

let column_expr = col("ts", &schema).unwrap();
let cast_expr = Arc::new(CastExpr::new(
column_expr,
DataType::Timestamp(TimeUnit::Millisecond, None),
None,
));
let literal_expr = lit(ScalarValue::TimestampMillisecond(Some(1), None));
let binary_expr =
Arc::new(BinaryExpr::new(cast_expr, Operator::Eq, literal_expr));

let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();

assert!(!result.transformed);
}

#[test]
fn test_unwrap_timestamp_precision_widening() {
let schema = Schema::new(vec![Field::new(
"ts",
DataType::Timestamp(TimeUnit::Millisecond, None),
false,
)]);

let column_expr = col("ts", &schema).unwrap();
let cast_expr = Arc::new(CastExpr::new(
column_expr,
DataType::Timestamp(TimeUnit::Nanosecond, None),
None,
));
let literal_expr = lit(ScalarValue::TimestampNanosecond(Some(1_000_000), None));
let binary_expr =
Arc::new(BinaryExpr::new(cast_expr, Operator::Eq, literal_expr));

let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();

assert!(result.transformed);
let optimized_binary = result.data.downcast_ref::<BinaryExpr>().unwrap();
assert!(!is_cast_expr(optimized_binary.left()));
let right_literal = optimized_binary.right().downcast_ref::<Literal>().unwrap();
assert_eq!(
right_literal.value(),
&ScalarValue::TimestampMillisecond(Some(1), None)
);
}

#[test]
fn test_complex_nested_expression() {
let schema = test_schema();
Expand Down
Loading