Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 165 additions & 0 deletions python/benchmarks/bench_eval_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -1981,3 +1981,168 @@ class TransformWithStatePandasUDFPeakmemBench(
_TransformWithStatePandasBenchMixin, _PeakmemBenchBase
):
pass


# -- SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF ----------------------------
# Stateful streaming with Pandas plus an initial-state dataset. The UDF
# signature is ``(api_client, mode, key, state_values, init_states)`` and
# returns ``Iterator[pandas.DataFrame]``.
#
# Unlike the plain TWS variant, the input wire stream wraps two datasets into a
# single Arrow stream whose top-level schema is
# ``struct<inputData: dataSchema, initState: initStateSchema>`` (see
# ``TransformWithStateInPySparkPythonInitialStateRunner``). Each batch carries
# either inputData or initState rows -- never both -- with the inactive column
# written as an all-null struct. Matching the JVM ``initData ++ data`` ordering,
# all initial-state batches are emitted first (initState populated), then all
# data batches (inputData populated). ``TransformWithStateInPandasInitStateSerializer``
# regroups rows by the leading key column, so each key surfaces as an init-only
# call followed by a data-only call; the empty side of each call is filtered out
# before the UDF sees it.


class _TransformWithStatePandasInitStateBenchMixin(_TransformWithStatePandasBenchMixin):
"""Provides ``_write_scenario`` for SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF.

Reuses the plain-TWS scenario grid for the input data and seeds a small
initial-state dataset per group (``_INIT_ROWS_PER_GROUP`` rows sharing the
input schema). The initial-state deserialization cost (nested-struct flatten
plus per-key regrouping) is incurred during ``load_stream`` regardless of
whether the UDF reads ``init_states``.
"""

# Initial state is small relative to the streamed data (one seeded chunk per
# key), so data deserialization stays the dominant cost -- mirroring
# production where initial state loads once and input data streams per batch.
_INIT_ROWS_PER_GROUP = 100

@classmethod
def _build_init_batches(cls, name):
"""Build the initial-state Arrow batches for a scenario.

Shares the input schema (same value columns) but with only
``_INIT_ROWS_PER_GROUP`` rows per group, pre-sorted by the leading key.
"""
np.random.seed(7)
num_groups, _, num_value_cols, value_pool = cls._scenario_configs[name]
total_rows = num_groups * cls._INIT_ROWS_PER_GROUP
key_array = pa.array(
np.repeat(np.arange(num_groups, dtype=np.int32), cls._INIT_ROWS_PER_GROUP),
type=pa.int32(),
)
value_arrays = [
value_pool[i % len(value_pool)][0](total_rows) for i in range(num_value_cols)
]
names = ["col_0"] + [f"col_{i + 1}" for i in range(num_value_cols)]
full_batch = pa.RecordBatch.from_arrays([key_array] + value_arrays, names=names)
batch_size = MockDataFactory.MAX_RECORDS_PER_BATCH
return [
full_batch.slice(offset, min(batch_size, total_rows - offset))
for offset in range(0, total_rows, batch_size)
]

@staticmethod
def _wrap_nested(flat_batch, struct_type, *, is_init):
"""Wrap a flat batch into a ``struct<inputData, initState>`` batch.

The populated side carries ``flat_batch``'s columns; the inactive side is
an all-null struct array of the same length, so ``flatten_columns`` in the
serializer treats it as empty.
"""
n = flat_batch.num_rows
populated = pa.StructArray.from_arrays(
[flat_batch.column(i) for i in range(flat_batch.num_columns)],
names=flat_batch.schema.names,
)
null_struct = pa.array([None] * n, type=struct_type)
arrays = [null_struct, populated] if is_init else [populated, null_struct]
return pa.RecordBatch.from_arrays(arrays, names=["inputData", "initState"])

def _tws_init_identity(api_client, mode, key, state_values, init_states):
from pyspark.sql.streaming.stateful_processor_util import (
TransformWithStateInPandasFuncMode,
)

if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA:
yield from state_values

def _tws_init_sort(api_client, mode, key, state_values, init_states):
from pyspark.sql.streaming.stateful_processor_util import (
TransformWithStateInPandasFuncMode,
)

if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA:
for pdf in state_values:
yield pdf.sort_values(pdf.columns[0])

def _tws_init_count(api_client, mode, key, state_values, init_states):
import pandas as pd
from pyspark.sql.streaming.stateful_processor_util import (
TransformWithStateInPandasFuncMode,
)

if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA:
# state_values and init_states arrive on separate per-key calls; sum
# whichever is non-empty so both deserialization paths are counted.
total = sum(len(pdf) for pdf in state_values) + sum(len(pdf) for pdf in init_states)
if total:
yield pd.DataFrame({"col_0": [key[0]], "col_1": [total]})

# ret_type=None means "echo the full input schema": the init-state worker
# path does not project value columns, so identity/sort receive and return
# the key column too. count_udf re-emits (key, total) explicitly.
_udfs = {
"identity_udf": (_tws_init_identity, None),
"sort_udf": (_tws_init_sort, None),
"count_udf": (
_tws_init_count,
StructType([StructField("col_0", IntegerType()), StructField("col_1", IntegerType())]),
),
}
params = [list(_TransformWithStatePandasBenchMixin._scenario_configs), list(_udfs)]
param_names = ["scenario", "udf"]

def _write_scenario(self, scenario, udf_name, buf):
data_batches, schema = self._build_scenario(scenario)
init_batches = self._build_init_batches(scenario)
udf_func, ret_type = self._udfs[udf_name]
if ret_type is None:
ret_type = schema
n_value_cols = len(schema.fields) - self._NUM_KEY_COLS
# Two arg-offset groups -- one for input data, one for initial state.
# Both datasets share the schema, so each resolves to key=[0], values=[1..n].
arg_offsets = MockUDFFactory.make_grouped_arg_offsets(
self._NUM_KEY_COLS, n_value_cols
) + MockUDFFactory.make_grouped_arg_offsets(self._NUM_KEY_COLS, n_value_cols)
grouping_key_schema = StructType(schema.fields[: self._NUM_KEY_COLS])
# Wrap both datasets into the struct<inputData, initState> wire schema;
# the two structs share a type since the datasets share a schema.
struct_type = pa.StructArray.from_arrays(
[data_batches[0].column(i) for i in range(data_batches[0].num_columns)],
names=data_batches[0].schema.names,
).type
nested_batches = [self._wrap_nested(b, struct_type, is_init=True) for b in init_batches] + [
self._wrap_nested(b, struct_type, is_init=False) for b in data_batches
]
MockProtocolWriter.write_worker_input(
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF,
lambda b: MockProtocolWriter.write_udf_payload(udf_func, ret_type, arg_offsets, b),
lambda b: MockProtocolWriter.write_data_payload(iter(nested_batches), b),
buf,
eval_conf={
"state_server_socket_port": str(_StubStateServer.get_port()),
"grouping_key_schema": grouping_key_schema.json(),
},
)


class TransformWithStatePandasInitStateUDFTimeBench(
_TransformWithStatePandasInitStateBenchMixin, _TimeBenchBase
):
pass


class TransformWithStatePandasInitStateUDFPeakmemBench(
_TransformWithStatePandasInitStateBenchMixin, _PeakmemBenchBase
):
pass