Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/content/program-api/python-api.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ record_batch = ...
table_write.write_arrow_batch(record_batch)

# 2.4 Write Ray Dataset (requires ray to be installed)

**Note:** Ray Data converts `large_binary()` to `binary()` when reading. `write_ray()` automatically converts `binary()` back to `large_binary()` to match the table schema.
import ray
ray_dataset = ray.data.read_json("/path/to/data.jsonl")
table_write.write_ray(ray_dataset, overwrite=False, concurrency=2)
Expand Down Expand Up @@ -471,6 +473,8 @@ df = ray_dataset.to_pandas()
- `**read_args`: Additional kwargs passed to the datasource (e.g., `per_task_row_limit`
in Ray 2.52.0+).

**Note:** Ray Data converts `large_binary()` to `binary()` when reading. When writing back via `write_ray()`, the conversion is handled automatically.

**Ray Block Size Configuration:**

If you need to configure Ray's block size (e.g., when Paimon splits exceed Ray's default
Expand Down
4 changes: 4 additions & 0 deletions paimon-python/pypaimon/read/datasource/ray_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ class RayDatasource(Datasource):

This datasource enables distributed parallel reading of Paimon table splits,
allowing Ray to read multiple splits concurrently across the cluster.

.. note::
Ray Data converts ``large_binary()`` to ``binary()`` when reading.
When writing back via :meth:`TableWrite.write_ray`, the conversion is handled automatically.
"""

def __init__(self, table_read: TableRead, splits: List[Split]):
Expand Down
5 changes: 5 additions & 0 deletions paimon-python/pypaimon/read/table_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ def to_ray(
**read_args,
) -> "ray.data.dataset.Dataset":
"""Convert Paimon table data to Ray Dataset.

.. note::
Ray Data converts ``large_binary()`` to ``binary()`` when reading.
When writing back via :meth:`write_ray`, the conversion is handled automatically.

Args:
splits: List of splits to read from the Paimon table.
ray_remote_args: Optional kwargs passed to :func:`ray.remote` in read tasks.
Expand Down
110 changes: 99 additions & 11 deletions paimon-python/pypaimon/tests/ray_data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
import shutil

import pyarrow as pa
import pyarrow.types as pa_types
import ray

from pypaimon import CatalogFactory, Schema
from pypaimon.common.options.core_options import CoreOptions
from pypaimon.schema.data_types import PyarrowFieldParser


class RayDataTest(unittest.TestCase):
Expand Down Expand Up @@ -115,22 +117,108 @@ def test_basic_ray_data_read(self):
self.assertIsNotNone(ray_dataset, "Ray dataset should not be None")
self.assertEqual(ray_dataset.count(), 5, "Should have 5 rows")

# Test basic operations
sample_data = ray_dataset.take(3)
self.assertEqual(len(sample_data), 3, "Should have 3 sample rows")
def test_ray_data_read_and_write_with_blob(self):
import time
pa_schema = pa.schema([
('id', pa.int64()),
('name', pa.string()),
('data', pa.large_binary()), # Table uses large_binary for blob
])

# Convert to pandas for verification
df = ray_dataset.to_pandas()
self.assertEqual(len(df), 5, "DataFrame should have 5 rows")
# Sort by id to ensure order-independent comparison
df_sorted = df.sort_values(by='id').reset_index(drop=True)
self.assertEqual(list(df_sorted['id']), [1, 2, 3, 4, 5], "ID column should match")
schema = Schema.from_pyarrow_schema(
pa_schema,
options={
'row-tracking.enabled': 'true',
'data-evolution.enabled': 'true',
'blob-field': 'data',
}
)

table_name = f'default.test_ray_read_write_blob_{int(time.time() * 1000000)}'
self.catalog.create_table(table_name, schema, False)
table = self.catalog.get_table(table_name)

# Step 1: Write data to Paimon table using write_arrow (large_binary type)
initial_data = pa.Table.from_pydict({
'id': [1, 2, 3],
'name': ['Alice', 'Bob', 'Charlie'],
'data': [b'blob_data_1', b'blob_data_2', b'blob_data_3'],
}, schema=pa_schema)

write_builder = table.new_batch_write_builder()
writer = write_builder.new_write()
writer.write_arrow(initial_data)
commit_messages = writer.prepare_commit()
commit = write_builder.new_commit()
commit.commit(commit_messages)
writer.close()

# Step 2: Read from Paimon table using to_ray()
read_builder = table.new_read_builder()
table_read = read_builder.new_read()
table_scan = read_builder.new_scan()
splits = table_scan.plan().splits()

ray_dataset = table_read.to_ray(splits)

df_check = ray_dataset.to_pandas()
ray_table_check = pa.Table.from_pandas(df_check)
ray_schema_check = ray_table_check.schema

ray_data_field = ray_schema_check.field('data')
self.assertTrue(
pa_types.is_binary(ray_data_field.type),
f"Ray Dataset from Paimon should have binary() type (Ray Data converts large_binary to binary), but got {ray_data_field.type}"
)
self.assertFalse(
pa_types.is_large_binary(ray_data_field.type),
f"Ray Dataset from Paimon should NOT have large_binary() type, but got {ray_data_field.type}"
)

table_schema = table.table_schema
table_pa_schema = PyarrowFieldParser.from_paimon_schema(table_schema.fields)
table_data_field = table_pa_schema.field('data')

self.assertTrue(
pa_types.is_large_binary(table_data_field.type),
f"Paimon table should have large_binary() type for data, but got {table_data_field.type}"
)

# Step 3: Write Ray Dataset back to Paimon table using write_ray()
write_builder2 = table.new_batch_write_builder()
writer2 = write_builder2.new_write()

writer2.write_ray(
ray_dataset,
overwrite=False,
concurrency=1
)
writer2.close()

# Step 4: Verify the data was written correctly
read_builder2 = table.new_read_builder()
table_read2 = read_builder2.new_read()
result = table_read2.to_arrow(read_builder2.new_scan().plan().splits())

self.assertEqual(result.num_rows, 6, "Table should have 6 rows after roundtrip")

result_df = result.to_pandas()
result_df_sorted = result_df.sort_values(by='id').reset_index(drop=True)

self.assertEqual(list(result_df_sorted['id']), [1, 1, 2, 2, 3, 3], "ID column should match")
self.assertEqual(
list(df_sorted['name']),
['Alice', 'Bob', 'Charlie', 'David', 'Eve'],
list(result_df_sorted['name']),
['Alice', 'Alice', 'Bob', 'Bob', 'Charlie', 'Charlie'],
"Name column should match"
)

written_data_values = [bytes(d) if d is not None else None for d in result_df_sorted['data']]
self.assertEqual(
written_data_values,
[b'blob_data_1', b'blob_data_1', b'blob_data_2', b'blob_data_2', b'blob_data_3', b'blob_data_3'],
"Blob data column should match"
)

def test_basic_ray_data_write(self):
"""Test basic Ray Data write from PyPaimon table."""
# Create schema
Expand Down
103 changes: 97 additions & 6 deletions paimon-python/pypaimon/write/table_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional

import pyarrow as pa
import pyarrow.types as pa_types

from pypaimon.schema.data_types import PyarrowFieldParser
from pypaimon.snapshot.snapshot import BATCH_COMMIT_IDENTIFIER
Expand All @@ -29,6 +30,14 @@
from ray.data import Dataset


def _is_binary_type_compatible(input_type: pa.DataType, table_type: pa.DataType) -> bool:
if pa_types.is_binary(input_type) and pa_types.is_large_binary(table_type):
return True
if pa_types.is_large_binary(input_type) and pa_types.is_binary(table_type):
return True
return False


class TableWrite:
def __init__(self, table, commit_user):
from pypaimon.table.file_store_table import FileStoreTable
Expand All @@ -44,8 +53,46 @@ def write_arrow(self, table: pa.Table):
for batch in batches_iterator:
self.write_arrow_batch(batch)

def _convert_binary_types(self, data: pa.RecordBatch) -> pa.RecordBatch:
write_cols = self.file_store_write.write_cols
table_schema = self.table_pyarrow_schema

converted_arrays = []
needs_conversion = False

for i, field in enumerate(data.schema):
array = data.column(i)
expected_type = None

if write_cols is None or field.name in write_cols:
try:
expected_type = table_schema.field(field.name).type
except KeyError:
pass

if expected_type and field.type != expected_type and _is_binary_type_compatible(field.type, expected_type):
try:
array = pa.compute.cast(array, expected_type)
needs_conversion = True
except (pa.ArrowInvalid, pa.ArrowCapacityError, ValueError) as e:
direction = f"{field.type} to {expected_type}"
raise ValueError(
f"Failed to convert field '{field.name}' from {direction}. "
f"If converting to binary(), ensure no value exceeds 2GB limit: {e}"
) from e

converted_arrays.append(array)

if needs_conversion:
new_fields = [pa.field(field.name, arr.type, nullable=field.nullable)
for field, arr in zip(data.schema, converted_arrays)]
return pa.RecordBatch.from_arrays(converted_arrays, schema=pa.schema(new_fields))

return data

def write_arrow_batch(self, data: pa.RecordBatch):
self._validate_pyarrow_schema(data.schema)
data = self._convert_binary_types(data)
partitions, buckets = self.row_key_extractor.extract_partition_bucket_batch(data)

partition_bucket_groups = defaultdict(list)
Expand All @@ -59,7 +106,7 @@ def write_arrow_batch(self, data: pa.RecordBatch):

def write_pandas(self, dataframe):
pa_schema = PyarrowFieldParser.from_paimon_schema(self.table.table_schema.fields)
record_batch = pa.RecordBatch.from_pandas(dataframe, schema=pa_schema)
record_batch = pa.RecordBatch.from_pandas(dataframe, schema=pa_schema, preserve_index=False)
return self.write_arrow_batch(record_batch)

def with_write_type(self, write_cols: List[str]):
Expand All @@ -81,6 +128,11 @@ def write_ray(
"""
Write a Ray Dataset to Paimon table.

.. note::
Ray Data converts ``large_binary()`` to ``binary()`` when reading.
This method automatically converts ``binary()`` back to ``large_binary()``
to match the table schema.

Args:
dataset: Ray Dataset to write. This is a distributed data collection
from Ray Data (ray.data.Dataset).
Expand All @@ -102,11 +154,50 @@ def close(self):
self.file_store_write.close()

def _validate_pyarrow_schema(self, data_schema: pa.Schema):
if data_schema != self.table_pyarrow_schema and data_schema.names != self.file_store_write.write_cols:
raise ValueError(f"Input schema isn't consistent with table schema and write cols. "
f"Input schema is: {data_schema} "
f"Table schema is: {self.table_pyarrow_schema} "
f"Write cols is: {self.file_store_write.write_cols}")
write_cols = self.file_store_write.write_cols

if write_cols is None:
if data_schema.names != self.table_pyarrow_schema.names:
raise ValueError(
f"Input schema isn't consistent with table schema and write cols. "
f"Input schema is: {data_schema} "
f"Table schema is: {self.table_pyarrow_schema} "
f"Write cols is: {self.file_store_write.write_cols}"
)
for input_field, table_field in zip(data_schema, self.table_pyarrow_schema):
if input_field.type != table_field.type:
if not _is_binary_type_compatible(input_field.type, table_field.type):
raise ValueError(
f"Input schema doesn't match table schema. "
f"Field '{input_field.name}' type mismatch.\n"
f"Input type: {input_field.type}\n"
f"Table type: {table_field.type}\n"
f"Input schema: {data_schema}\n"
f"Table schema: {self.table_pyarrow_schema}"
)
else:
if list(data_schema.names) != write_cols:
raise ValueError(
f"Input schema field names don't match write_cols. "
f"Field names and order must match write_cols.\n"
f"Input schema names: {list(data_schema.names)}\n"
f"Write cols: {write_cols}"
)
table_field_map = {field.name: field for field in self.table_pyarrow_schema}
for field_name in write_cols:
if field_name not in table_field_map:
raise ValueError(
f"Field '{field_name}' in write_cols is not in table schema."
)
input_field = data_schema.field(field_name)
table_field = table_field_map[field_name]
if input_field.type != table_field.type:
if not _is_binary_type_compatible(input_field.type, table_field.type):
raise ValueError(
f"Field '{field_name}' type mismatch.\n"
f"Input type: {input_field.type}\n"
f"Table type: {table_field.type}"
)


class BatchTableWrite(TableWrite):
Expand Down
Loading