diff --git a/vortex-cuda/src/arrow/canonical.rs b/vortex-cuda/src/arrow/canonical.rs index f8feacc6110..a74efcd3bf7 100644 --- a/vortex-cuda/src/arrow/canonical.rs +++ b/vortex-cuda/src/arrow/canonical.rs @@ -253,8 +253,6 @@ fn export_fixed_size( Ok((arrow_array, sync_event)) } -// export some nested data - unsafe extern "C" fn release_array(array: *mut ArrowArray) { // SAFETY: this is only safe if we're dropping an ArrowArray that was created from Rust // code. This is necessary to ensure that the fields inside the CudaPrivateData diff --git a/vortex-test/e2e-cuda/src/lib.rs b/vortex-test/e2e-cuda/src/lib.rs index d2fd9728f5b..a9cd16605d9 100644 --- a/vortex-test/e2e-cuda/src/lib.rs +++ b/vortex-test/e2e-cuda/src/lib.rs @@ -19,6 +19,7 @@ use std::sync::LazyLock; use arrow_array::Array; use arrow_array::ArrayRef; +use arrow_array::Date32Array; use arrow_array::Decimal128Array; use arrow_array::StringArray; use arrow_array::UInt32Array; @@ -35,11 +36,13 @@ use vortex::array::IntoArray; use vortex::array::arrays::DecimalArray; use vortex::array::arrays::PrimitiveArray; use vortex::array::arrays::StructArray; +use vortex::array::arrays::TemporalArray; use vortex::array::arrays::VarBinViewArray; use vortex::array::session::ArraySession; use vortex::array::validity::Validity; use vortex::dtype::DecimalDType; use vortex::dtype::FieldNames; +use vortex::dtype::datetime::TimeUnit; use vortex::expr::session::ExprSession; use vortex::io::session::RuntimeSession; use vortex::layout::session::LayoutSession; @@ -75,13 +78,18 @@ pub unsafe extern "C" fn export_array( "four", "this string is long five", ]); + let dates = TemporalArray::new_date( + PrimitiveArray::from_iter([100i32, 200, 300, 400, 500]).into_array(), + TimeUnit::Days, + ); let array = StructArray::new( - FieldNames::from_iter(["prims", "decimals", "strings"]), + FieldNames::from_iter(["prims", "decimals", "strings", "dates"]), vec![ primitive.into_array(), decimal.into_array(), strings.into_array(), + dates.into_array(), ], 5, Validity::NonNullable, @@ -92,6 +100,7 @@ pub unsafe extern "C" fn export_array( Field::new("prims", DataType::UInt32, false), Field::new("decimals", DataType::Decimal128(38, 2), false), Field::new("strings", DataType::Utf8, false), + Field::new("dates", DataType::Date32, false), ])); *schema_ptr = FFI_ArrowSchema::try_from(data_type).expect("data_type to FFI_ArrowSchema"); @@ -135,11 +144,13 @@ pub unsafe extern "C" fn validate_array( "four", "this string is long five", ]); + let date = Date32Array::from(vec![100i32, 200, 300, 400, 500]); let expected_fields = Fields::from_iter([ Field::new("prims", primitive.data_type().clone(), false), Field::new("decimals", decimal.data_type().clone(), false), Field::new("strings", string.data_type().clone(), false), + Field::new("dates", date.data_type().clone(), false), ]); assert_eq!( @@ -149,7 +160,12 @@ pub unsafe extern "C" fn validate_array( struct_array.fields() ); - let expected_fields: [ArrayRef; _] = [Arc::new(primitive), Arc::new(decimal), Arc::new(string)]; + let expected_fields: [ArrayRef; _] = [ + Arc::new(primitive), + Arc::new(decimal), + Arc::new(string), + Arc::new(date), + ]; for (expected, actual) in expected_fields.iter().zip(struct_array.columns()) { assert_eq!(expected.as_ref(), actual.as_ref());