diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 59cb5d5451c9..da03e3c314ec 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -20,7 +20,7 @@ """ from itertools import groupby -from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple +from typing import IO, TYPE_CHECKING, Any, Iterator, List, Optional, Tuple import pyspark from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError @@ -125,28 +125,20 @@ def __repr__(self): class ArrowStreamSerializer(Serializer): """ - Serializes Arrow record batches as a stream. + Serializes Arrow record batches as a plain stream. Parameters ---------- write_start_stream : bool If True, writes the START_ARROW_STREAM marker before the first output batch. Default False. - num_dfs : int - Number of dataframes per group. - For num_dfs=0, plain batch stream without group-count protocol. - For num_dfs=1, grouped loading (1 dataframe per group). - For num_dfs=2, cogrouped loading (2 dataframes per group). - Default 0. """ - def __init__(self, write_start_stream: bool = False, num_dfs: int = 0) -> None: + def __init__(self, write_start_stream: bool = False) -> None: super().__init__() - assert num_dfs in (0, 1, 2), "num_dfs must be 0, 1, or 2" self._write_start_stream: bool = write_start_stream - self._num_dfs: int = num_dfs - def dump_stream(self, iterator, stream): + def dump_stream(self, iterator: Iterator["pa.RecordBatch"], stream: IO[bytes]) -> None: """Optionally prepend START_ARROW_STREAM, then write batches.""" if self._write_start_stream: iterator = self._write_stream_start(iterator, stream) @@ -163,7 +155,7 @@ def dump_stream(self, iterator, stream): writer.close() @classmethod - def _read_arrow_stream(cls, stream) -> Iterator["pa.RecordBatch"]: + def _read_arrow_stream(cls, stream: IO[bytes]) -> Iterator["pa.RecordBatch"]: """Read a plain Arrow IPC stream, yielding one RecordBatch per message.""" import pyarrow as pa @@ -171,46 +163,12 @@ def _read_arrow_stream(cls, stream) -> Iterator["pa.RecordBatch"]: for batch in reader: yield batch - def load_stream(self, stream): - """Load batches: plain stream if num_dfs=0, grouped otherwise.""" - if self._num_dfs == 0: - yield from self._read_arrow_stream(stream) - elif self._num_dfs == 1: - # Grouped loading: yield single dataframe groups - for (batches,) in self._load_group_dataframes(stream, num_dfs=1): - yield batches - elif self._num_dfs == 2: - # Cogrouped loading: yield tuples of (left_batches, right_batches) - for left_batches, right_batches in self._load_group_dataframes(stream, num_dfs=2): - yield left_batches, right_batches - else: - assert False, f"Unexpected num_dfs: {self._num_dfs}" - - def _load_group_dataframes(self, stream, num_dfs: int = 1) -> Iterator: - """ - Yield groups of dataframes from the stream using the group-count protocol. - """ - dataframes_in_group = None - - while dataframes_in_group is None or dataframes_in_group > 0: - dataframes_in_group = read_int(stream) - - if dataframes_in_group == num_dfs: - if num_dfs == 1: - # Single dataframe: can use lazy iterator - yield (self._read_arrow_stream(stream),) - else: - # Multiple dataframes: must eagerly load sequentially - # to maintain correct stream position - yield tuple(list(self._read_arrow_stream(stream)) for _ in range(num_dfs)) - elif dataframes_in_group > 0: - raise PySparkValueError( - errorClass="INVALID_NUMBER_OF_DATAFRAMES_IN_GROUP", - messageParameters={"dataframes_in_group": str(dataframes_in_group)}, - ) + def load_stream(self, stream: IO[bytes]) -> Iterator["pa.RecordBatch"]: + """Load batches from a plain Arrow stream.""" + yield from self._read_arrow_stream(stream) def _write_stream_start( - self, batch_iterator: Iterator["pa.RecordBatch"], stream + self, batch_iterator: Iterator["pa.RecordBatch"], stream: IO[bytes] ) -> Iterator["pa.RecordBatch"]: """Write START_ARROW_STREAM before the first batch, then pass batches through.""" import itertools @@ -225,10 +183,57 @@ def _write_stream_start( yield from itertools.chain([first], batch_iterator) def __repr__(self) -> str: - return "ArrowStreamSerializer(write_start_stream=%s, num_dfs=%d)" % ( - self._write_start_stream, - self._num_dfs, - ) + return "ArrowStreamSerializer(write_start_stream=%s)" % self._write_start_stream + + +class ArrowStreamGroupSerializer(ArrowStreamSerializer): + """ + Extends :class:`ArrowStreamSerializer` with group-count protocol for loading + grouped Arrow record batches (1 dataframe per group). + """ + + def load_stream(self, stream: IO[bytes]) -> Iterator[Iterator["pa.RecordBatch"]]: + """Yield one iterator of record batches per group from the stream.""" + dataframes_in_group: Optional[int] = None + + while dataframes_in_group is None or dataframes_in_group > 0: + dataframes_in_group = read_int(stream) + + if dataframes_in_group == 1: + yield self._read_arrow_stream(stream) + elif dataframes_in_group > 0: + raise PySparkValueError( + errorClass="INVALID_NUMBER_OF_DATAFRAMES_IN_GROUP", + messageParameters={"dataframes_in_group": str(dataframes_in_group)}, + ) + + +class ArrowStreamCoGroupSerializer(ArrowStreamSerializer): + """ + Extends :class:`ArrowStreamSerializer` with group-count protocol for loading + cogrouped Arrow record batches (2 dataframes per group). + """ + + def load_stream( + self, stream: IO[bytes] + ) -> Iterator[Tuple[List["pa.RecordBatch"], List["pa.RecordBatch"]]]: + """Yield pairs of (left_batches, right_batches) from the stream.""" + dataframes_in_group: Optional[int] = None + + while dataframes_in_group is None or dataframes_in_group > 0: + dataframes_in_group = read_int(stream) + + if dataframes_in_group == 2: + # Must eagerly load each dataframe to maintain correct stream position + yield ( + list(self._read_arrow_stream(stream)), + list(self._read_arrow_stream(stream)), + ) + elif dataframes_in_group > 0: + raise PySparkValueError( + errorClass="INVALID_NUMBER_OF_DATAFRAMES_IN_GROUP", + messageParameters={"dataframes_in_group": str(dataframes_in_group)}, + ) class ArrowStreamUDFSerializer(ArrowStreamSerializer): @@ -370,7 +375,7 @@ def load_stream(self, stream): """ Load grouped Arrow record batches from stream. """ - for (batches,) in self._load_group_dataframes(stream, num_dfs=1): + for batches in ArrowStreamGroupSerializer.load_stream(self, stream): yield batches # Make sure the batches are fully iterated before getting the next group for _ in batches: @@ -807,7 +812,7 @@ def load_stream(self, stream): Each group yields Iterator[Tuple[pa.Array, ...]], allowing UDF to process batches one by one without consuming all batches upfront. """ - for (batches,) in self._load_group_dataframes(stream, num_dfs=1): + for batches in ArrowStreamGroupSerializer.load_stream(self, stream): # Lazily read and convert Arrow batches one at a time from the stream # This avoids loading all batches into memory for the group columns_iter = (batch.columns for batch in batches) @@ -851,7 +856,7 @@ def load_stream(self, stream): Each group yields Iterator[Tuple[pd.Series, ...]], allowing UDF to process batches one by one without consuming all batches upfront. """ - for (batches,) in self._load_group_dataframes(stream, num_dfs=1): + for batches in ArrowStreamGroupSerializer.load_stream(self, stream): # Lazily read and convert Arrow batches to pandas Series one at a time # from the stream. This avoids loading all batches into memory for the group series_iter = map( @@ -906,7 +911,7 @@ def load_stream(self, stream): Deserialize Grouped ArrowRecordBatches and yield raw Iterator[pa.RecordBatch]. Each outer iterator element represents a group. """ - for (batches,) in self._load_group_dataframes(stream, num_dfs=1): + for batches in ArrowStreamGroupSerializer.load_stream(self, stream): yield batches # Make sure the batches are fully iterated before getting the next group for _ in batches: @@ -941,8 +946,7 @@ def load_stream(self, stream): """ Deserialize Cogrouped ArrowRecordBatches and yield as two `pyarrow.RecordBatch`es. """ - for left_batches, right_batches in self._load_group_dataframes(stream, num_dfs=2): - yield left_batches, right_batches + yield from ArrowStreamCoGroupSerializer.load_stream(self, stream) class CogroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer): @@ -953,7 +957,7 @@ def load_stream(self, stream): """ import pyarrow as pa - for left_batches, right_batches in self._load_group_dataframes(stream, num_dfs=2): + for left_batches, right_batches in ArrowStreamCoGroupSerializer.load_stream(self, stream): left_table = pa.Table.from_batches(left_batches) right_table = pa.Table.from_batches(right_batches) yield ( diff --git a/python/pyspark/sql/streaming/python_streaming_source_runner.py b/python/pyspark/sql/streaming/python_streaming_source_runner.py index d0b6cfbe3234..08ee80284350 100644 --- a/python/pyspark/sql/streaming/python_streaming_source_runner.py +++ b/python/pyspark/sql/streaming/python_streaming_source_runner.py @@ -134,7 +134,7 @@ def send_batch_func( write_int(NON_EMPTY_PYARROW_RECORD_BATCHES, outfile) write_int(SpecialLengths.START_ARROW_STREAM, outfile) serializer = ArrowStreamSerializer() - serializer.dump_stream(batches, outfile) + serializer.dump_stream(iter(batches), outfile) else: write_int(EMPTY_PYARROW_RECORD_BATCHES, outfile) diff --git a/python/pyspark/sql/streaming/stateful_processor_api_client.py b/python/pyspark/sql/streaming/stateful_processor_api_client.py index 490ae184c273..73f5da953a4d 100644 --- a/python/pyspark/sql/streaming/stateful_processor_api_client.py +++ b/python/pyspark/sql/streaming/stateful_processor_api_client.py @@ -18,7 +18,7 @@ import json import os import socket -from typing import Any, Dict, List, Union, Optional, Tuple, Iterator +from typing import IO, Any, Dict, List, Union, Optional, Tuple, Iterator, cast from pyspark.serializers import write_int, read_int, UTF8Deserializer from pyspark.sql.pandas.serializers import ArrowStreamSerializer @@ -537,11 +537,11 @@ def _send_arrow_state(self, schema: StructType, state: List[Tuple]) -> None: pd.DataFrame(state, columns=column_names), schema ) batch = pa.RecordBatch.from_pandas(pandas_df) - self.serializer.dump_stream(iter([batch]), self.sockfile) + self.serializer.dump_stream(iter([batch]), cast(IO[bytes], self.sockfile)) self.sockfile.flush() def _read_arrow_state(self) -> Any: - return self.serializer.load_stream(self.sockfile) + return self.serializer.load_stream(cast(IO[bytes], self.sockfile)) def _send_list_state(self, schema: StructType, state: List[Tuple]) -> None: for value in state: