diff --git a/encodings/alp/src/alp/compute/take.rs b/encodings/alp/src/alp/compute/take.rs index 6965f0350a3..9d97c598e28 100644 --- a/encodings/alp/src/alp/compute/take.rs +++ b/encodings/alp/src/alp/compute/take.rs @@ -3,19 +3,21 @@ use vortex_array::Array; use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; use vortex_array::IntoArray; -use vortex_array::compute::TakeKernel; -use vortex_array::compute::TakeKernelAdapter; -use vortex_array::compute::take; -use vortex_array::register_kernel; +use vortex_array::arrays::TakeExecute; use vortex_error::VortexResult; use crate::ALPArray; use crate::ALPVTable; -impl TakeKernel for ALPVTable { - fn take(&self, array: &ALPArray, indices: &dyn Array) -> VortexResult { - let taken_encoded = take(array.encoded(), indices)?; +impl TakeExecute for ALPVTable { + fn take( + array: &ALPArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + let taken_encoded = array.encoded().take(indices.to_array())?; let taken_patches = array .patches() .map(|p| p.take(indices)) @@ -29,12 +31,12 @@ impl TakeKernel for ALPVTable { ) }) .transpose()?; - Ok(ALPArray::new(taken_encoded, array.exponents(), taken_patches).into_array()) + Ok(Some( + ALPArray::new(taken_encoded, array.exponents(), taken_patches).into_array(), + )) } } -register_kernel!(TakeKernelAdapter(ALPVTable).lift()); - #[cfg(test)] mod test { use rstest::rstest; diff --git a/encodings/alp/src/alp/rules.rs b/encodings/alp/src/alp/rules.rs index 69b88759fbc..eb25c3fc872 100644 --- a/encodings/alp/src/alp/rules.rs +++ b/encodings/alp/src/alp/rules.rs @@ -3,6 +3,7 @@ use vortex_array::arrays::FilterExecuteAdaptor; use vortex_array::arrays::SliceExecuteAdaptor; +use vortex_array::arrays::TakeExecuteAdaptor; use vortex_array::kernel::ParentKernelSet; use crate::ALPVTable; @@ -10,4 +11,5 @@ use crate::ALPVTable; pub(super) const PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ ParentKernelSet::lift(&FilterExecuteAdaptor(ALPVTable)), ParentKernelSet::lift(&SliceExecuteAdaptor(ALPVTable)), + ParentKernelSet::lift(&TakeExecuteAdaptor(ALPVTable)), ]); diff --git a/encodings/alp/src/alp_rd/compute/take.rs b/encodings/alp/src/alp_rd/compute/take.rs index 1b43c0c423f..b2a395c0813 100644 --- a/encodings/alp/src/alp_rd/compute/take.rs +++ b/encodings/alp/src/alp_rd/compute/take.rs @@ -3,12 +3,10 @@ use vortex_array::Array; use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; use vortex_array::IntoArray; -use vortex_array::compute::TakeKernel; -use vortex_array::compute::TakeKernelAdapter; +use vortex_array::arrays::TakeExecute; use vortex_array::compute::fill_null; -use vortex_array::compute::take; -use vortex_array::register_kernel; use vortex_error::VortexResult; use vortex_scalar::Scalar; use vortex_scalar::ScalarValue; @@ -16,9 +14,13 @@ use vortex_scalar::ScalarValue; use crate::ALPRDArray; use crate::ALPRDVTable; -impl TakeKernel for ALPRDVTable { - fn take(&self, array: &ALPRDArray, indices: &dyn Array) -> VortexResult { - let taken_left_parts = take(array.left_parts(), indices)?; +impl TakeExecute for ALPRDVTable { + fn take( + array: &ALPRDArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + let taken_left_parts = array.left_parts().take(indices.to_array())?; let left_parts_exceptions = array .left_parts_patches() .map(|patches| patches.take(indices)) @@ -33,26 +35,26 @@ impl TakeKernel for ALPRDVTable { }) .transpose()?; let right_parts = fill_null( - &take(array.right_parts(), indices)?, + &array.right_parts().take(indices.to_array())?, &Scalar::new(array.right_parts().dtype().clone(), ScalarValue::from(0)), )?; - Ok(ALPRDArray::try_new( - array - .dtype() - .with_nullability(taken_left_parts.dtype().nullability()), - taken_left_parts, - array.left_parts_dictionary().clone(), - right_parts, - array.right_bit_width(), - left_parts_exceptions, - )? - .into_array()) + Ok(Some( + ALPRDArray::try_new( + array + .dtype() + .with_nullability(taken_left_parts.dtype().nullability()), + taken_left_parts, + array.left_parts_dictionary().clone(), + right_parts, + array.right_bit_width(), + left_parts_exceptions, + )? + .into_array(), + )) } } -register_kernel!(TakeKernelAdapter(ALPRDVTable).lift()); - #[cfg(test)] mod test { use rstest::rstest; diff --git a/encodings/alp/src/alp_rd/kernel.rs b/encodings/alp/src/alp_rd/kernel.rs index ad9dd1c6a56..c2f37e5495e 100644 --- a/encodings/alp/src/alp_rd/kernel.rs +++ b/encodings/alp/src/alp_rd/kernel.rs @@ -3,6 +3,7 @@ use vortex_array::arrays::FilterExecuteAdaptor; use vortex_array::arrays::SliceExecuteAdaptor; +use vortex_array::arrays::TakeExecuteAdaptor; use vortex_array::kernel::ParentKernelSet; use crate::alp_rd::ALPRDVTable; @@ -10,4 +11,5 @@ use crate::alp_rd::ALPRDVTable; pub(crate) static PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ ParentKernelSet::lift(&SliceExecuteAdaptor(ALPRDVTable)), ParentKernelSet::lift(&FilterExecuteAdaptor(ALPRDVTable)), + ParentKernelSet::lift(&TakeExecuteAdaptor(ALPRDVTable)), ]); diff --git a/encodings/bytebool/src/array.rs b/encodings/bytebool/src/array.rs index 8462e7513aa..214227c7ec8 100644 --- a/encodings/bytebool/src/array.rs +++ b/encodings/bytebool/src/array.rs @@ -38,6 +38,8 @@ use vortex_error::vortex_panic; use vortex_scalar::Scalar; use vortex_session::VortexSession; +use crate::kernel::PARENT_KERNELS; + vtable!(ByteBool); impl VTable for ByteBoolVTable { @@ -124,6 +126,15 @@ impl VTable for ByteBoolVTable { let validity = array.validity().clone(); Ok(BoolArray::new(boolean_buffer, validity).into_array()) } + + fn execute_parent( + array: &Self::Array, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + PARENT_KERNELS.execute(array, parent, child_idx, ctx) + } } #[derive(Clone, Debug)] diff --git a/encodings/bytebool/src/compute.rs b/encodings/bytebool/src/compute.rs index 4bda0c343ef..ab9624db01a 100644 --- a/encodings/bytebool/src/compute.rs +++ b/encodings/bytebool/src/compute.rs @@ -4,14 +4,14 @@ use num_traits::AsPrimitive; use vortex_array::Array; use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::ToCanonical; +use vortex_array::arrays::TakeExecute; use vortex_array::compute::CastKernel; use vortex_array::compute::CastKernelAdapter; use vortex_array::compute::MaskKernel; use vortex_array::compute::MaskKernelAdapter; -use vortex_array::compute::TakeKernel; -use vortex_array::compute::TakeKernelAdapter; use vortex_array::register_kernel; use vortex_array::vtable::ValidityHelper; use vortex_dtype::DType; @@ -55,8 +55,12 @@ impl MaskKernel for ByteBoolVTable { register_kernel!(MaskKernelAdapter(ByteBoolVTable).lift()); -impl TakeKernel for ByteBoolVTable { - fn take(&self, array: &ByteBoolArray, indices: &dyn Array) -> VortexResult { +impl TakeExecute for ByteBoolVTable { + fn take( + array: &ByteBoolArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { let indices = indices.to_primitive(); let bools = array.as_slice(); @@ -74,12 +78,12 @@ impl TakeKernel for ByteBoolVTable { .collect::>() }); - Ok(ByteBoolArray::from_vec(taken_bools, validity).into_array()) + Ok(Some( + ByteBoolArray::from_vec(taken_bools, validity).into_array(), + )) } } -register_kernel!(TakeKernelAdapter(ByteBoolVTable).lift()); - #[cfg(test)] mod tests { use rstest::rstest; diff --git a/encodings/bytebool/src/kernel.rs b/encodings/bytebool/src/kernel.rs new file mode 100644 index 00000000000..91eb3add3a6 --- /dev/null +++ b/encodings/bytebool/src/kernel.rs @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::arrays::TakeExecuteAdaptor; +use vortex_array::kernel::ParentKernelSet; + +use crate::ByteBoolVTable; + +pub(crate) const PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(ByteBoolVTable))]); diff --git a/encodings/bytebool/src/lib.rs b/encodings/bytebool/src/lib.rs index c195e618317..35579a89dc8 100644 --- a/encodings/bytebool/src/lib.rs +++ b/encodings/bytebool/src/lib.rs @@ -5,5 +5,6 @@ pub use array::*; mod array; mod compute; +mod kernel; mod rules; mod slice; diff --git a/encodings/datetime-parts/src/array.rs b/encodings/datetime-parts/src/array.rs index 86f640d4b83..986c452674e 100644 --- a/encodings/datetime-parts/src/array.rs +++ b/encodings/datetime-parts/src/array.rs @@ -38,6 +38,7 @@ use vortex_error::vortex_err; use vortex_session::VortexSession; use crate::canonical::decode_to_temporal; +use crate::compute::kernel::PARENT_KERNELS; use crate::compute::rules::PARENT_RULES; vtable!(DateTimeParts); @@ -168,6 +169,15 @@ impl VTable for DateTimePartsVTable { ) -> VortexResult> { PARENT_RULES.evaluate(array, parent, child_idx) } + + fn execute_parent( + array: &Self::Array, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + PARENT_KERNELS.execute(array, parent, child_idx, ctx) + } } #[derive(Clone, Debug)] diff --git a/encodings/datetime-parts/src/compute/kernel.rs b/encodings/datetime-parts/src/compute/kernel.rs new file mode 100644 index 00000000000..9c95c3439ca --- /dev/null +++ b/encodings/datetime-parts/src/compute/kernel.rs @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::arrays::TakeExecuteAdaptor; +use vortex_array::kernel::ParentKernelSet; + +use crate::DateTimePartsVTable; + +pub(crate) const PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor( + DateTimePartsVTable, + ))]); diff --git a/encodings/datetime-parts/src/compute/mod.rs b/encodings/datetime-parts/src/compute/mod.rs index caf7df8ffbf..d606daccb59 100644 --- a/encodings/datetime-parts/src/compute/mod.rs +++ b/encodings/datetime-parts/src/compute/mod.rs @@ -5,6 +5,7 @@ mod cast; mod compare; mod filter; mod is_constant; +pub(crate) mod kernel; mod mask; pub(super) mod rules; mod slice; diff --git a/encodings/datetime-parts/src/compute/take.rs b/encodings/datetime-parts/src/compute/take.rs index 23ccc46af61..8014773c33a 100644 --- a/encodings/datetime-parts/src/compute/take.rs +++ b/encodings/datetime-parts/src/compute/take.rs @@ -3,15 +3,13 @@ use vortex_array::Array; use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::ToCanonical; -use vortex_array::compute::TakeKernel; -use vortex_array::compute::TakeKernelAdapter; +use vortex_array::arrays::TakeExecute; use vortex_array::compute::fill_null; -use vortex_array::compute::take; use vortex_array::expr::stats::Stat; use vortex_array::expr::stats::StatsProvider; -use vortex_array::register_kernel; use vortex_dtype::Nullability; use vortex_error::VortexResult; use vortex_error::vortex_panic; @@ -20,81 +18,83 @@ use vortex_scalar::Scalar; use crate::DateTimePartsArray; use crate::DateTimePartsVTable; -impl TakeKernel for DateTimePartsVTable { - fn take(&self, array: &DateTimePartsArray, indices: &dyn Array) -> VortexResult { - // we go ahead and canonicalize here to avoid worst-case canonicalizing 3 separate times - let indices = indices.to_primitive(); +fn take_datetime_parts(array: &DateTimePartsArray, indices: &dyn Array) -> VortexResult { + // we go ahead and canonicalize here to avoid worst-case canonicalizing 3 separate times + let indices = indices.to_primitive(); - let taken_days = take(array.days(), indices.as_ref())?; - let taken_seconds = take(array.seconds(), indices.as_ref())?; - let taken_subseconds = take(array.subseconds(), indices.as_ref())?; + let taken_days = array.days().take(indices.to_array())?; + let taken_seconds = array.seconds().take(indices.to_array())?; + let taken_subseconds = array.subseconds().take(indices.to_array())?; - // Update the dtype if the nullability changed due to nullable indices - let dtype = if taken_days.dtype().is_nullable() != array.dtype().is_nullable() { - array - .dtype() - .with_nullability(taken_days.dtype().nullability()) - } else { - array.dtype().clone() - }; + // Update the dtype if the nullability changed due to nullable indices + let dtype = if taken_days.dtype().is_nullable() != array.dtype().is_nullable() { + array + .dtype() + .with_nullability(taken_days.dtype().nullability()) + } else { + array.dtype().clone() + }; - if !taken_seconds.dtype().is_nullable() && !taken_subseconds.dtype().is_nullable() { - return Ok(DateTimePartsArray::try_new( - dtype, - taken_days, - taken_seconds, - taken_subseconds, - )? - .into_array()); - } + if !taken_seconds.dtype().is_nullable() && !taken_subseconds.dtype().is_nullable() { + return Ok(DateTimePartsArray::try_new( + dtype, + taken_days, + taken_seconds, + taken_subseconds, + )? + .into_array()); + } + + // DateTimePartsArray requires seconds and subseconds to be non-nullable. + // If they became nullable due to nullable indices, we need to fill nulls. + // But first, we need to check that the types are consistent. + if !taken_days.dtype().is_nullable() { + vortex_panic!("Mismatched types: days is not nullable, seconds is nullable"); + } + if !taken_seconds.dtype().is_nullable() { + vortex_panic!("Mismatched types: seconds is not nullable, days is nullable"); + } + if !taken_subseconds.dtype().is_nullable() { + vortex_panic!("Mismatched types: subseconds is not nullable, days & seconds are nullable"); + } + if !indices.dtype().is_nullable() { + vortex_panic!("Mismatched types: indices are not nullable, days & seconds are nullable"); + } - // DateTimePartsArray requires seconds and subseconds to be non-nullable. - // If they became nullable due to nullable indices, we need to fill nulls. - // But first, we need to check that the types are consistent. - if !taken_days.dtype().is_nullable() { - vortex_panic!("Mismatched types: days is not nullable, seconds is nullable"); - } - if !taken_seconds.dtype().is_nullable() { - vortex_panic!("Mismatched types: seconds is not nullable, days is nullable"); - } - if !taken_subseconds.dtype().is_nullable() { - vortex_panic!( - "Mismatched types: subseconds is not nullable, days & seconds are nullable" - ); - } - if !indices.dtype().is_nullable() { - vortex_panic!( - "Mismatched types: indices are not nullable, days & seconds are nullable" - ); - } + let seconds_fill = array + .seconds() + .statistics() + .get(Stat::Min) + .map(|s| s.into_inner()) + .unwrap_or_else(|| Scalar::primitive(0i64, Nullability::NonNullable)) + .cast(array.seconds().dtype())?; + let taken_seconds = fill_null(taken_seconds.as_ref(), &seconds_fill)?; - let seconds_fill = array - .seconds() - .statistics() - .get(Stat::Min) - .map(|s| s.into_inner()) - .unwrap_or_else(|| Scalar::primitive(0i64, Nullability::NonNullable)) - .cast(array.seconds().dtype())?; - let taken_seconds = fill_null(taken_seconds.as_ref(), &seconds_fill)?; + let subseconds_fill = array + .subseconds() + .statistics() + .get(Stat::Min) + .map(|s| s.into_inner()) + .unwrap_or_else(|| Scalar::primitive(0i64, Nullability::NonNullable)) + .cast(array.subseconds().dtype())?; + let taken_subseconds = fill_null(taken_subseconds.as_ref(), &subseconds_fill)?; - let subseconds_fill = array - .subseconds() - .statistics() - .get(Stat::Min) - .map(|s| s.into_inner()) - .unwrap_or_else(|| Scalar::primitive(0i64, Nullability::NonNullable)) - .cast(array.subseconds().dtype())?; - let taken_subseconds = fill_null(taken_subseconds.as_ref(), &subseconds_fill)?; + Ok( + DateTimePartsArray::try_new(dtype, taken_days, taken_seconds, taken_subseconds)? + .into_array(), + ) +} - Ok( - DateTimePartsArray::try_new(dtype, taken_days, taken_seconds, taken_subseconds)? - .into_array(), - ) +impl TakeExecute for DateTimePartsVTable { + fn take( + array: &DateTimePartsArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + take_datetime_parts(array, indices).map(Some) } } -register_kernel!(TakeKernelAdapter(DateTimePartsVTable).lift()); - #[cfg(test)] mod tests { use rstest::rstest; diff --git a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/kernel.rs b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/kernel.rs new file mode 100644 index 00000000000..a802fad5db1 --- /dev/null +++ b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/kernel.rs @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::arrays::TakeExecuteAdaptor; +use vortex_array::kernel::ParentKernelSet; + +use crate::DecimalBytePartsVTable; + +pub(crate) const PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor( + DecimalBytePartsVTable, + ))]); diff --git a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/mod.rs b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/mod.rs index 262c6cd5b72..2e798106a7e 100644 --- a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/mod.rs +++ b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/mod.rs @@ -5,6 +5,7 @@ mod cast; mod compare; mod filter; mod is_constant; +pub(crate) mod kernel; mod mask; mod take; diff --git a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/take.rs b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/take.rs index 0368d3c1f7f..ba5a04c7d3b 100644 --- a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/take.rs +++ b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/take.rs @@ -3,20 +3,20 @@ use vortex_array::Array; use vortex_array::ArrayRef; -use vortex_array::compute::TakeKernel; -use vortex_array::compute::TakeKernelAdapter; -use vortex_array::compute::take; -use vortex_array::register_kernel; +use vortex_array::ExecutionCtx; +use vortex_array::arrays::TakeExecute; use vortex_error::VortexResult; use crate::DecimalBytePartsArray; use crate::DecimalBytePartsVTable; -impl TakeKernel for DecimalBytePartsVTable { - fn take(&self, array: &DecimalBytePartsArray, indices: &dyn Array) -> VortexResult { - DecimalBytePartsArray::try_new(take(&array.msp, indices)?, *array.decimal_dtype()) - .map(|a| a.to_array()) +impl TakeExecute for DecimalBytePartsVTable { + fn take( + array: &DecimalBytePartsArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + DecimalBytePartsArray::try_new(array.msp.take(indices.to_array())?, *array.decimal_dtype()) + .map(|a| Some(a.to_array())) } } - -register_kernel!(TakeKernelAdapter(DecimalBytePartsVTable).lift()); diff --git a/encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs b/encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs index 151c38a1d59..0e4de81842d 100644 --- a/encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs +++ b/encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs @@ -46,6 +46,7 @@ use vortex_scalar::DecimalValue; use vortex_scalar::Scalar; use vortex_session::VortexSession; +use crate::decimal_byte_parts::compute::kernel::PARENT_KERNELS; use crate::decimal_byte_parts::rules::PARENT_RULES; vtable!(DecimalByteParts); @@ -136,6 +137,15 @@ impl VTable for DecimalBytePartsVTable { fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { to_canonical_decimal(array, ctx) } + + fn execute_parent( + array: &Self::Array, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + PARENT_KERNELS.execute(array, parent, child_idx, ctx) + } } /// This array encodes decimals as between 1-4 columns of primitive typed children. diff --git a/encodings/fastlanes/src/bitpacking/compute/filter.rs b/encodings/fastlanes/src/bitpacking/compute/filter.rs index 8b6d05f53f9..e8d1385e4d8 100644 --- a/encodings/fastlanes/src/bitpacking/compute/filter.rs +++ b/encodings/fastlanes/src/bitpacking/compute/filter.rs @@ -28,7 +28,7 @@ use crate::BitPackedVTable; /// The threshold over which it is faster to fully unpack the entire [`BitPackedArray`] and then /// filter the result than to unpack only specific bitpacked values into the output buffer. -const fn unpack_then_filter_threshold(ptype: PType) -> f64 { +pub const fn unpack_then_filter_threshold(ptype: PType) -> f64 { // TODO(connor): Where did these numbers come from? Add a public link after validating them. // These numbers probably don't work for in-place filtering either. match ptype.byte_width() { diff --git a/encodings/fastlanes/src/bitpacking/compute/take.rs b/encodings/fastlanes/src/bitpacking/compute/take.rs index 65dd4032cd6..6a5b33495dc 100644 --- a/encodings/fastlanes/src/bitpacking/compute/take.rs +++ b/encodings/fastlanes/src/bitpacking/compute/take.rs @@ -7,13 +7,11 @@ use std::mem::MaybeUninit; use fastlanes::BitPacking; use vortex_array::Array; use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::ToCanonical; use vortex_array::arrays::PrimitiveArray; -use vortex_array::compute::TakeKernel; -use vortex_array::compute::TakeKernelAdapter; -use vortex_array::compute::take; -use vortex_array::register_kernel; +use vortex_array::arrays::TakeExecute; use vortex_array::validity::Validity; use vortex_array::vtable::ValidityHelper; use vortex_buffer::Buffer; @@ -37,11 +35,15 @@ use crate::bitpack_decompress; /// see https://github.com/vortex-data/vortex/pull/190#issue-2223752833 pub(super) const UNPACK_CHUNK_THRESHOLD: usize = 8; -impl TakeKernel for BitPackedVTable { - fn take(&self, array: &BitPackedArray, indices: &dyn Array) -> VortexResult { +impl TakeExecute for BitPackedVTable { + fn take( + array: &BitPackedArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { // If the indices are large enough, it's faster to flatten and take the primitive array. if indices.len() * UNPACK_CHUNK_THRESHOLD > array.len() { - return take(array.to_primitive().as_ref(), indices); + return array.to_primitive().take(indices.to_array()).map(Some); } // NOTE: we use the unsigned PType because all values in the BitPackedArray must @@ -56,12 +58,10 @@ impl TakeKernel for BitPackedVTable { take_primitive::(array, &indices, taken_validity)? }) }); - Ok(taken.reinterpret_cast(ptype).into_array()) + Ok(Some(taken.reinterpret_cast(ptype).into_array())) } } -register_kernel!(TakeKernelAdapter(BitPackedVTable).lift()); - fn take_primitive( array: &BitPackedArray, indices: &PrimitiveArray, diff --git a/encodings/fastlanes/src/bitpacking/vtable/kernels.rs b/encodings/fastlanes/src/bitpacking/vtable/kernels.rs index ffbba0c7bbb..ca9b98c09db 100644 --- a/encodings/fastlanes/src/bitpacking/vtable/kernels.rs +++ b/encodings/fastlanes/src/bitpacking/vtable/kernels.rs @@ -3,6 +3,7 @@ use vortex_array::arrays::FilterExecuteAdaptor; use vortex_array::arrays::SliceExecuteAdaptor; +use vortex_array::arrays::TakeExecuteAdaptor; use vortex_array::kernel::ParentKernelSet; use crate::BitPackedVTable; @@ -10,4 +11,5 @@ use crate::BitPackedVTable; pub(crate) const PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ ParentKernelSet::lift(&FilterExecuteAdaptor(BitPackedVTable)), ParentKernelSet::lift(&SliceExecuteAdaptor(BitPackedVTable)), + ParentKernelSet::lift(&TakeExecuteAdaptor(BitPackedVTable)), ]); diff --git a/encodings/fastlanes/src/for/compute/mod.rs b/encodings/fastlanes/src/for/compute/mod.rs index f7bbe4a6c64..a8efc731793 100644 --- a/encodings/fastlanes/src/for/compute/mod.rs +++ b/encodings/fastlanes/src/for/compute/mod.rs @@ -8,30 +8,32 @@ mod is_sorted; use vortex_array::Array; use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::FilterReduce; -use vortex_array::compute::TakeKernel; -use vortex_array::compute::TakeKernelAdapter; -use vortex_array::compute::take; -use vortex_array::register_kernel; +use vortex_array::arrays::TakeExecute; use vortex_error::VortexResult; use vortex_mask::Mask; use crate::FoRArray; use crate::FoRVTable; -impl TakeKernel for FoRVTable { - fn take(&self, array: &FoRArray, indices: &dyn Array) -> VortexResult { - FoRArray::try_new( - take(array.encoded(), indices)?, - array.reference_scalar().clone(), - ) - .map(|a| a.into_array()) +impl TakeExecute for FoRVTable { + fn take( + array: &FoRArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + Ok(Some( + FoRArray::try_new( + array.encoded().take(indices.to_array())?, + array.reference_scalar().clone(), + )? + .into_array(), + )) } } -register_kernel!(TakeKernelAdapter(FoRVTable).lift()); - impl FilterReduce for FoRVTable { fn filter(array: &FoRArray, mask: &Mask) -> VortexResult> { FoRArray::try_new( diff --git a/encodings/fastlanes/src/for/vtable/kernels.rs b/encodings/fastlanes/src/for/vtable/kernels.rs new file mode 100644 index 00000000000..60a009afe15 --- /dev/null +++ b/encodings/fastlanes/src/for/vtable/kernels.rs @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::arrays::TakeExecuteAdaptor; +use vortex_array::kernel::ParentKernelSet; + +use crate::FoRVTable; + +pub(crate) const PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(FoRVTable))]); diff --git a/encodings/fastlanes/src/for/vtable/mod.rs b/encodings/fastlanes/src/for/vtable/mod.rs index 9e5923ea2fc..8808b34589f 100644 --- a/encodings/fastlanes/src/for/vtable/mod.rs +++ b/encodings/fastlanes/src/for/vtable/mod.rs @@ -25,9 +25,11 @@ use vortex_session::VortexSession; use crate::FoRArray; use crate::r#for::array::for_decompress::decompress; +use crate::r#for::vtable::kernels::PARENT_KERNELS; use crate::r#for::vtable::rules::PARENT_RULES; mod array; +mod kernels; mod operations; mod rules; mod slice; @@ -115,6 +117,15 @@ impl VTable for FoRVTable { fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { Ok(decompress(array, ctx)?.into_array()) } + + fn execute_parent( + array: &Self::Array, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + PARENT_KERNELS.execute(array, parent, child_idx, ctx) + } } #[derive(Debug)] diff --git a/encodings/fsst/src/compute/mod.rs b/encodings/fsst/src/compute/mod.rs index 4dde1cc19c0..b2657a73bae 100644 --- a/encodings/fsst/src/compute/mod.rs +++ b/encodings/fsst/src/compute/mod.rs @@ -7,47 +7,51 @@ mod filter; use vortex_array::Array; use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; use vortex_array::IntoArray; +use vortex_array::arrays::TakeExecute; use vortex_array::arrays::VarBinVTable; -use vortex_array::compute::TakeKernel; -use vortex_array::compute::TakeKernelAdapter; use vortex_array::compute::fill_null; -use vortex_array::compute::take; -use vortex_array::register_kernel; +use vortex_error::VortexExpect; use vortex_error::VortexResult; +use vortex_error::vortex_err; use vortex_scalar::Scalar; use vortex_scalar::ScalarValue; use crate::FSSTArray; use crate::FSSTVTable; -impl TakeKernel for FSSTVTable { - // Take on an FSSTArray is a simple take on the codes array. - fn take(&self, array: &FSSTArray, indices: &dyn Array) -> VortexResult { - Ok(FSSTArray::try_new( - array - .dtype() - .clone() - .union_nullability(indices.dtype().nullability()), - array.symbols().clone(), - array.symbol_lengths().clone(), - take(array.codes().as_ref(), indices)? - .as_::() - .clone(), - fill_null( - &take(array.uncompressed_lengths(), indices)?, - &Scalar::new( - array.uncompressed_lengths_dtype().clone(), - ScalarValue::from(0), - ), - )?, - )? - .into_array()) +impl TakeExecute for FSSTVTable { + fn take( + array: &FSSTArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + Ok(Some( + FSSTArray::try_new( + array + .dtype() + .clone() + .union_nullability(indices.dtype().nullability()), + array.symbols().clone(), + array.symbol_lengths().clone(), + VarBinVTable::take(array.codes(), indices, _ctx)? + .vortex_expect("cannot fail") + .try_into::() + .map_err(|_| vortex_err!("take for codes must return varbin array"))?, + fill_null( + &array.uncompressed_lengths().take(indices.to_array())?, + &Scalar::new( + array.uncompressed_lengths_dtype().clone(), + ScalarValue::from(0), + ), + )?, + )? + .into_array(), + )) } } -register_kernel!(TakeKernelAdapter(FSSTVTable).lift()); - #[cfg(test)] mod tests { use rstest::rstest; diff --git a/encodings/fsst/src/kernel.rs b/encodings/fsst/src/kernel.rs index 710f0dd5ba2..e3e11a1ed5e 100644 --- a/encodings/fsst/src/kernel.rs +++ b/encodings/fsst/src/kernel.rs @@ -2,12 +2,15 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_array::arrays::FilterExecuteAdaptor; +use vortex_array::arrays::TakeExecuteAdaptor; use vortex_array::kernel::ParentKernelSet; use crate::FSSTVTable; -pub(super) const PARENT_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&FilterExecuteAdaptor(FSSTVTable))]); +pub(super) const PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ + ParentKernelSet::lift(&FilterExecuteAdaptor(FSSTVTable)), + ParentKernelSet::lift(&TakeExecuteAdaptor(FSSTVTable)), +]); #[cfg(test)] mod tests { diff --git a/encodings/fsst/src/tests.rs b/encodings/fsst/src/tests.rs index 9163d96250c..d57ccd79e8e 100644 --- a/encodings/fsst/src/tests.rs +++ b/encodings/fsst/src/tests.rs @@ -8,7 +8,6 @@ use vortex_array::ToCanonical; use vortex_array::arrays::builder::VarBinBuilder; use vortex_array::assert_arrays_eq; use vortex_array::assert_nth_scalar; -use vortex_array::compute::take; use vortex_buffer::buffer; use vortex_dtype::DType; use vortex_dtype::Nullability; @@ -69,8 +68,7 @@ fn test_fsst_array_ops() { // test take let indices = buffer![0, 2].into_array(); - let fsst_taken = take(&fsst_array, &indices).unwrap(); - assert!(fsst_taken.is::()); + let fsst_taken = fsst_array.take(indices).unwrap(); assert_eq!(fsst_taken.len(), 2); assert_nth_scalar!( fsst_taken, diff --git a/encodings/runend/src/compute/mod.rs b/encodings/runend/src/compute/mod.rs index 4827436dc39..a4c80d0a5ae 100644 --- a/encodings/runend/src/compute/mod.rs +++ b/encodings/runend/src/compute/mod.rs @@ -11,7 +11,7 @@ mod is_constant; mod is_sorted; mod min_max; pub(crate) mod take; -mod take_from; +pub(crate) mod take_from; #[cfg(test)] mod tests { diff --git a/encodings/runend/src/compute/take.rs b/encodings/runend/src/compute/take.rs index 2b6a2915bb3..009dbd86783 100644 --- a/encodings/runend/src/compute/take.rs +++ b/encodings/runend/src/compute/take.rs @@ -5,12 +5,10 @@ use num_traits::AsPrimitive; use num_traits::NumCast; use vortex_array::Array; use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; use vortex_array::ToCanonical; use vortex_array::arrays::PrimitiveArray; -use vortex_array::compute::TakeKernel; -use vortex_array::compute::TakeKernelAdapter; -use vortex_array::compute::take; -use vortex_array::register_kernel; +use vortex_array::arrays::TakeExecute; use vortex_array::search_sorted::SearchResult; use vortex_array::search_sorted::SearchSorted; use vortex_array::search_sorted::SearchSortedSide; @@ -24,12 +22,16 @@ use vortex_error::vortex_bail; use crate::RunEndArray; use crate::RunEndVTable; -impl TakeKernel for RunEndVTable { +impl TakeExecute for RunEndVTable { #[expect( clippy::cast_possible_truncation, reason = "index cast to usize inside macro" )] - fn take(&self, array: &RunEndArray, indices: &dyn Array) -> VortexResult { + fn take( + array: &RunEndArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { let primitive_indices = indices.to_primitive(); let checked_indices = match_each_integer_ptype!(primitive_indices.ptype(), |P| { @@ -47,12 +49,10 @@ impl TakeKernel for RunEndVTable { .collect::>>()? }); - take_indices_unchecked(array, &checked_indices, primitive_indices.validity()) + take_indices_unchecked(array, &checked_indices, primitive_indices.validity()).map(Some) } } -register_kernel!(TakeKernelAdapter(RunEndVTable).lift()); - /// Perform a take operation on a RunEndArray by binary searching for each of the indices. pub fn take_indices_unchecked>( array: &RunEndArray, @@ -84,7 +84,7 @@ pub fn take_indices_unchecked>( PrimitiveArray::new(buffer, validity.clone()) }); - take(array.values(), physical_indices.as_ref()) + array.values().take(physical_indices.to_array()) } #[cfg(test)] diff --git a/encodings/runend/src/compute/take_from.rs b/encodings/runend/src/compute/take_from.rs index 8bda288c3a2..dbf7ed93465 100644 --- a/encodings/runend/src/compute/take_from.rs +++ b/encodings/runend/src/compute/take_from.rs @@ -1,59 +1,174 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use vortex_array::Array; use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; use vortex_array::IntoArray; -use vortex_array::compute::TakeFromKernel; -use vortex_array::compute::TakeFromKernelAdapter; -use vortex_array::compute::take; -use vortex_array::register_kernel; +use vortex_array::arrays::DictArray; +use vortex_array::arrays::DictVTable; +use vortex_array::kernel::ExecuteParentKernel; use vortex_dtype::DType; use vortex_error::VortexResult; use crate::RunEndArray; use crate::RunEndVTable; -impl TakeFromKernel for RunEndVTable { - /// Takes values from the source array using run-end encoded indices. - /// - /// # Arguments - /// - /// * `indices` - Run-end encoded indices - /// * `source` - Array to take values from - /// - /// # Returns - /// - /// * `Ok(Some(source))` - If successful - /// * `Ok(None)` - If the source array has an unsupported dtype - /// - fn take_from( +#[derive(Debug)] +pub(crate) struct RunEndVTableTakeFrom; + +impl ExecuteParentKernel for RunEndVTableTakeFrom { + type Parent = DictVTable; + + fn execute_parent( &self, - indices: &RunEndArray, - source: &dyn Array, + array: &RunEndArray, + dict: &DictArray, + child_idx: usize, + _ctx: &mut ExecutionCtx, ) -> VortexResult> { - // Only `Primitive` and `Bool` are valid run-end value types. - TODO: Support additional DTypes - if !matches!(source.dtype(), DType::Primitive(_, _) | DType::Bool(_)) { + if child_idx != 0 { + return Ok(None); + } + // Only `Primitive` and `Bool` are valid run-end value types. + // TODO: Support additional DTypes + if !matches!(dict.dtype(), DType::Primitive(_, _) | DType::Bool(_)) { return Ok(None); } - - // Transform the run-end encoding from storing indices to storing values - // by taking values from `source` at positions specified by `indices.values()`. - let values = take(source, indices.values())?; // Create a new run-end array containing values as values, instead of indices as values. // SAFETY: we are copying ends from an existing valid RunEndArray let ree_array = unsafe { RunEndArray::new_unchecked( - indices.ends().clone(), - values, - indices.offset(), - indices.len(), + array.ends().clone(), + dict.values().take(array.values().clone())?, + array.offset(), + array.len(), ) }; - + // Ok(Some(ree_array.into_array())) } } -register_kernel!(TakeFromKernelAdapter(RunEndVTable).lift()); +#[cfg(test)] +mod tests { + use vortex_array::Array; + use vortex_array::ExecutionCtx; + use vortex_array::IntoArray; + use vortex_array::arrays::DictArray; + use vortex_array::arrays::PrimitiveArray; + use vortex_array::assert_arrays_eq; + use vortex_array::kernel::ExecuteParentKernel; + use vortex_buffer::buffer; + use vortex_error::VortexResult; + use vortex_session::VortexSession; + + use crate::RunEndArray; + use crate::compute::take_from::RunEndVTableTakeFrom; + + /// Build a DictArray whose codes are run-end encoded. + /// + /// Input: `[2, 2, 2, 3, 3, 2, 2]` + /// Dict values: `[2, 3]` + /// Codes: `[0, 0, 0, 1, 1, 0, 0]` + /// RunEnd encoded codes: ends=`[3, 5, 7]`, values=`[0, 1, 0]` + fn make_dict_with_runend_codes() -> (RunEndArray, DictArray) { + let codes = RunEndArray::encode(buffer![0u32, 0, 0, 1, 1, 0, 0].into_array()).unwrap(); + let values = buffer![2i32, 3].into_array(); + let dict = DictArray::try_new(codes.clone().into_array(), values).unwrap(); + (codes, dict) + } + + #[test] + fn test_execute_parent_no_offset() -> VortexResult<()> { + let (codes, dict) = make_dict_with_runend_codes(); + let mut ctx = ExecutionCtx::new(VortexSession::empty()); + + let result = RunEndVTableTakeFrom + .execute_parent(&codes, &dict, 0, &mut ctx)? + .expect("kernel should return Some"); + + let expected = PrimitiveArray::from_iter([2i32, 2, 2, 3, 3, 2, 2]); + assert_arrays_eq!(result.to_canonical()?.into_array(), expected); + Ok(()) + } + + #[test] + fn test_execute_parent_with_offset() -> VortexResult<()> { + let (codes, dict) = make_dict_with_runend_codes(); + // Slice codes to positions 2..5 → logical codes [0, 1, 1] → values [2, 3, 3] + let sliced_codes = unsafe { + RunEndArray::new_unchecked( + codes.ends().clone(), + codes.values().clone(), + 2, // offset + 3, // len + ) + }; + let mut ctx = ExecutionCtx::new(VortexSession::empty()); + + let result = RunEndVTableTakeFrom + .execute_parent(&sliced_codes, &dict, 0, &mut ctx)? + .expect("kernel should return Some"); + + let expected = PrimitiveArray::from_iter([2i32, 3, 3]); + assert_arrays_eq!(result.to_canonical()?.into_array(), expected); + Ok(()) + } + + #[test] + fn test_execute_parent_offset_at_run_boundary() -> VortexResult<()> { + let (codes, dict) = make_dict_with_runend_codes(); + // Slice codes to positions 3..7 → logical codes [1, 1, 0, 0] → values [3, 3, 2, 2] + let sliced_codes = unsafe { + RunEndArray::new_unchecked( + codes.ends().clone(), + codes.values().clone(), + 3, // offset at exact run boundary + 4, // len + ) + }; + let mut ctx = ExecutionCtx::new(VortexSession::empty()); + + let result = RunEndVTableTakeFrom + .execute_parent(&sliced_codes, &dict, 0, &mut ctx)? + .expect("kernel should return Some"); + + let expected = PrimitiveArray::from_iter([3i32, 3, 2, 2]); + assert_arrays_eq!(result.to_canonical()?.into_array(), expected); + Ok(()) + } + + #[test] + fn test_execute_parent_single_element_offset() -> VortexResult<()> { + let (codes, dict) = make_dict_with_runend_codes(); + // Slice to single element at position 4 → code=1 → value=3 + let sliced_codes = unsafe { + RunEndArray::new_unchecked( + codes.ends().slice(1..3)?, + codes.values().slice(1..3)?, + 4, // offset + 1, // len + ) + }; + let mut ctx = ExecutionCtx::new(VortexSession::empty()); + + let result = RunEndVTableTakeFrom + .execute_parent(&sliced_codes, &dict, 0, &mut ctx)? + .expect("kernel should return Some"); + + let expected = PrimitiveArray::from_iter([3i32]); + assert_arrays_eq!(result.to_canonical()?.into_array(), expected); + Ok(()) + } + + #[test] + fn test_execute_parent_returns_none_for_non_codes_child() -> VortexResult<()> { + let (codes, dict) = make_dict_with_runend_codes(); + let mut ctx = ExecutionCtx::new(VortexSession::empty()); + + let result = RunEndVTableTakeFrom.execute_parent(&codes, &dict, 1, &mut ctx)?; + assert!(result.is_none()); + Ok(()) + } +} diff --git a/encodings/runend/src/kernel.rs b/encodings/runend/src/kernel.rs index 3c4e41ab296..4873d9e15b1 100644 --- a/encodings/runend/src/kernel.rs +++ b/encodings/runend/src/kernel.rs @@ -10,16 +10,20 @@ use vortex_array::arrays::ConstantArray; use vortex_array::arrays::FilterExecuteAdaptor; use vortex_array::arrays::SliceArray; use vortex_array::arrays::SliceVTable; +use vortex_array::arrays::TakeExecuteAdaptor; use vortex_array::kernel::ExecuteParentKernel; use vortex_array::kernel::ParentKernelSet; use vortex_error::VortexResult; use crate::RunEndArray; use crate::RunEndVTable; +use crate::compute::take_from::RunEndVTableTakeFrom; pub(super) const PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ ParentKernelSet::lift(&RunEndSliceKernel), ParentKernelSet::lift(&FilterExecuteAdaptor(RunEndVTable)), + ParentKernelSet::lift(&TakeExecuteAdaptor(RunEndVTable)), + ParentKernelSet::lift(&RunEndVTableTakeFrom), ]); /// Kernel to execute slicing on a RunEnd array. diff --git a/encodings/sequence/src/compute/take.rs b/encodings/sequence/src/compute/take.rs index eec5191d301..448c900c68c 100644 --- a/encodings/sequence/src/compute/take.rs +++ b/encodings/sequence/src/compute/take.rs @@ -4,13 +4,12 @@ use num_traits::cast::NumCast; use vortex_array::Array; use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::ToCanonical; use vortex_array::arrays::ConstantArray; use vortex_array::arrays::PrimitiveArray; -use vortex_array::compute::TakeKernel; -use vortex_array::compute::TakeKernelAdapter; -use vortex_array::register_kernel; +use vortex_array::arrays::TakeExecute; use vortex_array::validity::Validity; use vortex_buffer::Buffer; use vortex_dtype::DType; @@ -29,31 +28,7 @@ use vortex_scalar::Scalar; use crate::SequenceArray; use crate::SequenceVTable; -impl TakeKernel for SequenceVTable { - fn take(&self, array: &SequenceArray, indices: &dyn Array) -> VortexResult { - let mask = indices.validity_mask()?; - let indices = indices.to_primitive(); - let result_nullability = array.dtype().nullability() | indices.dtype().nullability(); - - match_each_integer_ptype!(indices.ptype(), |T| { - let indices = indices.as_slice::(); - match_each_native_ptype!(array.ptype(), |S| { - let mul = array.multiplier().cast::(); - let base = array.base().cast::(); - Ok(take( - mul, - base, - indices, - mask, - result_nullability, - array.len(), - )) - }) - }) - } -} - -fn take( +fn take_inner( mul: S, base: S, indices: &[T], @@ -98,7 +73,33 @@ fn take( } } -register_kernel!(TakeKernelAdapter(SequenceVTable).lift()); +impl TakeExecute for SequenceVTable { + fn take( + array: &SequenceArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + let mask = indices.validity_mask()?; + let indices = indices.to_primitive(); + let result_nullability = array.dtype().nullability() | indices.dtype().nullability(); + + match_each_integer_ptype!(indices.ptype(), |T| { + let indices = indices.as_slice::(); + match_each_native_ptype!(array.ptype(), |S| { + let mul = array.multiplier().cast::(); + let base = array.base().cast::(); + Ok(Some(take_inner( + mul, + base, + indices, + mask, + result_nullability, + array.len(), + ))) + }) + }) + } +} #[cfg(test)] mod test { @@ -163,7 +164,7 @@ mod test { } #[test] - #[should_panic(expected = "index 20 out of bounds")] + #[should_panic(expected = "out of bounds")] fn test_bounds_check() { let array = SequenceArray::typed_new(0i32, 1i32, Nullability::NonNullable, 10).unwrap(); let indices = vortex_array::arrays::PrimitiveArray::from_iter([0i32, 20]); diff --git a/encodings/sequence/src/kernel.rs b/encodings/sequence/src/kernel.rs index 3ce712d137f..982f33cf51b 100644 --- a/encodings/sequence/src/kernel.rs +++ b/encodings/sequence/src/kernel.rs @@ -11,6 +11,7 @@ use vortex_array::arrays::ExactScalarFn; use vortex_array::arrays::FilterExecuteAdaptor; use vortex_array::arrays::ScalarFnArrayView; use vortex_array::arrays::ScalarFnVTable; +use vortex_array::arrays::TakeExecuteAdaptor; use vortex_array::compute::Operator; use vortex_array::expr::Binary; use vortex_array::kernel::ExecuteParentKernel; @@ -34,6 +35,7 @@ use crate::compute::compare::find_intersection_scalar; pub(crate) const PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ ParentKernelSet::lift(&SequenceCompareKernel), ParentKernelSet::lift(&FilterExecuteAdaptor(SequenceVTable)), + ParentKernelSet::lift(&TakeExecuteAdaptor(SequenceVTable)), ]); /// Kernel to execute comparison operations directly on a sequence array. diff --git a/encodings/sparse/src/compute/take.rs b/encodings/sparse/src/compute/take.rs index 41e4c5f0eed..3ff7d74b0df 100644 --- a/encodings/sparse/src/compute/take.rs +++ b/encodings/sparse/src/compute/take.rs @@ -3,59 +3,63 @@ use vortex_array::Array; use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::ConstantArray; -use vortex_array::compute::TakeKernel; -use vortex_array::compute::TakeKernelAdapter; -use vortex_array::register_kernel; +use vortex_array::arrays::TakeExecute; use vortex_error::VortexResult; use crate::SparseArray; use crate::SparseVTable; -impl TakeKernel for SparseVTable { - fn take(&self, array: &SparseArray, take_indices: &dyn Array) -> VortexResult { +impl TakeExecute for SparseVTable { + fn take( + array: &SparseArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { let patches_take = if array.fill_scalar().is_null() { - array.patches().take(take_indices)? + array.patches().take(indices)? } else { - array.patches().take_with_nulls(take_indices)? + array.patches().take_with_nulls(indices)? }; let Some(new_patches) = patches_take else { let result_fill_scalar = array.fill_scalar().cast( &array .dtype() - .union_nullability(take_indices.dtype().nullability()), + .union_nullability(indices.dtype().nullability()), )?; - return Ok(ConstantArray::new(result_fill_scalar, take_indices.len()).into_array()); + return Ok(Some( + ConstantArray::new(result_fill_scalar, indices.len()).into_array(), + )); }; // See `SparseEncoding::slice`. if new_patches.array_len() == new_patches.values().len() { - return Ok(new_patches.into_values()); + return Ok(Some(new_patches.into_values())); } - Ok(SparseArray::try_new_from_patches( - new_patches, - array.fill_scalar().cast( - &array - .dtype() - .union_nullability(take_indices.dtype().nullability()), - )?, - )? - .into_array()) + Ok(Some( + SparseArray::try_new_from_patches( + new_patches, + array.fill_scalar().cast( + &array + .dtype() + .union_nullability(indices.dtype().nullability()), + )?, + )? + .into_array(), + )) } } -register_kernel!(TakeKernelAdapter(SparseVTable).lift()); - #[cfg(test)] mod test { use rstest::rstest; use vortex_array::Array; use vortex_array::ArrayRef; use vortex_array::IntoArray; - use vortex_array::ToCanonical; use vortex_array::arrays::ConstantArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::assert_arrays_eq; @@ -66,7 +70,6 @@ mod test { use vortex_scalar::Scalar; use crate::SparseArray; - use crate::SparseVTable; fn test_array_fill_value() -> Scalar { // making this const is annoying @@ -118,18 +121,11 @@ mod test { #[test] fn ordered_take() { let sparse = sparse_array(); - let taken_arr = take(&sparse, &buffer![69, 37].into_array()).unwrap(); - let taken = taken_arr.as_::(); - - assert_arrays_eq!( - taken.patches().indices().to_primitive(), - PrimitiveArray::from_iter([1u64]) - ); - assert_arrays_eq!( - taken.patches().values().to_primitive(), - PrimitiveArray::from_option_iter([Some(0.47f64)]) - ); - assert_eq!(taken.len(), 2); + // Note: take returns a canonical array, not SparseArray + let taken = take(&sparse, &buffer![69, 37].into_array()).unwrap(); + // Index 69 is not in sparse array (fill value is null), index 37 has value 0.47 + let expected = PrimitiveArray::from_option_iter([Option::::None, Some(0.47f64)]); + assert_arrays_eq!(taken, expected.to_array()); } #[test] diff --git a/encodings/sparse/src/kernel.rs b/encodings/sparse/src/kernel.rs index c5cf765c5ac..5b6795b407b 100644 --- a/encodings/sparse/src/kernel.rs +++ b/encodings/sparse/src/kernel.rs @@ -3,6 +3,7 @@ use vortex_array::arrays::FilterExecuteAdaptor; use vortex_array::arrays::SliceExecuteAdaptor; +use vortex_array::arrays::TakeExecuteAdaptor; use vortex_array::kernel::ParentKernelSet; use crate::SparseVTable; @@ -10,4 +11,5 @@ use crate::SparseVTable; pub(crate) static PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ ParentKernelSet::lift(&FilterExecuteAdaptor(SparseVTable)), ParentKernelSet::lift(&SliceExecuteAdaptor(SparseVTable)), + ParentKernelSet::lift(&TakeExecuteAdaptor(SparseVTable)), ]); diff --git a/encodings/sparse/src/lib.rs b/encodings/sparse/src/lib.rs index 71494d5d5f6..fbfd003c3d1 100644 --- a/encodings/sparse/src/lib.rs +++ b/encodings/sparse/src/lib.rs @@ -4,6 +4,7 @@ use std::fmt::Debug; use std::hash::Hash; +use kernel::PARENT_KERNELS; use prost::Message as _; use vortex_array::Array; use vortex_array::ArrayBufferVisitor; @@ -163,7 +164,7 @@ impl VTable for SparseVTable { child_idx: usize, ctx: &mut ExecutionCtx, ) -> VortexResult> { - kernel::PARENT_KERNELS.execute(array, parent, child_idx, ctx) + PARENT_KERNELS.execute(array, parent, child_idx, ctx) } fn execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult { diff --git a/encodings/zigzag/src/array.rs b/encodings/zigzag/src/array.rs index efc5c29e2e9..9b28770c1aa 100644 --- a/encodings/zigzag/src/array.rs +++ b/encodings/zigzag/src/array.rs @@ -37,6 +37,7 @@ use vortex_session::VortexSession; use zigzag::ZigZag as ExternalZigZag; use crate::compute::ZigZagEncoded; +use crate::kernel::PARENT_KERNELS; use crate::rules::RULES; use crate::zigzag_decode; @@ -112,6 +113,15 @@ impl VTable for ZigZagVTable { ) -> VortexResult> { RULES.evaluate(array, parent, child_idx) } + + fn execute_parent( + array: &Self::Array, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + PARENT_KERNELS.execute(array, parent, child_idx, ctx) + } } #[derive(Clone, Debug)] diff --git a/encodings/zigzag/src/compute/mod.rs b/encodings/zigzag/src/compute/mod.rs index 7d97d94edaa..42b9ea2d120 100644 --- a/encodings/zigzag/src/compute/mod.rs +++ b/encodings/zigzag/src/compute/mod.rs @@ -5,14 +5,13 @@ mod cast; use vortex_array::Array; use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::FilterReduce; +use vortex_array::arrays::TakeExecute; use vortex_array::compute::MaskKernel; use vortex_array::compute::MaskKernelAdapter; -use vortex_array::compute::TakeKernel; -use vortex_array::compute::TakeKernelAdapter; use vortex_array::compute::mask; -use vortex_array::compute::take; use vortex_array::register_kernel; use vortex_error::VortexResult; use vortex_mask::Mask; @@ -27,15 +26,17 @@ impl FilterReduce for ZigZagVTable { } } -impl TakeKernel for ZigZagVTable { - fn take(&self, array: &ZigZagArray, indices: &dyn Array) -> VortexResult { - let encoded = take(array.encoded(), indices)?; - Ok(ZigZagArray::try_new(encoded)?.into_array()) +impl TakeExecute for ZigZagVTable { + fn take( + array: &ZigZagArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + let encoded = array.encoded().take(indices.to_array())?; + Ok(Some(ZigZagArray::try_new(encoded)?.into_array())) } } -register_kernel!(TakeKernelAdapter(ZigZagVTable).lift()); - impl MaskKernel for ZigZagVTable { fn mask(&self, array: &ZigZagArray, filter_mask: &Mask) -> VortexResult { let encoded = mask(array.encoded(), filter_mask)?; diff --git a/encodings/zigzag/src/kernel.rs b/encodings/zigzag/src/kernel.rs new file mode 100644 index 00000000000..6b93005d0f2 --- /dev/null +++ b/encodings/zigzag/src/kernel.rs @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::arrays::TakeExecuteAdaptor; +use vortex_array::kernel::ParentKernelSet; + +use crate::ZigZagVTable; + +pub(crate) const PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(ZigZagVTable))]); diff --git a/encodings/zigzag/src/lib.rs b/encodings/zigzag/src/lib.rs index 733139a1a00..89da8bd6069 100644 --- a/encodings/zigzag/src/lib.rs +++ b/encodings/zigzag/src/lib.rs @@ -7,5 +7,6 @@ pub use compress::*; mod array; mod compress; mod compute; +mod kernel; mod rules; mod slice; diff --git a/vortex-array/src/arrays/bool/compute/take.rs b/vortex-array/src/arrays/bool/compute/take.rs index def1490cb69..e9fb6bd44dc 100644 --- a/vortex-array/src/arrays/bool/compute/take.rs +++ b/vortex-array/src/arrays/bool/compute/take.rs @@ -17,22 +17,24 @@ use crate::ToCanonical; use crate::arrays::BoolArray; use crate::arrays::BoolVTable; use crate::arrays::ConstantArray; -use crate::compute::TakeKernel; -use crate::compute::TakeKernelAdapter; +use crate::arrays::TakeExecute; use crate::compute::fill_null; -use crate::register_kernel; +use crate::executor::ExecutionCtx; use crate::vtable::ValidityHelper; -impl TakeKernel for BoolVTable { - fn take(&self, array: &BoolArray, indices: &dyn Array) -> VortexResult { +impl TakeExecute for BoolVTable { + fn take( + array: &BoolArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { let indices_nulls_zeroed = match indices.validity_mask()? { Mask::AllTrue(_) => indices.to_array(), Mask::AllFalse(_) => { - return Ok(ConstantArray::new( - Scalar::null(array.dtype().as_nullable()), - indices.len(), - ) - .into_array()); + return Ok(Some( + ConstantArray::new(Scalar::null(array.dtype().as_nullable()), indices.len()) + .into_array(), + )); } Mask::Values(_) => fill_null(indices, &Scalar::from(0).cast(indices.dtype())?)?, }; @@ -41,12 +43,12 @@ impl TakeKernel for BoolVTable { take_valid_indices(&array.to_bit_buffer(), indices_nulls_zeroed.as_slice::()) }); - Ok(BoolArray::new(buffer, array.validity().take(indices)?).to_array()) + Ok(Some( + BoolArray::new(buffer, array.validity().take(indices)?).to_array(), + )) } } -register_kernel!(TakeKernelAdapter(BoolVTable).lift()); - fn take_valid_indices>(bools: &BitBuffer, indices: &[I]) -> BitBuffer { // For boolean arrays that roughly fit into a single page (at least, on Linux), it's worth // the overhead to convert to a Vec. @@ -54,7 +56,7 @@ fn take_valid_indices>(bools: &BitBuffer, indices: &[I]) - let bools = bools.iter().collect_vec(); take_byte_bool(bools, indices) } else { - take_bool(bools, indices) + take_bool_impl(bools, indices) } } @@ -64,7 +66,7 @@ fn take_byte_bool>(bools: Vec, indices: &[I]) -> Bit }) } -fn take_bool>(bools: &BitBuffer, indices: &[I]) -> BitBuffer { +fn take_bool_impl>(bools: &BitBuffer, indices: &[I]) -> BitBuffer { // We dereference to underlying buffer to avoid access cost on every index. let buffer = bools.inner().as_ref(); BitBuffer::collect_bool(indices.len(), |idx| { @@ -107,7 +109,7 @@ mod test { BoolArray::from_iter([Some(false), None, Some(false)]).to_bit_buffer() ); - let all_invalid_indices = PrimitiveArray::from_option_iter([None::, None, None]); + let all_invalid_indices = PrimitiveArray::from_option_iter([None::, None, None]); let b = take(reference.as_ref(), all_invalid_indices.as_ref()).unwrap(); assert_arrays_eq!(b, BoolArray::from_iter([None, None, None])); } diff --git a/vortex-array/src/arrays/bool/vtable/kernel.rs b/vortex-array/src/arrays/bool/vtable/kernel.rs index be7d92058de..3afa0dc3755 100644 --- a/vortex-array/src/arrays/bool/vtable/kernel.rs +++ b/vortex-array/src/arrays/bool/vtable/kernel.rs @@ -2,8 +2,11 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use crate::arrays::BoolVTable; +use crate::arrays::TakeExecuteAdaptor; use crate::arrays::filter::FilterExecuteAdaptor; use crate::kernel::ParentKernelSet; -pub(super) const PARENT_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&FilterExecuteAdaptor(BoolVTable))]); +pub(super) const PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ + ParentKernelSet::lift(&FilterExecuteAdaptor(BoolVTable)), + ParentKernelSet::lift(&TakeExecuteAdaptor(BoolVTable)), +]); diff --git a/vortex-array/src/arrays/bool/vtable/mod.rs b/vortex-array/src/arrays/bool/vtable/mod.rs index ecbb061fa20..83d4c0c7e01 100644 --- a/vortex-array/src/arrays/bool/vtable/mod.rs +++ b/vortex-array/src/arrays/bool/vtable/mod.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use kernel::PARENT_KERNELS; use vortex_dtype::DType; use vortex_error::VortexExpect; use vortex_error::VortexResult; @@ -134,7 +135,7 @@ impl VTable for BoolVTable { child_idx: usize, ctx: &mut ExecutionCtx, ) -> VortexResult> { - kernel::PARENT_KERNELS.execute(array, parent, child_idx, ctx) + PARENT_KERNELS.execute(array, parent, child_idx, ctx) } } diff --git a/vortex-array/src/arrays/chunked/compute/filter.rs b/vortex-array/src/arrays/chunked/compute/filter.rs index 7eacbd731a2..2e31adcc0aa 100644 --- a/vortex-array/src/arrays/chunked/compute/filter.rs +++ b/vortex-array/src/arrays/chunked/compute/filter.rs @@ -15,7 +15,6 @@ use crate::arrays::ChunkedArray; use crate::arrays::ChunkedVTable; use crate::arrays::PrimitiveArray; use crate::arrays::filter::FilterKernel; -use crate::compute::take; use crate::search_sorted::SearchSorted; use crate::search_sorted::SearchSortedSide; use crate::validity::Validity; @@ -162,11 +161,9 @@ fn filter_indices( // Push the chunk we've accumulated. if !chunk_indices.is_empty() { let chunk = array.chunk(current_chunk_id); - let filtered_chunk = take( - chunk, - PrimitiveArray::new(chunk_indices.clone().freeze(), Validity::NonNullable) - .as_ref(), - )?; + let indices = + PrimitiveArray::new(chunk_indices.clone().freeze(), Validity::NonNullable); + let filtered_chunk = chunk.take(indices.to_array())?.to_canonical()?.into_array(); result.push(filtered_chunk); } @@ -180,10 +177,8 @@ fn filter_indices( if !chunk_indices.is_empty() { let chunk = array.chunk(current_chunk_id); - let filtered_chunk = take( - chunk, - PrimitiveArray::new(chunk_indices.clone().freeze(), Validity::NonNullable).as_ref(), - )?; + let indices = PrimitiveArray::new(chunk_indices.clone().freeze(), Validity::NonNullable); + let filtered_chunk = chunk.take(indices.to_array())?.to_canonical()?.into_array(); result.push(filtered_chunk); } diff --git a/vortex-array/src/arrays/chunked/compute/kernel.rs b/vortex-array/src/arrays/chunked/compute/kernel.rs index 6db192e4bdf..e933961d816 100644 --- a/vortex-array/src/arrays/chunked/compute/kernel.rs +++ b/vortex-array/src/arrays/chunked/compute/kernel.rs @@ -4,9 +4,11 @@ use crate::arrays::ChunkedVTable; use crate::arrays::FilterExecuteAdaptor; use crate::arrays::SliceExecuteAdaptor; +use crate::arrays::TakeExecuteAdaptor; use crate::kernel::ParentKernelSet; pub(crate) static PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ ParentKernelSet::lift(&SliceExecuteAdaptor(ChunkedVTable)), ParentKernelSet::lift(&FilterExecuteAdaptor(ChunkedVTable)), + ParentKernelSet::lift(&TakeExecuteAdaptor(ChunkedVTable)), ]); diff --git a/vortex-array/src/arrays/chunked/compute/take.rs b/vortex-array/src/arrays/chunked/compute/take.rs index 5facc27e503..fc20dcd3b36 100644 --- a/vortex-array/src/arrays/chunked/compute/take.rs +++ b/vortex-array/src/arrays/chunked/compute/take.rs @@ -12,79 +12,84 @@ use crate::IntoArray; use crate::ToCanonical; use crate::arrays::ChunkedVTable; use crate::arrays::PrimitiveArray; +use crate::arrays::TakeExecute; use crate::arrays::chunked::ChunkedArray; -use crate::compute::TakeKernel; -use crate::compute::TakeKernelAdapter; use crate::compute::cast; use crate::compute::take; -use crate::register_kernel; +use crate::executor::ExecutionCtx; use crate::validity::Validity; -impl TakeKernel for ChunkedVTable { - fn take(&self, array: &ChunkedArray, indices: &dyn Array) -> VortexResult { - let indices = cast( - indices, - &DType::Primitive(PType::U64, indices.dtype().nullability()), - )? - .to_primitive(); - - // TODO(joe): Should we split this implementation based on indices nullability? - let nullability = indices.dtype().nullability(); - let indices_mask = indices.validity_mask()?; - let indices = indices.as_slice::(); - - let mut chunks = Vec::new(); - let mut indices_in_chunk = BufferMut::::empty(); - let mut start = 0; - let mut stop = 0; - // We assume indices are non-empty as it's handled in the top-level `take` function - let mut prev_chunk_idx = array.find_chunk_idx(indices[0].try_into()?)?.0; - for idx in indices { - let idx = usize::try_from(*idx)?; - let (chunk_idx, idx_in_chunk) = array.find_chunk_idx(idx)?; - - if chunk_idx != prev_chunk_idx { - // Start a new chunk - let indices_in_chunk_array = PrimitiveArray::new( - indices_in_chunk.clone().freeze(), - Validity::from_mask(indices_mask.slice(start..stop), nullability), - ); - chunks.push(take( - array.chunk(prev_chunk_idx), - indices_in_chunk_array.as_ref(), - )?); - indices_in_chunk.clear(); - start = stop; - } - - indices_in_chunk.push(idx_in_chunk as u64); - stop += 1; - prev_chunk_idx = chunk_idx; - } - - if !indices_in_chunk.is_empty() { +fn take_chunked(array: &ChunkedArray, indices: &dyn Array) -> VortexResult { + let indices = cast( + indices, + &DType::Primitive(PType::U64, indices.dtype().nullability()), + )? + .to_primitive(); + + // TODO(joe): Should we split this implementation based on indices nullability? + let nullability = indices.dtype().nullability(); + let indices_mask = indices.validity_mask()?; + let indices = indices.as_slice::(); + + let mut chunks = Vec::new(); + let mut indices_in_chunk = BufferMut::::empty(); + let mut start = 0; + let mut stop = 0; + // We assume indices are non-empty as it's handled in the top-level `take` function + let mut prev_chunk_idx = array.find_chunk_idx(indices[0].try_into()?)?.0; + for idx in indices { + let idx = usize::try_from(*idx)?; + let (chunk_idx, idx_in_chunk) = array.find_chunk_idx(idx)?; + + if chunk_idx != prev_chunk_idx { + // Start a new chunk let indices_in_chunk_array = PrimitiveArray::new( - indices_in_chunk.freeze(), + indices_in_chunk.clone().freeze(), Validity::from_mask(indices_mask.slice(start..stop), nullability), ); chunks.push(take( array.chunk(prev_chunk_idx), indices_in_chunk_array.as_ref(), )?); + indices_in_chunk.clear(); + start = stop; } - // SAFETY: take on chunks that all have same DType retains same DType - unsafe { - Ok(ChunkedArray::new_unchecked( - chunks, - array.dtype().clone().union_nullability(nullability), - ) - .into_array()) - } + indices_in_chunk.push(idx_in_chunk as u64); + stop += 1; + prev_chunk_idx = chunk_idx; + } + + if !indices_in_chunk.is_empty() { + let indices_in_chunk_array = PrimitiveArray::new( + indices_in_chunk.freeze(), + Validity::from_mask(indices_mask.slice(start..stop), nullability), + ); + chunks.push(take( + array.chunk(prev_chunk_idx), + indices_in_chunk_array.as_ref(), + )?); + } + + // SAFETY: take on chunks that all have same DType retains same DType + unsafe { + Ok(ChunkedArray::new_unchecked( + chunks, + array.dtype().clone().union_nullability(nullability), + ) + .into_array()) } } -register_kernel!(TakeKernelAdapter(ChunkedVTable).lift()); +impl TakeExecute for ChunkedVTable { + fn take( + array: &ChunkedArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + take_chunked(array, indices).map(Some) + } +} #[cfg(test)] mod test { diff --git a/vortex-array/src/arrays/constant/compute/rules.rs b/vortex-array/src/arrays/constant/compute/rules.rs index ddc7b349754..76d390be4c1 100644 --- a/vortex-array/src/arrays/constant/compute/rules.rs +++ b/vortex-array/src/arrays/constant/compute/rules.rs @@ -11,6 +11,7 @@ use crate::arrays::FilterArray; use crate::arrays::FilterReduceAdaptor; use crate::arrays::FilterVTable; use crate::arrays::SliceReduceAdaptor; +use crate::arrays::TakeReduceAdaptor; use crate::optimizer::rules::ArrayParentReduceRule; use crate::optimizer::rules::ParentRuleSet; @@ -18,6 +19,7 @@ pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::ne ParentRuleSet::lift(&ConstantFilterRule), ParentRuleSet::lift(&FilterReduceAdaptor(ConstantVTable)), ParentRuleSet::lift(&SliceReduceAdaptor(ConstantVTable)), + ParentRuleSet::lift(&TakeReduceAdaptor(ConstantVTable)), ]); #[derive(Debug)] diff --git a/vortex-array/src/arrays/constant/compute/take.rs b/vortex-array/src/arrays/constant/compute/take.rs index 24b60dd4fb9..44a00f76ca4 100644 --- a/vortex-array/src/arrays/constant/compute/take.rs +++ b/vortex-array/src/arrays/constant/compute/take.rs @@ -11,14 +11,14 @@ use crate::IntoArray; use crate::arrays::ConstantArray; use crate::arrays::ConstantVTable; use crate::arrays::MaskedArray; -use crate::compute::TakeKernel; -use crate::compute::TakeKernelAdapter; -use crate::register_kernel; +use crate::arrays::TakeReduce; +use crate::arrays::TakeReduceAdaptor; +use crate::optimizer::rules::ParentRuleSet; use crate::validity::Validity; -impl TakeKernel for ConstantVTable { - fn take(&self, array: &ConstantArray, indices: &dyn Array) -> VortexResult { - match indices.validity_mask()?.bit_buffer() { +impl TakeReduce for ConstantVTable { + fn take(array: &ConstantArray, indices: &dyn Array) -> VortexResult> { + let result = match indices.validity_mask()?.bit_buffer() { AllOr::All => { let scalar = Scalar::new( array @@ -27,9 +27,9 @@ impl TakeKernel for ConstantVTable { .union_nullability(indices.dtype().nullability()), array.scalar().value().clone(), ); - Ok(ConstantArray::new(scalar, indices.len()).into_array()) + ConstantArray::new(scalar, indices.len()).into_array() } - AllOr::None => Ok(ConstantArray::new( + AllOr::None => ConstantArray::new( Scalar::null( array .dtype() @@ -37,21 +37,25 @@ impl TakeKernel for ConstantVTable { ), indices.len(), ) - .into_array()), + .into_array(), AllOr::Some(v) => { let arr = ConstantArray::new(array.scalar().clone(), indices.len()).into_array(); if array.scalar().is_null() { - return Ok(arr); + return Ok(Some(arr)); } - Ok(MaskedArray::try_new(arr, Validity::from(v.clone()))?.into_array()) + MaskedArray::try_new(arr, Validity::from(v.clone()))?.into_array() } - } + }; + Ok(Some(result)) } } -register_kernel!(TakeKernelAdapter(ConstantVTable).lift()); +impl ConstantVTable { + pub const TAKE_RULES: ParentRuleSet = + ParentRuleSet::new(&[ParentRuleSet::lift(&TakeReduceAdaptor::(Self))]); +} #[cfg(test)] mod tests { diff --git a/vortex-array/src/arrays/decimal/compute/take.rs b/vortex-array/src/arrays/decimal/compute/take.rs index ff9f5097a7e..15af95a3ce1 100644 --- a/vortex-array/src/arrays/decimal/compute/take.rs +++ b/vortex-array/src/arrays/decimal/compute/take.rs @@ -13,13 +13,16 @@ use crate::ArrayRef; use crate::ToCanonical; use crate::arrays::DecimalArray; use crate::arrays::DecimalVTable; -use crate::compute::TakeKernel; -use crate::compute::TakeKernelAdapter; -use crate::register_kernel; +use crate::arrays::TakeExecute; +use crate::executor::ExecutionCtx; use crate::vtable::ValidityHelper; -impl TakeKernel for DecimalVTable { - fn take(&self, array: &DecimalArray, indices: &dyn Array) -> VortexResult { +impl TakeExecute for DecimalVTable { + fn take( + array: &DecimalArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { let indices = indices.to_primitive(); let validity = array.validity().take(indices.as_ref())?; @@ -35,12 +38,10 @@ impl TakeKernel for DecimalVTable { }) }); - Ok(decimal.to_array()) + Ok(Some(decimal.to_array())) } } -register_kernel!(TakeKernelAdapter(DecimalVTable).lift()); - #[inline] fn take_to_buffer(indices: &[I], values: &[T]) -> Buffer { indices.iter().map(|idx| values[idx.as_()]).collect() diff --git a/vortex-array/src/arrays/decimal/vtable/kernel.rs b/vortex-array/src/arrays/decimal/vtable/kernel.rs new file mode 100644 index 00000000000..561f4073026 --- /dev/null +++ b/vortex-array/src/arrays/decimal/vtable/kernel.rs @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use crate::arrays::DecimalVTable; +use crate::arrays::TakeExecuteAdaptor; +use crate::kernel::ParentKernelSet; + +pub(super) const PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(DecimalVTable))]); diff --git a/vortex-array/src/arrays/decimal/vtable/mod.rs b/vortex-array/src/arrays/decimal/vtable/mod.rs index 534369f1169..96683f051bd 100644 --- a/vortex-array/src/arrays/decimal/vtable/mod.rs +++ b/vortex-array/src/arrays/decimal/vtable/mod.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use kernel::PARENT_KERNELS; use vortex_buffer::Alignment; use vortex_dtype::DType; use vortex_dtype::NativeDecimalType; @@ -26,6 +27,7 @@ use crate::vtable::VTable; use crate::vtable::ValidityVTableFromValidityHelper; mod array; +mod kernel; mod operations; mod validity; mod visitor; @@ -143,6 +145,15 @@ impl VTable for DecimalVTable { ) -> VortexResult> { RULES.evaluate(array, parent, child_idx) } + + fn execute_parent( + array: &Self::Array, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + PARENT_KERNELS.execute(array, parent, child_idx, ctx) + } } #[derive(Debug)] diff --git a/vortex-array/src/arrays/dict/compute/mod.rs b/vortex-array/src/arrays/dict/compute/mod.rs index f3dce73932b..fe8b053a53f 100644 --- a/vortex-array/src/arrays/dict/compute/mod.rs +++ b/vortex-array/src/arrays/dict/compute/mod.rs @@ -17,26 +17,32 @@ use vortex_mask::Mask; use super::DictArray; use super::DictVTable; +use super::TakeExecute; use crate::Array; use crate::ArrayRef; +use crate::ExecutionCtx; use crate::IntoArray; use crate::arrays::filter::FilterReduce; -use crate::compute::TakeKernel; -use crate::compute::TakeKernelAdapter; -use crate::compute::take; -use crate::register_kernel; - -impl TakeKernel for DictVTable { - fn take(&self, array: &DictArray, indices: &dyn Array) -> VortexResult { - let codes = take(array.codes(), indices)?; + +impl TakeExecute for DictVTable { + fn take( + array: &DictArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + let codes = array + .codes() + .take(indices.to_array())? + .to_canonical()? + .into_array(); // SAFETY: selecting codes doesn't change the invariants of DictArray // Preserve all_values_referenced since taking codes doesn't affect which values are referenced - Ok(unsafe { DictArray::new_unchecked(codes, array.values().clone()).into_array() }) + Ok(Some(unsafe { + DictArray::new_unchecked(codes, array.values().clone()).into_array() + })) } } -register_kernel!(TakeKernelAdapter(DictVTable).lift()); - impl FilterReduce for DictVTable { fn filter(array: &DictArray, mask: &Mask) -> VortexResult> { let codes = array.codes().filter(mask.clone())?; diff --git a/vortex-array/src/arrays/dict/execute.rs b/vortex-array/src/arrays/dict/execute.rs index d3ffa4a1b6c..0e1dfd506bb 100644 --- a/vortex-array/src/arrays/dict/execute.rs +++ b/vortex-array/src/arrays/dict/execute.rs @@ -7,195 +7,145 @@ use vortex_error::VortexExpect; use vortex_error::VortexResult; use crate::Canonical; +use crate::ExecutionCtx; use crate::arrays::BoolArray; use crate::arrays::BoolVTable; use crate::arrays::DecimalArray; use crate::arrays::DecimalVTable; use crate::arrays::ExtensionArray; +use crate::arrays::ExtensionVTable; use crate::arrays::FixedSizeListArray; use crate::arrays::FixedSizeListVTable; use crate::arrays::ListViewArray; use crate::arrays::ListViewVTable; use crate::arrays::NullArray; -use crate::arrays::NullVTable; use crate::arrays::PrimitiveArray; use crate::arrays::PrimitiveVTable; use crate::arrays::StructArray; use crate::arrays::StructVTable; +use crate::arrays::TakeExecute; use crate::arrays::VarBinViewArray; use crate::arrays::VarBinViewVTable; -use crate::compute::TakeKernel; -/// TODO: replace usage of compute fn. /// Take from a canonical array using indices (codes), returning a new canonical array. /// /// This is the core operation for dictionary decoding - it expands the dictionary /// by looking up each code in the values array. -pub fn take_canonical(values: Canonical, codes: &PrimitiveArray) -> VortexResult { +pub fn take_canonical( + values: Canonical, + codes: &PrimitiveArray, + ctx: &mut ExecutionCtx, +) -> VortexResult { Ok(match values { Canonical::Null(a) => Canonical::Null(take_null(&a, codes)), - Canonical::Bool(a) => Canonical::Bool(take_bool(&a, codes)?), - Canonical::Primitive(a) => Canonical::Primitive(take_primitive(&a, codes)), - Canonical::Decimal(a) => Canonical::Decimal(take_decimal(&a, codes)), - Canonical::VarBinView(a) => Canonical::VarBinView(take_varbinview(&a, codes)), - Canonical::List(a) => Canonical::List(take_listview(&a, codes)), - Canonical::FixedSizeList(a) => Canonical::FixedSizeList(take_fixed_size_list(&a, codes)), - Canonical::Struct(a) => Canonical::Struct(take_struct(&a, codes)), - Canonical::Extension(a) => Canonical::Extension(take_extension(&a, codes)), + Canonical::Bool(a) => Canonical::Bool(take_bool(&a, codes, ctx)?), + Canonical::Primitive(a) => Canonical::Primitive(take_primitive(&a, codes, ctx)), + Canonical::Decimal(a) => Canonical::Decimal(take_decimal(&a, codes, ctx)), + Canonical::VarBinView(a) => Canonical::VarBinView(take_varbinview(&a, codes, ctx)), + Canonical::List(a) => Canonical::List(take_listview(&a, codes, ctx)), + Canonical::FixedSizeList(a) => { + Canonical::FixedSizeList(take_fixed_size_list(&a, codes, ctx)) + } + Canonical::Struct(a) => Canonical::Struct(take_struct(&a, codes, ctx)), + Canonical::Extension(a) => Canonical::Extension(take_extension(&a, codes, ctx)), }) } +/// Take for NullArray is trivial - just create a new NullArray with the new length. fn take_null(_array: &NullArray, codes: &PrimitiveArray) -> NullArray { - NullVTable - .take(_array, codes.as_ref()) - .vortex_expect("take null array") - .as_::() - .clone() + NullArray::new(codes.len()) } -// pub(super) fn dict_bool_take(dict_array: &DictArray) -> VortexResult { -// let values = dict_array.values(); -// let codes = dict_array.codes(); -// let result_nullability = dict_array.dtype().nullability(); -// -// let bool_values = values.to_bool(); -// let result_validity = bool_values.validity_mask(); -// let bool_buffer = bool_values.bit_buffer(); -// let (first_match, second_match) = match result_validity.bit_buffer() { -// AllOr::All => { -// let mut indices_iter = bool_buffer.set_indices(); -// (indices_iter.next(), indices_iter.next()) -// } -// AllOr::None => (None, None), -// AllOr::Some(v) => { -// let mut indices_iter = bool_buffer.set_indices().filter(|i| v.value(*i)); -// (indices_iter.next(), indices_iter.next()) -// } -// }; -// -// Ok(match (first_match, second_match) { -// // Couldn't find a value match, so the result is all false. -// (None, _) => match result_validity { -// Mask::AllTrue(_) => BoolArray::new( -// BitBuffer::new_unset(codes.len()), -// Validity::copy_from_array(codes).union_nullability(result_nullability), -// ) -// .to_canonical()?, -// Mask::AllFalse(_) => ConstantArray::new( -// Scalar::null(DType::Bool(Nullability::Nullable)), -// codes.len(), -// ) -// .to_canonical()?, -// Mask::Values(_) => BoolArray::new( -// BitBuffer::new_unset(codes.len()), -// Validity::from_mask(result_validity, result_nullability).take(codes)?, -// ) -// .to_canonical()?, -// }, -// // We found a single matching value so we can compare the codes directly. -// (Some(code), None) => match result_validity { -// Mask::AllTrue(_) => cast( -// &compare( -// codes, -// &cast( -// ConstantArray::new(code, codes.len()).as_ref(), -// codes.dtype(), -// )?, -// Operator::Eq, -// )?, -// &DType::Bool(result_nullability), -// )? -// .to_canonical()?, -// Mask::AllFalse(_) => ConstantArray::new( -// Scalar::null(DType::Bool(Nullability::Nullable)), -// codes.len(), -// ) -// .to_canonical()?, -// Mask::Values(rv) => mask( -// &compare( -// codes, -// &cast( -// ConstantArray::new(code, codes.len()).as_ref(), -// codes.dtype(), -// )?, -// Operator::Eq, -// )?, -// &Mask::from_buffer( -// take(BoolArray::from(rv.bit_buffer().clone()).as_ref(), codes)? -// .to_bool() -// .bit_buffer() -// .not(), -// ), -// )? -// .to_canonical()?, -// }, -// // More than one value matches. -// _ => take(bool_values.as_ref(), codes) -// .vortex_expect("taking codes from dictionary values shouldn't fail") -// .to_canonical()?, -// }) -// } - // TODO(joe): use dict_bool_take -fn take_bool(array: &BoolArray, codes: &PrimitiveArray) -> VortexResult { - Ok(BoolVTable - .take(array, codes.as_ref())? - .as_::() - .clone()) +fn take_bool( + array: &BoolArray, + codes: &PrimitiveArray, + ctx: &mut ExecutionCtx, +) -> VortexResult { + Ok( + ::take(array, codes.as_ref(), ctx)? + .vortex_expect("take bool should not return None") + .as_::() + .clone(), + ) } -fn take_primitive(array: &PrimitiveArray, codes: &PrimitiveArray) -> PrimitiveArray { - PrimitiveVTable - .take(array, codes.as_ref()) +fn take_primitive( + array: &PrimitiveArray, + codes: &PrimitiveArray, + ctx: &mut ExecutionCtx, +) -> PrimitiveArray { + ::take(array, codes.as_ref(), ctx) .vortex_expect("take primitive array") + .vortex_expect("take primitive should not return None") .as_::() .clone() } -fn take_decimal(array: &DecimalArray, codes: &PrimitiveArray) -> DecimalArray { - DecimalVTable - .take(array, codes.as_ref()) +fn take_decimal( + array: &DecimalArray, + codes: &PrimitiveArray, + ctx: &mut ExecutionCtx, +) -> DecimalArray { + ::take(array, codes.as_ref(), ctx) .vortex_expect("take decimal array") + .vortex_expect("take decimal should not return None") .as_::() .clone() } -fn take_varbinview(array: &VarBinViewArray, codes: &PrimitiveArray) -> VarBinViewArray { - VarBinViewVTable - .take(array, codes.as_ref()) +fn take_varbinview( + array: &VarBinViewArray, + codes: &PrimitiveArray, + ctx: &mut ExecutionCtx, +) -> VarBinViewArray { + ::take(array, codes.as_ref(), ctx) .vortex_expect("take varbinview array") + .vortex_expect("take varbinview should not return None") .as_::() .clone() } -fn take_listview(array: &ListViewArray, codes: &PrimitiveArray) -> ListViewArray { - ListViewVTable - .take(array, codes.as_ref()) +fn take_listview( + array: &ListViewArray, + codes: &PrimitiveArray, + ctx: &mut ExecutionCtx, +) -> ListViewArray { + ::take(array, codes.as_ref(), ctx) .vortex_expect("take listview array") + .vortex_expect("take listview should not return None") .as_::() .clone() } -fn take_fixed_size_list(array: &FixedSizeListArray, codes: &PrimitiveArray) -> FixedSizeListArray { - FixedSizeListVTable - .take(array, codes.as_ref()) +fn take_fixed_size_list( + array: &FixedSizeListArray, + codes: &PrimitiveArray, + ctx: &mut ExecutionCtx, +) -> FixedSizeListArray { + ::take(array, codes.as_ref(), ctx) .vortex_expect("take fixed size list array") + .vortex_expect("take fixed size list should not return None") .as_::() .clone() } -fn take_struct(array: &StructArray, codes: &PrimitiveArray) -> StructArray { - StructVTable - .take(array, codes.as_ref()) +fn take_struct(array: &StructArray, codes: &PrimitiveArray, ctx: &mut ExecutionCtx) -> StructArray { + ::take(array, codes.as_ref(), ctx) .vortex_expect("take struct array") + .vortex_expect("take struct should not return None") .as_::() .clone() } -fn take_extension(array: &ExtensionArray, codes: &PrimitiveArray) -> ExtensionArray { - use crate::compute::take; - - let taken_storage = - take(array.storage(), codes.as_ref()).vortex_expect("take extension storage"); - ExtensionArray::new(array.ext_dtype().clone(), taken_storage) +fn take_extension( + array: &ExtensionArray, + codes: &PrimitiveArray, + ctx: &mut ExecutionCtx, +) -> ExtensionArray { + ::take(array, codes.as_ref(), ctx) + .vortex_expect("take extension storage") + .vortex_expect("take extension should not return None") + .as_::() + .clone() } diff --git a/vortex-array/src/arrays/dict/mod.rs b/vortex-array/src/arrays/dict/mod.rs index cf2c44febf3..7ba509e09e5 100644 --- a/vortex-array/src/arrays/dict/mod.rs +++ b/vortex-array/src/arrays/dict/mod.rs @@ -19,6 +19,9 @@ mod execute; pub use execute::take_canonical; +mod take; +pub use take::*; + pub mod vtable; pub use vtable::*; diff --git a/vortex-array/src/arrays/dict/take.rs b/vortex-array/src/arrays/dict/take.rs new file mode 100644 index 00000000000..86c3700952d --- /dev/null +++ b/vortex-array/src/arrays/dict/take.rs @@ -0,0 +1,165 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_error::VortexResult; + +use super::DictArray; +use super::DictVTable; +use crate::Array; +use crate::ArrayRef; +use crate::Canonical; +use crate::ExecutionCtx; +use crate::IntoArray; +use crate::expr::stats::Precision; +use crate::expr::stats::Stat; +use crate::expr::stats::StatsProvider; +use crate::expr::stats::StatsProviderExt; +use crate::kernel::ExecuteParentKernel; +use crate::matcher::Matcher; +use crate::optimizer::rules::ArrayParentReduceRule; +use crate::stats::StatsSet; +use crate::vtable::VTable; + +pub trait TakeReduce: VTable { + /// Take elements from an array at the given indices without reading buffers. + /// + /// This trait is for take implementations that can operate purely on array metadata and + /// structure without needing to read or execute on the underlying buffers. Implementations + /// should return `None` if taking requires buffer access. + /// + /// # Preconditions + /// + /// The indices are guaranteed to be non-empty. + fn take(array: &Self::Array, indices: &dyn Array) -> VortexResult>; +} + +pub trait TakeExecute: VTable { + /// Take elements from an array at the given indices, potentially reading buffers. + /// + /// Unlike [`TakeReduce`], this trait is for take implementations that may need to read + /// and execute on the underlying buffers to produce the result. + /// + /// # Preconditions + /// + /// The indices are guaranteed to be non-empty. + fn take( + array: &Self::Array, + indices: &dyn Array, + ctx: &mut ExecutionCtx, + ) -> VortexResult>; +} + +/// Common preconditions for take operations that apply to all arrays. +/// +/// Returns `Some(result)` if the precondition short-circuits the take operation, +/// or `None` if the take should proceed normally. +fn precondition(array: &V::Array, indices: &dyn Array) -> Option { + // Fast-path for empty indices. + if indices.is_empty() { + let result_dtype = array + .dtype() + .clone() + .union_nullability(indices.dtype().nullability()); + return Some(Canonical::empty(&result_dtype).into_array()); + } + + // TODO(joe): shall we enable this seems expensive. + // if indices.all_invalid()? { + // return Ok( + // ConstantArray::new(Scalar::null(array.dtype().as_nullable()), indices.len()) + // .into_array() + // .into(), + // ); + // } + + None +} + +#[derive(Default, Debug)] +pub struct TakeReduceAdaptor(pub V); + +impl ArrayParentReduceRule for TakeReduceAdaptor +where + V: TakeReduce, +{ + type Parent = DictVTable; + + fn reduce_parent( + &self, + array: &V::Array, + parent: &DictArray, + child_idx: usize, + ) -> VortexResult> { + // Only handle the values child (index 1), not the codes child (index 0). + if child_idx != 1 { + return Ok(None); + } + if let Some(result) = precondition::(array, parent.codes()) { + return Ok(Some(result)); + } + let result = ::take(array, parent.codes())?; + if let Some(ref taken) = result { + propagate_take_stats(&**array, taken.as_ref(), parent.codes())?; + } + Ok(result) + } +} + +#[derive(Default, Debug)] +pub struct TakeExecuteAdaptor(pub V); + +impl ExecuteParentKernel for TakeExecuteAdaptor +where + V: TakeExecute, +{ + type Parent = DictVTable; + + fn execute_parent( + &self, + array: &V::Array, + parent: ::Match<'_>, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + // Only handle the values child (index 1), not the codes child (index 0). + if child_idx != 1 { + return Ok(None); + } + if let Some(result) = precondition::(array, parent.codes()) { + return Ok(Some(result)); + } + let result = ::take(array, parent.codes(), ctx)?; + if let Some(ref taken) = result { + propagate_take_stats(&**array, taken.as_ref(), parent.codes())?; + } + Ok(result) + } +} + +pub(crate) fn propagate_take_stats( + source: &dyn Array, + target: &dyn Array, + indices: &dyn Array, +) -> VortexResult<()> { + target.statistics().with_mut_typed_stats_set(|mut st| { + if indices.all_valid().unwrap_or(false) { + let is_constant = source.statistics().get_as::(Stat::IsConstant); + if is_constant == Some(Precision::Exact(true)) { + // Any combination of elements from a constant array is still const + st.set(Stat::IsConstant, Precision::exact(true)); + } + } + let inexact_min_max = [Stat::Min, Stat::Max] + .into_iter() + .filter_map(|stat| { + source + .statistics() + .get(stat) + .map(|v| (stat, v.map(|s| s.into_value()).into_inexact())) + }) + .collect::>(); + st.combine_sets( + &(unsafe { StatsSet::new_unchecked(inexact_min_max) }).as_typed_ref(source.dtype()), + ) + }) +} diff --git a/vortex-array/src/arrays/dict/vtable/kernel.rs b/vortex-array/src/arrays/dict/vtable/kernel.rs new file mode 100644 index 00000000000..3050f9a5b85 --- /dev/null +++ b/vortex-array/src/arrays/dict/vtable/kernel.rs @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use crate::arrays::DictVTable; +use crate::arrays::TakeExecuteAdaptor; +use crate::kernel::ParentKernelSet; + +pub(super) const PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(DictVTable))]); diff --git a/vortex-array/src/arrays/dict/vtable/mod.rs b/vortex-array/src/arrays/dict/vtable/mod.rs index f504229a970..3e80fd42e32 100644 --- a/vortex-array/src/arrays/dict/vtable/mod.rs +++ b/vortex-array/src/arrays/dict/vtable/mod.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use kernel::PARENT_KERNELS; use vortex_dtype::DType; use vortex_dtype::Nullability; use vortex_dtype::PType; @@ -31,6 +32,7 @@ use crate::vtable::ArrayId; use crate::vtable::VTable; mod array; +mod kernel; mod operations; mod validity; mod visitor; @@ -148,7 +150,7 @@ impl VTable for DictVTable { // TODO(ngates): if indices min is quite high, we could slice self and offset the indices // such that canonicalize does less work. - Ok(take_canonical(values, &codes)?.into_array()) + Ok(take_canonical(values, &codes, ctx)?.into_array()) } fn reduce_parent( @@ -158,6 +160,15 @@ impl VTable for DictVTable { ) -> VortexResult> { PARENT_RULES.evaluate(array, parent, child_idx) } + + fn execute_parent( + array: &Self::Array, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + PARENT_KERNELS.execute(array, parent, child_idx, ctx) + } } /// Check for fast-path execution conditions. diff --git a/vortex-array/src/arrays/extension/compute/take.rs b/vortex-array/src/arrays/extension/compute/take.rs index e66886508c6..aa53205a5b1 100644 --- a/vortex-array/src/arrays/extension/compute/take.rs +++ b/vortex-array/src/arrays/extension/compute/take.rs @@ -5,25 +5,27 @@ use vortex_error::VortexResult; use crate::Array; use crate::ArrayRef; +use crate::ExecutionCtx; use crate::IntoArray; use crate::arrays::ExtensionArray; use crate::arrays::ExtensionVTable; -use crate::compute::TakeKernel; -use crate::compute::TakeKernelAdapter; -use crate::compute::{self}; -use crate::register_kernel; +use crate::arrays::TakeExecute; -impl TakeKernel for ExtensionVTable { - fn take(&self, array: &ExtensionArray, indices: &dyn Array) -> VortexResult { - let taken_storage = compute::take(array.storage(), indices)?; - Ok(ExtensionArray::new( - array - .ext_dtype() - .with_nullability(taken_storage.dtype().nullability()), - taken_storage, - ) - .into_array()) +impl TakeExecute for ExtensionVTable { + fn take( + array: &ExtensionArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + let taken_storage = array.storage().take(indices.to_array())?; + Ok(Some( + ExtensionArray::new( + array + .ext_dtype() + .with_nullability(taken_storage.dtype().nullability()), + taken_storage, + ) + .into_array(), + )) } } - -register_kernel!(TakeKernelAdapter(ExtensionVTable).lift()); diff --git a/vortex-array/src/arrays/extension/vtable/kernel.rs b/vortex-array/src/arrays/extension/vtable/kernel.rs new file mode 100644 index 00000000000..41f4f8a0f9a --- /dev/null +++ b/vortex-array/src/arrays/extension/vtable/kernel.rs @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use crate::arrays::ExtensionVTable; +use crate::arrays::TakeExecuteAdaptor; +use crate::kernel::ParentKernelSet; + +pub(super) const PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(ExtensionVTable))]); diff --git a/vortex-array/src/arrays/extension/vtable/mod.rs b/vortex-array/src/arrays/extension/vtable/mod.rs index a1118a5bdd5..c83e6b65e65 100644 --- a/vortex-array/src/arrays/extension/vtable/mod.rs +++ b/vortex-array/src/arrays/extension/vtable/mod.rs @@ -3,10 +3,12 @@ mod array; mod canonical; +mod kernel; mod operations; mod validity; mod visitor; +use kernel::PARENT_KERNELS; use vortex_dtype::DType; use vortex_error::VortexExpect; use vortex_error::VortexResult; @@ -100,6 +102,15 @@ impl VTable for ExtensionVTable { ) -> VortexResult> { PARENT_RULES.evaluate(array, parent, child_idx) } + + fn execute_parent( + array: &Self::Array, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + PARENT_KERNELS.execute(array, parent, child_idx, ctx) + } } #[derive(Debug)] diff --git a/vortex-array/src/arrays/filter/kernel.rs b/vortex-array/src/arrays/filter/kernel.rs index 174530caab8..fb1e6ab9fa4 100644 --- a/vortex-array/src/arrays/filter/kernel.rs +++ b/vortex-array/src/arrays/filter/kernel.rs @@ -54,7 +54,7 @@ pub trait FilterKernel: VTable { /// /// Returns `Some(result)` if the precondition short-circuits the filter operation, /// or `None` if the filter should proceed normally. -pub fn precondition(array: &V::Array, mask: &Mask) -> Option { +fn precondition(array: &V::Array, mask: &Mask) -> Option { let true_count = mask.true_count(); // Fast-path for empty mask (all false). diff --git a/vortex-array/src/arrays/fixed_size_list/array.rs b/vortex-array/src/arrays/fixed_size_list/array.rs index 65692a8f495..e512b77518c 100644 --- a/vortex-array/src/arrays/fixed_size_list/array.rs +++ b/vortex-array/src/arrays/fixed_size_list/array.rs @@ -227,8 +227,9 @@ impl FixedSizeListArray { pub fn fixed_size_list_elements_at(&self, index: usize) -> VortexResult { debug_assert!( index < self.len, - "index out of bounds: the len is {} but the index is {index}", - self.len + "index {} out of bounds: the len is {}", + index, + self.len, ); debug_assert!(self.validity.is_valid(index).unwrap_or(false)); diff --git a/vortex-array/src/arrays/fixed_size_list/compute/take.rs b/vortex-array/src/arrays/fixed_size_list/compute/take.rs index f4ae9ad9b21..3d507b19490 100644 --- a/vortex-array/src/arrays/fixed_size_list/compute/take.rs +++ b/vortex-array/src/arrays/fixed_size_list/compute/take.rs @@ -16,10 +16,8 @@ use crate::ToCanonical; use crate::arrays::FixedSizeListArray; use crate::arrays::FixedSizeListVTable; use crate::arrays::PrimitiveArray; -use crate::compute::TakeKernel; -use crate::compute::TakeKernelAdapter; -use crate::compute::{self}; -use crate::register_kernel; +use crate::arrays::TakeExecute; +use crate::executor::ExecutionCtx; use crate::validity::Validity; use crate::vtable::ValidityHelper; @@ -28,16 +26,19 @@ use crate::vtable::ValidityHelper; /// Unlike `ListView`, `FixedSizeListArray` must rebuild the elements array because it requires /// that elements start at offset 0 and be perfectly packed without gaps. We expand list indices /// into element indices and push them down to the child elements array. -impl TakeKernel for FixedSizeListVTable { - fn take(&self, array: &FixedSizeListArray, indices: &dyn Array) -> VortexResult { +impl TakeExecute for FixedSizeListVTable { + fn take( + array: &FixedSizeListArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { match_each_integer_ptype!(indices.dtype().as_ptype(), |I| { take_with_indices::(array, indices) }) + .map(Some) } } -register_kernel!(TakeKernelAdapter(FixedSizeListVTable).lift()); - /// Dispatches to the appropriate take implementation based on list size and nullability. fn take_with_indices( array: &FixedSizeListArray, @@ -114,7 +115,7 @@ fn take_non_nullable_fsl( debug_assert_eq!(elements_indices.len(), new_len * list_size); let elements_indices_array = PrimitiveArray::new(elements_indices, Validity::NonNullable); - let new_elements = compute::take(array.elements(), elements_indices_array.as_ref())?; + let new_elements = array.elements().take(elements_indices_array.to_array())?; debug_assert_eq!(new_elements.len(), new_len * list_size); // Both inputs are non-nullable, so the result is non-nullable. @@ -181,7 +182,7 @@ fn take_nullable_fsl( debug_assert_eq!(elements_indices.len(), new_len * list_size); let elements_indices_array = PrimitiveArray::new(elements_indices, Validity::NonNullable); - let new_elements = compute::take(array.elements(), elements_indices_array.as_ref())?; + let new_elements = array.elements().take(elements_indices_array.to_array())?; debug_assert_eq!(new_elements.len(), new_len * list_size); // At least one input was nullable, so the result is nullable. diff --git a/vortex-array/src/arrays/fixed_size_list/vtable/kernel.rs b/vortex-array/src/arrays/fixed_size_list/vtable/kernel.rs new file mode 100644 index 00000000000..6d38dccf2fa --- /dev/null +++ b/vortex-array/src/arrays/fixed_size_list/vtable/kernel.rs @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use crate::arrays::FixedSizeListVTable; +use crate::arrays::TakeExecuteAdaptor; +use crate::kernel::ParentKernelSet; + +impl FixedSizeListVTable { + pub(crate) const PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor( + FixedSizeListVTable, + ))]); +} diff --git a/vortex-array/src/arrays/fixed_size_list/vtable/mod.rs b/vortex-array/src/arrays/fixed_size_list/vtable/mod.rs index 7299c7aba6b..ba5146195c2 100644 --- a/vortex-array/src/arrays/fixed_size_list/vtable/mod.rs +++ b/vortex-array/src/arrays/fixed_size_list/vtable/mod.rs @@ -22,6 +22,7 @@ use crate::vtable::VTable; use crate::vtable::ValidityVTableFromValidityHelper; mod array; +mod kernel; mod operations; mod validity; mod visitor; @@ -57,6 +58,15 @@ impl VTable for FixedSizeListVTable { PARENT_RULES.evaluate(array, parent, child_idx) } + fn execute_parent( + array: &Self::Array, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + Self::PARENT_KERNELS.execute(array, parent, child_idx, ctx) + } + fn metadata(_array: &FixedSizeListArray) -> VortexResult { Ok(EmptyMetadata) } diff --git a/vortex-array/src/arrays/list/compute/filter.rs b/vortex-array/src/arrays/list/compute/filter.rs index 4650b37ee37..0393c65a817 100644 --- a/vortex-array/src/arrays/list/compute/filter.rs +++ b/vortex-array/src/arrays/list/compute/filter.rs @@ -14,7 +14,6 @@ use vortex_mask::Mask; use vortex_mask::MaskIter; use vortex_mask::MaskValues; -use crate::Array; use crate::ArrayRef; use crate::Canonical; use crate::ExecutionCtx; @@ -26,64 +25,6 @@ use crate::arrays::ListVTable; use crate::validity::Validity; use crate::vtable::ValidityHelper; -impl FilterKernel for ListVTable { - fn filter( - array: &ListArray, - mask: &Mask, - ctx: &mut ExecutionCtx, - ) -> VortexResult> { - let selection = match mask { - Mask::AllTrue(_) | Mask::AllFalse(_) => return Ok(None), - Mask::Values(v) => v, - }; - - let new_validity = match array.validity() { - Validity::NonNullable => Validity::NonNullable, - Validity::AllValid => Validity::AllValid, - Validity::AllInvalid => { - let elements = Canonical::empty(array.element_dtype()).into_array(); - let offsets = ConstantArray::new(0u64, selection.true_count() + 1).into_array(); - return Ok(Some(unsafe { - ListArray::new_unchecked(elements, offsets, Validity::AllInvalid).into_array() - })); - } - Validity::Array(a) => Validity::Array(a.filter(mask.clone())?), - }; - - // TODO(ngates): for ultra-sparse masks, we don't need to optimize the entire offsets. - let offsets = array.offsets().clone(); - - let (new_offsets, element_mask) = - match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| { - let offsets_buffer = offsets.execute::>(ctx)?; - let offsets = offsets_buffer.as_slice(); - let mut new_offsets = BufferMut::::with_capacity(selection.true_count() + 1); - - let mut offset = O::zero(); - unsafe { new_offsets.push_unchecked(offset) }; - for idx in selection.indices() { - let size = offsets[idx + 1] - offsets[*idx]; - offset += size; - unsafe { new_offsets.push_unchecked(offset) }; - } - - // TODO(ngates): for very dense masks, there may be no point in filtering the elements, - // and instead we should construct a view against the unfiltered elements. - let element_mask = element_mask_from_offsets::(offsets, selection); - - (new_offsets.freeze().into_array(), element_mask) - }); - - let new_elements = array.sliced_elements()?.filter(element_mask)?; - - // SAFETY: new_offsets are monotonically increasing starting from 0 with length - // true_count + 1, and the elements have been filtered to match. - Ok(Some(unsafe { - ListArray::new_unchecked(new_elements, new_offsets, new_validity).into_array() - })) - } -} - /// Density threshold for choosing between indices and slices representation when expanding masks. /// /// When the mask density is below this threshold, we use indices. Otherwise, we use slices. @@ -92,7 +33,10 @@ impl FilterKernel for ListVTable { const MASK_EXPANSION_DENSITY_THRESHOLD: f64 = 0.05; /// Construct an element mask from contiguous list offsets and a selection mask. -fn element_mask_from_offsets(offsets: &[O], selection: &Arc) -> Mask { +pub fn element_mask_from_offsets( + offsets: &[O], + selection: &Arc, +) -> Mask { let first_offset = offsets.first().map_or(0, |first_offset| first_offset.as_()); let last_offset = offsets.last().map_or(0, |last_offset| last_offset.as_()); let len = last_offset - first_offset; @@ -147,3 +91,61 @@ fn process_element_range( new_mask_builder.append_n(true, elems_len); } } + +impl FilterKernel for ListVTable { + fn filter( + array: &ListArray, + mask: &Mask, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + let selection = match mask { + Mask::AllTrue(_) | Mask::AllFalse(_) => return Ok(None), + Mask::Values(v) => v, + }; + + let new_validity = match array.validity() { + Validity::NonNullable => Validity::NonNullable, + Validity::AllValid => Validity::AllValid, + Validity::AllInvalid => { + let elements = Canonical::empty(array.element_dtype()).into_array(); + let offsets = ConstantArray::new(0u64, selection.true_count() + 1).into_array(); + return Ok(Some(unsafe { + ListArray::new_unchecked(elements, offsets, Validity::AllInvalid).into_array() + })); + } + Validity::Array(a) => Validity::Array(a.filter(mask.clone())?), + }; + + // TODO(ngates): for ultra-sparse masks, we don't need to optimize the entire offsets. + let offsets = array.offsets().clone(); + + let (new_offsets, element_mask) = + match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| { + let offsets_buffer = offsets.execute::>(ctx)?; + let offsets = offsets_buffer.as_slice(); + let mut new_offsets = BufferMut::::with_capacity(selection.true_count() + 1); + + let mut offset = O::zero(); + unsafe { new_offsets.push_unchecked(offset) }; + for idx in selection.indices() { + let size = offsets[idx + 1] - offsets[*idx]; + offset += size; + unsafe { new_offsets.push_unchecked(offset) }; + } + + // TODO(ngates): for very dense masks, there may be no point in filtering the elements, + // and instead we should construct a view against the unfiltered elements. + let element_mask = element_mask_from_offsets::(offsets, selection); + + (new_offsets.freeze().into_array(), element_mask) + }); + + let new_elements = array.sliced_elements()?.filter(element_mask)?; + + // SAFETY: new_offsets are monotonically increasing starting from 0 with length + // true_count + 1, and the elements have been filtered to match. + Ok(Some(unsafe { + ListArray::new_unchecked(new_elements, new_offsets, new_validity).into_array() + })) + } +} diff --git a/vortex-array/src/arrays/list/compute/kernels.rs b/vortex-array/src/arrays/list/compute/kernels.rs index 38e065c4ebf..6f3ea7c4579 100644 --- a/vortex-array/src/arrays/list/compute/kernels.rs +++ b/vortex-array/src/arrays/list/compute/kernels.rs @@ -3,7 +3,10 @@ use crate::arrays::FilterExecuteAdaptor; use crate::arrays::ListVTable; +use crate::arrays::TakeExecuteAdaptor; use crate::kernel::ParentKernelSet; -pub(crate) const PARENT_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&FilterExecuteAdaptor(ListVTable))]); +pub(crate) const PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ + ParentKernelSet::lift(&FilterExecuteAdaptor(ListVTable)), + ParentKernelSet::lift(&TakeExecuteAdaptor(ListVTable)), +]); diff --git a/vortex-array/src/arrays/list/compute/take.rs b/vortex-array/src/arrays/list/compute/take.rs index 5bc6c309671..29235df9d25 100644 --- a/vortex-array/src/arrays/list/compute/take.rs +++ b/vortex-array/src/arrays/list/compute/take.rs @@ -10,29 +10,32 @@ use vortex_error::VortexResult; use crate::Array; use crate::ArrayRef; +use crate::IntoArray; use crate::ToCanonical; use crate::arrays::ListArray; use crate::arrays::ListVTable; use crate::arrays::PrimitiveArray; +use crate::arrays::TakeExecute; use crate::builders::ArrayBuilder; use crate::builders::PrimitiveBuilder; -use crate::compute::TakeKernel; -use crate::compute::TakeKernelAdapter; -use crate::compute::take; -use crate::register_kernel; +use crate::executor::ExecutionCtx; use crate::vtable::ValidityHelper; // TODO(connor)[ListView]: Re-revert to the version where we simply convert to a `ListView` and call // the `ListView::take` compute function once `ListView` is more stable. -/// Take implementation for [`ListArray`]. -/// -/// Unlike `ListView`, `ListArray` must rebuild the elements array to maintain its invariant -/// that lists are stored contiguously and in-order (`offset[i+1] >= offset[i]`). Taking -/// non-contiguous indices would violate this requirement. -impl TakeKernel for ListVTable { +impl TakeExecute for ListVTable { + /// Take implementation for [`ListArray`]. + /// + /// Unlike `ListView`, `ListArray` must rebuild the elements array to maintain its invariant + /// that lists are stored contiguously and in-order (`offset[i+1] >= offset[i]`). Taking + /// non-contiguous indices would violate this requirement. #[expect(clippy::cognitive_complexity)] - fn take(&self, array: &ListArray, indices: &dyn Array) -> VortexResult { + fn take( + array: &ListArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { let indices = indices.to_primitive(); // This is an over-approximation of the total number of elements in the resulting array. let total_approx = array.elements().len().saturating_mul(indices.len()); @@ -40,15 +43,13 @@ impl TakeKernel for ListVTable { match_each_integer_ptype!(array.offsets().dtype().as_ptype(), |O| { match_each_integer_ptype!(indices.ptype(), |I| { match_smallest_offset_type!(total_approx, |OutputOffsetType| { - _take::(array, &indices) + _take::(array, &indices).map(Some) }) }) }) } } -register_kernel!(TakeKernelAdapter(ListVTable).lift()); - fn _take( array: &ListArray, indices_array: &PrimitiveArray, @@ -100,7 +101,11 @@ fn _take( let elements_to_take = elements_to_take.finish(); let new_offsets = new_offsets.finish(); - let new_elements = take(array.elements(), elements_to_take.as_ref())?; + let new_elements = array + .elements() + .take(elements_to_take.to_array())? + .to_canonical()? + .into_array(); Ok(ListArray::try_new( new_elements, @@ -168,7 +173,11 @@ fn _take_nullable VortexResult { +impl TakeExecute for ListViewVTable { + fn take( + array: &ListViewArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { let elements = array.elements(); let offsets = array.offsets(); let sizes = array.sizes(); @@ -55,8 +58,8 @@ impl TakeKernel for ListViewVTable { // Take the offsets and sizes arrays at the requested indices. // Take can reorder offsets, create gaps, and may introduce overlaps if the `indices` // contain duplicates. - let nullable_new_offsets = compute::take(offsets.as_ref(), indices)?; - let nullable_new_sizes = compute::take(sizes.as_ref(), indices)?; + let nullable_new_offsets = offsets.take(indices.to_array())?; + let nullable_new_sizes = sizes.take(indices.to_array())?; // Since `take` returns nullable arrays, we simply cast it back to non-nullable (filled with // zeros to represent null lists). @@ -86,10 +89,10 @@ impl TakeKernel for ListViewVTable { // compute functions have run, at the "top" of the operator tree. However, we cannot do this // right now, so we will just rebuild every time (similar to `ListArray`). - Ok(new_array - .rebuild(ListViewRebuildMode::MakeZeroCopyToList)? - .into_array()) + Ok(Some( + new_array + .rebuild(ListViewRebuildMode::MakeZeroCopyToList)? + .into_array(), + )) } } - -register_kernel!(TakeKernelAdapter(ListViewVTable).lift()); diff --git a/vortex-array/src/arrays/listview/vtable/kernel.rs b/vortex-array/src/arrays/listview/vtable/kernel.rs new file mode 100644 index 00000000000..4403eaa46a4 --- /dev/null +++ b/vortex-array/src/arrays/listview/vtable/kernel.rs @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use crate::arrays::ListViewVTable; +use crate::arrays::TakeExecuteAdaptor; +use crate::kernel::ParentKernelSet; + +pub(super) const PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(ListViewVTable))]); diff --git a/vortex-array/src/arrays/listview/vtable/mod.rs b/vortex-array/src/arrays/listview/vtable/mod.rs index ada1eeca207..a4207e91d3d 100644 --- a/vortex-array/src/arrays/listview/vtable/mod.rs +++ b/vortex-array/src/arrays/listview/vtable/mod.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use kernel::PARENT_KERNELS; use vortex_dtype::DType; use vortex_dtype::Nullability; use vortex_dtype::PType; @@ -26,6 +27,7 @@ use crate::vtable::VTable; use crate::vtable::ValidityVTableFromValidityHelper; mod array; +mod kernel; mod operations; mod validity; mod visitor; @@ -176,4 +178,13 @@ impl VTable for ListViewVTable { ) -> VortexResult> { PARENT_RULES.evaluate(array, parent, child_idx) } + + fn execute_parent( + array: &Self::Array, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + PARENT_KERNELS.execute(array, parent, child_idx, ctx) + } } diff --git a/vortex-array/src/arrays/masked/compute/take.rs b/vortex-array/src/arrays/masked/compute/take.rs index 945a1a047ed..162d6e5699d 100644 --- a/vortex-array/src/arrays/masked/compute/take.rs +++ b/vortex-array/src/arrays/masked/compute/take.rs @@ -6,39 +6,45 @@ use vortex_scalar::Scalar; use crate::Array; use crate::ArrayRef; +use crate::ExecutionCtx; use crate::IntoArray; use crate::arrays::MaskedArray; use crate::arrays::MaskedVTable; -use crate::compute::TakeKernel; -use crate::compute::TakeKernelAdapter; +use crate::arrays::TakeExecute; use crate::compute::fill_null; -use crate::compute::take; -use crate::register_kernel; use crate::vtable::ValidityHelper; -impl TakeKernel for MaskedVTable { - fn take(&self, array: &MaskedArray, indices: &dyn Array) -> VortexResult { +impl TakeExecute for MaskedVTable { + fn take( + array: &MaskedArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { let taken_child = if !indices.all_valid()? { // This is safe because we'll mask out these positions in the validity let filled_take = fill_null( indices, &Scalar::default_value(indices.dtype().clone().as_nonnullable()), )?; - take(&array.child, &filled_take)? + array.child.take(filled_take)?.to_canonical()?.into_array() } else { - take(&array.child, indices)? + array + .child + .take(indices.to_array())? + .to_canonical()? + .into_array() }; // Compute the new validity by taking from array's validity and merging with indices validity let taken_validity = array.validity().take(indices)?; // Construct new MaskedArray - Ok(MaskedArray::try_new(taken_child, taken_validity)?.into_array()) + Ok(Some( + MaskedArray::try_new(taken_child, taken_validity)?.into_array(), + )) } } -register_kernel!(TakeKernelAdapter(MaskedVTable).lift()); - #[cfg(test)] mod tests { use rstest::rstest; diff --git a/vortex-array/src/arrays/masked/execute.rs b/vortex-array/src/arrays/masked/execute.rs index 1dd9523ed4e..00e803de64e 100644 --- a/vortex-array/src/arrays/masked/execute.rs +++ b/vortex-array/src/arrays/masked/execute.rs @@ -99,7 +99,7 @@ fn mask_validity_decimal(array: DecimalArray, mask: &Mask) -> DecimalArray { /// Mask validity for VarBinViewArray. fn mask_validity_varbinview(array: VarBinViewArray, mask: &Mask) -> VarBinViewArray { let len = array.len(); - let dtype = array.dtype().clone(); + let dtype = array.dtype().as_nullable(); let new_validity = combine_validity(array.validity(), mask, len); // SAFETY: We're only changing validity, not the data structure unsafe { diff --git a/vortex-array/src/arrays/masked/vtable/kernel.rs b/vortex-array/src/arrays/masked/vtable/kernel.rs new file mode 100644 index 00000000000..869e8601aba --- /dev/null +++ b/vortex-array/src/arrays/masked/vtable/kernel.rs @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use crate::arrays::MaskedVTable; +use crate::arrays::TakeExecuteAdaptor; +use crate::kernel::ParentKernelSet; + +pub(super) const PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(MaskedVTable))]); diff --git a/vortex-array/src/arrays/masked/vtable/mod.rs b/vortex-array/src/arrays/masked/vtable/mod.rs index 9b0fa70c6d5..480d16b4341 100644 --- a/vortex-array/src/arrays/masked/vtable/mod.rs +++ b/vortex-array/src/arrays/masked/vtable/mod.rs @@ -6,6 +6,7 @@ mod canonical; mod operations; mod validity; +use kernel::PARENT_KERNELS; use vortex_dtype::DType; use vortex_error::VortexExpect; use vortex_error::VortexResult; @@ -34,6 +35,8 @@ use crate::vtable::VTable; use crate::vtable::ValidityVTableFromValidityHelper; use crate::vtable::VisitorVTable; +mod kernel; + vtable!(Masked); #[derive(Debug)] @@ -140,6 +143,15 @@ impl VTable for MaskedVTable { PARENT_RULES.evaluate(array, parent, child_idx) } + fn execute_parent( + array: &Self::Array, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + PARENT_KERNELS.execute(array, parent, child_idx, ctx) + } + fn with_children(array: &mut Self::Array, children: Vec) -> VortexResult<()> { vortex_ensure!( children.len() == 1 || children.len() == 2, diff --git a/vortex-array/src/arrays/null/compute/rules.rs b/vortex-array/src/arrays/null/compute/rules.rs index 33a554e90bd..62aed40e66c 100644 --- a/vortex-array/src/arrays/null/compute/rules.rs +++ b/vortex-array/src/arrays/null/compute/rules.rs @@ -4,9 +4,11 @@ use crate::arrays::FilterReduceAdaptor; use crate::arrays::NullVTable; use crate::arrays::SliceReduceAdaptor; +use crate::arrays::TakeReduceAdaptor; use crate::optimizer::rules::ParentRuleSet; pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::new(&[ ParentRuleSet::lift(&FilterReduceAdaptor(NullVTable)), ParentRuleSet::lift(&SliceReduceAdaptor(NullVTable)), + ParentRuleSet::lift(&TakeReduceAdaptor(NullVTable)), ]); diff --git a/vortex-array/src/arrays/null/compute/take.rs b/vortex-array/src/arrays/null/compute/take.rs index 4d3d595bd7d..be266175656 100644 --- a/vortex-array/src/arrays/null/compute/take.rs +++ b/vortex-array/src/arrays/null/compute/take.rs @@ -11,13 +11,13 @@ use crate::IntoArray; use crate::ToCanonical; use crate::arrays::NullArray; use crate::arrays::NullVTable; -use crate::compute::TakeKernel; -use crate::compute::TakeKernelAdapter; -use crate::register_kernel; +use crate::arrays::TakeReduce; +use crate::arrays::TakeReduceAdaptor; +use crate::optimizer::rules::ParentRuleSet; -impl TakeKernel for NullVTable { +impl TakeReduce for NullVTable { #[allow(clippy::cast_possible_truncation)] - fn take(&self, array: &NullArray, indices: &dyn Array) -> VortexResult { + fn take(array: &NullArray, indices: &dyn Array) -> VortexResult> { let indices = indices.to_primitive(); // Enforce all indices are valid @@ -29,8 +29,11 @@ impl TakeKernel for NullVTable { } }); - Ok(NullArray::new(indices.len()).into_array()) + Ok(Some(NullArray::new(indices.len()).into_array())) } } -register_kernel!(TakeKernelAdapter(NullVTable).lift()); +impl NullVTable { + pub const TAKE_RULES: ParentRuleSet = + ParentRuleSet::new(&[ParentRuleSet::lift(&TakeReduceAdaptor::(Self))]); +} diff --git a/vortex-array/src/arrays/primitive/compute/take/mod.rs b/vortex-array/src/arrays/primitive/compute/take/mod.rs index fd2ddb63cf3..be832f50cb6 100644 --- a/vortex-array/src/arrays/primitive/compute/take/mod.rs +++ b/vortex-array/src/arrays/primitive/compute/take/mod.rs @@ -23,11 +23,10 @@ use crate::ArrayRef; use crate::IntoArray; use crate::ToCanonical; use crate::arrays::PrimitiveVTable; +use crate::arrays::TakeExecute; use crate::arrays::primitive::PrimitiveArray; -use crate::compute::TakeKernel; -use crate::compute::TakeKernelAdapter; use crate::compute::cast; -use crate::register_kernel; +use crate::executor::ExecutionCtx; use crate::validity::Validity; use crate::vtable::ValidityHelper; @@ -81,8 +80,12 @@ impl TakeImpl for TakeKernelScalar { } } -impl TakeKernel for PrimitiveVTable { - fn take(&self, array: &PrimitiveArray, indices: &dyn Array) -> VortexResult { +impl TakeExecute for PrimitiveVTable { + fn take( + array: &PrimitiveArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { let DType::Primitive(ptype, null) = indices.dtype() else { vortex_bail!("Invalid indices dtype: {}", indices.dtype()) }; @@ -96,12 +99,12 @@ impl TakeKernel for PrimitiveVTable { let validity = array.validity().take(unsigned_indices.as_ref())?; // Delegate to the best kernel based on the target CPU - PRIMITIVE_TAKE_KERNEL.take(array, &unsigned_indices, validity) + PRIMITIVE_TAKE_KERNEL + .take(array, &unsigned_indices, validity) + .map(Some) } } -register_kernel!(TakeKernelAdapter(PrimitiveVTable).lift()); - // Compiler may see this as unused based on enabled features #[allow(unused)] #[inline(always)] diff --git a/vortex-array/src/arrays/primitive/vtable/kernel.rs b/vortex-array/src/arrays/primitive/vtable/kernel.rs new file mode 100644 index 00000000000..7882fde292e --- /dev/null +++ b/vortex-array/src/arrays/primitive/vtable/kernel.rs @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use crate::arrays::PrimitiveVTable; +use crate::arrays::TakeExecuteAdaptor; +use crate::kernel::ParentKernelSet; + +pub(super) const PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(PrimitiveVTable))]); diff --git a/vortex-array/src/arrays/primitive/vtable/mod.rs b/vortex-array/src/arrays/primitive/vtable/mod.rs index 4c218d97ae7..270c08de8af 100644 --- a/vortex-array/src/arrays/primitive/vtable/mod.rs +++ b/vortex-array/src/arrays/primitive/vtable/mod.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use kernel::PARENT_KERNELS; use vortex_dtype::DType; use vortex_dtype::PType; use vortex_error::VortexExpect; @@ -20,6 +21,7 @@ use crate::vtable::VTable; use crate::vtable::ValidityVTableFromValidityHelper; mod array; +mod kernel; mod operations; mod validity; mod visitor; @@ -138,6 +140,15 @@ impl VTable for PrimitiveVTable { ) -> VortexResult> { RULES.evaluate(array, parent, child_idx) } + + fn execute_parent( + array: &Self::Array, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + PARENT_KERNELS.execute(array, parent, child_idx, ctx) + } } #[derive(Debug)] diff --git a/vortex-array/src/arrays/slice/mod.rs b/vortex-array/src/arrays/slice/mod.rs index acfba57c52a..2338fb7048e 100644 --- a/vortex-array/src/arrays/slice/mod.rs +++ b/vortex-array/src/arrays/slice/mod.rs @@ -13,7 +13,9 @@ use vortex_error::VortexResult; pub use vtable::*; use crate::ArrayRef; +use crate::Canonical; use crate::ExecutionCtx; +use crate::IntoArray; use crate::kernel::ExecuteParentKernel; use crate::matcher::Matcher; use crate::optimizer::rules::ArrayParentReduceRule; @@ -52,6 +54,16 @@ pub trait SliceKernel: VTable { ) -> VortexResult>; } +fn precondition(array: &V::Array, range: &Range) -> Option { + if range.start == 0 && range.end == array.len() { + return Some(array.to_array()); + }; + if range.start == range.end { + return Some(Canonical::empty(array.dtype()).into_array()); + } + None +} + #[derive(Default, Debug)] pub struct SliceReduceAdaptor(pub V); @@ -68,6 +80,9 @@ where child_idx: usize, ) -> VortexResult> { assert_eq!(child_idx, 0); + if let Some(result) = precondition::(array, &parent.range) { + return Ok(Some(result)); + } ::slice(array, parent.range.clone()) } } @@ -89,6 +104,9 @@ where ctx: &mut ExecutionCtx, ) -> VortexResult> { assert_eq!(child_idx, 0); + if let Some(result) = precondition::(array, &parent.range) { + return Ok(Some(result)); + } ::slice(array, parent.range.clone(), ctx) } } diff --git a/vortex-array/src/arrays/struct_/compute/take.rs b/vortex-array/src/arrays/struct_/compute/take.rs index 39bb432d1dd..c11ff7dd3a9 100644 --- a/vortex-array/src/arrays/struct_/compute/take.rs +++ b/vortex-array/src/arrays/struct_/compute/take.rs @@ -7,18 +7,21 @@ use vortex_scalar::Scalar; use crate::Array; use crate::ArrayRef; +use crate::ExecutionCtx; use crate::IntoArray; use crate::arrays::StructArray; use crate::arrays::StructVTable; -use crate::compute::TakeKernel; -use crate::compute::TakeKernelAdapter; -use crate::compute::{self}; -use crate::register_kernel; +use crate::arrays::TakeExecute; +use crate::compute; use crate::validity::Validity; use crate::vtable::ValidityHelper; -impl TakeKernel for StructVTable { - fn take(&self, array: &StructArray, indices: &dyn Array) -> VortexResult { +impl TakeExecute for StructVTable { + fn take( + array: &StructArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { // If the struct array is empty then the indices must be all null, otherwise it will access // an out of bounds element if array.is_empty() { @@ -28,7 +31,8 @@ impl TakeKernel for StructVTable { indices.len(), Validity::AllInvalid, ) - .map(StructArray::into_array); + .map(StructArray::into_array) + .map(Some); } // The validity is applied to the struct validity, let inner_indices = &compute::fill_null( @@ -39,14 +43,13 @@ impl TakeKernel for StructVTable { array .unmasked_fields() .iter() - .map(|field| compute::take(field, inner_indices)) + .map(|field| field.take(inner_indices.to_array())) .collect::, _>>()?, array.struct_fields().clone(), indices.len(), array.validity().take(indices)?, ) .map(|a| a.into_array()) + .map(Some) } } - -register_kernel!(TakeKernelAdapter(StructVTable).lift()); diff --git a/vortex-array/src/arrays/struct_/vtable/kernel.rs b/vortex-array/src/arrays/struct_/vtable/kernel.rs new file mode 100644 index 00000000000..74a22f2ded4 --- /dev/null +++ b/vortex-array/src/arrays/struct_/vtable/kernel.rs @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use crate::arrays::StructVTable; +use crate::arrays::TakeExecuteAdaptor; +use crate::kernel::ParentKernelSet; + +pub(super) const PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(StructVTable))]); diff --git a/vortex-array/src/arrays/struct_/vtable/mod.rs b/vortex-array/src/arrays/struct_/vtable/mod.rs index e26949eeb6c..44aa55b951d 100644 --- a/vortex-array/src/arrays/struct_/vtable/mod.rs +++ b/vortex-array/src/arrays/struct_/vtable/mod.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use itertools::Itertools; +use kernel::PARENT_KERNELS; use vortex_dtype::DType; use vortex_error::VortexExpect; use vortex_error::VortexResult; @@ -24,6 +25,7 @@ use crate::vtable::VTable; use crate::vtable::ValidityVTableFromValidityHelper; mod array; +mod kernel; mod operations; mod validity; mod visitor; @@ -144,6 +146,15 @@ impl VTable for StructVTable { ) -> VortexResult> { PARENT_RULES.evaluate(array, parent, child_idx) } + + fn execute_parent( + array: &Self::Array, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + PARENT_KERNELS.execute(array, parent, child_idx, ctx) + } } #[derive(Debug)] diff --git a/vortex-array/src/arrays/varbin/compute/take.rs b/vortex-array/src/arrays/varbin/compute/take.rs index 0562a20209e..a1719f2dfb8 100644 --- a/vortex-array/src/arrays/varbin/compute/take.rs +++ b/vortex-array/src/arrays/varbin/compute/take.rs @@ -17,15 +17,19 @@ use crate::ArrayRef; use crate::IntoArray; use crate::ToCanonical; use crate::arrays::PrimitiveArray; +use crate::arrays::TakeExecute; use crate::arrays::VarBinVTable; use crate::arrays::varbin::VarBinArray; -use crate::compute::TakeKernel; -use crate::compute::TakeKernelAdapter; -use crate::register_kernel; +use crate::executor::ExecutionCtx; use crate::validity::Validity; -impl TakeKernel for VarBinVTable { - fn take(&self, array: &VarBinArray, indices: &dyn Array) -> VortexResult { +impl TakeExecute for VarBinVTable { + fn take( + array: &VarBinArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + // TODO(joe): Be lazy with execute let offsets = array.offsets().to_primitive(); let data = array.bytes(); let indices = indices.to_primitive(); @@ -33,6 +37,9 @@ impl TakeKernel for VarBinVTable { .dtype() .clone() .union_nullability(indices.dtype().nullability()); + let array_validity = array.validity_mask()?; + let indices_validity = indices.validity_mask()?; + let array = match_each_integer_ptype!(indices.ptype(), |I| { // On take, offsets get widened to either 32- or 64-bit based on the original type, // to avoid overflow issues. @@ -42,75 +49,73 @@ impl TakeKernel for VarBinVTable { offsets.as_slice::(), data.as_slice(), indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, + array_validity, + indices_validity, ), PType::U16 => take::( dtype, offsets.as_slice::(), data.as_slice(), indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, + array_validity, + indices_validity, ), PType::U32 => take::( dtype, offsets.as_slice::(), data.as_slice(), indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, + array_validity, + indices_validity, ), PType::U64 => take::( dtype, offsets.as_slice::(), data.as_slice(), indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, + array_validity, + indices_validity, ), PType::I8 => take::( dtype, offsets.as_slice::(), data.as_slice(), indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, + array_validity, + indices_validity, ), PType::I16 => take::( dtype, offsets.as_slice::(), data.as_slice(), indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, + array_validity, + indices_validity, ), PType::I32 => take::( dtype, offsets.as_slice::(), data.as_slice(), indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, + array_validity, + indices_validity, ), PType::I64 => take::( dtype, offsets.as_slice::(), data.as_slice(), indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, + array_validity, + indices_validity, ), _ => unreachable!("invalid PType for offsets"), } }); - Ok(array?.into_array()) + Ok(Some(array?.into_array())) } } -register_kernel!(TakeKernelAdapter(VarBinVTable).lift()); - fn take( dtype: DType, offsets: &[Offset], @@ -253,7 +258,7 @@ mod tests { use crate::IntoArray; use crate::arrays::PrimitiveArray; use crate::arrays::VarBinArray; - use crate::arrays::VarBinVTable; + use crate::arrays::VarBinViewVTable; use crate::compute::conformance::take::test_take_conformance; use crate::compute::take; use crate::validity::Validity; @@ -311,10 +316,10 @@ mod tests { let indices = buffer![0u32, 0u32, 0u32].into_array(); let taken = take(array.as_ref(), indices.as_ref()).unwrap(); - let taken_str = taken.as_::(); - assert_eq!(taken_str.len(), 3); - assert_eq!(taken_str.bytes_at(0).as_bytes(), scream.as_bytes()); - assert_eq!(taken_str.bytes_at(1).as_bytes(), scream.as_bytes()); - assert_eq!(taken_str.bytes_at(2).as_bytes(), scream.as_bytes()); + let taken_view = taken.as_::(); + assert_eq!(taken_view.len(), 3); + assert_eq!(taken_view.bytes_at(0).as_slice(), scream.as_bytes()); + assert_eq!(taken_view.bytes_at(1).as_slice(), scream.as_bytes()); + assert_eq!(taken_view.bytes_at(2).as_slice(), scream.as_bytes()); } } diff --git a/vortex-array/src/arrays/varbin/mod.rs b/vortex-array/src/arrays/varbin/mod.rs index 127f938fe2c..fd4806b3206 100644 --- a/vortex-array/src/arrays/varbin/mod.rs +++ b/vortex-array/src/arrays/varbin/mod.rs @@ -6,7 +6,6 @@ pub use array::VarBinArray; pub(crate) mod compute; pub(crate) use compute::varbin_compute_min_max; -// For use in `varbinview`. mod vtable; pub use vtable::VarBinVTable; diff --git a/vortex-array/src/arrays/varbin/vtable/kernel.rs b/vortex-array/src/arrays/varbin/vtable/kernel.rs index 09dd04b7557..6a94dffb2f8 100644 --- a/vortex-array/src/arrays/varbin/vtable/kernel.rs +++ b/vortex-array/src/arrays/varbin/vtable/kernel.rs @@ -1,9 +1,12 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use crate::arrays::TakeExecuteAdaptor; use crate::arrays::VarBinVTable; use crate::arrays::filter::FilterExecuteAdaptor; use crate::kernel::ParentKernelSet; -pub(super) const PARENT_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&FilterExecuteAdaptor(VarBinVTable))]); +pub(super) const PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ + ParentKernelSet::lift(&FilterExecuteAdaptor(VarBinVTable)), + ParentKernelSet::lift(&TakeExecuteAdaptor(VarBinVTable)), +]); diff --git a/vortex-array/src/arrays/varbin/vtable/mod.rs b/vortex-array/src/arrays/varbin/vtable/mod.rs index 0e02032a441..1450e210564 100644 --- a/vortex-array/src/arrays/varbin/vtable/mod.rs +++ b/vortex-array/src/arrays/varbin/vtable/mod.rs @@ -32,6 +32,7 @@ mod validity; mod visitor; use canonical::varbin_to_canonical; +use kernel::PARENT_KERNELS; use vortex_session::VortexSession; use crate::arrays::varbin::compute::rules::PARENT_RULES; @@ -105,9 +106,9 @@ impl VTable for VarBinVTable { if buffers.len() != 1 { vortex_bail!("Expected 1 buffer, got {}", buffers.len()); } - let bytes = buffers[0].clone(); + let bytes = buffers[0].clone().try_to_host_sync()?; - VarBinArray::try_new_from_handle(offsets, bytes, dtype.clone(), validity) + VarBinArray::try_new(offsets, bytes, dtype.clone(), validity) } fn with_children(array: &mut Self::Array, children: Vec) -> VortexResult<()> { @@ -147,7 +148,7 @@ impl VTable for VarBinVTable { child_idx: usize, ctx: &mut ExecutionCtx, ) -> VortexResult> { - kernel::PARENT_KERNELS.execute(array, parent, child_idx, ctx) + PARENT_KERNELS.execute(array, parent, child_idx, ctx) } fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { diff --git a/vortex-array/src/arrays/varbinview/compute/take.rs b/vortex-array/src/arrays/varbinview/compute/take.rs index b937c904438..48dc08db32f 100644 --- a/vortex-array/src/arrays/varbinview/compute/take.rs +++ b/vortex-array/src/arrays/varbinview/compute/take.rs @@ -15,18 +15,20 @@ use crate::Array; use crate::ArrayRef; use crate::IntoArray; use crate::ToCanonical; +use crate::arrays::TakeExecute; use crate::arrays::VarBinViewArray; use crate::arrays::VarBinViewVTable; use crate::buffer::BufferHandle; -use crate::compute::TakeKernel; -use crate::compute::TakeKernelAdapter; -use crate::register_kernel; +use crate::executor::ExecutionCtx; use crate::vtable::ValidityHelper; -/// Take involves creating a new array that references the old array, just with the given set of views. -impl TakeKernel for VarBinViewVTable { - fn take(&self, array: &VarBinViewArray, indices: &dyn Array) -> VortexResult { - // Compute the new validity. +impl TakeExecute for VarBinViewVTable { + /// Take involves creating a new array that references the old array, just with the given set of views. + fn take( + array: &VarBinViewArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { let validity = array.validity().take(indices)?; let indices = indices.to_primitive(); @@ -37,21 +39,21 @@ impl TakeKernel for VarBinViewVTable { // SAFETY: taking all components at same indices maintains invariants unsafe { - Ok(VarBinViewArray::new_handle_unchecked( - BufferHandle::new_host(views_buffer.into_byte_buffer()), - array.buffers().clone(), - array - .dtype() - .union_nullability(indices.dtype().nullability()), - validity, - ) - .into_array()) + Ok(Some( + VarBinViewArray::new_handle_unchecked( + BufferHandle::new_host(views_buffer.into_byte_buffer()), + array.buffers().clone(), + array + .dtype() + .union_nullability(indices.dtype().nullability()), + validity, + ) + .into_array(), + )) } } } -register_kernel!(TakeKernelAdapter(VarBinViewVTable).lift()); - fn take_views>( views_ref: &[BinaryView], indices: &[I], diff --git a/vortex-array/src/arrays/varbinview/vtable/kernel.rs b/vortex-array/src/arrays/varbinview/vtable/kernel.rs new file mode 100644 index 00000000000..4fc842eec54 --- /dev/null +++ b/vortex-array/src/arrays/varbinview/vtable/kernel.rs @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use crate::arrays::TakeExecuteAdaptor; +use crate::arrays::VarBinViewVTable; +use crate::kernel::ParentKernelSet; + +pub(super) const PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(VarBinViewVTable))]); diff --git a/vortex-array/src/arrays/varbinview/vtable/mod.rs b/vortex-array/src/arrays/varbinview/vtable/mod.rs index 532d8a0ad7d..eddd0875b0f 100644 --- a/vortex-array/src/arrays/varbinview/vtable/mod.rs +++ b/vortex-array/src/arrays/varbinview/vtable/mod.rs @@ -3,6 +3,7 @@ use std::sync::Arc; +use kernel::PARENT_KERNELS; use vortex_buffer::Buffer; use vortex_buffer::ByteBuffer; use vortex_dtype::DType; @@ -27,6 +28,7 @@ use crate::vtable::VTable; use crate::vtable::ValidityVTableFromValidityHelper; mod array; +mod kernel; mod operations; mod validity; mod visitor; @@ -130,6 +132,15 @@ impl VTable for VarBinViewVTable { PARENT_RULES.evaluate(array, parent, child_idx) } + fn execute_parent( + array: &Self::Array, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + PARENT_KERNELS.execute(array, parent, child_idx, ctx) + } + fn execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult { Ok(array.to_array()) } diff --git a/vortex-array/src/compute/mod.rs b/vortex-array/src/compute/mod.rs index 9c6047db699..b88b00413fa 100644 --- a/vortex-array/src/compute/mod.rs +++ b/vortex-array/src/compute/mod.rs @@ -98,7 +98,6 @@ pub fn warm_up_vtables() { nan_count::warm_up_vtable(); numeric::warm_up_vtable(); sum::warm_up_vtable(); - take::warm_up_vtable(); zip::warm_up_vtable(); } diff --git a/vortex-array/src/compute/take.rs b/vortex-array/src/compute/take.rs index 391c2924ce9..2a02298c326 100644 --- a/vortex-array/src/compute/take.rs +++ b/vortex-array/src/compute/take.rs @@ -1,45 +1,11 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use std::sync::LazyLock; - -use arcref::ArcRef; -use vortex_dtype::DType; -use vortex_error::VortexError; use vortex_error::VortexResult; -use vortex_error::vortex_bail; -use vortex_error::vortex_err; -use vortex_scalar::Scalar; use crate::Array; use crate::ArrayRef; -use crate::Canonical; use crate::IntoArray; -use crate::arrays::ConstantArray; -use crate::arrays::ConstantVTable; -use crate::compute::ComputeFn; -use crate::compute::ComputeFnVTable; -use crate::compute::InvocationArgs; -use crate::compute::Kernel; -use crate::compute::Output; -use crate::expr::stats::Precision; -use crate::expr::stats::Stat; -use crate::expr::stats::StatsProvider; -use crate::expr::stats::StatsProviderExt; -use crate::stats::StatsSet; -use crate::vtable::VTable; - -static TAKE_FN: LazyLock = LazyLock::new(|| { - let compute = ComputeFn::new("take".into(), ArcRef::new_ref(&Take)); - for kernel in inventory::iter:: { - compute.register_kernel(kernel.0.clone()); - } - compute -}); - -pub(crate) fn warm_up_vtable() -> usize { - TAKE_FN.kernels().len() + TAKE_FROM_FN.kernels().len() -} /// Creates a new array using the elements from the input `array` indexed by `indices`. /// @@ -48,260 +14,9 @@ pub(crate) fn warm_up_vtable() -> usize { /// /// The output array will have the same length as the `indices` array. pub fn take(array: &dyn Array, indices: &dyn Array) -> VortexResult { - if indices.is_empty() { - return Ok(Canonical::empty( - &array - .dtype() - .union_nullability(indices.dtype().nullability()), - ) - .into_array()); - } - - TAKE_FN - .invoke(&InvocationArgs { - inputs: &[array.into(), indices.into()], - options: &(), - })? - .unwrap_array() -} - -#[doc(hidden)] -pub struct Take; - -impl ComputeFnVTable for Take { - fn invoke( - &self, - args: &InvocationArgs, - kernels: &[ArcRef], - ) -> VortexResult { - let TakeArgs { array, indices } = TakeArgs::try_from(args)?; - - // TODO(ngates): if indices are sorted and unique (strict-sorted), then we should delegate to - // the filter function since they're typically optimised for this case. - // TODO(ngates): if indices min is quite high, we could slice self and offset the indices - // such that canonicalize does less work. - - if indices.all_invalid()? { - return Ok(ConstantArray::new( - Scalar::null(array.dtype().as_nullable()), - indices.len(), - ) - .into_array() - .into()); - } - - let taken_array = take_impl(array, indices, kernels)?; - - // We know that constant array don't need stats propagation, so we can avoid the overhead of - // computing derived stats and merging them in. - if !taken_array.is::() { - propagate_take_stats(array, &taken_array, indices)?; - } - - Ok(taken_array.into()) - } - - fn return_dtype(&self, args: &InvocationArgs) -> VortexResult { - let TakeArgs { array, indices } = TakeArgs::try_from(args)?; - - if !indices.dtype().is_int() { - vortex_bail!( - "Take indices must be an integer type, got {}", - indices.dtype() - ); - } - - Ok(array - .dtype() - .union_nullability(indices.dtype().nullability())) - } - - fn return_len(&self, args: &InvocationArgs) -> VortexResult { - let TakeArgs { indices, .. } = TakeArgs::try_from(args)?; - Ok(indices.len()) - } - - fn is_elementwise(&self) -> bool { - false - } -} - -fn propagate_take_stats( - source: &dyn Array, - target: &dyn Array, - indices: &dyn Array, -) -> VortexResult<()> { - target.statistics().with_mut_typed_stats_set(|mut st| { - if indices.all_valid().unwrap_or(false) { - let is_constant = source.statistics().get_as::(Stat::IsConstant); - if is_constant == Some(Precision::Exact(true)) { - // Any combination of elements from a constant array is still const - st.set(Stat::IsConstant, Precision::exact(true)); - } - } - let inexact_min_max = [Stat::Min, Stat::Max] - .into_iter() - .filter_map(|stat| { - source - .statistics() - .get(stat) - .map(|v| (stat, v.map(|s| s.into_value()).into_inexact())) - }) - .collect::>(); - st.combine_sets( - &(unsafe { StatsSet::new_unchecked(inexact_min_max) }).as_typed_ref(source.dtype()), - ) - }) -} - -fn take_impl( - array: &dyn Array, - indices: &dyn Array, - kernels: &[ArcRef], -) -> VortexResult { - let args = InvocationArgs { - inputs: &[array.into(), indices.into()], - options: &(), - }; - - // First look for a TakeFrom specialized on the indices. - for kernel in TAKE_FROM_FN.kernels() { - if let Some(output) = kernel.invoke(&args)? { - return output.unwrap_array(); - } - } - - // Then look for a Take kernel - for kernel in kernels { - if let Some(output) = kernel.invoke(&args)? { - return output.unwrap_array(); - } - } - - // Otherwise, canonicalize and try again. - if !array.is_canonical() { - tracing::debug!("No take implementation found for {}", array.encoding_id()); - let canonical = array.to_canonical()?; - return take(canonical.as_ref(), indices); - } - - vortex_bail!("No take implementation found for {}", array.encoding_id()); -} - -struct TakeArgs<'a> { - array: &'a dyn Array, - indices: &'a dyn Array, -} - -impl<'a> TryFrom<&InvocationArgs<'a>> for TakeArgs<'a> { - type Error = VortexError; - - fn try_from(value: &InvocationArgs<'a>) -> Result { - if value.inputs.len() != 2 { - vortex_bail!("Expected 2 inputs, found {}", value.inputs.len()); - } - let array = value.inputs[0] - .array() - .ok_or_else(|| vortex_err!("Expected first input to be an array"))?; - let indices = value.inputs[1] - .array() - .ok_or_else(|| vortex_err!("Expected second input to be an array"))?; - Ok(Self { array, indices }) - } -} - -pub trait TakeKernel: VTable { - /// Create a new array by taking the values from the `array` at the - /// given `indices`. - /// - /// # Panics - /// - /// Using `indices` that are invalid for the given `array` will cause a panic. - fn take(&self, array: &Self::Array, indices: &dyn Array) -> VortexResult; -} - -/// A kernel that implements the filter function. -pub struct TakeKernelRef(pub ArcRef); -inventory::collect!(TakeKernelRef); - -#[derive(Debug)] -pub struct TakeKernelAdapter(pub V); - -impl TakeKernelAdapter { - pub const fn lift(&'static self) -> TakeKernelRef { - TakeKernelRef(ArcRef::new_ref(self)) - } -} - -impl Kernel for TakeKernelAdapter { - fn invoke(&self, args: &InvocationArgs) -> VortexResult> { - let inputs = TakeArgs::try_from(args)?; - let Some(array) = inputs.array.as_opt::() else { - return Ok(None); - }; - Ok(Some(V::take(&self.0, array, inputs.indices)?.into())) - } -} - -static TAKE_FROM_FN: LazyLock = LazyLock::new(|| { - let compute = ComputeFn::new("take_from".into(), ArcRef::new_ref(&TakeFrom)); - for kernel in inventory::iter:: { - compute.register_kernel(kernel.0.clone()); - } - compute -}); - -pub struct TakeFrom; - -impl ComputeFnVTable for TakeFrom { - fn invoke( - &self, - _args: &InvocationArgs, - _kernels: &[ArcRef], - ) -> VortexResult { - vortex_bail!( - "TakeFrom should not be invoked directly. Its kernels are used to accelerated the Take function" - ) - } - - fn return_dtype(&self, args: &InvocationArgs) -> VortexResult { - Take.return_dtype(args) - } - - fn return_len(&self, args: &InvocationArgs) -> VortexResult { - Take.return_len(args) - } - - fn is_elementwise(&self) -> bool { - Take.is_elementwise() - } -} - -pub trait TakeFromKernel: VTable { - /// Create a new array by taking the values from the `array` at the - /// given `indices`. - fn take_from(&self, indices: &Self::Array, array: &dyn Array) - -> VortexResult>; -} - -pub struct TakeFromKernelRef(pub ArcRef); -inventory::collect!(TakeFromKernelRef); - -#[derive(Debug)] -pub struct TakeFromKernelAdapter(pub V); - -impl TakeFromKernelAdapter { - pub const fn lift(&'static self) -> TakeFromKernelRef { - TakeFromKernelRef(ArcRef::new_ref(self)) - } -} - -impl Kernel for TakeFromKernelAdapter { - fn invoke(&self, args: &InvocationArgs) -> VortexResult> { - let inputs = TakeArgs::try_from(args)?; - let Some(indices) = inputs.indices.as_opt::() else { - return Ok(None); - }; - Ok(V::take_from(&self.0, indices, inputs.array)?.map(Output::from)) - } + // TODO(joe): inline usage and remove to_canonical(). + array + .take(indices.to_array())? + .to_canonical() + .map(|c| c.into_array()) } diff --git a/vortex-array/src/patches.rs b/vortex-array/src/patches.rs index 96778bceda2..06d35912074 100644 --- a/vortex-array/src/patches.rs +++ b/vortex-array/src/patches.rs @@ -1044,9 +1044,11 @@ where AllOr::None => true, AllOr::Some(buf) => !buf.value(idx_in_take), }; - if include_nulls && is_null { - new_sparse_indices.push(idx_in_take as u64); - value_indices.push(0); + if is_null { + if include_nulls { + new_sparse_indices.push(idx_in_take as u64); + value_indices.push(0); + } } else if ti >= min_index && ti <= max_index { let ti_as_i = I::try_from(ti) .map_err(|_| vortex_err!("take index does not fit in index type"))?; @@ -1185,10 +1187,13 @@ fn take_indices_with_search_fn< let mut new_indices = BufferMut::with_capacity(take_indices.len()); for (new_patch_idx, &take_idx) in take_indices.iter().enumerate() { - if include_nulls && !take_validity.value(new_patch_idx) { - // For nulls, patch index doesn't matter - use 0 for consistency - values_indices.push(0u64); - new_indices.push(new_patch_idx as u64); + if !take_validity.value(new_patch_idx) { + if include_nulls { + // For nulls, patch index doesn't matter - use 0 for consistency + values_indices.push(0u64); + new_indices.push(new_patch_idx as u64); + } + continue; } else { let search_result = match I::from(take_idx) { Some(idx) => search_fn(idx)?, diff --git a/vortex-array/src/validity.rs b/vortex-array/src/validity.rs index 7dc0a0f7aaa..6dab74af78b 100644 --- a/vortex-array/src/validity.rs +++ b/vortex-array/src/validity.rs @@ -28,7 +28,6 @@ use crate::arrays::BoolArray; use crate::arrays::ConstantArray; use crate::compute::fill_null; use crate::compute::sum; -use crate::compute::take; use crate::patches::Patches; /// Validity information for an array @@ -173,7 +172,10 @@ impl Validity { }, Self::AllInvalid => Ok(Self::AllInvalid), Self::Array(is_valid) => { - let maybe_is_valid = take(is_valid, indices)?; + let maybe_is_valid = is_valid + .take(indices.to_array())? + .to_canonical()? + .into_array(); // Null indices invalidate that position. let is_valid = fill_null(&maybe_is_valid, &Scalar::from(false))?; Ok(Self::Array(is_valid)) diff --git a/vortex-python/src/arrays/mod.rs b/vortex-python/src/arrays/mod.rs index 24344c652f3..d372d5a86a4 100644 --- a/vortex-python/src/arrays/mod.rs +++ b/vortex-python/src/arrays/mod.rs @@ -27,7 +27,6 @@ use vortex::array::ArrayRef; use vortex::array::ToCanonical; use vortex::array::arrays::ChunkedVTable; use vortex::array::arrow::IntoArrowArray; -use vortex::array::compute::take; use vortex::compute::Operator; use vortex::compute::compare; use vortex::dtype::DType; @@ -603,7 +602,7 @@ impl PyArray { /// >>> a = vx.array(['a', 'b', 'c', 'd']) /// >>> indices = vx.array([0, 2]) /// >>> a.take(indices).to_arrow_array() - /// + /// /// [ /// "a", /// "c" @@ -616,7 +615,7 @@ impl PyArray { /// >>> a = vx.array(['a', 'b', 'c', 'd']) /// >>> indices = vx.array([0, 1, 1, 0]) /// >>> a.take(indices).to_arrow_array() - /// + /// /// [ /// "a", /// "b", @@ -629,13 +628,13 @@ impl PyArray { if !indices.dtype().is_int() { return Err(PyValueError::new_err(format!( - "indices: expected int or uint array, but found: {}", + "indices: expected int or uint arra sy, but found: {}", indices.dtype().python_repr() )) .into()); } - let inner = take(&slf, &*indices)?; + let inner = slf.take(indices.clone())?; Ok(PyArrayRef::from(inner)) }