Skip to content

Commit a11336d

Browse files
chore[array]: move take kernel to execute/reduce rules (#6310)
Signed-off-by: Joe Isaacs <joe.isaacs@live.co.uk>
1 parent 6fc233d commit a11336d

94 files changed

Lines changed: 1418 additions & 1056 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

encodings/alp/src/alp/compute/take.rs

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,21 @@
33

44
use vortex_array::Array;
55
use vortex_array::ArrayRef;
6+
use vortex_array::ExecutionCtx;
67
use vortex_array::IntoArray;
7-
use vortex_array::compute::TakeKernel;
8-
use vortex_array::compute::TakeKernelAdapter;
9-
use vortex_array::compute::take;
10-
use vortex_array::register_kernel;
8+
use vortex_array::arrays::TakeExecute;
119
use vortex_error::VortexResult;
1210

1311
use crate::ALPArray;
1412
use crate::ALPVTable;
1513

16-
impl TakeKernel for ALPVTable {
17-
fn take(&self, array: &ALPArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
18-
let taken_encoded = take(array.encoded(), indices)?;
14+
impl TakeExecute for ALPVTable {
15+
fn take(
16+
array: &ALPArray,
17+
indices: &dyn Array,
18+
_ctx: &mut ExecutionCtx,
19+
) -> VortexResult<Option<ArrayRef>> {
20+
let taken_encoded = array.encoded().take(indices.to_array())?;
1921
let taken_patches = array
2022
.patches()
2123
.map(|p| p.take(indices))
@@ -29,12 +31,12 @@ impl TakeKernel for ALPVTable {
2931
)
3032
})
3133
.transpose()?;
32-
Ok(ALPArray::new(taken_encoded, array.exponents(), taken_patches).into_array())
34+
Ok(Some(
35+
ALPArray::new(taken_encoded, array.exponents(), taken_patches).into_array(),
36+
))
3337
}
3438
}
3539

36-
register_kernel!(TakeKernelAdapter(ALPVTable).lift());
37-
3840
#[cfg(test)]
3941
mod test {
4042
use rstest::rstest;

encodings/alp/src/alp/rules.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33

44
use vortex_array::arrays::FilterExecuteAdaptor;
55
use vortex_array::arrays::SliceExecuteAdaptor;
6+
use vortex_array::arrays::TakeExecuteAdaptor;
67
use vortex_array::kernel::ParentKernelSet;
78

89
use crate::ALPVTable;
910

1011
pub(super) const PARENT_KERNELS: ParentKernelSet<ALPVTable> = ParentKernelSet::new(&[
1112
ParentKernelSet::lift(&FilterExecuteAdaptor(ALPVTable)),
1213
ParentKernelSet::lift(&SliceExecuteAdaptor(ALPVTable)),
14+
ParentKernelSet::lift(&TakeExecuteAdaptor(ALPVTable)),
1315
]);

encodings/alp/src/alp_rd/compute/take.rs

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,24 @@
33

44
use vortex_array::Array;
55
use vortex_array::ArrayRef;
6+
use vortex_array::ExecutionCtx;
67
use vortex_array::IntoArray;
7-
use vortex_array::compute::TakeKernel;
8-
use vortex_array::compute::TakeKernelAdapter;
8+
use vortex_array::arrays::TakeExecute;
99
use vortex_array::compute::fill_null;
10-
use vortex_array::compute::take;
11-
use vortex_array::register_kernel;
1210
use vortex_error::VortexResult;
1311
use vortex_scalar::Scalar;
1412
use vortex_scalar::ScalarValue;
1513

1614
use crate::ALPRDArray;
1715
use crate::ALPRDVTable;
1816

19-
impl TakeKernel for ALPRDVTable {
20-
fn take(&self, array: &ALPRDArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
21-
let taken_left_parts = take(array.left_parts(), indices)?;
17+
impl TakeExecute for ALPRDVTable {
18+
fn take(
19+
array: &ALPRDArray,
20+
indices: &dyn Array,
21+
_ctx: &mut ExecutionCtx,
22+
) -> VortexResult<Option<ArrayRef>> {
23+
let taken_left_parts = array.left_parts().take(indices.to_array())?;
2224
let left_parts_exceptions = array
2325
.left_parts_patches()
2426
.map(|patches| patches.take(indices))
@@ -33,26 +35,26 @@ impl TakeKernel for ALPRDVTable {
3335
})
3436
.transpose()?;
3537
let right_parts = fill_null(
36-
&take(array.right_parts(), indices)?,
38+
&array.right_parts().take(indices.to_array())?,
3739
&Scalar::new(array.right_parts().dtype().clone(), ScalarValue::from(0)),
3840
)?;
3941

40-
Ok(ALPRDArray::try_new(
41-
array
42-
.dtype()
43-
.with_nullability(taken_left_parts.dtype().nullability()),
44-
taken_left_parts,
45-
array.left_parts_dictionary().clone(),
46-
right_parts,
47-
array.right_bit_width(),
48-
left_parts_exceptions,
49-
)?
50-
.into_array())
42+
Ok(Some(
43+
ALPRDArray::try_new(
44+
array
45+
.dtype()
46+
.with_nullability(taken_left_parts.dtype().nullability()),
47+
taken_left_parts,
48+
array.left_parts_dictionary().clone(),
49+
right_parts,
50+
array.right_bit_width(),
51+
left_parts_exceptions,
52+
)?
53+
.into_array(),
54+
))
5155
}
5256
}
5357

54-
register_kernel!(TakeKernelAdapter(ALPRDVTable).lift());
55-
5658
#[cfg(test)]
5759
mod test {
5860
use rstest::rstest;

encodings/alp/src/alp_rd/kernel.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33

44
use vortex_array::arrays::FilterExecuteAdaptor;
55
use vortex_array::arrays::SliceExecuteAdaptor;
6+
use vortex_array::arrays::TakeExecuteAdaptor;
67
use vortex_array::kernel::ParentKernelSet;
78

89
use crate::alp_rd::ALPRDVTable;
910

1011
pub(crate) static PARENT_KERNELS: ParentKernelSet<ALPRDVTable> = ParentKernelSet::new(&[
1112
ParentKernelSet::lift(&SliceExecuteAdaptor(ALPRDVTable)),
1213
ParentKernelSet::lift(&FilterExecuteAdaptor(ALPRDVTable)),
14+
ParentKernelSet::lift(&TakeExecuteAdaptor(ALPRDVTable)),
1315
]);

encodings/bytebool/src/array.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ use vortex_error::vortex_panic;
3838
use vortex_scalar::Scalar;
3939
use vortex_session::VortexSession;
4040

41+
use crate::kernel::PARENT_KERNELS;
42+
4143
vtable!(ByteBool);
4244

4345
impl VTable for ByteBoolVTable {
@@ -124,6 +126,15 @@ impl VTable for ByteBoolVTable {
124126
let validity = array.validity().clone();
125127
Ok(BoolArray::new(boolean_buffer, validity).into_array())
126128
}
129+
130+
fn execute_parent(
131+
array: &Self::Array,
132+
parent: &ArrayRef,
133+
child_idx: usize,
134+
ctx: &mut ExecutionCtx,
135+
) -> VortexResult<Option<ArrayRef>> {
136+
PARENT_KERNELS.execute(array, parent, child_idx, ctx)
137+
}
127138
}
128139

129140
#[derive(Clone, Debug)]

encodings/bytebool/src/compute.rs

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
use num_traits::AsPrimitive;
55
use vortex_array::Array;
66
use vortex_array::ArrayRef;
7+
use vortex_array::ExecutionCtx;
78
use vortex_array::IntoArray;
89
use vortex_array::ToCanonical;
10+
use vortex_array::arrays::TakeExecute;
911
use vortex_array::compute::CastKernel;
1012
use vortex_array::compute::CastKernelAdapter;
1113
use vortex_array::compute::MaskKernel;
1214
use vortex_array::compute::MaskKernelAdapter;
13-
use vortex_array::compute::TakeKernel;
14-
use vortex_array::compute::TakeKernelAdapter;
1515
use vortex_array::register_kernel;
1616
use vortex_array::vtable::ValidityHelper;
1717
use vortex_dtype::DType;
@@ -55,8 +55,12 @@ impl MaskKernel for ByteBoolVTable {
5555

5656
register_kernel!(MaskKernelAdapter(ByteBoolVTable).lift());
5757

58-
impl TakeKernel for ByteBoolVTable {
59-
fn take(&self, array: &ByteBoolArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
58+
impl TakeExecute for ByteBoolVTable {
59+
fn take(
60+
array: &ByteBoolArray,
61+
indices: &dyn Array,
62+
_ctx: &mut ExecutionCtx,
63+
) -> VortexResult<Option<ArrayRef>> {
6064
let indices = indices.to_primitive();
6165
let bools = array.as_slice();
6266

@@ -74,12 +78,12 @@ impl TakeKernel for ByteBoolVTable {
7478
.collect::<Vec<bool>>()
7579
});
7680

77-
Ok(ByteBoolArray::from_vec(taken_bools, validity).into_array())
81+
Ok(Some(
82+
ByteBoolArray::from_vec(taken_bools, validity).into_array(),
83+
))
7884
}
7985
}
8086

81-
register_kernel!(TakeKernelAdapter(ByteBoolVTable).lift());
82-
8387
#[cfg(test)]
8488
mod tests {
8589
use rstest::rstest;

encodings/bytebool/src/kernel.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_array::arrays::TakeExecuteAdaptor;
5+
use vortex_array::kernel::ParentKernelSet;
6+
7+
use crate::ByteBoolVTable;
8+
9+
pub(crate) const PARENT_KERNELS: ParentKernelSet<ByteBoolVTable> =
10+
ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(ByteBoolVTable))]);

encodings/bytebool/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@ pub use array::*;
55

66
mod array;
77
mod compute;
8+
mod kernel;
89
mod rules;
910
mod slice;

encodings/datetime-parts/src/array.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ use vortex_error::vortex_err;
3838
use vortex_session::VortexSession;
3939

4040
use crate::canonical::decode_to_temporal;
41+
use crate::compute::kernel::PARENT_KERNELS;
4142
use crate::compute::rules::PARENT_RULES;
4243

4344
vtable!(DateTimeParts);
@@ -168,6 +169,15 @@ impl VTable for DateTimePartsVTable {
168169
) -> VortexResult<Option<ArrayRef>> {
169170
PARENT_RULES.evaluate(array, parent, child_idx)
170171
}
172+
173+
fn execute_parent(
174+
array: &Self::Array,
175+
parent: &ArrayRef,
176+
child_idx: usize,
177+
ctx: &mut ExecutionCtx,
178+
) -> VortexResult<Option<ArrayRef>> {
179+
PARENT_KERNELS.execute(array, parent, child_idx, ctx)
180+
}
171181
}
172182

173183
#[derive(Clone, Debug)]
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_array::arrays::TakeExecuteAdaptor;
5+
use vortex_array::kernel::ParentKernelSet;
6+
7+
use crate::DateTimePartsVTable;
8+
9+
pub(crate) const PARENT_KERNELS: ParentKernelSet<DateTimePartsVTable> =
10+
ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(
11+
DateTimePartsVTable,
12+
))]);

0 commit comments

Comments
 (0)