Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
311 changes: 305 additions & 6 deletions crates/iceberg/src/arrow/record_batch_transformer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
use std::collections::HashMap;
use std::sync::Arc;

use arrow_array::cast::AsArray;
use arrow_array::{
Array as ArrowArray, ArrayRef, Int32Array, RecordBatch, RecordBatchOptions, RunArray,
Array as ArrowArray, ArrayRef, Int32Array, LargeListArray, ListArray, MapArray, RecordBatch,
RecordBatchOptions, RunArray, StructArray, new_null_array,
};
use arrow_cast::cast;
use arrow_schema::{
Expand Down Expand Up @@ -610,7 +612,7 @@ impl RecordBatchTransformer {
ColumnSource::Promote {
target_type,
source_index,
} => cast(&*columns[*source_index], target_type)?,
} => cast_schema_to_target(&columns[*source_index], target_type)?,

ColumnSource::Add { target_type, value } => {
Self::create_column(target_type, value, num_rows)?
Expand All @@ -619,7 +621,113 @@ impl RecordBatchTransformer {
})
.collect()
}
}

/// Look up an Iceberg field id from an Arrow field's `PARQUET:field_id` metadata.
fn arrow_field_id(field: &Field) -> Option<i32> {
field
.metadata()
.get(PARQUET_FIELD_ID_META_KEY)
.and_then(|s| s.parse::<i32>().ok())
}

/// Mirrors iceberg-java logic, resolve by field ID instead of position.
fn cast_schema_to_target(array: &ArrayRef, target_type: &DataType) -> Result<ArrayRef> {
match target_type {
DataType::Struct(target_children) => {
let source = array.as_struct_opt().ok_or_else(|| {
Error::new(
ErrorKind::Unexpected,
format!(
"expected a struct array to promote to {target_type:?}, got {:?}",
array.data_type()
),
)
})?;
let mut source_by_id: HashMap<i32, usize> = HashMap::new();
for (idx, field) in source.fields().iter().enumerate() {
if let Some(id) = arrow_field_id(field) {
source_by_id.insert(id, idx);
}
}

let len = source.len();
let mut new_columns: Vec<ArrayRef> = Vec::with_capacity(target_children.len());
for target_child in target_children.iter() {
let matched = arrow_field_id(target_child)
.and_then(|id| source_by_id.get(&id))
.copied();
match matched {
Some(src_idx) => new_columns.push(cast_schema_to_target(
source.column(src_idx),
target_child.data_type(),
)?),
None => new_columns.push(new_null_array(target_child.data_type(), len)),
}
}

Ok(Arc::new(StructArray::new(
target_children.clone(),
new_columns,
source.nulls().cloned(),
)))
}
DataType::List(target_element) => {
let source = array
.as_list_opt::<i32>()
.ok_or_else(|| list_err(array, target_type))?;
let values = cast_schema_to_target(source.values(), target_element.data_type())?;
Ok(Arc::new(ListArray::new(
target_element.clone(),
source.offsets().clone(),
values,
source.nulls().cloned(),
)))
}
DataType::LargeList(target_element) => {
let source = array
.as_list_opt::<i64>()
.ok_or_else(|| list_err(array, target_type))?;
let values = cast_schema_to_target(source.values(), target_element.data_type())?;
Ok(Arc::new(LargeListArray::new(
target_element.clone(),
source.offsets().clone(),
values,
source.nulls().cloned(),
)))
}
DataType::Map(target_entries, sorted) => {
let source = array
.as_map_opt()
.ok_or_else(|| list_err(array, target_type))?;
let entries = cast_schema_to_target(
&(Arc::new(source.entries().clone()) as ArrayRef),
target_entries.data_type(),
)?;
let entries_struct = entries.as_struct().clone();
Ok(Arc::new(MapArray::new(
target_entries.clone(),
source.offsets().clone(),
entries_struct,
source.nulls().cloned(),
*sorted,
)))
}
_ => Ok(cast(array.as_ref(), target_type)?),
}
}

fn list_err(array: &ArrayRef, target_type: &DataType) -> Error {
Error::new(
ErrorKind::Unexpected,
format!(
"expected a list/map array to promote to {target_type:?}, got {:?}",
array.data_type()
),
)
}

impl RecordBatchTransformer {
fn create_column(
target_type: &DataType,
prim_lit: &Option<PrimitiveLiteral>,
Expand Down Expand Up @@ -663,18 +771,209 @@ mod test {
use std::collections::HashMap;
use std::sync::Arc;

use arrow_array::cast::AsArray;
use arrow_array::types::Int32Type;
use arrow_array::{
Array, Date32Array, Float32Array, Float64Array, Int32Array, Int64Array, RecordBatch,
StringArray,
Array, ArrayRef, Date32Array, Float32Array, Float64Array, Int32Array, Int64Array,
LargeListArray, ListArray, MapArray, RecordBatch, StringArray, StructArray,
};
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use arrow_buffer::OffsetBuffer;
use arrow_schema::{DataType, Field, Fields, Schema as ArrowSchema};
use parquet::arrow::PARQUET_FIELD_ID_META_KEY;

use crate::arrow::record_batch_transformer::{
RecordBatchTransformer, RecordBatchTransformerBuilder,
RecordBatchTransformer, RecordBatchTransformerBuilder, cast_schema_to_target,
};
use crate::spec::{Literal, NestedField, PrimitiveType, Schema, Struct, Type};

fn with_id(field: Field, id: i32) -> Field {
field.with_metadata(HashMap::from([(
PARQUET_FIELD_ID_META_KEY.to_string(),
id.to_string(),
)]))
}

fn field_with_id(name: &str, dt: DataType, id: i32) -> Field {
with_id(Field::new(name, dt, true), id)
}

fn unevolved_struct_type() -> DataType {
DataType::Struct(Fields::from(vec![field_with_id("x", DataType::Int32, 5)]))
}

fn evolved_struct_type() -> DataType {
DataType::Struct(Fields::from(vec![
field_with_id("x", DataType::Int32, 5),
field_with_id("y", DataType::Int32, 6),
]))
}

fn unevolved_struct_data(x_values: Vec<i32>) -> Arc<StructArray> {
Arc::new(StructArray::new(
Fields::from(vec![field_with_id("x", DataType::Int32, 5)]),
vec![Arc::new(Int32Array::from(x_values)) as ArrayRef],
None,
))
}

fn assert_existing_field_kept(s: &StructArray, expected_existing: &[i32]) {
assert_eq!(
s.column(0).as_primitive::<Int32Type>().values(),
expected_existing
);
}

fn assert_added_field_null(s: &StructArray) {
assert_eq!(s.column(1).null_count(), s.len());
}

#[test]
fn promote_struct_fills_added_middle_field_by_id() {
let source = Arc::new(StructArray::new(
Fields::from(vec![
field_with_id("a", DataType::Int32, 1),
field_with_id("c", DataType::Utf8, 3),
]),
vec![
Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef,
Arc::new(StringArray::from(vec!["x", "y"])) as ArrayRef,
],
None,
)) as ArrayRef;
let target = DataType::Struct(Fields::from(vec![
field_with_id("a", DataType::Int32, 1),
field_with_id("b", DataType::Int32, 2),
field_with_id("c", DataType::Utf8, 3),
]));

let out = cast_schema_to_target(&source, &target).unwrap();
let s = out.as_struct();
assert_eq!(s.num_columns(), 3);
assert_eq!(s.column(0).as_primitive::<Int32Type>().values(), &[1, 2]);
assert_eq!(s.column(1).null_count(), 2);
let cc = s.column(2).as_string::<i32>();
assert_eq!((cc.value(0), cc.value(1)), ("x", "y"));
}

#[test]
fn promote_struct_missing_field_before_nested_list_struct() {
let elem_field = Arc::new(field_with_id("element", unevolved_struct_type(), 4));
let list = Arc::new(ListArray::new(
elem_field.clone(),
OffsetBuffer::new(vec![0, 1, 2].into()),
unevolved_struct_data(vec![10, 20]),
None,
)) as ArrayRef;
let source = Arc::new(StructArray::new(
Fields::from(vec![
field_with_id("s", DataType::Utf8, 1),
field_with_id("ev", DataType::List(elem_field.clone()), 3),
]),
vec![
Arc::new(StringArray::from(vec!["a", "b"])) as ArrayRef,
list,
],
None,
)) as ArrayRef;
let target = DataType::Struct(Fields::from(vec![
field_with_id("s", DataType::Utf8, 1),
field_with_id("gap", DataType::Int32, 2),
field_with_id("ev", DataType::List(elem_field), 3),
]));

let out = cast_schema_to_target(&source, &target).unwrap();
let st = out.as_struct();
assert_eq!(st.num_columns(), 3);
assert_eq!(st.column(1).null_count(), 2);
let ev = st.column(2).as_list::<i32>();
assert_eq!(ev.len(), 2);
assert_eq!(
ev.value(0)
.as_struct()
.column(0)
.as_primitive::<Int32Type>()
.value(0),
10
);
}

#[test]
fn promote_list_element_struct_fills_added_field_by_id() {
let source = Arc::new(ListArray::new(
Arc::new(field_with_id("element", unevolved_struct_type(), 4)),
OffsetBuffer::new(vec![0, 1, 2].into()),
unevolved_struct_data(vec![10, 20]),
None,
)) as ArrayRef;
let target = DataType::List(Arc::new(field_with_id("element", evolved_struct_type(), 4)));

let out = cast_schema_to_target(&source, &target).unwrap();
let lst = out.as_list::<i32>();
assert_eq!(lst.len(), 2);
let elements = lst.values().as_struct();
assert_existing_field_kept(elements, &[10, 20]);
assert_added_field_null(elements);
}

#[test]
fn promote_map_value_struct_fills_added_field_by_id() {
let entries = StructArray::new(
Fields::from(vec![
with_id(Field::new("key", DataType::Utf8, false), 7),
field_with_id("value", unevolved_struct_type(), 8),
]),
vec![
Arc::new(StringArray::from(vec!["k1", "k2"])) as ArrayRef,
unevolved_struct_data(vec![100, 200]),
],
None,
);
let source = Arc::new(MapArray::new(
Arc::new(Field::new("entries", entries.data_type().clone(), false)),
OffsetBuffer::new(vec![0, 1, 2].into()),
entries,
None,
false,
)) as ArrayRef;
let target_entries = DataType::Struct(Fields::from(vec![
with_id(Field::new("key", DataType::Utf8, false), 7),
field_with_id("value", evolved_struct_type(), 8),
]));
let target = DataType::Map(
Arc::new(Field::new("entries", target_entries, false)),
false,
);

let out = cast_schema_to_target(&source, &target).unwrap();
let m = out.as_map();
assert_eq!(m.len(), 2);
let entries = m.entries();
let ks = entries.column(0).as_string::<i32>();
assert_eq!((ks.value(0), ks.value(1)), ("k1", "k2"));
let values = entries.column(1).as_struct();
assert_existing_field_kept(values, &[100, 200]);
assert_added_field_null(values);
}

#[test]
fn promote_large_list_element_struct_fills_added_field_by_id() {
let source = Arc::new(LargeListArray::new(
Arc::new(field_with_id("element", unevolved_struct_type(), 4)),
OffsetBuffer::new(vec![0i64, 1, 2].into()),
unevolved_struct_data(vec![7, 8]),
None,
)) as ArrayRef;
let target =
DataType::LargeList(Arc::new(field_with_id("element", evolved_struct_type(), 4)));

let out = cast_schema_to_target(&source, &target).unwrap();
let lst = out.as_list::<i64>();
assert_eq!(lst.len(), 2);
let elements = lst.values().as_struct();
assert_existing_field_kept(elements, &[7, 8]);
assert_added_field_null(elements);
}

/// Helper to extract string values from either StringArray or RunEndEncoded<StringArray>
/// Returns empty string for null values
fn get_string_value(array: &dyn Array, index: usize) -> String {
Expand Down
Loading