From c2eec90cfc8f546d1e4714980345d18ace60f37c Mon Sep 17 00:00:00 2001 From: Fokko Date: Thu, 8 May 2025 11:48:28 +0200 Subject: [PATCH 1/5] Write small decimals as INTs Resolves #1979 --- pyiceberg/io/pyarrow.py | 14 ++++- tests/integration/test_writes/test_writes.py | 65 ++++++++++++++++++++ 2 files changed, 77 insertions(+), 2 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 99223f1253..a8dad5676b 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -638,7 +638,13 @@ def visit_fixed(self, fixed_type: FixedType) -> pa.DataType: return pa.binary(len(fixed_type)) def visit_decimal(self, decimal_type: DecimalType) -> pa.DataType: - return pa.decimal128(decimal_type.precision, decimal_type.scale) + return ( + pa.decimal32(decimal_type.precision, decimal_type.scale) + if decimal_type.precision <= 9 + else pa.decimal64(decimal_type.precision, decimal_type.scale) + if decimal_type.precision <= 18 + else pa.decimal128(decimal_type.precision, decimal_type.scale) + ) def visit_boolean(self, _: BooleanType) -> pa.DataType: return pa.bool_() @@ -1749,6 +1755,8 @@ def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array: elif target_type.unit == "us" and values.type.unit in {"s", "ms", "us"}: return values.cast(target_type) raise ValueError(f"Unsupported schema projection from {values.type} to {target_type}") + else: + pass return values def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> pa.Field: @@ -2437,7 +2445,9 @@ def write_parquet(task: WriteTask) -> DataFile: ) fo = io.new_output(file_path) with fo.create(overwrite=True) as fos: - with pq.ParquetWriter(fos, schema=arrow_table.schema, **parquet_writer_kwargs) as writer: + with pq.ParquetWriter( + fos, schema=arrow_table.schema, store_decimal_as_integer=True, **parquet_writer_kwargs + ) as writer: writer.write(arrow_table, row_group_size=row_group_size) statistics = data_file_statistics_from_parquet_metadata( parquet_metadata=writer.writer.metadata, diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index 8ea2e93cb1..6898143394 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -20,6 +20,7 @@ import random import time from datetime import date, datetime, timedelta +from decimal import Decimal from pathlib import Path from typing import Any, Dict from urllib.parse import urlparse @@ -50,6 +51,7 @@ from pyiceberg.transforms import DayTransform, HourTransform, IdentityTransform from pyiceberg.types import ( DateType, + DecimalType, DoubleType, IntegerType, ListType, @@ -1684,3 +1686,66 @@ def test_write_optional_list(session_catalog: Catalog) -> None: session_catalog.load_table(identifier).append(df_2) assert len(session_catalog.load_table(identifier).scan().to_arrow()) == 4 + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_evolve_and_write( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = "default.test_evolve_and_write" + tbl = _create_table(session_catalog, identifier, properties={"format-version": format_version}, schema=Schema()) + other_table = session_catalog.load_table(identifier) + + numbers = pa.array([1, 2, 3, 4], type=pa.int32()) + + with tbl.update_schema() as upd: + # This is not known by other_table + upd.add_column("id", IntegerType()) + + with other_table.transaction() as tx: + # Refreshes the underlying metadata, and the schema + other_table.refresh() + tx.append( + pa.Table.from_arrays( + [ + numbers, + ], + schema=pa.schema( + [ + pa.field("id", pa.int32(), nullable=True), + ] + ), + ) + ) + + assert session_catalog.load_table(identifier).scan().to_arrow().column(0).combine_chunks() == numbers + + +@pytest.mark.integration +def test_read_write_decimals(session_catalog: Catalog) -> None: + """Roundtrip decimal types to make sure that we correctly write them as ints""" + identifier = "default.test_read_write_decimals" + + arrow_table = pa.Table.from_pydict( + { + "decimal8": pa.array([Decimal("123.45"), Decimal("678.91")], pa.decimal128(8, 2)), + "decimal16": pa.array([Decimal("12345679.123456"), Decimal("67891234.678912")], pa.decimal128(16, 6)), + "decimal19": pa.array([Decimal("1234567890123.123456"), Decimal("9876543210703.654321")], pa.decimal128(19, 6)), + }, + ) + + tbl = _create_table( + session_catalog, + identifier, + properties={"format-version": 2}, + schema=Schema( + NestedField(1, "decimal8", DecimalType(8, 2)), + NestedField(2, "decimal16", DecimalType(16, 6)), + NestedField(3, "decimal19", DecimalType(19, 6)), + ), + ) + + tbl.append(arrow_table) + + assert tbl.scan().to_arrow() == arrow_table From e5c41c7001f66a50623f43aaa0fa8bda0177ebb7 Mon Sep 17 00:00:00 2001 From: Fokko Date: Thu, 8 May 2025 11:50:58 +0200 Subject: [PATCH 2/5] Cleanup --- pyiceberg/io/pyarrow.py | 2 -- tests/integration/test_writes/test_writes.py | 34 -------------------- 2 files changed, 36 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index c7dcd64349..4675ecc7ca 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1753,8 +1753,6 @@ def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array: elif target_type.unit == "us" and values.type.unit in {"s", "ms", "us"}: return values.cast(target_type) raise ValueError(f"Unsupported schema projection from {values.type} to {target_type}") - else: - pass return values def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> pa.Field: diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index 8f241a4b84..1f6c52f0b7 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -1780,40 +1780,6 @@ def test_write_optional_list(session_catalog: Catalog) -> None: assert len(session_catalog.load_table(identifier).scan().to_arrow()) == 4 -@pytest.mark.integration -@pytest.mark.parametrize("format_version", [1, 2]) -def test_evolve_and_write( - spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int -) -> None: - identifier = "default.test_evolve_and_write" - tbl = _create_table(session_catalog, identifier, properties={"format-version": format_version}, schema=Schema()) - other_table = session_catalog.load_table(identifier) - - numbers = pa.array([1, 2, 3, 4], type=pa.int32()) - - with tbl.update_schema() as upd: - # This is not known by other_table - upd.add_column("id", IntegerType()) - - with other_table.transaction() as tx: - # Refreshes the underlying metadata, and the schema - other_table.refresh() - tx.append( - pa.Table.from_arrays( - [ - numbers, - ], - schema=pa.schema( - [ - pa.field("id", pa.int32(), nullable=True), - ] - ), - ) - ) - - assert session_catalog.load_table(identifier).scan().to_arrow().column(0).combine_chunks() == numbers - - @pytest.mark.integration def test_read_write_decimals(session_catalog: Catalog) -> None: """Roundtrip decimal types to make sure that we correctly write them as ints""" From 24fb3885f8d03ff80db5ff8682bd7810cbc92ac5 Mon Sep 17 00:00:00 2001 From: Fokko Date: Thu, 8 May 2025 11:52:04 +0200 Subject: [PATCH 3/5] Oops --- tests/integration/test_writes/test_writes.py | 34 ++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index 1f6c52f0b7..8f241a4b84 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -1780,6 +1780,40 @@ def test_write_optional_list(session_catalog: Catalog) -> None: assert len(session_catalog.load_table(identifier).scan().to_arrow()) == 4 +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_evolve_and_write( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = "default.test_evolve_and_write" + tbl = _create_table(session_catalog, identifier, properties={"format-version": format_version}, schema=Schema()) + other_table = session_catalog.load_table(identifier) + + numbers = pa.array([1, 2, 3, 4], type=pa.int32()) + + with tbl.update_schema() as upd: + # This is not known by other_table + upd.add_column("id", IntegerType()) + + with other_table.transaction() as tx: + # Refreshes the underlying metadata, and the schema + other_table.refresh() + tx.append( + pa.Table.from_arrays( + [ + numbers, + ], + schema=pa.schema( + [ + pa.field("id", pa.int32(), nullable=True), + ] + ), + ) + ) + + assert session_catalog.load_table(identifier).scan().to_arrow().column(0).combine_chunks() == numbers + + @pytest.mark.integration def test_read_write_decimals(session_catalog: Catalog) -> None: """Roundtrip decimal types to make sure that we correctly write them as ints""" From ca845bd0d9f97a88a6d1e9571e36cbdda44371f7 Mon Sep 17 00:00:00 2001 From: Fokko Date: Sat, 10 May 2025 22:35:10 +0200 Subject: [PATCH 4/5] Throw on >38 precision --- pyiceberg/io/pyarrow.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 4675ecc7ca..0f00a2149a 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -636,13 +636,14 @@ def visit_fixed(self, fixed_type: FixedType) -> pa.DataType: return pa.binary(len(fixed_type)) def visit_decimal(self, decimal_type: DecimalType) -> pa.DataType: - return ( - pa.decimal32(decimal_type.precision, decimal_type.scale) - if decimal_type.precision <= 9 - else pa.decimal64(decimal_type.precision, decimal_type.scale) - if decimal_type.precision <= 18 - else pa.decimal128(decimal_type.precision, decimal_type.scale) - ) + if decimal_type.precision <= 9: + return pa.decimal32(decimal_type.precision, decimal_type.scale) + elif decimal_type.precision <= 18: + return pa.decimal64(decimal_type.precision, decimal_type.scale) + elif decimal_type.precision <= 38: + return pa.decimal128(decimal_type.precision, decimal_type.scale) + else: + raise ValueError(f"Precision above 38 is not supported: {decimal_type.precision}") def visit_boolean(self, _: BooleanType) -> pa.DataType: return pa.bool_() From 8ededbfaa5ebcf0a45721bcb5419e7c9e5b65f7f Mon Sep 17 00:00:00 2001 From: Fokko Date: Sat, 10 May 2025 22:54:46 +0200 Subject: [PATCH 5/5] Move back to decimal128 --- pyiceberg/io/pyarrow.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 0f00a2149a..1aaab32dbe 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -636,14 +636,13 @@ def visit_fixed(self, fixed_type: FixedType) -> pa.DataType: return pa.binary(len(fixed_type)) def visit_decimal(self, decimal_type: DecimalType) -> pa.DataType: - if decimal_type.precision <= 9: - return pa.decimal32(decimal_type.precision, decimal_type.scale) - elif decimal_type.precision <= 18: - return pa.decimal64(decimal_type.precision, decimal_type.scale) - elif decimal_type.precision <= 38: - return pa.decimal128(decimal_type.precision, decimal_type.scale) - else: - raise ValueError(f"Precision above 38 is not supported: {decimal_type.precision}") + # It looks like decimal{32,64} is not fully implemented: + # https://github.com/apache/arrow/issues/25483 + # https://github.com/apache/arrow/issues/43956 + # However, if we keep it as 128 in memory, and based on the + # precision/scale Arrow will map it to INT{32,64} + # https://github.com/apache/arrow/blob/598938711a8376cbfdceaf5c77ab0fd5057e6c02/cpp/src/parquet/arrow/schema.cc#L380-L392 + return pa.decimal128(decimal_type.precision, decimal_type.scale) def visit_boolean(self, _: BooleanType) -> pa.DataType: return pa.bool_()