From 89676d621a0f8671ffbb3819ac2cf1dff93fd160 Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Thu, 5 Feb 2026 11:24:45 +0000 Subject: [PATCH 01/21] wip Signed-off-by: Joe Isaacs --- encodings/alp/src/alp/array.rs | 8 ++ encodings/alp/src/alp_rd/array.rs | 8 +- encodings/alp/src/alp_rd/mod.rs | 1 + encodings/alp/src/alp_rd/rules.rs | 10 ++ .../src/bitpacking/compute/filter.rs | 2 +- encodings/runend/src/compute/filter.rs | 50 +++++++ encodings/sequence/src/array.rs | 1 - .../src/arrays/chunked/vtable/kernel.rs | 9 ++ vortex-array/src/arrays/chunked/vtable/mod.rs | 6 +- .../src/arrays/list/compute/filter.rs | 122 +++++++++--------- vortex-array/src/arrays/varbin/vtable/mod.rs | 12 +- 11 files changed, 157 insertions(+), 72 deletions(-) create mode 100644 encodings/alp/src/alp_rd/rules.rs create mode 100644 vortex-array/src/arrays/chunked/vtable/kernel.rs diff --git a/encodings/alp/src/alp/array.rs b/encodings/alp/src/alp/array.rs index 95a9cc4166a..9299c52aa2d 100644 --- a/encodings/alp/src/alp/array.rs +++ b/encodings/alp/src/alp/array.rs @@ -168,6 +168,14 @@ impl VTable for ALPVTable { Ok(execute_decompress(array.clone(), ctx)?.into_array()) } + fn reduce_parent( + array: &Self::Array, + parent: &ArrayRef, + child_idx: usize, + ) -> VortexResult> { + crate::alp::rules::PARENT_RULES.evaluate(array, parent, child_idx) + } + fn execute_parent( array: &Self::Array, parent: &ArrayRef, diff --git a/encodings/alp/src/alp_rd/array.rs b/encodings/alp/src/alp_rd/array.rs index 42a5e253716..70e885f4e17 100644 --- a/encodings/alp/src/alp_rd/array.rs +++ b/encodings/alp/src/alp_rd/array.rs @@ -11,9 +11,9 @@ use vortex_array::ArrayChildVisitor; use vortex_array::ArrayEq; use vortex_array::ArrayHash; use vortex_array::ArrayRef; +use vortex_array::Canonical; use vortex_array::DeserializeMetadata; use vortex_array::ExecutionCtx; -use vortex_array::IntoArray; use vortex_array::Precision; use vortex_array::ProstMetadata; use vortex_array::SerializeMetadata; @@ -28,6 +28,7 @@ use vortex_array::validity::Validity; use vortex_array::vtable; use vortex_array::vtable::ArrayId; use vortex_array::vtable::BaseArrayVTable; +use vortex_array::vtable::NotSupported; use vortex_array::vtable::VTable; use vortex_array::vtable::ValidityChild; use vortex_array::vtable::ValidityVTableFromChild; @@ -71,6 +72,7 @@ impl VTable for ALPRDVTable { type OperationsVTable = Self; type ValidityVTable = ValidityVTableFromChild; type VisitorVTable = Self; + type ComputeVTable = NotSupported; fn id(_array: &Self::Array) -> ArrayId { Self::ID @@ -216,7 +218,7 @@ impl VTable for ALPRDVTable { Ok(()) } - fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { + fn canonicalize(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { let left_parts = array.left_parts().clone().execute::(ctx)?; let right_parts = array.right_parts().clone().execute::(ctx)?; @@ -255,7 +257,7 @@ impl VTable for ALPRDVTable { ) }; - Ok(decoded_array.into_array()) + Ok(Canonical::Primitive(decoded_array)) } fn execute_parent( diff --git a/encodings/alp/src/alp_rd/mod.rs b/encodings/alp/src/alp_rd/mod.rs index 8514b0d9576..77ccfe4872a 100644 --- a/encodings/alp/src/alp_rd/mod.rs +++ b/encodings/alp/src/alp_rd/mod.rs @@ -14,6 +14,7 @@ mod array; mod compute; mod kernel; mod ops; +mod rules; mod slice; use std::ops::Shl; diff --git a/encodings/alp/src/alp_rd/rules.rs b/encodings/alp/src/alp_rd/rules.rs new file mode 100644 index 00000000000..74909d4a329 --- /dev/null +++ b/encodings/alp/src/alp_rd/rules.rs @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::arrays::FilterReduceAdaptor; +use vortex_array::optimizer::rules::ParentRuleSet; + +use crate::ALPRDVTable; + +pub(super) const PARENT_RULES: ParentRuleSet = + ParentRuleSet::new(&[ParentRuleSet::lift(&FilterReduceAdaptor(ALPRDVTable))]); diff --git a/encodings/fastlanes/src/bitpacking/compute/filter.rs b/encodings/fastlanes/src/bitpacking/compute/filter.rs index 8b6d05f53f9..e8d1385e4d8 100644 --- a/encodings/fastlanes/src/bitpacking/compute/filter.rs +++ b/encodings/fastlanes/src/bitpacking/compute/filter.rs @@ -28,7 +28,7 @@ use crate::BitPackedVTable; /// The threshold over which it is faster to fully unpack the entire [`BitPackedArray`] and then /// filter the result than to unpack only specific bitpacked values into the output buffer. -const fn unpack_then_filter_threshold(ptype: PType) -> f64 { +pub const fn unpack_then_filter_threshold(ptype: PType) -> f64 { // TODO(connor): Where did these numbers come from? Add a public link after validating them. // These numbers probably don't work for in-place filtering either. match ptype.byte_width() { diff --git a/encodings/runend/src/compute/filter.rs b/encodings/runend/src/compute/filter.rs index 269d64b97f4..e22232ee867 100644 --- a/encodings/runend/src/compute/filter.rs +++ b/encodings/runend/src/compute/filter.rs @@ -73,6 +73,31 @@ impl FilterKernel for RunEndVTable { } } +// We expose this function to our benchmarks. +pub fn filter_run_end(array: &RunEndArray, mask: &Mask) -> VortexResult { + let primitive_run_ends = array.ends().to_primitive(); + let (run_ends, values_mask) = + match_each_unsigned_integer_ptype!(primitive_run_ends.ptype(), |P| { + filter_run_end_primitive( + primitive_run_ends.as_slice::

(), + array.offset() as u64, + array.len() as u64, + mask.values() + .vortex_expect("AllTrue and AllFalse handled by filter fn") + .bit_buffer(), + )? + }); + let values = array.values().filter(values_mask)?; + + // SAFETY: enforced by filter_run_end_primitive + unsafe { + Ok( + RunEndArray::new_unchecked(run_ends.into_array(), values, 0, mask.true_count()) + .into_array(), + ) + } +} + // Code adapted from apache arrow-rs https://github.com/apache/arrow-rs/blob/b1f5c250ebb6c1252b4e7c51d15b8e77f4c361fa/arrow-select/src/filter.rs#L425 fn filter_run_end_primitive + AsPrimitive>( run_ends: &[R], @@ -117,12 +142,15 @@ fn filter_run_end_primitive + AsPrimitiv mod tests { use vortex_array::Array; use vortex_array::IntoArray; + use vortex_array::ToCanonical; use vortex_array::arrays::PrimitiveArray; use vortex_array::assert_arrays_eq; use vortex_error::VortexResult; use vortex_mask::Mask; + use super::filter_run_end; use crate::RunEndArray; + use crate::RunEndVTable; fn ree_array() -> RunEndArray { RunEndArray::encode( @@ -131,6 +159,28 @@ mod tests { .unwrap() } + #[test] + fn run_end_filter() { + let arr = ree_array(); + let filtered = filter_run_end( + &arr, + &Mask::from_iter([ + true, true, false, false, false, false, false, false, false, false, true, true, + ]), + ) + .unwrap(); + let filtered_run_end = filtered.as_::(); + + assert_arrays_eq!( + filtered_run_end.ends().to_primitive(), + PrimitiveArray::from_iter([2u8, 4]) + ); + assert_arrays_eq!( + filtered_run_end.values().to_primitive(), + PrimitiveArray::from_iter([1i32, 5]) + ); + } + #[test] fn filter_sliced_run_end() -> VortexResult<()> { let arr = ree_array().slice(2..7).unwrap(); diff --git a/encodings/sequence/src/array.rs b/encodings/sequence/src/array.rs index 22e5e078fec..4ded8b71b42 100644 --- a/encodings/sequence/src/array.rs +++ b/encodings/sequence/src/array.rs @@ -9,7 +9,6 @@ use vortex_array::ArrayChildVisitor; use vortex_array::ArrayRef; use vortex_array::DeserializeMetadata; use vortex_array::ExecutionCtx; -use vortex_array::IntoArray; use vortex_array::Precision; use vortex_array::ProstMetadata; use vortex_array::SerializeMetadata; diff --git a/vortex-array/src/arrays/chunked/vtable/kernel.rs b/vortex-array/src/arrays/chunked/vtable/kernel.rs new file mode 100644 index 00000000000..534775434f2 --- /dev/null +++ b/vortex-array/src/arrays/chunked/vtable/kernel.rs @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use crate::arrays::ChunkedVTable; +use crate::arrays::filter::FilterExecuteAdaptor; +use crate::kernel::ParentKernelSet; + +pub(super) const PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&FilterExecuteAdaptor(ChunkedVTable))]); diff --git a/vortex-array/src/arrays/chunked/vtable/mod.rs b/vortex-array/src/arrays/chunked/vtable/mod.rs index eea3f954032..96b39afe5f5 100644 --- a/vortex-array/src/arrays/chunked/vtable/mod.rs +++ b/vortex-array/src/arrays/chunked/vtable/mod.rs @@ -31,6 +31,7 @@ use crate::vtable::VTable; mod array; mod canonical; +mod compute; mod operations; mod validity; mod visitor; @@ -53,6 +54,7 @@ impl VTable for ChunkedVTable { type OperationsVTable = Self; type ValidityVTable = Self; type VisitorVTable = Self; + type ComputeVTable = Self; fn id(_array: &Self::Array) -> ArrayId { Self::ID @@ -167,8 +169,8 @@ impl VTable for ChunkedVTable { Ok(()) } - fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { - Ok(_canonicalize(array, ctx)?.into_array()) + fn canonicalize(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { + _canonicalize(array, ctx) } fn reduce(array: &Self::Array) -> VortexResult> { diff --git a/vortex-array/src/arrays/list/compute/filter.rs b/vortex-array/src/arrays/list/compute/filter.rs index 4650b37ee37..0393c65a817 100644 --- a/vortex-array/src/arrays/list/compute/filter.rs +++ b/vortex-array/src/arrays/list/compute/filter.rs @@ -14,7 +14,6 @@ use vortex_mask::Mask; use vortex_mask::MaskIter; use vortex_mask::MaskValues; -use crate::Array; use crate::ArrayRef; use crate::Canonical; use crate::ExecutionCtx; @@ -26,64 +25,6 @@ use crate::arrays::ListVTable; use crate::validity::Validity; use crate::vtable::ValidityHelper; -impl FilterKernel for ListVTable { - fn filter( - array: &ListArray, - mask: &Mask, - ctx: &mut ExecutionCtx, - ) -> VortexResult> { - let selection = match mask { - Mask::AllTrue(_) | Mask::AllFalse(_) => return Ok(None), - Mask::Values(v) => v, - }; - - let new_validity = match array.validity() { - Validity::NonNullable => Validity::NonNullable, - Validity::AllValid => Validity::AllValid, - Validity::AllInvalid => { - let elements = Canonical::empty(array.element_dtype()).into_array(); - let offsets = ConstantArray::new(0u64, selection.true_count() + 1).into_array(); - return Ok(Some(unsafe { - ListArray::new_unchecked(elements, offsets, Validity::AllInvalid).into_array() - })); - } - Validity::Array(a) => Validity::Array(a.filter(mask.clone())?), - }; - - // TODO(ngates): for ultra-sparse masks, we don't need to optimize the entire offsets. - let offsets = array.offsets().clone(); - - let (new_offsets, element_mask) = - match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| { - let offsets_buffer = offsets.execute::>(ctx)?; - let offsets = offsets_buffer.as_slice(); - let mut new_offsets = BufferMut::::with_capacity(selection.true_count() + 1); - - let mut offset = O::zero(); - unsafe { new_offsets.push_unchecked(offset) }; - for idx in selection.indices() { - let size = offsets[idx + 1] - offsets[*idx]; - offset += size; - unsafe { new_offsets.push_unchecked(offset) }; - } - - // TODO(ngates): for very dense masks, there may be no point in filtering the elements, - // and instead we should construct a view against the unfiltered elements. - let element_mask = element_mask_from_offsets::(offsets, selection); - - (new_offsets.freeze().into_array(), element_mask) - }); - - let new_elements = array.sliced_elements()?.filter(element_mask)?; - - // SAFETY: new_offsets are monotonically increasing starting from 0 with length - // true_count + 1, and the elements have been filtered to match. - Ok(Some(unsafe { - ListArray::new_unchecked(new_elements, new_offsets, new_validity).into_array() - })) - } -} - /// Density threshold for choosing between indices and slices representation when expanding masks. /// /// When the mask density is below this threshold, we use indices. Otherwise, we use slices. @@ -92,7 +33,10 @@ impl FilterKernel for ListVTable { const MASK_EXPANSION_DENSITY_THRESHOLD: f64 = 0.05; /// Construct an element mask from contiguous list offsets and a selection mask. -fn element_mask_from_offsets(offsets: &[O], selection: &Arc) -> Mask { +pub fn element_mask_from_offsets( + offsets: &[O], + selection: &Arc, +) -> Mask { let first_offset = offsets.first().map_or(0, |first_offset| first_offset.as_()); let last_offset = offsets.last().map_or(0, |last_offset| last_offset.as_()); let len = last_offset - first_offset; @@ -147,3 +91,61 @@ fn process_element_range( new_mask_builder.append_n(true, elems_len); } } + +impl FilterKernel for ListVTable { + fn filter( + array: &ListArray, + mask: &Mask, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + let selection = match mask { + Mask::AllTrue(_) | Mask::AllFalse(_) => return Ok(None), + Mask::Values(v) => v, + }; + + let new_validity = match array.validity() { + Validity::NonNullable => Validity::NonNullable, + Validity::AllValid => Validity::AllValid, + Validity::AllInvalid => { + let elements = Canonical::empty(array.element_dtype()).into_array(); + let offsets = ConstantArray::new(0u64, selection.true_count() + 1).into_array(); + return Ok(Some(unsafe { + ListArray::new_unchecked(elements, offsets, Validity::AllInvalid).into_array() + })); + } + Validity::Array(a) => Validity::Array(a.filter(mask.clone())?), + }; + + // TODO(ngates): for ultra-sparse masks, we don't need to optimize the entire offsets. + let offsets = array.offsets().clone(); + + let (new_offsets, element_mask) = + match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| { + let offsets_buffer = offsets.execute::>(ctx)?; + let offsets = offsets_buffer.as_slice(); + let mut new_offsets = BufferMut::::with_capacity(selection.true_count() + 1); + + let mut offset = O::zero(); + unsafe { new_offsets.push_unchecked(offset) }; + for idx in selection.indices() { + let size = offsets[idx + 1] - offsets[*idx]; + offset += size; + unsafe { new_offsets.push_unchecked(offset) }; + } + + // TODO(ngates): for very dense masks, there may be no point in filtering the elements, + // and instead we should construct a view against the unfiltered elements. + let element_mask = element_mask_from_offsets::(offsets, selection); + + (new_offsets.freeze().into_array(), element_mask) + }); + + let new_elements = array.sliced_elements()?.filter(element_mask)?; + + // SAFETY: new_offsets are monotonically increasing starting from 0 with length + // true_count + 1, and the elements have been filtered to match. + Ok(Some(unsafe { + ListArray::new_unchecked(new_elements, new_offsets, new_validity).into_array() + })) + } +} diff --git a/vortex-array/src/arrays/varbin/vtable/mod.rs b/vortex-array/src/arrays/varbin/vtable/mod.rs index eae82d990c4..776c961b36c 100644 --- a/vortex-array/src/arrays/varbin/vtable/mod.rs +++ b/vortex-array/src/arrays/varbin/vtable/mod.rs @@ -10,9 +10,9 @@ use vortex_error::vortex_bail; use vortex_error::vortex_err; use crate::ArrayRef; +use crate::Canonical; use crate::DeserializeMetadata; use crate::ExecutionCtx; -use crate::IntoArray; use crate::ProstMetadata; use crate::SerializeMetadata; use crate::arrays::varbin::VarBinArray; @@ -21,6 +21,7 @@ use crate::serde::ArrayChildren; use crate::validity::Validity; use crate::vtable; use crate::vtable::ArrayId; +use crate::vtable::NotSupported; use crate::vtable::VTable; use crate::vtable::ValidityVTableFromValidityHelper; @@ -52,6 +53,7 @@ impl VTable for VarBinVTable { type OperationsVTable = Self; type ValidityVTable = ValidityVTableFromValidityHelper; type VisitorVTable = Self; + type ComputeVTable = NotSupported; fn id(_array: &Self::Array) -> ArrayId { Self::ID @@ -99,9 +101,9 @@ impl VTable for VarBinVTable { if buffers.len() != 1 { vortex_bail!("Expected 1 buffer, got {}", buffers.len()); } - let bytes = buffers[0].clone(); + let bytes = buffers[0].clone().try_to_host_sync()?; - VarBinArray::try_new_from_handle(offsets, bytes, dtype.clone(), validity) + VarBinArray::try_new(offsets, bytes, dtype.clone(), validity) } fn with_children(array: &mut Self::Array, children: Vec) -> VortexResult<()> { @@ -144,8 +146,8 @@ impl VTable for VarBinVTable { kernel::PARENT_KERNELS.execute(array, parent, child_idx, ctx) } - fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { - Ok(varbin_to_canonical(array, ctx)?.into_array()) + fn canonicalize(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { + varbin_to_canonical(array, ctx) } } From 2bf9f327d69f8087a0593c045ae6fa01a07dd721 Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Thu, 5 Feb 2026 14:24:47 +0000 Subject: [PATCH 02/21] wip Signed-off-by: Joe Isaacs --- encodings/alp/src/alp/array.rs | 20 ++++++++------------ encodings/alp/src/alp_rd/mod.rs | 1 - encodings/alp/src/alp_rd/rules.rs | 10 ---------- 3 files changed, 8 insertions(+), 23 deletions(-) delete mode 100644 encodings/alp/src/alp_rd/rules.rs diff --git a/encodings/alp/src/alp/array.rs b/encodings/alp/src/alp/array.rs index 9299c52aa2d..89568f02ddd 100644 --- a/encodings/alp/src/alp/array.rs +++ b/encodings/alp/src/alp/array.rs @@ -10,9 +10,9 @@ use vortex_array::ArrayChildVisitor; use vortex_array::ArrayEq; use vortex_array::ArrayHash; use vortex_array::ArrayRef; +use vortex_array::Canonical; use vortex_array::DeserializeMetadata; use vortex_array::ExecutionCtx; -use vortex_array::IntoArray; use vortex_array::Precision; use vortex_array::ProstMetadata; use vortex_array::SerializeMetadata; @@ -25,6 +25,7 @@ use vortex_array::stats::StatsSetRef; use vortex_array::vtable; use vortex_array::vtable::ArrayId; use vortex_array::vtable::BaseArrayVTable; +use vortex_array::vtable::NotSupported; use vortex_array::vtable::VTable; use vortex_array::vtable::ValidityChild; use vortex_array::vtable::ValidityVTableFromChild; @@ -53,6 +54,7 @@ impl VTable for ALPVTable { type OperationsVTable = Self; type ValidityVTable = ValidityVTableFromChild; type VisitorVTable = Self; + type ComputeVTable = NotSupported; fn id(_array: &Self::Array) -> ArrayId { Self::ID @@ -163,17 +165,12 @@ impl VTable for ALPVTable { Ok(()) } - fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { + fn canonicalize(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { // TODO(joe): take by value - Ok(execute_decompress(array.clone(), ctx)?.into_array()) - } - - fn reduce_parent( - array: &Self::Array, - parent: &ArrayRef, - child_idx: usize, - ) -> VortexResult> { - crate::alp::rules::PARENT_RULES.evaluate(array, parent, child_idx) + Ok(Canonical::Primitive(execute_decompress( + array.clone(), + ctx, + )?)) } fn execute_parent( @@ -460,7 +457,6 @@ mod tests { use std::sync::LazyLock; use rstest::rstest; - use vortex_array::Canonical; use vortex_array::IntoArray; use vortex_array::ToCanonical; use vortex_array::VortexSessionExecute; diff --git a/encodings/alp/src/alp_rd/mod.rs b/encodings/alp/src/alp_rd/mod.rs index 77ccfe4872a..8514b0d9576 100644 --- a/encodings/alp/src/alp_rd/mod.rs +++ b/encodings/alp/src/alp_rd/mod.rs @@ -14,7 +14,6 @@ mod array; mod compute; mod kernel; mod ops; -mod rules; mod slice; use std::ops::Shl; diff --git a/encodings/alp/src/alp_rd/rules.rs b/encodings/alp/src/alp_rd/rules.rs deleted file mode 100644 index 74909d4a329..00000000000 --- a/encodings/alp/src/alp_rd/rules.rs +++ /dev/null @@ -1,10 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use vortex_array::arrays::FilterReduceAdaptor; -use vortex_array::optimizer::rules::ParentRuleSet; - -use crate::ALPRDVTable; - -pub(super) const PARENT_RULES: ParentRuleSet = - ParentRuleSet::new(&[ParentRuleSet::lift(&FilterReduceAdaptor(ALPRDVTable))]); From 6b408a876c2e27b9b314568b0122129c51632b8c Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Thu, 5 Feb 2026 14:45:30 +0000 Subject: [PATCH 03/21] wip Signed-off-by: Joe Isaacs --- encodings/fastlanes/src/bitpacking/compute/slice.rs | 1 - encodings/fastlanes/src/bitpacking/vtable/mod.rs | 8 +++++--- vortex-array/src/arrays/chunked/vtable/kernel.rs | 9 --------- 3 files changed, 5 insertions(+), 13 deletions(-) delete mode 100644 vortex-array/src/arrays/chunked/vtable/kernel.rs diff --git a/encodings/fastlanes/src/bitpacking/compute/slice.rs b/encodings/fastlanes/src/bitpacking/compute/slice.rs index 55cd12b3975..24f99ea17aa 100644 --- a/encodings/fastlanes/src/bitpacking/compute/slice.rs +++ b/encodings/fastlanes/src/bitpacking/compute/slice.rs @@ -6,7 +6,6 @@ use std::ops::Range; use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; -use vortex_array::IntoArray; use vortex_array::arrays::SliceKernel; use vortex_error::VortexResult; diff --git a/encodings/fastlanes/src/bitpacking/vtable/mod.rs b/encodings/fastlanes/src/bitpacking/vtable/mod.rs index 1dfec786da7..3f42fee7f50 100644 --- a/encodings/fastlanes/src/bitpacking/vtable/mod.rs +++ b/encodings/fastlanes/src/bitpacking/vtable/mod.rs @@ -2,9 +2,9 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_array::ArrayRef; +use vortex_array::Canonical; use vortex_array::DeserializeMetadata; use vortex_array::ExecutionCtx; -use vortex_array::IntoArray; use vortex_array::ProstMetadata; use vortex_array::SerializeMetadata; use vortex_array::buffer::BufferHandle; @@ -15,6 +15,7 @@ use vortex_array::serde::ArrayChildren; use vortex_array::validity::Validity; use vortex_array::vtable; use vortex_array::vtable::ArrayId; +use vortex_array::vtable::NotSupported; use vortex_array::vtable::VTable; use vortex_array::vtable::ValidityVTableFromValidityHelper; use vortex_dtype::DType; @@ -57,6 +58,7 @@ impl VTable for BitPackedVTable { type OperationsVTable = Self; type ValidityVTable = ValidityVTableFromValidityHelper; type VisitorVTable = Self; + type ComputeVTable = NotSupported; fn id(_array: &Self::Array) -> ArrayId { Self::ID @@ -244,8 +246,8 @@ impl VTable for BitPackedVTable { }) } - fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { - Ok(unpack_array(array, ctx)?.into_array()) + fn canonicalize(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { + Ok(Canonical::Primitive(unpack_array(array, ctx)?)) } fn execute_parent( diff --git a/vortex-array/src/arrays/chunked/vtable/kernel.rs b/vortex-array/src/arrays/chunked/vtable/kernel.rs deleted file mode 100644 index 534775434f2..00000000000 --- a/vortex-array/src/arrays/chunked/vtable/kernel.rs +++ /dev/null @@ -1,9 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use crate::arrays::ChunkedVTable; -use crate::arrays::filter::FilterExecuteAdaptor; -use crate::kernel::ParentKernelSet; - -pub(super) const PARENT_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&FilterExecuteAdaptor(ChunkedVTable))]); From d99e5429440af64ee5bb3cd730aeb0766c372fb3 Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Thu, 5 Feb 2026 15:05:23 +0000 Subject: [PATCH 04/21] wip Signed-off-by: Joe Isaacs --- encodings/fastlanes/src/bitpacking/compute/slice.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/encodings/fastlanes/src/bitpacking/compute/slice.rs b/encodings/fastlanes/src/bitpacking/compute/slice.rs index 24f99ea17aa..55cd12b3975 100644 --- a/encodings/fastlanes/src/bitpacking/compute/slice.rs +++ b/encodings/fastlanes/src/bitpacking/compute/slice.rs @@ -6,6 +6,7 @@ use std::ops::Range; use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; use vortex_array::arrays::SliceKernel; use vortex_error::VortexResult; From 5993ca9d360257c176b9a633fe35ab039f1e7354 Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Thu, 5 Feb 2026 10:44:20 +0000 Subject: [PATCH 05/21] wip Signed-off-by: Joe Isaacs --- encodings/alp/src/alp/compute/take.rs | 49 ++-- encodings/alp/src/alp_rd/compute/take.rs | 77 ++--- encodings/bytebool/src/compute.rs | 59 ++-- encodings/datetime-parts/src/compute/take.rs | 137 ++++----- .../src/decimal_byte_parts/compute/take.rs | 26 +- .../fastlanes/src/bitpacking/compute/take.rs | 58 ++-- encodings/fastlanes/src/for/compute/mod.rs | 29 +- encodings/fsst/src/compute/mod.rs | 58 ++-- encodings/runend/src/compute/take.rs | 66 +++-- encodings/sequence/src/compute/take.rs | 62 +++-- encodings/sparse/src/compute/take.rs | 77 ++--- encodings/zigzag/src/compute/mod.rs | 22 +- vortex-array/src/arrays/bool/compute/take.rs | 62 +++-- .../src/arrays/chunked/compute/take.rs | 124 +++++---- .../src/arrays/constant/compute/take.rs | 74 ++--- .../src/arrays/decimal/compute/take.rs | 56 ++-- vortex-array/src/arrays/dict/compute/mod.rs | 29 +- vortex-array/src/arrays/dict/execute.rs | 30 +- .../src/arrays/extension/compute/take.rs | 35 ++- .../arrays/fixed_size_list/compute/take.rs | 30 +- vortex-array/src/arrays/list/compute/take.rs | 44 +-- .../src/arrays/listview/compute/take.rs | 102 ++++--- .../src/arrays/masked/compute/take.rs | 47 ++-- vortex-array/src/arrays/mod.rs | 2 + vortex-array/src/arrays/null/compute/take.rs | 39 +-- .../src/arrays/primitive/compute/take/mod.rs | 50 ++-- .../src/arrays/struct_/compute/take.rs | 67 +++-- vortex-array/src/arrays/take/array.rs | 70 +++++ vortex-array/src/arrays/take/execute.rs | 59 ++++ vortex-array/src/arrays/take/kernel.rs | 105 +++++++ vortex-array/src/arrays/take/mod.rs | 19 ++ vortex-array/src/arrays/take/rules.rs | 83 ++++++ vortex-array/src/arrays/take/vtable.rs | 213 ++++++++++++++ .../src/arrays/varbin/compute/take.rs | 190 +++++++------ .../src/arrays/varbinview/compute/take.rs | 65 +++-- vortex-array/src/compute/take.rs | 263 +----------------- 36 files changed, 1562 insertions(+), 1016 deletions(-) create mode 100644 vortex-array/src/arrays/take/array.rs create mode 100644 vortex-array/src/arrays/take/execute.rs create mode 100644 vortex-array/src/arrays/take/kernel.rs create mode 100644 vortex-array/src/arrays/take/mod.rs create mode 100644 vortex-array/src/arrays/take/rules.rs create mode 100644 vortex-array/src/arrays/take/vtable.rs diff --git a/encodings/alp/src/alp/compute/take.rs b/encodings/alp/src/alp/compute/take.rs index 6965f0350a3..2a890177355 100644 --- a/encodings/alp/src/alp/compute/take.rs +++ b/encodings/alp/src/alp/compute/take.rs @@ -4,36 +4,43 @@ use vortex_array::Array; use vortex_array::ArrayRef; use vortex_array::IntoArray; -use vortex_array::compute::TakeKernel; -use vortex_array::compute::TakeKernelAdapter; +use vortex_array::arrays::TakeReduce; +use vortex_array::arrays::TakeReduceAdaptor; use vortex_array::compute::take; -use vortex_array::register_kernel; +use vortex_array::optimizer::rules::ParentRuleSet; use vortex_error::VortexResult; use crate::ALPArray; use crate::ALPVTable; -impl TakeKernel for ALPVTable { - fn take(&self, array: &ALPArray, indices: &dyn Array) -> VortexResult { - let taken_encoded = take(array.encoded(), indices)?; - let taken_patches = array - .patches() - .map(|p| p.take(indices)) - .transpose()? - .flatten() - .map(|patches| { - patches.cast_values( - &array - .dtype() - .with_nullability(taken_encoded.dtype().nullability()), - ) - }) - .transpose()?; - Ok(ALPArray::new(taken_encoded, array.exponents(), taken_patches).into_array()) +fn take_alp(array: &ALPArray, indices: &dyn Array) -> VortexResult { + let taken_encoded = take(array.encoded(), indices)?; + let taken_patches = array + .patches() + .map(|p| p.take(indices)) + .transpose()? + .flatten() + .map(|patches| { + patches.cast_values( + &array + .dtype() + .with_nullability(taken_encoded.dtype().nullability()), + ) + }) + .transpose()?; + Ok(ALPArray::new(taken_encoded, array.exponents(), taken_patches).into_array()) +} + +impl TakeReduce for ALPVTable { + fn take(array: &ALPArray, indices: &dyn Array) -> VortexResult> { + take_alp(array, indices).map(Some) } } -register_kernel!(TakeKernelAdapter(ALPVTable).lift()); +impl ALPVTable { + pub const TAKE_RULES: ParentRuleSet = + ParentRuleSet::new(&[ParentRuleSet::lift(&TakeReduceAdaptor::(Self))]); +} #[cfg(test)] mod test { diff --git a/encodings/alp/src/alp_rd/compute/take.rs b/encodings/alp/src/alp_rd/compute/take.rs index 1b43c0c423f..4983a9a9b2c 100644 --- a/encodings/alp/src/alp_rd/compute/take.rs +++ b/encodings/alp/src/alp_rd/compute/take.rs @@ -4,11 +4,11 @@ use vortex_array::Array; use vortex_array::ArrayRef; use vortex_array::IntoArray; -use vortex_array::compute::TakeKernel; -use vortex_array::compute::TakeKernelAdapter; +use vortex_array::arrays::TakeReduce; +use vortex_array::arrays::TakeReduceAdaptor; use vortex_array::compute::fill_null; use vortex_array::compute::take; -use vortex_array::register_kernel; +use vortex_array::optimizer::rules::ParentRuleSet; use vortex_error::VortexResult; use vortex_scalar::Scalar; use vortex_scalar::ScalarValue; @@ -16,42 +16,49 @@ use vortex_scalar::ScalarValue; use crate::ALPRDArray; use crate::ALPRDVTable; -impl TakeKernel for ALPRDVTable { - fn take(&self, array: &ALPRDArray, indices: &dyn Array) -> VortexResult { - let taken_left_parts = take(array.left_parts(), indices)?; - let left_parts_exceptions = array - .left_parts_patches() - .map(|patches| patches.take(indices)) - .transpose()? - .flatten() - .map(|p| { - let values_dtype = p - .values() - .dtype() - .with_nullability(taken_left_parts.dtype().nullability()); - p.cast_values(&values_dtype) - }) - .transpose()?; - let right_parts = fill_null( - &take(array.right_parts(), indices)?, - &Scalar::new(array.right_parts().dtype().clone(), ScalarValue::from(0)), - )?; - - Ok(ALPRDArray::try_new( - array +fn take_alprd(array: &ALPRDArray, indices: &dyn Array) -> VortexResult { + let taken_left_parts = take(array.left_parts(), indices)?; + let left_parts_exceptions = array + .left_parts_patches() + .map(|patches| patches.take(indices)) + .transpose()? + .flatten() + .map(|p| { + let values_dtype = p + .values() .dtype() - .with_nullability(taken_left_parts.dtype().nullability()), - taken_left_parts, - array.left_parts_dictionary().clone(), - right_parts, - array.right_bit_width(), - left_parts_exceptions, - )? - .into_array()) + .with_nullability(taken_left_parts.dtype().nullability()); + p.cast_values(&values_dtype) + }) + .transpose()?; + let right_parts = fill_null( + &take(array.right_parts(), indices)?, + &Scalar::new(array.right_parts().dtype().clone(), ScalarValue::from(0)), + )?; + + Ok(ALPRDArray::try_new( + array + .dtype() + .with_nullability(taken_left_parts.dtype().nullability()), + taken_left_parts, + array.left_parts_dictionary().clone(), + right_parts, + array.right_bit_width(), + left_parts_exceptions, + )? + .into_array()) +} + +impl TakeReduce for ALPRDVTable { + fn take(array: &ALPRDArray, indices: &dyn Array) -> VortexResult> { + take_alprd(array, indices).map(Some) } } -register_kernel!(TakeKernelAdapter(ALPRDVTable).lift()); +impl ALPRDVTable { + pub const TAKE_RULES: ParentRuleSet = + ParentRuleSet::new(&[ParentRuleSet::lift(&TakeReduceAdaptor::(Self))]); +} #[cfg(test)] mod test { diff --git a/encodings/bytebool/src/compute.rs b/encodings/bytebool/src/compute.rs index 4bda0c343ef..f3ebb91f168 100644 --- a/encodings/bytebool/src/compute.rs +++ b/encodings/bytebool/src/compute.rs @@ -4,14 +4,16 @@ use num_traits::AsPrimitive; use vortex_array::Array; use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::ToCanonical; +use vortex_array::arrays::TakeExecute; +use vortex_array::arrays::TakeExecuteAdaptor; use vortex_array::compute::CastKernel; use vortex_array::compute::CastKernelAdapter; use vortex_array::compute::MaskKernel; use vortex_array::compute::MaskKernelAdapter; -use vortex_array::compute::TakeKernel; -use vortex_array::compute::TakeKernelAdapter; +use vortex_array::kernel::ParentKernelSet; use vortex_array::register_kernel; use vortex_array::vtable::ValidityHelper; use vortex_dtype::DType; @@ -55,30 +57,41 @@ impl MaskKernel for ByteBoolVTable { register_kernel!(MaskKernelAdapter(ByteBoolVTable).lift()); -impl TakeKernel for ByteBoolVTable { - fn take(&self, array: &ByteBoolArray, indices: &dyn Array) -> VortexResult { - let indices = indices.to_primitive(); - let bools = array.as_slice(); - - // This handles combining validity from both source array and nullable indices - let validity = array.validity().take(indices.as_ref())?; - - let taken_bools = match_each_integer_ptype!(indices.ptype(), |I| { - indices - .as_slice::() - .iter() - .map(|&idx| { - let idx: usize = idx.as_(); - bools[idx] - }) - .collect::>() - }); - - Ok(ByteBoolArray::from_vec(taken_bools, validity).into_array()) +fn take_bytebool(array: &ByteBoolArray, indices: &dyn Array) -> VortexResult { + let indices = indices.to_primitive(); + let bools = array.as_slice(); + + // This handles combining validity from both source array and nullable indices + let validity = array.validity().take(indices.as_ref())?; + + let taken_bools = match_each_integer_ptype!(indices.ptype(), |I| { + indices + .as_slice::() + .iter() + .map(|&idx| { + let idx: usize = idx.as_(); + bools[idx] + }) + .collect::>() + }); + + Ok(ByteBoolArray::from_vec(taken_bools, validity).into_array()) +} + +impl TakeExecute for ByteBoolVTable { + fn take( + array: &ByteBoolArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + take_bytebool(array, indices).map(Some) } } -register_kernel!(TakeKernelAdapter(ByteBoolVTable).lift()); +impl ByteBoolVTable { + pub const TAKE_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); +} #[cfg(test)] mod tests { diff --git a/encodings/datetime-parts/src/compute/take.rs b/encodings/datetime-parts/src/compute/take.rs index 23ccc46af61..770d173d521 100644 --- a/encodings/datetime-parts/src/compute/take.rs +++ b/encodings/datetime-parts/src/compute/take.rs @@ -5,13 +5,13 @@ use vortex_array::Array; use vortex_array::ArrayRef; use vortex_array::IntoArray; use vortex_array::ToCanonical; -use vortex_array::compute::TakeKernel; -use vortex_array::compute::TakeKernelAdapter; +use vortex_array::arrays::TakeReduce; +use vortex_array::arrays::TakeReduceAdaptor; use vortex_array::compute::fill_null; use vortex_array::compute::take; use vortex_array::expr::stats::Stat; use vortex_array::expr::stats::StatsProvider; -use vortex_array::register_kernel; +use vortex_array::optimizer::rules::ParentRuleSet; use vortex_dtype::Nullability; use vortex_error::VortexResult; use vortex_error::vortex_panic; @@ -20,80 +20,83 @@ use vortex_scalar::Scalar; use crate::DateTimePartsArray; use crate::DateTimePartsVTable; -impl TakeKernel for DateTimePartsVTable { - fn take(&self, array: &DateTimePartsArray, indices: &dyn Array) -> VortexResult { - // we go ahead and canonicalize here to avoid worst-case canonicalizing 3 separate times - let indices = indices.to_primitive(); +fn take_datetime_parts(array: &DateTimePartsArray, indices: &dyn Array) -> VortexResult { + // we go ahead and canonicalize here to avoid worst-case canonicalizing 3 separate times + let indices = indices.to_primitive(); - let taken_days = take(array.days(), indices.as_ref())?; - let taken_seconds = take(array.seconds(), indices.as_ref())?; - let taken_subseconds = take(array.subseconds(), indices.as_ref())?; + let taken_days = take(array.days(), indices.as_ref())?; + let taken_seconds = take(array.seconds(), indices.as_ref())?; + let taken_subseconds = take(array.subseconds(), indices.as_ref())?; - // Update the dtype if the nullability changed due to nullable indices - let dtype = if taken_days.dtype().is_nullable() != array.dtype().is_nullable() { - array - .dtype() - .with_nullability(taken_days.dtype().nullability()) - } else { - array.dtype().clone() - }; + // Update the dtype if the nullability changed due to nullable indices + let dtype = if taken_days.dtype().is_nullable() != array.dtype().is_nullable() { + array + .dtype() + .with_nullability(taken_days.dtype().nullability()) + } else { + array.dtype().clone() + }; - if !taken_seconds.dtype().is_nullable() && !taken_subseconds.dtype().is_nullable() { - return Ok(DateTimePartsArray::try_new( - dtype, - taken_days, - taken_seconds, - taken_subseconds, - )? - .into_array()); - } + if !taken_seconds.dtype().is_nullable() && !taken_subseconds.dtype().is_nullable() { + return Ok(DateTimePartsArray::try_new( + dtype, + taken_days, + taken_seconds, + taken_subseconds, + )? + .into_array()); + } + + // DateTimePartsArray requires seconds and subseconds to be non-nullable. + // If they became nullable due to nullable indices, we need to fill nulls. + // But first, we need to check that the types are consistent. + if !taken_days.dtype().is_nullable() { + vortex_panic!("Mismatched types: days is not nullable, seconds is nullable"); + } + if !taken_seconds.dtype().is_nullable() { + vortex_panic!("Mismatched types: seconds is not nullable, days is nullable"); + } + if !taken_subseconds.dtype().is_nullable() { + vortex_panic!("Mismatched types: subseconds is not nullable, days & seconds are nullable"); + } + if !indices.dtype().is_nullable() { + vortex_panic!("Mismatched types: indices are not nullable, days & seconds are nullable"); + } - // DateTimePartsArray requires seconds and subseconds to be non-nullable. - // If they became nullable due to nullable indices, we need to fill nulls. - // But first, we need to check that the types are consistent. - if !taken_days.dtype().is_nullable() { - vortex_panic!("Mismatched types: days is not nullable, seconds is nullable"); - } - if !taken_seconds.dtype().is_nullable() { - vortex_panic!("Mismatched types: seconds is not nullable, days is nullable"); - } - if !taken_subseconds.dtype().is_nullable() { - vortex_panic!( - "Mismatched types: subseconds is not nullable, days & seconds are nullable" - ); - } - if !indices.dtype().is_nullable() { - vortex_panic!( - "Mismatched types: indices are not nullable, days & seconds are nullable" - ); - } + let seconds_fill = array + .seconds() + .statistics() + .get(Stat::Min) + .map(|s| s.into_inner()) + .unwrap_or_else(|| Scalar::primitive(0i64, Nullability::NonNullable)) + .cast(array.seconds().dtype())?; + let taken_seconds = fill_null(taken_seconds.as_ref(), &seconds_fill)?; - let seconds_fill = array - .seconds() - .statistics() - .get(Stat::Min) - .map(|s| s.into_inner()) - .unwrap_or_else(|| Scalar::primitive(0i64, Nullability::NonNullable)) - .cast(array.seconds().dtype())?; - let taken_seconds = fill_null(taken_seconds.as_ref(), &seconds_fill)?; + let subseconds_fill = array + .subseconds() + .statistics() + .get(Stat::Min) + .map(|s| s.into_inner()) + .unwrap_or_else(|| Scalar::primitive(0i64, Nullability::NonNullable)) + .cast(array.subseconds().dtype())?; + let taken_subseconds = fill_null(taken_subseconds.as_ref(), &subseconds_fill)?; - let subseconds_fill = array - .subseconds() - .statistics() - .get(Stat::Min) - .map(|s| s.into_inner()) - .unwrap_or_else(|| Scalar::primitive(0i64, Nullability::NonNullable)) - .cast(array.subseconds().dtype())?; - let taken_subseconds = fill_null(taken_subseconds.as_ref(), &subseconds_fill)?; + Ok( + DateTimePartsArray::try_new(dtype, taken_days, taken_seconds, taken_subseconds)? + .into_array(), + ) +} - Ok( - DateTimePartsArray::try_new(dtype, taken_days, taken_seconds, taken_subseconds)? - .into_array(), - ) +impl TakeReduce for DateTimePartsVTable { + fn take(array: &DateTimePartsArray, indices: &dyn Array) -> VortexResult> { + take_datetime_parts(array, indices).map(Some) } } -register_kernel!(TakeKernelAdapter(DateTimePartsVTable).lift()); +impl DateTimePartsVTable { + pub const TAKE_RULES: ParentRuleSet = + ParentRuleSet::new(&[ParentRuleSet::lift(&TakeReduceAdaptor::(Self))]); +} #[cfg(test)] mod tests { diff --git a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/take.rs b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/take.rs index 0368d3c1f7f..407734d5c24 100644 --- a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/take.rs +++ b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/take.rs @@ -3,20 +3,30 @@ use vortex_array::Array; use vortex_array::ArrayRef; -use vortex_array::compute::TakeKernel; -use vortex_array::compute::TakeKernelAdapter; +use vortex_array::arrays::TakeReduce; +use vortex_array::arrays::TakeReduceAdaptor; use vortex_array::compute::take; -use vortex_array::register_kernel; +use vortex_array::optimizer::rules::ParentRuleSet; use vortex_error::VortexResult; use crate::DecimalBytePartsArray; use crate::DecimalBytePartsVTable; -impl TakeKernel for DecimalBytePartsVTable { - fn take(&self, array: &DecimalBytePartsArray, indices: &dyn Array) -> VortexResult { - DecimalBytePartsArray::try_new(take(&array.msp, indices)?, *array.decimal_dtype()) - .map(|a| a.to_array()) +fn take_decimal_byte_parts( + array: &DecimalBytePartsArray, + indices: &dyn Array, +) -> VortexResult { + DecimalBytePartsArray::try_new(take(&array.msp, indices)?, *array.decimal_dtype()) + .map(|a| a.to_array()) +} + +impl TakeReduce for DecimalBytePartsVTable { + fn take(array: &DecimalBytePartsArray, indices: &dyn Array) -> VortexResult> { + take_decimal_byte_parts(array, indices).map(Some) } } -register_kernel!(TakeKernelAdapter(DecimalBytePartsVTable).lift()); +impl DecimalBytePartsVTable { + pub const TAKE_RULES: ParentRuleSet = + ParentRuleSet::new(&[ParentRuleSet::lift(&TakeReduceAdaptor::(Self))]); +} diff --git a/encodings/fastlanes/src/bitpacking/compute/take.rs b/encodings/fastlanes/src/bitpacking/compute/take.rs index 65dd4032cd6..e2b8bf87b6f 100644 --- a/encodings/fastlanes/src/bitpacking/compute/take.rs +++ b/encodings/fastlanes/src/bitpacking/compute/take.rs @@ -7,13 +7,14 @@ use std::mem::MaybeUninit; use fastlanes::BitPacking; use vortex_array::Array; use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::ToCanonical; use vortex_array::arrays::PrimitiveArray; -use vortex_array::compute::TakeKernel; -use vortex_array::compute::TakeKernelAdapter; +use vortex_array::arrays::TakeExecute; +use vortex_array::arrays::TakeExecuteAdaptor; use vortex_array::compute::take; -use vortex_array::register_kernel; +use vortex_array::kernel::ParentKernelSet; use vortex_array::validity::Validity; use vortex_array::vtable::ValidityHelper; use vortex_buffer::Buffer; @@ -37,30 +38,41 @@ use crate::bitpack_decompress; /// see https://github.com/vortex-data/vortex/pull/190#issue-2223752833 pub(super) const UNPACK_CHUNK_THRESHOLD: usize = 8; -impl TakeKernel for BitPackedVTable { - fn take(&self, array: &BitPackedArray, indices: &dyn Array) -> VortexResult { - // If the indices are large enough, it's faster to flatten and take the primitive array. - if indices.len() * UNPACK_CHUNK_THRESHOLD > array.len() { - return take(array.to_primitive().as_ref(), indices); - } +fn take_bitpacked(array: &BitPackedArray, indices: &dyn Array) -> VortexResult { + // If the indices are large enough, it's faster to flatten and take the primitive array. + if indices.len() * UNPACK_CHUNK_THRESHOLD > array.len() { + return take(array.to_primitive().as_ref(), indices); + } + + // NOTE: we use the unsigned PType because all values in the BitPackedArray must + // be non-negative (pre-condition of creating the BitPackedArray). + let ptype: PType = PType::try_from(array.dtype())?; + let validity = array.validity(); + let taken_validity = validity.take(indices)?; + + let indices = indices.to_primitive(); + let taken = match_each_unsigned_integer_ptype!(ptype.to_unsigned(), |T| { + match_each_integer_ptype!(indices.ptype(), |I| { + take_primitive::(array, &indices, taken_validity)? + }) + }); + Ok(taken.reinterpret_cast(ptype).into_array()) +} - // NOTE: we use the unsigned PType because all values in the BitPackedArray must - // be non-negative (pre-condition of creating the BitPackedArray). - let ptype: PType = PType::try_from(array.dtype())?; - let validity = array.validity(); - let taken_validity = validity.take(indices)?; - - let indices = indices.to_primitive(); - let taken = match_each_unsigned_integer_ptype!(ptype.to_unsigned(), |T| { - match_each_integer_ptype!(indices.ptype(), |I| { - take_primitive::(array, &indices, taken_validity)? - }) - }); - Ok(taken.reinterpret_cast(ptype).into_array()) +impl TakeExecute for BitPackedVTable { + fn take( + array: &BitPackedArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + take_bitpacked(array, indices).map(Some) } } -register_kernel!(TakeKernelAdapter(BitPackedVTable).lift()); +impl BitPackedVTable { + pub const TAKE_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); +} fn take_primitive( array: &BitPackedArray, diff --git a/encodings/fastlanes/src/for/compute/mod.rs b/encodings/fastlanes/src/for/compute/mod.rs index f7bbe4a6c64..b2b4a590c2c 100644 --- a/encodings/fastlanes/src/for/compute/mod.rs +++ b/encodings/fastlanes/src/for/compute/mod.rs @@ -10,27 +10,34 @@ use vortex_array::Array; use vortex_array::ArrayRef; use vortex_array::IntoArray; use vortex_array::arrays::FilterReduce; -use vortex_array::compute::TakeKernel; -use vortex_array::compute::TakeKernelAdapter; +use vortex_array::arrays::TakeReduce; +use vortex_array::arrays::TakeReduceAdaptor; use vortex_array::compute::take; -use vortex_array::register_kernel; +use vortex_array::optimizer::rules::ParentRuleSet; use vortex_error::VortexResult; use vortex_mask::Mask; use crate::FoRArray; use crate::FoRVTable; -impl TakeKernel for FoRVTable { - fn take(&self, array: &FoRArray, indices: &dyn Array) -> VortexResult { - FoRArray::try_new( - take(array.encoded(), indices)?, - array.reference_scalar().clone(), - ) - .map(|a| a.into_array()) +fn take_for(array: &FoRArray, indices: &dyn Array) -> VortexResult { + FoRArray::try_new( + take(array.encoded(), indices)?, + array.reference_scalar().clone(), + ) + .map(|a| a.into_array()) +} + +impl TakeReduce for FoRVTable { + fn take(array: &FoRArray, indices: &dyn Array) -> VortexResult> { + take_for(array, indices).map(Some) } } -register_kernel!(TakeKernelAdapter(FoRVTable).lift()); +impl FoRVTable { + pub const TAKE_RULES: ParentRuleSet = + ParentRuleSet::new(&[ParentRuleSet::lift(&TakeReduceAdaptor::(Self))]); +} impl FilterReduce for FoRVTable { fn filter(array: &FoRArray, mask: &Mask) -> VortexResult> { diff --git a/encodings/fsst/src/compute/mod.rs b/encodings/fsst/src/compute/mod.rs index 4dde1cc19c0..8146851f942 100644 --- a/encodings/fsst/src/compute/mod.rs +++ b/encodings/fsst/src/compute/mod.rs @@ -8,12 +8,12 @@ mod filter; use vortex_array::Array; use vortex_array::ArrayRef; use vortex_array::IntoArray; +use vortex_array::arrays::TakeReduce; +use vortex_array::arrays::TakeReduceAdaptor; use vortex_array::arrays::VarBinVTable; -use vortex_array::compute::TakeKernel; -use vortex_array::compute::TakeKernelAdapter; use vortex_array::compute::fill_null; use vortex_array::compute::take; -use vortex_array::register_kernel; +use vortex_array::optimizer::rules::ParentRuleSet; use vortex_error::VortexResult; use vortex_scalar::Scalar; use vortex_scalar::ScalarValue; @@ -21,32 +21,38 @@ use vortex_scalar::ScalarValue; use crate::FSSTArray; use crate::FSSTVTable; -impl TakeKernel for FSSTVTable { - // Take on an FSSTArray is a simple take on the codes array. - fn take(&self, array: &FSSTArray, indices: &dyn Array) -> VortexResult { - Ok(FSSTArray::try_new( - array - .dtype() - .clone() - .union_nullability(indices.dtype().nullability()), - array.symbols().clone(), - array.symbol_lengths().clone(), - take(array.codes().as_ref(), indices)? - .as_::() - .clone(), - fill_null( - &take(array.uncompressed_lengths(), indices)?, - &Scalar::new( - array.uncompressed_lengths_dtype().clone(), - ScalarValue::from(0), - ), - )?, - )? - .into_array()) +fn take_fsst(array: &FSSTArray, indices: &dyn Array) -> VortexResult { + Ok(FSSTArray::try_new( + array + .dtype() + .clone() + .union_nullability(indices.dtype().nullability()), + array.symbols().clone(), + array.symbol_lengths().clone(), + take(array.codes().as_ref(), indices)? + .as_::() + .clone(), + fill_null( + &take(array.uncompressed_lengths(), indices)?, + &Scalar::new( + array.uncompressed_lengths_dtype().clone(), + ScalarValue::from(0), + ), + )?, + )? + .into_array()) +} + +impl TakeReduce for FSSTVTable { + fn take(array: &FSSTArray, indices: &dyn Array) -> VortexResult> { + take_fsst(array, indices).map(Some) } } -register_kernel!(TakeKernelAdapter(FSSTVTable).lift()); +impl FSSTVTable { + pub const TAKE_RULES: ParentRuleSet = + ParentRuleSet::new(&[ParentRuleSet::lift(&TakeReduceAdaptor::(Self))]); +} #[cfg(test)] mod tests { diff --git a/encodings/runend/src/compute/take.rs b/encodings/runend/src/compute/take.rs index 2b6a2915bb3..d77da4b1a8d 100644 --- a/encodings/runend/src/compute/take.rs +++ b/encodings/runend/src/compute/take.rs @@ -5,12 +5,13 @@ use num_traits::AsPrimitive; use num_traits::NumCast; use vortex_array::Array; use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; use vortex_array::ToCanonical; use vortex_array::arrays::PrimitiveArray; -use vortex_array::compute::TakeKernel; -use vortex_array::compute::TakeKernelAdapter; +use vortex_array::arrays::TakeExecute; +use vortex_array::arrays::TakeExecuteAdaptor; use vortex_array::compute::take; -use vortex_array::register_kernel; +use vortex_array::kernel::ParentKernelSet; use vortex_array::search_sorted::SearchResult; use vortex_array::search_sorted::SearchSorted; use vortex_array::search_sorted::SearchSortedSide; @@ -24,34 +25,45 @@ use vortex_error::vortex_bail; use crate::RunEndArray; use crate::RunEndVTable; -impl TakeKernel for RunEndVTable { - #[expect( - clippy::cast_possible_truncation, - reason = "index cast to usize inside macro" - )] - fn take(&self, array: &RunEndArray, indices: &dyn Array) -> VortexResult { - let primitive_indices = indices.to_primitive(); - - let checked_indices = match_each_integer_ptype!(primitive_indices.ptype(), |P| { - primitive_indices - .as_slice::

() - .iter() - .copied() - .map(|idx| { - let usize_idx = idx as usize; - if usize_idx >= array.len() { - vortex_bail!(OutOfBounds: usize_idx, 0, array.len()); - } - Ok(usize_idx) - }) - .collect::>>()? - }); +#[expect( + clippy::cast_possible_truncation, + reason = "index cast to usize inside macro" +)] +fn take_runend(array: &RunEndArray, indices: &dyn Array) -> VortexResult { + let primitive_indices = indices.to_primitive(); + + let checked_indices = match_each_integer_ptype!(primitive_indices.ptype(), |P| { + primitive_indices + .as_slice::

() + .iter() + .copied() + .map(|idx| { + let usize_idx = idx as usize; + if usize_idx >= array.len() { + vortex_bail!(OutOfBounds: usize_idx, 0, array.len()); + } + Ok(usize_idx) + }) + .collect::>>()? + }); - take_indices_unchecked(array, &checked_indices, primitive_indices.validity()) + take_indices_unchecked(array, &checked_indices, primitive_indices.validity()) +} + +impl TakeExecute for RunEndVTable { + fn take( + array: &RunEndArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + take_runend(array, indices).map(Some) } } -register_kernel!(TakeKernelAdapter(RunEndVTable).lift()); +impl RunEndVTable { + pub const TAKE_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); +} /// Perform a take operation on a RunEndArray by binary searching for each of the indices. pub fn take_indices_unchecked>( diff --git a/encodings/sequence/src/compute/take.rs b/encodings/sequence/src/compute/take.rs index eec5191d301..dfe7cecd282 100644 --- a/encodings/sequence/src/compute/take.rs +++ b/encodings/sequence/src/compute/take.rs @@ -4,13 +4,14 @@ use num_traits::cast::NumCast; use vortex_array::Array; use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::ToCanonical; use vortex_array::arrays::ConstantArray; use vortex_array::arrays::PrimitiveArray; -use vortex_array::compute::TakeKernel; -use vortex_array::compute::TakeKernelAdapter; -use vortex_array::register_kernel; +use vortex_array::arrays::TakeExecute; +use vortex_array::arrays::TakeExecuteAdaptor; +use vortex_array::kernel::ParentKernelSet; use vortex_array::validity::Validity; use vortex_buffer::Buffer; use vortex_dtype::DType; @@ -29,31 +30,29 @@ use vortex_scalar::Scalar; use crate::SequenceArray; use crate::SequenceVTable; -impl TakeKernel for SequenceVTable { - fn take(&self, array: &SequenceArray, indices: &dyn Array) -> VortexResult { - let mask = indices.validity_mask()?; - let indices = indices.to_primitive(); - let result_nullability = array.dtype().nullability() | indices.dtype().nullability(); +fn take_sequence(array: &SequenceArray, indices: &dyn Array) -> VortexResult { + let mask = indices.validity_mask()?; + let indices = indices.to_primitive(); + let result_nullability = array.dtype().nullability() | indices.dtype().nullability(); - match_each_integer_ptype!(indices.ptype(), |T| { - let indices = indices.as_slice::(); - match_each_native_ptype!(array.ptype(), |S| { - let mul = array.multiplier().cast::(); - let base = array.base().cast::(); - Ok(take( - mul, - base, - indices, - mask, - result_nullability, - array.len(), - )) - }) + match_each_integer_ptype!(indices.ptype(), |T| { + let indices = indices.as_slice::(); + match_each_native_ptype!(array.ptype(), |S| { + let mul = array.multiplier().cast::(); + let base = array.base().cast::(); + Ok(take_inner( + mul, + base, + indices, + mask, + result_nullability, + array.len(), + )) }) - } + }) } -fn take( +fn take_inner( mul: S, base: S, indices: &[T], @@ -98,7 +97,20 @@ fn take( } } -register_kernel!(TakeKernelAdapter(SequenceVTable).lift()); +impl TakeExecute for SequenceVTable { + fn take( + array: &SequenceArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + take_sequence(array, indices).map(Some) + } +} + +impl SequenceVTable { + pub const TAKE_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); +} #[cfg(test)] mod test { diff --git a/encodings/sparse/src/compute/take.rs b/encodings/sparse/src/compute/take.rs index 41e4c5f0eed..c63bedc4915 100644 --- a/encodings/sparse/src/compute/take.rs +++ b/encodings/sparse/src/compute/take.rs @@ -5,49 +5,56 @@ use vortex_array::Array; use vortex_array::ArrayRef; use vortex_array::IntoArray; use vortex_array::arrays::ConstantArray; -use vortex_array::compute::TakeKernel; -use vortex_array::compute::TakeKernelAdapter; -use vortex_array::register_kernel; +use vortex_array::arrays::TakeReduce; +use vortex_array::arrays::TakeReduceAdaptor; +use vortex_array::optimizer::rules::ParentRuleSet; use vortex_error::VortexResult; use crate::SparseArray; use crate::SparseVTable; -impl TakeKernel for SparseVTable { - fn take(&self, array: &SparseArray, take_indices: &dyn Array) -> VortexResult { - let patches_take = if array.fill_scalar().is_null() { - array.patches().take(take_indices)? - } else { - array.patches().take_with_nulls(take_indices)? - }; - - let Some(new_patches) = patches_take else { - let result_fill_scalar = array.fill_scalar().cast( - &array - .dtype() - .union_nullability(take_indices.dtype().nullability()), - )?; - return Ok(ConstantArray::new(result_fill_scalar, take_indices.len()).into_array()); - }; - - // See `SparseEncoding::slice`. - if new_patches.array_len() == new_patches.values().len() { - return Ok(new_patches.into_values()); - } - - Ok(SparseArray::try_new_from_patches( - new_patches, - array.fill_scalar().cast( - &array - .dtype() - .union_nullability(take_indices.dtype().nullability()), - )?, - )? - .into_array()) +fn take_sparse(array: &SparseArray, take_indices: &dyn Array) -> VortexResult { + let patches_take = if array.fill_scalar().is_null() { + array.patches().take(take_indices)? + } else { + array.patches().take_with_nulls(take_indices)? + }; + + let Some(new_patches) = patches_take else { + let result_fill_scalar = array.fill_scalar().cast( + &array + .dtype() + .union_nullability(take_indices.dtype().nullability()), + )?; + return Ok(ConstantArray::new(result_fill_scalar, take_indices.len()).into_array()); + }; + + // See `SparseEncoding::slice`. + if new_patches.array_len() == new_patches.values().len() { + return Ok(new_patches.into_values()); } + + Ok(SparseArray::try_new_from_patches( + new_patches, + array.fill_scalar().cast( + &array + .dtype() + .union_nullability(take_indices.dtype().nullability()), + )?, + )? + .into_array()) } -register_kernel!(TakeKernelAdapter(SparseVTable).lift()); +impl TakeReduce for SparseVTable { + fn take(array: &SparseArray, indices: &dyn Array) -> VortexResult> { + take_sparse(array, indices).map(Some) + } +} + +impl SparseVTable { + pub const TAKE_RULES: ParentRuleSet = + ParentRuleSet::new(&[ParentRuleSet::lift(&TakeReduceAdaptor::(Self))]); +} #[cfg(test)] mod test { diff --git a/encodings/zigzag/src/compute/mod.rs b/encodings/zigzag/src/compute/mod.rs index 7d97d94edaa..40a4cbb05e5 100644 --- a/encodings/zigzag/src/compute/mod.rs +++ b/encodings/zigzag/src/compute/mod.rs @@ -7,12 +7,13 @@ use vortex_array::Array; use vortex_array::ArrayRef; use vortex_array::IntoArray; use vortex_array::arrays::FilterReduce; +use vortex_array::arrays::TakeReduce; +use vortex_array::arrays::TakeReduceAdaptor; use vortex_array::compute::MaskKernel; use vortex_array::compute::MaskKernelAdapter; -use vortex_array::compute::TakeKernel; -use vortex_array::compute::TakeKernelAdapter; use vortex_array::compute::mask; use vortex_array::compute::take; +use vortex_array::optimizer::rules::ParentRuleSet; use vortex_array::register_kernel; use vortex_error::VortexResult; use vortex_mask::Mask; @@ -27,14 +28,21 @@ impl FilterReduce for ZigZagVTable { } } -impl TakeKernel for ZigZagVTable { - fn take(&self, array: &ZigZagArray, indices: &dyn Array) -> VortexResult { - let encoded = take(array.encoded(), indices)?; - Ok(ZigZagArray::try_new(encoded)?.into_array()) +fn take_zigzag(array: &ZigZagArray, indices: &dyn Array) -> VortexResult { + let encoded = take(array.encoded(), indices)?; + Ok(ZigZagArray::try_new(encoded)?.into_array()) +} + +impl TakeReduce for ZigZagVTable { + fn take(array: &ZigZagArray, indices: &dyn Array) -> VortexResult> { + take_zigzag(array, indices).map(Some) } } -register_kernel!(TakeKernelAdapter(ZigZagVTable).lift()); +impl ZigZagVTable { + pub const TAKE_RULES: ParentRuleSet = + ParentRuleSet::new(&[ParentRuleSet::lift(&TakeReduceAdaptor::(Self))]); +} impl MaskKernel for ZigZagVTable { fn mask(&self, array: &ZigZagArray, filter_mask: &Mask) -> VortexResult { diff --git a/vortex-array/src/arrays/bool/compute/take.rs b/vortex-array/src/arrays/bool/compute/take.rs index def1490cb69..57c2084f3dc 100644 --- a/vortex-array/src/arrays/bool/compute/take.rs +++ b/vortex-array/src/arrays/bool/compute/take.rs @@ -17,35 +17,47 @@ use crate::ToCanonical; use crate::arrays::BoolArray; use crate::arrays::BoolVTable; use crate::arrays::ConstantArray; -use crate::compute::TakeKernel; -use crate::compute::TakeKernelAdapter; +use crate::arrays::TakeExecute; +use crate::arrays::TakeExecuteAdaptor; use crate::compute::fill_null; -use crate::register_kernel; +use crate::executor::ExecutionCtx; +use crate::kernel::ParentKernelSet; use crate::vtable::ValidityHelper; -impl TakeKernel for BoolVTable { - fn take(&self, array: &BoolArray, indices: &dyn Array) -> VortexResult { - let indices_nulls_zeroed = match indices.validity_mask()? { - Mask::AllTrue(_) => indices.to_array(), - Mask::AllFalse(_) => { - return Ok(ConstantArray::new( - Scalar::null(array.dtype().as_nullable()), - indices.len(), - ) - .into_array()); - } - Mask::Values(_) => fill_null(indices, &Scalar::from(0).cast(indices.dtype())?)?, - }; - let indices_nulls_zeroed = indices_nulls_zeroed.to_primitive(); - let buffer = match_each_integer_ptype!(indices_nulls_zeroed.ptype(), |I| { - take_valid_indices(&array.to_bit_buffer(), indices_nulls_zeroed.as_slice::()) - }); - - Ok(BoolArray::new(buffer, array.validity().take(indices)?).to_array()) +fn take_bool(array: &BoolArray, indices: &dyn Array) -> VortexResult { + let indices_nulls_zeroed = match indices.validity_mask()? { + Mask::AllTrue(_) => indices.to_array(), + Mask::AllFalse(_) => { + return Ok(ConstantArray::new( + Scalar::null(array.dtype().as_nullable()), + indices.len(), + ) + .into_array()); + } + Mask::Values(_) => fill_null(indices, &Scalar::from(0).cast(indices.dtype())?)?, + }; + let indices_nulls_zeroed = indices_nulls_zeroed.to_primitive(); + let buffer = match_each_integer_ptype!(indices_nulls_zeroed.ptype(), |I| { + take_valid_indices(&array.to_bit_buffer(), indices_nulls_zeroed.as_slice::()) + }); + + Ok(BoolArray::new(buffer, array.validity().take(indices)?).to_array()) +} + +impl TakeExecute for BoolVTable { + fn take( + array: &BoolArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + take_bool(array, indices).map(Some) } } -register_kernel!(TakeKernelAdapter(BoolVTable).lift()); +impl BoolVTable { + pub const TAKE_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); +} fn take_valid_indices>(bools: &BitBuffer, indices: &[I]) -> BitBuffer { // For boolean arrays that roughly fit into a single page (at least, on Linux), it's worth @@ -54,7 +66,7 @@ fn take_valid_indices>(bools: &BitBuffer, indices: &[I]) - let bools = bools.iter().collect_vec(); take_byte_bool(bools, indices) } else { - take_bool(bools, indices) + take_bool_impl(bools, indices) } } @@ -64,7 +76,7 @@ fn take_byte_bool>(bools: Vec, indices: &[I]) -> Bit }) } -fn take_bool>(bools: &BitBuffer, indices: &[I]) -> BitBuffer { +fn take_bool_impl>(bools: &BitBuffer, indices: &[I]) -> BitBuffer { // We dereference to underlying buffer to avoid access cost on every index. let buffer = bools.inner().as_ref(); BitBuffer::collect_bool(indices.len(), |idx| { diff --git a/vortex-array/src/arrays/chunked/compute/take.rs b/vortex-array/src/arrays/chunked/compute/take.rs index 5facc27e503..47f9288f004 100644 --- a/vortex-array/src/arrays/chunked/compute/take.rs +++ b/vortex-array/src/arrays/chunked/compute/take.rs @@ -12,79 +12,91 @@ use crate::IntoArray; use crate::ToCanonical; use crate::arrays::ChunkedVTable; use crate::arrays::PrimitiveArray; +use crate::arrays::TakeExecute; +use crate::arrays::TakeExecuteAdaptor; use crate::arrays::chunked::ChunkedArray; -use crate::compute::TakeKernel; -use crate::compute::TakeKernelAdapter; use crate::compute::cast; use crate::compute::take; -use crate::register_kernel; +use crate::executor::ExecutionCtx; +use crate::kernel::ParentKernelSet; use crate::validity::Validity; -impl TakeKernel for ChunkedVTable { - fn take(&self, array: &ChunkedArray, indices: &dyn Array) -> VortexResult { - let indices = cast( - indices, - &DType::Primitive(PType::U64, indices.dtype().nullability()), - )? - .to_primitive(); - - // TODO(joe): Should we split this implementation based on indices nullability? - let nullability = indices.dtype().nullability(); - let indices_mask = indices.validity_mask()?; - let indices = indices.as_slice::(); - - let mut chunks = Vec::new(); - let mut indices_in_chunk = BufferMut::::empty(); - let mut start = 0; - let mut stop = 0; - // We assume indices are non-empty as it's handled in the top-level `take` function - let mut prev_chunk_idx = array.find_chunk_idx(indices[0].try_into()?)?.0; - for idx in indices { - let idx = usize::try_from(*idx)?; - let (chunk_idx, idx_in_chunk) = array.find_chunk_idx(idx)?; - - if chunk_idx != prev_chunk_idx { - // Start a new chunk - let indices_in_chunk_array = PrimitiveArray::new( - indices_in_chunk.clone().freeze(), - Validity::from_mask(indices_mask.slice(start..stop), nullability), - ); - chunks.push(take( - array.chunk(prev_chunk_idx), - indices_in_chunk_array.as_ref(), - )?); - indices_in_chunk.clear(); - start = stop; - } - - indices_in_chunk.push(idx_in_chunk as u64); - stop += 1; - prev_chunk_idx = chunk_idx; - } - - if !indices_in_chunk.is_empty() { +fn take_chunked(array: &ChunkedArray, indices: &dyn Array) -> VortexResult { + let indices = cast( + indices, + &DType::Primitive(PType::U64, indices.dtype().nullability()), + )? + .to_primitive(); + + // TODO(joe): Should we split this implementation based on indices nullability? + let nullability = indices.dtype().nullability(); + let indices_mask = indices.validity_mask()?; + let indices = indices.as_slice::(); + + let mut chunks = Vec::new(); + let mut indices_in_chunk = BufferMut::::empty(); + let mut start = 0; + let mut stop = 0; + // We assume indices are non-empty as it's handled in the top-level `take` function + let mut prev_chunk_idx = array.find_chunk_idx(indices[0].try_into()?)?.0; + for idx in indices { + let idx = usize::try_from(*idx)?; + let (chunk_idx, idx_in_chunk) = array.find_chunk_idx(idx)?; + + if chunk_idx != prev_chunk_idx { + // Start a new chunk let indices_in_chunk_array = PrimitiveArray::new( - indices_in_chunk.freeze(), + indices_in_chunk.clone().freeze(), Validity::from_mask(indices_mask.slice(start..stop), nullability), ); chunks.push(take( array.chunk(prev_chunk_idx), indices_in_chunk_array.as_ref(), )?); + indices_in_chunk.clear(); + start = stop; } - // SAFETY: take on chunks that all have same DType retains same DType - unsafe { - Ok(ChunkedArray::new_unchecked( - chunks, - array.dtype().clone().union_nullability(nullability), - ) - .into_array()) - } + indices_in_chunk.push(idx_in_chunk as u64); + stop += 1; + prev_chunk_idx = chunk_idx; + } + + if !indices_in_chunk.is_empty() { + let indices_in_chunk_array = PrimitiveArray::new( + indices_in_chunk.freeze(), + Validity::from_mask(indices_mask.slice(start..stop), nullability), + ); + chunks.push(take( + array.chunk(prev_chunk_idx), + indices_in_chunk_array.as_ref(), + )?); + } + + // SAFETY: take on chunks that all have same DType retains same DType + unsafe { + Ok(ChunkedArray::new_unchecked( + chunks, + array.dtype().clone().union_nullability(nullability), + ) + .into_array()) + } +} + +impl TakeExecute for ChunkedVTable { + fn take( + array: &ChunkedArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + take_chunked(array, indices).map(Some) } } -register_kernel!(TakeKernelAdapter(ChunkedVTable).lift()); +impl ChunkedVTable { + pub const TAKE_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); +} #[cfg(test)] mod test { diff --git a/vortex-array/src/arrays/constant/compute/take.rs b/vortex-array/src/arrays/constant/compute/take.rs index 24b60dd4fb9..352f153ff5e 100644 --- a/vortex-array/src/arrays/constant/compute/take.rs +++ b/vortex-array/src/arrays/constant/compute/take.rs @@ -11,47 +11,55 @@ use crate::IntoArray; use crate::arrays::ConstantArray; use crate::arrays::ConstantVTable; use crate::arrays::MaskedArray; -use crate::compute::TakeKernel; -use crate::compute::TakeKernelAdapter; -use crate::register_kernel; +use crate::arrays::TakeReduce; +use crate::arrays::TakeReduceAdaptor; +use crate::optimizer::rules::ParentRuleSet; use crate::validity::Validity; -impl TakeKernel for ConstantVTable { - fn take(&self, array: &ConstantArray, indices: &dyn Array) -> VortexResult { - match indices.validity_mask()?.bit_buffer() { - AllOr::All => { - let scalar = Scalar::new( - array - .scalar() - .dtype() - .union_nullability(indices.dtype().nullability()), - array.scalar().value().clone(), - ); - Ok(ConstantArray::new(scalar, indices.len()).into_array()) - } - AllOr::None => Ok(ConstantArray::new( - Scalar::null( - array - .dtype() - .union_nullability(indices.dtype().nullability()), - ), - indices.len(), - ) - .into_array()), - AllOr::Some(v) => { - let arr = ConstantArray::new(array.scalar().clone(), indices.len()).into_array(); - - if array.scalar().is_null() { - return Ok(arr); - } +fn take_constant(array: &ConstantArray, indices: &dyn Array) -> VortexResult { + let result = match indices.validity_mask()?.bit_buffer() { + AllOr::All => { + let scalar = Scalar::new( + array + .scalar() + .dtype() + .union_nullability(indices.dtype().nullability()), + array.scalar().value().clone(), + ); + ConstantArray::new(scalar, indices.len()).into_array() + } + AllOr::None => ConstantArray::new( + Scalar::null( + array + .dtype() + .union_nullability(indices.dtype().nullability()), + ), + indices.len(), + ) + .into_array(), + AllOr::Some(v) => { + let arr = ConstantArray::new(array.scalar().clone(), indices.len()).into_array(); - Ok(MaskedArray::try_new(arr, Validity::from(v.clone()))?.into_array()) + if array.scalar().is_null() { + return Ok(arr); } + + MaskedArray::try_new(arr, Validity::from(v.clone()))?.into_array() } + }; + Ok(result) +} + +impl TakeReduce for ConstantVTable { + fn take(array: &ConstantArray, indices: &dyn Array) -> VortexResult> { + take_constant(array, indices).map(Some) } } -register_kernel!(TakeKernelAdapter(ConstantVTable).lift()); +impl ConstantVTable { + pub const TAKE_RULES: ParentRuleSet = + ParentRuleSet::new(&[ParentRuleSet::lift(&TakeReduceAdaptor::(Self))]); +} #[cfg(test)] mod tests { diff --git a/vortex-array/src/arrays/decimal/compute/take.rs b/vortex-array/src/arrays/decimal/compute/take.rs index ff9f5097a7e..18298946cb0 100644 --- a/vortex-array/src/arrays/decimal/compute/take.rs +++ b/vortex-array/src/arrays/decimal/compute/take.rs @@ -13,33 +13,45 @@ use crate::ArrayRef; use crate::ToCanonical; use crate::arrays::DecimalArray; use crate::arrays::DecimalVTable; -use crate::compute::TakeKernel; -use crate::compute::TakeKernelAdapter; -use crate::register_kernel; +use crate::arrays::TakeExecute; +use crate::arrays::TakeExecuteAdaptor; +use crate::executor::ExecutionCtx; +use crate::kernel::ParentKernelSet; use crate::vtable::ValidityHelper; -impl TakeKernel for DecimalVTable { - fn take(&self, array: &DecimalArray, indices: &dyn Array) -> VortexResult { - let indices = indices.to_primitive(); - let validity = array.validity().take(indices.as_ref())?; - - // TODO(joe): if the true count of take indices validity is low, only take array values with - // valid indices. - let decimal = match_each_decimal_value_type!(array.values_type(), |D| { - match_each_integer_ptype!(indices.ptype(), |I| { - let buffer = - take_to_buffer::(indices.as_slice::(), array.buffer::().as_slice()); - // SAFETY: Take operation preserves decimal dtype and creates valid buffer. - // Validity is computed correctly from the parent array and indices. - unsafe { DecimalArray::new_unchecked(buffer, array.decimal_dtype(), validity) } - }) - }); - - Ok(decimal.to_array()) +fn take_decimal(array: &DecimalArray, indices: &dyn Array) -> VortexResult { + let indices = indices.to_primitive(); + let validity = array.validity().take(indices.as_ref())?; + + // TODO(joe): if the true count of take indices validity is low, only take array values with + // valid indices. + let decimal = match_each_decimal_value_type!(array.values_type(), |D| { + match_each_integer_ptype!(indices.ptype(), |I| { + let buffer = + take_to_buffer::(indices.as_slice::(), array.buffer::().as_slice()); + // SAFETY: Take operation preserves decimal dtype and creates valid buffer. + // Validity is computed correctly from the parent array and indices. + unsafe { DecimalArray::new_unchecked(buffer, array.decimal_dtype(), validity) } + }) + }); + + Ok(decimal.to_array()) +} + +impl TakeExecute for DecimalVTable { + fn take( + array: &DecimalArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + take_decimal(array, indices).map(Some) } } -register_kernel!(TakeKernelAdapter(DecimalVTable).lift()); +impl DecimalVTable { + pub const TAKE_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); +} #[inline] fn take_to_buffer(indices: &[I], values: &[T]) -> Buffer { diff --git a/vortex-array/src/arrays/dict/compute/mod.rs b/vortex-array/src/arrays/dict/compute/mod.rs index f3dce73932b..0dbc99e38e0 100644 --- a/vortex-array/src/arrays/dict/compute/mod.rs +++ b/vortex-array/src/arrays/dict/compute/mod.rs @@ -17,25 +17,32 @@ use vortex_mask::Mask; use super::DictArray; use super::DictVTable; +use super::TakeReduce; +use super::TakeReduceAdaptor; use crate::Array; use crate::ArrayRef; use crate::IntoArray; use crate::arrays::filter::FilterReduce; -use crate::compute::TakeKernel; -use crate::compute::TakeKernelAdapter; use crate::compute::take; -use crate::register_kernel; - -impl TakeKernel for DictVTable { - fn take(&self, array: &DictArray, indices: &dyn Array) -> VortexResult { - let codes = take(array.codes(), indices)?; - // SAFETY: selecting codes doesn't change the invariants of DictArray - // Preserve all_values_referenced since taking codes doesn't affect which values are referenced - Ok(unsafe { DictArray::new_unchecked(codes, array.values().clone()).into_array() }) +use crate::optimizer::rules::ParentRuleSet; + +fn take_dict(array: &DictArray, indices: &dyn Array) -> VortexResult { + let codes = take(array.codes(), indices)?; + // SAFETY: selecting codes doesn't change the invariants of DictArray + // Preserve all_values_referenced since taking codes doesn't affect which values are referenced + Ok(unsafe { DictArray::new_unchecked(codes, array.values().clone()).into_array() }) +} + +impl TakeReduce for DictVTable { + fn take(array: &DictArray, indices: &dyn Array) -> VortexResult> { + take_dict(array, indices).map(Some) } } -register_kernel!(TakeKernelAdapter(DictVTable).lift()); +impl DictVTable { + pub const TAKE_RULES: ParentRuleSet = + ParentRuleSet::new(&[ParentRuleSet::lift(&TakeReduceAdaptor::(Self))]); +} impl FilterReduce for DictVTable { fn filter(array: &DictArray, mask: &Mask) -> VortexResult> { diff --git a/vortex-array/src/arrays/dict/execute.rs b/vortex-array/src/arrays/dict/execute.rs index d3ffa4a1b6c..ae136cc99ac 100644 --- a/vortex-array/src/arrays/dict/execute.rs +++ b/vortex-array/src/arrays/dict/execute.rs @@ -24,7 +24,7 @@ use crate::arrays::StructArray; use crate::arrays::StructVTable; use crate::arrays::VarBinViewArray; use crate::arrays::VarBinViewVTable; -use crate::compute::TakeKernel; +use crate::compute::take; /// TODO: replace usage of compute fn. /// Take from a canonical array using indices (codes), returning a new canonical array. @@ -45,9 +45,8 @@ pub fn take_canonical(values: Canonical, codes: &PrimitiveArray) -> VortexResult }) } -fn take_null(_array: &NullArray, codes: &PrimitiveArray) -> NullArray { - NullVTable - .take(_array, codes.as_ref()) +fn take_null(array: &NullArray, codes: &PrimitiveArray) -> NullArray { + take(array.as_ref(), codes.as_ref()) .vortex_expect("take null array") .as_::() .clone() @@ -138,63 +137,54 @@ fn take_null(_array: &NullArray, codes: &PrimitiveArray) -> NullArray { // TODO(joe): use dict_bool_take fn take_bool(array: &BoolArray, codes: &PrimitiveArray) -> VortexResult { - Ok(BoolVTable - .take(array, codes.as_ref())? + Ok(take(array.as_ref(), codes.as_ref())? .as_::() .clone()) } fn take_primitive(array: &PrimitiveArray, codes: &PrimitiveArray) -> PrimitiveArray { - PrimitiveVTable - .take(array, codes.as_ref()) + take(array.as_ref(), codes.as_ref()) .vortex_expect("take primitive array") .as_::() .clone() } fn take_decimal(array: &DecimalArray, codes: &PrimitiveArray) -> DecimalArray { - DecimalVTable - .take(array, codes.as_ref()) + take(array.as_ref(), codes.as_ref()) .vortex_expect("take decimal array") .as_::() .clone() } fn take_varbinview(array: &VarBinViewArray, codes: &PrimitiveArray) -> VarBinViewArray { - VarBinViewVTable - .take(array, codes.as_ref()) + take(array.as_ref(), codes.as_ref()) .vortex_expect("take varbinview array") .as_::() .clone() } fn take_listview(array: &ListViewArray, codes: &PrimitiveArray) -> ListViewArray { - ListViewVTable - .take(array, codes.as_ref()) + take(array.as_ref(), codes.as_ref()) .vortex_expect("take listview array") .as_::() .clone() } fn take_fixed_size_list(array: &FixedSizeListArray, codes: &PrimitiveArray) -> FixedSizeListArray { - FixedSizeListVTable - .take(array, codes.as_ref()) + take(array.as_ref(), codes.as_ref()) .vortex_expect("take fixed size list array") .as_::() .clone() } fn take_struct(array: &StructArray, codes: &PrimitiveArray) -> StructArray { - StructVTable - .take(array, codes.as_ref()) + take(array.as_ref(), codes.as_ref()) .vortex_expect("take struct array") .as_::() .clone() } fn take_extension(array: &ExtensionArray, codes: &PrimitiveArray) -> ExtensionArray { - use crate::compute::take; - let taken_storage = take(array.storage(), codes.as_ref()).vortex_expect("take extension storage"); ExtensionArray::new(array.ext_dtype().clone(), taken_storage) diff --git a/vortex-array/src/arrays/extension/compute/take.rs b/vortex-array/src/arrays/extension/compute/take.rs index e66886508c6..dbe4555e61d 100644 --- a/vortex-array/src/arrays/extension/compute/take.rs +++ b/vortex-array/src/arrays/extension/compute/take.rs @@ -8,22 +8,29 @@ use crate::ArrayRef; use crate::IntoArray; use crate::arrays::ExtensionArray; use crate::arrays::ExtensionVTable; -use crate::compute::TakeKernel; -use crate::compute::TakeKernelAdapter; +use crate::arrays::TakeReduce; +use crate::arrays::TakeReduceAdaptor; use crate::compute::{self}; -use crate::register_kernel; +use crate::optimizer::rules::ParentRuleSet; -impl TakeKernel for ExtensionVTable { - fn take(&self, array: &ExtensionArray, indices: &dyn Array) -> VortexResult { - let taken_storage = compute::take(array.storage(), indices)?; - Ok(ExtensionArray::new( - array - .ext_dtype() - .with_nullability(taken_storage.dtype().nullability()), - taken_storage, - ) - .into_array()) +fn take_extension(array: &ExtensionArray, indices: &dyn Array) -> VortexResult { + let taken_storage = compute::take(array.storage(), indices)?; + Ok(ExtensionArray::new( + array + .ext_dtype() + .with_nullability(taken_storage.dtype().nullability()), + taken_storage, + ) + .into_array()) +} + +impl TakeReduce for ExtensionVTable { + fn take(array: &ExtensionArray, indices: &dyn Array) -> VortexResult> { + take_extension(array, indices).map(Some) } } -register_kernel!(TakeKernelAdapter(ExtensionVTable).lift()); +impl ExtensionVTable { + pub const TAKE_RULES: ParentRuleSet = + ParentRuleSet::new(&[ParentRuleSet::lift(&TakeReduceAdaptor::(Self))]); +} diff --git a/vortex-array/src/arrays/fixed_size_list/compute/take.rs b/vortex-array/src/arrays/fixed_size_list/compute/take.rs index f4ae9ad9b21..632c4e9850f 100644 --- a/vortex-array/src/arrays/fixed_size_list/compute/take.rs +++ b/vortex-array/src/arrays/fixed_size_list/compute/take.rs @@ -16,10 +16,11 @@ use crate::ToCanonical; use crate::arrays::FixedSizeListArray; use crate::arrays::FixedSizeListVTable; use crate::arrays::PrimitiveArray; -use crate::compute::TakeKernel; -use crate::compute::TakeKernelAdapter; +use crate::arrays::TakeExecute; +use crate::arrays::TakeExecuteAdaptor; use crate::compute::{self}; -use crate::register_kernel; +use crate::executor::ExecutionCtx; +use crate::kernel::ParentKernelSet; use crate::validity::Validity; use crate::vtable::ValidityHelper; @@ -28,15 +29,26 @@ use crate::vtable::ValidityHelper; /// Unlike `ListView`, `FixedSizeListArray` must rebuild the elements array because it requires /// that elements start at offset 0 and be perfectly packed without gaps. We expand list indices /// into element indices and push them down to the child elements array. -impl TakeKernel for FixedSizeListVTable { - fn take(&self, array: &FixedSizeListArray, indices: &dyn Array) -> VortexResult { - match_each_integer_ptype!(indices.dtype().as_ptype(), |I| { - take_with_indices::(array, indices) - }) +fn take_fixed_size_list(array: &FixedSizeListArray, indices: &dyn Array) -> VortexResult { + match_each_integer_ptype!(indices.dtype().as_ptype(), |I| { + take_with_indices::(array, indices) + }) +} + +impl TakeExecute for FixedSizeListVTable { + fn take( + array: &FixedSizeListArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + take_fixed_size_list(array, indices).map(Some) } } -register_kernel!(TakeKernelAdapter(FixedSizeListVTable).lift()); +impl FixedSizeListVTable { + pub const TAKE_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); +} /// Dispatches to the appropriate take implementation based on list size and nullability. fn take_with_indices( diff --git a/vortex-array/src/arrays/list/compute/take.rs b/vortex-array/src/arrays/list/compute/take.rs index 5bc6c309671..19d64cd5c5d 100644 --- a/vortex-array/src/arrays/list/compute/take.rs +++ b/vortex-array/src/arrays/list/compute/take.rs @@ -14,12 +14,13 @@ use crate::ToCanonical; use crate::arrays::ListArray; use crate::arrays::ListVTable; use crate::arrays::PrimitiveArray; +use crate::arrays::TakeExecute; +use crate::arrays::TakeExecuteAdaptor; use crate::builders::ArrayBuilder; use crate::builders::PrimitiveBuilder; -use crate::compute::TakeKernel; -use crate::compute::TakeKernelAdapter; use crate::compute::take; -use crate::register_kernel; +use crate::executor::ExecutionCtx; +use crate::kernel::ParentKernelSet; use crate::vtable::ValidityHelper; // TODO(connor)[ListView]: Re-revert to the version where we simply convert to a `ListView` and call @@ -30,24 +31,35 @@ use crate::vtable::ValidityHelper; /// Unlike `ListView`, `ListArray` must rebuild the elements array to maintain its invariant /// that lists are stored contiguously and in-order (`offset[i+1] >= offset[i]`). Taking /// non-contiguous indices would violate this requirement. -impl TakeKernel for ListVTable { - #[expect(clippy::cognitive_complexity)] - fn take(&self, array: &ListArray, indices: &dyn Array) -> VortexResult { - let indices = indices.to_primitive(); - // This is an over-approximation of the total number of elements in the resulting array. - let total_approx = array.elements().len().saturating_mul(indices.len()); - - match_each_integer_ptype!(array.offsets().dtype().as_ptype(), |O| { - match_each_integer_ptype!(indices.ptype(), |I| { - match_smallest_offset_type!(total_approx, |OutputOffsetType| { - _take::(array, &indices) - }) +#[expect(clippy::cognitive_complexity)] +fn take_list(array: &ListArray, indices: &dyn Array) -> VortexResult { + let indices = indices.to_primitive(); + // This is an over-approximation of the total number of elements in the resulting array. + let total_approx = array.elements().len().saturating_mul(indices.len()); + + match_each_integer_ptype!(array.offsets().dtype().as_ptype(), |O| { + match_each_integer_ptype!(indices.ptype(), |I| { + match_smallest_offset_type!(total_approx, |OutputOffsetType| { + _take::(array, &indices) }) }) + }) +} + +impl TakeExecute for ListVTable { + fn take( + array: &ListArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + take_list(array, indices).map(Some) } } -register_kernel!(TakeKernelAdapter(ListVTable).lift()); +impl ListVTable { + pub const TAKE_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); +} fn _take( array: &ListArray, diff --git a/vortex-array/src/arrays/listview/compute/take.rs b/vortex-array/src/arrays/listview/compute/take.rs index ac1d75169f5..699768eef0b 100644 --- a/vortex-array/src/arrays/listview/compute/take.rs +++ b/vortex-array/src/arrays/listview/compute/take.rs @@ -13,10 +13,11 @@ use crate::IntoArray; use crate::arrays::ListViewArray; use crate::arrays::ListViewRebuildMode; use crate::arrays::ListViewVTable; -use crate::compute::TakeKernel; -use crate::compute::TakeKernelAdapter; +use crate::arrays::TakeExecute; +use crate::arrays::TakeExecuteAdaptor; use crate::compute::{self}; -use crate::register_kernel; +use crate::executor::ExecutionCtx; +use crate::kernel::ParentKernelSet; use crate::vtable::ValidityHelper; // TODO(connor)[ListView]: Make use of this threshold after we start migrating operators. @@ -43,53 +44,64 @@ const REBUILD_DENSITY_THRESHOLD: f64 = 0.1; /// /// The trade-off is that we may keep unreferenced elements in memory, but this is acceptable since /// we're optimizing for read performance and the data isn't being copied. -impl TakeKernel for ListViewVTable { - fn take(&self, array: &ListViewArray, indices: &dyn Array) -> VortexResult { - let elements = array.elements(); - let offsets = array.offsets(); - let sizes = array.sizes(); +fn take_listview(array: &ListViewArray, indices: &dyn Array) -> VortexResult { + let elements = array.elements(); + let offsets = array.offsets(); + let sizes = array.sizes(); - // Compute the new validity by combining the array's validity with the indices' validity. - let new_validity = array.validity().take(indices)?; + // Compute the new validity by combining the array's validity with the indices' validity. + let new_validity = array.validity().take(indices)?; - // Take the offsets and sizes arrays at the requested indices. - // Take can reorder offsets, create gaps, and may introduce overlaps if the `indices` - // contain duplicates. - let nullable_new_offsets = compute::take(offsets.as_ref(), indices)?; - let nullable_new_sizes = compute::take(sizes.as_ref(), indices)?; + // Take the offsets and sizes arrays at the requested indices. + // Take can reorder offsets, create gaps, and may introduce overlaps if the `indices` + // contain duplicates. + let nullable_new_offsets = compute::take(offsets.as_ref(), indices)?; + let nullable_new_sizes = compute::take(sizes.as_ref(), indices)?; - // Since `take` returns nullable arrays, we simply cast it back to non-nullable (filled with - // zeros to represent null lists). - let new_offsets = match_each_integer_ptype!(nullable_new_offsets.dtype().as_ptype(), |O| { - compute::fill_null( - &nullable_new_offsets, - &Scalar::primitive(O::zero(), Nullability::NonNullable), - )? - }); - let new_sizes = match_each_integer_ptype!(nullable_new_sizes.dtype().as_ptype(), |S| { - compute::fill_null( - &nullable_new_sizes, - &Scalar::primitive(S::zero(), Nullability::NonNullable), - )? - }); - // SAFETY: Take operation maintains all `ListViewArray` invariants: - // - `new_offsets` and `new_sizes` are derived from existing valid child arrays. - // - `new_offsets` and `new_sizes` are non-nullable. - // - `new_offsets` and `new_sizes` have the same length (both taken with the same - // `indices`). - // - Validity correctly reflects the combination of array and indices validity. - let new_array = unsafe { - ListViewArray::new_unchecked(elements.clone(), new_offsets, new_sizes, new_validity) - }; + // Since `take` returns nullable arrays, we simply cast it back to non-nullable (filled with + // zeros to represent null lists). + let new_offsets = match_each_integer_ptype!(nullable_new_offsets.dtype().as_ptype(), |O| { + compute::fill_null( + &nullable_new_offsets, + &Scalar::primitive(O::zero(), Nullability::NonNullable), + )? + }); + let new_sizes = match_each_integer_ptype!(nullable_new_sizes.dtype().as_ptype(), |S| { + compute::fill_null( + &nullable_new_sizes, + &Scalar::primitive(S::zero(), Nullability::NonNullable), + )? + }); + // SAFETY: Take operation maintains all `ListViewArray` invariants: + // - `new_offsets` and `new_sizes` are derived from existing valid child arrays. + // - `new_offsets` and `new_sizes` are non-nullable. + // - `new_offsets` and `new_sizes` have the same length (both taken with the same + // `indices`). + // - Validity correctly reflects the combination of array and indices validity. + let new_array = unsafe { + ListViewArray::new_unchecked(elements.clone(), new_offsets, new_sizes, new_validity) + }; - // TODO(connor)[ListView]: Ideally, we would only rebuild after all `take`s and `filter` - // compute functions have run, at the "top" of the operator tree. However, we cannot do this - // right now, so we will just rebuild every time (similar to `ListArray`). + // TODO(connor)[ListView]: Ideally, we would only rebuild after all `take`s and `filter` + // compute functions have run, at the "top" of the operator tree. However, we cannot do this + // right now, so we will just rebuild every time (similar to `ListArray`). - Ok(new_array - .rebuild(ListViewRebuildMode::MakeZeroCopyToList)? - .into_array()) + Ok(new_array + .rebuild(ListViewRebuildMode::MakeZeroCopyToList)? + .into_array()) +} + +impl TakeExecute for ListViewVTable { + fn take( + array: &ListViewArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + take_listview(array, indices).map(Some) } } -register_kernel!(TakeKernelAdapter(ListViewVTable).lift()); +impl ListViewVTable { + pub const TAKE_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); +} diff --git a/vortex-array/src/arrays/masked/compute/take.rs b/vortex-array/src/arrays/masked/compute/take.rs index 945a1a047ed..10ce1609046 100644 --- a/vortex-array/src/arrays/masked/compute/take.rs +++ b/vortex-array/src/arrays/masked/compute/take.rs @@ -9,35 +9,42 @@ use crate::ArrayRef; use crate::IntoArray; use crate::arrays::MaskedArray; use crate::arrays::MaskedVTable; -use crate::compute::TakeKernel; -use crate::compute::TakeKernelAdapter; +use crate::arrays::TakeReduce; +use crate::arrays::TakeReduceAdaptor; use crate::compute::fill_null; use crate::compute::take; -use crate::register_kernel; +use crate::optimizer::rules::ParentRuleSet; use crate::vtable::ValidityHelper; -impl TakeKernel for MaskedVTable { - fn take(&self, array: &MaskedArray, indices: &dyn Array) -> VortexResult { - let taken_child = if !indices.all_valid()? { - // This is safe because we'll mask out these positions in the validity - let filled_take = fill_null( - indices, - &Scalar::default_value(indices.dtype().clone().as_nonnullable()), - )?; - take(&array.child, &filled_take)? - } else { - take(&array.child, indices)? - }; +fn take_masked(array: &MaskedArray, indices: &dyn Array) -> VortexResult { + let taken_child = if !indices.all_valid()? { + // This is safe because we'll mask out these positions in the validity + let filled_take = fill_null( + indices, + &Scalar::default_value(indices.dtype().clone().as_nonnullable()), + )?; + take(&array.child, &filled_take)? + } else { + take(&array.child, indices)? + }; - // Compute the new validity by taking from array's validity and merging with indices validity - let taken_validity = array.validity().take(indices)?; + // Compute the new validity by taking from array's validity and merging with indices validity + let taken_validity = array.validity().take(indices)?; - // Construct new MaskedArray - Ok(MaskedArray::try_new(taken_child, taken_validity)?.into_array()) + // Construct new MaskedArray + Ok(MaskedArray::try_new(taken_child, taken_validity)?.into_array()) +} + +impl TakeReduce for MaskedVTable { + fn take(array: &MaskedArray, indices: &dyn Array) -> VortexResult> { + take_masked(array, indices).map(Some) } } -register_kernel!(TakeKernelAdapter(MaskedVTable).lift()); +impl MaskedVTable { + pub const TAKE_RULES: ParentRuleSet = + ParentRuleSet::new(&[ParentRuleSet::lift(&TakeReduceAdaptor::(Self))]); +} #[cfg(test)] mod tests { diff --git a/vortex-array/src/arrays/mod.rs b/vortex-array/src/arrays/mod.rs index 0d5fa96c258..d6914891a6e 100644 --- a/vortex-array/src/arrays/mod.rs +++ b/vortex-array/src/arrays/mod.rs @@ -33,6 +33,7 @@ mod scalar_fn; mod shared; mod slice; mod struct_; +mod take; mod varbin; mod varbinview; @@ -59,5 +60,6 @@ pub use scalar_fn::*; pub use shared::*; pub use slice::*; pub use struct_::*; +pub use take::*; pub use varbin::*; pub use varbinview::*; diff --git a/vortex-array/src/arrays/null/compute/take.rs b/vortex-array/src/arrays/null/compute/take.rs index 4d3d595bd7d..fc9f24adab7 100644 --- a/vortex-array/src/arrays/null/compute/take.rs +++ b/vortex-array/src/arrays/null/compute/take.rs @@ -11,26 +11,33 @@ use crate::IntoArray; use crate::ToCanonical; use crate::arrays::NullArray; use crate::arrays::NullVTable; -use crate::compute::TakeKernel; -use crate::compute::TakeKernelAdapter; -use crate::register_kernel; +use crate::arrays::TakeReduce; +use crate::arrays::TakeReduceAdaptor; +use crate::optimizer::rules::ParentRuleSet; -impl TakeKernel for NullVTable { - #[allow(clippy::cast_possible_truncation)] - fn take(&self, array: &NullArray, indices: &dyn Array) -> VortexResult { - let indices = indices.to_primitive(); +#[allow(clippy::cast_possible_truncation)] +fn take_null(array: &NullArray, indices: &dyn Array) -> VortexResult { + let indices = indices.to_primitive(); - // Enforce all indices are valid - match_each_integer_ptype!(indices.ptype(), |T| { - for index in indices.as_slice::() { - if (*index as usize) >= array.len() { - vortex_bail!(OutOfBounds: *index as usize, 0, array.len()); - } + // Enforce all indices are valid + match_each_integer_ptype!(indices.ptype(), |T| { + for index in indices.as_slice::() { + if (*index as usize) >= array.len() { + vortex_bail!(OutOfBounds: *index as usize, 0, array.len()); } - }); + } + }); - Ok(NullArray::new(indices.len()).into_array()) + Ok(NullArray::new(indices.len()).into_array()) +} + +impl TakeReduce for NullVTable { + fn take(array: &NullArray, indices: &dyn Array) -> VortexResult> { + take_null(array, indices).map(Some) } } -register_kernel!(TakeKernelAdapter(NullVTable).lift()); +impl NullVTable { + pub const TAKE_RULES: ParentRuleSet = + ParentRuleSet::new(&[ParentRuleSet::lift(&TakeReduceAdaptor::(Self))]); +} diff --git a/vortex-array/src/arrays/primitive/compute/take/mod.rs b/vortex-array/src/arrays/primitive/compute/take/mod.rs index fd2ddb63cf3..e82055363c0 100644 --- a/vortex-array/src/arrays/primitive/compute/take/mod.rs +++ b/vortex-array/src/arrays/primitive/compute/take/mod.rs @@ -23,11 +23,12 @@ use crate::ArrayRef; use crate::IntoArray; use crate::ToCanonical; use crate::arrays::PrimitiveVTable; +use crate::arrays::TakeExecute; +use crate::arrays::TakeExecuteAdaptor; use crate::arrays::primitive::PrimitiveArray; -use crate::compute::TakeKernel; -use crate::compute::TakeKernelAdapter; use crate::compute::cast; -use crate::register_kernel; +use crate::executor::ExecutionCtx; +use crate::kernel::ParentKernelSet; use crate::validity::Validity; use crate::vtable::ValidityHelper; @@ -81,26 +82,37 @@ impl TakeImpl for TakeKernelScalar { } } -impl TakeKernel for PrimitiveVTable { - fn take(&self, array: &PrimitiveArray, indices: &dyn Array) -> VortexResult { - let DType::Primitive(ptype, null) = indices.dtype() else { - vortex_bail!("Invalid indices dtype: {}", indices.dtype()) - }; - - let unsigned_indices = if ptype.is_unsigned_int() { - indices.to_primitive() - } else { - // This will fail if all values cannot be converted to unsigned - cast(indices, &DType::Primitive(ptype.to_unsigned(), *null))?.to_primitive() - }; +fn take_primitive(array: &PrimitiveArray, indices: &dyn Array) -> VortexResult { + let DType::Primitive(ptype, null) = indices.dtype() else { + vortex_bail!("Invalid indices dtype: {}", indices.dtype()) + }; + + let unsigned_indices = if ptype.is_unsigned_int() { + indices.to_primitive() + } else { + // This will fail if all values cannot be converted to unsigned + cast(indices, &DType::Primitive(ptype.to_unsigned(), *null))?.to_primitive() + }; + + let validity = array.validity().take(unsigned_indices.as_ref())?; + // Delegate to the best kernel based on the target CPU + PRIMITIVE_TAKE_KERNEL.take(array, &unsigned_indices, validity) +} - let validity = array.validity().take(unsigned_indices.as_ref())?; - // Delegate to the best kernel based on the target CPU - PRIMITIVE_TAKE_KERNEL.take(array, &unsigned_indices, validity) +impl TakeExecute for PrimitiveVTable { + fn take( + array: &PrimitiveArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + take_primitive(array, indices).map(Some) } } -register_kernel!(TakeKernelAdapter(PrimitiveVTable).lift()); +impl PrimitiveVTable { + pub const TAKE_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); +} // Compiler may see this as unused based on enabled features #[allow(unused)] diff --git a/vortex-array/src/arrays/struct_/compute/take.rs b/vortex-array/src/arrays/struct_/compute/take.rs index 39bb432d1dd..3f6fcfddd2e 100644 --- a/vortex-array/src/arrays/struct_/compute/take.rs +++ b/vortex-array/src/arrays/struct_/compute/take.rs @@ -10,43 +10,50 @@ use crate::ArrayRef; use crate::IntoArray; use crate::arrays::StructArray; use crate::arrays::StructVTable; -use crate::compute::TakeKernel; -use crate::compute::TakeKernelAdapter; +use crate::arrays::TakeReduce; +use crate::arrays::TakeReduceAdaptor; use crate::compute::{self}; -use crate::register_kernel; +use crate::optimizer::rules::ParentRuleSet; use crate::validity::Validity; use crate::vtable::ValidityHelper; -impl TakeKernel for StructVTable { - fn take(&self, array: &StructArray, indices: &dyn Array) -> VortexResult { - // If the struct array is empty then the indices must be all null, otherwise it will access - // an out of bounds element - if array.is_empty() { - return StructArray::try_new_with_dtype( - array.unmasked_fields().clone(), - array.struct_fields().clone(), - indices.len(), - Validity::AllInvalid, - ) - .map(StructArray::into_array); - } - // The validity is applied to the struct validity, - let inner_indices = &compute::fill_null( - indices, - &Scalar::default_value(indices.dtype().with_nullability(Nullability::NonNullable)), - )?; - StructArray::try_new_with_dtype( - array - .unmasked_fields() - .iter() - .map(|field| compute::take(field, inner_indices)) - .collect::, _>>()?, +fn take_struct(array: &StructArray, indices: &dyn Array) -> VortexResult { + // If the struct array is empty then the indices must be all null, otherwise it will access + // an out of bounds element + if array.is_empty() { + return StructArray::try_new_with_dtype( + array.unmasked_fields().clone(), array.struct_fields().clone(), indices.len(), - array.validity().take(indices)?, + Validity::AllInvalid, ) - .map(|a| a.into_array()) + .map(StructArray::into_array); } + // The validity is applied to the struct validity, + let inner_indices = &compute::fill_null( + indices, + &Scalar::default_value(indices.dtype().with_nullability(Nullability::NonNullable)), + )?; + StructArray::try_new_with_dtype( + array + .unmasked_fields() + .iter() + .map(|field| compute::take(field, inner_indices)) + .collect::, _>>()?, + array.struct_fields().clone(), + indices.len(), + array.validity().take(indices)?, + ) + .map(|a| a.into_array()) } -register_kernel!(TakeKernelAdapter(StructVTable).lift()); +impl TakeReduce for StructVTable { + fn take(array: &StructArray, indices: &dyn Array) -> VortexResult> { + take_struct(array, indices).map(Some) + } +} + +impl StructVTable { + pub const TAKE_RULES: ParentRuleSet = + ParentRuleSet::new(&[ParentRuleSet::lift(&TakeReduceAdaptor::(Self))]); +} diff --git a/vortex-array/src/arrays/take/array.rs b/vortex-array/src/arrays/take/array.rs new file mode 100644 index 00000000000..acadee679e7 --- /dev/null +++ b/vortex-array/src/arrays/take/array.rs @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; + +use crate::ArrayRef; +use crate::stats::ArrayStats; + +/// Decomposed parts of the take array. +pub struct TakeArrayParts { + /// Child array to take elements from + pub child: ArrayRef, + /// Indices specifying which elements to take from the child + pub indices: ArrayRef, +} + +/// A lazy array that represents taking elements from a child array at specified indices. +/// +/// The resulting array contains the elements at the positions specified by the indices. +#[derive(Clone, Debug)] +pub struct TakeArray { + /// The source array to take elements from. + pub(super) child: ArrayRef, + + /// The indices specifying which elements to take. + pub(super) indices: ArrayRef, + + /// The stats for this array. + pub(super) stats: ArrayStats, +} + +impl TakeArray { + pub fn new(array: ArrayRef, indices: ArrayRef) -> Self { + Self::try_new(array, indices).vortex_expect("TakeArray construction failed") + } + + pub fn try_new(array: ArrayRef, indices: ArrayRef) -> VortexResult { + vortex_ensure!( + indices.dtype().is_int(), + "TakeArray indices must have integer dtype, got {}", + indices.dtype() + ); + + Ok(Self { + child: array, + indices, + stats: ArrayStats::default(), + }) + } + + /// The child array to take elements from. + pub fn child(&self) -> &ArrayRef { + &self.child + } + + /// The indices specifying which elements to take. + pub fn indices(&self) -> &ArrayRef { + &self.indices + } + + /// Consume the array and return its individual components. + pub fn into_parts(self) -> TakeArrayParts { + TakeArrayParts { + child: self.child, + indices: self.indices, + } + } +} diff --git a/vortex-array/src/arrays/take/execute.rs b/vortex-array/src/arrays/take/execute.rs new file mode 100644 index 00000000000..c649413f086 --- /dev/null +++ b/vortex-array/src/arrays/take/execute.rs @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Execution logic for [`TakeArray`]. +//! +//! The main entrypoint is [`execute_take`] which takes elements from any [`Canonical`] array. + +use vortex_error::VortexResult; +use vortex_scalar::Scalar; + +use crate::Array; +use crate::ArrayRef; +use crate::Canonical; +use crate::ExecutionCtx; +use crate::IntoArray; +use crate::arrays::ConstantArray; +use crate::arrays::TakeArray; +use crate::compute::take; + +/// Check for some fast-path execution conditions before calling [`execute_take`]. +pub(super) fn execute_take_fast_paths( + array: &TakeArray, + ctx: &mut ExecutionCtx, +) -> VortexResult> { + // If the indices are empty, the output is empty. + if array.indices.is_empty() { + return Ok(Some(Canonical::empty(array.dtype()))); + } + + // If all indices are invalid (null), return an array of nulls + if array.indices.all_invalid()? { + return Ok(Some( + ConstantArray::new( + Scalar::null(array.child.dtype().as_nullable()), + array.indices.len(), + ) + .into_array() + .execute(ctx)?, + )); + } + + // Also check if the source array itself is completely null + if array.child.validity_mask()?.true_count() == 0 { + return Ok(Some( + ConstantArray::new(Scalar::null(array.dtype().clone()), array.indices.len()) + .into_array() + .execute(ctx)?, + )); + } + + Ok(None) +} + +/// Take elements from a canonical array at the given indices, returning a new canonical array. +pub(super) fn execute_take(canonical: Canonical, indices: ArrayRef) -> VortexResult { + // For now, delegate to the compute take function and canonicalize the result + let taken = take(canonical.as_ref(), indices.as_ref())?; + taken.to_canonical() +} diff --git a/vortex-array/src/arrays/take/kernel.rs b/vortex-array/src/arrays/take/kernel.rs new file mode 100644 index 00000000000..c37ff5463fa --- /dev/null +++ b/vortex-array/src/arrays/take/kernel.rs @@ -0,0 +1,105 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_error::VortexResult; + +use crate::Array; +use crate::ArrayRef; +use crate::Canonical; +use crate::ExecutionCtx; +use crate::IntoArray; +use crate::arrays::TakeArray; +use crate::arrays::TakeVTable; +use crate::kernel::ExecuteParentKernel; +use crate::matcher::Matcher; +use crate::optimizer::rules::ArrayParentReduceRule; +use crate::vtable::VTable; + +pub trait TakeReduce: VTable { + /// Take elements from an array at the given indices without reading buffers. + /// + /// This trait is for take implementations that can operate purely on array metadata and + /// structure without needing to read or execute on the underlying buffers. Implementations + /// should return `None` if taking requires buffer access. + /// + /// # Preconditions + /// + /// The indices are guaranteed to be non-empty. + fn take(array: &Self::Array, indices: &dyn Array) -> VortexResult>; +} + +pub trait TakeExecute: VTable { + /// Take elements from an array at the given indices, potentially reading buffers. + /// + /// Unlike [`TakeReduce`], this trait is for take implementations that may need to read + /// and execute on the underlying buffers to produce the result. + /// + /// # Preconditions + /// + /// The indices are guaranteed to be non-empty. + fn take( + array: &Self::Array, + indices: &dyn Array, + ctx: &mut ExecutionCtx, + ) -> VortexResult>; +} + +/// Common preconditions for take operations that apply to all arrays. +/// +/// Returns `Some(result)` if the precondition short-circuits the take operation, +/// or `None` if the take should proceed normally. +pub fn precondition(array: &V::Array, indices: &dyn Array) -> Option { + // Fast-path for empty indices. + if indices.is_empty() { + return Some(Canonical::empty(array.dtype()).into_array()); + } + + None +} + +#[derive(Default, Debug)] +pub struct TakeReduceAdaptor(pub V); + +impl ArrayParentReduceRule for TakeReduceAdaptor +where + V: TakeReduce, +{ + type Parent = TakeVTable; + + fn reduce_parent( + &self, + array: &V::Array, + parent: &TakeArray, + child_idx: usize, + ) -> VortexResult> { + assert_eq!(child_idx, 0); + if let Some(result) = precondition::(array, parent.indices()) { + return Ok(Some(result)); + } + ::take(array, parent.indices()) + } +} + +#[derive(Default, Debug)] +pub struct TakeExecuteAdaptor(pub V); + +impl ExecuteParentKernel for TakeExecuteAdaptor +where + V: TakeExecute, +{ + type Parent = TakeVTable; + + fn execute_parent( + &self, + array: &V::Array, + parent: ::Match<'_>, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + assert_eq!(child_idx, 0); + if let Some(result) = precondition::(array, parent.indices()) { + return Ok(Some(result)); + } + ::take(array, parent.indices(), ctx) + } +} diff --git a/vortex-array/src/arrays/take/mod.rs b/vortex-array/src/arrays/take/mod.rs new file mode 100644 index 00000000000..44438a71d79 --- /dev/null +++ b/vortex-array/src/arrays/take/mod.rs @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +mod array; +pub use array::TakeArray; +pub use array::TakeArrayParts; + +mod execute; + +mod kernel; +pub use kernel::TakeExecute; +pub use kernel::TakeExecuteAdaptor; +pub use kernel::TakeReduce; +pub use kernel::TakeReduceAdaptor; + +mod rules; + +mod vtable; +pub use vtable::TakeVTable; diff --git a/vortex-array/src/arrays/take/rules.rs b/vortex-array/src/arrays/take/rules.rs new file mode 100644 index 00000000000..75d8c4e1cf4 --- /dev/null +++ b/vortex-array/src/arrays/take/rules.rs @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_error::VortexResult; + +use crate::Array; +use crate::ArrayRef; +use crate::IntoArray; +use crate::arrays::StructArray; +use crate::arrays::StructArrayParts; +use crate::arrays::StructVTable; +use crate::arrays::TakeArray; +use crate::arrays::TakeVTable; +use crate::optimizer::rules::ArrayParentReduceRule; +use crate::optimizer::rules::ArrayReduceRule; +use crate::optimizer::rules::ParentRuleSet; +use crate::optimizer::rules::ReduceRuleSet; + +pub(super) const PARENT_RULES: ParentRuleSet = + ParentRuleSet::new(&[ParentRuleSet::lift(&TakeTakeRule)]); + +pub(super) const RULES: ReduceRuleSet = ReduceRuleSet::new(&[&TakeStructRule]); + +/// A simple reduction rule that simplifies a [`TakeArray`] whose child is also a +/// [`TakeArray`]. +#[derive(Debug)] +struct TakeTakeRule; + +impl ArrayParentReduceRule for TakeTakeRule { + type Parent = TakeVTable; + + fn reduce_parent( + &self, + child: &TakeArray, + parent: &TakeArray, + _child_idx: usize, + ) -> VortexResult> { + // Take(Take(arr, indices1), indices2) = Take(arr, Take(indices1, indices2)) + // We need to take from the inner indices using the outer indices + let new_indices = child.indices.take(parent.indices.clone())?; + let new_array = child.child.take(new_indices)?; + + Ok(Some(new_array.into_array())) + } +} + +/// A reduce rule that pushes a take down into the fields of a StructArray. +#[derive(Debug)] +struct TakeStructRule; + +impl ArrayReduceRule for TakeStructRule { + fn reduce(&self, array: &TakeArray) -> VortexResult> { + let indices = array.indices(); + let Some(struct_array) = array.child().as_opt::() else { + return Ok(None); + }; + + let len = indices.len(); + let StructArrayParts { + fields, + struct_fields, + validity, + .. + } = struct_array.clone().into_parts(); + + let taken_validity = validity.take(indices)?; + + let taken_fields = fields + .iter() + .map(|field| field.take(indices.clone())) + .collect::>>()?; + + Ok(Some( + StructArray::new( + struct_fields.names().clone(), + taken_fields, + len, + taken_validity, + ) + .into_array(), + )) + } +} diff --git a/vortex-array/src/arrays/take/vtable.rs b/vortex-array/src/arrays/take/vtable.rs new file mode 100644 index 00000000000..8d89b64aa7a --- /dev/null +++ b/vortex-array/src/arrays/take/vtable.rs @@ -0,0 +1,213 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::fmt::Debug; +use std::fmt::Formatter; +use std::hash::Hasher; + +use vortex_dtype::DType; +use vortex_dtype::Nullability; +use vortex_dtype::PType; +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_ensure; +use vortex_scalar::Scalar; + +use crate::Array; +use crate::ArrayBufferVisitor; +use crate::ArrayChildVisitor; +use crate::ArrayEq; +use crate::ArrayHash; +use crate::ArrayRef; +use crate::Canonical; +use crate::IntoArray; +use crate::Precision; +use crate::arrays::take::array::TakeArray; +use crate::arrays::take::execute::execute_take; +use crate::arrays::take::execute::execute_take_fast_paths; +use crate::arrays::take::rules::PARENT_RULES; +use crate::arrays::take::rules::RULES; +use crate::buffer::BufferHandle; +use crate::executor::ExecutionCtx; +use crate::serde::ArrayChildren; +use crate::stats::StatsSetRef; +use crate::validity::Validity; +use crate::vtable; +use crate::vtable::ArrayId; +use crate::vtable::BaseArrayVTable; +use crate::vtable::NotSupported; +use crate::vtable::OperationsVTable; +use crate::vtable::VTable; +use crate::vtable::ValidityVTable; +use crate::vtable::VisitorVTable; + +vtable!(Take); + +#[derive(Debug)] +pub struct TakeVTable; + +impl TakeVTable { + pub const ID: ArrayId = ArrayId::new_ref("vortex.take"); +} + +impl VTable for TakeVTable { + type Array = TakeArray; + type Metadata = TakeMetadata; + type ArrayVTable = Self; + type OperationsVTable = Self; + type ValidityVTable = Self; + type VisitorVTable = Self; + type ComputeVTable = NotSupported; + + fn id(_array: &Self::Array) -> ArrayId { + Self::ID + } + + fn metadata(_array: &Self::Array) -> VortexResult { + Ok(TakeMetadata) + } + + fn serialize(_metadata: Self::Metadata) -> VortexResult>> { + vortex_bail!("Take array is not serializable") + } + + fn deserialize(_bytes: &[u8]) -> VortexResult { + vortex_bail!("Take array is not serializable") + } + + fn build( + dtype: &DType, + len: usize, + _metadata: &TakeMetadata, + _buffers: &[BufferHandle], + children: &dyn ArrayChildren, + ) -> VortexResult { + // The indices child determines the length - use u64 as the index type + let indices_dtype = DType::Primitive(PType::U64, Nullability::Nullable); + let indices = children.get(1, &indices_dtype, len)?; + let child = children.get(0, dtype, 0)?; // child length is unknown from metadata + Ok(TakeArray { + child, + indices, + stats: Default::default(), + }) + } + + fn with_children(array: &mut Self::Array, children: Vec) -> VortexResult<()> { + vortex_ensure!( + children.len() == 2, + "TakeArray expects exactly 2 children, got {}", + children.len() + ); + let mut iter = children.into_iter(); + array.child = iter + .next() + .vortex_expect("children length already validated"); + array.indices = iter + .next() + .vortex_expect("children length already validated"); + Ok(()) + } + + fn canonicalize(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { + if let Some(canonical) = execute_take_fast_paths(array, ctx)? { + return Ok(canonical); + } + + // Execute both child and indices to canonical form + let child_canonical = array.child.clone().execute::(ctx)?; + let indices_canonical = array.indices.clone().execute::(ctx)?; + + // Execute the take operation + let canonical = execute_take(child_canonical, indices_canonical.into_array())?; + + // Verify the resulting length and type + let result_len = array.indices.len(); + vortex_ensure!( + canonical.as_ref().len() == result_len, + "Take result length mismatch: expected {}, got {}", + result_len, + canonical.as_ref().len() + ); + + Ok(canonical) + } + + fn reduce_parent( + array: &Self::Array, + parent: &ArrayRef, + child_idx: usize, + ) -> VortexResult> { + PARENT_RULES.evaluate(array, parent, child_idx) + } + + fn reduce(array: &Self::Array) -> VortexResult> { + RULES.evaluate(array) + } +} + +impl BaseArrayVTable for TakeVTable { + fn len(array: &TakeArray) -> usize { + array.indices.len() + } + + fn dtype(array: &TakeArray) -> &DType { + // The dtype is the child's dtype with potentially nullable from indices + // For now, return child's dtype - nullability adjustment happens during execution + array.child.dtype() + } + + fn stats(array: &TakeArray) -> StatsSetRef<'_> { + array.stats.to_ref(array.as_ref()) + } + + fn array_hash(array: &TakeArray, state: &mut H, precision: Precision) { + array.child.array_hash(state, precision); + array.indices.array_hash(state, precision); + } + + fn array_eq(array: &TakeArray, other: &TakeArray, precision: Precision) -> bool { + array.child.array_eq(&other.child, precision) + && array.indices.array_eq(&other.indices, precision) + } +} + +impl OperationsVTable for TakeVTable { + fn scalar_at(array: &TakeArray, index: usize) -> VortexResult { + // Get the index value at position `index` from the indices array + let idx_scalar = array.indices.scalar_at(index)?; + if idx_scalar.is_null() { + return Ok(Scalar::null(array.child.dtype().as_nullable())); + } + let idx: usize = idx_scalar.as_ref().try_into()?; + array.child.scalar_at(idx) + } +} + +impl ValidityVTable for TakeVTable { + fn validity(array: &TakeArray) -> VortexResult { + // The validity of a take array depends on both: + // 1. The validity of the indices (null indices produce null values) + // 2. The validity of the child at the taken positions + // We return the child's validity taken at the indices positions + array.child.validity()?.take(&array.indices) + } +} + +impl VisitorVTable for TakeVTable { + fn visit_buffers(_array: &TakeArray, _visitor: &mut dyn ArrayBufferVisitor) {} + + fn visit_children(array: &TakeArray, visitor: &mut dyn ArrayChildVisitor) { + visitor.visit_child("child", &array.child); + visitor.visit_child("indices", &array.indices); + } +} + +pub struct TakeMetadata; + +impl Debug for TakeMetadata { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "TakeMetadata") + } +} diff --git a/vortex-array/src/arrays/varbin/compute/take.rs b/vortex-array/src/arrays/varbin/compute/take.rs index 0562a20209e..66536ed7abc 100644 --- a/vortex-array/src/arrays/varbin/compute/take.rs +++ b/vortex-array/src/arrays/varbin/compute/take.rs @@ -17,101 +17,117 @@ use crate::ArrayRef; use crate::IntoArray; use crate::ToCanonical; use crate::arrays::PrimitiveArray; +use crate::arrays::TakeExecute; +use crate::arrays::TakeExecuteAdaptor; use crate::arrays::VarBinVTable; use crate::arrays::varbin::VarBinArray; -use crate::compute::TakeKernel; -use crate::compute::TakeKernelAdapter; -use crate::register_kernel; +use crate::executor::ExecutionCtx; +use crate::kernel::ParentKernelSet; use crate::validity::Validity; -impl TakeKernel for VarBinVTable { - fn take(&self, array: &VarBinArray, indices: &dyn Array) -> VortexResult { - let offsets = array.offsets().to_primitive(); - let data = array.bytes(); - let indices = indices.to_primitive(); - let dtype = array - .dtype() - .clone() - .union_nullability(indices.dtype().nullability()); - let array = match_each_integer_ptype!(indices.ptype(), |I| { - // On take, offsets get widened to either 32- or 64-bit based on the original type, - // to avoid overflow issues. - match offsets.ptype() { - PType::U8 => take::( - dtype, - offsets.as_slice::(), - data.as_slice(), - indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, - ), - PType::U16 => take::( - dtype, - offsets.as_slice::(), - data.as_slice(), - indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, - ), - PType::U32 => take::( - dtype, - offsets.as_slice::(), - data.as_slice(), - indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, - ), - PType::U64 => take::( - dtype, - offsets.as_slice::(), - data.as_slice(), - indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, - ), - PType::I8 => take::( - dtype, - offsets.as_slice::(), - data.as_slice(), - indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, - ), - PType::I16 => take::( - dtype, - offsets.as_slice::(), - data.as_slice(), - indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, - ), - PType::I32 => take::( - dtype, - offsets.as_slice::(), - data.as_slice(), - indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, - ), - PType::I64 => take::( - dtype, - offsets.as_slice::(), - data.as_slice(), - indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, - ), - _ => unreachable!("invalid PType for offsets"), - } - }); - - Ok(array?.into_array()) +#[expect( + clippy::redundant_clone, + reason = "macro expansion causes false positive - only one match arm executes" +)] +fn take_varbin(array: &VarBinArray, indices: &dyn Array) -> VortexResult { + let offsets = array.offsets().to_primitive(); + let data = array.bytes(); + let indices = indices.to_primitive(); + let dtype = array + .dtype() + .clone() + .union_nullability(indices.dtype().nullability()); + let result = match_each_integer_ptype!(indices.ptype(), |I| { + // On take, offsets get widened to either 32- or 64-bit based on the original type, + // to avoid overflow issues. + match offsets.ptype() { + PType::U8 => take_impl::( + dtype.clone(), + offsets.as_slice::(), + data.as_slice(), + indices.as_slice::(), + array.validity_mask()?, + indices.validity_mask()?, + ), + PType::U16 => take_impl::( + dtype.clone(), + offsets.as_slice::(), + data.as_slice(), + indices.as_slice::(), + array.validity_mask()?, + indices.validity_mask()?, + ), + PType::U32 => take_impl::( + dtype.clone(), + offsets.as_slice::(), + data.as_slice(), + indices.as_slice::(), + array.validity_mask()?, + indices.validity_mask()?, + ), + PType::U64 => take_impl::( + dtype.clone(), + offsets.as_slice::(), + data.as_slice(), + indices.as_slice::(), + array.validity_mask()?, + indices.validity_mask()?, + ), + PType::I8 => take_impl::( + dtype.clone(), + offsets.as_slice::(), + data.as_slice(), + indices.as_slice::(), + array.validity_mask()?, + indices.validity_mask()?, + ), + PType::I16 => take_impl::( + dtype.clone(), + offsets.as_slice::(), + data.as_slice(), + indices.as_slice::(), + array.validity_mask()?, + indices.validity_mask()?, + ), + PType::I32 => take_impl::( + dtype.clone(), + offsets.as_slice::(), + data.as_slice(), + indices.as_slice::(), + array.validity_mask()?, + indices.validity_mask()?, + ), + PType::I64 => take_impl::( + dtype.clone(), + offsets.as_slice::(), + data.as_slice(), + indices.as_slice::(), + array.validity_mask()?, + indices.validity_mask()?, + ), + _ => unreachable!("invalid PType for offsets"), + } + }); + + Ok(result?.into_array()) +} + +impl TakeExecute for VarBinVTable { + fn take( + array: &VarBinArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + take_varbin(array, indices).map(Some) } } -register_kernel!(TakeKernelAdapter(VarBinVTable).lift()); +impl VarBinVTable { + pub const TAKE_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); +} -fn take( +fn take_impl( dtype: DType, offsets: &[Offset], data: &[u8], diff --git a/vortex-array/src/arrays/varbinview/compute/take.rs b/vortex-array/src/arrays/varbinview/compute/take.rs index b937c904438..d5abe48f102 100644 --- a/vortex-array/src/arrays/varbinview/compute/take.rs +++ b/vortex-array/src/arrays/varbinview/compute/take.rs @@ -15,42 +15,53 @@ use crate::Array; use crate::ArrayRef; use crate::IntoArray; use crate::ToCanonical; +use crate::arrays::TakeExecute; +use crate::arrays::TakeExecuteAdaptor; use crate::arrays::VarBinViewArray; use crate::arrays::VarBinViewVTable; use crate::buffer::BufferHandle; -use crate::compute::TakeKernel; -use crate::compute::TakeKernelAdapter; -use crate::register_kernel; +use crate::executor::ExecutionCtx; +use crate::kernel::ParentKernelSet; use crate::vtable::ValidityHelper; /// Take involves creating a new array that references the old array, just with the given set of views. -impl TakeKernel for VarBinViewVTable { - fn take(&self, array: &VarBinViewArray, indices: &dyn Array) -> VortexResult { - // Compute the new validity. - let validity = array.validity().take(indices)?; - let indices = indices.to_primitive(); - - let indices_mask = indices.validity_mask()?; - let views_buffer = match_each_integer_ptype!(indices.ptype(), |I| { - take_views(array.views(), indices.as_slice::(), &indices_mask) - }); - - // SAFETY: taking all components at same indices maintains invariants - unsafe { - Ok(VarBinViewArray::new_handle_unchecked( - BufferHandle::new_host(views_buffer.into_byte_buffer()), - array.buffers().clone(), - array - .dtype() - .union_nullability(indices.dtype().nullability()), - validity, - ) - .into_array()) - } +fn take_varbinview(array: &VarBinViewArray, indices: &dyn Array) -> VortexResult { + let validity = array.validity().take(indices)?; + let indices = indices.to_primitive(); + + let indices_mask = indices.validity_mask()?; + let views_buffer = match_each_integer_ptype!(indices.ptype(), |I| { + take_views(array.views(), indices.as_slice::(), &indices_mask) + }); + + // SAFETY: taking all components at same indices maintains invariants + unsafe { + Ok(VarBinViewArray::new_handle_unchecked( + BufferHandle::new_host(views_buffer.into_byte_buffer()), + array.buffers().clone(), + array + .dtype() + .union_nullability(indices.dtype().nullability()), + validity, + ) + .into_array()) } } -register_kernel!(TakeKernelAdapter(VarBinViewVTable).lift()); +impl TakeExecute for VarBinViewVTable { + fn take( + array: &VarBinViewArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + take_varbinview(array, indices).map(Some) + } +} + +impl VarBinViewVTable { + pub const TAKE_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); +} fn take_views>( views_ref: &[BinaryView], diff --git a/vortex-array/src/compute/take.rs b/vortex-array/src/compute/take.rs index 391c2924ce9..965ef86cb68 100644 --- a/vortex-array/src/compute/take.rs +++ b/vortex-array/src/compute/take.rs @@ -1,45 +1,16 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use std::sync::LazyLock; - -use arcref::ArcRef; -use vortex_dtype::DType; -use vortex_error::VortexError; use vortex_error::VortexResult; -use vortex_error::vortex_bail; -use vortex_error::vortex_err; -use vortex_scalar::Scalar; use crate::Array; use crate::ArrayRef; -use crate::Canonical; use crate::IntoArray; -use crate::arrays::ConstantArray; -use crate::arrays::ConstantVTable; -use crate::compute::ComputeFn; -use crate::compute::ComputeFnVTable; -use crate::compute::InvocationArgs; -use crate::compute::Kernel; -use crate::compute::Output; use crate::expr::stats::Precision; use crate::expr::stats::Stat; use crate::expr::stats::StatsProvider; use crate::expr::stats::StatsProviderExt; use crate::stats::StatsSet; -use crate::vtable::VTable; - -static TAKE_FN: LazyLock = LazyLock::new(|| { - let compute = ComputeFn::new("take".into(), ArcRef::new_ref(&Take)); - for kernel in inventory::iter:: { - compute.register_kernel(kernel.0.clone()); - } - compute -}); - -pub(crate) fn warm_up_vtable() -> usize { - TAKE_FN.kernels().len() + TAKE_FROM_FN.kernels().len() -} /// Creates a new array using the elements from the input `array` indexed by `indices`. /// @@ -48,85 +19,13 @@ pub(crate) fn warm_up_vtable() -> usize { /// /// The output array will have the same length as the `indices` array. pub fn take(array: &dyn Array, indices: &dyn Array) -> VortexResult { - if indices.is_empty() { - return Ok(Canonical::empty( - &array - .dtype() - .union_nullability(indices.dtype().nullability()), - ) - .into_array()); - } - - TAKE_FN - .invoke(&InvocationArgs { - inputs: &[array.into(), indices.into()], - options: &(), - })? - .unwrap_array() + array + .take(indices.to_array())? + .to_canonical() + .map(|c| c.into_array()) } -#[doc(hidden)] -pub struct Take; - -impl ComputeFnVTable for Take { - fn invoke( - &self, - args: &InvocationArgs, - kernels: &[ArcRef], - ) -> VortexResult { - let TakeArgs { array, indices } = TakeArgs::try_from(args)?; - - // TODO(ngates): if indices are sorted and unique (strict-sorted), then we should delegate to - // the filter function since they're typically optimised for this case. - // TODO(ngates): if indices min is quite high, we could slice self and offset the indices - // such that canonicalize does less work. - - if indices.all_invalid()? { - return Ok(ConstantArray::new( - Scalar::null(array.dtype().as_nullable()), - indices.len(), - ) - .into_array() - .into()); - } - - let taken_array = take_impl(array, indices, kernels)?; - - // We know that constant array don't need stats propagation, so we can avoid the overhead of - // computing derived stats and merging them in. - if !taken_array.is::() { - propagate_take_stats(array, &taken_array, indices)?; - } - - Ok(taken_array.into()) - } - - fn return_dtype(&self, args: &InvocationArgs) -> VortexResult { - let TakeArgs { array, indices } = TakeArgs::try_from(args)?; - - if !indices.dtype().is_int() { - vortex_bail!( - "Take indices must be an integer type, got {}", - indices.dtype() - ); - } - - Ok(array - .dtype() - .union_nullability(indices.dtype().nullability())) - } - - fn return_len(&self, args: &InvocationArgs) -> VortexResult { - let TakeArgs { indices, .. } = TakeArgs::try_from(args)?; - Ok(indices.len()) - } - - fn is_elementwise(&self) -> bool { - false - } -} - -fn propagate_take_stats( +pub(crate) fn propagate_take_stats( source: &dyn Array, target: &dyn Array, indices: &dyn Array, @@ -153,155 +52,3 @@ fn propagate_take_stats( ) }) } - -fn take_impl( - array: &dyn Array, - indices: &dyn Array, - kernels: &[ArcRef], -) -> VortexResult { - let args = InvocationArgs { - inputs: &[array.into(), indices.into()], - options: &(), - }; - - // First look for a TakeFrom specialized on the indices. - for kernel in TAKE_FROM_FN.kernels() { - if let Some(output) = kernel.invoke(&args)? { - return output.unwrap_array(); - } - } - - // Then look for a Take kernel - for kernel in kernels { - if let Some(output) = kernel.invoke(&args)? { - return output.unwrap_array(); - } - } - - // Otherwise, canonicalize and try again. - if !array.is_canonical() { - tracing::debug!("No take implementation found for {}", array.encoding_id()); - let canonical = array.to_canonical()?; - return take(canonical.as_ref(), indices); - } - - vortex_bail!("No take implementation found for {}", array.encoding_id()); -} - -struct TakeArgs<'a> { - array: &'a dyn Array, - indices: &'a dyn Array, -} - -impl<'a> TryFrom<&InvocationArgs<'a>> for TakeArgs<'a> { - type Error = VortexError; - - fn try_from(value: &InvocationArgs<'a>) -> Result { - if value.inputs.len() != 2 { - vortex_bail!("Expected 2 inputs, found {}", value.inputs.len()); - } - let array = value.inputs[0] - .array() - .ok_or_else(|| vortex_err!("Expected first input to be an array"))?; - let indices = value.inputs[1] - .array() - .ok_or_else(|| vortex_err!("Expected second input to be an array"))?; - Ok(Self { array, indices }) - } -} - -pub trait TakeKernel: VTable { - /// Create a new array by taking the values from the `array` at the - /// given `indices`. - /// - /// # Panics - /// - /// Using `indices` that are invalid for the given `array` will cause a panic. - fn take(&self, array: &Self::Array, indices: &dyn Array) -> VortexResult; -} - -/// A kernel that implements the filter function. -pub struct TakeKernelRef(pub ArcRef); -inventory::collect!(TakeKernelRef); - -#[derive(Debug)] -pub struct TakeKernelAdapter(pub V); - -impl TakeKernelAdapter { - pub const fn lift(&'static self) -> TakeKernelRef { - TakeKernelRef(ArcRef::new_ref(self)) - } -} - -impl Kernel for TakeKernelAdapter { - fn invoke(&self, args: &InvocationArgs) -> VortexResult> { - let inputs = TakeArgs::try_from(args)?; - let Some(array) = inputs.array.as_opt::() else { - return Ok(None); - }; - Ok(Some(V::take(&self.0, array, inputs.indices)?.into())) - } -} - -static TAKE_FROM_FN: LazyLock = LazyLock::new(|| { - let compute = ComputeFn::new("take_from".into(), ArcRef::new_ref(&TakeFrom)); - for kernel in inventory::iter:: { - compute.register_kernel(kernel.0.clone()); - } - compute -}); - -pub struct TakeFrom; - -impl ComputeFnVTable for TakeFrom { - fn invoke( - &self, - _args: &InvocationArgs, - _kernels: &[ArcRef], - ) -> VortexResult { - vortex_bail!( - "TakeFrom should not be invoked directly. Its kernels are used to accelerated the Take function" - ) - } - - fn return_dtype(&self, args: &InvocationArgs) -> VortexResult { - Take.return_dtype(args) - } - - fn return_len(&self, args: &InvocationArgs) -> VortexResult { - Take.return_len(args) - } - - fn is_elementwise(&self) -> bool { - Take.is_elementwise() - } -} - -pub trait TakeFromKernel: VTable { - /// Create a new array by taking the values from the `array` at the - /// given `indices`. - fn take_from(&self, indices: &Self::Array, array: &dyn Array) - -> VortexResult>; -} - -pub struct TakeFromKernelRef(pub ArcRef); -inventory::collect!(TakeFromKernelRef); - -#[derive(Debug)] -pub struct TakeFromKernelAdapter(pub V); - -impl TakeFromKernelAdapter { - pub const fn lift(&'static self) -> TakeFromKernelRef { - TakeFromKernelRef(ArcRef::new_ref(self)) - } -} - -impl Kernel for TakeFromKernelAdapter { - fn invoke(&self, args: &InvocationArgs) -> VortexResult> { - let inputs = TakeArgs::try_from(args)?; - let Some(indices) = inputs.indices.as_opt::() else { - return Ok(None); - }; - Ok(V::take_from(&self.0, indices, inputs.array)?.map(Output::from)) - } -} From 5d9fa4ae0abe5b899e67512eb447ae38142a9713 Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Thu, 5 Feb 2026 12:45:53 +0000 Subject: [PATCH 06/21] wip Signed-off-by: Joe Isaacs --- encodings/alp/src/alp/compute/take.rs | 19 +- encodings/alp/src/alp_rd/compute/take.rs | 19 +- encodings/datetime-parts/src/compute/take.rs | 19 +- .../src/decimal_byte_parts/compute/take.rs | 19 +- encodings/fastlanes/src/for/compute/mod.rs | 19 +- encodings/fsst/src/compute/mod.rs | 19 +- encodings/sparse/src/compute/take.rs | 19 +- encodings/zigzag/src/compute/mod.rs | 19 +- vortex-array/src/arrays/dict/compute/mod.rs | 19 +- vortex-array/src/arrays/dict/mod.rs | 3 + .../arrays/{take/kernel.rs => dict/take.rs} | 40 +++- .../src/arrays/extension/compute/take.rs | 19 +- .../src/arrays/masked/compute/take.rs | 19 +- vortex-array/src/arrays/mod.rs | 2 - .../src/arrays/struct_/compute/take.rs | 19 +- vortex-array/src/arrays/take/array.rs | 70 ------ vortex-array/src/arrays/take/execute.rs | 59 ----- vortex-array/src/arrays/take/mod.rs | 19 -- vortex-array/src/arrays/take/rules.rs | 83 ------- vortex-array/src/arrays/take/vtable.rs | 213 ------------------ 20 files changed, 176 insertions(+), 541 deletions(-) rename vortex-array/src/arrays/{take/kernel.rs => dict/take.rs} (70%) delete mode 100644 vortex-array/src/arrays/take/array.rs delete mode 100644 vortex-array/src/arrays/take/execute.rs delete mode 100644 vortex-array/src/arrays/take/mod.rs delete mode 100644 vortex-array/src/arrays/take/rules.rs delete mode 100644 vortex-array/src/arrays/take/vtable.rs diff --git a/encodings/alp/src/alp/compute/take.rs b/encodings/alp/src/alp/compute/take.rs index 2a890177355..f9bff38e486 100644 --- a/encodings/alp/src/alp/compute/take.rs +++ b/encodings/alp/src/alp/compute/take.rs @@ -3,11 +3,12 @@ use vortex_array::Array; use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; use vortex_array::IntoArray; -use vortex_array::arrays::TakeReduce; -use vortex_array::arrays::TakeReduceAdaptor; +use vortex_array::arrays::TakeExecute; +use vortex_array::arrays::TakeExecuteAdaptor; use vortex_array::compute::take; -use vortex_array::optimizer::rules::ParentRuleSet; +use vortex_array::kernel::ParentKernelSet; use vortex_error::VortexResult; use crate::ALPArray; @@ -31,15 +32,19 @@ fn take_alp(array: &ALPArray, indices: &dyn Array) -> VortexResult { Ok(ALPArray::new(taken_encoded, array.exponents(), taken_patches).into_array()) } -impl TakeReduce for ALPVTable { - fn take(array: &ALPArray, indices: &dyn Array) -> VortexResult> { +impl TakeExecute for ALPVTable { + fn take( + array: &ALPArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { take_alp(array, indices).map(Some) } } impl ALPVTable { - pub const TAKE_RULES: ParentRuleSet = - ParentRuleSet::new(&[ParentRuleSet::lift(&TakeReduceAdaptor::(Self))]); + pub const TAKE_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); } #[cfg(test)] diff --git a/encodings/alp/src/alp_rd/compute/take.rs b/encodings/alp/src/alp_rd/compute/take.rs index 4983a9a9b2c..aaa3cd2b909 100644 --- a/encodings/alp/src/alp_rd/compute/take.rs +++ b/encodings/alp/src/alp_rd/compute/take.rs @@ -3,12 +3,13 @@ use vortex_array::Array; use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; use vortex_array::IntoArray; -use vortex_array::arrays::TakeReduce; -use vortex_array::arrays::TakeReduceAdaptor; +use vortex_array::arrays::TakeExecute; +use vortex_array::arrays::TakeExecuteAdaptor; use vortex_array::compute::fill_null; use vortex_array::compute::take; -use vortex_array::optimizer::rules::ParentRuleSet; +use vortex_array::kernel::ParentKernelSet; use vortex_error::VortexResult; use vortex_scalar::Scalar; use vortex_scalar::ScalarValue; @@ -49,15 +50,19 @@ fn take_alprd(array: &ALPRDArray, indices: &dyn Array) -> VortexResult .into_array()) } -impl TakeReduce for ALPRDVTable { - fn take(array: &ALPRDArray, indices: &dyn Array) -> VortexResult> { +impl TakeExecute for ALPRDVTable { + fn take( + array: &ALPRDArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { take_alprd(array, indices).map(Some) } } impl ALPRDVTable { - pub const TAKE_RULES: ParentRuleSet = - ParentRuleSet::new(&[ParentRuleSet::lift(&TakeReduceAdaptor::(Self))]); + pub const TAKE_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); } #[cfg(test)] diff --git a/encodings/datetime-parts/src/compute/take.rs b/encodings/datetime-parts/src/compute/take.rs index 770d173d521..21943948142 100644 --- a/encodings/datetime-parts/src/compute/take.rs +++ b/encodings/datetime-parts/src/compute/take.rs @@ -3,15 +3,16 @@ use vortex_array::Array; use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::ToCanonical; -use vortex_array::arrays::TakeReduce; -use vortex_array::arrays::TakeReduceAdaptor; +use vortex_array::arrays::TakeExecute; +use vortex_array::arrays::TakeExecuteAdaptor; use vortex_array::compute::fill_null; use vortex_array::compute::take; use vortex_array::expr::stats::Stat; use vortex_array::expr::stats::StatsProvider; -use vortex_array::optimizer::rules::ParentRuleSet; +use vortex_array::kernel::ParentKernelSet; use vortex_dtype::Nullability; use vortex_error::VortexResult; use vortex_error::vortex_panic; @@ -87,15 +88,19 @@ fn take_datetime_parts(array: &DateTimePartsArray, indices: &dyn Array) -> Vorte ) } -impl TakeReduce for DateTimePartsVTable { - fn take(array: &DateTimePartsArray, indices: &dyn Array) -> VortexResult> { +impl TakeExecute for DateTimePartsVTable { + fn take( + array: &DateTimePartsArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { take_datetime_parts(array, indices).map(Some) } } impl DateTimePartsVTable { - pub const TAKE_RULES: ParentRuleSet = - ParentRuleSet::new(&[ParentRuleSet::lift(&TakeReduceAdaptor::(Self))]); + pub const TAKE_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); } #[cfg(test)] diff --git a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/take.rs b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/take.rs index 407734d5c24..f29aa4866ff 100644 --- a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/take.rs +++ b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/take.rs @@ -3,10 +3,11 @@ use vortex_array::Array; use vortex_array::ArrayRef; -use vortex_array::arrays::TakeReduce; -use vortex_array::arrays::TakeReduceAdaptor; +use vortex_array::ExecutionCtx; +use vortex_array::arrays::TakeExecute; +use vortex_array::arrays::TakeExecuteAdaptor; use vortex_array::compute::take; -use vortex_array::optimizer::rules::ParentRuleSet; +use vortex_array::kernel::ParentKernelSet; use vortex_error::VortexResult; use crate::DecimalBytePartsArray; @@ -20,13 +21,17 @@ fn take_decimal_byte_parts( .map(|a| a.to_array()) } -impl TakeReduce for DecimalBytePartsVTable { - fn take(array: &DecimalBytePartsArray, indices: &dyn Array) -> VortexResult> { +impl TakeExecute for DecimalBytePartsVTable { + fn take( + array: &DecimalBytePartsArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { take_decimal_byte_parts(array, indices).map(Some) } } impl DecimalBytePartsVTable { - pub const TAKE_RULES: ParentRuleSet = - ParentRuleSet::new(&[ParentRuleSet::lift(&TakeReduceAdaptor::(Self))]); + pub const TAKE_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); } diff --git a/encodings/fastlanes/src/for/compute/mod.rs b/encodings/fastlanes/src/for/compute/mod.rs index b2b4a590c2c..7b24564f019 100644 --- a/encodings/fastlanes/src/for/compute/mod.rs +++ b/encodings/fastlanes/src/for/compute/mod.rs @@ -8,12 +8,13 @@ mod is_sorted; use vortex_array::Array; use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::FilterReduce; -use vortex_array::arrays::TakeReduce; -use vortex_array::arrays::TakeReduceAdaptor; +use vortex_array::arrays::TakeExecute; +use vortex_array::arrays::TakeExecuteAdaptor; use vortex_array::compute::take; -use vortex_array::optimizer::rules::ParentRuleSet; +use vortex_array::kernel::ParentKernelSet; use vortex_error::VortexResult; use vortex_mask::Mask; @@ -28,15 +29,19 @@ fn take_for(array: &FoRArray, indices: &dyn Array) -> VortexResult { .map(|a| a.into_array()) } -impl TakeReduce for FoRVTable { - fn take(array: &FoRArray, indices: &dyn Array) -> VortexResult> { +impl TakeExecute for FoRVTable { + fn take( + array: &FoRArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { take_for(array, indices).map(Some) } } impl FoRVTable { - pub const TAKE_RULES: ParentRuleSet = - ParentRuleSet::new(&[ParentRuleSet::lift(&TakeReduceAdaptor::(Self))]); + pub const TAKE_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); } impl FilterReduce for FoRVTable { diff --git a/encodings/fsst/src/compute/mod.rs b/encodings/fsst/src/compute/mod.rs index 8146851f942..62230683e9f 100644 --- a/encodings/fsst/src/compute/mod.rs +++ b/encodings/fsst/src/compute/mod.rs @@ -7,13 +7,14 @@ mod filter; use vortex_array::Array; use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; use vortex_array::IntoArray; -use vortex_array::arrays::TakeReduce; -use vortex_array::arrays::TakeReduceAdaptor; +use vortex_array::arrays::TakeExecute; +use vortex_array::arrays::TakeExecuteAdaptor; use vortex_array::arrays::VarBinVTable; use vortex_array::compute::fill_null; use vortex_array::compute::take; -use vortex_array::optimizer::rules::ParentRuleSet; +use vortex_array::kernel::ParentKernelSet; use vortex_error::VortexResult; use vortex_scalar::Scalar; use vortex_scalar::ScalarValue; @@ -43,15 +44,19 @@ fn take_fsst(array: &FSSTArray, indices: &dyn Array) -> VortexResult { .into_array()) } -impl TakeReduce for FSSTVTable { - fn take(array: &FSSTArray, indices: &dyn Array) -> VortexResult> { +impl TakeExecute for FSSTVTable { + fn take( + array: &FSSTArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { take_fsst(array, indices).map(Some) } } impl FSSTVTable { - pub const TAKE_RULES: ParentRuleSet = - ParentRuleSet::new(&[ParentRuleSet::lift(&TakeReduceAdaptor::(Self))]); + pub const TAKE_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); } #[cfg(test)] diff --git a/encodings/sparse/src/compute/take.rs b/encodings/sparse/src/compute/take.rs index c63bedc4915..a84226629d9 100644 --- a/encodings/sparse/src/compute/take.rs +++ b/encodings/sparse/src/compute/take.rs @@ -3,11 +3,12 @@ use vortex_array::Array; use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::ConstantArray; -use vortex_array::arrays::TakeReduce; -use vortex_array::arrays::TakeReduceAdaptor; -use vortex_array::optimizer::rules::ParentRuleSet; +use vortex_array::arrays::TakeExecute; +use vortex_array::arrays::TakeExecuteAdaptor; +use vortex_array::kernel::ParentKernelSet; use vortex_error::VortexResult; use crate::SparseArray; @@ -45,15 +46,19 @@ fn take_sparse(array: &SparseArray, take_indices: &dyn Array) -> VortexResult VortexResult> { +impl TakeExecute for SparseVTable { + fn take( + array: &SparseArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { take_sparse(array, indices).map(Some) } } impl SparseVTable { - pub const TAKE_RULES: ParentRuleSet = - ParentRuleSet::new(&[ParentRuleSet::lift(&TakeReduceAdaptor::(Self))]); + pub const TAKE_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); } #[cfg(test)] diff --git a/encodings/zigzag/src/compute/mod.rs b/encodings/zigzag/src/compute/mod.rs index 40a4cbb05e5..059ebae4001 100644 --- a/encodings/zigzag/src/compute/mod.rs +++ b/encodings/zigzag/src/compute/mod.rs @@ -5,15 +5,16 @@ mod cast; use vortex_array::Array; use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::FilterReduce; -use vortex_array::arrays::TakeReduce; -use vortex_array::arrays::TakeReduceAdaptor; +use vortex_array::arrays::TakeExecute; +use vortex_array::arrays::TakeExecuteAdaptor; use vortex_array::compute::MaskKernel; use vortex_array::compute::MaskKernelAdapter; use vortex_array::compute::mask; use vortex_array::compute::take; -use vortex_array::optimizer::rules::ParentRuleSet; +use vortex_array::kernel::ParentKernelSet; use vortex_array::register_kernel; use vortex_error::VortexResult; use vortex_mask::Mask; @@ -33,15 +34,19 @@ fn take_zigzag(array: &ZigZagArray, indices: &dyn Array) -> VortexResult VortexResult> { +impl TakeExecute for ZigZagVTable { + fn take( + array: &ZigZagArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { take_zigzag(array, indices).map(Some) } } impl ZigZagVTable { - pub const TAKE_RULES: ParentRuleSet = - ParentRuleSet::new(&[ParentRuleSet::lift(&TakeReduceAdaptor::(Self))]); + pub const TAKE_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); } impl MaskKernel for ZigZagVTable { diff --git a/vortex-array/src/arrays/dict/compute/mod.rs b/vortex-array/src/arrays/dict/compute/mod.rs index 0dbc99e38e0..82cedd0576e 100644 --- a/vortex-array/src/arrays/dict/compute/mod.rs +++ b/vortex-array/src/arrays/dict/compute/mod.rs @@ -17,14 +17,15 @@ use vortex_mask::Mask; use super::DictArray; use super::DictVTable; -use super::TakeReduce; -use super::TakeReduceAdaptor; +use super::TakeExecute; +use super::TakeExecuteAdaptor; use crate::Array; use crate::ArrayRef; +use crate::ExecutionCtx; use crate::IntoArray; use crate::arrays::filter::FilterReduce; use crate::compute::take; -use crate::optimizer::rules::ParentRuleSet; +use crate::kernel::ParentKernelSet; fn take_dict(array: &DictArray, indices: &dyn Array) -> VortexResult { let codes = take(array.codes(), indices)?; @@ -33,15 +34,19 @@ fn take_dict(array: &DictArray, indices: &dyn Array) -> VortexResult { Ok(unsafe { DictArray::new_unchecked(codes, array.values().clone()).into_array() }) } -impl TakeReduce for DictVTable { - fn take(array: &DictArray, indices: &dyn Array) -> VortexResult> { +impl TakeExecute for DictVTable { + fn take( + array: &DictArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { take_dict(array, indices).map(Some) } } impl DictVTable { - pub const TAKE_RULES: ParentRuleSet = - ParentRuleSet::new(&[ParentRuleSet::lift(&TakeReduceAdaptor::(Self))]); + pub const TAKE_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); } impl FilterReduce for DictVTable { diff --git a/vortex-array/src/arrays/dict/mod.rs b/vortex-array/src/arrays/dict/mod.rs index cf2c44febf3..7ba509e09e5 100644 --- a/vortex-array/src/arrays/dict/mod.rs +++ b/vortex-array/src/arrays/dict/mod.rs @@ -19,6 +19,9 @@ mod execute; pub use execute::take_canonical; +mod take; +pub use take::*; + pub mod vtable; pub use vtable::*; diff --git a/vortex-array/src/arrays/take/kernel.rs b/vortex-array/src/arrays/dict/take.rs similarity index 70% rename from vortex-array/src/arrays/take/kernel.rs rename to vortex-array/src/arrays/dict/take.rs index c37ff5463fa..d90f67afda7 100644 --- a/vortex-array/src/arrays/take/kernel.rs +++ b/vortex-array/src/arrays/dict/take.rs @@ -3,13 +3,14 @@ use vortex_error::VortexResult; +use super::DictArray; +use super::DictVTable; use crate::Array; use crate::ArrayRef; use crate::Canonical; use crate::ExecutionCtx; use crate::IntoArray; -use crate::arrays::TakeArray; -use crate::arrays::TakeVTable; +use crate::compute::propagate_take_stats; use crate::kernel::ExecuteParentKernel; use crate::matcher::Matcher; use crate::optimizer::rules::ArrayParentReduceRule; @@ -54,6 +55,15 @@ pub fn precondition(array: &V::Array, indices: &dyn Array) -> Option< return Some(Canonical::empty(array.dtype()).into_array()); } + // TODO(joe): shall we enable this seems expensive. + // if indices.all_invalid()? { + // return Ok( + // ConstantArray::new(Scalar::null(array.dtype().as_nullable()), indices.len()) + // .into_array() + // .into(), + // ); + // } + None } @@ -64,19 +74,23 @@ impl ArrayParentReduceRule for TakeReduceAdaptor where V: TakeReduce, { - type Parent = TakeVTable; + type Parent = DictVTable; fn reduce_parent( &self, array: &V::Array, - parent: &TakeArray, + parent: &DictArray, child_idx: usize, ) -> VortexResult> { - assert_eq!(child_idx, 0); - if let Some(result) = precondition::(array, parent.indices()) { + assert_eq!(child_idx, 1); + if let Some(result) = precondition::(array, parent.codes()) { return Ok(Some(result)); } - ::take(array, parent.indices()) + let result = ::take(array, parent.codes())?; + if let Some(ref taken) = result { + propagate_take_stats(&**array, taken.as_ref(), parent.codes())?; + } + Ok(result) } } @@ -87,7 +101,7 @@ impl ExecuteParentKernel for TakeExecuteAdaptor where V: TakeExecute, { - type Parent = TakeVTable; + type Parent = DictVTable; fn execute_parent( &self, @@ -96,10 +110,14 @@ where child_idx: usize, ctx: &mut ExecutionCtx, ) -> VortexResult> { - assert_eq!(child_idx, 0); - if let Some(result) = precondition::(array, parent.indices()) { + assert_eq!(child_idx, 1); + if let Some(result) = precondition::(array, parent.codes()) { return Ok(Some(result)); } - ::take(array, parent.indices(), ctx) + let result = ::take(array, parent.codes(), ctx)?; + if let Some(ref taken) = result { + propagate_take_stats(&**array, taken.as_ref(), parent.codes())?; + } + Ok(result) } } diff --git a/vortex-array/src/arrays/extension/compute/take.rs b/vortex-array/src/arrays/extension/compute/take.rs index dbe4555e61d..b301ebfcd9f 100644 --- a/vortex-array/src/arrays/extension/compute/take.rs +++ b/vortex-array/src/arrays/extension/compute/take.rs @@ -5,13 +5,14 @@ use vortex_error::VortexResult; use crate::Array; use crate::ArrayRef; +use crate::ExecutionCtx; use crate::IntoArray; use crate::arrays::ExtensionArray; use crate::arrays::ExtensionVTable; -use crate::arrays::TakeReduce; -use crate::arrays::TakeReduceAdaptor; +use crate::arrays::TakeExecute; +use crate::arrays::TakeExecuteAdaptor; use crate::compute::{self}; -use crate::optimizer::rules::ParentRuleSet; +use crate::kernel::ParentKernelSet; fn take_extension(array: &ExtensionArray, indices: &dyn Array) -> VortexResult { let taken_storage = compute::take(array.storage(), indices)?; @@ -24,13 +25,17 @@ fn take_extension(array: &ExtensionArray, indices: &dyn Array) -> VortexResult VortexResult> { +impl TakeExecute for ExtensionVTable { + fn take( + array: &ExtensionArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { take_extension(array, indices).map(Some) } } impl ExtensionVTable { - pub const TAKE_RULES: ParentRuleSet = - ParentRuleSet::new(&[ParentRuleSet::lift(&TakeReduceAdaptor::(Self))]); + pub const TAKE_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); } diff --git a/vortex-array/src/arrays/masked/compute/take.rs b/vortex-array/src/arrays/masked/compute/take.rs index 10ce1609046..42a7a41d5d5 100644 --- a/vortex-array/src/arrays/masked/compute/take.rs +++ b/vortex-array/src/arrays/masked/compute/take.rs @@ -6,14 +6,15 @@ use vortex_scalar::Scalar; use crate::Array; use crate::ArrayRef; +use crate::ExecutionCtx; use crate::IntoArray; use crate::arrays::MaskedArray; use crate::arrays::MaskedVTable; -use crate::arrays::TakeReduce; -use crate::arrays::TakeReduceAdaptor; +use crate::arrays::TakeExecute; +use crate::arrays::TakeExecuteAdaptor; use crate::compute::fill_null; use crate::compute::take; -use crate::optimizer::rules::ParentRuleSet; +use crate::kernel::ParentKernelSet; use crate::vtable::ValidityHelper; fn take_masked(array: &MaskedArray, indices: &dyn Array) -> VortexResult { @@ -35,15 +36,19 @@ fn take_masked(array: &MaskedArray, indices: &dyn Array) -> VortexResult VortexResult> { +impl TakeExecute for MaskedVTable { + fn take( + array: &MaskedArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { take_masked(array, indices).map(Some) } } impl MaskedVTable { - pub const TAKE_RULES: ParentRuleSet = - ParentRuleSet::new(&[ParentRuleSet::lift(&TakeReduceAdaptor::(Self))]); + pub const TAKE_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); } #[cfg(test)] diff --git a/vortex-array/src/arrays/mod.rs b/vortex-array/src/arrays/mod.rs index d6914891a6e..0d5fa96c258 100644 --- a/vortex-array/src/arrays/mod.rs +++ b/vortex-array/src/arrays/mod.rs @@ -33,7 +33,6 @@ mod scalar_fn; mod shared; mod slice; mod struct_; -mod take; mod varbin; mod varbinview; @@ -60,6 +59,5 @@ pub use scalar_fn::*; pub use shared::*; pub use slice::*; pub use struct_::*; -pub use take::*; pub use varbin::*; pub use varbinview::*; diff --git a/vortex-array/src/arrays/struct_/compute/take.rs b/vortex-array/src/arrays/struct_/compute/take.rs index 3f6fcfddd2e..4367118e5fc 100644 --- a/vortex-array/src/arrays/struct_/compute/take.rs +++ b/vortex-array/src/arrays/struct_/compute/take.rs @@ -7,13 +7,14 @@ use vortex_scalar::Scalar; use crate::Array; use crate::ArrayRef; +use crate::ExecutionCtx; use crate::IntoArray; use crate::arrays::StructArray; use crate::arrays::StructVTable; -use crate::arrays::TakeReduce; -use crate::arrays::TakeReduceAdaptor; +use crate::arrays::TakeExecute; +use crate::arrays::TakeExecuteAdaptor; use crate::compute::{self}; -use crate::optimizer::rules::ParentRuleSet; +use crate::kernel::ParentKernelSet; use crate::validity::Validity; use crate::vtable::ValidityHelper; @@ -47,13 +48,17 @@ fn take_struct(array: &StructArray, indices: &dyn Array) -> VortexResult VortexResult> { +impl TakeExecute for StructVTable { + fn take( + array: &StructArray, + indices: &dyn Array, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { take_struct(array, indices).map(Some) } } impl StructVTable { - pub const TAKE_RULES: ParentRuleSet = - ParentRuleSet::new(&[ParentRuleSet::lift(&TakeReduceAdaptor::(Self))]); + pub const TAKE_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); } diff --git a/vortex-array/src/arrays/take/array.rs b/vortex-array/src/arrays/take/array.rs deleted file mode 100644 index acadee679e7..00000000000 --- a/vortex-array/src/arrays/take/array.rs +++ /dev/null @@ -1,70 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use vortex_error::VortexExpect; -use vortex_error::VortexResult; -use vortex_error::vortex_ensure; - -use crate::ArrayRef; -use crate::stats::ArrayStats; - -/// Decomposed parts of the take array. -pub struct TakeArrayParts { - /// Child array to take elements from - pub child: ArrayRef, - /// Indices specifying which elements to take from the child - pub indices: ArrayRef, -} - -/// A lazy array that represents taking elements from a child array at specified indices. -/// -/// The resulting array contains the elements at the positions specified by the indices. -#[derive(Clone, Debug)] -pub struct TakeArray { - /// The source array to take elements from. - pub(super) child: ArrayRef, - - /// The indices specifying which elements to take. - pub(super) indices: ArrayRef, - - /// The stats for this array. - pub(super) stats: ArrayStats, -} - -impl TakeArray { - pub fn new(array: ArrayRef, indices: ArrayRef) -> Self { - Self::try_new(array, indices).vortex_expect("TakeArray construction failed") - } - - pub fn try_new(array: ArrayRef, indices: ArrayRef) -> VortexResult { - vortex_ensure!( - indices.dtype().is_int(), - "TakeArray indices must have integer dtype, got {}", - indices.dtype() - ); - - Ok(Self { - child: array, - indices, - stats: ArrayStats::default(), - }) - } - - /// The child array to take elements from. - pub fn child(&self) -> &ArrayRef { - &self.child - } - - /// The indices specifying which elements to take. - pub fn indices(&self) -> &ArrayRef { - &self.indices - } - - /// Consume the array and return its individual components. - pub fn into_parts(self) -> TakeArrayParts { - TakeArrayParts { - child: self.child, - indices: self.indices, - } - } -} diff --git a/vortex-array/src/arrays/take/execute.rs b/vortex-array/src/arrays/take/execute.rs deleted file mode 100644 index c649413f086..00000000000 --- a/vortex-array/src/arrays/take/execute.rs +++ /dev/null @@ -1,59 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! Execution logic for [`TakeArray`]. -//! -//! The main entrypoint is [`execute_take`] which takes elements from any [`Canonical`] array. - -use vortex_error::VortexResult; -use vortex_scalar::Scalar; - -use crate::Array; -use crate::ArrayRef; -use crate::Canonical; -use crate::ExecutionCtx; -use crate::IntoArray; -use crate::arrays::ConstantArray; -use crate::arrays::TakeArray; -use crate::compute::take; - -/// Check for some fast-path execution conditions before calling [`execute_take`]. -pub(super) fn execute_take_fast_paths( - array: &TakeArray, - ctx: &mut ExecutionCtx, -) -> VortexResult> { - // If the indices are empty, the output is empty. - if array.indices.is_empty() { - return Ok(Some(Canonical::empty(array.dtype()))); - } - - // If all indices are invalid (null), return an array of nulls - if array.indices.all_invalid()? { - return Ok(Some( - ConstantArray::new( - Scalar::null(array.child.dtype().as_nullable()), - array.indices.len(), - ) - .into_array() - .execute(ctx)?, - )); - } - - // Also check if the source array itself is completely null - if array.child.validity_mask()?.true_count() == 0 { - return Ok(Some( - ConstantArray::new(Scalar::null(array.dtype().clone()), array.indices.len()) - .into_array() - .execute(ctx)?, - )); - } - - Ok(None) -} - -/// Take elements from a canonical array at the given indices, returning a new canonical array. -pub(super) fn execute_take(canonical: Canonical, indices: ArrayRef) -> VortexResult { - // For now, delegate to the compute take function and canonicalize the result - let taken = take(canonical.as_ref(), indices.as_ref())?; - taken.to_canonical() -} diff --git a/vortex-array/src/arrays/take/mod.rs b/vortex-array/src/arrays/take/mod.rs deleted file mode 100644 index 44438a71d79..00000000000 --- a/vortex-array/src/arrays/take/mod.rs +++ /dev/null @@ -1,19 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -mod array; -pub use array::TakeArray; -pub use array::TakeArrayParts; - -mod execute; - -mod kernel; -pub use kernel::TakeExecute; -pub use kernel::TakeExecuteAdaptor; -pub use kernel::TakeReduce; -pub use kernel::TakeReduceAdaptor; - -mod rules; - -mod vtable; -pub use vtable::TakeVTable; diff --git a/vortex-array/src/arrays/take/rules.rs b/vortex-array/src/arrays/take/rules.rs deleted file mode 100644 index 75d8c4e1cf4..00000000000 --- a/vortex-array/src/arrays/take/rules.rs +++ /dev/null @@ -1,83 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use vortex_error::VortexResult; - -use crate::Array; -use crate::ArrayRef; -use crate::IntoArray; -use crate::arrays::StructArray; -use crate::arrays::StructArrayParts; -use crate::arrays::StructVTable; -use crate::arrays::TakeArray; -use crate::arrays::TakeVTable; -use crate::optimizer::rules::ArrayParentReduceRule; -use crate::optimizer::rules::ArrayReduceRule; -use crate::optimizer::rules::ParentRuleSet; -use crate::optimizer::rules::ReduceRuleSet; - -pub(super) const PARENT_RULES: ParentRuleSet = - ParentRuleSet::new(&[ParentRuleSet::lift(&TakeTakeRule)]); - -pub(super) const RULES: ReduceRuleSet = ReduceRuleSet::new(&[&TakeStructRule]); - -/// A simple reduction rule that simplifies a [`TakeArray`] whose child is also a -/// [`TakeArray`]. -#[derive(Debug)] -struct TakeTakeRule; - -impl ArrayParentReduceRule for TakeTakeRule { - type Parent = TakeVTable; - - fn reduce_parent( - &self, - child: &TakeArray, - parent: &TakeArray, - _child_idx: usize, - ) -> VortexResult> { - // Take(Take(arr, indices1), indices2) = Take(arr, Take(indices1, indices2)) - // We need to take from the inner indices using the outer indices - let new_indices = child.indices.take(parent.indices.clone())?; - let new_array = child.child.take(new_indices)?; - - Ok(Some(new_array.into_array())) - } -} - -/// A reduce rule that pushes a take down into the fields of a StructArray. -#[derive(Debug)] -struct TakeStructRule; - -impl ArrayReduceRule for TakeStructRule { - fn reduce(&self, array: &TakeArray) -> VortexResult> { - let indices = array.indices(); - let Some(struct_array) = array.child().as_opt::() else { - return Ok(None); - }; - - let len = indices.len(); - let StructArrayParts { - fields, - struct_fields, - validity, - .. - } = struct_array.clone().into_parts(); - - let taken_validity = validity.take(indices)?; - - let taken_fields = fields - .iter() - .map(|field| field.take(indices.clone())) - .collect::>>()?; - - Ok(Some( - StructArray::new( - struct_fields.names().clone(), - taken_fields, - len, - taken_validity, - ) - .into_array(), - )) - } -} diff --git a/vortex-array/src/arrays/take/vtable.rs b/vortex-array/src/arrays/take/vtable.rs deleted file mode 100644 index 8d89b64aa7a..00000000000 --- a/vortex-array/src/arrays/take/vtable.rs +++ /dev/null @@ -1,213 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use std::fmt::Debug; -use std::fmt::Formatter; -use std::hash::Hasher; - -use vortex_dtype::DType; -use vortex_dtype::Nullability; -use vortex_dtype::PType; -use vortex_error::VortexExpect; -use vortex_error::VortexResult; -use vortex_error::vortex_bail; -use vortex_error::vortex_ensure; -use vortex_scalar::Scalar; - -use crate::Array; -use crate::ArrayBufferVisitor; -use crate::ArrayChildVisitor; -use crate::ArrayEq; -use crate::ArrayHash; -use crate::ArrayRef; -use crate::Canonical; -use crate::IntoArray; -use crate::Precision; -use crate::arrays::take::array::TakeArray; -use crate::arrays::take::execute::execute_take; -use crate::arrays::take::execute::execute_take_fast_paths; -use crate::arrays::take::rules::PARENT_RULES; -use crate::arrays::take::rules::RULES; -use crate::buffer::BufferHandle; -use crate::executor::ExecutionCtx; -use crate::serde::ArrayChildren; -use crate::stats::StatsSetRef; -use crate::validity::Validity; -use crate::vtable; -use crate::vtable::ArrayId; -use crate::vtable::BaseArrayVTable; -use crate::vtable::NotSupported; -use crate::vtable::OperationsVTable; -use crate::vtable::VTable; -use crate::vtable::ValidityVTable; -use crate::vtable::VisitorVTable; - -vtable!(Take); - -#[derive(Debug)] -pub struct TakeVTable; - -impl TakeVTable { - pub const ID: ArrayId = ArrayId::new_ref("vortex.take"); -} - -impl VTable for TakeVTable { - type Array = TakeArray; - type Metadata = TakeMetadata; - type ArrayVTable = Self; - type OperationsVTable = Self; - type ValidityVTable = Self; - type VisitorVTable = Self; - type ComputeVTable = NotSupported; - - fn id(_array: &Self::Array) -> ArrayId { - Self::ID - } - - fn metadata(_array: &Self::Array) -> VortexResult { - Ok(TakeMetadata) - } - - fn serialize(_metadata: Self::Metadata) -> VortexResult>> { - vortex_bail!("Take array is not serializable") - } - - fn deserialize(_bytes: &[u8]) -> VortexResult { - vortex_bail!("Take array is not serializable") - } - - fn build( - dtype: &DType, - len: usize, - _metadata: &TakeMetadata, - _buffers: &[BufferHandle], - children: &dyn ArrayChildren, - ) -> VortexResult { - // The indices child determines the length - use u64 as the index type - let indices_dtype = DType::Primitive(PType::U64, Nullability::Nullable); - let indices = children.get(1, &indices_dtype, len)?; - let child = children.get(0, dtype, 0)?; // child length is unknown from metadata - Ok(TakeArray { - child, - indices, - stats: Default::default(), - }) - } - - fn with_children(array: &mut Self::Array, children: Vec) -> VortexResult<()> { - vortex_ensure!( - children.len() == 2, - "TakeArray expects exactly 2 children, got {}", - children.len() - ); - let mut iter = children.into_iter(); - array.child = iter - .next() - .vortex_expect("children length already validated"); - array.indices = iter - .next() - .vortex_expect("children length already validated"); - Ok(()) - } - - fn canonicalize(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { - if let Some(canonical) = execute_take_fast_paths(array, ctx)? { - return Ok(canonical); - } - - // Execute both child and indices to canonical form - let child_canonical = array.child.clone().execute::(ctx)?; - let indices_canonical = array.indices.clone().execute::(ctx)?; - - // Execute the take operation - let canonical = execute_take(child_canonical, indices_canonical.into_array())?; - - // Verify the resulting length and type - let result_len = array.indices.len(); - vortex_ensure!( - canonical.as_ref().len() == result_len, - "Take result length mismatch: expected {}, got {}", - result_len, - canonical.as_ref().len() - ); - - Ok(canonical) - } - - fn reduce_parent( - array: &Self::Array, - parent: &ArrayRef, - child_idx: usize, - ) -> VortexResult> { - PARENT_RULES.evaluate(array, parent, child_idx) - } - - fn reduce(array: &Self::Array) -> VortexResult> { - RULES.evaluate(array) - } -} - -impl BaseArrayVTable for TakeVTable { - fn len(array: &TakeArray) -> usize { - array.indices.len() - } - - fn dtype(array: &TakeArray) -> &DType { - // The dtype is the child's dtype with potentially nullable from indices - // For now, return child's dtype - nullability adjustment happens during execution - array.child.dtype() - } - - fn stats(array: &TakeArray) -> StatsSetRef<'_> { - array.stats.to_ref(array.as_ref()) - } - - fn array_hash(array: &TakeArray, state: &mut H, precision: Precision) { - array.child.array_hash(state, precision); - array.indices.array_hash(state, precision); - } - - fn array_eq(array: &TakeArray, other: &TakeArray, precision: Precision) -> bool { - array.child.array_eq(&other.child, precision) - && array.indices.array_eq(&other.indices, precision) - } -} - -impl OperationsVTable for TakeVTable { - fn scalar_at(array: &TakeArray, index: usize) -> VortexResult { - // Get the index value at position `index` from the indices array - let idx_scalar = array.indices.scalar_at(index)?; - if idx_scalar.is_null() { - return Ok(Scalar::null(array.child.dtype().as_nullable())); - } - let idx: usize = idx_scalar.as_ref().try_into()?; - array.child.scalar_at(idx) - } -} - -impl ValidityVTable for TakeVTable { - fn validity(array: &TakeArray) -> VortexResult { - // The validity of a take array depends on both: - // 1. The validity of the indices (null indices produce null values) - // 2. The validity of the child at the taken positions - // We return the child's validity taken at the indices positions - array.child.validity()?.take(&array.indices) - } -} - -impl VisitorVTable for TakeVTable { - fn visit_buffers(_array: &TakeArray, _visitor: &mut dyn ArrayBufferVisitor) {} - - fn visit_children(array: &TakeArray, visitor: &mut dyn ArrayChildVisitor) { - visitor.visit_child("child", &array.child); - visitor.visit_child("indices", &array.indices); - } -} - -pub struct TakeMetadata; - -impl Debug for TakeMetadata { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "TakeMetadata") - } -} From f924dd0ab475fc4cd170934378140732f44480c5 Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Fri, 6 Feb 2026 10:33:37 +0000 Subject: [PATCH 07/21] wip Signed-off-by: Joe Isaacs --- encodings/alp/src/alp/compute/take.rs | 3 +- encodings/alp/src/alp_rd/compute/take.rs | 5 +- encodings/datetime-parts/src/compute/take.rs | 7 +- .../src/decimal_byte_parts/compute/take.rs | 3 +- .../fastlanes/src/bitpacking/compute/take.rs | 3 +- encodings/fastlanes/src/for/compute/mod.rs | 3 +- encodings/fsst/src/compute/mod.rs | 7 +- encodings/runend/src/compute/take.rs | 3 +- encodings/runend/src/compute/take_from.rs | 211 +++++++++++++++--- encodings/zigzag/src/compute/mod.rs | 3 +- vortex-array/src/arrays/dict/take.rs | 2 +- .../src/arrays/extension/compute/take.rs | 3 +- vortex-array/src/arrays/filter/kernel.rs | 2 +- .../arrays/fixed_size_list/compute/take.rs | 5 +- .../src/arrays/listview/compute/take.rs | 6 +- vortex-array/src/arrays/slice/mod.rs | 18 ++ .../src/arrays/struct_/compute/take.rs | 4 +- vortex-array/src/compute/mod.rs | 1 - 18 files changed, 227 insertions(+), 62 deletions(-) diff --git a/encodings/alp/src/alp/compute/take.rs b/encodings/alp/src/alp/compute/take.rs index f9bff38e486..daeb613871f 100644 --- a/encodings/alp/src/alp/compute/take.rs +++ b/encodings/alp/src/alp/compute/take.rs @@ -7,7 +7,6 @@ use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::TakeExecute; use vortex_array::arrays::TakeExecuteAdaptor; -use vortex_array::compute::take; use vortex_array::kernel::ParentKernelSet; use vortex_error::VortexResult; @@ -15,7 +14,7 @@ use crate::ALPArray; use crate::ALPVTable; fn take_alp(array: &ALPArray, indices: &dyn Array) -> VortexResult { - let taken_encoded = take(array.encoded(), indices)?; + let taken_encoded = array.encoded().take(indices.to_array())?; let taken_patches = array .patches() .map(|p| p.take(indices)) diff --git a/encodings/alp/src/alp_rd/compute/take.rs b/encodings/alp/src/alp_rd/compute/take.rs index aaa3cd2b909..7a4acb7dcc3 100644 --- a/encodings/alp/src/alp_rd/compute/take.rs +++ b/encodings/alp/src/alp_rd/compute/take.rs @@ -8,7 +8,6 @@ use vortex_array::IntoArray; use vortex_array::arrays::TakeExecute; use vortex_array::arrays::TakeExecuteAdaptor; use vortex_array::compute::fill_null; -use vortex_array::compute::take; use vortex_array::kernel::ParentKernelSet; use vortex_error::VortexResult; use vortex_scalar::Scalar; @@ -18,7 +17,7 @@ use crate::ALPRDArray; use crate::ALPRDVTable; fn take_alprd(array: &ALPRDArray, indices: &dyn Array) -> VortexResult { - let taken_left_parts = take(array.left_parts(), indices)?; + let taken_left_parts = array.left_parts().take(indices.to_array())?; let left_parts_exceptions = array .left_parts_patches() .map(|patches| patches.take(indices)) @@ -33,7 +32,7 @@ fn take_alprd(array: &ALPRDArray, indices: &dyn Array) -> VortexResult }) .transpose()?; let right_parts = fill_null( - &take(array.right_parts(), indices)?, + &array.right_parts().take(indices.to_array())?, &Scalar::new(array.right_parts().dtype().clone(), ScalarValue::from(0)), )?; diff --git a/encodings/datetime-parts/src/compute/take.rs b/encodings/datetime-parts/src/compute/take.rs index 21943948142..1a833ddf967 100644 --- a/encodings/datetime-parts/src/compute/take.rs +++ b/encodings/datetime-parts/src/compute/take.rs @@ -9,7 +9,6 @@ use vortex_array::ToCanonical; use vortex_array::arrays::TakeExecute; use vortex_array::arrays::TakeExecuteAdaptor; use vortex_array::compute::fill_null; -use vortex_array::compute::take; use vortex_array::expr::stats::Stat; use vortex_array::expr::stats::StatsProvider; use vortex_array::kernel::ParentKernelSet; @@ -25,9 +24,9 @@ fn take_datetime_parts(array: &DateTimePartsArray, indices: &dyn Array) -> Vorte // we go ahead and canonicalize here to avoid worst-case canonicalizing 3 separate times let indices = indices.to_primitive(); - let taken_days = take(array.days(), indices.as_ref())?; - let taken_seconds = take(array.seconds(), indices.as_ref())?; - let taken_subseconds = take(array.subseconds(), indices.as_ref())?; + let taken_days = array.days().take(indices.to_array())?; + let taken_seconds = array.seconds().take(indices.to_array())?; + let taken_subseconds = array.subseconds().take(indices.to_array())?; // Update the dtype if the nullability changed due to nullable indices let dtype = if taken_days.dtype().is_nullable() != array.dtype().is_nullable() { diff --git a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/take.rs b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/take.rs index f29aa4866ff..3f202972b7b 100644 --- a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/take.rs +++ b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/take.rs @@ -6,7 +6,6 @@ use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; use vortex_array::arrays::TakeExecute; use vortex_array::arrays::TakeExecuteAdaptor; -use vortex_array::compute::take; use vortex_array::kernel::ParentKernelSet; use vortex_error::VortexResult; @@ -17,7 +16,7 @@ fn take_decimal_byte_parts( array: &DecimalBytePartsArray, indices: &dyn Array, ) -> VortexResult { - DecimalBytePartsArray::try_new(take(&array.msp, indices)?, *array.decimal_dtype()) + DecimalBytePartsArray::try_new(array.msp.take(indices.to_array())?, *array.decimal_dtype()) .map(|a| a.to_array()) } diff --git a/encodings/fastlanes/src/bitpacking/compute/take.rs b/encodings/fastlanes/src/bitpacking/compute/take.rs index e2b8bf87b6f..4794286ab5f 100644 --- a/encodings/fastlanes/src/bitpacking/compute/take.rs +++ b/encodings/fastlanes/src/bitpacking/compute/take.rs @@ -13,7 +13,6 @@ use vortex_array::ToCanonical; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::TakeExecute; use vortex_array::arrays::TakeExecuteAdaptor; -use vortex_array::compute::take; use vortex_array::kernel::ParentKernelSet; use vortex_array::validity::Validity; use vortex_array::vtable::ValidityHelper; @@ -41,7 +40,7 @@ pub(super) const UNPACK_CHUNK_THRESHOLD: usize = 8; fn take_bitpacked(array: &BitPackedArray, indices: &dyn Array) -> VortexResult { // If the indices are large enough, it's faster to flatten and take the primitive array. if indices.len() * UNPACK_CHUNK_THRESHOLD > array.len() { - return take(array.to_primitive().as_ref(), indices); + return array.to_primitive().take(indices.to_array()); } // NOTE: we use the unsigned PType because all values in the BitPackedArray must diff --git a/encodings/fastlanes/src/for/compute/mod.rs b/encodings/fastlanes/src/for/compute/mod.rs index 7b24564f019..9ac8e97e761 100644 --- a/encodings/fastlanes/src/for/compute/mod.rs +++ b/encodings/fastlanes/src/for/compute/mod.rs @@ -13,7 +13,6 @@ use vortex_array::IntoArray; use vortex_array::arrays::FilterReduce; use vortex_array::arrays::TakeExecute; use vortex_array::arrays::TakeExecuteAdaptor; -use vortex_array::compute::take; use vortex_array::kernel::ParentKernelSet; use vortex_error::VortexResult; use vortex_mask::Mask; @@ -23,7 +22,7 @@ use crate::FoRVTable; fn take_for(array: &FoRArray, indices: &dyn Array) -> VortexResult { FoRArray::try_new( - take(array.encoded(), indices)?, + array.encoded().take(indices.to_array())?, array.reference_scalar().clone(), ) .map(|a| a.into_array()) diff --git a/encodings/fsst/src/compute/mod.rs b/encodings/fsst/src/compute/mod.rs index 62230683e9f..fb07d7e197e 100644 --- a/encodings/fsst/src/compute/mod.rs +++ b/encodings/fsst/src/compute/mod.rs @@ -13,7 +13,6 @@ use vortex_array::arrays::TakeExecute; use vortex_array::arrays::TakeExecuteAdaptor; use vortex_array::arrays::VarBinVTable; use vortex_array::compute::fill_null; -use vortex_array::compute::take; use vortex_array::kernel::ParentKernelSet; use vortex_error::VortexResult; use vortex_scalar::Scalar; @@ -30,11 +29,13 @@ fn take_fsst(array: &FSSTArray, indices: &dyn Array) -> VortexResult { .union_nullability(indices.dtype().nullability()), array.symbols().clone(), array.symbol_lengths().clone(), - take(array.codes().as_ref(), indices)? + array + .codes() + .take(indices.to_array())? .as_::() .clone(), fill_null( - &take(array.uncompressed_lengths(), indices)?, + &array.uncompressed_lengths().take(indices.to_array())?, &Scalar::new( array.uncompressed_lengths_dtype().clone(), ScalarValue::from(0), diff --git a/encodings/runend/src/compute/take.rs b/encodings/runend/src/compute/take.rs index d77da4b1a8d..9bb34b776e4 100644 --- a/encodings/runend/src/compute/take.rs +++ b/encodings/runend/src/compute/take.rs @@ -10,7 +10,6 @@ use vortex_array::ToCanonical; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::TakeExecute; use vortex_array::arrays::TakeExecuteAdaptor; -use vortex_array::compute::take; use vortex_array::kernel::ParentKernelSet; use vortex_array::search_sorted::SearchResult; use vortex_array::search_sorted::SearchSorted; @@ -96,7 +95,7 @@ pub fn take_indices_unchecked>( PrimitiveArray::new(buffer, validity.clone()) }); - take(array.values(), physical_indices.as_ref()) + array.values().take(physical_indices.to_array()) } #[cfg(test)] diff --git a/encodings/runend/src/compute/take_from.rs b/encodings/runend/src/compute/take_from.rs index 8bda288c3a2..f1127eb3e4d 100644 --- a/encodings/runend/src/compute/take_from.rs +++ b/encodings/runend/src/compute/take_from.rs @@ -1,54 +1,66 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use std::fmt::Debug; + use vortex_array::Array; use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; use vortex_array::IntoArray; -use vortex_array::compute::TakeFromKernel; -use vortex_array::compute::TakeFromKernelAdapter; -use vortex_array::compute::take; -use vortex_array::register_kernel; +use vortex_array::arrays::DictArray; +use vortex_array::arrays::DictVTable; +use vortex_array::kernel::ExecuteParentKernel; +use vortex_array::matcher::Matcher; use vortex_dtype::DType; use vortex_error::VortexResult; use crate::RunEndArray; use crate::RunEndVTable; -impl TakeFromKernel for RunEndVTable { - /// Takes values from the source array using run-end encoded indices. - /// - /// # Arguments - /// - /// * `indices` - Run-end encoded indices - /// * `source` - Array to take values from - /// - /// # Returns - /// - /// * `Ok(Some(source))` - If successful - /// * `Ok(None)` - If the source array has an unsupported dtype - /// - fn take_from( +// impl ParentKernelSet + +#[derive(Debug)] +struct RunEndVTableTakeFrom; + +impl ExecuteParentKernel for RunEndVTableTakeFrom { + type Parent = DictVTable; + + fn execute_parent( &self, - indices: &RunEndArray, - source: &dyn Array, + array: &RunEndArray, + dict: &DictArray, + child_idx: usize, + ctx: &mut ExecutionCtx, ) -> VortexResult> { + if child_idx != 0 { + return Ok(None); + } // Only `Primitive` and `Bool` are valid run-end value types. - TODO: Support additional DTypes - if !matches!(source.dtype(), DType::Primitive(_, _) | DType::Bool(_)) { + if !matches!(dict.dtype(), DType::Primitive(_, _) | DType::Bool(_)) { return Ok(None); } + println!("offset {}", array.offset()); + println!("len {}", array.len()); + + if true { + panic!("run end dict take") + } + + /// TODO: eager take and also slice offset + len + /// + /// // Transform the run-end encoding from storing indices to storing values // by taking values from `source` at positions specified by `indices.values()`. - let values = take(source, indices.values())?; // Create a new run-end array containing values as values, instead of indices as values. // SAFETY: we are copying ends from an existing valid RunEndArray let ree_array = unsafe { RunEndArray::new_unchecked( - indices.ends().clone(), - values, - indices.offset(), - indices.len(), + array.ends().clone(), + dict.values().take(array.values().clone())?, + array.offset(), + array.len(), ) }; @@ -56,4 +68,149 @@ impl TakeFromKernel for RunEndVTable { } } -register_kernel!(TakeFromKernelAdapter(RunEndVTable).lift()); +// impl TakeFromKernel for RunEndVTable { +// /// Takes values from the source array using run-end encoded indices. +// /// +// /// # Arguments +// /// +// /// * `indices` - Run-end encoded indices +// /// * `source` - Array to take values from +// /// +// /// # Returns +// /// +// /// * `Ok(Some(source))` - If successful +// /// * `Ok(None)` - If the source array has an unsupported dtype +// /// +// fn take_from( +// &self, +// indices: &RunEndArray, +// source: &dyn Array, +// ) -> VortexResult> { +// +// } +// } +// +// register_kernel!(TakeFromKernelAdapter(RunEndVTable).lift()); + +#[cfg(test)] +mod tests { + use vortex_array::Array; + use vortex_array::ExecutionCtx; + use vortex_array::IntoArray; + use vortex_array::arrays::DictArray; + use vortex_array::arrays::PrimitiveArray; + use vortex_array::assert_arrays_eq; + use vortex_array::kernel::ExecuteParentKernel; + use vortex_buffer::buffer; + use vortex_error::VortexResult; + use vortex_session::VortexSession; + + use super::RunEndVTableTakeFrom; + use crate::RunEndArray; + + /// Build a DictArray whose codes are run-end encoded. + /// + /// Input: `[2, 2, 2, 3, 3, 2, 2]` + /// Dict values: `[2, 3]` + /// Codes: `[0, 0, 0, 1, 1, 0, 0]` + /// RunEnd encoded codes: ends=`[3, 5, 7]`, values=`[0, 1, 0]` + fn make_dict_with_runend_codes() -> (RunEndArray, DictArray) { + let codes = RunEndArray::encode(buffer![0u32, 0, 0, 1, 1, 0, 0].into_array()).unwrap(); + let values = buffer![2i32, 3].into_array(); + let dict = DictArray::try_new(codes.clone().into_array(), values).unwrap(); + (codes, dict) + } + + #[test] + fn test_execute_parent_no_offset() -> VortexResult<()> { + let (codes, dict) = make_dict_with_runend_codes(); + let mut ctx = ExecutionCtx::new(VortexSession::empty()); + + let result = RunEndVTableTakeFrom + .execute_parent(&codes, &dict, 0, &mut ctx)? + .expect("kernel should return Some"); + + let expected = PrimitiveArray::from_iter([2i32, 2, 2, 3, 3, 2, 2]); + assert_arrays_eq!(result.to_canonical()?.into_array(), expected); + Ok(()) + } + + #[test] + fn test_execute_parent_with_offset() -> VortexResult<()> { + let (codes, dict) = make_dict_with_runend_codes(); + // Slice codes to positions 2..5 → logical codes [0, 1, 1] → values [2, 3, 3] + let sliced_codes = unsafe { + RunEndArray::new_unchecked( + codes.ends().clone(), + codes.values().clone(), + 2, // offset + 3, // len + ) + }; + let mut ctx = ExecutionCtx::new(VortexSession::empty()); + + let result = RunEndVTableTakeFrom + .execute_parent(&sliced_codes, &dict, 0, &mut ctx)? + .expect("kernel should return Some"); + + let expected = PrimitiveArray::from_iter([2i32, 3, 3]); + assert_arrays_eq!(result.to_canonical()?.into_array(), expected); + Ok(()) + } + + #[test] + fn test_execute_parent_offset_at_run_boundary() -> VortexResult<()> { + let (codes, dict) = make_dict_with_runend_codes(); + // Slice codes to positions 3..7 → logical codes [1, 1, 0, 0] → values [3, 3, 2, 2] + let sliced_codes = unsafe { + RunEndArray::new_unchecked( + codes.ends().clone(), + codes.values().clone(), + 3, // offset at exact run boundary + 4, // len + ) + }; + let mut ctx = ExecutionCtx::new(VortexSession::empty()); + + let result = RunEndVTableTakeFrom + .execute_parent(&sliced_codes, &dict, 0, &mut ctx)? + .expect("kernel should return Some"); + + let expected = PrimitiveArray::from_iter([3i32, 3, 2, 2]); + assert_arrays_eq!(result.to_canonical()?.into_array(), expected); + Ok(()) + } + + #[test] + fn test_execute_parent_single_element_offset() -> VortexResult<()> { + let (codes, dict) = make_dict_with_runend_codes(); + // Slice to single element at position 4 → code=1 → value=3 + let sliced_codes = unsafe { + RunEndArray::new_unchecked( + codes.ends().slice(1..3)?, + codes.values().slice(1..3)?, + 4, // offset + 1, // len + ) + }; + let mut ctx = ExecutionCtx::new(VortexSession::empty()); + + let result = RunEndVTableTakeFrom + .execute_parent(&sliced_codes, &dict, 0, &mut ctx)? + .expect("kernel should return Some"); + + let expected = PrimitiveArray::from_iter([3i32]); + assert_arrays_eq!(result.to_canonical()?.into_array(), expected); + Ok(()) + } + + #[test] + fn test_execute_parent_returns_none_for_non_codes_child() -> VortexResult<()> { + let (codes, dict) = make_dict_with_runend_codes(); + let mut ctx = ExecutionCtx::new(VortexSession::empty()); + + let result = RunEndVTableTakeFrom.execute_parent(&codes, &dict, 1, &mut ctx)?; + assert!(result.is_none()); + Ok(()) + } +} diff --git a/encodings/zigzag/src/compute/mod.rs b/encodings/zigzag/src/compute/mod.rs index 059ebae4001..14f7b1aa978 100644 --- a/encodings/zigzag/src/compute/mod.rs +++ b/encodings/zigzag/src/compute/mod.rs @@ -13,7 +13,6 @@ use vortex_array::arrays::TakeExecuteAdaptor; use vortex_array::compute::MaskKernel; use vortex_array::compute::MaskKernelAdapter; use vortex_array::compute::mask; -use vortex_array::compute::take; use vortex_array::kernel::ParentKernelSet; use vortex_array::register_kernel; use vortex_error::VortexResult; @@ -30,7 +29,7 @@ impl FilterReduce for ZigZagVTable { } fn take_zigzag(array: &ZigZagArray, indices: &dyn Array) -> VortexResult { - let encoded = take(array.encoded(), indices)?; + let encoded = array.encoded().take(indices.to_array())?; Ok(ZigZagArray::try_new(encoded)?.into_array()) } diff --git a/vortex-array/src/arrays/dict/take.rs b/vortex-array/src/arrays/dict/take.rs index d90f67afda7..13af8f75de0 100644 --- a/vortex-array/src/arrays/dict/take.rs +++ b/vortex-array/src/arrays/dict/take.rs @@ -49,7 +49,7 @@ pub trait TakeExecute: VTable { /// /// Returns `Some(result)` if the precondition short-circuits the take operation, /// or `None` if the take should proceed normally. -pub fn precondition(array: &V::Array, indices: &dyn Array) -> Option { +fn precondition(array: &V::Array, indices: &dyn Array) -> Option { // Fast-path for empty indices. if indices.is_empty() { return Some(Canonical::empty(array.dtype()).into_array()); diff --git a/vortex-array/src/arrays/extension/compute/take.rs b/vortex-array/src/arrays/extension/compute/take.rs index b301ebfcd9f..b6dd7c7e394 100644 --- a/vortex-array/src/arrays/extension/compute/take.rs +++ b/vortex-array/src/arrays/extension/compute/take.rs @@ -11,11 +11,10 @@ use crate::arrays::ExtensionArray; use crate::arrays::ExtensionVTable; use crate::arrays::TakeExecute; use crate::arrays::TakeExecuteAdaptor; -use crate::compute::{self}; use crate::kernel::ParentKernelSet; fn take_extension(array: &ExtensionArray, indices: &dyn Array) -> VortexResult { - let taken_storage = compute::take(array.storage(), indices)?; + let taken_storage = array.storage().take(indices.to_array())?; Ok(ExtensionArray::new( array .ext_dtype() diff --git a/vortex-array/src/arrays/filter/kernel.rs b/vortex-array/src/arrays/filter/kernel.rs index 174530caab8..fb1e6ab9fa4 100644 --- a/vortex-array/src/arrays/filter/kernel.rs +++ b/vortex-array/src/arrays/filter/kernel.rs @@ -54,7 +54,7 @@ pub trait FilterKernel: VTable { /// /// Returns `Some(result)` if the precondition short-circuits the filter operation, /// or `None` if the filter should proceed normally. -pub fn precondition(array: &V::Array, mask: &Mask) -> Option { +fn precondition(array: &V::Array, mask: &Mask) -> Option { let true_count = mask.true_count(); // Fast-path for empty mask (all false). diff --git a/vortex-array/src/arrays/fixed_size_list/compute/take.rs b/vortex-array/src/arrays/fixed_size_list/compute/take.rs index 632c4e9850f..3288e1bbfae 100644 --- a/vortex-array/src/arrays/fixed_size_list/compute/take.rs +++ b/vortex-array/src/arrays/fixed_size_list/compute/take.rs @@ -18,7 +18,6 @@ use crate::arrays::FixedSizeListVTable; use crate::arrays::PrimitiveArray; use crate::arrays::TakeExecute; use crate::arrays::TakeExecuteAdaptor; -use crate::compute::{self}; use crate::executor::ExecutionCtx; use crate::kernel::ParentKernelSet; use crate::validity::Validity; @@ -126,7 +125,7 @@ fn take_non_nullable_fsl( debug_assert_eq!(elements_indices.len(), new_len * list_size); let elements_indices_array = PrimitiveArray::new(elements_indices, Validity::NonNullable); - let new_elements = compute::take(array.elements(), elements_indices_array.as_ref())?; + let new_elements = array.elements().take(elements_indices_array.to_array())?; debug_assert_eq!(new_elements.len(), new_len * list_size); // Both inputs are non-nullable, so the result is non-nullable. @@ -193,7 +192,7 @@ fn take_nullable_fsl( debug_assert_eq!(elements_indices.len(), new_len * list_size); let elements_indices_array = PrimitiveArray::new(elements_indices, Validity::NonNullable); - let new_elements = compute::take(array.elements(), elements_indices_array.as_ref())?; + let new_elements = array.elements().take(elements_indices_array.to_array())?; debug_assert_eq!(new_elements.len(), new_len * list_size); // At least one input was nullable, so the result is nullable. diff --git a/vortex-array/src/arrays/listview/compute/take.rs b/vortex-array/src/arrays/listview/compute/take.rs index 699768eef0b..090bb15536a 100644 --- a/vortex-array/src/arrays/listview/compute/take.rs +++ b/vortex-array/src/arrays/listview/compute/take.rs @@ -15,7 +15,7 @@ use crate::arrays::ListViewRebuildMode; use crate::arrays::ListViewVTable; use crate::arrays::TakeExecute; use crate::arrays::TakeExecuteAdaptor; -use crate::compute::{self}; +use crate::compute; use crate::executor::ExecutionCtx; use crate::kernel::ParentKernelSet; use crate::vtable::ValidityHelper; @@ -55,8 +55,8 @@ fn take_listview(array: &ListViewArray, indices: &dyn Array) -> VortexResult VortexResult>; } +fn precondition(array: &V::Array, range: &Range) -> Option { + if range.start == 0 && range.end == array.len() { + return Some(array.to_array()); + }; + if range.start == range.end { + return Some(Canonical::empty(array.dtype()).into_array()); + } + None +} + #[derive(Default, Debug)] pub struct SliceReduceAdaptor(pub V); @@ -68,6 +80,9 @@ where child_idx: usize, ) -> VortexResult> { assert_eq!(child_idx, 0); + if let Some(result) = precondition::(array, &parent.range) { + return Ok(Some(result)); + } ::slice(array, parent.range.clone()) } } @@ -89,6 +104,9 @@ where ctx: &mut ExecutionCtx, ) -> VortexResult> { assert_eq!(child_idx, 0); + if let Some(result) = precondition::(array, &parent.range) { + return Ok(Some(result)); + } ::slice(array, parent.range.clone(), ctx) } } diff --git a/vortex-array/src/arrays/struct_/compute/take.rs b/vortex-array/src/arrays/struct_/compute/take.rs index 4367118e5fc..db694e0da73 100644 --- a/vortex-array/src/arrays/struct_/compute/take.rs +++ b/vortex-array/src/arrays/struct_/compute/take.rs @@ -13,7 +13,7 @@ use crate::arrays::StructArray; use crate::arrays::StructVTable; use crate::arrays::TakeExecute; use crate::arrays::TakeExecuteAdaptor; -use crate::compute::{self}; +use crate::compute; use crate::kernel::ParentKernelSet; use crate::validity::Validity; use crate::vtable::ValidityHelper; @@ -39,7 +39,7 @@ fn take_struct(array: &StructArray, indices: &dyn Array) -> VortexResult, _>>()?, array.struct_fields().clone(), indices.len(), diff --git a/vortex-array/src/compute/mod.rs b/vortex-array/src/compute/mod.rs index 9c6047db699..b88b00413fa 100644 --- a/vortex-array/src/compute/mod.rs +++ b/vortex-array/src/compute/mod.rs @@ -98,7 +98,6 @@ pub fn warm_up_vtables() { nan_count::warm_up_vtable(); numeric::warm_up_vtable(); sum::warm_up_vtable(); - take::warm_up_vtable(); zip::warm_up_vtable(); } From 49c0e4cfbae773cdc1ab796f3465669ec013d6fa Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Fri, 6 Feb 2026 12:20:24 +0000 Subject: [PATCH 08/21] wip Signed-off-by: Joe Isaacs --- encodings/alp/src/alp/array.rs | 10 +- encodings/alp/src/alp_rd/array.rs | 8 +- .../fastlanes/src/bitpacking/vtable/mod.rs | 8 +- encodings/runend/src/compute/take_from.rs | 51 ++--- encodings/sequence/src/array.rs | 1 + vortex-array/src/arrays/chunked/vtable/mod.rs | 6 +- vortex-array/src/arrays/dict/execute.rs | 200 +++++++----------- vortex-array/src/arrays/dict/vtable/mod.rs | 2 +- vortex-array/src/arrays/varbin/vtable/mod.rs | 8 +- 9 files changed, 119 insertions(+), 175 deletions(-) diff --git a/encodings/alp/src/alp/array.rs b/encodings/alp/src/alp/array.rs index 89568f02ddd..6712ba93e01 100644 --- a/encodings/alp/src/alp/array.rs +++ b/encodings/alp/src/alp/array.rs @@ -13,6 +13,7 @@ use vortex_array::ArrayRef; use vortex_array::Canonical; use vortex_array::DeserializeMetadata; use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; use vortex_array::Precision; use vortex_array::ProstMetadata; use vortex_array::SerializeMetadata; @@ -25,7 +26,6 @@ use vortex_array::stats::StatsSetRef; use vortex_array::vtable; use vortex_array::vtable::ArrayId; use vortex_array::vtable::BaseArrayVTable; -use vortex_array::vtable::NotSupported; use vortex_array::vtable::VTable; use vortex_array::vtable::ValidityChild; use vortex_array::vtable::ValidityVTableFromChild; @@ -54,7 +54,6 @@ impl VTable for ALPVTable { type OperationsVTable = Self; type ValidityVTable = ValidityVTableFromChild; type VisitorVTable = Self; - type ComputeVTable = NotSupported; fn id(_array: &Self::Array) -> ArrayId { Self::ID @@ -165,12 +164,9 @@ impl VTable for ALPVTable { Ok(()) } - fn canonicalize(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { + fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { // TODO(joe): take by value - Ok(Canonical::Primitive(execute_decompress( - array.clone(), - ctx, - )?)) + Ok(execute_decompress(array.clone(), ctx)?.into_array()) } fn execute_parent( diff --git a/encodings/alp/src/alp_rd/array.rs b/encodings/alp/src/alp_rd/array.rs index 70e885f4e17..42a5e253716 100644 --- a/encodings/alp/src/alp_rd/array.rs +++ b/encodings/alp/src/alp_rd/array.rs @@ -11,9 +11,9 @@ use vortex_array::ArrayChildVisitor; use vortex_array::ArrayEq; use vortex_array::ArrayHash; use vortex_array::ArrayRef; -use vortex_array::Canonical; use vortex_array::DeserializeMetadata; use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; use vortex_array::Precision; use vortex_array::ProstMetadata; use vortex_array::SerializeMetadata; @@ -28,7 +28,6 @@ use vortex_array::validity::Validity; use vortex_array::vtable; use vortex_array::vtable::ArrayId; use vortex_array::vtable::BaseArrayVTable; -use vortex_array::vtable::NotSupported; use vortex_array::vtable::VTable; use vortex_array::vtable::ValidityChild; use vortex_array::vtable::ValidityVTableFromChild; @@ -72,7 +71,6 @@ impl VTable for ALPRDVTable { type OperationsVTable = Self; type ValidityVTable = ValidityVTableFromChild; type VisitorVTable = Self; - type ComputeVTable = NotSupported; fn id(_array: &Self::Array) -> ArrayId { Self::ID @@ -218,7 +216,7 @@ impl VTable for ALPRDVTable { Ok(()) } - fn canonicalize(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { + fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { let left_parts = array.left_parts().clone().execute::(ctx)?; let right_parts = array.right_parts().clone().execute::(ctx)?; @@ -257,7 +255,7 @@ impl VTable for ALPRDVTable { ) }; - Ok(Canonical::Primitive(decoded_array)) + Ok(decoded_array.into_array()) } fn execute_parent( diff --git a/encodings/fastlanes/src/bitpacking/vtable/mod.rs b/encodings/fastlanes/src/bitpacking/vtable/mod.rs index 3f42fee7f50..1dfec786da7 100644 --- a/encodings/fastlanes/src/bitpacking/vtable/mod.rs +++ b/encodings/fastlanes/src/bitpacking/vtable/mod.rs @@ -2,9 +2,9 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_array::ArrayRef; -use vortex_array::Canonical; use vortex_array::DeserializeMetadata; use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; use vortex_array::ProstMetadata; use vortex_array::SerializeMetadata; use vortex_array::buffer::BufferHandle; @@ -15,7 +15,6 @@ use vortex_array::serde::ArrayChildren; use vortex_array::validity::Validity; use vortex_array::vtable; use vortex_array::vtable::ArrayId; -use vortex_array::vtable::NotSupported; use vortex_array::vtable::VTable; use vortex_array::vtable::ValidityVTableFromValidityHelper; use vortex_dtype::DType; @@ -58,7 +57,6 @@ impl VTable for BitPackedVTable { type OperationsVTable = Self; type ValidityVTable = ValidityVTableFromValidityHelper; type VisitorVTable = Self; - type ComputeVTable = NotSupported; fn id(_array: &Self::Array) -> ArrayId { Self::ID @@ -246,8 +244,8 @@ impl VTable for BitPackedVTable { }) } - fn canonicalize(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { - Ok(Canonical::Primitive(unpack_array(array, ctx)?)) + fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { + Ok(unpack_array(array, ctx)?.into_array()) } fn execute_parent( diff --git a/encodings/runend/src/compute/take_from.rs b/encodings/runend/src/compute/take_from.rs index f1127eb3e4d..bcabac5db5a 100644 --- a/encodings/runend/src/compute/take_from.rs +++ b/encodings/runend/src/compute/take_from.rs @@ -27,44 +27,39 @@ impl ExecuteParentKernel for RunEndVTableTakeFrom { fn execute_parent( &self, - array: &RunEndArray, + _array: &RunEndArray, dict: &DictArray, child_idx: usize, - ctx: &mut ExecutionCtx, + _ctx: &mut ExecutionCtx, ) -> VortexResult> { if child_idx != 0 { return Ok(None); } - // Only `Primitive` and `Bool` are valid run-end value types. - TODO: Support additional DTypes + // Only `Primitive` and `Bool` are valid run-end value types. + // TODO: Support additional DTypes if !matches!(dict.dtype(), DType::Primitive(_, _) | DType::Bool(_)) { return Ok(None); } - println!("offset {}", array.offset()); - println!("len {}", array.len()); - - if true { - panic!("run end dict take") - } - - /// TODO: eager take and also slice offset + len - /// - /// - // Transform the run-end encoding from storing indices to storing values - // by taking values from `source` at positions specified by `indices.values()`. - - // Create a new run-end array containing values as values, instead of indices as values. - // SAFETY: we are copying ends from an existing valid RunEndArray - let ree_array = unsafe { - RunEndArray::new_unchecked( - array.ends().clone(), - dict.values().take(array.values().clone())?, - array.offset(), - array.len(), - ) - }; - - Ok(Some(ree_array.into_array())) + // // Transform the run-end encoding from storing indices to storing values + // // by taking values from `source` at positions specified by `indices.values()`. + // + // // Create a new run-end array containing values as values, instead of indices as values. + // // SAFETY: we are copying ends from an existing valid RunEndArray + // let ree_array = unsafe { + // RunEndArray::new_unchecked( + // array.ends().clone(), + // dict.values().take(array.values().clone())?, + // array.offset(), + // array.len(), + // ) + // }; + // + // Ok(Some(ree_array.into_array())) + + // TODO: implement run-end take from optimization + // For now, skip this optimization and fall back to default take + Ok(None) } } diff --git a/encodings/sequence/src/array.rs b/encodings/sequence/src/array.rs index 4ded8b71b42..22e5e078fec 100644 --- a/encodings/sequence/src/array.rs +++ b/encodings/sequence/src/array.rs @@ -9,6 +9,7 @@ use vortex_array::ArrayChildVisitor; use vortex_array::ArrayRef; use vortex_array::DeserializeMetadata; use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; use vortex_array::Precision; use vortex_array::ProstMetadata; use vortex_array::SerializeMetadata; diff --git a/vortex-array/src/arrays/chunked/vtable/mod.rs b/vortex-array/src/arrays/chunked/vtable/mod.rs index 96b39afe5f5..eea3f954032 100644 --- a/vortex-array/src/arrays/chunked/vtable/mod.rs +++ b/vortex-array/src/arrays/chunked/vtable/mod.rs @@ -31,7 +31,6 @@ use crate::vtable::VTable; mod array; mod canonical; -mod compute; mod operations; mod validity; mod visitor; @@ -54,7 +53,6 @@ impl VTable for ChunkedVTable { type OperationsVTable = Self; type ValidityVTable = Self; type VisitorVTable = Self; - type ComputeVTable = Self; fn id(_array: &Self::Array) -> ArrayId { Self::ID @@ -169,8 +167,8 @@ impl VTable for ChunkedVTable { Ok(()) } - fn canonicalize(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { - _canonicalize(array, ctx) + fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { + Ok(_canonicalize(array, ctx)?.into_array()) } fn reduce(array: &Self::Array) -> VortexResult> { diff --git a/vortex-array/src/arrays/dict/execute.rs b/vortex-array/src/arrays/dict/execute.rs index ae136cc99ac..0e1dfd506bb 100644 --- a/vortex-array/src/arrays/dict/execute.rs +++ b/vortex-array/src/arrays/dict/execute.rs @@ -7,185 +7,145 @@ use vortex_error::VortexExpect; use vortex_error::VortexResult; use crate::Canonical; +use crate::ExecutionCtx; use crate::arrays::BoolArray; use crate::arrays::BoolVTable; use crate::arrays::DecimalArray; use crate::arrays::DecimalVTable; use crate::arrays::ExtensionArray; +use crate::arrays::ExtensionVTable; use crate::arrays::FixedSizeListArray; use crate::arrays::FixedSizeListVTable; use crate::arrays::ListViewArray; use crate::arrays::ListViewVTable; use crate::arrays::NullArray; -use crate::arrays::NullVTable; use crate::arrays::PrimitiveArray; use crate::arrays::PrimitiveVTable; use crate::arrays::StructArray; use crate::arrays::StructVTable; +use crate::arrays::TakeExecute; use crate::arrays::VarBinViewArray; use crate::arrays::VarBinViewVTable; -use crate::compute::take; -/// TODO: replace usage of compute fn. /// Take from a canonical array using indices (codes), returning a new canonical array. /// /// This is the core operation for dictionary decoding - it expands the dictionary /// by looking up each code in the values array. -pub fn take_canonical(values: Canonical, codes: &PrimitiveArray) -> VortexResult { +pub fn take_canonical( + values: Canonical, + codes: &PrimitiveArray, + ctx: &mut ExecutionCtx, +) -> VortexResult { Ok(match values { Canonical::Null(a) => Canonical::Null(take_null(&a, codes)), - Canonical::Bool(a) => Canonical::Bool(take_bool(&a, codes)?), - Canonical::Primitive(a) => Canonical::Primitive(take_primitive(&a, codes)), - Canonical::Decimal(a) => Canonical::Decimal(take_decimal(&a, codes)), - Canonical::VarBinView(a) => Canonical::VarBinView(take_varbinview(&a, codes)), - Canonical::List(a) => Canonical::List(take_listview(&a, codes)), - Canonical::FixedSizeList(a) => Canonical::FixedSizeList(take_fixed_size_list(&a, codes)), - Canonical::Struct(a) => Canonical::Struct(take_struct(&a, codes)), - Canonical::Extension(a) => Canonical::Extension(take_extension(&a, codes)), + Canonical::Bool(a) => Canonical::Bool(take_bool(&a, codes, ctx)?), + Canonical::Primitive(a) => Canonical::Primitive(take_primitive(&a, codes, ctx)), + Canonical::Decimal(a) => Canonical::Decimal(take_decimal(&a, codes, ctx)), + Canonical::VarBinView(a) => Canonical::VarBinView(take_varbinview(&a, codes, ctx)), + Canonical::List(a) => Canonical::List(take_listview(&a, codes, ctx)), + Canonical::FixedSizeList(a) => { + Canonical::FixedSizeList(take_fixed_size_list(&a, codes, ctx)) + } + Canonical::Struct(a) => Canonical::Struct(take_struct(&a, codes, ctx)), + Canonical::Extension(a) => Canonical::Extension(take_extension(&a, codes, ctx)), }) } -fn take_null(array: &NullArray, codes: &PrimitiveArray) -> NullArray { - take(array.as_ref(), codes.as_ref()) - .vortex_expect("take null array") - .as_::() - .clone() +/// Take for NullArray is trivial - just create a new NullArray with the new length. +fn take_null(_array: &NullArray, codes: &PrimitiveArray) -> NullArray { + NullArray::new(codes.len()) } -// pub(super) fn dict_bool_take(dict_array: &DictArray) -> VortexResult { -// let values = dict_array.values(); -// let codes = dict_array.codes(); -// let result_nullability = dict_array.dtype().nullability(); -// -// let bool_values = values.to_bool(); -// let result_validity = bool_values.validity_mask(); -// let bool_buffer = bool_values.bit_buffer(); -// let (first_match, second_match) = match result_validity.bit_buffer() { -// AllOr::All => { -// let mut indices_iter = bool_buffer.set_indices(); -// (indices_iter.next(), indices_iter.next()) -// } -// AllOr::None => (None, None), -// AllOr::Some(v) => { -// let mut indices_iter = bool_buffer.set_indices().filter(|i| v.value(*i)); -// (indices_iter.next(), indices_iter.next()) -// } -// }; -// -// Ok(match (first_match, second_match) { -// // Couldn't find a value match, so the result is all false. -// (None, _) => match result_validity { -// Mask::AllTrue(_) => BoolArray::new( -// BitBuffer::new_unset(codes.len()), -// Validity::copy_from_array(codes).union_nullability(result_nullability), -// ) -// .to_canonical()?, -// Mask::AllFalse(_) => ConstantArray::new( -// Scalar::null(DType::Bool(Nullability::Nullable)), -// codes.len(), -// ) -// .to_canonical()?, -// Mask::Values(_) => BoolArray::new( -// BitBuffer::new_unset(codes.len()), -// Validity::from_mask(result_validity, result_nullability).take(codes)?, -// ) -// .to_canonical()?, -// }, -// // We found a single matching value so we can compare the codes directly. -// (Some(code), None) => match result_validity { -// Mask::AllTrue(_) => cast( -// &compare( -// codes, -// &cast( -// ConstantArray::new(code, codes.len()).as_ref(), -// codes.dtype(), -// )?, -// Operator::Eq, -// )?, -// &DType::Bool(result_nullability), -// )? -// .to_canonical()?, -// Mask::AllFalse(_) => ConstantArray::new( -// Scalar::null(DType::Bool(Nullability::Nullable)), -// codes.len(), -// ) -// .to_canonical()?, -// Mask::Values(rv) => mask( -// &compare( -// codes, -// &cast( -// ConstantArray::new(code, codes.len()).as_ref(), -// codes.dtype(), -// )?, -// Operator::Eq, -// )?, -// &Mask::from_buffer( -// take(BoolArray::from(rv.bit_buffer().clone()).as_ref(), codes)? -// .to_bool() -// .bit_buffer() -// .not(), -// ), -// )? -// .to_canonical()?, -// }, -// // More than one value matches. -// _ => take(bool_values.as_ref(), codes) -// .vortex_expect("taking codes from dictionary values shouldn't fail") -// .to_canonical()?, -// }) -// } - // TODO(joe): use dict_bool_take -fn take_bool(array: &BoolArray, codes: &PrimitiveArray) -> VortexResult { - Ok(take(array.as_ref(), codes.as_ref())? - .as_::() - .clone()) +fn take_bool( + array: &BoolArray, + codes: &PrimitiveArray, + ctx: &mut ExecutionCtx, +) -> VortexResult { + Ok( + ::take(array, codes.as_ref(), ctx)? + .vortex_expect("take bool should not return None") + .as_::() + .clone(), + ) } -fn take_primitive(array: &PrimitiveArray, codes: &PrimitiveArray) -> PrimitiveArray { - take(array.as_ref(), codes.as_ref()) +fn take_primitive( + array: &PrimitiveArray, + codes: &PrimitiveArray, + ctx: &mut ExecutionCtx, +) -> PrimitiveArray { + ::take(array, codes.as_ref(), ctx) .vortex_expect("take primitive array") + .vortex_expect("take primitive should not return None") .as_::() .clone() } -fn take_decimal(array: &DecimalArray, codes: &PrimitiveArray) -> DecimalArray { - take(array.as_ref(), codes.as_ref()) +fn take_decimal( + array: &DecimalArray, + codes: &PrimitiveArray, + ctx: &mut ExecutionCtx, +) -> DecimalArray { + ::take(array, codes.as_ref(), ctx) .vortex_expect("take decimal array") + .vortex_expect("take decimal should not return None") .as_::() .clone() } -fn take_varbinview(array: &VarBinViewArray, codes: &PrimitiveArray) -> VarBinViewArray { - take(array.as_ref(), codes.as_ref()) +fn take_varbinview( + array: &VarBinViewArray, + codes: &PrimitiveArray, + ctx: &mut ExecutionCtx, +) -> VarBinViewArray { + ::take(array, codes.as_ref(), ctx) .vortex_expect("take varbinview array") + .vortex_expect("take varbinview should not return None") .as_::() .clone() } -fn take_listview(array: &ListViewArray, codes: &PrimitiveArray) -> ListViewArray { - take(array.as_ref(), codes.as_ref()) +fn take_listview( + array: &ListViewArray, + codes: &PrimitiveArray, + ctx: &mut ExecutionCtx, +) -> ListViewArray { + ::take(array, codes.as_ref(), ctx) .vortex_expect("take listview array") + .vortex_expect("take listview should not return None") .as_::() .clone() } -fn take_fixed_size_list(array: &FixedSizeListArray, codes: &PrimitiveArray) -> FixedSizeListArray { - take(array.as_ref(), codes.as_ref()) +fn take_fixed_size_list( + array: &FixedSizeListArray, + codes: &PrimitiveArray, + ctx: &mut ExecutionCtx, +) -> FixedSizeListArray { + ::take(array, codes.as_ref(), ctx) .vortex_expect("take fixed size list array") + .vortex_expect("take fixed size list should not return None") .as_::() .clone() } -fn take_struct(array: &StructArray, codes: &PrimitiveArray) -> StructArray { - take(array.as_ref(), codes.as_ref()) +fn take_struct(array: &StructArray, codes: &PrimitiveArray, ctx: &mut ExecutionCtx) -> StructArray { + ::take(array, codes.as_ref(), ctx) .vortex_expect("take struct array") + .vortex_expect("take struct should not return None") .as_::() .clone() } -fn take_extension(array: &ExtensionArray, codes: &PrimitiveArray) -> ExtensionArray { - let taken_storage = - take(array.storage(), codes.as_ref()).vortex_expect("take extension storage"); - ExtensionArray::new(array.ext_dtype().clone(), taken_storage) +fn take_extension( + array: &ExtensionArray, + codes: &PrimitiveArray, + ctx: &mut ExecutionCtx, +) -> ExtensionArray { + ::take(array, codes.as_ref(), ctx) + .vortex_expect("take extension storage") + .vortex_expect("take extension should not return None") + .as_::() + .clone() } diff --git a/vortex-array/src/arrays/dict/vtable/mod.rs b/vortex-array/src/arrays/dict/vtable/mod.rs index 2e656bc6960..c7f32954d09 100644 --- a/vortex-array/src/arrays/dict/vtable/mod.rs +++ b/vortex-array/src/arrays/dict/vtable/mod.rs @@ -142,7 +142,7 @@ impl VTable for DictVTable { // TODO(ngates): if indices min is quite high, we could slice self and offset the indices // such that canonicalize does less work. - Ok(take_canonical(values, &codes)?.into_array()) + Ok(take_canonical(values, &codes, ctx)?.into_array()) } fn reduce_parent( diff --git a/vortex-array/src/arrays/varbin/vtable/mod.rs b/vortex-array/src/arrays/varbin/vtable/mod.rs index 776c961b36c..fa92e528f27 100644 --- a/vortex-array/src/arrays/varbin/vtable/mod.rs +++ b/vortex-array/src/arrays/varbin/vtable/mod.rs @@ -10,9 +10,9 @@ use vortex_error::vortex_bail; use vortex_error::vortex_err; use crate::ArrayRef; -use crate::Canonical; use crate::DeserializeMetadata; use crate::ExecutionCtx; +use crate::IntoArray; use crate::ProstMetadata; use crate::SerializeMetadata; use crate::arrays::varbin::VarBinArray; @@ -21,7 +21,6 @@ use crate::serde::ArrayChildren; use crate::validity::Validity; use crate::vtable; use crate::vtable::ArrayId; -use crate::vtable::NotSupported; use crate::vtable::VTable; use crate::vtable::ValidityVTableFromValidityHelper; @@ -53,7 +52,6 @@ impl VTable for VarBinVTable { type OperationsVTable = Self; type ValidityVTable = ValidityVTableFromValidityHelper; type VisitorVTable = Self; - type ComputeVTable = NotSupported; fn id(_array: &Self::Array) -> ArrayId { Self::ID @@ -146,8 +144,8 @@ impl VTable for VarBinVTable { kernel::PARENT_KERNELS.execute(array, parent, child_idx, ctx) } - fn canonicalize(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { - varbin_to_canonical(array, ctx) + fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { + Ok(varbin_to_canonical(array, ctx)?.into_array()) } } From 011a0710c4d95ee3c62a6413343babdcd0d380fd Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Fri, 6 Feb 2026 12:21:04 +0000 Subject: [PATCH 09/21] wip Signed-off-by: Joe Isaacs --- vortex-array/src/arrays/bool/compute/take.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vortex-array/src/arrays/bool/compute/take.rs b/vortex-array/src/arrays/bool/compute/take.rs index 57c2084f3dc..258e507a910 100644 --- a/vortex-array/src/arrays/bool/compute/take.rs +++ b/vortex-array/src/arrays/bool/compute/take.rs @@ -119,7 +119,7 @@ mod test { BoolArray::from_iter([Some(false), None, Some(false)]).to_bit_buffer() ); - let all_invalid_indices = PrimitiveArray::from_option_iter([None::, None, None]); + let all_invalid_indices = PrimitiveArray::from_option_iter([None::, None, None]); let b = take(reference.as_ref(), all_invalid_indices.as_ref()).unwrap(); assert_arrays_eq!(b, BoolArray::from_iter([None, None, None])); } From a4c9fa453b3dea393dbdfe46e06ec3b976b5add9 Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Fri, 6 Feb 2026 12:37:48 +0000 Subject: [PATCH 10/21] wip Signed-off-by: Joe Isaacs --- vortex-array/src/arrays/bool/vtable/kernel.rs | 7 +- .../src/arrays/chunked/compute/kernel.rs | 2 + .../src/arrays/constant/compute/rules.rs | 2 + .../src/arrays/constant/compute/take.rs | 64 +++++++++---------- .../src/arrays/decimal/vtable/kernel.rs | 9 +++ vortex-array/src/arrays/decimal/vtable/mod.rs | 10 +++ vortex-array/src/arrays/dict/compute/mod.rs | 14 ++-- vortex-array/src/arrays/dict/vtable/kernel.rs | 9 +++ vortex-array/src/arrays/dict/vtable/mod.rs | 10 +++ .../src/arrays/extension/compute/take.rs | 22 +++---- .../src/arrays/extension/vtable/kernel.rs | 9 +++ .../src/arrays/extension/vtable/mod.rs | 10 +++ .../arrays/fixed_size_list/vtable/kernel.rs | 11 ++++ .../src/arrays/fixed_size_list/vtable/mod.rs | 10 +++ .../src/arrays/list/compute/kernels.rs | 7 +- .../src/arrays/listview/vtable/kernel.rs | 9 +++ .../src/arrays/listview/vtable/mod.rs | 10 +++ .../src/arrays/masked/compute/take.rs | 38 ++++++----- .../src/arrays/masked/vtable/kernel.rs | 9 +++ vortex-array/src/arrays/masked/vtable/mod.rs | 12 ++++ vortex-array/src/arrays/null/compute/rules.rs | 2 + vortex-array/src/arrays/null/compute/take.rs | 28 ++++---- .../src/arrays/primitive/vtable/kernel.rs | 9 +++ .../src/arrays/primitive/vtable/mod.rs | 11 ++++ .../src/arrays/struct_/compute/rules.rs | 1 + .../src/arrays/struct_/vtable/kernel.rs | 9 +++ vortex-array/src/arrays/struct_/vtable/mod.rs | 11 ++++ .../src/arrays/varbin/vtable/kernel.rs | 7 +- .../src/arrays/varbinview/vtable/kernel.rs | 9 +++ .../src/arrays/varbinview/vtable/mod.rs | 11 ++++ 30 files changed, 276 insertions(+), 96 deletions(-) create mode 100644 vortex-array/src/arrays/decimal/vtable/kernel.rs create mode 100644 vortex-array/src/arrays/dict/vtable/kernel.rs create mode 100644 vortex-array/src/arrays/extension/vtable/kernel.rs create mode 100644 vortex-array/src/arrays/fixed_size_list/vtable/kernel.rs create mode 100644 vortex-array/src/arrays/listview/vtable/kernel.rs create mode 100644 vortex-array/src/arrays/masked/vtable/kernel.rs create mode 100644 vortex-array/src/arrays/primitive/vtable/kernel.rs create mode 100644 vortex-array/src/arrays/struct_/vtable/kernel.rs create mode 100644 vortex-array/src/arrays/varbinview/vtable/kernel.rs diff --git a/vortex-array/src/arrays/bool/vtable/kernel.rs b/vortex-array/src/arrays/bool/vtable/kernel.rs index be7d92058de..3afa0dc3755 100644 --- a/vortex-array/src/arrays/bool/vtable/kernel.rs +++ b/vortex-array/src/arrays/bool/vtable/kernel.rs @@ -2,8 +2,11 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use crate::arrays::BoolVTable; +use crate::arrays::TakeExecuteAdaptor; use crate::arrays::filter::FilterExecuteAdaptor; use crate::kernel::ParentKernelSet; -pub(super) const PARENT_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&FilterExecuteAdaptor(BoolVTable))]); +pub(super) const PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ + ParentKernelSet::lift(&FilterExecuteAdaptor(BoolVTable)), + ParentKernelSet::lift(&TakeExecuteAdaptor(BoolVTable)), +]); diff --git a/vortex-array/src/arrays/chunked/compute/kernel.rs b/vortex-array/src/arrays/chunked/compute/kernel.rs index 6db192e4bdf..e933961d816 100644 --- a/vortex-array/src/arrays/chunked/compute/kernel.rs +++ b/vortex-array/src/arrays/chunked/compute/kernel.rs @@ -4,9 +4,11 @@ use crate::arrays::ChunkedVTable; use crate::arrays::FilterExecuteAdaptor; use crate::arrays::SliceExecuteAdaptor; +use crate::arrays::TakeExecuteAdaptor; use crate::kernel::ParentKernelSet; pub(crate) static PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ ParentKernelSet::lift(&SliceExecuteAdaptor(ChunkedVTable)), ParentKernelSet::lift(&FilterExecuteAdaptor(ChunkedVTable)), + ParentKernelSet::lift(&TakeExecuteAdaptor(ChunkedVTable)), ]); diff --git a/vortex-array/src/arrays/constant/compute/rules.rs b/vortex-array/src/arrays/constant/compute/rules.rs index ddc7b349754..76d390be4c1 100644 --- a/vortex-array/src/arrays/constant/compute/rules.rs +++ b/vortex-array/src/arrays/constant/compute/rules.rs @@ -11,6 +11,7 @@ use crate::arrays::FilterArray; use crate::arrays::FilterReduceAdaptor; use crate::arrays::FilterVTable; use crate::arrays::SliceReduceAdaptor; +use crate::arrays::TakeReduceAdaptor; use crate::optimizer::rules::ArrayParentReduceRule; use crate::optimizer::rules::ParentRuleSet; @@ -18,6 +19,7 @@ pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::ne ParentRuleSet::lift(&ConstantFilterRule), ParentRuleSet::lift(&FilterReduceAdaptor(ConstantVTable)), ParentRuleSet::lift(&SliceReduceAdaptor(ConstantVTable)), + ParentRuleSet::lift(&TakeReduceAdaptor(ConstantVTable)), ]); #[derive(Debug)] diff --git a/vortex-array/src/arrays/constant/compute/take.rs b/vortex-array/src/arrays/constant/compute/take.rs index 352f153ff5e..44a00f76ca4 100644 --- a/vortex-array/src/arrays/constant/compute/take.rs +++ b/vortex-array/src/arrays/constant/compute/take.rs @@ -16,43 +16,39 @@ use crate::arrays::TakeReduceAdaptor; use crate::optimizer::rules::ParentRuleSet; use crate::validity::Validity; -fn take_constant(array: &ConstantArray, indices: &dyn Array) -> VortexResult { - let result = match indices.validity_mask()?.bit_buffer() { - AllOr::All => { - let scalar = Scalar::new( - array - .scalar() - .dtype() - .union_nullability(indices.dtype().nullability()), - array.scalar().value().clone(), - ); - ConstantArray::new(scalar, indices.len()).into_array() - } - AllOr::None => ConstantArray::new( - Scalar::null( - array - .dtype() - .union_nullability(indices.dtype().nullability()), - ), - indices.len(), - ) - .into_array(), - AllOr::Some(v) => { - let arr = ConstantArray::new(array.scalar().clone(), indices.len()).into_array(); - - if array.scalar().is_null() { - return Ok(arr); +impl TakeReduce for ConstantVTable { + fn take(array: &ConstantArray, indices: &dyn Array) -> VortexResult> { + let result = match indices.validity_mask()?.bit_buffer() { + AllOr::All => { + let scalar = Scalar::new( + array + .scalar() + .dtype() + .union_nullability(indices.dtype().nullability()), + array.scalar().value().clone(), + ); + ConstantArray::new(scalar, indices.len()).into_array() } + AllOr::None => ConstantArray::new( + Scalar::null( + array + .dtype() + .union_nullability(indices.dtype().nullability()), + ), + indices.len(), + ) + .into_array(), + AllOr::Some(v) => { + let arr = ConstantArray::new(array.scalar().clone(), indices.len()).into_array(); - MaskedArray::try_new(arr, Validity::from(v.clone()))?.into_array() - } - }; - Ok(result) -} + if array.scalar().is_null() { + return Ok(Some(arr)); + } -impl TakeReduce for ConstantVTable { - fn take(array: &ConstantArray, indices: &dyn Array) -> VortexResult> { - take_constant(array, indices).map(Some) + MaskedArray::try_new(arr, Validity::from(v.clone()))?.into_array() + } + }; + Ok(Some(result)) } } diff --git a/vortex-array/src/arrays/decimal/vtable/kernel.rs b/vortex-array/src/arrays/decimal/vtable/kernel.rs new file mode 100644 index 00000000000..561f4073026 --- /dev/null +++ b/vortex-array/src/arrays/decimal/vtable/kernel.rs @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use crate::arrays::DecimalVTable; +use crate::arrays::TakeExecuteAdaptor; +use crate::kernel::ParentKernelSet; + +pub(super) const PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(DecimalVTable))]); diff --git a/vortex-array/src/arrays/decimal/vtable/mod.rs b/vortex-array/src/arrays/decimal/vtable/mod.rs index 02aba7ff9cf..b2fd347cd78 100644 --- a/vortex-array/src/arrays/decimal/vtable/mod.rs +++ b/vortex-array/src/arrays/decimal/vtable/mod.rs @@ -25,6 +25,7 @@ use crate::vtable::VTable; use crate::vtable::ValidityVTableFromValidityHelper; mod array; +mod kernel; mod operations; mod validity; mod visitor; @@ -137,6 +138,15 @@ impl VTable for DecimalVTable { ) -> VortexResult> { RULES.evaluate(array, parent, child_idx) } + + fn execute_parent( + array: &Self::Array, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + kernel::PARENT_KERNELS.execute(array, parent, child_idx, ctx) + } } #[derive(Debug)] diff --git a/vortex-array/src/arrays/dict/compute/mod.rs b/vortex-array/src/arrays/dict/compute/mod.rs index 82cedd0576e..94316a730f5 100644 --- a/vortex-array/src/arrays/dict/compute/mod.rs +++ b/vortex-array/src/arrays/dict/compute/mod.rs @@ -27,20 +27,18 @@ use crate::arrays::filter::FilterReduce; use crate::compute::take; use crate::kernel::ParentKernelSet; -fn take_dict(array: &DictArray, indices: &dyn Array) -> VortexResult { - let codes = take(array.codes(), indices)?; - // SAFETY: selecting codes doesn't change the invariants of DictArray - // Preserve all_values_referenced since taking codes doesn't affect which values are referenced - Ok(unsafe { DictArray::new_unchecked(codes, array.values().clone()).into_array() }) -} - impl TakeExecute for DictVTable { fn take( array: &DictArray, indices: &dyn Array, _ctx: &mut ExecutionCtx, ) -> VortexResult> { - take_dict(array, indices).map(Some) + let codes = take(array.codes(), indices)?; + // SAFETY: selecting codes doesn't change the invariants of DictArray + // Preserve all_values_referenced since taking codes doesn't affect which values are referenced + Ok(Some(unsafe { + DictArray::new_unchecked(codes, array.values().clone()).into_array() + })) } } diff --git a/vortex-array/src/arrays/dict/vtable/kernel.rs b/vortex-array/src/arrays/dict/vtable/kernel.rs new file mode 100644 index 00000000000..3050f9a5b85 --- /dev/null +++ b/vortex-array/src/arrays/dict/vtable/kernel.rs @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use crate::arrays::DictVTable; +use crate::arrays::TakeExecuteAdaptor; +use crate::kernel::ParentKernelSet; + +pub(super) const PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(DictVTable))]); diff --git a/vortex-array/src/arrays/dict/vtable/mod.rs b/vortex-array/src/arrays/dict/vtable/mod.rs index c7f32954d09..2fd80bb4904 100644 --- a/vortex-array/src/arrays/dict/vtable/mod.rs +++ b/vortex-array/src/arrays/dict/vtable/mod.rs @@ -30,6 +30,7 @@ use crate::vtable::ArrayId; use crate::vtable::VTable; mod array; +mod kernel; mod operations; mod validity; mod visitor; @@ -152,6 +153,15 @@ impl VTable for DictVTable { ) -> VortexResult> { PARENT_RULES.evaluate(array, parent, child_idx) } + + fn execute_parent( + array: &Self::Array, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + kernel::PARENT_KERNELS.execute(array, parent, child_idx, ctx) + } } /// Check for fast-path execution conditions. diff --git a/vortex-array/src/arrays/extension/compute/take.rs b/vortex-array/src/arrays/extension/compute/take.rs index b6dd7c7e394..04cf64b277c 100644 --- a/vortex-array/src/arrays/extension/compute/take.rs +++ b/vortex-array/src/arrays/extension/compute/take.rs @@ -13,24 +13,22 @@ use crate::arrays::TakeExecute; use crate::arrays::TakeExecuteAdaptor; use crate::kernel::ParentKernelSet; -fn take_extension(array: &ExtensionArray, indices: &dyn Array) -> VortexResult { - let taken_storage = array.storage().take(indices.to_array())?; - Ok(ExtensionArray::new( - array - .ext_dtype() - .with_nullability(taken_storage.dtype().nullability()), - taken_storage, - ) - .into_array()) -} - impl TakeExecute for ExtensionVTable { fn take( array: &ExtensionArray, indices: &dyn Array, _ctx: &mut ExecutionCtx, ) -> VortexResult> { - take_extension(array, indices).map(Some) + let taken_storage = array.storage().take(indices.to_array())?; + Ok(Some( + ExtensionArray::new( + array + .ext_dtype() + .with_nullability(taken_storage.dtype().nullability()), + taken_storage, + ) + .into_array(), + )) } } diff --git a/vortex-array/src/arrays/extension/vtable/kernel.rs b/vortex-array/src/arrays/extension/vtable/kernel.rs new file mode 100644 index 00000000000..41f4f8a0f9a --- /dev/null +++ b/vortex-array/src/arrays/extension/vtable/kernel.rs @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use crate::arrays::ExtensionVTable; +use crate::arrays::TakeExecuteAdaptor; +use crate::kernel::ParentKernelSet; + +pub(super) const PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(ExtensionVTable))]); diff --git a/vortex-array/src/arrays/extension/vtable/mod.rs b/vortex-array/src/arrays/extension/vtable/mod.rs index bcbe02602d8..0d6d2672bf5 100644 --- a/vortex-array/src/arrays/extension/vtable/mod.rs +++ b/vortex-array/src/arrays/extension/vtable/mod.rs @@ -3,6 +3,7 @@ mod array; mod canonical; +mod kernel; mod operations; mod validity; mod visitor; @@ -94,6 +95,15 @@ impl VTable for ExtensionVTable { ) -> VortexResult> { PARENT_RULES.evaluate(array, parent, child_idx) } + + fn execute_parent( + array: &Self::Array, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + kernel::PARENT_KERNELS.execute(array, parent, child_idx, ctx) + } } #[derive(Debug)] diff --git a/vortex-array/src/arrays/fixed_size_list/vtable/kernel.rs b/vortex-array/src/arrays/fixed_size_list/vtable/kernel.rs new file mode 100644 index 00000000000..e82c90cd4ba --- /dev/null +++ b/vortex-array/src/arrays/fixed_size_list/vtable/kernel.rs @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use crate::arrays::FixedSizeListVTable; +use crate::arrays::TakeExecuteAdaptor; +use crate::kernel::ParentKernelSet; + +pub(super) const PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor( + FixedSizeListVTable, + ))]); diff --git a/vortex-array/src/arrays/fixed_size_list/vtable/mod.rs b/vortex-array/src/arrays/fixed_size_list/vtable/mod.rs index 8b82ea71c3b..1c3bb11d623 100644 --- a/vortex-array/src/arrays/fixed_size_list/vtable/mod.rs +++ b/vortex-array/src/arrays/fixed_size_list/vtable/mod.rs @@ -21,6 +21,7 @@ use crate::vtable::VTable; use crate::vtable::ValidityVTableFromValidityHelper; mod array; +mod kernel; mod operations; mod validity; mod visitor; @@ -56,6 +57,15 @@ impl VTable for FixedSizeListVTable { PARENT_RULES.evaluate(array, parent, child_idx) } + fn execute_parent( + array: &Self::Array, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + kernel::PARENT_KERNELS.execute(array, parent, child_idx, ctx) + } + fn metadata(_array: &FixedSizeListArray) -> VortexResult { Ok(EmptyMetadata) } diff --git a/vortex-array/src/arrays/list/compute/kernels.rs b/vortex-array/src/arrays/list/compute/kernels.rs index 38e065c4ebf..6f3ea7c4579 100644 --- a/vortex-array/src/arrays/list/compute/kernels.rs +++ b/vortex-array/src/arrays/list/compute/kernels.rs @@ -3,7 +3,10 @@ use crate::arrays::FilterExecuteAdaptor; use crate::arrays::ListVTable; +use crate::arrays::TakeExecuteAdaptor; use crate::kernel::ParentKernelSet; -pub(crate) const PARENT_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&FilterExecuteAdaptor(ListVTable))]); +pub(crate) const PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ + ParentKernelSet::lift(&FilterExecuteAdaptor(ListVTable)), + ParentKernelSet::lift(&TakeExecuteAdaptor(ListVTable)), +]); diff --git a/vortex-array/src/arrays/listview/vtable/kernel.rs b/vortex-array/src/arrays/listview/vtable/kernel.rs new file mode 100644 index 00000000000..4403eaa46a4 --- /dev/null +++ b/vortex-array/src/arrays/listview/vtable/kernel.rs @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use crate::arrays::ListViewVTable; +use crate::arrays::TakeExecuteAdaptor; +use crate::kernel::ParentKernelSet; + +pub(super) const PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(ListViewVTable))]); diff --git a/vortex-array/src/arrays/listview/vtable/mod.rs b/vortex-array/src/arrays/listview/vtable/mod.rs index 916414409cd..6ba12435077 100644 --- a/vortex-array/src/arrays/listview/vtable/mod.rs +++ b/vortex-array/src/arrays/listview/vtable/mod.rs @@ -25,6 +25,7 @@ use crate::vtable::VTable; use crate::vtable::ValidityVTableFromValidityHelper; mod array; +mod kernel; mod operations; mod validity; mod visitor; @@ -170,4 +171,13 @@ impl VTable for ListViewVTable { ) -> VortexResult> { PARENT_RULES.evaluate(array, parent, child_idx) } + + fn execute_parent( + array: &Self::Array, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + kernel::PARENT_KERNELS.execute(array, parent, child_idx, ctx) + } } diff --git a/vortex-array/src/arrays/masked/compute/take.rs b/vortex-array/src/arrays/masked/compute/take.rs index 42a7a41d5d5..8914e307391 100644 --- a/vortex-array/src/arrays/masked/compute/take.rs +++ b/vortex-array/src/arrays/masked/compute/take.rs @@ -17,32 +17,30 @@ use crate::compute::take; use crate::kernel::ParentKernelSet; use crate::vtable::ValidityHelper; -fn take_masked(array: &MaskedArray, indices: &dyn Array) -> VortexResult { - let taken_child = if !indices.all_valid()? { - // This is safe because we'll mask out these positions in the validity - let filled_take = fill_null( - indices, - &Scalar::default_value(indices.dtype().clone().as_nonnullable()), - )?; - take(&array.child, &filled_take)? - } else { - take(&array.child, indices)? - }; - - // Compute the new validity by taking from array's validity and merging with indices validity - let taken_validity = array.validity().take(indices)?; - - // Construct new MaskedArray - Ok(MaskedArray::try_new(taken_child, taken_validity)?.into_array()) -} - impl TakeExecute for MaskedVTable { fn take( array: &MaskedArray, indices: &dyn Array, _ctx: &mut ExecutionCtx, ) -> VortexResult> { - take_masked(array, indices).map(Some) + let taken_child = if !indices.all_valid()? { + // This is safe because we'll mask out these positions in the validity + let filled_take = fill_null( + indices, + &Scalar::default_value(indices.dtype().clone().as_nonnullable()), + )?; + take(&array.child, &filled_take)? + } else { + take(&array.child, indices)? + }; + + // Compute the new validity by taking from array's validity and merging with indices validity + let taken_validity = array.validity().take(indices)?; + + // Construct new MaskedArray + Ok(Some( + MaskedArray::try_new(taken_child, taken_validity)?.into_array(), + )) } } diff --git a/vortex-array/src/arrays/masked/vtable/kernel.rs b/vortex-array/src/arrays/masked/vtable/kernel.rs new file mode 100644 index 00000000000..869e8601aba --- /dev/null +++ b/vortex-array/src/arrays/masked/vtable/kernel.rs @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use crate::arrays::MaskedVTable; +use crate::arrays::TakeExecuteAdaptor; +use crate::kernel::ParentKernelSet; + +pub(super) const PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(MaskedVTable))]); diff --git a/vortex-array/src/arrays/masked/vtable/mod.rs b/vortex-array/src/arrays/masked/vtable/mod.rs index d567a82a234..477aadbf5ec 100644 --- a/vortex-array/src/arrays/masked/vtable/mod.rs +++ b/vortex-array/src/arrays/masked/vtable/mod.rs @@ -6,6 +6,7 @@ mod canonical; mod operations; mod validity; +use kernel::PARENT_KERNELS; use vortex_dtype::DType; use vortex_error::VortexExpect; use vortex_error::VortexResult; @@ -33,6 +34,8 @@ use crate::vtable::VTable; use crate::vtable::ValidityVTableFromValidityHelper; use crate::vtable::VisitorVTable; +mod kernel; + vtable!(Masked); #[derive(Debug)] @@ -134,6 +137,15 @@ impl VTable for MaskedVTable { PARENT_RULES.evaluate(array, parent, child_idx) } + fn execute_parent( + array: &Self::Array, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + PARENT_KERNELS.execute(array, parent, child_idx, ctx) + } + fn with_children(array: &mut Self::Array, children: Vec) -> VortexResult<()> { vortex_ensure!( children.len() == 1 || children.len() == 2, diff --git a/vortex-array/src/arrays/null/compute/rules.rs b/vortex-array/src/arrays/null/compute/rules.rs index 33a554e90bd..62aed40e66c 100644 --- a/vortex-array/src/arrays/null/compute/rules.rs +++ b/vortex-array/src/arrays/null/compute/rules.rs @@ -4,9 +4,11 @@ use crate::arrays::FilterReduceAdaptor; use crate::arrays::NullVTable; use crate::arrays::SliceReduceAdaptor; +use crate::arrays::TakeReduceAdaptor; use crate::optimizer::rules::ParentRuleSet; pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::new(&[ ParentRuleSet::lift(&FilterReduceAdaptor(NullVTable)), ParentRuleSet::lift(&SliceReduceAdaptor(NullVTable)), + ParentRuleSet::lift(&TakeReduceAdaptor(NullVTable)), ]); diff --git a/vortex-array/src/arrays/null/compute/take.rs b/vortex-array/src/arrays/null/compute/take.rs index fc9f24adab7..be266175656 100644 --- a/vortex-array/src/arrays/null/compute/take.rs +++ b/vortex-array/src/arrays/null/compute/take.rs @@ -15,25 +15,21 @@ use crate::arrays::TakeReduce; use crate::arrays::TakeReduceAdaptor; use crate::optimizer::rules::ParentRuleSet; -#[allow(clippy::cast_possible_truncation)] -fn take_null(array: &NullArray, indices: &dyn Array) -> VortexResult { - let indices = indices.to_primitive(); +impl TakeReduce for NullVTable { + #[allow(clippy::cast_possible_truncation)] + fn take(array: &NullArray, indices: &dyn Array) -> VortexResult> { + let indices = indices.to_primitive(); - // Enforce all indices are valid - match_each_integer_ptype!(indices.ptype(), |T| { - for index in indices.as_slice::() { - if (*index as usize) >= array.len() { - vortex_bail!(OutOfBounds: *index as usize, 0, array.len()); + // Enforce all indices are valid + match_each_integer_ptype!(indices.ptype(), |T| { + for index in indices.as_slice::() { + if (*index as usize) >= array.len() { + vortex_bail!(OutOfBounds: *index as usize, 0, array.len()); + } } - } - }); - - Ok(NullArray::new(indices.len()).into_array()) -} + }); -impl TakeReduce for NullVTable { - fn take(array: &NullArray, indices: &dyn Array) -> VortexResult> { - take_null(array, indices).map(Some) + Ok(Some(NullArray::new(indices.len()).into_array())) } } diff --git a/vortex-array/src/arrays/primitive/vtable/kernel.rs b/vortex-array/src/arrays/primitive/vtable/kernel.rs new file mode 100644 index 00000000000..7882fde292e --- /dev/null +++ b/vortex-array/src/arrays/primitive/vtable/kernel.rs @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use crate::arrays::PrimitiveVTable; +use crate::arrays::TakeExecuteAdaptor; +use crate::kernel::ParentKernelSet; + +pub(super) const PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(PrimitiveVTable))]); diff --git a/vortex-array/src/arrays/primitive/vtable/mod.rs b/vortex-array/src/arrays/primitive/vtable/mod.rs index 05b2fa10fbf..283a76848b3 100644 --- a/vortex-array/src/arrays/primitive/vtable/mod.rs +++ b/vortex-array/src/arrays/primitive/vtable/mod.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use kernel::PARENT_KERNELS; use vortex_dtype::DType; use vortex_dtype::PType; use vortex_error::VortexExpect; @@ -20,6 +21,7 @@ use crate::vtable::VTable; use crate::vtable::ValidityVTableFromValidityHelper; mod array; +mod kernel; mod operations; mod validity; mod visitor; @@ -132,6 +134,15 @@ impl VTable for PrimitiveVTable { ) -> VortexResult> { RULES.evaluate(array, parent, child_idx) } + + fn execute_parent( + array: &Self::Array, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + PARENT_KERNELS.execute(array, parent, child_idx, ctx) + } } #[derive(Debug)] diff --git a/vortex-array/src/arrays/struct_/compute/rules.rs b/vortex-array/src/arrays/struct_/compute/rules.rs index ff2df148c23..49401f1589a 100644 --- a/vortex-array/src/arrays/struct_/compute/rules.rs +++ b/vortex-array/src/arrays/struct_/compute/rules.rs @@ -9,6 +9,7 @@ use crate::ArrayRef; use crate::IntoArray; use crate::arrays::ConstantArray; use crate::arrays::ExactScalarFn; +use crate::arrays::FilterReduceAdaptor; use crate::arrays::ScalarFnArrayExt; use crate::arrays::ScalarFnArrayView; use crate::arrays::SliceReduceAdaptor; diff --git a/vortex-array/src/arrays/struct_/vtable/kernel.rs b/vortex-array/src/arrays/struct_/vtable/kernel.rs new file mode 100644 index 00000000000..74a22f2ded4 --- /dev/null +++ b/vortex-array/src/arrays/struct_/vtable/kernel.rs @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use crate::arrays::StructVTable; +use crate::arrays::TakeExecuteAdaptor; +use crate::kernel::ParentKernelSet; + +pub(super) const PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(StructVTable))]); diff --git a/vortex-array/src/arrays/struct_/vtable/mod.rs b/vortex-array/src/arrays/struct_/vtable/mod.rs index 267d9574440..02525bc2ff7 100644 --- a/vortex-array/src/arrays/struct_/vtable/mod.rs +++ b/vortex-array/src/arrays/struct_/vtable/mod.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use itertools::Itertools; +use kernel::PARENT_KERNELS; use vortex_dtype::DType; use vortex_error::VortexExpect; use vortex_error::VortexResult; @@ -23,6 +24,7 @@ use crate::vtable::VTable; use crate::vtable::ValidityVTableFromValidityHelper; mod array; +mod kernel; mod operations; mod validity; mod visitor; @@ -138,6 +140,15 @@ impl VTable for StructVTable { ) -> VortexResult> { PARENT_RULES.evaluate(array, parent, child_idx) } + + fn execute_parent( + array: &Self::Array, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + PARENT_KERNELS.execute(array, parent, child_idx, ctx) + } } #[derive(Debug)] diff --git a/vortex-array/src/arrays/varbin/vtable/kernel.rs b/vortex-array/src/arrays/varbin/vtable/kernel.rs index 09dd04b7557..6a94dffb2f8 100644 --- a/vortex-array/src/arrays/varbin/vtable/kernel.rs +++ b/vortex-array/src/arrays/varbin/vtable/kernel.rs @@ -1,9 +1,12 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use crate::arrays::TakeExecuteAdaptor; use crate::arrays::VarBinVTable; use crate::arrays::filter::FilterExecuteAdaptor; use crate::kernel::ParentKernelSet; -pub(super) const PARENT_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&FilterExecuteAdaptor(VarBinVTable))]); +pub(super) const PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ + ParentKernelSet::lift(&FilterExecuteAdaptor(VarBinVTable)), + ParentKernelSet::lift(&TakeExecuteAdaptor(VarBinVTable)), +]); diff --git a/vortex-array/src/arrays/varbinview/vtable/kernel.rs b/vortex-array/src/arrays/varbinview/vtable/kernel.rs new file mode 100644 index 00000000000..4fc842eec54 --- /dev/null +++ b/vortex-array/src/arrays/varbinview/vtable/kernel.rs @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use crate::arrays::TakeExecuteAdaptor; +use crate::arrays::VarBinViewVTable; +use crate::kernel::ParentKernelSet; + +pub(super) const PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(VarBinViewVTable))]); diff --git a/vortex-array/src/arrays/varbinview/vtable/mod.rs b/vortex-array/src/arrays/varbinview/vtable/mod.rs index 078feb8deb7..061c9f7a37e 100644 --- a/vortex-array/src/arrays/varbinview/vtable/mod.rs +++ b/vortex-array/src/arrays/varbinview/vtable/mod.rs @@ -3,6 +3,7 @@ use std::sync::Arc; +use kernel::PARENT_KERNELS; use vortex_buffer::Buffer; use vortex_buffer::ByteBuffer; use vortex_dtype::DType; @@ -26,6 +27,7 @@ use crate::vtable::VTable; use crate::vtable::ValidityVTableFromValidityHelper; mod array; +mod kernel; mod operations; mod validity; mod visitor; @@ -124,6 +126,15 @@ impl VTable for VarBinViewVTable { PARENT_RULES.evaluate(array, parent, child_idx) } + fn execute_parent( + array: &Self::Array, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + PARENT_KERNELS.execute(array, parent, child_idx, ctx) + } + fn execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult { Ok(array.to_array()) } From 44ac66f50d5863849d3079c31578d79ca6460bd6 Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Fri, 6 Feb 2026 12:38:47 +0000 Subject: [PATCH 11/21] wip Signed-off-by: Joe Isaacs --- encodings/alp/src/alp/compute/take.rs | 36 +++++++++++++-------------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/encodings/alp/src/alp/compute/take.rs b/encodings/alp/src/alp/compute/take.rs index daeb613871f..44b24641266 100644 --- a/encodings/alp/src/alp/compute/take.rs +++ b/encodings/alp/src/alp/compute/take.rs @@ -13,31 +13,29 @@ use vortex_error::VortexResult; use crate::ALPArray; use crate::ALPVTable; -fn take_alp(array: &ALPArray, indices: &dyn Array) -> VortexResult { - let taken_encoded = array.encoded().take(indices.to_array())?; - let taken_patches = array - .patches() - .map(|p| p.take(indices)) - .transpose()? - .flatten() - .map(|patches| { - patches.cast_values( - &array - .dtype() - .with_nullability(taken_encoded.dtype().nullability()), - ) - }) - .transpose()?; - Ok(ALPArray::new(taken_encoded, array.exponents(), taken_patches).into_array()) -} - impl TakeExecute for ALPVTable { fn take( array: &ALPArray, indices: &dyn Array, _ctx: &mut ExecutionCtx, ) -> VortexResult> { - take_alp(array, indices).map(Some) + let taken_encoded = array.encoded().take(indices.to_array())?; + let taken_patches = array + .patches() + .map(|p| p.take(indices)) + .transpose()? + .flatten() + .map(|patches| { + patches.cast_values( + &array + .dtype() + .with_nullability(taken_encoded.dtype().nullability()), + ) + }) + .transpose()?; + Ok(Some( + ALPArray::new(taken_encoded, array.exponents(), taken_patches).into_array(), + )) } } From 66db51297e4087cf0ef2199421e68ed3a5b24bc3 Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Fri, 6 Feb 2026 14:09:33 +0000 Subject: [PATCH 12/21] wip Signed-off-by: Joe Isaacs --- encodings/alp/src/alp/array.rs | 1 - encodings/alp/src/alp_rd/compute/take.rs | 66 ++++--- encodings/bytebool/src/compute.rs | 42 ++--- .../src/decimal_byte_parts/compute/take.rs | 11 +- .../fastlanes/src/bitpacking/compute/take.rs | 40 ++-- encodings/fastlanes/src/for/compute/mod.rs | 16 +- encodings/fsst/src/compute/mod.rs | 48 +++-- encodings/runend/src/compute/filter.rs | 24 --- encodings/runend/src/compute/take.rs | 48 +++-- encodings/runend/src/compute/take_from.rs | 3 - encodings/sequence/src/compute/take.rs | 42 ++--- encodings/sparse/src/compute/take.rs | 66 +++---- encodings/zigzag/src/compute/mod.rs | 8 +- vortex-array/src/arrays/bool/compute/take.rs | 39 ++-- .../src/arrays/decimal/compute/take.rs | 36 ++-- .../arrays/fixed_size_list/compute/take.rs | 11 +- vortex-array/src/arrays/list/compute/take.rs | 38 ++-- .../src/arrays/listview/compute/take.rs | 94 +++++----- .../src/arrays/primitive/compute/take/mod.rs | 34 ++-- .../src/arrays/struct_/compute/rules.rs | 1 - .../src/arrays/struct_/compute/take.rs | 60 +++--- .../src/arrays/varbin/compute/take.rs | 172 +++++++++--------- .../src/arrays/varbinview/compute/take.rs | 48 +++-- 23 files changed, 431 insertions(+), 517 deletions(-) diff --git a/encodings/alp/src/alp/array.rs b/encodings/alp/src/alp/array.rs index 6712ba93e01..157ea88abbb 100644 --- a/encodings/alp/src/alp/array.rs +++ b/encodings/alp/src/alp/array.rs @@ -10,7 +10,6 @@ use vortex_array::ArrayChildVisitor; use vortex_array::ArrayEq; use vortex_array::ArrayHash; use vortex_array::ArrayRef; -use vortex_array::Canonical; use vortex_array::DeserializeMetadata; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; diff --git a/encodings/alp/src/alp_rd/compute/take.rs b/encodings/alp/src/alp_rd/compute/take.rs index 7a4acb7dcc3..10f236de877 100644 --- a/encodings/alp/src/alp_rd/compute/take.rs +++ b/encodings/alp/src/alp_rd/compute/take.rs @@ -16,46 +16,44 @@ use vortex_scalar::ScalarValue; use crate::ALPRDArray; use crate::ALPRDVTable; -fn take_alprd(array: &ALPRDArray, indices: &dyn Array) -> VortexResult { - let taken_left_parts = array.left_parts().take(indices.to_array())?; - let left_parts_exceptions = array - .left_parts_patches() - .map(|patches| patches.take(indices)) - .transpose()? - .flatten() - .map(|p| { - let values_dtype = p - .values() - .dtype() - .with_nullability(taken_left_parts.dtype().nullability()); - p.cast_values(&values_dtype) - }) - .transpose()?; - let right_parts = fill_null( - &array.right_parts().take(indices.to_array())?, - &Scalar::new(array.right_parts().dtype().clone(), ScalarValue::from(0)), - )?; - - Ok(ALPRDArray::try_new( - array - .dtype() - .with_nullability(taken_left_parts.dtype().nullability()), - taken_left_parts, - array.left_parts_dictionary().clone(), - right_parts, - array.right_bit_width(), - left_parts_exceptions, - )? - .into_array()) -} - impl TakeExecute for ALPRDVTable { fn take( array: &ALPRDArray, indices: &dyn Array, _ctx: &mut ExecutionCtx, ) -> VortexResult> { - take_alprd(array, indices).map(Some) + let taken_left_parts = array.left_parts().take(indices.to_array())?; + let left_parts_exceptions = array + .left_parts_patches() + .map(|patches| patches.take(indices)) + .transpose()? + .flatten() + .map(|p| { + let values_dtype = p + .values() + .dtype() + .with_nullability(taken_left_parts.dtype().nullability()); + p.cast_values(&values_dtype) + }) + .transpose()?; + let right_parts = fill_null( + &array.right_parts().take(indices.to_array())?, + &Scalar::new(array.right_parts().dtype().clone(), ScalarValue::from(0)), + )?; + + Ok(Some( + ALPRDArray::try_new( + array + .dtype() + .with_nullability(taken_left_parts.dtype().nullability()), + taken_left_parts, + array.left_parts_dictionary().clone(), + right_parts, + array.right_bit_width(), + left_parts_exceptions, + )? + .into_array(), + )) } } diff --git a/encodings/bytebool/src/compute.rs b/encodings/bytebool/src/compute.rs index f3ebb91f168..3fc5428dc40 100644 --- a/encodings/bytebool/src/compute.rs +++ b/encodings/bytebool/src/compute.rs @@ -57,34 +57,32 @@ impl MaskKernel for ByteBoolVTable { register_kernel!(MaskKernelAdapter(ByteBoolVTable).lift()); -fn take_bytebool(array: &ByteBoolArray, indices: &dyn Array) -> VortexResult { - let indices = indices.to_primitive(); - let bools = array.as_slice(); - - // This handles combining validity from both source array and nullable indices - let validity = array.validity().take(indices.as_ref())?; - - let taken_bools = match_each_integer_ptype!(indices.ptype(), |I| { - indices - .as_slice::() - .iter() - .map(|&idx| { - let idx: usize = idx.as_(); - bools[idx] - }) - .collect::>() - }); - - Ok(ByteBoolArray::from_vec(taken_bools, validity).into_array()) -} - impl TakeExecute for ByteBoolVTable { fn take( array: &ByteBoolArray, indices: &dyn Array, _ctx: &mut ExecutionCtx, ) -> VortexResult> { - take_bytebool(array, indices).map(Some) + let indices = indices.to_primitive(); + let bools = array.as_slice(); + + // This handles combining validity from both source array and nullable indices + let validity = array.validity().take(indices.as_ref())?; + + let taken_bools = match_each_integer_ptype!(indices.ptype(), |I| { + indices + .as_slice::() + .iter() + .map(|&idx| { + let idx: usize = idx.as_(); + bools[idx] + }) + .collect::>() + }); + + Ok(Some( + ByteBoolArray::from_vec(taken_bools, validity).into_array(), + )) } } diff --git a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/take.rs b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/take.rs index 3f202972b7b..88d1bbfb2e9 100644 --- a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/take.rs +++ b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/take.rs @@ -12,21 +12,14 @@ use vortex_error::VortexResult; use crate::DecimalBytePartsArray; use crate::DecimalBytePartsVTable; -fn take_decimal_byte_parts( - array: &DecimalBytePartsArray, - indices: &dyn Array, -) -> VortexResult { - DecimalBytePartsArray::try_new(array.msp.take(indices.to_array())?, *array.decimal_dtype()) - .map(|a| a.to_array()) -} - impl TakeExecute for DecimalBytePartsVTable { fn take( array: &DecimalBytePartsArray, indices: &dyn Array, _ctx: &mut ExecutionCtx, ) -> VortexResult> { - take_decimal_byte_parts(array, indices).map(Some) + DecimalBytePartsArray::try_new(array.msp.take(indices.to_array())?, *array.decimal_dtype()) + .map(|a| Some(a.to_array())) } } diff --git a/encodings/fastlanes/src/bitpacking/compute/take.rs b/encodings/fastlanes/src/bitpacking/compute/take.rs index 4794286ab5f..84cbf8536af 100644 --- a/encodings/fastlanes/src/bitpacking/compute/take.rs +++ b/encodings/fastlanes/src/bitpacking/compute/take.rs @@ -37,34 +37,30 @@ use crate::bitpack_decompress; /// see https://github.com/vortex-data/vortex/pull/190#issue-2223752833 pub(super) const UNPACK_CHUNK_THRESHOLD: usize = 8; -fn take_bitpacked(array: &BitPackedArray, indices: &dyn Array) -> VortexResult { - // If the indices are large enough, it's faster to flatten and take the primitive array. - if indices.len() * UNPACK_CHUNK_THRESHOLD > array.len() { - return array.to_primitive().take(indices.to_array()); - } - - // NOTE: we use the unsigned PType because all values in the BitPackedArray must - // be non-negative (pre-condition of creating the BitPackedArray). - let ptype: PType = PType::try_from(array.dtype())?; - let validity = array.validity(); - let taken_validity = validity.take(indices)?; - - let indices = indices.to_primitive(); - let taken = match_each_unsigned_integer_ptype!(ptype.to_unsigned(), |T| { - match_each_integer_ptype!(indices.ptype(), |I| { - take_primitive::(array, &indices, taken_validity)? - }) - }); - Ok(taken.reinterpret_cast(ptype).into_array()) -} - impl TakeExecute for BitPackedVTable { fn take( array: &BitPackedArray, indices: &dyn Array, _ctx: &mut ExecutionCtx, ) -> VortexResult> { - take_bitpacked(array, indices).map(Some) + // If the indices are large enough, it's faster to flatten and take the primitive array. + if indices.len() * UNPACK_CHUNK_THRESHOLD > array.len() { + return array.to_primitive().take(indices.to_array()).map(Some); + } + + // NOTE: we use the unsigned PType because all values in the BitPackedArray must + // be non-negative (pre-condition of creating the BitPackedArray). + let ptype: PType = PType::try_from(array.dtype())?; + let validity = array.validity(); + let taken_validity = validity.take(indices)?; + + let indices = indices.to_primitive(); + let taken = match_each_unsigned_integer_ptype!(ptype.to_unsigned(), |T| { + match_each_integer_ptype!(indices.ptype(), |I| { + take_primitive::(array, &indices, taken_validity)? + }) + }); + Ok(Some(taken.reinterpret_cast(ptype).into_array())) } } diff --git a/encodings/fastlanes/src/for/compute/mod.rs b/encodings/fastlanes/src/for/compute/mod.rs index 9ac8e97e761..35ed1d715f2 100644 --- a/encodings/fastlanes/src/for/compute/mod.rs +++ b/encodings/fastlanes/src/for/compute/mod.rs @@ -20,21 +20,19 @@ use vortex_mask::Mask; use crate::FoRArray; use crate::FoRVTable; -fn take_for(array: &FoRArray, indices: &dyn Array) -> VortexResult { - FoRArray::try_new( - array.encoded().take(indices.to_array())?, - array.reference_scalar().clone(), - ) - .map(|a| a.into_array()) -} - impl TakeExecute for FoRVTable { fn take( array: &FoRArray, indices: &dyn Array, _ctx: &mut ExecutionCtx, ) -> VortexResult> { - take_for(array, indices).map(Some) + Ok(Some( + FoRArray::try_new( + array.encoded().take(indices.to_array())?, + array.reference_scalar().clone(), + )? + .into_array(), + )) } } diff --git a/encodings/fsst/src/compute/mod.rs b/encodings/fsst/src/compute/mod.rs index fb07d7e197e..1e11609e86c 100644 --- a/encodings/fsst/src/compute/mod.rs +++ b/encodings/fsst/src/compute/mod.rs @@ -21,37 +21,35 @@ use vortex_scalar::ScalarValue; use crate::FSSTArray; use crate::FSSTVTable; -fn take_fsst(array: &FSSTArray, indices: &dyn Array) -> VortexResult { - Ok(FSSTArray::try_new( - array - .dtype() - .clone() - .union_nullability(indices.dtype().nullability()), - array.symbols().clone(), - array.symbol_lengths().clone(), - array - .codes() - .take(indices.to_array())? - .as_::() - .clone(), - fill_null( - &array.uncompressed_lengths().take(indices.to_array())?, - &Scalar::new( - array.uncompressed_lengths_dtype().clone(), - ScalarValue::from(0), - ), - )?, - )? - .into_array()) -} - impl TakeExecute for FSSTVTable { fn take( array: &FSSTArray, indices: &dyn Array, _ctx: &mut ExecutionCtx, ) -> VortexResult> { - take_fsst(array, indices).map(Some) + Ok(Some( + FSSTArray::try_new( + array + .dtype() + .clone() + .union_nullability(indices.dtype().nullability()), + array.symbols().clone(), + array.symbol_lengths().clone(), + array + .codes() + .take(indices.to_array())? + .as_::() + .clone(), + fill_null( + &array.uncompressed_lengths().take(indices.to_array())?, + &Scalar::new( + array.uncompressed_lengths_dtype().clone(), + ScalarValue::from(0), + ), + )?, + )? + .into_array(), + )) } } diff --git a/encodings/runend/src/compute/filter.rs b/encodings/runend/src/compute/filter.rs index e22232ee867..7b43f399558 100644 --- a/encodings/runend/src/compute/filter.rs +++ b/encodings/runend/src/compute/filter.rs @@ -73,30 +73,6 @@ impl FilterKernel for RunEndVTable { } } -// We expose this function to our benchmarks. -pub fn filter_run_end(array: &RunEndArray, mask: &Mask) -> VortexResult { - let primitive_run_ends = array.ends().to_primitive(); - let (run_ends, values_mask) = - match_each_unsigned_integer_ptype!(primitive_run_ends.ptype(), |P| { - filter_run_end_primitive( - primitive_run_ends.as_slice::

(), - array.offset() as u64, - array.len() as u64, - mask.values() - .vortex_expect("AllTrue and AllFalse handled by filter fn") - .bit_buffer(), - )? - }); - let values = array.values().filter(values_mask)?; - - // SAFETY: enforced by filter_run_end_primitive - unsafe { - Ok( - RunEndArray::new_unchecked(run_ends.into_array(), values, 0, mask.true_count()) - .into_array(), - ) - } -} // Code adapted from apache arrow-rs https://github.com/apache/arrow-rs/blob/b1f5c250ebb6c1252b4e7c51d15b8e77f4c361fa/arrow-select/src/filter.rs#L425 fn filter_run_end_primitive + AsPrimitive>( diff --git a/encodings/runend/src/compute/take.rs b/encodings/runend/src/compute/take.rs index 9bb34b776e4..7a44270d971 100644 --- a/encodings/runend/src/compute/take.rs +++ b/encodings/runend/src/compute/take.rs @@ -24,38 +24,34 @@ use vortex_error::vortex_bail; use crate::RunEndArray; use crate::RunEndVTable; -#[expect( - clippy::cast_possible_truncation, - reason = "index cast to usize inside macro" -)] -fn take_runend(array: &RunEndArray, indices: &dyn Array) -> VortexResult { - let primitive_indices = indices.to_primitive(); - - let checked_indices = match_each_integer_ptype!(primitive_indices.ptype(), |P| { - primitive_indices - .as_slice::

() - .iter() - .copied() - .map(|idx| { - let usize_idx = idx as usize; - if usize_idx >= array.len() { - vortex_bail!(OutOfBounds: usize_idx, 0, array.len()); - } - Ok(usize_idx) - }) - .collect::>>()? - }); - - take_indices_unchecked(array, &checked_indices, primitive_indices.validity()) -} - impl TakeExecute for RunEndVTable { + #[expect( + clippy::cast_possible_truncation, + reason = "index cast to usize inside macro" + )] fn take( array: &RunEndArray, indices: &dyn Array, _ctx: &mut ExecutionCtx, ) -> VortexResult> { - take_runend(array, indices).map(Some) + let primitive_indices = indices.to_primitive(); + + let checked_indices = match_each_integer_ptype!(primitive_indices.ptype(), |P| { + primitive_indices + .as_slice::

() + .iter() + .copied() + .map(|idx| { + let usize_idx = idx as usize; + if usize_idx >= array.len() { + vortex_bail!(OutOfBounds: usize_idx, 0, array.len()); + } + Ok(usize_idx) + }) + .collect::>>()? + }); + + take_indices_unchecked(array, &checked_indices, primitive_indices.validity()).map(Some) } } diff --git a/encodings/runend/src/compute/take_from.rs b/encodings/runend/src/compute/take_from.rs index bcabac5db5a..96dd35a34f5 100644 --- a/encodings/runend/src/compute/take_from.rs +++ b/encodings/runend/src/compute/take_from.rs @@ -3,14 +3,11 @@ use std::fmt::Debug; -use vortex_array::Array; use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; -use vortex_array::IntoArray; use vortex_array::arrays::DictArray; use vortex_array::arrays::DictVTable; use vortex_array::kernel::ExecuteParentKernel; -use vortex_array::matcher::Matcher; use vortex_dtype::DType; use vortex_error::VortexResult; diff --git a/encodings/sequence/src/compute/take.rs b/encodings/sequence/src/compute/take.rs index dfe7cecd282..d6e2740ae06 100644 --- a/encodings/sequence/src/compute/take.rs +++ b/encodings/sequence/src/compute/take.rs @@ -30,28 +30,6 @@ use vortex_scalar::Scalar; use crate::SequenceArray; use crate::SequenceVTable; -fn take_sequence(array: &SequenceArray, indices: &dyn Array) -> VortexResult { - let mask = indices.validity_mask()?; - let indices = indices.to_primitive(); - let result_nullability = array.dtype().nullability() | indices.dtype().nullability(); - - match_each_integer_ptype!(indices.ptype(), |T| { - let indices = indices.as_slice::(); - match_each_native_ptype!(array.ptype(), |S| { - let mul = array.multiplier().cast::(); - let base = array.base().cast::(); - Ok(take_inner( - mul, - base, - indices, - mask, - result_nullability, - array.len(), - )) - }) - }) -} - fn take_inner( mul: S, base: S, @@ -103,7 +81,25 @@ impl TakeExecute for SequenceVTable { indices: &dyn Array, _ctx: &mut ExecutionCtx, ) -> VortexResult> { - take_sequence(array, indices).map(Some) + let mask = indices.validity_mask()?; + let indices = indices.to_primitive(); + let result_nullability = array.dtype().nullability() | indices.dtype().nullability(); + + match_each_integer_ptype!(indices.ptype(), |T| { + let indices = indices.as_slice::(); + match_each_native_ptype!(array.ptype(), |S| { + let mul = array.multiplier().cast::(); + let base = array.base().cast::(); + Ok(Some(take_inner( + mul, + base, + indices, + mask, + result_nullability, + array.len(), + ))) + }) + }) } } diff --git a/encodings/sparse/src/compute/take.rs b/encodings/sparse/src/compute/take.rs index a84226629d9..06c494d6747 100644 --- a/encodings/sparse/src/compute/take.rs +++ b/encodings/sparse/src/compute/take.rs @@ -14,45 +14,45 @@ use vortex_error::VortexResult; use crate::SparseArray; use crate::SparseVTable; -fn take_sparse(array: &SparseArray, take_indices: &dyn Array) -> VortexResult { - let patches_take = if array.fill_scalar().is_null() { - array.patches().take(take_indices)? - } else { - array.patches().take_with_nulls(take_indices)? - }; - - let Some(new_patches) = patches_take else { - let result_fill_scalar = array.fill_scalar().cast( - &array - .dtype() - .union_nullability(take_indices.dtype().nullability()), - )?; - return Ok(ConstantArray::new(result_fill_scalar, take_indices.len()).into_array()); - }; - - // See `SparseEncoding::slice`. - if new_patches.array_len() == new_patches.values().len() { - return Ok(new_patches.into_values()); - } - - Ok(SparseArray::try_new_from_patches( - new_patches, - array.fill_scalar().cast( - &array - .dtype() - .union_nullability(take_indices.dtype().nullability()), - )?, - )? - .into_array()) -} - impl TakeExecute for SparseVTable { fn take( array: &SparseArray, indices: &dyn Array, _ctx: &mut ExecutionCtx, ) -> VortexResult> { - take_sparse(array, indices).map(Some) + let patches_take = if array.fill_scalar().is_null() { + array.patches().take(indices)? + } else { + array.patches().take_with_nulls(indices)? + }; + + let Some(new_patches) = patches_take else { + let result_fill_scalar = array.fill_scalar().cast( + &array + .dtype() + .union_nullability(indices.dtype().nullability()), + )?; + return Ok(Some( + ConstantArray::new(result_fill_scalar, indices.len()).into_array(), + )); + }; + + // See `SparseEncoding::slice`. + if new_patches.array_len() == new_patches.values().len() { + return Ok(Some(new_patches.into_values())); + } + + Ok(Some( + SparseArray::try_new_from_patches( + new_patches, + array.fill_scalar().cast( + &array + .dtype() + .union_nullability(indices.dtype().nullability()), + )?, + )? + .into_array(), + )) } } diff --git a/encodings/zigzag/src/compute/mod.rs b/encodings/zigzag/src/compute/mod.rs index 14f7b1aa978..88d9419d8df 100644 --- a/encodings/zigzag/src/compute/mod.rs +++ b/encodings/zigzag/src/compute/mod.rs @@ -28,18 +28,14 @@ impl FilterReduce for ZigZagVTable { } } -fn take_zigzag(array: &ZigZagArray, indices: &dyn Array) -> VortexResult { - let encoded = array.encoded().take(indices.to_array())?; - Ok(ZigZagArray::try_new(encoded)?.into_array()) -} - impl TakeExecute for ZigZagVTable { fn take( array: &ZigZagArray, indices: &dyn Array, _ctx: &mut ExecutionCtx, ) -> VortexResult> { - take_zigzag(array, indices).map(Some) + let encoded = array.encoded().take(indices.to_array())?; + Ok(Some(ZigZagArray::try_new(encoded)?.into_array())) } } diff --git a/vortex-array/src/arrays/bool/compute/take.rs b/vortex-array/src/arrays/bool/compute/take.rs index 258e507a910..d7dfea29807 100644 --- a/vortex-array/src/arrays/bool/compute/take.rs +++ b/vortex-array/src/arrays/bool/compute/take.rs @@ -24,33 +24,30 @@ use crate::executor::ExecutionCtx; use crate::kernel::ParentKernelSet; use crate::vtable::ValidityHelper; -fn take_bool(array: &BoolArray, indices: &dyn Array) -> VortexResult { - let indices_nulls_zeroed = match indices.validity_mask()? { - Mask::AllTrue(_) => indices.to_array(), - Mask::AllFalse(_) => { - return Ok(ConstantArray::new( - Scalar::null(array.dtype().as_nullable()), - indices.len(), - ) - .into_array()); - } - Mask::Values(_) => fill_null(indices, &Scalar::from(0).cast(indices.dtype())?)?, - }; - let indices_nulls_zeroed = indices_nulls_zeroed.to_primitive(); - let buffer = match_each_integer_ptype!(indices_nulls_zeroed.ptype(), |I| { - take_valid_indices(&array.to_bit_buffer(), indices_nulls_zeroed.as_slice::()) - }); - - Ok(BoolArray::new(buffer, array.validity().take(indices)?).to_array()) -} - impl TakeExecute for BoolVTable { fn take( array: &BoolArray, indices: &dyn Array, _ctx: &mut ExecutionCtx, ) -> VortexResult> { - take_bool(array, indices).map(Some) + let indices_nulls_zeroed = match indices.validity_mask()? { + Mask::AllTrue(_) => indices.to_array(), + Mask::AllFalse(_) => { + return Ok(Some( + ConstantArray::new(Scalar::null(array.dtype().as_nullable()), indices.len()) + .into_array(), + )); + } + Mask::Values(_) => fill_null(indices, &Scalar::from(0).cast(indices.dtype())?)?, + }; + let indices_nulls_zeroed = indices_nulls_zeroed.to_primitive(); + let buffer = match_each_integer_ptype!(indices_nulls_zeroed.ptype(), |I| { + take_valid_indices(&array.to_bit_buffer(), indices_nulls_zeroed.as_slice::()) + }); + + Ok(Some( + BoolArray::new(buffer, array.validity().take(indices)?).to_array(), + )) } } diff --git a/vortex-array/src/arrays/decimal/compute/take.rs b/vortex-array/src/arrays/decimal/compute/take.rs index 18298946cb0..c931ee03512 100644 --- a/vortex-array/src/arrays/decimal/compute/take.rs +++ b/vortex-array/src/arrays/decimal/compute/take.rs @@ -19,32 +19,28 @@ use crate::executor::ExecutionCtx; use crate::kernel::ParentKernelSet; use crate::vtable::ValidityHelper; -fn take_decimal(array: &DecimalArray, indices: &dyn Array) -> VortexResult { - let indices = indices.to_primitive(); - let validity = array.validity().take(indices.as_ref())?; - - // TODO(joe): if the true count of take indices validity is low, only take array values with - // valid indices. - let decimal = match_each_decimal_value_type!(array.values_type(), |D| { - match_each_integer_ptype!(indices.ptype(), |I| { - let buffer = - take_to_buffer::(indices.as_slice::(), array.buffer::().as_slice()); - // SAFETY: Take operation preserves decimal dtype and creates valid buffer. - // Validity is computed correctly from the parent array and indices. - unsafe { DecimalArray::new_unchecked(buffer, array.decimal_dtype(), validity) } - }) - }); - - Ok(decimal.to_array()) -} - impl TakeExecute for DecimalVTable { fn take( array: &DecimalArray, indices: &dyn Array, _ctx: &mut ExecutionCtx, ) -> VortexResult> { - take_decimal(array, indices).map(Some) + let indices = indices.to_primitive(); + let validity = array.validity().take(indices.as_ref())?; + + // TODO(joe): if the true count of take indices validity is low, only take array values with + // valid indices. + let decimal = match_each_decimal_value_type!(array.values_type(), |D| { + match_each_integer_ptype!(indices.ptype(), |I| { + let buffer = + take_to_buffer::(indices.as_slice::(), array.buffer::().as_slice()); + // SAFETY: Take operation preserves decimal dtype and creates valid buffer. + // Validity is computed correctly from the parent array and indices. + unsafe { DecimalArray::new_unchecked(buffer, array.decimal_dtype(), validity) } + }) + }); + + Ok(Some(decimal.to_array())) } } diff --git a/vortex-array/src/arrays/fixed_size_list/compute/take.rs b/vortex-array/src/arrays/fixed_size_list/compute/take.rs index 3288e1bbfae..784e9d19ed3 100644 --- a/vortex-array/src/arrays/fixed_size_list/compute/take.rs +++ b/vortex-array/src/arrays/fixed_size_list/compute/take.rs @@ -28,19 +28,16 @@ use crate::vtable::ValidityHelper; /// Unlike `ListView`, `FixedSizeListArray` must rebuild the elements array because it requires /// that elements start at offset 0 and be perfectly packed without gaps. We expand list indices /// into element indices and push them down to the child elements array. -fn take_fixed_size_list(array: &FixedSizeListArray, indices: &dyn Array) -> VortexResult { - match_each_integer_ptype!(indices.dtype().as_ptype(), |I| { - take_with_indices::(array, indices) - }) -} - impl TakeExecute for FixedSizeListVTable { fn take( array: &FixedSizeListArray, indices: &dyn Array, _ctx: &mut ExecutionCtx, ) -> VortexResult> { - take_fixed_size_list(array, indices).map(Some) + match_each_integer_ptype!(indices.dtype().as_ptype(), |I| { + take_with_indices::(array, indices) + }) + .map(Some) } } diff --git a/vortex-array/src/arrays/list/compute/take.rs b/vortex-array/src/arrays/list/compute/take.rs index 19d64cd5c5d..9011b4dc8c6 100644 --- a/vortex-array/src/arrays/list/compute/take.rs +++ b/vortex-array/src/arrays/list/compute/take.rs @@ -26,33 +26,29 @@ use crate::vtable::ValidityHelper; // TODO(connor)[ListView]: Re-revert to the version where we simply convert to a `ListView` and call // the `ListView::take` compute function once `ListView` is more stable. -/// Take implementation for [`ListArray`]. -/// -/// Unlike `ListView`, `ListArray` must rebuild the elements array to maintain its invariant -/// that lists are stored contiguously and in-order (`offset[i+1] >= offset[i]`). Taking -/// non-contiguous indices would violate this requirement. -#[expect(clippy::cognitive_complexity)] -fn take_list(array: &ListArray, indices: &dyn Array) -> VortexResult { - let indices = indices.to_primitive(); - // This is an over-approximation of the total number of elements in the resulting array. - let total_approx = array.elements().len().saturating_mul(indices.len()); - - match_each_integer_ptype!(array.offsets().dtype().as_ptype(), |O| { - match_each_integer_ptype!(indices.ptype(), |I| { - match_smallest_offset_type!(total_approx, |OutputOffsetType| { - _take::(array, &indices) - }) - }) - }) -} - impl TakeExecute for ListVTable { + /// Take implementation for [`ListArray`]. + /// + /// Unlike `ListView`, `ListArray` must rebuild the elements array to maintain its invariant + /// that lists are stored contiguously and in-order (`offset[i+1] >= offset[i]`). Taking + /// non-contiguous indices would violate this requirement. + #[expect(clippy::cognitive_complexity)] fn take( array: &ListArray, indices: &dyn Array, _ctx: &mut ExecutionCtx, ) -> VortexResult> { - take_list(array, indices).map(Some) + let indices = indices.to_primitive(); + // This is an over-approximation of the total number of elements in the resulting array. + let total_approx = array.elements().len().saturating_mul(indices.len()); + + match_each_integer_ptype!(array.offsets().dtype().as_ptype(), |O| { + match_each_integer_ptype!(indices.ptype(), |I| { + match_smallest_offset_type!(total_approx, |OutputOffsetType| { + _take::(array, &indices).map(Some) + }) + }) + }) } } diff --git a/vortex-array/src/arrays/listview/compute/take.rs b/vortex-array/src/arrays/listview/compute/take.rs index 090bb15536a..72f278789e2 100644 --- a/vortex-array/src/arrays/listview/compute/take.rs +++ b/vortex-array/src/arrays/listview/compute/take.rs @@ -44,60 +44,58 @@ const REBUILD_DENSITY_THRESHOLD: f64 = 0.1; /// /// The trade-off is that we may keep unreferenced elements in memory, but this is acceptable since /// we're optimizing for read performance and the data isn't being copied. -fn take_listview(array: &ListViewArray, indices: &dyn Array) -> VortexResult { - let elements = array.elements(); - let offsets = array.offsets(); - let sizes = array.sizes(); - - // Compute the new validity by combining the array's validity with the indices' validity. - let new_validity = array.validity().take(indices)?; - - // Take the offsets and sizes arrays at the requested indices. - // Take can reorder offsets, create gaps, and may introduce overlaps if the `indices` - // contain duplicates. - let nullable_new_offsets = offsets.take(indices.to_array())?; - let nullable_new_sizes = sizes.take(indices.to_array())?; - - // Since `take` returns nullable arrays, we simply cast it back to non-nullable (filled with - // zeros to represent null lists). - let new_offsets = match_each_integer_ptype!(nullable_new_offsets.dtype().as_ptype(), |O| { - compute::fill_null( - &nullable_new_offsets, - &Scalar::primitive(O::zero(), Nullability::NonNullable), - )? - }); - let new_sizes = match_each_integer_ptype!(nullable_new_sizes.dtype().as_ptype(), |S| { - compute::fill_null( - &nullable_new_sizes, - &Scalar::primitive(S::zero(), Nullability::NonNullable), - )? - }); - // SAFETY: Take operation maintains all `ListViewArray` invariants: - // - `new_offsets` and `new_sizes` are derived from existing valid child arrays. - // - `new_offsets` and `new_sizes` are non-nullable. - // - `new_offsets` and `new_sizes` have the same length (both taken with the same - // `indices`). - // - Validity correctly reflects the combination of array and indices validity. - let new_array = unsafe { - ListViewArray::new_unchecked(elements.clone(), new_offsets, new_sizes, new_validity) - }; - - // TODO(connor)[ListView]: Ideally, we would only rebuild after all `take`s and `filter` - // compute functions have run, at the "top" of the operator tree. However, we cannot do this - // right now, so we will just rebuild every time (similar to `ListArray`). - - Ok(new_array - .rebuild(ListViewRebuildMode::MakeZeroCopyToList)? - .into_array()) -} - impl TakeExecute for ListViewVTable { fn take( array: &ListViewArray, indices: &dyn Array, _ctx: &mut ExecutionCtx, ) -> VortexResult> { - take_listview(array, indices).map(Some) + let elements = array.elements(); + let offsets = array.offsets(); + let sizes = array.sizes(); + + // Compute the new validity by combining the array's validity with the indices' validity. + let new_validity = array.validity().take(indices)?; + + // Take the offsets and sizes arrays at the requested indices. + // Take can reorder offsets, create gaps, and may introduce overlaps if the `indices` + // contain duplicates. + let nullable_new_offsets = offsets.take(indices.to_array())?; + let nullable_new_sizes = sizes.take(indices.to_array())?; + + // Since `take` returns nullable arrays, we simply cast it back to non-nullable (filled with + // zeros to represent null lists). + let new_offsets = match_each_integer_ptype!(nullable_new_offsets.dtype().as_ptype(), |O| { + compute::fill_null( + &nullable_new_offsets, + &Scalar::primitive(O::zero(), Nullability::NonNullable), + )? + }); + let new_sizes = match_each_integer_ptype!(nullable_new_sizes.dtype().as_ptype(), |S| { + compute::fill_null( + &nullable_new_sizes, + &Scalar::primitive(S::zero(), Nullability::NonNullable), + )? + }); + // SAFETY: Take operation maintains all `ListViewArray` invariants: + // - `new_offsets` and `new_sizes` are derived from existing valid child arrays. + // - `new_offsets` and `new_sizes` are non-nullable. + // - `new_offsets` and `new_sizes` have the same length (both taken with the same + // `indices`). + // - Validity correctly reflects the combination of array and indices validity. + let new_array = unsafe { + ListViewArray::new_unchecked(elements.clone(), new_offsets, new_sizes, new_validity) + }; + + // TODO(connor)[ListView]: Ideally, we would only rebuild after all `take`s and `filter` + // compute functions have run, at the "top" of the operator tree. However, we cannot do this + // right now, so we will just rebuild every time (similar to `ListArray`). + + Ok(Some( + new_array + .rebuild(ListViewRebuildMode::MakeZeroCopyToList)? + .into_array(), + )) } } diff --git a/vortex-array/src/arrays/primitive/compute/take/mod.rs b/vortex-array/src/arrays/primitive/compute/take/mod.rs index e82055363c0..23793ed9e78 100644 --- a/vortex-array/src/arrays/primitive/compute/take/mod.rs +++ b/vortex-array/src/arrays/primitive/compute/take/mod.rs @@ -82,30 +82,28 @@ impl TakeImpl for TakeKernelScalar { } } -fn take_primitive(array: &PrimitiveArray, indices: &dyn Array) -> VortexResult { - let DType::Primitive(ptype, null) = indices.dtype() else { - vortex_bail!("Invalid indices dtype: {}", indices.dtype()) - }; - - let unsigned_indices = if ptype.is_unsigned_int() { - indices.to_primitive() - } else { - // This will fail if all values cannot be converted to unsigned - cast(indices, &DType::Primitive(ptype.to_unsigned(), *null))?.to_primitive() - }; - - let validity = array.validity().take(unsigned_indices.as_ref())?; - // Delegate to the best kernel based on the target CPU - PRIMITIVE_TAKE_KERNEL.take(array, &unsigned_indices, validity) -} - impl TakeExecute for PrimitiveVTable { fn take( array: &PrimitiveArray, indices: &dyn Array, _ctx: &mut ExecutionCtx, ) -> VortexResult> { - take_primitive(array, indices).map(Some) + let DType::Primitive(ptype, null) = indices.dtype() else { + vortex_bail!("Invalid indices dtype: {}", indices.dtype()) + }; + + let unsigned_indices = if ptype.is_unsigned_int() { + indices.to_primitive() + } else { + // This will fail if all values cannot be converted to unsigned + cast(indices, &DType::Primitive(ptype.to_unsigned(), *null))?.to_primitive() + }; + + let validity = array.validity().take(unsigned_indices.as_ref())?; + // Delegate to the best kernel based on the target CPU + PRIMITIVE_TAKE_KERNEL + .take(array, &unsigned_indices, validity) + .map(Some) } } diff --git a/vortex-array/src/arrays/struct_/compute/rules.rs b/vortex-array/src/arrays/struct_/compute/rules.rs index 49401f1589a..ff2df148c23 100644 --- a/vortex-array/src/arrays/struct_/compute/rules.rs +++ b/vortex-array/src/arrays/struct_/compute/rules.rs @@ -9,7 +9,6 @@ use crate::ArrayRef; use crate::IntoArray; use crate::arrays::ConstantArray; use crate::arrays::ExactScalarFn; -use crate::arrays::FilterReduceAdaptor; use crate::arrays::ScalarFnArrayExt; use crate::arrays::ScalarFnArrayView; use crate::arrays::SliceReduceAdaptor; diff --git a/vortex-array/src/arrays/struct_/compute/take.rs b/vortex-array/src/arrays/struct_/compute/take.rs index db694e0da73..a152da021ee 100644 --- a/vortex-array/src/arrays/struct_/compute/take.rs +++ b/vortex-array/src/arrays/struct_/compute/take.rs @@ -18,43 +18,41 @@ use crate::kernel::ParentKernelSet; use crate::validity::Validity; use crate::vtable::ValidityHelper; -fn take_struct(array: &StructArray, indices: &dyn Array) -> VortexResult { - // If the struct array is empty then the indices must be all null, otherwise it will access - // an out of bounds element - if array.is_empty() { - return StructArray::try_new_with_dtype( - array.unmasked_fields().clone(), - array.struct_fields().clone(), - indices.len(), - Validity::AllInvalid, - ) - .map(StructArray::into_array); - } - // The validity is applied to the struct validity, - let inner_indices = &compute::fill_null( - indices, - &Scalar::default_value(indices.dtype().with_nullability(Nullability::NonNullable)), - )?; - StructArray::try_new_with_dtype( - array - .unmasked_fields() - .iter() - .map(|field| field.take(inner_indices.to_array())) - .collect::, _>>()?, - array.struct_fields().clone(), - indices.len(), - array.validity().take(indices)?, - ) - .map(|a| a.into_array()) -} - impl TakeExecute for StructVTable { fn take( array: &StructArray, indices: &dyn Array, _ctx: &mut ExecutionCtx, ) -> VortexResult> { - take_struct(array, indices).map(Some) + // If the struct array is empty then the indices must be all null, otherwise it will access + // an out of bounds element + if array.is_empty() { + return StructArray::try_new_with_dtype( + array.unmasked_fields().clone(), + array.struct_fields().clone(), + indices.len(), + Validity::AllInvalid, + ) + .map(StructArray::into_array) + .map(Some); + } + // The validity is applied to the struct validity, + let inner_indices = &compute::fill_null( + indices, + &Scalar::default_value(indices.dtype().with_nullability(Nullability::NonNullable)), + )?; + StructArray::try_new_with_dtype( + array + .unmasked_fields() + .iter() + .map(|field| field.take(inner_indices.to_array())) + .collect::, _>>()?, + array.struct_fields().clone(), + indices.len(), + array.validity().take(indices)?, + ) + .map(|a| a.into_array()) + .map(Some) } } diff --git a/vortex-array/src/arrays/varbin/compute/take.rs b/vortex-array/src/arrays/varbin/compute/take.rs index 66536ed7abc..d643a9c6064 100644 --- a/vortex-array/src/arrays/varbin/compute/take.rs +++ b/vortex-array/src/arrays/varbin/compute/take.rs @@ -25,100 +25,96 @@ use crate::executor::ExecutionCtx; use crate::kernel::ParentKernelSet; use crate::validity::Validity; -#[expect( - clippy::redundant_clone, - reason = "macro expansion causes false positive - only one match arm executes" -)] -fn take_varbin(array: &VarBinArray, indices: &dyn Array) -> VortexResult { - let offsets = array.offsets().to_primitive(); - let data = array.bytes(); - let indices = indices.to_primitive(); - let dtype = array - .dtype() - .clone() - .union_nullability(indices.dtype().nullability()); - let result = match_each_integer_ptype!(indices.ptype(), |I| { - // On take, offsets get widened to either 32- or 64-bit based on the original type, - // to avoid overflow issues. - match offsets.ptype() { - PType::U8 => take_impl::( - dtype.clone(), - offsets.as_slice::(), - data.as_slice(), - indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, - ), - PType::U16 => take_impl::( - dtype.clone(), - offsets.as_slice::(), - data.as_slice(), - indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, - ), - PType::U32 => take_impl::( - dtype.clone(), - offsets.as_slice::(), - data.as_slice(), - indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, - ), - PType::U64 => take_impl::( - dtype.clone(), - offsets.as_slice::(), - data.as_slice(), - indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, - ), - PType::I8 => take_impl::( - dtype.clone(), - offsets.as_slice::(), - data.as_slice(), - indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, - ), - PType::I16 => take_impl::( - dtype.clone(), - offsets.as_slice::(), - data.as_slice(), - indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, - ), - PType::I32 => take_impl::( - dtype.clone(), - offsets.as_slice::(), - data.as_slice(), - indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, - ), - PType::I64 => take_impl::( - dtype.clone(), - offsets.as_slice::(), - data.as_slice(), - indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, - ), - _ => unreachable!("invalid PType for offsets"), - } - }); - - Ok(result?.into_array()) -} - impl TakeExecute for VarBinVTable { + #[expect( + clippy::redundant_clone, + reason = "macro expansion causes false positive - only one match arm executes" + )] fn take( array: &VarBinArray, indices: &dyn Array, _ctx: &mut ExecutionCtx, ) -> VortexResult> { - take_varbin(array, indices).map(Some) + let offsets = array.offsets().to_primitive(); + let data = array.bytes(); + let indices = indices.to_primitive(); + let dtype = array + .dtype() + .clone() + .union_nullability(indices.dtype().nullability()); + let result = match_each_integer_ptype!(indices.ptype(), |I| { + // On take, offsets get widened to either 32- or 64-bit based on the original type, + // to avoid overflow issues. + match offsets.ptype() { + PType::U8 => take_impl::( + dtype.clone(), + offsets.as_slice::(), + data.as_slice(), + indices.as_slice::(), + array.validity_mask()?, + indices.validity_mask()?, + ), + PType::U16 => take_impl::( + dtype.clone(), + offsets.as_slice::(), + data.as_slice(), + indices.as_slice::(), + array.validity_mask()?, + indices.validity_mask()?, + ), + PType::U32 => take_impl::( + dtype.clone(), + offsets.as_slice::(), + data.as_slice(), + indices.as_slice::(), + array.validity_mask()?, + indices.validity_mask()?, + ), + PType::U64 => take_impl::( + dtype.clone(), + offsets.as_slice::(), + data.as_slice(), + indices.as_slice::(), + array.validity_mask()?, + indices.validity_mask()?, + ), + PType::I8 => take_impl::( + dtype.clone(), + offsets.as_slice::(), + data.as_slice(), + indices.as_slice::(), + array.validity_mask()?, + indices.validity_mask()?, + ), + PType::I16 => take_impl::( + dtype.clone(), + offsets.as_slice::(), + data.as_slice(), + indices.as_slice::(), + array.validity_mask()?, + indices.validity_mask()?, + ), + PType::I32 => take_impl::( + dtype.clone(), + offsets.as_slice::(), + data.as_slice(), + indices.as_slice::(), + array.validity_mask()?, + indices.validity_mask()?, + ), + PType::I64 => take_impl::( + dtype.clone(), + offsets.as_slice::(), + data.as_slice(), + indices.as_slice::(), + array.validity_mask()?, + indices.validity_mask()?, + ), + _ => unreachable!("invalid PType for offsets"), + } + }); + + Ok(Some(result?.into_array())) } } diff --git a/vortex-array/src/arrays/varbinview/compute/take.rs b/vortex-array/src/arrays/varbinview/compute/take.rs index d5abe48f102..ee027d3a5b3 100644 --- a/vortex-array/src/arrays/varbinview/compute/take.rs +++ b/vortex-array/src/arrays/varbinview/compute/take.rs @@ -24,37 +24,35 @@ use crate::executor::ExecutionCtx; use crate::kernel::ParentKernelSet; use crate::vtable::ValidityHelper; -/// Take involves creating a new array that references the old array, just with the given set of views. -fn take_varbinview(array: &VarBinViewArray, indices: &dyn Array) -> VortexResult { - let validity = array.validity().take(indices)?; - let indices = indices.to_primitive(); - - let indices_mask = indices.validity_mask()?; - let views_buffer = match_each_integer_ptype!(indices.ptype(), |I| { - take_views(array.views(), indices.as_slice::(), &indices_mask) - }); - - // SAFETY: taking all components at same indices maintains invariants - unsafe { - Ok(VarBinViewArray::new_handle_unchecked( - BufferHandle::new_host(views_buffer.into_byte_buffer()), - array.buffers().clone(), - array - .dtype() - .union_nullability(indices.dtype().nullability()), - validity, - ) - .into_array()) - } -} - impl TakeExecute for VarBinViewVTable { + /// Take involves creating a new array that references the old array, just with the given set of views. fn take( array: &VarBinViewArray, indices: &dyn Array, _ctx: &mut ExecutionCtx, ) -> VortexResult> { - take_varbinview(array, indices).map(Some) + let validity = array.validity().take(indices)?; + let indices = indices.to_primitive(); + + let indices_mask = indices.validity_mask()?; + let views_buffer = match_each_integer_ptype!(indices.ptype(), |I| { + take_views(array.views(), indices.as_slice::(), &indices_mask) + }); + + // SAFETY: taking all components at same indices maintains invariants + unsafe { + Ok(Some( + VarBinViewArray::new_handle_unchecked( + BufferHandle::new_host(views_buffer.into_byte_buffer()), + array.buffers().clone(), + array + .dtype() + .union_nullability(indices.dtype().nullability()), + validity, + ) + .into_array(), + )) + } } } From c0fe4d828d063a8dd818defc26090835cf868be5 Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Fri, 6 Feb 2026 14:22:19 +0000 Subject: [PATCH 13/21] wip Signed-off-by: Joe Isaacs --- encodings/alp/src/alp/array.rs | 1 + encodings/runend/src/compute/filter.rs | 25 ---- encodings/runend/src/compute/take_from.rs | 134 ++++++++-------------- 3 files changed, 51 insertions(+), 109 deletions(-) diff --git a/encodings/alp/src/alp/array.rs b/encodings/alp/src/alp/array.rs index 157ea88abbb..95a9cc4166a 100644 --- a/encodings/alp/src/alp/array.rs +++ b/encodings/alp/src/alp/array.rs @@ -452,6 +452,7 @@ mod tests { use std::sync::LazyLock; use rstest::rstest; + use vortex_array::Canonical; use vortex_array::IntoArray; use vortex_array::ToCanonical; use vortex_array::VortexSessionExecute; diff --git a/encodings/runend/src/compute/filter.rs b/encodings/runend/src/compute/filter.rs index 7b43f399558..141ff38373e 100644 --- a/encodings/runend/src/compute/filter.rs +++ b/encodings/runend/src/compute/filter.rs @@ -73,7 +73,6 @@ impl FilterKernel for RunEndVTable { } } - // Code adapted from apache arrow-rs https://github.com/apache/arrow-rs/blob/b1f5c250ebb6c1252b4e7c51d15b8e77f4c361fa/arrow-select/src/filter.rs#L425 fn filter_run_end_primitive + AsPrimitive>( run_ends: &[R], @@ -124,9 +123,7 @@ mod tests { use vortex_error::VortexResult; use vortex_mask::Mask; - use super::filter_run_end; use crate::RunEndArray; - use crate::RunEndVTable; fn ree_array() -> RunEndArray { RunEndArray::encode( @@ -135,28 +132,6 @@ mod tests { .unwrap() } - #[test] - fn run_end_filter() { - let arr = ree_array(); - let filtered = filter_run_end( - &arr, - &Mask::from_iter([ - true, true, false, false, false, false, false, false, false, false, true, true, - ]), - ) - .unwrap(); - let filtered_run_end = filtered.as_::(); - - assert_arrays_eq!( - filtered_run_end.ends().to_primitive(), - PrimitiveArray::from_iter([2u8, 4]) - ); - assert_arrays_eq!( - filtered_run_end.values().to_primitive(), - PrimitiveArray::from_iter([1i32, 5]) - ); - } - #[test] fn filter_sliced_run_end() -> VortexResult<()> { let arr = ree_array().slice(2..7).unwrap(); diff --git a/encodings/runend/src/compute/take_from.rs b/encodings/runend/src/compute/take_from.rs index 96dd35a34f5..92cb0856096 100644 --- a/encodings/runend/src/compute/take_from.rs +++ b/encodings/runend/src/compute/take_from.rs @@ -1,104 +1,70 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use std::fmt::Debug; - -use vortex_array::ArrayRef; -use vortex_array::ExecutionCtx; -use vortex_array::arrays::DictArray; -use vortex_array::arrays::DictVTable; -use vortex_array::kernel::ExecuteParentKernel; -use vortex_dtype::DType; -use vortex_error::VortexResult; - -use crate::RunEndArray; -use crate::RunEndVTable; - -// impl ParentKernelSet - -#[derive(Debug)] -struct RunEndVTableTakeFrom; - -impl ExecuteParentKernel for RunEndVTableTakeFrom { - type Parent = DictVTable; - - fn execute_parent( - &self, - _array: &RunEndArray, - dict: &DictArray, - child_idx: usize, - _ctx: &mut ExecutionCtx, - ) -> VortexResult> { - if child_idx != 0 { - return Ok(None); - } - // Only `Primitive` and `Bool` are valid run-end value types. - // TODO: Support additional DTypes - if !matches!(dict.dtype(), DType::Primitive(_, _) | DType::Bool(_)) { - return Ok(None); - } - - // // Transform the run-end encoding from storing indices to storing values - // // by taking values from `source` at positions specified by `indices.values()`. - // - // // Create a new run-end array containing values as values, instead of indices as values. - // // SAFETY: we are copying ends from an existing valid RunEndArray - // let ree_array = unsafe { - // RunEndArray::new_unchecked( - // array.ends().clone(), - // dict.values().take(array.values().clone())?, - // array.offset(), - // array.len(), - // ) - // }; - // - // Ok(Some(ree_array.into_array())) - - // TODO: implement run-end take from optimization - // For now, skip this optimization and fall back to default take - Ok(None) - } -} - -// impl TakeFromKernel for RunEndVTable { -// /// Takes values from the source array using run-end encoded indices. -// /// -// /// # Arguments -// /// -// /// * `indices` - Run-end encoded indices -// /// * `source` - Array to take values from -// /// -// /// # Returns -// /// -// /// * `Ok(Some(source))` - If successful -// /// * `Ok(None)` - If the source array has an unsupported dtype -// /// -// fn take_from( -// &self, -// indices: &RunEndArray, -// source: &dyn Array, -// ) -> VortexResult> { -// -// } -// } -// -// register_kernel!(TakeFromKernelAdapter(RunEndVTable).lift()); - #[cfg(test)] mod tests { + use std::fmt::Debug; + use vortex_array::Array; + use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::DictArray; + use vortex_array::arrays::DictVTable; use vortex_array::arrays::PrimitiveArray; use vortex_array::assert_arrays_eq; use vortex_array::kernel::ExecuteParentKernel; use vortex_buffer::buffer; + use vortex_dtype::DType; use vortex_error::VortexResult; use vortex_session::VortexSession; - use super::RunEndVTableTakeFrom; use crate::RunEndArray; + use crate::RunEndVTable; + + #[derive(Debug)] + struct RunEndVTableTakeFrom; + + impl ExecuteParentKernel for RunEndVTableTakeFrom { + type Parent = DictVTable; + + fn execute_parent( + &self, + array: &RunEndArray, + dict: &DictArray, + child_idx: usize, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + if child_idx != 0 { + return Ok(None); + } + // Only `Primitive` and `Bool` are valid run-end value types. + // TODO: Support additional DTypes + if !matches!(dict.dtype(), DType::Primitive(_, _) | DType::Bool(_)) { + return Ok(None); + } + + // // Transform the run-end encoding from storing indices to storing values + // // by taking values from `source` at positions specified by `indices.values()`. + // + // // Create a new run-end array containing values as values, instead of indices as values. + // // SAFETY: we are copying ends from an existing valid RunEndArray + let ree_array = unsafe { + RunEndArray::new_unchecked( + array.ends().clone(), + dict.values().take(array.values().clone())?, + array.offset(), + array.len(), + ) + }; + // + Ok(Some(ree_array.into_array())) + + // TODO: implement run-end take from optimization + // For now, skip this optimization and fall back to default take + // Ok(None) + } + } /// Build a DictArray whose codes are run-end encoded. /// From 53b9d17c1e5ec77909f06e0a83703879861b6ed1 Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Fri, 6 Feb 2026 14:32:10 +0000 Subject: [PATCH 14/21] wip Signed-off-by: Joe Isaacs --- vortex-array/src/arrays/dict/take.rs | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/vortex-array/src/arrays/dict/take.rs b/vortex-array/src/arrays/dict/take.rs index 13af8f75de0..9531e6d56b6 100644 --- a/vortex-array/src/arrays/dict/take.rs +++ b/vortex-array/src/arrays/dict/take.rs @@ -82,7 +82,10 @@ where parent: &DictArray, child_idx: usize, ) -> VortexResult> { - assert_eq!(child_idx, 1); + // Only handle the values child (index 1), not the codes child (index 0). + if child_idx != 1 { + return Ok(None); + } if let Some(result) = precondition::(array, parent.codes()) { return Ok(Some(result)); } @@ -110,7 +113,10 @@ where child_idx: usize, ctx: &mut ExecutionCtx, ) -> VortexResult> { - assert_eq!(child_idx, 1); + // Only handle the values child (index 1), not the codes child (index 0). + if child_idx != 1 { + return Ok(None); + } if let Some(result) = precondition::(array, parent.codes()) { return Ok(Some(result)); } From f05ebb9e4ec486632cea8f6017b0aeb8fe981c65 Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Fri, 6 Feb 2026 15:39:09 +0000 Subject: [PATCH 15/21] wip Signed-off-by: Joe Isaacs --- encodings/fsst/src/tests.rs | 3 +-- encodings/runend/src/compute/filter.rs | 1 - encodings/sequence/src/compute/take.rs | 2 +- encodings/sparse/src/compute/take.rs | 19 +++++-------------- vortex-array/src/arrays/dict/take.rs | 6 +++++- .../src/arrays/fixed_size_list/array.rs | 5 +++-- vortex-array/src/arrays/masked/execute.rs | 2 +- .../src/arrays/varbin/compute/take.rs | 12 ++++++------ 8 files changed, 22 insertions(+), 28 deletions(-) diff --git a/encodings/fsst/src/tests.rs b/encodings/fsst/src/tests.rs index 9163d96250c..7490a07746d 100644 --- a/encodings/fsst/src/tests.rs +++ b/encodings/fsst/src/tests.rs @@ -69,8 +69,7 @@ fn test_fsst_array_ops() { // test take let indices = buffer![0, 2].into_array(); - let fsst_taken = take(&fsst_array, &indices).unwrap(); - assert!(fsst_taken.is::()); + let fsst_taken = fsst_array.take(indices).unwrap(); assert_eq!(fsst_taken.len(), 2); assert_nth_scalar!( fsst_taken, diff --git a/encodings/runend/src/compute/filter.rs b/encodings/runend/src/compute/filter.rs index 141ff38373e..269d64b97f4 100644 --- a/encodings/runend/src/compute/filter.rs +++ b/encodings/runend/src/compute/filter.rs @@ -117,7 +117,6 @@ fn filter_run_end_primitive + AsPrimitiv mod tests { use vortex_array::Array; use vortex_array::IntoArray; - use vortex_array::ToCanonical; use vortex_array::arrays::PrimitiveArray; use vortex_array::assert_arrays_eq; use vortex_error::VortexResult; diff --git a/encodings/sequence/src/compute/take.rs b/encodings/sequence/src/compute/take.rs index d6e2740ae06..b3725389ed3 100644 --- a/encodings/sequence/src/compute/take.rs +++ b/encodings/sequence/src/compute/take.rs @@ -171,7 +171,7 @@ mod test { } #[test] - #[should_panic(expected = "index 20 out of bounds")] + #[should_panic(expected = "index out of bounds")] fn test_bounds_check() { let array = SequenceArray::typed_new(0i32, 1i32, Nullability::NonNullable, 10).unwrap(); let indices = vortex_array::arrays::PrimitiveArray::from_iter([0i32, 20]); diff --git a/encodings/sparse/src/compute/take.rs b/encodings/sparse/src/compute/take.rs index 06c494d6747..39f0b76ef1e 100644 --- a/encodings/sparse/src/compute/take.rs +++ b/encodings/sparse/src/compute/take.rs @@ -67,7 +67,6 @@ mod test { use vortex_array::Array; use vortex_array::ArrayRef; use vortex_array::IntoArray; - use vortex_array::ToCanonical; use vortex_array::arrays::ConstantArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::assert_arrays_eq; @@ -78,7 +77,6 @@ mod test { use vortex_scalar::Scalar; use crate::SparseArray; - use crate::SparseVTable; fn test_array_fill_value() -> Scalar { // making this const is annoying @@ -130,18 +128,11 @@ mod test { #[test] fn ordered_take() { let sparse = sparse_array(); - let taken_arr = take(&sparse, &buffer![69, 37].into_array()).unwrap(); - let taken = taken_arr.as_::(); - - assert_arrays_eq!( - taken.patches().indices().to_primitive(), - PrimitiveArray::from_iter([1u64]) - ); - assert_arrays_eq!( - taken.patches().values().to_primitive(), - PrimitiveArray::from_option_iter([Some(0.47f64)]) - ); - assert_eq!(taken.len(), 2); + // Note: take returns a canonical array, not SparseArray + let taken = take(&sparse, &buffer![69, 37].into_array()).unwrap(); + // Index 69 is not in sparse array (fill value is null), index 37 has value 0.47 + let expected = PrimitiveArray::from_option_iter([Option::::None, Some(0.47f64)]); + assert_arrays_eq!(taken, expected.to_array()); } #[test] diff --git a/vortex-array/src/arrays/dict/take.rs b/vortex-array/src/arrays/dict/take.rs index 9531e6d56b6..8dc4a98f54b 100644 --- a/vortex-array/src/arrays/dict/take.rs +++ b/vortex-array/src/arrays/dict/take.rs @@ -52,7 +52,11 @@ pub trait TakeExecute: VTable { fn precondition(array: &V::Array, indices: &dyn Array) -> Option { // Fast-path for empty indices. if indices.is_empty() { - return Some(Canonical::empty(array.dtype()).into_array()); + let result_dtype = array + .dtype() + .clone() + .union_nullability(indices.dtype().nullability()); + return Some(Canonical::empty(&result_dtype).into_array()); } // TODO(joe): shall we enable this seems expensive. diff --git a/vortex-array/src/arrays/fixed_size_list/array.rs b/vortex-array/src/arrays/fixed_size_list/array.rs index 65692a8f495..e512b77518c 100644 --- a/vortex-array/src/arrays/fixed_size_list/array.rs +++ b/vortex-array/src/arrays/fixed_size_list/array.rs @@ -227,8 +227,9 @@ impl FixedSizeListArray { pub fn fixed_size_list_elements_at(&self, index: usize) -> VortexResult { debug_assert!( index < self.len, - "index out of bounds: the len is {} but the index is {index}", - self.len + "index {} out of bounds: the len is {}", + index, + self.len, ); debug_assert!(self.validity.is_valid(index).unwrap_or(false)); diff --git a/vortex-array/src/arrays/masked/execute.rs b/vortex-array/src/arrays/masked/execute.rs index 1dd9523ed4e..00e803de64e 100644 --- a/vortex-array/src/arrays/masked/execute.rs +++ b/vortex-array/src/arrays/masked/execute.rs @@ -99,7 +99,7 @@ fn mask_validity_decimal(array: DecimalArray, mask: &Mask) -> DecimalArray { /// Mask validity for VarBinViewArray. fn mask_validity_varbinview(array: VarBinViewArray, mask: &Mask) -> VarBinViewArray { let len = array.len(); - let dtype = array.dtype().clone(); + let dtype = array.dtype().as_nullable(); let new_validity = combine_validity(array.validity(), mask, len); // SAFETY: We're only changing validity, not the data structure unsafe { diff --git a/vortex-array/src/arrays/varbin/compute/take.rs b/vortex-array/src/arrays/varbin/compute/take.rs index d643a9c6064..634eb22b800 100644 --- a/vortex-array/src/arrays/varbin/compute/take.rs +++ b/vortex-array/src/arrays/varbin/compute/take.rs @@ -265,7 +265,7 @@ mod tests { use crate::IntoArray; use crate::arrays::PrimitiveArray; use crate::arrays::VarBinArray; - use crate::arrays::VarBinVTable; + use crate::arrays::VarBinViewVTable; use crate::compute::conformance::take::test_take_conformance; use crate::compute::take; use crate::validity::Validity; @@ -323,10 +323,10 @@ mod tests { let indices = buffer![0u32, 0u32, 0u32].into_array(); let taken = take(array.as_ref(), indices.as_ref()).unwrap(); - let taken_str = taken.as_::(); - assert_eq!(taken_str.len(), 3); - assert_eq!(taken_str.bytes_at(0).as_bytes(), scream.as_bytes()); - assert_eq!(taken_str.bytes_at(1).as_bytes(), scream.as_bytes()); - assert_eq!(taken_str.bytes_at(2).as_bytes(), scream.as_bytes()); + let taken_view = taken.as_::(); + assert_eq!(taken_view.len(), 3); + assert_eq!(taken_view.bytes_at(0).as_slice(), scream.as_bytes()); + assert_eq!(taken_view.bytes_at(1).as_slice(), scream.as_bytes()); + assert_eq!(taken_view.bytes_at(2).as_slice(), scream.as_bytes()); } } From 5269799d31d3fc74f86b1df63d71a512d938c980 Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Fri, 6 Feb 2026 15:50:02 +0000 Subject: [PATCH 16/21] wip Signed-off-by: Joe Isaacs --- encodings/fsst/src/tests.rs | 1 - .../src/arrays/chunked/compute/filter.rs | 16 ++++----- vortex-array/src/arrays/dict/compute/mod.rs | 3 +- vortex-array/src/arrays/dict/take.rs | 34 ++++++++++++++++++- vortex-array/src/arrays/list/compute/take.rs | 14 ++++++-- .../src/arrays/masked/compute/take.rs | 5 ++- vortex-array/src/compute/take.rs | 34 +------------------ vortex-array/src/validity.rs | 3 +- 8 files changed, 55 insertions(+), 55 deletions(-) diff --git a/encodings/fsst/src/tests.rs b/encodings/fsst/src/tests.rs index 7490a07746d..d57ccd79e8e 100644 --- a/encodings/fsst/src/tests.rs +++ b/encodings/fsst/src/tests.rs @@ -8,7 +8,6 @@ use vortex_array::ToCanonical; use vortex_array::arrays::builder::VarBinBuilder; use vortex_array::assert_arrays_eq; use vortex_array::assert_nth_scalar; -use vortex_array::compute::take; use vortex_buffer::buffer; use vortex_dtype::DType; use vortex_dtype::Nullability; diff --git a/vortex-array/src/arrays/chunked/compute/filter.rs b/vortex-array/src/arrays/chunked/compute/filter.rs index 7eacbd731a2..2d47615dfa4 100644 --- a/vortex-array/src/arrays/chunked/compute/filter.rs +++ b/vortex-array/src/arrays/chunked/compute/filter.rs @@ -15,7 +15,6 @@ use crate::arrays::ChunkedArray; use crate::arrays::ChunkedVTable; use crate::arrays::PrimitiveArray; use crate::arrays::filter::FilterKernel; -use crate::compute::take; use crate::search_sorted::SearchSorted; use crate::search_sorted::SearchSortedSide; use crate::validity::Validity; @@ -162,11 +161,9 @@ fn filter_indices( // Push the chunk we've accumulated. if !chunk_indices.is_empty() { let chunk = array.chunk(current_chunk_id); - let filtered_chunk = take( - chunk, - PrimitiveArray::new(chunk_indices.clone().freeze(), Validity::NonNullable) - .as_ref(), - )?; + let indices = + PrimitiveArray::new(chunk_indices.clone().freeze(), Validity::NonNullable); + let filtered_chunk = chunk.take(indices.to_array())?.to_canonical()?.into_array(); result.push(filtered_chunk); } @@ -180,10 +177,9 @@ fn filter_indices( if !chunk_indices.is_empty() { let chunk = array.chunk(current_chunk_id); - let filtered_chunk = take( - chunk, - PrimitiveArray::new(chunk_indices.clone().freeze(), Validity::NonNullable).as_ref(), - )?; + let indices = + PrimitiveArray::new(chunk_indices.clone().freeze(), Validity::NonNullable); + let filtered_chunk = chunk.take(indices.to_array())?.to_canonical()?.into_array(); result.push(filtered_chunk); } diff --git a/vortex-array/src/arrays/dict/compute/mod.rs b/vortex-array/src/arrays/dict/compute/mod.rs index 94316a730f5..5015726b378 100644 --- a/vortex-array/src/arrays/dict/compute/mod.rs +++ b/vortex-array/src/arrays/dict/compute/mod.rs @@ -24,7 +24,6 @@ use crate::ArrayRef; use crate::ExecutionCtx; use crate::IntoArray; use crate::arrays::filter::FilterReduce; -use crate::compute::take; use crate::kernel::ParentKernelSet; impl TakeExecute for DictVTable { @@ -33,7 +32,7 @@ impl TakeExecute for DictVTable { indices: &dyn Array, _ctx: &mut ExecutionCtx, ) -> VortexResult> { - let codes = take(array.codes(), indices)?; + let codes = array.codes().take(indices.to_array())?.to_canonical()?.into_array(); // SAFETY: selecting codes doesn't change the invariants of DictArray // Preserve all_values_referenced since taking codes doesn't affect which values are referenced Ok(Some(unsafe { diff --git a/vortex-array/src/arrays/dict/take.rs b/vortex-array/src/arrays/dict/take.rs index 8dc4a98f54b..86c3700952d 100644 --- a/vortex-array/src/arrays/dict/take.rs +++ b/vortex-array/src/arrays/dict/take.rs @@ -10,10 +10,14 @@ use crate::ArrayRef; use crate::Canonical; use crate::ExecutionCtx; use crate::IntoArray; -use crate::compute::propagate_take_stats; +use crate::expr::stats::Precision; +use crate::expr::stats::Stat; +use crate::expr::stats::StatsProvider; +use crate::expr::stats::StatsProviderExt; use crate::kernel::ExecuteParentKernel; use crate::matcher::Matcher; use crate::optimizer::rules::ArrayParentReduceRule; +use crate::stats::StatsSet; use crate::vtable::VTable; pub trait TakeReduce: VTable { @@ -131,3 +135,31 @@ where Ok(result) } } + +pub(crate) fn propagate_take_stats( + source: &dyn Array, + target: &dyn Array, + indices: &dyn Array, +) -> VortexResult<()> { + target.statistics().with_mut_typed_stats_set(|mut st| { + if indices.all_valid().unwrap_or(false) { + let is_constant = source.statistics().get_as::(Stat::IsConstant); + if is_constant == Some(Precision::Exact(true)) { + // Any combination of elements from a constant array is still const + st.set(Stat::IsConstant, Precision::exact(true)); + } + } + let inexact_min_max = [Stat::Min, Stat::Max] + .into_iter() + .filter_map(|stat| { + source + .statistics() + .get(stat) + .map(|v| (stat, v.map(|s| s.into_value()).into_inexact())) + }) + .collect::>(); + st.combine_sets( + &(unsafe { StatsSet::new_unchecked(inexact_min_max) }).as_typed_ref(source.dtype()), + ) + }) +} diff --git a/vortex-array/src/arrays/list/compute/take.rs b/vortex-array/src/arrays/list/compute/take.rs index 9011b4dc8c6..01864d6dd75 100644 --- a/vortex-array/src/arrays/list/compute/take.rs +++ b/vortex-array/src/arrays/list/compute/take.rs @@ -10,6 +10,7 @@ use vortex_error::VortexResult; use crate::Array; use crate::ArrayRef; +use crate::IntoArray; use crate::ToCanonical; use crate::arrays::ListArray; use crate::arrays::ListVTable; @@ -18,7 +19,6 @@ use crate::arrays::TakeExecute; use crate::arrays::TakeExecuteAdaptor; use crate::builders::ArrayBuilder; use crate::builders::PrimitiveBuilder; -use crate::compute::take; use crate::executor::ExecutionCtx; use crate::kernel::ParentKernelSet; use crate::vtable::ValidityHelper; @@ -108,7 +108,11 @@ fn _take( let elements_to_take = elements_to_take.finish(); let new_offsets = new_offsets.finish(); - let new_elements = take(array.elements(), elements_to_take.as_ref())?; + let new_elements = array + .elements() + .take(elements_to_take.to_array())? + .to_canonical()? + .into_array(); Ok(ListArray::try_new( new_elements, @@ -176,7 +180,11 @@ fn _take_nullable VortexResult { + // TODO(joe): inline usage and remove to_canonical(). array .take(indices.to_array())? .to_canonical() .map(|c| c.into_array()) } - -pub(crate) fn propagate_take_stats( - source: &dyn Array, - target: &dyn Array, - indices: &dyn Array, -) -> VortexResult<()> { - target.statistics().with_mut_typed_stats_set(|mut st| { - if indices.all_valid().unwrap_or(false) { - let is_constant = source.statistics().get_as::(Stat::IsConstant); - if is_constant == Some(Precision::Exact(true)) { - // Any combination of elements from a constant array is still const - st.set(Stat::IsConstant, Precision::exact(true)); - } - } - let inexact_min_max = [Stat::Min, Stat::Max] - .into_iter() - .filter_map(|stat| { - source - .statistics() - .get(stat) - .map(|v| (stat, v.map(|s| s.into_value()).into_inexact())) - }) - .collect::>(); - st.combine_sets( - &(unsafe { StatsSet::new_unchecked(inexact_min_max) }).as_typed_ref(source.dtype()), - ) - }) -} diff --git a/vortex-array/src/validity.rs b/vortex-array/src/validity.rs index 7dc0a0f7aaa..fb2ed1c2d27 100644 --- a/vortex-array/src/validity.rs +++ b/vortex-array/src/validity.rs @@ -28,7 +28,6 @@ use crate::arrays::BoolArray; use crate::arrays::ConstantArray; use crate::compute::fill_null; use crate::compute::sum; -use crate::compute::take; use crate::patches::Patches; /// Validity information for an array @@ -173,7 +172,7 @@ impl Validity { }, Self::AllInvalid => Ok(Self::AllInvalid), Self::Array(is_valid) => { - let maybe_is_valid = take(is_valid, indices)?; + let maybe_is_valid = is_valid.take(indices.to_array())?.to_canonical()?.into_array(); // Null indices invalidate that position. let is_valid = fill_null(&maybe_is_valid, &Scalar::from(false))?; Ok(Self::Array(is_valid)) From 962bfae4c0f8ba3aa5d751f8755a0e5b0e94caaf Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Fri, 6 Feb 2026 16:43:21 +0000 Subject: [PATCH 17/21] wip Signed-off-by: Joe Isaacs --- vortex-array/src/arrays/chunked/compute/filter.rs | 3 +-- vortex-array/src/arrays/dict/compute/mod.rs | 6 +++++- vortex-array/src/arrays/masked/compute/take.rs | 6 +++++- vortex-array/src/validity.rs | 5 ++++- vortex-python/src/arrays/mod.rs | 9 ++++----- 5 files changed, 19 insertions(+), 10 deletions(-) diff --git a/vortex-array/src/arrays/chunked/compute/filter.rs b/vortex-array/src/arrays/chunked/compute/filter.rs index 2d47615dfa4..2e31adcc0aa 100644 --- a/vortex-array/src/arrays/chunked/compute/filter.rs +++ b/vortex-array/src/arrays/chunked/compute/filter.rs @@ -177,8 +177,7 @@ fn filter_indices( if !chunk_indices.is_empty() { let chunk = array.chunk(current_chunk_id); - let indices = - PrimitiveArray::new(chunk_indices.clone().freeze(), Validity::NonNullable); + let indices = PrimitiveArray::new(chunk_indices.clone().freeze(), Validity::NonNullable); let filtered_chunk = chunk.take(indices.to_array())?.to_canonical()?.into_array(); result.push(filtered_chunk); } diff --git a/vortex-array/src/arrays/dict/compute/mod.rs b/vortex-array/src/arrays/dict/compute/mod.rs index 5015726b378..16cb6fe7c33 100644 --- a/vortex-array/src/arrays/dict/compute/mod.rs +++ b/vortex-array/src/arrays/dict/compute/mod.rs @@ -32,7 +32,11 @@ impl TakeExecute for DictVTable { indices: &dyn Array, _ctx: &mut ExecutionCtx, ) -> VortexResult> { - let codes = array.codes().take(indices.to_array())?.to_canonical()?.into_array(); + let codes = array + .codes() + .take(indices.to_array())? + .to_canonical()? + .into_array(); // SAFETY: selecting codes doesn't change the invariants of DictArray // Preserve all_values_referenced since taking codes doesn't affect which values are referenced Ok(Some(unsafe { diff --git a/vortex-array/src/arrays/masked/compute/take.rs b/vortex-array/src/arrays/masked/compute/take.rs index ed6a21e354e..6c544db7f7b 100644 --- a/vortex-array/src/arrays/masked/compute/take.rs +++ b/vortex-array/src/arrays/masked/compute/take.rs @@ -30,7 +30,11 @@ impl TakeExecute for MaskedVTable { )?; array.child.take(filled_take)?.to_canonical()?.into_array() } else { - array.child.take(indices.to_array())?.to_canonical()?.into_array() + array + .child + .take(indices.to_array())? + .to_canonical()? + .into_array() }; // Compute the new validity by taking from array's validity and merging with indices validity diff --git a/vortex-array/src/validity.rs b/vortex-array/src/validity.rs index fb2ed1c2d27..6dab74af78b 100644 --- a/vortex-array/src/validity.rs +++ b/vortex-array/src/validity.rs @@ -172,7 +172,10 @@ impl Validity { }, Self::AllInvalid => Ok(Self::AllInvalid), Self::Array(is_valid) => { - let maybe_is_valid = is_valid.take(indices.to_array())?.to_canonical()?.into_array(); + let maybe_is_valid = is_valid + .take(indices.to_array())? + .to_canonical()? + .into_array(); // Null indices invalidate that position. let is_valid = fill_null(&maybe_is_valid, &Scalar::from(false))?; Ok(Self::Array(is_valid)) diff --git a/vortex-python/src/arrays/mod.rs b/vortex-python/src/arrays/mod.rs index 24344c652f3..d372d5a86a4 100644 --- a/vortex-python/src/arrays/mod.rs +++ b/vortex-python/src/arrays/mod.rs @@ -27,7 +27,6 @@ use vortex::array::ArrayRef; use vortex::array::ToCanonical; use vortex::array::arrays::ChunkedVTable; use vortex::array::arrow::IntoArrowArray; -use vortex::array::compute::take; use vortex::compute::Operator; use vortex::compute::compare; use vortex::dtype::DType; @@ -603,7 +602,7 @@ impl PyArray { /// >>> a = vx.array(['a', 'b', 'c', 'd']) /// >>> indices = vx.array([0, 2]) /// >>> a.take(indices).to_arrow_array() - /// + /// /// [ /// "a", /// "c" @@ -616,7 +615,7 @@ impl PyArray { /// >>> a = vx.array(['a', 'b', 'c', 'd']) /// >>> indices = vx.array([0, 1, 1, 0]) /// >>> a.take(indices).to_arrow_array() - /// + /// /// [ /// "a", /// "b", @@ -629,13 +628,13 @@ impl PyArray { if !indices.dtype().is_int() { return Err(PyValueError::new_err(format!( - "indices: expected int or uint array, but found: {}", + "indices: expected int or uint arra sy, but found: {}", indices.dtype().python_repr() )) .into()); } - let inner = take(&slf, &*indices)?; + let inner = slf.take(indices.clone())?; Ok(PyArrayRef::from(inner)) } From 04e9feada86e4c4757fae3660557bf09c62d0916 Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Fri, 6 Feb 2026 17:30:06 +0000 Subject: [PATCH 18/21] wip Signed-off-by: Joe Isaacs --- encodings/alp/src/alp/compute/take.rs | 7 - encodings/alp/src/alp/rules.rs | 2 + encodings/alp/src/alp_rd/compute/take.rs | 7 - encodings/alp/src/alp_rd/kernel.rs | 2 + encodings/bytebool/src/array.rs | 9 + encodings/bytebool/src/compute.rs | 7 - encodings/bytebool/src/kernel.rs | 10 + encodings/bytebool/src/lib.rs | 1 + encodings/datetime-parts/src/array.rs | 10 + .../datetime-parts/src/compute/kernel.rs | 12 ++ encodings/datetime-parts/src/compute/mod.rs | 1 + encodings/datetime-parts/src/compute/take.rs | 7 - .../src/decimal_byte_parts/compute/kernel.rs | 12 ++ .../src/decimal_byte_parts/compute/mod.rs | 1 + .../src/decimal_byte_parts/compute/take.rs | 7 - .../src/decimal_byte_parts/mod.rs | 10 + .../fastlanes/src/bitpacking/compute/take.rs | 7 - .../src/bitpacking/vtable/kernels.rs | 2 + encodings/fastlanes/src/for/compute/mod.rs | 7 - encodings/fastlanes/src/for/vtable/kernels.rs | 10 + encodings/fastlanes/src/for/vtable/mod.rs | 11 ++ encodings/fsst/src/compute/mod.rs | 15 +- encodings/fsst/src/kernel.rs | 7 +- encodings/runend/src/compute/mod.rs | 2 +- encodings/runend/src/compute/take.rs | 7 - encodings/runend/src/compute/take_from.rs | 100 +++++----- encodings/runend/src/kernel.rs | 4 + encodings/sequence/src/compute/take.rs | 9 +- encodings/sequence/src/kernel.rs | 2 + encodings/sparse/src/compute/take.rs | 7 - encodings/sparse/src/kernel.rs | 2 + encodings/zigzag/src/array.rs | 10 + encodings/zigzag/src/compute/mod.rs | 7 - encodings/zigzag/src/kernel.rs | 10 + encodings/zigzag/src/lib.rs | 1 + vortex-array/src/arrays/bool/compute/take.rs | 7 - .../src/arrays/chunked/compute/take.rs | 7 - .../src/arrays/decimal/compute/take.rs | 7 - vortex-array/src/arrays/dict/compute/mod.rs | 7 - .../src/arrays/extension/compute/take.rs | 7 - .../arrays/fixed_size_list/compute/take.rs | 7 - vortex-array/src/arrays/list/compute/take.rs | 7 - .../src/arrays/listview/compute/take.rs | 7 - .../src/arrays/masked/compute/take.rs | 7 - .../src/arrays/primitive/compute/take/mod.rs | 7 - .../src/arrays/struct_/compute/take.rs | 7 - vortex-array/src/arrays/varbin/compute/mod.rs | 1 + .../src/arrays/varbin/compute/take.rs | 177 +++++++++--------- vortex-array/src/arrays/varbin/mod.rs | 2 +- .../src/arrays/varbinview/compute/take.rs | 7 - 50 files changed, 271 insertions(+), 318 deletions(-) create mode 100644 encodings/bytebool/src/kernel.rs create mode 100644 encodings/datetime-parts/src/compute/kernel.rs create mode 100644 encodings/decimal-byte-parts/src/decimal_byte_parts/compute/kernel.rs create mode 100644 encodings/fastlanes/src/for/vtable/kernels.rs create mode 100644 encodings/zigzag/src/kernel.rs diff --git a/encodings/alp/src/alp/compute/take.rs b/encodings/alp/src/alp/compute/take.rs index 44b24641266..9d97c598e28 100644 --- a/encodings/alp/src/alp/compute/take.rs +++ b/encodings/alp/src/alp/compute/take.rs @@ -6,8 +6,6 @@ use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::TakeExecute; -use vortex_array::arrays::TakeExecuteAdaptor; -use vortex_array::kernel::ParentKernelSet; use vortex_error::VortexResult; use crate::ALPArray; @@ -39,11 +37,6 @@ impl TakeExecute for ALPVTable { } } -impl ALPVTable { - pub const TAKE_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); -} - #[cfg(test)] mod test { use rstest::rstest; diff --git a/encodings/alp/src/alp/rules.rs b/encodings/alp/src/alp/rules.rs index 69b88759fbc..eb25c3fc872 100644 --- a/encodings/alp/src/alp/rules.rs +++ b/encodings/alp/src/alp/rules.rs @@ -3,6 +3,7 @@ use vortex_array::arrays::FilterExecuteAdaptor; use vortex_array::arrays::SliceExecuteAdaptor; +use vortex_array::arrays::TakeExecuteAdaptor; use vortex_array::kernel::ParentKernelSet; use crate::ALPVTable; @@ -10,4 +11,5 @@ use crate::ALPVTable; pub(super) const PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ ParentKernelSet::lift(&FilterExecuteAdaptor(ALPVTable)), ParentKernelSet::lift(&SliceExecuteAdaptor(ALPVTable)), + ParentKernelSet::lift(&TakeExecuteAdaptor(ALPVTable)), ]); diff --git a/encodings/alp/src/alp_rd/compute/take.rs b/encodings/alp/src/alp_rd/compute/take.rs index 10f236de877..b2a395c0813 100644 --- a/encodings/alp/src/alp_rd/compute/take.rs +++ b/encodings/alp/src/alp_rd/compute/take.rs @@ -6,9 +6,7 @@ use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::TakeExecute; -use vortex_array::arrays::TakeExecuteAdaptor; use vortex_array::compute::fill_null; -use vortex_array::kernel::ParentKernelSet; use vortex_error::VortexResult; use vortex_scalar::Scalar; use vortex_scalar::ScalarValue; @@ -57,11 +55,6 @@ impl TakeExecute for ALPRDVTable { } } -impl ALPRDVTable { - pub const TAKE_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); -} - #[cfg(test)] mod test { use rstest::rstest; diff --git a/encodings/alp/src/alp_rd/kernel.rs b/encodings/alp/src/alp_rd/kernel.rs index ad9dd1c6a56..c2f37e5495e 100644 --- a/encodings/alp/src/alp_rd/kernel.rs +++ b/encodings/alp/src/alp_rd/kernel.rs @@ -3,6 +3,7 @@ use vortex_array::arrays::FilterExecuteAdaptor; use vortex_array::arrays::SliceExecuteAdaptor; +use vortex_array::arrays::TakeExecuteAdaptor; use vortex_array::kernel::ParentKernelSet; use crate::alp_rd::ALPRDVTable; @@ -10,4 +11,5 @@ use crate::alp_rd::ALPRDVTable; pub(crate) static PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ ParentKernelSet::lift(&SliceExecuteAdaptor(ALPRDVTable)), ParentKernelSet::lift(&FilterExecuteAdaptor(ALPRDVTable)), + ParentKernelSet::lift(&TakeExecuteAdaptor(ALPRDVTable)), ]); diff --git a/encodings/bytebool/src/array.rs b/encodings/bytebool/src/array.rs index dd078089e6b..1301efea03b 100644 --- a/encodings/bytebool/src/array.rs +++ b/encodings/bytebool/src/array.rs @@ -118,6 +118,15 @@ impl VTable for ByteBoolVTable { let validity = array.validity().clone(); Ok(BoolArray::new(boolean_buffer, validity).into_array()) } + + fn execute_parent( + array: &Self::Array, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + crate::kernel::PARENT_KERNELS.execute(array, parent, child_idx, ctx) + } } #[derive(Clone, Debug)] diff --git a/encodings/bytebool/src/compute.rs b/encodings/bytebool/src/compute.rs index 3fc5428dc40..ab9624db01a 100644 --- a/encodings/bytebool/src/compute.rs +++ b/encodings/bytebool/src/compute.rs @@ -8,12 +8,10 @@ use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::ToCanonical; use vortex_array::arrays::TakeExecute; -use vortex_array::arrays::TakeExecuteAdaptor; use vortex_array::compute::CastKernel; use vortex_array::compute::CastKernelAdapter; use vortex_array::compute::MaskKernel; use vortex_array::compute::MaskKernelAdapter; -use vortex_array::kernel::ParentKernelSet; use vortex_array::register_kernel; use vortex_array::vtable::ValidityHelper; use vortex_dtype::DType; @@ -86,11 +84,6 @@ impl TakeExecute for ByteBoolVTable { } } -impl ByteBoolVTable { - pub const TAKE_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); -} - #[cfg(test)] mod tests { use rstest::rstest; diff --git a/encodings/bytebool/src/kernel.rs b/encodings/bytebool/src/kernel.rs new file mode 100644 index 00000000000..91eb3add3a6 --- /dev/null +++ b/encodings/bytebool/src/kernel.rs @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::arrays::TakeExecuteAdaptor; +use vortex_array::kernel::ParentKernelSet; + +use crate::ByteBoolVTable; + +pub(crate) const PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(ByteBoolVTable))]); diff --git a/encodings/bytebool/src/lib.rs b/encodings/bytebool/src/lib.rs index c195e618317..35579a89dc8 100644 --- a/encodings/bytebool/src/lib.rs +++ b/encodings/bytebool/src/lib.rs @@ -5,5 +5,6 @@ pub use array::*; mod array; mod compute; +mod kernel; mod rules; mod slice; diff --git a/encodings/datetime-parts/src/array.rs b/encodings/datetime-parts/src/array.rs index 04489dd5129..7036500e476 100644 --- a/encodings/datetime-parts/src/array.rs +++ b/encodings/datetime-parts/src/array.rs @@ -37,6 +37,7 @@ use vortex_error::vortex_ensure; use vortex_error::vortex_err; use crate::canonical::decode_to_temporal; +use crate::compute::kernel::PARENT_KERNELS; use crate::compute::rules::PARENT_RULES; vtable!(DateTimeParts); @@ -162,6 +163,15 @@ impl VTable for DateTimePartsVTable { ) -> VortexResult> { PARENT_RULES.evaluate(array, parent, child_idx) } + + fn execute_parent( + array: &Self::Array, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + PARENT_KERNELS.execute(array, parent, child_idx, ctx) + } } #[derive(Clone, Debug)] diff --git a/encodings/datetime-parts/src/compute/kernel.rs b/encodings/datetime-parts/src/compute/kernel.rs new file mode 100644 index 00000000000..9c95c3439ca --- /dev/null +++ b/encodings/datetime-parts/src/compute/kernel.rs @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::arrays::TakeExecuteAdaptor; +use vortex_array::kernel::ParentKernelSet; + +use crate::DateTimePartsVTable; + +pub(crate) const PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor( + DateTimePartsVTable, + ))]); diff --git a/encodings/datetime-parts/src/compute/mod.rs b/encodings/datetime-parts/src/compute/mod.rs index caf7df8ffbf..d606daccb59 100644 --- a/encodings/datetime-parts/src/compute/mod.rs +++ b/encodings/datetime-parts/src/compute/mod.rs @@ -5,6 +5,7 @@ mod cast; mod compare; mod filter; mod is_constant; +pub(crate) mod kernel; mod mask; pub(super) mod rules; mod slice; diff --git a/encodings/datetime-parts/src/compute/take.rs b/encodings/datetime-parts/src/compute/take.rs index 1a833ddf967..8014773c33a 100644 --- a/encodings/datetime-parts/src/compute/take.rs +++ b/encodings/datetime-parts/src/compute/take.rs @@ -7,11 +7,9 @@ use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::ToCanonical; use vortex_array::arrays::TakeExecute; -use vortex_array::arrays::TakeExecuteAdaptor; use vortex_array::compute::fill_null; use vortex_array::expr::stats::Stat; use vortex_array::expr::stats::StatsProvider; -use vortex_array::kernel::ParentKernelSet; use vortex_dtype::Nullability; use vortex_error::VortexResult; use vortex_error::vortex_panic; @@ -97,11 +95,6 @@ impl TakeExecute for DateTimePartsVTable { } } -impl DateTimePartsVTable { - pub const TAKE_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); -} - #[cfg(test)] mod tests { use rstest::rstest; diff --git a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/kernel.rs b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/kernel.rs new file mode 100644 index 00000000000..a802fad5db1 --- /dev/null +++ b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/kernel.rs @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::arrays::TakeExecuteAdaptor; +use vortex_array::kernel::ParentKernelSet; + +use crate::DecimalBytePartsVTable; + +pub(crate) const PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor( + DecimalBytePartsVTable, + ))]); diff --git a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/mod.rs b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/mod.rs index 262c6cd5b72..2e798106a7e 100644 --- a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/mod.rs +++ b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/mod.rs @@ -5,6 +5,7 @@ mod cast; mod compare; mod filter; mod is_constant; +pub(crate) mod kernel; mod mask; mod take; diff --git a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/take.rs b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/take.rs index 88d1bbfb2e9..ba5a04c7d3b 100644 --- a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/take.rs +++ b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/take.rs @@ -5,8 +5,6 @@ use vortex_array::Array; use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; use vortex_array::arrays::TakeExecute; -use vortex_array::arrays::TakeExecuteAdaptor; -use vortex_array::kernel::ParentKernelSet; use vortex_error::VortexResult; use crate::DecimalBytePartsArray; @@ -22,8 +20,3 @@ impl TakeExecute for DecimalBytePartsVTable { .map(|a| Some(a.to_array())) } } - -impl DecimalBytePartsVTable { - pub const TAKE_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); -} diff --git a/encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs b/encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs index 001f10dfd0d..91d6944004f 100644 --- a/encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs +++ b/encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs @@ -45,6 +45,7 @@ use vortex_error::vortex_ensure; use vortex_scalar::DecimalValue; use vortex_scalar::Scalar; +use crate::decimal_byte_parts::compute::kernel::PARENT_KERNELS; use crate::decimal_byte_parts::rules::PARENT_RULES; vtable!(DecimalByteParts); @@ -130,6 +131,15 @@ impl VTable for DecimalBytePartsVTable { fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { to_canonical_decimal(array, ctx) } + + fn execute_parent( + array: &Self::Array, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + PARENT_KERNELS.execute(array, parent, child_idx, ctx) + } } /// This array encodes decimals as between 1-4 columns of primitive typed children. diff --git a/encodings/fastlanes/src/bitpacking/compute/take.rs b/encodings/fastlanes/src/bitpacking/compute/take.rs index 84cbf8536af..6a5b33495dc 100644 --- a/encodings/fastlanes/src/bitpacking/compute/take.rs +++ b/encodings/fastlanes/src/bitpacking/compute/take.rs @@ -12,8 +12,6 @@ use vortex_array::IntoArray; use vortex_array::ToCanonical; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::TakeExecute; -use vortex_array::arrays::TakeExecuteAdaptor; -use vortex_array::kernel::ParentKernelSet; use vortex_array::validity::Validity; use vortex_array::vtable::ValidityHelper; use vortex_buffer::Buffer; @@ -64,11 +62,6 @@ impl TakeExecute for BitPackedVTable { } } -impl BitPackedVTable { - pub const TAKE_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); -} - fn take_primitive( array: &BitPackedArray, indices: &PrimitiveArray, diff --git a/encodings/fastlanes/src/bitpacking/vtable/kernels.rs b/encodings/fastlanes/src/bitpacking/vtable/kernels.rs index ffbba0c7bbb..ca9b98c09db 100644 --- a/encodings/fastlanes/src/bitpacking/vtable/kernels.rs +++ b/encodings/fastlanes/src/bitpacking/vtable/kernels.rs @@ -3,6 +3,7 @@ use vortex_array::arrays::FilterExecuteAdaptor; use vortex_array::arrays::SliceExecuteAdaptor; +use vortex_array::arrays::TakeExecuteAdaptor; use vortex_array::kernel::ParentKernelSet; use crate::BitPackedVTable; @@ -10,4 +11,5 @@ use crate::BitPackedVTable; pub(crate) const PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ ParentKernelSet::lift(&FilterExecuteAdaptor(BitPackedVTable)), ParentKernelSet::lift(&SliceExecuteAdaptor(BitPackedVTable)), + ParentKernelSet::lift(&TakeExecuteAdaptor(BitPackedVTable)), ]); diff --git a/encodings/fastlanes/src/for/compute/mod.rs b/encodings/fastlanes/src/for/compute/mod.rs index 35ed1d715f2..a8efc731793 100644 --- a/encodings/fastlanes/src/for/compute/mod.rs +++ b/encodings/fastlanes/src/for/compute/mod.rs @@ -12,8 +12,6 @@ use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::FilterReduce; use vortex_array::arrays::TakeExecute; -use vortex_array::arrays::TakeExecuteAdaptor; -use vortex_array::kernel::ParentKernelSet; use vortex_error::VortexResult; use vortex_mask::Mask; @@ -36,11 +34,6 @@ impl TakeExecute for FoRVTable { } } -impl FoRVTable { - pub const TAKE_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); -} - impl FilterReduce for FoRVTable { fn filter(array: &FoRArray, mask: &Mask) -> VortexResult> { FoRArray::try_new( diff --git a/encodings/fastlanes/src/for/vtable/kernels.rs b/encodings/fastlanes/src/for/vtable/kernels.rs new file mode 100644 index 00000000000..60a009afe15 --- /dev/null +++ b/encodings/fastlanes/src/for/vtable/kernels.rs @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::arrays::TakeExecuteAdaptor; +use vortex_array::kernel::ParentKernelSet; + +use crate::FoRVTable; + +pub(crate) const PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(FoRVTable))]); diff --git a/encodings/fastlanes/src/for/vtable/mod.rs b/encodings/fastlanes/src/for/vtable/mod.rs index a4ade0feaf4..6b6969496af 100644 --- a/encodings/fastlanes/src/for/vtable/mod.rs +++ b/encodings/fastlanes/src/for/vtable/mod.rs @@ -24,9 +24,11 @@ use vortex_scalar::ScalarValue; use crate::FoRArray; use crate::r#for::array::for_decompress::decompress; +use crate::r#for::vtable::kernels::PARENT_KERNELS; use crate::r#for::vtable::rules::PARENT_RULES; mod array; +mod kernels; mod operations; mod rules; mod slice; @@ -109,6 +111,15 @@ impl VTable for FoRVTable { fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { Ok(decompress(array, ctx)?.into_array()) } + + fn execute_parent( + array: &Self::Array, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + PARENT_KERNELS.execute(array, parent, child_idx, ctx) + } } #[derive(Debug)] diff --git a/encodings/fsst/src/compute/mod.rs b/encodings/fsst/src/compute/mod.rs index 1e11609e86c..7256cf31122 100644 --- a/encodings/fsst/src/compute/mod.rs +++ b/encodings/fsst/src/compute/mod.rs @@ -10,10 +10,8 @@ use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::TakeExecute; -use vortex_array::arrays::TakeExecuteAdaptor; -use vortex_array::arrays::VarBinVTable; +use vortex_array::arrays::take_into_varbin; use vortex_array::compute::fill_null; -use vortex_array::kernel::ParentKernelSet; use vortex_error::VortexResult; use vortex_scalar::Scalar; use vortex_scalar::ScalarValue; @@ -35,11 +33,7 @@ impl TakeExecute for FSSTVTable { .union_nullability(indices.dtype().nullability()), array.symbols().clone(), array.symbol_lengths().clone(), - array - .codes() - .take(indices.to_array())? - .as_::() - .clone(), + take_into_varbin(array.codes(), indices)?, fill_null( &array.uncompressed_lengths().take(indices.to_array())?, &Scalar::new( @@ -53,11 +47,6 @@ impl TakeExecute for FSSTVTable { } } -impl FSSTVTable { - pub const TAKE_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); -} - #[cfg(test)] mod tests { use rstest::rstest; diff --git a/encodings/fsst/src/kernel.rs b/encodings/fsst/src/kernel.rs index 710f0dd5ba2..e3e11a1ed5e 100644 --- a/encodings/fsst/src/kernel.rs +++ b/encodings/fsst/src/kernel.rs @@ -2,12 +2,15 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_array::arrays::FilterExecuteAdaptor; +use vortex_array::arrays::TakeExecuteAdaptor; use vortex_array::kernel::ParentKernelSet; use crate::FSSTVTable; -pub(super) const PARENT_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&FilterExecuteAdaptor(FSSTVTable))]); +pub(super) const PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ + ParentKernelSet::lift(&FilterExecuteAdaptor(FSSTVTable)), + ParentKernelSet::lift(&TakeExecuteAdaptor(FSSTVTable)), +]); #[cfg(test)] mod tests { diff --git a/encodings/runend/src/compute/mod.rs b/encodings/runend/src/compute/mod.rs index 4827436dc39..a4c80d0a5ae 100644 --- a/encodings/runend/src/compute/mod.rs +++ b/encodings/runend/src/compute/mod.rs @@ -11,7 +11,7 @@ mod is_constant; mod is_sorted; mod min_max; pub(crate) mod take; -mod take_from; +pub(crate) mod take_from; #[cfg(test)] mod tests { diff --git a/encodings/runend/src/compute/take.rs b/encodings/runend/src/compute/take.rs index 7a44270d971..009dbd86783 100644 --- a/encodings/runend/src/compute/take.rs +++ b/encodings/runend/src/compute/take.rs @@ -9,8 +9,6 @@ use vortex_array::ExecutionCtx; use vortex_array::ToCanonical; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::TakeExecute; -use vortex_array::arrays::TakeExecuteAdaptor; -use vortex_array::kernel::ParentKernelSet; use vortex_array::search_sorted::SearchResult; use vortex_array::search_sorted::SearchSorted; use vortex_array::search_sorted::SearchSortedSide; @@ -55,11 +53,6 @@ impl TakeExecute for RunEndVTable { } } -impl RunEndVTable { - pub const TAKE_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); -} - /// Perform a take operation on a RunEndArray by binary searching for each of the indices. pub fn take_indices_unchecked>( array: &RunEndArray, diff --git a/encodings/runend/src/compute/take_from.rs b/encodings/runend/src/compute/take_from.rs index 92cb0856096..dbf7ed93465 100644 --- a/encodings/runend/src/compute/take_from.rs +++ b/encodings/runend/src/compute/take_from.rs @@ -1,70 +1,70 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::DictArray; +use vortex_array::arrays::DictVTable; +use vortex_array::kernel::ExecuteParentKernel; +use vortex_dtype::DType; +use vortex_error::VortexResult; + +use crate::RunEndArray; +use crate::RunEndVTable; + +#[derive(Debug)] +pub(crate) struct RunEndVTableTakeFrom; + +impl ExecuteParentKernel for RunEndVTableTakeFrom { + type Parent = DictVTable; + + fn execute_parent( + &self, + array: &RunEndArray, + dict: &DictArray, + child_idx: usize, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + if child_idx != 0 { + return Ok(None); + } + // Only `Primitive` and `Bool` are valid run-end value types. + // TODO: Support additional DTypes + if !matches!(dict.dtype(), DType::Primitive(_, _) | DType::Bool(_)) { + return Ok(None); + } + + // Create a new run-end array containing values as values, instead of indices as values. + // SAFETY: we are copying ends from an existing valid RunEndArray + let ree_array = unsafe { + RunEndArray::new_unchecked( + array.ends().clone(), + dict.values().take(array.values().clone())?, + array.offset(), + array.len(), + ) + }; + // + Ok(Some(ree_array.into_array())) + } +} + #[cfg(test)] mod tests { - use std::fmt::Debug; - use vortex_array::Array; - use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::DictArray; - use vortex_array::arrays::DictVTable; use vortex_array::arrays::PrimitiveArray; use vortex_array::assert_arrays_eq; use vortex_array::kernel::ExecuteParentKernel; use vortex_buffer::buffer; - use vortex_dtype::DType; use vortex_error::VortexResult; use vortex_session::VortexSession; use crate::RunEndArray; - use crate::RunEndVTable; - - #[derive(Debug)] - struct RunEndVTableTakeFrom; - - impl ExecuteParentKernel for RunEndVTableTakeFrom { - type Parent = DictVTable; - - fn execute_parent( - &self, - array: &RunEndArray, - dict: &DictArray, - child_idx: usize, - _ctx: &mut ExecutionCtx, - ) -> VortexResult> { - if child_idx != 0 { - return Ok(None); - } - // Only `Primitive` and `Bool` are valid run-end value types. - // TODO: Support additional DTypes - if !matches!(dict.dtype(), DType::Primitive(_, _) | DType::Bool(_)) { - return Ok(None); - } - - // // Transform the run-end encoding from storing indices to storing values - // // by taking values from `source` at positions specified by `indices.values()`. - // - // // Create a new run-end array containing values as values, instead of indices as values. - // // SAFETY: we are copying ends from an existing valid RunEndArray - let ree_array = unsafe { - RunEndArray::new_unchecked( - array.ends().clone(), - dict.values().take(array.values().clone())?, - array.offset(), - array.len(), - ) - }; - // - Ok(Some(ree_array.into_array())) - - // TODO: implement run-end take from optimization - // For now, skip this optimization and fall back to default take - // Ok(None) - } - } + use crate::compute::take_from::RunEndVTableTakeFrom; /// Build a DictArray whose codes are run-end encoded. /// diff --git a/encodings/runend/src/kernel.rs b/encodings/runend/src/kernel.rs index 3c4e41ab296..4873d9e15b1 100644 --- a/encodings/runend/src/kernel.rs +++ b/encodings/runend/src/kernel.rs @@ -10,16 +10,20 @@ use vortex_array::arrays::ConstantArray; use vortex_array::arrays::FilterExecuteAdaptor; use vortex_array::arrays::SliceArray; use vortex_array::arrays::SliceVTable; +use vortex_array::arrays::TakeExecuteAdaptor; use vortex_array::kernel::ExecuteParentKernel; use vortex_array::kernel::ParentKernelSet; use vortex_error::VortexResult; use crate::RunEndArray; use crate::RunEndVTable; +use crate::compute::take_from::RunEndVTableTakeFrom; pub(super) const PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ ParentKernelSet::lift(&RunEndSliceKernel), ParentKernelSet::lift(&FilterExecuteAdaptor(RunEndVTable)), + ParentKernelSet::lift(&TakeExecuteAdaptor(RunEndVTable)), + ParentKernelSet::lift(&RunEndVTableTakeFrom), ]); /// Kernel to execute slicing on a RunEnd array. diff --git a/encodings/sequence/src/compute/take.rs b/encodings/sequence/src/compute/take.rs index b3725389ed3..448c900c68c 100644 --- a/encodings/sequence/src/compute/take.rs +++ b/encodings/sequence/src/compute/take.rs @@ -10,8 +10,6 @@ use vortex_array::ToCanonical; use vortex_array::arrays::ConstantArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::TakeExecute; -use vortex_array::arrays::TakeExecuteAdaptor; -use vortex_array::kernel::ParentKernelSet; use vortex_array::validity::Validity; use vortex_buffer::Buffer; use vortex_dtype::DType; @@ -103,11 +101,6 @@ impl TakeExecute for SequenceVTable { } } -impl SequenceVTable { - pub const TAKE_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); -} - #[cfg(test)] mod test { use rstest::rstest; @@ -171,7 +164,7 @@ mod test { } #[test] - #[should_panic(expected = "index out of bounds")] + #[should_panic(expected = "out of bounds")] fn test_bounds_check() { let array = SequenceArray::typed_new(0i32, 1i32, Nullability::NonNullable, 10).unwrap(); let indices = vortex_array::arrays::PrimitiveArray::from_iter([0i32, 20]); diff --git a/encodings/sequence/src/kernel.rs b/encodings/sequence/src/kernel.rs index 3ce712d137f..982f33cf51b 100644 --- a/encodings/sequence/src/kernel.rs +++ b/encodings/sequence/src/kernel.rs @@ -11,6 +11,7 @@ use vortex_array::arrays::ExactScalarFn; use vortex_array::arrays::FilterExecuteAdaptor; use vortex_array::arrays::ScalarFnArrayView; use vortex_array::arrays::ScalarFnVTable; +use vortex_array::arrays::TakeExecuteAdaptor; use vortex_array::compute::Operator; use vortex_array::expr::Binary; use vortex_array::kernel::ExecuteParentKernel; @@ -34,6 +35,7 @@ use crate::compute::compare::find_intersection_scalar; pub(crate) const PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ ParentKernelSet::lift(&SequenceCompareKernel), ParentKernelSet::lift(&FilterExecuteAdaptor(SequenceVTable)), + ParentKernelSet::lift(&TakeExecuteAdaptor(SequenceVTable)), ]); /// Kernel to execute comparison operations directly on a sequence array. diff --git a/encodings/sparse/src/compute/take.rs b/encodings/sparse/src/compute/take.rs index 39f0b76ef1e..3ff7d74b0df 100644 --- a/encodings/sparse/src/compute/take.rs +++ b/encodings/sparse/src/compute/take.rs @@ -7,8 +7,6 @@ use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::ConstantArray; use vortex_array::arrays::TakeExecute; -use vortex_array::arrays::TakeExecuteAdaptor; -use vortex_array::kernel::ParentKernelSet; use vortex_error::VortexResult; use crate::SparseArray; @@ -56,11 +54,6 @@ impl TakeExecute for SparseVTable { } } -impl SparseVTable { - pub const TAKE_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); -} - #[cfg(test)] mod test { use rstest::rstest; diff --git a/encodings/sparse/src/kernel.rs b/encodings/sparse/src/kernel.rs index c5cf765c5ac..5b6795b407b 100644 --- a/encodings/sparse/src/kernel.rs +++ b/encodings/sparse/src/kernel.rs @@ -3,6 +3,7 @@ use vortex_array::arrays::FilterExecuteAdaptor; use vortex_array::arrays::SliceExecuteAdaptor; +use vortex_array::arrays::TakeExecuteAdaptor; use vortex_array::kernel::ParentKernelSet; use crate::SparseVTable; @@ -10,4 +11,5 @@ use crate::SparseVTable; pub(crate) static PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ ParentKernelSet::lift(&FilterExecuteAdaptor(SparseVTable)), ParentKernelSet::lift(&SliceExecuteAdaptor(SparseVTable)), + ParentKernelSet::lift(&TakeExecuteAdaptor(SparseVTable)), ]); diff --git a/encodings/zigzag/src/array.rs b/encodings/zigzag/src/array.rs index 89e30e3f1c8..19fb3f9ac73 100644 --- a/encodings/zigzag/src/array.rs +++ b/encodings/zigzag/src/array.rs @@ -36,6 +36,7 @@ use vortex_scalar::Scalar; use zigzag::ZigZag as ExternalZigZag; use crate::compute::ZigZagEncoded; +use crate::kernel::PARENT_KERNELS; use crate::rules::RULES; use crate::zigzag_decode; @@ -106,6 +107,15 @@ impl VTable for ZigZagVTable { ) -> VortexResult> { RULES.evaluate(array, parent, child_idx) } + + fn execute_parent( + array: &Self::Array, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + PARENT_KERNELS.execute(array, parent, child_idx, ctx) + } } #[derive(Clone, Debug)] diff --git a/encodings/zigzag/src/compute/mod.rs b/encodings/zigzag/src/compute/mod.rs index 88d9419d8df..42b9ea2d120 100644 --- a/encodings/zigzag/src/compute/mod.rs +++ b/encodings/zigzag/src/compute/mod.rs @@ -9,11 +9,9 @@ use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::FilterReduce; use vortex_array::arrays::TakeExecute; -use vortex_array::arrays::TakeExecuteAdaptor; use vortex_array::compute::MaskKernel; use vortex_array::compute::MaskKernelAdapter; use vortex_array::compute::mask; -use vortex_array::kernel::ParentKernelSet; use vortex_array::register_kernel; use vortex_error::VortexResult; use vortex_mask::Mask; @@ -39,11 +37,6 @@ impl TakeExecute for ZigZagVTable { } } -impl ZigZagVTable { - pub const TAKE_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); -} - impl MaskKernel for ZigZagVTable { fn mask(&self, array: &ZigZagArray, filter_mask: &Mask) -> VortexResult { let encoded = mask(array.encoded(), filter_mask)?; diff --git a/encodings/zigzag/src/kernel.rs b/encodings/zigzag/src/kernel.rs new file mode 100644 index 00000000000..6b93005d0f2 --- /dev/null +++ b/encodings/zigzag/src/kernel.rs @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::arrays::TakeExecuteAdaptor; +use vortex_array::kernel::ParentKernelSet; + +use crate::ZigZagVTable; + +pub(crate) const PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(ZigZagVTable))]); diff --git a/encodings/zigzag/src/lib.rs b/encodings/zigzag/src/lib.rs index 733139a1a00..89da8bd6069 100644 --- a/encodings/zigzag/src/lib.rs +++ b/encodings/zigzag/src/lib.rs @@ -7,5 +7,6 @@ pub use compress::*; mod array; mod compress; mod compute; +mod kernel; mod rules; mod slice; diff --git a/vortex-array/src/arrays/bool/compute/take.rs b/vortex-array/src/arrays/bool/compute/take.rs index d7dfea29807..e9fb6bd44dc 100644 --- a/vortex-array/src/arrays/bool/compute/take.rs +++ b/vortex-array/src/arrays/bool/compute/take.rs @@ -18,10 +18,8 @@ use crate::arrays::BoolArray; use crate::arrays::BoolVTable; use crate::arrays::ConstantArray; use crate::arrays::TakeExecute; -use crate::arrays::TakeExecuteAdaptor; use crate::compute::fill_null; use crate::executor::ExecutionCtx; -use crate::kernel::ParentKernelSet; use crate::vtable::ValidityHelper; impl TakeExecute for BoolVTable { @@ -51,11 +49,6 @@ impl TakeExecute for BoolVTable { } } -impl BoolVTable { - pub const TAKE_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); -} - fn take_valid_indices>(bools: &BitBuffer, indices: &[I]) -> BitBuffer { // For boolean arrays that roughly fit into a single page (at least, on Linux), it's worth // the overhead to convert to a Vec. diff --git a/vortex-array/src/arrays/chunked/compute/take.rs b/vortex-array/src/arrays/chunked/compute/take.rs index 47f9288f004..fc20dcd3b36 100644 --- a/vortex-array/src/arrays/chunked/compute/take.rs +++ b/vortex-array/src/arrays/chunked/compute/take.rs @@ -13,12 +13,10 @@ use crate::ToCanonical; use crate::arrays::ChunkedVTable; use crate::arrays::PrimitiveArray; use crate::arrays::TakeExecute; -use crate::arrays::TakeExecuteAdaptor; use crate::arrays::chunked::ChunkedArray; use crate::compute::cast; use crate::compute::take; use crate::executor::ExecutionCtx; -use crate::kernel::ParentKernelSet; use crate::validity::Validity; fn take_chunked(array: &ChunkedArray, indices: &dyn Array) -> VortexResult { @@ -93,11 +91,6 @@ impl TakeExecute for ChunkedVTable { } } -impl ChunkedVTable { - pub const TAKE_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); -} - #[cfg(test)] mod test { use vortex_buffer::buffer; diff --git a/vortex-array/src/arrays/decimal/compute/take.rs b/vortex-array/src/arrays/decimal/compute/take.rs index c931ee03512..15af95a3ce1 100644 --- a/vortex-array/src/arrays/decimal/compute/take.rs +++ b/vortex-array/src/arrays/decimal/compute/take.rs @@ -14,9 +14,7 @@ use crate::ToCanonical; use crate::arrays::DecimalArray; use crate::arrays::DecimalVTable; use crate::arrays::TakeExecute; -use crate::arrays::TakeExecuteAdaptor; use crate::executor::ExecutionCtx; -use crate::kernel::ParentKernelSet; use crate::vtable::ValidityHelper; impl TakeExecute for DecimalVTable { @@ -44,11 +42,6 @@ impl TakeExecute for DecimalVTable { } } -impl DecimalVTable { - pub const TAKE_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); -} - #[inline] fn take_to_buffer(indices: &[I], values: &[T]) -> Buffer { indices.iter().map(|idx| values[idx.as_()]).collect() diff --git a/vortex-array/src/arrays/dict/compute/mod.rs b/vortex-array/src/arrays/dict/compute/mod.rs index 16cb6fe7c33..fe8b053a53f 100644 --- a/vortex-array/src/arrays/dict/compute/mod.rs +++ b/vortex-array/src/arrays/dict/compute/mod.rs @@ -18,13 +18,11 @@ use vortex_mask::Mask; use super::DictArray; use super::DictVTable; use super::TakeExecute; -use super::TakeExecuteAdaptor; use crate::Array; use crate::ArrayRef; use crate::ExecutionCtx; use crate::IntoArray; use crate::arrays::filter::FilterReduce; -use crate::kernel::ParentKernelSet; impl TakeExecute for DictVTable { fn take( @@ -45,11 +43,6 @@ impl TakeExecute for DictVTable { } } -impl DictVTable { - pub const TAKE_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); -} - impl FilterReduce for DictVTable { fn filter(array: &DictArray, mask: &Mask) -> VortexResult> { let codes = array.codes().filter(mask.clone())?; diff --git a/vortex-array/src/arrays/extension/compute/take.rs b/vortex-array/src/arrays/extension/compute/take.rs index 04cf64b277c..aa53205a5b1 100644 --- a/vortex-array/src/arrays/extension/compute/take.rs +++ b/vortex-array/src/arrays/extension/compute/take.rs @@ -10,8 +10,6 @@ use crate::IntoArray; use crate::arrays::ExtensionArray; use crate::arrays::ExtensionVTable; use crate::arrays::TakeExecute; -use crate::arrays::TakeExecuteAdaptor; -use crate::kernel::ParentKernelSet; impl TakeExecute for ExtensionVTable { fn take( @@ -31,8 +29,3 @@ impl TakeExecute for ExtensionVTable { )) } } - -impl ExtensionVTable { - pub const TAKE_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); -} diff --git a/vortex-array/src/arrays/fixed_size_list/compute/take.rs b/vortex-array/src/arrays/fixed_size_list/compute/take.rs index 784e9d19ed3..3d507b19490 100644 --- a/vortex-array/src/arrays/fixed_size_list/compute/take.rs +++ b/vortex-array/src/arrays/fixed_size_list/compute/take.rs @@ -17,9 +17,7 @@ use crate::arrays::FixedSizeListArray; use crate::arrays::FixedSizeListVTable; use crate::arrays::PrimitiveArray; use crate::arrays::TakeExecute; -use crate::arrays::TakeExecuteAdaptor; use crate::executor::ExecutionCtx; -use crate::kernel::ParentKernelSet; use crate::validity::Validity; use crate::vtable::ValidityHelper; @@ -41,11 +39,6 @@ impl TakeExecute for FixedSizeListVTable { } } -impl FixedSizeListVTable { - pub const TAKE_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); -} - /// Dispatches to the appropriate take implementation based on list size and nullability. fn take_with_indices( array: &FixedSizeListArray, diff --git a/vortex-array/src/arrays/list/compute/take.rs b/vortex-array/src/arrays/list/compute/take.rs index 01864d6dd75..29235df9d25 100644 --- a/vortex-array/src/arrays/list/compute/take.rs +++ b/vortex-array/src/arrays/list/compute/take.rs @@ -16,11 +16,9 @@ use crate::arrays::ListArray; use crate::arrays::ListVTable; use crate::arrays::PrimitiveArray; use crate::arrays::TakeExecute; -use crate::arrays::TakeExecuteAdaptor; use crate::builders::ArrayBuilder; use crate::builders::PrimitiveBuilder; use crate::executor::ExecutionCtx; -use crate::kernel::ParentKernelSet; use crate::vtable::ValidityHelper; // TODO(connor)[ListView]: Re-revert to the version where we simply convert to a `ListView` and call @@ -52,11 +50,6 @@ impl TakeExecute for ListVTable { } } -impl ListVTable { - pub const TAKE_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); -} - fn _take( array: &ListArray, indices_array: &PrimitiveArray, diff --git a/vortex-array/src/arrays/listview/compute/take.rs b/vortex-array/src/arrays/listview/compute/take.rs index 72f278789e2..faf6340e5f9 100644 --- a/vortex-array/src/arrays/listview/compute/take.rs +++ b/vortex-array/src/arrays/listview/compute/take.rs @@ -14,10 +14,8 @@ use crate::arrays::ListViewArray; use crate::arrays::ListViewRebuildMode; use crate::arrays::ListViewVTable; use crate::arrays::TakeExecute; -use crate::arrays::TakeExecuteAdaptor; use crate::compute; use crate::executor::ExecutionCtx; -use crate::kernel::ParentKernelSet; use crate::vtable::ValidityHelper; // TODO(connor)[ListView]: Make use of this threshold after we start migrating operators. @@ -98,8 +96,3 @@ impl TakeExecute for ListViewVTable { )) } } - -impl ListViewVTable { - pub const TAKE_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); -} diff --git a/vortex-array/src/arrays/masked/compute/take.rs b/vortex-array/src/arrays/masked/compute/take.rs index 6c544db7f7b..162d6e5699d 100644 --- a/vortex-array/src/arrays/masked/compute/take.rs +++ b/vortex-array/src/arrays/masked/compute/take.rs @@ -11,9 +11,7 @@ use crate::IntoArray; use crate::arrays::MaskedArray; use crate::arrays::MaskedVTable; use crate::arrays::TakeExecute; -use crate::arrays::TakeExecuteAdaptor; use crate::compute::fill_null; -use crate::kernel::ParentKernelSet; use crate::vtable::ValidityHelper; impl TakeExecute for MaskedVTable { @@ -47,11 +45,6 @@ impl TakeExecute for MaskedVTable { } } -impl MaskedVTable { - pub const TAKE_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); -} - #[cfg(test)] mod tests { use rstest::rstest; diff --git a/vortex-array/src/arrays/primitive/compute/take/mod.rs b/vortex-array/src/arrays/primitive/compute/take/mod.rs index 23793ed9e78..be832f50cb6 100644 --- a/vortex-array/src/arrays/primitive/compute/take/mod.rs +++ b/vortex-array/src/arrays/primitive/compute/take/mod.rs @@ -24,11 +24,9 @@ use crate::IntoArray; use crate::ToCanonical; use crate::arrays::PrimitiveVTable; use crate::arrays::TakeExecute; -use crate::arrays::TakeExecuteAdaptor; use crate::arrays::primitive::PrimitiveArray; use crate::compute::cast; use crate::executor::ExecutionCtx; -use crate::kernel::ParentKernelSet; use crate::validity::Validity; use crate::vtable::ValidityHelper; @@ -107,11 +105,6 @@ impl TakeExecute for PrimitiveVTable { } } -impl PrimitiveVTable { - pub const TAKE_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); -} - // Compiler may see this as unused based on enabled features #[allow(unused)] #[inline(always)] diff --git a/vortex-array/src/arrays/struct_/compute/take.rs b/vortex-array/src/arrays/struct_/compute/take.rs index a152da021ee..c11ff7dd3a9 100644 --- a/vortex-array/src/arrays/struct_/compute/take.rs +++ b/vortex-array/src/arrays/struct_/compute/take.rs @@ -12,9 +12,7 @@ use crate::IntoArray; use crate::arrays::StructArray; use crate::arrays::StructVTable; use crate::arrays::TakeExecute; -use crate::arrays::TakeExecuteAdaptor; use crate::compute; -use crate::kernel::ParentKernelSet; use crate::validity::Validity; use crate::vtable::ValidityHelper; @@ -55,8 +53,3 @@ impl TakeExecute for StructVTable { .map(Some) } } - -impl StructVTable { - pub const TAKE_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); -} diff --git a/vortex-array/src/arrays/varbin/compute/mod.rs b/vortex-array/src/arrays/varbin/compute/mod.rs index 4af8bd61d13..ab5b09ed818 100644 --- a/vortex-array/src/arrays/varbin/compute/mod.rs +++ b/vortex-array/src/arrays/varbin/compute/mod.rs @@ -4,6 +4,7 @@ pub(crate) mod rules; mod slice; pub(crate) use min_max::varbin_compute_min_max; +pub use take::take_into_varbin; mod cast; mod compare; diff --git a/vortex-array/src/arrays/varbin/compute/take.rs b/vortex-array/src/arrays/varbin/compute/take.rs index 634eb22b800..e4bbe1f7813 100644 --- a/vortex-array/src/arrays/varbin/compute/take.rs +++ b/vortex-array/src/arrays/varbin/compute/take.rs @@ -18,109 +18,108 @@ use crate::IntoArray; use crate::ToCanonical; use crate::arrays::PrimitiveArray; use crate::arrays::TakeExecute; -use crate::arrays::TakeExecuteAdaptor; use crate::arrays::VarBinVTable; use crate::arrays::varbin::VarBinArray; use crate::executor::ExecutionCtx; -use crate::kernel::ParentKernelSet; use crate::validity::Validity; impl TakeExecute for VarBinVTable { - #[expect( - clippy::redundant_clone, - reason = "macro expansion causes false positive - only one match arm executes" - )] fn take( array: &VarBinArray, indices: &dyn Array, _ctx: &mut ExecutionCtx, ) -> VortexResult> { - let offsets = array.offsets().to_primitive(); - let data = array.bytes(); - let indices = indices.to_primitive(); - let dtype = array - .dtype() - .clone() - .union_nullability(indices.dtype().nullability()); - let result = match_each_integer_ptype!(indices.ptype(), |I| { - // On take, offsets get widened to either 32- or 64-bit based on the original type, - // to avoid overflow issues. - match offsets.ptype() { - PType::U8 => take_impl::( - dtype.clone(), - offsets.as_slice::(), - data.as_slice(), - indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, - ), - PType::U16 => take_impl::( - dtype.clone(), - offsets.as_slice::(), - data.as_slice(), - indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, - ), - PType::U32 => take_impl::( - dtype.clone(), - offsets.as_slice::(), - data.as_slice(), - indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, - ), - PType::U64 => take_impl::( - dtype.clone(), - offsets.as_slice::(), - data.as_slice(), - indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, - ), - PType::I8 => take_impl::( - dtype.clone(), - offsets.as_slice::(), - data.as_slice(), - indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, - ), - PType::I16 => take_impl::( - dtype.clone(), - offsets.as_slice::(), - data.as_slice(), - indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, - ), - PType::I32 => take_impl::( - dtype.clone(), - offsets.as_slice::(), - data.as_slice(), - indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, - ), - PType::I64 => take_impl::( - dtype.clone(), - offsets.as_slice::(), - data.as_slice(), - indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, - ), - _ => unreachable!("invalid PType for offsets"), - } - }); - - Ok(Some(result?.into_array())) + Ok(Some(take_into_varbin(array, indices)?.into_array())) } } -impl VarBinVTable { - pub const TAKE_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); +/// Take elements from a VarBinArray and return a new VarBinArray. +/// +/// Unlike `Array::take` which may canonicalize to VarBinView, this function +/// guarantees the result is a VarBinArray. +#[expect( + clippy::redundant_clone, + reason = "macro expansion causes false positive - only one match arm executes" +)] +pub fn take_into_varbin(array: &VarBinArray, indices: &dyn Array) -> VortexResult { + let offsets = array.offsets().to_primitive(); + let data = array.bytes(); + let indices = indices.to_primitive(); + let dtype = array + .dtype() + .clone() + .union_nullability(indices.dtype().nullability()); + match_each_integer_ptype!(indices.ptype(), |I| { + // On take, offsets get widened to either 32- or 64-bit based on the original type, + // to avoid overflow issues. + match offsets.ptype() { + PType::U8 => take_impl::( + dtype.clone(), + offsets.as_slice::(), + data.as_slice(), + indices.as_slice::(), + array.validity_mask()?, + indices.validity_mask()?, + ), + PType::U16 => take_impl::( + dtype.clone(), + offsets.as_slice::(), + data.as_slice(), + indices.as_slice::(), + array.validity_mask()?, + indices.validity_mask()?, + ), + PType::U32 => take_impl::( + dtype.clone(), + offsets.as_slice::(), + data.as_slice(), + indices.as_slice::(), + array.validity_mask()?, + indices.validity_mask()?, + ), + PType::U64 => take_impl::( + dtype.clone(), + offsets.as_slice::(), + data.as_slice(), + indices.as_slice::(), + array.validity_mask()?, + indices.validity_mask()?, + ), + PType::I8 => take_impl::( + dtype.clone(), + offsets.as_slice::(), + data.as_slice(), + indices.as_slice::(), + array.validity_mask()?, + indices.validity_mask()?, + ), + PType::I16 => take_impl::( + dtype.clone(), + offsets.as_slice::(), + data.as_slice(), + indices.as_slice::(), + array.validity_mask()?, + indices.validity_mask()?, + ), + PType::I32 => take_impl::( + dtype.clone(), + offsets.as_slice::(), + data.as_slice(), + indices.as_slice::(), + array.validity_mask()?, + indices.validity_mask()?, + ), + PType::I64 => take_impl::( + dtype.clone(), + offsets.as_slice::(), + data.as_slice(), + indices.as_slice::(), + array.validity_mask()?, + indices.validity_mask()?, + ), + _ => unreachable!("invalid PType for offsets"), + } + }) } fn take_impl( diff --git a/vortex-array/src/arrays/varbin/mod.rs b/vortex-array/src/arrays/varbin/mod.rs index 127f938fe2c..fbfd7cf16d7 100644 --- a/vortex-array/src/arrays/varbin/mod.rs +++ b/vortex-array/src/arrays/varbin/mod.rs @@ -5,8 +5,8 @@ mod array; pub use array::VarBinArray; pub(crate) mod compute; +pub use compute::take_into_varbin; pub(crate) use compute::varbin_compute_min_max; -// For use in `varbinview`. mod vtable; pub use vtable::VarBinVTable; diff --git a/vortex-array/src/arrays/varbinview/compute/take.rs b/vortex-array/src/arrays/varbinview/compute/take.rs index ee027d3a5b3..48dc08db32f 100644 --- a/vortex-array/src/arrays/varbinview/compute/take.rs +++ b/vortex-array/src/arrays/varbinview/compute/take.rs @@ -16,12 +16,10 @@ use crate::ArrayRef; use crate::IntoArray; use crate::ToCanonical; use crate::arrays::TakeExecute; -use crate::arrays::TakeExecuteAdaptor; use crate::arrays::VarBinViewArray; use crate::arrays::VarBinViewVTable; use crate::buffer::BufferHandle; use crate::executor::ExecutionCtx; -use crate::kernel::ParentKernelSet; use crate::vtable::ValidityHelper; impl TakeExecute for VarBinViewVTable { @@ -56,11 +54,6 @@ impl TakeExecute for VarBinViewVTable { } } -impl VarBinViewVTable { - pub const TAKE_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor::(Self))]); -} - fn take_views>( views_ref: &[BinaryView], indices: &[I], From f455669bab3cefa147f647589ae08b03070669d1 Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Fri, 6 Feb 2026 17:42:12 +0000 Subject: [PATCH 19/21] wip Signed-off-by: Joe Isaacs --- encodings/fsst/src/compute/mod.rs | 9 +- vortex-array/src/arrays/varbin/compute/mod.rs | 1 - .../src/arrays/varbin/compute/take.rs | 176 +++++++++--------- vortex-array/src/arrays/varbin/mod.rs | 1 - 4 files changed, 92 insertions(+), 95 deletions(-) diff --git a/encodings/fsst/src/compute/mod.rs b/encodings/fsst/src/compute/mod.rs index 7256cf31122..b2657a73bae 100644 --- a/encodings/fsst/src/compute/mod.rs +++ b/encodings/fsst/src/compute/mod.rs @@ -10,9 +10,11 @@ use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::TakeExecute; -use vortex_array::arrays::take_into_varbin; +use vortex_array::arrays::VarBinVTable; use vortex_array::compute::fill_null; +use vortex_error::VortexExpect; use vortex_error::VortexResult; +use vortex_error::vortex_err; use vortex_scalar::Scalar; use vortex_scalar::ScalarValue; @@ -33,7 +35,10 @@ impl TakeExecute for FSSTVTable { .union_nullability(indices.dtype().nullability()), array.symbols().clone(), array.symbol_lengths().clone(), - take_into_varbin(array.codes(), indices)?, + VarBinVTable::take(array.codes(), indices, _ctx)? + .vortex_expect("cannot fail") + .try_into::() + .map_err(|_| vortex_err!("take for codes must return varbin array"))?, fill_null( &array.uncompressed_lengths().take(indices.to_array())?, &Scalar::new( diff --git a/vortex-array/src/arrays/varbin/compute/mod.rs b/vortex-array/src/arrays/varbin/compute/mod.rs index ab5b09ed818..4af8bd61d13 100644 --- a/vortex-array/src/arrays/varbin/compute/mod.rs +++ b/vortex-array/src/arrays/varbin/compute/mod.rs @@ -4,7 +4,6 @@ pub(crate) mod rules; mod slice; pub(crate) use min_max::varbin_compute_min_max; -pub use take::take_into_varbin; mod cast; mod compare; diff --git a/vortex-array/src/arrays/varbin/compute/take.rs b/vortex-array/src/arrays/varbin/compute/take.rs index e4bbe1f7813..a1719f2dfb8 100644 --- a/vortex-array/src/arrays/varbin/compute/take.rs +++ b/vortex-array/src/arrays/varbin/compute/take.rs @@ -29,100 +29,94 @@ impl TakeExecute for VarBinVTable { indices: &dyn Array, _ctx: &mut ExecutionCtx, ) -> VortexResult> { - Ok(Some(take_into_varbin(array, indices)?.into_array())) + // TODO(joe): Be lazy with execute + let offsets = array.offsets().to_primitive(); + let data = array.bytes(); + let indices = indices.to_primitive(); + let dtype = array + .dtype() + .clone() + .union_nullability(indices.dtype().nullability()); + let array_validity = array.validity_mask()?; + let indices_validity = indices.validity_mask()?; + + let array = match_each_integer_ptype!(indices.ptype(), |I| { + // On take, offsets get widened to either 32- or 64-bit based on the original type, + // to avoid overflow issues. + match offsets.ptype() { + PType::U8 => take::( + dtype, + offsets.as_slice::(), + data.as_slice(), + indices.as_slice::(), + array_validity, + indices_validity, + ), + PType::U16 => take::( + dtype, + offsets.as_slice::(), + data.as_slice(), + indices.as_slice::(), + array_validity, + indices_validity, + ), + PType::U32 => take::( + dtype, + offsets.as_slice::(), + data.as_slice(), + indices.as_slice::(), + array_validity, + indices_validity, + ), + PType::U64 => take::( + dtype, + offsets.as_slice::(), + data.as_slice(), + indices.as_slice::(), + array_validity, + indices_validity, + ), + PType::I8 => take::( + dtype, + offsets.as_slice::(), + data.as_slice(), + indices.as_slice::(), + array_validity, + indices_validity, + ), + PType::I16 => take::( + dtype, + offsets.as_slice::(), + data.as_slice(), + indices.as_slice::(), + array_validity, + indices_validity, + ), + PType::I32 => take::( + dtype, + offsets.as_slice::(), + data.as_slice(), + indices.as_slice::(), + array_validity, + indices_validity, + ), + PType::I64 => take::( + dtype, + offsets.as_slice::(), + data.as_slice(), + indices.as_slice::(), + array_validity, + indices_validity, + ), + _ => unreachable!("invalid PType for offsets"), + } + }); + + Ok(Some(array?.into_array())) } } -/// Take elements from a VarBinArray and return a new VarBinArray. -/// -/// Unlike `Array::take` which may canonicalize to VarBinView, this function -/// guarantees the result is a VarBinArray. -#[expect( - clippy::redundant_clone, - reason = "macro expansion causes false positive - only one match arm executes" -)] -pub fn take_into_varbin(array: &VarBinArray, indices: &dyn Array) -> VortexResult { - let offsets = array.offsets().to_primitive(); - let data = array.bytes(); - let indices = indices.to_primitive(); - let dtype = array - .dtype() - .clone() - .union_nullability(indices.dtype().nullability()); - match_each_integer_ptype!(indices.ptype(), |I| { - // On take, offsets get widened to either 32- or 64-bit based on the original type, - // to avoid overflow issues. - match offsets.ptype() { - PType::U8 => take_impl::( - dtype.clone(), - offsets.as_slice::(), - data.as_slice(), - indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, - ), - PType::U16 => take_impl::( - dtype.clone(), - offsets.as_slice::(), - data.as_slice(), - indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, - ), - PType::U32 => take_impl::( - dtype.clone(), - offsets.as_slice::(), - data.as_slice(), - indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, - ), - PType::U64 => take_impl::( - dtype.clone(), - offsets.as_slice::(), - data.as_slice(), - indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, - ), - PType::I8 => take_impl::( - dtype.clone(), - offsets.as_slice::(), - data.as_slice(), - indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, - ), - PType::I16 => take_impl::( - dtype.clone(), - offsets.as_slice::(), - data.as_slice(), - indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, - ), - PType::I32 => take_impl::( - dtype.clone(), - offsets.as_slice::(), - data.as_slice(), - indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, - ), - PType::I64 => take_impl::( - dtype.clone(), - offsets.as_slice::(), - data.as_slice(), - indices.as_slice::(), - array.validity_mask()?, - indices.validity_mask()?, - ), - _ => unreachable!("invalid PType for offsets"), - } - }) -} - -fn take_impl( +fn take( dtype: DType, offsets: &[Offset], data: &[u8], diff --git a/vortex-array/src/arrays/varbin/mod.rs b/vortex-array/src/arrays/varbin/mod.rs index fbfd7cf16d7..fd4806b3206 100644 --- a/vortex-array/src/arrays/varbin/mod.rs +++ b/vortex-array/src/arrays/varbin/mod.rs @@ -5,7 +5,6 @@ mod array; pub use array::VarBinArray; pub(crate) mod compute; -pub use compute::take_into_varbin; pub(crate) use compute::varbin_compute_min_max; mod vtable; From 0d8cca3b64de499094201b5b6d8c01a816300dd0 Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Fri, 6 Feb 2026 17:45:19 +0000 Subject: [PATCH 20/21] wip Signed-off-by: Joe Isaacs --- encodings/bytebool/src/array.rs | 4 +++- encodings/sparse/src/lib.rs | 3 ++- vortex-array/src/arrays/bool/vtable/mod.rs | 3 ++- vortex-array/src/arrays/decimal/vtable/mod.rs | 3 ++- vortex-array/src/arrays/dict/vtable/mod.rs | 3 ++- vortex-array/src/arrays/extension/vtable/mod.rs | 3 ++- .../src/arrays/fixed_size_list/vtable/kernel.rs | 10 ++++++---- vortex-array/src/arrays/fixed_size_list/vtable/mod.rs | 2 +- vortex-array/src/arrays/listview/vtable/mod.rs | 3 ++- vortex-array/src/arrays/varbin/vtable/mod.rs | 3 ++- 10 files changed, 24 insertions(+), 13 deletions(-) diff --git a/encodings/bytebool/src/array.rs b/encodings/bytebool/src/array.rs index 1301efea03b..0e7fd5c2a75 100644 --- a/encodings/bytebool/src/array.rs +++ b/encodings/bytebool/src/array.rs @@ -37,6 +37,8 @@ use vortex_error::vortex_ensure; use vortex_error::vortex_panic; use vortex_scalar::Scalar; +use crate::kernel::PARENT_KERNELS; + vtable!(ByteBool); impl VTable for ByteBoolVTable { @@ -125,7 +127,7 @@ impl VTable for ByteBoolVTable { child_idx: usize, ctx: &mut ExecutionCtx, ) -> VortexResult> { - crate::kernel::PARENT_KERNELS.execute(array, parent, child_idx, ctx) + PARENT_KERNELS.execute(array, parent, child_idx, ctx) } } diff --git a/encodings/sparse/src/lib.rs b/encodings/sparse/src/lib.rs index 0c2cdac3dfa..ce998c7bd33 100644 --- a/encodings/sparse/src/lib.rs +++ b/encodings/sparse/src/lib.rs @@ -4,6 +4,7 @@ use std::fmt::Debug; use std::hash::Hash; +use kernel::PARENT_KERNELS; use prost::Message as _; use vortex_array::Array; use vortex_array::ArrayBufferVisitor; @@ -157,7 +158,7 @@ impl VTable for SparseVTable { child_idx: usize, ctx: &mut ExecutionCtx, ) -> VortexResult> { - kernel::PARENT_KERNELS.execute(array, parent, child_idx, ctx) + PARENT_KERNELS.execute(array, parent, child_idx, ctx) } fn execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult { diff --git a/vortex-array/src/arrays/bool/vtable/mod.rs b/vortex-array/src/arrays/bool/vtable/mod.rs index 518efe8e645..a883cc265a6 100644 --- a/vortex-array/src/arrays/bool/vtable/mod.rs +++ b/vortex-array/src/arrays/bool/vtable/mod.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use kernel::PARENT_KERNELS; use vortex_dtype::DType; use vortex_error::VortexExpect; use vortex_error::VortexResult; @@ -128,7 +129,7 @@ impl VTable for BoolVTable { child_idx: usize, ctx: &mut ExecutionCtx, ) -> VortexResult> { - kernel::PARENT_KERNELS.execute(array, parent, child_idx, ctx) + PARENT_KERNELS.execute(array, parent, child_idx, ctx) } } diff --git a/vortex-array/src/arrays/decimal/vtable/mod.rs b/vortex-array/src/arrays/decimal/vtable/mod.rs index b2fd347cd78..445776c4bdc 100644 --- a/vortex-array/src/arrays/decimal/vtable/mod.rs +++ b/vortex-array/src/arrays/decimal/vtable/mod.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use kernel::PARENT_KERNELS; use vortex_buffer::Alignment; use vortex_dtype::DType; use vortex_dtype::NativeDecimalType; @@ -145,7 +146,7 @@ impl VTable for DecimalVTable { child_idx: usize, ctx: &mut ExecutionCtx, ) -> VortexResult> { - kernel::PARENT_KERNELS.execute(array, parent, child_idx, ctx) + PARENT_KERNELS.execute(array, parent, child_idx, ctx) } } diff --git a/vortex-array/src/arrays/dict/vtable/mod.rs b/vortex-array/src/arrays/dict/vtable/mod.rs index 2fd80bb4904..7fca42b18c3 100644 --- a/vortex-array/src/arrays/dict/vtable/mod.rs +++ b/vortex-array/src/arrays/dict/vtable/mod.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use kernel::PARENT_KERNELS; use vortex_dtype::DType; use vortex_dtype::Nullability; use vortex_dtype::PType; @@ -160,7 +161,7 @@ impl VTable for DictVTable { child_idx: usize, ctx: &mut ExecutionCtx, ) -> VortexResult> { - kernel::PARENT_KERNELS.execute(array, parent, child_idx, ctx) + PARENT_KERNELS.execute(array, parent, child_idx, ctx) } } diff --git a/vortex-array/src/arrays/extension/vtable/mod.rs b/vortex-array/src/arrays/extension/vtable/mod.rs index 0d6d2672bf5..056909b94ba 100644 --- a/vortex-array/src/arrays/extension/vtable/mod.rs +++ b/vortex-array/src/arrays/extension/vtable/mod.rs @@ -8,6 +8,7 @@ mod operations; mod validity; mod visitor; +use kernel::PARENT_KERNELS; use vortex_dtype::DType; use vortex_error::VortexExpect; use vortex_error::VortexResult; @@ -102,7 +103,7 @@ impl VTable for ExtensionVTable { child_idx: usize, ctx: &mut ExecutionCtx, ) -> VortexResult> { - kernel::PARENT_KERNELS.execute(array, parent, child_idx, ctx) + PARENT_KERNELS.execute(array, parent, child_idx, ctx) } } diff --git a/vortex-array/src/arrays/fixed_size_list/vtable/kernel.rs b/vortex-array/src/arrays/fixed_size_list/vtable/kernel.rs index e82c90cd4ba..6d38dccf2fa 100644 --- a/vortex-array/src/arrays/fixed_size_list/vtable/kernel.rs +++ b/vortex-array/src/arrays/fixed_size_list/vtable/kernel.rs @@ -5,7 +5,9 @@ use crate::arrays::FixedSizeListVTable; use crate::arrays::TakeExecuteAdaptor; use crate::kernel::ParentKernelSet; -pub(super) const PARENT_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor( - FixedSizeListVTable, - ))]); +impl FixedSizeListVTable { + pub(crate) const PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor( + FixedSizeListVTable, + ))]); +} diff --git a/vortex-array/src/arrays/fixed_size_list/vtable/mod.rs b/vortex-array/src/arrays/fixed_size_list/vtable/mod.rs index 1c3bb11d623..17fe2d8c537 100644 --- a/vortex-array/src/arrays/fixed_size_list/vtable/mod.rs +++ b/vortex-array/src/arrays/fixed_size_list/vtable/mod.rs @@ -63,7 +63,7 @@ impl VTable for FixedSizeListVTable { child_idx: usize, ctx: &mut ExecutionCtx, ) -> VortexResult> { - kernel::PARENT_KERNELS.execute(array, parent, child_idx, ctx) + Self::PARENT_KERNELS.execute(array, parent, child_idx, ctx) } fn metadata(_array: &FixedSizeListArray) -> VortexResult { diff --git a/vortex-array/src/arrays/listview/vtable/mod.rs b/vortex-array/src/arrays/listview/vtable/mod.rs index 6ba12435077..d1984153fd0 100644 --- a/vortex-array/src/arrays/listview/vtable/mod.rs +++ b/vortex-array/src/arrays/listview/vtable/mod.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use kernel::PARENT_KERNELS; use vortex_dtype::DType; use vortex_dtype::Nullability; use vortex_dtype::PType; @@ -178,6 +179,6 @@ impl VTable for ListViewVTable { child_idx: usize, ctx: &mut ExecutionCtx, ) -> VortexResult> { - kernel::PARENT_KERNELS.execute(array, parent, child_idx, ctx) + PARENT_KERNELS.execute(array, parent, child_idx, ctx) } } diff --git a/vortex-array/src/arrays/varbin/vtable/mod.rs b/vortex-array/src/arrays/varbin/vtable/mod.rs index fa92e528f27..106c0305fdb 100644 --- a/vortex-array/src/arrays/varbin/vtable/mod.rs +++ b/vortex-array/src/arrays/varbin/vtable/mod.rs @@ -32,6 +32,7 @@ mod validity; mod visitor; use canonical::varbin_to_canonical; +use kernel::PARENT_KERNELS; use crate::arrays::varbin::compute::rules::PARENT_RULES; @@ -141,7 +142,7 @@ impl VTable for VarBinVTable { child_idx: usize, ctx: &mut ExecutionCtx, ) -> VortexResult> { - kernel::PARENT_KERNELS.execute(array, parent, child_idx, ctx) + PARENT_KERNELS.execute(array, parent, child_idx, ctx) } fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { From 0e68b7f8c47e5133cba627d245b727ffeb323c7f Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Fri, 6 Feb 2026 18:58:30 +0000 Subject: [PATCH 21/21] wip Signed-off-by: Joe Isaacs --- vortex-array/src/patches.rs | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/vortex-array/src/patches.rs b/vortex-array/src/patches.rs index 96778bceda2..06d35912074 100644 --- a/vortex-array/src/patches.rs +++ b/vortex-array/src/patches.rs @@ -1044,9 +1044,11 @@ where AllOr::None => true, AllOr::Some(buf) => !buf.value(idx_in_take), }; - if include_nulls && is_null { - new_sparse_indices.push(idx_in_take as u64); - value_indices.push(0); + if is_null { + if include_nulls { + new_sparse_indices.push(idx_in_take as u64); + value_indices.push(0); + } } else if ti >= min_index && ti <= max_index { let ti_as_i = I::try_from(ti) .map_err(|_| vortex_err!("take index does not fit in index type"))?; @@ -1185,10 +1187,13 @@ fn take_indices_with_search_fn< let mut new_indices = BufferMut::with_capacity(take_indices.len()); for (new_patch_idx, &take_idx) in take_indices.iter().enumerate() { - if include_nulls && !take_validity.value(new_patch_idx) { - // For nulls, patch index doesn't matter - use 0 for consistency - values_indices.push(0u64); - new_indices.push(new_patch_idx as u64); + if !take_validity.value(new_patch_idx) { + if include_nulls { + // For nulls, patch index doesn't matter - use 0 for consistency + values_indices.push(0u64); + new_indices.push(new_patch_idx as u64); + } + continue; } else { let search_result = match I::from(take_idx) { Some(idx) => search_fn(idx)?,