diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 522af0f344..1aaab32dbe 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -636,6 +636,12 @@ def visit_fixed(self, fixed_type: FixedType) -> pa.DataType: return pa.binary(len(fixed_type)) def visit_decimal(self, decimal_type: DecimalType) -> pa.DataType: + # It looks like decimal{32,64} is not fully implemented: + # https://github.com/apache/arrow/issues/25483 + # https://github.com/apache/arrow/issues/43956 + # However, if we keep it as 128 in memory, and based on the + # precision/scale Arrow will map it to INT{32,64} + # https://github.com/apache/arrow/blob/598938711a8376cbfdceaf5c77ab0fd5057e6c02/cpp/src/parquet/arrow/schema.cc#L380-L392 return pa.decimal128(decimal_type.precision, decimal_type.scale) def visit_boolean(self, _: BooleanType) -> pa.DataType: @@ -2442,7 +2448,9 @@ def write_parquet(task: WriteTask) -> DataFile: ) fo = io.new_output(file_path) with fo.create(overwrite=True) as fos: - with pq.ParquetWriter(fos, schema=arrow_table.schema, **parquet_writer_kwargs) as writer: + with pq.ParquetWriter( + fos, schema=arrow_table.schema, store_decimal_as_integer=True, **parquet_writer_kwargs + ) as writer: writer.write(arrow_table, row_group_size=row_group_size) statistics = data_file_statistics_from_parquet_metadata( parquet_metadata=writer.writer.metadata, diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index 46d54f0491..8f241a4b84 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -20,6 +20,7 @@ import random import time from datetime import date, datetime, timedelta +from decimal import Decimal from pathlib import Path from typing import Any, Dict from urllib.parse import urlparse @@ -50,6 +51,7 @@ from pyiceberg.transforms import DayTransform, HourTransform, IdentityTransform from pyiceberg.types import ( DateType, + DecimalType, DoubleType, IntegerType, ListType, @@ -1810,3 +1812,32 @@ def test_evolve_and_write( ) assert session_catalog.load_table(identifier).scan().to_arrow().column(0).combine_chunks() == numbers + + +@pytest.mark.integration +def test_read_write_decimals(session_catalog: Catalog) -> None: + """Roundtrip decimal types to make sure that we correctly write them as ints""" + identifier = "default.test_read_write_decimals" + + arrow_table = pa.Table.from_pydict( + { + "decimal8": pa.array([Decimal("123.45"), Decimal("678.91")], pa.decimal128(8, 2)), + "decimal16": pa.array([Decimal("12345679.123456"), Decimal("67891234.678912")], pa.decimal128(16, 6)), + "decimal19": pa.array([Decimal("1234567890123.123456"), Decimal("9876543210703.654321")], pa.decimal128(19, 6)), + }, + ) + + tbl = _create_table( + session_catalog, + identifier, + properties={"format-version": 2}, + schema=Schema( + NestedField(1, "decimal8", DecimalType(8, 2)), + NestedField(2, "decimal16", DecimalType(16, 6)), + NestedField(3, "decimal19", DecimalType(19, 6)), + ), + ) + + tbl.append(arrow_table) + + assert tbl.scan().to_arrow() == arrow_table