From 19da46378cb95aacf3ca93fc19a01be853110a62 Mon Sep 17 00:00:00 2001 From: klion26 Date: Wed, 10 Jun 2026 14:19:22 +0800 Subject: [PATCH 1/2] [Variant] Align cast logic related to utf8 with arrow-cast kernel --- arrow-cast/src/display.rs | 3 +- arrow-cast/src/parse.rs | 9 +- .../src/type_conversion.rs | 332 ++++++++++++++- .../src/variant_to_arrow.rs | 25 +- parquet-variant/src/variant.rs | 391 ++++++++++++------ 5 files changed, 633 insertions(+), 127 deletions(-) diff --git a/arrow-cast/src/display.rs b/arrow-cast/src/display.rs index 0460c0c96b55..7bbea9508ffb 100644 --- a/arrow-cast/src/display.rs +++ b/arrow-cast/src/display.rs @@ -733,7 +733,8 @@ macro_rules! decimal_display { decimal_display!(Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type); -fn write_timestamp( +/// Writes a timestamp value to the output using the given representation. +pub fn write_timestamp( f: &mut dyn Write, naive: NaiveDateTime, timezone: Option, diff --git a/arrow-cast/src/parse.rs b/arrow-cast/src/parse.rs index e7c6f90e75a1..0d5381d3aa70 100644 --- a/arrow-cast/src/parse.rs +++ b/arrow-cast/src/parse.rs @@ -611,7 +611,14 @@ fn parse_extended_ymd(string: &str) -> Option<(i32, u32, u32)> { Some((year, month, day)) } -fn parse_date(string: &str) -> Option { +/// Parse a given date string into a `NaiveDate`. +/// +/// Supports: +/// - ISO date strings: `"2026-06-10"`, `"2026-6-9"`, `"2026-06-9"`, `"2026-6-09"` +/// - ISO extended (signed) year date strings: `"+2026-06-10"`, `"-2026-06-10"` +/// - No hyphen date strings: `"20260610"` +/// - Datetime strings:`"2026-06-10T14:23:45"` +pub fn parse_date(string: &str) -> Option { // If the date has an extended (signed) year such as "+10999-12-31" or "-0012-05-06" // // According to [ISO 8601], years have: diff --git a/parquet-variant-compute/src/type_conversion.rs b/parquet-variant-compute/src/type_conversion.rs index 2255d4316b25..596b7eed6055 100644 --- a/parquet-variant-compute/src/type_conversion.rs +++ b/parquet-variant-compute/src/type_conversion.rs @@ -23,11 +23,13 @@ use arrow::compute::{ }; use arrow::datatypes::{ self, ArrowPrimitiveType, ArrowTimestampType, Decimal32Type, Decimal64Type, Decimal128Type, - DecimalType, + DecimalType, format_decimal_str, }; use arrow::error::{ArrowError, Result}; +use arrow::util::display::{lexical_to_string, write_timestamp}; use chrono::Timelike; use parquet_variant::{Variant, VariantDecimal4, VariantDecimal8, VariantDecimal16}; +use std::fmt::Write; /// Extension trait for Arrow primitive types that can extract their native value from a Variant pub(crate) trait PrimitiveFromVariant: ArrowPrimitiveType { @@ -287,6 +289,100 @@ where } } +pub(crate) fn variant_to_string(variant: &Variant<'_, '_>) -> Option { + match variant { + Variant::String(s) => Some(s.to_string()), + Variant::ShortString(s) => Some(s.to_string()), + Variant::BooleanTrue => Some("true".into()), + Variant::BooleanFalse => Some("false".into()), + Variant::Int8(i) => Some(lexical_to_string(*i)), + Variant::Int16(i) => Some(lexical_to_string(*i)), + Variant::Int32(i) => Some(lexical_to_string(*i)), + Variant::Int64(i) => Some(lexical_to_string(*i)), + Variant::Float(f) => Some(lexical_to_string(*f)), + Variant::Double(f) => Some(lexical_to_string(*f)), + Variant::Decimal4(d) => { + let value_str = d.integer().to_string(); + Some(format_decimal_str( + &value_str, + value_str.len(), + d.scale() as _, + )) + } + Variant::Decimal8(d) => { + let value_str = d.integer().to_string(); + Some(format_decimal_str( + &value_str, + value_str.len(), + d.scale() as _, + )) + } + Variant::Decimal16(d) => { + let value_str = d.integer().to_string(); + Some(format_decimal_str( + &value_str, + value_str.len(), + d.scale() as _, + )) + } + Variant::Date(d) => { + let mut ret_string = String::new(); + let _ = write!(ret_string, "{d:?}"); + Some(ret_string) + } + Variant::Time(t) => { + let mut ret_string = String::new(); + let _ = write!(ret_string, "{t:?}"); + Some(ret_string) + } + Variant::TimestampMicros(t) => { + let mut out = String::new(); + let _ = write_timestamp(&mut out, t.naive_utc(), "+00:00".parse().ok(), None); + Some(out) + } + Variant::TimestampNtzMicros(t) => { + let mut out = String::new(); + let _ = write_timestamp(&mut out, *t, None, None); + Some(out) + } + Variant::TimestampNanos(t) => { + let mut out = String::new(); + let _ = write_timestamp(&mut out, t.naive_utc(), "+00:00".parse().ok(), None); + Some(out) + } + Variant::TimestampNtzNanos(t) => { + let mut out = String::new(); + let _ = write_timestamp(&mut out, *t, None, None); + Some(out) + } + Variant::Uuid(u) => Some(u.to_string()), + Variant::Binary(v) => std::str::from_utf8(v).ok().map(|s| s.to_string()), + Variant::List(l) => Some(cast_list_to_string(l.iter())), + _ => None, + } +} + +fn cast_list_to_string<'m, 'v>(mut iter: impl Iterator>) -> String { + let mut ret_str = String::new(); + let _ = ret_str.write_char('['); + + if let Some(item) = iter.next() { + let _ = write!(ret_str, "{}", variant_to_string(&item).unwrap_or_default()); + } + + for item in iter { + let _ = write!( + ret_str, + ", {}", + variant_to_string(&item).unwrap_or_default() + ); + } + + let _ = ret_str.write_char(']'); + + ret_str +} + /// Convert the value at a specific index in the given array into a `Variant`. macro_rules! non_generic_conversion_single_value { ($array:expr, $cast_fn:expr, $index:expr) => {{ @@ -345,3 +441,237 @@ macro_rules! primitive_conversion_single_value { }}; } pub(crate) use primitive_conversion_single_value; + +#[cfg(test)] +mod tests { + use crate::type_conversion::variant_to_string; + use arrow::array::{ + Array, BooleanArray, Date32Array, Int32Builder, ListBuilder, StringArray, + Time64MicrosecondArray, TimestampMicrosecondArray, TimestampNanosecondArray, + }; + use arrow::compute::cast; + use arrow_schema::DataType; + use chrono::{DateTime, NaiveDate, NaiveTime}; + use parquet_variant::{Variant, VariantBuilder, VariantBuilderExt}; + use std::iter::zip; + + #[test] + fn test_compatible_cast_logic_with_cast_kernel() { + // boolean -> string + let boolean_array = BooleanArray::from(vec![Some(true), Some(false)]); + let cast_array = cast(&boolean_array, &DataType::Utf8).unwrap(); + let boolean_utf8_array = cast_array.as_any().downcast_ref::().unwrap(); + let expected_array = vec![ + variant_to_string(&Variant::BooleanTrue), + variant_to_string(&Variant::BooleanFalse), + ]; + for (a, b) in zip(boolean_utf8_array, expected_array) { + assert_eq!(a.unwrap(), b.unwrap()); + } + + // date -> string + let epoch_days = [-10, 0, 18628]; + let date_array = epoch_days + .iter() + .map(|d| Variant::Date(NaiveDate::from_epoch_days(*d).unwrap())) + .collect::>(); + let variant_as_string_array = date_array + .iter() + .map(|v| variant_to_string(v)) + .collect::>>(); + + let date32_array = Date32Array::from_iter_values(epoch_days); + let date32_cast_array = cast(&date32_array, &DataType::Utf8).unwrap(); + let date32_utf8_array = date32_cast_array + .as_any() + .downcast_ref::() + .unwrap(); + for (a, b) in zip(variant_as_string_array, date32_utf8_array) { + assert_eq!(a.unwrap(), b.unwrap()); + } + + // time -> string + let time_tuples = [(123, 0), (123, 456789000), (12345, 456789000)]; + let time_array = time_tuples + .iter() + .map(|tuple| { + Variant::Time( + NaiveTime::from_num_seconds_from_midnight_opt(tuple.0, tuple.1).unwrap(), + ) + }) + .collect::>(); + let time_variant_as_string_array = time_array + .iter() + .map(|v| variant_to_string(v)) + .collect::>>(); + + let time_micro_array = Time64MicrosecondArray::from_iter( + time_tuples + .iter() + .map(|item| Some(item.0 as i64 * 1_000_000 + item.1 as i64 / 1000)), + ); + + let time_micro_cast_array = cast(&time_micro_array, &DataType::Utf8).unwrap(); + let time_micro_utf8_array = time_micro_cast_array + .as_any() + .downcast_ref::() + .unwrap(); + + for (a, b) in zip(time_variant_as_string_array, time_micro_utf8_array) { + assert_eq!(a.unwrap(), b.unwrap()); + } + + // timestamp(micro) -> string + let micros = [-123456, 123456, 45678]; + let timestamp_micro_array = micros + .iter() + .map(|m| Variant::TimestampMicros(DateTime::from_timestamp_micros(*m).unwrap())) + .collect::>(); + let timestamp_micro_as_string_array = timestamp_micro_array + .iter() + .map(|v| variant_to_string(v)) + .collect::>>(); + + let timestamp_micro_arrow_array = + TimestampMicrosecondArray::from_iter_values(micros).with_timezone("+00:00"); + let timestamp_micro_arrow_cast_array = + cast(×tamp_micro_arrow_array, &DataType::Utf8).unwrap(); + let timestamp_micro_utf8_array = timestamp_micro_arrow_cast_array + .as_any() + .downcast_ref::() + .unwrap(); + for (a, b) in zip(timestamp_micro_as_string_array, timestamp_micro_utf8_array) { + assert_eq!(a.unwrap(), b.unwrap()); + } + + // timestamp(micro) ntz -> string + let micros_ntz = [-123456, 123456, 45678]; + let timestamp_micro_ntz_variant_array = micros_ntz + .iter() + .map(|m| { + Variant::TimestampNtzMicros( + DateTime::from_timestamp_micros(*m).unwrap().naive_utc(), + ) + }) + .collect::>(); + let timestamp_micro_ntz_variant_as_string_array = timestamp_micro_ntz_variant_array + .iter() + .map(|v| variant_to_string(v)) + .collect::>>(); + + let timestamp_micro_ntz_arrow_array = + TimestampMicrosecondArray::from_iter_values(micros_ntz); + let timestamp_micro_ntz_arrow_cast_array = + cast(×tamp_micro_ntz_arrow_array, &DataType::Utf8).unwrap(); + let timestamp_micro_ntz_utf8_array = timestamp_micro_ntz_arrow_cast_array + .as_any() + .downcast_ref::() + .unwrap(); + + for (a, b) in zip( + timestamp_micro_ntz_variant_as_string_array, + timestamp_micro_ntz_utf8_array, + ) { + assert_eq!(a.unwrap(), b.unwrap()); + } + + // timestamp(nano) -> string + let nanos = [-2_208_936_075_000_000_000, 0, 1_662_921_288_000_000_000]; + let timestamp_nano_variant_array = nanos + .iter() + .map(|n| Variant::TimestampNanos(DateTime::from_timestamp_nanos(*n))) + .collect::>(); + let timestamp_nano_as_string_array = timestamp_nano_variant_array + .iter() + .map(|v| variant_to_string(v)) + .collect::>>(); + + let timestamp_nano_arrow_array = + TimestampNanosecondArray::from_iter_values(nanos).with_timezone("+00:00"); + let timestamp_nano_arrow_cast_array = + cast(×tamp_nano_arrow_array, &DataType::Utf8).unwrap(); + let timestamp_nano_cast_utf8_array = timestamp_nano_arrow_cast_array + .as_any() + .downcast_ref::() + .unwrap(); + for (a, b) in zip( + timestamp_nano_cast_utf8_array, + timestamp_nano_as_string_array, + ) { + assert_eq!(a.unwrap(), b.unwrap()); + } + + // timestamp(nano) ntz -> string + let nanos_ntz = [-2_208_936_075_000_000_000i64, 0, 1_662_921_288_000_000_000]; + let timestamp_nano_ntz_variant_array = nanos_ntz + .iter() + .map(|n| Variant::TimestampNtzNanos(DateTime::from_timestamp_nanos(*n).naive_utc())) + .collect::>(); + + let timestamp_nano_ntz_variant_as_string_array = timestamp_nano_ntz_variant_array + .iter() + .map(|v| variant_to_string(v)) + .collect::>>(); + + let timestamp_nano_ntz_arrow_array = TimestampNanosecondArray::from_iter_values(nanos_ntz); + + let timestamp_nano_ntz_arrow_cast_array = + cast(×tamp_nano_ntz_arrow_array, &DataType::Utf8).unwrap(); + let timestamp_nano_ntz_utf8_array = timestamp_nano_ntz_arrow_cast_array + .as_any() + .downcast_ref::() + .unwrap(); + + for (a, b) in zip( + timestamp_nano_ntz_variant_as_string_array, + timestamp_nano_ntz_utf8_array, + ) { + assert_eq!(a.unwrap(), b.unwrap()); + } + + // list -> string + let mut variant_builder = VariantBuilder::new(); + let mut list_builder = variant_builder.new_list(); + list_builder.append_value(123); + list_builder.append_value(234); + list_builder.append_null(); + list_builder.append_value(345); + list_builder.finish(); + let (metadata, value) = variant_builder.finish(); + let variant_list = Variant::new(&metadata, &value); + let variant_list_as_string = variant_to_string(&variant_list); + + let inner_builder = Int32Builder::new(); + let mut builder = ListBuilder::new(inner_builder); + builder.values().append_value(123); + builder.values().append_value(234); + builder.values().append_null(); + builder.values().append_value(345); + builder.append(true); + let list_arrow_array = builder.finish(); + let cast_array = cast(&list_arrow_array, &DataType::Utf8).unwrap(); + let arrow_list_cast_utf8_array = cast_array.as_any().downcast_ref::().unwrap(); + + assert_eq!(arrow_list_cast_utf8_array.len(), 1); + assert_eq!( + variant_list_as_string.unwrap(), + arrow_list_cast_utf8_array.value(0) + ); + } + + #[test] + fn test_variant_to_string_list_mixed_types() { + // Test mixed types list + let mut variant_builder = VariantBuilder::new(); + let mut list_builder = variant_builder.new_list(); + list_builder.append_value(42i32); + list_builder.append_value("text"); + list_builder.append_value(true); + list_builder.finish(); + let (metadata, value) = variant_builder.finish(); + let variant_list = Variant::new(&metadata, &value); + + let result = variant_to_string(&variant_list).unwrap(); + assert_eq!(result, "[42, text, true]"); + } +} diff --git a/parquet-variant-compute/src/variant_to_arrow.rs b/parquet-variant-compute/src/variant_to_arrow.rs index 9841da555da0..f59a6a06a1b6 100644 --- a/parquet-variant-compute/src/variant_to_arrow.rs +++ b/parquet-variant-compute/src/variant_to_arrow.rs @@ -20,7 +20,7 @@ use crate::shred_variant::{ make_variant_to_shredded_variant_arrow_row_builder, }; use crate::type_conversion::{ - PrimitiveFromVariant, TimestampFromVariant, variant_cast_with_options, + PrimitiveFromVariant, TimestampFromVariant, variant_cast_with_options, variant_to_string, variant_to_unscaled_decimal, }; use crate::variant_array::ShreddedVariantFieldArray; @@ -762,6 +762,20 @@ macro_rules! define_variant_to_primitive_builder { |$array_param:ident $(, $field:ident: $field_type:ty)?| -> $builder_name:ident $(< $array_type:ty >)? { $init_expr: expr }, |$value: ident| $value_transform:expr, type_name: $type_name:expr) => { + define_variant_to_primitive_builder!( + struct $name<$lifetime $(, $generic: $bound )?> + |$array_param $(, $field: $field_type)?| -> $builder_name $(< $array_type >)? { $init_expr }, + |$value| $value_transform, + type_name: $type_name, + append_value: |builder, v| builder.append_value(v) + ); + }; + + (struct $name:ident<$lifetime:lifetime $(, $generic:ident: $bound:path )?> + |$array_param:ident $(, $field:ident: $field_type:ty)?| -> $builder_name:ident $(< $array_type:ty >)? { $init_expr: expr }, + |$value: ident| $value_transform:expr, + type_name: $type_name:expr, + append_value: |$builder:ident, $append_value:ident| $append_expr:expr) => { pub(crate) struct $name<$lifetime $(, $generic : $bound )?> { builder: $builder_name $(<$array_type>)?, @@ -793,7 +807,9 @@ macro_rules! define_variant_to_primitive_builder { |$value| $value_transform, ) { Ok(Some(v)) => { - self.builder.append_value(v); + let $builder = &mut self.builder; + let $append_value = v; + $append_expr; Ok(true) } Ok(None) => { @@ -824,8 +840,9 @@ macro_rules! define_variant_to_primitive_builder { define_variant_to_primitive_builder!( struct VariantToStringArrowBuilder<'a, B: StringLikeArrayBuilder> |capacity| -> B { B::with_capacity(capacity) }, - |value| value.as_string(), - type_name: B::type_name() + |value| variant_to_string(value), + type_name: B::type_name(), + append_value: |builder, v| builder.append_value(&v) ); define_variant_to_primitive_builder!( diff --git a/parquet-variant/src/variant.rs b/parquet-variant/src/variant.rs index c9f175c3a610..3768191ea7dd 100644 --- a/parquet-variant/src/variant.rs +++ b/parquet-variant/src/variant.rs @@ -29,14 +29,19 @@ use crate::decoder::{ }; use crate::path::{VariantPath, VariantPathElement}; use crate::utils::{first_byte_from_slice, slice_from_slice}; -use arrow::array::ArrowNativeTypeOp; +use arrow::array::{ArrowNativeTypeOp, ArrowPrimitiveType}; use arrow::compute::{ DecimalCast, cast_num_to_bool, cast_single_string_to_boolean_default, num_cast, parse_string_to_decimal_native, single_bool_to_numeric, single_decimal_to_float_lossy, single_float_to_decimal, }; -use arrow::datatypes::{Decimal32Type, Decimal64Type, Decimal128Type, DecimalType}; +use arrow::datatypes::{ + Decimal32Type, Decimal64Type, Decimal128Type, DecimalType, Float16Type, Float32Type, + Float64Type, Int8Type, Int16Type, Int32Type, Int64Type, Time64MicrosecondType, UInt8Type, + UInt16Type, UInt32Type, UInt64Type, +}; +use arrow::compute::kernels::cast_utils::{Parser, parse_date, string_to_datetime}; use arrow_schema::ArrowError; use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Timelike, Utc}; use num_traits::NumCast; @@ -582,8 +587,12 @@ impl<'m, 'v> Variant<'m, 'v> { /// Converts this variant to a `NaiveDate` if possible. /// - /// Returns `Some(NaiveDate)` for date variants, - /// `None` for non-date variants. + /// Returns `Some(NaiveDate)` for date variants and string variants + /// that can be parsed as dates. Supports ISO date strings (`"2025-04-12"`), + /// compact date strings (`"20250412"`), flexible formats (`"2025-4-2"`), + /// and datetime strings (`"2025-04-12T10:30:00Z"`, date part extracted). + /// + /// Returns `None` for non-date, non-string variants or unparseable strings. /// /// # Examples /// @@ -596,15 +605,28 @@ impl<'m, 'v> Variant<'m, 'v> { /// let v1 = Variant::from(date); /// assert_eq!(v1.as_naive_date(), Some(date)); /// - /// // but not from other variants - /// let v2 = Variant::from("hello!"); - /// assert_eq!(v2.as_naive_date(), None); + /// // or from an ISO date string + /// let v2 = Variant::from("2025-04-12"); + /// assert_eq!(v2.as_naive_date(), Some(date)); + /// + /// // or from a compact date string + /// let v3 = Variant::from("20250412"); + /// assert_eq!(v3.as_naive_date(), Some(date)); + /// + /// // or from a datetime string (date part is extracted) + /// let v4 = Variant::from("2025-04-12T10:30:00Z"); + /// assert_eq!(v4.as_naive_date(), Some(date)); + /// + /// // but not from unparseable strings + /// let v5 = Variant::from("hello!"); + /// assert_eq!(v5.as_naive_date(), None); /// ``` pub fn as_naive_date(&self) -> Option { - if let Variant::Date(d) = self { - Some(*d) - } else { - None + match *self { + Variant::Date(d) => Some(d), + Variant::ShortString(s) => parse_date(s.as_ref()), + Variant::String(s) => parse_date(s), + _ => None, } } @@ -628,18 +650,29 @@ impl<'m, 'v> Variant<'m, 'v> { /// let v1 = Variant::from(datetime); /// assert_eq!(v1.as_timestamp_micros(), Some(datetime)); /// + /// // or from string variant + /// let v2 = Variant::from("2026-06-10T12:34:56.780Z"); + /// let datetime = NaiveDate::from_ymd_opt(2026, 6, 10) + /// .unwrap() + /// .and_hms_milli_opt(12, 34, 56, 780) + /// .unwrap() + /// .and_utc(); + /// assert_eq!(v2.as_timestamp_micros(), Some(datetime)); + /// /// // but not for other variants. /// let datetime_nanos = NaiveDate::from_ymd_opt(2025, 8, 14) /// .unwrap() /// .and_hms_nano_opt(12, 33, 54, 123456789) /// .unwrap() /// .and_utc(); - /// let v2 = Variant::from(datetime_nanos); - /// assert_eq!(v2.as_timestamp_micros(), None); + /// let v3 = Variant::from(datetime_nanos); + /// assert_eq!(v3.as_timestamp_micros(), None); /// ``` pub fn as_timestamp_micros(&self) -> Option> { match *self { Variant::TimestampMicros(d) => Some(d), + Variant::ShortString(s) => string_to_datetime(&Utc, s.as_ref()).ok(), + Variant::String(s) => string_to_datetime(&Utc, s).ok(), _ => None, } } @@ -663,17 +696,29 @@ impl<'m, 'v> Variant<'m, 'v> { /// let v1 = Variant::from(datetime); /// assert_eq!(v1.as_timestamp_ntz_micros(), Some(datetime)); /// + /// // or from string variant + /// let v2 = Variant::from("2026-06-10T12:34:56.780Z"); + /// let datetime = NaiveDate::from_ymd_opt(2026, 6, 10) + /// .unwrap() + /// .and_hms_milli_opt(12, 34, 56, 780) + /// .unwrap(); + /// assert_eq!(v2.as_timestamp_ntz_micros(), Some(datetime)); + /// /// // but not for other variants. /// let datetime_nanos = NaiveDate::from_ymd_opt(2025, 8, 14) /// .unwrap() /// .and_hms_nano_opt(12, 33, 54, 123456789) /// .unwrap(); - /// let v2 = Variant::from(datetime_nanos); - /// assert_eq!(v2.as_timestamp_micros(), None); + /// let v3 = Variant::from(datetime_nanos); + /// assert_eq!(v3.as_timestamp_micros(), None); /// ``` pub fn as_timestamp_ntz_micros(&self) -> Option { match *self { Variant::TimestampNtzMicros(d) => Some(d), + Variant::String(s) => string_to_datetime(&Utc, s).ok().map(|dt| dt.naive_utc()), + Variant::ShortString(s) => string_to_datetime(&Utc, s.as_ref()) + .ok() + .map(|dt| dt.naive_utc()), _ => None, } } @@ -708,13 +753,24 @@ impl<'m, 'v> Variant<'m, 'v> { /// let v2 = Variant::from(datetime_micros); /// assert_eq!(v2.as_timestamp_nanos(), Some(datetime_micros)); /// + /// // or from string variant + /// let v3 = Variant::from("2026-06-10T12:34:56.123456789Z"); + /// let datetime = NaiveDate::from_ymd_opt(2026, 6, 10) + /// .unwrap() + /// .and_hms_nano_opt(12, 34, 56, 123456789) + /// .unwrap() + /// .and_utc(); + /// assert_eq!(v3.as_timestamp_nanos(), Some(datetime)); + /// /// // but not for other variants. - /// let v3 = Variant::from("hello!"); - /// assert_eq!(v3.as_timestamp_nanos(), None); + /// let v4 = Variant::from("hello!"); + /// assert_eq!(v4.as_timestamp_nanos(), None); /// ``` pub fn as_timestamp_nanos(&self) -> Option> { match *self { Variant::TimestampNanos(d) | Variant::TimestampMicros(d) => Some(d), + Variant::ShortString(s) => string_to_datetime(&Utc, s.as_ref()).ok(), + Variant::String(s) => string_to_datetime(&Utc, s).ok(), _ => None, } } @@ -747,13 +803,25 @@ impl<'m, 'v> Variant<'m, 'v> { /// let v2 = Variant::from(datetime_micros); /// assert_eq!(v2.as_timestamp_ntz_nanos(), Some(datetime_micros)); /// + /// // or from string variant + /// let v3 = Variant::from("2026-06-10T12:34:56.123456789Z"); + /// let datetime = NaiveDate::from_ymd_opt(2026, 6, 10) + /// .unwrap() + /// .and_hms_nano_opt(12, 34, 56, 123456789) + /// .unwrap(); + /// assert_eq!(v3.as_timestamp_ntz_nanos(), Some(datetime)); + /// /// // but not for other variants. - /// let v3 = Variant::from("hello!"); - /// assert_eq!(v3.as_timestamp_ntz_nanos(), None); + /// let v4 = Variant::from("hello!"); + /// assert_eq!(v4.as_timestamp_ntz_nanos(), None); /// ``` pub fn as_timestamp_ntz_nanos(&self) -> Option { match *self { Variant::TimestampNtzNanos(d) | Variant::TimestampNtzMicros(d) => Some(d), + Variant::String(s) => string_to_datetime(&Utc, s).ok().map(|dt| dt.naive_utc()), + Variant::ShortString(s) => string_to_datetime(&Utc, s.as_ref()) + .ok() + .map(|dt| dt.naive_utc()), _ => None, } } @@ -773,15 +841,21 @@ impl<'m, 'v> Variant<'m, 'v> { /// let v1 = Variant::Binary(data); /// assert_eq!(v1.as_u8_slice(), Some(data.as_slice())); /// + /// // or string variant + /// let data = b"world"; + /// let v2 = Variant::from("world"); + /// assert_eq!(v2.as_u8_slice(), Some(data.as_slice())); + /// /// // but not from other variant types - /// let v2 = Variant::from(123i64); - /// assert_eq!(v2.as_u8_slice(), None); + /// let v3 = Variant::from(123i64); + /// assert_eq!(v3.as_u8_slice(), None); /// ``` pub fn as_u8_slice(&'v self) -> Option<&'v [u8]> { - if let Variant::Binary(d) = self { - Some(d) - } else { - None + match self { + Variant::Binary(d) => Some(d), + Variant::String(s) => Some(s.as_bytes()), + Variant::ShortString(s) => Some(s.as_ref().as_bytes()), + _ => None, } } @@ -811,9 +885,10 @@ impl<'m, 'v> Variant<'m, 'v> { } } - /// Converts this variant to a `uuid hyphenated string` if possible. + /// Converts this variant to a `Uuid` if possible. /// - /// Returns `Some(String)` for UUID variants, `None` for non-UUID variants. + /// Returns `Some(Uuid)` for UUID variants and string variants that can be + /// parsed as UUIDs. /// /// # Examples /// @@ -824,15 +899,24 @@ impl<'m, 'v> Variant<'m, 'v> { /// let s = uuid::Uuid::parse_str("67e55044-10b1-426f-9247-bb680e5fe0c8").unwrap(); /// let v1 = Variant::Uuid(s); /// assert_eq!(s, v1.as_uuid().unwrap()); - /// assert_eq!("67e55044-10b1-426f-9247-bb680e5fe0c8", v1.as_uuid().unwrap().to_string()); /// - /// //but not from other variants - /// let v2 = Variant::from(1234); - /// assert_eq!(None, v2.as_uuid()) + /// // or from a UUID-format string + /// let v2 = Variant::from("67e55044-10b1-426f-9247-bb680e5fe0c8"); + /// assert_eq!(s, v2.as_uuid().unwrap()); + /// + /// // but not from other variants + /// let v3 = Variant::from(1234); + /// assert_eq!(None, v3.as_uuid()); + /// + /// // or non-UUID strings + /// let v4 = Variant::from("not-a-uuid"); + /// assert_eq!(None, v4.as_uuid()); /// ``` pub fn as_uuid(&self) -> Option { match self { Variant::Uuid(u) => Some(*u), + Variant::ShortString(s) => Uuid::parse_str(s.as_ref()).ok(), + Variant::String(s) => Uuid::parse_str(s).ok(), _ => None, } } @@ -860,14 +944,15 @@ impl<'m, 'v> Variant<'m, 'v> { } } - /// Converts a boolean or numeric variant(integers, floating-point, and decimals) + /// Converts a boolean, string or numeric variant(integers, floating-point, and decimals) /// to the specified numeric type `T`. /// /// Uses Arrow's casting logic to perform the conversion. Returns `Some(T)` if /// the conversion succeeds, `None` if the variant can't be casted to type `T`. - fn as_num(&self) -> Option + fn as_num(&self) -> Option where - T: DecimalCastTarget, + T: ArrowPrimitiveType + Parser, + T::Native: DecimalCastTarget, { match *self { Variant::BooleanFalse => single_bool_to_numeric(false), @@ -878,28 +963,30 @@ impl<'m, 'v> Variant<'m, 'v> { Variant::Int64(i) => num_cast(i), Variant::Float(f) => num_cast(f), Variant::Double(d) => num_cast(d), - Variant::Decimal4(d) => { - Self::cast_decimal_to_num::(d.integer(), d.scale(), |x| { - x as f64 - }) - } - Variant::Decimal8(d) => { - Self::cast_decimal_to_num::(d.integer(), d.scale(), |x| { - x as f64 - }) - } - Variant::Decimal16(d) => { - Self::cast_decimal_to_num::(d.integer(), d.scale(), |x| { - x as f64 - }) - } + Variant::Decimal4(d) => Self::cast_decimal_to_num::( + d.integer(), + d.scale(), + |x| x as f64, + ), + Variant::Decimal8(d) => Self::cast_decimal_to_num::( + d.integer(), + d.scale(), + |x| x as f64, + ), + Variant::Decimal16(d) => Self::cast_decimal_to_num::( + d.integer(), + d.scale(), + |x| x as f64, + ), + Variant::ShortString(s) => T::parse(s.as_ref()), + Variant::String(s) => T::parse(s), _ => None, } } /// Converts this variant to an `i8` if possible. /// - /// Returns `Some(i8)` for boolean and numeric variants(integers, floating-point, + /// Returns `Some(i8)` for boolean, string and numeric variants(integers, floating-point, /// and decimals with scale 0) that fit in `i8` range, /// `None` for other variants or values that would overflow. /// @@ -916,16 +1003,20 @@ impl<'m, 'v> Variant<'m, 'v> { /// let v2 = Variant::BooleanFalse; /// assert_eq!(v2.as_int8(), Some(0)); /// + /// // or from string variant + /// let v3 = Variant::String("123"); + /// assert_eq!(v3.as_int8(), Some(123i8)); + /// /// // but not if it would overflow - /// let v3 = Variant::from(1234i64); - /// assert_eq!(v3.as_int8(), None); + /// let v4 = Variant::from(1234i64); + /// assert_eq!(v4.as_int8(), None); /// /// // or if the variant cannot be cast into an integer - /// let v4 = Variant::from("hello!"); - /// assert_eq!(v4.as_int8(), None); + /// let v5 = Variant::from("hello!"); + /// assert_eq!(v5.as_int8(), None); /// ``` pub fn as_int8(&self) -> Option { - self.as_num() + self.as_num::() } /// Converts this variant to an `i16` if possible. @@ -947,21 +1038,25 @@ impl<'m, 'v> Variant<'m, 'v> { /// let v2 = Variant::BooleanFalse; /// assert_eq!(v2.as_int16(), Some(0)); /// + /// // or from string variant + /// let v3 = Variant::String("123"); + /// assert_eq!(v3.as_int16(), Some(123i16)); + /// /// // but not if it would overflow - /// let v3 = Variant::from(123456i64); - /// assert_eq!(v3.as_int16(), None); + /// let v4 = Variant::from(123456i64); + /// assert_eq!(v4.as_int16(), None); /// /// // or if the variant cannot be cast into an integer - /// let v4 = Variant::from("hello!"); - /// assert_eq!(v4.as_int16(), None); + /// let v5 = Variant::from("hello!"); + /// assert_eq!(v5.as_int16(), None); /// ``` pub fn as_int16(&self) -> Option { - self.as_num() + self.as_num::() } /// Converts this variant to an `i32` if possible. /// - /// Returns `Some(i32)` for boolean and numeric variants(integers, floating-point, + /// Returns `Some(i32)` for boolean, string and numeric variants(integers, floating-point, /// and decimals with scale 0) that fit in `i32` range /// `None` for other variants or values that would overflow. /// @@ -978,21 +1073,25 @@ impl<'m, 'v> Variant<'m, 'v> { /// let v2 = Variant::BooleanFalse; /// assert_eq!(v2.as_int32(), Some(0)); /// + /// // or from string variant + /// let v3 = Variant::String("12345"); + /// assert_eq!(v3.as_int32(), Some(12345i32)); + /// /// // but not if it would overflow - /// let v3 = Variant::from(12345678901i64); - /// assert_eq!(v3.as_int32(), None); + /// let v4 = Variant::from(12345678901i64); + /// assert_eq!(v4.as_int32(), None); /// /// // or if the variant cannot be cast into an integer - /// let v4 = Variant::from("hello!"); - /// assert_eq!(v4.as_int32(), None); + /// let v5 = Variant::from("hello!"); + /// assert_eq!(v5.as_int32(), None); /// ``` pub fn as_int32(&self) -> Option { - self.as_num() + self.as_num::() } /// Converts this variant to an `i64` if possible. /// - /// Returns `Some(i64)` for boolean and numeric variants(integers, floating-point, + /// Returns `Some(i64)` for boolean, string and numeric variants(integers, floating-point, /// and decimals with scale 0) that fit in `i64` range /// `None` for other variants or values that would overflow. /// @@ -1009,17 +1108,21 @@ impl<'m, 'v> Variant<'m, 'v> { /// let v2 = Variant::BooleanFalse; /// assert_eq!(v2.as_int64(), Some(0)); /// + /// // or from string variant + /// let v3 = Variant::String("123456"); + /// assert_eq!(v3.as_int64(), Some(123456i64)); + /// /// // but not a variant that cannot be cast into an integer - /// let v3 = Variant::from("hello!"); - /// assert_eq!(v3.as_int64(), None); + /// let v4 = Variant::from("hello!"); + /// assert_eq!(v4.as_int64(), None); /// ``` pub fn as_int64(&self) -> Option { - self.as_num() + self.as_num::() } /// Converts this variant to a `u8` if possible. /// - /// Returns `Some(u8)` for boolean and numeric variants(integers, floating-point, + /// Returns `Some(u8)` for boolean, string and numeric variants(integers, floating-point, /// and decimals with scale 0) that fit in `u8` range /// `None` for other variants or values that would overflow. /// @@ -1046,21 +1149,25 @@ impl<'m, 'v> Variant<'m, 'v> { /// let v4 = Variant::BooleanFalse; /// assert_eq!(v4.as_u8(), Some(0)); /// + /// // or from string variant + /// let v5 = Variant::String("123"); + /// assert_eq!(v5.as_u8(), Some(123u8)); + /// /// // but not a variant that can't fit into the range - /// let v5 = Variant::from(-1); - /// assert_eq!(v5.as_u8(), None); + /// let v6 = Variant::from(-1); + /// assert_eq!(v6.as_u8(), None); /// /// // or not a variant that cannot be cast into an integer - /// let v6 = Variant::from("hello!"); - /// assert_eq!(v6.as_u8(), None); + /// let v7 = Variant::from("hello!"); + /// assert_eq!(v7.as_u8(), None); /// ``` pub fn as_u8(&self) -> Option { - self.as_num() + self.as_num::() } /// Converts this variant to an `u16` if possible. /// - /// Returns `Some(u16)` for boolean and numeric variants(integers, floating-point, + /// Returns `Some(u16)` for boolean, string and numeric variants(integers, floating-point, /// and decimals with scale 0) that fit in `u16` range /// `None` for other variants or values that would overflow. /// @@ -1087,21 +1194,25 @@ impl<'m, 'v> Variant<'m, 'v> { /// let v4= Variant::BooleanFalse; /// assert_eq!(v4.as_u16(), Some(0)); /// + /// // or from string variant + /// let v5 = Variant::String("1234"); + /// assert_eq!(v5.as_u16(), Some(1234u16)); + /// /// // but not a variant that can't fit into the range - /// let v5 = Variant::from(-1); - /// assert_eq!(v5.as_u16(), None); + /// let v6 = Variant::from(-1); + /// assert_eq!(v6.as_u16(), None); /// /// // or not a variant that cannot be cast into an integer - /// let v6 = Variant::from("hello!"); - /// assert_eq!(v6.as_u16(), None); + /// let v7 = Variant::from("hello!"); + /// assert_eq!(v7.as_u16(), None); /// ``` pub fn as_u16(&self) -> Option { - self.as_num() + self.as_num::() } /// Converts this variant to an `u32` if possible. /// - /// Returns `Some(u32)` for boolean and numeric variants(integers, floating-point, + /// Returns `Some(u32)` for boolean, string and numeric variants(integers, floating-point, /// and decimals with scale 0) that fit in `u32` range /// `None` for other variants or values that would overflow. /// @@ -1128,21 +1239,25 @@ impl<'m, 'v> Variant<'m, 'v> { /// let v4 = Variant::BooleanFalse; /// assert_eq!(v4.as_u32(), Some(0)); /// + /// // or from string variant + /// let v5 = Variant::String("12345"); + /// assert_eq!(v5.as_u32(), Some(12345u32)); + /// /// // but not a variant that can't fit into the range - /// let v5 = Variant::from(-1); - /// assert_eq!(v5.as_u32(), None); + /// let v6 = Variant::from(-1); + /// assert_eq!(v6.as_u32(), None); /// /// // or not a variant that cannot be cast into an integer - /// let v6 = Variant::from("hello!"); - /// assert_eq!(v6.as_u32(), None); + /// let v7 = Variant::from("hello!"); + /// assert_eq!(v7.as_u32(), None); /// ``` pub fn as_u32(&self) -> Option { - self.as_num() + self.as_num::() } /// Converts this variant to an `u64` if possible. /// - /// Returns `Some(u64)` for boolean and numeric variants(integers, floating-point, + /// Returns `Some(u64)` for boolean, string and numeric variants(integers, floating-point, /// and decimals with scale 0) that fit in `u64` range /// `None` for other variants or values that would overflow. /// @@ -1169,16 +1284,20 @@ impl<'m, 'v> Variant<'m, 'v> { /// let v4 = Variant::BooleanFalse; /// assert_eq!(v4.as_u64(), Some(0)); /// + /// // or from string variant + /// let v5 = Variant::String("12345"); + /// assert_eq!(v5.as_u64(), Some(12345u64)); + /// /// // but not a variant that can't fit into the range - /// let v5 = Variant::from(-1); - /// assert_eq!(v5.as_u64(), None); + /// let v6 = Variant::from(-1); + /// assert_eq!(v6.as_u64(), None); /// /// // or not a variant that cannot be cast into an integer - /// let v6 = Variant::from("hello!"); - /// assert_eq!(v6.as_u64(), None); + /// let v7 = Variant::from("hello!"); + /// assert_eq!(v7.as_u64(), None); /// ``` pub fn as_u64(&self) -> Option { - self.as_num() + self.as_num::() } fn convert_string_to_decimal(input: &str) -> Option @@ -1230,7 +1349,7 @@ impl<'m, 'v> Variant<'m, 'v> { pub fn as_decimal4(&self) -> Option { match *self { Variant::Int8(_) | Variant::Int16(_) | Variant::Int32(_) | Variant::Int64(_) => { - self.as_num::().and_then(|x| x.try_into().ok()) + self.as_num::().and_then(|x| x.try_into().ok()) } Variant::Float(f) => single_float_to_decimal::(f as _, 1f64) .and_then(|x: i32| x.try_into().ok()), @@ -1281,7 +1400,7 @@ impl<'m, 'v> Variant<'m, 'v> { pub fn as_decimal8(&self) -> Option { match *self { Variant::Int8(_) | Variant::Int16(_) | Variant::Int32(_) | Variant::Int64(_) => { - self.as_num::().and_then(|x| x.try_into().ok()) + self.as_num::().and_then(|x| x.try_into().ok()) } Variant::Float(f) => single_float_to_decimal::(f as _, 1f64) .and_then(|x: i64| x.try_into().ok()), @@ -1324,7 +1443,7 @@ impl<'m, 'v> Variant<'m, 'v> { pub fn as_decimal16(&self) -> Option { match *self { Variant::Int8(_) | Variant::Int16(_) | Variant::Int32(_) | Variant::Int64(_) => { - let x = self.as_num::()?; + let x = self.as_num::()?; >::from(x).try_into().ok() } Variant::Float(f) => { @@ -1347,7 +1466,7 @@ impl<'m, 'v> Variant<'m, 'v> { /// Converts this variant to an `f16` if possible. /// - /// Returns `Some(f16)` for boolean and numeric variants(integers, floating-point, + /// Returns `Some(f16)` for boolean, string and numeric variants(integers, floating-point, /// and decimals with scale 0) that fit in `f16` range /// `None` otherwise. /// @@ -1365,24 +1484,28 @@ impl<'m, 'v> Variant<'m, 'v> { /// let v2 = Variant::from(std::f64::consts::PI); /// assert_eq!(v2.as_f16(), Some(f16::from_f64(std::f64::consts::PI))); /// - /// // and from boolean + /// // and from boolean variant /// let v3 = Variant::BooleanTrue; /// assert_eq!(v3.as_f16(), Some(f16::from_f32(1.0))); /// + /// // and from string variant + /// let v4 = Variant::String("123.45"); + /// assert_eq!(v4.as_f16(), Some(f16::from_f32(123.45f32))); + /// /// // return inf if overflow - /// let v4 = Variant::from(123456); - /// assert_eq!(v4.as_f16(), Some(f16::INFINITY)); + /// let v5 = Variant::from(123456); + /// assert_eq!(v5.as_f16(), Some(f16::INFINITY)); /// /// // but not from other variants - /// let v5 = Variant::from("hello!"); - /// assert_eq!(v5.as_f16(), None); + /// let v6 = Variant::from("hello!"); + /// assert_eq!(v6.as_f16(), None); pub fn as_f16(&self) -> Option { - self.as_num() + self.as_num::() } /// Converts this variant to an `f32` if possible. /// - /// Returns `Some(f32)` for boolean and numeric variants(integers, floating-point, + /// Returns `Some(f32)` for boolean, string and numeric variants(integers, floating-point, /// and decimals with scale 0) that fit in `f32` range /// `None` otherwise. /// @@ -1403,21 +1526,25 @@ impl<'m, 'v> Variant<'m, 'v> { /// let v3 = Variant::BooleanTrue; /// assert_eq!(v3.as_f32(), Some(1.0)); /// + /// // and from string variant + /// let v4 = Variant::String("123.45"); + /// assert_eq!(v4.as_f32(), Some(123.45f32)); + /// /// // and return inf if overflow - /// let v4 = Variant::from(f64::MAX); - /// assert_eq!(v4.as_f32(), Some(f32::INFINITY)); + /// let v5 = Variant::from(f64::MAX); + /// assert_eq!(v5.as_f32(), Some(f32::INFINITY)); /// /// // but not from other variants - /// let v5 = Variant::from("hello!"); - /// assert_eq!(v5.as_f32(), None); + /// let v6 = Variant::from("hello!"); + /// assert_eq!(v6.as_f32(), None); /// ``` pub fn as_f32(&self) -> Option { - self.as_num() + self.as_num::() } /// Converts this variant to an `f64` if possible. /// - /// Returns `Some(f64)` for boolean and numeric variants(integers, floating-point, + /// Returns `Some(f64)` for boolean, string and numeric variants(integers, floating-point, /// and decimals with scale 0) that fit in `f64` range /// `None` for other variants or can't be represented by an f64. /// @@ -1438,12 +1565,16 @@ impl<'m, 'v> Variant<'m, 'v> { /// let v3 = Variant::BooleanTrue; /// assert_eq!(v3.as_f64(), Some(1.0f64)); /// + /// // and from string variant + /// let v4 = Variant::String("123.45"); + /// assert_eq!(v4.as_f64(), Some(123.45f64)); + /// /// // but not from other variants /// let v5 = Variant::from("hello!"); /// assert_eq!(v5.as_f64(), None); /// ``` pub fn as_f64(&self) -> Option { - self.as_num() + self.as_num::() } /// Converts this variant to an `Object` if it is an [`VariantObject`]. @@ -1541,8 +1672,8 @@ impl<'m, 'v> Variant<'m, 'v> { /// Converts this variant to a `NaiveTime` if possible. /// - /// Returns `Some(NaiveTime)` for `Variant::Time`, - /// `None` for non-Time variants. + /// Returns `Some(NaiveTime)` for a time and string variant. + /// `None` for the other variants. /// /// # Example /// @@ -1555,15 +1686,35 @@ impl<'m, 'v> Variant<'m, 'v> { /// let v1 = Variant::from(time); /// assert_eq!(Some(time), v1.as_time_utc()); /// + /// // or from string variant + /// let v2 = Variant::String("1234567"); + /// let time = NaiveTime::from_hms_micro_opt(1, 2, 3, 4).unwrap(); + /// assert_eq!(Some(time), v2.as_time_utc()); + /// /// // but not from other variants. - /// let v2 = Variant::from("Hello"); - /// assert_eq!(None, v2.as_time_utc()); + /// let v3 = Variant::from("Hello"); + /// assert_eq!(None, v3.as_time_utc()); /// ``` pub fn as_time_utc(&'m self) -> Option { - if let Variant::Time(time) = self { - Some(*time) - } else { - None + match *self { + Variant::Time(time) => Some(time), + Variant::ShortString(s) => { + Time64MicrosecondType::parse(s.as_ref()).and_then(|nanos_since_midnight| { + NaiveTime::from_num_seconds_from_midnight_opt( + (nanos_since_midnight / 1_000_000_000) as u32, + (nanos_since_midnight % 1_000_000_000) as u32, + ) + }) + } + Variant::String(s) => { + Time64MicrosecondType::parse(s).and_then(|nanos_since_midnight| { + NaiveTime::from_num_seconds_from_midnight_opt( + (nanos_since_midnight / 1_000_000_000) as u32, + (nanos_since_midnight % 1_000_000_000) as u32, + ) + }) + } + _ => None, } } From 408a83f55c3cdb47745ee64a57e2507f82a4bd23 Mon Sep 17 00:00:00 2001 From: klion26 Date: Fri, 12 Jun 2026 17:55:52 +0800 Subject: [PATCH 2/2] fix tests --- parquet-variant-compute/src/shred_variant.rs | 8 +- .../src/type_conversion.rs | 9 + .../src/variant_to_arrow.rs | 174 +++++++++++++++--- parquet-variant/src/variant.rs | 11 +- 4 files changed, 163 insertions(+), 39 deletions(-) diff --git a/parquet-variant-compute/src/shred_variant.rs b/parquet-variant-compute/src/shred_variant.rs index 440f4b716521..c3acaf02d08d 100644 --- a/parquet-variant-compute/src/shred_variant.rs +++ b/parquet-variant-compute/src/shred_variant.rs @@ -92,6 +92,7 @@ pub(crate) fn shred_variant_with_options( cast_options, array.len(), NullValue::TopLevelVariant, + true, )?; for i in 0..array.len() { if array.is_null(i) { @@ -145,6 +146,7 @@ pub(crate) fn make_variant_to_shredded_variant_arrow_row_builder<'a>( cast_options: &'a CastOptions, capacity: usize, null_value: NullValue, + shred: bool, ) -> Result> { let builder = match data_type { DataType::Struct(fields) => { @@ -193,7 +195,7 @@ pub(crate) fn make_variant_to_shredded_variant_arrow_row_builder<'a>( | DataType::FixedSizeBinary(16) // UUID => { let builder = - make_primitive_variant_to_arrow_row_builder(data_type, cast_options, capacity)?; + make_primitive_variant_to_arrow_row_builder(data_type, cast_options, capacity, shred)?; let typed_value_builder = VariantToShreddedPrimitiveVariantRowBuilder::new(builder, capacity, null_value); VariantToShreddedVariantRowBuilder::Primitive(typed_value_builder) @@ -376,6 +378,7 @@ impl<'a> VariantToShreddedObjectVariantRowBuilder<'a> { cast_options, capacity, NullValue::ObjectField, + true, )?; Ok((field.name().as_str(), builder)) }); @@ -1046,6 +1049,7 @@ mod tests { &cast_options, 1, mode, + true, ) .unwrap(); primitive_builder.append_null().unwrap(); @@ -1076,6 +1080,7 @@ mod tests { &cast_options, 1, mode, + true, ) .unwrap(); array_builder.append_null().unwrap(); @@ -1104,6 +1109,7 @@ mod tests { &cast_options, 1, mode, + true, ) .unwrap(); object_builder.append_null().unwrap(); diff --git a/parquet-variant-compute/src/type_conversion.rs b/parquet-variant-compute/src/type_conversion.rs index 596b7eed6055..cd0420a65c44 100644 --- a/parquet-variant-compute/src/type_conversion.rs +++ b/parquet-variant-compute/src/type_conversion.rs @@ -383,6 +383,15 @@ fn cast_list_to_string<'m, 'v>(mut iter: impl Iterator>) ret_str } +pub(crate) fn variant_to_binary<'v>(variant: &Variant<'_, 'v>) -> Option<&'v [u8]> { + match *variant { + Variant::Binary(d) => Some(d), + Variant::String(s) => Some(s.as_bytes()), + Variant::ShortString(s) => Some(s.as_str().as_bytes()), + _ => None, + } +} + /// Convert the value at a specific index in the given array into a `Variant`. macro_rules! non_generic_conversion_single_value { ($array:expr, $cast_fn:expr, $index:expr) => {{ diff --git a/parquet-variant-compute/src/variant_to_arrow.rs b/parquet-variant-compute/src/variant_to_arrow.rs index f59a6a06a1b6..bc80f5d5408e 100644 --- a/parquet-variant-compute/src/variant_to_arrow.rs +++ b/parquet-variant-compute/src/variant_to_arrow.rs @@ -20,8 +20,8 @@ use crate::shred_variant::{ make_variant_to_shredded_variant_arrow_row_builder, }; use crate::type_conversion::{ - PrimitiveFromVariant, TimestampFromVariant, variant_cast_with_options, variant_to_string, - variant_to_unscaled_decimal, + PrimitiveFromVariant, TimestampFromVariant, variant_cast_with_options, variant_to_binary, + variant_to_string, variant_to_unscaled_decimal, }; use crate::variant_array::ShreddedVariantFieldArray; use crate::{VariantArray, VariantValueArrayBuilder}; @@ -133,8 +133,12 @@ fn make_typed_variant_to_arrow_row_builder<'a>( Ok(Encoded(builder)) } data_type => { - let builder = - make_primitive_variant_to_arrow_row_builder(data_type, cast_options, capacity)?; + let builder = make_primitive_variant_to_arrow_row_builder( + data_type, + cast_options, + capacity, + false, + )?; Ok(Primitive(builder)) } } @@ -171,6 +175,61 @@ pub(crate) fn make_variant_to_arrow_row_builder<'a>( Ok(builder) } +pub(crate) enum VariantToStringArrowRowBuilder<'a, B: StringLikeArrayBuilder> { + Get(VariantToStringGetArrowBuilder<'a, B>), + Shred(VariantToStringShredArrowBuilder<'a, B>), +} + +impl<'a, B: StringLikeArrayBuilder> VariantToStringArrowRowBuilder<'a, B> { + fn append_null(&mut self) -> Result<()> { + match self { + Self::Get(b) => b.append_null(), + Self::Shred(b) => b.append_null(), + } + } + + fn append_value(&mut self, value: &Variant<'_, '_>) -> Result { + match self { + Self::Get(b) => b.append_value(value), + Self::Shred(b) => b.append_value(value), + } + } + + fn finish(self) -> Result { + match self { + Self::Get(b) => b.finish(), + Self::Shred(b) => b.finish(), + } + } +} + +pub(crate) enum VariantToBinaryArrowRowBuilder<'a, B: BinaryLikeArrayBuilder> { + Get(VariantToBinaryGetArrowRowBuilder<'a, B>), + Shred(VariantToBinaryShredArrowRowBuilder<'a, B>), +} + +impl<'a, B: BinaryLikeArrayBuilder> VariantToBinaryArrowRowBuilder<'a, B> { + fn append_null(&mut self) -> Result<()> { + match self { + Self::Get(b) => b.append_null(), + Self::Shred(b) => b.append_null(), + } + } + + fn append_value(&mut self, value: &Variant<'_, '_>) -> Result { + match self { + Self::Get(b) => b.append_value(value), + Self::Shred(b) => b.append_value(value), + } + } + + fn finish(self) -> Result { + match self { + Self::Get(b) => b.finish(), + Self::Shred(b) => b.finish(), + } + } +} /// Builder for converting primitive variant values to Arrow arrays. It is used by both /// `VariantToArrowRowBuilder` (below) and `VariantToShreddedPrimitiveVariantRowBuilder` (in /// `shred_variant.rs`). @@ -211,9 +270,9 @@ pub(crate) enum PrimitiveVariantToArrowRowBuilder<'a> { Date32(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Date32Type>), Date64(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Date64Type>), Uuid(VariantToUuidArrowRowBuilder<'a>), - String(VariantToStringArrowBuilder<'a, StringBuilder>), - LargeString(VariantToStringArrowBuilder<'a, LargeStringBuilder>), - StringView(VariantToStringArrowBuilder<'a, StringViewBuilder>), + String(VariantToStringArrowRowBuilder<'a, StringBuilder>), + LargeString(VariantToStringArrowRowBuilder<'a, LargeStringBuilder>), + StringView(VariantToStringArrowRowBuilder<'a, StringViewBuilder>), Binary(VariantToBinaryArrowRowBuilder<'a, BinaryBuilder>), LargeBinary(VariantToBinaryArrowRowBuilder<'a, LargeBinaryBuilder>), BinaryView(VariantToBinaryArrowRowBuilder<'a, BinaryViewBuilder>), @@ -397,6 +456,7 @@ pub(crate) fn make_primitive_variant_to_arrow_row_builder<'a>( data_type: &'a DataType, cast_options: &'a CastOptions, capacity: usize, + shred: bool, ) -> Result> { use PrimitiveVariantToArrowRowBuilder::*; @@ -523,13 +583,30 @@ pub(crate) fn make_primitive_variant_to_arrow_row_builder<'a>( .to_string(), )); } - DataType::Binary => Binary(VariantToBinaryArrowRowBuilder::new(cast_options, capacity)), - DataType::LargeBinary => { - LargeBinary(VariantToBinaryArrowRowBuilder::new(cast_options, capacity)) - } - DataType::BinaryView => { - BinaryView(VariantToBinaryArrowRowBuilder::new(cast_options, capacity)) - } + DataType::Binary => match shred { + true => Binary(VariantToBinaryArrowRowBuilder::Shred( + VariantToBinaryShredArrowRowBuilder::new(cast_options, capacity), + )), + false => Binary(VariantToBinaryArrowRowBuilder::Get( + VariantToBinaryGetArrowRowBuilder::new(cast_options, capacity), + )), + }, + DataType::LargeBinary => match shred { + true => LargeBinary(VariantToBinaryArrowRowBuilder::Shred( + VariantToBinaryShredArrowRowBuilder::new(cast_options, capacity), + )), + false => LargeBinary(VariantToBinaryArrowRowBuilder::Get( + VariantToBinaryGetArrowRowBuilder::new(cast_options, capacity), + )), + }, + DataType::BinaryView => match shred { + true => BinaryView(VariantToBinaryArrowRowBuilder::Shred( + VariantToBinaryShredArrowRowBuilder::new(cast_options, capacity), + )), + false => BinaryView(VariantToBinaryArrowRowBuilder::Get( + VariantToBinaryGetArrowRowBuilder::new(cast_options, capacity), + )), + }, DataType::FixedSizeBinary(16) => { Uuid(VariantToUuidArrowRowBuilder::new(cast_options, capacity)) } @@ -538,13 +615,30 @@ pub(crate) fn make_primitive_variant_to_arrow_row_builder<'a>( "DataType {data_type:?} not yet implemented" ))); } - DataType::Utf8 => String(VariantToStringArrowBuilder::new(cast_options, capacity)), - DataType::LargeUtf8 => { - LargeString(VariantToStringArrowBuilder::new(cast_options, capacity)) - } - DataType::Utf8View => { - StringView(VariantToStringArrowBuilder::new(cast_options, capacity)) - } + DataType::Utf8 => match shred { + true => String(VariantToStringArrowRowBuilder::Shred( + VariantToStringShredArrowBuilder::new(cast_options, capacity), + )), + false => String(VariantToStringArrowRowBuilder::Get( + VariantToStringGetArrowBuilder::new(cast_options, capacity), + )), + }, + DataType::LargeUtf8 => match shred { + true => LargeString(VariantToStringArrowRowBuilder::Shred( + VariantToStringShredArrowBuilder::new(cast_options, capacity), + )), + false => LargeString(VariantToStringArrowRowBuilder::Get( + VariantToStringGetArrowBuilder::new(cast_options, capacity), + )), + }, + DataType::Utf8View => match shred { + true => StringView(VariantToStringArrowRowBuilder::Shred( + VariantToStringShredArrowBuilder::new(cast_options, capacity), + )), + false => StringView(VariantToStringArrowRowBuilder::Get( + VariantToStringGetArrowBuilder::new(cast_options, capacity), + )), + }, DataType::List(_) | DataType::LargeList(_) | DataType::ListView(_) @@ -838,13 +932,20 @@ macro_rules! define_variant_to_primitive_builder { } define_variant_to_primitive_builder!( - struct VariantToStringArrowBuilder<'a, B: StringLikeArrayBuilder> + struct VariantToStringGetArrowBuilder<'a, B: StringLikeArrayBuilder> |capacity| -> B { B::with_capacity(capacity) }, |value| variant_to_string(value), type_name: B::type_name(), append_value: |builder, v| builder.append_value(&v) ); +define_variant_to_primitive_builder!( + struct VariantToStringShredArrowBuilder<'a, B: StringLikeArrayBuilder> + |capacity| -> B { B::with_capacity(capacity) }, + |value| value.as_string(), + type_name: B::type_name() +); + define_variant_to_primitive_builder!( struct VariantToBooleanArrowRowBuilder<'a> |capacity| -> BooleanBuilder { BooleanBuilder::with_capacity(capacity) }, @@ -876,12 +977,19 @@ define_variant_to_primitive_builder!( ); define_variant_to_primitive_builder!( - struct VariantToBinaryArrowRowBuilder<'a, B: BinaryLikeArrayBuilder> + struct VariantToBinaryShredArrowRowBuilder<'a, B: BinaryLikeArrayBuilder> |capacity| -> B { B::with_capacity(capacity) }, |value| value.as_u8_slice(), type_name: B::type_name() ); +define_variant_to_primitive_builder!( + struct VariantToBinaryGetArrowRowBuilder<'a, B: BinaryLikeArrayBuilder> + |capacity| -> B { B::with_capacity(capacity) }, + |value| variant_to_binary(value), + type_name: B::type_name() +); + /// Builder for converting variant values to arrow Decimal values pub(crate) struct VariantToDecimalArrowRowBuilder<'a, T> where @@ -1063,6 +1171,7 @@ where cast_options, capacity, NullValue::ArrayElement, + false, )?; ListElementBuilder::Shredded(Box::new(builder)) } else { @@ -1168,6 +1277,7 @@ impl<'a> VariantToFixedSizeListArrowRowBuilder<'a> { cast_options, capacity, NullValue::ArrayElement, + false, )?; ListElementBuilder::Shredded(Box::new(builder)) } else { @@ -1353,11 +1463,15 @@ mod tests { ]; for data_type in non_primitive_types { - let err = - match make_primitive_variant_to_arrow_row_builder(&data_type, &cast_options, 1) { - Ok(_) => panic!("non-primitive type {data_type:?} should be rejected"), - Err(err) => err, - }; + let err = match make_primitive_variant_to_arrow_row_builder( + &data_type, + &cast_options, + 1, + false, + ) { + Ok(_) => panic!("non-primitive type {data_type:?} should be rejected"), + Err(err) => err, + }; match err { ArrowError::InvalidArgumentError(msg) => { @@ -1375,7 +1489,7 @@ mod tests { ..Default::default() }; let mut builder = - make_primitive_variant_to_arrow_row_builder(&DataType::Int32, &cast_options, 2) + make_primitive_variant_to_arrow_row_builder(&DataType::Int32, &cast_options, 2, false) .unwrap(); assert!(!builder.append_value(&Variant::Null).unwrap()); @@ -1397,6 +1511,7 @@ mod tests { &DataType::Decimal32(9, 2), &cast_options, 2, + false, ) .unwrap(); let decimal_variant: Variant<'_, '_> = VariantDecimal4::try_new(1234, 2).unwrap().into(); @@ -1420,6 +1535,7 @@ mod tests { &DataType::FixedSizeBinary(16), &cast_options, 2, + false, ) .unwrap(); let uuid = Uuid::nil(); diff --git a/parquet-variant/src/variant.rs b/parquet-variant/src/variant.rs index 3768191ea7dd..8340ff7e678a 100644 --- a/parquet-variant/src/variant.rs +++ b/parquet-variant/src/variant.rs @@ -841,20 +841,13 @@ impl<'m, 'v> Variant<'m, 'v> { /// let v1 = Variant::Binary(data); /// assert_eq!(v1.as_u8_slice(), Some(data.as_slice())); /// - /// // or string variant - /// let data = b"world"; - /// let v2 = Variant::from("world"); - /// assert_eq!(v2.as_u8_slice(), Some(data.as_slice())); - /// /// // but not from other variant types - /// let v3 = Variant::from(123i64); - /// assert_eq!(v3.as_u8_slice(), None); + /// let v2 = Variant::from(123i64); + /// assert_eq!(v2.as_u8_slice(), None); /// ``` pub fn as_u8_slice(&'v self) -> Option<&'v [u8]> { match self { Variant::Binary(d) => Some(d), - Variant::String(s) => Some(s.as_bytes()), - Variant::ShortString(s) => Some(s.as_ref().as_bytes()), _ => None, } }