Skip to content

Commit 41aef66

Browse files
committed
fix: preserve dictionary encoding in to_arrow_batch_reader
1 parent 9d36e23 commit 41aef66

2 files changed

Lines changed: 81 additions & 0 deletions

File tree

pyiceberg/table/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2196,6 +2196,20 @@ def _to_arrow_batch_reader_via_file_scan_tasks(
21962196
from pyiceberg.io.pyarrow import ArrowScan, schema_to_pyarrow
21972197

21982198
target_schema = schema_to_pyarrow(projected_schema)
2199+
if dictionary_columns:
2200+
"""schema_to_pyarrow returns plain types. ArrowScan yields
2201+
dictionary-encoded batches for the requested columns, so rebuild
2202+
target_schema with dictionary types for those fields. Without this,
2203+
.cast(target_schema) would silently convert dictionary arrays back
2204+
to their plain value type and erase the encoding."""
2205+
dict_col_set = set(dictionary_columns)
2206+
target_schema = pa.schema(
2207+
[
2208+
field.with_type(pa.dictionary(pa.int32(), field.type)) if field.name in dict_col_set else field
2209+
for field in target_schema
2210+
],
2211+
metadata=target_schema.metadata,
2212+
)
21992213
batches = ArrowScan(
22002214
scan.table_metadata,
22012215
scan.io,

tests/io/test_pyarrow.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5229,6 +5229,73 @@ def test_partition_column_projection_with_schema_evolution(catalog: InMemoryCata
52295229
assert result_sorted["new_column"].to_pylist() == [None, None, "new1", "new2"]
52305230

52315231

5232+
def test_to_arrow_batch_reader_preserves_dictionary_columns(tmpdir: str) -> None:
5233+
"""_to_arrow_batch_reader_via_file_scan_tasks must not strip dictionary encoding.
5234+
5235+
Regression test for https://github.com/apache/iceberg-python/issues/3540.
5236+
Before the fix, RecordBatchReader.cast(target_schema) was called with a
5237+
plain schema, silently converting dictionary arrays back to their value
5238+
type so to_arrow_batch_reader(dictionary_columns=...).read_all() returned
5239+
plain strings instead of dictionary-encoded arrays.
5240+
"""
5241+
from pyiceberg.expressions import AlwaysTrue
5242+
from pyiceberg.io.pyarrow import PyArrowFileIO
5243+
from pyiceberg.partitioning import PartitionSpec
5244+
from pyiceberg.table import FileScanTask, _to_arrow_batch_reader_via_file_scan_tasks
5245+
from pyiceberg.table.metadata import TableMetadataV2
5246+
5247+
arrow_schema = pa.schema(
5248+
[
5249+
pa.field("id", pa.int32(), nullable=True, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"}),
5250+
pa.field("label", pa.string(), nullable=True, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "2"}),
5251+
]
5252+
)
5253+
arrow_table = pa.table(
5254+
[pa.array([1, 2, 3, 4], type=pa.int32()), pa.array(["a", "b", "a", "b"], type=pa.string())],
5255+
schema=arrow_schema,
5256+
)
5257+
data_file = _write_table_to_data_file(f"{tmpdir}/test_batch_reader_dict.parquet", arrow_schema, arrow_table)
5258+
data_file.spec_id = 0
5259+
5260+
iceberg_schema = Schema(
5261+
NestedField(1, "id", IntegerType(), required=False),
5262+
NestedField(2, "label", StringType(), required=False),
5263+
)
5264+
table_metadata = TableMetadataV2(
5265+
location=f"file://{tmpdir}",
5266+
last_column_id=2,
5267+
format_version=2,
5268+
schemas=[iceberg_schema],
5269+
partition_specs=[PartitionSpec()],
5270+
)
5271+
5272+
class _MockScan:
5273+
def __init__(self) -> None:
5274+
self.table_metadata = table_metadata
5275+
self.io = PyArrowFileIO()
5276+
self.row_filter = AlwaysTrue()
5277+
self.case_sensitive = True
5278+
self.limit = None
5279+
5280+
tasks = [FileScanTask(data_file)]
5281+
result = _to_arrow_batch_reader_via_file_scan_tasks(
5282+
_MockScan(), # type: ignore[arg-type]
5283+
iceberg_schema,
5284+
tasks,
5285+
dictionary_columns=("label",),
5286+
).read_all()
5287+
5288+
# label must be dictionary-encoded, not plain string
5289+
assert pa.types.is_dictionary(result.schema.field("label").type), (
5290+
f"expected dictionary type, got {result.schema.field('label').type}"
5291+
)
5292+
# id is not in dictionary_columns — must remain int32
5293+
assert result.schema.field("id").type == pa.int32()
5294+
# Values must be identical to the source data
5295+
assert result.column("label").to_pylist() == ["a", "b", "a", "b"]
5296+
assert result.column("id").to_pylist() == [1, 2, 3, 4]
5297+
5298+
52325299
def test_dictionary_columns_produces_dict_encoded_output(tmpdir: str) -> None:
52335300
"""dictionary_columns passed to ArrowScan must yield dictionary-encoded arrays.
52345301

0 commit comments

Comments
 (0)