-
Notifications
You must be signed in to change notification settings - Fork 29.1k
[SPARK-56222][PYTHON] Create ArrowStreamGroupSerializer and ArrowStreamCoGroupSerializer #55026
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
59c401f
404026f
fc0e50d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -20,7 +20,7 @@ | |||||||||||
| """ | ||||||||||||
|
|
||||||||||||
| from itertools import groupby | ||||||||||||
| from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple | ||||||||||||
| from typing import IO, TYPE_CHECKING, Any, Iterator, List, Optional, Tuple | ||||||||||||
|
|
||||||||||||
| import pyspark | ||||||||||||
| from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError | ||||||||||||
|
|
@@ -125,28 +125,20 @@ def __repr__(self): | |||||||||||
|
|
||||||||||||
| class ArrowStreamSerializer(Serializer): | ||||||||||||
| """ | ||||||||||||
| Serializes Arrow record batches as a stream. | ||||||||||||
| Serializes Arrow record batches as a plain stream. | ||||||||||||
|
|
||||||||||||
| Parameters | ||||||||||||
| ---------- | ||||||||||||
| write_start_stream : bool | ||||||||||||
| If True, writes the START_ARROW_STREAM marker before the first | ||||||||||||
| output batch. Default False. | ||||||||||||
| num_dfs : int | ||||||||||||
| Number of dataframes per group. | ||||||||||||
| For num_dfs=0, plain batch stream without group-count protocol. | ||||||||||||
| For num_dfs=1, grouped loading (1 dataframe per group). | ||||||||||||
| For num_dfs=2, cogrouped loading (2 dataframes per group). | ||||||||||||
| Default 0. | ||||||||||||
| """ | ||||||||||||
|
|
||||||||||||
| def __init__(self, write_start_stream: bool = False, num_dfs: int = 0) -> None: | ||||||||||||
| def __init__(self, write_start_stream: bool = False) -> None: | ||||||||||||
| super().__init__() | ||||||||||||
| assert num_dfs in (0, 1, 2), "num_dfs must be 0, 1, or 2" | ||||||||||||
| self._write_start_stream: bool = write_start_stream | ||||||||||||
| self._num_dfs: int = num_dfs | ||||||||||||
|
|
||||||||||||
| def dump_stream(self, iterator, stream): | ||||||||||||
| def dump_stream(self, iterator: Iterator["pa.RecordBatch"], stream: IO[bytes]) -> None: | ||||||||||||
| """Optionally prepend START_ARROW_STREAM, then write batches.""" | ||||||||||||
| if self._write_start_stream: | ||||||||||||
| iterator = self._write_stream_start(iterator, stream) | ||||||||||||
|
|
@@ -163,54 +155,20 @@ def dump_stream(self, iterator, stream): | |||||||||||
| writer.close() | ||||||||||||
|
|
||||||||||||
| @classmethod | ||||||||||||
| def _read_arrow_stream(cls, stream) -> Iterator["pa.RecordBatch"]: | ||||||||||||
| def _read_arrow_stream(cls, stream: IO[bytes]) -> Iterator["pa.RecordBatch"]: | ||||||||||||
| """Read a plain Arrow IPC stream, yielding one RecordBatch per message.""" | ||||||||||||
| import pyarrow as pa | ||||||||||||
|
|
||||||||||||
| reader = pa.ipc.open_stream(stream) | ||||||||||||
| for batch in reader: | ||||||||||||
| yield batch | ||||||||||||
|
|
||||||||||||
| def load_stream(self, stream): | ||||||||||||
| """Load batches: plain stream if num_dfs=0, grouped otherwise.""" | ||||||||||||
| if self._num_dfs == 0: | ||||||||||||
| yield from self._read_arrow_stream(stream) | ||||||||||||
| elif self._num_dfs == 1: | ||||||||||||
| # Grouped loading: yield single dataframe groups | ||||||||||||
| for (batches,) in self._load_group_dataframes(stream, num_dfs=1): | ||||||||||||
| yield batches | ||||||||||||
| elif self._num_dfs == 2: | ||||||||||||
| # Cogrouped loading: yield tuples of (left_batches, right_batches) | ||||||||||||
| for left_batches, right_batches in self._load_group_dataframes(stream, num_dfs=2): | ||||||||||||
| yield left_batches, right_batches | ||||||||||||
| else: | ||||||||||||
| assert False, f"Unexpected num_dfs: {self._num_dfs}" | ||||||||||||
|
|
||||||||||||
| def _load_group_dataframes(self, stream, num_dfs: int = 1) -> Iterator: | ||||||||||||
| """ | ||||||||||||
| Yield groups of dataframes from the stream using the group-count protocol. | ||||||||||||
| """ | ||||||||||||
| dataframes_in_group = None | ||||||||||||
|
|
||||||||||||
| while dataframes_in_group is None or dataframes_in_group > 0: | ||||||||||||
| dataframes_in_group = read_int(stream) | ||||||||||||
|
|
||||||||||||
| if dataframes_in_group == num_dfs: | ||||||||||||
| if num_dfs == 1: | ||||||||||||
| # Single dataframe: can use lazy iterator | ||||||||||||
| yield (self._read_arrow_stream(stream),) | ||||||||||||
| else: | ||||||||||||
| # Multiple dataframes: must eagerly load sequentially | ||||||||||||
| # to maintain correct stream position | ||||||||||||
| yield tuple(list(self._read_arrow_stream(stream)) for _ in range(num_dfs)) | ||||||||||||
| elif dataframes_in_group > 0: | ||||||||||||
| raise PySparkValueError( | ||||||||||||
| errorClass="INVALID_NUMBER_OF_DATAFRAMES_IN_GROUP", | ||||||||||||
| messageParameters={"dataframes_in_group": str(dataframes_in_group)}, | ||||||||||||
| ) | ||||||||||||
| def load_stream(self, stream: IO[bytes]) -> Iterator["pa.RecordBatch"]: | ||||||||||||
| """Load batches from a plain Arrow stream.""" | ||||||||||||
| yield from self._read_arrow_stream(stream) | ||||||||||||
|
|
||||||||||||
| def _write_stream_start( | ||||||||||||
| self, batch_iterator: Iterator["pa.RecordBatch"], stream | ||||||||||||
| self, batch_iterator: Iterator["pa.RecordBatch"], stream: IO[bytes] | ||||||||||||
| ) -> Iterator["pa.RecordBatch"]: | ||||||||||||
| """Write START_ARROW_STREAM before the first batch, then pass batches through.""" | ||||||||||||
| import itertools | ||||||||||||
|
|
@@ -225,10 +183,57 @@ def _write_stream_start( | |||||||||||
| yield from itertools.chain([first], batch_iterator) | ||||||||||||
|
|
||||||||||||
| def __repr__(self) -> str: | ||||||||||||
| return "ArrowStreamSerializer(write_start_stream=%s, num_dfs=%d)" % ( | ||||||||||||
| self._write_start_stream, | ||||||||||||
| self._num_dfs, | ||||||||||||
| ) | ||||||||||||
| return "ArrowStreamSerializer(write_start_stream=%s)" % self._write_start_stream | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| class ArrowStreamGroupSerializer(ArrowStreamSerializer): | ||||||||||||
| """ | ||||||||||||
| Extends :class:`ArrowStreamSerializer` with group-count protocol for loading | ||||||||||||
| grouped Arrow record batches (1 dataframe per group). | ||||||||||||
| """ | ||||||||||||
|
|
||||||||||||
| def load_stream(self, stream: IO[bytes]) -> Iterator[Iterator["pa.RecordBatch"]]: | ||||||||||||
| """Yield one iterator of record batches per group from the stream.""" | ||||||||||||
| dataframes_in_group: Optional[int] = None | ||||||||||||
|
|
||||||||||||
| while dataframes_in_group is None or dataframes_in_group > 0: | ||||||||||||
| dataframes_in_group = read_int(stream) | ||||||||||||
|
Comment on lines
+197
to
+200
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
|
|
||||||||||||
| if dataframes_in_group == 1: | ||||||||||||
| yield self._read_arrow_stream(stream) | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now that the base class has a clear |
||||||||||||
| elif dataframes_in_group > 0: | ||||||||||||
| raise PySparkValueError( | ||||||||||||
| errorClass="INVALID_NUMBER_OF_DATAFRAMES_IN_GROUP", | ||||||||||||
| messageParameters={"dataframes_in_group": str(dataframes_in_group)}, | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| class ArrowStreamCoGroupSerializer(ArrowStreamSerializer): | ||||||||||||
| """ | ||||||||||||
| Extends :class:`ArrowStreamSerializer` with group-count protocol for loading | ||||||||||||
| cogrouped Arrow record batches (2 dataframes per group). | ||||||||||||
| """ | ||||||||||||
|
|
||||||||||||
| def load_stream( | ||||||||||||
| self, stream: IO[bytes] | ||||||||||||
| ) -> Iterator[Tuple[List["pa.RecordBatch"], List["pa.RecordBatch"]]]: | ||||||||||||
| """Yield pairs of (left_batches, right_batches) from the stream.""" | ||||||||||||
| dataframes_in_group: Optional[int] = None | ||||||||||||
|
|
||||||||||||
| while dataframes_in_group is None or dataframes_in_group > 0: | ||||||||||||
| dataframes_in_group = read_int(stream) | ||||||||||||
|
|
||||||||||||
| if dataframes_in_group == 2: | ||||||||||||
| # Must eagerly load each dataframe to maintain correct stream position | ||||||||||||
| yield ( | ||||||||||||
| list(self._read_arrow_stream(stream)), | ||||||||||||
| list(self._read_arrow_stream(stream)), | ||||||||||||
| ) | ||||||||||||
| elif dataframes_in_group > 0: | ||||||||||||
| raise PySparkValueError( | ||||||||||||
| errorClass="INVALID_NUMBER_OF_DATAFRAMES_IN_GROUP", | ||||||||||||
| messageParameters={"dataframes_in_group": str(dataframes_in_group)}, | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| class ArrowStreamUDFSerializer(ArrowStreamSerializer): | ||||||||||||
|
|
@@ -370,7 +375,7 @@ def load_stream(self, stream): | |||||||||||
| """ | ||||||||||||
| Load grouped Arrow record batches from stream. | ||||||||||||
| """ | ||||||||||||
| for (batches,) in self._load_group_dataframes(stream, num_dfs=1): | ||||||||||||
| for batches in ArrowStreamGroupSerializer.load_stream(self, stream): | ||||||||||||
| yield batches | ||||||||||||
| # Make sure the batches are fully iterated before getting the next group | ||||||||||||
| for _ in batches: | ||||||||||||
|
|
@@ -807,7 +812,7 @@ def load_stream(self, stream): | |||||||||||
| Each group yields Iterator[Tuple[pa.Array, ...]], allowing UDF to process batches one by one | ||||||||||||
| without consuming all batches upfront. | ||||||||||||
| """ | ||||||||||||
| for (batches,) in self._load_group_dataframes(stream, num_dfs=1): | ||||||||||||
| for batches in ArrowStreamGroupSerializer.load_stream(self, stream): | ||||||||||||
| # Lazily read and convert Arrow batches one at a time from the stream | ||||||||||||
| # This avoids loading all batches into memory for the group | ||||||||||||
| columns_iter = (batch.columns for batch in batches) | ||||||||||||
|
|
@@ -851,7 +856,7 @@ def load_stream(self, stream): | |||||||||||
| Each group yields Iterator[Tuple[pd.Series, ...]], allowing UDF to | ||||||||||||
| process batches one by one without consuming all batches upfront. | ||||||||||||
| """ | ||||||||||||
| for (batches,) in self._load_group_dataframes(stream, num_dfs=1): | ||||||||||||
| for batches in ArrowStreamGroupSerializer.load_stream(self, stream): | ||||||||||||
| # Lazily read and convert Arrow batches to pandas Series one at a time | ||||||||||||
| # from the stream. This avoids loading all batches into memory for the group | ||||||||||||
| series_iter = map( | ||||||||||||
|
|
@@ -906,7 +911,7 @@ def load_stream(self, stream): | |||||||||||
| Deserialize Grouped ArrowRecordBatches and yield raw Iterator[pa.RecordBatch]. | ||||||||||||
| Each outer iterator element represents a group. | ||||||||||||
| """ | ||||||||||||
| for (batches,) in self._load_group_dataframes(stream, num_dfs=1): | ||||||||||||
| for batches in ArrowStreamGroupSerializer.load_stream(self, stream): | ||||||||||||
| yield batches | ||||||||||||
| # Make sure the batches are fully iterated before getting the next group | ||||||||||||
| for _ in batches: | ||||||||||||
|
|
@@ -941,8 +946,7 @@ def load_stream(self, stream): | |||||||||||
| """ | ||||||||||||
| Deserialize Cogrouped ArrowRecordBatches and yield as two `pyarrow.RecordBatch`es. | ||||||||||||
| """ | ||||||||||||
| for left_batches, right_batches in self._load_group_dataframes(stream, num_dfs=2): | ||||||||||||
| yield left_batches, right_batches | ||||||||||||
| yield from ArrowStreamCoGroupSerializer.load_stream(self, stream) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| class CogroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer): | ||||||||||||
|
|
@@ -953,7 +957,7 @@ def load_stream(self, stream): | |||||||||||
| """ | ||||||||||||
| import pyarrow as pa | ||||||||||||
|
|
||||||||||||
| for left_batches, right_batches in self._load_group_dataframes(stream, num_dfs=2): | ||||||||||||
| for left_batches, right_batches in ArrowStreamCoGroupSerializer.load_stream(self, stream): | ||||||||||||
| left_table = pa.Table.from_batches(left_batches) | ||||||||||||
| right_table = pa.Table.from_batches(right_batches) | ||||||||||||
| yield ( | ||||||||||||
|
|
||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Normally we want to make input less restrictive and output more restrictive. Do we have to restrict the input
iteratortoIteratorrather thanIterable? I saw we had to doiteron lists to make tests work. What if we take anIterableas an input and do aiter()inside the function to get the actual iterator? If anIteratoris passed it,iter()is basically an identity function that returns the iterator itself - we lose nothing.