diff --git a/Cargo.lock b/Cargo.lock index 6bf31b33898..a21f2a03e23 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10640,6 +10640,7 @@ dependencies = [ name = "vortex-error" version = "0.1.0" dependencies = [ + "arbitrary", "arrow-schema 57.2.0", "flatbuffers", "jiff", diff --git a/encodings/alp/src/alp/ops.rs b/encodings/alp/src/alp/ops.rs index d17b0dddf02..af1a3a72bb6 100644 --- a/encodings/alp/src/alp/ops.rs +++ b/encodings/alp/src/alp/ops.rs @@ -22,10 +22,8 @@ impl OperationsVTable for ALPVTable { let encoded_val = array.encoded().scalar_at(index)?; Ok(match_each_alp_float_ptype!(array.ptype(), |T| { - let encoded_val: ::ALPInt = encoded_val - .as_ref() - .try_into() - .vortex_expect("invalid ALPInt"); + let encoded_val: ::ALPInt = + (&encoded_val).try_into().vortex_expect("invalid ALPInt"); Scalar::primitive( ::decode_single(encoded_val, array.exponents()), array.dtype().nullability(), diff --git a/encodings/alp/src/alp_rd/compute/take.rs b/encodings/alp/src/alp_rd/compute/take.rs index b2a395c0813..7235a7a1393 100644 --- a/encodings/alp/src/alp_rd/compute/take.rs +++ b/encodings/alp/src/alp_rd/compute/take.rs @@ -9,7 +9,6 @@ use vortex_array::arrays::TakeExecute; use vortex_array::compute::fill_null; use vortex_error::VortexResult; use vortex_scalar::Scalar; -use vortex_scalar::ScalarValue; use crate::ALPRDArray; use crate::ALPRDVTable; @@ -36,7 +35,7 @@ impl TakeExecute for ALPRDVTable { .transpose()?; let right_parts = fill_null( &array.right_parts().take(indices.to_array())?, - &Scalar::new(array.right_parts().dtype().clone(), ScalarValue::from(0)), + &Scalar::zero_value(&array.right_parts().dtype().clone()), )?; Ok(Some( diff --git a/encodings/datetime-parts/src/compute/rules.rs b/encodings/datetime-parts/src/compute/rules.rs index 9702927c9ad..a7afbca64dc 100644 --- a/encodings/datetime-parts/src/compute/rules.rs +++ b/encodings/datetime-parts/src/compute/rules.rs @@ -171,7 +171,7 @@ fn try_extract_days_constant(array: &ArrayRef) -> Option { fn is_constant_zero(array: &ArrayRef) -> bool { array .as_opt::() - .is_some_and(|c| c.scalar().is_zero()) + .is_some_and(|c| c.scalar().value().is_some_and(|value| value.is_zero())) } #[cfg(test)] diff --git a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/compare.rs b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/compare.rs index ca7ff934be7..2ff776c36b3 100644 --- a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/compare.rs +++ b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/compare.rs @@ -46,9 +46,9 @@ impl CompareKernel for DecimalBytePartsVTable { .vortex_expect("checked for null in entry func"); match decimal_value_wrapper_to_primitive(rhs_decimal, lhs.msp.as_primitive_typed().ptype()) - .map(|value| Scalar::new(scalar_type.clone(), value)) { - Ok(encoded_scalar) => { + Ok(value) => { + let encoded_scalar = Scalar::try_new(scalar_type, Some(value))?; let encoded_const = ConstantArray::new(encoded_scalar, rhs.len()); compare(&lhs.msp, &encoded_const.to_array(), operator).map(Some) } @@ -165,7 +165,10 @@ mod tests { ) .unwrap() .to_array(); - let rhs = ConstantArray::new(Scalar::new(dtype, DecimalValue::I64(400).into()), lhs.len()); + let rhs = ConstantArray::new( + Scalar::try_new(dtype, Some(DecimalValue::I64(400).into())).unwrap(), + lhs.len(), + ); let res = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap(); @@ -215,10 +218,11 @@ mod tests { .to_array(); // This cannot be converted to a i32. let rhs = ConstantArray::new( - Scalar::new( + Scalar::try_new( dtype.clone(), - DecimalValue::I128(-9999999999999965304).into(), - ), + Some(DecimalValue::I128(-9999999999999965304).into()), + ) + .unwrap(), lhs.len(), ); @@ -236,7 +240,7 @@ mod tests { // This cannot be converted to a i32. let rhs = ConstantArray::new( - Scalar::new(dtype, DecimalValue::I128(9999999999999965304).into()), + Scalar::try_new(dtype, Some(DecimalValue::I128(9999999999999965304).into())).unwrap(), lhs.len(), ); diff --git a/encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs b/encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs index 0e4de81842d..4f6c61268ca 100644 --- a/encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs +++ b/encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs @@ -44,6 +44,7 @@ use vortex_error::vortex_bail; use vortex_error::vortex_ensure; use vortex_scalar::DecimalValue; use vortex_scalar::Scalar; +use vortex_scalar::ScalarValue; use vortex_session::VortexSession; use crate::decimal_byte_parts::compute::kernel::PARENT_KERNELS; @@ -285,10 +286,10 @@ impl OperationsVTable for DecimalBytePartsVTable { let primitive_scalar = scalar.as_primitive(); // TODO(joe): extend this to support multiple parts. let value = primitive_scalar.as_::().vortex_expect("non-null"); - Ok(Scalar::new( + Scalar::try_new( array.dtype.clone(), - DecimalValue::I64(value).into(), - )) + Some(ScalarValue::Decimal(DecimalValue::I64(value))), + ) } } @@ -319,6 +320,7 @@ mod tests { use vortex_dtype::Nullability; use vortex_scalar::DecimalValue; use vortex_scalar::Scalar; + use vortex_scalar::ScalarValue; use crate::DecimalBytePartsArray; @@ -339,11 +341,15 @@ mod tests { assert_eq!(Scalar::null(dtype.clone()), array.scalar_at(0).unwrap()); assert_eq!( - Scalar::new(dtype.clone(), DecimalValue::I64(200).into()), + Scalar::try_new( + dtype.clone(), + Some(ScalarValue::Decimal(DecimalValue::I64(200))) + ) + .unwrap(), array.scalar_at(1).unwrap() ); assert_eq!( - Scalar::new(dtype, DecimalValue::I64(400).into()), + Scalar::try_new(dtype, Some(ScalarValue::Decimal(DecimalValue::I64(400)))).unwrap(), array.scalar_at(2).unwrap() ); } diff --git a/encodings/fastlanes/src/for/array/for_compress.rs b/encodings/fastlanes/src/for/array/for_compress.rs index 4b20acc5b77..350842560be 100644 --- a/encodings/fastlanes/src/for/array/for_compress.rs +++ b/encodings/fastlanes/src/for/array/for_compress.rs @@ -175,10 +175,7 @@ mod test { .iter() .enumerate() .for_each(|(i, v)| { - assert_eq!( - *v, - i8::try_from(compressed.scalar_at(i).unwrap().as_ref()).unwrap() - ); + assert_eq!(*v, i8::try_from(compressed.scalar_at(i).unwrap()).unwrap()); }); assert_arrays_eq!(decompressed, array); Ok(()) diff --git a/encodings/fastlanes/src/for/vtable/mod.rs b/encodings/fastlanes/src/for/vtable/mod.rs index 8808b34589f..7c15e491d2a 100644 --- a/encodings/fastlanes/src/for/vtable/mod.rs +++ b/encodings/fastlanes/src/for/vtable/mod.rs @@ -2,13 +2,10 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use std::fmt::Debug; -use std::fmt::Formatter; use vortex_array::ArrayRef; -use vortex_array::DeserializeMetadata; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; -use vortex_array::SerializeMetadata; use vortex_array::buffer::BufferHandle; use vortex_array::serde::ArrayChildren; use vortex_array::vtable; @@ -41,7 +38,9 @@ vtable!(FoR); impl VTable for FoRVTable { type Array = FoRArray; - type Metadata = ScalarValueMetadata; + // TODO(connor): This should really be a `Scalar` but we need to deprecate `deserialize` for the + // `build` method. + type Metadata = Vec; type ArrayVTable = Self; type OperationsVTable = Self; @@ -67,14 +66,16 @@ impl VTable for FoRVTable { Ok(()) } + // TODO(connor): DON'T TOUCH THIS UNLESS YOU KNOW WHAT YOU ARE DOING!!! fn metadata(array: &FoRArray) -> VortexResult { - Ok(ScalarValueMetadata( - array.reference_scalar().value().clone(), + Ok(ScalarValue::to_proto_bytes( + array.reference_scalar().value(), )) } + // TODO(connor): DON'T TOUCH THIS UNLESS YOU KNOW WHAT YOU ARE DOING!!! fn serialize(metadata: Self::Metadata) -> VortexResult>> { - Ok(Some(metadata.serialize())) + Ok(Some(metadata)) } fn deserialize( @@ -83,7 +84,7 @@ impl VTable for FoRVTable { _len: usize, _session: &VortexSession, ) -> VortexResult { - ScalarValueMetadata::deserialize(bytes) + Ok(bytes.to_vec()) } fn build( @@ -101,7 +102,8 @@ impl VTable for FoRVTable { } let encoded = children.get(0, dtype, len)?; - let reference = Scalar::new(dtype.clone(), metadata.0.clone()); + let scalar_value = ScalarValue::from_proto_bytes(metadata, dtype)?; + let reference = Scalar::try_new(dtype.clone(), scalar_value)?; FoRArray::try_new(encoded, reference) } @@ -134,27 +136,3 @@ pub struct FoRVTable; impl FoRVTable { pub const ID: ArrayId = ArrayId::new_ref("fastlanes.for"); } - -#[derive(Clone)] -pub struct ScalarValueMetadata(pub ScalarValue); - -impl SerializeMetadata for ScalarValueMetadata { - fn serialize(self) -> Vec { - self.0.to_protobytes() - } -} - -impl DeserializeMetadata for ScalarValueMetadata { - type Output = ScalarValueMetadata; - - fn deserialize(metadata: &[u8]) -> VortexResult { - let scalar_value = ScalarValue::from_protobytes(metadata)?; - Ok(ScalarValueMetadata(scalar_value)) - } -} - -impl Debug for ScalarValueMetadata { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", &self.0) - } -} diff --git a/encodings/fastlanes/src/rle/vtable/operations.rs b/encodings/fastlanes/src/rle/vtable/operations.rs index 10af0cb6f00..2ea20a3b8a1 100644 --- a/encodings/fastlanes/src/rle/vtable/operations.rs +++ b/encodings/fastlanes/src/rle/vtable/operations.rs @@ -27,7 +27,7 @@ impl OperationsVTable for RLEVTable { .values() .scalar_at(value_idx_offset + chunk_relative_idx)?; - Ok(Scalar::new(array.dtype().clone(), scalar.into_value())) + Scalar::try_new(array.dtype().clone(), scalar.into_value()) } } diff --git a/encodings/fsst/src/compute/compare.rs b/encodings/fsst/src/compute/compare.rs index 9119725c655..a3852422761 100644 --- a/encodings/fsst/src/compute/compare.rs +++ b/encodings/fsst/src/compute/compare.rs @@ -104,16 +104,16 @@ fn compare_fsst_constant( DType::Binary(_) => { let value = right .as_binary() - .value() + .to_value() .vortex_expect("Expected non-null scalar"); ByteBuffer::from(compressor.compress(value.as_slice())) } _ => unreachable!("FSSTArray can only have string or binary data type"), }; - let encoded_scalar = Scalar::new( - DType::Binary(left.dtype().nullability() | right.dtype().nullability()), - encoded_buffer.into(), + let encoded_scalar = Scalar::binary( + encoded_buffer, + left.dtype().nullability() | right.dtype().nullability(), ); let rhs = ConstantArray::new(encoded_scalar, left.len()); diff --git a/encodings/fsst/src/compute/mod.rs b/encodings/fsst/src/compute/mod.rs index b2657a73bae..2d8bb58f48c 100644 --- a/encodings/fsst/src/compute/mod.rs +++ b/encodings/fsst/src/compute/mod.rs @@ -16,7 +16,6 @@ use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_err; use vortex_scalar::Scalar; -use vortex_scalar::ScalarValue; use crate::FSSTArray; use crate::FSSTVTable; @@ -41,10 +40,7 @@ impl TakeExecute for FSSTVTable { .map_err(|_| vortex_err!("take for codes must return varbin array"))?, fill_null( &array.uncompressed_lengths().take(indices.to_array())?, - &Scalar::new( - array.uncompressed_lengths_dtype().clone(), - ScalarValue::from(0), - ), + &Scalar::zero_value(&array.uncompressed_lengths_dtype().clone()), )?, )? .into_array(), diff --git a/encodings/fsst/src/ops.rs b/encodings/fsst/src/ops.rs index 9b69b743a36..d1cb618647f 100644 --- a/encodings/fsst/src/ops.rs +++ b/encodings/fsst/src/ops.rs @@ -14,7 +14,7 @@ use crate::FSSTVTable; impl OperationsVTable for FSSTVTable { fn scalar_at(array: &FSSTArray, index: usize) -> VortexResult { let compressed = array.codes().scalar_at(index)?; - let binary_datum = compressed.as_binary().value().vortex_expect("non-null"); + let binary_datum = compressed.as_binary().to_value().vortex_expect("non-null"); let decoded_buffer = ByteBuffer::from(array.decompressor().decompress(binary_datum.as_slice())); diff --git a/encodings/runend/src/array.rs b/encodings/runend/src/array.rs index 242a3d12c9b..f4bf8eaddac 100644 --- a/encodings/runend/src/array.rs +++ b/encodings/runend/src/array.rs @@ -222,13 +222,13 @@ impl RunEndArray { // Validate the offset and length are valid for the given ends and values if offset != 0 && length != 0 { - let first_run_end: usize = ends.scalar_at(0)?.as_ref().try_into()?; + let first_run_end: usize = usize::try_from(&ends.scalar_at(0)?)?; if first_run_end <= offset { vortex_bail!("First run end {first_run_end} must be bigger than offset {offset}"); } } - let last_run_end: usize = ends.scalar_at(ends.len() - 1)?.as_ref().try_into()?; + let last_run_end: usize = usize::try_from(&ends.scalar_at(ends.len() - 1)?)?; let min_required_end = offset + length; if last_run_end < min_required_end { vortex_bail!("Last run end {last_run_end} must be >= offset+length {min_required_end}"); @@ -302,7 +302,7 @@ impl RunEndArray { let length: usize = if ends.is_empty() { 0 } else { - ends.scalar_at(ends.len() - 1)?.as_ref().try_into()? + usize::try_from(&ends.scalar_at(ends.len() - 1)?)? }; Self::try_new_offset_length(ends, values, 0, length) diff --git a/encodings/runend/src/compute/cast.rs b/encodings/runend/src/compute/cast.rs index 35316316972..6b35c7b2183 100644 --- a/encodings/runend/src/compute/cast.rs +++ b/encodings/runend/src/compute/cast.rs @@ -76,27 +76,19 @@ mod tests { // RunEnd encoding should expand to [100, 100, 100, 200, 200, 100, 100, 100, 300, 300] assert_eq!(decoded.len(), 10); assert_eq!( - TryInto::::try_into(decoded.scalar_at(0).unwrap().as_ref()) - .ok() - .unwrap(), + TryInto::::try_into(decoded.scalar_at(0).unwrap()).unwrap(), 100i64 ); assert_eq!( - TryInto::::try_into(decoded.scalar_at(3).unwrap().as_ref()) - .ok() - .unwrap(), + TryInto::::try_into(decoded.scalar_at(3).unwrap()).unwrap(), 200i64 ); assert_eq!( - TryInto::::try_into(decoded.scalar_at(5).unwrap().as_ref()) - .ok() - .unwrap(), + TryInto::::try_into(decoded.scalar_at(5).unwrap()).unwrap(), 100i64 ); assert_eq!( - TryInto::::try_into(decoded.scalar_at(8).unwrap().as_ref()) - .ok() - .unwrap(), + TryInto::::try_into(decoded.scalar_at(8).unwrap()).unwrap(), 300i64 ); } diff --git a/encodings/sequence/src/array.rs b/encodings/sequence/src/array.rs index 2a9d0983c6f..28e83e10ccc 100644 --- a/encodings/sequence/src/array.rs +++ b/encodings/sequence/src/array.rs @@ -245,28 +245,26 @@ impl VTable for SequenceVTable { let ptype = dtype.as_ptype(); // We go via scalar to cast the scalar values into the correct PType - let base = Scalar::new( - DType::Primitive(ptype, NonNullable), + let base = Scalar::from_proto_value( metadata .0 .base .as_ref() - .ok_or_else(|| vortex_err!("base required"))? - .try_into()?, - ) + .ok_or_else(|| vortex_err!("base required"))?, + &DType::Primitive(ptype, NonNullable), + )? .as_primitive() .pvalue() .vortex_expect("non-nullable primitive"); - let multiplier = Scalar::new( - DType::Primitive(ptype, NonNullable), + let multiplier = Scalar::from_proto_value( metadata .0 .multiplier .as_ref() - .ok_or_else(|| vortex_err!("base required"))? - .try_into()?, - ) + .ok_or_else(|| vortex_err!("multiplier required"))?, + &DType::Primitive(ptype, NonNullable), + )? .as_primitive() .pvalue() .vortex_expect("non-nullable primitive"); @@ -355,10 +353,10 @@ impl BaseArrayVTable for SequenceVTable { impl OperationsVTable for SequenceVTable { fn scalar_at(array: &SequenceArray, index: usize) -> VortexResult { - Ok(Scalar::new( + Scalar::try_new( array.dtype().clone(), - ScalarValue::from(array.index_value(index)), - )) + Some(ScalarValue::Primitive(array.index_value(index))), + ) } } @@ -423,7 +421,7 @@ mod tests { assert_eq!( scalar, - Scalar::new(scalar.dtype().clone(), ScalarValue::from(8i64)) + Scalar::try_new(scalar.dtype().clone(), Some(ScalarValue::from(8i64))).unwrap() ) } diff --git a/encodings/sequence/src/compute/cast.rs b/encodings/sequence/src/compute/cast.rs index 68725d9e6a3..20cfebcf46f 100644 --- a/encodings/sequence/src/compute/cast.rs +++ b/encodings/sequence/src/compute/cast.rs @@ -48,14 +48,14 @@ impl CastKernel for SequenceVTable { // For type changes, we need to cast the base and multiplier if array.ptype() != *target_ptype { // Create scalars from PValues and cast them - let base_scalar = Scalar::new( + let base_scalar = Scalar::try_new( DType::Primitive(array.ptype(), Nullability::NonNullable), - ScalarValue::from(array.base()), - ); - let multiplier_scalar = Scalar::new( + Some(ScalarValue::Primitive(array.base())), + )?; + let multiplier_scalar = Scalar::try_new( DType::Primitive(array.ptype(), Nullability::NonNullable), - ScalarValue::from(array.multiplier()), - ); + Some(ScalarValue::Primitive(array.multiplier())), + )?; let new_base_scalar = base_scalar.cast(&DType::Primitive(*target_ptype, Nullability::NonNullable))?; diff --git a/encodings/sequence/src/compute/compare.rs b/encodings/sequence/src/compute/compare.rs index 6fa11d534a5..fc652216a15 100644 --- a/encodings/sequence/src/compute/compare.rs +++ b/encodings/sequence/src/compute/compare.rs @@ -9,7 +9,6 @@ use vortex_array::compute::CompareKernel; use vortex_array::compute::Operator; use vortex_array::validity::Validity; use vortex_buffer::BitBuffer; -use vortex_dtype::DType; use vortex_dtype::NativePType; use vortex_dtype::Nullability; use vortex_dtype::match_each_integer_ptype; @@ -58,11 +57,7 @@ impl CompareKernel for SequenceVTable { Ok(Some(BoolArray::new(buffer, validity).to_array())) } else { Ok(Some( - ConstantArray::new( - Scalar::new(DType::Bool(nullability), false.into()), - lhs.len(), - ) - .to_array(), + ConstantArray::new(Scalar::bool(false, nullability), lhs.len()).to_array(), )) } } diff --git a/encodings/sequence/src/kernel.rs b/encodings/sequence/src/kernel.rs index 982f33cf51b..6dacc4d01ca 100644 --- a/encodings/sequence/src/kernel.rs +++ b/encodings/sequence/src/kernel.rs @@ -129,22 +129,15 @@ fn compare_eq_neq( find_intersection_scalar(array.base(), array.multiplier(), array.len, constant) else { return Ok(Some( - ConstantArray::new( - Scalar::new(DType::Bool(nullability), not_match_val.into()), - array.len, - ) - .into_array(), + ConstantArray::new(Scalar::bool(not_match_val, nullability), array.len).into_array(), )); }; let idx = set_idx as u64; let len = array.len as u64; if len == 1 && set_idx == 0 { - let result_array = ConstantArray::new( - Scalar::new(DType::Bool(nullability), match_val.into()), - array.len, - ) - .to_array(); + let result_array = + ConstantArray::new(Scalar::bool(match_val, nullability), array.len).to_array(); return Ok(Some(result_array)); } @@ -186,16 +179,12 @@ fn compare_ordering( ); let result_array = match transition { - Transition::AllTrue => ConstantArray::new( - Scalar::new(DType::Bool(nullability), true.into()), - array.len, - ) - .to_array(), - Transition::AllFalse => ConstantArray::new( - Scalar::new(DType::Bool(nullability), false.into()), - array.len, - ) - .to_array(), + Transition::AllTrue => { + ConstantArray::new(Scalar::bool(true, nullability), array.len).to_array() + } + Transition::AllFalse => { + ConstantArray::new(Scalar::bool(false, nullability), array.len).to_array() + } Transition::FalseToTrue(idx) => { // [0..idx) is false, [idx..len) is true let ends = buffer![idx as u64, array.len as u64].into_array(); @@ -362,10 +351,11 @@ mod tests { fn test_sequence_gte_constant() -> VortexResult<()> { let seq = SequenceArray::typed_new(0i64, 1, NonNullable, 10)?.to_array(); let constant = ConstantArray::new( - Scalar::new( + Scalar::try_new( DType::Primitive(PType::I64, Nullability::Nullable), - 5i64.into(), - ), + Some(5i64.into()), + ) + .unwrap(), 10, ) .to_array(); diff --git a/encodings/sparse/src/canonical.rs b/encodings/sparse/src/canonical.rs index 222c4befeb0..b2c2245c5cc 100644 --- a/encodings/sparse/src/canonical.rs +++ b/encodings/sparse/src/canonical.rs @@ -97,12 +97,12 @@ pub(super) fn execute_sparse(array: &SparseArray) -> VortexResult { }) } dtype @ DType::Utf8(..) => { - let fill_value = array.fill_scalar().as_utf8().value(); + let fill_value = array.fill_scalar().as_utf8().value().cloned(); let fill_value = fill_value.map(BufferString::into_inner); execute_varbin(array, dtype.clone(), fill_value)? } dtype @ DType::Binary(..) => { - let fill_value = array.fill_scalar().as_binary().value(); + let fill_value = array.fill_scalar().as_binary().to_value(); execute_varbin(array, dtype.clone(), fill_value)? } DType::List(values_dtype, nullability) => { @@ -369,12 +369,12 @@ fn execute_sparse_struct( unresolved_patches: &Patches, len: usize, ) -> VortexResult { - let (fill_values, top_level_fill_validity) = match fill_struct.fields() { + let (fill_values, top_level_fill_validity) = match fill_struct.fields_iter() { Some(fill_values) => (fill_values.collect::>(), Validity::AllValid), None => ( struct_fields .fields() - .map(Scalar::default_value) + .map(|f| Scalar::default_value(&f)) .collect::>(), Validity::AllInvalid, ), diff --git a/encodings/sparse/src/compute/cast.rs b/encodings/sparse/src/compute/cast.rs index 2db305e80b9..c38ac614ded 100644 --- a/encodings/sparse/src/compute/cast.rs +++ b/encodings/sparse/src/compute/cast.rs @@ -77,7 +77,7 @@ mod tests { buffer![1u64, 3, 5].into_array(), PrimitiveArray::from_option_iter([Some(42i32), Some(84), Some(126)]).into_array(), 8, - Scalar::null_typed::(), + Scalar::null_native::(), ) .unwrap(); @@ -109,7 +109,7 @@ mod tests { buffer![1u64, 3, 7].into_array(), PrimitiveArray::from_option_iter([Some(100i32), None, Some(300)]).into_array(), 10, - Scalar::null_typed::() + Scalar::null_native::() ).unwrap())] #[case(SparseArray::try_new( buffer![5u64].into_array(), diff --git a/encodings/sparse/src/compute/filter.rs b/encodings/sparse/src/compute/filter.rs index 4b0046631e4..47f8c805156 100644 --- a/encodings/sparse/src/compute/filter.rs +++ b/encodings/sparse/src/compute/filter.rs @@ -60,7 +60,7 @@ mod tests { buffer![2u64, 9, 15].into_array(), PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(), 20, - Scalar::null_typed::(), + Scalar::null_native::(), ) .unwrap() .into_array() @@ -80,7 +80,7 @@ mod tests { buffer![0u64].into_array(), PrimitiveArray::new(buffer![33_i32], Validity::AllValid).into_array(), 1, - Scalar::null_typed::(), + Scalar::null_native::(), ) .unwrap(); @@ -94,7 +94,7 @@ mod tests { buffer![0_u64, 3, 6].into_array(), PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(), 7, - Scalar::null_typed::(), + Scalar::null_native::(), ) .unwrap() .into_array(); @@ -109,7 +109,7 @@ mod tests { buffer![1u64, 3].into_array(), PrimitiveArray::new(buffer![44_i32, 55], Validity::AllValid).into_array(), 4, - Scalar::null_typed::(), + Scalar::null_native::(), ) .unwrap(); diff --git a/encodings/sparse/src/compute/mod.rs b/encodings/sparse/src/compute/mod.rs index 38bdb41b68a..c3b6868964b 100644 --- a/encodings/sparse/src/compute/mod.rs +++ b/encodings/sparse/src/compute/mod.rs @@ -14,6 +14,7 @@ mod test { use vortex_array::ArrayRef; use vortex_array::IntoArray; use vortex_array::arrays::PrimitiveArray; + use vortex_array::assert_arrays_eq; use vortex_array::compute::cast; use vortex_array::compute::conformance::binary_numeric::test_binary_numeric_array; use vortex_array::compute::conformance::mask::test_mask_conformance; @@ -22,6 +23,7 @@ mod test { use vortex_dtype::DType; use vortex_dtype::Nullability; use vortex_dtype::PType; + use vortex_mask::Mask; use vortex_scalar::Scalar; use crate::SparseArray; @@ -32,12 +34,62 @@ mod test { buffer![2u64, 9, 15].into_array(), PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(), 20, - Scalar::null_typed::(), + Scalar::null_native::(), ) .unwrap() .into_array() } + #[rstest] + fn test_filter(array: ArrayRef) { + let mut predicate = vec![false, false, true]; + predicate.extend_from_slice(&[false; 17]); + let mask = Mask::from_iter(predicate); + + let filtered_array = array.filter(mask).unwrap(); + + // Construct expected SparseArray: index 2 was kept, which had value 33. + // The new index is 0 (since it's the only element). + let expected = SparseArray::try_new( + buffer![0u64].into_array(), + PrimitiveArray::new(buffer![33_i32], Validity::AllValid).into_array(), + 1, + Scalar::null_native::(), + ) + .unwrap(); + + assert_arrays_eq!(filtered_array, expected); + } + + #[test] + fn true_fill_value() { + let mask = Mask::from_iter([false, true, false, true, false, true, true]); + let array = SparseArray::try_new( + buffer![0_u64, 3, 6].into_array(), + PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(), + 7, + Scalar::null_native::(), + ) + .unwrap() + .into_array(); + + let filtered_array = array.filter(mask).unwrap(); + + // Original indices 0, 3, 6 with values 33, 44, 55. + // Mask keeps indices 1, 3, 5, 6 -> new indices 0, 1, 2, 3. + // Index 3 (value 44) maps to new index 1. + // Index 6 (value 55) maps to new index 3. + let expected = SparseArray::try_new( + buffer![1u64, 3].into_array(), + PrimitiveArray::new(buffer![44_i32, 55], Validity::AllValid).into_array(), + 4, + Scalar::null_native::(), + ) + .unwrap(); + + assert_arrays_eq!(filtered_array, expected); + } + #[rstest] fn test_sparse_binary_numeric(array: ArrayRef) { test_binary_numeric_array(array) @@ -97,7 +149,7 @@ mod tests { buffer![2u64, 5, 8].into_array(), PrimitiveArray::from_option_iter([Some(100i32), Some(200), Some(300)]).into_array(), 10, - Scalar::null_typed::() + Scalar::null_native::() ).unwrap())] #[case::sparse_i32_value_fill(SparseArray::try_new( buffer![1u64, 3, 7].into_array(), @@ -129,7 +181,7 @@ mod tests { buffer![0u64, 1, 2, 3, 4].into_array(), PrimitiveArray::from_option_iter([Some(10i32), Some(20), Some(30), Some(40), Some(50)]).into_array(), 5, - Scalar::null_typed::() + Scalar::null_native::() ).unwrap())] // Large sparse arrays #[case::sparse_large(SparseArray::try_new( diff --git a/encodings/sparse/src/compute/take.rs b/encodings/sparse/src/compute/take.rs index 3ff7d74b0df..628dc345157 100644 --- a/encodings/sparse/src/compute/take.rs +++ b/encodings/sparse/src/compute/take.rs @@ -73,7 +73,7 @@ mod test { fn test_array_fill_value() -> Scalar { // making this const is annoying - Scalar::null_typed::() + Scalar::null_native::() } fn sparse_array() -> ArrayRef { @@ -175,7 +175,7 @@ mod test { buffer![0u64, 37, 47, 99].into_array(), PrimitiveArray::new(buffer![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid).into_array(), 100, - Scalar::null_typed::(), + Scalar::null_native::(), ).unwrap())] #[case(SparseArray::try_new( buffer![1u32, 3, 7, 8, 9].into_array(), @@ -189,7 +189,7 @@ mod test { buffer![2u64, 4, 6].into_array(), nullable_values.into_array(), 10, - Scalar::null_typed::(), + Scalar::null_native::(), ).unwrap() })] #[case(SparseArray::try_new( diff --git a/encodings/sparse/src/lib.rs b/encodings/sparse/src/lib.rs index fbfd003c3d1..bca47636966 100644 --- a/encodings/sparse/src/lib.rs +++ b/encodings/sparse/src/lib.rs @@ -128,10 +128,7 @@ impl VTable for SparseVTable { if buffers.len() != 1 { vortex_bail!("Expected 1 buffer, got {}", buffers.len()); } - let fill_value = Scalar::new( - dtype.clone(), - ScalarValue::from_protobytes(&buffers[0].clone().try_to_host_sync()?)?, - ); + let fill_value = Scalar::from_proto_bytes(&buffers[0].clone().try_to_host_sync()?, dtype)?; SparseArray::try_new(patch_indices, patch_values, len, fill_value) } @@ -418,11 +415,8 @@ impl ValidityVTable for SparseVTable { impl VisitorVTable for SparseVTable { fn visit_buffers(array: &SparseArray, visitor: &mut dyn ArrayBufferVisitor) { - let fill_value_buffer = array - .fill_value - .value() - .to_protobytes::() - .freeze(); + let fill_value_buffer = + ScalarValue::to_proto_bytes::(array.fill_value.value()).freeze(); visitor.visit_buffer_handle("fill_value", &BufferHandle::new_host(fill_value_buffer)); } @@ -619,7 +613,8 @@ mod test { let indices = buffer![0u8, 2, 4, 6, 8].into_array(); let values = PrimitiveArray::from_option_iter([Some(0i16), Some(1), None, None, Some(4)]) .into_array(); - let array = SparseArray::try_new(indices, values, 10, Scalar::null_typed::()).unwrap(); + let array = + SparseArray::try_new(indices, values, 10, Scalar::null_native::()).unwrap(); let actual = array.validity_mask().unwrap(); let expected = Mask::from_iter([ true, false, true, false, false, false, false, false, true, false, diff --git a/encodings/zstd/src/test.rs b/encodings/zstd/src/test.rs index f498218e2e1..ca38ef3d36c 100644 --- a/encodings/zstd/src/test.rs +++ b/encodings/zstd/src/test.rs @@ -83,12 +83,7 @@ fn test_zstd_with_validity_and_multi_frame() { // check slicing works let slice = compressed.slice(176..179).unwrap(); let primitive = slice.to_primitive(); - assert_eq!( - TryInto::::try_into(primitive.scalar_at(1).unwrap().as_ref()) - .ok() - .unwrap(), - 177 - ); + assert_eq!(i32::try_from(primitive.scalar_at(1).unwrap()).unwrap(), 177); assert_eq!( primitive.validity(), &Validity::Array(BoolArray::from_iter(vec![false, true, false]).to_array()) diff --git a/fuzz/src/array/compare.rs b/fuzz/src/array/compare.rs index d5fde3bfd52..2447a0f98d4 100644 --- a/fuzz/src/array/compare.rs +++ b/fuzz/src/array/compare.rs @@ -107,27 +107,23 @@ pub fn compare_canonical_array(array: &dyn Array, value: &Scalar, operator: Oper }) } DType::Utf8(_) => array.to_varbinview().with_iterator(|iter| { - let utf8_value = value - .as_utf8() - .value() - .vortex_expect("nulls handled before"); + let utf8_value = value.as_utf8(); compare_to( iter.map(|v| v.map(|b| unsafe { str::from_utf8_unchecked(b) })), - &utf8_value, + utf8_value.value().vortex_expect("nulls handled before"), operator, result_nullability, ) }), DType::Binary(_) => array.to_varbinview().with_iterator(|iter| { - let binary_value = value - .as_binary() - .value() - .vortex_expect("nulls handled before"); + let binary_value = value.as_binary(); compare_to( // Don't understand the lifetime problem here but identity map makes it go away #[allow(clippy::map_identity)] iter.map(|v| v), - &binary_value, + binary_value + .value_ref() + .vortex_expect("nulls handled before"), operator, result_nullability, ) diff --git a/fuzz/src/array/fill_null.rs b/fuzz/src/array/fill_null.rs index 29c90c61c32..84ff49e2c3e 100644 --- a/fuzz/src/array/fill_null.rs +++ b/fuzz/src/array/fill_null.rs @@ -210,7 +210,7 @@ fn fill_varbinview_array( DType::Binary(_) => { let fill_bytes = fill_value .as_binary() - .value() + .to_value() .vortex_expect("cannot have null fill value"); let binaries: Vec> = (0..array.len()) .map(|i| { @@ -219,7 +219,7 @@ fn fill_varbinview_array( .scalar_at(i) .vortex_expect("scalar_at") .as_binary() - .value() + .to_value() .vortex_expect("cannot have null valid value") .to_vec() } else { diff --git a/vortex-array/src/array/mod.rs b/vortex-array/src/array/mod.rs index 7af207736ad..95a66056102 100644 --- a/vortex-array/src/array/mod.rs +++ b/vortex-array/src/array/mod.rs @@ -452,7 +452,8 @@ impl Array for ArrayAdapter { stat, Stat::IsConstant | Stat::IsSorted | Stat::IsStrictSorted ) && value.as_ref().as_exact().is_some_and(|v| { - Scalar::new(DType::Bool(Nullability::NonNullable), v.clone()) + Scalar::try_new(DType::Bool(Nullability::NonNullable), Some(v.clone())) + .vortex_expect("A stat that was expected to be a boolean stat was not") .as_bool() .value() .unwrap_or_default() diff --git a/vortex-array/src/arrays/bool/compute/sum.rs b/vortex-array/src/arrays/bool/compute/sum.rs index abcb89f1733..19c435c9d2f 100644 --- a/vortex-array/src/arrays/bool/compute/sum.rs +++ b/vortex-array/src/arrays/bool/compute/sum.rs @@ -3,6 +3,7 @@ use std::ops::BitAnd; +use vortex_dtype::Nullability; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_mask::AllOr; @@ -30,13 +31,15 @@ impl SumKernel for BoolVTable { } }; - let accumulator = accumulator + let acc_value = accumulator .as_primitive() .as_::() .vortex_expect("cannot be null"); - Ok(Scalar::from( - true_count.and_then(|tc| accumulator.checked_add(tc)), - )) + let result = true_count.and_then(|tc| acc_value.checked_add(tc)); + Ok(match result { + Some(v) => Scalar::primitive(v, Nullability::Nullable), + None => Scalar::null_native::(), + }) } } diff --git a/vortex-array/src/arrays/chunked/compute/sum.rs b/vortex-array/src/arrays/chunked/compute/sum.rs index c9bb9389781..6b80c2e2fa8 100644 --- a/vortex-array/src/arrays/chunked/compute/sum.rs +++ b/vortex-array/src/arrays/chunked/compute/sum.rs @@ -30,9 +30,9 @@ mod tests { use vortex_dtype::DType; use vortex_dtype::DecimalDType; use vortex_dtype::Nullability; + use vortex_dtype::i256; use vortex_scalar::DecimalValue; use vortex_scalar::Scalar; - use vortex_scalar::i256; use crate::array::IntoArray; use crate::arrays::ChunkedArray; diff --git a/vortex-array/src/arrays/constant/compute/cast.rs b/vortex-array/src/arrays/constant/compute/cast.rs index 36401af8a36..1875c52f39d 100644 --- a/vortex-array/src/arrays/constant/compute/cast.rs +++ b/vortex-array/src/arrays/constant/compute/cast.rs @@ -37,7 +37,7 @@ mod tests { #[case(ConstantArray::new(Scalar::from(-100i32), 10).into_array())] #[case(ConstantArray::new(Scalar::from(3.5f32), 3).into_array())] #[case(ConstantArray::new(Scalar::from(true), 7).into_array())] - #[case(ConstantArray::new(Scalar::null_typed::(), 4).into_array())] + #[case(ConstantArray::new(Scalar::null_native::(), 4).into_array())] #[case(ConstantArray::new(Scalar::from(255u8), 1).into_array())] fn test_cast_constant_conformance(#[case] array: crate::ArrayRef) { test_cast_conformance(array.as_ref()); diff --git a/vortex-array/src/arrays/constant/compute/fill_null.rs b/vortex-array/src/arrays/constant/compute/fill_null.rs index 9e6fd16cda6..cde536b7a35 100644 --- a/vortex-array/src/arrays/constant/compute/fill_null.rs +++ b/vortex-array/src/arrays/constant/compute/fill_null.rs @@ -37,7 +37,7 @@ mod test { #[test] fn test_null() { let actual = fill_null( - &ConstantArray::new(Scalar::from(None::), 3).into_array(), + &ConstantArray::new(Scalar::null_native::(), 3).into_array(), &Scalar::from(1), ) .unwrap(); diff --git a/vortex-array/src/arrays/constant/compute/mod.rs b/vortex-array/src/arrays/constant/compute/mod.rs index 7e31b9f326a..f9e3eb3e6b7 100644 --- a/vortex-array/src/arrays/constant/compute/mod.rs +++ b/vortex-array/src/arrays/constant/compute/mod.rs @@ -29,7 +29,7 @@ mod test { #[test] fn test_mask_constant() { - test_mask_conformance(&ConstantArray::new(Scalar::null_typed::(), 5).into_array()); + test_mask_conformance(&ConstantArray::new(Scalar::null_native::(), 5).into_array()); test_mask_conformance(&ConstantArray::new(Scalar::from(3u16), 5).into_array()); test_mask_conformance(&ConstantArray::new(Scalar::from(1.0f32 / 0.0f32), 5).into_array()); test_mask_conformance( @@ -39,7 +39,7 @@ mod test { #[test] fn test_filter_constant() { - test_filter_conformance(&ConstantArray::new(Scalar::null_typed::(), 5).into_array()); + test_filter_conformance(&ConstantArray::new(Scalar::null_native::(), 5).into_array()); test_filter_conformance(&ConstantArray::new(Scalar::from(3u16), 5).into_array()); test_filter_conformance(&ConstantArray::new(Scalar::from(1.0f32 / 0.0f32), 5).into_array()); test_filter_conformance( diff --git a/vortex-array/src/arrays/constant/compute/sum.rs b/vortex-array/src/arrays/constant/compute/sum.rs index 75a412d50c3..feaaa5d4778 100644 --- a/vortex-array/src/arrays/constant/compute/sum.rs +++ b/vortex-array/src/arrays/constant/compute/sum.rs @@ -35,11 +35,15 @@ impl SumKernel for ConstantVTable { .ok_or_else(|| vortex_err!("Sum not supported for dtype {}", array.dtype()))?; let sum_value = sum_scalar(array.scalar(), array.len(), accumulator)?; - Ok(Scalar::new(sum_dtype, sum_value)) + Scalar::try_new(sum_dtype, sum_value) } } -fn sum_scalar(scalar: &Scalar, len: usize, accumulator: &Scalar) -> VortexResult { +fn sum_scalar( + scalar: &Scalar, + len: usize, + accumulator: &Scalar, +) -> VortexResult> { match scalar.dtype() { DType::Bool(_) => { let count = match scalar.as_bool().value() { @@ -51,14 +55,16 @@ fn sum_scalar(scalar: &Scalar, len: usize, accumulator: &Scalar) -> VortexResult .as_primitive() .as_::() .vortex_expect("cannot be null"); - Ok(ScalarValue::from(accumulator.checked_add(count))) + Ok(accumulator + .checked_add(count) + .map(|v| ScalarValue::Primitive(v.into()))) } DType::Primitive(ptype, _) => { let result = match_each_native_ptype!( ptype, - unsigned: |T| { sum_integral::(scalar.as_primitive(), len, accumulator)?.into() }, - signed: |T| { sum_integral::(scalar.as_primitive(), len, accumulator)?.into() }, - floating: |T| { sum_float(scalar.as_primitive(), len, accumulator)?.into() } + unsigned: |T| { sum_integral::(scalar.as_primitive(), len, accumulator)?.map(|v| ScalarValue::Primitive(v.into())) }, + signed: |T| { sum_integral::(scalar.as_primitive(), len, accumulator)?.map(|v| ScalarValue::Primitive(v.into())) }, + floating: |T| { sum_float(scalar.as_primitive(), len, accumulator)?.map(|v| ScalarValue::Primitive(v.into())) } ); Ok(result) } @@ -75,7 +81,7 @@ fn sum_decimal( array_len: usize, decimal_dtype: DecimalDType, accumulator: &Scalar, -) -> VortexResult { +) -> VortexResult> { let result_dtype = Stat::Sum .dtype(&DType::Decimal(decimal_dtype, Nullability::Nullable)) .vortex_expect("decimal supports sum"); @@ -85,22 +91,22 @@ fn sum_decimal( let Some(value) = decimal_scalar.decimal_value() else { // Null value: return null - return Ok(ScalarValue::null()); + return Ok(None); }; - // Convert array_len to DecimalValue for multiplication + // Convert array_len to DecimalValue for multiplication. let len_value = DecimalValue::I256(i256::from_i128(array_len as i128)); - // Multiply value * len + // Multiply value * len. let array_sum = value.checked_mul(&len_value).and_then(|result| { - // Check if result fits in the precision + // Check if result fits in the precision. result .fits_in_precision(*result_decimal_type) .unwrap_or(false) .then_some(result) }); - // Add accumulator to array_sum + // Add accumulator to array_sum. let initial_decimal = DecimalScalar::try_from(accumulator)?; let initial_dec_value = initial_decimal .decimal_value() @@ -117,11 +123,11 @@ fn sum_decimal( .then_some(result) }); match total { - Some(result_value) => Ok(ScalarValue::from(result_value)), - None => Ok(ScalarValue::null()), // Overflow + Some(result_value) => Ok(Some(ScalarValue::from(result_value))), + None => Ok(None), // Overflow } } - None => Ok(ScalarValue::null()), // Overflow + None => Ok(None), // Overflow } } @@ -132,7 +138,6 @@ fn sum_integral( ) -> VortexResult> where T: NativePType + CheckedMul + CheckedAdd, - Scalar: From>, { let v = primitive_scalar.as_::(); let array_len = diff --git a/vortex-array/src/arrays/constant/compute/take.rs b/vortex-array/src/arrays/constant/compute/take.rs index 44a00f76ca4..3fe851332e7 100644 --- a/vortex-array/src/arrays/constant/compute/take.rs +++ b/vortex-array/src/arrays/constant/compute/take.rs @@ -20,13 +20,13 @@ impl TakeReduce for ConstantVTable { fn take(array: &ConstantArray, indices: &dyn Array) -> VortexResult> { let result = match indices.validity_mask()?.bit_buffer() { AllOr::All => { - let scalar = Scalar::new( + let scalar = Scalar::try_new( array .scalar() .dtype() .union_nullability(indices.dtype().nullability()), - array.scalar().value().clone(), - ); + array.scalar().value().cloned(), + )?; ConstantArray::new(scalar, indices.len()).into_array() } AllOr::None => ConstantArray::new( @@ -128,7 +128,7 @@ mod tests { #[case(ConstantArray::new(42i32, 5))] #[case(ConstantArray::new(std::f64::consts::PI, 10))] #[case(ConstantArray::new(Scalar::from("hello"), 3))] - #[case(ConstantArray::new(Scalar::null_typed::(), 5))] + #[case(ConstantArray::new(Scalar::null_native::(), 5))] #[case(ConstantArray::new(true, 1))] fn test_take_constant_conformance(#[case] array: ConstantArray) { test_take_conformance(array.as_ref()); diff --git a/vortex-array/src/arrays/constant/vtable/canonical.rs b/vortex-array/src/arrays/constant/vtable/canonical.rs index 42c3391f6fc..9657076fc3d 100644 --- a/vortex-array/src/arrays/constant/vtable/canonical.rs +++ b/vortex-array/src/arrays/constant/vtable/canonical.rs @@ -124,7 +124,7 @@ pub(crate) fn constant_canonicalize(array: &ConstantArray) -> VortexResult { let value = BinaryScalar::try_from(scalar) .vortex_expect("must be a binary scalar") - .value(); + .to_value(); let const_value = value.as_ref().map(|v| v.as_slice()); Canonical::VarBinView(constant_canonical_byte_view( const_value, @@ -134,17 +134,20 @@ pub(crate) fn constant_canonicalize(array: &ConstantArray) -> VortexResult { let value = StructScalar::try_from(scalar).vortex_expect("must be struct"); - let fields: Vec<_> = match value.fields() { + let fields: Vec<_> = match value.fields_iter() { Some(fields) => fields .into_iter() .map(|s| ConstantArray::new(s, array.len()).into_array()) .collect(), None => { assert!(validity.all_invalid(array.len())?); + // The struct is entirely null, so fields just need placeholder values with the + // correct dtype. We use `default_value` which returns a zero for non-nullable + // dtypes and null for nullable dtypes, preserving each field's nullability. struct_dtype .fields() .map(|dt| { - let scalar = Scalar::default_value(dt); + let scalar = Scalar::default_value(&dt); ConstantArray::new(scalar, array.len()).into_array() }) .collect() diff --git a/vortex-array/src/arrays/constant/vtable/mod.rs b/vortex-array/src/arrays/constant/vtable/mod.rs index 2b0e8a84eec..504de94c06b 100644 --- a/vortex-array/src/arrays/constant/vtable/mod.rs +++ b/vortex-array/src/arrays/constant/vtable/mod.rs @@ -8,7 +8,6 @@ use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_ensure; use vortex_scalar::Scalar; -use vortex_scalar::ScalarValue; use vortex_session::VortexSession; use crate::ArrayRef; @@ -81,8 +80,9 @@ impl VTable for ConstantVTable { vortex_bail!("Expected 1 buffer, got {}", buffers.len()); } let buffer = buffers[0].clone().try_to_host_sync()?; - let sv = ScalarValue::from_protobytes(&buffer)?; - let scalar = Scalar::new(dtype.clone(), sv); + + let scalar = Scalar::from_proto_bytes(buffer.as_ref(), dtype)?; + Ok(ConstantArray::new(scalar, len)) } diff --git a/vortex-array/src/arrays/constant/vtable/visitor.rs b/vortex-array/src/arrays/constant/vtable/visitor.rs index 28e613ba121..9f1e74553ca 100644 --- a/vortex-array/src/arrays/constant/vtable/visitor.rs +++ b/vortex-array/src/arrays/constant/vtable/visitor.rs @@ -2,6 +2,7 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_buffer::ByteBufferMut; +use vortex_scalar::ScalarValue; use crate::ArrayBufferVisitor; use crate::ArrayChildVisitor; @@ -12,11 +13,7 @@ use crate::vtable::VisitorVTable; impl VisitorVTable for ConstantVTable { fn visit_buffers(array: &ConstantArray, visitor: &mut dyn ArrayBufferVisitor) { - let buffer = array - .scalar - .value() - .to_protobytes::() - .freeze(); + let buffer = ScalarValue::to_proto_bytes::(array.scalar.value()).freeze(); visitor.visit_buffer_handle("scalar", &BufferHandle::new_host(buffer)); } diff --git a/vortex-array/src/arrays/decimal/compute/min_max.rs b/vortex-array/src/arrays/decimal/compute/min_max.rs index da4483a3a69..68450cf7f9d 100644 --- a/vortex-array/src/arrays/decimal/compute/min_max.rs +++ b/vortex-array/src/arrays/decimal/compute/min_max.rs @@ -95,14 +95,16 @@ mod tests { let non_nullable_dtype = decimal.dtype().as_nonnullable(); let expected = MinMaxResult { - min: Scalar::new( + min: Scalar::try_new( non_nullable_dtype.clone(), - ScalarValue::from(DecimalValue::from(100i32)), - ), - max: Scalar::new( + Some(ScalarValue::from(DecimalValue::from(100i32))), + ) + .unwrap(), + max: Scalar::try_new( non_nullable_dtype, - ScalarValue::from(DecimalValue::from(200i32)), - ), + Some(ScalarValue::from(DecimalValue::from(200i32))), + ) + .unwrap(), }; assert_eq!(Some(expected), min_max) diff --git a/vortex-array/src/arrays/decimal/compute/sum.rs b/vortex-array/src/arrays/decimal/compute/sum.rs index 5b5a192fc58..8fcf4a19050 100644 --- a/vortex-array/src/arrays/decimal/compute/sum.rs +++ b/vortex-array/src/arrays/decimal/compute/sum.rs @@ -129,11 +129,11 @@ mod tests { use vortex_dtype::DType; use vortex_dtype::DecimalDType; use vortex_dtype::Nullability; + use vortex_dtype::i256; use vortex_error::VortexExpect; use vortex_scalar::DecimalValue; use vortex_scalar::Scalar; use vortex_scalar::ScalarValue; - use vortex_scalar::i256; use crate::arrays::DecimalArray; use crate::compute::sum; @@ -149,10 +149,11 @@ mod tests { let result = sum(decimal.as_ref()).unwrap(); - let expected = Scalar::new( + let expected = Scalar::try_new( DType::Decimal(DecimalDType::new(14, 2), Nullability::NonNullable), - ScalarValue::from(DecimalValue::from(600i32)), - ); + Some(ScalarValue::from(DecimalValue::from(600i32))), + ) + .unwrap(); assert_eq!(result, expected); } @@ -167,10 +168,11 @@ mod tests { let result = sum(decimal.as_ref()).unwrap(); - let expected = Scalar::new( + let expected = Scalar::try_new( DType::Decimal(DecimalDType::new(14, 2), Nullability::Nullable), - ScalarValue::from(DecimalValue::from(800i32)), - ); + Some(ScalarValue::from(DecimalValue::from(800i32))), + ) + .unwrap(); assert_eq!(result, expected); } @@ -185,10 +187,11 @@ mod tests { let result = sum(decimal.as_ref()).unwrap(); - let expected = Scalar::new( + let expected = Scalar::try_new( DType::Decimal(DecimalDType::new(14, 2), Nullability::NonNullable), - ScalarValue::from(DecimalValue::from(150i32)), - ); + Some(ScalarValue::from(DecimalValue::from(150i32))), + ) + .unwrap(); assert_eq!(result, expected); } @@ -207,10 +210,11 @@ mod tests { // Should use i64 for accumulation since precision increases let expected_sum = near_max as i64 + 500 + 400; - let expected = Scalar::new( + let expected = Scalar::try_new( DType::Decimal(DecimalDType::new(20, 2), Nullability::NonNullable), - ScalarValue::from(DecimalValue::from(expected_sum)), - ); + Some(ScalarValue::from(DecimalValue::from(expected_sum))), + ) + .unwrap(); assert_eq!(result, expected); } @@ -228,17 +232,18 @@ mod tests { let result = sum(decimal.as_ref()).unwrap(); let expected_sum = (large_val as i128) * 4 + 1; - let expected = Scalar::new( + let expected = Scalar::try_new( DType::Decimal(DecimalDType::new(29, 0), Nullability::NonNullable), - ScalarValue::from(DecimalValue::from(expected_sum)), - ); + Some(ScalarValue::from(DecimalValue::from(expected_sum))), + ) + .unwrap(); assert_eq!(result, expected); } #[test] fn test_sum_overflow_detection() { - use vortex_scalar::i256; + use vortex_dtype::i256; // Create values that will overflow when summed // Use maximum i128 values that will overflow when added @@ -254,10 +259,11 @@ mod tests { // Should use i256 for accumulation let expected_sum = i256::from_i128(max_val) + i256::from_i128(max_val) + i256::from_i128(max_val); - let expected = Scalar::new( + let expected = Scalar::try_new( DType::Decimal(DecimalDType::new(48, 0), Nullability::NonNullable), - ScalarValue::from(DecimalValue::from(expected_sum)), - ); + Some(ScalarValue::from(DecimalValue::from(expected_sum))), + ) + .unwrap(); assert_eq!(result, expected); } @@ -276,10 +282,11 @@ mod tests { let result = sum(decimal.as_ref()).unwrap(); let expected_sum = (large_pos as i128) + (large_neg as i128) + (large_pos as i128) + 1000; - let expected = Scalar::new( + let expected = Scalar::try_new( DType::Decimal(DecimalDType::new(29, 3), Nullability::NonNullable), - ScalarValue::from(DecimalValue::from(expected_sum)), - ); + Some(ScalarValue::from(DecimalValue::from(expected_sum))), + ) + .unwrap(); assert_eq!(result, expected); } @@ -295,10 +302,11 @@ mod tests { let result = sum(decimal.as_ref()).unwrap(); // Scale should be preserved, precision increased by 10 - let expected = Scalar::new( + let expected = Scalar::try_new( DType::Decimal(DecimalDType::new(16, 4), Nullability::NonNullable), - ScalarValue::from(DecimalValue::from(91346i32)), - ); + Some(ScalarValue::from(DecimalValue::from(91346i32))), + ) + .unwrap(); assert_eq!(result, expected); } @@ -310,10 +318,11 @@ mod tests { let result = sum(decimal.as_ref()).unwrap(); - let expected = Scalar::new( + let expected = Scalar::try_new( DType::Decimal(DecimalDType::new(13, 1), Nullability::NonNullable), - ScalarValue::from(DecimalValue::from(42i32)), - ); + Some(ScalarValue::from(DecimalValue::from(42i32))), + ) + .unwrap(); assert_eq!(result, expected); } @@ -328,10 +337,11 @@ mod tests { let result = sum(decimal.as_ref()).unwrap(); - let expected = Scalar::new( + let expected = Scalar::try_new( DType::Decimal(DecimalDType::new(14, 2), Nullability::Nullable), - ScalarValue::from(DecimalValue::from(300i32)), - ); + Some(ScalarValue::from(DecimalValue::from(300i32))), + ) + .unwrap(); assert_eq!(result, expected); } @@ -353,10 +363,11 @@ mod tests { // Should use i256 for accumulation since 9 * (i128::MAX / 10) fits in i128 but we increase precision let expected_sum = i256::from_i128(large_i128).wrapping_pow(1) * i256::from_i128(9); - let expected = Scalar::new( + let expected = Scalar::try_new( DType::Decimal(DecimalDType::new(48, 0), Nullability::NonNullable), - ScalarValue::from(DecimalValue::from(expected_sum)), - ); + Some(ScalarValue::from(DecimalValue::from(expected_sum))), + ) + .unwrap(); assert_eq!(result, expected); } diff --git a/vortex-array/src/arrays/decimal/utils.rs b/vortex-array/src/arrays/decimal/utils.rs index 00bd6a86fb3..1b148b130a9 100644 --- a/vortex-array/src/arrays/decimal/utils.rs +++ b/vortex-array/src/arrays/decimal/utils.rs @@ -3,9 +3,9 @@ use itertools::Itertools; use itertools::MinMaxResult; +use vortex_dtype::DecimalType; +use vortex_dtype::i256; use vortex_error::VortexExpect; -use vortex_scalar::DecimalType; -use vortex_scalar::i256; use crate::arrays::DecimalArray; use crate::vtable::ValidityHelper; diff --git a/vortex-array/src/arrays/decimal/vtable/array.rs b/vortex-array/src/arrays/decimal/vtable/array.rs index 23013e066ec..f5997dd0b95 100644 --- a/vortex-array/src/arrays/decimal/vtable/array.rs +++ b/vortex-array/src/arrays/decimal/vtable/array.rs @@ -4,7 +4,7 @@ use std::hash::Hash; use vortex_dtype::DType; -use vortex_scalar::DecimalType; +use vortex_dtype::DecimalType; use crate::Precision; use crate::arrays::DecimalArray; diff --git a/vortex-array/src/arrays/decimal/vtable/mod.rs b/vortex-array/src/arrays/decimal/vtable/mod.rs index 96683f051bd..70ac0c3fdad 100644 --- a/vortex-array/src/arrays/decimal/vtable/mod.rs +++ b/vortex-array/src/arrays/decimal/vtable/mod.rs @@ -4,13 +4,13 @@ use kernel::PARENT_KERNELS; use vortex_buffer::Alignment; use vortex_dtype::DType; +use vortex_dtype::DecimalType; use vortex_dtype::NativeDecimalType; use vortex_dtype::match_each_decimal_value_type; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_ensure; -use vortex_scalar::DecimalType; use vortex_session::VortexSession; use crate::ArrayRef; diff --git a/vortex-array/src/arrays/dict/compute/fill_null.rs b/vortex-array/src/arrays/dict/compute/fill_null.rs index 4a7ab763d60..4e9428326f8 100644 --- a/vortex-array/src/arrays/dict/compute/fill_null.rs +++ b/vortex-array/src/arrays/dict/compute/fill_null.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use vortex_dtype::match_each_unsigned_integer_ptype; use vortex_error::VortexResult; use vortex_scalar::Scalar; use vortex_scalar::ScalarValue; @@ -21,8 +22,8 @@ use crate::register_kernel; impl FillNullKernel for DictVTable { fn fill_null(&self, array: &DictArray, fill_value: &Scalar) -> VortexResult { - // If the fill value exists in the dictionary, we can simply rewrite the null codes to - // point to the value. + // If the fill value already exists in the dictionary, we can simply rewrite the null codes + // to point to the value. let found_fill_values = compare( array.values(), ConstantArray::new(fill_value.clone(), array.values().len()).as_ref(), @@ -30,7 +31,10 @@ impl FillNullKernel for DictVTable { )? .to_bool(); - let Some(first_fill_value) = found_fill_values.to_bit_buffer().set_indices().next() else { + // We found the fill value already in the values at this given index. + let Some(existing_fill_value_index) = + found_fill_values.to_bit_buffer().set_indices().next() + else { // No fill values found, so we must canonicalize and fill_null. // TODO(ngates): compute kernels should all return Option to support this // fall back. @@ -38,20 +42,27 @@ impl FillNullKernel for DictVTable { }; // Now we rewrite the nullable codes to point at the fill value. + let codes = array.codes(); + + // Cast the index to the correct unsigned integer type matching the codes' ptype. + let codes_ptype = codes.dtype().as_ptype(); + + #[expect( + clippy::cast_possible_truncation, + reason = "The existing index must be representable by the existing ptype" + )] + let fill_scalar_value = match_each_unsigned_integer_ptype!(codes_ptype, |P| { + ScalarValue::from(existing_fill_value_index as P) + }); + + // Fill nulls in both the codes and the values. let codes = fill_null( - array.codes(), - &Scalar::new( - array - .codes() - .dtype() - .with_nullability(fill_value.dtype().nullability()), - ScalarValue::from(first_fill_value), - ), + codes, + &Scalar::try_new(codes.dtype().as_nonnullable(), Some(fill_scalar_value))?, )?; - // And fill nulls in the values let values = fill_null(array.values(), fill_value)?; - // SAFETY: invariants are still satisfied after patching nulls + // SAFETY: invariants are still satisfied after patching nulls. unsafe { Ok(DictArray::new_unchecked(codes, values) .set_all_values_referenced(array.has_all_values_referenced()) diff --git a/vortex-array/src/arrays/dict/take.rs b/vortex-array/src/arrays/dict/take.rs index 86c3700952d..e83cabc5d43 100644 --- a/vortex-array/src/arrays/dict/take.rs +++ b/vortex-array/src/arrays/dict/take.rs @@ -155,7 +155,8 @@ pub(crate) fn propagate_take_stats( source .statistics() .get(stat) - .map(|v| (stat, v.map(|s| s.into_value()).into_inexact())) + .and_then(|v| v.map(|s| s.into_value()).into_inexact().transpose()) + .map(|sv| (stat, sv)) }) .collect::>(); st.combine_sets( diff --git a/vortex-array/src/arrays/masked/compute/take.rs b/vortex-array/src/arrays/masked/compute/take.rs index 162d6e5699d..c9026c43ea9 100644 --- a/vortex-array/src/arrays/masked/compute/take.rs +++ b/vortex-array/src/arrays/masked/compute/take.rs @@ -21,12 +21,14 @@ impl TakeExecute for MaskedVTable { _ctx: &mut ExecutionCtx, ) -> VortexResult> { let taken_child = if !indices.all_valid()? { - // This is safe because we'll mask out these positions in the validity - let filled_take = fill_null( - indices, - &Scalar::default_value(indices.dtype().clone().as_nonnullable()), - )?; - array.child.take(filled_take)?.to_canonical()?.into_array() + // This is safe because we'll mask out these positions in the validity. + let fill_scalar = Scalar::zero_value(indices.dtype()); + let filled_take_indices = fill_null(indices, &fill_scalar)?; + array + .child + .take(filled_take_indices)? + .to_canonical()? + .into_array() } else { array .child diff --git a/vortex-array/src/arrays/null/compute/cast.rs b/vortex-array/src/arrays/null/compute/cast.rs index ceb35504e63..32bc1db0107 100644 --- a/vortex-array/src/arrays/null/compute/cast.rs +++ b/vortex-array/src/arrays/null/compute/cast.rs @@ -5,7 +5,6 @@ use vortex_dtype::DType; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_scalar::Scalar; -use vortex_scalar::ScalarValue; use crate::ArrayRef; use crate::IntoArray; @@ -25,7 +24,7 @@ impl CastKernel for NullVTable { return Ok(Some(array.to_array())); } - let scalar = Scalar::new(dtype.clone(), ScalarValue::null()); + let scalar = Scalar::null(dtype.clone()); Ok(Some(ConstantArray::new(scalar, array.len()).into_array())) } } diff --git a/vortex-array/src/arrays/primitive/compute/sum.rs b/vortex-array/src/arrays/primitive/compute/sum.rs index bd3d861d4eb..2c02f3e308c 100644 --- a/vortex-array/src/arrays/primitive/compute/sum.rs +++ b/vortex-array/src/arrays/primitive/compute/sum.rs @@ -7,10 +7,12 @@ use num_traits::Float; use num_traits::ToPrimitive; use vortex_buffer::BitBuffer; use vortex_dtype::NativePType; +use vortex_dtype::Nullability; use vortex_dtype::match_each_native_ptype; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_mask::AllOr; +use vortex_scalar::PValue; use vortex_scalar::Scalar; use crate::arrays::PrimitiveArray; @@ -19,6 +21,15 @@ use crate::compute::SumKernel; use crate::compute::SumKernelAdapter; use crate::register_kernel; +// TODO(connor): This should be public and a `From` implementation. +/// Helper to convert an Option to a Scalar. None represents overflow (null result). +fn option_to_scalar>(opt: Option) -> Scalar { + match opt { + Some(v) => Scalar::primitive(v, Nullability::Nullable), + None => Scalar::null(T::PTYPE.into()), + } +} + impl SumKernel for PrimitiveVTable { fn sum(&self, array: &PrimitiveArray, accumulator: &Scalar) -> VortexResult { let array_sum_scalar = match array.validity_mask()?.bit_buffer() { @@ -26,9 +37,27 @@ impl SumKernel for PrimitiveVTable { // All-valid match_each_native_ptype!( array.ptype(), - unsigned: |T| { sum_integer::<_, u64>(array.as_slice::(), accumulator.as_primitive().as_::().vortex_expect("cannot be null")).into() }, - signed: |T| { sum_integer::<_, i64>(array.as_slice::(), accumulator.as_primitive().as_::().vortex_expect("cannot be null")).into() }, - floating: |T| { Some(sum_float(array.as_slice::(), accumulator.as_primitive().as_::().vortex_expect("cannot be null"))).into() } + unsigned: |T| { + option_to_scalar(sum_integer::<_, u64>( + array.as_slice::(), + accumulator.as_primitive().as_::().vortex_expect("cannot be null"), + )) + }, + signed: |T| { + option_to_scalar(sum_integer::<_, i64>( + array.as_slice::(), + accumulator.as_primitive().as_::().vortex_expect("cannot be null"), + )) + }, + floating: |T| { + Scalar::primitive( + sum_float( + array.as_slice::(), + accumulator.as_primitive().as_::().vortex_expect("cannot be null"), + ), + Nullability::Nullable, + ) + } ) } AllOr::None => { @@ -40,13 +69,28 @@ impl SumKernel for PrimitiveVTable { match_each_native_ptype!( array.ptype(), unsigned: |T| { - sum_integer_with_validity::<_, u64>(array.as_slice::(), validity_mask, accumulator.as_primitive().as_::().vortex_expect("cannot be null")).into() + option_to_scalar(sum_integer_with_validity::<_, u64>( + array.as_slice::(), + validity_mask, + accumulator.as_primitive().as_::().vortex_expect("cannot be null"), + )) }, signed: |T| { - sum_integer_with_validity::<_, i64>(array.as_slice::(), validity_mask, accumulator.as_primitive().as_::().vortex_expect("cannot be null")).into() + option_to_scalar(sum_integer_with_validity::<_, i64>( + array.as_slice::(), + validity_mask, + accumulator.as_primitive().as_::().vortex_expect("cannot be null"), + )) }, floating: |T| { - Some(sum_float_with_validity(array.as_slice::(), validity_mask, accumulator.as_primitive().as_::().vortex_expect("cannot be null"))).into() + Scalar::primitive( + sum_float_with_validity( + array.as_slice::(), + validity_mask, + accumulator.as_primitive().as_::().vortex_expect("cannot be null"), + ), + Nullability::Nullable, + ) } ) } diff --git a/vortex-array/src/arrays/primitive/compute/take/mod.rs b/vortex-array/src/arrays/primitive/compute/take/mod.rs index be832f50cb6..a75db90b01c 100644 --- a/vortex-array/src/arrays/primitive/compute/take/mod.rs +++ b/vortex-array/src/arrays/primitive/compute/take/mod.rs @@ -154,12 +154,12 @@ mod test { // position 3 is null assert_eq!( actual.scalar_at(1).vortex_expect("no fail"), - Scalar::null_typed::() + Scalar::null_native::() ); // the third index is null assert_eq!( actual.scalar_at(2).vortex_expect("no fail"), - Scalar::null_typed::() + Scalar::null_native::() ); } diff --git a/vortex-array/src/arrays/struct_/compute/take.rs b/vortex-array/src/arrays/struct_/compute/take.rs index c11ff7dd3a9..0827a0e772f 100644 --- a/vortex-array/src/arrays/struct_/compute/take.rs +++ b/vortex-array/src/arrays/struct_/compute/take.rs @@ -1,7 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use vortex_dtype::Nullability; use vortex_error::VortexResult; use vortex_scalar::Scalar; @@ -34,11 +33,15 @@ impl TakeExecute for StructVTable { .map(StructArray::into_array) .map(Some); } - // The validity is applied to the struct validity, - let inner_indices = &compute::fill_null( - indices, - &Scalar::default_value(indices.dtype().with_nullability(Nullability::NonNullable)), - )?; + + // TODO(connor): This could be bad for cache locality... + + // Fill null indices with zero so they point at a valid row. + // Note that we strip nullability so that `Take::return_dtype` doesn't union nullable into + // each field's dtype (the struct-level validity already captures which rows are null). + let fill_scalar = Scalar::zero_value(&indices.dtype().as_nonnullable()); + let inner_indices = &compute::fill_null(indices, &fill_scalar)?; + StructArray::try_new_with_dtype( array .unmasked_fields() diff --git a/vortex-array/src/arrays/varbin/array.rs b/vortex-array/src/arrays/varbin/array.rs index bb4d2250ecb..90d51367d19 100644 --- a/vortex-array/src/arrays/varbin/array.rs +++ b/vortex-array/src/arrays/varbin/array.rs @@ -355,10 +355,11 @@ impl VarBinArray { self.len() ); - self.offsets() + // TODO(connor): Fix the `TryFrom` implementation here. + (&self + .offsets() .scalar_at(index) - .vortex_expect("offsets must support scalar_at") - .as_ref() + .vortex_expect("offsets must support scalar_at")) .try_into() .vortex_expect("Failed to convert offset to usize") } diff --git a/vortex-array/src/arrays/varbin/compute/compare.rs b/vortex-array/src/arrays/varbin/compute/compare.rs index 7dd9b9d48e1..1dc2a50ae45 100644 --- a/vortex-array/src/arrays/varbin/compute/compare.rs +++ b/vortex-array/src/arrays/varbin/compute/compare.rs @@ -91,7 +91,7 @@ impl CompareKernel for VarBinVTable { .unwrap_or_else(|| arrow_array::Scalar::new(StringArray::new_null(1))), DType::Binary(_) => &rhs_const .as_binary() - .value() + .to_value() .map(BinaryArray::new_scalar) .unwrap_or_else(|| arrow_array::Scalar::new(BinaryArray::new_null(1))), _ => vortex_bail!( diff --git a/vortex-array/src/arrays/varbin/compute/min_max.rs b/vortex-array/src/arrays/varbin/compute/min_max.rs index 5cb603f380f..0b89c4f817a 100644 --- a/vortex-array/src/arrays/varbin/compute/min_max.rs +++ b/vortex-array/src/arrays/varbin/compute/min_max.rs @@ -88,17 +88,19 @@ mod tests { assert_eq!( min, - Scalar::new( + Scalar::try_new( Utf8(NonNullable), - BufferString::from("hello world".to_string()).into(), + Some(BufferString::from("hello world".to_string()).into()), ) + .unwrap() ); assert_eq!( max, - Scalar::new( + Scalar::try_new( Utf8(NonNullable), - BufferString::from("hello world this is a long string".to_string()).into() + Some(BufferString::from("hello world this is a long string".to_string()).into()), ) + .unwrap() ); } diff --git a/vortex-array/src/arrow/convert.rs b/vortex-array/src/arrow/convert.rs index f57bf91def8..12849414ddc 100644 --- a/vortex-array/src/arrow/convert.rs +++ b/vortex-array/src/arrow/convert.rs @@ -68,11 +68,11 @@ use vortex_dtype::IntegerPType; use vortex_dtype::NativePType; use vortex_dtype::PType; use vortex_dtype::datetime::TimeUnit; +use vortex_dtype::i256; use vortex_error::VortexExpect as _; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_panic; -use vortex_scalar::i256; use crate::ArrayRef; use crate::IntoArray; diff --git a/vortex-array/src/arrow/executor/decimal.rs b/vortex-array/src/arrow/executor/decimal.rs index f88362813e8..9c06292911f 100644 --- a/vortex-array/src/arrow/executor/decimal.rs +++ b/vortex-array/src/arrow/executor/decimal.rs @@ -67,7 +67,7 @@ fn to_arrow_decimal32(array: DecimalArray) -> VortexResult { }) .process_results(|iter| Buffer::from_trusted_len_iter(iter))?, DecimalType::I256 => array - .buffer::() + .buffer::() .into_iter() .map(|x| { x.to_i32() @@ -106,7 +106,7 @@ fn to_arrow_decimal64(array: DecimalArray) -> VortexResult { }) .process_results(|iter| Buffer::from_trusted_len_iter(iter))?, DecimalType::I256 => array - .buffer::() + .buffer::() .into_iter() .map(|x| { x.to_i64() @@ -140,7 +140,7 @@ fn to_arrow_decimal128(array: DecimalArray) -> VortexResult { } DecimalType::I128 => array.buffer::(), DecimalType::I256 => array - .buffer::() + .buffer::() .into_iter() .map(|x| { x.to_i128() @@ -176,7 +176,7 @@ fn to_arrow_decimal256(array: DecimalArray) -> VortexResult { array .buffer::() .into_iter() - .map(|x| vortex_scalar::i256::from_i128(x).into()), + .map(|x| vortex_dtype::i256::from_i128(x).into()), ), DecimalType::I256 => { Buffer::::from_byte_buffer(array.buffer_handle().clone().into_host_sync()) @@ -241,7 +241,7 @@ mod tests { #[case(0i32)] #[case(0i64)] #[case(0i128)] - #[case(vortex_scalar::i256::ZERO)] + #[case(vortex_dtype::i256::ZERO)] fn test_to_arrow_decimal128( #[case] _decimal_type: T, ) -> VortexResult<()> { @@ -268,7 +268,7 @@ mod tests { #[case(0i32)] #[case(0i64)] #[case(0i128)] - #[case(vortex_scalar::i256::ZERO)] + #[case(vortex_dtype::i256::ZERO)] fn test_to_arrow_decimal32(#[case] _decimal_type: T) -> VortexResult<()> { use arrow_array::Decimal32Array; @@ -295,7 +295,7 @@ mod tests { #[case(0i32)] #[case(0i64)] #[case(0i128)] - #[case(vortex_scalar::i256::ZERO)] + #[case(vortex_dtype::i256::ZERO)] fn test_to_arrow_decimal64(#[case] _decimal_type: T) -> VortexResult<()> { use arrow_array::Decimal64Array; @@ -322,7 +322,7 @@ mod tests { #[case(0i32)] #[case(0i64)] #[case(0i128)] - #[case(vortex_scalar::i256::ZERO)] + #[case(vortex_dtype::i256::ZERO)] fn test_to_arrow_decimal256( #[case] _decimal_type: T, ) -> VortexResult<()> { diff --git a/vortex-array/src/builders/decimal.rs b/vortex-array/src/builders/decimal.rs index 24767753538..0c9331ca211 100644 --- a/vortex-array/src/builders/decimal.rs +++ b/vortex-array/src/builders/decimal.rs @@ -9,6 +9,7 @@ use vortex_dtype::DType; use vortex_dtype::DecimalDType; use vortex_dtype::NativeDecimalType; use vortex_dtype::Nullability; +use vortex_dtype::i256; use vortex_dtype::match_each_decimal_value; use vortex_dtype::match_each_decimal_value_type; use vortex_error::VortexExpect; @@ -19,7 +20,6 @@ use vortex_error::vortex_panic; use vortex_mask::Mask; use vortex_scalar::DecimalValue; use vortex_scalar::Scalar; -use vortex_scalar::i256; use crate::Array; use crate::ArrayRef; diff --git a/vortex-array/src/builders/struct_.rs b/vortex-array/src/builders/struct_.rs index ff4d5f527cf..4a8c34057ae 100644 --- a/vortex-array/src/builders/struct_.rs +++ b/vortex-array/src/builders/struct_.rs @@ -73,7 +73,7 @@ impl StructBuilder { ); } - if let Some(fields) = struct_scalar.fields() { + if let Some(fields) = struct_scalar.fields_iter() { for (builder, field) in self.builders.iter_mut().zip_eq(fields) { builder.append_scalar(&field)?; } diff --git a/vortex-array/src/builders/tests.rs b/vortex-array/src/builders/tests.rs index c6cb8b30e91..31b785f575b 100644 --- a/vortex-array/src/builders/tests.rs +++ b/vortex-array/src/builders/tests.rs @@ -78,7 +78,7 @@ fn test_append_zeros_matches_default_value(#[case] dtype: DType) { // Builder 2: Manually append default values. let mut builder_manual = builder_with_capacity(&dtype, num_elements); - let default_scalar = Scalar::default_value(dtype.clone()); + let default_scalar = Scalar::zero_value(&dtype); for _ in 0..num_elements { builder_manual.append_scalar(&default_scalar).unwrap(); } @@ -198,7 +198,7 @@ fn test_append_defaults_behavior(#[case] dtype: DType, #[case] should_be_null: b i ); // For non-nullable, it should match the default value. - let expected = Scalar::default_value(dtype.clone()); + let expected = Scalar::default_value(&dtype); // Skip list comparison due to known bug. if !matches!(dtype, DType::List(..)) { assert_eq!( @@ -359,7 +359,7 @@ fn test_to_canonical_struct() { ); compare_to_canonical_methods(&dtype, |builder| { for _ in 0..3 { - let value = Scalar::default_value(dtype.clone()); + let value = Scalar::default_value(&dtype); builder.append_scalar(&value).unwrap(); } }); @@ -395,7 +395,7 @@ fn test_to_canonical_decimal() { let dtype = DType::Decimal(DecimalDType::new(10, 2), Nullability::NonNullable); compare_to_canonical_methods(&dtype, |builder| { for _ in 0..5 { - let value = Scalar::default_value(dtype.clone()); + let value = Scalar::default_value(&dtype); builder.append_scalar(&value).unwrap(); } }); @@ -592,7 +592,7 @@ fn create_test_scalars_for_dtype(dtype: &DType, count: usize) -> Vec { Scalar::primitive((i + j) as f64, *n) } DType::Utf8(n) => Scalar::utf8(format!("field_{}", i + j), *n), - _ => Scalar::default_value(field_dtype), + _ => Scalar::default_value(&field_dtype), } }) .collect(); @@ -605,7 +605,7 @@ fn create_test_scalars_for_dtype(dtype: &DType, count: usize) -> Vec { DType::Primitive(PType::I32, n) => { Scalar::primitive(j.min(i32::MAX as usize) as i32, *n) } - _ => Scalar::default_value(element_dtype.as_ref().clone()), + _ => Scalar::default_value(element_dtype.as_ref()), }) .collect(); Scalar::list(element_dtype.clone(), elements, *n) @@ -617,7 +617,7 @@ fn create_test_scalars_for_dtype(dtype: &DType, count: usize) -> Vec { DType::Primitive(PType::I32, n) => { Scalar::primitive((i as i32).saturating_add(j as i32), *n) } - _ => Scalar::default_value(element_dtype.as_ref().clone()), + _ => Scalar::default_value(element_dtype.as_ref()), }) .collect(); Scalar::fixed_size_list(element_dtype.clone(), elements, *n) @@ -626,7 +626,7 @@ fn create_test_scalars_for_dtype(dtype: &DType, count: usize) -> Vec { // Create extension scalars with storage values. let storage_scalar = match ext_dtype.storage_dtype() { DType::Primitive(PType::I64, n) => Scalar::primitive(i as i64, *n), - _ => Scalar::default_value(ext_dtype.storage_dtype().clone()), + _ => Scalar::default_value(ext_dtype.storage_dtype()), }; Scalar::extension_ref(ext_dtype.clone(), storage_scalar) } diff --git a/vortex-array/src/builders/varbinview.rs b/vortex-array/src/builders/varbinview.rs index 4525da8ec92..8d53bff2e8d 100644 --- a/vortex-array/src/builders/varbinview.rs +++ b/vortex-array/src/builders/varbinview.rs @@ -261,7 +261,7 @@ impl ArrayBuilder for VarBinViewBuilder { } DType::Binary(_) => { let binary_scalar = BinaryScalar::try_from(scalar)?; - match binary_scalar.value() { + match binary_scalar.to_value() { Some(value) => self.append_value(value), None => self.append_null(), } @@ -1028,7 +1028,7 @@ mod tests { assert_eq!(array.len(), 1); // Verify the value was stored correctly - let retrieved = array.scalar_at(0).unwrap().as_binary().value().unwrap(); + let retrieved = array.scalar_at(0).unwrap().as_binary().to_value().unwrap(); assert_eq!(retrieved.len(), 8192); assert_eq!(retrieved.as_slice(), &large_value); } diff --git a/vortex-array/src/compute/is_constant.rs b/vortex-array/src/compute/is_constant.rs index 31bf507c907..5df78d83471 100644 --- a/vortex-array/src/compute/is_constant.rs +++ b/vortex-array/src/compute/is_constant.rs @@ -85,7 +85,8 @@ impl ComputeFnVTable for IsConstant { // We try and rely on some easy-to-get stats if let Some(Precision::Exact(value)) = array.statistics().get_as::(Stat::IsConstant) { - return Ok(Scalar::from(Some(value)).into()); + let scalar: Scalar = Some(value).into(); + return Ok(scalar.into()); } let value = is_constant_impl(array, options, kernels)?; @@ -105,7 +106,8 @@ impl ComputeFnVTable for IsConstant { .set(Stat::IsConstant, Precision::Exact(value.into())); } - Ok(Scalar::from(value).into()) + let scalar: Scalar = value.into(); + Ok(scalar.into()) } fn return_dtype(&self, _args: &InvocationArgs) -> VortexResult { @@ -227,7 +229,8 @@ impl Kernel for IsConstantKernelAdapter { return Ok(None); }; let is_constant = V::is_constant(&self.0, array, args.options)?; - Ok(Some(Scalar::from(is_constant).into())) + let scalar: Scalar = is_constant.into(); + Ok(Some(scalar.into())) } } diff --git a/vortex-array/src/compute/is_sorted.rs b/vortex-array/src/compute/is_sorted.rs index 2a8818b5d99..9189fa53b96 100644 --- a/vortex-array/src/compute/is_sorted.rs +++ b/vortex-array/src/compute/is_sorted.rs @@ -69,14 +69,16 @@ impl ComputeFnVTable for IsSorted { // We currently don't support sorting struct arrays. if array.dtype().is_struct() { - return Ok(Scalar::from(Some(false)).into()); + let scalar: Scalar = Some(false).into(); + return Ok(scalar.into()); } let is_sorted = if strict { if let Some(Precision::Exact(value)) = array.statistics().get_as::(Stat::IsStrictSorted) { - return Ok(Scalar::from(Some(value)).into()); + let scalar: Scalar = Some(value).into(); + return Ok(scalar.into()); } let is_strict_sorted = is_sorted_impl(array, kernels, true)?; @@ -95,7 +97,8 @@ impl ComputeFnVTable for IsSorted { } else { if let Some(Precision::Exact(value)) = array.statistics().get_as::(Stat::IsSorted) { - return Ok(Scalar::from(Some(value)).into()); + let scalar: Scalar = Some(value).into(); + return Ok(scalar.into()); } let is_sorted = is_sorted_impl(array, kernels, false)?; @@ -113,7 +116,8 @@ impl ComputeFnVTable for IsSorted { is_sorted }; - Ok(Scalar::from(is_sorted).into()) + let scalar: Scalar = is_sorted.into(); + Ok(scalar.into()) } fn return_dtype(&self, _args: &InvocationArgs) -> VortexResult { @@ -198,7 +202,8 @@ impl Kernel for IsSortedKernelAdapter { V::is_sorted(&self.0, array)? }; - Ok(Some(Scalar::from(is_sorted).into())) + let scalar: Scalar = is_sorted.into(); + Ok(Some(scalar.into())) } } diff --git a/vortex-array/src/compute/min_max.rs b/vortex-array/src/compute/min_max.rs index ced0fcfeadf..931c2880cf7 100644 --- a/vortex-array/src/compute/min_max.rs +++ b/vortex-array/src/compute/min_max.rs @@ -102,13 +102,17 @@ impl ComputeFnVTable for MinMax { array.encoding_id() ); - // Update the stats set with the computed min/max - array - .statistics() - .set(Stat::Min, Precision::Exact(min.value().clone())); - array - .statistics() - .set(Stat::Max, Precision::Exact(max.value().clone())); + // Update the stats set with the computed min/max. + if let Some(min_value) = min.value() { + array + .statistics() + .set(Stat::Min, Precision::Exact(min_value.clone())); + } + if let Some(max_value) = max.value() { + array + .statistics() + .set(Stat::Max, Precision::Exact(max_value.clone())); + } // Return the min/max as a struct scalar Ok(Scalar::struct_(return_dtype, vec![min, max]).into()) diff --git a/vortex-array/src/compute/sum.rs b/vortex-array/src/compute/sum.rs index d1a2f3eff22..c10396540ed 100644 --- a/vortex-array/src/compute/sum.rs +++ b/vortex-array/src/compute/sum.rs @@ -66,7 +66,7 @@ pub fn sum(array: &dyn Array) -> VortexResult { let sum_dtype = Stat::Sum .dtype(array.dtype()) .ok_or_else(|| vortex_err!("Sum not supported for dtype: {}", array.dtype()))?; - let zero = Scalar::zero_value(sum_dtype); + let zero = Scalar::zero_value(&sum_dtype); sum_with_accumulator(array, &zero) } @@ -113,16 +113,17 @@ impl ComputeFnVTable for Sum { ); // Short-circuit using array statistics. - if let Some(Precision::Exact(sum)) = array.statistics().get(Stat::Sum) { - // For floats only use stats if accumulator is zero. otherwise we might have numerical stability issues. - match sum_dtype { + if let Some(Precision::Exact(sum_scalar)) = array.statistics().get(Stat::Sum) { + // For floats only use stats if accumulator is zero. otherwise we might have numerical + // stability issues. + match &sum_dtype { DType::Primitive(p, _) => { - if p.is_float() && accumulator.is_zero() { - return Ok(sum.into()); + if p.is_float() && accumulator.is_null() { + return Ok(sum_scalar.into()); } else if p.is_int() { let sum_from_stat = accumulator .as_primitive() - .checked_add(&sum.as_primitive()) + .checked_add(&sum_scalar.as_primitive()) .map(Scalar::from); return Ok(sum_from_stat .unwrap_or_else(|| Scalar::null(sum_dtype)) @@ -132,7 +133,7 @@ impl ComputeFnVTable for Sum { DType::Decimal(..) => { let sum_from_stat = accumulator .as_decimal() - .checked_binary_numeric(&sum.as_decimal(), NumericOperator::Add) + .checked_binary_numeric(&sum_scalar.as_decimal(), NumericOperator::Add) .map(Scalar::from); return Ok(sum_from_stat .unwrap_or_else(|| Scalar::null(sum_dtype)) @@ -147,30 +148,29 @@ impl ComputeFnVTable for Sum { // Update the statistics with the computed sum. Stored statistic shouldn't include the accumulator. match sum_dtype { DType::Primitive(p, _) => { - if p.is_float() && accumulator.is_zero() { + if p.is_float() + && accumulator.value().is_some_and(|value| value.is_zero()) + && let Some(sum_value) = sum_scalar.value().cloned() + { array .statistics() - .set(Stat::Sum, Precision::Exact(sum_scalar.value().clone())); + .set(Stat::Sum, Precision::Exact(sum_value)); } else if p.is_int() && let Some(less_accumulator) = sum_scalar .as_primitive() .checked_sub(&accumulator.as_primitive()) + && let Some(val) = Scalar::from(less_accumulator).into_value() { - array.statistics().set( - Stat::Sum, - Precision::Exact(Scalar::from(less_accumulator).value().clone()), - ); + array.statistics().set(Stat::Sum, Precision::Exact(val)); } } DType::Decimal(..) => { if let Some(less_accumulator) = sum_scalar .as_decimal() .checked_binary_numeric(&accumulator.as_decimal(), NumericOperator::Sub) + && let Some(val) = Scalar::from(less_accumulator).into_value() { - array.statistics().set( - Stat::Sum, - Precision::Exact(Scalar::from(less_accumulator).value().clone()), - ) + array.statistics().set(Stat::Sum, Precision::Exact(val)); } } _ => unreachable!("Sum will always be a decimal or a primitive dtype"), diff --git a/vortex-array/src/expr/exprs/dynamic.rs b/vortex-array/src/expr/exprs/dynamic.rs index a1717483a0f..7636e79098e 100644 --- a/vortex-array/src/expr/exprs/dynamic.rs +++ b/vortex-array/src/expr/exprs/dynamic.rs @@ -107,10 +107,11 @@ impl VTable for DynamicComparison { let ret_dtype = DType::Bool(args.inputs[0].dtype().nullability() | data.rhs.dtype.nullability()); - Ok( - ConstantArray::new(Scalar::new(ret_dtype, data.default.into()), args.row_count) - .into_array(), + Ok(ConstantArray::new( + Scalar::try_new(ret_dtype, Some(data.default.into()))?, + args.row_count, ) + .into_array()) } fn stat_falsification( @@ -193,7 +194,10 @@ pub struct DynamicComparisonExpr { impl DynamicComparisonExpr { pub fn scalar(&self) -> Option { - (self.rhs.value)().map(|v| Scalar::new(self.rhs.dtype.clone(), v)) + (self.rhs.value)().map(|v| { + Scalar::try_new(self.rhs.dtype.clone(), Some(v)) + .vortex_expect("`DynamicComparisonExpr` was invalid") + }) } } @@ -237,7 +241,9 @@ struct Rhs { impl Rhs { pub fn scalar(&self) -> Option { - (self.value)().map(|v| Scalar::new(self.dtype.clone(), v)) + (self.value)().map(|v| { + Scalar::try_new(self.dtype.clone(), Some(v)).vortex_expect("`Rhs` was invalid") + }) } } @@ -283,7 +289,12 @@ impl DynamicExprUpdates { let exprs = visitor.0.into_boxed_slice(); let prev_versions = exprs .iter() - .map(|expr| (expr.rhs.value)().map(|v| Scalar::new(expr.rhs.dtype.clone(), v))) + .map(|expr| { + (expr.rhs.value)().map(|v| { + Scalar::try_new(expr.rhs.dtype.clone(), Some(v)) + .vortex_expect("`DynamicExprUpdates` was invalid") + }) + }) .collect(); Some(Self { diff --git a/vortex-array/src/expr/exprs/like.rs b/vortex-array/src/expr/exprs/like.rs index 8ac15b9bce4..e4cc85a3314 100644 --- a/vortex-array/src/expr/exprs/like.rs +++ b/vortex-array/src/expr/exprs/like.rs @@ -159,7 +159,7 @@ impl VTable for Like { let src_min = src.stat_min(catalog)?; let src_max = src.stat_max(catalog)?; - match LikeVariant::from_str(&pat_str)? { + match LikeVariant::from_str(pat_str)? { LikeVariant::Exact(text) => { // col LIKE 'exact' ==> col.min > 'exact' || col.max < 'exact' Some(or(gt(src_min, lit(text)), lt(src_max, lit(text)))) diff --git a/vortex-array/src/expr/exprs/literal.rs b/vortex-array/src/expr/exprs/literal.rs index 1465ec23ddd..b610359b699 100644 --- a/vortex-array/src/expr/exprs/literal.rs +++ b/vortex-array/src/expr/exprs/literal.rs @@ -38,7 +38,7 @@ impl VTable for Literal { fn serialize(&self, instance: &Self::Options) -> VortexResult>> { Ok(Some( pb::LiteralOpts { - value: Some(instance.as_ref().into()), + value: Some(instance.into()), } .encode_to_vec(), )) diff --git a/vortex-array/src/expr/stats/precision.rs b/vortex-array/src/expr/stats/precision.rs index 3c2ed9332d3..7f3e13b3c48 100644 --- a/vortex-array/src/expr/stats/precision.rs +++ b/vortex-array/src/expr/stats/precision.rs @@ -6,6 +6,7 @@ use std::fmt::Display; use std::fmt::Formatter; use vortex_dtype::DType; +use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_scalar::Scalar; use vortex_scalar::ScalarValue; @@ -20,26 +21,14 @@ use crate::expr::stats::precision::Precision::Inexact; /// This is statistic specific, for max this will be an upper bound. Meaning that the actual max /// in an array is guaranteed to be less than or equal to the inexact value, but equal to the exact /// value. -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] pub enum Precision { Exact(T), Inexact(T), } -impl Clone for Precision -where - T: Clone, -{ - fn clone(&self) -> Self { - match self { - Exact(e) => Exact(e.clone()), - Inexact(ie) => Inexact(ie.clone()), - } - } -} - impl Precision> { - /// Transpose the `Option>` into `Option>`. + /// Transpose the `Precision>` into `Option>`. pub fn transpose(self) -> Option> { match self { Exact(Some(x)) => Some(Exact(x)), @@ -167,13 +156,22 @@ impl PartialEq for Precision { } impl Precision { + /// Convert this [`Precision`] into a [`Precision`] with the given + /// [`DType`]. pub fn into_scalar(self, dtype: DType) -> Precision { - self.map(|v| Scalar::new(dtype, v)) + self.map(|v| { + Scalar::try_new(dtype, Some(v)).vortex_expect("`Precision` was invalid") + }) } } impl Precision<&ScalarValue> { + /// Convert this [`Precision<&ScalarValue>`] into a [`Precision`] with the given + /// [`DType`]. pub fn into_scalar(self, dtype: DType) -> Precision { - self.map(|v| Scalar::new(dtype, v.clone())) + self.map(|v| { + Scalar::try_new(dtype, Some(v.clone())) + .vortex_expect("`Precision` was invalid") + }) } } diff --git a/vortex-array/src/serde.rs b/vortex-array/src/serde.rs index be064f9652e..373e174360e 100644 --- a/vortex-array/src/serde.rs +++ b/vortex-array/src/serde.rs @@ -22,7 +22,6 @@ use vortex_error::vortex_bail; use vortex_error::vortex_err; use vortex_error::vortex_panic; use vortex_flatbuffers::FlatBuffer; -use vortex_flatbuffers::ReadFlatBuffer; use vortex_flatbuffers::WriteFlatBuffer; use vortex_flatbuffers::array as fba; use vortex_flatbuffers::array::Compression; @@ -378,7 +377,7 @@ impl ArrayParts { // Populate statistics from the serialized array. if let Some(stats) = self.flatbuffer().stats() { let decoded_statistics = decoded.statistics(); - StatsSet::read_flatbuffer(&stats)? + StatsSet::from_flatbuffer(&stats, dtype)? .into_iter() .for_each(|(stat, val)| decoded_statistics.set(stat, val)); } diff --git a/vortex-array/src/stats/array.rs b/vortex-array/src/stats/array.rs index 9df19542f15..873d5ee635f 100644 --- a/vortex-array/src/stats/array.rs +++ b/vortex-array/src/stats/array.rs @@ -197,8 +197,10 @@ impl StatsSetRef<'_> { pub fn compute_all(&self, stats: &[Stat]) -> VortexResult { let mut stats_set = StatsSet::default(); for &stat in stats { - if let Some(s) = self.compute_stat(stat)? { - stats_set.set(stat, Precision::exact(s.into_value())) + if let Some(s) = self.compute_stat(stat)? + && let Some(value) = s.into_value() + { + stats_set.set(stat, Precision::exact(value)); } } Ok(stats_set) diff --git a/vortex-array/src/stats/flatbuffers.rs b/vortex-array/src/stats/flatbuffers.rs index d7f3b23172b..67f82fcb7c3 100644 --- a/vortex-array/src/stats/flatbuffers.rs +++ b/vortex-array/src/stats/flatbuffers.rs @@ -2,15 +2,12 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use flatbuffers::FlatBufferBuilder; -use flatbuffers::Follow; use flatbuffers::WIPOffset; use vortex_dtype::DType; use vortex_dtype::Nullability; use vortex_dtype::PType; -use vortex_error::VortexError; use vortex_error::VortexResult; use vortex_error::vortex_bail; -use vortex_flatbuffers::ReadFlatBuffer; use vortex_flatbuffers::WriteFlatBuffer; use vortex_flatbuffers::array as fba; use vortex_scalar::ScalarValue; @@ -49,7 +46,11 @@ impl WriteFlatBuffer for StatsSet { } else { fba::Precision::Inexact }, - Some(fbb.create_vector(&min.into_inner().to_protobytes::>())), + Some( + fbb.create_vector(&ScalarValue::to_proto_bytes::>(Some( + &min.into_inner(), + ))), + ), ) }) .unwrap_or_else(|| (fba::Precision::Inexact, None)); @@ -63,7 +64,11 @@ impl WriteFlatBuffer for StatsSet { } else { fba::Precision::Inexact }, - Some(fbb.create_vector(&max.into_inner().to_protobytes::>())), + Some( + fbb.create_vector(&ScalarValue::to_proto_bytes::>(Some( + &max.into_inner(), + ))), + ), ) }) .unwrap_or_else(|| (fba::Precision::Inexact, None)); @@ -71,7 +76,7 @@ impl WriteFlatBuffer for StatsSet { let sum = self .get(Stat::Sum) .and_then(Precision::as_exact) - .map(|sum| fbb.create_vector(&sum.to_protobytes::>())); + .map(|sum| fbb.create_vector(&ScalarValue::to_proto_bytes::>(Some(&sum)))); let stat_args = &fba::ArrayStatsArgs { min, @@ -103,16 +108,17 @@ impl WriteFlatBuffer for StatsSet { } } -impl ReadFlatBuffer for StatsSet { - type Source<'a> = fba::ArrayStats<'a>; - type Error = VortexError; - - fn read_flatbuffer<'buf>( - fb: & as Follow<'buf>>::Inner, - ) -> Result { +impl StatsSet { + /// Creates a [`StatsSet`] from a flatbuffers array [`fba::ArrayStats<'a>`]. + pub fn from_flatbuffer<'a>( + fb: &fba::ArrayStats<'a>, + array_dtype: &DType, + ) -> VortexResult { let mut stats_set = StatsSet::default(); for stat in Stat::all() { + let stat_dtype = stat.dtype(array_dtype); + match stat { Stat::IsConstant => { if let Some(is_constant) = fb.is_constant() { @@ -133,8 +139,14 @@ impl ReadFlatBuffer for StatsSet { } } Stat::Max => { - if let Some(max) = fb.max() { - let value = ScalarValue::from_protobytes(max.bytes())?; + if let Some(max) = fb.max() + && let Some(stat_dtype) = stat_dtype + { + let value = ScalarValue::from_proto_bytes(max.bytes(), &stat_dtype)?; + let Some(value) = value else { + continue; + }; + stats_set.set( Stat::Max, match fb.max_precision() { @@ -146,8 +158,14 @@ impl ReadFlatBuffer for StatsSet { } } Stat::Min => { - if let Some(min) = fb.min() { - let value = ScalarValue::from_protobytes(min.bytes())?; + if let Some(min) = fb.min() + && let Some(stat_dtype) = stat_dtype + { + let value = ScalarValue::from_proto_bytes(min.bytes(), &stat_dtype)?; + let Some(value) = value else { + continue; + }; + stats_set.set( Stat::Min, match fb.min_precision() { @@ -172,11 +190,15 @@ impl ReadFlatBuffer for StatsSet { } } Stat::Sum => { - if let Some(sum) = fb.sum() { - stats_set.set( - Stat::Sum, - Precision::Exact(ScalarValue::from_protobytes(sum.bytes())?), - ); + if let Some(sum) = fb.sum() + && let Some(stat_dtype) = stat_dtype + { + let value = ScalarValue::from_proto_bytes(sum.bytes(), &stat_dtype)?; + let Some(value) = value else { + continue; + }; + + stats_set.set(Stat::Sum, Precision::Exact(value)); } } Stat::NaNCount => { diff --git a/vortex-array/src/stats/stats_set.rs b/vortex-array/src/stats/stats_set.rs index af977e6753c..aa0ed782b90 100644 --- a/vortex-array/src/stats/stats_set.rs +++ b/vortex-array/src/stats/stats_set.rs @@ -132,7 +132,11 @@ impl StatsSet { ) -> Option> { self.get(stat).map(|v| { v.map(|v| { - T::try_from(&Scalar::new(dtype.clone(), v)).unwrap_or_else(|err| { + T::try_from( + &Scalar::try_new(dtype.clone(), Some(v)) + .vortex_expect("failed to construct a scalar statistic"), + ) + .unwrap_or_else(|err| { vortex_panic!( err, "Failed to get stat {} as {}", @@ -225,11 +229,12 @@ impl StatsProvider for TypedStatsSetRef<'_, '_> { fn get(&self, stat: Stat) -> Option> { self.values.get(stat).map(|p| { p.map(|sv| { - Scalar::new( + Scalar::try_new( stat.dtype(self.dtype) .vortex_expect("Must have valid dtype if value is present"), - sv, + Some(sv), ) + .vortex_expect("failed to construct a scalar statistic") }) }) } @@ -260,11 +265,12 @@ impl StatsProvider for MutTypedStatsSetRef<'_, '_> { fn get(&self, stat: Stat) -> Option> { self.values.get(stat).map(|p| { p.map(|sv| { - Scalar::new( + Scalar::try_new( stat.dtype(self.dtype) .vortex_expect("Must have valid dtype if value is present"), - sv, + Some(sv), ) + .vortex_expect("failed to construct a scalar statistic") }) }) } @@ -356,10 +362,22 @@ impl MutTypedStatsSetRef<'_, '_> { vortex_err!("{:?} bounds ({m1:?}, {m2:?}) do not overlap", S::STAT) })?; if meet != m1 { - self.set(S::STAT, meet.into_value().map(Scalar::into_value)); + self.set( + S::STAT, + meet.into_value().map(|s| { + s.into_value() + .vortex_expect("stat scalar value cannot be null") + }), + ); } } - (None, Some(m)) => self.set(S::STAT, m.into_value().map(Scalar::into_value)), + (None, Some(m)) => self.set( + S::STAT, + m.into_value().map(|s| { + s.into_value() + .vortex_expect("stat scalar value cannot be null") + }), + ), (Some(_), _) => (), (None, None) => self.clear(S::STAT), } @@ -400,7 +418,13 @@ impl MutTypedStatsSetRef<'_, '_> { (Some(m1), Some(m2)) => { let meet = m1.union(&m2).vortex_expect("can compare scalar"); if meet != m1 { - self.set(Stat::Min, meet.into_value().map(Scalar::into_value)); + self.set( + Stat::Min, + meet.into_value().map(|s| { + s.into_value() + .vortex_expect("stat scalar value cannot be null") + }), + ); } } _ => self.clear(Stat::Min), @@ -415,7 +439,13 @@ impl MutTypedStatsSetRef<'_, '_> { (Some(m1), Some(m2)) => { let meet = m1.union(&m2).vortex_expect("can compare scalar"); if meet != m1 { - self.set(Stat::Max, meet.into_value().map(Scalar::into_value)); + self.set( + Stat::Max, + meet.into_value().map(|s| { + s.into_value() + .vortex_expect("stat scalar value cannot be null") + }), + ); } } _ => self.clear(Stat::Max), @@ -432,19 +462,7 @@ impl MutTypedStatsSetRef<'_, '_> { if let Some(scalar_value) = m1.zip(m2).as_exact().and_then(|(s1, s2)| { s1.as_primitive() .checked_add(&s2.as_primitive()) - .map(|pscalar| { - pscalar - .pvalue() - .map(|pvalue| { - Scalar::primitive_value( - pvalue, - pscalar.ptype(), - pscalar.dtype().nullability(), - ) - .into_value() - }) - .unwrap_or_else(ScalarValue::null) - }) + .and_then(|pscalar| pscalar.pvalue().map(ScalarValue::Primitive)) }) { self.set(Stat::Sum, Precision::Exact(scalar_value)); } @@ -565,17 +583,18 @@ mod test { let first = iter.next().unwrap().clone(); assert_eq!(first.0, Stat::Max); assert_eq!( - first - .1 - .map(|f| i32::try_from(&Scalar::new(PType::I32.into(), f)).unwrap()), + first.1.map( + |f| i32::try_from(&Scalar::try_new(PType::I32.into(), Some(f)).unwrap()).unwrap() + ), Precision::exact(100) ); let snd = iter.next().unwrap().clone(); assert_eq!(snd.0, Stat::Min); assert_eq!( - snd.1 - .map(|s| i32::try_from(&Scalar::new(PType::I32.into(), s)).unwrap()), - 42 + snd.1.map( + |s| i32::try_from(&Scalar::try_new(PType::I32.into(), Some(s)).unwrap()).unwrap() + ), + Precision::exact(42) ); } @@ -592,14 +611,17 @@ mod test { let (stat, first) = set.next().unwrap(); assert_eq!(stat, Stat::Max); assert_eq!( - first.map(|f| i32::try_from(&Scalar::new(PType::I32.into(), f)).unwrap()), + first.map( + |f| i32::try_from(&Scalar::try_new(PType::I32.into(), Some(f)).unwrap()).unwrap() + ), Precision::exact(100) ); let snd = set.next().unwrap(); assert_eq!(snd.0, Stat::Min); assert_eq!( - snd.1 - .map(|s| i32::try_from(&Scalar::new(PType::I32.into(), s)).unwrap()), + snd.1.map( + |s| i32::try_from(&Scalar::try_new(PType::I32.into(), Some(s)).unwrap()).unwrap() + ), Precision::exact(42) ); } @@ -710,7 +732,9 @@ mod test { #[test] fn merge_into_scalar() { - let first = StatsSet::of(Stat::Sum, Precision::exact(42)).merge_ordered( + // Sum stats for primitive types are always the 64-bit version (i64 for signed, u64 + // for unsigned, f64 for floats). + let first = StatsSet::of(Stat::Sum, Precision::exact(42i64)).merge_ordered( &StatsSet::default(), &DType::Primitive(PType::I32, Nullability::NonNullable), ); @@ -720,8 +744,10 @@ mod test { #[test] fn merge_from_scalar() { + // Sum stats for primitive types are always the 64-bit version (i64 for signed, u64 + // for unsigned, f64 for floats). let first = StatsSet::default().merge_ordered( - &StatsSet::of(Stat::Sum, Precision::exact(42)), + &StatsSet::of(Stat::Sum, Precision::exact(42i64)), &DType::Primitive(PType::I32, Nullability::NonNullable), ); let first_ref = first.as_typed_ref(&DType::Primitive(PType::I32, Nullability::NonNullable)); @@ -730,14 +756,16 @@ mod test { #[test] fn merge_scalars() { - let first = StatsSet::of(Stat::Sum, Precision::exact(37)).merge_ordered( - &StatsSet::of(Stat::Sum, Precision::exact(42)), + // Sum stats for primitive types are always the 64-bit version (i64 for signed, u64 + // for unsigned, f64 for floats). + let first = StatsSet::of(Stat::Sum, Precision::exact(37i64)).merge_ordered( + &StatsSet::of(Stat::Sum, Precision::exact(42i64)), &DType::Primitive(PType::I32, Nullability::NonNullable), ); let first_ref = first.as_typed_ref(&DType::Primitive(PType::I32, Nullability::NonNullable)); assert_eq!( - first_ref.get_as::(Stat::Sum), - Some(Precision::exact(79usize)) + first_ref.get_as::(Stat::Sum), + Some(Precision::exact(79i64)) ); } diff --git a/vortex-array/src/variants.rs b/vortex-array/src/variants.rs index 1ee704ff777..7a3e248dadf 100644 --- a/vortex-array/src/variants.rs +++ b/vortex-array/src/variants.rs @@ -122,7 +122,7 @@ impl PrimitiveTyped<'_> { .scalar_at(idx)? .as_primitive() .pvalue() - .unwrap_or_else(|| PValue::zero(self.ptype()))) + .unwrap_or_else(|| PValue::zero(&self.ptype()))) } } diff --git a/vortex-btrblocks/src/compressor/decimal.rs b/vortex-btrblocks/src/compressor/decimal.rs index 5170405d10c..4a3f6e5475a 100644 --- a/vortex-btrblocks/src/compressor/decimal.rs +++ b/vortex-btrblocks/src/compressor/decimal.rs @@ -8,8 +8,8 @@ use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::narrowed_decimal; use vortex_array::vtable::ValidityHelper; use vortex_decimal_byte_parts::DecimalBytePartsArray; +use vortex_dtype::DecimalType; use vortex_error::VortexResult; -use vortex_scalar::DecimalType; use crate::BtrBlocksCompressor; use crate::CanonicalCompressor; diff --git a/vortex-buffer/src/string.rs b/vortex-buffer/src/string.rs index 182b2fe6989..bcaf8d7c3d2 100644 --- a/vortex-buffer/src/string.rs +++ b/vortex-buffer/src/string.rs @@ -84,10 +84,29 @@ impl TryFrom for BufferString { let err = simdutf8::compat::from_utf8(value.as_ref()).unwrap_err(); vortex_err!("invalid utf-8: {err}") })?; + Ok(Self(value)) } } +impl TryFrom<&[u8]> for BufferString { + type Error = VortexError; + + fn try_from(value: &[u8]) -> Result { + simdutf8::basic::from_utf8(value).map_err(|_| { + #[expect( + clippy::unwrap_used, + reason = "unwrap is intentional - the error was already detected" + )] + // run validation using `compat` package to get more detailed error message + let err = simdutf8::compat::from_utf8(value).unwrap_err(); + vortex_err!("invalid utf-8: {err}") + })?; + + Ok(Self(ByteBuffer::from(value.to_vec()))) + } +} + impl Deref for BufferString { type Target = str; diff --git a/vortex-datafusion/src/convert/scalars.rs b/vortex-datafusion/src/convert/scalars.rs index 5031614b0b7..4606dfdb90d 100644 --- a/vortex-datafusion/src/convert/scalars.rs +++ b/vortex-datafusion/src/convert/scalars.rs @@ -13,11 +13,12 @@ use vortex::dtype::datetime::AnyTemporal; use vortex::dtype::datetime::TemporalMetadata; use vortex::dtype::datetime::TimeUnit; use vortex::dtype::half::f16; +use vortex::dtype::i256; +use vortex::error::VortexExpect; use vortex::error::VortexResult; use vortex::error::vortex_bail; use vortex::scalar::DecimalValue; use vortex::scalar::Scalar; -use vortex::scalar::i256; use crate::convert::FromDataFusion; use crate::convert::TryToDataFusion; @@ -101,12 +102,12 @@ impl TryToDataFusion for Scalar { } } // SAFETY: By construction Utf8 scalar values are utf8 - DType::Utf8(_) => ScalarValue::Utf8(self.as_utf8().value().map(|s| unsafe { + DType::Utf8(_) => ScalarValue::Utf8(self.as_utf8().value().cloned().map(|s| unsafe { String::from_utf8_unchecked(Vec::::from(s.into_inner().into_inner())) })), DType::Binary(_) => ScalarValue::Binary( self.as_binary() - .value() + .to_value() .map(|b| Vec::::from(b.into_inner())), ), DType::Struct(..) => todo!("struct scalar conversion"), @@ -217,11 +218,8 @@ impl FromDataFusion for Scalar { | ScalarValue::Time32Second(v) | ScalarValue::Time32Millisecond(v) => { let dtype = DType::from_arrow((&value.data_type(), Nullability::Nullable)); - Scalar::new( - dtype, - v.map(vortex::scalar::ScalarValue::from) - .unwrap_or_else(vortex::scalar::ScalarValue::null), - ) + Scalar::try_new(dtype, v.map(vortex::scalar::ScalarValue::from)) + .vortex_expect("unable to create a time `Scalar`") } ScalarValue::Date64(v) | ScalarValue::Time64Microsecond(v) @@ -231,11 +229,8 @@ impl FromDataFusion for Scalar { | ScalarValue::TimestampMicrosecond(v, _) | ScalarValue::TimestampNanosecond(v, _) => { let dtype = DType::from_arrow((&value.data_type(), Nullability::Nullable)); - Scalar::new( - dtype, - v.map(vortex::scalar::ScalarValue::from) - .unwrap_or_else(vortex::scalar::ScalarValue::null), - ) + Scalar::try_new(dtype, v.map(vortex::scalar::ScalarValue::from)) + .vortex_expect("unable to create a time `Scalar`") } ScalarValue::Decimal32(decimal, precision, scale) => { let decimal_dtype = DecimalDType::new(*precision, *scale); @@ -305,9 +300,9 @@ mod tests { use vortex::dtype::DecimalDType; use vortex::dtype::Nullability; use vortex::dtype::PType; + use vortex::dtype::i256; use vortex::scalar::DecimalValue; use vortex::scalar::Scalar; - use vortex::scalar::i256; use super::*; @@ -684,7 +679,7 @@ mod tests { #[case::fixed_size_binary(ScalarValue::FixedSizeBinary(5, Some(vec![1u8, 2, 3, 4, 5])))] fn test_binary_variants(#[case] variant: ScalarValue) { let result = Scalar::from_df(&variant); - let result_bytes: Vec = result.as_binary().value().unwrap().into_inner().into(); + let result_bytes: Vec = result.as_binary().to_value().unwrap().into_inner().into(); assert_eq!(result_bytes, vec![1u8, 2, 3, 4, 5]); } } diff --git a/vortex-datafusion/src/persistent/cache.rs b/vortex-datafusion/src/persistent/cache.rs index b28cf972c2b..567b8c8869b 100644 --- a/vortex-datafusion/src/persistent/cache.rs +++ b/vortex-datafusion/src/persistent/cache.rs @@ -58,6 +58,7 @@ fn estimate_footer_size(footer: &Footer) -> usize { .statistics() .map(|stats| { stats + .stats() .iter() .map(|s| { s.iter().count() * (size_of::() + size_of::>()) diff --git a/vortex-datafusion/src/persistent/format.rs b/vortex-datafusion/src/persistent/format.rs index 51c7667561d..a447ee9092b 100644 --- a/vortex-datafusion/src/persistent/format.rs +++ b/vortex-datafusion/src/persistent/format.rs @@ -39,11 +39,9 @@ use futures::FutureExt; use futures::StreamExt as _; use futures::TryStreamExt as _; use futures::stream; -use itertools::Itertools; use object_store::ObjectMeta; use object_store::ObjectStore; use vortex::VortexSessionDefault; -use vortex::array::stats::StatsSet; use vortex::dtype::DType; use vortex::dtype::Nullability; use vortex::dtype::PType; @@ -367,81 +365,86 @@ impl FileFormat for VortexFormat { }); }; - let stats = table_schema - .fields() - .iter() - .map(|field| struct_dtype.find(field.name())) - .map(|idx| match idx { - None => StatsSet::default(), - Some(id) => file_stats[id].clone(), - }) - .collect_vec(); - - let total_byte_size = stats - .iter() - .map(|stats_set| { - stats_set - .get_as::(Stat::UncompressedSizeInBytes, &PType::U64.into()) - .unwrap_or_else(|| stats::Precision::inexact(0_usize)) - }) - .fold(stats::Precision::exact(0_usize), |acc, stats_set| { - acc.zip(stats_set).map(|(acc, stats_set)| acc + stats_set) - }); - - // Sum up the total byte size across all the columns. - let total_byte_size = total_byte_size.to_df(); - - let column_statistics = stats - .into_iter() - .zip(table_schema.fields().iter()) - .map(|(stats_set, field)| { - let null_count = stats_set.get_as::(Stat::NullCount, &PType::U64.into()); - let min = stats_set.get(Stat::Min).and_then(|n| { - n.map(|n| { - Scalar::new( + let mut sum_of_column_byte_sizes = stats::Precision::exact(0_usize); + let mut column_statistics = Vec::with_capacity(table_schema.fields().len()); + + for field in table_schema.fields().iter() { + // TODO(connor): Is this actually true? + // If the column does not exist, continue. This can happen if the schema has evolved + // but we have not yet updated the Vortex file. + let Some(col_idx) = struct_dtype.find(field.name()) else { + // The default sets all statistics to `Precision`. + column_statistics.push(ColumnStatistics::default()); + continue; + }; + let (stats_set, stats_dtype) = file_stats.get(col_idx); + + // Update the total size in bytes. + let column_size = stats_set + .get_as::(Stat::UncompressedSizeInBytes, &PType::U64.into()) + .unwrap_or_else(|| stats::Precision::inexact(0_usize)); + sum_of_column_byte_sizes = sum_of_column_byte_sizes + .zip(column_size) + .map(|(acc, size)| acc + size); + + // Find the min statistic. + let min = stats_set.get(Stat::Min).and_then(|pstat_val| { + pstat_val + .map(|stat_val| { + // Because of DataFusion's Schema evolution, it is possible that the + // type of the min/max stat has changed. Thus we construct the stat as + // the file datatype first and only then do we cast accordingly. + Scalar::try_new( Stat::Min - .dtype(&DType::from_arrow(field.as_ref())) + .dtype(stats_dtype) .vortex_expect("must have a valid dtype"), - n, + Some(stat_val), ) + .vortex_expect("`Stat::Min` somehow had an incompatible `DType`") + .cast(&DType::from_arrow(field.as_ref())) + .vortex_expect("Unable to cast to target type that DataFusion wants") .try_to_df() .ok() }) .transpose() - }); + }); - let max = stats_set.get(Stat::Max).and_then(|n| { - n.map(|n| { - Scalar::new( + // Find the max statistic. + let max = stats_set.get(Stat::Max).and_then(|pstat_val| { + pstat_val + .map(|stat_val| { + Scalar::try_new( Stat::Max - .dtype(&DType::from_arrow(field.as_ref())) + .dtype(stats_dtype) .vortex_expect("must have a valid dtype"), - n, + Some(stat_val), ) + .vortex_expect("`Stat::Max` somehow had an incompatible `DType`") + .cast(&DType::from_arrow(field.as_ref())) + .vortex_expect("Unable to cast to target type that DataFusion wants") .try_to_df() .ok() }) .transpose() - }); - - ColumnStatistics { - null_count: null_count.to_df(), - max_value: max.to_df(), - min_value: min.to_df(), - sum_value: Precision::Absent, - distinct_count: stats_set - .get_as::( - Stat::IsConstant, - &DType::Bool(Nullability::NonNullable), - ) - .and_then(|is_constant| { - is_constant.as_exact().map(|_| Precision::Exact(1)) - }) - .unwrap_or(Precision::Absent), - byte_size: Precision::Absent, - } + }); + + let null_count = stats_set.get_as::(Stat::NullCount, &PType::U64.into()); + + column_statistics.push(ColumnStatistics { + null_count: null_count.to_df(), + min_value: min.to_df(), + max_value: max.to_df(), + sum_value: Precision::Absent, + distinct_count: stats_set + .get_as::(Stat::IsConstant, &DType::Bool(Nullability::NonNullable)) + .and_then(|is_constant| is_constant.as_exact().map(|_| Precision::Exact(1))) + .unwrap_or(Precision::Absent), + // TODO(connor): Is this correct? + byte_size: column_size.to_df(), }) - .collect::>(); + } + + let total_byte_size = sum_of_column_byte_sizes.to_df(); Ok(Statistics { num_rows: Precision::Exact( diff --git a/vortex-dtype/src/dtype.rs b/vortex-dtype/src/dtype.rs index 07a54a2bcdd..7120c7a5a37 100644 --- a/vortex-dtype/src/dtype.rs +++ b/vortex-dtype/src/dtype.rs @@ -252,7 +252,7 @@ impl DType { if let Primitive(ptype, _) = self { *ptype } else { - vortex_panic!("DType is not a primitive type") + vortex_panic!("DType {self} is not a primitive type") } } diff --git a/vortex-duckdb/src/convert/scalar.rs b/vortex-duckdb/src/convert/scalar.rs index ac82f21ffd4..78f43dab13a 100644 --- a/vortex-duckdb/src/convert/scalar.rs +++ b/vortex-duckdb/src/convert/scalar.rs @@ -156,7 +156,7 @@ impl ToDuckDBScalar for Utf8Scalar<'_> { impl ToDuckDBScalar for BinaryScalar<'_> { /// Converts a binary scalar to a DuckDB BLOB value. fn try_to_duckdb_scalar(&self) -> VortexResult { - Ok(match self.value() { + Ok(match self.value_ref() { Some(value) => Value::from(value.as_slice()), None => Value::null(&LogicalType::blob()), }) @@ -261,39 +261,57 @@ impl<'a> TryFrom> for Scalar { ExtractedValue::Blob(b) => Ok(Scalar::binary(b, Nullable)), ExtractedValue::Date(days) => Ok(Scalar::extension::( TimeUnit::Days, - Scalar::new(DType::Primitive(I32, Nullable), ScalarValue::from(days)), + Scalar::try_new( + DType::Primitive(I32, Nullable), + Some(ScalarValue::from(days)), + )?, )), ExtractedValue::Time(micros) => Ok(Scalar::extension::