Skip to content

Commit 6b8fb88

Browse files
authored
chore[cuda]: add date32 exporter test (#6345)
last one needed for full TPCH type coverage Signed-off-by: Andrew Duffy <andrew@a10y.dev>
1 parent dcc89b9 commit 6b8fb88

2 files changed

Lines changed: 18 additions & 4 deletions

File tree

vortex-cuda/src/arrow/canonical.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,8 +253,6 @@ fn export_fixed_size(
253253
Ok((arrow_array, sync_event))
254254
}
255255

256-
// export some nested data
257-
258256
unsafe extern "C" fn release_array(array: *mut ArrowArray) {
259257
// SAFETY: this is only safe if we're dropping an ArrowArray that was created from Rust
260258
// code. This is necessary to ensure that the fields inside the CudaPrivateData

vortex-test/e2e-cuda/src/lib.rs

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use std::sync::LazyLock;
1919

2020
use arrow_array::Array;
2121
use arrow_array::ArrayRef;
22+
use arrow_array::Date32Array;
2223
use arrow_array::Decimal128Array;
2324
use arrow_array::StringArray;
2425
use arrow_array::UInt32Array;
@@ -35,11 +36,13 @@ use vortex::array::IntoArray;
3536
use vortex::array::arrays::DecimalArray;
3637
use vortex::array::arrays::PrimitiveArray;
3738
use vortex::array::arrays::StructArray;
39+
use vortex::array::arrays::TemporalArray;
3840
use vortex::array::arrays::VarBinViewArray;
3941
use vortex::array::session::ArraySession;
4042
use vortex::array::validity::Validity;
4143
use vortex::dtype::DecimalDType;
4244
use vortex::dtype::FieldNames;
45+
use vortex::dtype::datetime::TimeUnit;
4346
use vortex::expr::session::ExprSession;
4447
use vortex::io::session::RuntimeSession;
4548
use vortex::layout::session::LayoutSession;
@@ -75,13 +78,18 @@ pub unsafe extern "C" fn export_array(
7578
"four",
7679
"this string is long five",
7780
]);
81+
let dates = TemporalArray::new_date(
82+
PrimitiveArray::from_iter([100i32, 200, 300, 400, 500]).into_array(),
83+
TimeUnit::Days,
84+
);
7885

7986
let array = StructArray::new(
80-
FieldNames::from_iter(["prims", "decimals", "strings"]),
87+
FieldNames::from_iter(["prims", "decimals", "strings", "dates"]),
8188
vec![
8289
primitive.into_array(),
8390
decimal.into_array(),
8491
strings.into_array(),
92+
dates.into_array(),
8593
],
8694
5,
8795
Validity::NonNullable,
@@ -92,6 +100,7 @@ pub unsafe extern "C" fn export_array(
92100
Field::new("prims", DataType::UInt32, false),
93101
Field::new("decimals", DataType::Decimal128(38, 2), false),
94102
Field::new("strings", DataType::Utf8, false),
103+
Field::new("dates", DataType::Date32, false),
95104
]));
96105

97106
*schema_ptr = FFI_ArrowSchema::try_from(data_type).expect("data_type to FFI_ArrowSchema");
@@ -135,11 +144,13 @@ pub unsafe extern "C" fn validate_array(
135144
"four",
136145
"this string is long five",
137146
]);
147+
let date = Date32Array::from(vec![100i32, 200, 300, 400, 500]);
138148

139149
let expected_fields = Fields::from_iter([
140150
Field::new("prims", primitive.data_type().clone(), false),
141151
Field::new("decimals", decimal.data_type().clone(), false),
142152
Field::new("strings", string.data_type().clone(), false),
153+
Field::new("dates", date.data_type().clone(), false),
143154
]);
144155

145156
assert_eq!(
@@ -149,7 +160,12 @@ pub unsafe extern "C" fn validate_array(
149160
struct_array.fields()
150161
);
151162

152-
let expected_fields: [ArrayRef; _] = [Arc::new(primitive), Arc::new(decimal), Arc::new(string)];
163+
let expected_fields: [ArrayRef; _] = [
164+
Arc::new(primitive),
165+
Arc::new(decimal),
166+
Arc::new(string),
167+
Arc::new(date),
168+
];
153169

154170
for (expected, actual) in expected_fields.iter().zip(struct_array.columns()) {
155171
assert_eq!(expected.as_ref(), actual.as_ref());

0 commit comments

Comments
 (0)