Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 36 additions & 31 deletions python/pyspark/sql/streaming/stateful_processor_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from datetime import datetime
from enum import Enum
import json
import os
Expand All @@ -27,12 +28,43 @@
Row,
)
from pyspark.sql.pandas.types import convert_pandas_using_numpy_type
from pyspark.serializers import CPickleSerializer
from pyspark.serializers import PickleSerializer
from pyspark.errors import PySparkRuntimeError
import uuid

__all__ = ["StatefulProcessorApiClient", "StatefulProcessorHandleState"]

try:
import numpy as np

has_numpy = True
SCALAR_TYPES = frozenset((bool, int, float, str, bytes, datetime, type(None)))

def _normalize_state_value(v: Any) -> Any:
if type(v) in SCALAR_TYPES: # Fast path for common scalar values.
return v
# Convert NumPy scalar values to Python primitive values.
if isinstance(v, np.generic):
return v.tolist()
# Named tuples (collections.namedtuple or typing.NamedTuple) and Row both
# require positional arguments and cannot be instantiated with a generator expression.
if isinstance(v, Row) or (isinstance(v, tuple) and hasattr(v, "_fields")):
return type(v)(*map(_normalize_state_value, v))
# List / tuple: recursively normalize each element.
if isinstance(v, (list, tuple)):
return type(v)(map(_normalize_state_value, v))
# Dict: normalize both keys and values.
if isinstance(v, dict):
return {_normalize_state_value(k): _normalize_state_value(val) for k, val in v.items()}
# Address a couple of pandas dtypes too.
if hasattr(v, "to_pytimedelta"):
return v.to_pytimedelta()
if hasattr(v, "to_pydatetime"):
return v.to_pydatetime()
return v
except ImportError:
has_numpy = False


class StatefulProcessorHandleState(Enum):
PRE_INIT = 0
Expand Down Expand Up @@ -74,7 +106,7 @@ def __init__(
else:
self.handle_state = StatefulProcessorHandleState.CREATED
self.utf8_deserializer = UTF8Deserializer()
self.pickleSer = CPickleSerializer()
self.pickleSer = PickleSerializer()
Comment thread
funrollloops marked this conversation as resolved.
self.serializer = ArrowStreamSerializer()
# Dictionaries to store the mapping between iterator id and a tuple of data batch
# and the index of the last row that was read.
Expand Down Expand Up @@ -488,35 +520,8 @@ def _receive_str(self) -> str:
return self.utf8_deserializer.loads(self.sockfile)

def _serialize_to_bytes(self, schema: StructType, data: Tuple) -> bytes:
from pyspark.testing.utils import have_numpy

if have_numpy:
import numpy as np

def normalize_value(v: Any) -> Any:
# Convert NumPy types to Python primitive types.
if isinstance(v, np.generic):
return v.tolist()
# Named tuples (collections.namedtuple or typing.NamedTuple) and Row both
# require positional arguments and cannot be instantiated
# with a generator expression.
if isinstance(v, Row) or (isinstance(v, tuple) and hasattr(v, "_fields")):
return type(v)(*[normalize_value(e) for e in v])
# List / tuple: recursively normalize each element
if isinstance(v, (list, tuple)):
return type(v)(normalize_value(e) for e in v)
# Dict: normalize both keys and values
if isinstance(v, dict):
return {normalize_value(k): normalize_value(val) for k, val in v.items()}
# Address a couple of pandas dtypes too.
elif hasattr(v, "to_pytimedelta"):
return v.to_pytimedelta()
elif hasattr(v, "to_pydatetime"):
return v.to_pydatetime()
else:
return v

converted = tuple(normalize_value(v) for v in data)
if has_numpy:
converted = tuple(map(_normalize_state_value, data))
else:
converted = data

Expand Down