diff --git a/python/pyspark/sql/streaming/stateful_processor_api_client.py b/python/pyspark/sql/streaming/stateful_processor_api_client.py index 490ae184c273f..b2d7c1b2da6f6 100644 --- a/python/pyspark/sql/streaming/stateful_processor_api_client.py +++ b/python/pyspark/sql/streaming/stateful_processor_api_client.py @@ -27,12 +27,45 @@ Row, ) from pyspark.sql.pandas.types import convert_pandas_using_numpy_type +from pyspark.sql.utils import has_numpy from pyspark.serializers import CPickleSerializer from pyspark.errors import PySparkRuntimeError import uuid __all__ = ["StatefulProcessorApiClient", "StatefulProcessorHandleState"] +if has_numpy: + import numpy as np + + def _normalize_value(v: Any) -> Any: + # Convert NumPy types to Python primitive types. + if isinstance(v, np.generic): + return v.tolist() + # Named tuples (collections.namedtuple or typing.NamedTuple) and Row both + # require positional arguments and cannot be instantiated + # with a generator expression. + if isinstance(v, Row) or (isinstance(v, tuple) and hasattr(v, "_fields")): + return type(v)(*[_normalize_value(e) for e in v]) + # List / tuple: recursively normalize each element + if isinstance(v, (list, tuple)): + return type(v)(_normalize_value(e) for e in v) + # Dict: normalize both keys and values + if isinstance(v, dict): + return {_normalize_value(k): _normalize_value(val) for k, val in v.items()} + # Address a couple of pandas dtypes too. + elif hasattr(v, "to_pytimedelta"): + return v.to_pytimedelta() + elif hasattr(v, "to_pydatetime"): + return v.to_pydatetime() + return v + + def _normalize_tuple(data: Tuple) -> Tuple: + return tuple(_normalize_value(v) for v in data) +else: + + def _normalize_tuple(data: Tuple) -> Tuple: + return data # toInternal handles tuples natively + class StatefulProcessorHandleState(Enum): PRE_INIT = 0 @@ -81,6 +114,10 @@ def __init__( self.list_timer_iterator_cursors: Dict[str, Tuple[Any, int, bool]] = {} self.expiry_timer_iterator_cursors: Dict[str, Tuple[Any, int, bool]] = {} + # Cache of schema-id -> fast-serialize callable, so we avoid + # rebuilding field_names / Row / closure on every _serialize_to_bytes call. + self._serializer_cache: Dict[int, Any] = {} + # statefulProcessorApiClient is initialized per batch per partition, # so we will have new timestamps for a new batch self._batch_timestamp = -1 @@ -487,43 +524,23 @@ def _receive_proto_message_with_timers(self) -> Tuple[int, str, Any, bool]: def _receive_str(self) -> str: return self.utf8_deserializer.loads(self.sockfile) - def _serialize_to_bytes(self, schema: StructType, data: Tuple) -> bytes: - from pyspark.testing.utils import have_numpy - - if have_numpy: - import numpy as np - - def normalize_value(v: Any) -> Any: - # Convert NumPy types to Python primitive types. - if isinstance(v, np.generic): - return v.tolist() - # Named tuples (collections.namedtuple or typing.NamedTuple) and Row both - # require positional arguments and cannot be instantiated - # with a generator expression. - if isinstance(v, Row) or (isinstance(v, tuple) and hasattr(v, "_fields")): - return type(v)(*[normalize_value(e) for e in v]) - # List / tuple: recursively normalize each element - if isinstance(v, (list, tuple)): - return type(v)(normalize_value(e) for e in v) - # Dict: normalize both keys and values - if isinstance(v, dict): - return {normalize_value(k): normalize_value(val) for k, val in v.items()} - # Address a couple of pandas dtypes too. - elif hasattr(v, "to_pytimedelta"): - return v.to_pytimedelta() - elif hasattr(v, "to_pydatetime"): - return v.to_pydatetime() - else: - return v - - converted = [normalize_value(v) for v in data] - else: - converted = list(data) + def _get_serializer(self, schema: StructType) -> Any: + schema_id = id(schema) + serializer = self._serializer_cache.get(schema_id) + if serializer is not None: + return serializer - field_names = [f.name for f in schema.fields] - row_value = Row(**dict(zip(field_names, converted))) + to_internal = schema.toInternal + dumps = self.pickleSer.dumps - return self.pickleSer.dumps(schema.toInternal(row_value)) + def _fast_serialize(data: Tuple) -> bytes: + return dumps(to_internal(_normalize_tuple(data))) + + self._serializer_cache[schema_id] = _fast_serialize + return _fast_serialize + + def _serialize_to_bytes(self, schema: StructType, data: Tuple) -> bytes: + return self._get_serializer(schema)(data) def _deserialize_from_bytes(self, value: bytes) -> Any: return self.pickleSer.loads(value) diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 493eb6ee0abc3..b76f077768aff 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -50,6 +50,14 @@ from pyspark.errors.exceptions.captured import CapturedException # noqa: F401 from pyspark.find_spark_home import _find_spark_home +has_numpy: bool = False +try: + import numpy as np # noqa: F401 + + has_numpy = True +except ImportError: + pass + if TYPE_CHECKING: from py4j.java_collections import JavaArray from py4j.java_gateway import (