From 59c401f2f0a30ef212a6caec43df54b2f8fc5cc4 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Wed, 25 Mar 2026 17:20:00 -0700 Subject: [PATCH 1/3] refactor: extract ArrowStreamGroupSerializer and ArrowStreamCoGroupSerializer --- python/pyspark/sql/pandas/serializers.py | 155 +++++++++++++---------- 1 file changed, 88 insertions(+), 67 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 59cb5d5451c9..b92d348bcf22 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -20,7 +20,7 @@ """ from itertools import groupby -from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple +from typing import IO, TYPE_CHECKING, Any, Iterator, List, Optional, Tuple import pyspark from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError @@ -125,28 +125,20 @@ def __repr__(self): class ArrowStreamSerializer(Serializer): """ - Serializes Arrow record batches as a stream. + Serializes Arrow record batches as a plain stream. Parameters ---------- write_start_stream : bool If True, writes the START_ARROW_STREAM marker before the first output batch. Default False. - num_dfs : int - Number of dataframes per group. - For num_dfs=0, plain batch stream without group-count protocol. - For num_dfs=1, grouped loading (1 dataframe per group). - For num_dfs=2, cogrouped loading (2 dataframes per group). - Default 0. """ - def __init__(self, write_start_stream: bool = False, num_dfs: int = 0) -> None: + def __init__(self, write_start_stream: bool = False) -> None: super().__init__() - assert num_dfs in (0, 1, 2), "num_dfs must be 0, 1, or 2" self._write_start_stream: bool = write_start_stream - self._num_dfs: int = num_dfs - def dump_stream(self, iterator, stream): + def dump_stream(self, iterator: Iterator["pa.RecordBatch"], stream: IO[bytes]) -> None: """Optionally prepend START_ARROW_STREAM, then write batches.""" if self._write_start_stream: iterator = self._write_stream_start(iterator, stream) @@ -163,7 +155,7 @@ def dump_stream(self, iterator, stream): writer.close() @classmethod - def _read_arrow_stream(cls, stream) -> Iterator["pa.RecordBatch"]: + def _read_arrow_stream(cls, stream: IO[bytes]) -> Iterator["pa.RecordBatch"]: """Read a plain Arrow IPC stream, yielding one RecordBatch per message.""" import pyarrow as pa @@ -171,46 +163,12 @@ def _read_arrow_stream(cls, stream) -> Iterator["pa.RecordBatch"]: for batch in reader: yield batch - def load_stream(self, stream): - """Load batches: plain stream if num_dfs=0, grouped otherwise.""" - if self._num_dfs == 0: - yield from self._read_arrow_stream(stream) - elif self._num_dfs == 1: - # Grouped loading: yield single dataframe groups - for (batches,) in self._load_group_dataframes(stream, num_dfs=1): - yield batches - elif self._num_dfs == 2: - # Cogrouped loading: yield tuples of (left_batches, right_batches) - for left_batches, right_batches in self._load_group_dataframes(stream, num_dfs=2): - yield left_batches, right_batches - else: - assert False, f"Unexpected num_dfs: {self._num_dfs}" - - def _load_group_dataframes(self, stream, num_dfs: int = 1) -> Iterator: - """ - Yield groups of dataframes from the stream using the group-count protocol. - """ - dataframes_in_group = None - - while dataframes_in_group is None or dataframes_in_group > 0: - dataframes_in_group = read_int(stream) - - if dataframes_in_group == num_dfs: - if num_dfs == 1: - # Single dataframe: can use lazy iterator - yield (self._read_arrow_stream(stream),) - else: - # Multiple dataframes: must eagerly load sequentially - # to maintain correct stream position - yield tuple(list(self._read_arrow_stream(stream)) for _ in range(num_dfs)) - elif dataframes_in_group > 0: - raise PySparkValueError( - errorClass="INVALID_NUMBER_OF_DATAFRAMES_IN_GROUP", - messageParameters={"dataframes_in_group": str(dataframes_in_group)}, - ) + def load_stream(self, stream: IO[bytes]) -> Iterator["pa.RecordBatch"]: + """Load batches from a plain Arrow stream.""" + yield from self._read_arrow_stream(stream) def _write_stream_start( - self, batch_iterator: Iterator["pa.RecordBatch"], stream + self, batch_iterator: Iterator["pa.RecordBatch"], stream: IO[bytes] ) -> Iterator["pa.RecordBatch"]: """Write START_ARROW_STREAM before the first batch, then pass batches through.""" import itertools @@ -225,10 +183,57 @@ def _write_stream_start( yield from itertools.chain([first], batch_iterator) def __repr__(self) -> str: - return "ArrowStreamSerializer(write_start_stream=%s, num_dfs=%d)" % ( - self._write_start_stream, - self._num_dfs, - ) + return "ArrowStreamSerializer(write_start_stream=%s)" % self._write_start_stream + + +class ArrowStreamGroupSerializer(ArrowStreamSerializer): + """ + Extends :class:`ArrowStreamSerializer` with group-count protocol for loading + grouped Arrow record batches (1 dataframe per group). + """ + + def load_stream(self, stream: IO[bytes]) -> Iterator[Iterator["pa.RecordBatch"]]: + """Yield one iterator of record batches per group from the stream.""" + dataframes_in_group: Optional[int] = None + + while dataframes_in_group is None or dataframes_in_group > 0: + dataframes_in_group = read_int(stream) + + if dataframes_in_group == 1: + yield self._read_arrow_stream(stream) + elif dataframes_in_group > 0: + raise PySparkValueError( + errorClass="INVALID_NUMBER_OF_DATAFRAMES_IN_GROUP", + messageParameters={"dataframes_in_group": str(dataframes_in_group)}, + ) + + +class ArrowStreamCoGroupSerializer(ArrowStreamSerializer): + """ + Extends :class:`ArrowStreamSerializer` with group-count protocol for loading + cogrouped Arrow record batches (2 dataframes per group). + """ + + def load_stream( + self, stream: IO[bytes] + ) -> Iterator[Tuple[List["pa.RecordBatch"], List["pa.RecordBatch"]]]: + """Yield pairs of (left_batches, right_batches) from the stream.""" + dataframes_in_group: Optional[int] = None + + while dataframes_in_group is None or dataframes_in_group > 0: + dataframes_in_group = read_int(stream) + + if dataframes_in_group == 2: + # Must eagerly load each dataframe to maintain correct stream position + yield ( + list(self._read_arrow_stream(stream)), + list(self._read_arrow_stream(stream)), + ) + elif dataframes_in_group > 0: + raise PySparkValueError( + errorClass="INVALID_NUMBER_OF_DATAFRAMES_IN_GROUP", + messageParameters={"dataframes_in_group": str(dataframes_in_group)}, + ) class ArrowStreamUDFSerializer(ArrowStreamSerializer): @@ -370,7 +375,7 @@ def load_stream(self, stream): """ Load grouped Arrow record batches from stream. """ - for (batches,) in self._load_group_dataframes(stream, num_dfs=1): + for batches in ArrowStreamGroupSerializer.load_stream(self, stream): yield batches # Make sure the batches are fully iterated before getting the next group for _ in batches: @@ -807,7 +812,7 @@ def load_stream(self, stream): Each group yields Iterator[Tuple[pa.Array, ...]], allowing UDF to process batches one by one without consuming all batches upfront. """ - for (batches,) in self._load_group_dataframes(stream, num_dfs=1): + for batches in ArrowStreamGroupSerializer.load_stream(self, stream): # Lazily read and convert Arrow batches one at a time from the stream # This avoids loading all batches into memory for the group columns_iter = (batch.columns for batch in batches) @@ -851,7 +856,7 @@ def load_stream(self, stream): Each group yields Iterator[Tuple[pd.Series, ...]], allowing UDF to process batches one by one without consuming all batches upfront. """ - for (batches,) in self._load_group_dataframes(stream, num_dfs=1): + for batches in ArrowStreamGroupSerializer.load_stream(self, stream): # Lazily read and convert Arrow batches to pandas Series one at a time # from the stream. This avoids loading all batches into memory for the group series_iter = map( @@ -906,7 +911,7 @@ def load_stream(self, stream): Deserialize Grouped ArrowRecordBatches and yield raw Iterator[pa.RecordBatch]. Each outer iterator element represents a group. """ - for (batches,) in self._load_group_dataframes(stream, num_dfs=1): + for batches in ArrowStreamGroupSerializer.load_stream(self, stream): yield batches # Make sure the batches are fully iterated before getting the next group for _ in batches: @@ -924,7 +929,7 @@ def __repr__(self): return "GroupPandasUDFSerializer" -class CogroupArrowUDFSerializer(ArrowStreamGroupUDFSerializer): +class CogroupArrowUDFSerializer(ArrowStreamCoGroupSerializer): """ Serializes pyarrow.RecordBatch data with Arrow streaming format. @@ -937,12 +942,28 @@ class CogroupArrowUDFSerializer(ArrowStreamGroupUDFSerializer): If True, then DataFrames will get columns by name """ - def load_stream(self, stream): - """ - Deserialize Cogrouped ArrowRecordBatches and yield as two `pyarrow.RecordBatch`es. - """ - for left_batches, right_batches in self._load_group_dataframes(stream, num_dfs=2): - yield left_batches, right_batches + def __init__(self, *, assign_cols_by_name): + super().__init__() + self._assign_cols_by_name = assign_cols_by_name + + def dump_stream(self, iterator, stream): + import pyarrow as pa + + batch_iter = ((batch, arrow_type) for batches, arrow_type in iterator for batch in batches) + + if self._assign_cols_by_name: + batch_iter = ( + ( + pa.RecordBatch.from_arrays( + [batch.column(field.name) for field in arrow_type], + names=[field.name for field in arrow_type], + ), + arrow_type, + ) + for batch, arrow_type in batch_iter + ) + + super().dump_stream(batch_iter, stream) class CogroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer): @@ -953,7 +974,7 @@ def load_stream(self, stream): """ import pyarrow as pa - for left_batches, right_batches in self._load_group_dataframes(stream, num_dfs=2): + for left_batches, right_batches in ArrowStreamCoGroupSerializer.load_stream(self, stream): left_table = pa.Table.from_batches(left_batches) right_table = pa.Table.from_batches(right_batches) yield ( From 404026f5bd6b3674654d62b9d3d0bc85202e4955 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Wed, 25 Mar 2026 18:02:38 -0700 Subject: [PATCH 2/3] fix: keep CogroupArrowUDFSerializer inheriting ArrowStreamGroupUDFSerializer --- python/pyspark/sql/pandas/serializers.py | 29 +++++------------------- 1 file changed, 6 insertions(+), 23 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index b92d348bcf22..da03e3c314ec 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -929,7 +929,7 @@ def __repr__(self): return "GroupPandasUDFSerializer" -class CogroupArrowUDFSerializer(ArrowStreamCoGroupSerializer): +class CogroupArrowUDFSerializer(ArrowStreamGroupUDFSerializer): """ Serializes pyarrow.RecordBatch data with Arrow streaming format. @@ -942,28 +942,11 @@ class CogroupArrowUDFSerializer(ArrowStreamCoGroupSerializer): If True, then DataFrames will get columns by name """ - def __init__(self, *, assign_cols_by_name): - super().__init__() - self._assign_cols_by_name = assign_cols_by_name - - def dump_stream(self, iterator, stream): - import pyarrow as pa - - batch_iter = ((batch, arrow_type) for batches, arrow_type in iterator for batch in batches) - - if self._assign_cols_by_name: - batch_iter = ( - ( - pa.RecordBatch.from_arrays( - [batch.column(field.name) for field in arrow_type], - names=[field.name for field in arrow_type], - ), - arrow_type, - ) - for batch, arrow_type in batch_iter - ) - - super().dump_stream(batch_iter, stream) + def load_stream(self, stream): + """ + Deserialize Cogrouped ArrowRecordBatches and yield as two `pyarrow.RecordBatch`es. + """ + yield from ArrowStreamCoGroupSerializer.load_stream(self, stream) class CogroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer): From fc0e50d9ac85872ac8a5dc3be287650c6d6ed149 Mon Sep 17 00:00:00 2001 From: Yicong Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Thu, 26 Mar 2026 05:51:30 +0000 Subject: [PATCH 3/3] fix: resolve mypy type errors in ArrowStreamSerializer call sites --- .../pyspark/sql/streaming/python_streaming_source_runner.py | 2 +- .../pyspark/sql/streaming/stateful_processor_api_client.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/streaming/python_streaming_source_runner.py b/python/pyspark/sql/streaming/python_streaming_source_runner.py index d0b6cfbe3234..08ee80284350 100644 --- a/python/pyspark/sql/streaming/python_streaming_source_runner.py +++ b/python/pyspark/sql/streaming/python_streaming_source_runner.py @@ -134,7 +134,7 @@ def send_batch_func( write_int(NON_EMPTY_PYARROW_RECORD_BATCHES, outfile) write_int(SpecialLengths.START_ARROW_STREAM, outfile) serializer = ArrowStreamSerializer() - serializer.dump_stream(batches, outfile) + serializer.dump_stream(iter(batches), outfile) else: write_int(EMPTY_PYARROW_RECORD_BATCHES, outfile) diff --git a/python/pyspark/sql/streaming/stateful_processor_api_client.py b/python/pyspark/sql/streaming/stateful_processor_api_client.py index 490ae184c273..73f5da953a4d 100644 --- a/python/pyspark/sql/streaming/stateful_processor_api_client.py +++ b/python/pyspark/sql/streaming/stateful_processor_api_client.py @@ -18,7 +18,7 @@ import json import os import socket -from typing import Any, Dict, List, Union, Optional, Tuple, Iterator +from typing import IO, Any, Dict, List, Union, Optional, Tuple, Iterator, cast from pyspark.serializers import write_int, read_int, UTF8Deserializer from pyspark.sql.pandas.serializers import ArrowStreamSerializer @@ -537,11 +537,11 @@ def _send_arrow_state(self, schema: StructType, state: List[Tuple]) -> None: pd.DataFrame(state, columns=column_names), schema ) batch = pa.RecordBatch.from_pandas(pandas_df) - self.serializer.dump_stream(iter([batch]), self.sockfile) + self.serializer.dump_stream(iter([batch]), cast(IO[bytes], self.sockfile)) self.sockfile.flush() def _read_arrow_state(self) -> Any: - return self.serializer.load_stream(self.sockfile) + return self.serializer.load_stream(cast(IO[bytes], self.sockfile)) def _send_list_state(self, schema: StructType, state: List[Tuple]) -> None: for value in state: