Skip to content

Commit 7af45f9

Browse files
committed
feat: support pyarrow float16 by widening to float on read/write
PyArrow's float16 (halffloat) raised UnsupportedPyArrowTypeException during schema conversion because _ConvertToIceberg.primitive only handled float32/float64. Iceberg has no half-precision float, but float16 -> float32 is lossless, mirroring how int8/int16 already widen to IntegerType. Map float16 to FloatType, and widen smaller float arrays to the target type in ArrowProjectionVisitor._cast_if_needed (parallel to the integer-widening branch) so float16 columns write as float32.
1 parent 9d36e23 commit 7af45f9

3 files changed

Lines changed: 47 additions & 0 deletions

File tree

pyiceberg/io/pyarrow.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1438,6 +1438,9 @@ def primitive(self, primitive: pa.DataType) -> PrimitiveType:
14381438
else:
14391439
# Does not exist (yet)
14401440
raise TypeError(f"Unsupported integer type: {primitive}")
1441+
elif pa.types.is_float16(primitive):
1442+
# Iceberg has no half-precision float; widen to single precision (lossless)
1443+
return FloatType()
14411444
elif pa.types.is_float32(primitive):
14421445
return FloatType()
14431446
elif pa.types.is_float64(primitive):
@@ -1978,6 +1981,15 @@ def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
19781981
target_width = target_type.bit_width
19791982
if source_width < target_width:
19801983
return values.cast(target_type)
1984+
elif isinstance(field.field_type, (FloatType, DoubleType)):
1985+
# Cast smaller float types to target type for cross-platform compatibility
1986+
# Only allow widening conversions (smaller bit width to larger), e.g. float16 -> float32
1987+
# Narrowing conversions fall through to promote() handling below
1988+
if pa.types.is_floating(values.type):
1989+
source_width = values.type.bit_width
1990+
target_width = target_type.bit_width
1991+
if source_width < target_width:
1992+
return values.cast(target_type)
19811993

19821994
if field.field_type != file_field.field_type:
19831995
target_schema = schema_to_pyarrow(

tests/io/test_pyarrow.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3109,6 +3109,35 @@ def test__to_requested_schema_integer_promotion(
31093109
assert result.column(0).to_pylist() == [1, 2, 3, None]
31103110

31113111

3112+
@pytest.mark.parametrize(
3113+
"arrow_type,iceberg_type,expected_arrow_type",
3114+
[
3115+
(pa.float16(), FloatType(), pa.float32()),
3116+
(pa.float16(), DoubleType(), pa.float64()),
3117+
(pa.float32(), DoubleType(), pa.float64()),
3118+
],
3119+
)
3120+
def test__to_requested_schema_float_promotion(
3121+
arrow_type: pa.DataType,
3122+
iceberg_type: PrimitiveType,
3123+
expected_arrow_type: pa.DataType,
3124+
) -> None:
3125+
"""Test that smaller float types are cast to target Iceberg type during write."""
3126+
requested_schema = Schema(NestedField(1, "col", iceberg_type, required=False))
3127+
file_schema = requested_schema
3128+
3129+
arrow_schema = pa.schema([pa.field("col", arrow_type)])
3130+
data = pa.array([1.5, 2.25, 3.0, None], type=arrow_type)
3131+
batch = pa.RecordBatch.from_arrays([data], schema=arrow_schema)
3132+
3133+
result = _to_requested_schema(
3134+
requested_schema, file_schema, batch, downcast_ns_timestamp_to_us=False, include_field_ids=False
3135+
)
3136+
3137+
assert result.schema[0].type == expected_arrow_type
3138+
assert result.column(0).to_pylist() == [1.5, 2.25, 3.0, None]
3139+
3140+
31123141
def test_pyarrow_file_io_fs_by_scheme_cache() -> None:
31133142
# It's better to set up multi-region minio servers for an integration test once `endpoint_url` argument
31143143
# becomes available for `resolve_s3_region`

tests/io/test_pyarrow_visitor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,12 @@ def test_pyarrow_int64_to_iceberg() -> None:
128128
assert visit(converted_iceberg_type, _ConvertToArrowSchema()) == pyarrow_type
129129

130130

131+
def test_pyarrow_float16_to_iceberg() -> None:
132+
pyarrow_type = pa.float16()
133+
converted_iceberg_type = visit_pyarrow(pyarrow_type, _ConvertToIceberg())
134+
assert converted_iceberg_type == FloatType()
135+
136+
131137
def test_pyarrow_float32_to_iceberg() -> None:
132138
pyarrow_type = pa.float32()
133139
converted_iceberg_type = visit_pyarrow(pyarrow_type, _ConvertToIceberg())

0 commit comments

Comments
 (0)