Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 66 additions & 62 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"""

from itertools import groupby
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple
from typing import IO, TYPE_CHECKING, Any, Iterator, List, Optional, Tuple

import pyspark
from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError
Expand Down Expand Up @@ -125,28 +125,20 @@ def __repr__(self):

class ArrowStreamSerializer(Serializer):
"""
Serializes Arrow record batches as a stream.
Serializes Arrow record batches as a plain stream.

Parameters
----------
write_start_stream : bool
If True, writes the START_ARROW_STREAM marker before the first
output batch. Default False.
num_dfs : int
Number of dataframes per group.
For num_dfs=0, plain batch stream without group-count protocol.
For num_dfs=1, grouped loading (1 dataframe per group).
For num_dfs=2, cogrouped loading (2 dataframes per group).
Default 0.
"""

def __init__(self, write_start_stream: bool = False, num_dfs: int = 0) -> None:
def __init__(self, write_start_stream: bool = False) -> None:
super().__init__()
assert num_dfs in (0, 1, 2), "num_dfs must be 0, 1, or 2"
self._write_start_stream: bool = write_start_stream
self._num_dfs: int = num_dfs

def dump_stream(self, iterator, stream):
def dump_stream(self, iterator: Iterator["pa.RecordBatch"], stream: IO[bytes]) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Normally we want to make input less restrictive and output more restrictive. Do we have to restrict the input iterator to Iterator rather than Iterable? I saw we had to do iter on lists to make tests work. What if we take an Iterable as an input and do a iter() inside the function to get the actual iterator? If an Iterator is passed it, iter() is basically an identity function that returns the iterator itself - we lose nothing.

"""Optionally prepend START_ARROW_STREAM, then write batches."""
if self._write_start_stream:
iterator = self._write_stream_start(iterator, stream)
Expand All @@ -163,54 +155,20 @@ def dump_stream(self, iterator, stream):
writer.close()

@classmethod
def _read_arrow_stream(cls, stream) -> Iterator["pa.RecordBatch"]:
def _read_arrow_stream(cls, stream: IO[bytes]) -> Iterator["pa.RecordBatch"]:
"""Read a plain Arrow IPC stream, yielding one RecordBatch per message."""
import pyarrow as pa

reader = pa.ipc.open_stream(stream)
for batch in reader:
yield batch

def load_stream(self, stream):
"""Load batches: plain stream if num_dfs=0, grouped otherwise."""
if self._num_dfs == 0:
yield from self._read_arrow_stream(stream)
elif self._num_dfs == 1:
# Grouped loading: yield single dataframe groups
for (batches,) in self._load_group_dataframes(stream, num_dfs=1):
yield batches
elif self._num_dfs == 2:
# Cogrouped loading: yield tuples of (left_batches, right_batches)
for left_batches, right_batches in self._load_group_dataframes(stream, num_dfs=2):
yield left_batches, right_batches
else:
assert False, f"Unexpected num_dfs: {self._num_dfs}"

def _load_group_dataframes(self, stream, num_dfs: int = 1) -> Iterator:
"""
Yield groups of dataframes from the stream using the group-count protocol.
"""
dataframes_in_group = None

while dataframes_in_group is None or dataframes_in_group > 0:
dataframes_in_group = read_int(stream)

if dataframes_in_group == num_dfs:
if num_dfs == 1:
# Single dataframe: can use lazy iterator
yield (self._read_arrow_stream(stream),)
else:
# Multiple dataframes: must eagerly load sequentially
# to maintain correct stream position
yield tuple(list(self._read_arrow_stream(stream)) for _ in range(num_dfs))
elif dataframes_in_group > 0:
raise PySparkValueError(
errorClass="INVALID_NUMBER_OF_DATAFRAMES_IN_GROUP",
messageParameters={"dataframes_in_group": str(dataframes_in_group)},
)
def load_stream(self, stream: IO[bytes]) -> Iterator["pa.RecordBatch"]:
"""Load batches from a plain Arrow stream."""
yield from self._read_arrow_stream(stream)

def _write_stream_start(
self, batch_iterator: Iterator["pa.RecordBatch"], stream
self, batch_iterator: Iterator["pa.RecordBatch"], stream: IO[bytes]
) -> Iterator["pa.RecordBatch"]:
"""Write START_ARROW_STREAM before the first batch, then pass batches through."""
import itertools
Expand All @@ -225,10 +183,57 @@ def _write_stream_start(
yield from itertools.chain([first], batch_iterator)

def __repr__(self) -> str:
return "ArrowStreamSerializer(write_start_stream=%s, num_dfs=%d)" % (
self._write_start_stream,
self._num_dfs,
)
return "ArrowStreamSerializer(write_start_stream=%s)" % self._write_start_stream


class ArrowStreamGroupSerializer(ArrowStreamSerializer):
"""
Extends :class:`ArrowStreamSerializer` with group-count protocol for loading
grouped Arrow record batches (1 dataframe per group).
"""

def load_stream(self, stream: IO[bytes]) -> Iterator[Iterator["pa.RecordBatch"]]:
"""Yield one iterator of record batches per group from the stream."""
dataframes_in_group: Optional[int] = None

while dataframes_in_group is None or dataframes_in_group > 0:
dataframes_in_group = read_int(stream)
Comment on lines +197 to +200
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dataframes_in_group: Optional[int] = None
while dataframes_in_group is None or dataframes_in_group > 0:
dataframes_in_group = read_int(stream)
while dataframes_in_group := read_int(stream):


if dataframes_in_group == 1:
yield self._read_arrow_stream(stream)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that the base class has a clear load_stream, we don't have the recursive issue anymore. Would it make sense to do super().load_stream() here and get rid of _read_arrow_stream? As _read_arrow_stream is literally load_stream in the trivial case.

elif dataframes_in_group > 0:
raise PySparkValueError(
errorClass="INVALID_NUMBER_OF_DATAFRAMES_IN_GROUP",
messageParameters={"dataframes_in_group": str(dataframes_in_group)},
)


class ArrowStreamCoGroupSerializer(ArrowStreamSerializer):
"""
Extends :class:`ArrowStreamSerializer` with group-count protocol for loading
cogrouped Arrow record batches (2 dataframes per group).
"""

def load_stream(
self, stream: IO[bytes]
) -> Iterator[Tuple[List["pa.RecordBatch"], List["pa.RecordBatch"]]]:
"""Yield pairs of (left_batches, right_batches) from the stream."""
dataframes_in_group: Optional[int] = None

while dataframes_in_group is None or dataframes_in_group > 0:
dataframes_in_group = read_int(stream)

if dataframes_in_group == 2:
# Must eagerly load each dataframe to maintain correct stream position
yield (
list(self._read_arrow_stream(stream)),
list(self._read_arrow_stream(stream)),
)
elif dataframes_in_group > 0:
raise PySparkValueError(
errorClass="INVALID_NUMBER_OF_DATAFRAMES_IN_GROUP",
messageParameters={"dataframes_in_group": str(dataframes_in_group)},
)


class ArrowStreamUDFSerializer(ArrowStreamSerializer):
Expand Down Expand Up @@ -370,7 +375,7 @@ def load_stream(self, stream):
"""
Load grouped Arrow record batches from stream.
"""
for (batches,) in self._load_group_dataframes(stream, num_dfs=1):
for batches in ArrowStreamGroupSerializer.load_stream(self, stream):
yield batches
# Make sure the batches are fully iterated before getting the next group
for _ in batches:
Expand Down Expand Up @@ -807,7 +812,7 @@ def load_stream(self, stream):
Each group yields Iterator[Tuple[pa.Array, ...]], allowing UDF to process batches one by one
without consuming all batches upfront.
"""
for (batches,) in self._load_group_dataframes(stream, num_dfs=1):
for batches in ArrowStreamGroupSerializer.load_stream(self, stream):
# Lazily read and convert Arrow batches one at a time from the stream
# This avoids loading all batches into memory for the group
columns_iter = (batch.columns for batch in batches)
Expand Down Expand Up @@ -851,7 +856,7 @@ def load_stream(self, stream):
Each group yields Iterator[Tuple[pd.Series, ...]], allowing UDF to
process batches one by one without consuming all batches upfront.
"""
for (batches,) in self._load_group_dataframes(stream, num_dfs=1):
for batches in ArrowStreamGroupSerializer.load_stream(self, stream):
# Lazily read and convert Arrow batches to pandas Series one at a time
# from the stream. This avoids loading all batches into memory for the group
series_iter = map(
Expand Down Expand Up @@ -906,7 +911,7 @@ def load_stream(self, stream):
Deserialize Grouped ArrowRecordBatches and yield raw Iterator[pa.RecordBatch].
Each outer iterator element represents a group.
"""
for (batches,) in self._load_group_dataframes(stream, num_dfs=1):
for batches in ArrowStreamGroupSerializer.load_stream(self, stream):
yield batches
# Make sure the batches are fully iterated before getting the next group
for _ in batches:
Expand Down Expand Up @@ -941,8 +946,7 @@ def load_stream(self, stream):
"""
Deserialize Cogrouped ArrowRecordBatches and yield as two `pyarrow.RecordBatch`es.
"""
for left_batches, right_batches in self._load_group_dataframes(stream, num_dfs=2):
yield left_batches, right_batches
yield from ArrowStreamCoGroupSerializer.load_stream(self, stream)


class CogroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
Expand All @@ -953,7 +957,7 @@ def load_stream(self, stream):
"""
import pyarrow as pa

for left_batches, right_batches in self._load_group_dataframes(stream, num_dfs=2):
for left_batches, right_batches in ArrowStreamCoGroupSerializer.load_stream(self, stream):
left_table = pa.Table.from_batches(left_batches)
right_table = pa.Table.from_batches(right_batches)
yield (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def send_batch_func(
write_int(NON_EMPTY_PYARROW_RECORD_BATCHES, outfile)
write_int(SpecialLengths.START_ARROW_STREAM, outfile)
serializer = ArrowStreamSerializer()
serializer.dump_stream(batches, outfile)
serializer.dump_stream(iter(batches), outfile)
else:
write_int(EMPTY_PYARROW_RECORD_BATCHES, outfile)

Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/sql/streaming/stateful_processor_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import json
import os
import socket
from typing import Any, Dict, List, Union, Optional, Tuple, Iterator
from typing import IO, Any, Dict, List, Union, Optional, Tuple, Iterator, cast

from pyspark.serializers import write_int, read_int, UTF8Deserializer
from pyspark.sql.pandas.serializers import ArrowStreamSerializer
Expand Down Expand Up @@ -537,11 +537,11 @@ def _send_arrow_state(self, schema: StructType, state: List[Tuple]) -> None:
pd.DataFrame(state, columns=column_names), schema
)
batch = pa.RecordBatch.from_pandas(pandas_df)
self.serializer.dump_stream(iter([batch]), self.sockfile)
self.serializer.dump_stream(iter([batch]), cast(IO[bytes], self.sockfile))
self.sockfile.flush()

def _read_arrow_state(self) -> Any:
return self.serializer.load_stream(self.sockfile)
return self.serializer.load_stream(cast(IO[bytes], self.sockfile))

def _send_list_state(self, schema: StructType, state: List[Tuple]) -> None:
for value in state:
Expand Down