Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 52 additions & 35 deletions python/pyspark/sql/streaming/stateful_processor_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,45 @@
Row,
)
from pyspark.sql.pandas.types import convert_pandas_using_numpy_type
from pyspark.sql.utils import has_numpy
from pyspark.serializers import CPickleSerializer
from pyspark.errors import PySparkRuntimeError
import uuid

__all__ = ["StatefulProcessorApiClient", "StatefulProcessorHandleState"]

if has_numpy:
import numpy as np

def _normalize_value(v: Any) -> Any:
# Convert NumPy types to Python primitive types.
if isinstance(v, np.generic):
return v.tolist()
# Named tuples (collections.namedtuple or typing.NamedTuple) and Row both
# require positional arguments and cannot be instantiated
# with a generator expression.
if isinstance(v, Row) or (isinstance(v, tuple) and hasattr(v, "_fields")):
return type(v)(*[_normalize_value(e) for e in v])
# List / tuple: recursively normalize each element
if isinstance(v, (list, tuple)):
return type(v)(_normalize_value(e) for e in v)
# Dict: normalize both keys and values
if isinstance(v, dict):
return {_normalize_value(k): _normalize_value(val) for k, val in v.items()}
# Address a couple of pandas dtypes too.
elif hasattr(v, "to_pytimedelta"):
return v.to_pytimedelta()
elif hasattr(v, "to_pydatetime"):
return v.to_pydatetime()
return v

def _normalize_tuple(data: Tuple) -> Tuple:
return tuple(_normalize_value(v) for v in data)
else:

def _normalize_tuple(data: Tuple) -> Tuple:
return data # toInternal handles tuples natively


class StatefulProcessorHandleState(Enum):
PRE_INIT = 0
Expand Down Expand Up @@ -81,6 +114,10 @@ def __init__(
self.list_timer_iterator_cursors: Dict[str, Tuple[Any, int, bool]] = {}
self.expiry_timer_iterator_cursors: Dict[str, Tuple[Any, int, bool]] = {}

# Cache of schema-id -> fast-serialize callable, so we avoid
# rebuilding field_names / Row / closure on every _serialize_to_bytes call.
self._serializer_cache: Dict[int, Any] = {}

# statefulProcessorApiClient is initialized per batch per partition,
# so we will have new timestamps for a new batch
self._batch_timestamp = -1
Expand Down Expand Up @@ -487,43 +524,23 @@ def _receive_proto_message_with_timers(self) -> Tuple[int, str, Any, bool]:
def _receive_str(self) -> str:
return self.utf8_deserializer.loads(self.sockfile)

def _serialize_to_bytes(self, schema: StructType, data: Tuple) -> bytes:
from pyspark.testing.utils import have_numpy

if have_numpy:
import numpy as np

def normalize_value(v: Any) -> Any:
# Convert NumPy types to Python primitive types.
if isinstance(v, np.generic):
return v.tolist()
# Named tuples (collections.namedtuple or typing.NamedTuple) and Row both
# require positional arguments and cannot be instantiated
# with a generator expression.
if isinstance(v, Row) or (isinstance(v, tuple) and hasattr(v, "_fields")):
return type(v)(*[normalize_value(e) for e in v])
# List / tuple: recursively normalize each element
if isinstance(v, (list, tuple)):
return type(v)(normalize_value(e) for e in v)
# Dict: normalize both keys and values
if isinstance(v, dict):
return {normalize_value(k): normalize_value(val) for k, val in v.items()}
# Address a couple of pandas dtypes too.
elif hasattr(v, "to_pytimedelta"):
return v.to_pytimedelta()
elif hasattr(v, "to_pydatetime"):
return v.to_pydatetime()
else:
return v

converted = [normalize_value(v) for v in data]
else:
converted = list(data)
def _get_serializer(self, schema: StructType) -> Any:
schema_id = id(schema)
serializer = self._serializer_cache.get(schema_id)
if serializer is not None:
return serializer

field_names = [f.name for f in schema.fields]
row_value = Row(**dict(zip(field_names, converted)))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copying from the PR description for reference on why this field_name/Row creation is no longer necessary:

  • StructType.toInternal dispatches on type: for dict it looks up by field name, for tuple/list it zips by position. So functionally there is no need to convert the tuple to list. (L521 deletion)
  • Row is a tuple subclass, so it always hit the positional branch.
  • Since 3.0.0 (types.py change notes), Row field names are insertion ordered. Python dictionaries (as of 3.7+) are also insertion ordered.
  • dict(zip(field_names, converted)) → Row(**...) ends up adding extra hops to (1) fetch field names, (2) zip them with row values, (3) create an insertion-ordered dictionary of those field names, and (4) create an insertion-ordered row (dropping the field names which are no longer used). With the end result being a Row (tuple subclass) which uses same positional branch of Schema.toInternal as the original input tuple would.

to_internal = schema.toInternal
dumps = self.pickleSer.dumps

return self.pickleSer.dumps(schema.toInternal(row_value))
def _fast_serialize(data: Tuple) -> bytes:
return dumps(to_internal(_normalize_tuple(data)))

self._serializer_cache[schema_id] = _fast_serialize
return _fast_serialize

def _serialize_to_bytes(self, schema: StructType, data: Tuple) -> bytes:
return self._get_serializer(schema)(data)

def _deserialize_from_bytes(self, value: bytes) -> Any:
return self.pickleSer.loads(value)
Expand Down
8 changes: 8 additions & 0 deletions python/pyspark/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@
from pyspark.errors.exceptions.captured import CapturedException # noqa: F401
from pyspark.find_spark_home import _find_spark_home

has_numpy: bool = False
try:
import numpy as np # noqa: F401

has_numpy = True
except ImportError:
pass

if TYPE_CHECKING:
from py4j.java_collections import JavaArray
from py4j.java_gateway import (
Expand Down