From 3bc9d15906b0e6d3ec96d66f060588c3107f0a61 Mon Sep 17 00:00:00 2001 From: Anton Borisov Date: Thu, 22 Jan 2026 00:40:44 +0000 Subject: [PATCH 1/2] [TASK-191] Arrow serialization for decimal and temporal types --- .../src/client/table/log_fetch_buffer.rs | 12 +- crates/fluss/src/client/table/scanner.rs | 8 +- crates/fluss/src/client/write/accumulator.rs | 2 +- crates/fluss/src/client/write/batch.rs | 8 +- crates/fluss/src/record/arrow.rs | 546 ++++++++++++++---- crates/fluss/src/row/datum.rs | 501 +++++++++++++++- 6 files changed, 935 insertions(+), 142 deletions(-) diff --git a/crates/fluss/src/client/table/log_fetch_buffer.rs b/crates/fluss/src/client/table/log_fetch_buffer.rs index ca0a2532..214a79cd 100644 --- a/crates/fluss/src/client/table/log_fetch_buffer.rs +++ b/crates/fluss/src/client/table/log_fetch_buffer.rs @@ -657,13 +657,13 @@ mod tests { use std::sync::Arc; use std::time::Duration; - fn test_read_context() -> ReadContext { + fn test_read_context() -> Result { let row_type = RowType::new(vec![DataField::new( "id".to_string(), DataTypes::int(), None, )]); - ReadContext::new(to_arrow_schema(&row_type), false) + Ok(ReadContext::new(to_arrow_schema(&row_type)?, false)) } struct ErrorPendingFetch { @@ -689,7 +689,7 @@ mod tests { #[tokio::test] async fn await_not_empty_returns_wakeup_error() { - let buffer = LogFetchBuffer::new(test_read_context()); + let buffer = LogFetchBuffer::new(test_read_context().unwrap()); buffer.wakeup(); let result = buffer.await_not_empty(Duration::from_millis(10)).await; @@ -698,7 +698,7 @@ mod tests { #[tokio::test] async fn await_not_empty_returns_pending_error() { - let buffer = LogFetchBuffer::new(test_read_context()); + let buffer = LogFetchBuffer::new(test_read_context().unwrap()); let table_bucket = TableBucket::new(1, 0); buffer.pend(Box::new(ErrorPendingFetch { table_bucket: table_bucket.clone(), @@ -728,7 +728,7 @@ mod tests { compression_type: ArrowCompressionType::None, compression_level: DEFAULT_NON_ZSTD_COMPRESSION_LEVEL, }, - ); + )?; let mut row = GenericRow::new(); row.set_field(0, 1_i32); @@ -738,7 +738,7 @@ mod tests { let data = builder.build()?; let log_records = LogRecordsBatches::new(data.clone()); - let read_context = ReadContext::new(to_arrow_schema(&row_type), false); + let read_context = ReadContext::new(to_arrow_schema(&row_type)?, false); let mut fetch = DefaultCompletedFetch::new( TableBucket::new(1, 0), log_records, diff --git a/crates/fluss/src/client/table/scanner.rs b/crates/fluss/src/client/table/scanner.rs index e9b2ce10..cf0b257f 100644 --- a/crates/fluss/src/client/table/scanner.rs +++ b/crates/fluss/src/client/table/scanner.rs @@ -470,7 +470,7 @@ impl LogFetcher { log_scanner_status: Arc, projected_fields: Option>, ) -> Result { - let full_arrow_schema = to_arrow_schema(table_info.get_row_type()); + let full_arrow_schema = to_arrow_schema(table_info.get_row_type())?; let read_context = Self::create_read_context(full_arrow_schema.clone(), projected_fields.clone(), false)?; let remote_read_context = @@ -1445,7 +1445,7 @@ mod tests { compression_type: ArrowCompressionType::None, compression_level: DEFAULT_NON_ZSTD_COMPRESSION_LEVEL, }, - ); + )?; let record = WriteRecord::for_append( table_path, 1, @@ -1477,7 +1477,7 @@ mod tests { let data = build_records(&table_info, Arc::new(table_path))?; let log_records = LogRecordsBatches::new(data.clone()); - let read_context = ReadContext::new(to_arrow_schema(table_info.get_row_type()), false); + let read_context = ReadContext::new(to_arrow_schema(table_info.get_row_type())?, false); let completed = DefaultCompletedFetch::new(bucket.clone(), log_records, data.len(), read_context, 0, 0); fetcher.log_fetch_buffer.add(Box::new(completed)); @@ -1506,7 +1506,7 @@ mod tests { let bucket = TableBucket::new(1, 0); let data = build_records(&table_info, Arc::new(table_path))?; let log_records = LogRecordsBatches::new(data.clone()); - let read_context = ReadContext::new(to_arrow_schema(table_info.get_row_type()), false); + let read_context = ReadContext::new(to_arrow_schema(table_info.get_row_type())?, false); let mut completed: Box = Box::new(DefaultCompletedFetch::new( bucket, log_records, diff --git a/crates/fluss/src/client/write/accumulator.rs b/crates/fluss/src/client/write/accumulator.rs index fb7b5447..46c822c1 100644 --- a/crates/fluss/src/client/write/accumulator.rs +++ b/crates/fluss/src/client/write/accumulator.rs @@ -112,7 +112,7 @@ impl RecordAccumulator { bucket_id, current_time_ms(), matches!(&record.record, Record::Log(LogWriteRecord::RecordBatch(_))), - )), + )?), Record::Kv(kv_record) => Kv(KvWriteBatch::new( self.batch_id.fetch_add(1, Ordering::Relaxed), table_path.as_ref().clone(), diff --git a/crates/fluss/src/client/write/batch.rs b/crates/fluss/src/client/write/batch.rs index 159e3136..78381c6e 100644 --- a/crates/fluss/src/client/write/batch.rs +++ b/crates/fluss/src/client/write/batch.rs @@ -197,18 +197,18 @@ impl ArrowLogWriteBatch { bucket_id: BucketId, create_ms: i64, to_append_record_batch: bool, - ) -> Self { + ) -> Result { let base = InnerWriteBatch::new(batch_id, table_path, create_ms, bucket_id); - Self { + Ok(Self { write_batch: base, arrow_builder: MemoryLogRecordsArrowBuilder::new( schema_id, row_type, to_append_record_batch, arrow_compression_info, - ), + )?, built_records: None, - } + }) } pub fn batch_id(&self) -> i64 { diff --git a/crates/fluss/src/record/arrow.rs b/crates/fluss/src/record/arrow.rs index 3c94b720..d3e8a5a9 100644 --- a/crates/fluss/src/record/arrow.rs +++ b/crates/fluss/src/record/arrow.rs @@ -22,9 +22,12 @@ use crate::metadata::{DataType, RowType}; use crate::record::{ChangeType, ScanRecord}; use crate::row::{ColumnarRow, GenericRow}; use arrow::array::{ - ArrayBuilder, ArrayRef, BinaryBuilder, BooleanBuilder, Float32Builder, Float64Builder, - Int8Builder, Int16Builder, Int32Builder, Int64Builder, StringBuilder, UInt8Builder, - UInt16Builder, UInt32Builder, UInt64Builder, + ArrayBuilder, ArrayRef, BinaryBuilder, BooleanBuilder, Date32Builder, Decimal128Builder, + Float32Builder, Float64Builder, Int8Builder, Int16Builder, Int32Builder, Int64Builder, + StringBuilder, Time32MillisecondBuilder, Time32SecondBuilder, Time64MicrosecondBuilder, + Time64NanosecondBuilder, TimestampMicrosecondBuilder, TimestampMillisecondBuilder, + TimestampNanosecondBuilder, TimestampSecondBuilder, UInt8Builder, UInt16Builder, UInt32Builder, + UInt64Builder, }; use arrow::{ array::RecordBatch, @@ -172,61 +175,129 @@ pub struct RowAppendRecordBatchBuilder { } impl RowAppendRecordBatchBuilder { - pub fn new(row_type: &RowType) -> Self { - let schema_ref = to_arrow_schema(row_type); - let builders = Mutex::new( - schema_ref - .fields() - .iter() - .map(|field| Self::create_builder(field.data_type())) - .collect(), - ); - Self { + pub fn new(row_type: &RowType) -> Result { + let schema_ref = to_arrow_schema(row_type)?; + let builders: Result> = schema_ref + .fields() + .iter() + .map(|field| Self::create_builder(field.data_type())) + .collect(); + Ok(Self { table_schema: schema_ref.clone(), - arrow_column_builders: builders, + arrow_column_builders: Mutex::new(builders?), records_count: 0, - } + }) } - fn create_builder(data_type: &arrow_schema::DataType) -> Box { + fn create_builder(data_type: &arrow_schema::DataType) -> Result> { match data_type { - arrow_schema::DataType::Int8 => Box::new(Int8Builder::new()), - arrow_schema::DataType::Int16 => Box::new(Int16Builder::new()), - arrow_schema::DataType::Int32 => Box::new(Int32Builder::new()), - arrow_schema::DataType::Int64 => Box::new(Int64Builder::new()), - arrow_schema::DataType::UInt8 => Box::new(UInt8Builder::new()), - arrow_schema::DataType::UInt16 => Box::new(UInt16Builder::new()), - arrow_schema::DataType::UInt32 => Box::new(UInt32Builder::new()), - arrow_schema::DataType::UInt64 => Box::new(UInt64Builder::new()), - arrow_schema::DataType::Float32 => Box::new(Float32Builder::new()), - arrow_schema::DataType::Float64 => Box::new(Float64Builder::new()), - arrow_schema::DataType::Boolean => Box::new(BooleanBuilder::new()), - arrow_schema::DataType::Utf8 => Box::new(StringBuilder::new()), - arrow_schema::DataType::Binary => Box::new(BinaryBuilder::new()), - dt => panic!("Unsupported data type: {dt:?}"), + arrow_schema::DataType::Int8 => Ok(Box::new(Int8Builder::new())), + arrow_schema::DataType::Int16 => Ok(Box::new(Int16Builder::new())), + arrow_schema::DataType::Int32 => Ok(Box::new(Int32Builder::new())), + arrow_schema::DataType::Int64 => Ok(Box::new(Int64Builder::new())), + arrow_schema::DataType::UInt8 => Ok(Box::new(UInt8Builder::new())), + arrow_schema::DataType::UInt16 => Ok(Box::new(UInt16Builder::new())), + arrow_schema::DataType::UInt32 => Ok(Box::new(UInt32Builder::new())), + arrow_schema::DataType::UInt64 => Ok(Box::new(UInt64Builder::new())), + arrow_schema::DataType::Float32 => Ok(Box::new(Float32Builder::new())), + arrow_schema::DataType::Float64 => Ok(Box::new(Float64Builder::new())), + arrow_schema::DataType::Boolean => Ok(Box::new(BooleanBuilder::new())), + arrow_schema::DataType::Utf8 => Ok(Box::new(StringBuilder::new())), + arrow_schema::DataType::Binary => Ok(Box::new(BinaryBuilder::new())), + arrow_schema::DataType::Decimal128(precision, scale) => { + let builder = Decimal128Builder::new() + .with_precision_and_scale(*precision, *scale) + .map_err(|e| Error::IllegalArgument { + message: format!( + "Invalid decimal precision {} or scale {}: {}", + precision, scale, e + ), + })?; + Ok(Box::new(builder)) + } + arrow_schema::DataType::Date32 => Ok(Box::new(Date32Builder::new())), + arrow_schema::DataType::Time32(unit) => match unit { + arrow_schema::TimeUnit::Second => Ok(Box::new(Time32SecondBuilder::new())), + arrow_schema::TimeUnit::Millisecond => { + Ok(Box::new(Time32MillisecondBuilder::new())) + } + _ => Err(Error::IllegalArgument { + message: format!( + "Time32 only supports Second and Millisecond units, got: {:?}", + unit + ), + }), + }, + arrow_schema::DataType::Time64(unit) => match unit { + arrow_schema::TimeUnit::Microsecond => { + Ok(Box::new(Time64MicrosecondBuilder::new())) + } + arrow_schema::TimeUnit::Nanosecond => Ok(Box::new(Time64NanosecondBuilder::new())), + _ => Err(Error::IllegalArgument { + message: format!( + "Time64 only supports Microsecond and Nanosecond units, got: {:?}", + unit + ), + }), + }, + arrow_schema::DataType::Timestamp(arrow_schema::TimeUnit::Second, _) => { + Ok(Box::new(TimestampSecondBuilder::new())) + } + arrow_schema::DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, _) => { + Ok(Box::new(TimestampMillisecondBuilder::new())) + } + arrow_schema::DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, _) => { + Ok(Box::new(TimestampMicrosecondBuilder::new())) + } + arrow_schema::DataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, _) => { + Ok(Box::new(TimestampNanosecondBuilder::new())) + } + dt => Err(Error::IllegalArgument { + message: format!("Unsupported data type: {dt:?}"), + }), } } } impl ArrowRecordBatchInnerBuilder for RowAppendRecordBatchBuilder { fn build_arrow_record_batch(&self) -> Result> { - let arrays = self + let arrays: Result> = self .arrow_column_builders .lock() .iter_mut() - .map(|b| b.finish()) - .collect::>(); + .enumerate() + .map(|(idx, b)| { + let array = b.finish(); + let expected_type = self.table_schema.field(idx).data_type(); + + // Validate array type matches schema + if array.data_type() != expected_type { + return Err(Error::IllegalArgument { + message: format!( + "Builder type mismatch at column {}: expected {:?}, got {:?}", + idx, + expected_type, + array.data_type() + ), + }); + } + + Ok(array) + }) + .collect(); + Ok(Arc::new(RecordBatch::try_new( self.table_schema.clone(), - arrays, + arrays?, )?)) } fn append(&mut self, row: &GenericRow) -> Result { for (idx, value) in row.values.iter().enumerate() { + let field_type = self.table_schema.field(idx).data_type(); let mut builder_binding = self.arrow_column_builders.lock(); let builder = builder_binding.get_mut(idx).unwrap(); - value.append_to(builder.as_mut())?; + value.append_to(builder.as_mut(), field_type)?; } self.records_count += 1; Ok(true) @@ -255,15 +326,15 @@ impl MemoryLogRecordsArrowBuilder { row_type: &RowType, to_append_record_batch: bool, arrow_compression_info: ArrowCompressionInfo, - ) -> Self { + ) -> Result { let arrow_batch_builder: Box = { if to_append_record_batch { Box::new(PrebuiltRecordBatchBuilder::default()) } else { - Box::new(RowAppendRecordBatchBuilder::new(row_type)) + Box::new(RowAppendRecordBatchBuilder::new(row_type)?) } }; - MemoryLogRecordsArrowBuilder { + Ok(MemoryLogRecordsArrowBuilder { base_log_offset: BUILDER_DEFAULT_OFFSET, schema_id, magic: CURRENT_LOG_MAGIC_VALUE, @@ -272,7 +343,7 @@ impl MemoryLogRecordsArrowBuilder { is_closed: false, arrow_record_batch_builder: arrow_batch_builder, arrow_compression_info, - } + }) } pub fn append(&mut self, record: &WriteRecord) -> Result { @@ -641,24 +712,24 @@ fn parse_ipc_message( Ok((batch_metadata, body_buffer, message.version())) } -pub fn to_arrow_schema(fluss_schema: &RowType) -> SchemaRef { - let fields: Vec = fluss_schema +pub fn to_arrow_schema(fluss_schema: &RowType) -> Result { + let fields: Result> = fluss_schema .fields() .iter() .map(|f| { - Field::new( + Ok(Field::new( f.name(), - to_arrow_type(f.data_type()), + to_arrow_type(f.data_type())?, f.data_type().is_nullable(), - ) + )) }) .collect(); - SchemaRef::new(arrow_schema::Schema::new(fields)) + Ok(SchemaRef::new(arrow_schema::Schema::new(fields?))) } -pub fn to_arrow_type(fluss_type: &DataType) -> ArrowDataType { - match fluss_type { +pub fn to_arrow_type(fluss_type: &DataType) -> Result { + Ok(match fluss_type { DataType::Boolean(_) => ArrowDataType::Boolean, DataType::TinyInt(_) => ArrowDataType::Int8, DataType::SmallInt(_) => ArrowDataType::Int16, @@ -668,58 +739,91 @@ pub fn to_arrow_type(fluss_type: &DataType) -> ArrowDataType { DataType::Double(_) => ArrowDataType::Float64, DataType::Char(_) => ArrowDataType::Utf8, DataType::String(_) => ArrowDataType::Utf8, - DataType::Decimal(decimal_type) => ArrowDataType::Decimal128( - decimal_type - .precision() - .try_into() - .expect("precision exceeds u8::MAX"), - decimal_type + DataType::Decimal(decimal_type) => { + let precision = + decimal_type + .precision() + .try_into() + .map_err(|_| Error::IllegalArgument { + message: format!( + "Decimal precision {} exceeds Arrow's maximum (u8::MAX)", + decimal_type.precision() + ), + })?; + let scale = decimal_type .scale() .try_into() - .expect("scale exceeds i8::MAX"), - ), + .map_err(|_| Error::IllegalArgument { + message: format!( + "Decimal scale {} exceeds Arrow's maximum (i8::MAX)", + decimal_type.scale() + ), + })?; + ArrowDataType::Decimal128(precision, scale) + } DataType::Date(_) => ArrowDataType::Date32, DataType::Time(time_type) => match time_type.precision() { 0 => ArrowDataType::Time32(arrow_schema::TimeUnit::Second), 1..=3 => ArrowDataType::Time32(arrow_schema::TimeUnit::Millisecond), 4..=6 => ArrowDataType::Time64(arrow_schema::TimeUnit::Microsecond), 7..=9 => ArrowDataType::Time64(arrow_schema::TimeUnit::Nanosecond), - // This arm should never be reached due to validation in TimeType. - invalid => panic!("Invalid precision value for TimeType: {invalid}"), + invalid => { + return Err(Error::IllegalArgument { + message: format!("Invalid precision {} for TimeType (must be 0-9)", invalid), + }); + } }, DataType::Timestamp(timestamp_type) => match timestamp_type.precision() { 0 => ArrowDataType::Timestamp(arrow_schema::TimeUnit::Second, None), 1..=3 => ArrowDataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None), 4..=6 => ArrowDataType::Timestamp(arrow_schema::TimeUnit::Microsecond, None), 7..=9 => ArrowDataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, None), - // This arm should never be reached due to validation in Timestamp. - invalid => panic!("Invalid precision value for TimestampType: {invalid}"), + invalid => { + return Err(Error::IllegalArgument { + message: format!( + "Invalid precision {} for TimestampType (must be 0-9)", + invalid + ), + }); + } }, DataType::TimestampLTz(timestamp_ltz_type) => match timestamp_ltz_type.precision() { 0 => ArrowDataType::Timestamp(arrow_schema::TimeUnit::Second, None), 1..=3 => ArrowDataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None), 4..=6 => ArrowDataType::Timestamp(arrow_schema::TimeUnit::Microsecond, None), 7..=9 => ArrowDataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, None), - // This arm should never be reached due to validation in TimestampLTz. - invalid => panic!("Invalid precision value for TimestampLTzType: {invalid}"), + invalid => { + return Err(Error::IllegalArgument { + message: format!( + "Invalid precision {} for TimestampLTzType (must be 0-9)", + invalid + ), + }); + } }, DataType::Bytes(_) => ArrowDataType::Binary, - DataType::Binary(binary_type) => ArrowDataType::FixedSizeBinary( - binary_type + DataType::Binary(binary_type) => { + let length = binary_type .length() .try_into() - .expect("length exceeds i32::MAX"), - ), + .map_err(|_| Error::IllegalArgument { + message: format!( + "Binary length {} exceeds Arrow's maximum (i32::MAX)", + binary_type.length() + ), + })?; + ArrowDataType::FixedSizeBinary(length) + } DataType::Array(array_type) => ArrowDataType::List( Field::new_list_field( - to_arrow_type(array_type.get_element_type()), + to_arrow_type(array_type.get_element_type())?, fluss_type.is_nullable(), ) .into(), ), DataType::Map(map_type) => { - let key_type = to_arrow_type(map_type.key_type()); - let value_type = to_arrow_type(map_type.value_type()); + let key_type = to_arrow_type(map_type.key_type())?; + let value_type = to_arrow_type(map_type.value_type())?; let entry_fields = vec![ Field::new("key", key_type, map_type.key_type().is_nullable()), Field::new("value", value_type, map_type.value_type().is_nullable()), @@ -733,20 +837,21 @@ pub fn to_arrow_type(fluss_type: &DataType) -> ArrowDataType { false, ) } - DataType::Row(row_type) => ArrowDataType::Struct(arrow_schema::Fields::from( - row_type + DataType::Row(row_type) => { + let fields: Result> = row_type .fields() .iter() .map(|f| { - Field::new( + Ok(Field::new( f.name(), - to_arrow_type(f.data_type()), + to_arrow_type(f.data_type())?, f.data_type().is_nullable(), - ) + )) }) - .collect::>(), - )), - } + .collect(); + ArrowDataType::Struct(arrow_schema::Fields::from(fields?)) + } + }) } #[derive(Clone)] @@ -1059,81 +1164,114 @@ mod tests { #[test] fn test_to_array_type() { - assert_eq!(to_arrow_type(&DataTypes::boolean()), ArrowDataType::Boolean); - assert_eq!(to_arrow_type(&DataTypes::tinyint()), ArrowDataType::Int8); - assert_eq!(to_arrow_type(&DataTypes::smallint()), ArrowDataType::Int16); - assert_eq!(to_arrow_type(&DataTypes::bigint()), ArrowDataType::Int64); - assert_eq!(to_arrow_type(&DataTypes::int()), ArrowDataType::Int32); - assert_eq!(to_arrow_type(&DataTypes::float()), ArrowDataType::Float32); - assert_eq!(to_arrow_type(&DataTypes::double()), ArrowDataType::Float64); - assert_eq!(to_arrow_type(&DataTypes::char(16)), ArrowDataType::Utf8); - assert_eq!(to_arrow_type(&DataTypes::string()), ArrowDataType::Utf8); assert_eq!( - to_arrow_type(&DataTypes::decimal(10, 2)), + to_arrow_type(&DataTypes::boolean()).unwrap(), + ArrowDataType::Boolean + ); + assert_eq!( + to_arrow_type(&DataTypes::tinyint()).unwrap(), + ArrowDataType::Int8 + ); + assert_eq!( + to_arrow_type(&DataTypes::smallint()).unwrap(), + ArrowDataType::Int16 + ); + assert_eq!( + to_arrow_type(&DataTypes::bigint()).unwrap(), + ArrowDataType::Int64 + ); + assert_eq!( + to_arrow_type(&DataTypes::int()).unwrap(), + ArrowDataType::Int32 + ); + assert_eq!( + to_arrow_type(&DataTypes::float()).unwrap(), + ArrowDataType::Float32 + ); + assert_eq!( + to_arrow_type(&DataTypes::double()).unwrap(), + ArrowDataType::Float64 + ); + assert_eq!( + to_arrow_type(&DataTypes::char(16)).unwrap(), + ArrowDataType::Utf8 + ); + assert_eq!( + to_arrow_type(&DataTypes::string()).unwrap(), + ArrowDataType::Utf8 + ); + assert_eq!( + to_arrow_type(&DataTypes::decimal(10, 2)).unwrap(), ArrowDataType::Decimal128(10, 2) ); - assert_eq!(to_arrow_type(&DataTypes::date()), ArrowDataType::Date32); assert_eq!( - to_arrow_type(&DataTypes::time()), + to_arrow_type(&DataTypes::date()).unwrap(), + ArrowDataType::Date32 + ); + assert_eq!( + to_arrow_type(&DataTypes::time()).unwrap(), ArrowDataType::Time32(arrow_schema::TimeUnit::Second) ); assert_eq!( - to_arrow_type(&DataTypes::time_with_precision(3)), + to_arrow_type(&DataTypes::time_with_precision(3)).unwrap(), ArrowDataType::Time32(arrow_schema::TimeUnit::Millisecond) ); assert_eq!( - to_arrow_type(&DataTypes::time_with_precision(6)), + to_arrow_type(&DataTypes::time_with_precision(6)).unwrap(), ArrowDataType::Time64(arrow_schema::TimeUnit::Microsecond) ); assert_eq!( - to_arrow_type(&DataTypes::time_with_precision(9)), + to_arrow_type(&DataTypes::time_with_precision(9)).unwrap(), ArrowDataType::Time64(arrow_schema::TimeUnit::Nanosecond) ); assert_eq!( - to_arrow_type(&DataTypes::timestamp_with_precision(0)), + to_arrow_type(&DataTypes::timestamp_with_precision(0)).unwrap(), ArrowDataType::Timestamp(arrow_schema::TimeUnit::Second, None) ); assert_eq!( - to_arrow_type(&DataTypes::timestamp_with_precision(3)), + to_arrow_type(&DataTypes::timestamp_with_precision(3)).unwrap(), ArrowDataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None) ); assert_eq!( - to_arrow_type(&DataTypes::timestamp_with_precision(6)), + to_arrow_type(&DataTypes::timestamp_with_precision(6)).unwrap(), ArrowDataType::Timestamp(arrow_schema::TimeUnit::Microsecond, None) ); assert_eq!( - to_arrow_type(&DataTypes::timestamp_with_precision(9)), + to_arrow_type(&DataTypes::timestamp_with_precision(9)).unwrap(), ArrowDataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, None) ); assert_eq!( - to_arrow_type(&DataTypes::timestamp_ltz_with_precision(0)), + to_arrow_type(&DataTypes::timestamp_ltz_with_precision(0)).unwrap(), ArrowDataType::Timestamp(arrow_schema::TimeUnit::Second, None) ); assert_eq!( - to_arrow_type(&DataTypes::timestamp_ltz_with_precision(3)), + to_arrow_type(&DataTypes::timestamp_ltz_with_precision(3)).unwrap(), ArrowDataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None) ); assert_eq!( - to_arrow_type(&DataTypes::timestamp_ltz_with_precision(6)), + to_arrow_type(&DataTypes::timestamp_ltz_with_precision(6)).unwrap(), ArrowDataType::Timestamp(arrow_schema::TimeUnit::Microsecond, None) ); assert_eq!( - to_arrow_type(&DataTypes::timestamp_ltz_with_precision(9)), + to_arrow_type(&DataTypes::timestamp_ltz_with_precision(9)).unwrap(), ArrowDataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, None) ); - assert_eq!(to_arrow_type(&DataTypes::bytes()), ArrowDataType::Binary); assert_eq!( - to_arrow_type(&DataTypes::binary(16)), + to_arrow_type(&DataTypes::bytes()).unwrap(), + ArrowDataType::Binary + ); + assert_eq!( + to_arrow_type(&DataTypes::binary(16)).unwrap(), ArrowDataType::FixedSizeBinary(16) ); assert_eq!( - to_arrow_type(&DataTypes::array(DataTypes::int())), + to_arrow_type(&DataTypes::array(DataTypes::int())).unwrap(), ArrowDataType::List(Field::new_list_field(ArrowDataType::Int32, true).into()) ); assert_eq!( - to_arrow_type(&DataTypes::map(DataTypes::string(), DataTypes::int())), + to_arrow_type(&DataTypes::map(DataTypes::string(), DataTypes::int())).unwrap(), ArrowDataType::Map( Arc::new(Field::new( "entries", @@ -1151,7 +1289,8 @@ mod tests { to_arrow_type(&DataTypes::row(vec![ DataTypes::field("f1".to_string(), DataTypes::int()), DataTypes::field("f2".to_string(), DataTypes::string()), - ])), + ])) + .unwrap(), ArrowDataType::Struct(arrow_schema::Fields::from(vec![ Field::new("f1", ArrowDataType::Int32, true), Field::new("f2", ArrowDataType::Utf8, true), @@ -1215,7 +1354,7 @@ mod tests { DataField::new("id".to_string(), DataTypes::int(), None), DataField::new("name".to_string(), DataTypes::string(), None), ]); - let schema = to_arrow_schema(&row_type); + let schema = to_arrow_schema(&row_type).unwrap(); let result = ReadContext::with_projection_pushdown(schema, vec![0, 2], false); assert!(matches!(result, Err(IllegalArgument { .. }))); @@ -1249,4 +1388,201 @@ mod tests { } out } + + #[test] + fn test_temporal_and_decimal_builder_validation() { + // Test valid builder creation + let builder = + RowAppendRecordBatchBuilder::create_builder(&ArrowDataType::Decimal128(10, 2)).unwrap(); + assert!( + builder + .as_any() + .downcast_ref::() + .is_some() + ); + + // Test error case: invalid precision/scale + let result = + RowAppendRecordBatchBuilder::create_builder(&ArrowDataType::Decimal128(100, 50)); + assert!(result.is_err()); + } + + #[test] + fn test_decimal_rescaling_and_validation() -> Result<()> { + use crate::row::{Datum, Decimal, GenericRow}; + use arrow::array::Decimal128Array; + use bigdecimal::BigDecimal; + use std::str::FromStr; + + // Test 1: Rescaling from scale 3 to scale 2 + let row_type = RowType::new(vec![DataField::new( + "amount".to_string(), + DataTypes::decimal(10, 2), + None, + )]); + let mut builder = RowAppendRecordBatchBuilder::new(&row_type)?; + let decimal = Decimal::from_big_decimal(BigDecimal::from_str("123.456").unwrap(), 10, 3)?; + builder.append(&GenericRow { + values: vec![Datum::Decimal(decimal)], + })?; + let batch = builder.build_arrow_record_batch()?; + let array = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(array.value(0), 12346); // 123.46 rounded + assert_eq!(array.scale(), 2); + + // Test 2: Precision overflow (should error) + let row_type = RowType::new(vec![DataField::new( + "amount".to_string(), + DataTypes::decimal(5, 2), + None, + )]); + let mut builder = RowAppendRecordBatchBuilder::new(&row_type)?; + let decimal = Decimal::from_big_decimal(BigDecimal::from_str("123456.78").unwrap(), 10, 2)?; + let result = builder.append(&GenericRow { + values: vec![Datum::Decimal(decimal)], + }); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("precision overflow") + ); + + Ok(()) + } + + #[test] + fn test_temporal_types_end_to_end() -> Result<()> { + use crate::row::{Date, Datum, Decimal, GenericRow, Time, TimestampLtz, TimestampNtz}; + use arrow::array::{ + Date32Array, Decimal128Array, Int32Array, Time32MillisecondArray, + Time64NanosecondArray, TimestampMicrosecondArray, TimestampNanosecondArray, + }; + use bigdecimal::BigDecimal; + use std::str::FromStr; + + // Schema with decimal, date, time (ms + ns), timestamps (μs + ns) + let row_type = RowType::new(vec![ + DataField::new("id".to_string(), DataTypes::int(), None), + DataField::new("amount".to_string(), DataTypes::decimal(10, 2), None), + DataField::new("date".to_string(), DataTypes::date(), None), + DataField::new( + "time_ms".to_string(), + DataTypes::time_with_precision(3), + None, + ), + DataField::new( + "time_ns".to_string(), + DataTypes::time_with_precision(9), + None, + ), + DataField::new( + "ts_us".to_string(), + DataTypes::timestamp_with_precision(6), + None, + ), + DataField::new( + "ts_ltz_ns".to_string(), + DataTypes::timestamp_ltz_with_precision(9), + None, + ), + ]); + + let mut builder = RowAppendRecordBatchBuilder::new(&row_type)?; + + // Append rows with temporal values + builder.append(&GenericRow { + values: vec![ + Datum::Int32(1), + Datum::Decimal(Decimal::from_big_decimal( + BigDecimal::from_str("123.456").unwrap(), + 10, + 3, + )?), + Datum::Date(Date::new(18000)), + Datum::Time(Time::new(43200000)), + Datum::Time(Time::new(12345)), + Datum::TimestampNtz(TimestampNtz::from_millis_nanos(1609459200000, 123456)?), + Datum::TimestampLtz(TimestampLtz::from_millis_nanos(1609459200000, 987654)?), + ], + })?; + + let batch = builder.build_arrow_record_batch()?; + + // Verify all conversions + assert_eq!( + batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0), + 1 + ); + + let dec = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(dec.value(0), 12346); // 123.456 → 123.46 (scale 3 → 2) + + assert_eq!( + batch + .column(2) + .as_any() + .downcast_ref::() + .unwrap() + .value(0), + 18000 + ); + + assert_eq!( + batch + .column(3) + .as_any() + .downcast_ref::() + .unwrap() + .value(0), + 43200000 + ); + + assert_eq!( + batch + .column(4) + .as_any() + .downcast_ref::() + .unwrap() + .value(0), + 12345000000 + ); + + // Timestamp with sub-millisecond nanos preserved + assert_eq!( + batch + .column(5) + .as_any() + .downcast_ref::() + .unwrap() + .value(0), + 1609459200000123 + ); + + assert_eq!( + batch + .column(6) + .as_any() + .downcast_ref::() + .unwrap() + .value(0), + 1609459200000987654 + ); + + Ok(()) + } } diff --git a/crates/fluss/src/row/datum.rs b/crates/fluss/src/row/datum.rs index 5b21b389..42a168cc 100644 --- a/crates/fluss/src/row/datum.rs +++ b/crates/fluss/src/row/datum.rs @@ -19,9 +19,13 @@ use crate::error::Error::RowConvertError; use crate::error::Result; use crate::row::Decimal; use arrow::array::{ - ArrayBuilder, BinaryBuilder, BooleanBuilder, Float32Builder, Float64Builder, Int8Builder, - Int16Builder, Int32Builder, Int64Builder, StringBuilder, + ArrayBuilder, BinaryBuilder, BooleanBuilder, Date32Builder, Decimal128Builder, Float32Builder, + Float64Builder, Int8Builder, Int16Builder, Int32Builder, Int64Builder, StringBuilder, + Time32MillisecondBuilder, Time32SecondBuilder, Time64MicrosecondBuilder, + Time64NanosecondBuilder, TimestampMicrosecondBuilder, TimestampMillisecondBuilder, + TimestampNanosecondBuilder, TimestampSecondBuilder, }; +use arrow::datatypes as arrow_schema; use jiff::ToSpan; use ordered_float::OrderedFloat; use parse_display::Display; @@ -83,6 +87,41 @@ impl Datum<'_> { _ => panic!("not a blob: {self:?}"), } } + + pub fn as_decimal(&self) -> &Decimal { + match self { + Self::Decimal(d) => d, + _ => panic!("not a decimal: {self:?}"), + } + } + + pub fn as_date(&self) -> Date { + match self { + Self::Date(d) => *d, + _ => panic!("not a date: {self:?}"), + } + } + + pub fn as_time(&self) -> Time { + match self { + Self::Time(t) => *t, + _ => panic!("not a time: {self:?}"), + } + } + + pub fn as_timestamp_ntz(&self) -> TimestampNtz { + match self { + Self::TimestampNtz(ts) => *ts, + _ => panic!("not a timestamp ntz: {self:?}"), + } + } + + pub fn as_timestamp_ltz(&self) -> TimestampLtz { + match self { + Self::TimestampLtz(ts) => *ts, + _ => panic!("not a timestamp ltz: {self:?}"), + } + } } // ----------- implement from @@ -246,6 +285,66 @@ impl TryFrom<&Datum<'_>> for i8 { } } +impl TryFrom<&Datum<'_>> for Decimal { + type Error = (); + + #[inline] + fn try_from(from: &Datum) -> std::result::Result { + match from { + Datum::Decimal(d) => Ok(d.clone()), + _ => Err(()), + } + } +} + +impl TryFrom<&Datum<'_>> for Date { + type Error = (); + + #[inline] + fn try_from(from: &Datum) -> std::result::Result { + match from { + Datum::Date(d) => Ok(*d), + _ => Err(()), + } + } +} + +impl TryFrom<&Datum<'_>> for Time { + type Error = (); + + #[inline] + fn try_from(from: &Datum) -> std::result::Result { + match from { + Datum::Time(t) => Ok(*t), + _ => Err(()), + } + } +} + +impl TryFrom<&Datum<'_>> for TimestampNtz { + type Error = (); + + #[inline] + fn try_from(from: &Datum) -> std::result::Result { + match from { + Datum::TimestampNtz(ts) => Ok(*ts), + _ => Err(()), + } + } +} + +impl TryFrom<&Datum<'_>> for TimestampLtz { + type Error = (); + + #[inline] + fn try_from(from: &Datum) -> std::result::Result { + match from { + Datum::TimestampLtz(ts) => Ok(*ts), + _ => Err(()), + } + } +} + impl<'a> From for Datum<'a> { #[inline] fn from(b: bool) -> Datum<'a> { @@ -253,12 +352,55 @@ impl<'a> From for Datum<'a> { } } +impl<'a> From for Datum<'a> { + #[inline] + fn from(d: Decimal) -> Datum<'a> { + Datum::Decimal(d) + } +} + +impl<'a> From for Datum<'a> { + #[inline] + fn from(d: Date) -> Datum<'a> { + Datum::Date(d) + } +} + +impl<'a> From