From c2fc0e35135950700184382037cce0ad2cbf83d4 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Wed, 4 Feb 2026 10:04:29 -0500 Subject: [PATCH 01/22] change scalar and scalar value Signed-off-by: Connor Tsui --- vortex-scalar/src/scalar.rs | 750 ++++++++++-------------------- vortex-scalar/src/scalar_value.rs | 346 ++++---------- 2 files changed, 342 insertions(+), 754 deletions(-) diff --git a/vortex-scalar/src/scalar.rs b/vortex-scalar/src/scalar.rs index ec9d0ca4580..6c24cc80995 100644 --- a/vortex-scalar/src/scalar.rs +++ b/vortex-scalar/src/scalar.rs @@ -2,551 +2,335 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use std::cmp::Ordering; -use std::hash::Hash; -use std::sync::Arc; -use vortex_buffer::Buffer; use vortex_dtype::DType; -use vortex_dtype::NativeDType; -use vortex_dtype::NativeDecimalType; -use vortex_dtype::Nullability; -use vortex_dtype::i256; -use vortex_error::VortexError; use vortex_error::VortexExpect; use vortex_error::VortexResult; -use vortex_error::vortex_bail; -use vortex_error::vortex_err; - -use super::*; - -/// A single logical item, composed of both a [`ScalarValue`] and a logical [`DType`]. -/// -/// A [`ScalarValue`] is opaque, and should be accessed via one of the type-specific scalar wrappers -/// for example [`BoolScalar`], [`PrimitiveScalar`], etc. -/// -/// Note that [`PartialOrd`] is implemented only for an exact match of the scalar's dtype, -/// including nullability. When the DType does match, ordering is nulls first (lowest), then the -/// natural ordering of the scalar value. -#[derive(Debug, Clone)] +use vortex_error::vortex_ensure; + +use crate::BinaryScalar; +use crate::BoolScalar; +use crate::DecimalScalar; +use crate::FixedSizeListScalar; +use crate::ListScalar; +use crate::PrimitiveScalar; +use crate::ScalarValue; +use crate::StructScalar; +use crate::Utf8Scalar; +use crate::extension::ExtensionScalar; + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct Scalar { /// The type of the scalar. dtype: DType, - /// The value of the scalar. + /// The value of the scalar. This is `None` if the value is null, otherwise it is `Some`. /// - /// Invariant: If the `dtype` is non-nullable, then this value _cannot_ be equal to - /// [`ScalarValue::null()`](ScalarValue::null). - value: ScalarValue, + /// Invariant: If the `dtype` is non-nullable, then this value _cannot_ be `None`. + value: Option, } impl Scalar { - /// Creates a new scalar with the given data type and value. - pub fn new(dtype: DType, value: ScalarValue) -> Self { - if !dtype.is_nullable() { - assert!( - !value.is_null(), - "Tried to construct a null scalar when the `DType` is non-nullable: {dtype}", - ); - } - + /// Create a new Scalar with the given DType and value without checking compatibility. + /// + /// # Safety + /// + /// The caller must ensure that the given DType and value are compatible per the rules defined + /// in `is_compatible`. + pub unsafe fn new_unchecked(dtype: DType, value: Option) -> Self { Self { dtype, value } } - /// Returns a reference to the scalar's data type. - #[inline] - pub fn dtype(&self) -> &DType { - &self.dtype - } - - /// Returns a reference to the scalar's underlying value. - #[inline] - pub fn value(&self) -> &ScalarValue { - &self.value + /// Create a new Scalar with the given DType and value. + pub fn try_new(dtype: DType, value: Option) -> VortexResult { + vortex_ensure!( + is_compatible(&dtype, value.as_ref()), + "Incompatible dtype {} with value {}", + dtype, + value.map(|v| format!("{}", v)).unwrap_or_default() + ); + Ok(Self { dtype, value }) } - /// Consumes the scalar and returns its data type and value as a tuple. - #[inline] - pub fn into_parts(self) -> (DType, ScalarValue) { + /// Returns the parts of the Scalar. + pub fn into_parts(self) -> (DType, Option) { (self.dtype, self.value) } - /// Consumes the scalar and returns its underlying [`DType`]. - #[inline] - pub fn into_dtype(self) -> DType { - self.dtype - } - - /// Consumes the scalar and returns its underlying [`ScalarValue`]. - #[inline] - pub fn into_value(self) -> ScalarValue { - self.value - } - - /// Returns true if the scalar is not null. - pub fn is_valid(&self) -> bool { - !self.value.is_null() + /// Returns the DType of the Scalar. + pub fn dtype(&self) -> &DType { + &self.dtype } - /// Returns true if the scalar is null. + /// Returns true if the Scalar is null. pub fn is_null(&self) -> bool { - self.value.is_null() + self.value.is_none() } - /// Creates a null scalar with the given nullable data type. - /// - /// # Panics - /// - /// Panics if the data type is not nullable. - pub fn null(dtype: DType) -> Self { - assert!( - dtype.is_nullable(), - "Tried to construct a null scalar when the `DType` is non-nullable: {dtype}" - ); - - Self { - dtype, - value: ScalarValue(InnerScalarValue::Null), - } + /// Returns the scalar value. + pub fn value(&self) -> Option<&ScalarValue> { + self.value.as_ref() } - /// Creates a null scalar for the given scalar type. - /// - /// The resulting scalar will have a nullable version of the type's data type. - pub fn null_typed() -> Self { - Self { - dtype: T::dtype().as_nullable(), - value: ScalarValue(InnerScalarValue::Null), - } - } - - /// Casts the scalar to the target data type. - /// - /// Returns an error if the cast is not supported or if the value cannot be represented - /// in the target type. - pub fn cast(&self, target: &DType) -> VortexResult { - if let DType::Extension(ext_dtype) = target { - let storage_scalar = self.cast_to_non_extension(ext_dtype.storage_dtype())?; - Ok(Scalar::extension_ref(ext_dtype.clone(), storage_scalar)) - } else { - self.cast_to_non_extension(target) - } + /// Returns the scalar value, consuming the Scalar. + pub fn into_value(self) -> Option { + self.value } +} - fn cast_to_non_extension(&self, target: &DType) -> VortexResult { - assert!( - !matches!(target, DType::Extension(..)), - "cast_to_non_extension must not be called with an Extension dtype (got {target})", - ); - - if self.is_null() { - if target.is_nullable() { - return Ok(Scalar::new(target.clone(), self.value.clone())); +/// Check if the given ScalarValue is compatible with the given DType. +fn is_compatible(dtype: &DType, value: Option<&ScalarValue>) -> bool { + let Some(value) = value else { + return dtype.is_nullable(); + }; + + match dtype { + DType::Null => false, + DType::Bool(_) => matches!(value, ScalarValue::Bool(_)), + DType::Primitive(ptype, _) => { + if let ScalarValue::Primitive(pvalue) = value { + pvalue.ptype() == *ptype + } else { + false } - - vortex_bail!("Cannot cast null to {target}: target type is non-nullable") } - - match &self.dtype { - DType::Null => unreachable!(), // Handled by `if self.is_null()` case. - DType::Bool(_) => self.as_bool().cast(target), - DType::Primitive(..) => self.as_primitive().cast(target), - DType::Decimal(..) => self.as_decimal().cast(target), - DType::Utf8(_) => self.as_utf8().cast(target), - DType::Binary(_) => self.as_binary().cast(target), - DType::Struct(..) => self.as_struct().cast(target), - DType::List(..) | DType::FixedSizeList(..) => self.as_list().cast(target), - DType::Extension(..) => self.as_extension().cast(target), - } - } - - /// Converts the scalar to have a nullable version of its data type. - pub fn into_nullable(self) -> Self { - Self { - dtype: self.dtype.as_nullable(), - value: self.value, - } - } - - /// Returns the size of the scalar in bytes, uncompressed. - pub fn nbytes(&self) -> usize { - match self.dtype() { - DType::Null => 0, - DType::Bool(_) => 1, - DType::Primitive(ptype, _) => ptype.byte_width(), - DType::Decimal(dt, _) => { - if dt.precision() <= i128::MAX_PRECISION { - size_of::() - } else { - size_of::() - } + DType::Decimal(dec_dtype, _) => { + if let ScalarValue::Decimal(dvalue) = value { + dvalue + .fits_in_precision(*dec_dtype) + // FIXME(ngates): why the option? + .vortex_expect("Failed to check decimal precision compatibility") + } else { + false } - DType::Binary(_) | DType::Utf8(_) => self - .value() - .as_buffer() - .ok() - .flatten() - .map_or(0, |s| s.len()), - DType::Struct(..) => self - .as_struct() - .fields() - .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::()) - .unwrap_or_default(), - DType::List(..) | DType::FixedSizeList(..) => self - .as_list() - .elements() - .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::()) - .unwrap_or_default(), - DType::Extension(_) => self.as_extension().storage().nbytes(), } - } - - /// Creates a "zero"-value scalar value for the given data type. - /// - /// For nullable types the zero value is the underlying `DType`'s zero value. - /// - /// # Zero Values - /// - /// Here is the list of zero values for each [`DType`] (when the [`DType`] is non-nullable): - /// - `Bool`: `false` - /// - `Primitive`: `0` - /// - `Decimal`: `0` - /// - `Utf8`: `""` - /// - `Binary`: An empty buffer - /// - `List`: An empty list - /// - `FixedSizeList`: A list (with correct size) of zero values, which is determined by the - /// element [`DType`] - /// - `Struct`: A struct where each field has a zero value, which is determined by the field - /// [`DType`] - /// - `Extension`: The zero value of the storage [`DType`] - /// - /// This is similar to `default_value` except in its handling of nullability. - pub fn zero_value(dtype: DType) -> Self { - match dtype { - DType::Null => Self::null(dtype), - DType::Bool(nullability) => Self::bool(false, nullability), - DType::Primitive(pt, nullability) => { - Self::primitive_value(PValue::zero(pt), pt, nullability) - } - DType::Decimal(dt, nullability) => { - Self::decimal(DecimalValue::from(0i8), dt, nullability) - } - DType::Utf8(nullability) => Self::utf8("", nullability), - DType::Binary(nullability) => Self::binary(Buffer::empty(), nullability), - DType::List(edt, nullability) => Self::list(edt, vec![], nullability), - DType::FixedSizeList(edt, size, nullability) => { - let elements = (0..size) - .map(|_| Scalar::zero_value(edt.as_ref().clone())) - .collect(); - Self::fixed_size_list(edt, elements, nullability) - } - DType::Struct(sf, nullability) => { - let fields: Vec<_> = sf.fields().map(Scalar::zero_value).collect(); - Self::struct_(DType::Struct(sf, nullability), fields) - } - DType::Extension(dt) => { - let scalar = Self::zero_value(dt.storage_dtype().clone()); - Self::extension_ref(dt, scalar) + DType::Utf8(_) => matches!(value, ScalarValue::Utf8(_)), + DType::Binary(_) => matches!(value, ScalarValue::Binary(_)), + DType::List(elem_dtype, _) => { + if let ScalarValue::List(elements) = value { + elements + .iter() + .all(|element| is_compatible(elem_dtype.as_ref(), element.as_ref())) + } else { + false } } - } - - /// Returns true if the scalar is a zero value i.e., equal to a scalar returned from the ` zero_value ` method. - pub fn is_zero(&self) -> bool { - match self.dtype() { - DType::Null => true, - DType::Bool(_) => self.as_bool().value() == Some(false), - DType::Primitive(pt, _) => self.as_primitive().pvalue() == Some(PValue::zero(*pt)), - DType::Decimal(..) => { - self.as_decimal().decimal_value() == Some(DecimalValue::from(0i8)) + DType::FixedSizeList(elem_dtype, size, _) => { + if let ScalarValue::List(elements) = value { + if elements.len() != *size as usize { + return false; + } + elements + .iter() + .all(|element| is_compatible(elem_dtype.as_ref(), element.as_ref())) + } else { + false } - DType::Utf8(_) => self - .as_utf8() - .value() - .map(|v| v.is_empty()) - .unwrap_or(false), - DType::Binary(_) => self - .as_binary() - .value() - .map(|v| v.is_empty()) - .unwrap_or(false), - DType::Struct(..) => self - .as_struct() - .fields() - .map(|mut sf| sf.all(|f| f.is_zero())) - .unwrap_or(false), - DType::List(..) => self - .as_list() - .elements() - .map(|vals| vals.is_empty()) - .unwrap_or(false), - DType::FixedSizeList(..) => self - .as_list() - .elements() - .map(|vals| vals.iter().all(|f| f.is_zero())) - .unwrap_or(false), - DType::Extension(..) => self.as_extension().storage().is_zero(), - } - } - - /// Creates a "default" scalar value for the given data type. - /// - /// For nullable types, returns null. For non-nullable types, returns an appropriate zero/empty - /// value. - /// - /// # Default Values - /// - /// Here is the list of default values for each [`DType`] (when the [`DType`] is non-nullable): - /// - /// - `Null`: `null` - /// - `Bool`: `false` - /// - `Primitive`: `0` - /// - `Decimal`: `0` - /// - `Utf8`: `""` - /// - `Binary`: An empty buffer - /// - `List`: An empty list - /// - `FixedSizeList`: A list (with correct size) of default values, which is determined by the - /// element [`DType`] - /// - `Struct`: A struct where each field has a default value, which is determined by the field - /// [`DType`] - /// - `Extension`: The default value of the storage [`DType`] - pub fn default_value(dtype: DType) -> Self { - if dtype.is_nullable() { - return Self::null(dtype); } - - match dtype { - DType::Null => Self::null(dtype), - DType::Bool(nullability) => Self::bool(false, nullability), - DType::Primitive(pt, nullability) => { - Self::primitive_value(PValue::zero(pt), pt, nullability) - } - DType::Decimal(dt, nullability) => { - Self::decimal(DecimalValue::from(0i8), dt, nullability) - } - DType::Utf8(nullability) => Self::utf8("", nullability), - DType::Binary(nullability) => Self::binary(Buffer::empty(), nullability), - DType::List(edt, nullability) => Self::list(edt, vec![], nullability), - DType::FixedSizeList(edt, size, nullability) => { - let elements = (0..size) - .map(|_| Scalar::default_value(edt.as_ref().clone())) - .collect(); - Self::fixed_size_list(edt, elements, nullability) - } - DType::Struct(sf, nullability) => { - let fields: Vec<_> = sf.fields().map(Scalar::default_value).collect(); - Self::struct_(DType::Struct(sf, nullability), fields) - } - DType::Extension(dt) => { - let scalar = Self::default_value(dt.storage_dtype().clone()); - Self::extension_ref(dt, scalar) + DType::Struct(fields, _) => { + if let ScalarValue::List(values) = value { + if values.len() != fields.nfields() { + return false; + } + for (field, field_value) in fields.fields().zip(values.iter()) { + if !is_compatible(&field, field_value.as_ref()) { + return false; + } + } + true + } else { + false } - } + } // DType::Extension(ext_dtype) => match value { + // ScalarValue::Extension(ext_scalar) => ext_scalar.id() == ext_dtype.id(), + // _ => false, + // }, } } -/// This implementation block contains only `TryFrom` and `From` wrappers (`as_something`). +/// Scalar downcasing methods impl Scalar { - /// Returns a view of the scalar as a boolean scalar. - /// - /// # Panics - /// - /// Panics if the scalar is not a boolean type. + /// Converts the Scalar into a BoolScalar, panicking if the conversion fails. pub fn as_bool(&self) -> BoolScalar<'_> { - BoolScalar::try_from(self).vortex_expect("Failed to convert scalar to bool") + self.as_bool_opt() + .vortex_expect("Scalar is not a BoolScalar") } - /// Returns a view of the scalar as a boolean scalar if it has a boolean type. + /// Attempts to convert the Scalar into a BoolScalar. pub fn as_bool_opt(&self) -> Option> { - matches!(self.dtype, DType::Bool(..)).then(|| self.as_bool()) + let DType::Bool(n) = &self.dtype else { + return None; + }; + Some(BoolScalar { + nullability: *n, + value: match &self.value { + None => None, + Some(ScalarValue::Bool(b)) => Some(*b), + _ => unreachable!(), + }, + _marker: Default::default(), + }) } - /// Returns a view of the scalar as a primitive scalar. - /// - /// # Panics - /// - /// Panics if the scalar is not a primitive type. pub fn as_primitive(&self) -> PrimitiveScalar<'_> { - PrimitiveScalar::try_from(self).vortex_expect("Failed to convert scalar to primitive") + self.as_primitive_opt() + .vortex_expect("Scalar is not a PrimitiveScalar") } - /// Returns a view of the scalar as a primitive scalar if it has a primitive type. pub fn as_primitive_opt(&self) -> Option> { - matches!(self.dtype, DType::Primitive(..)).then(|| self.as_primitive()) + let DType::Primitive(ptype, n) = &self.dtype else { + return None; + }; + Some(PrimitiveScalar { + ptype: *ptype, + nullability: *n, + pvalue: match &self.value { + None => None, + Some(ScalarValue::Primitive(p)) => Some(p), + _ => unreachable!(), + }, + }) } - /// Returns a view of the scalar as a decimal scalar. - /// - /// # Panics - /// - /// Panics if the scalar is not a decimal type. pub fn as_decimal(&self) -> DecimalScalar<'_> { - DecimalScalar::try_from(self).vortex_expect("Failed to convert scalar to decimal") + self.as_decimal_opt() + .vortex_expect("Scalar is not a DecimalScalar") } - /// Returns a view of the scalar as a decimal scalar if it has a decimal type. pub fn as_decimal_opt(&self) -> Option> { - matches!(self.dtype, DType::Decimal(..)).then(|| self.as_decimal()) + let DType::Decimal(dec_dtype, n) = &self.dtype else { + return None; + }; + Some(DecimalScalar { + decimal_type: dec_dtype, + nullability: *n, + dvalue: match &self.value { + None => None, + Some(ScalarValue::Decimal(d)) => Some(d), + _ => unreachable!(), + }, + }) } - /// Returns a view of the scalar as a UTF-8 string scalar. - /// - /// # Panics - /// - /// Panics if the scalar is not a UTF-8 type. pub fn as_utf8(&self) -> Utf8Scalar<'_> { - Utf8Scalar::try_from(self).vortex_expect("Failed to convert scalar to utf8") + self.as_utf8_opt() + .vortex_expect("Scalar is not a Utf8Scalar") } - /// Returns a view of the scalar as a UTF-8 string scalar if it has a UTF-8 type. pub fn as_utf8_opt(&self) -> Option> { - matches!(self.dtype, DType::Utf8(..)).then(|| self.as_utf8()) + let DType::Utf8(n) = &self.dtype else { + return None; + }; + Some(Utf8Scalar { + nullability: *n, + value: match &self.value { + None => None, + Some(ScalarValue::Utf8(b)) => Some(b), + _ => unreachable!(), + }, + }) } - /// Returns a view of the scalar as a binary scalar. - /// - /// # Panics - /// - /// Panics if the scalar is not a binary type. pub fn as_binary(&self) -> BinaryScalar<'_> { - BinaryScalar::try_from(self).vortex_expect("Failed to convert scalar to binary") + self.as_binary_opt() + .vortex_expect("Scalar is not a BinaryScalar") } - /// Returns a view of the scalar as a binary scalar if it has a binary type. pub fn as_binary_opt(&self) -> Option> { - matches!(self.dtype, DType::Binary(..)).then(|| self.as_binary()) - } - - /// Returns a view of the scalar as a struct scalar. - /// - /// # Panics - /// - /// Panics if the scalar is not a struct type. - pub fn as_struct(&self) -> StructScalar<'_> { - StructScalar::try_from(self).vortex_expect("Failed to convert scalar to struct") - } - - /// Returns a view of the scalar as a struct scalar if it has a struct type. - pub fn as_struct_opt(&self) -> Option> { - matches!(self.dtype, DType::Struct(..)).then(|| self.as_struct()) + let DType::Binary(n) = &self.dtype else { + return None; + }; + Some(BinaryScalar { + nullability: *n, + value: match &self.value { + None => None, + Some(ScalarValue::Binary(b)) => Some(b), + _ => unreachable!(), + }, + }) } - /// Returns a view of the scalar as a list scalar. - /// - /// Note that we use [`ListScalar`] to represent **both** [`DType::List`] and - /// [`DType::FixedSizeList`]. - /// - /// # Panics - /// - /// Panics if the scalar is not a list type. pub fn as_list(&self) -> ListScalar<'_> { - ListScalar::try_from(self).vortex_expect("Failed to convert scalar to list") + self.as_list_opt() + .vortex_expect("Scalar is not a ListScalar") } - /// Returns a view of the scalar as a list scalar if it has a list type. - /// - /// Note that we use [`ListScalar`] to represent **both** [`DType::List`] and - /// [`DType::FixedSizeList`]. pub fn as_list_opt(&self) -> Option> { - matches!(self.dtype, DType::List(..) | DType::FixedSizeList(..)).then(|| self.as_list()) - } - - /// Returns a view of the scalar as an extension scalar. - /// - /// # Panics - /// - /// Panics if the scalar is not an extension type. - pub fn as_extension(&self) -> ExtScalar<'_> { - ExtScalar::try_from(self).vortex_expect("Failed to convert scalar to extension") + let DType::List(element_dtype, n) = &self.dtype else { + return None; + }; + Some(ListScalar { + element_dtype, + nullability: *n, + elements: match &self.value { + None => None, + Some(ScalarValue::List(e)) => Some(e.as_slice()), + _ => unreachable!(), + }, + }) } - /// Returns a view of the scalar as an extension scalar if it has an extension type. - pub fn as_extension_opt(&self) -> Option> { - matches!(self.dtype, DType::Extension(..)).then(|| self.as_extension()) + pub fn as_fixed_size_list(&self) -> FixedSizeListScalar<'_> { + self.as_fixed_size_list_opt() + .vortex_expect("Scalar is not a FixedSizeListScalar") } -} -/// It is common to represent a nullable type `T` as an `Option`, so we implement a blanket -/// implementation for all `Option` to simply be a nullable `T`. -impl From> for Scalar -where - T: NativeDType, - Scalar: From, -{ - /// A blanket implementation for all `Option`. - fn from(value: Option) -> Self { - value - .map(Scalar::from) - .map(|x| x.into_nullable()) - .unwrap_or_else(|| Scalar { - dtype: T::dtype().as_nullable(), - value: ScalarValue(InnerScalarValue::Null), - }) + pub fn as_fixed_size_list_opt(&self) -> Option> { + let DType::FixedSizeList(element_dtype, element_size, n) = &self.dtype else { + return None; + }; + Some(FixedSizeListScalar { + list_size: *element_size, + element_dtype, + nullability: *n, + elements: match &self.value { + None => None, + Some(ScalarValue::List(e)) => Some(e.as_slice()), + _ => unreachable!(), + }, + }) } -} -impl From> for Scalar -where - T: NativeDType, - Scalar: From, -{ - /// Converts a vector into a `Scalar` (where the value is a `ListScalar`). - fn from(vec: Vec) -> Self { - Scalar { - dtype: DType::List(Arc::from(T::dtype()), Nullability::NonNullable), - value: ScalarValue::from(vec), - } + pub fn as_struct(&self) -> StructScalar<'_> { + self.as_struct_opt() + .vortex_expect("Scalar is not a StructScalar") } -} - -impl TryFrom for Vec -where - T: for<'b> TryFrom<&'b Scalar, Error = VortexError>, -{ - type Error = VortexError; - fn try_from(value: Scalar) -> Result { - Vec::try_from(&value) + pub fn as_struct_opt(&self) -> Option> { + let DType::Struct(fields, n) = &self.dtype else { + return None; + }; + Some(StructScalar { + fields, + nullability: *n, + values: match &self.value { + None => None, + Some(ScalarValue::List(s)) => Some(s.as_slice()), + _ => unreachable!(), + }, + }) } -} -impl<'a, T> TryFrom<&'a Scalar> for Vec -where - T: for<'b> TryFrom<&'b Scalar, Error = VortexError>, -{ - type Error = VortexError; - - fn try_from(value: &'a Scalar) -> Result { - ListScalar::try_from(value)? - .elements() - .ok_or_else(|| vortex_err!("Expected non-null list"))? - .into_iter() - .map(|e| T::try_from(&e)) - .collect::>>() + pub fn as_extension(&self) -> ExtensionScalar<'_> { + self.as_extension_opt() + .vortex_expect("Scalar is not an ExtScalarRef") } -} -impl PartialEq for Scalar { - fn eq(&self, other: &Self) -> bool { - if !self.dtype.eq_ignore_nullability(&other.dtype) { - return false; - } - - match self.dtype() { - DType::Null => true, - DType::Bool(_) => self.as_bool() == other.as_bool(), - DType::Primitive(..) => self.as_primitive() == other.as_primitive(), - DType::Decimal(..) => self.as_decimal() == other.as_decimal(), - DType::Utf8(_) => self.as_utf8() == other.as_utf8(), - DType::Binary(_) => self.as_binary() == other.as_binary(), - DType::Struct(..) => self.as_struct() == other.as_struct(), - DType::List(..) | DType::FixedSizeList(..) => self.as_list() == other.as_list(), - DType::Extension(_) => self.as_extension() == other.as_extension(), - } + pub fn as_extension_opt(&self) -> Option> { + let DType::Extension(ext_dtype) = &self.dtype else { + return None; + }; + Some(ExtensionScalar { + ext_dtype, + ext_scalar: match &self.value { + None => None, + Some(ScalarValue::Extension(e)) => Some(e), + _ => unreachable!(), + }, + }) } } -impl Eq for Scalar {} - impl PartialOrd for Scalar { /// Compares two scalar values for ordering. /// @@ -580,62 +364,6 @@ impl PartialOrd for Scalar { if !self.dtype().eq_ignore_nullability(other.dtype()) { return None; } - match self.dtype() { - DType::Null => Some(Ordering::Equal), - DType::Bool(_) => self.as_bool().partial_cmp(&other.as_bool()), - DType::Primitive(..) => self.as_primitive().partial_cmp(&other.as_primitive()), - DType::Decimal(..) => self.as_decimal().partial_cmp(&other.as_decimal()), - DType::Utf8(_) => self.as_utf8().partial_cmp(&other.as_utf8()), - DType::Binary(_) => self.as_binary().partial_cmp(&other.as_binary()), - DType::Struct(..) => self.as_struct().partial_cmp(&other.as_struct()), - DType::List(..) | DType::FixedSizeList(..) => { - self.as_list().partial_cmp(&other.as_list()) - } - DType::Extension(_) => self.as_extension().partial_cmp(&other.as_extension()), - } - } -} - -impl Hash for Scalar { - fn hash(&self, state: &mut H) { - match self.dtype() { - DType::Null => self.dtype().hash(state), // Hash the dtype instead of the value - DType::Bool(_) => self.as_bool().hash(state), - DType::Primitive(..) => self.as_primitive().hash(state), - DType::Decimal(..) => self.as_decimal().hash(state), - DType::Utf8(_) => self.as_utf8().hash(state), - DType::Binary(_) => self.as_binary().hash(state), - DType::Struct(..) => self.as_struct().hash(state), - DType::List(..) | DType::FixedSizeList(..) => self.as_list().hash(state), - DType::Extension(_) => self.as_extension().hash(state), - } - } -} - -impl AsRef for Scalar { - fn as_ref(&self) -> &Self { - self - } -} - -impl From> for Scalar { - fn from(pscalar: PrimitiveScalar<'_>) -> Self { - let dtype = pscalar.dtype().clone(); - let value = pscalar - .pvalue() - .map(|pvalue| ScalarValue(InnerScalarValue::Primitive(pvalue))) - .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null)); - Self::new(dtype, value) - } -} - -impl From> for Scalar { - fn from(decimal_scalar: DecimalScalar<'_>) -> Self { - let dtype = decimal_scalar.dtype().clone(); - let value = decimal_scalar - .decimal_value() - .map(|value| ScalarValue(InnerScalarValue::Decimal(value))) - .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null)); - Self::new(dtype, value) + self.value().partial_cmp(&other.value()) } } diff --git a/vortex-scalar/src/scalar_value.rs b/vortex-scalar/src/scalar_value.rs index ffa75fd80d7..1414188a408 100644 --- a/vortex-scalar/src/scalar_value.rs +++ b/vortex-scalar/src/scalar_value.rs @@ -1,288 +1,148 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use std::cmp::Ordering; use std::fmt::Display; -use std::sync::Arc; +use std::fmt::Formatter; -use bytes::BufMut; use itertools::Itertools; -use prost::Message; use vortex_buffer::BufferString; use vortex_buffer::ByteBuffer; -use vortex_dtype::NativeDType; -use vortex_dtype::i256; -use vortex_error::VortexExpect; -use vortex_error::VortexResult; -use vortex_error::vortex_bail; -use vortex_error::vortex_err; -use vortex_proto::scalar as pb; +use vortex_error::vortex_panic; -use crate::Scalar; -use crate::decimal::DecimalValue; -use crate::pvalue::PValue; +use crate::DecimalValue; +// use crate::ExtScalarRef; +use crate::PValue; -/// Represents the internal data of a scalar value. Must be interpreted by wrapping up with a -/// [`vortex_dtype::DType`] to make a [`super::Scalar`]. -/// -/// Note that these values can be deserialized from JSON or other formats. So a [`PValue`] may not -/// have the correct width for what the [`vortex_dtype::DType`] expects. Primitive values should therefore be -/// read using [`super::PrimitiveScalar`] which will handle the conversion. -#[derive(Debug, Clone)] -pub struct ScalarValue(pub(crate) InnerScalarValue); - -/// It is common to represent a nullable type `T` as an `Option`, so we implement a blanket -/// implementation for all `Option` to simply be a nullable `T`. -impl From> for ScalarValue -where - T: NativeDType, - ScalarValue: From, -{ - fn from(value: Option) -> Self { - value - .map(ScalarValue::from) - .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null)) - } -} - -impl From> for ScalarValue -where - T: NativeDType, - Scalar: From, -{ - /// Converts a vector into a `ScalarValue` (specifically a `ListScalar`). - fn from(value: Vec) -> Self { - ScalarValue(InnerScalarValue::List( - value - .into_iter() - .map(|x| { - let scalar: Scalar = T::into(x); - scalar.into_value() - }) - .collect::>(), - )) - } -} - -#[derive(Debug, Clone)] -pub(crate) enum InnerScalarValue { - Null, +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum ScalarValue { Bool(bool), Primitive(PValue), Decimal(DecimalValue), - Buffer(Arc), - BufferString(Arc), - List(Arc<[ScalarValue]>), + Utf8(BufferString), + Binary(ByteBuffer), + List(Vec>), + // Extension(ExtScalarRef), } impl ScalarValue { - /// Serializes the scalar value to Protocol Buffers format. - pub fn to_protobytes(&self) -> B { - let pb_scalar = pb::ScalarValue::from(self); - - let mut buf = B::default(); - pb_scalar - .encode(&mut buf) - .vortex_expect("protobuf encoding should succeed"); - buf - } - - /// Deserializes a scalar value from Protocol Buffers format. - pub fn from_protobytes(buf: &[u8]) -> VortexResult { - ScalarValue::try_from(&pb::ScalarValue::decode(buf)?) - } -} - -fn to_hex(slice: &[u8]) -> String { - slice - .iter() - .format_with("", |f, b| b(&format_args!("{f:02x}"))) - .to_string() -} - -impl Display for ScalarValue { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} - -impl Display for InnerScalarValue { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + pub fn as_bool(&self) -> bool { match self { - Self::Bool(b) => write!(f, "{b}"), - Self::Primitive(pvalue) => write!(f, "{pvalue}"), - Self::Decimal(value) => write!(f, "{value}"), - Self::Buffer(buf) => { - if buf.len() > 10 { - write!( - f, - "{}..{}", - to_hex(&buf[0..5]), - to_hex(&buf[buf.len() - 5..buf.len()]), - ) - } else { - write!(f, "{}", to_hex(buf)) - } - } - Self::BufferString(bufstr) => { - let bufstr = bufstr.as_str(); - let str_len = bufstr.chars().count(); - - if str_len > 10 { - let prefix = String::from_iter(bufstr.chars().take(5)); - let suffix = String::from_iter(bufstr.chars().skip(str_len - 5)); - - write!(f, "\"{prefix}..{suffix}\"") - } else { - write!(f, "\"{bufstr}\"") - } - } - Self::List(elems) => { - write!(f, "[{}]", elems.iter().format(",")) - } - Self::Null => write!(f, "null"), + ScalarValue::Bool(b) => *b, + _ => vortex_panic!("ScalarValue is not a Bool"), } } -} - -impl ScalarValue { - /// Creates a null scalar value. - pub const fn null() -> Self { - ScalarValue(InnerScalarValue::Null) - } - /// Returns true if this is a null value. - #[inline] - pub fn is_null(&self) -> bool { - self.0.is_null() - } - - /// Returns scalar as a null value - #[inline] - pub(crate) fn as_null(&self) -> VortexResult<()> { - self.0.as_null() - } - - /// Returns scalar as a boolean value - #[inline] - pub(crate) fn as_bool(&self) -> VortexResult> { - self.0.as_bool() - } - - /// Return scalar as a primitive value. PValues don't match dtypes but will be castable to the scalars dtype - #[inline] - pub(crate) fn as_pvalue(&self) -> VortexResult> { - self.0.as_pvalue() - } - - /// Returns scalar as a decimal value - #[inline] - pub(crate) fn as_decimal(&self) -> VortexResult> { - self.0.as_decimal() - } - - /// Returns scalar as a binary buffer - #[inline] - pub(crate) fn as_buffer(&self) -> VortexResult>> { - self.0.as_buffer() - } - - /// Returns scalar as a string buffer - #[inline] - pub(crate) fn as_buffer_string(&self) -> VortexResult>> { - self.0.as_buffer_string() - } - - /// Returns scalar as a list value - #[inline] - pub(crate) fn as_list(&self) -> VortexResult>> { - self.0.as_list() - } -} - -impl InnerScalarValue { - #[inline] - pub(crate) fn is_null(&self) -> bool { - matches!(self, InnerScalarValue::Null) + pub fn as_primitive(&self) -> &PValue { + match self { + ScalarValue::Primitive(p) => p, + _ => vortex_panic!("ScalarValue is not a Primitive"), + } } - #[inline] - pub(crate) fn as_null(&self) -> VortexResult<()> { - if matches!(self, InnerScalarValue::Null) { - Ok(()) - } else { - Err(vortex_err!("Expected a Null scalar, found {self}")) + pub fn as_decimal(&self) -> &DecimalValue { + match self { + ScalarValue::Decimal(d) => d, + _ => vortex_panic!("ScalarValue is not a Decimal"), } } - #[inline] - pub(crate) fn as_bool(&self) -> VortexResult> { + pub fn as_utf8(&self) -> &BufferString { match self { - InnerScalarValue::Null => Ok(None), - InnerScalarValue::Bool(b) => Ok(Some(*b)), - other => Err(vortex_err!("Expected a bool scalar, found {other}",)), + ScalarValue::Utf8(s) => s, + _ => vortex_panic!("ScalarValue is not a Utf8"), } } - /// FIXME(ngates): PValues are such a footgun... we should probably remove this. - /// But the other accessors can sometimes be useful? e.g. as_buffer. But maybe we just force - /// the user to switch over Utf8 and Binary and use the correct Scalar wrapper? - #[inline] - pub(crate) fn as_pvalue(&self) -> VortexResult> { + pub fn as_binary(&self) -> &ByteBuffer { match self { - InnerScalarValue::Null => Ok(None), - InnerScalarValue::Primitive(pvalue) => Ok(Some(*pvalue)), - other => Err(vortex_err!("Expected a primitive scalar, found {other}")), + ScalarValue::Binary(b) => b, + _ => vortex_panic!("ScalarValue is not a Binary"), } } - #[inline] - pub(crate) fn as_decimal(&self) -> VortexResult> { + pub fn as_list(&self) -> &[Option] { match self { - InnerScalarValue::Null => Ok(None), - InnerScalarValue::Decimal(v) => Ok(Some(*v)), - InnerScalarValue::Buffer(b) => Ok(Some(match b.len() { - 1 => DecimalValue::I8(b[0] as i8), - 2 => DecimalValue::I16(i16::from_le_bytes(b.as_slice().try_into()?)), - 4 => DecimalValue::I32(i32::from_le_bytes(b.as_slice().try_into()?)), - 8 => DecimalValue::I64(i64::from_le_bytes(b.as_slice().try_into()?)), - 16 => DecimalValue::I128(i128::from_le_bytes(b.as_slice().try_into()?)), - 32 => DecimalValue::I256(i256::from_le_bytes(b.as_slice().try_into()?)), - l => vortex_bail!("Buffer is not a decimal value length {l}"), - })), - _ => vortex_bail!("Expected a decimal scalar, found {:?}", self), + ScalarValue::List(elements) => elements, + _ => vortex_panic!("ScalarValue is not a List"), } } - #[inline] - pub(crate) fn as_buffer(&self) -> VortexResult>> { - match &self { - InnerScalarValue::Null => Ok(None), - InnerScalarValue::Buffer(b) => Ok(Some(b.clone())), - InnerScalarValue::BufferString(b) => { - Ok(Some(Arc::new(b.as_ref().clone().into_inner()))) - } - _ => Err(vortex_err!("Expected a binary scalar, found {:?}", self)), + // pub fn as_extension(&self) -> &ExtScalarRef { + // match self { + // ScalarValue::Extension(e) => e, + // _ => vortex_panic!("ScalarValue is not an Extension"), + // } + // } +} + +impl PartialOrd for ScalarValue { + fn partial_cmp(&self, other: &Self) -> Option { + match (self, other) { + (ScalarValue::Bool(a), ScalarValue::Bool(b)) => a.partial_cmp(b), + (ScalarValue::Primitive(a), ScalarValue::Primitive(b)) => a.partial_cmp(b), + (ScalarValue::Decimal(a), ScalarValue::Decimal(b)) => a.partial_cmp(b), + (ScalarValue::Utf8(a), ScalarValue::Utf8(b)) => a.partial_cmp(b), + (ScalarValue::Binary(a), ScalarValue::Binary(b)) => a.partial_cmp(b), + (ScalarValue::List(a), ScalarValue::List(b)) => a.partial_cmp(b), + // (ScalarValue::Extension(a), ScalarValue::Extension(b)) => a.partial_cmp(b), + _ => None, } } +} + +impl Display for ScalarValue { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + ScalarValue::Bool(b) => write!(f, "{}", b), + ScalarValue::Primitive(p) => write!(f, "{}", p), + ScalarValue::Decimal(d) => write!(f, "{}", d), + ScalarValue::Utf8(s) => { + let bufstr = s.as_str(); + let str_len = bufstr.chars().count(); - #[inline] - pub(crate) fn as_buffer_string(&self) -> VortexResult>> { - match &self { - InnerScalarValue::Null => Ok(None), - InnerScalarValue::Buffer(b) => { - Ok(Some(Arc::new(BufferString::try_from(b.as_ref().clone())?))) + if str_len > 10 { + let prefix = String::from_iter(bufstr.chars().take(5)); + let suffix = String::from_iter(bufstr.chars().skip(str_len - 5)); + + write!(f, "\"{prefix}..{suffix}\"") + } else { + write!(f, "\"{bufstr}\"") + } + } + ScalarValue::Binary(b) => { + if b.len() > 10 { + write!( + f, + "{}..{}", + to_hex(&b[0..5]), + to_hex(&b[b.len() - 5..b.len()]), + ) + } else { + write!(f, "{}", to_hex(b)) + } } - InnerScalarValue::BufferString(b) => Ok(Some(b.clone())), - _ => Err(vortex_err!("Expected a string scalar, found {:?}", self)), + ScalarValue::List(elements) => { + write!(f, "[")?; + for (i, element) in elements.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + match element { + None => write!(f, "null")?, + Some(e) => write!(f, "{}", e)?, + } + } + write!(f, "]") + } // + // ScalarValue::Extrension(e) => write!(f, "{}", e), } } +} - #[inline] - pub(crate) fn as_list(&self) -> VortexResult>> { - match &self { - InnerScalarValue::Null => Ok(None), - InnerScalarValue::List(l) => Ok(Some(l)), - _ => Err(vortex_err!("Expected a list scalar, found {:?}", self)), - } - } +fn to_hex(slice: &[u8]) -> String { + slice + .iter() + .format_with("", |f, b| b(&format_args!("{f:02x}"))) + .to_string() } From 251fdd272e14f68561e3a378021abb3ad17c2c6b Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Wed, 4 Feb 2026 12:11:17 -0500 Subject: [PATCH 02/22] progress towards refactoring Signed-off-by: Connor Tsui --- encodings/sparse/src/canonical.rs | 4 +- .../src/arrays/constant/vtable/canonical.rs | 2 +- vortex-array/src/builders/struct_.rs | 2 +- vortex-python/src/scalar/into_py.rs | 2 +- vortex-scalar/src/arbitrary.rs | 96 ++-- vortex-scalar/src/arrow/tests.rs | 4 +- vortex-scalar/src/binary.rs | 333 ++++++------- vortex-scalar/src/bool.rs | 23 +- vortex-scalar/src/cast.rs | 50 ++ vortex-scalar/src/display.rs | 45 +- vortex-scalar/src/lib.rs | 7 +- vortex-scalar/src/primitive.rs | 1 - vortex-scalar/src/proto.rs | 378 ++++++-------- vortex-scalar/src/scalar.rs | 462 +++++++++--------- vortex-scalar/src/struct_.rs | 95 ++-- 15 files changed, 734 insertions(+), 770 deletions(-) create mode 100644 vortex-scalar/src/cast.rs diff --git a/encodings/sparse/src/canonical.rs b/encodings/sparse/src/canonical.rs index 222c4befeb0..78a092fdbd7 100644 --- a/encodings/sparse/src/canonical.rs +++ b/encodings/sparse/src/canonical.rs @@ -368,8 +368,8 @@ fn execute_sparse_struct( // Resolution is unnecessary b/c we're just pushing the patches into the fields. unresolved_patches: &Patches, len: usize, -) -> VortexResult { - let (fill_values, top_level_fill_validity) = match fill_struct.fields() { +) -> VortexResult { + let (fill_values, top_level_fill_validity) = match fill_struct.fields_iter() { Some(fill_values) => (fill_values.collect::>(), Validity::AllValid), None => ( struct_fields diff --git a/vortex-array/src/arrays/constant/vtable/canonical.rs b/vortex-array/src/arrays/constant/vtable/canonical.rs index 42c3391f6fc..60afa03a870 100644 --- a/vortex-array/src/arrays/constant/vtable/canonical.rs +++ b/vortex-array/src/arrays/constant/vtable/canonical.rs @@ -134,7 +134,7 @@ pub(crate) fn constant_canonicalize(array: &ConstantArray) -> VortexResult { let value = StructScalar::try_from(scalar).vortex_expect("must be struct"); - let fields: Vec<_> = match value.fields() { + let fields: Vec<_> = match value.fields_iter() { Some(fields) => fields .into_iter() .map(|s| ConstantArray::new(s, array.len()).into_array()) diff --git a/vortex-array/src/builders/struct_.rs b/vortex-array/src/builders/struct_.rs index ff4d5f527cf..4a8c34057ae 100644 --- a/vortex-array/src/builders/struct_.rs +++ b/vortex-array/src/builders/struct_.rs @@ -73,7 +73,7 @@ impl StructBuilder { ); } - if let Some(fields) = struct_scalar.fields() { + if let Some(fields) = struct_scalar.fields_iter() { for (builder, field) in self.builders.iter_mut().zip_eq(fields) { builder.append_scalar(&field)?; } diff --git a/vortex-python/src/scalar/into_py.rs b/vortex-python/src/scalar/into_py.rs index bbc04f20670..624eb2fe7a8 100644 --- a/vortex-python/src/scalar/into_py.rs +++ b/vortex-python/src/scalar/into_py.rs @@ -100,7 +100,7 @@ impl<'py> IntoPyObject<'py> for PyVortex> { type Error = PyErr; fn into_pyobject(self, py: Python<'py>) -> Result { - let Some(fields) = self.0.fields() else { + let Some(fields) = self.0.fields_iter() else { return Ok(py.None().into_pyobject(py)?); }; diff --git a/vortex-scalar/src/arbitrary.rs b/vortex-scalar/src/arbitrary.rs index 45e4d9e85ec..2fb93b0a379 100644 --- a/vortex-scalar/src/arbitrary.rs +++ b/vortex-scalar/src/arbitrary.rs @@ -7,7 +7,6 @@ //! It is used by the fuzzer to test the correctness of the scalar value implementation. use std::iter; -use std::sync::Arc; use arbitrary::Result; use arbitrary::Unstructured; @@ -21,56 +20,67 @@ use vortex_dtype::half::f16; use vortex_dtype::match_each_decimal_value_type; use crate::DecimalValue; -use crate::InnerScalarValue; use crate::PValue; use crate::Scalar; use crate::ScalarValue; -/// Generate an arbitrary scalar value of the given data type. +/// Generates an arbitrary [`Scalar`] of the given [`DType`]. pub fn random_scalar(u: &mut Unstructured, dtype: &DType) -> Result { - Ok(Scalar::new(dtype.clone(), random_scalar_value(u, dtype)?)) -} + // For nullable types, return null ~25% of the time. This is just to make sure we don't generate + // too few nulls. + if dtype.is_nullable() && u.ratio(1, 4)? { + return Ok(Scalar::null(dtype.clone())); + } -fn random_scalar_value(u: &mut Unstructured, dtype: &DType) -> Result { - match dtype { - DType::Null => Ok(ScalarValue(InnerScalarValue::Null)), - DType::Bool(_) => Ok(ScalarValue(InnerScalarValue::Bool(u.arbitrary()?))), - DType::Primitive(p, _) => Ok(ScalarValue(InnerScalarValue::Primitive(random_pvalue( - u, p, - )?))), - DType::Decimal(decimal_type, _) => random_decimal(u, decimal_type), - DType::Utf8(_) => Ok(ScalarValue(InnerScalarValue::BufferString(Arc::new( - BufferString::from(u.arbitrary::()?), - )))), - DType::Binary(_) => Ok(ScalarValue(InnerScalarValue::Buffer(Arc::new( - ByteBuffer::from(u.arbitrary::>()?), - )))), - DType::Struct(sdt, _) => Ok(ScalarValue(InnerScalarValue::List( - sdt.fields() - .map(|d| random_scalar_value(u, &d)) - .collect::>>()? - .into(), - ))), - DType::List(edt, _) => Ok(ScalarValue(InnerScalarValue::List( - iter::from_fn(|| { - // Creates `Some(_)` with 1/4 probability. - u.arbitrary() - .unwrap_or(false) - .then(|| random_scalar_value(u, edt)) - }) - .collect::>>()? - .into(), - ))), - DType::FixedSizeList(edt, size, _) => Ok(ScalarValue(InnerScalarValue::List( - (0..*size) - .map(|_| random_scalar_value(u, edt)) - .collect::>>()? - .into(), - ))), + Ok(match dtype { + DType::Null => Scalar::null(dtype.clone()), + DType::Bool(_) => Scalar::new_value(dtype.clone(), ScalarValue::Bool(u.arbitrary()?)), + DType::Primitive(p, _) => { + Scalar::new_value(dtype.clone(), ScalarValue::Primitive(random_pvalue(u, p)?)) + } + DType::Decimal(decimal_type, _) => { + Scalar::new_value(dtype.clone(), random_decimal(u, decimal_type)?) + } + DType::Utf8(_) => Scalar::new_value( + dtype.clone(), + ScalarValue::Utf8(BufferString::from(u.arbitrary::()?)), + ), + DType::Binary(_) => Scalar::new_value( + dtype.clone(), + ScalarValue::Binary(ByteBuffer::from(u.arbitrary::>()?)), + ), + DType::Struct(sdt, _) => Scalar::new_value( + dtype.clone(), + ScalarValue::List( + sdt.fields() + .map(|d| random_scalar(u, &d).map(|s| s.into_value())) + .collect::>>()?, + ), + ), + DType::List(edt, _) => Scalar::new_value( + dtype.clone(), + ScalarValue::List( + iter::from_fn(|| { + // Generate elements with 1/4 probability. + u.arbitrary() + .unwrap_or(false) + .then(|| random_scalar(u, edt).map(|s| s.into_value())) + }) + .collect::>>()?, + ), + ), + DType::FixedSizeList(edt, size, _) => Scalar::new_value( + dtype.clone(), + ScalarValue::List( + (0..*size) + .map(|_| random_scalar(u, edt).map(|s| s.into_value())) + .collect::>>()?, + ), + ), DType::Extension(..) => { unreachable!("Can't yet generate arbitrary scalars for ext dtype") } - } + }) } fn random_pvalue(u: &mut Unstructured, ptype: &PType) -> Result { @@ -101,5 +111,5 @@ pub fn random_decimal(u: &mut Unstructured, decimal_type: &DecimalDType) -> Resu } ); - Ok(ScalarValue(InnerScalarValue::Decimal(value))) + Ok(ScalarValue::Decimal(value)) } diff --git a/vortex-scalar/src/arrow/tests.rs b/vortex-scalar/src/arrow/tests.rs index 5b1d0390e00..b34e2f707b4 100644 --- a/vortex-scalar/src/arrow/tests.rs +++ b/vortex-scalar/src/arrow/tests.rs @@ -121,7 +121,7 @@ fn test_primitive_f64_to_arrow() { #[test] fn test_null_primitive_to_arrow() { - let scalar = Scalar::null_typed::(); + let scalar = Scalar::null(i32::dtype().as_nullable()); let result = Arc::::try_from(&scalar); assert!(result.is_ok()); } @@ -135,7 +135,7 @@ fn test_utf8_scalar_to_arrow() { #[test] fn test_null_utf8_scalar_to_arrow() { - let scalar = Scalar::null_typed::(); + let scalar = Scalar::null(String::dtype().as_nullable()); let result = Arc::::try_from(&scalar); assert!(result.is_ok()); } diff --git a/vortex-scalar/src/binary.rs b/vortex-scalar/src/binary.rs index 79d8b352ca7..3a9d3e6c423 100644 --- a/vortex-scalar/src/binary.rs +++ b/vortex-scalar/src/binary.rs @@ -3,7 +3,6 @@ use std::fmt::Display; use std::fmt::Formatter; -use std::sync::Arc; use itertools::Itertools; use vortex_buffer::ByteBuffer; @@ -15,7 +14,6 @@ use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_err; -use crate::InnerScalarValue; use crate::Scalar; use crate::ScalarValue; @@ -26,7 +24,7 @@ use crate::ScalarValue; #[derive(Debug, Clone, Hash)] pub struct BinaryScalar<'a> { dtype: &'a DType, - value: Option>, + value: Option<&'a ByteBuffer>, } impl Display for BinaryScalar<'_> { @@ -62,19 +60,28 @@ impl Ord for BinaryScalar<'_> { } } +impl<'a> TryFrom<&'a Scalar> for BinaryScalar<'a> { + type Error = VortexError; + + fn try_from(scalar: &'a Scalar) -> Result { + Self::try_new(scalar.dtype(), scalar.value()) + } +} + impl<'a> BinaryScalar<'a> { /// Creates a binary scalar from a data type and scalar value. /// /// # Errors /// /// Returns an error if the data type is not a binary type. - pub fn from_scalar_value(dtype: &'a DType, value: ScalarValue) -> VortexResult { + pub fn try_new(dtype: &'a DType, value: Option<&'a ScalarValue>) -> VortexResult { if !matches!(dtype, DType::Binary(..)) { vortex_bail!("Can only construct binary scalar from binary dtype, found {dtype}") } + Ok(Self { dtype, - value: value.as_buffer()?, + value: value.map(|value| value.as_binary()), }) } @@ -85,66 +92,69 @@ impl<'a> BinaryScalar<'a> { } /// Returns the binary value as a byte buffer, or None if null. - pub fn value(&self) -> Option { - self.value.as_ref().map(|v| v.as_ref().clone()) + pub fn to_value(&self) -> Option { + self.value.map(|v| v.clone()) } /// Returns a reference to the binary value, or None if null. /// This avoids cloning the underlying ByteBuffer. pub fn value_ref(&self) -> Option<&ByteBuffer> { - self.value.as_ref().map(|v| v.as_ref()) - } - - /// Constructs the next scalar at most `max_length` bytes that's lexicographically greater than - /// this. - /// - /// Returns None if constructing a greater value would overflow. - pub fn upper_bound(self, max_length: usize) -> Option { - if let Some(value) = self.value { - if value.len() > max_length { - let sliced = value.slice(0..max_length); - drop(value); - let mut sliced_mut = sliced.into_mut(); - for b in sliced_mut.iter_mut().rev() { - let (incr, overflow) = b.overflowing_add(1); - *b = incr; - if !overflow { - return Some(Self { - dtype: self.dtype, - value: Some(Arc::new(sliced_mut.freeze())), - }); + self.value + } + + // TODO(connor): Figure out how to deal with the lifetime. + /* + /// Constructs the next scalar at most `max_length` bytes that's lexicographically greater than + /// this. + /// + /// Returns None if constructing a greater value would overflow. + pub fn upper_bound(self, max_length: usize) -> Option { + if let Some(value) = self.value { + if value.len() > max_length { + let sliced = value.slice(0..max_length); + drop(value); + let mut sliced_mut = sliced.into_mut(); + for b in sliced_mut.iter_mut().rev() { + let (incr, overflow) = b.overflowing_add(1); + *b = incr; + if !overflow { + return Some(Self { + dtype: self.dtype, + value: Some(sliced_mut.freeze()), + }); + } } + None + } else { + Some(Self { + dtype: self.dtype, + value: Some(value), + }) } - None } else { - Some(Self { - dtype: self.dtype, - value: Some(value), - }) + Some(self) } - } else { - Some(self) } - } - /// Construct a value at most `max_length` in size that's less than ourselves. - pub fn lower_bound(self, max_length: usize) -> Self { - if let Some(value) = self.value { - if value.len() > max_length { - Self { - dtype: self.dtype, - value: Some(Arc::new(value.slice(0..max_length))), + /// Construct a value at most `max_length` in size that's less than ourselves. + pub fn lower_bound(self, max_length: usize) -> Self { + if let Some(value) = self.value { + if value.len() > max_length { + Self { + dtype: self.dtype, + value: Some(value.slice(0..max_length)), + } + } else { + Self { + dtype: self.dtype, + value: Some(value), + } } } else { - Self { - dtype: self.dtype, - value: Some(value), - } + self } - } else { - self } - } + */ pub(crate) fn cast(&self, dtype: &DType) -> VortexResult { if !matches!(dtype, DType::Binary(..)) { @@ -152,14 +162,12 @@ impl<'a> BinaryScalar<'a> { "Cannot cast binary to {dtype}: binary scalars can only be cast to binary types with different nullability" ) } - Ok(Scalar::new( + Ok(Scalar::new_value( dtype.clone(), - ScalarValue(InnerScalarValue::Buffer( - self.value - .as_ref() - .vortex_expect("nullness handled in Scalar::cast") - .clone(), - )), + ScalarValue::Binary( + self.to_value() + .vortex_expect("nullness handled in Scalar::cast"), + ), )) } @@ -177,27 +185,13 @@ impl<'a> BinaryScalar<'a> { impl Scalar { /// Creates a new binary scalar from a byte buffer. pub fn binary(buffer: impl Into, nullability: Nullability) -> Self { - Self::new( + Self::new_value( DType::Binary(nullability), - ScalarValue(InnerScalarValue::Buffer(Arc::new(buffer.into()))), + ScalarValue::Binary(buffer.into()), ) } } -impl<'a> TryFrom<&'a Scalar> for BinaryScalar<'a> { - type Error = VortexError; - - fn try_from(value: &'a Scalar) -> Result { - if !matches!(value.dtype(), DType::Binary(_)) { - vortex_bail!("Expected binary scalar, found {}", value.dtype()) - } - Ok(Self { - dtype: value.dtype(), - value: value.value().as_buffer()?, - }) - } -} - impl<'a> TryFrom<&'a Scalar> for ByteBuffer { type Error = VortexError; @@ -207,7 +201,7 @@ impl<'a> TryFrom<&'a Scalar> for ByteBuffer { .ok_or_else(|| vortex_err!("Cannot extract buffer from non-buffer scalar"))?; binary - .value() + .to_value() .ok_or_else(|| vortex_err!("Cannot extract present value from null scalar")) } } @@ -219,7 +213,7 @@ impl<'a> TryFrom<&'a Scalar> for Option { Ok(scalar .as_binary_opt() .ok_or_else(|| vortex_err!("Cannot extract buffer from non-buffer scalar"))? - .value()) + .to_value()) } } @@ -247,28 +241,22 @@ impl From<&[u8]> for Scalar { impl From for Scalar { fn from(value: ByteBuffer) -> Self { - Self::new(DType::Binary(Nullability::NonNullable), value.into()) - } -} - -impl From> for Scalar { - fn from(value: Arc) -> Self { - Self::new( + Self::new_value( DType::Binary(Nullability::NonNullable), - ScalarValue(InnerScalarValue::Buffer(value)), + ScalarValue::Binary(value), ) } } impl From<&[u8]> for ScalarValue { fn from(value: &[u8]) -> Self { - ScalarValue::from(ByteBuffer::from(value.to_vec())) + ScalarValue::Binary(ByteBuffer::from(value.to_vec())) } } impl From for ScalarValue { fn from(value: ByteBuffer) -> Self { - ScalarValue(InnerScalarValue::Buffer(Arc::new(value))) + ScalarValue::Binary(value) } } @@ -279,48 +267,51 @@ mod tests { use rstest::rstest; use vortex_buffer::buffer; use vortex_dtype::Nullability; - use vortex_error::VortexExpect; use crate::BinaryScalar; + use crate::PValue; use crate::Scalar; + use crate::ScalarValue; + + /* + #[test] + fn lower_bound() { + let binary = Scalar::binary(buffer![0u8, 5, 47, 33, 129], Nullability::NonNullable); + let expected = Scalar::binary(buffer![0u8, 5], Nullability::NonNullable); + assert_eq!( + BinaryScalar::try_from(&binary) + .vortex_expect("binary scalar conversion should succeed") + .lower_bound(2), + BinaryScalar::try_from(&expected) + .vortex_expect("binary scalar conversion should succeed") + ); + } - #[test] - fn lower_bound() { - let binary = Scalar::binary(buffer![0u8, 5, 47, 33, 129], Nullability::NonNullable); - let expected = Scalar::binary(buffer![0u8, 5], Nullability::NonNullable); - assert_eq!( - BinaryScalar::try_from(&binary) - .vortex_expect("binary scalar conversion should succeed") - .lower_bound(2), - BinaryScalar::try_from(&expected) - .vortex_expect("binary scalar conversion should succeed") - ); - } - - #[test] - fn upper_bound() { - let binary = Scalar::binary(buffer![0u8, 5, 255, 234, 23], Nullability::NonNullable); - let expected = Scalar::binary(buffer![0u8, 6, 0], Nullability::NonNullable); - assert_eq!( - BinaryScalar::try_from(&binary) - .vortex_expect("binary scalar conversion should succeed") - .upper_bound(3) - .vortex_expect("must have upper bound"), - BinaryScalar::try_from(&expected) - .vortex_expect("binary scalar conversion should succeed") - ); - } + #[test] + fn upper_bound() { + let binary = Scalar::binary(buffer![0u8, 5, 255, 234, 23], Nullability::NonNullable); + let expected = Scalar::binary(buffer![0u8, 6, 0], Nullability::NonNullable); + assert_eq!( + BinaryScalar::try_from(&binary) + .vortex_expect("binary scalar conversion should succeed") + .upper_bound(3) + .vortex_expect("must have upper bound"), + BinaryScalar::try_from(&expected) + .vortex_expect("binary scalar conversion should succeed") + ); + } - #[test] - fn upper_bound_overflow() { - let binary = Scalar::binary(buffer![255u8, 255, 255], Nullability::NonNullable); - assert!( - BinaryScalar::try_from(&binary) - .vortex_expect("binary scalar conversion should succeed") - .upper_bound(2) - .is_none() - ); - } + #[test] + fn upper_bound_overflow() { + let binary = Scalar::binary(buffer![255u8, 255, 255], Nullability::NonNullable); + assert!( + BinaryScalar::try_from(&binary) + .vortex_expect("binary scalar conversion should succeed") + .upper_bound(2) + .is_none() + ); + } + */ #[rstest] #[case(&[1u8, 2, 3], &[1u8, 2, 3], true)] @@ -366,7 +357,7 @@ mod tests { let null_binary = Scalar::null(vortex_dtype::DType::Binary(Nullability::Nullable)); let scalar = BinaryScalar::try_from(&null_binary).unwrap(); - assert!(scalar.value().is_none()); + assert!(scalar.to_value().is_none()); assert!(scalar.value_ref().is_none()); assert!(scalar.len().is_none()); assert!(scalar.is_empty().is_none()); @@ -400,8 +391,8 @@ mod tests { let value_ref = scalar.value_ref().unwrap(); assert_eq!(value_ref.as_slice(), &data); - // value should clone - let value = scalar.value().unwrap(); + // to_value should clone + let value = scalar.to_value().unwrap(); assert_eq!(value.as_slice(), &data); } @@ -418,7 +409,7 @@ mod tests { assert_eq!(result.dtype(), &DType::Binary(Nullability::Nullable)); let casted = BinaryScalar::try_from(&result).unwrap(); - assert_eq!(casted.value().unwrap().as_slice(), &[1, 2, 3]); + assert_eq!(casted.value_ref().unwrap().as_slice(), &[1, 2, 3]); } #[test] @@ -441,9 +432,9 @@ mod tests { use vortex_dtype::PType; let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let value = crate::ScalarValue(crate::InnerScalarValue::Primitive(crate::PValue::I32(42))); + let value = ScalarValue::Primitive(PValue::I32(42)); - let result = BinaryScalar::from_scalar_value(&dtype, value); + let result = BinaryScalar::try_new(&dtype, Some(&value)); assert!(result.is_err()); } @@ -456,44 +447,46 @@ mod tests { assert!(result.is_err()); } - #[test] - fn test_upper_bound_null() { - let null_binary = Scalar::null(vortex_dtype::DType::Binary(Nullability::Nullable)); - let scalar = BinaryScalar::try_from(&null_binary).unwrap(); + /* + #[test] + fn test_upper_bound_null() { + let null_binary = Scalar::null(vortex_dtype::DType::Binary(Nullability::Nullable)); + let scalar = BinaryScalar::try_from(&null_binary).unwrap(); - let result = scalar.upper_bound(10); - assert!(result.is_some()); - assert!(result.unwrap().value().is_none()); - } + let result = scalar.upper_bound(10); + assert!(result.is_some()); + assert!(result.unwrap().value().is_none()); + } - #[test] - fn test_lower_bound_null() { - let null_binary = Scalar::null(vortex_dtype::DType::Binary(Nullability::Nullable)); - let scalar = BinaryScalar::try_from(&null_binary).unwrap(); + #[test] + fn test_lower_bound_null() { + let null_binary = Scalar::null(vortex_dtype::DType::Binary(Nullability::Nullable)); + let scalar = BinaryScalar::try_from(&null_binary).unwrap(); - let result = scalar.lower_bound(10); - assert!(result.value().is_none()); - } + let result = scalar.lower_bound(10); + assert!(result.value_ref().is_none()); + } - #[test] - fn test_upper_bound_exact_length() { - let binary = Scalar::binary(buffer![1u8, 2, 3], Nullability::NonNullable); - let scalar = BinaryScalar::try_from(&binary).unwrap(); + #[test] + fn test_upper_bound_exact_length() { + let binary = Scalar::binary(buffer![1u8, 2, 3], Nullability::NonNullable); + let scalar = BinaryScalar::try_from(&binary).unwrap(); - let result = scalar.upper_bound(3); - assert!(result.is_some()); - let upper = result.unwrap(); - assert_eq!(upper.value().unwrap().as_slice(), &[1, 2, 3]); - } + let result = scalar.upper_bound(3); + assert!(result.is_some()); + let upper = result.unwrap(); + assert_eq!(upper.value_raf().unwrap().as_slice(), &[1, 2, 3]); + } - #[test] - fn test_lower_bound_exact_length() { - let binary = Scalar::binary(buffer![1u8, 2, 3], Nullability::NonNullable); - let scalar = BinaryScalar::try_from(&binary).unwrap(); + #[test] + fn test_lower_bound_exact_length() { + let binary = Scalar::binary(buffer![1u8, 2, 3], Nullability::NonNullable); + let scalar = BinaryScalar::try_from(&binary).unwrap(); - let result = scalar.lower_bound(3); - assert_eq!(result.value().unwrap().as_slice(), &[1, 2, 3]); - } + let result = scalar.lower_bound(3); + assert_eq!(result.value_ref().unwrap().as_slice(), &[1, 2, 3]); + } + */ #[test] fn test_from_slice() { @@ -505,7 +498,7 @@ mod tests { &vortex_dtype::DType::Binary(Nullability::NonNullable) ); let binary = BinaryScalar::try_from(&scalar).unwrap(); - assert_eq!(binary.value().unwrap().as_slice(), data); + assert_eq!(binary.value_ref().unwrap().as_slice(), data); } #[test] @@ -558,12 +551,10 @@ mod tests { #[test] fn test_from_arc_bytebuffer() { - use std::sync::Arc; - use vortex_buffer::ByteBuffer; let data = vec![10u8, 20, 30]; - let buffer = Arc::new(ByteBuffer::from(data.clone())); + let buffer = ByteBuffer::from(data.clone()); let scalar: Scalar = buffer.into(); assert_eq!( @@ -571,17 +562,18 @@ mod tests { &vortex_dtype::DType::Binary(Nullability::NonNullable) ); let binary = BinaryScalar::try_from(&scalar).unwrap(); - assert_eq!(binary.value().unwrap().as_slice(), &data); + assert_eq!(binary.value_ref().unwrap().as_slice(), &data); } #[test] fn test_scalar_value_from_slice() { let data: &[u8] = &[100u8, 200]; - let value: crate::ScalarValue = data.into(); + let value: ScalarValue = data.into(); - let scalar = Scalar::new(vortex_dtype::DType::Binary(Nullability::NonNullable), value); + let scalar = + Scalar::new_value(vortex_dtype::DType::Binary(Nullability::NonNullable), value); let binary = BinaryScalar::try_from(&scalar).unwrap(); - assert_eq!(binary.value().unwrap().as_slice(), data); + assert_eq!(binary.value_ref().unwrap().as_slice(), data); } #[test] @@ -590,10 +582,11 @@ mod tests { let data = vec![111u8, 222]; let buffer = ByteBuffer::from(data.clone()); - let value: crate::ScalarValue = buffer.into(); + let value: ScalarValue = buffer.into(); - let scalar = Scalar::new(vortex_dtype::DType::Binary(Nullability::NonNullable), value); + let scalar = + Scalar::new_value(vortex_dtype::DType::Binary(Nullability::NonNullable), value); let binary = BinaryScalar::try_from(&scalar).unwrap(); - assert_eq!(binary.value().unwrap().as_slice(), &data); + assert_eq!(binary.value_ref().unwrap().as_slice(), &data); } } diff --git a/vortex-scalar/src/bool.rs b/vortex-scalar/src/bool.rs index 0ef1acf2e9f..09bd73d8524 100644 --- a/vortex-scalar/src/bool.rs +++ b/vortex-scalar/src/bool.rs @@ -14,7 +14,6 @@ use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_err; -use crate::InnerScalarValue; use crate::Scalar; use crate::ScalarValue; @@ -91,22 +90,14 @@ impl<'a> BoolScalar<'a> { /// Converts this boolean scalar into a general scalar. pub fn into_scalar(self) -> Scalar { - Scalar::new( - self.dtype.clone(), - self.value - .map(|x| ScalarValue(InnerScalarValue::Bool(x))) - .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null)), - ) + Scalar::new(self.dtype.clone(), self.value.map(|x| ScalarValue::Bool(x))) } } impl Scalar { /// Creates a new boolean scalar with the given value and nullability. pub fn bool(value: bool, nullability: Nullability) -> Self { - Self::new( - DType::Bool(nullability), - ScalarValue(InnerScalarValue::Bool(value)), - ) + Self::new_value(DType::Bool(nullability), ScalarValue::Bool(value)) } } @@ -119,7 +110,7 @@ impl<'a> TryFrom<&'a Scalar> for BoolScalar<'a> { } Ok(Self { dtype: value.dtype(), - value: value.value().as_bool()?, + value: value.value().map(|value| value.as_bool()), }) } } @@ -159,13 +150,13 @@ impl TryFrom for Option { impl From for Scalar { fn from(value: bool) -> Self { - Self::new(DType::Bool(NonNullable), value.into()) + Self::new_value(DType::Bool(NonNullable), value.into()) } } impl From for ScalarValue { fn from(value: bool) -> Self { - ScalarValue(InnerScalarValue::Bool(value)) + ScalarValue::Bool(value) } } @@ -328,11 +319,11 @@ mod test { #[test] fn test_scalar_value_from_bool() { let value: ScalarValue = true.into(); - let scalar = Scalar::new(DType::Bool(NonNullable), value); + let scalar = Scalar::new_value(DType::Bool(NonNullable), value); assert!(bool::try_from(&scalar).unwrap()); let value: ScalarValue = false.into(); - let scalar = Scalar::new(DType::Bool(NonNullable), value); + let scalar = Scalar::new_value(DType::Bool(NonNullable), value); assert!(!bool::try_from(&scalar).unwrap()); } diff --git a/vortex-scalar/src/cast.rs b/vortex-scalar/src/cast.rs new file mode 100644 index 00000000000..33a1159210a --- /dev/null +++ b/vortex-scalar/src/cast.rs @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_dtype::DType; +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; + +use crate::Scalar; + +impl Scalar { + /// Cast this scalar to another data type. + pub fn cast(&self, dtype: &DType) -> VortexResult { + // If the types are the same, return a clone. + if self.dtype() == dtype { + return Ok(self.clone()); + } + + // Check for nullability casting. + if self.dtype().eq_ignore_nullability(dtype) { + // Cast from non-nullable to nullable or vice versa. + // The try_new with check will handle nullability checks. + return Scalar::try_new(dtype.clone(), self.value().cloned()); + } + + match (self.dtype(), dtype) { + (_, DType::Null) => { + // Can cast anything to null if the value is null. + if self.value().is_none() { + return Ok(Scalar::null(dtype.clone())); + } + vortex_bail!("Cannot cast non-null value {} to null dtype", self); + } + _ => { + vortex_bail!( + "Casting scalar from {} to {} is not supported", + self.dtype(), + dtype + ); + } + } + } + + /// Cast the scalar into a nullable version of its current type. + pub fn into_nullable(self) -> Scalar { + let (dtype, value) = self.into_parts(); + Self::try_new(dtype.as_nullable(), value) + .vortex_expect("Casting to nullable should always succeed") + } +} diff --git a/vortex-scalar/src/display.rs b/vortex-scalar/src/display.rs index a5af9ebbc07..bf0c2014703 100644 --- a/vortex-scalar/src/display.rs +++ b/vortex-scalar/src/display.rs @@ -19,7 +19,10 @@ impl Display for Scalar { DType::Binary(_) => write!(f, "{}", self.as_binary()), DType::Struct(..) => write!(f, "{}", self.as_struct()), DType::List(..) | DType::FixedSizeList(..) => write!(f, "{}", self.as_list()), - DType::Extension(_) => write!(f, "{}", self.as_extension()), + DType::Extension(_) => { + todo!() + // write!(f, "{}", self.as_extension()) + } } } } @@ -29,6 +32,7 @@ mod tests { use vortex_buffer::ByteBuffer; use vortex_dtype::DType; use vortex_dtype::FieldName; + use vortex_dtype::NativeDType; use vortex_dtype::Nullability::NonNullable; use vortex_dtype::Nullability::Nullable; use vortex_dtype::PType; @@ -38,7 +42,6 @@ mod tests { use vortex_dtype::datetime::TimeUnit; use vortex_dtype::datetime::Timestamp; - use crate::InnerScalarValue; use crate::PValue; use crate::Scalar; use crate::ScalarValue; @@ -129,7 +132,7 @@ mod tests { assert_eq!( format!( "{}", - Scalar::struct_(dtype(), vec![Scalar::null_typed::()]) + Scalar::struct_(dtype(), vec![Scalar::null(u32::dtype().as_nullable())]) ), "{foo: null}" ); @@ -204,9 +207,9 @@ mod tests { assert_eq!( format!( "{}", - Scalar::new( + Scalar::new_value( dtype(), - ScalarValue(InnerScalarValue::Primitive(PValue::I32(3 * MINUTES + 25))) + ScalarValue::Primitive(PValue::I32(3 * MINUTES + 25)) ) ), "00:03:25" @@ -224,10 +227,7 @@ mod tests { assert_eq!( format!( "{}", - Scalar::new( - dtype(), - ScalarValue(InnerScalarValue::Primitive(PValue::I32(25))) - ) + Scalar::new_value(dtype(), ScalarValue::Primitive(PValue::I32(25))) ), "1970-01-26" ); @@ -235,10 +235,7 @@ mod tests { assert_eq!( format!( "{}", - Scalar::new( - dtype(), - ScalarValue(InnerScalarValue::Primitive(PValue::I32(365))) - ) + Scalar::new_value(dtype(), ScalarValue::Primitive(PValue::I32(365))) ), "1971-01-01" ); @@ -246,10 +243,7 @@ mod tests { assert_eq!( format!( "{}", - Scalar::new( - dtype(), - ScalarValue(InnerScalarValue::Primitive(PValue::I32(365 * 4))) - ) + Scalar::new_value(dtype(), ScalarValue::Primitive(PValue::I32(365 * 4))) ), "1973-12-31" ); @@ -266,11 +260,9 @@ mod tests { assert_eq!( format!( "{}", - Scalar::new( + Scalar::new_value( dtype(), - ScalarValue(InnerScalarValue::Primitive(PValue::I32( - 3 * DAYS + 2 * HOURS + 5 * MINUTES + 10 - ))) + ScalarValue::Primitive(PValue::I32(3 * DAYS + 2 * HOURS + 5 * MINUTES + 10)) ) ), "1970-01-04T02:05:10" @@ -292,10 +284,7 @@ mod tests { assert_eq!( format!( "{}", - Scalar::new( - dtype(), - ScalarValue(InnerScalarValue::Primitive(PValue::I32(0))) - ) + Scalar::new_value(dtype(), ScalarValue::Primitive(PValue::I32(0))) ), "1970-01-01T10:00:00+10:00[Pacific/Guam]" ); @@ -303,11 +292,9 @@ mod tests { assert_eq!( format!( "{}", - Scalar::new( + Scalar::new_value( dtype(), - ScalarValue(InnerScalarValue::Primitive(PValue::I32( - 3 * DAYS + 2 * HOURS + 5 * MINUTES + 10 - ))) + ScalarValue::Primitive(PValue::I32(3 * DAYS + 2 * HOURS + 5 * MINUTES + 10)) ) ), "1970-01-04T12:05:10+10:00[Pacific/Guam]" diff --git a/vortex-scalar/src/lib.rs b/vortex-scalar/src/lib.rs index 2cbb62be58a..92b5d217fc9 100644 --- a/vortex-scalar/src/lib.rs +++ b/vortex-scalar/src/lib.rs @@ -18,7 +18,8 @@ mod decimal; mod display; mod extension; mod list; -mod null; +// mod null; +mod cast; mod primitive; mod proto; mod pvalue; @@ -39,5 +40,5 @@ pub use scalar_value::*; pub use struct_::*; pub use utf8::*; -#[cfg(test)] -mod tests; +// #[cfg(test)] +// mod tests; diff --git a/vortex-scalar/src/primitive.rs b/vortex-scalar/src/primitive.rs index 87dbaca11be..4a8bf36abfd 100644 --- a/vortex-scalar/src/primitive.rs +++ b/vortex-scalar/src/primitive.rs @@ -26,7 +26,6 @@ use vortex_error::VortexResult; use vortex_error::vortex_err; use vortex_error::vortex_panic; -use crate::InnerScalarValue; use crate::Scalar; use crate::ScalarValue; use crate::pvalue::CoercePValue; diff --git a/vortex-scalar/src/proto.rs b/vortex-scalar/src/proto.rs index b243005141f..9a5a3c145db 100644 --- a/vortex-scalar/src/proto.rs +++ b/vortex-scalar/src/proto.rs @@ -1,14 +1,11 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use std::sync::Arc; - use num_traits::ToBytes; use vortex_buffer::BufferString; use vortex_buffer::ByteBuffer; use vortex_dtype::DType; use vortex_dtype::half::f16; -use vortex_error::VortexError; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_err; @@ -18,7 +15,6 @@ use vortex_proto::scalar::scalar_value::Kind; use vortex_session::VortexSession; use crate::DecimalValue; -use crate::InnerScalarValue; use crate::Scalar; use crate::ScalarValue; use crate::pvalue::PValue; @@ -31,49 +27,51 @@ impl From<&Scalar> for pb::Scalar { .try_into() .vortex_expect("Failed to convert DType to proto"), ), - value: Some((value.value()).into()), + value: Some(scalar_value_to_proto(value.value())), } } } -impl From<&ScalarValue> for pb::ScalarValue { - fn from(value: &ScalarValue) -> Self { - match value { - ScalarValue(InnerScalarValue::Null) => pb::ScalarValue { - kind: Some(Kind::NullValue(0)), - }, - ScalarValue(InnerScalarValue::Bool(v)) => pb::ScalarValue { - kind: Some(Kind::BoolValue(*v)), - }, - ScalarValue(InnerScalarValue::Primitive(v)) => v.into(), - ScalarValue(InnerScalarValue::Decimal(v)) => { - let inner_value = match v { - DecimalValue::I8(v) => v.to_le_bytes().to_vec(), - DecimalValue::I16(v) => v.to_le_bytes().to_vec(), - DecimalValue::I32(v) => v.to_le_bytes().to_vec(), - DecimalValue::I64(v) => v.to_le_bytes().to_vec(), - DecimalValue::I128(v128) => v128.to_le_bytes().to_vec(), - DecimalValue::I256(v256) => v256.to_le_bytes().to_vec(), - }; - - pb::ScalarValue { - kind: Some(Kind::BytesValue(inner_value)), - } +fn scalar_value_to_proto(value: Option<&ScalarValue>) -> pb::ScalarValue { + let Some(value) = value else { + return pb::ScalarValue { + // TODO(connor): Document why there is a value here??? + kind: Some(Kind::NullValue(0)), + }; + }; + + match value { + ScalarValue::Bool(v) => pb::ScalarValue { + kind: Some(Kind::BoolValue(*v)), + }, + ScalarValue::Primitive(v) => v.into(), + ScalarValue::Decimal(v) => { + let inner_value = match v { + DecimalValue::I8(v) => v.to_le_bytes().to_vec(), + DecimalValue::I16(v) => v.to_le_bytes().to_vec(), + DecimalValue::I32(v) => v.to_le_bytes().to_vec(), + DecimalValue::I64(v) => v.to_le_bytes().to_vec(), + DecimalValue::I128(v128) => v128.to_le_bytes().to_vec(), + DecimalValue::I256(v256) => v256.to_le_bytes().to_vec(), + }; + + pb::ScalarValue { + kind: Some(Kind::BytesValue(inner_value)), } - ScalarValue(InnerScalarValue::Buffer(v)) => pb::ScalarValue { - kind: Some(Kind::BytesValue(v.as_slice().to_vec())), - }, - ScalarValue(InnerScalarValue::BufferString(v)) => pb::ScalarValue { - kind: Some(Kind::StringValue(v.as_str().to_string())), - }, - ScalarValue(InnerScalarValue::List(v)) => { - let mut values = Vec::with_capacity(v.len()); - for elem in v.iter() { - values.push(pb::ScalarValue::from(elem)); - } - pb::ScalarValue { - kind: Some(Kind::ListValue(ListValue { values })), - } + } + ScalarValue::Binary(v) => pb::ScalarValue { + kind: Some(Kind::BytesValue(v.as_slice().to_vec())), + }, + ScalarValue::Utf8(v) => pb::ScalarValue { + kind: Some(Kind::StringValue(v.as_str().to_string())), + }, + ScalarValue::List(v) => { + let mut values = Vec::with_capacity(v.len()); + for elem in v.iter() { + values.push(scalar_value_to_proto(elem.as_ref())); + } + pb::ScalarValue { + kind: Some(Kind::ListValue(ListValue { values })), } } } @@ -120,7 +118,7 @@ impl From<&PValue> for pb::ScalarValue { } impl Scalar { - /// Creates a Scalar from its protobuf representation. + /// Creates a [`Scalar`] from its protobuf representation. pub fn from_proto(value: &pb::Scalar, session: &VortexSession) -> VortexResult { let dtype = DType::from_proto( value @@ -130,7 +128,7 @@ impl Scalar { session, )?; - let value = ScalarValue::try_from( + let value = scalar_value_from_proto( value .value .as_ref() @@ -141,40 +139,32 @@ impl Scalar { } } -impl TryFrom<&pb::ScalarValue> for ScalarValue { - type Error = VortexError; - - fn try_from(value: &pb::ScalarValue) -> Result { - let kind = value - .kind - .as_ref() - .ok_or_else(|| vortex_err!(InvalidSerde: "ScalarValue missing kind"))?; - - match kind { - Kind::NullValue(_) => Ok(ScalarValue(InnerScalarValue::Null)), - Kind::BoolValue(v) => Ok(ScalarValue(InnerScalarValue::Bool(*v))), - Kind::Int64Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::I64(*v)))), - Kind::Uint64Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::U64(*v)))), - Kind::F16Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::F16( - f16::from_bits(u16::try_from(*v).map_err(|_| { - vortex_err!("f16 bitwise representation has more than 16 bits: {}", v) - })?), - )))), - Kind::F32Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::F32(*v)))), - Kind::F64Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::F64(*v)))), - Kind::StringValue(v) => Ok(ScalarValue(InnerScalarValue::BufferString(Arc::new( - BufferString::from(v.clone()), - )))), - Kind::BytesValue(v) => Ok(ScalarValue(InnerScalarValue::Buffer(Arc::new( - ByteBuffer::from(v.clone()), - )))), - Kind::ListValue(v) => { - let mut values = Vec::with_capacity(v.values.len()); - for elem in v.values.iter() { - values.push(elem.try_into()?); - } - Ok(ScalarValue(InnerScalarValue::List(values.into()))) +fn scalar_value_from_proto(value: &pb::ScalarValue) -> VortexResult> { + let kind = value + .kind + .as_ref() + .ok_or_else(|| vortex_err!(InvalidSerde: "ScalarValue missing kind"))?; + + match kind { + Kind::NullValue(_) => Ok(None), + Kind::BoolValue(v) => Ok(Some(ScalarValue::Bool(*v))), + Kind::Int64Value(v) => Ok(Some(ScalarValue::Primitive(PValue::I64(*v)))), + Kind::Uint64Value(v) => Ok(Some(ScalarValue::Primitive(PValue::U64(*v)))), + Kind::F16Value(v) => Ok(Some(ScalarValue::Primitive(PValue::F16(f16::from_bits( + u16::try_from(*v).map_err(|_| { + vortex_err!("f16 bitwise representation has more than 16 bits: {}", v) + })?, + ))))), + Kind::F32Value(v) => Ok(Some(ScalarValue::Primitive(PValue::F32(*v)))), + Kind::F64Value(v) => Ok(Some(ScalarValue::Primitive(PValue::F64(*v)))), + Kind::StringValue(v) => Ok(Some(ScalarValue::Utf8(BufferString::from(v.clone())))), + Kind::BytesValue(v) => Ok(Some(ScalarValue::Binary(ByteBuffer::from(v.clone())))), + Kind::ListValue(v) => { + let mut values = Vec::with_capacity(v.values.len()); + for elem in v.values.iter() { + values.push(scalar_value_from_proto(elem)?); } + Ok(Some(ScalarValue::List(values))) } } } @@ -183,29 +173,27 @@ impl TryFrom<&pb::ScalarValue> for ScalarValue { mod tests { use std::sync::Arc; - use rstest::rstest; use vortex_buffer::BufferString; use vortex_dtype::DType; - use vortex_dtype::DecimalDType; - use vortex_dtype::FieldDType; use vortex_dtype::Nullability; use vortex_dtype::PType; - use vortex_dtype::StructFields; use vortex_dtype::half::f16; - use vortex_dtype::i256; use vortex_error::vortex_panic; use vortex_proto::scalar as pb; + use vortex_session::VortexSession; use super::*; - use crate::InnerScalarValue; use crate::Scalar; use crate::ScalarValue; - use crate::tests::SESSION; + + fn session() -> VortexSession { + VortexSession::empty() + } fn round_trip(scalar: Scalar) { assert_eq!( scalar, - Scalar::from_proto(&pb::Scalar::from(&scalar), &SESSION).unwrap(), + Scalar::from_proto(&pb::Scalar::from(&scalar), &session()).unwrap(), ); } @@ -216,52 +204,47 @@ mod tests { #[test] fn test_bool() { - round_trip(Scalar::new( + round_trip(Scalar::new_value( DType::Bool(Nullability::Nullable), - ScalarValue(InnerScalarValue::Bool(true)), + ScalarValue::Bool(true), )); } #[test] fn test_primitive() { - round_trip(Scalar::new( + round_trip(Scalar::new_value( DType::Primitive(PType::I32, Nullability::Nullable), - ScalarValue(InnerScalarValue::Primitive(42i32.into())), + ScalarValue::Primitive(42i32.into()), )); } #[test] fn test_buffer() { - round_trip(Scalar::new( + round_trip(Scalar::new_value( DType::Binary(Nullability::Nullable), - ScalarValue(InnerScalarValue::Buffer(Arc::new(vec![1, 2, 3].into()))), + ScalarValue::Binary(vec![1, 2, 3].into()), )); } #[test] fn test_buffer_string() { - round_trip(Scalar::new( + round_trip(Scalar::new_value( DType::Utf8(Nullability::Nullable), - ScalarValue(InnerScalarValue::BufferString(Arc::new( - BufferString::from("hello".to_string()), - ))), + ScalarValue::Utf8(BufferString::from("hello".to_string())), )); } #[test] fn test_list() { - round_trip(Scalar::new( + round_trip(Scalar::new_value( DType::List( Arc::new(DType::Primitive(PType::I32, Nullability::Nullable)), Nullability::Nullable, ), - ScalarValue(InnerScalarValue::List( - vec![ - ScalarValue(InnerScalarValue::Primitive(42i32.into())), - ScalarValue(InnerScalarValue::Primitive(43i32.into())), - ] - .into(), - )), + ScalarValue::List(vec![ + Some(ScalarValue::Primitive(42i32.into())), + Some(ScalarValue::Primitive(43i32.into())), + ]), )); } @@ -275,86 +258,43 @@ mod tests { #[test] fn test_i8() { - round_trip(Scalar::new( + round_trip(Scalar::new_value( DType::Primitive(PType::I8, Nullability::Nullable), - ScalarValue(InnerScalarValue::Primitive(i8::MIN.into())), + ScalarValue::Primitive(i8::MIN.into()), )); - round_trip(Scalar::new( + round_trip(Scalar::new_value( DType::Primitive(PType::I8, Nullability::Nullable), - ScalarValue(InnerScalarValue::Primitive(0i8.into())), + ScalarValue::Primitive(0i8.into()), )); - round_trip(Scalar::new( + round_trip(Scalar::new_value( DType::Primitive(PType::I8, Nullability::Nullable), - ScalarValue(InnerScalarValue::Primitive(i8::MAX.into())), + ScalarValue::Primitive(i8::MAX.into()), )); } - #[rstest] - #[case(Scalar::binary(ByteBuffer::copy_from(b"hello"), Nullability::NonNullable))] - #[case(Scalar::utf8("hello", Nullability::NonNullable))] - #[case(Scalar::primitive(1u8, Nullability::NonNullable))] - #[case(Scalar::primitive( - f32::from_bits(u32::from_le_bytes([0xFFu8, 0x8A, 0xF9, 0xFF])), - Nullability::NonNullable - ))] - #[case(Scalar::list(Arc::new(PType::U8.into()), vec![Scalar::primitive(1u8, Nullability::NonNullable)], Nullability::NonNullable - ))] - #[case(Scalar::struct_(DType::Struct( - StructFields::from_iter([ - ("a", FieldDType::from(DType::Primitive(PType::U32, Nullability::NonNullable))), - ("b", FieldDType::from(DType::Primitive(PType::F16, Nullability::NonNullable))), - ]), - Nullability::NonNullable), - vec![ - Scalar::primitive(23592960u32, Nullability::NonNullable), - Scalar::primitive(f16::from_f32(2.6584664e36f32), Nullability::NonNullable), - ], - ))] - #[case(Scalar::struct_(DType::Struct( - StructFields::from_iter([ - ("a", FieldDType::from(DType::Primitive(PType::U64, Nullability::NonNullable))), - ("b", FieldDType::from(DType::Primitive(PType::F32, Nullability::NonNullable))), - ("c", FieldDType::from(DType::Primitive(PType::F16, Nullability::NonNullable))), - ]), - Nullability::NonNullable), - vec![ - Scalar::primitive(415118687234u64, Nullability::NonNullable), - Scalar::primitive(2.6584664e36f32, Nullability::NonNullable), - Scalar::primitive(f16::from_f32(2.6584664e36f32), Nullability::NonNullable), - ], - ))] - #[case(Scalar::decimal( - DecimalValue::I256(i256::from_i128(12345643673471)), - DecimalDType::new(10, 2), - Nullability::NonNullable - ))] - #[case(Scalar::decimal( - DecimalValue::I16(23412), - DecimalDType::new(3, 2), - Nullability::NonNullable - ))] - fn test_scalar_value_serde_roundtrip(#[case] scalar: Scalar) { - let written = scalar.value().to_protobytes::>(); - let scalar_read_back = ScalarValue::from_protobytes(&written).unwrap(); - assert_eq!( - Scalar::new(scalar.dtype().clone(), scalar_read_back), - scalar - ); - } + // TODO(ct): Re-enable these tests once ScalarValue has protobytes support. + // #[rstest] + // #[case(Scalar::binary(ByteBuffer::copy_from(b"hello"), Nullability::NonNullable))] + // #[case(Scalar::utf8("hello", Nullability::NonNullable))] + // #[case(Scalar::primitive(1u8, Nullability::NonNullable))] + // fn test_scalar_value_serde_roundtrip(#[case] scalar: Scalar) { + // ... + // } #[test] fn test_backcompat_f16_serialized_as_u64() { - // Note that this is a backwards compatibility test for poor design in the previous implementation. - // Previously, f16 ScalarValues were serialized as `pb::ScalarValue::Uint64Value(v.to_bits() as u64)`. + // Note that this is a backwards compatibility test for poor design in the previous + // implementation. Previously, f16 ScalarValues were serialized as + // `pb::ScalarValue::Uint64Value(v.to_bits() as u64)`. let pb_scalar_value = pb::ScalarValue { kind: Some(Kind::Uint64Value(f16::from_f32(0.42).to_bits() as u64)), }; - let scalar_value = ScalarValue::try_from(&pb_scalar_value).unwrap(); + let scalar_value = scalar_value_from_proto(&pb_scalar_value).unwrap(); assert_eq!( - scalar_value.as_pvalue().unwrap(), - Some(PValue::U64(14008u64)) + scalar_value.as_ref().map(|v| v.as_primitive()), + Some(&PValue::U64(14008u64)) ); let scalar = Scalar::new( @@ -370,7 +310,7 @@ mod tests { #[test] fn test_scalar_value_direct_roundtrip_f16() { - // Test that ScalarValue with f16 roundtrips correctly without going through Scalar + // Test that ScalarValue with f16 roundtrips correctly without going through Scalar. let f16_values = vec![ f16::from_f32(0.0), f16::from_f32(1.0), @@ -384,17 +324,17 @@ mod tests { ]; for f16_val in f16_values { - let scalar_value = ScalarValue(InnerScalarValue::Primitive(PValue::F16(f16_val))); - let written = scalar_value.to_protobytes::>(); - let read_back = ScalarValue::from_protobytes(&written).unwrap(); + let scalar_value = ScalarValue::Primitive(PValue::F16(f16_val)); + let pb_value = scalar_value_to_proto(Some(&scalar_value)); + let read_back = scalar_value_from_proto(&pb_value).unwrap(); - match (&scalar_value.0, &read_back.0) { + match (&scalar_value, read_back.as_ref()) { ( - InnerScalarValue::Primitive(PValue::F16(original)), - InnerScalarValue::Primitive(PValue::F16(roundtripped)), + ScalarValue::Primitive(PValue::F16(original)), + Some(ScalarValue::Primitive(PValue::F16(roundtripped))), ) => { if original.is_nan() && roundtripped.is_nan() { - // NaN values are equal for our purposes + // NaN values are equal for our purposes. continue; } assert_eq!( @@ -413,55 +353,43 @@ mod tests { #[test] fn test_scalar_value_direct_roundtrip_preserves_values() { - // Test that ScalarValue roundtripping preserves values (but not necessarily exact types) - // Note: Proto encoding consolidates integer types (u8/u16/u32 → u64, i8/i16/i32 → i64) - - // Test cases that should roundtrip exactly - let exact_roundtrip_cases = vec![ - ("null", ScalarValue(InnerScalarValue::Null)), - ("bool_true", ScalarValue(InnerScalarValue::Bool(true))), - ("bool_false", ScalarValue(InnerScalarValue::Bool(false))), + // Test that ScalarValue roundtripping preserves values (but not necessarily exact types). + // Note: Proto encoding consolidates integer types (u8/u16/u32 → u64, i8/i16/i32 → i64). + + // Test cases that should roundtrip exactly. + let exact_roundtrip_cases: Vec<(&str, Option)> = vec![ + ("null", None), + ("bool_true", Some(ScalarValue::Bool(true))), + ("bool_false", Some(ScalarValue::Bool(false))), ( "u64", - ScalarValue(InnerScalarValue::Primitive(PValue::U64( - 18446744073709551615, - ))), + Some(ScalarValue::Primitive(PValue::U64(18446744073709551615))), ), ( "i64", - ScalarValue(InnerScalarValue::Primitive(PValue::I64( - -9223372036854775808, - ))), + Some(ScalarValue::Primitive(PValue::I64(-9223372036854775808))), ), ( "f32", - ScalarValue(InnerScalarValue::Primitive(PValue::F32( - std::f32::consts::E, - ))), + Some(ScalarValue::Primitive(PValue::F32(std::f32::consts::E))), ), ( "f64", - ScalarValue(InnerScalarValue::Primitive(PValue::F64( - std::f64::consts::PI, - ))), + Some(ScalarValue::Primitive(PValue::F64(std::f64::consts::PI))), ), ( "string", - ScalarValue(InnerScalarValue::BufferString(Arc::new( - BufferString::from("test"), - ))), + Some(ScalarValue::Utf8(BufferString::from("test"))), ), ( "bytes", - ScalarValue(InnerScalarValue::Buffer(Arc::new( - vec![1, 2, 3, 4, 5].into(), - ))), + Some(ScalarValue::Binary(vec![1, 2, 3, 4, 5].into())), ), ]; for (name, value) in exact_roundtrip_cases { - let written = value.to_protobytes::>(); - let read_back = ScalarValue::from_protobytes(&written).unwrap(); + let pb_value = scalar_value_to_proto(value.as_ref()); + let read_back = scalar_value_from_proto(&pb_value).unwrap(); let original_debug = format!("{value:?}"); let roundtrip_debug = format!("{read_back:?}"); @@ -471,32 +399,24 @@ mod tests { ); } - // Test cases where type changes but value is preserved - // Unsigned integers consolidate to U64 + // Test cases where type changes but value is preserved. + // Unsigned integers consolidate to U64. let unsigned_cases = vec![ - ( - "u8", - ScalarValue(InnerScalarValue::Primitive(PValue::U8(255))), - 255u64, - ), - ( - "u16", - ScalarValue(InnerScalarValue::Primitive(PValue::U16(65535))), - 65535u64, - ), + ("u8", ScalarValue::Primitive(PValue::U8(255)), 255u64), + ("u16", ScalarValue::Primitive(PValue::U16(65535)), 65535u64), ( "u32", - ScalarValue(InnerScalarValue::Primitive(PValue::U32(4294967295))), + ScalarValue::Primitive(PValue::U32(4294967295)), 4294967295u64, ), ]; for (name, value, expected) in unsigned_cases { - let written = value.to_protobytes::>(); - let read_back = ScalarValue::from_protobytes(&written).unwrap(); + let pb_value = scalar_value_to_proto(Some(&value)); + let read_back = scalar_value_from_proto(&pb_value).unwrap(); - match &read_back.0 { - InnerScalarValue::Primitive(PValue::U64(v)) => { + match read_back.as_ref() { + Some(ScalarValue::Primitive(PValue::U64(v))) => { assert_eq!( *v, expected, "ScalarValue {name} value not preserved: expected {expected}, got {v}" @@ -506,31 +426,27 @@ mod tests { } } - // Signed integers consolidate to I64 + // Signed integers consolidate to I64. let signed_cases = vec![ - ( - "i8", - ScalarValue(InnerScalarValue::Primitive(PValue::I8(-128))), - -128i64, - ), + ("i8", ScalarValue::Primitive(PValue::I8(-128)), -128i64), ( "i16", - ScalarValue(InnerScalarValue::Primitive(PValue::I16(-32768))), + ScalarValue::Primitive(PValue::I16(-32768)), -32768i64, ), ( "i32", - ScalarValue(InnerScalarValue::Primitive(PValue::I32(-2147483648))), + ScalarValue::Primitive(PValue::I32(-2147483648)), -2147483648i64, ), ]; for (name, value, expected) in signed_cases { - let written = value.to_protobytes::>(); - let read_back = ScalarValue::from_protobytes(&written).unwrap(); + let pb_value = scalar_value_to_proto(Some(&value)); + let read_back = scalar_value_from_proto(&pb_value).unwrap(); - match &read_back.0 { - InnerScalarValue::Primitive(PValue::I64(v)) => { + match read_back.as_ref() { + Some(ScalarValue::Primitive(PValue::I64(v))) => { assert_eq!( *v, expected, "ScalarValue {name} value not preserved: expected {expected}, got {v}" diff --git a/vortex-scalar/src/scalar.rs b/vortex-scalar/src/scalar.rs index 6c24cc80995..47167053399 100644 --- a/vortex-scalar/src/scalar.rs +++ b/vortex-scalar/src/scalar.rs @@ -11,47 +11,195 @@ use vortex_error::vortex_ensure; use crate::BinaryScalar; use crate::BoolScalar; use crate::DecimalScalar; -use crate::FixedSizeListScalar; +use crate::ExtScalar; +// use crate::FixedSizeListScalar; use crate::ListScalar; use crate::PrimitiveScalar; use crate::ScalarValue; use crate::StructScalar; use crate::Utf8Scalar; -use crate::extension::ExtensionScalar; +// use crate::extension::ExtensionScalar; #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct Scalar { /// The type of the scalar. dtype: DType, - /// The value of the scalar. This is `None` if the value is null, otherwise it is `Some`. + /// The value of the scalar. This is [`None`] if the value is null, otherwise it is [`Some`]. /// - /// Invariant: If the `dtype` is non-nullable, then this value _cannot_ be `None`. + /// Invariant: If the [`DType`] is non-nullable, then this value _cannot_ be [`None`]. value: Option, } impl Scalar { - /// Create a new Scalar with the given DType and value without checking compatibility. + /// Creates a new null [`Scalar`] with the given [`DType`]. /// - /// # Safety + /// # Panics /// - /// The caller must ensure that the given DType and value are compatible per the rules defined - /// in `is_compatible`. - pub unsafe fn new_unchecked(dtype: DType, value: Option) -> Self { - Self { dtype, value } + /// Panics if the given [`DType`] is non-nullable. + pub fn null(dtype: DType) -> Self { + vortex_ensure!( + dtype.is_nullable(), + "Cannot create null scalar with non-nullable dtype {dtype}" + ); + + Self { dtype, value: None } + } + + // Constructors for potentially null scalar values. + + /// Creates a new [`Scalar`] with the given [`DType`] and potentially null [`ScalarValue`]. + /// + /// # Panics + /// + /// Panics if the given [`DType`] and [`ScalarValue`] are incompatible. + pub fn new(dtype: DType, value: Option) -> Self { + Self::try_new(dtype, value).vortex_expect("Failed to create Scalar") } - /// Create a new Scalar with the given DType and value. + /// Attempts to create a new [`Scalar`] with the given [`DType`] and potentially null + /// [`ScalarValue`]. + /// + /// # Errors + /// + /// Returns an error if the given [`DType`] and [`ScalarValue`] are incompatible. pub fn try_new(dtype: DType, value: Option) -> VortexResult { vortex_ensure!( - is_compatible(&dtype, value.as_ref()), - "Incompatible dtype {} with value {}", - dtype, + Self::is_compatible(&dtype, value.as_ref()), + "Incompatible dtype {dtype} with value {}", value.map(|v| format!("{}", v)).unwrap_or_default() ); + Ok(Self { dtype, value }) } + /// Creates a new [`Scalar`] with the given [`DType`] and potentially null [`ScalarValue`] + /// without checking compatibility. + /// + /// # Safety + /// + /// The caller must ensure that the given [`DType`] and [`ScalarValue`] are compatible per the + /// rules defined in [`Self::is_compatible`]. + pub unsafe fn new_unchecked(dtype: DType, value: Option) -> Self { + Self { dtype, value } + } + + // Constructors for non-null scalar values. + + /// Creates a new [`Scalar`] with the given [`DType`] and [`ScalarValue`]. + /// + /// # Panics + /// + /// Panics if the given [`DType`] and [`ScalarValue`] are incompatible. + pub fn new_value(dtype: DType, value: ScalarValue) -> Self { + Self::try_new_value(dtype, value).vortex_expect("Failed to create Scalar") + } + + /// Attempts to create a new [`Scalar`] with the given [`DType`] and [`ScalarValue`]. + /// + /// # Errors + /// + /// Returns an error if the given [`DType`] and [`ScalarValue`] are incompatible. + pub fn try_new_value(dtype: DType, value: ScalarValue) -> VortexResult { + vortex_ensure!( + Self::is_compatible(&dtype, Some(&value)), + "Incompatible dtype {dtype} with value {value}" + ); + + Ok(Self { + dtype, + value: Some(value), + }) + } + + /// Creates a new [`Scalar`] with the given [`DType`] and [`ScalarValue`] without checking + /// compatibility. + /// + /// # Safety + /// + /// The caller must ensure that the given [`DType`] and [`ScalarValue`] are compatible per the + /// rules defined in [`Scalar::is_compatible`]. + pub unsafe fn new_value_unchecked(dtype: DType, value: ScalarValue) -> Self { + Self { + dtype, + value: Some(value), + } + } + + /// Check if the given [`ScalarValue`] is compatible with the given [`DType`]. + pub fn is_compatible(dtype: &DType, value: Option<&ScalarValue>) -> bool { + let Some(value) = value else { + return dtype.is_nullable(); + }; + + match dtype { + DType::Null => false, + DType::Bool(_) => matches!(value, ScalarValue::Bool(_)), + DType::Primitive(ptype, _) => { + if let ScalarValue::Primitive(pvalue) = value { + pvalue.ptype() == *ptype + } else { + false + } + } + DType::Decimal(dec_dtype, _) => { + if let ScalarValue::Decimal(dvalue) = value { + dvalue + .fits_in_precision(*dec_dtype) + // FIXME(ngates): why the option? + .vortex_expect("Failed to check decimal precision compatibility") + } else { + false + } + } + DType::Utf8(_) => matches!(value, ScalarValue::Utf8(_)), + DType::Binary(_) => matches!(value, ScalarValue::Binary(_)), + DType::List(elem_dtype, _) => { + if let ScalarValue::List(elements) = value { + elements + .iter() + .all(|element| Self::is_compatible(elem_dtype.as_ref(), element.as_ref())) + } else { + false + } + } + DType::FixedSizeList(elem_dtype, size, _) => { + if let ScalarValue::List(elements) = value { + if elements.len() != *size as usize { + return false; + } + elements + .iter() + .all(|element| Self::is_compatible(elem_dtype.as_ref(), element.as_ref())) + } else { + false + } + } + DType::Struct(fields, _) => { + if let ScalarValue::List(values) = value { + if values.len() != fields.nfields() { + return false; + } + for (field, field_value) in fields.fields().zip(values.iter()) { + if !Self::is_compatible(&field, field_value.as_ref()) { + return false; + } + } + true + } else { + false + } + } + DType::Extension(_ext_dtype) => { + todo!() + // match value { + // ScalarValue::Extension(ext_scalar) => ext_scalar.id() == ext_dtype.id(), + // _ => false, + // } + } + } + } + /// Returns the parts of the Scalar. pub fn into_parts(self) -> (DType, Option) { (self.dtype, self.value) @@ -78,256 +226,124 @@ impl Scalar { } } -/// Check if the given ScalarValue is compatible with the given DType. -fn is_compatible(dtype: &DType, value: Option<&ScalarValue>) -> bool { - let Some(value) = value else { - return dtype.is_nullable(); - }; - - match dtype { - DType::Null => false, - DType::Bool(_) => matches!(value, ScalarValue::Bool(_)), - DType::Primitive(ptype, _) => { - if let ScalarValue::Primitive(pvalue) = value { - pvalue.ptype() == *ptype - } else { - false - } - } - DType::Decimal(dec_dtype, _) => { - if let ScalarValue::Decimal(dvalue) = value { - dvalue - .fits_in_precision(*dec_dtype) - // FIXME(ngates): why the option? - .vortex_expect("Failed to check decimal precision compatibility") - } else { - false - } - } - DType::Utf8(_) => matches!(value, ScalarValue::Utf8(_)), - DType::Binary(_) => matches!(value, ScalarValue::Binary(_)), - DType::List(elem_dtype, _) => { - if let ScalarValue::List(elements) = value { - elements - .iter() - .all(|element| is_compatible(elem_dtype.as_ref(), element.as_ref())) - } else { - false - } - } - DType::FixedSizeList(elem_dtype, size, _) => { - if let ScalarValue::List(elements) = value { - if elements.len() != *size as usize { - return false; - } - elements - .iter() - .all(|element| is_compatible(elem_dtype.as_ref(), element.as_ref())) - } else { - false - } - } - DType::Struct(fields, _) => { - if let ScalarValue::List(values) = value { - if values.len() != fields.nfields() { - return false; - } - for (field, field_value) in fields.fields().zip(values.iter()) { - if !is_compatible(&field, field_value.as_ref()) { - return false; - } - } - true - } else { - false - } - } // DType::Extension(ext_dtype) => match value { - // ScalarValue::Extension(ext_scalar) => ext_scalar.id() == ext_dtype.id(), - // _ => false, - // }, - } -} - -/// Scalar downcasing methods +/// Scalar downcasing methods to typed views. impl Scalar { - /// Converts the Scalar into a BoolScalar, panicking if the conversion fails. + /// Returns a view of the scalar as a boolean scalar. + /// + /// # Panics + /// + /// Panics if the scalar is not a boolean type. pub fn as_bool(&self) -> BoolScalar<'_> { - self.as_bool_opt() - .vortex_expect("Scalar is not a BoolScalar") + BoolScalar::try_from(self).vortex_expect("Failed to convert scalar to bool") } - /// Attempts to convert the Scalar into a BoolScalar. + /// Returns a view of the scalar as a boolean scalar if it has a boolean type. pub fn as_bool_opt(&self) -> Option> { - let DType::Bool(n) = &self.dtype else { - return None; - }; - Some(BoolScalar { - nullability: *n, - value: match &self.value { - None => None, - Some(ScalarValue::Bool(b)) => Some(*b), - _ => unreachable!(), - }, - _marker: Default::default(), - }) + matches!(self.dtype, DType::Bool(..)).then(|| self.as_bool()) } + /// Returns a view of the scalar as a primitive scalar. + /// + /// # Panics + /// + /// Panics if the scalar is not a primitive type. pub fn as_primitive(&self) -> PrimitiveScalar<'_> { - self.as_primitive_opt() - .vortex_expect("Scalar is not a PrimitiveScalar") + PrimitiveScalar::try_from(self).vortex_expect("Failed to convert scalar to primitive") } + /// Returns a view of the scalar as a primitive scalar if it has a primitive type. pub fn as_primitive_opt(&self) -> Option> { - let DType::Primitive(ptype, n) = &self.dtype else { - return None; - }; - Some(PrimitiveScalar { - ptype: *ptype, - nullability: *n, - pvalue: match &self.value { - None => None, - Some(ScalarValue::Primitive(p)) => Some(p), - _ => unreachable!(), - }, - }) + matches!(self.dtype, DType::Primitive(..)).then(|| self.as_primitive()) } + /// Returns a view of the scalar as a decimal scalar. + /// + /// # Panics + /// + /// Panics if the scalar is not a decimal type. pub fn as_decimal(&self) -> DecimalScalar<'_> { - self.as_decimal_opt() - .vortex_expect("Scalar is not a DecimalScalar") + DecimalScalar::try_from(self).vortex_expect("Failed to convert scalar to decimal") } + /// Returns a view of the scalar as a decimal scalar if it has a decimal type. pub fn as_decimal_opt(&self) -> Option> { - let DType::Decimal(dec_dtype, n) = &self.dtype else { - return None; - }; - Some(DecimalScalar { - decimal_type: dec_dtype, - nullability: *n, - dvalue: match &self.value { - None => None, - Some(ScalarValue::Decimal(d)) => Some(d), - _ => unreachable!(), - }, - }) + matches!(self.dtype, DType::Decimal(..)).then(|| self.as_decimal()) } + /// Returns a view of the scalar as a UTF-8 string scalar. + /// + /// # Panics + /// + /// Panics if the scalar is not a UTF-8 type. pub fn as_utf8(&self) -> Utf8Scalar<'_> { - self.as_utf8_opt() - .vortex_expect("Scalar is not a Utf8Scalar") + Utf8Scalar::try_from(self).vortex_expect("Failed to convert scalar to utf8") } + /// Returns a view of the scalar as a UTF-8 string scalar if it has a UTF-8 type. pub fn as_utf8_opt(&self) -> Option> { - let DType::Utf8(n) = &self.dtype else { - return None; - }; - Some(Utf8Scalar { - nullability: *n, - value: match &self.value { - None => None, - Some(ScalarValue::Utf8(b)) => Some(b), - _ => unreachable!(), - }, - }) + matches!(self.dtype, DType::Utf8(..)).then(|| self.as_utf8()) } + /// Returns a view of the scalar as a binary scalar. + /// + /// # Panics + /// + /// Panics if the scalar is not a binary type. pub fn as_binary(&self) -> BinaryScalar<'_> { - self.as_binary_opt() - .vortex_expect("Scalar is not a BinaryScalar") + BinaryScalar::try_from(self).vortex_expect("Failed to convert scalar to binary") } + /// Returns a view of the scalar as a binary scalar if it has a binary type. pub fn as_binary_opt(&self) -> Option> { - let DType::Binary(n) = &self.dtype else { - return None; - }; - Some(BinaryScalar { - nullability: *n, - value: match &self.value { - None => None, - Some(ScalarValue::Binary(b)) => Some(b), - _ => unreachable!(), - }, - }) - } - - pub fn as_list(&self) -> ListScalar<'_> { - self.as_list_opt() - .vortex_expect("Scalar is not a ListScalar") + matches!(self.dtype, DType::Binary(..)).then(|| self.as_binary()) } - pub fn as_list_opt(&self) -> Option> { - let DType::List(element_dtype, n) = &self.dtype else { - return None; - }; - Some(ListScalar { - element_dtype, - nullability: *n, - elements: match &self.value { - None => None, - Some(ScalarValue::List(e)) => Some(e.as_slice()), - _ => unreachable!(), - }, - }) - } - - pub fn as_fixed_size_list(&self) -> FixedSizeListScalar<'_> { - self.as_fixed_size_list_opt() - .vortex_expect("Scalar is not a FixedSizeListScalar") + /// Returns a view of the scalar as a struct scalar. + /// + /// # Panics + /// + /// Panics if the scalar is not a struct type. + pub fn as_struct(&self) -> StructScalar<'_> { + StructScalar::try_from(self).vortex_expect("Failed to convert scalar to struct") } - pub fn as_fixed_size_list_opt(&self) -> Option> { - let DType::FixedSizeList(element_dtype, element_size, n) = &self.dtype else { - return None; - }; - Some(FixedSizeListScalar { - list_size: *element_size, - element_dtype, - nullability: *n, - elements: match &self.value { - None => None, - Some(ScalarValue::List(e)) => Some(e.as_slice()), - _ => unreachable!(), - }, - }) + /// Returns a view of the scalar as a struct scalar if it has a struct type. + pub fn as_struct_opt(&self) -> Option> { + matches!(self.dtype, DType::Struct(..)).then(|| self.as_struct()) } - pub fn as_struct(&self) -> StructScalar<'_> { - self.as_struct_opt() - .vortex_expect("Scalar is not a StructScalar") + /// Returns a view of the scalar as a list scalar. + /// + /// Note that we use [`ListScalar`] to represent **both** [`DType::List`] and + /// [`DType::FixedSizeList`]. + /// + /// # Panics + /// + /// Panics if the scalar is not a list type. + pub fn as_list(&self) -> ListScalar<'_> { + ListScalar::try_from(self).vortex_expect("Failed to convert scalar to list") } - pub fn as_struct_opt(&self) -> Option> { - let DType::Struct(fields, n) = &self.dtype else { - return None; - }; - Some(StructScalar { - fields, - nullability: *n, - values: match &self.value { - None => None, - Some(ScalarValue::List(s)) => Some(s.as_slice()), - _ => unreachable!(), - }, - }) + /// Returns a view of the scalar as a list scalar if it has a list type. + /// + /// Note that we use [`ListScalar`] to represent **both** [`DType::List`] and + /// [`DType::FixedSizeList`]. + pub fn as_list_opt(&self) -> Option> { + matches!(self.dtype, DType::List(..) | DType::FixedSizeList(..)).then(|| self.as_list()) } - pub fn as_extension(&self) -> ExtensionScalar<'_> { - self.as_extension_opt() - .vortex_expect("Scalar is not an ExtScalarRef") + /// Returns a view of the scalar as an extension scalar. + /// + /// # Panics + /// + /// Panics if the scalar is not an extension type. + pub fn as_extension(&self) -> ExtScalar<'_> { + ExtScalar::try_from(self).vortex_expect("Failed to convert scalar to extension") } - pub fn as_extension_opt(&self) -> Option> { - let DType::Extension(ext_dtype) = &self.dtype else { - return None; - }; - Some(ExtensionScalar { - ext_dtype, - ext_scalar: match &self.value { - None => None, - Some(ScalarValue::Extension(e)) => Some(e), - _ => unreachable!(), - }, - }) + /// Returns a view of the scalar as an extension scalar if it has an extension type. + pub fn as_extension_opt(&self) -> Option> { + matches!(self.dtype, DType::Extension(..)).then(|| self.as_extension()) } } diff --git a/vortex-scalar/src/struct_.rs b/vortex-scalar/src/struct_.rs index d8ef1af7dd9..8d6e3110d36 100644 --- a/vortex-scalar/src/struct_.rs +++ b/vortex-scalar/src/struct_.rs @@ -6,8 +6,6 @@ use std::fmt::Display; use std::fmt::Formatter; use std::hash::Hash; use std::hash::Hasher; -use std::ops::Deref; -use std::sync::Arc; use itertools::Itertools; use vortex_dtype::DType; @@ -21,7 +19,6 @@ use vortex_error::vortex_bail; use vortex_error::vortex_err; use vortex_error::vortex_panic; -use crate::InnerScalarValue; use crate::Scalar; use crate::ScalarValue; @@ -32,7 +29,7 @@ use crate::ScalarValue; #[derive(Debug, Clone)] pub struct StructScalar<'a> { dtype: &'a DType, - fields: Option<&'a Arc<[ScalarValue]>>, + pub(crate) fields: Option<&'a [Option]>, } impl Display for StructScalar<'_> { @@ -64,7 +61,7 @@ impl PartialEq for StructScalar<'_> { return false; } - match (self.fields(), other.fields()) { + match (self.fields_iter(), other.fields_iter()) { (Some(lhs), Some(rhs)) => lhs.zip(rhs).all(|(l_s, r_s)| l_s == r_s), (None, None) => true, (Some(_), None) | (None, Some(_)) => false, @@ -81,7 +78,7 @@ impl PartialOrd for StructScalar<'_> { return None; } - match (self.fields(), other.fields()) { + match (self.fields_iter(), other.fields_iter()) { (Some(lhs), Some(rhs)) => { for (l_s, r_s) in lhs.zip(rhs) { match l_s.partial_cmp(&r_s)? { @@ -103,7 +100,7 @@ impl PartialOrd for StructScalar<'_> { impl Hash for StructScalar<'_> { fn hash(&self, state: &mut H) { self.dtype.hash(state); - if let Some(fields) = self.fields() { + if let Some(fields) = self.fields_iter() { for f in fields { f.hash(state); } @@ -113,14 +110,14 @@ impl Hash for StructScalar<'_> { impl<'a> StructScalar<'a> { #[inline] - pub(crate) fn try_new(dtype: &'a DType, value: &'a ScalarValue) -> VortexResult { + pub(crate) fn try_new(dtype: &'a DType, value: Option<&'a ScalarValue>) -> VortexResult { if !matches!(dtype, DType::Struct(..)) { vortex_bail!("Expected struct scalar, found {}", dtype) } Ok(Self { dtype, - fields: value.as_list()?, + fields: value.map(|value| value.as_list()), }) } @@ -174,7 +171,7 @@ impl<'a> StructScalar<'a> { } /// Returns the fields of the struct scalar, or None if the scalar is null. - pub fn fields(&self) -> Option> { + pub fn fields_iter(&self) -> Option> { let fields = self.fields?; Some( fields @@ -184,10 +181,6 @@ impl<'a> StructScalar<'a> { ) } - pub(crate) fn field_values(&self) -> Option<&[ScalarValue]> { - self.fields.map(Arc::deref) - } - /// Casts this struct scalar to another struct type. /// /// # Errors @@ -210,7 +203,7 @@ impl<'a> StructScalar<'a> { ); } - if let Some(fs) = self.field_values() { + if let Some(fs) = self.fields { let fields = fs .iter() .enumerate() @@ -228,9 +221,9 @@ impl<'a> StructScalar<'a> { .map(|s| s.into_value()) }) .collect::>>()?; - Ok(Scalar::new( + Ok(Scalar::new_value( dtype.clone(), - ScalarValue(InnerScalarValue::List(fields.into())), + ScalarValue::List(fields.into()), )) } else { Ok(Scalar::null(dtype.clone())) @@ -247,26 +240,28 @@ impl<'a> StructScalar<'a> { .dtype .as_struct_fields_opt() .ok_or_else(|| vortex_err!("Not a struct dtype"))?; - let projected_dtype = struct_dtype.project(projection)?; - let new_fields = if let Some(fs) = self.field_values() { - ScalarValue(InnerScalarValue::List( - projection - .iter() - .map(|name| { - struct_dtype - .find(name) - .vortex_expect("DType has been successfully projected already") - }) - .map(|i| fs[i].clone()) - .collect(), - )) - } else { - ScalarValue(InnerScalarValue::Null) + let projected_dtype = DType::Struct( + struct_dtype.project(projection)?, + self.dtype().nullability(), + ); + + let Some(fs) = self.fields else { + return Ok(Scalar::null(projected_dtype)); }; - Ok(Scalar::new( - DType::Struct(projected_dtype, self.dtype().nullability()), - new_fields, - )) + + let new_fields = ScalarValue::List( + projection + .iter() + .map(|name| { + struct_dtype + .find(name) + .vortex_expect("DType has been successfully projected already") + }) + .map(|i| fs[i].clone()) + .collect(), + ); + + Ok(Scalar::new_value(projected_dtype, new_fields)) } } @@ -300,10 +295,7 @@ impl Scalar { let mut value_children = Vec::with_capacity(children.len()); value_children.extend(children.into_iter().map(|x| x.into_value())); - Self::new( - dtype, - ScalarValue(InnerScalarValue::List(value_children.into())), - ) + Self::new_value(dtype, ScalarValue::List(value_children.into())) } } @@ -323,6 +315,7 @@ mod tests { use vortex_dtype::StructFields; use super::*; + use crate::PValue; fn setup_types() -> (DType, DType, DType) { let f0_dt = DType::Primitive(I32, Nullability::NonNullable); @@ -436,7 +429,11 @@ mod tests { let scalar = Scalar::struct_(dtype, vec![f0_val, f1_val]); - let fields = scalar.as_struct().fields().unwrap().collect::>(); + let fields = scalar + .as_struct() + .fields_iter() + .unwrap() + .collect::>(); assert_eq!(fields.len(), 2); assert_eq!(fields[0].as_primitive().typed_value::().unwrap(), 100); assert_eq!(fields[1].as_utf8().value().unwrap(), "test".into()); @@ -448,8 +445,8 @@ mod tests { let null_scalar = Scalar::null(dtype); assert!(null_scalar.as_struct().is_null()); - assert!(null_scalar.as_struct().fields().is_none()); - assert!(null_scalar.as_struct().field_values().is_none()); + assert!(null_scalar.as_struct().fields_iter().is_none()); + assert!(null_scalar.as_struct().fields.is_none()); } #[test] @@ -482,7 +479,11 @@ mod tests { let result = source_scalar.as_struct().cast(&target_dtype).unwrap(); assert_eq!(result.dtype(), &target_dtype); - let fields = result.as_struct().fields().unwrap().collect::>(); + let fields = result + .as_struct() + .fields_iter() + .unwrap() + .collect::>(); assert_eq!(fields[0].as_primitive().typed_value::().unwrap(), 42); assert_eq!(fields[1].as_primitive().typed_value::().unwrap(), 123); } @@ -545,7 +546,7 @@ mod tests { assert_eq!(projected_struct.names().len(), 1); assert_eq!(projected_struct.names()[0].as_ref(), "b"); - let fields = projected_struct.fields().unwrap().collect::>(); + let fields = projected_struct.fields_iter().unwrap().collect::>(); assert_eq!(fields.len(), 1); assert_eq!(fields[0].as_utf8().value().unwrap().as_str(), "hello"); } @@ -668,9 +669,9 @@ mod tests { #[test] fn test_struct_try_new_non_struct_dtype() { let dtype = DType::Primitive(I32, Nullability::NonNullable); - let value = ScalarValue(InnerScalarValue::Primitive(crate::PValue::I32(42))); + let value = ScalarValue::Primitive(PValue::I32(42)); - let result = StructScalar::try_new(&dtype, &value); + let result = StructScalar::try_new(&dtype, Some(&value)); assert!(result.is_err()); } From 074ba08c4033adcdc9fe77e06ed2b1add67c5f89 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Wed, 4 Feb 2026 12:40:05 -0500 Subject: [PATCH 03/22] it compiles! Signed-off-by: Connor Tsui --- vortex-scalar/src/arrow/mod.rs | 4 +- vortex-scalar/src/arrow/tests.rs | 3 +- vortex-scalar/src/binary.rs | 2 +- vortex-scalar/src/bool.rs | 2 +- vortex-scalar/src/decimal/scalar.rs | 9 +- vortex-scalar/src/decimal/value.rs | 7 +- vortex-scalar/src/display.rs | 15 +- vortex-scalar/src/extension.rs | 22 +-- vortex-scalar/src/list.rs | 31 +-- vortex-scalar/src/primitive.rs | 285 +++++++++------------------- vortex-scalar/src/scalar.rs | 6 +- vortex-scalar/src/scalar_value.rs | 18 ++ vortex-scalar/src/struct_.rs | 16 +- vortex-scalar/src/utf8.rs | 151 ++++++--------- 14 files changed, 220 insertions(+), 351 deletions(-) diff --git a/vortex-scalar/src/arrow/mod.rs b/vortex-scalar/src/arrow/mod.rs index 530502b5f88..3cb9e0ed44a 100644 --- a/vortex-scalar/src/arrow/mod.rs +++ b/vortex-scalar/src/arrow/mod.rs @@ -108,10 +108,10 @@ impl TryFrom<&Scalar> for Arc { ))), }, DType::Utf8(_) => { - value_to_arrow_scalar!(value.as_utf8().value(), StringViewArray) + value_to_arrow_scalar!(value.as_utf8().to_value(), StringViewArray) } DType::Binary(_) => { - value_to_arrow_scalar!(value.as_binary().value(), BinaryViewArray) + value_to_arrow_scalar!(value.as_binary().to_value(), BinaryViewArray) } DType::Struct(..) => { todo!("struct scalar conversion") diff --git a/vortex-scalar/src/arrow/tests.rs b/vortex-scalar/src/arrow/tests.rs index b34e2f707b4..8deff832eee 100644 --- a/vortex-scalar/src/arrow/tests.rs +++ b/vortex-scalar/src/arrow/tests.rs @@ -6,6 +6,7 @@ use std::sync::Arc; use arrow_array::Datum; use rstest::rstest; use vortex_dtype::DType; +use vortex_dtype::NativeDType; use vortex_dtype::Nullability; use vortex_dtype::PType; use vortex_dtype::datetime::Date; @@ -35,7 +36,7 @@ fn test_bool_scalar_to_arrow() { #[test] fn test_null_bool_scalar_to_arrow() { - let scalar = Scalar::null_typed::(); + let scalar = Scalar::null(bool::dtype().as_nullable()); let result = Arc::::try_from(&scalar); assert!(result.is_ok()); } diff --git a/vortex-scalar/src/binary.rs b/vortex-scalar/src/binary.rs index 3a9d3e6c423..9cd97cb67f0 100644 --- a/vortex-scalar/src/binary.rs +++ b/vortex-scalar/src/binary.rs @@ -93,7 +93,7 @@ impl<'a> BinaryScalar<'a> { /// Returns the binary value as a byte buffer, or None if null. pub fn to_value(&self) -> Option { - self.value.map(|v| v.clone()) + self.value.cloned() } /// Returns a reference to the binary value, or None if null. diff --git a/vortex-scalar/src/bool.rs b/vortex-scalar/src/bool.rs index 09bd73d8524..526ba90275c 100644 --- a/vortex-scalar/src/bool.rs +++ b/vortex-scalar/src/bool.rs @@ -90,7 +90,7 @@ impl<'a> BoolScalar<'a> { /// Converts this boolean scalar into a general scalar. pub fn into_scalar(self) -> Scalar { - Scalar::new(self.dtype.clone(), self.value.map(|x| ScalarValue::Bool(x))) + Scalar::new(self.dtype.clone(), self.value.map(ScalarValue::Bool)) } } diff --git a/vortex-scalar/src/decimal/scalar.rs b/vortex-scalar/src/decimal/scalar.rs index 07a825ed263..9c306867f01 100644 --- a/vortex-scalar/src/decimal/scalar.rs +++ b/vortex-scalar/src/decimal/scalar.rs @@ -16,7 +16,6 @@ use vortex_error::vortex_err; use vortex_error::vortex_panic; use crate::DecimalValue; -use crate::InnerScalarValue; use crate::NumericOperator; use crate::Scalar; use crate::ScalarValue; @@ -35,9 +34,9 @@ impl<'a> DecimalScalar<'a> { /// # Errors /// /// Returns an error if the data type is not a decimal type. - pub fn try_new(dtype: &'a DType, value: &ScalarValue) -> VortexResult { + pub fn try_new(dtype: &'a DType, value: Option<&ScalarValue>) -> VortexResult { let decimal_type = DecimalDType::try_from(dtype)?; - let value = value.as_decimal()?; + let value = value.map(|v| *v.as_decimal()); Ok(Self { dtype, @@ -66,9 +65,7 @@ impl<'a> DecimalScalar<'a> { // Same decimal type, just change nullability if needed return Ok(Scalar::new( dtype.clone(), - ScalarValue(InnerScalarValue::Decimal( - self.value.unwrap_or(DecimalValue::I128(0)), - )), + self.value.map(ScalarValue::Decimal), )); } diff --git a/vortex-scalar/src/decimal/value.rs b/vortex-scalar/src/decimal/value.rs index e4de73a53b5..8b5da4f685c 100644 --- a/vortex-scalar/src/decimal/value.rs +++ b/vortex-scalar/src/decimal/value.rs @@ -23,7 +23,6 @@ use vortex_error::VortexExpect; use vortex_error::vortex_err; use crate::DecimalScalar; -use crate::InnerScalarValue; use crate::Scalar; use crate::ScalarValue; @@ -34,9 +33,9 @@ impl Scalar { decimal_type: DecimalDType, nullability: Nullability, ) -> Self { - Self::new( + Self::new_value( DType::Decimal(decimal_type, nullability), - ScalarValue(InnerScalarValue::Decimal(value)), + ScalarValue::Decimal(value), ) } } @@ -205,7 +204,7 @@ decimal_scalar_pack!(u64, i128, I128); impl From for ScalarValue { fn from(value: DecimalValue) -> Self { - Self(InnerScalarValue::Decimal(value)) + Self::Decimal(value) } } diff --git a/vortex-scalar/src/display.rs b/vortex-scalar/src/display.rs index bf0c2014703..41e09916781 100644 --- a/vortex-scalar/src/display.rs +++ b/vortex-scalar/src/display.rs @@ -138,10 +138,7 @@ mod tests { ); assert_eq!( - format!( - "{}", - Scalar::struct_(dtype(), vec![Scalar::from(Some(32_u32))]) - ), + format!("{}", Scalar::struct_(dtype(), vec![Scalar::from(32_u32)])), "{foo: 32u32}" ); } @@ -176,10 +173,7 @@ mod tests { assert_eq!( format!( "{}", - Scalar::struct_( - dtype.clone(), - vec![Scalar::from(Some(true)), Scalar::null(f2)] - ) + Scalar::struct_(dtype.clone(), vec![Scalar::from(true), Scalar::null(f2)]) ), "{foo: true, bar: null}" ); @@ -187,10 +181,7 @@ mod tests { assert_eq!( format!( "{}", - Scalar::struct_( - dtype, - vec![Scalar::from(Some(true)), Scalar::from(Some(32_u32))] - ) + Scalar::struct_(dtype, vec![Scalar::from(true), Scalar::from(32_u32)]) ), "{foo: true, bar: 32u32}" ); diff --git a/vortex-scalar/src/extension.rs b/vortex-scalar/src/extension.rs index 5983b9a13ab..51d511ec347 100644 --- a/vortex-scalar/src/extension.rs +++ b/vortex-scalar/src/extension.rs @@ -24,7 +24,7 @@ use crate::ScalarValue; #[derive(Debug, Clone)] pub struct ExtScalar<'a> { ext_dtype: &'a ExtDTypeRef, - value: &'a ScalarValue, + value: Option<&'a ScalarValue>, } impl Display for ExtScalar<'_> { @@ -80,7 +80,7 @@ impl<'a> ExtScalar<'a> { /// # Errors /// /// Returns an error if the data type is not an extension type. - pub fn try_new(dtype: &'a DType, value: &'a ScalarValue) -> VortexResult { + pub fn try_new(dtype: &'a DType, value: Option<&'a ScalarValue>) -> VortexResult { let DType::Extension(ext_dtype) = dtype else { vortex_bail!("Expected extension scalar, found {}", dtype) }; @@ -90,7 +90,7 @@ impl<'a> ExtScalar<'a> { /// Returns the storage scalar of the extension scalar. pub fn storage(&self) -> Scalar { - Scalar::new(self.ext_dtype.storage_dtype().clone(), self.value.clone()) + Scalar::new(self.ext_dtype.storage_dtype().clone(), self.value.cloned()) } /// Returns the extension data type. @@ -99,7 +99,7 @@ impl<'a> ExtScalar<'a> { } pub(crate) fn cast(&self, dtype: &DType) -> VortexResult { - if self.value.is_null() && !dtype.is_nullable() { + if self.value.is_none() && !dtype.is_nullable() { vortex_bail!( "cannot cast extension dtype with id {} and storage type {} to {}", self.ext_dtype.id(), @@ -110,13 +110,13 @@ impl<'a> ExtScalar<'a> { if self.ext_dtype.storage_dtype().eq_ignore_nullability(dtype) { // Casting from an extension type to the underlying storage type is OK. - return Ok(Scalar::new(dtype.clone(), self.value.clone())); + return Ok(Scalar::new(dtype.clone(), self.value.cloned())); } if let DType::Extension(ext_dtype) = dtype && self.ext_dtype.eq_ignore_nullability(ext_dtype) { - return Ok(Scalar::new(dtype.clone(), self.value.clone())); + return Ok(Scalar::new(dtype.clone(), self.value.cloned())); } vortex_bail!( @@ -141,13 +141,13 @@ impl Scalar { pub fn extension(options: V::Metadata, value: Scalar) -> Self { let ext_dtype = ExtDType::::try_new(options, value.dtype().clone()) .vortex_expect("Failed to create extension dtype"); - Self::new(DType::Extension(ext_dtype.erased()), value.value().clone()) + Self::new(DType::Extension(ext_dtype.erased()), value.into_value()) } /// Creates a new extension scalar wrapping the given storage value. pub fn extension_ref(ext_dtype: ExtDTypeRef, value: Scalar) -> Self { assert_eq!(ext_dtype.storage_dtype(), value.dtype()); - Self::new(DType::Extension(ext_dtype), value.value().clone()) + Self::new(DType::Extension(ext_dtype), value.into_value()) } } @@ -163,7 +163,7 @@ mod tests { use vortex_error::VortexResult; use crate::ExtScalar; - use crate::InnerScalarValue; + use crate::PValue; use crate::Scalar; use crate::ScalarValue; @@ -409,9 +409,9 @@ mod tests { #[test] fn test_ext_scalar_try_new_non_extension() { let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let value = ScalarValue(InnerScalarValue::Primitive(crate::PValue::I32(42))); + let value = ScalarValue::Primitive(PValue::I32(42)); - let result = ExtScalar::try_new(&dtype, &value); + let result = ExtScalar::try_new(&dtype, Some(&value)); assert!(result.is_err()); } diff --git a/vortex-scalar/src/list.rs b/vortex-scalar/src/list.rs index 34b829c3013..7c053638642 100644 --- a/vortex-scalar/src/list.rs +++ b/vortex-scalar/src/list.rs @@ -16,7 +16,6 @@ use vortex_error::vortex_bail; use vortex_error::vortex_err; use vortex_error::vortex_panic; -use crate::InnerScalarValue; use crate::Scalar; use crate::ScalarValue; @@ -34,7 +33,10 @@ use crate::ScalarValue; pub struct ListScalar<'a> { dtype: &'a DType, element_dtype: &'a Arc, - elements: Option>, + /// The elements of the list. `None` if the entire list is null. + /// Each element is `Option` where `None` represents a null element within the + /// list. + elements: Option<&'a [Option]>, } impl Display for ListScalar<'_> { @@ -133,7 +135,6 @@ impl<'a> ListScalar<'a> { /// Returns None if the list is null or the index is out of bounds. pub fn element(&self, idx: usize) -> Option { self.elements - .as_ref() .and_then(|l| l.get(idx)) .map(|value| Scalar::new(self.element_dtype().clone(), value.clone())) } @@ -142,7 +143,7 @@ impl<'a> ListScalar<'a> { /// /// Returns None if the list is null. pub fn elements(&self) -> Option> { - self.elements.as_ref().map(|elems| { + self.elements.map(|elems| { elems .iter() .map(|e| Scalar::new(self.element_dtype().clone(), e.clone())) @@ -180,11 +181,10 @@ impl<'a> ListScalar<'a> { ) } - Ok(Scalar::new( + Ok(Scalar::new_value( dtype.clone(), - ScalarValue(InnerScalarValue::List( + ScalarValue::List( self.elements - .as_ref() .vortex_expect("nullness handled in Scalar::cast") .iter() .map(|element| { @@ -193,8 +193,8 @@ impl<'a> ListScalar<'a> { .cast(target_element_dtype) .map(|x| x.into_value()) }) - .collect::>>()?, - )), + .collect::>>>()?, + ), )) } } @@ -215,7 +215,7 @@ impl Scalar { ) -> Self { let element_dtype = element_dtype.into(); - let children: Arc<[ScalarValue]> = children + let children: Vec> = children .into_iter() .map(|child| { if child.dtype() != &*element_dtype { @@ -238,7 +238,7 @@ impl Scalar { ListKind::FixedSize => DType::FixedSizeList(element_dtype, size, nullability), }; - Self::new(dtype, ScalarValue(InnerScalarValue::List(children))) + Self::new_value(dtype, ScalarValue::List(children)) } /// Creates a new list scalar with the given element type and children. @@ -287,7 +287,7 @@ impl<'a> TryFrom<&'a Scalar> for ListScalar<'a> { Ok(Self { dtype: value.dtype(), element_dtype, - elements: value.value().as_list()?.cloned(), + elements: value.value().map(|v| v.as_list()), }) } } @@ -504,6 +504,8 @@ mod tests { assert_eq!(hash1, hash2); } + // TODO(connor): These tests use a non-existent `Vec::try_from(&Scalar)` impl. + /* #[test] fn test_vec_conversion() { let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)); @@ -526,6 +528,7 @@ mod tests { let result: VortexResult> = Vec::try_from(&list_scalar); assert!(result.unwrap().is_empty()); } + */ #[test] fn test_list_cast() { @@ -582,10 +585,10 @@ mod tests { assert_eq!(list.len(), 2); let elem0 = list.element(0).unwrap(); - assert_eq!(elem0.as_utf8().value().unwrap().as_str(), "hello"); + assert_eq!(elem0.as_utf8().to_value().unwrap().as_str(), "hello"); let elem1 = list.element(1).unwrap(); - assert_eq!(elem1.as_utf8().value().unwrap().as_str(), "world"); + assert_eq!(elem1.as_utf8().to_value().unwrap().as_str(), "world"); } #[test] diff --git a/vortex-scalar/src/primitive.rs b/vortex-scalar/src/primitive.rs index 4a8bf36abfd..e56a60f48f5 100644 --- a/vortex-scalar/src/primitive.rs +++ b/vortex-scalar/src/primitive.rs @@ -76,17 +76,18 @@ impl<'a> PrimitiveScalar<'a> { /// /// Returns an error if the data type is not a primitive type or if the value /// cannot be converted to the expected primitive type. - pub fn try_new(dtype: &'a DType, value: &ScalarValue) -> VortexResult { + pub fn try_new(dtype: &'a DType, value: Option<&ScalarValue>) -> VortexResult { let ptype = PType::try_from(dtype)?; // Read the serialized value into the correct PValue. // The serialized form may come back over the wire as e.g. any integer type. - let pvalue = match_each_native_ptype!(ptype, |T| { - value - .as_pvalue()? - .map(|pv| VortexResult::Ok(PValue::from(::coerce(pv)?))) - .transpose()? - }); + let pvalue = match value { + None => None, + Some(v) => { + let pv = v.as_primitive(); + match_each_native_ptype!(ptype, |T| { Some(PValue::from(::coerce(*pv)?)) }) + } + }; Ok(Self { dtype, @@ -286,9 +287,9 @@ impl Scalar { /// Note that an explicit PType is passed since any compatible PValue may be used as the value /// for a primitive type. pub fn primitive_value(value: PValue, ptype: PType, nullability: Nullability) -> Self { - Self::new( + Self::new_value( DType::Primitive(ptype, nullability), - ScalarValue(InnerScalarValue::Primitive(value)), + ScalarValue::Primitive(value), ) } @@ -321,8 +322,7 @@ impl Scalar { primitive .pvalue .map(|p| p.reinterpret_cast(ptype)) - .map(|x| ScalarValue(InnerScalarValue::Primitive(x))) - .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null)), + .map(ScalarValue::Primitive), ) } } @@ -364,16 +364,16 @@ macro_rules! primitive_scalar { impl From<$T> for Scalar { fn from(value: $T) -> Self { - Scalar::new( + Scalar::new_value( DType::Primitive(<$T>::PTYPE, Nullability::NonNullable), - ScalarValue(InnerScalarValue::Primitive(value.into())), + ScalarValue::Primitive(value.into()), ) } } impl From<$T> for ScalarValue { fn from(value: $T) -> Self { - ScalarValue(InnerScalarValue::Primitive(value.into())) + ScalarValue::Primitive(value.into()) } } }; @@ -423,14 +423,14 @@ impl From for Scalar { impl From for ScalarValue { fn from(value: PValue) -> Self { - ScalarValue(InnerScalarValue::Primitive(value)) + ScalarValue::Primitive(value) } } /// Read a scalar as usize. For usize only, we implicitly cast for better ergonomics. impl From for ScalarValue { fn from(value: usize) -> Self { - ScalarValue(InnerScalarValue::Primitive((value as u64).into())) + ScalarValue::Primitive((value as u64).into()) } } @@ -572,7 +572,6 @@ mod tests { use vortex_dtype::PType; use vortex_error::VortexExpect; - use crate::InnerScalarValue; use crate::PValue; use crate::PrimitiveScalar; use crate::ScalarValue; @@ -580,16 +579,10 @@ mod tests { #[test] fn test_integer_subtract() { let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let p_scalar1 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(5))), - ) - .unwrap(); - let p_scalar2 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))), - ) - .unwrap(); + let value1 = ScalarValue::Primitive(PValue::I32(5)); + let value2 = ScalarValue::Primitive(PValue::I32(4)); + let p_scalar1 = PrimitiveScalar::try_new(&dtype, Some(&value1)).unwrap(); + let p_scalar2 = PrimitiveScalar::try_new(&dtype, Some(&value2)).unwrap(); let pscalar_or_overflow = p_scalar1.checked_sub(&p_scalar2); let value_or_null_or_type_error = pscalar_or_overflow.unwrap().as_::(); assert_eq!(value_or_null_or_type_error.unwrap(), 1); @@ -601,32 +594,20 @@ mod tests { #[should_panic(expected = "PrimitiveScalar subtract: overflow or underflow")] fn test_integer_subtract_overflow() { let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let p_scalar1 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MIN))), - ) - .unwrap(); - let p_scalar2 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MAX))), - ) - .unwrap(); + let value1 = ScalarValue::Primitive(PValue::I32(i32::MIN)); + let value2 = ScalarValue::Primitive(PValue::I32(i32::MAX)); + let p_scalar1 = PrimitiveScalar::try_new(&dtype, Some(&value1)).unwrap(); + let p_scalar2 = PrimitiveScalar::try_new(&dtype, Some(&value2)).unwrap(); let _ = p_scalar1 - p_scalar2; } #[test] fn test_float_subtract() { let dtype = DType::Primitive(PType::F32, Nullability::NonNullable); - let p_scalar1 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::F32(1.99f32))), - ) - .unwrap(); - let p_scalar2 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::F32(1.0f32))), - ) - .unwrap(); + let value1 = ScalarValue::Primitive(PValue::F32(1.99f32)); + let value2 = ScalarValue::Primitive(PValue::F32(1.0f32)); + let p_scalar1 = PrimitiveScalar::try_new(&dtype, Some(&value1)).unwrap(); + let p_scalar2 = PrimitiveScalar::try_new(&dtype, Some(&value2)).unwrap(); let pscalar_or_overflow = p_scalar1.checked_sub(&p_scalar2).unwrap(); let value_or_null_or_type_error = pscalar_or_overflow.as_::(); assert_eq!(value_or_null_or_type_error.unwrap(), 0.99f32); @@ -637,21 +618,12 @@ mod tests { #[test] fn test_primitive_scalar_equality() { let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let scalar1 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))), - ) - .unwrap(); - let scalar2 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))), - ) - .unwrap(); - let scalar3 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(43))), - ) - .unwrap(); + let value1 = ScalarValue::Primitive(PValue::I32(42)); + let value2 = ScalarValue::Primitive(PValue::I32(42)); + let value3 = ScalarValue::Primitive(PValue::I32(43)); + let scalar1 = PrimitiveScalar::try_new(&dtype, Some(&value1)).unwrap(); + let scalar2 = PrimitiveScalar::try_new(&dtype, Some(&value2)).unwrap(); + let scalar3 = PrimitiveScalar::try_new(&dtype, Some(&value3)).unwrap(); assert_eq!(scalar1, scalar2); assert_ne!(scalar1, scalar3); @@ -660,16 +632,10 @@ mod tests { #[test] fn test_primitive_scalar_partial_ord() { let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let scalar1 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))), - ) - .unwrap(); - let scalar2 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(20))), - ) - .unwrap(); + let value1 = ScalarValue::Primitive(PValue::I32(10)); + let value2 = ScalarValue::Primitive(PValue::I32(20)); + let scalar1 = PrimitiveScalar::try_new(&dtype, Some(&value1)).unwrap(); + let scalar2 = PrimitiveScalar::try_new(&dtype, Some(&value2)).unwrap(); assert!(scalar1 < scalar2); assert!(scalar2 > scalar1); @@ -682,8 +648,7 @@ mod tests { #[test] fn test_primitive_scalar_null_handling() { let dtype = DType::Primitive(PType::I32, Nullability::Nullable); - let null_scalar = - PrimitiveScalar::try_new(&dtype, &ScalarValue(InnerScalarValue::Null)).unwrap(); + let null_scalar = PrimitiveScalar::try_new(&dtype, None).unwrap(); assert_eq!(null_scalar.pvalue(), None); assert_eq!(null_scalar.typed_value::(), None); @@ -692,11 +657,8 @@ mod tests { #[test] fn test_typed_value_correct_type() { let dtype = DType::Primitive(PType::F64, Nullability::NonNullable); - let scalar = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::F64(3.5))), - ) - .unwrap(); + let value = ScalarValue::Primitive(PValue::F64(3.5)); + let scalar = PrimitiveScalar::try_new(&dtype, Some(&value)).unwrap(); assert_eq!(scalar.typed_value::(), Some(3.5)); } @@ -705,11 +667,8 @@ mod tests { #[should_panic(expected = "Attempting to read")] fn test_typed_value_wrong_type() { let dtype = DType::Primitive(PType::F64, Nullability::NonNullable); - let scalar = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::F64(3.5))), - ) - .unwrap(); + let value = ScalarValue::Primitive(PValue::F64(3.5)); + let scalar = PrimitiveScalar::try_new(&dtype, Some(&value)).unwrap(); let _ = scalar.typed_value::(); } @@ -742,11 +701,8 @@ mod tests { }; let dtype = DType::Primitive(source_type, Nullability::NonNullable); - let scalar = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(source_pvalue)), - ) - .unwrap(); + let value = ScalarValue::Primitive(source_pvalue); + let scalar = PrimitiveScalar::try_new(&dtype, Some(&value)).unwrap(); let target_dtype = DType::Primitive(target_type, Nullability::NonNullable); let result = scalar.cast(&target_dtype); @@ -771,11 +727,8 @@ mod tests { #[test] fn test_as_conversion_success() { let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let scalar = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))), - ) - .unwrap(); + let value = ScalarValue::Primitive(PValue::I32(42)); + let scalar = PrimitiveScalar::try_new(&dtype, Some(&value)).unwrap(); assert_eq!(scalar.as_::(), Some(42i64)); assert_eq!(scalar.as_::(), Some(42.0)); @@ -784,11 +737,8 @@ mod tests { #[test] fn test_as_conversion_overflow() { let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let scalar = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(-1))), - ) - .unwrap(); + let value = ScalarValue::Primitive(PValue::I32(-1)); + let scalar = PrimitiveScalar::try_new(&dtype, Some(&value)).unwrap(); // Converting -1 to u32 should fail let result = scalar.as_opt::(); @@ -798,8 +748,7 @@ mod tests { #[test] fn test_as_conversion_null() { let dtype = DType::Primitive(PType::I32, Nullability::Nullable); - let scalar = - PrimitiveScalar::try_new(&dtype, &ScalarValue(InnerScalarValue::Null)).unwrap(); + let scalar = PrimitiveScalar::try_new(&dtype, None).unwrap(); assert_eq!(scalar.as_::(), None); assert_eq!(scalar.as_::(), None); @@ -822,16 +771,10 @@ mod tests { use crate::primitive::NumericOperator; let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let scalar1 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))), - ) - .unwrap(); - let scalar2 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(20))), - ) - .unwrap(); + let value1 = ScalarValue::Primitive(PValue::I32(10)); + let value2 = ScalarValue::Primitive(PValue::I32(20)); + let scalar1 = PrimitiveScalar::try_new(&dtype, Some(&value1)).unwrap(); + let scalar2 = PrimitiveScalar::try_new(&dtype, Some(&value2)).unwrap(); let result = scalar1 .checked_binary_numeric(&scalar2, NumericOperator::Add) @@ -844,16 +787,10 @@ mod tests { use crate::primitive::NumericOperator; let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let scalar1 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MAX))), - ) - .unwrap(); - let scalar2 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(1))), - ) - .unwrap(); + let value1 = ScalarValue::Primitive(PValue::I32(i32::MAX)); + let value2 = ScalarValue::Primitive(PValue::I32(1)); + let scalar1 = PrimitiveScalar::try_new(&dtype, Some(&value1)).unwrap(); + let scalar2 = PrimitiveScalar::try_new(&dtype, Some(&value2)).unwrap(); // Add should overflow and return None let result = scalar1.checked_binary_numeric(&scalar2, NumericOperator::Add); @@ -865,13 +802,9 @@ mod tests { use crate::primitive::NumericOperator; let dtype = DType::Primitive(PType::I32, Nullability::Nullable); - let scalar1 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))), - ) - .unwrap(); - let null_scalar = - PrimitiveScalar::try_new(&dtype, &ScalarValue(InnerScalarValue::Null)).unwrap(); + let value = ScalarValue::Primitive(PValue::I32(10)); + let scalar1 = PrimitiveScalar::try_new(&dtype, Some(&value)).unwrap(); + let null_scalar = PrimitiveScalar::try_new(&dtype, None).unwrap(); // Operation with null should return null let result = scalar1 @@ -885,16 +818,10 @@ mod tests { use crate::primitive::NumericOperator; let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let scalar1 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(5))), - ) - .unwrap(); - let scalar2 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(6))), - ) - .unwrap(); + let value1 = ScalarValue::Primitive(PValue::I32(5)); + let value2 = ScalarValue::Primitive(PValue::I32(6)); + let scalar1 = PrimitiveScalar::try_new(&dtype, Some(&value1)).unwrap(); + let scalar2 = PrimitiveScalar::try_new(&dtype, Some(&value2)).unwrap(); let result = scalar1 .checked_binary_numeric(&scalar2, NumericOperator::Mul) @@ -907,16 +834,10 @@ mod tests { use crate::primitive::NumericOperator; let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let scalar1 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(20))), - ) - .unwrap(); - let scalar2 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))), - ) - .unwrap(); + let value1 = ScalarValue::Primitive(PValue::I32(20)); + let value2 = ScalarValue::Primitive(PValue::I32(4)); + let scalar1 = PrimitiveScalar::try_new(&dtype, Some(&value1)).unwrap(); + let scalar2 = PrimitiveScalar::try_new(&dtype, Some(&value2)).unwrap(); let result = scalar1 .checked_binary_numeric(&scalar2, NumericOperator::Div) @@ -929,16 +850,10 @@ mod tests { use crate::primitive::NumericOperator; let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let scalar1 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))), - ) - .unwrap(); - let scalar2 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(20))), - ) - .unwrap(); + let value1 = ScalarValue::Primitive(PValue::I32(4)); + let value2 = ScalarValue::Primitive(PValue::I32(20)); + let scalar1 = PrimitiveScalar::try_new(&dtype, Some(&value1)).unwrap(); + let scalar2 = PrimitiveScalar::try_new(&dtype, Some(&value2)).unwrap(); // RDiv means right / left, so 20 / 4 = 5 let result = scalar1 @@ -952,16 +867,10 @@ mod tests { use crate::primitive::NumericOperator; let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let scalar1 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))), - ) - .unwrap(); - let scalar2 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(0))), - ) - .unwrap(); + let value1 = ScalarValue::Primitive(PValue::I32(10)); + let value2 = ScalarValue::Primitive(PValue::I32(0)); + let scalar1 = PrimitiveScalar::try_new(&dtype, Some(&value1)).unwrap(); + let scalar2 = PrimitiveScalar::try_new(&dtype, Some(&value2)).unwrap(); // Division by zero should return None for integers let result = scalar1.checked_binary_numeric(&scalar2, NumericOperator::Div); @@ -973,16 +882,10 @@ mod tests { use crate::primitive::NumericOperator; let dtype = DType::Primitive(PType::F32, Nullability::NonNullable); - let scalar1 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::F32(10.0))), - ) - .unwrap(); - let scalar2 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::F32(2.5))), - ) - .unwrap(); + let value1 = ScalarValue::Primitive(PValue::F32(10.0)); + let value2 = ScalarValue::Primitive(PValue::F32(2.5)); + let scalar1 = PrimitiveScalar::try_new(&dtype, Some(&value1)).unwrap(); + let scalar2 = PrimitiveScalar::try_new(&dtype, Some(&value2)).unwrap(); // Test all operations with floats let add_result = scalar1 @@ -1029,16 +932,10 @@ mod tests { let dtype1 = DType::Primitive(PType::I32, Nullability::NonNullable); let dtype2 = DType::Primitive(PType::F32, Nullability::NonNullable); - let scalar1 = PrimitiveScalar::try_new( - &dtype1, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))), - ) - .unwrap(); - let scalar2 = PrimitiveScalar::try_new( - &dtype2, - &ScalarValue(InnerScalarValue::Primitive(PValue::F32(10.0))), - ) - .unwrap(); + let value1 = ScalarValue::Primitive(PValue::I32(10)); + let value2 = ScalarValue::Primitive(PValue::F32(10.0)); + let scalar1 = PrimitiveScalar::try_new(&dtype1, Some(&value1)).unwrap(); + let scalar2 = PrimitiveScalar::try_new(&dtype2, Some(&value2)).unwrap(); // Different types should not be comparable assert_eq!(scalar1.partial_cmp(&scalar2), None); @@ -1047,20 +944,14 @@ mod tests { #[test] fn test_scalar_value_from_usize() { let value: ScalarValue = 42usize.into(); - assert!(matches!( - value.0, - InnerScalarValue::Primitive(PValue::U64(42)) - )); + assert!(matches!(value, ScalarValue::Primitive(PValue::U64(42)))); } #[test] fn test_getters() { let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let scalar = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))), - ) - .unwrap(); + let value = ScalarValue::Primitive(PValue::I32(42)); + let scalar = PrimitiveScalar::try_new(&dtype, Some(&value)).unwrap(); assert_eq!(scalar.dtype(), &dtype); assert_eq!(scalar.ptype(), PType::I32); diff --git a/vortex-scalar/src/scalar.rs b/vortex-scalar/src/scalar.rs index 47167053399..292da746684 100644 --- a/vortex-scalar/src/scalar.rs +++ b/vortex-scalar/src/scalar.rs @@ -20,6 +20,10 @@ use crate::StructScalar; use crate::Utf8Scalar; // use crate::extension::ExtensionScalar; +/// A typed scalar value. +/// +/// Scalars represent a single value with an associated [`DType`]. The value can be null, in which +/// case the [`value`][Scalar::value] method returns `None`. #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct Scalar { /// The type of the scalar. @@ -38,7 +42,7 @@ impl Scalar { /// /// Panics if the given [`DType`] is non-nullable. pub fn null(dtype: DType) -> Self { - vortex_ensure!( + assert!( dtype.is_nullable(), "Cannot create null scalar with non-nullable dtype {dtype}" ); diff --git a/vortex-scalar/src/scalar_value.rs b/vortex-scalar/src/scalar_value.rs index 1414188a408..a477a035442 100644 --- a/vortex-scalar/src/scalar_value.rs +++ b/vortex-scalar/src/scalar_value.rs @@ -14,18 +14,29 @@ use crate::DecimalValue; // use crate::ExtScalarRef; use crate::PValue; +/// The value stored in a [`Scalar`][crate::Scalar]. +/// +/// This enum represents the possible non-null values that can be stored in a scalar. When the +/// scalar is null, the value is represented as `None` in the `Option` field. #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum ScalarValue { + /// A boolean value. Bool(bool), + /// A primitive numeric value. Primitive(PValue), + /// A decimal value. Decimal(DecimalValue), + /// A UTF-8 encoded string value. Utf8(BufferString), + /// A binary (byte array) value. Binary(ByteBuffer), + /// A list of potentially null scalar values. List(Vec>), // Extension(ExtScalarRef), } impl ScalarValue { + /// Returns the boolean value, panicking if the value is not a [`Bool`][ScalarValue::Bool]. pub fn as_bool(&self) -> bool { match self { ScalarValue::Bool(b) => *b, @@ -33,6 +44,8 @@ impl ScalarValue { } } + /// Returns the primitive value, panicking if the value is not a + /// [`Primitive`][ScalarValue::Primitive]. pub fn as_primitive(&self) -> &PValue { match self { ScalarValue::Primitive(p) => p, @@ -40,6 +53,8 @@ impl ScalarValue { } } + /// Returns the decimal value, panicking if the value is not a + /// [`Decimal`][ScalarValue::Decimal]. pub fn as_decimal(&self) -> &DecimalValue { match self { ScalarValue::Decimal(d) => d, @@ -47,6 +62,7 @@ impl ScalarValue { } } + /// Returns the UTF-8 string value, panicking if the value is not a [`Utf8`][ScalarValue::Utf8]. pub fn as_utf8(&self) -> &BufferString { match self { ScalarValue::Utf8(s) => s, @@ -54,6 +70,7 @@ impl ScalarValue { } } + /// Returns the binary value, panicking if the value is not a [`Binary`][ScalarValue::Binary]. pub fn as_binary(&self) -> &ByteBuffer { match self { ScalarValue::Binary(b) => b, @@ -61,6 +78,7 @@ impl ScalarValue { } } + /// Returns the list elements, panicking if the value is not a [`List`][ScalarValue::List]. pub fn as_list(&self) -> &[Option] { match self { ScalarValue::List(elements) => elements, diff --git a/vortex-scalar/src/struct_.rs b/vortex-scalar/src/struct_.rs index 8d6e3110d36..339d826e794 100644 --- a/vortex-scalar/src/struct_.rs +++ b/vortex-scalar/src/struct_.rs @@ -221,10 +221,7 @@ impl<'a> StructScalar<'a> { .map(|s| s.into_value()) }) .collect::>>()?; - Ok(Scalar::new_value( - dtype.clone(), - ScalarValue::List(fields.into()), - )) + Ok(Scalar::new_value(dtype.clone(), ScalarValue::List(fields))) } else { Ok(Scalar::null(dtype.clone())) } @@ -295,7 +292,7 @@ impl Scalar { let mut value_children = Vec::with_capacity(children.len()); value_children.extend(children.into_iter().map(|x| x.into_value())); - Self::new_value(dtype, ScalarValue::List(value_children.into())) + Self::new_value(dtype, ScalarValue::List(value_children)) } } @@ -414,7 +411,10 @@ mod tests { let field_b = scalar.as_struct().field("b"); assert!(field_b.is_some()); - assert_eq!(field_b.unwrap().as_utf8().value().unwrap(), "world".into()); + assert_eq!( + field_b.unwrap().as_utf8().to_value().unwrap(), + "world".into() + ); // Non-existent field let field_c = scalar.as_struct().field("c"); @@ -436,7 +436,7 @@ mod tests { .collect::>(); assert_eq!(fields.len(), 2); assert_eq!(fields[0].as_primitive().typed_value::().unwrap(), 100); - assert_eq!(fields[1].as_utf8().value().unwrap(), "test".into()); + assert_eq!(fields[1].as_utf8().to_value().unwrap(), "test".into()); } #[test] @@ -548,7 +548,7 @@ mod tests { let fields = projected_struct.fields_iter().unwrap().collect::>(); assert_eq!(fields.len(), 1); - assert_eq!(fields[0].as_utf8().value().unwrap().as_str(), "hello"); + assert_eq!(fields[0].as_utf8().to_value().unwrap().as_str(), "hello"); } #[test] diff --git a/vortex-scalar/src/utf8.rs b/vortex-scalar/src/utf8.rs index c274ac6b735..7d7c50075fe 100644 --- a/vortex-scalar/src/utf8.rs +++ b/vortex-scalar/src/utf8.rs @@ -5,7 +5,6 @@ use std::cmp; use std::fmt; use std::fmt::Display; use std::fmt::Formatter; -use std::sync::Arc; use vortex_buffer::BufferString; use vortex_dtype::DType; @@ -18,7 +17,6 @@ use vortex_error::vortex_bail; use vortex_error::vortex_err; use vortex_utils::aliases::StringEscape; -use crate::InnerScalarValue; use crate::Scalar; use crate::ScalarValue; @@ -95,7 +93,7 @@ mod private { #[derive(Debug, Clone, Hash, Eq)] pub struct Utf8Scalar<'a> { dtype: &'a DType, - value: Option>, + value: Option<&'a BufferString>, } impl Display for Utf8Scalar<'_> { @@ -131,13 +129,14 @@ impl<'a> Utf8Scalar<'a> { /// # Errors /// /// Returns an error if the data type is not a UTF-8 type. - pub fn from_scalar_value(dtype: &'a DType, value: ScalarValue) -> VortexResult { + pub fn try_new(dtype: &'a DType, value: Option<&'a ScalarValue>) -> VortexResult { if !matches!(dtype, DType::Utf8(..)) { vortex_bail!("Can only construct utf8 scalar from utf8 dtype, found {dtype}") } + Ok(Self { dtype, - value: value.as_buffer_string()?, + value: value.map(|value| value.as_utf8()), }) } @@ -148,16 +147,18 @@ impl<'a> Utf8Scalar<'a> { } /// Returns the string value, or None if null. - pub fn value(&self) -> Option { - self.value.as_ref().map(|v| v.as_ref().clone()) + pub fn to_value(&self) -> Option { + self.value.cloned() } /// Returns a reference to the string value, or None if null. /// This avoids cloning the underlying BufferString. pub fn value_ref(&self) -> Option<&BufferString> { - self.value.as_ref().map(|v| v.as_ref()) + self.value } + // TODO(connor): Figure out how to deal with the lifetime. + /* /// Constructs the next scalar at most `max_length` bytes that's lexicographically greater than /// this. /// @@ -177,7 +178,7 @@ impl<'a> Utf8Scalar<'a> { let incremented = sliced_buf.increment().ok()?; Some(Self { dtype: self.dtype, - value: Some(Arc::new(incremented)), + value: Some(incremented), }) } else { Some(Self { @@ -201,9 +202,9 @@ impl<'a> Utf8Scalar<'a> { Self { dtype: self.dtype, - value: Some(Arc::new(unsafe { + value: Some(unsafe { BufferString::new_unchecked(value.inner().slice(0..utf8_split_pos)) - })), + }), } } else { Self { @@ -215,6 +216,7 @@ impl<'a> Utf8Scalar<'a> { self } } + */ pub(crate) fn cast(&self, dtype: &DType) -> VortexResult { if !matches!(dtype, DType::Utf8(..)) { @@ -222,14 +224,12 @@ impl<'a> Utf8Scalar<'a> { "Cannot cast utf8 to {dtype}: UTF-8 scalars can only be cast to UTF-8 types with different nullability" ) } - Ok(Scalar::new( + Ok(Scalar::new_value( dtype.clone(), - ScalarValue(InnerScalarValue::BufferString( - self.value - .as_ref() - .vortex_expect("nullness handled in Scalar::cast") - .clone(), - )), + ScalarValue::Utf8( + self.to_value() + .vortex_expect("nullness handled in Scalar::cast"), + ), )) } @@ -269,9 +269,9 @@ impl Scalar { where B: TryInto, { - Ok(Self::new( + Ok(Self::new_value( DType::Utf8(nullability), - ScalarValue(InnerScalarValue::BufferString(Arc::new(str.try_into()?))), + ScalarValue::Utf8(str.try_into()?), )) } } @@ -280,13 +280,7 @@ impl<'a> TryFrom<&'a Scalar> for Utf8Scalar<'a> { type Error = VortexError; fn try_from(value: &'a Scalar) -> Result { - if !matches!(value.dtype(), DType::Utf8(_)) { - vortex_bail!("Expected utf8 scalar, found {}", value.dtype()) - } - Ok(Self { - dtype: value.dtype(), - value: value.value().as_buffer_string()?, - }) + Self::try_new(value.dtype(), value.value()) } } @@ -308,39 +302,22 @@ impl TryFrom for String { impl From<&str> for Scalar { fn from(value: &str) -> Self { - Self::new( + Self::new_value( DType::Utf8(NonNullable), - ScalarValue(InnerScalarValue::BufferString(Arc::new( - value.to_string().into(), - ))), + ScalarValue::Utf8(value.to_string().into()), ) } } impl From for Scalar { fn from(value: String) -> Self { - Self::new( - DType::Utf8(NonNullable), - ScalarValue(InnerScalarValue::BufferString(Arc::new(value.into()))), - ) + Self::new_value(DType::Utf8(NonNullable), ScalarValue::Utf8(value.into())) } } impl From for Scalar { fn from(value: BufferString) -> Self { - Self::new( - DType::Utf8(NonNullable), - ScalarValue(InnerScalarValue::BufferString(Arc::new(value))), - ) - } -} - -impl From> for Scalar { - fn from(value: Arc) -> Self { - Self::new( - DType::Utf8(NonNullable), - ScalarValue(InnerScalarValue::BufferString(value)), - ) + Self::new_value(DType::Utf8(NonNullable), ScalarValue::Utf8(value)) } } @@ -365,7 +342,7 @@ impl<'a> TryFrom<&'a Scalar> for Option { type Error = VortexError; fn try_from(scalar: &'a Scalar) -> Result { - Ok(Utf8Scalar::try_from(scalar)?.value()) + Ok(Utf8Scalar::try_from(scalar)?.to_value()) } } @@ -379,21 +356,19 @@ impl TryFrom for Option { impl From<&str> for ScalarValue { fn from(value: &str) -> Self { - ScalarValue(InnerScalarValue::BufferString(Arc::new( - value.to_string().into(), - ))) + ScalarValue::Utf8(value.to_string().into()) } } impl From for ScalarValue { fn from(value: String) -> Self { - ScalarValue(InnerScalarValue::BufferString(Arc::new(value.into()))) + ScalarValue::Utf8(value.into()) } } impl From for ScalarValue { fn from(value: BufferString) -> Self { - ScalarValue(InnerScalarValue::BufferString(Arc::new(value))) + ScalarValue::Utf8(value) } } @@ -403,11 +378,13 @@ mod tests { use rstest::rstest; use vortex_dtype::Nullability; - use vortex_error::VortexExpect; use crate::Scalar; use crate::Utf8Scalar; + // TODO(connor): Tests for lower_bound and upper_bound are commented out + // because the methods are commented out due to lifetime issues. + /* #[test] fn lower_bound() { let utf8 = Scalar::utf8("snowman⛄️snowman", Nullability::NonNullable); @@ -444,6 +421,7 @@ mod tests { .is_none() ); } + */ #[rstest] #[case("hello", "hello", true)] @@ -485,7 +463,7 @@ mod tests { let null_utf8 = Scalar::null(vortex_dtype::DType::Utf8(Nullability::Nullable)); let scalar = Utf8Scalar::try_from(&null_utf8).unwrap(); - assert!(scalar.value().is_none()); + assert!(scalar.to_value().is_none()); assert!(scalar.value_ref().is_none()); assert!(scalar.len().is_none()); assert!(scalar.is_empty().is_none()); @@ -515,8 +493,8 @@ mod tests { let value_ref = scalar.value_ref().unwrap(); assert_eq!(value_ref.as_str(), data); - // value should clone - let value = scalar.value().unwrap(); + // to_value should clone + let value = scalar.to_value().unwrap(); assert_eq!(value.as_str(), data); } @@ -533,7 +511,7 @@ mod tests { assert_eq!(result.dtype(), &DType::Utf8(Nullability::Nullable)); let casted = Utf8Scalar::try_from(&result).unwrap(); - assert_eq!(casted.value().unwrap().as_str(), "test"); + assert_eq!(casted.to_value().unwrap().as_str(), "test"); } #[test] @@ -550,15 +528,15 @@ mod tests { } #[test] - fn test_from_scalar_value_non_utf8_dtype() { + fn test_try_new_non_utf8_dtype() { use vortex_dtype::DType; use vortex_dtype::Nullability; use vortex_dtype::PType; let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let value = crate::ScalarValue(crate::InnerScalarValue::Primitive(crate::PValue::I32(42))); + let value = crate::ScalarValue::Primitive(crate::PValue::I32(42)); - let result = Utf8Scalar::from_scalar_value(&dtype, value); + let result = Utf8Scalar::try_new(&dtype, Some(&value)); assert!(result.is_err()); } @@ -571,6 +549,9 @@ mod tests { assert!(result.is_err()); } + // TODO(connor): Tests for upper_bound and lower_bound are commented out + // because the methods are commented out due to lifetime issues. + /* #[test] fn test_upper_bound_null() { let null_utf8 = Scalar::null(vortex_dtype::DType::Utf8(Nullability::Nullable)); @@ -578,7 +559,7 @@ mod tests { let result = scalar.upper_bound(10); assert!(result.is_some()); - assert!(result.unwrap().value().is_none()); + assert!(result.unwrap().to_value().is_none()); } #[test] @@ -587,7 +568,7 @@ mod tests { let scalar = Utf8Scalar::try_from(&null_utf8).unwrap(); let result = scalar.lower_bound(10); - assert!(result.value().is_none()); + assert!(result.to_value().is_none()); } #[test] @@ -598,7 +579,7 @@ mod tests { let result = scalar.upper_bound(3); assert!(result.is_some()); let upper = result.unwrap(); - assert_eq!(upper.value().unwrap().as_str(), "abc"); + assert_eq!(upper.to_value().unwrap().as_str(), "abc"); } #[test] @@ -607,8 +588,9 @@ mod tests { let scalar = Utf8Scalar::try_from(&utf8).unwrap(); let result = scalar.lower_bound(3); - assert_eq!(result.value().unwrap().as_str(), "abc"); + assert_eq!(result.to_value().unwrap().as_str(), "abc"); } + */ #[test] fn test_from_str() { @@ -620,7 +602,7 @@ mod tests { &vortex_dtype::DType::Utf8(Nullability::NonNullable) ); let utf8 = Utf8Scalar::try_from(&scalar).unwrap(); - assert_eq!(utf8.value().unwrap().as_str(), data); + assert_eq!(utf8.to_value().unwrap().as_str(), data); } #[test] @@ -633,7 +615,7 @@ mod tests { &vortex_dtype::DType::Utf8(Nullability::NonNullable) ); let utf8 = Utf8Scalar::try_from(&scalar).unwrap(); - assert_eq!(utf8.value().unwrap().as_str(), "hello world"); + assert_eq!(utf8.to_value().unwrap().as_str(), "hello world"); } #[test] @@ -648,24 +630,7 @@ mod tests { &vortex_dtype::DType::Utf8(Nullability::NonNullable) ); let utf8 = Utf8Scalar::try_from(&scalar).unwrap(); - assert_eq!(utf8.value().unwrap().as_str(), "test"); - } - - #[test] - fn test_from_arc_buffer_string() { - use std::sync::Arc; - - use vortex_buffer::BufferString; - - let data = Arc::new(BufferString::from("test")); - let scalar: Scalar = data.into(); - - assert_eq!( - scalar.dtype(), - &vortex_dtype::DType::Utf8(Nullability::NonNullable) - ); - let utf8 = Utf8Scalar::try_from(&scalar).unwrap(); - assert_eq!(utf8.value().unwrap().as_str(), "test"); + assert_eq!(utf8.to_value().unwrap().as_str(), "test"); } #[test] @@ -730,9 +695,9 @@ mod tests { let data = "test"; let value: crate::ScalarValue = data.into(); - let scalar = Scalar::new(vortex_dtype::DType::Utf8(Nullability::NonNullable), value); + let scalar = Scalar::new_value(vortex_dtype::DType::Utf8(Nullability::NonNullable), value); let utf8 = Utf8Scalar::try_from(&scalar).unwrap(); - assert_eq!(utf8.value().unwrap().as_str(), data); + assert_eq!(utf8.to_value().unwrap().as_str(), data); } #[test] @@ -740,9 +705,9 @@ mod tests { let data = String::from("test"); let value: crate::ScalarValue = data.clone().into(); - let scalar = Scalar::new(vortex_dtype::DType::Utf8(Nullability::NonNullable), value); + let scalar = Scalar::new_value(vortex_dtype::DType::Utf8(Nullability::NonNullable), value); let utf8 = Utf8Scalar::try_from(&scalar).unwrap(); - assert_eq!(utf8.value().unwrap().as_str(), &data); + assert_eq!(utf8.to_value().unwrap().as_str(), &data); } #[test] @@ -752,9 +717,9 @@ mod tests { let data = BufferString::from("test"); let value: crate::ScalarValue = data.into(); - let scalar = Scalar::new(vortex_dtype::DType::Utf8(Nullability::NonNullable), value); + let scalar = Scalar::new_value(vortex_dtype::DType::Utf8(Nullability::NonNullable), value); let utf8 = Utf8Scalar::try_from(&scalar).unwrap(); - assert_eq!(utf8.value().unwrap().as_str(), "test"); + assert_eq!(utf8.to_value().unwrap().as_str(), "test"); } #[test] @@ -763,7 +728,7 @@ mod tests { let scalar = Scalar::utf8(emoji_str, Nullability::NonNullable); let utf8_scalar = Utf8Scalar::try_from(&scalar).unwrap(); - assert_eq!(utf8_scalar.value().unwrap().as_str(), emoji_str); + assert_eq!(utf8_scalar.to_value().unwrap().as_str(), emoji_str); assert!(utf8_scalar.len().unwrap() > emoji_str.chars().count()); // Byte length > char count } From fcf8d0c25711010d6b75f3d68dd1efa8bb86865b Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Wed, 4 Feb 2026 16:03:55 -0500 Subject: [PATCH 04/22] add some nice things to scalars Signed-off-by: Connor Tsui --- vortex-scalar/src/cast.rs | 38 +++++------ vortex-scalar/src/decimal/scalar.rs | 9 +++ vortex-scalar/src/decimal/value.rs | 27 ++++++++ vortex-scalar/src/primitive.rs | 6 ++ vortex-scalar/src/proto.rs | 44 ++++++++++++ vortex-scalar/src/pvalue.rs | 41 +++++++---- vortex-scalar/src/scalar.rs | 101 ++++++++++++++++++++++++---- vortex-scalar/src/scalar_value.rs | 64 ++++++++++++++++++ 8 files changed, 285 insertions(+), 45 deletions(-) diff --git a/vortex-scalar/src/cast.rs b/vortex-scalar/src/cast.rs index 33a1159210a..80c22d09fff 100644 --- a/vortex-scalar/src/cast.rs +++ b/vortex-scalar/src/cast.rs @@ -4,40 +4,38 @@ use vortex_dtype::DType; use vortex_error::VortexExpect; use vortex_error::VortexResult; -use vortex_error::vortex_bail; use crate::Scalar; impl Scalar { /// Cast this scalar to another data type. - pub fn cast(&self, dtype: &DType) -> VortexResult { + pub fn cast(&self, target: &DType) -> VortexResult { // If the types are the same, return a clone. - if self.dtype() == dtype { + if self.dtype() == target { return Ok(self.clone()); } // Check for nullability casting. - if self.dtype().eq_ignore_nullability(dtype) { + if self.dtype().eq_ignore_nullability(target) { // Cast from non-nullable to nullable or vice versa. // The try_new with check will handle nullability checks. - return Scalar::try_new(dtype.clone(), self.value().cloned()); + return Scalar::try_new(target.clone(), self.value().cloned()); } - match (self.dtype(), dtype) { - (_, DType::Null) => { - // Can cast anything to null if the value is null. - if self.value().is_none() { - return Ok(Scalar::null(dtype.clone())); - } - vortex_bail!("Cannot cast non-null value {} to null dtype", self); - } - _ => { - vortex_bail!( - "Casting scalar from {} to {} is not supported", - self.dtype(), - dtype - ); - } + if (self.value.is_none() || matches!(self.dtype, DType::Null)) && target.is_nullable() { + return Ok(Scalar::new(target.clone(), self.value.clone())); + } + + match &self.dtype { + DType::Null => unreachable!("Handled by the if case above"), + DType::Bool(_) => self.as_bool().cast(target), + DType::Primitive(..) => self.as_primitive().cast(target), + DType::Decimal(..) => self.as_decimal().cast(target), + DType::Utf8(_) => self.as_utf8().cast(target), + DType::Binary(_) => self.as_binary().cast(target), + DType::Struct(..) => self.as_struct().cast(target), + DType::List(..) | DType::FixedSizeList(..) => self.as_list().cast(target), + DType::Extension(..) => self.as_extension().cast(target), } } diff --git a/vortex-scalar/src/decimal/scalar.rs b/vortex-scalar/src/decimal/scalar.rs index 9c306867f01..f80e51a97f5 100644 --- a/vortex-scalar/src/decimal/scalar.rs +++ b/vortex-scalar/src/decimal/scalar.rs @@ -247,6 +247,15 @@ impl<'a> TryFrom<&'a Scalar> for DecimalScalar<'a> { } } +impl From> for Scalar { + fn from(ds: DecimalScalar<'_>) -> Self { + Scalar::new( + ds.dtype().clone(), + ds.decimal_value().map(ScalarValue::Decimal), + ) + } +} + impl PartialEq for DecimalScalar<'_> { fn eq(&self, other: &Self) -> bool { self.dtype.eq_ignore_nullability(other.dtype) && self.value == other.value diff --git a/vortex-scalar/src/decimal/value.rs b/vortex-scalar/src/decimal/value.rs index 8b5da4f685c..9a95f2b2254 100644 --- a/vortex-scalar/src/decimal/value.rs +++ b/vortex-scalar/src/decimal/value.rs @@ -13,6 +13,7 @@ use num_traits::CheckedMul; use num_traits::CheckedSub; use vortex_dtype::DType; use vortex_dtype::DecimalDType; +use vortex_dtype::DecimalType; use vortex_dtype::NativeDecimalType; use vortex_dtype::Nullability; use vortex_dtype::ToI256; @@ -67,6 +68,32 @@ impl DecimalValue { match_each_decimal_value!(self, |value| { T::from(*value) }) } + /// Returns true if this decimal value is zero. + pub fn is_zero(&self) -> bool { + match self { + DecimalValue::I8(v) => *v == 0, + DecimalValue::I16(v) => *v == 0, + DecimalValue::I32(v) => *v == 0, + DecimalValue::I64(v) => *v == 0, + DecimalValue::I128(v) => *v == 0, + DecimalValue::I256(v) => *v == i256::ZERO, + } + } + + /// Returns the 0 value given the [`DecimalType`]. + pub fn zero(decimal_type: &DecimalDType) -> Self { + let smallest_type = DecimalType::smallest_decimal_value_type(decimal_type); + + match smallest_type { + DecimalType::I8 => DecimalValue::I8(0), + DecimalType::I16 => DecimalValue::I16(0), + DecimalType::I32 => DecimalValue::I32(0), + DecimalType::I64 => DecimalValue::I64(0), + DecimalType::I128 => DecimalValue::I128(0), + DecimalType::I256 => DecimalValue::I256(i256::ZERO), + } + } + /// Check if this decimal value fits within the precision constraints of the given decimal type. /// /// The precision defines the total number of significant digits that can be represented. diff --git a/vortex-scalar/src/primitive.rs b/vortex-scalar/src/primitive.rs index e56a60f48f5..c67b6b8ddf0 100644 --- a/vortex-scalar/src/primitive.rs +++ b/vortex-scalar/src/primitive.rs @@ -246,6 +246,12 @@ impl<'a> TryFrom<&'a Scalar> for PrimitiveScalar<'a> { } } +impl From> for Scalar { + fn from(ps: PrimitiveScalar<'_>) -> Self { + Scalar::new(ps.dtype().clone(), ps.pvalue().map(ScalarValue::Primitive)) + } +} + impl Sub for PrimitiveScalar<'_> { type Output = Self; diff --git a/vortex-scalar/src/proto.rs b/vortex-scalar/src/proto.rs index 9a5a3c145db..924e209171d 100644 --- a/vortex-scalar/src/proto.rs +++ b/vortex-scalar/src/proto.rs @@ -169,6 +169,50 @@ fn scalar_value_from_proto(value: &pb::ScalarValue) -> VortexResult(&self) -> B { + use prost::Message; + let proto = scalar_value_to_proto(Some(self)); + let mut buf = B::default(); + proto + .encode(&mut buf) + .vortex_expect("Failed to encode scalar value"); + buf + } + + /// Deserialize a [`ScalarValue`] from protobuf bytes. + pub fn from_protobytes(bytes: &[u8]) -> VortexResult { + use prost::Message; + let proto = pb::ScalarValue::decode(bytes)?; + scalar_value_from_proto(&proto)? + .ok_or_else(|| vortex_err!("Cannot deserialize null as ScalarValue")) + } + + /// Serialize an optional [`ScalarValue`] to protobuf bytes (handles null values). + pub fn option_to_protobytes(value: Option<&ScalarValue>) -> B { + use prost::Message; + let proto = scalar_value_to_proto(value); + let mut buf = B::default(); + proto + .encode(&mut buf) + .vortex_expect("Failed to encode scalar value"); + buf + } + + /// Deserialize an optional [`ScalarValue`] from protobuf bytes (handles null values). + pub fn option_from_protobytes(bytes: &[u8]) -> VortexResult> { + use prost::Message; + let proto = pb::ScalarValue::decode(bytes)?; + scalar_value_from_proto(&proto) + } + + /// Convert from a protobuf [`ScalarValue`] to an optional [`ScalarValue`]. + pub fn option_from_proto(proto: &pb::ScalarValue) -> VortexResult> { + scalar_value_from_proto(proto) + } +} + #[cfg(test)] mod tests { use std::sync::Arc; diff --git a/vortex-scalar/src/pvalue.rs b/vortex-scalar/src/pvalue.rs index 021979010b5..a90fb3d7ab7 100644 --- a/vortex-scalar/src/pvalue.rs +++ b/vortex-scalar/src/pvalue.rs @@ -136,8 +136,25 @@ macro_rules! as_primitive { } impl PValue { + /// Returns true if this decimal value is zero. + pub fn is_zero(&self) -> bool { + matches!( + self, + PValue::U8(0) + | PValue::U16(0) + | PValue::U32(0) + | PValue::U64(0) + | PValue::I8(0) + | PValue::I16(0) + | PValue::I32(0) + | PValue::I64(0) + ) || matches!(self, PValue::F16(f) if f.to_f32() == Some(0.0)) + || matches!(self, PValue::F32(f) if *f == 0.0) + || matches!(self, PValue::F64(f) if *f == 0.0) + } + /// Creates a zero value for the given primitive type. - pub fn zero(ptype: PType) -> PValue { + pub fn zero(ptype: &PType) -> PValue { match ptype { PType::U8 => PValue::U8(0), PType::U16 => PValue::U16(0), @@ -599,17 +616,17 @@ mod test { #[test] fn test_zero_values() { - assert_eq!(PValue::zero(PType::U8), PValue::U8(0)); - assert_eq!(PValue::zero(PType::U16), PValue::U16(0)); - assert_eq!(PValue::zero(PType::U32), PValue::U32(0)); - assert_eq!(PValue::zero(PType::U64), PValue::U64(0)); - assert_eq!(PValue::zero(PType::I8), PValue::I8(0)); - assert_eq!(PValue::zero(PType::I16), PValue::I16(0)); - assert_eq!(PValue::zero(PType::I32), PValue::I32(0)); - assert_eq!(PValue::zero(PType::I64), PValue::I64(0)); - assert_eq!(PValue::zero(PType::F16), PValue::F16(f16::from_f32(0.0))); - assert_eq!(PValue::zero(PType::F32), PValue::F32(0.0)); - assert_eq!(PValue::zero(PType::F64), PValue::F64(0.0)); + assert_eq!(PValue::zero(&PType::U8), PValue::U8(0)); + assert_eq!(PValue::zero(&PType::U16), PValue::U16(0)); + assert_eq!(PValue::zero(&PType::U32), PValue::U32(0)); + assert_eq!(PValue::zero(&PType::U64), PValue::U64(0)); + assert_eq!(PValue::zero(&PType::I8), PValue::I8(0)); + assert_eq!(PValue::zero(&PType::I16), PValue::I16(0)); + assert_eq!(PValue::zero(&PType::I32), PValue::I32(0)); + assert_eq!(PValue::zero(&PType::I64), PValue::I64(0)); + assert_eq!(PValue::zero(&PType::F16), PValue::F16(f16::from_f32(0.0))); + assert_eq!(PValue::zero(&PType::F32), PValue::F32(0.0)); + assert_eq!(PValue::zero(&PType::F64), PValue::F64(0.0)); } #[test] diff --git a/vortex-scalar/src/scalar.rs b/vortex-scalar/src/scalar.rs index 292da746684..2c674d0b29f 100644 --- a/vortex-scalar/src/scalar.rs +++ b/vortex-scalar/src/scalar.rs @@ -4,6 +4,7 @@ use std::cmp::Ordering; use vortex_dtype::DType; +use vortex_dtype::NativeDType; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_ensure; @@ -27,15 +28,17 @@ use crate::Utf8Scalar; #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct Scalar { /// The type of the scalar. - dtype: DType, + pub(crate) dtype: DType, /// The value of the scalar. This is [`None`] if the value is null, otherwise it is [`Some`]. /// /// Invariant: If the [`DType`] is non-nullable, then this value _cannot_ be [`None`]. - value: Option, + pub(crate) value: Option, } impl Scalar { + // Constructors for null scalars. + /// Creates a new null [`Scalar`] with the given [`DType`]. /// /// # Panics @@ -50,7 +53,18 @@ impl Scalar { Self { dtype, value: None } } - // Constructors for potentially null scalar values. + // TODO(connor): Find places to use this instead of `null()`. + /// Creates a new null [`Scalar`] for the given scalar type. + /// + /// The resulting scalar will have a nullable version of the type's data type. + pub fn null_native() -> Self { + Self { + dtype: T::dtype().as_nullable(), + value: None, + } + } + + // Constructors for potentially null scalars. /// Creates a new [`Scalar`] with the given [`DType`] and potentially null [`ScalarValue`]. /// @@ -88,7 +102,7 @@ impl Scalar { Self { dtype, value } } - // Constructors for non-null scalar values. + // Constructors for non-null scalars. /// Creates a new [`Scalar`] with the given [`DType`] and [`ScalarValue`]. /// @@ -204,30 +218,91 @@ impl Scalar { } } - /// Returns the parts of the Scalar. + /// Returns the parts of the [`Scalar`]. pub fn into_parts(self) -> (DType, Option) { (self.dtype, self.value) } - /// Returns the DType of the Scalar. + /// Returns the [`DType`] of the [`Scalar`]. pub fn dtype(&self) -> &DType { &self.dtype } - /// Returns true if the Scalar is null. - pub fn is_null(&self) -> bool { - self.value.is_none() - } - - /// Returns the scalar value. + /// Returns an optional [`ScalarValue`] of the [`Scalar`], where `None` means the value is null. pub fn value(&self) -> Option<&ScalarValue> { self.value.as_ref() } - /// Returns the scalar value, consuming the Scalar. + /// Returns the internal optional [`ScalarValue`], where `None` means the value is null, + /// consuming the [`Scalar`]. pub fn into_value(self) -> Option { self.value } + + /// Returns `true` if the [`Scalar`] has a non-null value. + pub fn is_valid(&self) -> bool { + self.value.is_some() + } + + /// Returns `true` if the [`Scalar`] is null. + pub fn is_null(&self) -> bool { + self.value.is_none() + } + + /// Returns a default value for the given [`DType`]. + /// + /// For nullable types, this returns a null scalar. For non-nullable types, this returns the + /// zero value for the type. + /// + /// See [`ScalarValue::zero_value`] for more details about "zero" values. + pub fn default_value(dtype: &DType) -> Self { + if dtype.is_nullable() { + Self::null(dtype.clone()) + } else { + Self::zero_value(dtype) + } + } + + /// Returns a non-null zero / identity value for the given [`DType`]. + /// + /// See [`ScalarValue::zero_value`] for more details about "zero" values. + pub fn zero_value(dtype: &DType) -> Self { + let value = ScalarValue::zero_value(dtype); + Self::new_value(dtype.clone(), value) + } + + // /// Returns the size of the scalar in bytes, uncompressed. + // pub fn nbytes(&self) -> usize { + // match self.dtype() { + // DType::Null => 0, + // DType::Bool(_) => 1, + // DType::Primitive(ptype, _) => ptype.byte_width(), + // DType::Decimal(dt, _) => { + // if dt.precision() <= i128::MAX_PRECISION { + // size_of::() + // } else { + // size_of::() + // } + // } + // DType::Binary(_) | DType::Utf8(_) => self + // .value() + // .as_buffer() + // .ok() + // .flatten() + // .map_or(0, |s| s.len()), + // DType::Struct(..) => self + // .as_struct() + // .fields() + // .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::()) + // .unwrap_or_default(), + // DType::List(..) | DType::FixedSizeList(..) => self + // .as_list() + // .elements() + // .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::()) + // .unwrap_or_default(), + // DType::Extension(_) => self.as_extension().storage().nbytes(), + // } + // } } /// Scalar downcasing methods to typed views. diff --git a/vortex-scalar/src/scalar_value.rs b/vortex-scalar/src/scalar_value.rs index a477a035442..7a00553852b 100644 --- a/vortex-scalar/src/scalar_value.rs +++ b/vortex-scalar/src/scalar_value.rs @@ -8,6 +8,7 @@ use std::fmt::Formatter; use itertools::Itertools; use vortex_buffer::BufferString; use vortex_buffer::ByteBuffer; +use vortex_dtype::DType; use vortex_error::vortex_panic; use crate::DecimalValue; @@ -35,6 +36,69 @@ pub enum ScalarValue { // Extension(ExtScalarRef), } +impl ScalarValue { + /// Returns true if the scalar represents the zero / identity value for its [`DType`]. + /// + /// Returns false if the scalar is null. + /// + /// See [`Scalar::zero_value`] for more details about "zero" values. + pub fn is_zero(&self) -> bool { + // TODO(connor): Is it better to just do == Self::zero_value()? + match self { + ScalarValue::Bool(b) => !*b, + ScalarValue::Primitive(p) => p.is_zero(), + ScalarValue::Decimal(d) => d.is_zero(), + ScalarValue::Utf8(s) => s.is_empty(), + ScalarValue::Binary(b) => b.is_empty(), + ScalarValue::List(elems) => elems.is_empty(), + } + } + + /// Returns the zero / identity value for the given [`DType`]. + /// + /// # Zero Values + /// + /// Here is the list of zero values for each [`DType`] (when the [`DType`] is non-nullable): + /// + /// - `Null`: Does not have a "zero" value + /// - `Bool`: `false` + /// - `Primitive`: `0` + /// - `Decimal`: `0` + /// - `Utf8`: `""` + /// - `Binary`: An empty buffer + /// - `List`: An empty list + /// - `FixedSizeList`: A list (with correct size) of zero values, which is determined by the + /// element [`DType`] + /// - `Struct`: A struct where each field has a zero value, which is determined by the field + /// [`DType`] + /// + /// - `Extension`: TODO(connor): Is this right? + /// The zero value of the storage [`DType`] + pub fn zero_value(dtype: &DType) -> Self { + match dtype { + DType::Null => vortex_panic!("Null dtype has no zero value"), + DType::Bool(_) => Self::Bool(false), + DType::Primitive(ptype, _) => Self::Primitive(PValue::zero(ptype)), + DType::Decimal(dt, ..) => Self::Decimal(DecimalValue::zero(dt)), + DType::Utf8(_) => Self::Utf8(BufferString::empty()), + DType::Binary(_) => Self::Binary(ByteBuffer::empty()), + DType::List(..) => Self::List(vec![]), + DType::FixedSizeList(edt, size, _) => { + let elements = (0..*size).map(|_| Some(Self::zero_value(edt))).collect(); + Self::List(elements) + } + DType::Struct(fields, _) => { + let field_values = fields + .fields() + .map(|f| Some(Self::zero_value(&f))) + .collect(); + Self::List(field_values) + } + DType::Extension(_) => vortex_panic!("Extension dtype has no zero value"), // TODO(connor): Fix this! + } + } +} + impl ScalarValue { /// Returns the boolean value, panicking if the value is not a [`Bool`][ScalarValue::Bool]. pub fn as_bool(&self) -> bool { From 63ec1bf28dd135688f7565530144d57f1f8d5c6b Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Wed, 4 Feb 2026 17:44:05 -0500 Subject: [PATCH 05/22] clean up vortex-scalar Signed-off-by: Connor Tsui --- vortex-scalar/src/arrow.rs | 608 +++++++++++++++++++++++ vortex-scalar/src/arrow/mod.rs | 203 -------- vortex-scalar/src/arrow/tests.rs | 407 --------------- vortex-scalar/src/binary.rs | 84 +--- vortex-scalar/src/bool.rs | 68 +-- vortex-scalar/src/convert/decimal.rs | 122 +++++ vortex-scalar/src/convert/from_scalar.rs | 221 ++++++++ vortex-scalar/src/convert/into_scalar.rs | 124 +++++ vortex-scalar/src/convert/mod.rs | 12 + vortex-scalar/src/convert/primitive.rs | 137 +++++ vortex-scalar/src/decimal/macros.rs | 52 -- vortex-scalar/src/decimal/mod.rs | 8 +- vortex-scalar/src/decimal/scalar.rs | 6 +- vortex-scalar/src/decimal/value.rs | 90 ---- vortex-scalar/src/lib.rs | 6 +- vortex-scalar/src/list.rs | 24 +- vortex-scalar/src/primitive.rs | 122 ----- vortex-scalar/src/scalar_value.rs | 4 +- vortex-scalar/src/struct_.rs | 9 - vortex-scalar/src/utf8.rs | 101 +--- 20 files changed, 1247 insertions(+), 1161 deletions(-) create mode 100644 vortex-scalar/src/arrow.rs delete mode 100644 vortex-scalar/src/arrow/mod.rs delete mode 100644 vortex-scalar/src/arrow/tests.rs create mode 100644 vortex-scalar/src/convert/decimal.rs create mode 100644 vortex-scalar/src/convert/from_scalar.rs create mode 100644 vortex-scalar/src/convert/into_scalar.rs create mode 100644 vortex-scalar/src/convert/mod.rs create mode 100644 vortex-scalar/src/convert/primitive.rs delete mode 100644 vortex-scalar/src/decimal/macros.rs diff --git a/vortex-scalar/src/arrow.rs b/vortex-scalar/src/arrow.rs new file mode 100644 index 00000000000..912d96a6e2e --- /dev/null +++ b/vortex-scalar/src/arrow.rs @@ -0,0 +1,608 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::sync::Arc; + +use arrow_array::Scalar as ArrowScalar; +use arrow_array::*; +use vortex_dtype::DType; +use vortex_dtype::PType; +use vortex_dtype::datetime::AnyTemporal; +use vortex_dtype::datetime::TemporalMetadata; +use vortex_dtype::datetime::TimeUnit; +use vortex_error::VortexError; +use vortex_error::vortex_bail; +use vortex_error::vortex_err; + +use crate::Scalar; +use crate::decimal::DecimalValue; + +macro_rules! value_to_arrow_scalar { + ($V:expr, $AR:ty) => { + Ok(std::sync::Arc::new( + $V.map(<$AR>::new_scalar) + .unwrap_or_else(|| arrow_array::Scalar::new(<$AR>::new_null(1))), + )) + }; +} + +macro_rules! timestamp_to_arrow_scalar { + ($V:expr, $TZ:expr, $AR:ty) => {{ + let array = match $V { + Some(v) => <$AR>::new_scalar(v).into_inner(), + None => <$AR>::new_null(1), + } + .with_timezone_opt($TZ); + Ok(Arc::new(ArrowScalar::new(array))) + }}; +} + +impl TryFrom<&Scalar> for Arc { + type Error = VortexError; + + fn try_from(value: &Scalar) -> Result, Self::Error> { + match value.dtype() { + DType::Null => Ok(Arc::new(NullArray::new(1))), + DType::Bool(_) => value_to_arrow_scalar!(value.as_bool().value(), BooleanArray), + DType::Primitive(ptype, ..) => { + let scalar = value.as_primitive(); + Ok(match ptype { + PType::U8 => scalar + .typed_value() + .map(|i| Arc::new(UInt8Array::new_scalar(i)) as Arc) + .unwrap_or_else(|| Arc::new(UInt8Array::new_null(1))), + PType::U16 => scalar + .typed_value() + .map(|i| Arc::new(UInt16Array::new_scalar(i)) as Arc) + .unwrap_or_else(|| Arc::new(UInt16Array::new_null(1))), + PType::U32 => scalar + .typed_value() + .map(|i| Arc::new(UInt32Array::new_scalar(i)) as Arc) + .unwrap_or_else(|| Arc::new(UInt32Array::new_null(1))), + PType::U64 => scalar + .typed_value() + .map(|i| Arc::new(UInt64Array::new_scalar(i)) as Arc) + .unwrap_or_else(|| Arc::new(UInt64Array::new_null(1))), + PType::I8 => scalar + .typed_value() + .map(|i| Arc::new(Int8Array::new_scalar(i)) as Arc) + .unwrap_or_else(|| Arc::new(Int8Array::new_null(1))), + PType::I16 => scalar + .typed_value() + .map(|i| Arc::new(Int16Array::new_scalar(i)) as Arc) + .unwrap_or_else(|| Arc::new(Int16Array::new_null(1))), + PType::I32 => scalar + .typed_value() + .map(|i| Arc::new(Int32Array::new_scalar(i)) as Arc) + .unwrap_or_else(|| Arc::new(Int32Array::new_null(1))), + PType::I64 => scalar + .typed_value() + .map(|i| Arc::new(Int64Array::new_scalar(i)) as Arc) + .unwrap_or_else(|| Arc::new(Int64Array::new_null(1))), + PType::F16 => scalar + .typed_value() + .map(|i| Arc::new(Float16Array::new_scalar(i)) as Arc) + .unwrap_or_else(|| Arc::new(Float16Array::new_null(1))), + PType::F32 => scalar + .typed_value() + .map(|i| Arc::new(Float32Array::new_scalar(i)) as Arc) + .unwrap_or_else(|| Arc::new(Float32Array::new_null(1))), + PType::F64 => scalar + .typed_value() + .map(|i| Arc::new(Float64Array::new_scalar(i)) as Arc) + .unwrap_or_else(|| Arc::new(Float64Array::new_null(1))), + }) + } + DType::Decimal(..) => match value.as_decimal().decimal_value() { + // TODO(joe): replace with decimal32, etc. + Some(DecimalValue::I8(v)) => Ok(Arc::new(Decimal128Array::new_scalar(v as i128))), + Some(DecimalValue::I16(v)) => Ok(Arc::new(Decimal128Array::new_scalar(v as i128))), + Some(DecimalValue::I32(v)) => Ok(Arc::new(Decimal128Array::new_scalar(v as i128))), + Some(DecimalValue::I64(v)) => Ok(Arc::new(Decimal128Array::new_scalar(v as i128))), + Some(DecimalValue::I128(v128)) => Ok(Arc::new(Decimal128Array::new_scalar(v128))), + Some(DecimalValue::I256(v256)) => { + Ok(Arc::new(Decimal256Array::new_scalar(v256.into()))) + } + None => Ok(Arc::new(arrow_array::Scalar::new( + Decimal128Array::new_null(1), + ))), + }, + DType::Utf8(_) => { + value_to_arrow_scalar!(value.as_utf8().to_value(), StringViewArray) + } + DType::Binary(_) => { + value_to_arrow_scalar!(value.as_binary().to_value(), BinaryViewArray) + } + DType::Struct(..) => { + todo!("struct scalar conversion") + } + DType::List(..) => { + todo!("list scalar conversion") + } + DType::FixedSizeList(..) => { + todo!("fixed-size list scalar conversion") + } + DType::Extension(ext) => { + let Some(temporal) = ext.metadata_opt::() else { + vortex_bail!("Cannot convert extension scalar {} to Arrow", ext.id()) + }; + + let storage_scalar = value.as_extension().storage(); + let primitive = storage_scalar + .as_primitive_opt() + .ok_or_else(|| vortex_err!("Expected primitive scalar"))?; + + match temporal { + TemporalMetadata::Timestamp(unit, tz) => { + let value = primitive.as_::(); + match unit { + TimeUnit::Nanoseconds => { + timestamp_to_arrow_scalar!( + value, + tz.clone(), + TimestampNanosecondArray + ) + } + TimeUnit::Microseconds => { + timestamp_to_arrow_scalar!( + value, + tz.clone(), + TimestampMicrosecondArray + ) + } + TimeUnit::Milliseconds => { + timestamp_to_arrow_scalar!( + value, + tz.clone(), + TimestampMillisecondArray + ) + } + TimeUnit::Seconds => { + timestamp_to_arrow_scalar!(value, tz.clone(), TimestampSecondArray) + } + TimeUnit::Days => { + vortex_bail!("Unsupported TimeUnit {unit} for {}", ext.id()) + } + } + } + TemporalMetadata::Date(unit) => match unit { + TimeUnit::Milliseconds => { + value_to_arrow_scalar!(primitive.as_::(), Date64Array) + } + TimeUnit::Days => { + value_to_arrow_scalar!(primitive.as_::(), Date32Array) + } + TimeUnit::Nanoseconds | TimeUnit::Microseconds | TimeUnit::Seconds => { + vortex_bail!("Unsupported TimeUnit {unit} for {}", ext.id()) + } + }, + TemporalMetadata::Time(unit) => match unit { + TimeUnit::Nanoseconds => { + value_to_arrow_scalar!(primitive.as_::(), Time64NanosecondArray) + } + TimeUnit::Microseconds => { + value_to_arrow_scalar!(primitive.as_::(), Time64MicrosecondArray) + } + TimeUnit::Milliseconds => { + value_to_arrow_scalar!(primitive.as_::(), Time32MillisecondArray) + } + TimeUnit::Seconds => { + value_to_arrow_scalar!(primitive.as_::(), Time32SecondArray) + } + TimeUnit::Days => { + vortex_bail!("Unsupported TimeUnit {unit} for {}", ext.id()) + } + }, + } + } + } + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow_array::Datum; + use rstest::rstest; + use vortex_dtype::DType; + use vortex_dtype::NativeDType; + use vortex_dtype::Nullability; + use vortex_dtype::PType; + use vortex_dtype::datetime::Date; + use vortex_dtype::datetime::Time; + use vortex_dtype::datetime::TimeUnit; + use vortex_dtype::datetime::Timestamp; + use vortex_dtype::datetime::TimestampOptions; + use vortex_dtype::extension::ExtDTypeVTable; + use vortex_error::VortexResult; + use vortex_error::vortex_bail; + + use crate::Scalar; + + #[test] + fn test_null_scalar_to_arrow() { + let scalar = Scalar::null(DType::Null); + let result = Arc::::try_from(&scalar); + assert!(result.is_ok()); + } + + #[test] + fn test_bool_scalar_to_arrow() { + let scalar = Scalar::bool(true, Nullability::NonNullable); + let result = Arc::::try_from(&scalar); + assert!(result.is_ok()); + } + + #[test] + fn test_null_bool_scalar_to_arrow() { + let scalar = Scalar::null(bool::dtype().as_nullable()); + let result = Arc::::try_from(&scalar); + assert!(result.is_ok()); + } + + #[test] + fn test_primitive_u8_to_arrow() { + let scalar = Scalar::primitive(42u8, Nullability::NonNullable); + let result = Arc::::try_from(&scalar); + assert!(result.is_ok()); + } + + #[test] + fn test_primitive_u16_to_arrow() { + let scalar = Scalar::primitive(1000u16, Nullability::NonNullable); + let result = Arc::::try_from(&scalar); + assert!(result.is_ok()); + } + + #[test] + fn test_primitive_u32_to_arrow() { + let scalar = Scalar::primitive(100000u32, Nullability::NonNullable); + let result = Arc::::try_from(&scalar); + assert!(result.is_ok()); + } + + #[test] + fn test_primitive_u64_to_arrow() { + let scalar = Scalar::primitive(10000000000u64, Nullability::NonNullable); + let result = Arc::::try_from(&scalar); + assert!(result.is_ok()); + } + + #[test] + fn test_primitive_i8_to_arrow() { + let scalar = Scalar::primitive(-42i8, Nullability::NonNullable); + let result = Arc::::try_from(&scalar); + assert!(result.is_ok()); + } + + #[test] + fn test_primitive_i16_to_arrow() { + let scalar = Scalar::primitive(-1000i16, Nullability::NonNullable); + let result = Arc::::try_from(&scalar); + assert!(result.is_ok()); + } + + #[test] + fn test_primitive_i32_to_arrow() { + let scalar = Scalar::primitive(-100000i32, Nullability::NonNullable); + let result = Arc::::try_from(&scalar); + assert!(result.is_ok()); + } + + #[test] + fn test_primitive_i64_to_arrow() { + let scalar = Scalar::primitive(-10000000000i64, Nullability::NonNullable); + let result = Arc::::try_from(&scalar); + assert!(result.is_ok()); + } + + #[test] + fn test_primitive_f16_to_arrow() { + use vortex_dtype::half::f16; + + let scalar = Scalar::primitive(f16::from_f32(1.234), Nullability::NonNullable); + let result = Arc::::try_from(&scalar); + assert!(result.is_ok()); + } + + #[test] + fn test_primitive_f32_to_arrow() { + let scalar = Scalar::primitive(1.234f32, Nullability::NonNullable); + let result = Arc::::try_from(&scalar); + assert!(result.is_ok()); + } + + #[test] + fn test_primitive_f64_to_arrow() { + let scalar = Scalar::primitive(1.234567890123f64, Nullability::NonNullable); + let result = Arc::::try_from(&scalar); + assert!(result.is_ok()); + } + + #[test] + fn test_null_primitive_to_arrow() { + let scalar = Scalar::null(i32::dtype().as_nullable()); + let result = Arc::::try_from(&scalar); + assert!(result.is_ok()); + } + + #[test] + fn test_utf8_scalar_to_arrow() { + let scalar = Scalar::utf8("hello world".to_string(), Nullability::NonNullable); + let result = Arc::::try_from(&scalar); + assert!(result.is_ok()); + } + + #[test] + fn test_null_utf8_scalar_to_arrow() { + let scalar = Scalar::null(String::dtype().as_nullable()); + let result = Arc::::try_from(&scalar); + assert!(result.is_ok()); + } + + #[test] + fn test_binary_scalar_to_arrow() { + let data = vec![1u8, 2, 3, 4, 5]; + let scalar = Scalar::binary(data, Nullability::NonNullable); + let result = Arc::::try_from(&scalar); + assert!(result.is_ok()); + } + + #[test] + fn test_null_binary_scalar_to_arrow() { + let scalar = Scalar::null(DType::Binary(Nullability::Nullable)); + let result = Arc::::try_from(&scalar); + assert!(result.is_ok()); + } + + #[test] + fn test_decimal_scalars_to_arrow() { + use vortex_dtype::DecimalDType; + + use crate::decimal::DecimalValue; + + // Test various decimal value types + let decimal_dtype = DecimalDType::new(5, 2); + + let scalar_i8 = Scalar::decimal( + DecimalValue::I8(100), + decimal_dtype, + Nullability::NonNullable, + ); + assert!(Arc::::try_from(&scalar_i8).is_ok()); + + let scalar_i16 = Scalar::decimal( + DecimalValue::I16(10000), + decimal_dtype, + Nullability::NonNullable, + ); + assert!(Arc::::try_from(&scalar_i16).is_ok()); + + let scalar_i32 = Scalar::decimal( + DecimalValue::I32(1000000), + decimal_dtype, + Nullability::NonNullable, + ); + assert!(Arc::::try_from(&scalar_i32).is_ok()); + + let scalar_i64 = Scalar::decimal( + DecimalValue::I64(100000000000), + decimal_dtype, + Nullability::NonNullable, + ); + assert!(Arc::::try_from(&scalar_i64).is_ok()); + + let scalar_i128 = Scalar::decimal( + DecimalValue::I128(123456789012345678901234567890i128), + decimal_dtype, + Nullability::NonNullable, + ); + assert!(Arc::::try_from(&scalar_i128).is_ok()); + + // Test i256 + use vortex_dtype::i256; + let value_i256 = i256::from_i128(123456789012345678901234567890i128); + let scalar_i256 = Scalar::decimal( + DecimalValue::I256(value_i256), + decimal_dtype, + Nullability::NonNullable, + ); + assert!(Arc::::try_from(&scalar_i256).is_ok()); + } + + #[test] + fn test_null_decimal_to_arrow() { + use vortex_dtype::DecimalDType; + + let decimal_dtype = DecimalDType::new(10, 2); + let scalar = Scalar::null(DType::Decimal(decimal_dtype, Nullability::Nullable)); + let result = Arc::::try_from(&scalar); + assert!(result.is_ok()); + } + + #[test] + #[should_panic(expected = "struct scalar conversion")] + fn test_struct_scalar_to_arrow_todo() { + use vortex_dtype::FieldDType; + use vortex_dtype::StructFields; + + let struct_dtype = DType::Struct( + StructFields::from_iter([( + "field1", + FieldDType::from(DType::Primitive(PType::I32, Nullability::NonNullable)), + )]), + Nullability::NonNullable, + ); + + let struct_scalar = Scalar::struct_( + struct_dtype, + vec![Scalar::primitive(42i32, Nullability::NonNullable)], + ); + Arc::::try_from(&struct_scalar).unwrap(); + } + + #[test] + #[should_panic(expected = "list scalar conversion")] + fn test_list_scalar_to_arrow_todo() { + let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)); + let list_scalar = Scalar::list( + element_dtype, + vec![ + Scalar::primitive(1i32, Nullability::NonNullable), + Scalar::primitive(2i32, Nullability::NonNullable), + ], + Nullability::NonNullable, + ); + + Arc::::try_from(&list_scalar).unwrap(); + } + + #[test] + #[should_panic(expected = "Cannot convert extension scalar")] + fn test_non_temporal_extension_to_arrow_todo() { + use vortex_dtype::ExtID; + + #[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] + struct SomeExt; + impl ExtDTypeVTable for SomeExt { + type Metadata = String; + + fn id(&self) -> ExtID { + ExtID::new_ref("some_ext") + } + + fn serialize(&self, _options: &Self::Metadata) -> VortexResult> { + vortex_bail!("not implemented") + } + + fn deserialize(&self, _data: &[u8]) -> VortexResult { + vortex_bail!("not implemented") + } + + fn validate_dtype( + &self, + _options: &Self::Metadata, + _storage_dtype: &DType, + ) -> VortexResult<()> { + Ok(()) + } + } + + let scalar = Scalar::extension::( + "".into(), + Scalar::primitive(42i32, Nullability::NonNullable), + ); + + Arc::::try_from(&scalar).unwrap(); + } + + #[rstest] + #[case(TimeUnit::Nanoseconds, PType::I64, 123456789i64)] + #[case(TimeUnit::Microseconds, PType::I64, 123456789i64)] + #[case(TimeUnit::Milliseconds, PType::I32, 123456i64)] + #[case(TimeUnit::Seconds, PType::I32, 1234i64)] + fn test_temporal_time_to_arrow( + #[case] time_unit: TimeUnit, + #[case] ptype: PType, + #[case] value: i64, + ) { + let scalar = Scalar::extension::