From ad3ee78084de709e68598c90978976aebf169867 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Wed, 8 Apr 2026 12:16:23 -0700 Subject: [PATCH 01/18] Add Nexus system payload codec rewrite support --- pyproject.toml | 16 +- scripts/gen_nexus_system_test_models.py | 75 ++ temporalio/bridge/worker.py | 46 +- temporalio/nexus/system/__init__.py | 41 + .../system/_workflow_service_generated.py | 821 ++++++++++++++++++ temporalio/worker/_command_aware_visitor.py | 22 +- tests/nexus/test_temporal_system_nexus.py | 178 ++++ 7 files changed, 1192 insertions(+), 7 deletions(-) create mode 100644 scripts/gen_nexus_system_test_models.py create mode 100644 temporalio/nexus/system/__init__.py create mode 100644 temporalio/nexus/system/_workflow_service_generated.py create mode 100644 tests/nexus/test_temporal_system_nexus.py diff --git a/pyproject.toml b/pyproject.toml index 4bcd3f03e..3c810ccc9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,8 +67,14 @@ dev = [ ] [tool.poe.tasks] -build-develop = "uv run maturin develop --uv" -build-develop-with-release = { cmd = "uv run maturin develop --release --uv" } +build-develop = [ + { ref = "gen-nexus-system-test-models" }, + { cmd = "uv run maturin develop --uv" }, +] +build-develop-with-release = [ + { ref = "gen-nexus-system-test-models" }, + { cmd = "uv run maturin develop --release --uv" }, +] format = [ { cmd = "uv run ruff check --select I --fix" }, { cmd = "uv run ruff format" }, @@ -79,6 +85,7 @@ gen-protos = [ { cmd = "uv run scripts/gen_protos.py" }, { cmd = "uv run scripts/gen_payload_visitor.py" }, { cmd = "uv run scripts/gen_bridge_client.py" }, + { ref = "gen-nexus-system-test-models" }, { ref = "format" }, ] gen-protos-docker = [ @@ -98,10 +105,12 @@ bridge-lint = { cmd = "cargo clippy -- -D warnings", cwd = "temporalio/bridge" } # https://github.com/PyCQA/pydocstyle/pull/511? lint-docs = "uv run pydocstyle --ignore-decorators=overload" lint-types = [ + { ref = "gen-nexus-system-test-models" }, { cmd = "uv run pyright" }, { cmd = "uv run mypy --namespace-packages --check-untyped-defs ." }, { cmd = "uv run basedpyright" }, ] +gen-nexus-system-test-models = "uv run scripts/gen_nexus_system_test_models.py" run-bench = "uv run python scripts/run_bench.py" test = "uv run pytest" @@ -141,12 +150,14 @@ exclude = [ # Ignore generated code 'temporalio/api', 'temporalio/bridge/proto', + 'temporalio/nexus/system/_workflow_service_generated.py', ] [tool.pydocstyle] convention = "google" # https://github.com/PyCQA/pydocstyle/issues/363#issuecomment-625563088 match_dir = "^(?!(docs|scripts|tests|api|proto|\\.)).*" +match = "^(?!_workflow_service_generated\\.py$).*\\.py" add_ignore = [ # We like to wrap at a certain number of chars, even long summary sentences. # https://github.com/PyCQA/pydocstyle/issues/184 @@ -212,6 +223,7 @@ exclude = [ "temporalio/api", "temporalio/bridge/proto", "temporalio/bridge/_visitor.py", + "temporalio/nexus/system/_workflow_service_generated.py", "tests/worker/workflow_sandbox/testmodules/proto", ] diff --git a/scripts/gen_nexus_system_test_models.py b/scripts/gen_nexus_system_test_models.py new file mode 100644 index 000000000..d707727a5 --- /dev/null +++ b/scripts/gen_nexus_system_test_models.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import subprocess +import sys +from pathlib import Path + + +def main() -> None: + repo_root = Path(__file__).resolve().parent.parent + workspace_root = repo_root.parent + nexus_rpc_gen_root = workspace_root / "nexus-rpc-gen" / "src" + input_schema = ( + workspace_root + / "temporal-api" + / "nexus" + / "temporal-json-schema-models-nexusrpc.yaml" + ) + output_file = ( + repo_root / "temporalio" / "nexus" / "system" / "_workflow_service_generated.py" + ) + + if not nexus_rpc_gen_root.is_dir(): + raise RuntimeError(f"Expected nexus-rpc-gen checkout at {nexus_rpc_gen_root}") + if not input_schema.is_file(): + raise RuntimeError(f"Expected Temporal Nexus schema at {input_schema}") + + subprocess.run( + [ + "npm", + "run", + "cli", + "--", + "--lang", + "py", + "--out-file", + str(output_file), + "--temporal-nexus-payload-codec-support", + str(input_schema), + ], + cwd=nexus_rpc_gen_root, + check=True, + ) + subprocess.run( + [ + "uv", + "run", + "ruff", + "check", + "--select", + "I", + "--fix", + str(output_file), + ], + cwd=repo_root, + check=True, + ) + subprocess.run( + [ + "uv", + "run", + "ruff", + "format", + str(output_file), + ], + cwd=repo_root, + check=True, + ) + + +if __name__ == "__main__": + try: + main() + except Exception as err: + print(f"Failed to generate Nexus system test models: {err}", file=sys.stderr) + raise diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index c2e426d28..52a54ce9c 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -20,7 +20,9 @@ import temporalio.bridge.runtime import temporalio.bridge.temporal_sdk_bridge import temporalio.converter +import temporalio.nexus.system from temporalio.api.common.v1.message_pb2 import Payload +from temporalio.api.enums.v1.command_type_pb2 import CommandType from temporalio.bridge._visitor import VisitorFunctions from temporalio.bridge.temporal_sdk_bridge import ( CustomSlotSupplier as BridgeCustomSlotSupplier, @@ -28,6 +30,7 @@ from temporalio.bridge.temporal_sdk_bridge import ( PollShutdownError, # type: ignore # noqa: F401 ) +from temporalio.worker import _command_aware_visitor from temporalio.worker._command_aware_visitor import CommandAwarePayloadVisitor @@ -279,14 +282,50 @@ async def finalize_shutdown(self) -> None: class _Visitor(VisitorFunctions): - def __init__(self, f: Callable[[Sequence[Payload]], Awaitable[list[Payload]]]): + def __init__( + self, + f: Callable[[Sequence[Payload]], Awaitable[list[Payload]]], + payload_codec: temporalio.converter.PayloadCodec | None = None, + ): self._f = f + self._payload_codec = payload_codec async def visit_payload(self, payload: Payload) -> None: + if self._payload_codec: + rewritten_payload = await self._maybe_rewrite_nexus_payload(payload) + if rewritten_payload is not None: + if rewritten_payload is not payload: + payload.CopyFrom(rewritten_payload) + return new_payload = (await self._f([payload]))[0] if new_payload is not payload: payload.CopyFrom(new_payload) + async def _maybe_rewrite_nexus_payload(self, payload: Payload) -> Payload | None: + command_info = _command_aware_visitor.current_command_info.get() + if ( + command_info is None + or command_info.command_type + != CommandType.COMMAND_TYPE_SCHEDULE_NEXUS_OPERATION + or not command_info.nexus_service + or not command_info.nexus_operation + ): + return None + + rewrite = temporalio.nexus.system.get_payload_codec_rewriter( + command_info.nexus_service, + command_info.nexus_operation, + ) + if rewrite is None: + return None + + rewritten_payload = await rewrite(payload, self._payload_codec) + if not isinstance(rewritten_payload, Payload): + raise TypeError( + "temporal nexus payload codec rewriter must return a Payload" + ) + return rewritten_payload + async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: if len(payloads) == 0: return @@ -316,4 +355,7 @@ async def encode_completion( """Encode all payloads in the completion.""" await CommandAwarePayloadVisitor( skip_search_attributes=True, skip_headers=not encode_headers - ).visit(_Visitor(data_converter._encode_payload_sequence), completion) + ).visit( + _Visitor(data_converter._encode_payload_sequence, data_converter.payload_codec), + completion, + ) diff --git a/temporalio/nexus/system/__init__.py b/temporalio/nexus/system/__init__.py new file mode 100644 index 000000000..7eff3a93d --- /dev/null +++ b/temporalio/nexus/system/__init__.py @@ -0,0 +1,41 @@ +"""Generated system Nexus service models. + +This package contains code generated from Temporal's system Nexus schemas. +Higher-level ergonomic APIs may wrap these generated types. +""" + +from collections.abc import Awaitable, Callable + +import temporalio.api.common.v1 +import temporalio.converter + +from ._workflow_service_generated import ( + WorkflowService, + WorkflowServiceSignalWithStartWorkflowExecutionInput, + WorkflowServiceSignalWithStartWorkflowExecutionOutput, + __temporal_nexus_payload_codec_rewriters__, +) + +TemporalNexusPayloadCodecRewriter = Callable[ + [ + temporalio.api.common.v1.Payload, + temporalio.converter.PayloadCodec | None, + ], + Awaitable[temporalio.api.common.v1.Payload], +] + + +def get_payload_codec_rewriter( + service: str, + operation: str, +) -> TemporalNexusPayloadCodecRewriter | None: + """Return the generated payload codec rewriter for a system Nexus operation.""" + return __temporal_nexus_payload_codec_rewriters__.get((service, operation)) + + +__all__ = ( + "WorkflowService", + "WorkflowServiceSignalWithStartWorkflowExecutionInput", + "WorkflowServiceSignalWithStartWorkflowExecutionOutput", + "get_payload_codec_rewriter", +) diff --git a/temporalio/nexus/system/_workflow_service_generated.py b/temporalio/nexus/system/_workflow_service_generated.py new file mode 100644 index 000000000..1723631a1 --- /dev/null +++ b/temporalio/nexus/system/_workflow_service_generated.py @@ -0,0 +1,821 @@ +# Generated by nexus-rpc-gen. DO NOT EDIT! + +from __future__ import annotations + +import json +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Optional + +from google.protobuf.json_format import MessageToDict, ParseDict +from nexusrpc import Operation, service +from pydantic import BaseModel, Field + +import temporalio.api.common.v1 +import temporalio.converter + + +async def _temporal_nexus_encode_payload_json( + value: dict, payload_codec: temporalio.converter.PayloadCodec +) -> dict: + payload = ParseDict(value, temporalio.api.common.v1.Payload()) + [encoded_payload] = await payload_codec.encode([payload]) + return MessageToDict(encoded_payload) + + +async def _temporal_nexus_encode_payloads_json( + value: dict, payload_codec: temporalio.converter.PayloadCodec +) -> dict: + payloads = ParseDict(value, temporalio.api.common.v1.Payloads()) + encoded_payloads = await payload_codec.encode(payloads.payloads) + del payloads.payloads[:] + payloads.payloads.extend(encoded_payloads) + return MessageToDict(payloads) + + +async def _temporal_nexus_encode_payload_map_json( + message_type: type, value: dict, payload_codec: temporalio.converter.PayloadCodec +) -> dict: + message = ParseDict(value, message_type()) + keys = list(message.fields.keys()) + encoded_payloads = await payload_codec.encode([message.fields[key] for key in keys]) + for key, encoded_payload in zip(keys, encoded_payloads): + message.fields[key].CopyFrom(encoded_payload) + return MessageToDict(message) + + +async def _temporal_nexus_encode_json_value( + value: object, payload_codec: temporalio.converter.PayloadCodec +) -> object: + if isinstance(value, list): + return [ + await _temporal_nexus_encode_json_value(item, payload_codec) + for item in value + ] + if not isinstance(value, dict): + return value + if "indexedFields" in value: + return value + if "payloads" in value and isinstance(value["payloads"], list): + return await _temporal_nexus_encode_payloads_json(value, payload_codec) + if "fields" in value and isinstance(value["fields"], dict): + return await _temporal_nexus_encode_payload_map_json( + temporalio.api.common.v1.Header, value, payload_codec + ) + if "data" in value and "metadata" in value: + return await _temporal_nexus_encode_payload_json(value, payload_codec) + rewritten: dict[str, object] = {} + for key, item in value.items(): + rewritten[key] = ( + item + if key == "indexedFields" + else await _temporal_nexus_encode_json_value(item, payload_codec) + ) + return rewritten + + +class Header(BaseModel): + """Contains metadata that can be attached to a variety of requests, like starting a + workflow, and + can be propagated between, for example, workflows and activities. + """ + + fields: Optional[Dict[str, Any]] = None + + +class Input(BaseModel): + """Serialized arguments to the workflow. These are passed as arguments to the workflow + function. + + See `Payload` + + Serialized value(s) to provide with the signal + """ + + payloads: Optional[List[Any]] = None + + +class BatchJob(BaseModel): + """A link to a built-in batch job. + Batch jobs can be used to perform operations on a set of workflows (e.g. terminate, + signal, cancel, etc). + This link can be put on workflow history events generated by actions taken by a batch job. + """ + + job_id: Optional[str] = Field(None, alias="jobId") + + +class EventType(Enum): + EVENT_TYPE_ACTIVITY_PROPERTIES_MODIFIED_EXTERNALLY = ( + "EVENT_TYPE_ACTIVITY_PROPERTIES_MODIFIED_EXTERNALLY" + ) + EVENT_TYPE_ACTIVITY_TASK_CANCELED = "EVENT_TYPE_ACTIVITY_TASK_CANCELED" + EVENT_TYPE_ACTIVITY_TASK_CANCEL_REQUESTED = ( + "EVENT_TYPE_ACTIVITY_TASK_CANCEL_REQUESTED" + ) + EVENT_TYPE_ACTIVITY_TASK_COMPLETED = "EVENT_TYPE_ACTIVITY_TASK_COMPLETED" + EVENT_TYPE_ACTIVITY_TASK_FAILED = "EVENT_TYPE_ACTIVITY_TASK_FAILED" + EVENT_TYPE_ACTIVITY_TASK_SCHEDULED = "EVENT_TYPE_ACTIVITY_TASK_SCHEDULED" + EVENT_TYPE_ACTIVITY_TASK_STARTED = "EVENT_TYPE_ACTIVITY_TASK_STARTED" + EVENT_TYPE_ACTIVITY_TASK_TIMED_OUT = "EVENT_TYPE_ACTIVITY_TASK_TIMED_OUT" + EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_CANCELED = ( + "EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_CANCELED" + ) + EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_COMPLETED = ( + "EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_COMPLETED" + ) + EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_FAILED = ( + "EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_FAILED" + ) + EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_STARTED = ( + "EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_STARTED" + ) + EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_TERMINATED = ( + "EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_TERMINATED" + ) + EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_TIMED_OUT = ( + "EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_TIMED_OUT" + ) + EVENT_TYPE_EXTERNAL_WORKFLOW_EXECUTION_CANCEL_REQUESTED = ( + "EVENT_TYPE_EXTERNAL_WORKFLOW_EXECUTION_CANCEL_REQUESTED" + ) + EVENT_TYPE_EXTERNAL_WORKFLOW_EXECUTION_SIGNALED = ( + "EVENT_TYPE_EXTERNAL_WORKFLOW_EXECUTION_SIGNALED" + ) + EVENT_TYPE_MARKER_RECORDED = "EVENT_TYPE_MARKER_RECORDED" + EVENT_TYPE_NEXUS_OPERATION_CANCELED = "EVENT_TYPE_NEXUS_OPERATION_CANCELED" + EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED = ( + "EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED" + ) + EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_COMPLETED = ( + "EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_COMPLETED" + ) + EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_FAILED = ( + "EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_FAILED" + ) + EVENT_TYPE_NEXUS_OPERATION_COMPLETED = "EVENT_TYPE_NEXUS_OPERATION_COMPLETED" + EVENT_TYPE_NEXUS_OPERATION_FAILED = "EVENT_TYPE_NEXUS_OPERATION_FAILED" + EVENT_TYPE_NEXUS_OPERATION_SCHEDULED = "EVENT_TYPE_NEXUS_OPERATION_SCHEDULED" + EVENT_TYPE_NEXUS_OPERATION_STARTED = "EVENT_TYPE_NEXUS_OPERATION_STARTED" + EVENT_TYPE_NEXUS_OPERATION_TIMED_OUT = "EVENT_TYPE_NEXUS_OPERATION_TIMED_OUT" + EVENT_TYPE_REQUEST_CANCEL_EXTERNAL_WORKFLOW_EXECUTION_FAILED = ( + "EVENT_TYPE_REQUEST_CANCEL_EXTERNAL_WORKFLOW_EXECUTION_FAILED" + ) + EVENT_TYPE_REQUEST_CANCEL_EXTERNAL_WORKFLOW_EXECUTION_INITIATED = ( + "EVENT_TYPE_REQUEST_CANCEL_EXTERNAL_WORKFLOW_EXECUTION_INITIATED" + ) + EVENT_TYPE_SIGNAL_EXTERNAL_WORKFLOW_EXECUTION_FAILED = ( + "EVENT_TYPE_SIGNAL_EXTERNAL_WORKFLOW_EXECUTION_FAILED" + ) + EVENT_TYPE_SIGNAL_EXTERNAL_WORKFLOW_EXECUTION_INITIATED = ( + "EVENT_TYPE_SIGNAL_EXTERNAL_WORKFLOW_EXECUTION_INITIATED" + ) + EVENT_TYPE_START_CHILD_WORKFLOW_EXECUTION_FAILED = ( + "EVENT_TYPE_START_CHILD_WORKFLOW_EXECUTION_FAILED" + ) + EVENT_TYPE_START_CHILD_WORKFLOW_EXECUTION_INITIATED = ( + "EVENT_TYPE_START_CHILD_WORKFLOW_EXECUTION_INITIATED" + ) + EVENT_TYPE_TIMER_CANCELED = "EVENT_TYPE_TIMER_CANCELED" + EVENT_TYPE_TIMER_FIRED = "EVENT_TYPE_TIMER_FIRED" + EVENT_TYPE_TIMER_STARTED = "EVENT_TYPE_TIMER_STARTED" + EVENT_TYPE_UNSPECIFIED = "EVENT_TYPE_UNSPECIFIED" + EVENT_TYPE_UPSERT_WORKFLOW_SEARCH_ATTRIBUTES = ( + "EVENT_TYPE_UPSERT_WORKFLOW_SEARCH_ATTRIBUTES" + ) + EVENT_TYPE_WORKFLOW_EXECUTION_CANCELED = "EVENT_TYPE_WORKFLOW_EXECUTION_CANCELED" + EVENT_TYPE_WORKFLOW_EXECUTION_CANCEL_REQUESTED = ( + "EVENT_TYPE_WORKFLOW_EXECUTION_CANCEL_REQUESTED" + ) + EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED = "EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED" + EVENT_TYPE_WORKFLOW_EXECUTION_CONTINUED_AS_NEW = ( + "EVENT_TYPE_WORKFLOW_EXECUTION_CONTINUED_AS_NEW" + ) + EVENT_TYPE_WORKFLOW_EXECUTION_FAILED = "EVENT_TYPE_WORKFLOW_EXECUTION_FAILED" + EVENT_TYPE_WORKFLOW_EXECUTION_OPTIONS_UPDATED = ( + "EVENT_TYPE_WORKFLOW_EXECUTION_OPTIONS_UPDATED" + ) + EVENT_TYPE_WORKFLOW_EXECUTION_PAUSED = "EVENT_TYPE_WORKFLOW_EXECUTION_PAUSED" + EVENT_TYPE_WORKFLOW_EXECUTION_SIGNALED = "EVENT_TYPE_WORKFLOW_EXECUTION_SIGNALED" + EVENT_TYPE_WORKFLOW_EXECUTION_STARTED = "EVENT_TYPE_WORKFLOW_EXECUTION_STARTED" + EVENT_TYPE_WORKFLOW_EXECUTION_TERMINATED = ( + "EVENT_TYPE_WORKFLOW_EXECUTION_TERMINATED" + ) + EVENT_TYPE_WORKFLOW_EXECUTION_TIMED_OUT = "EVENT_TYPE_WORKFLOW_EXECUTION_TIMED_OUT" + EVENT_TYPE_WORKFLOW_EXECUTION_TIME_SKIPPING_TRANSITIONED = ( + "EVENT_TYPE_WORKFLOW_EXECUTION_TIME_SKIPPING_TRANSITIONED" + ) + EVENT_TYPE_WORKFLOW_EXECUTION_UNPAUSED = "EVENT_TYPE_WORKFLOW_EXECUTION_UNPAUSED" + EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_ACCEPTED = ( + "EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_ACCEPTED" + ) + EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_ADMITTED = ( + "EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_ADMITTED" + ) + EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_COMPLETED = ( + "EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_COMPLETED" + ) + EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_REJECTED = ( + "EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_REJECTED" + ) + EVENT_TYPE_WORKFLOW_PROPERTIES_MODIFIED = "EVENT_TYPE_WORKFLOW_PROPERTIES_MODIFIED" + EVENT_TYPE_WORKFLOW_PROPERTIES_MODIFIED_EXTERNALLY = ( + "EVENT_TYPE_WORKFLOW_PROPERTIES_MODIFIED_EXTERNALLY" + ) + EVENT_TYPE_WORKFLOW_TASK_COMPLETED = "EVENT_TYPE_WORKFLOW_TASK_COMPLETED" + EVENT_TYPE_WORKFLOW_TASK_FAILED = "EVENT_TYPE_WORKFLOW_TASK_FAILED" + EVENT_TYPE_WORKFLOW_TASK_SCHEDULED = "EVENT_TYPE_WORKFLOW_TASK_SCHEDULED" + EVENT_TYPE_WORKFLOW_TASK_STARTED = "EVENT_TYPE_WORKFLOW_TASK_STARTED" + EVENT_TYPE_WORKFLOW_TASK_TIMED_OUT = "EVENT_TYPE_WORKFLOW_TASK_TIMED_OUT" + + +class EventRef(BaseModel): + """EventReference is a direct reference to a history event through the event ID.""" + + event_id: Optional[str] = Field(None, alias="eventId") + event_type: Optional[EventType] = Field(None, alias="eventType") + + +class RequestIDRef(BaseModel): + """RequestIdReference is a indirect reference to a history event through the request ID.""" + + event_type: Optional[EventType] = Field(None, alias="eventType") + request_id: Optional[str] = Field(None, alias="requestId") + + +class WorkflowEvent(BaseModel): + event_ref: Optional[EventRef] = Field(None, alias="eventRef") + namespace: Optional[str] = None + request_id_ref: Optional[RequestIDRef] = Field(None, alias="requestIdRef") + run_id: Optional[str] = Field(None, alias="runId") + workflow_id: Optional[str] = Field(None, alias="workflowId") + + +class Openapiv3(BaseModel): + """Link can be associated with history events. It might contain information about an + external entity + related to the history event. For example, workflow A makes a Nexus call that starts + workflow B: + in this case, a history event in workflow A could contain a Link to the workflow started + event in + workflow B, and vice-versa. + """ + + batch_job: Optional[BatchJob] = Field(None, alias="batchJob") + workflow_event: Optional[WorkflowEvent] = Field(None, alias="workflowEvent") + + +class Memo(BaseModel): + """A user-defined set of *unindexed* fields that are exposed when listing/searching workflows""" + + fields: Optional[Dict[str, Any]] = None + + +class Priority(BaseModel): + """Priority metadata + + Priority contains metadata that controls relative ordering of task processing + when tasks are backed up in a queue. Initially, Priority will be used in + matching (workflow and activity) task queues. Later it may be used in history + task queues and in rate limiting decisions. + + Priority is attached to workflows and activities. By default, activities + inherit Priority from the workflow that created them, but may override fields + when an activity is started or modified. + + Despite being named "Priority", this message also contains fields that + control "fairness" mechanisms. + + For all fields, the field not present or equal to zero/empty string means to + inherit the value from the calling workflow, or if there is no calling + workflow, then use the default value. + + For all fields other than fairness_key, the zero value isn't meaningful so + there's no confusion between inherit/default and a meaningful value. For + fairness_key, the empty string will be interpreted as "inherit". This means + that if a workflow has a non-empty fairness key, you can't override the + fairness key of its activity to the empty string. + + The overall semantics of Priority are: + 1. First, consider "priority": higher priority (lower number) goes first. + 2. Then, consider fairness: try to dispatch tasks for different fairness keys + in proportion to their weight. + + Applications may use any subset of mechanisms that are useful to them and + leave the other fields to use default values. + + Not all queues in the system may support the "full" semantics of all priority + fields. (Currently only support in matching task queues is planned.) + """ + + fairness_key: Optional[str] = Field(None, alias="fairnessKey") + """Fairness key is a short string that's used as a key for a fairness + balancing mechanism. It may correspond to a tenant id, or to a fixed + string like "high" or "low". The default is the empty string. + + The fairness mechanism attempts to dispatch tasks for a given key in + proportion to its weight. For example, using a thousand distinct tenant + ids, each with a weight of 1.0 (the default) will result in each tenant + getting a roughly equal share of task dispatch throughput. + + (Note: this does not imply equal share of worker capacity! Fairness + decisions are made based on queue statistics, not + current worker load.) + + As another example, using keys "high" and "low" with weight 9.0 and 1.0 + respectively will prefer dispatching "high" tasks over "low" tasks at a + 9:1 ratio, while allowing either key to use all worker capacity if the + other is not present. + + All fairness mechanisms, including rate limits, are best-effort and + probabilistic. The results may not match what a "perfect" algorithm with + infinite resources would produce. The more unique keys are used, the less + accurate the results will be. + + Fairness keys are limited to 64 bytes. + """ + fairness_weight: Optional[float] = Field(None, alias="fairnessWeight") + """Fairness weight for a task can come from multiple sources for + flexibility. From highest to lowest precedence: + 1. Weights for a small set of keys can be overridden in task queue + configuration with an API. + 2. It can be attached to the workflow/activity in this field. + 3. The default weight of 1.0 will be used. + + Weight values are clamped to the range [0.001, 1000]. + """ + priority_key: Optional[int] = Field(None, alias="priorityKey") + """Priority key is a positive integer from 1 to n, where smaller integers + correspond to higher priorities (tasks run sooner). In general, tasks in + a queue should be processed in close to priority order, although small + deviations are possible. + + The maximum priority value (minimum priority) is determined by server + configuration, and defaults to 5. + + If priority is not present (or zero), then the effective priority will be + the default priority, which is calculated by (min+max)/2. With the + default max of 5, and min of 1, that comes out to 3. + """ + + +class RetryPolicy(BaseModel): + """Retry policy for the workflow + + How retries ought to be handled, usable by both workflows and activities + """ + + backoff_coefficient: Optional[float] = Field(None, alias="backoffCoefficient") + """Coefficient used to calculate the next retry interval. + The next retry interval is previous interval multiplied by the coefficient. + Must be 1 or larger. + """ + initial_interval: Optional[str] = Field(None, alias="initialInterval") + """Interval of the first retry. If retryBackoffCoefficient is 1.0 then it is used for all + retries. + """ + maximum_attempts: Optional[int] = Field(None, alias="maximumAttempts") + """Maximum number of attempts. When exceeded the retries stop even if not expired yet. + 1 disables retries. 0 means unlimited (up to the timeouts) + """ + maximum_interval: Optional[str] = Field(None, alias="maximumInterval") + """Maximum interval between retries. Exponential backoff leads to interval increase. + This value is the cap of the increase. Default is 100x of the initial interval. + """ + non_retryable_error_types: Optional[List[str]] = Field( + None, alias="nonRetryableErrorTypes" + ) + """Non-Retryable errors types. Will stop retrying if the error type matches this list. Note + that + this is not a substring match, the error *type* (not message) must match exactly. + """ + + +class SearchAttributes(BaseModel): + """A user-defined set of *indexed* fields that are used/exposed when listing/searching + workflows. + The payload is not serialized in a user-defined way. + """ + + indexed_fields: Optional[Dict[str, Any]] = Field(None, alias="indexedFields") + + +class Kind(Enum): + """Default: TASK_QUEUE_KIND_NORMAL.""" + + TASK_QUEUE_KIND_NORMAL = "TASK_QUEUE_KIND_NORMAL" + TASK_QUEUE_KIND_STICKY = "TASK_QUEUE_KIND_STICKY" + TASK_QUEUE_KIND_UNSPECIFIED = "TASK_QUEUE_KIND_UNSPECIFIED" + + +class TaskQueue(BaseModel): + """The task queue to start this workflow on, if it will be started + + See https://docs.temporal.io/docs/concepts/task-queues/ + """ + + kind: Optional[Kind] = None + """Default: TASK_QUEUE_KIND_NORMAL.""" + + name: Optional[str] = None + normal_name: Optional[str] = Field(None, alias="normalName") + """Iff kind == TASK_QUEUE_KIND_STICKY, then this field contains the name of + the normal task queue that the sticky worker is running on. + """ + + +class TimeSkippingConfig(BaseModel): + """Time-skipping configuration. If not set, time skipping is disabled. + + Configuration for time skipping during a workflow execution. + When enabled, virtual time advances automatically whenever there is no in-flight work. + In-flight work includes activities, child workflows, Nexus operations, signal/cancel + external workflow operations, + and possibly other features added in the future. + User timers are not classified as in-flight work and will be skipped over. + When time advances, it skips to the earlier of the next user timer or the configured + bound, if either exists. + """ + + disable_propagation: Optional[bool] = Field(None, alias="disablePropagation") + """If set, the enabled field is not propagated to transitively related workflows.""" + + enabled: Optional[bool] = None + """Enables or disables time skipping for this workflow execution. + By default, this field is propagated to transitively related workflows (child + workflows/start-as-new/reset) + at the time they are started. + Changes made after a transitively related workflow has started are not propagated. + """ + max_elapsed_duration: Optional[str] = Field(None, alias="maxElapsedDuration") + """Maximum elapsed time since time skipping was enabled. + This includes both skipped time and real time elapsing. + """ + max_skipped_duration: Optional[str] = Field(None, alias="maxSkippedDuration") + """Maximum total virtual time that can be skipped.""" + + max_target_time: Optional[datetime] = Field(None, alias="maxTargetTime") + """Absolute virtual timestamp at which time skipping is disabled. + Time skipping will not advance beyond this point. + """ + + +class UserMetadata(BaseModel): + """Metadata on the workflow if it is started. This is carried over to the + WorkflowExecutionInfo + for use by user interfaces to display the fixed as-of-start summary and details of the + workflow. + + Information a user can set, often for use by user interfaces. + """ + + details: Any + """Long-form text that provides details. This payload should be a "json/plain"-encoded + payload + that is a single JSON string for use in user interfaces. User interface formatting may + apply to + this text in common use. The payload data section is limited to 20000 bytes by default. + """ + summary: Any + """Short-form text that provides a summary. This payload should be a "json/plain"-encoded + payload + that is a single JSON string for use in user interfaces. User interface formatting may + not + apply to this text when used in "title" situations. The payload data section is limited + to 400 + bytes by default. + """ + + +class VersioningOverrideBehavior(Enum): + """Required. + Deprecated. Use `override`. + """ + + VERSIONING_BEHAVIOR_AUTO_UPGRADE = "VERSIONING_BEHAVIOR_AUTO_UPGRADE" + VERSIONING_BEHAVIOR_PINNED = "VERSIONING_BEHAVIOR_PINNED" + VERSIONING_BEHAVIOR_UNSPECIFIED = "VERSIONING_BEHAVIOR_UNSPECIFIED" + + +class Deployment(BaseModel): + """Required if behavior is `PINNED`. Must be null if behavior is `AUTO_UPGRADE`. + Identifies the worker deployment to pin the workflow to. + Deprecated. Use `override.pinned.version`. + + `Deployment` identifies a deployment of Temporal workers. The combination of deployment + series + name + build ID serves as the identifier. User can use `WorkerDeploymentOptions` in their + worker + programs to specify these values. + Deprecated. + """ + + build_id: Optional[str] = Field(None, alias="buildId") + """Build ID changes with each version of the worker when the worker program code and/or + config + changes. + """ + series_name: Optional[str] = Field(None, alias="seriesName") + """Different versions of the same worker service/application are related together by having + a + shared series name. + Out of all deployments of a series, one can be designated as the current deployment, + which + receives new workflow executions and new tasks of workflows with + `VERSIONING_BEHAVIOR_AUTO_UPGRADE` versioning behavior. + """ + + +class PinnedBehavior(Enum): + """Defaults to PINNED_OVERRIDE_BEHAVIOR_UNSPECIFIED. + See `PinnedOverrideBehavior` for details. + """ + + PINNED_OVERRIDE_BEHAVIOR_PINNED = "PINNED_OVERRIDE_BEHAVIOR_PINNED" + PINNED_OVERRIDE_BEHAVIOR_UNSPECIFIED = "PINNED_OVERRIDE_BEHAVIOR_UNSPECIFIED" + + +class Version(BaseModel): + """Specifies the Worker Deployment Version to pin this workflow to. + Required if the target workflow is not already pinned to a version. + + If omitted and the target workflow is already pinned, the effective + pinned version will be the existing pinned version. + + If omitted and the target workflow is not pinned, the override request + will be rejected with a PreconditionFailed error. + + A Worker Deployment Version (Version, for short) represents a + version of workers within a Worker Deployment. (see documentation of + WorkerDeploymentVersionInfo) + Version records are created in Temporal server automatically when their + first poller arrives to the server. + Experimental. Worker Deployment Versions are experimental and might significantly change + in the future. + """ + + build_id: Optional[str] = Field(None, alias="buildId") + """A unique identifier for this Version within the Deployment it is a part of. + Not necessarily unique within the namespace. + The combination of `deployment_name` and `build_id` uniquely identifies this + Version within the namespace, because Deployment names are unique within a namespace. + """ + deployment_name: Optional[str] = Field(None, alias="deploymentName") + """Identifies the Worker Deployment this Version is part of.""" + + +class Pinned(BaseModel): + """Override the workflow to have Pinned behavior.""" + + behavior: Optional[PinnedBehavior] = None + """Defaults to PINNED_OVERRIDE_BEHAVIOR_UNSPECIFIED. + See `PinnedOverrideBehavior` for details. + """ + version: Optional[Version] = None + """Specifies the Worker Deployment Version to pin this workflow to. + Required if the target workflow is not already pinned to a version. + + If omitted and the target workflow is already pinned, the effective + pinned version will be the existing pinned version. + + If omitted and the target workflow is not pinned, the override request + will be rejected with a PreconditionFailed error. + """ + + +class VersioningOverride(BaseModel): + """If set, takes precedence over the Versioning Behavior sent by the SDK on Workflow Task + completion. + To unset the override after the workflow is running, use UpdateWorkflowExecutionOptions. + + Used to override the versioning behavior (and pinned deployment version, if applicable) + of a + specific workflow execution. If set, this override takes precedence over worker-sent + values. + See `WorkflowExecutionInfo.VersioningInfo` for more information. + + To remove the override, call `UpdateWorkflowExecutionOptions` with a null + `VersioningOverride`, and use the `update_mask` to indicate that it should be mutated. + + Pinned behavior overrides are automatically inherited by child workflows, workflow + retries, continue-as-new + workflows, and cron workflows. + """ + + auto_upgrade: Optional[bool] = Field(None, alias="autoUpgrade") + """Override the workflow to have AutoUpgrade behavior.""" + + behavior: Optional[VersioningOverrideBehavior] = None + """Required. + Deprecated. Use `override`. + """ + deployment: Optional[Deployment] = None + """Required if behavior is `PINNED`. Must be null if behavior is `AUTO_UPGRADE`. + Identifies the worker deployment to pin the workflow to. + Deprecated. Use `override.pinned.version`. + """ + pinned: Optional[Pinned] = None + """Override the workflow to have Pinned behavior.""" + + pinned_version: Optional[str] = Field(None, alias="pinnedVersion") + """Required if behavior is `PINNED`. Must be absent if behavior is not `PINNED`. + Identifies the worker deployment version to pin the workflow to, in the format + ".". + Deprecated. Use `override.pinned.version`. + """ + + +class WorkflowIDConflictPolicy(Enum): + """Defines how to resolve a workflow id conflict with a *running* workflow. + The default policy is WORKFLOW_ID_CONFLICT_POLICY_USE_EXISTING. + Note that WORKFLOW_ID_CONFLICT_POLICY_FAIL is an invalid option. + + See `workflow_id_reuse_policy` for handling a workflow id duplication with a *closed* + workflow. + """ + + WORKFLOW_ID_CONFLICT_POLICY_FAIL = "WORKFLOW_ID_CONFLICT_POLICY_FAIL" + WORKFLOW_ID_CONFLICT_POLICY_TERMINATE_EXISTING = ( + "WORKFLOW_ID_CONFLICT_POLICY_TERMINATE_EXISTING" + ) + WORKFLOW_ID_CONFLICT_POLICY_UNSPECIFIED = "WORKFLOW_ID_CONFLICT_POLICY_UNSPECIFIED" + WORKFLOW_ID_CONFLICT_POLICY_USE_EXISTING = ( + "WORKFLOW_ID_CONFLICT_POLICY_USE_EXISTING" + ) + + +class WorkflowIDReusePolicy(Enum): + """Defines whether to allow re-using the workflow id from a previously *closed* workflow. + The default policy is WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE. + + See `workflow_id_reuse_policy` for handling a workflow id duplication with a *running* + workflow. + """ + + WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE = ( + "WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE" + ) + WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE_FAILED_ONLY = ( + "WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE_FAILED_ONLY" + ) + WORKFLOW_ID_REUSE_POLICY_REJECT_DUPLICATE = ( + "WORKFLOW_ID_REUSE_POLICY_REJECT_DUPLICATE" + ) + WORKFLOW_ID_REUSE_POLICY_TERMINATE_IF_RUNNING = ( + "WORKFLOW_ID_REUSE_POLICY_TERMINATE_IF_RUNNING" + ) + WORKFLOW_ID_REUSE_POLICY_UNSPECIFIED = "WORKFLOW_ID_REUSE_POLICY_UNSPECIFIED" + + +class WorkflowType(BaseModel): + """Represents the identifier used by a workflow author to define the workflow. Typically, + the + name of a function. This is sometimes referred to as the workflow's "name" + """ + + name: Optional[str] = None + + +class WorkflowServiceSignalWithStartWorkflowExecutionInput(BaseModel): + control: Optional[str] = None + """Deprecated.""" + + cron_schedule: Optional[str] = Field(None, alias="cronSchedule") + """See https://docs.temporal.io/docs/content/what-is-a-temporal-cron-job/""" + + header: Optional[Header] = None + identity: Optional[str] = None + """The identity of the worker/client""" + + input: Optional[Input] = None + """Serialized arguments to the workflow. These are passed as arguments to the workflow + function. + """ + links: Optional[List[Openapiv3]] = None + """Links to be associated with the WorkflowExecutionStarted and WorkflowExecutionSignaled + events. + """ + memo: Optional[Memo] = None + namespace: Optional[str] = None + priority: Optional[Priority] = None + """Priority metadata""" + + request_id: Optional[str] = Field(None, alias="requestId") + """Used to de-dupe signal w/ start requests""" + + retry_policy: Optional[RetryPolicy] = Field(None, alias="retryPolicy") + """Retry policy for the workflow""" + + search_attributes: Optional[SearchAttributes] = Field( + None, alias="searchAttributes" + ) + signal_input: Optional[Input] = Field(None, alias="signalInput") + """Serialized value(s) to provide with the signal""" + + signal_name: Optional[str] = Field(None, alias="signalName") + """The workflow author-defined name of the signal to send to the workflow""" + + task_queue: Optional[TaskQueue] = Field(None, alias="taskQueue") + """The task queue to start this workflow on, if it will be started""" + + time_skipping_config: Optional[TimeSkippingConfig] = Field( + None, alias="timeSkippingConfig" + ) + """Time-skipping configuration. If not set, time skipping is disabled.""" + + user_metadata: Optional[UserMetadata] = Field(None, alias="userMetadata") + """Metadata on the workflow if it is started. This is carried over to the + WorkflowExecutionInfo + for use by user interfaces to display the fixed as-of-start summary and details of the + workflow. + """ + versioning_override: Optional[VersioningOverride] = Field( + None, alias="versioningOverride" + ) + """If set, takes precedence over the Versioning Behavior sent by the SDK on Workflow Task + completion. + To unset the override after the workflow is running, use UpdateWorkflowExecutionOptions. + """ + workflow_execution_timeout: Optional[str] = Field( + None, alias="workflowExecutionTimeout" + ) + """Total workflow execution timeout including retries and continue as new""" + + workflow_id: Optional[str] = Field(None, alias="workflowId") + workflow_id_conflict_policy: Optional[WorkflowIDConflictPolicy] = Field( + None, alias="workflowIdConflictPolicy" + ) + """Defines how to resolve a workflow id conflict with a *running* workflow. + The default policy is WORKFLOW_ID_CONFLICT_POLICY_USE_EXISTING. + Note that WORKFLOW_ID_CONFLICT_POLICY_FAIL is an invalid option. + + See `workflow_id_reuse_policy` for handling a workflow id duplication with a *closed* + workflow. + """ + workflow_id_reuse_policy: Optional[WorkflowIDReusePolicy] = Field( + None, alias="workflowIdReusePolicy" + ) + """Defines whether to allow re-using the workflow id from a previously *closed* workflow. + The default policy is WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE. + + See `workflow_id_reuse_policy` for handling a workflow id duplication with a *running* + workflow. + """ + workflow_run_timeout: Optional[str] = Field(None, alias="workflowRunTimeout") + """Timeout of a single workflow run""" + + workflow_start_delay: Optional[str] = Field(None, alias="workflowStartDelay") + """Time to wait before dispatching the first workflow task. Cannot be used with + `cron_schedule`. + Note that the signal will be delivered with the first workflow task. If the workflow + gets + another SignalWithStartWorkflow before the delay a workflow task will be dispatched + immediately + and the rest of the delay period will be ignored, even if that request also had a delay. + Signal via SignalWorkflowExecution will not unblock the workflow. + """ + workflow_task_timeout: Optional[str] = Field(None, alias="workflowTaskTimeout") + """Timeout of a single workflow task""" + + workflow_type: Optional[WorkflowType] = Field(None, alias="workflowType") + + +class WorkflowServiceSignalWithStartWorkflowExecutionOutput(BaseModel): + run_id: Optional[str] = Field(None, alias="runId") + """The run id of the workflow that was started - or just signaled, if it was already running.""" + + started: Optional[bool] = None + """If true, a new workflow was started.""" + + +async def _temporal_nexus_rewrite_workflow_service_signal_with_start_workflow_execution_input( + payload: temporalio.api.common.v1.Payload, + payload_codec: temporalio.converter.PayloadCodec | None, +) -> temporalio.api.common.v1.Payload: + if payload_codec is None: + return payload + try: + value = json.loads(payload.data) + except json.JSONDecodeError: + return payload + rewritten = await _temporal_nexus_encode_json_value(value, payload_codec) + return temporalio.api.common.v1.Payload( + metadata=dict(payload.metadata), + data=json.dumps(rewritten, separators=(",", ":"), sort_keys=True).encode(), + ) + + +__temporal_nexus_payload_codec_rewriters__ = { + ( + "WorkflowService", + "SignalWithStartWorkflowExecution", + ): _temporal_nexus_rewrite_workflow_service_signal_with_start_workflow_execution_input, +} + + +@service +class WorkflowService: + signal_with_start_workflow_execution: Operation[ + WorkflowServiceSignalWithStartWorkflowExecutionInput, + WorkflowServiceSignalWithStartWorkflowExecutionOutput, + ] = Operation(name="SignalWithStartWorkflowExecution") diff --git a/temporalio/worker/_command_aware_visitor.py b/temporalio/worker/_command_aware_visitor.py index 2d7f3990b..85c38ff06 100644 --- a/temporalio/worker/_command_aware_visitor.py +++ b/temporalio/worker/_command_aware_visitor.py @@ -31,6 +31,8 @@ class CommandInfo: command_type: CommandType.ValueType command_seq: int + nexus_service: str | None = None + nexus_operation: str | None = None current_command_info: contextvars.ContextVar[CommandInfo | None] = ( @@ -81,7 +83,12 @@ async def _visit_coresdk_workflow_commands_SignalExternalWorkflowExecution( async def _visit_coresdk_workflow_commands_ScheduleNexusOperation( self, fs: VisitorFunctions, o: ScheduleNexusOperation ) -> None: - with current_command(CommandType.COMMAND_TYPE_SCHEDULE_NEXUS_OPERATION, o.seq): + with current_command( + CommandType.COMMAND_TYPE_SCHEDULE_NEXUS_OPERATION, + o.seq, + nexus_service=o.service, + nexus_operation=o.operation, + ): await super()._visit_coresdk_workflow_commands_ScheduleNexusOperation(fs, o) # Workflow activation jobs with payloads @@ -150,11 +157,20 @@ async def _visit_coresdk_workflow_activation_ResolveNexusOperation( @contextmanager def current_command( - command_type: CommandType.ValueType, command_seq: int + command_type: CommandType.ValueType, + command_seq: int, + *, + nexus_service: str | None = None, + nexus_operation: str | None = None, ) -> Iterator[None]: """Context manager for setting command info.""" token = current_command_info.set( - CommandInfo(command_type=command_type, command_seq=command_seq) + CommandInfo( + command_type=command_type, + command_seq=command_seq, + nexus_service=nexus_service, + nexus_operation=nexus_operation, + ) ) try: yield diff --git a/tests/nexus/test_temporal_system_nexus.py b/tests/nexus/test_temporal_system_nexus.py new file mode 100644 index 000000000..f9dd6b7a9 --- /dev/null +++ b/tests/nexus/test_temporal_system_nexus.py @@ -0,0 +1,178 @@ +from __future__ import annotations + +import dataclasses +import json +import uuid +from collections.abc import Sequence +from typing import cast + +import nexusrpc.handler +import pytest +from google.protobuf.json_format import MessageToDict + +import temporalio.api.common.v1 +from temporalio import workflow +from temporalio.client import Client +from temporalio.contrib.pydantic import pydantic_data_converter +from temporalio.converter import PayloadCodec +from temporalio.nexus.system import ( + WorkflowService, + WorkflowServiceSignalWithStartWorkflowExecutionInput, + WorkflowServiceSignalWithStartWorkflowExecutionOutput, +) +from temporalio.testing import WorkflowEnvironment +from temporalio.worker import Worker +from temporalio.worker._workflow_instance import UnsandboxedWorkflowRunner +from tests.helpers.nexus import make_nexus_endpoint_name + + +@nexusrpc.handler.service_handler(service=WorkflowService) +class WorkflowServicePayloadHandler: + @nexusrpc.handler.sync_operation + async def signal_with_start_workflow_execution( + self, + _ctx: nexusrpc.handler.StartOperationContext, + request: WorkflowServiceSignalWithStartWorkflowExecutionInput, + ) -> WorkflowServiceSignalWithStartWorkflowExecutionOutput: + for field_name in ("input", "signalInput"): + payloads = request.model_dump(by_alias=True)[field_name]["payloads"] + assert "test-codec" in payloads[0]["metadata"] + for field_name in ("memo", "header"): + fields = request.model_dump(by_alias=True)[field_name]["fields"] + assert "test-codec" in next(iter(fields.values()))["metadata"] + return WorkflowServiceSignalWithStartWorkflowExecutionOutput( + runId=f"{request.workflow_id}-run" + ) + + +@workflow.defn +class SystemNexusCallerWithPayloadsWorkflow: + @workflow.run + async def run(self, task_queue: str) -> str: + nexus_client = workflow.create_nexus_client( + service=WorkflowService, + endpoint=make_nexus_endpoint_name(task_queue), + ) + request = WorkflowServiceSignalWithStartWorkflowExecutionInput.model_validate( + { + "namespace": "default", + "workflowId": "system-nexus-workflow-id", + "signalName": "test-signal", + "input": MessageToDict( + temporalio.api.common.v1.Payloads( + payloads=[ + temporalio.api.common.v1.Payload( + metadata={"encoding": b"json/plain"}, + data=b'"workflow-input"', + ) + ] + ) + ), + "signalInput": MessageToDict( + temporalio.api.common.v1.Payloads( + payloads=[ + temporalio.api.common.v1.Payload( + metadata={"encoding": b"json/plain"}, + data=b'"signal-input"', + ) + ] + ) + ), + "memo": MessageToDict( + temporalio.api.common.v1.Memo( + fields={ + "memo-key": temporalio.api.common.v1.Payload( + metadata={"encoding": b"json/plain"}, + data=b'"memo-value"', + ) + } + ) + ), + "header": MessageToDict( + temporalio.api.common.v1.Header( + fields={ + "header-key": temporalio.api.common.v1.Payload( + metadata={"encoding": b"json/plain"}, + data=b'"header-value"', + ) + } + ) + ), + } + ) + handle = await nexus_client.start_operation( + WorkflowService.signal_with_start_workflow_execution, + request, + ) + result = await handle + return cast(str, result.run_id) + + +class RejectOuterSystemNexusCodec(PayloadCodec): + def __init__(self) -> None: + self.encode_count = 0 + + async def encode( + self, payloads: Sequence[temporalio.api.common.v1.Payload] + ) -> list[temporalio.api.common.v1.Payload]: + encoded: list[temporalio.api.common.v1.Payload] = [] + for payload in payloads: + try: + body = json.loads(payload.data) + except json.JSONDecodeError: + body = None + if isinstance(body, dict) and { + "namespace", + "workflowId", + "signalName", + }.issubset(body): + raise RuntimeError( + "outer system nexus envelope should not be codec encoded" + ) + self.encode_count += 1 + encoded.append( + temporalio.api.common.v1.Payload( + metadata={**payload.metadata, "test-codec": b"true"}, + data=payload.data, + ) + ) + return encoded + + async def decode( + self, payloads: Sequence[temporalio.api.common.v1.Payload] + ) -> list[temporalio.api.common.v1.Payload]: + return list(payloads) + + +async def test_workflow_service_signal_with_start_nested_payloads_use_codec_without_encoding_outer_envelope( + env: WorkflowEnvironment, +): + if env.supports_time_skipping: + pytest.skip("Nexus tests don't work with the Java test server") + + codec = RejectOuterSystemNexusCodec() + config = env.client.config() + config["data_converter"] = dataclasses.replace( + pydantic_data_converter, + payload_codec=codec, + ) + client = Client(**config) + + async with Worker( + client, + task_queue=str(uuid.uuid4()), + workflows=[SystemNexusCallerWithPayloadsWorkflow], + nexus_service_handlers=[WorkflowServicePayloadHandler()], + workflow_runner=UnsandboxedWorkflowRunner(), + ) as worker: + endpoint_name = make_nexus_endpoint_name(worker.task_queue) + await env.create_nexus_endpoint(endpoint_name, worker.task_queue) + result = await client.execute_workflow( + SystemNexusCallerWithPayloadsWorkflow.run, + worker.task_queue, + id=str(uuid.uuid4()), + task_queue=worker.task_queue, + ) + + assert result == "system-nexus-workflow-id-run" + assert codec.encode_count >= 4 From f934a822d0f111333db7cf97c2c7ea6c591a092e Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Wed, 8 Apr 2026 12:31:09 -0700 Subject: [PATCH 02/18] Make Nexus system model generation repo-local --- README.md | 7 ++ pyproject.toml | 10 +- scripts/gen_nexus_system_models.py | 118 ++++++++++++++++++ scripts/gen_nexus_system_test_models.py | 75 ----------- .../system/_workflow_service.nexusrpc.yaml | 9 ++ .../system/_workflow_service_generated.py | 45 ------- 6 files changed, 139 insertions(+), 125 deletions(-) create mode 100644 scripts/gen_nexus_system_models.py delete mode 100644 scripts/gen_nexus_system_test_models.py create mode 100644 temporalio/nexus/system/_workflow_service.nexusrpc.yaml diff --git a/README.md b/README.md index 6e42a2019..d1e162ab9 100644 --- a/README.md +++ b/README.md @@ -1933,6 +1933,8 @@ To build the SDK from source for use as a dependency, the following prerequisite * [uv](https://docs.astral.sh/uv/) * [Rust](https://www.rust-lang.org/) * [Protobuf Compiler](https://protobuf.dev/) +* [Node.js](https://nodejs.org/) +* [`pnpm`](https://pnpm.io/) Use `uv` to install `poe`: @@ -2074,6 +2076,11 @@ back from this downgrade, restore both of those files and run `uv sync --all-ext run for protobuf version 3 by setting the `TEMPORAL_TEST_PROTO3` env var to `1` prior to running tests. +The local build and lint flows also regenerate Temporal system Nexus models. By default this pulls +in `nexus-rpc-gen@0.1.0-alpha.4` via `npx`. To use an existing checkout instead, set +`TEMPORAL_NEXUS_RPC_GEN_DIR` to the `nexus-rpc-gen` repo root or its `src` directory before +running `poe build-develop`, `poe lint`, or `poe gen-protos`. + ### Style * Mostly [Google Style Guide](https://google.github.io/styleguide/pyguide.html). Notable exceptions: diff --git a/pyproject.toml b/pyproject.toml index 3c810ccc9..440ed4195 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,11 +68,11 @@ dev = [ [tool.poe.tasks] build-develop = [ - { ref = "gen-nexus-system-test-models" }, + { ref = "gen-nexus-system-models" }, { cmd = "uv run maturin develop --uv" }, ] build-develop-with-release = [ - { ref = "gen-nexus-system-test-models" }, + { ref = "gen-nexus-system-models" }, { cmd = "uv run maturin develop --release --uv" }, ] format = [ @@ -85,7 +85,7 @@ gen-protos = [ { cmd = "uv run scripts/gen_protos.py" }, { cmd = "uv run scripts/gen_payload_visitor.py" }, { cmd = "uv run scripts/gen_bridge_client.py" }, - { ref = "gen-nexus-system-test-models" }, + { ref = "gen-nexus-system-models" }, { ref = "format" }, ] gen-protos-docker = [ @@ -105,12 +105,12 @@ bridge-lint = { cmd = "cargo clippy -- -D warnings", cwd = "temporalio/bridge" } # https://github.com/PyCQA/pydocstyle/pull/511? lint-docs = "uv run pydocstyle --ignore-decorators=overload" lint-types = [ - { ref = "gen-nexus-system-test-models" }, + { ref = "gen-nexus-system-models" }, { cmd = "uv run pyright" }, { cmd = "uv run mypy --namespace-packages --check-untyped-defs ." }, { cmd = "uv run basedpyright" }, ] -gen-nexus-system-test-models = "uv run scripts/gen_nexus_system_test_models.py" +gen-nexus-system-models = "uv run scripts/gen_nexus_system_models.py" run-bench = "uv run python scripts/run_bench.py" test = "uv run pytest" diff --git a/scripts/gen_nexus_system_models.py b/scripts/gen_nexus_system_models.py new file mode 100644 index 000000000..d90ebdf8f --- /dev/null +++ b/scripts/gen_nexus_system_models.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +import os +import subprocess +import sys +from pathlib import Path + +NEXUS_RPC_GEN_ENV_VAR = "TEMPORAL_NEXUS_RPC_GEN_DIR" +NEXUS_RPC_GEN_VERSION = "0.1.0-alpha.4" + + +def main() -> None: + repo_root = Path(__file__).resolve().parent.parent + override_root = normalize_nexus_rpc_gen_root( + Path.cwd(), env_value=NEXUS_RPC_GEN_ENV_VAR + ) + input_schema = ( + repo_root + / "temporalio" + / "nexus" + / "system" + / "_workflow_service.nexusrpc.yaml" + ) + output_file = ( + repo_root / "temporalio" / "nexus" / "system" / "_workflow_service_generated.py" + ) + + if not input_schema.is_file(): + raise RuntimeError(f"Expected Nexus schema at {input_schema}") + + run_nexus_rpc_gen( + override_root=override_root, + output_file=output_file, + input_schema=input_schema, + ) + subprocess.run( + [ + "uv", + "run", + "ruff", + "check", + "--select", + "I", + "--fix", + str(output_file), + ], + cwd=repo_root, + check=True, + ) + subprocess.run( + [ + "uv", + "run", + "ruff", + "format", + str(output_file), + ], + cwd=repo_root, + check=True, + ) + + +def run_nexus_rpc_gen( + *, override_root: Path | None, output_file: Path, input_schema: Path +) -> None: + common_args = [ + "--lang", + "py", + "--out-file", + str(output_file), + "--temporal-nexus-payload-codec-support", + str(input_schema), + ] + if override_root is None: + subprocess.run( + ["npx", "--yes", f"nexus-rpc-gen@{NEXUS_RPC_GEN_VERSION}", *common_args], + check=True, + ) + return + + subprocess.run( + [ + "node", + "packages/nexus-rpc-gen/dist/index.js", + *common_args, + ], + cwd=override_root, + check=True, + ) + + +def normalize_nexus_rpc_gen_root(base_dir: Path, env_value: str) -> Path | None: + raw_root = env_get(env_value) + if raw_root is None: + return None + candidate = Path(raw_root) + if not candidate.is_absolute(): + candidate = base_dir / candidate + candidate = candidate.resolve() + if (candidate / "package.json").is_file() and (candidate / "packages").is_dir(): + return candidate + if (candidate / "src" / "package.json").is_file(): + return candidate / "src" + raise RuntimeError( + f"{NEXUS_RPC_GEN_ENV_VAR} must point to the nexus-rpc-gen repo root or its src directory" + ) + + +def env_get(name: str) -> str | None: + return os.environ.get(name) + + +if __name__ == "__main__": + try: + main() + except Exception as err: + print(f"Failed to generate Nexus system models: {err}", file=sys.stderr) + raise diff --git a/scripts/gen_nexus_system_test_models.py b/scripts/gen_nexus_system_test_models.py deleted file mode 100644 index d707727a5..000000000 --- a/scripts/gen_nexus_system_test_models.py +++ /dev/null @@ -1,75 +0,0 @@ -from __future__ import annotations - -import subprocess -import sys -from pathlib import Path - - -def main() -> None: - repo_root = Path(__file__).resolve().parent.parent - workspace_root = repo_root.parent - nexus_rpc_gen_root = workspace_root / "nexus-rpc-gen" / "src" - input_schema = ( - workspace_root - / "temporal-api" - / "nexus" - / "temporal-json-schema-models-nexusrpc.yaml" - ) - output_file = ( - repo_root / "temporalio" / "nexus" / "system" / "_workflow_service_generated.py" - ) - - if not nexus_rpc_gen_root.is_dir(): - raise RuntimeError(f"Expected nexus-rpc-gen checkout at {nexus_rpc_gen_root}") - if not input_schema.is_file(): - raise RuntimeError(f"Expected Temporal Nexus schema at {input_schema}") - - subprocess.run( - [ - "npm", - "run", - "cli", - "--", - "--lang", - "py", - "--out-file", - str(output_file), - "--temporal-nexus-payload-codec-support", - str(input_schema), - ], - cwd=nexus_rpc_gen_root, - check=True, - ) - subprocess.run( - [ - "uv", - "run", - "ruff", - "check", - "--select", - "I", - "--fix", - str(output_file), - ], - cwd=repo_root, - check=True, - ) - subprocess.run( - [ - "uv", - "run", - "ruff", - "format", - str(output_file), - ], - cwd=repo_root, - check=True, - ) - - -if __name__ == "__main__": - try: - main() - except Exception as err: - print(f"Failed to generate Nexus system test models: {err}", file=sys.stderr) - raise diff --git a/temporalio/nexus/system/_workflow_service.nexusrpc.yaml b/temporalio/nexus/system/_workflow_service.nexusrpc.yaml new file mode 100644 index 000000000..c5b8cc671 --- /dev/null +++ b/temporalio/nexus/system/_workflow_service.nexusrpc.yaml @@ -0,0 +1,9 @@ +nexusrpc: 1.0.0 +services: + WorkflowService: + operations: + SignalWithStartWorkflowExecution: + input: + $ref: ../../bridge/sdk-core/crates/common/protos/api_upstream/openapi/openapiv3.yaml#/components/schemas/SignalWithStartWorkflowExecutionRequest + output: + $ref: ../../bridge/sdk-core/crates/common/protos/api_upstream/openapi/openapiv3.yaml#/components/schemas/SignalWithStartWorkflowExecutionResponse diff --git a/temporalio/nexus/system/_workflow_service_generated.py b/temporalio/nexus/system/_workflow_service_generated.py index 1723631a1..085dad9b1 100644 --- a/temporalio/nexus/system/_workflow_service_generated.py +++ b/temporalio/nexus/system/_workflow_service_generated.py @@ -3,7 +3,6 @@ from __future__ import annotations import json -from datetime import datetime from enum import Enum from typing import Any, Dict, List, Optional @@ -202,9 +201,6 @@ class EventType(Enum): "EVENT_TYPE_WORKFLOW_EXECUTION_TERMINATED" ) EVENT_TYPE_WORKFLOW_EXECUTION_TIMED_OUT = "EVENT_TYPE_WORKFLOW_EXECUTION_TIMED_OUT" - EVENT_TYPE_WORKFLOW_EXECUTION_TIME_SKIPPING_TRANSITIONED = ( - "EVENT_TYPE_WORKFLOW_EXECUTION_TIME_SKIPPING_TRANSITIONED" - ) EVENT_TYPE_WORKFLOW_EXECUTION_UNPAUSED = "EVENT_TYPE_WORKFLOW_EXECUTION_UNPAUSED" EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_ACCEPTED = ( "EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_ACCEPTED" @@ -424,42 +420,6 @@ class TaskQueue(BaseModel): """ -class TimeSkippingConfig(BaseModel): - """Time-skipping configuration. If not set, time skipping is disabled. - - Configuration for time skipping during a workflow execution. - When enabled, virtual time advances automatically whenever there is no in-flight work. - In-flight work includes activities, child workflows, Nexus operations, signal/cancel - external workflow operations, - and possibly other features added in the future. - User timers are not classified as in-flight work and will be skipped over. - When time advances, it skips to the earlier of the next user timer or the configured - bound, if either exists. - """ - - disable_propagation: Optional[bool] = Field(None, alias="disablePropagation") - """If set, the enabled field is not propagated to transitively related workflows.""" - - enabled: Optional[bool] = None - """Enables or disables time skipping for this workflow execution. - By default, this field is propagated to transitively related workflows (child - workflows/start-as-new/reset) - at the time they are started. - Changes made after a transitively related workflow has started are not propagated. - """ - max_elapsed_duration: Optional[str] = Field(None, alias="maxElapsedDuration") - """Maximum elapsed time since time skipping was enabled. - This includes both skipped time and real time elapsing. - """ - max_skipped_duration: Optional[str] = Field(None, alias="maxSkippedDuration") - """Maximum total virtual time that can be skipped.""" - - max_target_time: Optional[datetime] = Field(None, alias="maxTargetTime") - """Absolute virtual timestamp at which time skipping is disabled. - Time skipping will not advance beyond this point. - """ - - class UserMetadata(BaseModel): """Metadata on the workflow if it is started. This is carried over to the WorkflowExecutionInfo @@ -718,11 +678,6 @@ class WorkflowServiceSignalWithStartWorkflowExecutionInput(BaseModel): task_queue: Optional[TaskQueue] = Field(None, alias="taskQueue") """The task queue to start this workflow on, if it will be started""" - time_skipping_config: Optional[TimeSkippingConfig] = Field( - None, alias="timeSkippingConfig" - ) - """Time-skipping configuration. If not set, time skipping is disabled.""" - user_metadata: Optional[UserMetadata] = Field(None, alias="userMetadata") """Metadata on the workflow if it is started. This is carried over to the WorkflowExecutionInfo From 9d4ba7831199963913fd851004e4ae637d278418 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Wed, 8 Apr 2026 12:33:03 -0700 Subject: [PATCH 03/18] Clarify Nexus generator fallback requirements --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index d1e162ab9..736f04e8a 100644 --- a/README.md +++ b/README.md @@ -1934,7 +1934,6 @@ To build the SDK from source for use as a dependency, the following prerequisite * [Rust](https://www.rust-lang.org/) * [Protobuf Compiler](https://protobuf.dev/) * [Node.js](https://nodejs.org/) -* [`pnpm`](https://pnpm.io/) Use `uv` to install `poe`: @@ -2079,7 +2078,8 @@ tests. The local build and lint flows also regenerate Temporal system Nexus models. By default this pulls in `nexus-rpc-gen@0.1.0-alpha.4` via `npx`. To use an existing checkout instead, set `TEMPORAL_NEXUS_RPC_GEN_DIR` to the `nexus-rpc-gen` repo root or its `src` directory before -running `poe build-develop`, `poe lint`, or `poe gen-protos`. +running `poe build-develop`, `poe lint`, or `poe gen-protos`. The local checkout override path +also requires [`pnpm`](https://pnpm.io/) to be installed. ### Style From 7960be50e4daf0b2e3313af67b32dd052ed38a31 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Wed, 8 Apr 2026 13:54:13 -0700 Subject: [PATCH 04/18] Refine Nexus system payload rewriting --- pyproject.toml | 3 +- scripts/gen_nexus_system_models.py | 2 + temporalio/bridge/worker.py | 78 +++---- temporalio/nexus/system/__init__.py | 23 +- .../system/_workflow_service.nexusrpc.yaml | 2 + .../system/_workflow_service_generated.py | 204 +++++++++++------- tests/nexus/test_temporal_system_nexus.py | 66 +++++- 7 files changed, 251 insertions(+), 127 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 440ed4195..a1f5a6661 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -148,6 +148,7 @@ environment = { PATH = "$PATH:$HOME/.cargo/bin", CARGO_NET_GIT_FETCH_WITH_CLI = ignore_missing_imports = true exclude = [ # Ignore generated code + 'build', 'temporalio/api', 'temporalio/bridge/proto', 'temporalio/nexus/system/_workflow_service_generated.py', @@ -156,7 +157,7 @@ exclude = [ [tool.pydocstyle] convention = "google" # https://github.com/PyCQA/pydocstyle/issues/363#issuecomment-625563088 -match_dir = "^(?!(docs|scripts|tests|api|proto|\\.)).*" +match_dir = "^(?!(build|docs|scripts|tests|api|proto|\\.)).*" match = "^(?!_workflow_service_generated\\.py$).*\\.py" add_ignore = [ # We like to wrap at a certain number of chars, even long summary sentences. diff --git a/scripts/gen_nexus_system_models.py b/scripts/gen_nexus_system_models.py index d90ebdf8f..2009a43a7 100644 --- a/scripts/gen_nexus_system_models.py +++ b/scripts/gen_nexus_system_models.py @@ -11,6 +11,8 @@ def main() -> None: repo_root = Path(__file__).resolve().parent.parent + # TODO: Remove the local .nexusrpc.yaml shim once the upstream API repo + # checks in the Nexus definition we can consume directly. override_root = normalize_nexus_rpc_gen_root( Path.cwd(), env_value=NEXUS_RPC_GEN_ENV_VAR ) diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index 52a54ce9c..587bbaa12 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -285,46 +285,13 @@ class _Visitor(VisitorFunctions): def __init__( self, f: Callable[[Sequence[Payload]], Awaitable[list[Payload]]], - payload_codec: temporalio.converter.PayloadCodec | None = None, ): self._f = f - self._payload_codec = payload_codec async def visit_payload(self, payload: Payload) -> None: - if self._payload_codec: - rewritten_payload = await self._maybe_rewrite_nexus_payload(payload) - if rewritten_payload is not None: - if rewritten_payload is not payload: - payload.CopyFrom(rewritten_payload) - return - new_payload = (await self._f([payload]))[0] - if new_payload is not payload: - payload.CopyFrom(new_payload) - - async def _maybe_rewrite_nexus_payload(self, payload: Payload) -> Payload | None: - command_info = _command_aware_visitor.current_command_info.get() - if ( - command_info is None - or command_info.command_type - != CommandType.COMMAND_TYPE_SCHEDULE_NEXUS_OPERATION - or not command_info.nexus_service - or not command_info.nexus_operation - ): - return None - - rewrite = temporalio.nexus.system.get_payload_codec_rewriter( - command_info.nexus_service, - command_info.nexus_operation, - ) - if rewrite is None: - return None - - rewritten_payload = await rewrite(payload, self._payload_codec) - if not isinstance(rewritten_payload, Payload): - raise TypeError( - "temporal nexus payload codec rewriter must return a Payload" - ) - return rewritten_payload + rewritten_payload = (await self._f([payload]))[0] + if rewritten_payload is not payload: + payload.CopyFrom(rewritten_payload) async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: if len(payloads) == 0: @@ -336,6 +303,43 @@ async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: payloads.extend(new_payloads) +async def _encode_completion_payloads( + data_converter: temporalio.converter.DataConverter, + payloads: Sequence[Payload], +) -> list[Payload]: + if len(payloads) != 1: + return await data_converter._encode_payload_sequence(payloads) + + # A single payload may be the outer envelope for a system Nexus operation. + # In that case we leave the envelope itself unencoded so the server can read + # it, but still route any nested Temporal payloads through normal payload + # processing via the generated operation-specific rewriter. + payload = payloads[0] + command_info = _command_aware_visitor.current_command_info.get() + if ( + command_info is None + or command_info.command_type + != CommandType.COMMAND_TYPE_SCHEDULE_NEXUS_OPERATION + or not command_info.nexus_service + or not command_info.nexus_operation + ): + return await data_converter._encode_payload_sequence(payloads) + + rewrite = temporalio.nexus.system.get_payload_rewriter( + command_info.nexus_service, + command_info.nexus_operation, + ) + if rewrite is None: + return await data_converter._encode_payload_sequence(payloads) + + rewritten_payload = await rewrite( + payload, + data_converter._encode_payload_sequence, + False, + ) + return [rewritten_payload] + + async def decode_activation( activation: temporalio.bridge.proto.workflow_activation.WorkflowActivation, data_converter: temporalio.converter.DataConverter, @@ -356,6 +360,6 @@ async def encode_completion( await CommandAwarePayloadVisitor( skip_search_attributes=True, skip_headers=not encode_headers ).visit( - _Visitor(data_converter._encode_payload_sequence, data_converter.payload_codec), + _Visitor(lambda payloads: _encode_completion_payloads(data_converter, payloads)), completion, ) diff --git a/temporalio/nexus/system/__init__.py b/temporalio/nexus/system/__init__.py index 7eff3a93d..52a80cda1 100644 --- a/temporalio/nexus/system/__init__.py +++ b/temporalio/nexus/system/__init__.py @@ -4,38 +4,41 @@ Higher-level ergonomic APIs may wrap these generated types. """ -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Callable, Sequence import temporalio.api.common.v1 -import temporalio.converter from ._workflow_service_generated import ( WorkflowService, WorkflowServiceSignalWithStartWorkflowExecutionInput, WorkflowServiceSignalWithStartWorkflowExecutionOutput, - __temporal_nexus_payload_codec_rewriters__, + __temporal_nexus_payload_rewriters__, ) -TemporalNexusPayloadCodecRewriter = Callable[ +TemporalNexusPayloadRewriter = Callable[ [ temporalio.api.common.v1.Payload, - temporalio.converter.PayloadCodec | None, + Callable[ + [Sequence[temporalio.api.common.v1.Payload]], + Awaitable[list[temporalio.api.common.v1.Payload]], + ], + bool, ], Awaitable[temporalio.api.common.v1.Payload], ] -def get_payload_codec_rewriter( +def get_payload_rewriter( service: str, operation: str, -) -> TemporalNexusPayloadCodecRewriter | None: - """Return the generated payload codec rewriter for a system Nexus operation.""" - return __temporal_nexus_payload_codec_rewriters__.get((service, operation)) +) -> TemporalNexusPayloadRewriter | None: + """Return the generated nested-payload rewriter for a system Nexus operation.""" + return __temporal_nexus_payload_rewriters__.get((service, operation)) __all__ = ( "WorkflowService", "WorkflowServiceSignalWithStartWorkflowExecutionInput", "WorkflowServiceSignalWithStartWorkflowExecutionOutput", - "get_payload_codec_rewriter", + "get_payload_rewriter", ) diff --git a/temporalio/nexus/system/_workflow_service.nexusrpc.yaml b/temporalio/nexus/system/_workflow_service.nexusrpc.yaml index c5b8cc671..edea24b4e 100644 --- a/temporalio/nexus/system/_workflow_service.nexusrpc.yaml +++ b/temporalio/nexus/system/_workflow_service.nexusrpc.yaml @@ -1,3 +1,5 @@ +# TODO: Remove this local shim once the upstream API repo checks in the Nexus +# definition and the generator can consume it directly. nexusrpc: 1.0.0 services: WorkflowService: diff --git a/temporalio/nexus/system/_workflow_service_generated.py b/temporalio/nexus/system/_workflow_service_generated.py index 085dad9b1..df330d8de 100644 --- a/temporalio/nexus/system/_workflow_service_generated.py +++ b/temporalio/nexus/system/_workflow_service_generated.py @@ -2,6 +2,7 @@ from __future__ import annotations +import collections.abc import json from enum import Enum from typing import Any, Dict, List, Optional @@ -11,66 +12,6 @@ from pydantic import BaseModel, Field import temporalio.api.common.v1 -import temporalio.converter - - -async def _temporal_nexus_encode_payload_json( - value: dict, payload_codec: temporalio.converter.PayloadCodec -) -> dict: - payload = ParseDict(value, temporalio.api.common.v1.Payload()) - [encoded_payload] = await payload_codec.encode([payload]) - return MessageToDict(encoded_payload) - - -async def _temporal_nexus_encode_payloads_json( - value: dict, payload_codec: temporalio.converter.PayloadCodec -) -> dict: - payloads = ParseDict(value, temporalio.api.common.v1.Payloads()) - encoded_payloads = await payload_codec.encode(payloads.payloads) - del payloads.payloads[:] - payloads.payloads.extend(encoded_payloads) - return MessageToDict(payloads) - - -async def _temporal_nexus_encode_payload_map_json( - message_type: type, value: dict, payload_codec: temporalio.converter.PayloadCodec -) -> dict: - message = ParseDict(value, message_type()) - keys = list(message.fields.keys()) - encoded_payloads = await payload_codec.encode([message.fields[key] for key in keys]) - for key, encoded_payload in zip(keys, encoded_payloads): - message.fields[key].CopyFrom(encoded_payload) - return MessageToDict(message) - - -async def _temporal_nexus_encode_json_value( - value: object, payload_codec: temporalio.converter.PayloadCodec -) -> object: - if isinstance(value, list): - return [ - await _temporal_nexus_encode_json_value(item, payload_codec) - for item in value - ] - if not isinstance(value, dict): - return value - if "indexedFields" in value: - return value - if "payloads" in value and isinstance(value["payloads"], list): - return await _temporal_nexus_encode_payloads_json(value, payload_codec) - if "fields" in value and isinstance(value["fields"], dict): - return await _temporal_nexus_encode_payload_map_json( - temporalio.api.common.v1.Header, value, payload_codec - ) - if "data" in value and "metadata" in value: - return await _temporal_nexus_encode_payload_json(value, payload_codec) - rewritten: dict[str, object] = {} - for key, item in value.items(): - rewritten[key] = ( - item - if key == "indexedFields" - else await _temporal_nexus_encode_json_value(item, payload_codec) - ) - return rewritten class Header(BaseModel): @@ -743,34 +684,151 @@ class WorkflowServiceSignalWithStartWorkflowExecutionOutput(BaseModel): """If true, a new workflow was started.""" +@service +class WorkflowService: + signal_with_start_workflow_execution: Operation[ + WorkflowServiceSignalWithStartWorkflowExecutionInput, + WorkflowServiceSignalWithStartWorkflowExecutionOutput, + ] = Operation(name="SignalWithStartWorkflowExecution") + + +class _TemporalNexusPayloadRewriter: + def __init__( + self, + payload_visitor: collections.abc.Callable[ + [collections.abc.Sequence[temporalio.api.common.v1.Payload]], + collections.abc.Awaitable[list[temporalio.api.common.v1.Payload]], + ], + visit_search_attributes: bool = False, + ): + self._payload_visitor = payload_visitor + self._visit_search_attributes = visit_search_attributes + + async def _rewrite_payload_json(self, value: dict) -> dict: + payload = ParseDict(value, temporalio.api.common.v1.Payload()) + [rewritten_payload] = await self._payload_visitor([payload]) + return MessageToDict(rewritten_payload) + + async def _rewrite_payloads_json(self, value: dict) -> dict: + payloads = ParseDict(value, temporalio.api.common.v1.Payloads()) + rewritten_payloads = await self._payload_visitor(payloads.payloads) + del payloads.payloads[:] + payloads.payloads.extend(rewritten_payloads) + return MessageToDict(payloads) + + async def _rewrite_payload_map_json(self, message_type: type, value: dict) -> dict: + message = message_type() + keys = list(value.keys()) + rewritten_payloads = await self._payload_visitor( + [ParseDict(value[key], temporalio.api.common.v1.Payload()) for key in keys] + ) + for key, rewritten_payload in zip(keys, rewritten_payloads): + message.fields[key].CopyFrom(rewritten_payload) + return MessageToDict(message).get("fields", {}) + + async def _temporal_nexus_rewrite_header_json(self, value: dict) -> dict: + rewritten = dict(value) + if rewritten.get("fields") is not None: + rewritten["fields"] = await self._rewrite_payload_map_json( + temporalio.api.common.v1.Header, rewritten["fields"] + ) + return rewritten + + async def _temporal_nexus_rewrite_input_json(self, value: dict) -> dict: + return await self._rewrite_payloads_json(value) + + async def _temporal_nexus_rewrite_memo_json(self, value: dict) -> dict: + rewritten = dict(value) + if rewritten.get("fields") is not None: + rewritten["fields"] = await self._rewrite_payload_map_json( + temporalio.api.common.v1.Memo, rewritten["fields"] + ) + return rewritten + + async def _temporal_nexus_rewrite_search_attributes_json(self, value: dict) -> dict: + if not self._visit_search_attributes: + return value + rewritten = dict(value) + if rewritten.get("indexedFields") is not None: + rewritten["indexedFields"] = await self._rewrite_payload_map_json( + temporalio.api.common.v1.SearchAttributes, rewritten["indexedFields"] + ) + return rewritten + + async def _temporal_nexus_rewrite_user_metadata_json(self, value: dict) -> dict: + rewritten = dict(value) + if rewritten.get("details") is not None: + rewritten["details"] = await self._rewrite_payload_json( + rewritten["details"] + ) + if rewritten.get("summary") is not None: + rewritten["summary"] = await self._rewrite_payload_json( + rewritten["summary"] + ) + return rewritten + + async def _temporal_nexus_rewrite_workflow_service_signal_with_start_workflow_execution_input_json( + self, value: dict + ) -> dict: + rewritten = dict(value) + if rewritten.get("header") is not None: + rewritten["header"] = await self._temporal_nexus_rewrite_header_json( + rewritten["header"] + ) + if rewritten.get("input") is not None: + rewritten["input"] = await self._temporal_nexus_rewrite_input_json( + rewritten["input"] + ) + if rewritten.get("memo") is not None: + rewritten["memo"] = await self._temporal_nexus_rewrite_memo_json( + rewritten["memo"] + ) + if rewritten.get("searchAttributes") is not None: + rewritten[ + "searchAttributes" + ] = await self._temporal_nexus_rewrite_search_attributes_json( + rewritten["searchAttributes"] + ) + if rewritten.get("signalInput") is not None: + rewritten["signalInput"] = await self._temporal_nexus_rewrite_input_json( + rewritten["signalInput"] + ) + if rewritten.get("userMetadata") is not None: + rewritten[ + "userMetadata" + ] = await self._temporal_nexus_rewrite_user_metadata_json( + rewritten["userMetadata"] + ) + return rewritten + + async def _temporal_nexus_rewrite_workflow_service_signal_with_start_workflow_execution_input( payload: temporalio.api.common.v1.Payload, - payload_codec: temporalio.converter.PayloadCodec | None, + payload_visitor: collections.abc.Callable[ + [collections.abc.Sequence[temporalio.api.common.v1.Payload]], + collections.abc.Awaitable[list[temporalio.api.common.v1.Payload]], + ], + visit_search_attributes: bool = False, ) -> temporalio.api.common.v1.Payload: - if payload_codec is None: - return payload try: value = json.loads(payload.data) except json.JSONDecodeError: return payload - rewritten = await _temporal_nexus_encode_json_value(value, payload_codec) + if not isinstance(value, dict): + return payload + rewriter = _TemporalNexusPayloadRewriter(payload_visitor, visit_search_attributes) + rewritten = await rewriter._temporal_nexus_rewrite_workflow_service_signal_with_start_workflow_execution_input_json( + value + ) return temporalio.api.common.v1.Payload( metadata=dict(payload.metadata), data=json.dumps(rewritten, separators=(",", ":"), sort_keys=True).encode(), ) -__temporal_nexus_payload_codec_rewriters__ = { +__temporal_nexus_payload_rewriters__ = { ( "WorkflowService", "SignalWithStartWorkflowExecution", ): _temporal_nexus_rewrite_workflow_service_signal_with_start_workflow_execution_input, } - - -@service -class WorkflowService: - signal_with_start_workflow_execution: Operation[ - WorkflowServiceSignalWithStartWorkflowExecutionInput, - WorkflowServiceSignalWithStartWorkflowExecutionOutput, - ] = Operation(name="SignalWithStartWorkflowExecution") diff --git a/tests/nexus/test_temporal_system_nexus.py b/tests/nexus/test_temporal_system_nexus.py index f9dd6b7a9..ecaf300da 100644 --- a/tests/nexus/test_temporal_system_nexus.py +++ b/tests/nexus/test_temporal_system_nexus.py @@ -14,7 +14,7 @@ from temporalio import workflow from temporalio.client import Client from temporalio.contrib.pydantic import pydantic_data_converter -from temporalio.converter import PayloadCodec +from temporalio.converter import ExternalStorage, PayloadCodec from temporalio.nexus.system import ( WorkflowService, WorkflowServiceSignalWithStartWorkflowExecutionInput, @@ -24,6 +24,7 @@ from temporalio.worker import Worker from temporalio.worker._workflow_instance import UnsandboxedWorkflowRunner from tests.helpers.nexus import make_nexus_endpoint_name +from tests.test_extstore import InMemoryTestDriver @nexusrpc.handler.service_handler(service=WorkflowService) @@ -34,12 +35,21 @@ async def signal_with_start_workflow_execution( _ctx: nexusrpc.handler.StartOperationContext, request: WorkflowServiceSignalWithStartWorkflowExecutionInput, ) -> WorkflowServiceSignalWithStartWorkflowExecutionOutput: + request_dict = request.model_dump(by_alias=True) for field_name in ("input", "signalInput"): - payloads = request.model_dump(by_alias=True)[field_name]["payloads"] - assert "test-codec" in payloads[0]["metadata"] + payloads = request_dict[field_name]["payloads"] + assert payloads[0]["externalPayloads"] for field_name in ("memo", "header"): - fields = request.model_dump(by_alias=True)[field_name]["fields"] - assert "test-codec" in next(iter(fields.values()))["metadata"] + fields = request_dict[field_name]["fields"] + assert next(iter(fields.values()))["externalPayloads"] + for field_name in ("summary", "details"): + payload = request_dict["userMetadata"][field_name] + assert payload["externalPayloads"] + search_attribute_payload = request_dict["searchAttributes"]["indexedFields"][ + "custom-key" + ] + assert "externalPayloads" not in search_attribute_payload + assert "test-codec" not in search_attribute_payload["metadata"] return WorkflowServiceSignalWithStartWorkflowExecutionOutput( runId=f"{request.workflow_id}-run" ) @@ -98,6 +108,30 @@ async def run(self, task_queue: str) -> str: } ) ), + "userMetadata": { + "summary": MessageToDict( + temporalio.api.common.v1.Payload( + metadata={"encoding": b"json/plain"}, + data=b'"summary-value"', + ) + ), + "details": MessageToDict( + temporalio.api.common.v1.Payload( + metadata={"encoding": b"json/plain"}, + data=b'"details-value"', + ) + ), + }, + "searchAttributes": { + "indexedFields": { + "custom-key": MessageToDict( + temporalio.api.common.v1.Payload( + metadata={"encoding": b"json/plain"}, + data=b'"search-attribute-value"', + ) + ) + } + }, } ) handle = await nexus_client.start_operation( @@ -151,10 +185,15 @@ async def test_workflow_service_signal_with_start_nested_payloads_use_codec_with pytest.skip("Nexus tests don't work with the Java test server") codec = RejectOuterSystemNexusCodec() + driver = InMemoryTestDriver() config = env.client.config() config["data_converter"] = dataclasses.replace( pydantic_data_converter, payload_codec=codec, + external_storage=ExternalStorage( + drivers=[driver], + payload_size_threshold=1, + ), ) client = Client(**config) @@ -175,4 +214,19 @@ async def test_workflow_service_signal_with_start_nested_payloads_use_codec_with ) assert result == "system-nexus-workflow-id-run" - assert codec.encode_count >= 4 + assert codec.encode_count >= 6 + stored_payloads: list[temporalio.api.common.v1.Payload] = [] + for stored_payload_bytes in driver._storage.values(): + stored_payload = temporalio.api.common.v1.Payload() + stored_payload.ParseFromString(stored_payload_bytes) + stored_payloads.append(stored_payload) + assert stored_payload.metadata["test-codec"] == b"true" + stored_payload_data = {payload.data for payload in stored_payloads} + assert { + b'"workflow-input"', + b'"signal-input"', + b'"memo-value"', + b'"header-value"', + b'"summary-value"', + b'"details-value"', + }.issubset(stored_payload_data) From e0523210ab594767cacd0181a12ab65e05f17cba Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Thu, 9 Apr 2026 09:37:12 -0700 Subject: [PATCH 05/18] Use generated system nexus module directly --- pyproject.toml | 12 +- temporalio/bridge/worker.py | 7 +- temporalio/nexus/system/__init__.py | 27 ++- .../system/_workflow_service_generated.py | 164 +++++++------ temporalio/worker/_workflow_instance.py | 15 +- tests/nexus/test_temporal_system_nexus.py | 229 +++++++++++------- 6 files changed, 262 insertions(+), 192 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a1f5a6661..e65b0ed54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,14 +67,8 @@ dev = [ ] [tool.poe.tasks] -build-develop = [ - { ref = "gen-nexus-system-models" }, - { cmd = "uv run maturin develop --uv" }, -] -build-develop-with-release = [ - { ref = "gen-nexus-system-models" }, - { cmd = "uv run maturin develop --release --uv" }, -] +build-develop = "uv run maturin develop --uv" +build-develop-with-release = { cmd = "uv run maturin develop --release --uv" } format = [ { cmd = "uv run ruff check --select I --fix" }, { cmd = "uv run ruff format" }, @@ -92,6 +86,7 @@ gen-protos-docker = [ { cmd = "uv run scripts/gen_protos_docker.py" }, { cmd = "uv run scripts/gen_payload_visitor.py" }, { cmd = "uv run scripts/gen_bridge_client.py" }, + { ref = "gen-nexus-system-models" }, { ref = "format" }, ] lint = [ @@ -105,7 +100,6 @@ bridge-lint = { cmd = "cargo clippy -- -D warnings", cwd = "temporalio/bridge" } # https://github.com/PyCQA/pydocstyle/pull/511? lint-docs = "uv run pydocstyle --ignore-decorators=overload" lint-types = [ - { ref = "gen-nexus-system-models" }, { cmd = "uv run pyright" }, { cmd = "uv run mypy --namespace-packages --check-untyped-defs ." }, { cmd = "uv run basedpyright" }, diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index 587bbaa12..8539d695d 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -326,8 +326,7 @@ async def _encode_completion_payloads( return await data_converter._encode_payload_sequence(payloads) rewrite = temporalio.nexus.system.get_payload_rewriter( - command_info.nexus_service, - command_info.nexus_operation, + command_info.nexus_service, command_info.nexus_operation ) if rewrite is None: return await data_converter._encode_payload_sequence(payloads) @@ -360,6 +359,8 @@ async def encode_completion( await CommandAwarePayloadVisitor( skip_search_attributes=True, skip_headers=not encode_headers ).visit( - _Visitor(lambda payloads: _encode_completion_payloads(data_converter, payloads)), + _Visitor( + lambda payloads: _encode_completion_payloads(data_converter, payloads) + ), completion, ) diff --git a/temporalio/nexus/system/__init__.py b/temporalio/nexus/system/__init__.py index 52a80cda1..2f4dc7926 100644 --- a/temporalio/nexus/system/__init__.py +++ b/temporalio/nexus/system/__init__.py @@ -7,13 +7,10 @@ from collections.abc import Awaitable, Callable, Sequence import temporalio.api.common.v1 +import temporalio.converter -from ._workflow_service_generated import ( - WorkflowService, - WorkflowServiceSignalWithStartWorkflowExecutionInput, - WorkflowServiceSignalWithStartWorkflowExecutionOutput, - __temporal_nexus_payload_rewriters__, -) +from . import _workflow_service_generated as generated +from ._workflow_service_generated import __temporal_nexus_payload_rewriters__ TemporalNexusPayloadRewriter = Callable[ [ @@ -27,6 +24,8 @@ Awaitable[temporalio.api.common.v1.Payload], ] +_SYSTEM_NEXUS_PAYLOAD_CONVERTER = temporalio.converter.JSONPlainPayloadConverter() + def get_payload_rewriter( service: str, @@ -36,9 +35,19 @@ def get_payload_rewriter( return __temporal_nexus_payload_rewriters__.get((service, operation)) +def is_system_operation(service: str, operation: str) -> bool: + """Return whether a Nexus operation uses the generated system envelope.""" + return get_payload_rewriter(service, operation) is not None + + +def get_payload_converter() -> temporalio.converter.EncodingPayloadConverter: + """Return the fixed payload converter for system Nexus outer envelopes.""" + return _SYSTEM_NEXUS_PAYLOAD_CONVERTER + + __all__ = ( - "WorkflowService", - "WorkflowServiceSignalWithStartWorkflowExecutionInput", - "WorkflowServiceSignalWithStartWorkflowExecutionOutput", + "generated", + "get_payload_converter", "get_payload_rewriter", + "is_system_operation", ) diff --git a/temporalio/nexus/system/_workflow_service_generated.py b/temporalio/nexus/system/_workflow_service_generated.py index df330d8de..e2273b871 100644 --- a/temporalio/nexus/system/_workflow_service_generated.py +++ b/temporalio/nexus/system/_workflow_service_generated.py @@ -4,17 +4,18 @@ import collections.abc import json +from dataclasses import dataclass from enum import Enum from typing import Any, Dict, List, Optional from google.protobuf.json_format import MessageToDict, ParseDict from nexusrpc import Operation, service -from pydantic import BaseModel, Field import temporalio.api.common.v1 -class Header(BaseModel): +@dataclass +class Header: """Contains metadata that can be attached to a variety of requests, like starting a workflow, and can be propagated between, for example, workflows and activities. @@ -23,7 +24,8 @@ class Header(BaseModel): fields: Optional[Dict[str, Any]] = None -class Input(BaseModel): +@dataclass +class Input: """Serialized arguments to the workflow. These are passed as arguments to the workflow function. @@ -35,14 +37,15 @@ class Input(BaseModel): payloads: Optional[List[Any]] = None -class BatchJob(BaseModel): +@dataclass +class BatchJob: """A link to a built-in batch job. Batch jobs can be used to perform operations on a set of workflows (e.g. terminate, signal, cancel, etc). This link can be put on workflow history events generated by actions taken by a batch job. """ - job_id: Optional[str] = Field(None, alias="jobId") + jobId: Optional[str] = None class EventType(Enum): @@ -166,29 +169,33 @@ class EventType(Enum): EVENT_TYPE_WORKFLOW_TASK_TIMED_OUT = "EVENT_TYPE_WORKFLOW_TASK_TIMED_OUT" -class EventRef(BaseModel): +@dataclass +class EventRef: """EventReference is a direct reference to a history event through the event ID.""" - event_id: Optional[str] = Field(None, alias="eventId") - event_type: Optional[EventType] = Field(None, alias="eventType") + eventId: Optional[str] = None + eventType: Optional[EventType] = None -class RequestIDRef(BaseModel): +@dataclass +class RequestIDRef: """RequestIdReference is a indirect reference to a history event through the request ID.""" - event_type: Optional[EventType] = Field(None, alias="eventType") - request_id: Optional[str] = Field(None, alias="requestId") + eventType: Optional[EventType] = None + requestId: Optional[str] = None -class WorkflowEvent(BaseModel): - event_ref: Optional[EventRef] = Field(None, alias="eventRef") +@dataclass +class WorkflowEvent: + eventRef: Optional[EventRef] = None namespace: Optional[str] = None - request_id_ref: Optional[RequestIDRef] = Field(None, alias="requestIdRef") - run_id: Optional[str] = Field(None, alias="runId") - workflow_id: Optional[str] = Field(None, alias="workflowId") + requestIdRef: Optional[RequestIDRef] = None + runId: Optional[str] = None + workflowId: Optional[str] = None -class Openapiv3(BaseModel): +@dataclass +class Openapiv3: """Link can be associated with history events. It might contain information about an external entity related to the history event. For example, workflow A makes a Nexus call that starts @@ -198,17 +205,19 @@ class Openapiv3(BaseModel): workflow B, and vice-versa. """ - batch_job: Optional[BatchJob] = Field(None, alias="batchJob") - workflow_event: Optional[WorkflowEvent] = Field(None, alias="workflowEvent") + batchJob: Optional[BatchJob] = None + workflowEvent: Optional[WorkflowEvent] = None -class Memo(BaseModel): +@dataclass +class Memo: """A user-defined set of *unindexed* fields that are exposed when listing/searching workflows""" fields: Optional[Dict[str, Any]] = None -class Priority(BaseModel): +@dataclass +class Priority: """Priority metadata Priority contains metadata that controls relative ordering of task processing @@ -245,7 +254,7 @@ class Priority(BaseModel): fields. (Currently only support in matching task queues is planned.) """ - fairness_key: Optional[str] = Field(None, alias="fairnessKey") + fairnessKey: Optional[str] = None """Fairness key is a short string that's used as a key for a fairness balancing mechanism. It may correspond to a tenant id, or to a fixed string like "high" or "low". The default is the empty string. @@ -271,7 +280,7 @@ class Priority(BaseModel): Fairness keys are limited to 64 bytes. """ - fairness_weight: Optional[float] = Field(None, alias="fairnessWeight") + fairnessWeight: Optional[float] = None """Fairness weight for a task can come from multiple sources for flexibility. From highest to lowest precedence: 1. Weights for a small set of keys can be overridden in task queue @@ -281,7 +290,7 @@ class Priority(BaseModel): Weight values are clamped to the range [0.001, 1000]. """ - priority_key: Optional[int] = Field(None, alias="priorityKey") + priorityKey: Optional[int] = None """Priority key is a positive integer from 1 to n, where smaller integers correspond to higher priorities (tasks run sooner). In general, tasks in a queue should be processed in close to priority order, although small @@ -296,45 +305,45 @@ class Priority(BaseModel): """ -class RetryPolicy(BaseModel): +@dataclass +class RetryPolicy: """Retry policy for the workflow How retries ought to be handled, usable by both workflows and activities """ - backoff_coefficient: Optional[float] = Field(None, alias="backoffCoefficient") + backoffCoefficient: Optional[float] = None """Coefficient used to calculate the next retry interval. The next retry interval is previous interval multiplied by the coefficient. Must be 1 or larger. """ - initial_interval: Optional[str] = Field(None, alias="initialInterval") + initialInterval: Optional[str] = None """Interval of the first retry. If retryBackoffCoefficient is 1.0 then it is used for all retries. """ - maximum_attempts: Optional[int] = Field(None, alias="maximumAttempts") + maximumAttempts: Optional[int] = None """Maximum number of attempts. When exceeded the retries stop even if not expired yet. 1 disables retries. 0 means unlimited (up to the timeouts) """ - maximum_interval: Optional[str] = Field(None, alias="maximumInterval") + maximumInterval: Optional[str] = None """Maximum interval between retries. Exponential backoff leads to interval increase. This value is the cap of the increase. Default is 100x of the initial interval. """ - non_retryable_error_types: Optional[List[str]] = Field( - None, alias="nonRetryableErrorTypes" - ) + nonRetryableErrorTypes: Optional[List[str]] = None """Non-Retryable errors types. Will stop retrying if the error type matches this list. Note that this is not a substring match, the error *type* (not message) must match exactly. """ -class SearchAttributes(BaseModel): +@dataclass +class SearchAttributes: """A user-defined set of *indexed* fields that are used/exposed when listing/searching workflows. The payload is not serialized in a user-defined way. """ - indexed_fields: Optional[Dict[str, Any]] = Field(None, alias="indexedFields") + indexedFields: Optional[Dict[str, Any]] = None class Kind(Enum): @@ -345,7 +354,8 @@ class Kind(Enum): TASK_QUEUE_KIND_UNSPECIFIED = "TASK_QUEUE_KIND_UNSPECIFIED" -class TaskQueue(BaseModel): +@dataclass +class TaskQueue: """The task queue to start this workflow on, if it will be started See https://docs.temporal.io/docs/concepts/task-queues/ @@ -355,13 +365,14 @@ class TaskQueue(BaseModel): """Default: TASK_QUEUE_KIND_NORMAL.""" name: Optional[str] = None - normal_name: Optional[str] = Field(None, alias="normalName") + normalName: Optional[str] = None """Iff kind == TASK_QUEUE_KIND_STICKY, then this field contains the name of the normal task queue that the sticky worker is running on. """ -class UserMetadata(BaseModel): +@dataclass +class UserMetadata: """Metadata on the workflow if it is started. This is carried over to the WorkflowExecutionInfo for use by user interfaces to display the fixed as-of-start summary and details of the @@ -398,7 +409,8 @@ class VersioningOverrideBehavior(Enum): VERSIONING_BEHAVIOR_UNSPECIFIED = "VERSIONING_BEHAVIOR_UNSPECIFIED" -class Deployment(BaseModel): +@dataclass +class Deployment: """Required if behavior is `PINNED`. Must be null if behavior is `AUTO_UPGRADE`. Identifies the worker deployment to pin the workflow to. Deprecated. Use `override.pinned.version`. @@ -411,12 +423,12 @@ class Deployment(BaseModel): Deprecated. """ - build_id: Optional[str] = Field(None, alias="buildId") + buildId: Optional[str] = None """Build ID changes with each version of the worker when the worker program code and/or config changes. """ - series_name: Optional[str] = Field(None, alias="seriesName") + seriesName: Optional[str] = None """Different versions of the same worker service/application are related together by having a shared series name. @@ -436,7 +448,8 @@ class PinnedBehavior(Enum): PINNED_OVERRIDE_BEHAVIOR_UNSPECIFIED = "PINNED_OVERRIDE_BEHAVIOR_UNSPECIFIED" -class Version(BaseModel): +@dataclass +class Version: """Specifies the Worker Deployment Version to pin this workflow to. Required if the target workflow is not already pinned to a version. @@ -455,17 +468,18 @@ class Version(BaseModel): in the future. """ - build_id: Optional[str] = Field(None, alias="buildId") + buildId: Optional[str] = None """A unique identifier for this Version within the Deployment it is a part of. Not necessarily unique within the namespace. The combination of `deployment_name` and `build_id` uniquely identifies this Version within the namespace, because Deployment names are unique within a namespace. """ - deployment_name: Optional[str] = Field(None, alias="deploymentName") + deploymentName: Optional[str] = None """Identifies the Worker Deployment this Version is part of.""" -class Pinned(BaseModel): +@dataclass +class Pinned: """Override the workflow to have Pinned behavior.""" behavior: Optional[PinnedBehavior] = None @@ -484,7 +498,8 @@ class Pinned(BaseModel): """ -class VersioningOverride(BaseModel): +@dataclass +class VersioningOverride: """If set, takes precedence over the Versioning Behavior sent by the SDK on Workflow Task completion. To unset the override after the workflow is running, use UpdateWorkflowExecutionOptions. @@ -503,7 +518,7 @@ class VersioningOverride(BaseModel): workflows, and cron workflows. """ - auto_upgrade: Optional[bool] = Field(None, alias="autoUpgrade") + autoUpgrade: Optional[bool] = None """Override the workflow to have AutoUpgrade behavior.""" behavior: Optional[VersioningOverrideBehavior] = None @@ -518,7 +533,7 @@ class VersioningOverride(BaseModel): pinned: Optional[Pinned] = None """Override the workflow to have Pinned behavior.""" - pinned_version: Optional[str] = Field(None, alias="pinnedVersion") + pinnedVersion: Optional[str] = None """Required if behavior is `PINNED`. Must be absent if behavior is not `PINNED`. Identifies the worker deployment version to pin the workflow to, in the format ".". @@ -568,7 +583,8 @@ class WorkflowIDReusePolicy(Enum): WORKFLOW_ID_REUSE_POLICY_UNSPECIFIED = "WORKFLOW_ID_REUSE_POLICY_UNSPECIFIED" -class WorkflowType(BaseModel): +@dataclass +class WorkflowType: """Represents the identifier used by a workflow author to define the workflow. Typically, the name of a function. This is sometimes referred to as the workflow's "name" @@ -577,11 +593,12 @@ class WorkflowType(BaseModel): name: Optional[str] = None -class WorkflowServiceSignalWithStartWorkflowExecutionInput(BaseModel): +@dataclass +class WorkflowServiceSignalWithStartWorkflowExecutionInput: control: Optional[str] = None """Deprecated.""" - cron_schedule: Optional[str] = Field(None, alias="cronSchedule") + cronSchedule: Optional[str] = None """See https://docs.temporal.io/docs/content/what-is-a-temporal-cron-job/""" header: Optional[Header] = None @@ -601,46 +618,38 @@ class WorkflowServiceSignalWithStartWorkflowExecutionInput(BaseModel): priority: Optional[Priority] = None """Priority metadata""" - request_id: Optional[str] = Field(None, alias="requestId") + requestId: Optional[str] = None """Used to de-dupe signal w/ start requests""" - retry_policy: Optional[RetryPolicy] = Field(None, alias="retryPolicy") + retryPolicy: Optional[RetryPolicy] = None """Retry policy for the workflow""" - search_attributes: Optional[SearchAttributes] = Field( - None, alias="searchAttributes" - ) - signal_input: Optional[Input] = Field(None, alias="signalInput") + searchAttributes: Optional[SearchAttributes] = None + signalInput: Optional[Input] = None """Serialized value(s) to provide with the signal""" - signal_name: Optional[str] = Field(None, alias="signalName") + signalName: Optional[str] = None """The workflow author-defined name of the signal to send to the workflow""" - task_queue: Optional[TaskQueue] = Field(None, alias="taskQueue") + taskQueue: Optional[TaskQueue] = None """The task queue to start this workflow on, if it will be started""" - user_metadata: Optional[UserMetadata] = Field(None, alias="userMetadata") + userMetadata: Optional[UserMetadata] = None """Metadata on the workflow if it is started. This is carried over to the WorkflowExecutionInfo for use by user interfaces to display the fixed as-of-start summary and details of the workflow. """ - versioning_override: Optional[VersioningOverride] = Field( - None, alias="versioningOverride" - ) + versioningOverride: Optional[VersioningOverride] = None """If set, takes precedence over the Versioning Behavior sent by the SDK on Workflow Task completion. To unset the override after the workflow is running, use UpdateWorkflowExecutionOptions. """ - workflow_execution_timeout: Optional[str] = Field( - None, alias="workflowExecutionTimeout" - ) + workflowExecutionTimeout: Optional[str] = None """Total workflow execution timeout including retries and continue as new""" - workflow_id: Optional[str] = Field(None, alias="workflowId") - workflow_id_conflict_policy: Optional[WorkflowIDConflictPolicy] = Field( - None, alias="workflowIdConflictPolicy" - ) + workflowId: Optional[str] = None + workflowIdConflictPolicy: Optional[WorkflowIDConflictPolicy] = None """Defines how to resolve a workflow id conflict with a *running* workflow. The default policy is WORKFLOW_ID_CONFLICT_POLICY_USE_EXISTING. Note that WORKFLOW_ID_CONFLICT_POLICY_FAIL is an invalid option. @@ -648,19 +657,17 @@ class WorkflowServiceSignalWithStartWorkflowExecutionInput(BaseModel): See `workflow_id_reuse_policy` for handling a workflow id duplication with a *closed* workflow. """ - workflow_id_reuse_policy: Optional[WorkflowIDReusePolicy] = Field( - None, alias="workflowIdReusePolicy" - ) + workflowIdReusePolicy: Optional[WorkflowIDReusePolicy] = None """Defines whether to allow re-using the workflow id from a previously *closed* workflow. The default policy is WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE. See `workflow_id_reuse_policy` for handling a workflow id duplication with a *running* workflow. """ - workflow_run_timeout: Optional[str] = Field(None, alias="workflowRunTimeout") + workflowRunTimeout: Optional[str] = None """Timeout of a single workflow run""" - workflow_start_delay: Optional[str] = Field(None, alias="workflowStartDelay") + workflowStartDelay: Optional[str] = None """Time to wait before dispatching the first workflow task. Cannot be used with `cron_schedule`. Note that the signal will be delivered with the first workflow task. If the workflow @@ -670,14 +677,15 @@ class WorkflowServiceSignalWithStartWorkflowExecutionInput(BaseModel): and the rest of the delay period will be ignored, even if that request also had a delay. Signal via SignalWorkflowExecution will not unblock the workflow. """ - workflow_task_timeout: Optional[str] = Field(None, alias="workflowTaskTimeout") + workflowTaskTimeout: Optional[str] = None """Timeout of a single workflow task""" - workflow_type: Optional[WorkflowType] = Field(None, alias="workflowType") + workflowType: Optional[WorkflowType] = None -class WorkflowServiceSignalWithStartWorkflowExecutionOutput(BaseModel): - run_id: Optional[str] = Field(None, alias="runId") +@dataclass +class WorkflowServiceSignalWithStartWorkflowExecutionOutput: + runId: Optional[str] = None """The run id of the workflow that was started - or just signaled, if it was already running.""" started: Optional[bool] = None diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 1bfa77c3c..8fb60667c 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -57,6 +57,7 @@ import temporalio.common import temporalio.converter import temporalio.exceptions +import temporalio.nexus.system import temporalio.workflow from temporalio.service import __version__ @@ -3345,14 +3346,24 @@ def _resolve_failure(self, err: BaseException) -> None: self._result_fut.set_result(None) def _apply_schedule_command(self) -> None: - payload = self._payload_converter.to_payload(self._input.input) command = self._instance._add_command() v = command.schedule_nexus_operation v.seq = self._seq v.endpoint = self._input.endpoint v.service = self._input.service v.operation = self._input.operation_name - v.input.CopyFrom(payload) + payload_converter = ( + temporalio.nexus.system.get_payload_converter() + if temporalio.nexus.system.is_system_operation(v.service, v.operation) + else self._payload_converter + ) + payload = payload_converter.to_payload(self._input.input) + if payload is None: + raise RuntimeError( + "Nexus operation input could not be converted to a payload" + ) + payload_message: temporalio.api.common.v1.Payload = payload + v.input.CopyFrom(payload_message) if self._input.schedule_to_close_timeout is not None: v.schedule_to_close_timeout.FromTimedelta( self._input.schedule_to_close_timeout diff --git a/tests/nexus/test_temporal_system_nexus.py b/tests/nexus/test_temporal_system_nexus.py index ecaf300da..d02849bca 100644 --- a/tests/nexus/test_temporal_system_nexus.py +++ b/tests/nexus/test_temporal_system_nexus.py @@ -11,15 +11,15 @@ from google.protobuf.json_format import MessageToDict import temporalio.api.common.v1 +import temporalio.converter from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.pydantic import pydantic_data_converter -from temporalio.converter import ExternalStorage, PayloadCodec -from temporalio.nexus.system import ( - WorkflowService, - WorkflowServiceSignalWithStartWorkflowExecutionInput, - WorkflowServiceSignalWithStartWorkflowExecutionOutput, +from temporalio.converter import ( + DefaultPayloadConverter, + ExternalStorage, + PayloadCodec, ) +from temporalio.nexus.system import generated from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker from temporalio.worker._workflow_instance import UnsandboxedWorkflowRunner @@ -27,15 +27,17 @@ from tests.test_extstore import InMemoryTestDriver -@nexusrpc.handler.service_handler(service=WorkflowService) +@nexusrpc.handler.service_handler(service=generated.WorkflowService) class WorkflowServicePayloadHandler: @nexusrpc.handler.sync_operation async def signal_with_start_workflow_execution( self, _ctx: nexusrpc.handler.StartOperationContext, - request: WorkflowServiceSignalWithStartWorkflowExecutionInput, - ) -> WorkflowServiceSignalWithStartWorkflowExecutionOutput: - request_dict = request.model_dump(by_alias=True) + request: generated.WorkflowServiceSignalWithStartWorkflowExecutionInput, + ) -> generated.WorkflowServiceSignalWithStartWorkflowExecutionOutput: + assert request.workflowId == "system-nexus-workflow-id" + assert request.signalName == "test-signal" + request_dict = dataclasses.asdict(request) for field_name in ("input", "signalInput"): payloads = request_dict[field_name]["payloads"] assert payloads[0]["externalPayloads"] @@ -50,8 +52,8 @@ async def signal_with_start_workflow_execution( ] assert "externalPayloads" not in search_attribute_payload assert "test-codec" not in search_attribute_payload["metadata"] - return WorkflowServiceSignalWithStartWorkflowExecutionOutput( - runId=f"{request.workflow_id}-run" + return generated.WorkflowServiceSignalWithStartWorkflowExecutionOutput( + runId=f"{request.workflowId}-run" ) @@ -60,86 +62,84 @@ class SystemNexusCallerWithPayloadsWorkflow: @workflow.run async def run(self, task_queue: str) -> str: nexus_client = workflow.create_nexus_client( - service=WorkflowService, + service=generated.WorkflowService, endpoint=make_nexus_endpoint_name(task_queue), ) - request = WorkflowServiceSignalWithStartWorkflowExecutionInput.model_validate( - { - "namespace": "default", - "workflowId": "system-nexus-workflow-id", - "signalName": "test-signal", - "input": MessageToDict( - temporalio.api.common.v1.Payloads( - payloads=[ - temporalio.api.common.v1.Payload( - metadata={"encoding": b"json/plain"}, - data=b'"workflow-input"', - ) - ] - ) - ), - "signalInput": MessageToDict( - temporalio.api.common.v1.Payloads( - payloads=[ - temporalio.api.common.v1.Payload( - metadata={"encoding": b"json/plain"}, - data=b'"signal-input"', - ) - ] - ) - ), - "memo": MessageToDict( - temporalio.api.common.v1.Memo( - fields={ - "memo-key": temporalio.api.common.v1.Payload( - metadata={"encoding": b"json/plain"}, - data=b'"memo-value"', - ) - } + request = generated.WorkflowServiceSignalWithStartWorkflowExecutionInput( + namespace="default", + workflowId="system-nexus-workflow-id", + signalName="test-signal", + input=generated.Input( + payloads=[ + MessageToDict( + temporalio.api.common.v1.Payload( + metadata={"encoding": b"json/plain"}, + data=b'"workflow-input"', + ) ) - ), - "header": MessageToDict( - temporalio.api.common.v1.Header( - fields={ - "header-key": temporalio.api.common.v1.Payload( - metadata={"encoding": b"json/plain"}, - data=b'"header-value"', - ) - } + ] + ), + signalInput=generated.Input( + payloads=[ + MessageToDict( + temporalio.api.common.v1.Payload( + metadata={"encoding": b"json/plain"}, + data=b'"signal-input"', + ) ) - ), - "userMetadata": { - "summary": MessageToDict( + ] + ), + memo=generated.Memo( + fields={ + "memo-key": MessageToDict( temporalio.api.common.v1.Payload( metadata={"encoding": b"json/plain"}, - data=b'"summary-value"', + data=b'"memo-value"', ) - ), - "details": MessageToDict( + ) + } + ), + header=generated.Header( + fields={ + "header-key": MessageToDict( temporalio.api.common.v1.Payload( metadata={"encoding": b"json/plain"}, - data=b'"details-value"', + data=b'"header-value"', ) - ), - }, - "searchAttributes": { - "indexedFields": { - "custom-key": MessageToDict( - temporalio.api.common.v1.Payload( - metadata={"encoding": b"json/plain"}, - data=b'"search-attribute-value"', - ) + ) + } + ), + userMetadata=generated.UserMetadata( + summary=MessageToDict( + temporalio.api.common.v1.Payload( + metadata={"encoding": b"json/plain"}, + data=b'"summary-value"', + ) + ), + details=MessageToDict( + temporalio.api.common.v1.Payload( + metadata={"encoding": b"json/plain"}, + data=b'"details-value"', + ) + ), + ), + searchAttributes=generated.SearchAttributes( + indexedFields={ + "custom-key": MessageToDict( + temporalio.api.common.v1.Payload( + metadata={"encoding": b"json/plain"}, + data=b'"search-attribute-value"', ) - } - }, - } + ) + } + ), ) handle = await nexus_client.start_operation( - WorkflowService.signal_with_start_workflow_execution, + generated.WorkflowService.signal_with_start_workflow_execution, request, ) result = await handle - return cast(str, result.run_id) + return cast(str, result.runId) class RejectOuterSystemNexusCodec(PayloadCodec): @@ -175,7 +175,42 @@ async def encode( async def decode( self, payloads: Sequence[temporalio.api.common.v1.Payload] ) -> list[temporalio.api.common.v1.Payload]: - return list(payloads) + decoded: list[temporalio.api.common.v1.Payload] = [] + for payload in payloads: + try: + body = json.loads(payload.data) + except json.JSONDecodeError: + body = None + if isinstance(body, dict) and { + "namespace", + "workflowId", + "signalName", + }.issubset(body): + raise RuntimeError( + "outer system nexus envelope should not be codec decoded" + ) + decoded.append(payload) + return decoded + + +class BadSystemNexusEnvelopePayloadConverter(DefaultPayloadConverter): + def to_payloads( + self, values: Sequence[object] + ) -> list[temporalio.api.common.v1.Payload]: + payloads: list[temporalio.api.common.v1.Payload] = [] + for value in values: + if isinstance( + value, generated.WorkflowServiceSignalWithStartWorkflowExecutionInput + ): + payloads.append( + temporalio.api.common.v1.Payload( + metadata={"encoding": b"json/plain"}, + data=b'{"workflow_id":"bad-envelope"}', + ) + ) + else: + payloads.extend(super().to_payloads([value])) + return payloads async def test_workflow_service_signal_with_start_nested_payloads_use_codec_without_encoding_outer_envelope( @@ -186,31 +221,43 @@ async def test_workflow_service_signal_with_start_nested_payloads_use_codec_with codec = RejectOuterSystemNexusCodec() driver = InMemoryTestDriver() - config = env.client.config() - config["data_converter"] = dataclasses.replace( - pydantic_data_converter, + caller_config = env.client.config() + caller_config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), + payload_converter_class=BadSystemNexusEnvelopePayloadConverter, payload_codec=codec, external_storage=ExternalStorage( drivers=[driver], payload_size_threshold=1, ), ) - client = Client(**config) + caller_client = Client(**caller_config) + handler_config = env.client.config() + handler_config["data_converter"] = temporalio.converter.default() + handler_client = Client(**handler_config) + caller_task_queue = str(uuid.uuid4()) + handler_task_queue = str(uuid.uuid4()) - async with Worker( - client, - task_queue=str(uuid.uuid4()), + caller_worker = Worker( + caller_client, + task_queue=caller_task_queue, workflows=[SystemNexusCallerWithPayloadsWorkflow], - nexus_service_handlers=[WorkflowServicePayloadHandler()], workflow_runner=UnsandboxedWorkflowRunner(), - ) as worker: - endpoint_name = make_nexus_endpoint_name(worker.task_queue) - await env.create_nexus_endpoint(endpoint_name, worker.task_queue) - result = await client.execute_workflow( + ) + handler_worker = Worker( + handler_client, + task_queue=handler_task_queue, + nexus_service_handlers=[WorkflowServicePayloadHandler()], + ) + + async with caller_worker, handler_worker: + endpoint_name = make_nexus_endpoint_name(handler_task_queue) + await env.create_nexus_endpoint(endpoint_name, handler_task_queue) + result = await caller_client.execute_workflow( SystemNexusCallerWithPayloadsWorkflow.run, - worker.task_queue, + handler_task_queue, id=str(uuid.uuid4()), - task_queue=worker.task_queue, + task_queue=caller_task_queue, ) assert result == "system-nexus-workflow-id-run" From ccf4ef395e654e72f3da8ce22a89b6aa9a8c7e43 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Thu, 9 Apr 2026 10:04:24 -0700 Subject: [PATCH 06/18] Typecheck generated nexus system models --- pyproject.toml | 1 - temporalio/nexus/system/__init__.py | 4 ++-- temporalio/nexus/system/_workflow_service_generated.py | 1 + temporalio/worker/_workflow_instance.py | 8 +------- 4 files changed, 4 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e65b0ed54..40bd7189e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -218,7 +218,6 @@ exclude = [ "temporalio/api", "temporalio/bridge/proto", "temporalio/bridge/_visitor.py", - "temporalio/nexus/system/_workflow_service_generated.py", "tests/worker/workflow_sandbox/testmodules/proto", ] diff --git a/temporalio/nexus/system/__init__.py b/temporalio/nexus/system/__init__.py index 2f4dc7926..bb1a2f37e 100644 --- a/temporalio/nexus/system/__init__.py +++ b/temporalio/nexus/system/__init__.py @@ -24,7 +24,7 @@ Awaitable[temporalio.api.common.v1.Payload], ] -_SYSTEM_NEXUS_PAYLOAD_CONVERTER = temporalio.converter.JSONPlainPayloadConverter() +_SYSTEM_NEXUS_PAYLOAD_CONVERTER = temporalio.converter.default().payload_converter def get_payload_rewriter( @@ -40,7 +40,7 @@ def is_system_operation(service: str, operation: str) -> bool: return get_payload_rewriter(service, operation) is not None -def get_payload_converter() -> temporalio.converter.EncodingPayloadConverter: +def get_payload_converter() -> temporalio.converter.PayloadConverter: """Return the fixed payload converter for system Nexus outer envelopes.""" return _SYSTEM_NEXUS_PAYLOAD_CONVERTER diff --git a/temporalio/nexus/system/_workflow_service_generated.py b/temporalio/nexus/system/_workflow_service_generated.py index e2273b871..4ac22c76e 100644 --- a/temporalio/nexus/system/_workflow_service_generated.py +++ b/temporalio/nexus/system/_workflow_service_generated.py @@ -1,4 +1,5 @@ # Generated by nexus-rpc-gen. DO NOT EDIT! +# pyright: reportDeprecated=false from __future__ import annotations diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 8fb60667c..3454eb2ad 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -3357,13 +3357,7 @@ def _apply_schedule_command(self) -> None: if temporalio.nexus.system.is_system_operation(v.service, v.operation) else self._payload_converter ) - payload = payload_converter.to_payload(self._input.input) - if payload is None: - raise RuntimeError( - "Nexus operation input could not be converted to a payload" - ) - payload_message: temporalio.api.common.v1.Payload = payload - v.input.CopyFrom(payload_message) + v.input.CopyFrom(payload_converter.to_payload(self._input.input)) if self._input.schedule_to_close_timeout is not None: v.schedule_to_close_timeout.FromTimedelta( self._input.schedule_to_close_timeout From 629567c34d2974e16132c50840648b1f732e19ac Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Thu, 9 Apr 2026 10:07:04 -0700 Subject: [PATCH 07/18] Use neutral payload variable names --- temporalio/bridge/worker.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index 8539d695d..9f0a43b0e 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -289,9 +289,9 @@ def __init__( self._f = f async def visit_payload(self, payload: Payload) -> None: - rewritten_payload = (await self._f([payload]))[0] - if rewritten_payload is not payload: - payload.CopyFrom(rewritten_payload) + new_payload = (await self._f([payload]))[0] + if new_payload is not payload: + payload.CopyFrom(new_payload) async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: if len(payloads) == 0: @@ -331,12 +331,12 @@ async def _encode_completion_payloads( if rewrite is None: return await data_converter._encode_payload_sequence(payloads) - rewritten_payload = await rewrite( + new_payload = await rewrite( payload, data_converter._encode_payload_sequence, False, ) - return [rewritten_payload] + return [new_payload] async def decode_activation( From 28ffda60d18c6fcb7509a145f4a8f1ec12f465fd Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Tue, 14 Apr 2026 12:20:20 -0700 Subject: [PATCH 08/18] Move Python system nexus envelope handling into visitor --- scripts/gen_payload_visitor.py | 36 ++++ temporalio/bridge/_visitor.py | 219 +++++++------------- temporalio/bridge/worker.py | 50 +---- temporalio/worker/_command_aware_visitor.py | 9 - tests/worker/test_visitor.py | 145 +++++++++++++ 5 files changed, 270 insertions(+), 189 deletions(-) diff --git a/scripts/gen_payload_visitor.py b/scripts/gen_payload_visitor.py index eabfd9e6a..a4d3f88c3 100644 --- a/scripts/gen_payload_visitor.py +++ b/scripts/gen_payload_visitor.py @@ -104,6 +104,11 @@ async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: \"\"\"Called when encountering multiple payloads together.\"\"\" raise NotImplementedError() + @abc.abstractmethod + async def visit_system_nexus_envelope(self, payload: Payload) -> None: + \"\"\"Called when encountering a recognized system Nexus envelope payload.\"\"\" + raise NotImplementedError() + class PayloadVisitor: \"\"\"A visitor for payloads. Applies a function to every payload in a tree of messages. @@ -126,6 +131,26 @@ async def visit( else: raise ValueError(f"Unknown root message type: {root.DESCRIPTOR.full_name}") + async def _visit_system_nexus_payload(self, fs, service, operation, payload) -> None: + import temporalio.nexus.system + + rewrite = temporalio.nexus.system.get_payload_rewriter(service, operation) + if rewrite is None: + await self._visit_temporal_api_common_v1_Payload(fs, payload) + return + + async def payload_visitor(payloads): + new_payloads = list(payloads) + await fs.visit_payloads(new_payloads) + return new_payloads + + new_payload = await rewrite( + payload, payload_visitor, not self.skip_search_attributes + ) + if new_payload is not payload: + payload.CopyFrom(new_payload) + await fs.visit_system_nexus_envelope(payload) + """ return header + "\n".join(self.methods) @@ -202,6 +227,17 @@ def walk(self, desc: Descriptor) -> bool: # Process regular fields first for field in regular_fields: + if ( + desc.full_name == "coresdk.workflow_commands.ScheduleNexusOperation" + and field.name == "input" + ): + has_payload = True + lines.append( + """\ + if o.HasField("input"): + await self._visit_system_nexus_payload(fs, o.service, o.operation, o.input)""" + ) + continue # Repeated fields (including maps which are represented as repeated messages) if field.label == FieldDescriptor.LABEL_REPEATED: if ( diff --git a/temporalio/bridge/_visitor.py b/temporalio/bridge/_visitor.py index 16876fb59..c496c53ff 100644 --- a/temporalio/bridge/_visitor.py +++ b/temporalio/bridge/_visitor.py @@ -1,3 +1,4 @@ + # This file is generated by gen_payload_visitor.py. Changes should be made there. import abc from typing import Any, MutableSequence @@ -6,10 +7,9 @@ class VisitorFunctions(abc.ABC): - """Set of functions which can be called by the visitor. + """Set of functions which can be called by the visitor. Allows handling payloads as a sequence. """ - @abc.abstractmethod async def visit_payload(self, payload: Payload) -> None: """Called when encountering a single payload.""" @@ -20,12 +20,15 @@ async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: """Called when encountering multiple payloads together.""" raise NotImplementedError() + @abc.abstractmethod + async def visit_system_nexus_envelope(self, payload: Payload) -> None: + """Called when encountering a recognized system Nexus envelope payload.""" + raise NotImplementedError() class PayloadVisitor: - """A visitor for payloads. + """A visitor for payloads. Applies a function to every payload in a tree of messages. """ - def __init__( self, *, skip_search_attributes: bool = False, skip_headers: bool = False ): @@ -33,7 +36,9 @@ def __init__( self.skip_search_attributes = skip_search_attributes self.skip_headers = skip_headers - async def visit(self, fs: VisitorFunctions, root: Any) -> None: + async def visit( + self, fs: VisitorFunctions, root: Any + ) -> None: """Visits the given root message with the given function.""" method_name = "_visit_" + root.DESCRIPTOR.full_name.replace(".", "_") method = getattr(self, method_name, None) @@ -42,24 +47,42 @@ async def visit(self, fs: VisitorFunctions, root: Any) -> None: else: raise ValueError(f"Unknown root message type: {root.DESCRIPTOR.full_name}") + async def _visit_system_nexus_payload(self, fs, service, operation, payload) -> None: + import temporalio.nexus.system + + rewrite = temporalio.nexus.system.get_payload_rewriter(service, operation) + if rewrite is None: + await self._visit_temporal_api_common_v1_Payload(fs, payload) + return + + async def payload_visitor(payloads): + new_payloads = list(payloads) + await fs.visit_payloads(new_payloads) + return new_payloads + + new_payload = await rewrite( + payload, payload_visitor, not self.skip_search_attributes + ) + if new_payload is not payload: + payload.CopyFrom(new_payload) + await fs.visit_system_nexus_envelope(payload) + async def _visit_temporal_api_common_v1_Payload(self, fs, o): await fs.visit_payload(o) - + async def _visit_temporal_api_common_v1_Payloads(self, fs, o): await fs.visit_payloads(o.payloads) - + async def _visit_payload_container(self, fs, o): await fs.visit_payloads(o) - + async def _visit_temporal_api_failure_v1_ApplicationFailureInfo(self, fs, o): if o.HasField("details"): await self._visit_temporal_api_common_v1_Payloads(fs, o.details) async def _visit_temporal_api_failure_v1_TimeoutFailureInfo(self, fs, o): if o.HasField("last_heartbeat_details"): - await self._visit_temporal_api_common_v1_Payloads( - fs, o.last_heartbeat_details - ) + await self._visit_temporal_api_common_v1_Payloads(fs, o.last_heartbeat_details) async def _visit_temporal_api_failure_v1_CanceledFailureInfo(self, fs, o): if o.HasField("details"): @@ -67,9 +90,7 @@ async def _visit_temporal_api_failure_v1_CanceledFailureInfo(self, fs, o): async def _visit_temporal_api_failure_v1_ResetWorkflowFailureInfo(self, fs, o): if o.HasField("last_heartbeat_details"): - await self._visit_temporal_api_common_v1_Payloads( - fs, o.last_heartbeat_details - ) + await self._visit_temporal_api_common_v1_Payloads(fs, o.last_heartbeat_details) async def _visit_temporal_api_failure_v1_Failure(self, fs, o): if o.HasField("encoded_attributes"): @@ -77,21 +98,13 @@ async def _visit_temporal_api_failure_v1_Failure(self, fs, o): if o.HasField("cause"): await self._visit_temporal_api_failure_v1_Failure(fs, o.cause) if o.HasField("application_failure_info"): - await self._visit_temporal_api_failure_v1_ApplicationFailureInfo( - fs, o.application_failure_info - ) + await self._visit_temporal_api_failure_v1_ApplicationFailureInfo(fs, o.application_failure_info) elif o.HasField("timeout_failure_info"): - await self._visit_temporal_api_failure_v1_TimeoutFailureInfo( - fs, o.timeout_failure_info - ) + await self._visit_temporal_api_failure_v1_TimeoutFailureInfo(fs, o.timeout_failure_info) elif o.HasField("canceled_failure_info"): - await self._visit_temporal_api_failure_v1_CanceledFailureInfo( - fs, o.canceled_failure_info - ) + await self._visit_temporal_api_failure_v1_CanceledFailureInfo(fs, o.canceled_failure_info) elif o.HasField("reset_workflow_failure_info"): - await self._visit_temporal_api_failure_v1_ResetWorkflowFailureInfo( - fs, o.reset_workflow_failure_info - ) + await self._visit_temporal_api_failure_v1_ResetWorkflowFailureInfo(fs, o.reset_workflow_failure_info) async def _visit_temporal_api_common_v1_Memo(self, fs, o): for v in o.fields.values(): @@ -111,15 +124,11 @@ async def _visit_coresdk_workflow_activation_InitializeWorkflow(self, fs, o): if o.HasField("continued_failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.continued_failure) if o.HasField("last_completion_result"): - await self._visit_temporal_api_common_v1_Payloads( - fs, o.last_completion_result - ) + await self._visit_temporal_api_common_v1_Payloads(fs, o.last_completion_result) if o.HasField("memo"): await self._visit_temporal_api_common_v1_Memo(fs, o.memo) if o.HasField("search_attributes"): - await self._visit_temporal_api_common_v1_SearchAttributes( - fs, o.search_attributes - ) + await self._visit_temporal_api_common_v1_SearchAttributes(fs, o.search_attributes) async def _visit_coresdk_workflow_activation_QueryWorkflow(self, fs, o): await self._visit_payload_container(fs, o.arguments) @@ -157,19 +166,13 @@ async def _visit_coresdk_workflow_activation_ResolveActivity(self, fs, o): if o.HasField("result"): await self._visit_coresdk_activity_result_ActivityResolution(fs, o.result) - async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStartCancelled( - self, fs, o - ): + async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStartCancelled(self, fs, o): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) - async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStart( - self, fs, o - ): + async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStart(self, fs, o): if o.HasField("cancelled"): - await self._visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStartCancelled( - fs, o.cancelled - ) + await self._visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStartCancelled(fs, o.cancelled) async def _visit_coresdk_child_workflow_Success(self, fs, o): if o.HasField("result"): @@ -191,21 +194,15 @@ async def _visit_coresdk_child_workflow_ChildWorkflowResult(self, fs, o): elif o.HasField("cancelled"): await self._visit_coresdk_child_workflow_Cancellation(fs, o.cancelled) - async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecution( - self, fs, o - ): + async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecution(self, fs, o): if o.HasField("result"): await self._visit_coresdk_child_workflow_ChildWorkflowResult(fs, o.result) - async def _visit_coresdk_workflow_activation_ResolveSignalExternalWorkflow( - self, fs, o - ): + async def _visit_coresdk_workflow_activation_ResolveSignalExternalWorkflow(self, fs, o): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) - async def _visit_coresdk_workflow_activation_ResolveRequestCancelExternalWorkflow( - self, fs, o - ): + async def _visit_coresdk_workflow_activation_ResolveRequestCancelExternalWorkflow(self, fs, o): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) @@ -215,9 +212,7 @@ async def _visit_coresdk_workflow_activation_DoUpdate(self, fs, o): for v in o.headers.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) - async def _visit_coresdk_workflow_activation_ResolveNexusOperationStart( - self, fs, o - ): + async def _visit_coresdk_workflow_activation_ResolveNexusOperationStart(self, fs, o): if o.HasField("failed"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failed) @@ -237,47 +232,27 @@ async def _visit_coresdk_workflow_activation_ResolveNexusOperation(self, fs, o): async def _visit_coresdk_workflow_activation_WorkflowActivationJob(self, fs, o): if o.HasField("initialize_workflow"): - await self._visit_coresdk_workflow_activation_InitializeWorkflow( - fs, o.initialize_workflow - ) + await self._visit_coresdk_workflow_activation_InitializeWorkflow(fs, o.initialize_workflow) elif o.HasField("query_workflow"): - await self._visit_coresdk_workflow_activation_QueryWorkflow( - fs, o.query_workflow - ) + await self._visit_coresdk_workflow_activation_QueryWorkflow(fs, o.query_workflow) elif o.HasField("signal_workflow"): - await self._visit_coresdk_workflow_activation_SignalWorkflow( - fs, o.signal_workflow - ) + await self._visit_coresdk_workflow_activation_SignalWorkflow(fs, o.signal_workflow) elif o.HasField("resolve_activity"): - await self._visit_coresdk_workflow_activation_ResolveActivity( - fs, o.resolve_activity - ) + await self._visit_coresdk_workflow_activation_ResolveActivity(fs, o.resolve_activity) elif o.HasField("resolve_child_workflow_execution_start"): - await self._visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStart( - fs, o.resolve_child_workflow_execution_start - ) + await self._visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStart(fs, o.resolve_child_workflow_execution_start) elif o.HasField("resolve_child_workflow_execution"): - await self._visit_coresdk_workflow_activation_ResolveChildWorkflowExecution( - fs, o.resolve_child_workflow_execution - ) + await self._visit_coresdk_workflow_activation_ResolveChildWorkflowExecution(fs, o.resolve_child_workflow_execution) elif o.HasField("resolve_signal_external_workflow"): - await self._visit_coresdk_workflow_activation_ResolveSignalExternalWorkflow( - fs, o.resolve_signal_external_workflow - ) + await self._visit_coresdk_workflow_activation_ResolveSignalExternalWorkflow(fs, o.resolve_signal_external_workflow) elif o.HasField("resolve_request_cancel_external_workflow"): - await self._visit_coresdk_workflow_activation_ResolveRequestCancelExternalWorkflow( - fs, o.resolve_request_cancel_external_workflow - ) + await self._visit_coresdk_workflow_activation_ResolveRequestCancelExternalWorkflow(fs, o.resolve_request_cancel_external_workflow) elif o.HasField("do_update"): await self._visit_coresdk_workflow_activation_DoUpdate(fs, o.do_update) elif o.HasField("resolve_nexus_operation_start"): - await self._visit_coresdk_workflow_activation_ResolveNexusOperationStart( - fs, o.resolve_nexus_operation_start - ) + await self._visit_coresdk_workflow_activation_ResolveNexusOperationStart(fs, o.resolve_nexus_operation_start) elif o.HasField("resolve_nexus_operation"): - await self._visit_coresdk_workflow_activation_ResolveNexusOperation( - fs, o.resolve_nexus_operation - ) + await self._visit_coresdk_workflow_activation_ResolveNexusOperation(fs, o.resolve_nexus_operation) async def _visit_coresdk_workflow_activation_WorkflowActivation(self, fs, o): for v in o.jobs: @@ -313,9 +288,7 @@ async def _visit_coresdk_workflow_commands_FailWorkflowExecution(self, fs, o): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) - async def _visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution( - self, fs, o - ): + async def _visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution(self, fs, o): await self._visit_payload_container(fs, o.arguments) for v in o.memo.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) @@ -323,9 +296,7 @@ async def _visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution( for v in o.headers.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) if o.HasField("search_attributes"): - await self._visit_temporal_api_common_v1_SearchAttributes( - fs, o.search_attributes - ) + await self._visit_temporal_api_common_v1_SearchAttributes(fs, o.search_attributes) async def _visit_coresdk_workflow_commands_StartChildWorkflowExecution(self, fs, o): await self._visit_payload_container(fs, o.input) @@ -335,13 +306,9 @@ async def _visit_coresdk_workflow_commands_StartChildWorkflowExecution(self, fs, for v in o.memo.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) if o.HasField("search_attributes"): - await self._visit_temporal_api_common_v1_SearchAttributes( - fs, o.search_attributes - ) + await self._visit_temporal_api_common_v1_SearchAttributes(fs, o.search_attributes) - async def _visit_coresdk_workflow_commands_SignalExternalWorkflowExecution( - self, fs, o - ): + async def _visit_coresdk_workflow_commands_SignalExternalWorkflowExecution(self, fs, o): await self._visit_payload_container(fs, o.args) if not self.skip_headers: for v in o.headers.values(): @@ -353,13 +320,9 @@ async def _visit_coresdk_workflow_commands_ScheduleLocalActivity(self, fs, o): await self._visit_temporal_api_common_v1_Payload(fs, v) await self._visit_payload_container(fs, o.arguments) - async def _visit_coresdk_workflow_commands_UpsertWorkflowSearchAttributes( - self, fs, o - ): + async def _visit_coresdk_workflow_commands_UpsertWorkflowSearchAttributes(self, fs, o): if o.HasField("search_attributes"): - await self._visit_temporal_api_common_v1_SearchAttributes( - fs, o.search_attributes - ) + await self._visit_temporal_api_common_v1_SearchAttributes(fs, o.search_attributes) async def _visit_coresdk_workflow_commands_ModifyWorkflowProperties(self, fs, o): if o.HasField("upserted_memo"): @@ -373,59 +336,35 @@ async def _visit_coresdk_workflow_commands_UpdateResponse(self, fs, o): async def _visit_coresdk_workflow_commands_ScheduleNexusOperation(self, fs, o): if o.HasField("input"): - await self._visit_temporal_api_common_v1_Payload(fs, o.input) + await self._visit_system_nexus_payload(fs, o.service, o.operation, o.input) async def _visit_coresdk_workflow_commands_WorkflowCommand(self, fs, o): if o.HasField("user_metadata"): await self._visit_temporal_api_sdk_v1_UserMetadata(fs, o.user_metadata) if o.HasField("schedule_activity"): - await self._visit_coresdk_workflow_commands_ScheduleActivity( - fs, o.schedule_activity - ) + await self._visit_coresdk_workflow_commands_ScheduleActivity(fs, o.schedule_activity) elif o.HasField("respond_to_query"): - await self._visit_coresdk_workflow_commands_QueryResult( - fs, o.respond_to_query - ) + await self._visit_coresdk_workflow_commands_QueryResult(fs, o.respond_to_query) elif o.HasField("complete_workflow_execution"): - await self._visit_coresdk_workflow_commands_CompleteWorkflowExecution( - fs, o.complete_workflow_execution - ) + await self._visit_coresdk_workflow_commands_CompleteWorkflowExecution(fs, o.complete_workflow_execution) elif o.HasField("fail_workflow_execution"): - await self._visit_coresdk_workflow_commands_FailWorkflowExecution( - fs, o.fail_workflow_execution - ) + await self._visit_coresdk_workflow_commands_FailWorkflowExecution(fs, o.fail_workflow_execution) elif o.HasField("continue_as_new_workflow_execution"): - await self._visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution( - fs, o.continue_as_new_workflow_execution - ) + await self._visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution(fs, o.continue_as_new_workflow_execution) elif o.HasField("start_child_workflow_execution"): - await self._visit_coresdk_workflow_commands_StartChildWorkflowExecution( - fs, o.start_child_workflow_execution - ) + await self._visit_coresdk_workflow_commands_StartChildWorkflowExecution(fs, o.start_child_workflow_execution) elif o.HasField("signal_external_workflow_execution"): - await self._visit_coresdk_workflow_commands_SignalExternalWorkflowExecution( - fs, o.signal_external_workflow_execution - ) + await self._visit_coresdk_workflow_commands_SignalExternalWorkflowExecution(fs, o.signal_external_workflow_execution) elif o.HasField("schedule_local_activity"): - await self._visit_coresdk_workflow_commands_ScheduleLocalActivity( - fs, o.schedule_local_activity - ) + await self._visit_coresdk_workflow_commands_ScheduleLocalActivity(fs, o.schedule_local_activity) elif o.HasField("upsert_workflow_search_attributes"): - await self._visit_coresdk_workflow_commands_UpsertWorkflowSearchAttributes( - fs, o.upsert_workflow_search_attributes - ) + await self._visit_coresdk_workflow_commands_UpsertWorkflowSearchAttributes(fs, o.upsert_workflow_search_attributes) elif o.HasField("modify_workflow_properties"): - await self._visit_coresdk_workflow_commands_ModifyWorkflowProperties( - fs, o.modify_workflow_properties - ) + await self._visit_coresdk_workflow_commands_ModifyWorkflowProperties(fs, o.modify_workflow_properties) elif o.HasField("update_response"): - await self._visit_coresdk_workflow_commands_UpdateResponse( - fs, o.update_response - ) + await self._visit_coresdk_workflow_commands_UpdateResponse(fs, o.update_response) elif o.HasField("schedule_nexus_operation"): - await self._visit_coresdk_workflow_commands_ScheduleNexusOperation( - fs, o.schedule_nexus_operation - ) + await self._visit_coresdk_workflow_commands_ScheduleNexusOperation(fs, o.schedule_nexus_operation) async def _visit_coresdk_workflow_completion_Success(self, fs, o): for v in o.commands: @@ -435,9 +374,7 @@ async def _visit_coresdk_workflow_completion_Failure(self, fs, o): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) - async def _visit_coresdk_workflow_completion_WorkflowActivationCompletion( - self, fs, o - ): + async def _visit_coresdk_workflow_completion_WorkflowActivationCompletion(self, fs, o): if o.HasField("successful"): await self._visit_coresdk_workflow_completion_Success(fs, o.successful) elif o.HasField("failed"): diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index 9f0a43b0e..996ed8c19 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -20,9 +20,7 @@ import temporalio.bridge.runtime import temporalio.bridge.temporal_sdk_bridge import temporalio.converter -import temporalio.nexus.system from temporalio.api.common.v1.message_pb2 import Payload -from temporalio.api.enums.v1.command_type_pb2 import CommandType from temporalio.bridge._visitor import VisitorFunctions from temporalio.bridge.temporal_sdk_bridge import ( CustomSlotSupplier as BridgeCustomSlotSupplier, @@ -30,7 +28,6 @@ from temporalio.bridge.temporal_sdk_bridge import ( PollShutdownError, # type: ignore # noqa: F401 ) -from temporalio.worker import _command_aware_visitor from temporalio.worker._command_aware_visitor import CommandAwarePayloadVisitor @@ -285,8 +282,10 @@ class _Visitor(VisitorFunctions): def __init__( self, f: Callable[[Sequence[Payload]], Awaitable[list[Payload]]], + visit_system_nexus_envelope: Callable[[Payload], Awaitable[None]] | None = None, ): self._f = f + self._visit_system_nexus_envelope = visit_system_nexus_envelope async def visit_payload(self, payload: Payload) -> None: new_payload = (await self._f([payload]))[0] @@ -302,41 +301,9 @@ async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: del payloads[:] payloads.extend(new_payloads) - -async def _encode_completion_payloads( - data_converter: temporalio.converter.DataConverter, - payloads: Sequence[Payload], -) -> list[Payload]: - if len(payloads) != 1: - return await data_converter._encode_payload_sequence(payloads) - - # A single payload may be the outer envelope for a system Nexus operation. - # In that case we leave the envelope itself unencoded so the server can read - # it, but still route any nested Temporal payloads through normal payload - # processing via the generated operation-specific rewriter. - payload = payloads[0] - command_info = _command_aware_visitor.current_command_info.get() - if ( - command_info is None - or command_info.command_type - != CommandType.COMMAND_TYPE_SCHEDULE_NEXUS_OPERATION - or not command_info.nexus_service - or not command_info.nexus_operation - ): - return await data_converter._encode_payload_sequence(payloads) - - rewrite = temporalio.nexus.system.get_payload_rewriter( - command_info.nexus_service, command_info.nexus_operation - ) - if rewrite is None: - return await data_converter._encode_payload_sequence(payloads) - - new_payload = await rewrite( - payload, - data_converter._encode_payload_sequence, - False, - ) - return [new_payload] + async def visit_system_nexus_envelope(self, payload: Payload) -> None: + if self._visit_system_nexus_envelope is not None: + await self._visit_system_nexus_envelope(payload) async def decode_activation( @@ -356,11 +323,16 @@ async def encode_completion( encode_headers: bool, ) -> None: """Encode all payloads in the completion.""" + + async def visit_system_nexus_envelope(payload: Payload) -> None: + data_converter._validate_payload_limits([payload]) + await CommandAwarePayloadVisitor( skip_search_attributes=True, skip_headers=not encode_headers ).visit( _Visitor( - lambda payloads: _encode_completion_payloads(data_converter, payloads) + data_converter._encode_payload_sequence, + visit_system_nexus_envelope=visit_system_nexus_envelope, ), completion, ) diff --git a/temporalio/worker/_command_aware_visitor.py b/temporalio/worker/_command_aware_visitor.py index 85c38ff06..10aea1422 100644 --- a/temporalio/worker/_command_aware_visitor.py +++ b/temporalio/worker/_command_aware_visitor.py @@ -31,8 +31,6 @@ class CommandInfo: command_type: CommandType.ValueType command_seq: int - nexus_service: str | None = None - nexus_operation: str | None = None current_command_info: contextvars.ContextVar[CommandInfo | None] = ( @@ -86,8 +84,6 @@ async def _visit_coresdk_workflow_commands_ScheduleNexusOperation( with current_command( CommandType.COMMAND_TYPE_SCHEDULE_NEXUS_OPERATION, o.seq, - nexus_service=o.service, - nexus_operation=o.operation, ): await super()._visit_coresdk_workflow_commands_ScheduleNexusOperation(fs, o) @@ -159,17 +155,12 @@ async def _visit_coresdk_workflow_activation_ResolveNexusOperation( def current_command( command_type: CommandType.ValueType, command_seq: int, - *, - nexus_service: str | None = None, - nexus_operation: str | None = None, ) -> Iterator[None]: """Context manager for setting command info.""" token = current_command_info.set( CommandInfo( command_type=command_type, command_seq=command_seq, - nexus_service=nexus_service, - nexus_operation=nexus_operation, ) ) try: diff --git a/tests/worker/test_visitor.py b/tests/worker/test_visitor.py index 5604b8542..b913cbffb 100644 --- a/tests/worker/test_visitor.py +++ b/tests/worker/test_visitor.py @@ -1,10 +1,14 @@ import dataclasses +import json from collections.abc import MutableSequence +import pytest from google.protobuf.duration_pb2 import Duration +from google.protobuf.json_format import MessageToDict import temporalio.bridge.worker import temporalio.converter +from temporalio import nexus from temporalio.api.common.v1.message_pb2 import ( Payload, Payloads, @@ -22,6 +26,7 @@ ContinueAsNewWorkflowExecution, ScheduleActivity, ScheduleLocalActivity, + ScheduleNexusOperation, SignalExternalWorkflowExecution, StartChildWorkflowExecution, UpdateResponse, @@ -31,6 +36,10 @@ Success, WorkflowActivationCompletion, ) +from temporalio.converter._payload_limits import ( + _PayloadSizeError, + _ServerPayloadErrorLimits, +) from tests.worker.test_workflow import SimpleCodec @@ -42,6 +51,9 @@ async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: for payload in payloads: payload.metadata["visited"] = b"True" + async def visit_system_nexus_envelope(self, payload: Payload) -> None: + payload.metadata["visited"] = b"True" + async def test_workflow_activation_completion(): comp = WorkflowActivationCompletion( @@ -205,6 +217,139 @@ async def test_visit_payloads_on_other_commands(): assert ur.completed.metadata["visited"] +async def test_visit_system_nexus_payloads_on_schedule_nexus_operation(): + envelope = ( + nexus.system.generated.WorkflowServiceSignalWithStartWorkflowExecutionInput( + namespace="default", + workflowId="workflow-id", + signalName="signal-name", + input=nexus.system.generated.Input( + payloads=[ + MessageToDict( + Payload( + metadata={"encoding": b"json/plain"}, data=b'"input-value"' + ) + ) + ] + ), + signalInput=nexus.system.generated.Input( + payloads=[ + MessageToDict( + Payload( + metadata={"encoding": b"json/plain"}, data=b'"signal-value"' + ) + ) + ] + ), + memo=nexus.system.generated.Memo( + fields={ + "memo-key": MessageToDict( + Payload( + metadata={"encoding": b"json/plain"}, data=b'"memo-value"' + ) + ) + } + ), + searchAttributes=nexus.system.generated.SearchAttributes( + indexedFields={ + "search-key": MessageToDict( + Payload( + metadata={"encoding": b"json/plain"}, data=b'"search-value"' + ) + ) + } + ), + ) + ) + comp = WorkflowActivationCompletion( + run_id="1", + successful=Success( + commands=[ + WorkflowCommand( + schedule_nexus_operation=ScheduleNexusOperation( + seq=1, + service="WorkflowService", + operation="SignalWithStartWorkflowExecution", + input=Payload( + metadata={"encoding": b"json/plain"}, + data=json.dumps( + dataclasses.asdict(envelope), + separators=(",", ":"), + sort_keys=True, + ).encode(), + ), + ) + ) + ], + ), + ) + + await PayloadVisitor(skip_search_attributes=True).visit(Visitor(), comp) + + input_payload = comp.successful.commands[0].schedule_nexus_operation.input + assert input_payload.metadata["visited"] + rewritten = json.loads(input_payload.data) + assert rewritten["input"]["payloads"][0]["metadata"]["visited"] == "VHJ1ZQ==" + assert rewritten["signalInput"]["payloads"][0]["metadata"]["visited"] == "VHJ1ZQ==" + assert rewritten["memo"]["fields"]["memo-key"]["metadata"]["visited"] == "VHJ1ZQ==" + assert ( + "visited" + not in rewritten["searchAttributes"]["indexedFields"]["search-key"]["metadata"] + ) + + +async def test_bridge_encoding_checks_system_nexus_envelope_size(): + envelope = ( + nexus.system.generated.WorkflowServiceSignalWithStartWorkflowExecutionInput( + namespace="default", + workflowId="workflow-id", + signalName="signal-name", + requestId="x" * 2048, + input=nexus.system.generated.Input( + payloads=[ + MessageToDict( + Payload( + metadata={"encoding": b"json/plain"}, data=b'"input-value"' + ) + ) + ] + ), + ) + ) + comp = WorkflowActivationCompletion( + run_id="1", + successful=Success( + commands=[ + WorkflowCommand( + schedule_nexus_operation=ScheduleNexusOperation( + seq=1, + service="WorkflowService", + operation="SignalWithStartWorkflowExecution", + input=Payload( + metadata={"encoding": b"json/plain"}, + data=json.dumps( + dataclasses.asdict(envelope), + separators=(",", ":"), + sort_keys=True, + ).encode(), + ), + ) + ) + ], + ), + ) + + data_converter = temporalio.converter.default()._with_payload_error_limits( + _ServerPayloadErrorLimits( + memo_size_error=1024 * 1024, + payload_size_error=512, + ) + ) + + with pytest.raises(_PayloadSizeError, match="payloads with size that exceeded"): + await temporalio.bridge.worker.encode_completion(comp, data_converter, True) + + async def test_bridge_encoding(): comp = WorkflowActivationCompletion( run_id="1", From f1c943726c3d607706bf47cb73d127bb9452b7d7 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Tue, 14 Apr 2026 12:38:24 -0700 Subject: [PATCH 09/18] Rename Python temporal nexus payload visitors --- scripts/gen_payload_visitor.py | 6 +- temporalio/bridge/_visitor.py | 6 +- temporalio/nexus/system/__init__.py | 16 +- .../system/_workflow_service_generated.py | 138 +++++++++--------- 4 files changed, 81 insertions(+), 85 deletions(-) diff --git a/scripts/gen_payload_visitor.py b/scripts/gen_payload_visitor.py index a4d3f88c3..0e86e0f56 100644 --- a/scripts/gen_payload_visitor.py +++ b/scripts/gen_payload_visitor.py @@ -134,8 +134,8 @@ async def visit( async def _visit_system_nexus_payload(self, fs, service, operation, payload) -> None: import temporalio.nexus.system - rewrite = temporalio.nexus.system.get_payload_rewriter(service, operation) - if rewrite is None: + visitor = temporalio.nexus.system.get_payload_visitor(service, operation) + if visitor is None: await self._visit_temporal_api_common_v1_Payload(fs, payload) return @@ -144,7 +144,7 @@ async def payload_visitor(payloads): await fs.visit_payloads(new_payloads) return new_payloads - new_payload = await rewrite( + new_payload = await visitor( payload, payload_visitor, not self.skip_search_attributes ) if new_payload is not payload: diff --git a/temporalio/bridge/_visitor.py b/temporalio/bridge/_visitor.py index c496c53ff..1516e27d4 100644 --- a/temporalio/bridge/_visitor.py +++ b/temporalio/bridge/_visitor.py @@ -50,8 +50,8 @@ async def visit( async def _visit_system_nexus_payload(self, fs, service, operation, payload) -> None: import temporalio.nexus.system - rewrite = temporalio.nexus.system.get_payload_rewriter(service, operation) - if rewrite is None: + visitor = temporalio.nexus.system.get_payload_visitor(service, operation) + if visitor is None: await self._visit_temporal_api_common_v1_Payload(fs, payload) return @@ -60,7 +60,7 @@ async def payload_visitor(payloads): await fs.visit_payloads(new_payloads) return new_payloads - new_payload = await rewrite( + new_payload = await visitor( payload, payload_visitor, not self.skip_search_attributes ) if new_payload is not payload: diff --git a/temporalio/nexus/system/__init__.py b/temporalio/nexus/system/__init__.py index bb1a2f37e..bbce702d1 100644 --- a/temporalio/nexus/system/__init__.py +++ b/temporalio/nexus/system/__init__.py @@ -10,9 +10,9 @@ import temporalio.converter from . import _workflow_service_generated as generated -from ._workflow_service_generated import __temporal_nexus_payload_rewriters__ +from ._workflow_service_generated import __temporal_nexus_payload_visitors__ -TemporalNexusPayloadRewriter = Callable[ +TemporalNexusPayloadVisitor = Callable[ [ temporalio.api.common.v1.Payload, Callable[ @@ -27,17 +27,17 @@ _SYSTEM_NEXUS_PAYLOAD_CONVERTER = temporalio.converter.default().payload_converter -def get_payload_rewriter( +def get_payload_visitor( service: str, operation: str, -) -> TemporalNexusPayloadRewriter | None: - """Return the generated nested-payload rewriter for a system Nexus operation.""" - return __temporal_nexus_payload_rewriters__.get((service, operation)) +) -> TemporalNexusPayloadVisitor | None: + """Return the generated nested-payload visitor for a system Nexus operation.""" + return __temporal_nexus_payload_visitors__.get((service, operation)) def is_system_operation(service: str, operation: str) -> bool: """Return whether a Nexus operation uses the generated system envelope.""" - return get_payload_rewriter(service, operation) is not None + return get_payload_visitor(service, operation) is not None def get_payload_converter() -> temporalio.converter.PayloadConverter: @@ -48,6 +48,6 @@ def get_payload_converter() -> temporalio.converter.PayloadConverter: __all__ = ( "generated", "get_payload_converter", - "get_payload_rewriter", + "get_payload_visitor", "is_system_operation", ) diff --git a/temporalio/nexus/system/_workflow_service_generated.py b/temporalio/nexus/system/_workflow_service_generated.py index 4ac22c76e..684132b96 100644 --- a/temporalio/nexus/system/_workflow_service_generated.py +++ b/temporalio/nexus/system/_workflow_service_generated.py @@ -701,7 +701,7 @@ class WorkflowService: ] = Operation(name="SignalWithStartWorkflowExecution") -class _TemporalNexusPayloadRewriter: +class _TemporalNexusPayloadVisitor: def __init__( self, payload_visitor: collections.abc.Callable[ @@ -713,105 +713,101 @@ def __init__( self._payload_visitor = payload_visitor self._visit_search_attributes = visit_search_attributes - async def _rewrite_payload_json(self, value: dict) -> dict: + async def _visit_payload_json(self, value: dict) -> dict: payload = ParseDict(value, temporalio.api.common.v1.Payload()) - [rewritten_payload] = await self._payload_visitor([payload]) - return MessageToDict(rewritten_payload) + [visited_payload] = await self._payload_visitor([payload]) + return MessageToDict(visited_payload) - async def _rewrite_payloads_json(self, value: dict) -> dict: + async def _visit_payloads_json(self, value: dict) -> dict: payloads = ParseDict(value, temporalio.api.common.v1.Payloads()) - rewritten_payloads = await self._payload_visitor(payloads.payloads) + visited_payloads = await self._payload_visitor(payloads.payloads) del payloads.payloads[:] - payloads.payloads.extend(rewritten_payloads) + payloads.payloads.extend(visited_payloads) return MessageToDict(payloads) - async def _rewrite_payload_map_json(self, message_type: type, value: dict) -> dict: + async def _visit_payload_map_json(self, message_type: type, value: dict) -> dict: message = message_type() keys = list(value.keys()) - rewritten_payloads = await self._payload_visitor( + visited_payloads = await self._payload_visitor( [ParseDict(value[key], temporalio.api.common.v1.Payload()) for key in keys] ) - for key, rewritten_payload in zip(keys, rewritten_payloads): - message.fields[key].CopyFrom(rewritten_payload) + for key, visited_payload in zip(keys, visited_payloads): + message.fields[key].CopyFrom(visited_payload) return MessageToDict(message).get("fields", {}) - async def _temporal_nexus_rewrite_header_json(self, value: dict) -> dict: - rewritten = dict(value) - if rewritten.get("fields") is not None: - rewritten["fields"] = await self._rewrite_payload_map_json( - temporalio.api.common.v1.Header, rewritten["fields"] + async def _temporal_nexus_visit_header_json(self, value: dict) -> dict: + visited = dict(value) + if visited.get("fields") is not None: + visited["fields"] = await self._visit_payload_map_json( + temporalio.api.common.v1.Header, visited["fields"] ) - return rewritten + return visited - async def _temporal_nexus_rewrite_input_json(self, value: dict) -> dict: - return await self._rewrite_payloads_json(value) + async def _temporal_nexus_visit_input_json(self, value: dict) -> dict: + return await self._visit_payloads_json(value) - async def _temporal_nexus_rewrite_memo_json(self, value: dict) -> dict: - rewritten = dict(value) - if rewritten.get("fields") is not None: - rewritten["fields"] = await self._rewrite_payload_map_json( - temporalio.api.common.v1.Memo, rewritten["fields"] + async def _temporal_nexus_visit_memo_json(self, value: dict) -> dict: + visited = dict(value) + if visited.get("fields") is not None: + visited["fields"] = await self._visit_payload_map_json( + temporalio.api.common.v1.Memo, visited["fields"] ) - return rewritten + return visited - async def _temporal_nexus_rewrite_search_attributes_json(self, value: dict) -> dict: + async def _temporal_nexus_visit_search_attributes_json(self, value: dict) -> dict: if not self._visit_search_attributes: return value - rewritten = dict(value) - if rewritten.get("indexedFields") is not None: - rewritten["indexedFields"] = await self._rewrite_payload_map_json( - temporalio.api.common.v1.SearchAttributes, rewritten["indexedFields"] + visited = dict(value) + if visited.get("indexedFields") is not None: + visited["indexedFields"] = await self._visit_payload_map_json( + temporalio.api.common.v1.SearchAttributes, visited["indexedFields"] ) - return rewritten + return visited - async def _temporal_nexus_rewrite_user_metadata_json(self, value: dict) -> dict: - rewritten = dict(value) - if rewritten.get("details") is not None: - rewritten["details"] = await self._rewrite_payload_json( - rewritten["details"] - ) - if rewritten.get("summary") is not None: - rewritten["summary"] = await self._rewrite_payload_json( - rewritten["summary"] - ) - return rewritten + async def _temporal_nexus_visit_user_metadata_json(self, value: dict) -> dict: + visited = dict(value) + if visited.get("details") is not None: + visited["details"] = await self._visit_payload_json(visited["details"]) + if visited.get("summary") is not None: + visited["summary"] = await self._visit_payload_json(visited["summary"]) + return visited - async def _temporal_nexus_rewrite_workflow_service_signal_with_start_workflow_execution_input_json( + async def _temporal_nexus_visit_workflow_service_signal_with_start_workflow_execution_input_json( self, value: dict ) -> dict: - rewritten = dict(value) - if rewritten.get("header") is not None: - rewritten["header"] = await self._temporal_nexus_rewrite_header_json( - rewritten["header"] + visited = dict(value) + if visited.get("header") is not None: + visited["header"] = await self._temporal_nexus_visit_header_json( + visited["header"] ) - if rewritten.get("input") is not None: - rewritten["input"] = await self._temporal_nexus_rewrite_input_json( - rewritten["input"] + if visited.get("input") is not None: + visited["input"] = await self._temporal_nexus_visit_input_json( + visited["input"] ) - if rewritten.get("memo") is not None: - rewritten["memo"] = await self._temporal_nexus_rewrite_memo_json( - rewritten["memo"] + if visited.get("memo") is not None: + visited["memo"] = await self._temporal_nexus_visit_memo_json( + visited["memo"] ) - if rewritten.get("searchAttributes") is not None: - rewritten[ + if visited.get("searchAttributes") is not None: + visited[ "searchAttributes" - ] = await self._temporal_nexus_rewrite_search_attributes_json( - rewritten["searchAttributes"] + ] = await self._temporal_nexus_visit_search_attributes_json( + visited["searchAttributes"] ) - if rewritten.get("signalInput") is not None: - rewritten["signalInput"] = await self._temporal_nexus_rewrite_input_json( - rewritten["signalInput"] + if visited.get("signalInput") is not None: + visited["signalInput"] = await self._temporal_nexus_visit_input_json( + visited["signalInput"] ) - if rewritten.get("userMetadata") is not None: - rewritten[ + if visited.get("userMetadata") is not None: + visited[ "userMetadata" - ] = await self._temporal_nexus_rewrite_user_metadata_json( - rewritten["userMetadata"] + ] = await self._temporal_nexus_visit_user_metadata_json( + visited["userMetadata"] ) - return rewritten + return visited -async def _temporal_nexus_rewrite_workflow_service_signal_with_start_workflow_execution_input( +async def _temporal_nexus_visit_workflow_service_signal_with_start_workflow_execution_input( payload: temporalio.api.common.v1.Payload, payload_visitor: collections.abc.Callable[ [collections.abc.Sequence[temporalio.api.common.v1.Payload]], @@ -825,19 +821,19 @@ async def _temporal_nexus_rewrite_workflow_service_signal_with_start_workflow_ex return payload if not isinstance(value, dict): return payload - rewriter = _TemporalNexusPayloadRewriter(payload_visitor, visit_search_attributes) - rewritten = await rewriter._temporal_nexus_rewrite_workflow_service_signal_with_start_workflow_execution_input_json( + visitor = _TemporalNexusPayloadVisitor(payload_visitor, visit_search_attributes) + visited = await visitor._temporal_nexus_visit_workflow_service_signal_with_start_workflow_execution_input_json( value ) return temporalio.api.common.v1.Payload( metadata=dict(payload.metadata), - data=json.dumps(rewritten, separators=(",", ":"), sort_keys=True).encode(), + data=json.dumps(visited, separators=(",", ":"), sort_keys=True).encode(), ) -__temporal_nexus_payload_rewriters__ = { +__temporal_nexus_payload_visitors__ = { ( "WorkflowService", "SignalWithStartWorkflowExecution", - ): _temporal_nexus_rewrite_workflow_service_signal_with_start_workflow_execution_input, + ): _temporal_nexus_visit_workflow_service_signal_with_start_workflow_execution_input, } From 2999e9d58cf3c90566741c2f6a4f89849e24eabe Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Wed, 15 Apr 2026 13:56:42 -0700 Subject: [PATCH 10/18] Work on user API --- temporalio/bridge/_visitor.py | 198 +++++++++++----- .../contrib/opentelemetry/_interceptor.py | 10 + .../opentelemetry/_otel_interceptor.py | 10 + temporalio/nexus/system/__init__.py | 212 +++++++++++++++++- .../system/_workflow_service_generated.py | 26 +-- temporalio/worker/__init__.py | 2 + temporalio/worker/_interceptor.py | 39 ++++ temporalio/worker/_workflow_instance.py | 122 ++++++++++ temporalio/workflow.py | 71 ++++++ tests/nexus/test_temporal_system_nexus.py | 146 +++++++++++- tests/worker/test_visitor.py | 16 +- 11 files changed, 766 insertions(+), 86 deletions(-) diff --git a/temporalio/bridge/_visitor.py b/temporalio/bridge/_visitor.py index 1516e27d4..3a2d98818 100644 --- a/temporalio/bridge/_visitor.py +++ b/temporalio/bridge/_visitor.py @@ -1,4 +1,3 @@ - # This file is generated by gen_payload_visitor.py. Changes should be made there. import abc from typing import Any, MutableSequence @@ -7,9 +6,10 @@ class VisitorFunctions(abc.ABC): - """Set of functions which can be called by the visitor. + """Set of functions which can be called by the visitor. Allows handling payloads as a sequence. """ + @abc.abstractmethod async def visit_payload(self, payload: Payload) -> None: """Called when encountering a single payload.""" @@ -25,10 +25,12 @@ async def visit_system_nexus_envelope(self, payload: Payload) -> None: """Called when encountering a recognized system Nexus envelope payload.""" raise NotImplementedError() + class PayloadVisitor: - """A visitor for payloads. + """A visitor for payloads. Applies a function to every payload in a tree of messages. """ + def __init__( self, *, skip_search_attributes: bool = False, skip_headers: bool = False ): @@ -36,9 +38,7 @@ def __init__( self.skip_search_attributes = skip_search_attributes self.skip_headers = skip_headers - async def visit( - self, fs: VisitorFunctions, root: Any - ) -> None: + async def visit(self, fs: VisitorFunctions, root: Any) -> None: """Visits the given root message with the given function.""" method_name = "_visit_" + root.DESCRIPTOR.full_name.replace(".", "_") method = getattr(self, method_name, None) @@ -47,7 +47,9 @@ async def visit( else: raise ValueError(f"Unknown root message type: {root.DESCRIPTOR.full_name}") - async def _visit_system_nexus_payload(self, fs, service, operation, payload) -> None: + async def _visit_system_nexus_payload( + self, fs, service, operation, payload + ) -> None: import temporalio.nexus.system visitor = temporalio.nexus.system.get_payload_visitor(service, operation) @@ -69,20 +71,22 @@ async def payload_visitor(payloads): async def _visit_temporal_api_common_v1_Payload(self, fs, o): await fs.visit_payload(o) - + async def _visit_temporal_api_common_v1_Payloads(self, fs, o): await fs.visit_payloads(o.payloads) - + async def _visit_payload_container(self, fs, o): await fs.visit_payloads(o) - + async def _visit_temporal_api_failure_v1_ApplicationFailureInfo(self, fs, o): if o.HasField("details"): await self._visit_temporal_api_common_v1_Payloads(fs, o.details) async def _visit_temporal_api_failure_v1_TimeoutFailureInfo(self, fs, o): if o.HasField("last_heartbeat_details"): - await self._visit_temporal_api_common_v1_Payloads(fs, o.last_heartbeat_details) + await self._visit_temporal_api_common_v1_Payloads( + fs, o.last_heartbeat_details + ) async def _visit_temporal_api_failure_v1_CanceledFailureInfo(self, fs, o): if o.HasField("details"): @@ -90,7 +94,9 @@ async def _visit_temporal_api_failure_v1_CanceledFailureInfo(self, fs, o): async def _visit_temporal_api_failure_v1_ResetWorkflowFailureInfo(self, fs, o): if o.HasField("last_heartbeat_details"): - await self._visit_temporal_api_common_v1_Payloads(fs, o.last_heartbeat_details) + await self._visit_temporal_api_common_v1_Payloads( + fs, o.last_heartbeat_details + ) async def _visit_temporal_api_failure_v1_Failure(self, fs, o): if o.HasField("encoded_attributes"): @@ -98,13 +104,21 @@ async def _visit_temporal_api_failure_v1_Failure(self, fs, o): if o.HasField("cause"): await self._visit_temporal_api_failure_v1_Failure(fs, o.cause) if o.HasField("application_failure_info"): - await self._visit_temporal_api_failure_v1_ApplicationFailureInfo(fs, o.application_failure_info) + await self._visit_temporal_api_failure_v1_ApplicationFailureInfo( + fs, o.application_failure_info + ) elif o.HasField("timeout_failure_info"): - await self._visit_temporal_api_failure_v1_TimeoutFailureInfo(fs, o.timeout_failure_info) + await self._visit_temporal_api_failure_v1_TimeoutFailureInfo( + fs, o.timeout_failure_info + ) elif o.HasField("canceled_failure_info"): - await self._visit_temporal_api_failure_v1_CanceledFailureInfo(fs, o.canceled_failure_info) + await self._visit_temporal_api_failure_v1_CanceledFailureInfo( + fs, o.canceled_failure_info + ) elif o.HasField("reset_workflow_failure_info"): - await self._visit_temporal_api_failure_v1_ResetWorkflowFailureInfo(fs, o.reset_workflow_failure_info) + await self._visit_temporal_api_failure_v1_ResetWorkflowFailureInfo( + fs, o.reset_workflow_failure_info + ) async def _visit_temporal_api_common_v1_Memo(self, fs, o): for v in o.fields.values(): @@ -124,11 +138,15 @@ async def _visit_coresdk_workflow_activation_InitializeWorkflow(self, fs, o): if o.HasField("continued_failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.continued_failure) if o.HasField("last_completion_result"): - await self._visit_temporal_api_common_v1_Payloads(fs, o.last_completion_result) + await self._visit_temporal_api_common_v1_Payloads( + fs, o.last_completion_result + ) if o.HasField("memo"): await self._visit_temporal_api_common_v1_Memo(fs, o.memo) if o.HasField("search_attributes"): - await self._visit_temporal_api_common_v1_SearchAttributes(fs, o.search_attributes) + await self._visit_temporal_api_common_v1_SearchAttributes( + fs, o.search_attributes + ) async def _visit_coresdk_workflow_activation_QueryWorkflow(self, fs, o): await self._visit_payload_container(fs, o.arguments) @@ -166,13 +184,19 @@ async def _visit_coresdk_workflow_activation_ResolveActivity(self, fs, o): if o.HasField("result"): await self._visit_coresdk_activity_result_ActivityResolution(fs, o.result) - async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStartCancelled(self, fs, o): + async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStartCancelled( + self, fs, o + ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) - async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStart(self, fs, o): + async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStart( + self, fs, o + ): if o.HasField("cancelled"): - await self._visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStartCancelled(fs, o.cancelled) + await self._visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStartCancelled( + fs, o.cancelled + ) async def _visit_coresdk_child_workflow_Success(self, fs, o): if o.HasField("result"): @@ -194,15 +218,21 @@ async def _visit_coresdk_child_workflow_ChildWorkflowResult(self, fs, o): elif o.HasField("cancelled"): await self._visit_coresdk_child_workflow_Cancellation(fs, o.cancelled) - async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecution(self, fs, o): + async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecution( + self, fs, o + ): if o.HasField("result"): await self._visit_coresdk_child_workflow_ChildWorkflowResult(fs, o.result) - async def _visit_coresdk_workflow_activation_ResolveSignalExternalWorkflow(self, fs, o): + async def _visit_coresdk_workflow_activation_ResolveSignalExternalWorkflow( + self, fs, o + ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) - async def _visit_coresdk_workflow_activation_ResolveRequestCancelExternalWorkflow(self, fs, o): + async def _visit_coresdk_workflow_activation_ResolveRequestCancelExternalWorkflow( + self, fs, o + ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) @@ -212,7 +242,9 @@ async def _visit_coresdk_workflow_activation_DoUpdate(self, fs, o): for v in o.headers.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) - async def _visit_coresdk_workflow_activation_ResolveNexusOperationStart(self, fs, o): + async def _visit_coresdk_workflow_activation_ResolveNexusOperationStart( + self, fs, o + ): if o.HasField("failed"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failed) @@ -232,27 +264,47 @@ async def _visit_coresdk_workflow_activation_ResolveNexusOperation(self, fs, o): async def _visit_coresdk_workflow_activation_WorkflowActivationJob(self, fs, o): if o.HasField("initialize_workflow"): - await self._visit_coresdk_workflow_activation_InitializeWorkflow(fs, o.initialize_workflow) + await self._visit_coresdk_workflow_activation_InitializeWorkflow( + fs, o.initialize_workflow + ) elif o.HasField("query_workflow"): - await self._visit_coresdk_workflow_activation_QueryWorkflow(fs, o.query_workflow) + await self._visit_coresdk_workflow_activation_QueryWorkflow( + fs, o.query_workflow + ) elif o.HasField("signal_workflow"): - await self._visit_coresdk_workflow_activation_SignalWorkflow(fs, o.signal_workflow) + await self._visit_coresdk_workflow_activation_SignalWorkflow( + fs, o.signal_workflow + ) elif o.HasField("resolve_activity"): - await self._visit_coresdk_workflow_activation_ResolveActivity(fs, o.resolve_activity) + await self._visit_coresdk_workflow_activation_ResolveActivity( + fs, o.resolve_activity + ) elif o.HasField("resolve_child_workflow_execution_start"): - await self._visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStart(fs, o.resolve_child_workflow_execution_start) + await self._visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStart( + fs, o.resolve_child_workflow_execution_start + ) elif o.HasField("resolve_child_workflow_execution"): - await self._visit_coresdk_workflow_activation_ResolveChildWorkflowExecution(fs, o.resolve_child_workflow_execution) + await self._visit_coresdk_workflow_activation_ResolveChildWorkflowExecution( + fs, o.resolve_child_workflow_execution + ) elif o.HasField("resolve_signal_external_workflow"): - await self._visit_coresdk_workflow_activation_ResolveSignalExternalWorkflow(fs, o.resolve_signal_external_workflow) + await self._visit_coresdk_workflow_activation_ResolveSignalExternalWorkflow( + fs, o.resolve_signal_external_workflow + ) elif o.HasField("resolve_request_cancel_external_workflow"): - await self._visit_coresdk_workflow_activation_ResolveRequestCancelExternalWorkflow(fs, o.resolve_request_cancel_external_workflow) + await self._visit_coresdk_workflow_activation_ResolveRequestCancelExternalWorkflow( + fs, o.resolve_request_cancel_external_workflow + ) elif o.HasField("do_update"): await self._visit_coresdk_workflow_activation_DoUpdate(fs, o.do_update) elif o.HasField("resolve_nexus_operation_start"): - await self._visit_coresdk_workflow_activation_ResolveNexusOperationStart(fs, o.resolve_nexus_operation_start) + await self._visit_coresdk_workflow_activation_ResolveNexusOperationStart( + fs, o.resolve_nexus_operation_start + ) elif o.HasField("resolve_nexus_operation"): - await self._visit_coresdk_workflow_activation_ResolveNexusOperation(fs, o.resolve_nexus_operation) + await self._visit_coresdk_workflow_activation_ResolveNexusOperation( + fs, o.resolve_nexus_operation + ) async def _visit_coresdk_workflow_activation_WorkflowActivation(self, fs, o): for v in o.jobs: @@ -288,7 +340,9 @@ async def _visit_coresdk_workflow_commands_FailWorkflowExecution(self, fs, o): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) - async def _visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution(self, fs, o): + async def _visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution( + self, fs, o + ): await self._visit_payload_container(fs, o.arguments) for v in o.memo.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) @@ -296,7 +350,9 @@ async def _visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution(self, for v in o.headers.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) if o.HasField("search_attributes"): - await self._visit_temporal_api_common_v1_SearchAttributes(fs, o.search_attributes) + await self._visit_temporal_api_common_v1_SearchAttributes( + fs, o.search_attributes + ) async def _visit_coresdk_workflow_commands_StartChildWorkflowExecution(self, fs, o): await self._visit_payload_container(fs, o.input) @@ -306,9 +362,13 @@ async def _visit_coresdk_workflow_commands_StartChildWorkflowExecution(self, fs, for v in o.memo.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) if o.HasField("search_attributes"): - await self._visit_temporal_api_common_v1_SearchAttributes(fs, o.search_attributes) + await self._visit_temporal_api_common_v1_SearchAttributes( + fs, o.search_attributes + ) - async def _visit_coresdk_workflow_commands_SignalExternalWorkflowExecution(self, fs, o): + async def _visit_coresdk_workflow_commands_SignalExternalWorkflowExecution( + self, fs, o + ): await self._visit_payload_container(fs, o.args) if not self.skip_headers: for v in o.headers.values(): @@ -320,9 +380,13 @@ async def _visit_coresdk_workflow_commands_ScheduleLocalActivity(self, fs, o): await self._visit_temporal_api_common_v1_Payload(fs, v) await self._visit_payload_container(fs, o.arguments) - async def _visit_coresdk_workflow_commands_UpsertWorkflowSearchAttributes(self, fs, o): + async def _visit_coresdk_workflow_commands_UpsertWorkflowSearchAttributes( + self, fs, o + ): if o.HasField("search_attributes"): - await self._visit_temporal_api_common_v1_SearchAttributes(fs, o.search_attributes) + await self._visit_temporal_api_common_v1_SearchAttributes( + fs, o.search_attributes + ) async def _visit_coresdk_workflow_commands_ModifyWorkflowProperties(self, fs, o): if o.HasField("upserted_memo"): @@ -342,29 +406,53 @@ async def _visit_coresdk_workflow_commands_WorkflowCommand(self, fs, o): if o.HasField("user_metadata"): await self._visit_temporal_api_sdk_v1_UserMetadata(fs, o.user_metadata) if o.HasField("schedule_activity"): - await self._visit_coresdk_workflow_commands_ScheduleActivity(fs, o.schedule_activity) + await self._visit_coresdk_workflow_commands_ScheduleActivity( + fs, o.schedule_activity + ) elif o.HasField("respond_to_query"): - await self._visit_coresdk_workflow_commands_QueryResult(fs, o.respond_to_query) + await self._visit_coresdk_workflow_commands_QueryResult( + fs, o.respond_to_query + ) elif o.HasField("complete_workflow_execution"): - await self._visit_coresdk_workflow_commands_CompleteWorkflowExecution(fs, o.complete_workflow_execution) + await self._visit_coresdk_workflow_commands_CompleteWorkflowExecution( + fs, o.complete_workflow_execution + ) elif o.HasField("fail_workflow_execution"): - await self._visit_coresdk_workflow_commands_FailWorkflowExecution(fs, o.fail_workflow_execution) + await self._visit_coresdk_workflow_commands_FailWorkflowExecution( + fs, o.fail_workflow_execution + ) elif o.HasField("continue_as_new_workflow_execution"): - await self._visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution(fs, o.continue_as_new_workflow_execution) + await self._visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution( + fs, o.continue_as_new_workflow_execution + ) elif o.HasField("start_child_workflow_execution"): - await self._visit_coresdk_workflow_commands_StartChildWorkflowExecution(fs, o.start_child_workflow_execution) + await self._visit_coresdk_workflow_commands_StartChildWorkflowExecution( + fs, o.start_child_workflow_execution + ) elif o.HasField("signal_external_workflow_execution"): - await self._visit_coresdk_workflow_commands_SignalExternalWorkflowExecution(fs, o.signal_external_workflow_execution) + await self._visit_coresdk_workflow_commands_SignalExternalWorkflowExecution( + fs, o.signal_external_workflow_execution + ) elif o.HasField("schedule_local_activity"): - await self._visit_coresdk_workflow_commands_ScheduleLocalActivity(fs, o.schedule_local_activity) + await self._visit_coresdk_workflow_commands_ScheduleLocalActivity( + fs, o.schedule_local_activity + ) elif o.HasField("upsert_workflow_search_attributes"): - await self._visit_coresdk_workflow_commands_UpsertWorkflowSearchAttributes(fs, o.upsert_workflow_search_attributes) + await self._visit_coresdk_workflow_commands_UpsertWorkflowSearchAttributes( + fs, o.upsert_workflow_search_attributes + ) elif o.HasField("modify_workflow_properties"): - await self._visit_coresdk_workflow_commands_ModifyWorkflowProperties(fs, o.modify_workflow_properties) + await self._visit_coresdk_workflow_commands_ModifyWorkflowProperties( + fs, o.modify_workflow_properties + ) elif o.HasField("update_response"): - await self._visit_coresdk_workflow_commands_UpdateResponse(fs, o.update_response) + await self._visit_coresdk_workflow_commands_UpdateResponse( + fs, o.update_response + ) elif o.HasField("schedule_nexus_operation"): - await self._visit_coresdk_workflow_commands_ScheduleNexusOperation(fs, o.schedule_nexus_operation) + await self._visit_coresdk_workflow_commands_ScheduleNexusOperation( + fs, o.schedule_nexus_operation + ) async def _visit_coresdk_workflow_completion_Success(self, fs, o): for v in o.commands: @@ -374,7 +462,9 @@ async def _visit_coresdk_workflow_completion_Failure(self, fs, o): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) - async def _visit_coresdk_workflow_completion_WorkflowActivationCompletion(self, fs, o): + async def _visit_coresdk_workflow_completion_WorkflowActivationCompletion( + self, fs, o + ): if o.HasField("successful"): await self._visit_coresdk_workflow_completion_Success(fs, o.successful) elif o.HasField("failed"): diff --git a/temporalio/contrib/opentelemetry/_interceptor.py b/temporalio/contrib/opentelemetry/_interceptor.py index 69a2cfb0c..e09b355df 100644 --- a/temporalio/contrib/opentelemetry/_interceptor.py +++ b/temporalio/contrib/opentelemetry/_interceptor.py @@ -772,6 +772,16 @@ async def signal_external_workflow( ) await super().signal_external_workflow(input) + async def signal_with_start_external_workflow( + self, input: temporalio.worker.SignalWithStartExternalWorkflowInput + ) -> temporalio.workflow.ExternalWorkflowHandle[Any]: + self.root._completed_span( + f"SignalWithStartWorkflow:{input.signal}", + add_to_outbound_str=input, + kind=opentelemetry.trace.SpanKind.CLIENT, + ) + return await super().signal_with_start_external_workflow(input) + def start_activity( self, input: temporalio.worker.StartActivityInput ) -> temporalio.workflow.ActivityHandle: diff --git a/temporalio/contrib/opentelemetry/_otel_interceptor.py b/temporalio/contrib/opentelemetry/_otel_interceptor.py index 1756f93e1..34e61eb65 100644 --- a/temporalio/contrib/opentelemetry/_otel_interceptor.py +++ b/temporalio/contrib/opentelemetry/_otel_interceptor.py @@ -545,6 +545,16 @@ async def signal_external_workflow( input.headers = _context_to_headers(input.headers) await super().signal_external_workflow(input) + async def signal_with_start_external_workflow( + self, input: temporalio.worker.SignalWithStartExternalWorkflowInput + ) -> temporalio.workflow.ExternalWorkflowHandle[Any]: + with self._workflow_maybe_span( + f"SignalWithStartWorkflow:{input.signal}", + kind=opentelemetry.trace.SpanKind.CLIENT, + ): + input.headers = _context_to_nexus_headers(input.headers or {}) + return await super().signal_with_start_external_workflow(input) + def start_activity( self, input: temporalio.worker.StartActivityInput ) -> temporalio.workflow.ActivityHandle: diff --git a/temporalio/nexus/system/__init__.py b/temporalio/nexus/system/__init__.py index bbce702d1..b97fc4847 100644 --- a/temporalio/nexus/system/__init__.py +++ b/temporalio/nexus/system/__init__.py @@ -4,9 +4,15 @@ Higher-level ergonomic APIs may wrap these generated types. """ -from collections.abc import Awaitable, Callable, Sequence +from collections.abc import Awaitable, Callable, Mapping, Sequence +from datetime import timedelta +from typing import Any, cast + +from google.protobuf.json_format import MessageToDict import temporalio.api.common.v1 +import temporalio.api.enums.v1 +import temporalio.common import temporalio.converter from . import _workflow_service_generated as generated @@ -27,6 +33,209 @@ _SYSTEM_NEXUS_PAYLOAD_CONVERTER = temporalio.converter.default().payload_converter +def _payload_to_json_value( + converter: temporalio.converter.PayloadConverter, value: Any +) -> dict[str, Any]: + return MessageToDict(converter.to_payload(value)) + + +def _payloads_to_input( + converter: temporalio.converter.PayloadConverter, values: Sequence[Any] +) -> generated.Input | None: + payloads = converter.to_payloads(values) if values else [] + if not payloads: + return None + return generated.Input(payloads=[MessageToDict(payload) for payload in payloads]) + + +def _search_attributes_to_json_map( + attributes: temporalio.common.TypedSearchAttributes, +) -> dict[str, Any]: + return { + pair.key.name: MessageToDict( + temporalio.converter.encode_typed_search_attribute_value( + pair.key, pair.value + ) + ) + for pair in attributes + } + + +def _retry_policy_to_generated( + retry_policy: temporalio.common.RetryPolicy, +) -> generated.RetryPolicy: + retry_policy._validate() + return generated.RetryPolicy( + initialInterval=f"{retry_policy.initial_interval.total_seconds()}s", + backoffCoefficient=retry_policy.backoff_coefficient, + maximumInterval=f"{(retry_policy.maximum_interval or retry_policy.initial_interval * 100).total_seconds()}s", + maximumAttempts=retry_policy.maximum_attempts, + nonRetryableErrorTypes=( + list(retry_policy.non_retryable_error_types) + if retry_policy.non_retryable_error_types + else None + ), + ) + + +def _priority_to_generated( + priority: temporalio.common.Priority, +) -> generated.Priority | None: + if ( + priority.priority_key is None + and priority.fairness_key is None + and priority.fairness_weight is None + ): + return None + return generated.Priority( + priorityKey=priority.priority_key, + fairnessKey=priority.fairness_key, + fairnessWeight=priority.fairness_weight, + ) + + +def _workflow_id_reuse_policy_to_generated( + policy: temporalio.common.WorkflowIDReusePolicy, +) -> generated.WorkflowIDReusePolicy: + return generated.WorkflowIDReusePolicy( + temporalio.api.enums.v1.WorkflowIdReusePolicy.Name( + cast("temporalio.api.enums.v1.WorkflowIdReusePolicy.ValueType", int(policy)) + ) + ) + + +def _workflow_id_conflict_policy_to_generated( + policy: temporalio.common.WorkflowIDConflictPolicy, +) -> generated.WorkflowIDConflictPolicy: + return generated.WorkflowIDConflictPolicy( + temporalio.api.enums.v1.WorkflowIdConflictPolicy.Name( + cast( + "temporalio.api.enums.v1.WorkflowIdConflictPolicy.ValueType", + int(policy), + ) + ) + ) + + +def _versioning_override_to_generated( + versioning_override: temporalio.common.VersioningOverride, +) -> generated.VersioningOverride: + if isinstance(versioning_override, temporalio.common.AutoUpgradeVersioningOverride): + return generated.VersioningOverride( + autoUpgrade=True, + behavior=generated.VersioningOverrideBehavior.VERSIONING_BEHAVIOR_AUTO_UPGRADE, + ) + if isinstance(versioning_override, temporalio.common.PinnedVersioningOverride): + return generated.VersioningOverride( + behavior=generated.VersioningOverrideBehavior.VERSIONING_BEHAVIOR_PINNED, + pinnedVersion=versioning_override.version.to_canonical_string(), + pinned=generated.Pinned( + behavior=generated.PinnedBehavior.PINNED_OVERRIDE_BEHAVIOR_PINNED, + version=generated.Version( + deploymentName=versioning_override.version.deployment_name, + buildId=versioning_override.version.build_id, + ), + ), + deployment=generated.Deployment( + seriesName=versioning_override.version.deployment_name, + buildId=versioning_override.version.build_id, + ), + ) + raise TypeError( + f"Unsupported versioning override type: {type(versioning_override)!r}" + ) + + +def build_signal_with_start_workflow_execution_input( + *, + namespace: str, + workflow_id: str, + workflow: str, + workflow_args: Sequence[Any], + signal: str, + signal_args: Sequence[Any], + task_queue: str, + request_id: str, + payload_converter: temporalio.converter.PayloadConverter, + execution_timeout: timedelta | None = None, + run_timeout: timedelta | None = None, + task_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: temporalio.common.RetryPolicy | None = None, + cron_schedule: str = "", + memo: Mapping[str, Any] | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + static_summary: str | None = None, + static_details: str | None = None, + start_delay: timedelta | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: temporalio.common.VersioningOverride | None = None, +) -> generated.WorkflowServiceSignalWithStartWorkflowExecutionInput: + """Build the generated system Nexus input for signal-with-start.""" + return generated.WorkflowServiceSignalWithStartWorkflowExecutionInput( + namespace=namespace, + workflowId=workflow_id, + workflowType=generated.WorkflowType(name=workflow), + taskQueue=generated.TaskQueue(name=task_queue), + input=_payloads_to_input(payload_converter, workflow_args), + workflowExecutionTimeout=( + f"{execution_timeout.total_seconds()}s" if execution_timeout else None + ), + workflowRunTimeout=f"{run_timeout.total_seconds()}s" if run_timeout else None, + workflowTaskTimeout=( + f"{task_timeout.total_seconds()}s" if task_timeout else None + ), + requestId=request_id, + workflowIdReusePolicy=_workflow_id_reuse_policy_to_generated(id_reuse_policy), + workflowIdConflictPolicy=_workflow_id_conflict_policy_to_generated( + id_conflict_policy + ), + retryPolicy=( + _retry_policy_to_generated(retry_policy) if retry_policy else None + ), + cronSchedule=cron_schedule, + memo=( + generated.Memo( + fields={ + key: _payload_to_json_value(payload_converter, value) + for key, value in memo.items() + } + ) + if memo + else None + ), + searchAttributes=( + generated.SearchAttributes( + indexedFields=_search_attributes_to_json_map(search_attributes) + ) + if search_attributes + else None + ), + signalName=signal, + signalInput=_payloads_to_input(payload_converter, signal_args), + userMetadata=( + generated.UserMetadata( + summary=_payload_to_json_value(payload_converter, static_summary) + if static_summary is not None + else None, + details=_payload_to_json_value(payload_converter, static_details) + if static_details is not None + else None, + ) + if static_summary is not None or static_details is not None + else None + ), + workflowStartDelay=(f"{start_delay.total_seconds()}s" if start_delay else None), + priority=_priority_to_generated(priority), + versioningOverride=( + _versioning_override_to_generated(versioning_override) + if versioning_override + else None + ), + ) + + def get_payload_visitor( service: str, operation: str, @@ -46,6 +255,7 @@ def get_payload_converter() -> temporalio.converter.PayloadConverter: __all__ = ( + "build_signal_with_start_workflow_execution_input", "generated", "get_payload_converter", "get_payload_visitor", diff --git a/temporalio/nexus/system/_workflow_service_generated.py b/temporalio/nexus/system/_workflow_service_generated.py index 684132b96..585923e9a 100644 --- a/temporalio/nexus/system/_workflow_service_generated.py +++ b/temporalio/nexus/system/_workflow_service_generated.py @@ -49,7 +49,7 @@ class BatchJob: jobId: Optional[str] = None -class EventType(Enum): +class EventType(str, Enum): EVENT_TYPE_ACTIVITY_PROPERTIES_MODIFIED_EXTERNALLY = ( "EVENT_TYPE_ACTIVITY_PROPERTIES_MODIFIED_EXTERNALLY" ) @@ -175,14 +175,14 @@ class EventRef: """EventReference is a direct reference to a history event through the event ID.""" eventId: Optional[str] = None - eventType: Optional[EventType] = None + eventType: Optional["EventType"] = None @dataclass class RequestIDRef: """RequestIdReference is a indirect reference to a history event through the request ID.""" - eventType: Optional[EventType] = None + eventType: Optional["EventType"] = None requestId: Optional[str] = None @@ -347,7 +347,7 @@ class SearchAttributes: indexedFields: Optional[Dict[str, Any]] = None -class Kind(Enum): +class Kind(str, Enum): """Default: TASK_QUEUE_KIND_NORMAL.""" TASK_QUEUE_KIND_NORMAL = "TASK_QUEUE_KIND_NORMAL" @@ -362,7 +362,7 @@ class TaskQueue: See https://docs.temporal.io/docs/concepts/task-queues/ """ - kind: Optional[Kind] = None + kind: Optional["Kind"] = None """Default: TASK_QUEUE_KIND_NORMAL.""" name: Optional[str] = None @@ -400,7 +400,7 @@ class UserMetadata: """ -class VersioningOverrideBehavior(Enum): +class VersioningOverrideBehavior(str, Enum): """Required. Deprecated. Use `override`. """ @@ -440,7 +440,7 @@ class Deployment: """ -class PinnedBehavior(Enum): +class PinnedBehavior(str, Enum): """Defaults to PINNED_OVERRIDE_BEHAVIOR_UNSPECIFIED. See `PinnedOverrideBehavior` for details. """ @@ -483,7 +483,7 @@ class Version: class Pinned: """Override the workflow to have Pinned behavior.""" - behavior: Optional[PinnedBehavior] = None + behavior: Optional["PinnedBehavior"] = None """Defaults to PINNED_OVERRIDE_BEHAVIOR_UNSPECIFIED. See `PinnedOverrideBehavior` for details. """ @@ -522,7 +522,7 @@ class VersioningOverride: autoUpgrade: Optional[bool] = None """Override the workflow to have AutoUpgrade behavior.""" - behavior: Optional[VersioningOverrideBehavior] = None + behavior: Optional["VersioningOverrideBehavior"] = None """Required. Deprecated. Use `override`. """ @@ -542,7 +542,7 @@ class VersioningOverride: """ -class WorkflowIDConflictPolicy(Enum): +class WorkflowIDConflictPolicy(str, Enum): """Defines how to resolve a workflow id conflict with a *running* workflow. The default policy is WORKFLOW_ID_CONFLICT_POLICY_USE_EXISTING. Note that WORKFLOW_ID_CONFLICT_POLICY_FAIL is an invalid option. @@ -561,7 +561,7 @@ class WorkflowIDConflictPolicy(Enum): ) -class WorkflowIDReusePolicy(Enum): +class WorkflowIDReusePolicy(str, Enum): """Defines whether to allow re-using the workflow id from a previously *closed* workflow. The default policy is WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE. @@ -650,7 +650,7 @@ class WorkflowServiceSignalWithStartWorkflowExecutionInput: """Total workflow execution timeout including retries and continue as new""" workflowId: Optional[str] = None - workflowIdConflictPolicy: Optional[WorkflowIDConflictPolicy] = None + workflowIdConflictPolicy: Optional["WorkflowIDConflictPolicy"] = None """Defines how to resolve a workflow id conflict with a *running* workflow. The default policy is WORKFLOW_ID_CONFLICT_POLICY_USE_EXISTING. Note that WORKFLOW_ID_CONFLICT_POLICY_FAIL is an invalid option. @@ -658,7 +658,7 @@ class WorkflowServiceSignalWithStartWorkflowExecutionInput: See `workflow_id_reuse_policy` for handling a workflow id duplication with a *closed* workflow. """ - workflowIdReusePolicy: Optional[WorkflowIDReusePolicy] = None + workflowIdReusePolicy: Optional["WorkflowIDReusePolicy"] = None """Defines whether to allow re-using the workflow id from a previously *closed* workflow. The default policy is WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE. diff --git a/temporalio/worker/__init__.py b/temporalio/worker/__init__.py index 55966b35d..cd8feb720 100644 --- a/temporalio/worker/__init__.py +++ b/temporalio/worker/__init__.py @@ -17,6 +17,7 @@ NexusOperationInboundInterceptor, SignalChildWorkflowInput, SignalExternalWorkflowInput, + SignalWithStartExternalWorkflowInput, StartActivityInput, StartChildWorkflowInput, StartLocalActivityInput, @@ -94,6 +95,7 @@ "HandleUpdateInput", "SignalChildWorkflowInput", "SignalExternalWorkflowInput", + "SignalWithStartExternalWorkflowInput", "StartActivityInput", "StartChildWorkflowInput", "StartLocalActivityInput", diff --git a/temporalio/worker/_interceptor.py b/temporalio/worker/_interceptor.py index f0d616f2c..371c5e054 100644 --- a/temporalio/worker/_interceptor.py +++ b/temporalio/worker/_interceptor.py @@ -241,6 +241,36 @@ class SignalExternalWorkflowInput: headers: Mapping[str, temporalio.api.common.v1.Payload] +@dataclass +class SignalWithStartExternalWorkflowInput: + """Input for + :py:meth:`WorkflowOutboundInterceptor.signal_with_start_external_workflow`. + """ + + signal: str + signal_args: Sequence[Any] + namespace: str + workflow_id: str + workflow: str + workflow_args: Sequence[Any] + task_queue: str + execution_timeout: timedelta | None + run_timeout: timedelta | None + task_timeout: timedelta | None + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy + retry_policy: temporalio.common.RetryPolicy | None + cron_schedule: str + memo: Mapping[str, Any] | None + search_attributes: temporalio.common.TypedSearchAttributes | None + static_summary: str | None + static_details: str | None + start_delay: timedelta | None + priority: temporalio.common.Priority + versioning_override: temporalio.common.VersioningOverride | None + headers: Mapping[str, str] | None + + @dataclass class StartActivityInput: """Input for :py:meth:`WorkflowOutboundInterceptor.start_activity`.""" @@ -450,6 +480,15 @@ async def signal_external_workflow( """ return await self.next.signal_external_workflow(input) + async def signal_with_start_external_workflow( + self, input: SignalWithStartExternalWorkflowInput + ) -> temporalio.workflow.ExternalWorkflowHandle[Any]: + """Called for every + :py:meth:`temporalio.workflow.ExternalWorkflowHandle.signal_with_start_workflow` + call. + """ + return await self.next.signal_with_start_external_workflow(input) + def start_activity( self, input: StartActivityInput ) -> temporalio.workflow.ActivityHandle[Any]: diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 3454eb2ad..5aa493d01 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -12,6 +12,7 @@ import sys import threading import traceback +import uuid import warnings from abc import ABC, abstractmethod from collections import deque @@ -71,6 +72,7 @@ HandleUpdateInput, SignalChildWorkflowInput, SignalExternalWorkflowInput, + SignalWithStartExternalWorkflowInput, StartActivityInput, StartChildWorkflowInput, StartLocalActivityInput, @@ -1952,6 +1954,61 @@ async def _outbound_signal_external_workflow( temporalio.common._apply_headers(input.headers, v.headers) await self._signal_external_workflow(command) + async def _outbound_signal_with_start_external_workflow( + self, input: SignalWithStartExternalWorkflowInput + ) -> temporalio.workflow.ExternalWorkflowHandle[Any]: + payload_converter = self._payload_converter_with_context( + temporalio.converter.WorkflowSerializationContext( + namespace=input.namespace, + workflow_id=input.workflow_id, + ) + ) + request = ( + temporalio.nexus.system.build_signal_with_start_workflow_execution_input( + namespace=input.namespace, + workflow_id=input.workflow_id, + workflow=input.workflow, + workflow_args=input.workflow_args, + signal=input.signal, + signal_args=input.signal_args, + task_queue=input.task_queue, + request_id=str(uuid.uuid4()), + payload_converter=payload_converter, + execution_timeout=input.execution_timeout, + run_timeout=input.run_timeout, + task_timeout=input.task_timeout, + id_reuse_policy=input.id_reuse_policy, + id_conflict_policy=input.id_conflict_policy, + retry_policy=input.retry_policy, + cron_schedule=input.cron_schedule, + memo=input.memo, + search_attributes=input.search_attributes, + static_summary=input.static_summary, + static_details=input.static_details, + start_delay=input.start_delay, + priority=input.priority, + versioning_override=input.versioning_override, + ) + ) + handle = await self._outbound_start_nexus_operation( + StartNexusOperationInput( + endpoint=temporalio.workflow._SYSTEM_NEXUS_ENDPOINT, + service=temporalio.nexus.system.generated.WorkflowService.__name__, + operation=temporalio.nexus.system.generated.WorkflowService.signal_with_start_workflow_execution, + input=request, + schedule_to_close_timeout=None, + schedule_to_start_timeout=None, + start_to_close_timeout=None, + cancellation_type=temporalio.workflow.NexusOperationCancellationType.WAIT_COMPLETED, + headers=input.headers, + summary=None, + ) + ) + result = await handle + return self.workflow_get_external_workflow_handle( + input.workflow_id, run_id=result.runId + ) + async def _outbound_start_child_workflow( self, input: StartChildWorkflowInput ) -> _ChildWorkflowHandle: @@ -2878,6 +2935,11 @@ async def signal_external_workflow( ) -> None: await self._instance._outbound_signal_external_workflow(input) + async def signal_with_start_external_workflow( + self, input: SignalWithStartExternalWorkflowInput + ) -> temporalio.workflow.ExternalWorkflowHandle[Any]: + return await self._instance._outbound_signal_with_start_external_workflow(input) + def start_activity( self, input: StartActivityInput ) -> temporalio.workflow.ActivityHandle[Any]: @@ -3287,6 +3349,66 @@ async def signal( ) ) + async def signal_with_start_workflow( + self, + signal: str | Callable, + workflow: str | Callable[..., Awaitable[Any]], + signal_arg: Any = temporalio.common._arg_unset, + workflow_arg: Any = temporalio.common._arg_unset, + *, + signal_args: Sequence[Any] = [], + workflow_args: Sequence[Any] = [], + task_queue: str, + execution_timeout: timedelta | None = None, + run_timeout: timedelta | None = None, + task_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: temporalio.common.RetryPolicy | None = None, + cron_schedule: str = "", + memo: Mapping[str, Any] | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + static_summary: str | None = None, + static_details: str | None = None, + start_delay: timedelta | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: temporalio.common.VersioningOverride | None = None, + ) -> temporalio.workflow.ExternalWorkflowHandle[Any]: + self._instance._assert_not_read_only("signal with start external handle") + workflow_name, _ = temporalio.workflow._Definition.get_name_and_result_type( + workflow + ) + return await self._instance._outbound.signal_with_start_external_workflow( + SignalWithStartExternalWorkflowInput( + signal=temporalio.workflow._SignalDefinition.must_name_from_fn_or_str( + signal + ), + signal_args=temporalio.common._arg_or_args(signal_arg, signal_args), + namespace=self._instance._info.namespace, + workflow_id=self._id, + workflow=workflow_name, + workflow_args=temporalio.common._arg_or_args( + workflow_arg, workflow_args + ), + task_queue=task_queue, + execution_timeout=execution_timeout, + run_timeout=run_timeout, + task_timeout=task_timeout, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + cron_schedule=cron_schedule, + memo=memo, + search_attributes=search_attributes, + static_summary=static_summary, + static_details=static_details, + start_delay=start_delay, + priority=priority, + versioning_override=versioning_override, + headers=None, + ) + ) + async def cancel(self) -> None: self._instance._assert_not_read_only("cancel external handle") command = self._instance._add_command() diff --git a/temporalio/workflow.py b/temporalio/workflow.py index dd8565f78..7168f229e 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -4282,6 +4282,11 @@ class ChildWorkflowConfig(TypedDict, total=False): priority: temporalio.common.Priority +_SYSTEM_NEXUS_ENDPOINT = "temporal-system" +# TODO: Switch this back to "__temporal_system" once the server supports reserved +# endpoint names for system operations. + + # Overload for no-param workflow @overload async def start_child_workflow( @@ -4738,6 +4743,72 @@ async def cancel(self) -> None: """ raise NotImplementedError + async def signal_with_start_workflow( + self, + signal: str | Callable, # type: ignore[reportUnusedParameter] + workflow: str | Callable[..., Awaitable[Any]], # type: ignore[reportUnusedParameter] + signal_arg: Any = temporalio.common._arg_unset, # type: ignore[reportUnusedParameter] + workflow_arg: Any = temporalio.common._arg_unset, # type: ignore[reportUnusedParameter] + *, + signal_args: Sequence[Any] = [], # type: ignore[reportUnusedParameter] + workflow_args: Sequence[Any] = [], # type: ignore[reportUnusedParameter] + task_queue: str, # type: ignore[reportUnusedParameter] + execution_timeout: timedelta | None = None, # type: ignore[reportUnusedParameter] + run_timeout: timedelta | None = None, # type: ignore[reportUnusedParameter] + task_timeout: timedelta | None = None, # type: ignore[reportUnusedParameter] + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, # type: ignore[reportUnusedParameter] + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, # type: ignore[reportUnusedParameter] + retry_policy: temporalio.common.RetryPolicy | None = None, # type: ignore[reportUnusedParameter] + cron_schedule: str = "", # type: ignore[reportUnusedParameter] + memo: Mapping[str, Any] | None = None, # type: ignore[reportUnusedParameter] + search_attributes: temporalio.common.TypedSearchAttributes | None = None, # type: ignore[reportUnusedParameter] + static_summary: str | None = None, # type: ignore[reportUnusedParameter] + static_details: str | None = None, # type: ignore[reportUnusedParameter] + start_delay: timedelta | None = None, # type: ignore[reportUnusedParameter] + priority: temporalio.common.Priority = temporalio.common.Priority.default, # type: ignore[reportUnusedParameter] + versioning_override: temporalio.common.VersioningOverride | None = None, # type: ignore[reportUnusedParameter] + ) -> ExternalWorkflowHandle[SelfType]: + """Signal the workflow, or start it and signal it if it is not running. + + This uses the system Nexus ``SignalWithStartWorkflowExecution`` operation + under the hood. If this handle has a ``run_id``, it is ignored because + signal-with-start operates on workflow ID only. + + Args: + signal: Name or method reference for the signal. + workflow: String name or class method decorated with ``@workflow.run`` + for the workflow to start. + signal_arg: Single argument to the signal. + workflow_arg: Single argument to the workflow. + signal_args: Multiple arguments to the signal. Cannot be set if + signal_arg is. + workflow_args: Multiple arguments to the workflow. Cannot be set if + workflow_arg is. + task_queue: Task queue to run the workflow on if it is started. + execution_timeout: Total workflow execution timeout including + retries and continue as new. + run_timeout: Timeout of a single workflow run. + task_timeout: Timeout of a single workflow task. + id_reuse_policy: How already-existing IDs are treated. + id_conflict_policy: How already-running IDs are treated. + retry_policy: Retry policy for the workflow. + cron_schedule: See https://docs.temporal.io/docs/content/what-is-a-temporal-cron-job/ + memo: Memo for the workflow. + search_attributes: Typed search attributes for the workflow. + static_summary: A single-line fixed summary for this workflow + execution that may appear in the UI/CLI. + static_details: General fixed details for this workflow execution + that may appear in UI/CLI. + start_delay: Time to wait before dispatching the first workflow task. + priority: Priority to use for this workflow. + versioning_override: Versioning override to apply if the workflow is + started. + + Returns: + A handle for the resulting workflow run. + """ + raise NotImplementedError + def get_external_workflow_handle( workflow_id: str, diff --git a/tests/nexus/test_temporal_system_nexus.py b/tests/nexus/test_temporal_system_nexus.py index d02849bca..69d3a7b3b 100644 --- a/tests/nexus/test_temporal_system_nexus.py +++ b/tests/nexus/test_temporal_system_nexus.py @@ -21,11 +21,20 @@ ) from temporalio.nexus.system import generated from temporalio.testing import WorkflowEnvironment -from temporalio.worker import Worker +from temporalio.worker import ( + Interceptor, + SignalWithStartExternalWorkflowInput, + Worker, + WorkflowInboundInterceptor, + WorkflowInterceptorClassInput, + WorkflowOutboundInterceptor, +) from temporalio.worker._workflow_instance import UnsandboxedWorkflowRunner from tests.helpers.nexus import make_nexus_endpoint_name from tests.test_extstore import InMemoryTestDriver +interceptor_traces: list[tuple[str, object]] = [] + @nexusrpc.handler.service_handler(service=generated.WorkflowService) class WorkflowServicePayloadHandler: @@ -42,16 +51,19 @@ async def signal_with_start_workflow_execution( payloads = request_dict[field_name]["payloads"] assert payloads[0]["externalPayloads"] for field_name in ("memo", "header"): - fields = request_dict[field_name]["fields"] - assert next(iter(fields.values()))["externalPayloads"] + fields = (request_dict.get(field_name) or {}).get("fields") + if fields: + assert next(iter(fields.values()))["externalPayloads"] for field_name in ("summary", "details"): - payload = request_dict["userMetadata"][field_name] - assert payload["externalPayloads"] - search_attribute_payload = request_dict["searchAttributes"]["indexedFields"][ - "custom-key" - ] - assert "externalPayloads" not in search_attribute_payload - assert "test-codec" not in search_attribute_payload["metadata"] + payload = (request_dict.get("userMetadata") or {}).get(field_name) + if payload: + assert payload["externalPayloads"] + if search_attributes := (request_dict.get("searchAttributes") or {}).get( + "indexedFields" + ): + search_attribute_payload = search_attributes["custom-key"] + assert "externalPayloads" not in search_attribute_payload + assert "test-codec" not in search_attribute_payload["metadata"] return generated.WorkflowServiceSignalWithStartWorkflowExecutionOutput( runId=f"{request.workflowId}-run" ) @@ -142,6 +154,24 @@ async def run(self, task_queue: str) -> str: return cast(str, result.runId) +@workflow.defn +class ExternalHandleSignalWithStartWorkflowCaller: + @workflow.run + async def run(self, task_queue: str) -> str: + handle = workflow.get_external_workflow_handle("system-nexus-workflow-id") + started_handle = await handle.signal_with_start_workflow( + "test-signal", + "test-workflow", + "signal-input", + "workflow-input", + task_queue=task_queue, + memo={"memo-key": "memo-value"}, + static_summary="summary-value", + static_details="details-value", + ) + return cast(str, started_handle.run_id) + + class RejectOuterSystemNexusCodec(PayloadCodec): def __init__(self) -> None: self.encode_count = 0 @@ -213,6 +243,28 @@ def to_payloads( return payloads +class TracingWorkflowInterceptor(Interceptor): + def workflow_interceptor_class( + self, input: WorkflowInterceptorClassInput + ) -> type[WorkflowInboundInterceptor] | None: + return _TracingWorkflowInboundInterceptor + + +class _TracingWorkflowInboundInterceptor(WorkflowInboundInterceptor): + def init(self, outbound: WorkflowOutboundInterceptor) -> None: + super().init(_TracingWorkflowOutboundInterceptor(outbound)) + + +class _TracingWorkflowOutboundInterceptor(WorkflowOutboundInterceptor): + async def signal_with_start_external_workflow( + self, input: SignalWithStartExternalWorkflowInput + ) -> workflow.ExternalWorkflowHandle[object]: + interceptor_traces.append( + ("workflow.signal_with_start_external_workflow", input) + ) + return await super().signal_with_start_external_workflow(input) + + async def test_workflow_service_signal_with_start_nested_payloads_use_codec_without_encoding_outer_envelope( env: WorkflowEnvironment, ): @@ -277,3 +329,77 @@ async def test_workflow_service_signal_with_start_nested_payloads_use_codec_with b'"summary-value"', b'"details-value"', }.issubset(stored_payload_data) + + +async def test_external_workflow_handle_signal_with_start_workflow_uses_system_nexus( + env: WorkflowEnvironment, + monkeypatch: pytest.MonkeyPatch, +): + if env.supports_time_skipping: + pytest.skip("Nexus tests don't work with the Java test server") + + codec = RejectOuterSystemNexusCodec() + interceptor_traces.clear() + driver = InMemoryTestDriver() + caller_config = env.client.config() + caller_config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), + payload_codec=codec, + external_storage=ExternalStorage( + drivers=[driver], + payload_size_threshold=1, + ), + ) + caller_client = Client(**caller_config) + handler_config = env.client.config() + handler_config["data_converter"] = temporalio.converter.default() + handler_client = Client(**handler_config) + caller_task_queue = str(uuid.uuid4()) + handler_task_queue = str(uuid.uuid4()) + endpoint_name = make_nexus_endpoint_name(handler_task_queue) + monkeypatch.setattr(workflow, "_SYSTEM_NEXUS_ENDPOINT", endpoint_name) + + caller_worker = Worker( + caller_client, + task_queue=caller_task_queue, + workflows=[ExternalHandleSignalWithStartWorkflowCaller], + workflow_runner=UnsandboxedWorkflowRunner(), + interceptors=[TracingWorkflowInterceptor()], + ) + handler_worker = Worker( + handler_client, + task_queue=handler_task_queue, + nexus_service_handlers=[WorkflowServicePayloadHandler()], + ) + + async with caller_worker, handler_worker: + await env.create_nexus_endpoint(endpoint_name, handler_task_queue) + result = await caller_client.execute_workflow( + ExternalHandleSignalWithStartWorkflowCaller.run, + args=[handler_task_queue], + id=str(uuid.uuid4()), + task_queue=caller_task_queue, + ) + + assert result == "system-nexus-workflow-id-run" + assert codec.encode_count >= 5 + stored_payloads: list[temporalio.api.common.v1.Payload] = [] + for stored_payload_bytes in driver._storage.values(): + stored_payload = temporalio.api.common.v1.Payload() + stored_payload.ParseFromString(stored_payload_bytes) + stored_payloads.append(stored_payload) + assert stored_payload.metadata["test-codec"] == b"true" + stored_payload_data = {payload.data for payload in stored_payloads} + assert { + b'"workflow-input"', + b'"signal-input"', + b'"memo-value"', + b'"summary-value"', + b'"details-value"', + }.issubset(stored_payload_data) + trace = interceptor_traces.pop() + assert trace[0] == "workflow.signal_with_start_external_workflow" + trace_input = cast(SignalWithStartExternalWorkflowInput, trace[1]) + assert trace_input.workflow_id == "system-nexus-workflow-id" + assert trace_input.signal == "test-signal" + assert trace_input.workflow == "test-workflow" diff --git a/tests/worker/test_visitor.py b/tests/worker/test_visitor.py index b913cbffb..71027c7d0 100644 --- a/tests/worker/test_visitor.py +++ b/tests/worker/test_visitor.py @@ -8,7 +8,7 @@ import temporalio.bridge.worker import temporalio.converter -from temporalio import nexus +import temporalio.nexus.system as nexus_system from temporalio.api.common.v1.message_pb2 import ( Payload, Payloads, @@ -219,11 +219,11 @@ async def test_visit_payloads_on_other_commands(): async def test_visit_system_nexus_payloads_on_schedule_nexus_operation(): envelope = ( - nexus.system.generated.WorkflowServiceSignalWithStartWorkflowExecutionInput( + nexus_system.generated.WorkflowServiceSignalWithStartWorkflowExecutionInput( namespace="default", workflowId="workflow-id", signalName="signal-name", - input=nexus.system.generated.Input( + input=nexus_system.generated.Input( payloads=[ MessageToDict( Payload( @@ -232,7 +232,7 @@ async def test_visit_system_nexus_payloads_on_schedule_nexus_operation(): ) ] ), - signalInput=nexus.system.generated.Input( + signalInput=nexus_system.generated.Input( payloads=[ MessageToDict( Payload( @@ -241,7 +241,7 @@ async def test_visit_system_nexus_payloads_on_schedule_nexus_operation(): ) ] ), - memo=nexus.system.generated.Memo( + memo=nexus_system.generated.Memo( fields={ "memo-key": MessageToDict( Payload( @@ -250,7 +250,7 @@ async def test_visit_system_nexus_payloads_on_schedule_nexus_operation(): ) } ), - searchAttributes=nexus.system.generated.SearchAttributes( + searchAttributes=nexus_system.generated.SearchAttributes( indexedFields={ "search-key": MessageToDict( Payload( @@ -300,12 +300,12 @@ async def test_visit_system_nexus_payloads_on_schedule_nexus_operation(): async def test_bridge_encoding_checks_system_nexus_envelope_size(): envelope = ( - nexus.system.generated.WorkflowServiceSignalWithStartWorkflowExecutionInput( + nexus_system.generated.WorkflowServiceSignalWithStartWorkflowExecutionInput( namespace="default", workflowId="workflow-id", signalName="signal-name", requestId="x" * 2048, - input=nexus.system.generated.Input( + input=nexus_system.generated.Input( payloads=[ MessageToDict( Payload( From 46ca02c1f338c3357ca6c4da4f19feaeb79fa8b2 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Wed, 15 Apr 2026 15:15:33 -0700 Subject: [PATCH 11/18] Align system nexus signal-with-start workflow API --- temporalio/nexus/system/__init__.py | 2 -- temporalio/worker/_workflow_instance.py | 1 - temporalio/workflow.py | 5 +---- 3 files changed, 1 insertion(+), 7 deletions(-) diff --git a/temporalio/nexus/system/__init__.py b/temporalio/nexus/system/__init__.py index b97fc4847..609b59f1c 100644 --- a/temporalio/nexus/system/__init__.py +++ b/temporalio/nexus/system/__init__.py @@ -155,7 +155,6 @@ def build_signal_with_start_workflow_execution_input( signal: str, signal_args: Sequence[Any], task_queue: str, - request_id: str, payload_converter: temporalio.converter.PayloadConverter, execution_timeout: timedelta | None = None, run_timeout: timedelta | None = None, @@ -186,7 +185,6 @@ def build_signal_with_start_workflow_execution_input( workflowTaskTimeout=( f"{task_timeout.total_seconds()}s" if task_timeout else None ), - requestId=request_id, workflowIdReusePolicy=_workflow_id_reuse_policy_to_generated(id_reuse_policy), workflowIdConflictPolicy=_workflow_id_conflict_policy_to_generated( id_conflict_policy diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 5aa493d01..f85c6598e 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -1972,7 +1972,6 @@ async def _outbound_signal_with_start_external_workflow( signal=input.signal, signal_args=input.signal_args, task_queue=input.task_queue, - request_id=str(uuid.uuid4()), payload_converter=payload_converter, execution_timeout=input.execution_timeout, run_timeout=input.run_timeout, diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 7168f229e..db8c77988 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -4282,10 +4282,7 @@ class ChildWorkflowConfig(TypedDict, total=False): priority: temporalio.common.Priority -_SYSTEM_NEXUS_ENDPOINT = "temporal-system" -# TODO: Switch this back to "__temporal_system" once the server supports reserved -# endpoint names for system operations. - +_SYSTEM_NEXUS_ENDPOINT = "__temporal_system" # Overload for no-param workflow @overload From df0ee0e94b885cda879774b0a1947839a3307fdc Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Fri, 17 Apr 2026 12:02:44 -0700 Subject: [PATCH 12/18] Update Python system nexus generated model usage --- scripts/gen_nexus_system_models.py | 9 +- temporalio/nexus/system/__init__.py | 73 ++- .../system/_workflow_service_generated.py | 472 +++--------------- temporalio/worker/_interceptor.py | 3 +- temporalio/worker/_workflow_instance.py | 18 +- temporalio/workflow.py | 63 ++- tests/nexus/test_temporal_system_nexus.py | 253 +++------- tests/worker/test_visitor.py | 224 ++++----- 8 files changed, 361 insertions(+), 754 deletions(-) diff --git a/scripts/gen_nexus_system_models.py b/scripts/gen_nexus_system_models.py index 2009a43a7..a0e56249b 100644 --- a/scripts/gen_nexus_system_models.py +++ b/scripts/gen_nexus_system_models.py @@ -19,9 +19,14 @@ def main() -> None: input_schema = ( repo_root / "temporalio" + / "bridge" + / "sdk-core" + / "crates" + / "common" + / "protos" + / "api_upstream" / "nexus" - / "system" - / "_workflow_service.nexusrpc.yaml" + / "temporal-json-schema-models-nexusrpc.yaml" ) output_file = ( repo_root / "temporalio" / "nexus" / "system" / "_workflow_service_generated.py" diff --git a/temporalio/nexus/system/__init__.py b/temporalio/nexus/system/__init__.py index 609b59f1c..362502d3b 100644 --- a/temporalio/nexus/system/__init__.py +++ b/temporalio/nexus/system/__init__.py @@ -11,7 +11,6 @@ from google.protobuf.json_format import MessageToDict import temporalio.api.common.v1 -import temporalio.api.enums.v1 import temporalio.common import temporalio.converter @@ -32,27 +31,60 @@ _SYSTEM_NEXUS_PAYLOAD_CONVERTER = temporalio.converter.default().payload_converter +_WORKFLOW_ID_REUSE_POLICY_TO_GENERATED = { + temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE: generated.WorkflowIDReusePolicy.WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE, + temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE_FAILED_ONLY: generated.WorkflowIDReusePolicy.WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE_FAILED_ONLY, + temporalio.common.WorkflowIDReusePolicy.REJECT_DUPLICATE: generated.WorkflowIDReusePolicy.WORKFLOW_ID_REUSE_POLICY_REJECT_DUPLICATE, + temporalio.common.WorkflowIDReusePolicy.TERMINATE_IF_RUNNING: generated.WorkflowIDReusePolicy.WORKFLOW_ID_REUSE_POLICY_TERMINATE_IF_RUNNING, +} + +_WORKFLOW_ID_CONFLICT_POLICY_TO_GENERATED = { + temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED: generated.WorkflowIDConflictPolicy.WORKFLOW_ID_CONFLICT_POLICY_UNSPECIFIED, + temporalio.common.WorkflowIDConflictPolicy.FAIL: generated.WorkflowIDConflictPolicy.WORKFLOW_ID_CONFLICT_POLICY_FAIL, + temporalio.common.WorkflowIDConflictPolicy.USE_EXISTING: generated.WorkflowIDConflictPolicy.WORKFLOW_ID_CONFLICT_POLICY_USE_EXISTING, + temporalio.common.WorkflowIDConflictPolicy.TERMINATE_EXISTING: generated.WorkflowIDConflictPolicy.WORKFLOW_ID_CONFLICT_POLICY_TERMINATE_EXISTING, +} + def _payload_to_json_value( converter: temporalio.converter.PayloadConverter, value: Any -) -> dict[str, Any]: - return MessageToDict(converter.to_payload(value)) +) -> generated.Payload: + return _proto_payload_to_generated(converter.to_payload(value)) + + +def _proto_payload_to_generated( + payload: temporalio.api.common.v1.Payload, +) -> generated.Payload: + value = MessageToDict(payload) + return generated.Payload( + data=cast("str | None", value.get("data")), + externalPayloads=[ + generated.PayloadExternalPayloadDetails(**details) + for details in cast( + "list[dict[str, str]]", value.get("externalPayloads", []) + ) + ] + or None, + metadata=cast("dict[str, str] | None", value.get("metadata")), + ) def _payloads_to_input( converter: temporalio.converter.PayloadConverter, values: Sequence[Any] -) -> generated.Input | None: +) -> generated.Payloads | None: payloads = converter.to_payloads(values) if values else [] if not payloads: return None - return generated.Input(payloads=[MessageToDict(payload) for payload in payloads]) + return generated.Payloads( + payloads=[_proto_payload_to_generated(payload) for payload in payloads] + ) def _search_attributes_to_json_map( attributes: temporalio.common.TypedSearchAttributes, -) -> dict[str, Any]: +) -> dict[str, generated.Payload]: return { - pair.key.name: MessageToDict( + pair.key.name: _proto_payload_to_generated( temporalio.converter.encode_typed_search_attribute_value( pair.key, pair.value ) @@ -97,24 +129,13 @@ def _priority_to_generated( def _workflow_id_reuse_policy_to_generated( policy: temporalio.common.WorkflowIDReusePolicy, ) -> generated.WorkflowIDReusePolicy: - return generated.WorkflowIDReusePolicy( - temporalio.api.enums.v1.WorkflowIdReusePolicy.Name( - cast("temporalio.api.enums.v1.WorkflowIdReusePolicy.ValueType", int(policy)) - ) - ) + return _WORKFLOW_ID_REUSE_POLICY_TO_GENERATED[policy] def _workflow_id_conflict_policy_to_generated( policy: temporalio.common.WorkflowIDConflictPolicy, ) -> generated.WorkflowIDConflictPolicy: - return generated.WorkflowIDConflictPolicy( - temporalio.api.enums.v1.WorkflowIdConflictPolicy.Name( - cast( - "temporalio.api.enums.v1.WorkflowIdConflictPolicy.ValueType", - int(policy), - ) - ) - ) + return _WORKFLOW_ID_CONFLICT_POLICY_TO_GENERATED[policy] def _versioning_override_to_generated( @@ -129,9 +150,9 @@ def _versioning_override_to_generated( return generated.VersioningOverride( behavior=generated.VersioningOverrideBehavior.VERSIONING_BEHAVIOR_PINNED, pinnedVersion=versioning_override.version.to_canonical_string(), - pinned=generated.Pinned( - behavior=generated.PinnedBehavior.PINNED_OVERRIDE_BEHAVIOR_PINNED, - version=generated.Version( + pinned=generated.VersioningOverridePinnedOverride( + behavior=generated.VersioningOverridePinnedOverrideBehavior.PINNED_OVERRIDE_BEHAVIOR_PINNED, + version=generated.WorkerDeploymentVersion( deploymentName=versioning_override.version.deployment_name, buildId=versioning_override.version.build_id, ), @@ -155,6 +176,7 @@ def build_signal_with_start_workflow_execution_input( signal: str, signal_args: Sequence[Any], task_queue: str, + request_id: str | None, payload_converter: temporalio.converter.PayloadConverter, execution_timeout: timedelta | None = None, run_timeout: timedelta | None = None, @@ -170,9 +192,9 @@ def build_signal_with_start_workflow_execution_input( start_delay: timedelta | None = None, priority: temporalio.common.Priority = temporalio.common.Priority.default, versioning_override: temporalio.common.VersioningOverride | None = None, -) -> generated.WorkflowServiceSignalWithStartWorkflowExecutionInput: +) -> generated.SignalWithStartWorkflowExecutionRequest: """Build the generated system Nexus input for signal-with-start.""" - return generated.WorkflowServiceSignalWithStartWorkflowExecutionInput( + return generated.SignalWithStartWorkflowExecutionRequest( namespace=namespace, workflowId=workflow_id, workflowType=generated.WorkflowType(name=workflow), @@ -185,6 +207,7 @@ def build_signal_with_start_workflow_execution_input( workflowTaskTimeout=( f"{task_timeout.total_seconds()}s" if task_timeout else None ), + requestId=request_id, workflowIdReusePolicy=_workflow_id_reuse_policy_to_generated(id_reuse_policy), workflowIdConflictPolicy=_workflow_id_conflict_policy_to_generated( id_conflict_policy diff --git a/temporalio/nexus/system/_workflow_service_generated.py b/temporalio/nexus/system/_workflow_service_generated.py index 585923e9a..5be522798 100644 --- a/temporalio/nexus/system/_workflow_service_generated.py +++ b/temporalio/nexus/system/_workflow_service_generated.py @@ -6,8 +6,9 @@ import collections.abc import json from dataclasses import dataclass +from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional from google.protobuf.json_format import MessageToDict, ParseDict from nexusrpc import Operation, service @@ -16,36 +17,36 @@ @dataclass -class Header: - """Contains metadata that can be attached to a variety of requests, like starting a - workflow, and - can be propagated between, for example, workflows and activities. - """ - - fields: Optional[Dict[str, Any]] = None +class PayloadExternalPayloadDetails: + sizeBytes: Optional[str] = None @dataclass -class Input: - """Serialized arguments to the workflow. These are passed as arguments to the workflow - function. +class Payload: + data: Optional[str] = None + externalPayloads: Optional[List[PayloadExternalPayloadDetails]] = None + metadata: Optional[Dict[str, str]] = None - See `Payload` - Serialized value(s) to provide with the signal - """ +@dataclass +class Header: + fields: Optional[Dict[str, Payload]] = None - payloads: Optional[List[Any]] = None + +@dataclass +class Payloads: + payloads: Optional[List[Payload]] = None @dataclass -class BatchJob: - """A link to a built-in batch job. - Batch jobs can be used to perform operations on a set of workflows (e.g. terminate, - signal, cancel, etc). - This link can be put on workflow history events generated by actions taken by a batch job. - """ +class LinkActivity: + activityId: Optional[str] = None + namespace: Optional[str] = None + runId: Optional[str] = None + +@dataclass +class LinkBatchJob: jobId: Optional[str] = None @@ -146,6 +147,9 @@ class EventType(str, Enum): "EVENT_TYPE_WORKFLOW_EXECUTION_TERMINATED" ) EVENT_TYPE_WORKFLOW_EXECUTION_TIMED_OUT = "EVENT_TYPE_WORKFLOW_EXECUTION_TIMED_OUT" + EVENT_TYPE_WORKFLOW_EXECUTION_TIME_SKIPPING_TRANSITIONED = ( + "EVENT_TYPE_WORKFLOW_EXECUTION_TIME_SKIPPING_TRANSITIONED" + ) EVENT_TYPE_WORKFLOW_EXECUTION_UNPAUSED = "EVENT_TYPE_WORKFLOW_EXECUTION_UNPAUSED" EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_ACCEPTED = ( "EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_ACCEPTED" @@ -171,240 +175,89 @@ class EventType(str, Enum): @dataclass -class EventRef: - """EventReference is a direct reference to a history event through the event ID.""" - +class WorkflowEventEventReference: eventId: Optional[str] = None eventType: Optional["EventType"] = None @dataclass -class RequestIDRef: - """RequestIdReference is a indirect reference to a history event through the request ID.""" - +class WorkflowEventRequestIDReference: eventType: Optional["EventType"] = None requestId: Optional[str] = None @dataclass -class WorkflowEvent: - eventRef: Optional[EventRef] = None +class LinkWorkflowEvent: + eventRef: Optional[WorkflowEventEventReference] = None namespace: Optional[str] = None - requestIdRef: Optional[RequestIDRef] = None + requestIdRef: Optional[WorkflowEventRequestIDReference] = None runId: Optional[str] = None workflowId: Optional[str] = None @dataclass -class Openapiv3: - """Link can be associated with history events. It might contain information about an - external entity - related to the history event. For example, workflow A makes a Nexus call that starts - workflow B: - in this case, a history event in workflow A could contain a Link to the workflow started - event in - workflow B, and vice-versa. - """ - - batchJob: Optional[BatchJob] = None - workflowEvent: Optional[WorkflowEvent] = None +class Link: + activity: Optional[LinkActivity] = None + batchJob: Optional[LinkBatchJob] = None + workflowEvent: Optional[LinkWorkflowEvent] = None @dataclass class Memo: - """A user-defined set of *unindexed* fields that are exposed when listing/searching workflows""" - - fields: Optional[Dict[str, Any]] = None + fields: Optional[Dict[str, Payload]] = None @dataclass class Priority: - """Priority metadata - - Priority contains metadata that controls relative ordering of task processing - when tasks are backed up in a queue. Initially, Priority will be used in - matching (workflow and activity) task queues. Later it may be used in history - task queues and in rate limiting decisions. - - Priority is attached to workflows and activities. By default, activities - inherit Priority from the workflow that created them, but may override fields - when an activity is started or modified. - - Despite being named "Priority", this message also contains fields that - control "fairness" mechanisms. - - For all fields, the field not present or equal to zero/empty string means to - inherit the value from the calling workflow, or if there is no calling - workflow, then use the default value. - - For all fields other than fairness_key, the zero value isn't meaningful so - there's no confusion between inherit/default and a meaningful value. For - fairness_key, the empty string will be interpreted as "inherit". This means - that if a workflow has a non-empty fairness key, you can't override the - fairness key of its activity to the empty string. - - The overall semantics of Priority are: - 1. First, consider "priority": higher priority (lower number) goes first. - 2. Then, consider fairness: try to dispatch tasks for different fairness keys - in proportion to their weight. - - Applications may use any subset of mechanisms that are useful to them and - leave the other fields to use default values. - - Not all queues in the system may support the "full" semantics of all priority - fields. (Currently only support in matching task queues is planned.) - """ - fairnessKey: Optional[str] = None - """Fairness key is a short string that's used as a key for a fairness - balancing mechanism. It may correspond to a tenant id, or to a fixed - string like "high" or "low". The default is the empty string. - - The fairness mechanism attempts to dispatch tasks for a given key in - proportion to its weight. For example, using a thousand distinct tenant - ids, each with a weight of 1.0 (the default) will result in each tenant - getting a roughly equal share of task dispatch throughput. - - (Note: this does not imply equal share of worker capacity! Fairness - decisions are made based on queue statistics, not - current worker load.) - - As another example, using keys "high" and "low" with weight 9.0 and 1.0 - respectively will prefer dispatching "high" tasks over "low" tasks at a - 9:1 ratio, while allowing either key to use all worker capacity if the - other is not present. - - All fairness mechanisms, including rate limits, are best-effort and - probabilistic. The results may not match what a "perfect" algorithm with - infinite resources would produce. The more unique keys are used, the less - accurate the results will be. - - Fairness keys are limited to 64 bytes. - """ fairnessWeight: Optional[float] = None - """Fairness weight for a task can come from multiple sources for - flexibility. From highest to lowest precedence: - 1. Weights for a small set of keys can be overridden in task queue - configuration with an API. - 2. It can be attached to the workflow/activity in this field. - 3. The default weight of 1.0 will be used. - - Weight values are clamped to the range [0.001, 1000]. - """ priorityKey: Optional[int] = None - """Priority key is a positive integer from 1 to n, where smaller integers - correspond to higher priorities (tasks run sooner). In general, tasks in - a queue should be processed in close to priority order, although small - deviations are possible. - - The maximum priority value (minimum priority) is determined by server - configuration, and defaults to 5. - - If priority is not present (or zero), then the effective priority will be - the default priority, which is calculated by (min+max)/2. With the - default max of 5, and min of 1, that comes out to 3. - """ @dataclass class RetryPolicy: - """Retry policy for the workflow - - How retries ought to be handled, usable by both workflows and activities - """ - backoffCoefficient: Optional[float] = None - """Coefficient used to calculate the next retry interval. - The next retry interval is previous interval multiplied by the coefficient. - Must be 1 or larger. - """ initialInterval: Optional[str] = None - """Interval of the first retry. If retryBackoffCoefficient is 1.0 then it is used for all - retries. - """ maximumAttempts: Optional[int] = None - """Maximum number of attempts. When exceeded the retries stop even if not expired yet. - 1 disables retries. 0 means unlimited (up to the timeouts) - """ maximumInterval: Optional[str] = None - """Maximum interval between retries. Exponential backoff leads to interval increase. - This value is the cap of the increase. Default is 100x of the initial interval. - """ nonRetryableErrorTypes: Optional[List[str]] = None - """Non-Retryable errors types. Will stop retrying if the error type matches this list. Note - that - this is not a substring match, the error *type* (not message) must match exactly. - """ @dataclass class SearchAttributes: - """A user-defined set of *indexed* fields that are used/exposed when listing/searching - workflows. - The payload is not serialized in a user-defined way. - """ - - indexedFields: Optional[Dict[str, Any]] = None + indexedFields: Optional[Dict[str, Payload]] = None class Kind(str, Enum): - """Default: TASK_QUEUE_KIND_NORMAL.""" - TASK_QUEUE_KIND_NORMAL = "TASK_QUEUE_KIND_NORMAL" TASK_QUEUE_KIND_STICKY = "TASK_QUEUE_KIND_STICKY" TASK_QUEUE_KIND_UNSPECIFIED = "TASK_QUEUE_KIND_UNSPECIFIED" + TASK_QUEUE_KIND_WORKER_COMMANDS = "TASK_QUEUE_KIND_WORKER_COMMANDS" @dataclass class TaskQueue: - """The task queue to start this workflow on, if it will be started - - See https://docs.temporal.io/docs/concepts/task-queues/ - """ - kind: Optional["Kind"] = None - """Default: TASK_QUEUE_KIND_NORMAL.""" - name: Optional[str] = None normalName: Optional[str] = None - """Iff kind == TASK_QUEUE_KIND_STICKY, then this field contains the name of - the normal task queue that the sticky worker is running on. - """ + + +@dataclass +class TimeSkippingConfig: + disablePropagation: Optional[bool] = None + enabled: Optional[bool] = None + maxElapsedDuration: Optional[str] = None + maxSkippedDuration: Optional[str] = None + maxTargetTime: Optional[datetime] = None @dataclass class UserMetadata: - """Metadata on the workflow if it is started. This is carried over to the - WorkflowExecutionInfo - for use by user interfaces to display the fixed as-of-start summary and details of the - workflow. - - Information a user can set, often for use by user interfaces. - """ - - details: Any - """Long-form text that provides details. This payload should be a "json/plain"-encoded - payload - that is a single JSON string for use in user interfaces. User interface formatting may - apply to - this text in common use. The payload data section is limited to 20000 bytes by default. - """ - summary: Any - """Short-form text that provides a summary. This payload should be a "json/plain"-encoded - payload - that is a single JSON string for use in user interfaces. User interface formatting may - not - apply to this text when used in "title" situations. The payload data section is limited - to 400 - bytes by default. - """ + details: Optional[Payload] = None + summary: Optional[Payload] = None class VersioningOverrideBehavior(str, Enum): - """Required. - Deprecated. Use `override`. - """ - VERSIONING_BEHAVIOR_AUTO_UPGRADE = "VERSIONING_BEHAVIOR_AUTO_UPGRADE" VERSIONING_BEHAVIOR_PINNED = "VERSIONING_BEHAVIOR_PINNED" VERSIONING_BEHAVIOR_UNSPECIFIED = "VERSIONING_BEHAVIOR_UNSPECIFIED" @@ -412,145 +265,37 @@ class VersioningOverrideBehavior(str, Enum): @dataclass class Deployment: - """Required if behavior is `PINNED`. Must be null if behavior is `AUTO_UPGRADE`. - Identifies the worker deployment to pin the workflow to. - Deprecated. Use `override.pinned.version`. - - `Deployment` identifies a deployment of Temporal workers. The combination of deployment - series - name + build ID serves as the identifier. User can use `WorkerDeploymentOptions` in their - worker - programs to specify these values. - Deprecated. - """ - buildId: Optional[str] = None - """Build ID changes with each version of the worker when the worker program code and/or - config - changes. - """ seriesName: Optional[str] = None - """Different versions of the same worker service/application are related together by having - a - shared series name. - Out of all deployments of a series, one can be designated as the current deployment, - which - receives new workflow executions and new tasks of workflows with - `VERSIONING_BEHAVIOR_AUTO_UPGRADE` versioning behavior. - """ - -class PinnedBehavior(str, Enum): - """Defaults to PINNED_OVERRIDE_BEHAVIOR_UNSPECIFIED. - See `PinnedOverrideBehavior` for details. - """ +class VersioningOverridePinnedOverrideBehavior(str, Enum): PINNED_OVERRIDE_BEHAVIOR_PINNED = "PINNED_OVERRIDE_BEHAVIOR_PINNED" PINNED_OVERRIDE_BEHAVIOR_UNSPECIFIED = "PINNED_OVERRIDE_BEHAVIOR_UNSPECIFIED" @dataclass -class Version: - """Specifies the Worker Deployment Version to pin this workflow to. - Required if the target workflow is not already pinned to a version. - - If omitted and the target workflow is already pinned, the effective - pinned version will be the existing pinned version. - - If omitted and the target workflow is not pinned, the override request - will be rejected with a PreconditionFailed error. - - A Worker Deployment Version (Version, for short) represents a - version of workers within a Worker Deployment. (see documentation of - WorkerDeploymentVersionInfo) - Version records are created in Temporal server automatically when their - first poller arrives to the server. - Experimental. Worker Deployment Versions are experimental and might significantly change - in the future. - """ - +class WorkerDeploymentVersion: buildId: Optional[str] = None - """A unique identifier for this Version within the Deployment it is a part of. - Not necessarily unique within the namespace. - The combination of `deployment_name` and `build_id` uniquely identifies this - Version within the namespace, because Deployment names are unique within a namespace. - """ deploymentName: Optional[str] = None - """Identifies the Worker Deployment this Version is part of.""" @dataclass -class Pinned: - """Override the workflow to have Pinned behavior.""" - - behavior: Optional["PinnedBehavior"] = None - """Defaults to PINNED_OVERRIDE_BEHAVIOR_UNSPECIFIED. - See `PinnedOverrideBehavior` for details. - """ - version: Optional[Version] = None - """Specifies the Worker Deployment Version to pin this workflow to. - Required if the target workflow is not already pinned to a version. - - If omitted and the target workflow is already pinned, the effective - pinned version will be the existing pinned version. - - If omitted and the target workflow is not pinned, the override request - will be rejected with a PreconditionFailed error. - """ +class VersioningOverridePinnedOverride: + behavior: Optional["VersioningOverridePinnedOverrideBehavior"] = None + version: Optional[WorkerDeploymentVersion] = None @dataclass class VersioningOverride: - """If set, takes precedence over the Versioning Behavior sent by the SDK on Workflow Task - completion. - To unset the override after the workflow is running, use UpdateWorkflowExecutionOptions. - - Used to override the versioning behavior (and pinned deployment version, if applicable) - of a - specific workflow execution. If set, this override takes precedence over worker-sent - values. - See `WorkflowExecutionInfo.VersioningInfo` for more information. - - To remove the override, call `UpdateWorkflowExecutionOptions` with a null - `VersioningOverride`, and use the `update_mask` to indicate that it should be mutated. - - Pinned behavior overrides are automatically inherited by child workflows, workflow - retries, continue-as-new - workflows, and cron workflows. - """ - autoUpgrade: Optional[bool] = None - """Override the workflow to have AutoUpgrade behavior.""" - behavior: Optional["VersioningOverrideBehavior"] = None - """Required. - Deprecated. Use `override`. - """ deployment: Optional[Deployment] = None - """Required if behavior is `PINNED`. Must be null if behavior is `AUTO_UPGRADE`. - Identifies the worker deployment to pin the workflow to. - Deprecated. Use `override.pinned.version`. - """ - pinned: Optional[Pinned] = None - """Override the workflow to have Pinned behavior.""" - + pinned: Optional[VersioningOverridePinnedOverride] = None pinnedVersion: Optional[str] = None - """Required if behavior is `PINNED`. Must be absent if behavior is not `PINNED`. - Identifies the worker deployment version to pin the workflow to, in the format - ".". - Deprecated. Use `override.pinned.version`. - """ class WorkflowIDConflictPolicy(str, Enum): - """Defines how to resolve a workflow id conflict with a *running* workflow. - The default policy is WORKFLOW_ID_CONFLICT_POLICY_USE_EXISTING. - Note that WORKFLOW_ID_CONFLICT_POLICY_FAIL is an invalid option. - - See `workflow_id_reuse_policy` for handling a workflow id duplication with a *closed* - workflow. - """ - WORKFLOW_ID_CONFLICT_POLICY_FAIL = "WORKFLOW_ID_CONFLICT_POLICY_FAIL" WORKFLOW_ID_CONFLICT_POLICY_TERMINATE_EXISTING = ( "WORKFLOW_ID_CONFLICT_POLICY_TERMINATE_EXISTING" @@ -562,13 +307,6 @@ class WorkflowIDConflictPolicy(str, Enum): class WorkflowIDReusePolicy(str, Enum): - """Defines whether to allow re-using the workflow id from a previously *closed* workflow. - The default policy is WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE. - - See `workflow_id_reuse_policy` for handling a workflow id duplication with a *running* - workflow. - """ - WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE = ( "WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE" ) @@ -586,118 +324,51 @@ class WorkflowIDReusePolicy(str, Enum): @dataclass class WorkflowType: - """Represents the identifier used by a workflow author to define the workflow. Typically, - the - name of a function. This is sometimes referred to as the workflow's "name" - """ - name: Optional[str] = None @dataclass -class WorkflowServiceSignalWithStartWorkflowExecutionInput: +class SignalWithStartWorkflowExecutionRequest: control: Optional[str] = None - """Deprecated.""" - cronSchedule: Optional[str] = None - """See https://docs.temporal.io/docs/content/what-is-a-temporal-cron-job/""" - header: Optional[Header] = None identity: Optional[str] = None - """The identity of the worker/client""" - - input: Optional[Input] = None - """Serialized arguments to the workflow. These are passed as arguments to the workflow - function. - """ - links: Optional[List[Openapiv3]] = None - """Links to be associated with the WorkflowExecutionStarted and WorkflowExecutionSignaled - events. - """ + input: Optional[Payloads] = None + links: Optional[List[Link]] = None memo: Optional[Memo] = None namespace: Optional[str] = None priority: Optional[Priority] = None - """Priority metadata""" - requestId: Optional[str] = None - """Used to de-dupe signal w/ start requests""" - retryPolicy: Optional[RetryPolicy] = None - """Retry policy for the workflow""" - searchAttributes: Optional[SearchAttributes] = None - signalInput: Optional[Input] = None - """Serialized value(s) to provide with the signal""" - + signalInput: Optional[Payloads] = None signalName: Optional[str] = None - """The workflow author-defined name of the signal to send to the workflow""" - taskQueue: Optional[TaskQueue] = None - """The task queue to start this workflow on, if it will be started""" - + timeSkippingConfig: Optional[TimeSkippingConfig] = None userMetadata: Optional[UserMetadata] = None - """Metadata on the workflow if it is started. This is carried over to the - WorkflowExecutionInfo - for use by user interfaces to display the fixed as-of-start summary and details of the - workflow. - """ versioningOverride: Optional[VersioningOverride] = None - """If set, takes precedence over the Versioning Behavior sent by the SDK on Workflow Task - completion. - To unset the override after the workflow is running, use UpdateWorkflowExecutionOptions. - """ workflowExecutionTimeout: Optional[str] = None - """Total workflow execution timeout including retries and continue as new""" - workflowId: Optional[str] = None workflowIdConflictPolicy: Optional["WorkflowIDConflictPolicy"] = None - """Defines how to resolve a workflow id conflict with a *running* workflow. - The default policy is WORKFLOW_ID_CONFLICT_POLICY_USE_EXISTING. - Note that WORKFLOW_ID_CONFLICT_POLICY_FAIL is an invalid option. - - See `workflow_id_reuse_policy` for handling a workflow id duplication with a *closed* - workflow. - """ workflowIdReusePolicy: Optional["WorkflowIDReusePolicy"] = None - """Defines whether to allow re-using the workflow id from a previously *closed* workflow. - The default policy is WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE. - - See `workflow_id_reuse_policy` for handling a workflow id duplication with a *running* - workflow. - """ workflowRunTimeout: Optional[str] = None - """Timeout of a single workflow run""" - workflowStartDelay: Optional[str] = None - """Time to wait before dispatching the first workflow task. Cannot be used with - `cron_schedule`. - Note that the signal will be delivered with the first workflow task. If the workflow - gets - another SignalWithStartWorkflow before the delay a workflow task will be dispatched - immediately - and the rest of the delay period will be ignored, even if that request also had a delay. - Signal via SignalWorkflowExecution will not unblock the workflow. - """ workflowTaskTimeout: Optional[str] = None - """Timeout of a single workflow task""" - workflowType: Optional[WorkflowType] = None @dataclass -class WorkflowServiceSignalWithStartWorkflowExecutionOutput: +class SignalWithStartWorkflowExecutionResponse: runId: Optional[str] = None - """The run id of the workflow that was started - or just signaled, if it was already running.""" - + signalLink: Optional[Link] = None started: Optional[bool] = None - """If true, a new workflow was started.""" @service class WorkflowService: signal_with_start_workflow_execution: Operation[ - WorkflowServiceSignalWithStartWorkflowExecutionInput, - WorkflowServiceSignalWithStartWorkflowExecutionOutput, + SignalWithStartWorkflowExecutionRequest, + SignalWithStartWorkflowExecutionResponse, ] = Operation(name="SignalWithStartWorkflowExecution") @@ -743,7 +414,10 @@ async def _temporal_nexus_visit_header_json(self, value: dict) -> dict: ) return visited - async def _temporal_nexus_visit_input_json(self, value: dict) -> dict: + async def _temporal_nexus_visit_payload_json(self, value: dict) -> dict: + return await self._visit_payload_json(value) + + async def _temporal_nexus_visit_payloads_json(self, value: dict) -> dict: return await self._visit_payloads_json(value) async def _temporal_nexus_visit_memo_json(self, value: dict) -> dict: @@ -772,7 +446,7 @@ async def _temporal_nexus_visit_user_metadata_json(self, value: dict) -> dict: visited["summary"] = await self._visit_payload_json(visited["summary"]) return visited - async def _temporal_nexus_visit_workflow_service_signal_with_start_workflow_execution_input_json( + async def _temporal_nexus_visit_signal_with_start_workflow_execution_request_json( self, value: dict ) -> dict: visited = dict(value) @@ -781,7 +455,7 @@ async def _temporal_nexus_visit_workflow_service_signal_with_start_workflow_exec visited["header"] ) if visited.get("input") is not None: - visited["input"] = await self._temporal_nexus_visit_input_json( + visited["input"] = await self._temporal_nexus_visit_payloads_json( visited["input"] ) if visited.get("memo") is not None: @@ -795,7 +469,7 @@ async def _temporal_nexus_visit_workflow_service_signal_with_start_workflow_exec visited["searchAttributes"] ) if visited.get("signalInput") is not None: - visited["signalInput"] = await self._temporal_nexus_visit_input_json( + visited["signalInput"] = await self._temporal_nexus_visit_payloads_json( visited["signalInput"] ) if visited.get("userMetadata") is not None: @@ -807,7 +481,7 @@ async def _temporal_nexus_visit_workflow_service_signal_with_start_workflow_exec return visited -async def _temporal_nexus_visit_workflow_service_signal_with_start_workflow_execution_input( +async def _temporal_nexus_visit_signal_with_start_workflow_execution_request( payload: temporalio.api.common.v1.Payload, payload_visitor: collections.abc.Callable[ [collections.abc.Sequence[temporalio.api.common.v1.Payload]], @@ -822,7 +496,7 @@ async def _temporal_nexus_visit_workflow_service_signal_with_start_workflow_exec if not isinstance(value, dict): return payload visitor = _TemporalNexusPayloadVisitor(payload_visitor, visit_search_attributes) - visited = await visitor._temporal_nexus_visit_workflow_service_signal_with_start_workflow_execution_input_json( + visited = await visitor._temporal_nexus_visit_signal_with_start_workflow_execution_request_json( value ) return temporalio.api.common.v1.Payload( @@ -835,5 +509,5 @@ async def _temporal_nexus_visit_workflow_service_signal_with_start_workflow_exec ( "WorkflowService", "SignalWithStartWorkflowExecution", - ): _temporal_nexus_visit_workflow_service_signal_with_start_workflow_execution_input, + ): _temporal_nexus_visit_signal_with_start_workflow_execution_request, } diff --git a/temporalio/worker/_interceptor.py b/temporalio/worker/_interceptor.py index 371c5e054..25579e69f 100644 --- a/temporalio/worker/_interceptor.py +++ b/temporalio/worker/_interceptor.py @@ -266,6 +266,7 @@ class SignalWithStartExternalWorkflowInput: static_summary: str | None static_details: str | None start_delay: timedelta | None + request_id: str | None priority: temporalio.common.Priority versioning_override: temporalio.common.VersioningOverride | None headers: Mapping[str, str] | None @@ -484,7 +485,7 @@ async def signal_with_start_external_workflow( self, input: SignalWithStartExternalWorkflowInput ) -> temporalio.workflow.ExternalWorkflowHandle[Any]: """Called for every - :py:meth:`temporalio.workflow.ExternalWorkflowHandle.signal_with_start_workflow` + :py:meth:`temporalio.workflow.ExternalWorkflowHandle.signal_with_start` call. """ return await self.next.signal_with_start_external_workflow(input) diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index f85c6598e..10ae8bb1f 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -12,7 +12,6 @@ import sys import threading import traceback -import uuid import warnings from abc import ABC, abstractmethod from collections import deque @@ -1972,6 +1971,7 @@ async def _outbound_signal_with_start_external_workflow( signal=input.signal, signal_args=input.signal_args, task_queue=input.task_queue, + request_id=input.request_id, payload_converter=payload_converter, execution_timeout=input.execution_timeout, run_timeout=input.run_timeout, @@ -3348,15 +3348,13 @@ async def signal( ) ) - async def signal_with_start_workflow( + async def signal_with_start( self, signal: str | Callable, workflow: str | Callable[..., Awaitable[Any]], - signal_arg: Any = temporalio.common._arg_unset, - workflow_arg: Any = temporalio.common._arg_unset, *, - signal_args: Sequence[Any] = [], - workflow_args: Sequence[Any] = [], + signal_args: Sequence[Any] = (), + workflow_args: Sequence[Any] = (), task_queue: str, execution_timeout: timedelta | None = None, run_timeout: timedelta | None = None, @@ -3370,6 +3368,7 @@ async def signal_with_start_workflow( static_summary: str | None = None, static_details: str | None = None, start_delay: timedelta | None = None, + request_id: str | None = None, priority: temporalio.common.Priority = temporalio.common.Priority.default, versioning_override: temporalio.common.VersioningOverride | None = None, ) -> temporalio.workflow.ExternalWorkflowHandle[Any]: @@ -3382,13 +3381,11 @@ async def signal_with_start_workflow( signal=temporalio.workflow._SignalDefinition.must_name_from_fn_or_str( signal ), - signal_args=temporalio.common._arg_or_args(signal_arg, signal_args), + signal_args=signal_args, namespace=self._instance._info.namespace, workflow_id=self._id, workflow=workflow_name, - workflow_args=temporalio.common._arg_or_args( - workflow_arg, workflow_args - ), + workflow_args=workflow_args, task_queue=task_queue, execution_timeout=execution_timeout, run_timeout=run_timeout, @@ -3402,6 +3399,7 @@ async def signal_with_start_workflow( static_summary=static_summary, static_details=static_details, start_delay=start_delay, + request_id=request_id, priority=priority, versioning_override=versioning_override, headers=None, diff --git a/temporalio/workflow.py b/temporalio/workflow.py index db8c77988..4b678107d 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -4284,6 +4284,7 @@ class ChildWorkflowConfig(TypedDict, total=False): _SYSTEM_NEXUS_ENDPOINT = "__temporal_system" + # Overload for no-param workflow @overload async def start_child_workflow( @@ -4740,12 +4741,10 @@ async def cancel(self) -> None: """ raise NotImplementedError - async def signal_with_start_workflow( + async def signal_with_start( self, signal: str | Callable, # type: ignore[reportUnusedParameter] workflow: str | Callable[..., Awaitable[Any]], # type: ignore[reportUnusedParameter] - signal_arg: Any = temporalio.common._arg_unset, # type: ignore[reportUnusedParameter] - workflow_arg: Any = temporalio.common._arg_unset, # type: ignore[reportUnusedParameter] *, signal_args: Sequence[Any] = [], # type: ignore[reportUnusedParameter] workflow_args: Sequence[Any] = [], # type: ignore[reportUnusedParameter] @@ -4762,6 +4761,7 @@ async def signal_with_start_workflow( static_summary: str | None = None, # type: ignore[reportUnusedParameter] static_details: str | None = None, # type: ignore[reportUnusedParameter] start_delay: timedelta | None = None, # type: ignore[reportUnusedParameter] + request_id: str | None = None, # type: ignore[reportUnusedParameter] priority: temporalio.common.Priority = temporalio.common.Priority.default, # type: ignore[reportUnusedParameter] versioning_override: temporalio.common.VersioningOverride | None = None, # type: ignore[reportUnusedParameter] ) -> ExternalWorkflowHandle[SelfType]: @@ -4775,12 +4775,8 @@ async def signal_with_start_workflow( signal: Name or method reference for the signal. workflow: String name or class method decorated with ``@workflow.run`` for the workflow to start. - signal_arg: Single argument to the signal. - workflow_arg: Single argument to the workflow. - signal_args: Multiple arguments to the signal. Cannot be set if - signal_arg is. - workflow_args: Multiple arguments to the workflow. Cannot be set if - workflow_arg is. + signal_args: Arguments to the signal. + workflow_args: Arguments to the workflow. task_queue: Task queue to run the workflow on if it is started. execution_timeout: Total workflow execution timeout including retries and continue as new. @@ -4797,6 +4793,7 @@ async def signal_with_start_workflow( static_details: General fixed details for this workflow execution that may appear in UI/CLI. start_delay: Time to wait before dispatching the first workflow task. + request_id: Optional idempotency request ID for the start request. priority: Priority to use for this workflow. versioning_override: Versioning override to apply if the workflow is started. @@ -4806,6 +4803,54 @@ async def signal_with_start_workflow( """ raise NotImplementedError + async def signal_with_start_workflow( + self, + signal: str | Callable, # type: ignore[reportUnusedParameter] + workflow: str | Callable[..., Awaitable[Any]], # type: ignore[reportUnusedParameter] + *, + signal_args: Sequence[Any] = [], # type: ignore[reportUnusedParameter] + workflow_args: Sequence[Any] = [], # type: ignore[reportUnusedParameter] + task_queue: str, # type: ignore[reportUnusedParameter] + execution_timeout: timedelta | None = None, # type: ignore[reportUnusedParameter] + run_timeout: timedelta | None = None, # type: ignore[reportUnusedParameter] + task_timeout: timedelta | None = None, # type: ignore[reportUnusedParameter] + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, # type: ignore[reportUnusedParameter] + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, # type: ignore[reportUnusedParameter] + retry_policy: temporalio.common.RetryPolicy | None = None, # type: ignore[reportUnusedParameter] + cron_schedule: str = "", # type: ignore[reportUnusedParameter] + memo: Mapping[str, Any] | None = None, # type: ignore[reportUnusedParameter] + search_attributes: temporalio.common.TypedSearchAttributes | None = None, # type: ignore[reportUnusedParameter] + static_summary: str | None = None, # type: ignore[reportUnusedParameter] + static_details: str | None = None, # type: ignore[reportUnusedParameter] + start_delay: timedelta | None = None, # type: ignore[reportUnusedParameter] + request_id: str | None = None, # type: ignore[reportUnusedParameter] + priority: temporalio.common.Priority = temporalio.common.Priority.default, # type: ignore[reportUnusedParameter] + versioning_override: temporalio.common.VersioningOverride | None = None, # type: ignore[reportUnusedParameter] + ) -> ExternalWorkflowHandle[SelfType]: + """Deprecated alias for :py:meth:`signal_with_start`.""" + return await self.signal_with_start( + signal, + workflow, + signal_args=signal_args, + workflow_args=workflow_args, + task_queue=task_queue, + execution_timeout=execution_timeout, + run_timeout=run_timeout, + task_timeout=task_timeout, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + cron_schedule=cron_schedule, + memo=memo, + search_attributes=search_attributes, + static_summary=static_summary, + static_details=static_details, + start_delay=start_delay, + request_id=request_id, + priority=priority, + versioning_override=versioning_override, + ) + def get_external_workflow_handle( workflow_id: str, diff --git a/tests/nexus/test_temporal_system_nexus.py b/tests/nexus/test_temporal_system_nexus.py index 69d3a7b3b..49a0129aa 100644 --- a/tests/nexus/test_temporal_system_nexus.py +++ b/tests/nexus/test_temporal_system_nexus.py @@ -4,14 +4,14 @@ import json import uuid from collections.abc import Sequence -from typing import cast +from typing import Any, cast import nexusrpc.handler import pytest -from google.protobuf.json_format import MessageToDict import temporalio.api.common.v1 import temporalio.converter +import temporalio.nexus.system as nexus_system from temporalio import workflow from temporalio.client import Client from temporalio.converter import ( @@ -34,6 +34,7 @@ from tests.test_extstore import InMemoryTestDriver interceptor_traces: list[tuple[str, object]] = [] +received_requests: list[dict[str, Any]] = [] @nexusrpc.handler.service_handler(service=generated.WorkflowService) @@ -42,128 +43,26 @@ class WorkflowServicePayloadHandler: async def signal_with_start_workflow_execution( self, _ctx: nexusrpc.handler.StartOperationContext, - request: generated.WorkflowServiceSignalWithStartWorkflowExecutionInput, - ) -> generated.WorkflowServiceSignalWithStartWorkflowExecutionOutput: + request: generated.SignalWithStartWorkflowExecutionRequest, + ) -> generated.SignalWithStartWorkflowExecutionResponse: assert request.workflowId == "system-nexus-workflow-id" assert request.signalName == "test-signal" - request_dict = dataclasses.asdict(request) - for field_name in ("input", "signalInput"): - payloads = request_dict[field_name]["payloads"] - assert payloads[0]["externalPayloads"] - for field_name in ("memo", "header"): - fields = (request_dict.get(field_name) or {}).get("fields") - if fields: - assert next(iter(fields.values()))["externalPayloads"] - for field_name in ("summary", "details"): - payload = (request_dict.get("userMetadata") or {}).get(field_name) - if payload: - assert payload["externalPayloads"] - if search_attributes := (request_dict.get("searchAttributes") or {}).get( - "indexedFields" - ): - search_attribute_payload = search_attributes["custom-key"] - assert "externalPayloads" not in search_attribute_payload - assert "test-codec" not in search_attribute_payload["metadata"] - return generated.WorkflowServiceSignalWithStartWorkflowExecutionOutput( + received_requests.append(dataclasses.asdict(request)) + return generated.SignalWithStartWorkflowExecutionResponse( runId=f"{request.workflowId}-run" ) -@workflow.defn -class SystemNexusCallerWithPayloadsWorkflow: - @workflow.run - async def run(self, task_queue: str) -> str: - nexus_client = workflow.create_nexus_client( - service=generated.WorkflowService, - endpoint=make_nexus_endpoint_name(task_queue), - ) - request = generated.WorkflowServiceSignalWithStartWorkflowExecutionInput( - namespace="default", - workflowId="system-nexus-workflow-id", - signalName="test-signal", - input=generated.Input( - payloads=[ - MessageToDict( - temporalio.api.common.v1.Payload( - metadata={"encoding": b"json/plain"}, - data=b'"workflow-input"', - ) - ) - ] - ), - signalInput=generated.Input( - payloads=[ - MessageToDict( - temporalio.api.common.v1.Payload( - metadata={"encoding": b"json/plain"}, - data=b'"signal-input"', - ) - ) - ] - ), - memo=generated.Memo( - fields={ - "memo-key": MessageToDict( - temporalio.api.common.v1.Payload( - metadata={"encoding": b"json/plain"}, - data=b'"memo-value"', - ) - ) - } - ), - header=generated.Header( - fields={ - "header-key": MessageToDict( - temporalio.api.common.v1.Payload( - metadata={"encoding": b"json/plain"}, - data=b'"header-value"', - ) - ) - } - ), - userMetadata=generated.UserMetadata( - summary=MessageToDict( - temporalio.api.common.v1.Payload( - metadata={"encoding": b"json/plain"}, - data=b'"summary-value"', - ) - ), - details=MessageToDict( - temporalio.api.common.v1.Payload( - metadata={"encoding": b"json/plain"}, - data=b'"details-value"', - ) - ), - ), - searchAttributes=generated.SearchAttributes( - indexedFields={ - "custom-key": MessageToDict( - temporalio.api.common.v1.Payload( - metadata={"encoding": b"json/plain"}, - data=b'"search-attribute-value"', - ) - ) - } - ), - ) - handle = await nexus_client.start_operation( - generated.WorkflowService.signal_with_start_workflow_execution, - request, - ) - result = await handle - return cast(str, result.runId) - - @workflow.defn class ExternalHandleSignalWithStartWorkflowCaller: @workflow.run async def run(self, task_queue: str) -> str: handle = workflow.get_external_workflow_handle("system-nexus-workflow-id") - started_handle = await handle.signal_with_start_workflow( + started_handle = await handle.signal_with_start( "test-signal", "test-workflow", - "signal-input", - "workflow-input", + signal_args=["signal-input"], + workflow_args=["workflow-input"], task_queue=task_queue, memo={"memo-key": "memo-value"}, static_summary="summary-value", @@ -229,9 +128,7 @@ def to_payloads( ) -> list[temporalio.api.common.v1.Payload]: payloads: list[temporalio.api.common.v1.Payload] = [] for value in values: - if isinstance( - value, generated.WorkflowServiceSignalWithStartWorkflowExecutionInput - ): + if isinstance(value, generated.SignalWithStartWorkflowExecutionRequest): payloads.append( temporalio.api.common.v1.Payload( metadata={"encoding": b"json/plain"}, @@ -265,70 +162,52 @@ async def signal_with_start_external_workflow( return await super().signal_with_start_external_workflow(input) -async def test_workflow_service_signal_with_start_nested_payloads_use_codec_without_encoding_outer_envelope( - env: WorkflowEnvironment, -): - if env.supports_time_skipping: - pytest.skip("Nexus tests don't work with the Java test server") +def _pop_received_request() -> dict[str, Any]: + assert len(received_requests) == 1 + return received_requests.pop() - codec = RejectOuterSystemNexusCodec() - driver = InMemoryTestDriver() - caller_config = env.client.config() - caller_config["data_converter"] = dataclasses.replace( - temporalio.converter.default(), - payload_converter_class=BadSystemNexusEnvelopePayloadConverter, - payload_codec=codec, - external_storage=ExternalStorage( - drivers=[driver], - payload_size_threshold=1, - ), - ) - caller_client = Client(**caller_config) - handler_config = env.client.config() - handler_config["data_converter"] = temporalio.converter.default() - handler_client = Client(**handler_config) - caller_task_queue = str(uuid.uuid4()) - handler_task_queue = str(uuid.uuid4()) - caller_worker = Worker( - caller_client, - task_queue=caller_task_queue, - workflows=[SystemNexusCallerWithPayloadsWorkflow], - workflow_runner=UnsandboxedWorkflowRunner(), - ) - handler_worker = Worker( - handler_client, - task_queue=handler_task_queue, - nexus_service_handlers=[WorkflowServicePayloadHandler()], +def _assert_request_payload_was_externally_stored( + request_dict: dict[str, Any], field_name: str +) -> None: + payloads = cast("dict[str, list[dict[str, object]]]", request_dict[field_name])[ + "payloads" + ] + assert len(payloads) == 1 + assert payloads[0]["externalPayloads"] + + +def _assert_request_user_metadata_was_externally_stored( + request_dict: dict[str, Any], +) -> None: + user_metadata = cast( + "dict[str, dict[str, object]] | None", request_dict.get("userMetadata") ) + assert user_metadata is not None + assert user_metadata["summary"]["externalPayloads"] + assert user_metadata["details"]["externalPayloads"] - async with caller_worker, handler_worker: - endpoint_name = make_nexus_endpoint_name(handler_task_queue) - await env.create_nexus_endpoint(endpoint_name, handler_task_queue) - result = await caller_client.execute_workflow( - SystemNexusCallerWithPayloadsWorkflow.run, - handler_task_queue, - id=str(uuid.uuid4()), - task_queue=caller_task_queue, - ) - assert result == "system-nexus-workflow-id-run" - assert codec.encode_count >= 6 - stored_payloads: list[temporalio.api.common.v1.Payload] = [] +def _assert_stored_payloads_include( + driver: InMemoryTestDriver, expected_payload_data: set[bytes] +) -> None: + stored_payload_data: set[bytes] = set() for stored_payload_bytes in driver._storage.values(): stored_payload = temporalio.api.common.v1.Payload() stored_payload.ParseFromString(stored_payload_bytes) - stored_payloads.append(stored_payload) assert stored_payload.metadata["test-codec"] == b"true" - stored_payload_data = {payload.data for payload in stored_payloads} - assert { - b'"workflow-input"', - b'"signal-input"', - b'"memo-value"', - b'"header-value"', - b'"summary-value"', - b'"details-value"', - }.issubset(stored_payload_data) + stored_payload_data.add(stored_payload.data) + assert expected_payload_data.issubset(stored_payload_data) + + +def _assert_signal_with_start_interceptor_trace() -> None: + assert len(interceptor_traces) == 1 + trace_name, trace_value = interceptor_traces.pop() + assert trace_name == "workflow.signal_with_start_external_workflow" + trace_input = cast(SignalWithStartExternalWorkflowInput, trace_value) + assert trace_input.workflow_id == "system-nexus-workflow-id" + assert trace_input.signal == "test-signal" + assert trace_input.workflow == "test-workflow" async def test_external_workflow_handle_signal_with_start_workflow_uses_system_nexus( @@ -340,6 +219,7 @@ async def test_external_workflow_handle_signal_with_start_workflow_uses_system_n codec = RejectOuterSystemNexusCodec() interceptor_traces.clear() + received_requests.clear() driver = InMemoryTestDriver() caller_config = env.client.config() caller_config["data_converter"] = dataclasses.replace( @@ -382,24 +262,19 @@ async def test_external_workflow_handle_signal_with_start_workflow_uses_system_n ) assert result == "system-nexus-workflow-id-run" + request_dict = _pop_received_request() + _assert_request_payload_was_externally_stored(request_dict, "input") + _assert_request_payload_was_externally_stored(request_dict, "signalInput") + _assert_request_user_metadata_was_externally_stored(request_dict) assert codec.encode_count >= 5 - stored_payloads: list[temporalio.api.common.v1.Payload] = [] - for stored_payload_bytes in driver._storage.values(): - stored_payload = temporalio.api.common.v1.Payload() - stored_payload.ParseFromString(stored_payload_bytes) - stored_payloads.append(stored_payload) - assert stored_payload.metadata["test-codec"] == b"true" - stored_payload_data = {payload.data for payload in stored_payloads} - assert { - b'"workflow-input"', - b'"signal-input"', - b'"memo-value"', - b'"summary-value"', - b'"details-value"', - }.issubset(stored_payload_data) - trace = interceptor_traces.pop() - assert trace[0] == "workflow.signal_with_start_external_workflow" - trace_input = cast(SignalWithStartExternalWorkflowInput, trace[1]) - assert trace_input.workflow_id == "system-nexus-workflow-id" - assert trace_input.signal == "test-signal" - assert trace_input.workflow == "test-workflow" + _assert_stored_payloads_include( + driver, + { + b'"workflow-input"', + b'"signal-input"', + b'"memo-value"', + b'"summary-value"', + b'"details-value"', + }, + ) + _assert_signal_with_start_interceptor_trace() diff --git a/tests/worker/test_visitor.py b/tests/worker/test_visitor.py index 71027c7d0..d57077951 100644 --- a/tests/worker/test_visitor.py +++ b/tests/worker/test_visitor.py @@ -1,10 +1,10 @@ +import base64 import dataclasses import json from collections.abc import MutableSequence import pytest from google.protobuf.duration_pb2 import Duration -from google.protobuf.json_format import MessageToDict import temporalio.bridge.worker import temporalio.converter @@ -217,49 +217,84 @@ async def test_visit_payloads_on_other_commands(): assert ur.completed.metadata["visited"] +async def test_bridge_encoding(): + comp = WorkflowActivationCompletion( + run_id="1", + successful=Success( + commands=[ + WorkflowCommand( + schedule_activity=ScheduleActivity( + seq=1, + activity_id="1", + activity_type="", + task_queue="", + headers={"foo": Payload(data=b"bar")}, + arguments=[ + Payload(data=b"repeated1"), + Payload(data=b"repeated2"), + ], + schedule_to_close_timeout=Duration(seconds=5), + priority=Priority(), + ), + user_metadata=UserMetadata(summary=Payload(data=b"Summary")), + ) + ], + ), + ) + + data_converter = dataclasses.replace( + temporalio.converter.default(), + payload_codec=SimpleCodec(), + ) + + await temporalio.bridge.worker.encode_completion(comp, data_converter, True) + + cmd = comp.successful.commands[0] + sa = cmd.schedule_activity + assert sa.headers["foo"].metadata["simple-codec"] + assert len(sa.arguments) == 1 + assert sa.arguments[0].metadata["simple-codec"] + + assert cmd.user_metadata.summary.metadata["simple-codec"] + + async def test_visit_system_nexus_payloads_on_schedule_nexus_operation(): - envelope = ( - nexus_system.generated.WorkflowServiceSignalWithStartWorkflowExecutionInput( - namespace="default", - workflowId="workflow-id", - signalName="signal-name", - input=nexus_system.generated.Input( - payloads=[ - MessageToDict( - Payload( - metadata={"encoding": b"json/plain"}, data=b'"input-value"' - ) - ) - ] - ), - signalInput=nexus_system.generated.Input( - payloads=[ - MessageToDict( - Payload( - metadata={"encoding": b"json/plain"}, data=b'"signal-value"' - ) - ) - ] - ), - memo=nexus_system.generated.Memo( - fields={ - "memo-key": MessageToDict( - Payload( - metadata={"encoding": b"json/plain"}, data=b'"memo-value"' - ) - ) - } - ), - searchAttributes=nexus_system.generated.SearchAttributes( - indexedFields={ - "search-key": MessageToDict( - Payload( - metadata={"encoding": b"json/plain"}, data=b'"search-value"' - ) - ) - } - ), - ) + envelope = nexus_system.generated.SignalWithStartWorkflowExecutionRequest( + namespace="default", + workflowId="workflow-id", + signalName="signal-name", + input=nexus_system.generated.Payloads( + payloads=[ + nexus_system.generated.Payload( + data="ImlucHV0LXZhbHVlIg==", + metadata={"encoding": "anNvbi9wbGFpbg=="}, + ) + ] + ), + signalInput=nexus_system.generated.Payloads( + payloads=[ + nexus_system.generated.Payload( + data="InNpZ25hbC12YWx1ZSI=", + metadata={"encoding": "anNvbi9wbGFpbg=="}, + ) + ] + ), + memo=nexus_system.generated.Memo( + fields={ + "memo-key": nexus_system.generated.Payload( + data="Im1lbW8tdmFsdWUi", + metadata={"encoding": "anNvbi9wbGFpbg=="}, + ) + } + ), + searchAttributes=nexus_system.generated.SearchAttributes( + indexedFields={ + "search-key": nexus_system.generated.Payload( + data="InNlYXJjaC12YWx1ZSI=", + metadata={"encoding": "anNvbi9wbGFpbg=="}, + ) + } + ), ) comp = WorkflowActivationCompletion( run_id="1", @@ -270,14 +305,7 @@ async def test_visit_system_nexus_payloads_on_schedule_nexus_operation(): seq=1, service="WorkflowService", operation="SignalWithStartWorkflowExecution", - input=Payload( - metadata={"encoding": b"json/plain"}, - data=json.dumps( - dataclasses.asdict(envelope), - separators=(",", ":"), - sort_keys=True, - ).encode(), - ), + input=nexus_system.get_payload_converter().to_payload(envelope), ) ) ], @@ -289,9 +317,18 @@ async def test_visit_system_nexus_payloads_on_schedule_nexus_operation(): input_payload = comp.successful.commands[0].schedule_nexus_operation.input assert input_payload.metadata["visited"] rewritten = json.loads(input_payload.data) - assert rewritten["input"]["payloads"][0]["metadata"]["visited"] == "VHJ1ZQ==" - assert rewritten["signalInput"]["payloads"][0]["metadata"]["visited"] == "VHJ1ZQ==" - assert rewritten["memo"]["fields"]["memo-key"]["metadata"]["visited"] == "VHJ1ZQ==" + assert ( + base64.b64decode(rewritten["input"]["payloads"][0]["metadata"]["visited"]) + == b"True" + ) + assert ( + base64.b64decode(rewritten["signalInput"]["payloads"][0]["metadata"]["visited"]) + == b"True" + ) + assert ( + base64.b64decode(rewritten["memo"]["fields"]["memo-key"]["metadata"]["visited"]) + == b"True" + ) assert ( "visited" not in rewritten["searchAttributes"]["indexedFields"]["search-key"]["metadata"] @@ -299,22 +336,19 @@ async def test_visit_system_nexus_payloads_on_schedule_nexus_operation(): async def test_bridge_encoding_checks_system_nexus_envelope_size(): - envelope = ( - nexus_system.generated.WorkflowServiceSignalWithStartWorkflowExecutionInput( - namespace="default", - workflowId="workflow-id", - signalName="signal-name", - requestId="x" * 2048, - input=nexus_system.generated.Input( - payloads=[ - MessageToDict( - Payload( - metadata={"encoding": b"json/plain"}, data=b'"input-value"' - ) - ) - ] - ), - ) + envelope = nexus_system.generated.SignalWithStartWorkflowExecutionRequest( + namespace="default", + workflowId="workflow-id", + signalName="signal-name", + requestId="x" * 2048, + input=nexus_system.generated.Payloads( + payloads=[ + nexus_system.generated.Payload( + data="ImlucHV0LXZhbHVlIg==", + metadata={"encoding": "anNvbi9wbGFpbg=="}, + ) + ] + ), ) comp = WorkflowActivationCompletion( run_id="1", @@ -325,14 +359,7 @@ async def test_bridge_encoding_checks_system_nexus_envelope_size(): seq=1, service="WorkflowService", operation="SignalWithStartWorkflowExecution", - input=Payload( - metadata={"encoding": b"json/plain"}, - data=json.dumps( - dataclasses.asdict(envelope), - separators=(",", ":"), - sort_keys=True, - ).encode(), - ), + input=nexus_system.get_payload_converter().to_payload(envelope), ) ) ], @@ -348,44 +375,3 @@ async def test_bridge_encoding_checks_system_nexus_envelope_size(): with pytest.raises(_PayloadSizeError, match="payloads with size that exceeded"): await temporalio.bridge.worker.encode_completion(comp, data_converter, True) - - -async def test_bridge_encoding(): - comp = WorkflowActivationCompletion( - run_id="1", - successful=Success( - commands=[ - WorkflowCommand( - schedule_activity=ScheduleActivity( - seq=1, - activity_id="1", - activity_type="", - task_queue="", - headers={"foo": Payload(data=b"bar")}, - arguments=[ - Payload(data=b"repeated1"), - Payload(data=b"repeated2"), - ], - schedule_to_close_timeout=Duration(seconds=5), - priority=Priority(), - ), - user_metadata=UserMetadata(summary=Payload(data=b"Summary")), - ) - ], - ), - ) - - data_converter = dataclasses.replace( - temporalio.converter.default(), - payload_codec=SimpleCodec(), - ) - - await temporalio.bridge.worker.encode_completion(comp, data_converter, True) - - cmd = comp.successful.commands[0] - sa = cmd.schedule_activity - assert sa.headers["foo"].metadata["simple-codec"] - assert len(sa.arguments) == 1 - assert sa.arguments[0].metadata["simple-codec"] - - assert cmd.user_metadata.summary.metadata["simple-codec"] From 6b2f5619058fd3f79dc0f5ee78f72c2b8da226cf Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Tue, 21 Apr 2026 10:17:31 -0700 Subject: [PATCH 13/18] Tighten Python system nexus payload conversion --- temporalio/nexus/system/__init__.py | 151 ++++++++++++---- .../system/_workflow_service.nexusrpc.yaml | 11 -- .../system/_workflow_service_generated.py | 137 +++++++++------ temporalio/worker/_workflow_instance.py | 30 ++-- tests/nexus/test_temporal_system_nexus.py | 166 ++++++++++++++++-- tests/worker/test_visitor.py | 16 +- 6 files changed, 375 insertions(+), 136 deletions(-) delete mode 100644 temporalio/nexus/system/_workflow_service.nexusrpc.yaml diff --git a/temporalio/nexus/system/__init__.py b/temporalio/nexus/system/__init__.py index 362502d3b..63037d6dc 100644 --- a/temporalio/nexus/system/__init__.py +++ b/temporalio/nexus/system/__init__.py @@ -4,16 +4,21 @@ Higher-level ergonomic APIs may wrap these generated types. """ +import dataclasses from collections.abc import Awaitable, Callable, Mapping, Sequence from datetime import timedelta +from enum import Enum from typing import Any, cast -from google.protobuf.json_format import MessageToDict +import google.protobuf.message +from google.protobuf.json_format import MessageToDict, Parse, ParseDict import temporalio.api.common.v1 import temporalio.common import temporalio.converter +from ...converter import CompositePayloadConverter, JSONProtoPayloadConverter +from ...converter._payload_converter import value_to_type from . import _workflow_service_generated as generated from ._workflow_service_generated import __temporal_nexus_payload_visitors__ @@ -46,6 +51,72 @@ } +class _SystemNexusJSONProtoPayloadConverter(JSONProtoPayloadConverter): + def to_payload(self, value: Any) -> temporalio.api.common.v1.Payload | None: + proto_type = _get_generated_proto_type(value) + if proto_type is not None: + return super().to_payload( + ParseDict( + dataclasses.asdict(value), + proto_type(), + ignore_unknown_fields=True, + ) + ) + return super().to_payload(value) + + def from_payload( + self, + payload: temporalio.api.common.v1.Payload, + type_hint: type | None = None, + ) -> Any: + proto_type = _get_generated_proto_type(type_hint) + if proto_type is not None and type_hint is not None: + proto_value = Parse( + payload.data, + proto_type(), + ignore_unknown_fields=True, + ) + return value_to_type( + type_hint, + MessageToDict(proto_value, preserving_proto_field_name=True), + [_SystemNexusStrEnumConverter()], + ) + return super().from_payload(payload, type_hint) + + +class SystemNexusPayloadConverter(CompositePayloadConverter): + """Payload converter for system Nexus outer envelopes.""" + + def __init__(self) -> None: + """Create a payload converter for system Nexus outer envelopes.""" + super().__init__(_SystemNexusJSONProtoPayloadConverter()) + + +def _get_generated_proto_type( + value_or_type: Any, +) -> type[google.protobuf.message.Message] | None: + candidate = ( + value_or_type if isinstance(value_or_type, type) else type(value_or_type) + ) + proto_type = getattr(candidate, "__temporal_nexus_proto_type__", None) + if isinstance(proto_type, type) and issubclass( + proto_type, google.protobuf.message.Message + ): + return proto_type + return None + + +class _SystemNexusStrEnumConverter(temporalio.converter.JSONTypeConverter): + # Generated enums subclass str and Enum, not StrEnum, so the default + # value_to_type enum handling does not reconstruct them. + def to_typed_value(self, hint: type, value: Any) -> Any: + if isinstance(hint, type) and issubclass(hint, Enum) and issubclass(hint, str): + if not isinstance(value, str): + raise TypeError(f"Expected value to be str, was {type(value)}") + return hint(value) + return temporalio.converter.JSONTypeConverter.Unhandled + + def _payload_to_json_value( converter: temporalio.converter.PayloadConverter, value: Any ) -> generated.Payload: @@ -58,7 +129,7 @@ def _proto_payload_to_generated( value = MessageToDict(payload) return generated.Payload( data=cast("str | None", value.get("data")), - externalPayloads=[ + external_payloads=[ generated.PayloadExternalPayloadDetails(**details) for details in cast( "list[dict[str, str]]", value.get("externalPayloads", []) @@ -98,11 +169,11 @@ def _retry_policy_to_generated( ) -> generated.RetryPolicy: retry_policy._validate() return generated.RetryPolicy( - initialInterval=f"{retry_policy.initial_interval.total_seconds()}s", - backoffCoefficient=retry_policy.backoff_coefficient, - maximumInterval=f"{(retry_policy.maximum_interval or retry_policy.initial_interval * 100).total_seconds()}s", - maximumAttempts=retry_policy.maximum_attempts, - nonRetryableErrorTypes=( + initial_interval=f"{retry_policy.initial_interval.total_seconds()}s", + backoff_coefficient=retry_policy.backoff_coefficient, + maximum_interval=f"{(retry_policy.maximum_interval or retry_policy.initial_interval * 100).total_seconds()}s", + maximum_attempts=retry_policy.maximum_attempts, + non_retryable_error_types=( list(retry_policy.non_retryable_error_types) if retry_policy.non_retryable_error_types else None @@ -120,9 +191,9 @@ def _priority_to_generated( ): return None return generated.Priority( - priorityKey=priority.priority_key, - fairnessKey=priority.fairness_key, - fairnessWeight=priority.fairness_weight, + priority_key=priority.priority_key, + fairness_key=priority.fairness_key, + fairness_weight=priority.fairness_weight, ) @@ -143,23 +214,23 @@ def _versioning_override_to_generated( ) -> generated.VersioningOverride: if isinstance(versioning_override, temporalio.common.AutoUpgradeVersioningOverride): return generated.VersioningOverride( - autoUpgrade=True, + auto_upgrade=True, behavior=generated.VersioningOverrideBehavior.VERSIONING_BEHAVIOR_AUTO_UPGRADE, ) if isinstance(versioning_override, temporalio.common.PinnedVersioningOverride): return generated.VersioningOverride( behavior=generated.VersioningOverrideBehavior.VERSIONING_BEHAVIOR_PINNED, - pinnedVersion=versioning_override.version.to_canonical_string(), + pinned_version=versioning_override.version.to_canonical_string(), pinned=generated.VersioningOverridePinnedOverride( behavior=generated.VersioningOverridePinnedOverrideBehavior.PINNED_OVERRIDE_BEHAVIOR_PINNED, version=generated.WorkerDeploymentVersion( - deploymentName=versioning_override.version.deployment_name, - buildId=versioning_override.version.build_id, + deployment_name=versioning_override.version.deployment_name, + build_id=versioning_override.version.build_id, ), ), deployment=generated.Deployment( - seriesName=versioning_override.version.deployment_name, - buildId=versioning_override.version.build_id, + series_name=versioning_override.version.deployment_name, + build_id=versioning_override.version.build_id, ), ) raise TypeError( @@ -196,26 +267,31 @@ def build_signal_with_start_workflow_execution_input( """Build the generated system Nexus input for signal-with-start.""" return generated.SignalWithStartWorkflowExecutionRequest( namespace=namespace, - workflowId=workflow_id, - workflowType=generated.WorkflowType(name=workflow), - taskQueue=generated.TaskQueue(name=task_queue), + workflow_id=workflow_id, + workflow_type=generated.WorkflowType(name=workflow), + task_queue=generated.TaskQueue(name=task_queue), input=_payloads_to_input(payload_converter, workflow_args), - workflowExecutionTimeout=( + workflow_execution_timeout=( f"{execution_timeout.total_seconds()}s" if execution_timeout else None ), - workflowRunTimeout=f"{run_timeout.total_seconds()}s" if run_timeout else None, - workflowTaskTimeout=( + workflow_run_timeout=f"{run_timeout.total_seconds()}s" if run_timeout else None, + workflow_task_timeout=( f"{task_timeout.total_seconds()}s" if task_timeout else None ), - requestId=request_id, - workflowIdReusePolicy=_workflow_id_reuse_policy_to_generated(id_reuse_policy), - workflowIdConflictPolicy=_workflow_id_conflict_policy_to_generated( - id_conflict_policy + request_id=request_id, + workflow_id_reuse_policy=_workflow_id_reuse_policy_to_generated( + id_reuse_policy + ), + workflow_id_conflict_policy=( + _workflow_id_conflict_policy_to_generated(id_conflict_policy) + if id_conflict_policy + != temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED + else None ), - retryPolicy=( + retry_policy=( _retry_policy_to_generated(retry_policy) if retry_policy else None ), - cronSchedule=cron_schedule, + cron_schedule=cron_schedule or None, memo=( generated.Memo( fields={ @@ -226,16 +302,16 @@ def build_signal_with_start_workflow_execution_input( if memo else None ), - searchAttributes=( + search_attributes=( generated.SearchAttributes( - indexedFields=_search_attributes_to_json_map(search_attributes) + indexed_fields=_search_attributes_to_json_map(search_attributes) ) if search_attributes else None ), - signalName=signal, - signalInput=_payloads_to_input(payload_converter, signal_args), - userMetadata=( + signal_name=signal, + signal_input=_payloads_to_input(payload_converter, signal_args), + user_metadata=( generated.UserMetadata( summary=_payload_to_json_value(payload_converter, static_summary) if static_summary is not None @@ -247,9 +323,11 @@ def build_signal_with_start_workflow_execution_input( if static_summary is not None or static_details is not None else None ), - workflowStartDelay=(f"{start_delay.total_seconds()}s" if start_delay else None), + workflow_start_delay=( + f"{start_delay.total_seconds()}s" if start_delay else None + ), priority=_priority_to_generated(priority), - versioningOverride=( + versioning_override=( _versioning_override_to_generated(versioning_override) if versioning_override else None @@ -272,7 +350,7 @@ def is_system_operation(service: str, operation: str) -> bool: def get_payload_converter() -> temporalio.converter.PayloadConverter: """Return the fixed payload converter for system Nexus outer envelopes.""" - return _SYSTEM_NEXUS_PAYLOAD_CONVERTER + return SystemNexusPayloadConverter() __all__ = ( @@ -281,4 +359,5 @@ def get_payload_converter() -> temporalio.converter.PayloadConverter: "get_payload_converter", "get_payload_visitor", "is_system_operation", + "SystemNexusPayloadConverter", ) diff --git a/temporalio/nexus/system/_workflow_service.nexusrpc.yaml b/temporalio/nexus/system/_workflow_service.nexusrpc.yaml deleted file mode 100644 index edea24b4e..000000000 --- a/temporalio/nexus/system/_workflow_service.nexusrpc.yaml +++ /dev/null @@ -1,11 +0,0 @@ -# TODO: Remove this local shim once the upstream API repo checks in the Nexus -# definition and the generator can consume it directly. -nexusrpc: 1.0.0 -services: - WorkflowService: - operations: - SignalWithStartWorkflowExecution: - input: - $ref: ../../bridge/sdk-core/crates/common/protos/api_upstream/openapi/openapiv3.yaml#/components/schemas/SignalWithStartWorkflowExecutionRequest - output: - $ref: ../../bridge/sdk-core/crates/common/protos/api_upstream/openapi/openapiv3.yaml#/components/schemas/SignalWithStartWorkflowExecutionResponse diff --git a/temporalio/nexus/system/_workflow_service_generated.py b/temporalio/nexus/system/_workflow_service_generated.py index 5be522798..22fabfa80 100644 --- a/temporalio/nexus/system/_workflow_service_generated.py +++ b/temporalio/nexus/system/_workflow_service_generated.py @@ -8,23 +8,24 @@ from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import Dict, List, Optional +from typing import ClassVar, Dict, List, Optional from google.protobuf.json_format import MessageToDict, ParseDict from nexusrpc import Operation, service import temporalio.api.common.v1 +import temporalio.api.workflowservice.v1 @dataclass class PayloadExternalPayloadDetails: - sizeBytes: Optional[str] = None + size_bytes: Optional[str] = None @dataclass class Payload: data: Optional[str] = None - externalPayloads: Optional[List[PayloadExternalPayloadDetails]] = None + external_payloads: Optional[List[PayloadExternalPayloadDetails]] = None metadata: Optional[Dict[str, str]] = None @@ -40,14 +41,14 @@ class Payloads: @dataclass class LinkActivity: - activityId: Optional[str] = None + activity_id: Optional[str] = None namespace: Optional[str] = None - runId: Optional[str] = None + run_id: Optional[str] = None @dataclass class LinkBatchJob: - jobId: Optional[str] = None + job_id: Optional[str] = None class EventType(str, Enum): @@ -176,30 +177,30 @@ class EventType(str, Enum): @dataclass class WorkflowEventEventReference: - eventId: Optional[str] = None - eventType: Optional["EventType"] = None + event_id: Optional[str] = None + event_type: Optional["EventType"] = None @dataclass class WorkflowEventRequestIDReference: - eventType: Optional["EventType"] = None - requestId: Optional[str] = None + event_type: Optional["EventType"] = None + request_id: Optional[str] = None @dataclass class LinkWorkflowEvent: - eventRef: Optional[WorkflowEventEventReference] = None + event_ref: Optional[WorkflowEventEventReference] = None namespace: Optional[str] = None - requestIdRef: Optional[WorkflowEventRequestIDReference] = None - runId: Optional[str] = None - workflowId: Optional[str] = None + request_id_ref: Optional[WorkflowEventRequestIDReference] = None + run_id: Optional[str] = None + workflow_id: Optional[str] = None @dataclass class Link: activity: Optional[LinkActivity] = None - batchJob: Optional[LinkBatchJob] = None - workflowEvent: Optional[LinkWorkflowEvent] = None + batch_job: Optional[LinkBatchJob] = None + workflow_event: Optional[LinkWorkflowEvent] = None @dataclass @@ -209,23 +210,23 @@ class Memo: @dataclass class Priority: - fairnessKey: Optional[str] = None - fairnessWeight: Optional[float] = None - priorityKey: Optional[int] = None + fairness_key: Optional[str] = None + fairness_weight: Optional[float] = None + priority_key: Optional[int] = None @dataclass class RetryPolicy: - backoffCoefficient: Optional[float] = None - initialInterval: Optional[str] = None - maximumAttempts: Optional[int] = None - maximumInterval: Optional[str] = None - nonRetryableErrorTypes: Optional[List[str]] = None + backoff_coefficient: Optional[float] = None + initial_interval: Optional[str] = None + maximum_attempts: Optional[int] = None + maximum_interval: Optional[str] = None + non_retryable_error_types: Optional[List[str]] = None @dataclass class SearchAttributes: - indexedFields: Optional[Dict[str, Payload]] = None + indexed_fields: Optional[Dict[str, Payload]] = None class Kind(str, Enum): @@ -239,16 +240,16 @@ class Kind(str, Enum): class TaskQueue: kind: Optional["Kind"] = None name: Optional[str] = None - normalName: Optional[str] = None + normal_name: Optional[str] = None @dataclass class TimeSkippingConfig: - disablePropagation: Optional[bool] = None + disable_propagation: Optional[bool] = None enabled: Optional[bool] = None - maxElapsedDuration: Optional[str] = None - maxSkippedDuration: Optional[str] = None - maxTargetTime: Optional[datetime] = None + max_elapsed_duration: Optional[str] = None + max_skipped_duration: Optional[str] = None + max_target_time: Optional[datetime] = None @dataclass @@ -265,8 +266,8 @@ class VersioningOverrideBehavior(str, Enum): @dataclass class Deployment: - buildId: Optional[str] = None - seriesName: Optional[str] = None + build_id: Optional[str] = None + series_name: Optional[str] = None class VersioningOverridePinnedOverrideBehavior(str, Enum): @@ -276,8 +277,8 @@ class VersioningOverridePinnedOverrideBehavior(str, Enum): @dataclass class WorkerDeploymentVersion: - buildId: Optional[str] = None - deploymentName: Optional[str] = None + build_id: Optional[str] = None + deployment_name: Optional[str] = None @dataclass @@ -288,11 +289,11 @@ class VersioningOverridePinnedOverride: @dataclass class VersioningOverride: - autoUpgrade: Optional[bool] = None + auto_upgrade: Optional[bool] = None behavior: Optional["VersioningOverrideBehavior"] = None deployment: Optional[Deployment] = None pinned: Optional[VersioningOverridePinnedOverride] = None - pinnedVersion: Optional[str] = None + pinned_version: Optional[str] = None class WorkflowIDConflictPolicy(str, Enum): @@ -330,7 +331,7 @@ class WorkflowType: @dataclass class SignalWithStartWorkflowExecutionRequest: control: Optional[str] = None - cronSchedule: Optional[str] = None + cron_schedule: Optional[str] = None header: Optional[Header] = None identity: Optional[str] = None input: Optional[Payloads] = None @@ -338,31 +339,55 @@ class SignalWithStartWorkflowExecutionRequest: memo: Optional[Memo] = None namespace: Optional[str] = None priority: Optional[Priority] = None - requestId: Optional[str] = None - retryPolicy: Optional[RetryPolicy] = None - searchAttributes: Optional[SearchAttributes] = None - signalInput: Optional[Payloads] = None - signalName: Optional[str] = None - taskQueue: Optional[TaskQueue] = None - timeSkippingConfig: Optional[TimeSkippingConfig] = None - userMetadata: Optional[UserMetadata] = None - versioningOverride: Optional[VersioningOverride] = None - workflowExecutionTimeout: Optional[str] = None - workflowId: Optional[str] = None - workflowIdConflictPolicy: Optional["WorkflowIDConflictPolicy"] = None - workflowIdReusePolicy: Optional["WorkflowIDReusePolicy"] = None - workflowRunTimeout: Optional[str] = None - workflowStartDelay: Optional[str] = None - workflowTaskTimeout: Optional[str] = None - workflowType: Optional[WorkflowType] = None + request_id: Optional[str] = None + retry_policy: Optional[RetryPolicy] = None + search_attributes: Optional[SearchAttributes] = None + signal_input: Optional[Payloads] = None + signal_name: Optional[str] = None + task_queue: Optional[TaskQueue] = None + time_skipping_config: Optional[TimeSkippingConfig] = None + user_metadata: Optional[UserMetadata] = None + versioning_override: Optional[VersioningOverride] = None + workflow_execution_timeout: Optional[str] = None + workflow_id: Optional[str] = None + workflow_id_conflict_policy: Optional["WorkflowIDConflictPolicy"] = None + workflow_id_reuse_policy: Optional["WorkflowIDReusePolicy"] = None + workflow_run_timeout: Optional[str] = None + workflow_start_delay: Optional[str] = None + workflow_task_timeout: Optional[str] = None + workflow_type: Optional[WorkflowType] = None + + __temporal_nexus_proto_type__: ClassVar[ + type[temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest] + ] = temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest + + @property + def proto_type( + self, + ) -> type[ + temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest + ]: + return self.__temporal_nexus_proto_type__ @dataclass class SignalWithStartWorkflowExecutionResponse: - runId: Optional[str] = None - signalLink: Optional[Link] = None + run_id: Optional[str] = None + signal_link: Optional[Link] = None started: Optional[bool] = None + __temporal_nexus_proto_type__: ClassVar[ + type[temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionResponse] + ] = temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionResponse + + @property + def proto_type( + self, + ) -> type[ + temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionResponse + ]: + return self.__temporal_nexus_proto_type__ + @service class WorkflowService: diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 10ae8bb1f..8b72bc7cf 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -62,6 +62,7 @@ from temporalio.service import __version__ from ..api.failure.v1.message_pb2 import Failure +from ..converter import PayloadConverter from . import _command_aware_visitor from ._interceptor import ( ContinueAsNewInput, @@ -1993,7 +1994,7 @@ async def _outbound_signal_with_start_external_workflow( StartNexusOperationInput( endpoint=temporalio.workflow._SYSTEM_NEXUS_ENDPOINT, service=temporalio.nexus.system.generated.WorkflowService.__name__, - operation=temporalio.nexus.system.generated.WorkflowService.signal_with_start_workflow_execution, + operation=temporalio.nexus.system.generated.WorkflowService.signal_with_start_workflow_execution.name, input=request, schedule_to_close_timeout=None, schedule_to_start_timeout=None, @@ -2001,11 +2002,12 @@ async def _outbound_signal_with_start_external_workflow( cancellation_type=temporalio.workflow.NexusOperationCancellationType.WAIT_COMPLETED, headers=input.headers, summary=None, + output_type=temporalio.nexus.system.generated.SignalWithStartWorkflowExecutionResponse, ) ) result = await handle return self.workflow_get_external_workflow_handle( - input.workflow_id, run_id=result.runId + input.workflow_id, run_id=result.run_id ) async def _outbound_start_child_workflow( @@ -2086,8 +2088,20 @@ async def operation_handle_fn() -> OutputT: cancel_command = self._add_command() handle._apply_cancel_command(cancel_command) + is_system_operation = temporalio.nexus.system.is_system_operation( + input.service, input.operation_name + ) + payload_converter = ( + temporalio.nexus.system.get_payload_converter() + if is_system_operation + else self._context_free_payload_converter + ) handle = _NexusOperationHandle( - self, self._next_seq("nexus_operation"), input, operation_handle_fn() + self, + self._next_seq("nexus_operation"), + input, + operation_handle_fn(), + payload_converter, ) handle._apply_schedule_command() self._pending_nexus_operations[handle._seq] = handle @@ -3424,6 +3438,7 @@ def __init__( seq: int, input: StartNexusOperationInput[Any, OutputT], fn: Coroutine[Any, Any, OutputT], + payload_converter: PayloadConverter, ): self._instance = instance self._seq = seq @@ -3431,7 +3446,7 @@ def __init__( self._task = asyncio.Task(fn) self._start_fut: asyncio.Future[str | None] = instance.create_future() self._result_fut: asyncio.Future[OutputT | None] = instance.create_future() - self._payload_converter = self._instance._context_free_payload_converter + self._payload_converter = payload_converter self._failure_converter = self._instance._context_free_failure_converter @property @@ -3471,12 +3486,7 @@ def _apply_schedule_command(self) -> None: v.endpoint = self._input.endpoint v.service = self._input.service v.operation = self._input.operation_name - payload_converter = ( - temporalio.nexus.system.get_payload_converter() - if temporalio.nexus.system.is_system_operation(v.service, v.operation) - else self._payload_converter - ) - v.input.CopyFrom(payload_converter.to_payload(self._input.input)) + v.input.CopyFrom(self._payload_converter.to_payload(self._input.input)) if self._input.schedule_to_close_timeout is not None: v.schedule_to_close_timeout.FromTimedelta( self._input.schedule_to_close_timeout diff --git a/tests/nexus/test_temporal_system_nexus.py b/tests/nexus/test_temporal_system_nexus.py index 49a0129aa..41b063230 100644 --- a/tests/nexus/test_temporal_system_nexus.py +++ b/tests/nexus/test_temporal_system_nexus.py @@ -4,21 +4,20 @@ import json import uuid from collections.abc import Sequence -from typing import Any, cast +from datetime import timedelta +from typing import Any, ClassVar, Protocol, cast import nexusrpc.handler import pytest +from google.protobuf.descriptor import FieldDescriptor +from google.protobuf.message import Message import temporalio.api.common.v1 import temporalio.converter import temporalio.nexus.system as nexus_system from temporalio import workflow from temporalio.client import Client -from temporalio.converter import ( - DefaultPayloadConverter, - ExternalStorage, - PayloadCodec, -) +from temporalio.converter import DefaultPayloadConverter, ExternalStorage, PayloadCodec from temporalio.nexus.system import generated from temporalio.testing import WorkflowEnvironment from temporalio.worker import ( @@ -37,6 +36,13 @@ received_requests: list[dict[str, Any]] = [] +class _AnnotatedSystemNexusMessage(Protocol): + __temporal_nexus_proto_type__: ClassVar[type[Message]] + + @property + def proto_type(self) -> type[Message]: ... + + @nexusrpc.handler.service_handler(service=generated.WorkflowService) class WorkflowServicePayloadHandler: @nexusrpc.handler.sync_operation @@ -45,11 +51,11 @@ async def signal_with_start_workflow_execution( _ctx: nexusrpc.handler.StartOperationContext, request: generated.SignalWithStartWorkflowExecutionRequest, ) -> generated.SignalWithStartWorkflowExecutionResponse: - assert request.workflowId == "system-nexus-workflow-id" - assert request.signalName == "test-signal" + assert request.workflow_id == "system-nexus-workflow-id" + assert request.signal_name == "test-signal" received_requests.append(dataclasses.asdict(request)) return generated.SignalWithStartWorkflowExecutionResponse( - runId=f"{request.workflowId}-run" + run_id=f"{request.workflow_id}-run" ) @@ -174,18 +180,18 @@ def _assert_request_payload_was_externally_stored( "payloads" ] assert len(payloads) == 1 - assert payloads[0]["externalPayloads"] + assert payloads[0]["external_payloads"] def _assert_request_user_metadata_was_externally_stored( request_dict: dict[str, Any], ) -> None: user_metadata = cast( - "dict[str, dict[str, object]] | None", request_dict.get("userMetadata") + "dict[str, dict[str, object]] | None", request_dict.get("user_metadata") ) assert user_metadata is not None - assert user_metadata["summary"]["externalPayloads"] - assert user_metadata["details"]["externalPayloads"] + assert user_metadata["summary"]["external_payloads"] + assert user_metadata["details"]["external_payloads"] def _assert_stored_payloads_include( @@ -210,6 +216,132 @@ def _assert_signal_with_start_interceptor_trace() -> None: assert trace_input.workflow == "test-workflow" +def _build_proto_sample(message_type: type[Message]) -> Message: + message = message_type() + _populate_proto_sample(message) + return message + + +def _populate_proto_sample(message: Message, *, path: str = "value") -> None: + seen_oneofs: set[str] = set() + for field in message.DESCRIPTOR.fields: + if field.containing_oneof is not None: + if field.containing_oneof.name in seen_oneofs: + continue + seen_oneofs.add(field.containing_oneof.name) + if field.label == FieldDescriptor.LABEL_REPEATED: + if ( + field.message_type is not None + and field.message_type.GetOptions().map_entry + ): + _populate_proto_map_entry(message, field, path=path) + elif field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE: + _populate_proto_sample( + getattr(message, field.name).add(), + path=f"{path}.{field.name}[0]", + ) + else: + getattr(message, field.name).append( + _proto_scalar_sample(field, path=f"{path}.{field.name}[0]") + ) + elif field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE: + _populate_proto_sample( + getattr(message, field.name), + path=f"{path}.{field.name}", + ) + else: + setattr( + message, + field.name, + _proto_scalar_sample(field, path=f"{path}.{field.name}"), + ) + + +def _populate_proto_map_entry( + message: Message, + field: FieldDescriptor, + *, + path: str, +) -> None: + key_field = field.message_type.fields_by_name["key"] + value_field = field.message_type.fields_by_name["value"] + key = _proto_scalar_sample(key_field, path=f"{path}.{field.name}.key") + container = getattr(message, field.name) + if value_field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE: + _populate_proto_sample( + container[key], + path=f"{path}.{field.name}[{key!r}]", + ) + else: + container[key] = _proto_scalar_sample( + value_field, + path=f"{path}.{field.name}[{key!r}]", + ) + + +def _proto_scalar_sample(field: FieldDescriptor, *, path: str) -> Any: + if field.type == FieldDescriptor.TYPE_BYTES: + return b"test" + if field.cpp_type == FieldDescriptor.CPPTYPE_STRING: + return f"{path}-value" + if field.cpp_type == FieldDescriptor.CPPTYPE_BOOL: + return True + if field.cpp_type in ( + FieldDescriptor.CPPTYPE_INT32, + FieldDescriptor.CPPTYPE_INT64, + FieldDescriptor.CPPTYPE_UINT32, + FieldDescriptor.CPPTYPE_UINT64, + ): + return 1 + if field.cpp_type in ( + FieldDescriptor.CPPTYPE_FLOAT, + FieldDescriptor.CPPTYPE_DOUBLE, + ): + return 1.5 + if field.cpp_type == FieldDescriptor.CPPTYPE_ENUM: + for enum_value in field.enum_type.values: + if enum_value.number != 0: + return enum_value.number + return field.enum_type.values[0].number + raise TypeError(f"Unhandled proto scalar sample at {path}: {field!r}") + + +def test_generated_system_nexus_proto_roundtrip() -> None: + payload_converter = nexus_system.get_payload_converter() + annotated_types = sorted( + ( + value + for value in vars(generated).values() + if isinstance(value, type) + and dataclasses.is_dataclass(value) + and hasattr(value, "__temporal_nexus_proto_type__") + ), + key=lambda value: value.__name__, + ) + assert annotated_types + + for annotated_type in annotated_types: + annotated_message_type = cast( + type[_AnnotatedSystemNexusMessage], annotated_type + ) + proto_type = annotated_message_type.__temporal_nexus_proto_type__ + proto_value = _build_proto_sample(proto_type) + payload = payload_converter.to_payload(proto_value) + assert payload is not None + assert ( + payload.metadata["messageType"] == proto_type.DESCRIPTOR.full_name.encode() + ) + value = payload_converter.from_payload(payload, annotated_message_type) + assert dataclasses.is_dataclass(value) + assert value.proto_type is proto_type + roundtripped_payload = payload_converter.to_payload(value) + assert roundtripped_payload is not None + roundtripped = payload_converter.from_payload( + roundtripped_payload, annotated_message_type + ) + assert roundtripped == value + + async def test_external_workflow_handle_signal_with_start_workflow_uses_system_nexus( env: WorkflowEnvironment, monkeypatch: pytest.MonkeyPatch, @@ -232,7 +364,10 @@ async def test_external_workflow_handle_signal_with_start_workflow_uses_system_n ) caller_client = Client(**caller_config) handler_config = env.client.config() - handler_config["data_converter"] = temporalio.converter.default() + handler_config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), + payload_converter_class=nexus_system.SystemNexusPayloadConverter, + ) handler_client = Client(**handler_config) caller_task_queue = str(uuid.uuid4()) handler_task_queue = str(uuid.uuid4()) @@ -259,12 +394,13 @@ async def test_external_workflow_handle_signal_with_start_workflow_uses_system_n args=[handler_task_queue], id=str(uuid.uuid4()), task_queue=caller_task_queue, + execution_timeout=timedelta(seconds=5), ) assert result == "system-nexus-workflow-id-run" request_dict = _pop_received_request() _assert_request_payload_was_externally_stored(request_dict, "input") - _assert_request_payload_was_externally_stored(request_dict, "signalInput") + _assert_request_payload_was_externally_stored(request_dict, "signal_input") _assert_request_user_metadata_was_externally_stored(request_dict) assert codec.encode_count >= 5 _assert_stored_payloads_include( diff --git a/tests/worker/test_visitor.py b/tests/worker/test_visitor.py index d57077951..196f8e9e0 100644 --- a/tests/worker/test_visitor.py +++ b/tests/worker/test_visitor.py @@ -261,8 +261,8 @@ async def test_bridge_encoding(): async def test_visit_system_nexus_payloads_on_schedule_nexus_operation(): envelope = nexus_system.generated.SignalWithStartWorkflowExecutionRequest( namespace="default", - workflowId="workflow-id", - signalName="signal-name", + workflow_id="workflow-id", + signal_name="signal-name", input=nexus_system.generated.Payloads( payloads=[ nexus_system.generated.Payload( @@ -271,7 +271,7 @@ async def test_visit_system_nexus_payloads_on_schedule_nexus_operation(): ) ] ), - signalInput=nexus_system.generated.Payloads( + signal_input=nexus_system.generated.Payloads( payloads=[ nexus_system.generated.Payload( data="InNpZ25hbC12YWx1ZSI=", @@ -287,8 +287,8 @@ async def test_visit_system_nexus_payloads_on_schedule_nexus_operation(): ) } ), - searchAttributes=nexus_system.generated.SearchAttributes( - indexedFields={ + search_attributes=nexus_system.generated.SearchAttributes( + indexed_fields={ "search-key": nexus_system.generated.Payload( data="InNlYXJjaC12YWx1ZSI=", metadata={"encoding": "anNvbi9wbGFpbg=="}, @@ -338,9 +338,9 @@ async def test_visit_system_nexus_payloads_on_schedule_nexus_operation(): async def test_bridge_encoding_checks_system_nexus_envelope_size(): envelope = nexus_system.generated.SignalWithStartWorkflowExecutionRequest( namespace="default", - workflowId="workflow-id", - signalName="signal-name", - requestId="x" * 2048, + workflow_id="workflow-id", + signal_name="signal-name", + request_id="x" * 2048, input=nexus_system.generated.Payloads( payloads=[ nexus_system.generated.Payload( From 9136695f7ea36f9446dbeda01cb543bae48a3562 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Wed, 22 Apr 2026 12:04:20 -0700 Subject: [PATCH 14/18] Move API off of external handle --- .../contrib/opentelemetry/_interceptor.py | 6 +- .../opentelemetry/_otel_interceptor.py | 6 +- temporalio/worker/__init__.py | 4 +- temporalio/worker/_interceptor.py | 14 +- temporalio/worker/_workflow_instance.py | 129 ++++---- temporalio/workflow.py | 278 +++++++++++------- tests/nexus/test_temporal_system_nexus.py | 20 +- 7 files changed, 256 insertions(+), 201 deletions(-) diff --git a/temporalio/contrib/opentelemetry/_interceptor.py b/temporalio/contrib/opentelemetry/_interceptor.py index e09b355df..502a25f16 100644 --- a/temporalio/contrib/opentelemetry/_interceptor.py +++ b/temporalio/contrib/opentelemetry/_interceptor.py @@ -772,15 +772,15 @@ async def signal_external_workflow( ) await super().signal_external_workflow(input) - async def signal_with_start_external_workflow( - self, input: temporalio.worker.SignalWithStartExternalWorkflowInput + async def signal_with_start_workflow( + self, input: temporalio.worker.SignalWithStartWorkflowInput ) -> temporalio.workflow.ExternalWorkflowHandle[Any]: self.root._completed_span( f"SignalWithStartWorkflow:{input.signal}", add_to_outbound_str=input, kind=opentelemetry.trace.SpanKind.CLIENT, ) - return await super().signal_with_start_external_workflow(input) + return await super().signal_with_start_workflow(input) def start_activity( self, input: temporalio.worker.StartActivityInput diff --git a/temporalio/contrib/opentelemetry/_otel_interceptor.py b/temporalio/contrib/opentelemetry/_otel_interceptor.py index 34e61eb65..4b6d8b537 100644 --- a/temporalio/contrib/opentelemetry/_otel_interceptor.py +++ b/temporalio/contrib/opentelemetry/_otel_interceptor.py @@ -545,15 +545,15 @@ async def signal_external_workflow( input.headers = _context_to_headers(input.headers) await super().signal_external_workflow(input) - async def signal_with_start_external_workflow( - self, input: temporalio.worker.SignalWithStartExternalWorkflowInput + async def signal_with_start_workflow( + self, input: temporalio.worker.SignalWithStartWorkflowInput ) -> temporalio.workflow.ExternalWorkflowHandle[Any]: with self._workflow_maybe_span( f"SignalWithStartWorkflow:{input.signal}", kind=opentelemetry.trace.SpanKind.CLIENT, ): input.headers = _context_to_nexus_headers(input.headers or {}) - return await super().signal_with_start_external_workflow(input) + return await super().signal_with_start_workflow(input) def start_activity( self, input: temporalio.worker.StartActivityInput diff --git a/temporalio/worker/__init__.py b/temporalio/worker/__init__.py index cd8feb720..fb099f73b 100644 --- a/temporalio/worker/__init__.py +++ b/temporalio/worker/__init__.py @@ -17,7 +17,7 @@ NexusOperationInboundInterceptor, SignalChildWorkflowInput, SignalExternalWorkflowInput, - SignalWithStartExternalWorkflowInput, + SignalWithStartWorkflowInput, StartActivityInput, StartChildWorkflowInput, StartLocalActivityInput, @@ -95,7 +95,7 @@ "HandleUpdateInput", "SignalChildWorkflowInput", "SignalExternalWorkflowInput", - "SignalWithStartExternalWorkflowInput", + "SignalWithStartWorkflowInput", "StartActivityInput", "StartChildWorkflowInput", "StartLocalActivityInput", diff --git a/temporalio/worker/_interceptor.py b/temporalio/worker/_interceptor.py index 25579e69f..ec3c606b6 100644 --- a/temporalio/worker/_interceptor.py +++ b/temporalio/worker/_interceptor.py @@ -242,10 +242,8 @@ class SignalExternalWorkflowInput: @dataclass -class SignalWithStartExternalWorkflowInput: - """Input for - :py:meth:`WorkflowOutboundInterceptor.signal_with_start_external_workflow`. - """ +class SignalWithStartWorkflowInput: + """Input for :py:meth:`WorkflowOutboundInterceptor.signal_with_start_workflow`.""" signal: str signal_args: Sequence[Any] @@ -481,14 +479,14 @@ async def signal_external_workflow( """ return await self.next.signal_external_workflow(input) - async def signal_with_start_external_workflow( - self, input: SignalWithStartExternalWorkflowInput + async def signal_with_start_workflow( + self, input: SignalWithStartWorkflowInput ) -> temporalio.workflow.ExternalWorkflowHandle[Any]: """Called for every - :py:meth:`temporalio.workflow.ExternalWorkflowHandle.signal_with_start` + :py:func:`temporalio.workflow.signal_with_start_workflow` call. """ - return await self.next.signal_with_start_external_workflow(input) + return await self.next.signal_with_start_workflow(input) def start_activity( self, input: StartActivityInput diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 8b72bc7cf..d2a90ef9a 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -72,7 +72,7 @@ HandleUpdateInput, SignalChildWorkflowInput, SignalExternalWorkflowInput, - SignalWithStartExternalWorkflowInput, + SignalWithStartWorkflowInput, StartActivityInput, StartChildWorkflowInput, StartLocalActivityInput, @@ -1549,6 +1549,65 @@ async def workflow_start_child_workflow( ) ) + async def workflow_signal_with_start_workflow( + self, + workflow_id: str, + signal: str | Callable, + workflow: Any, + *, + signal_args: Sequence[Any], + workflow_args: Sequence[Any], + task_queue: str, + execution_timeout: timedelta | None, + run_timeout: timedelta | None, + task_timeout: timedelta | None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy, + retry_policy: temporalio.common.RetryPolicy | None, + cron_schedule: str, + memo: Mapping[str, Any] | None, + search_attributes: temporalio.common.TypedSearchAttributes | None, + static_summary: str | None, + static_details: str | None, + start_delay: timedelta | None, + request_id: str | None, + priority: temporalio.common.Priority, + versioning_override: temporalio.common.VersioningOverride | None, + ) -> temporalio.workflow.ExternalWorkflowHandle[Any]: + self._assert_not_read_only("signal with start workflow") + workflow_name, _ = temporalio.workflow._Definition.get_name_and_result_type( + workflow + ) + return await self._outbound.signal_with_start_workflow( + SignalWithStartWorkflowInput( + signal=temporalio.workflow._SignalDefinition.must_name_from_fn_or_str( + signal + ), + signal_args=signal_args, + namespace=self._info.namespace, + workflow_id=workflow_id, + workflow=workflow_name, + workflow_args=workflow_args, + task_queue=task_queue, + execution_timeout=execution_timeout, + run_timeout=run_timeout, + task_timeout=task_timeout, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + cron_schedule=cron_schedule, + memo=memo, + search_attributes=search_attributes, + static_summary=static_summary, + static_details=static_details, + start_delay=start_delay, + request_id=request_id, + priority=priority, + versioning_override=versioning_override, + headers=None, + ) + ) + def workflow_start_local_activity( self, activity: Any, @@ -1954,8 +2013,8 @@ async def _outbound_signal_external_workflow( temporalio.common._apply_headers(input.headers, v.headers) await self._signal_external_workflow(command) - async def _outbound_signal_with_start_external_workflow( - self, input: SignalWithStartExternalWorkflowInput + async def _outbound_signal_with_start_workflow( + self, input: SignalWithStartWorkflowInput ) -> temporalio.workflow.ExternalWorkflowHandle[Any]: payload_converter = self._payload_converter_with_context( temporalio.converter.WorkflowSerializationContext( @@ -2948,10 +3007,10 @@ async def signal_external_workflow( ) -> None: await self._instance._outbound_signal_external_workflow(input) - async def signal_with_start_external_workflow( - self, input: SignalWithStartExternalWorkflowInput + async def signal_with_start_workflow( + self, input: SignalWithStartWorkflowInput ) -> temporalio.workflow.ExternalWorkflowHandle[Any]: - return await self._instance._outbound_signal_with_start_external_workflow(input) + return await self._instance._outbound_signal_with_start_workflow(input) def start_activity( self, input: StartActivityInput @@ -3362,64 +3421,6 @@ async def signal( ) ) - async def signal_with_start( - self, - signal: str | Callable, - workflow: str | Callable[..., Awaitable[Any]], - *, - signal_args: Sequence[Any] = (), - workflow_args: Sequence[Any] = (), - task_queue: str, - execution_timeout: timedelta | None = None, - run_timeout: timedelta | None = None, - task_timeout: timedelta | None = None, - id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, - id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, - retry_policy: temporalio.common.RetryPolicy | None = None, - cron_schedule: str = "", - memo: Mapping[str, Any] | None = None, - search_attributes: temporalio.common.TypedSearchAttributes | None = None, - static_summary: str | None = None, - static_details: str | None = None, - start_delay: timedelta | None = None, - request_id: str | None = None, - priority: temporalio.common.Priority = temporalio.common.Priority.default, - versioning_override: temporalio.common.VersioningOverride | None = None, - ) -> temporalio.workflow.ExternalWorkflowHandle[Any]: - self._instance._assert_not_read_only("signal with start external handle") - workflow_name, _ = temporalio.workflow._Definition.get_name_and_result_type( - workflow - ) - return await self._instance._outbound.signal_with_start_external_workflow( - SignalWithStartExternalWorkflowInput( - signal=temporalio.workflow._SignalDefinition.must_name_from_fn_or_str( - signal - ), - signal_args=signal_args, - namespace=self._instance._info.namespace, - workflow_id=self._id, - workflow=workflow_name, - workflow_args=workflow_args, - task_queue=task_queue, - execution_timeout=execution_timeout, - run_timeout=run_timeout, - task_timeout=task_timeout, - id_reuse_policy=id_reuse_policy, - id_conflict_policy=id_conflict_policy, - retry_policy=retry_policy, - cron_schedule=cron_schedule, - memo=memo, - search_attributes=search_attributes, - static_summary=static_summary, - static_details=static_details, - start_delay=start_delay, - request_id=request_id, - priority=priority, - versioning_override=versioning_override, - headers=None, - ) - ) - async def cancel(self) -> None: self._instance._assert_not_read_only("cancel external handle") command = self._instance._add_command() diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 4b678107d..dfd8ea5d3 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -847,6 +847,33 @@ async def workflow_start_child_workflow( priority: temporalio.common.Priority = temporalio.common.Priority.default, ) -> ChildWorkflowHandle[Any, Any]: ... + @abstractmethod + async def workflow_signal_with_start_workflow( + self, + workflow_id: str, + signal: str | Callable, + workflow: Any, + *, + signal_args: Sequence[Any], + workflow_args: Sequence[Any], + task_queue: str, + execution_timeout: timedelta | None, + run_timeout: timedelta | None, + task_timeout: timedelta | None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy, + retry_policy: temporalio.common.RetryPolicy | None, + cron_schedule: str, + memo: Mapping[str, Any] | None, + search_attributes: temporalio.common.TypedSearchAttributes | None, + static_summary: str | None, + static_details: str | None, + start_delay: timedelta | None, + request_id: str | None, + priority: temporalio.common.Priority, + versioning_override: temporalio.common.VersioningOverride | None, + ) -> ExternalWorkflowHandle[Any]: ... + @abstractmethod def workflow_start_local_activity( self, @@ -4658,6 +4685,147 @@ async def execute_child_workflow( return await handle +@overload +async def signal_with_start_workflow( + workflow_id: str, + signal: str | Callable, + workflow: MethodAsyncNoParam[SelfType, Any] + | MethodAsyncSingleParam[SelfType, Any, Any] + | Callable[Concatenate[SelfType, MultiParamSpec], Awaitable[Any]], + *, + signal_args: Sequence[Any] = [], + workflow_args: Sequence[Any] = [], + task_queue: str, + execution_timeout: timedelta | None = None, + run_timeout: timedelta | None = None, + task_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: temporalio.common.RetryPolicy | None = None, + cron_schedule: str = "", + memo: Mapping[str, Any] | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + static_summary: str | None = None, + static_details: str | None = None, + start_delay: timedelta | None = None, + request_id: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: temporalio.common.VersioningOverride | None = None, +) -> ExternalWorkflowHandle[SelfType]: ... + + +@overload +async def signal_with_start_workflow( + workflow_id: str, + signal: str | Callable, + workflow: str, + *, + signal_args: Sequence[Any] = [], + workflow_args: Sequence[Any] = [], + task_queue: str, + execution_timeout: timedelta | None = None, + run_timeout: timedelta | None = None, + task_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: temporalio.common.RetryPolicy | None = None, + cron_schedule: str = "", + memo: Mapping[str, Any] | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + static_summary: str | None = None, + static_details: str | None = None, + start_delay: timedelta | None = None, + request_id: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: temporalio.common.VersioningOverride | None = None, +) -> ExternalWorkflowHandle[Any]: ... + + +async def signal_with_start_workflow( + workflow_id: str, + signal: str | Callable, + workflow: str | Callable[..., Awaitable[Any]], + *, + signal_args: Sequence[Any] = [], + workflow_args: Sequence[Any] = [], + task_queue: str, + execution_timeout: timedelta | None = None, + run_timeout: timedelta | None = None, + task_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: temporalio.common.RetryPolicy | None = None, + cron_schedule: str = "", + memo: Mapping[str, Any] | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + static_summary: str | None = None, + static_details: str | None = None, + start_delay: timedelta | None = None, + request_id: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: temporalio.common.VersioningOverride | None = None, +) -> ExternalWorkflowHandle[Any]: + """Signal a workflow, or start it and signal it if it is not running. + + This uses the system Nexus ``SignalWithStartWorkflowExecution`` operation + under the hood. + + Args: + workflow_id: Workflow ID to signal or start. + signal: Name or method reference for the signal. + workflow: String name or class method decorated with ``@workflow.run`` + for the workflow to start. + signal_args: Arguments to the signal. + workflow_args: Arguments to the workflow. + task_queue: Task queue to run the workflow on if it is started. + execution_timeout: Total workflow execution timeout including + retries and continue as new. + run_timeout: Timeout of a single workflow run. + task_timeout: Timeout of a single workflow task. + id_reuse_policy: How already-existing IDs are treated. + id_conflict_policy: How already-running IDs are treated. + retry_policy: Retry policy for the workflow. + cron_schedule: See https://docs.temporal.io/docs/content/what-is-a-temporal-cron-job/ + memo: Memo for the workflow. + search_attributes: Typed search attributes for the workflow. + static_summary: A single-line fixed summary for this workflow + execution that may appear in the UI/CLI. + static_details: General fixed details for this workflow execution + that may appear in UI/CLI. + start_delay: Time to wait before dispatching the first workflow task. + request_id: Optional idempotency request ID for the start request. + priority: Priority to use for this workflow. + versioning_override: Versioning override to apply if the workflow is + started. + + Returns: + A handle for the resulting workflow run. + """ + return await _Runtime.current().workflow_signal_with_start_workflow( + workflow_id, + signal, + workflow, + signal_args=signal_args, + workflow_args=workflow_args, + task_queue=task_queue, + execution_timeout=execution_timeout, + run_timeout=run_timeout, + task_timeout=task_timeout, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + cron_schedule=cron_schedule, + memo=memo, + search_attributes=search_attributes, + static_summary=static_summary, + static_details=static_details, + start_delay=start_delay, + request_id=request_id, + priority=priority, + versioning_override=versioning_override, + ) + + class NexusOperationHandle(Generic[OutputT]): """Handle for interacting with a Nexus operation.""" @@ -4741,116 +4909,6 @@ async def cancel(self) -> None: """ raise NotImplementedError - async def signal_with_start( - self, - signal: str | Callable, # type: ignore[reportUnusedParameter] - workflow: str | Callable[..., Awaitable[Any]], # type: ignore[reportUnusedParameter] - *, - signal_args: Sequence[Any] = [], # type: ignore[reportUnusedParameter] - workflow_args: Sequence[Any] = [], # type: ignore[reportUnusedParameter] - task_queue: str, # type: ignore[reportUnusedParameter] - execution_timeout: timedelta | None = None, # type: ignore[reportUnusedParameter] - run_timeout: timedelta | None = None, # type: ignore[reportUnusedParameter] - task_timeout: timedelta | None = None, # type: ignore[reportUnusedParameter] - id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, # type: ignore[reportUnusedParameter] - id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, # type: ignore[reportUnusedParameter] - retry_policy: temporalio.common.RetryPolicy | None = None, # type: ignore[reportUnusedParameter] - cron_schedule: str = "", # type: ignore[reportUnusedParameter] - memo: Mapping[str, Any] | None = None, # type: ignore[reportUnusedParameter] - search_attributes: temporalio.common.TypedSearchAttributes | None = None, # type: ignore[reportUnusedParameter] - static_summary: str | None = None, # type: ignore[reportUnusedParameter] - static_details: str | None = None, # type: ignore[reportUnusedParameter] - start_delay: timedelta | None = None, # type: ignore[reportUnusedParameter] - request_id: str | None = None, # type: ignore[reportUnusedParameter] - priority: temporalio.common.Priority = temporalio.common.Priority.default, # type: ignore[reportUnusedParameter] - versioning_override: temporalio.common.VersioningOverride | None = None, # type: ignore[reportUnusedParameter] - ) -> ExternalWorkflowHandle[SelfType]: - """Signal the workflow, or start it and signal it if it is not running. - - This uses the system Nexus ``SignalWithStartWorkflowExecution`` operation - under the hood. If this handle has a ``run_id``, it is ignored because - signal-with-start operates on workflow ID only. - - Args: - signal: Name or method reference for the signal. - workflow: String name or class method decorated with ``@workflow.run`` - for the workflow to start. - signal_args: Arguments to the signal. - workflow_args: Arguments to the workflow. - task_queue: Task queue to run the workflow on if it is started. - execution_timeout: Total workflow execution timeout including - retries and continue as new. - run_timeout: Timeout of a single workflow run. - task_timeout: Timeout of a single workflow task. - id_reuse_policy: How already-existing IDs are treated. - id_conflict_policy: How already-running IDs are treated. - retry_policy: Retry policy for the workflow. - cron_schedule: See https://docs.temporal.io/docs/content/what-is-a-temporal-cron-job/ - memo: Memo for the workflow. - search_attributes: Typed search attributes for the workflow. - static_summary: A single-line fixed summary for this workflow - execution that may appear in the UI/CLI. - static_details: General fixed details for this workflow execution - that may appear in UI/CLI. - start_delay: Time to wait before dispatching the first workflow task. - request_id: Optional idempotency request ID for the start request. - priority: Priority to use for this workflow. - versioning_override: Versioning override to apply if the workflow is - started. - - Returns: - A handle for the resulting workflow run. - """ - raise NotImplementedError - - async def signal_with_start_workflow( - self, - signal: str | Callable, # type: ignore[reportUnusedParameter] - workflow: str | Callable[..., Awaitable[Any]], # type: ignore[reportUnusedParameter] - *, - signal_args: Sequence[Any] = [], # type: ignore[reportUnusedParameter] - workflow_args: Sequence[Any] = [], # type: ignore[reportUnusedParameter] - task_queue: str, # type: ignore[reportUnusedParameter] - execution_timeout: timedelta | None = None, # type: ignore[reportUnusedParameter] - run_timeout: timedelta | None = None, # type: ignore[reportUnusedParameter] - task_timeout: timedelta | None = None, # type: ignore[reportUnusedParameter] - id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, # type: ignore[reportUnusedParameter] - id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, # type: ignore[reportUnusedParameter] - retry_policy: temporalio.common.RetryPolicy | None = None, # type: ignore[reportUnusedParameter] - cron_schedule: str = "", # type: ignore[reportUnusedParameter] - memo: Mapping[str, Any] | None = None, # type: ignore[reportUnusedParameter] - search_attributes: temporalio.common.TypedSearchAttributes | None = None, # type: ignore[reportUnusedParameter] - static_summary: str | None = None, # type: ignore[reportUnusedParameter] - static_details: str | None = None, # type: ignore[reportUnusedParameter] - start_delay: timedelta | None = None, # type: ignore[reportUnusedParameter] - request_id: str | None = None, # type: ignore[reportUnusedParameter] - priority: temporalio.common.Priority = temporalio.common.Priority.default, # type: ignore[reportUnusedParameter] - versioning_override: temporalio.common.VersioningOverride | None = None, # type: ignore[reportUnusedParameter] - ) -> ExternalWorkflowHandle[SelfType]: - """Deprecated alias for :py:meth:`signal_with_start`.""" - return await self.signal_with_start( - signal, - workflow, - signal_args=signal_args, - workflow_args=workflow_args, - task_queue=task_queue, - execution_timeout=execution_timeout, - run_timeout=run_timeout, - task_timeout=task_timeout, - id_reuse_policy=id_reuse_policy, - id_conflict_policy=id_conflict_policy, - retry_policy=retry_policy, - cron_schedule=cron_schedule, - memo=memo, - search_attributes=search_attributes, - static_summary=static_summary, - static_details=static_details, - start_delay=start_delay, - request_id=request_id, - priority=priority, - versioning_override=versioning_override, - ) - def get_external_workflow_handle( workflow_id: str, diff --git a/tests/nexus/test_temporal_system_nexus.py b/tests/nexus/test_temporal_system_nexus.py index 41b063230..aa8e91e70 100644 --- a/tests/nexus/test_temporal_system_nexus.py +++ b/tests/nexus/test_temporal_system_nexus.py @@ -22,7 +22,7 @@ from temporalio.testing import WorkflowEnvironment from temporalio.worker import ( Interceptor, - SignalWithStartExternalWorkflowInput, + SignalWithStartWorkflowInput, Worker, WorkflowInboundInterceptor, WorkflowInterceptorClassInput, @@ -63,8 +63,8 @@ async def signal_with_start_workflow_execution( class ExternalHandleSignalWithStartWorkflowCaller: @workflow.run async def run(self, task_queue: str) -> str: - handle = workflow.get_external_workflow_handle("system-nexus-workflow-id") - started_handle = await handle.signal_with_start( + started_handle = await workflow.signal_with_start_workflow( + "system-nexus-workflow-id", "test-signal", "test-workflow", signal_args=["signal-input"], @@ -159,13 +159,11 @@ def init(self, outbound: WorkflowOutboundInterceptor) -> None: class _TracingWorkflowOutboundInterceptor(WorkflowOutboundInterceptor): - async def signal_with_start_external_workflow( - self, input: SignalWithStartExternalWorkflowInput + async def signal_with_start_workflow( + self, input: SignalWithStartWorkflowInput ) -> workflow.ExternalWorkflowHandle[object]: - interceptor_traces.append( - ("workflow.signal_with_start_external_workflow", input) - ) - return await super().signal_with_start_external_workflow(input) + interceptor_traces.append(("workflow.signal_with_start_workflow", input)) + return await super().signal_with_start_workflow(input) def _pop_received_request() -> dict[str, Any]: @@ -209,8 +207,8 @@ def _assert_stored_payloads_include( def _assert_signal_with_start_interceptor_trace() -> None: assert len(interceptor_traces) == 1 trace_name, trace_value = interceptor_traces.pop() - assert trace_name == "workflow.signal_with_start_external_workflow" - trace_input = cast(SignalWithStartExternalWorkflowInput, trace_value) + assert trace_name == "workflow.signal_with_start_workflow" + trace_input = cast(SignalWithStartWorkflowInput, trace_value) assert trace_input.workflow_id == "system-nexus-workflow-id" assert trace_input.signal == "test-signal" assert trace_input.workflow == "test-workflow" From 98262360784411e22152e122d9f9eeab01c938ac Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Thu, 23 Apr 2026 12:14:12 -0700 Subject: [PATCH 15/18] Use shared workflow request builder for system nexus --- scripts/gen_nexus_system_models.py | 3 +- scripts/gen_payload_visitor.py | 289 ++++++---- temporalio/_workflow_requests.py | 142 +++++ temporalio/bridge/_visitor.py | 208 ++++--- temporalio/client.py | 130 +++-- temporalio/nexus/system/__init__.py | 379 ++++--------- temporalio/nexus/system/_payload_visitor.py | 125 +++++ .../system/_workflow_service_generated.py | 529 +----------------- temporalio/worker/_workflow_instance.py | 3 +- tests/nexus/test_temporal_system_nexus.py | 160 ++---- tests/worker/test_visitor.py | 79 +-- 11 files changed, 848 insertions(+), 1199 deletions(-) create mode 100644 temporalio/_workflow_requests.py create mode 100644 temporalio/nexus/system/_payload_visitor.py diff --git a/scripts/gen_nexus_system_models.py b/scripts/gen_nexus_system_models.py index a0e56249b..a133965d8 100644 --- a/scripts/gen_nexus_system_models.py +++ b/scripts/gen_nexus_system_models.py @@ -26,7 +26,7 @@ def main() -> None: / "protos" / "api_upstream" / "nexus" - / "temporal-json-schema-models-nexusrpc.yaml" + / "temporal-proto-models-nexusrpc.yaml" ) output_file = ( repo_root / "temporalio" / "nexus" / "system" / "_workflow_service_generated.py" @@ -75,7 +75,6 @@ def run_nexus_rpc_gen( "py", "--out-file", str(output_file), - "--temporal-nexus-payload-codec-support", str(input_schema), ] if override_root is None: diff --git a/scripts/gen_payload_visitor.py b/scripts/gen_payload_visitor.py index 0e86e0f56..2b66103cb 100644 --- a/scripts/gen_payload_visitor.py +++ b/scripts/gen_payload_visitor.py @@ -1,11 +1,18 @@ import subprocess import sys +from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path -from typing import Optional +from types import ModuleType +from typing import cast +import google.protobuf.message from google.protobuf.descriptor import Descriptor, FieldDescriptor from temporalio.api.common.v1.message_pb2 import Payload, Payloads, SearchAttributes +from temporalio.api.workflowservice.v1.request_response_pb2 import ( + SignalWithStartWorkflowExecutionRequest, + SignalWithStartWorkflowExecutionResponse, +) from temporalio.bridge.proto.workflow_activation.workflow_activation_pb2 import ( WorkflowActivation, ) @@ -17,28 +24,63 @@ def name_for(desc: Descriptor) -> str: - # Use fully-qualified name to avoid collisions; replace dots with underscores return desc.full_name.replace(".", "_") -def emit_loop( - field_name: str, - iter_expr: str, - child_method: str, -) -> str: - # Helper to emit a for-loop over a collection with optional headers guard +def python_type_for(desc: Descriptor) -> str: + module = desc.file.package + if module.startswith("temporal.api."): + module = "temporalio.api." + module[len("temporal.api.") :] + return f"{module}.{desc.name}" + return "object" + + +def load_generated_system_module() -> ModuleType: + module_path = ( + base_dir / "temporalio" / "nexus" / "system" / "_workflow_service_generated.py" + ) + module_name = "temporalio_nexus_system_generated" + spec = spec_from_file_location(module_name, module_path) + if spec is None or spec.loader is None: + raise RuntimeError(f"Cannot load generated system module from {module_path}") + module = module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def discover_system_nexus_roots() -> list[Descriptor]: + module = load_generated_system_module() + roots: list[Descriptor] = [] + for operation in getattr(module, "__nexus_operation_registry__", {}).values(): + for proto_type in (operation.input_type, operation.output_type): + if ( + isinstance(proto_type, type) + and issubclass(proto_type, google.protobuf.message.Message) + and proto_type.DESCRIPTOR is not None + ): + roots.append(cast(Descriptor, proto_type.DESCRIPTOR)) + deduped: list[Descriptor] = [] + seen: set[str] = set() + for root in roots: + if root.full_name not in seen: + seen.add(root.full_name) + deduped.append(root) + return deduped + + +def emit_loop(field_name: str, iter_expr: str, child_method: str) -> str: if field_name == "headers": return f"""\ if not self.skip_headers: for v in {iter_expr}: await self._visit_{child_method}(fs, v)""" - elif field_name == "search_attributes": + if field_name == "search_attributes": return f"""\ if not self.skip_search_attributes: for v in {iter_expr}: await self._visit_{child_method}(fs, v)""" - else: - return f"""\ + return f"""\ for v in {iter_expr}: await self._visit_{child_method}(fs, v)""" @@ -46,54 +88,68 @@ def emit_loop( def emit_singular( field_name: str, access_expr: str, child_method: str, presence_word: str | None ) -> str: - # Helper to emit a singular field visit with presence check and optional headers guard if presence_word: if field_name == "headers": return f"""\ if not self.skip_headers: {presence_word} o.HasField("{field_name}"): await self._visit_{child_method}(fs, {access_expr})""" - else: - return f"""\ + return f"""\ {presence_word} o.HasField("{field_name}"): await self._visit_{child_method}(fs, {access_expr})""" - else: - if field_name == "headers": - return f"""\ + if field_name == "headers": + return f"""\ if not self.skip_headers: await self._visit_{child_method}(fs, {access_expr})""" - else: - return f"""\ + return f"""\ await self._visit_{child_method}(fs, {access_expr})""" -class VisitorGenerator: - def generate(self, roots: list[Descriptor]) -> str: - """ - Generate Python source code that, given a function f(Payload) -> Payload, - applies it to every Payload contained within a WorkflowActivation tree. - - The generated code defines async visitor functions for each reachable - protobuf message type starting from WorkflowActivation, including support - for repeated fields and map entries, and a convenience entrypoint - function `visit`. - """ +class PayloadVisitorGenerator: + def __init__(self) -> None: + self.generated: dict[str, bool] = { + Payload.DESCRIPTOR.full_name: True, + Payloads.DESCRIPTOR.full_name: True, + } + self.in_progress: set[str] = set() + self.methods: list[str] = [ + """\ + async def _visit_temporal_api_common_v1_Payload( + self, fs: VisitorFunctions, o: Any + ): + await fs.visit_payload(o) + """, + """\ + async def _visit_temporal_api_common_v1_Payloads( + self, fs: VisitorFunctions, o: Any + ): + await fs.visit_payloads(o.payloads) + """, + """\ + async def _visit_payload_container(self, fs: VisitorFunctions, o: Any): + await fs.visit_payloads(o) + """, + ] - for r in roots: - self.walk(r) + def generate(self, roots: list[Descriptor]) -> str: + for root in roots: + self.walk(root) header = """ # This file is generated by gen_payload_visitor.py. Changes should be made there. import abc -from typing import Any, MutableSequence +from collections.abc import MutableSequence +from typing import Any +import temporalio.nexus.system from temporalio.api.common.v1.message_pb2 import Payload class VisitorFunctions(abc.ABC): - \"\"\"Set of functions which can be called by the visitor. + \"\"\"Set of functions which can be called by the visitor. Allows handling payloads as a sequence. \"\"\" + @abc.abstractmethod async def visit_payload(self, payload: Payload) -> None: \"\"\"Called when encountering a single payload.\"\"\" @@ -109,21 +165,21 @@ async def visit_system_nexus_envelope(self, payload: Payload) -> None: \"\"\"Called when encountering a recognized system Nexus envelope payload.\"\"\" raise NotImplementedError() + class PayloadVisitor: - \"\"\"A visitor for payloads. + \"\"\"A visitor for payloads. Applies a function to every payload in a tree of messages. \"\"\" + def __init__( self, *, skip_search_attributes: bool = False, skip_headers: bool = False ): - \"\"\"Creates a new payload visitor.\"\"\" + \"\"\"Create a new payload visitor.\"\"\" self.skip_search_attributes = skip_search_attributes self.skip_headers = skip_headers - async def visit( - self, fs: VisitorFunctions, root: Any - ) -> None: - \"\"\"Visits the given root message with the given function.\"\"\" + async def visit(self, fs: VisitorFunctions, root: Any) -> None: + \"\"\"Visit the given root message with the given function set.\"\"\" method_name = "_visit_" + root.DESCRIPTOR.full_name.replace(".", "_") method = getattr(self, method_name, None) if method is not None: @@ -131,101 +187,70 @@ async def visit( else: raise ValueError(f"Unknown root message type: {root.DESCRIPTOR.full_name}") - async def _visit_system_nexus_payload(self, fs, service, operation, payload) -> None: - import temporalio.nexus.system - - visitor = temporalio.nexus.system.get_payload_visitor(service, operation) - if visitor is None: + async def _visit_system_nexus_payload( + self, + fs: VisitorFunctions, + service: str, + operation: str, + payload: Payload, + ) -> None: + new_payload = await temporalio.nexus.system.visit_payload( + service, + operation, + payload, + fs, + self.skip_search_attributes, + ) + if new_payload is None: await self._visit_temporal_api_common_v1_Payload(fs, payload) return - async def payload_visitor(payloads): - new_payloads = list(payloads) - await fs.visit_payloads(new_payloads) - return new_payloads - - new_payload = await visitor( - payload, payload_visitor, not self.skip_search_attributes - ) if new_payload is not payload: payload.CopyFrom(new_payload) await fs.visit_system_nexus_envelope(payload) """ - return header + "\n".join(self.methods) - def __init__(self): - # Track which message descriptors have visitor methods generated - self.generated: dict[str, bool] = { - Payload.DESCRIPTOR.full_name: True, - Payloads.DESCRIPTOR.full_name: True, - } - self.in_progress: set[str] = set() - self.methods: list[str] = [ - """\ - async def _visit_temporal_api_common_v1_Payload(self, fs, o): - await fs.visit_payload(o) - """, - """\ - async def _visit_temporal_api_common_v1_Payloads(self, fs, o): - await fs.visit_payloads(o.payloads) - """, - """\ - async def _visit_payload_container(self, fs, o): - await fs.visit_payloads(o) - """, - ] - - def check_repeated(self, child_desc, field, iter_expr) -> str | None: - # Special case for repeated payloads, handle them directly + def check_repeated( + self, child_desc: Descriptor, field: FieldDescriptor, iter_expr: str + ) -> str | None: if child_desc.full_name == Payload.DESCRIPTOR.full_name: return emit_singular(field.name, iter_expr, "payload_container", None) - else: - child_needed = self.walk(child_desc) - if child_needed: - return emit_loop( - field.name, - iter_expr, - name_for(child_desc), - ) - else: - return None + child_needed = self.walk(child_desc) + if child_needed: + return emit_loop(field.name, iter_expr, name_for(child_desc)) + return None def walk(self, desc: Descriptor) -> bool: key = desc.full_name if key in self.generated: return self.generated[key] if key in self.in_progress: - # Break cycles; Assume the child will be needed (Used by Failure -> Cause) return True has_payload = False self.in_progress.add(key) - lines: list[str] = [f" async def _visit_{name_for(desc)}(self, fs, o):"] - # If this is the SearchAttributes message, allow skipping + lines: list[str] = [ + f" async def _visit_{name_for(desc)}(" + "self, fs: VisitorFunctions, o: Any" + "):" + ] if desc.full_name == SearchAttributes.DESCRIPTOR.full_name: lines.append(" if self.skip_search_attributes:") lines.append(" return") - # Group fields by oneof to generate if/elif chains oneof_fields: dict[int, list[FieldDescriptor]] = {} regular_fields: list[FieldDescriptor] = [] for field in desc.fields: if field.type != FieldDescriptor.TYPE_MESSAGE: continue - - # Skip synthetic oneofs (proto3 optional fields) if field.containing_oneof is not None: - oneof_idx = field.containing_oneof.index - if oneof_idx not in oneof_fields: - oneof_fields[oneof_idx] = [] - oneof_fields[oneof_idx].append(field) + oneof_fields.setdefault(field.containing_oneof.index, []).append(field) else: regular_fields.append(field) - # Process regular fields first for field in regular_fields: if ( desc.full_name == "coresdk.workflow_commands.ScheduleNexusOperation" @@ -238,7 +263,7 @@ def walk(self, desc: Descriptor) -> bool: await self._visit_system_nexus_payload(fs, o.service, o.operation, o.input)""" ) continue - # Repeated fields (including maps which are represented as repeated messages) + if field.label == FieldDescriptor.LABEL_REPEATED: if ( field.message_type is not None @@ -295,8 +320,7 @@ def walk(self, desc: Descriptor) -> bool: ) ) - # Process oneof fields as if/elif chains - for oneof_idx, fields in oneof_fields.items(): + for fields in oneof_fields.values(): oneof_lines = [] first = True for field in fields: @@ -304,12 +328,15 @@ def walk(self, desc: Descriptor) -> bool: child_has_payload = self.walk(child_desc) has_payload |= child_has_payload if child_has_payload: - if_word = "if" if first else "elif" - first = False - line = emit_singular( - field.name, f"o.{field.name}", name_for(child_desc), if_word + oneof_lines.append( + emit_singular( + field.name, + f"o.{field.name}", + name_for(child_desc), + "if" if first else "elif", + ) ) - oneof_lines.append(line) + first = False if oneof_lines: lines.extend(oneof_lines) @@ -320,22 +347,50 @@ def walk(self, desc: Descriptor) -> bool: return has_payload -def write_generated_visitors_into_visitor_generated_py() -> None: - """Write the generated visitor code into _visitor.py.""" +def write_bridge_visitors() -> None: out_path = base_dir / "temporalio" / "bridge" / "_visitor.py" - - # Build root descriptors: WorkflowActivation, WorkflowActivationCompletion, - # and all messages from selected API modules - roots: list[Descriptor] = [ + roots = [ WorkflowActivation.DESCRIPTOR, WorkflowActivationCompletion.DESCRIPTOR, ] + out_path.write_text(PayloadVisitorGenerator().generate(roots)) + - code = VisitorGenerator().generate(roots) - out_path.write_text(code) +def write_system_nexus_payload_visitors() -> None: + out_path = base_dir / "temporalio" / "nexus" / "system" / "_payload_visitor.py" + roots = discover_system_nexus_roots() + out_path.write_text(PayloadVisitorGenerator().generate(roots)) if __name__ == "__main__": print("Generating temporalio/bridge/_visitor.py...", file=sys.stderr) - write_generated_visitors_into_visitor_generated_py() - subprocess.run(["uv", "run", "ruff", "format", "temporalio/bridge/_visitor.py"]) + write_bridge_visitors() + print("Generating temporalio/nexus/system/_payload_visitor.py...", file=sys.stderr) + write_system_nexus_payload_visitors() + subprocess.run( + [ + "uv", + "run", + "ruff", + "check", + "--select", + "I", + "--fix", + "temporalio/bridge/_visitor.py", + "temporalio/nexus/system/_payload_visitor.py", + ], + cwd=base_dir, + check=True, + ) + subprocess.run( + [ + "uv", + "run", + "ruff", + "format", + "temporalio/bridge/_visitor.py", + "temporalio/nexus/system/_payload_visitor.py", + ], + cwd=base_dir, + check=True, + ) diff --git a/temporalio/_workflow_requests.py b/temporalio/_workflow_requests.py new file mode 100644 index 000000000..eb2b72606 --- /dev/null +++ b/temporalio/_workflow_requests.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +from collections.abc import Sequence +from datetime import timedelta +from typing import cast + +import temporalio.api.common.v1 +import temporalio.api.enums.v1 +import temporalio.api.sdk.v1 +import temporalio.api.workflowservice.v1 +import temporalio.common + + +def populate_start_workflow_execution_request( + req: ( + temporalio.api.workflowservice.v1.StartWorkflowExecutionRequest + | temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest + ), + *, + namespace: str, + workflow_id: str, + workflow: str, + task_queue: str, + input_payloads: Sequence[temporalio.api.common.v1.Payload] = (), + execution_timeout: timedelta | None = None, + run_timeout: timedelta | None = None, + task_timeout: timedelta | None = None, + identity: str | None = None, + request_id: str | None = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: temporalio.common.RetryPolicy | None = None, + cron_schedule: str = "", + memo: temporalio.api.common.v1.Memo | None = None, + search_attributes: temporalio.api.common.v1.SearchAttributes | None = None, + header: temporalio.api.common.v1.Header | None = None, + user_metadata: temporalio.api.sdk.v1.UserMetadata | None = None, + start_delay: timedelta | None = None, + priority: temporalio.common.Priority | None = None, + versioning_override: temporalio.common.VersioningOverride | None = None, +) -> None: + """Populate a workflow-service start-style request from pre-encoded pieces.""" + req.namespace = namespace + req.workflow_id = workflow_id + req.workflow_type.name = workflow + req.task_queue.name = task_queue + if input_payloads: + req.input.payloads.extend(input_payloads) + if execution_timeout is not None: + req.workflow_execution_timeout.FromTimedelta(execution_timeout) + if run_timeout is not None: + req.workflow_run_timeout.FromTimedelta(run_timeout) + if task_timeout is not None: + req.workflow_task_timeout.FromTimedelta(task_timeout) + if identity is not None: + req.identity = identity + if request_id is not None: + req.request_id = request_id + req.workflow_id_reuse_policy = cast( + "temporalio.api.enums.v1.WorkflowIdReusePolicy.ValueType", + int(id_reuse_policy), + ) + req.workflow_id_conflict_policy = cast( + "temporalio.api.enums.v1.WorkflowIdConflictPolicy.ValueType", + int(id_conflict_policy), + ) + if retry_policy is not None: + retry_policy.apply_to_proto(req.retry_policy) + req.cron_schedule = cron_schedule + if memo is not None: + req.memo.CopyFrom(memo) + if search_attributes is not None: + req.search_attributes.CopyFrom(search_attributes) + if header is not None: + req.header.CopyFrom(header) + if user_metadata is not None: + req.user_metadata.CopyFrom(user_metadata) + if start_delay is not None: + req.workflow_start_delay.FromTimedelta(start_delay) + if priority is not None: + req.priority.CopyFrom(priority._to_proto()) + if versioning_override is not None: + req.versioning_override.CopyFrom(versioning_override._to_proto()) + + +def build_signal_with_start_workflow_execution_request( + *, + namespace: str, + workflow_id: str, + workflow: str, + task_queue: str, + signal_name: str, + workflow_input_payloads: Sequence[temporalio.api.common.v1.Payload] = (), + signal_input_payloads: Sequence[temporalio.api.common.v1.Payload] = (), + execution_timeout: timedelta | None = None, + run_timeout: timedelta | None = None, + task_timeout: timedelta | None = None, + identity: str | None = None, + request_id: str | None = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: temporalio.common.RetryPolicy | None = None, + cron_schedule: str = "", + memo: temporalio.api.common.v1.Memo | None = None, + search_attributes: temporalio.api.common.v1.SearchAttributes | None = None, + header: temporalio.api.common.v1.Header | None = None, + user_metadata: temporalio.api.sdk.v1.UserMetadata | None = None, + start_delay: timedelta | None = None, + priority: temporalio.common.Priority | None = None, + versioning_override: temporalio.common.VersioningOverride | None = None, +) -> temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest: + """Build a signal-with-start workflow-service request from pre-encoded pieces.""" + req = temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest( + signal_name=signal_name + ) + if signal_input_payloads: + req.signal_input.payloads.extend(signal_input_payloads) + populate_start_workflow_execution_request( + req, + namespace=namespace, + workflow_id=workflow_id, + workflow=workflow, + task_queue=task_queue, + input_payloads=workflow_input_payloads, + execution_timeout=execution_timeout, + run_timeout=run_timeout, + task_timeout=task_timeout, + identity=identity, + request_id=request_id, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + cron_schedule=cron_schedule, + memo=memo, + search_attributes=search_attributes, + header=header, + user_metadata=user_metadata, + start_delay=start_delay, + priority=priority, + versioning_override=versioning_override, + ) + return req diff --git a/temporalio/bridge/_visitor.py b/temporalio/bridge/_visitor.py index 3a2d98818..b6c61d730 100644 --- a/temporalio/bridge/_visitor.py +++ b/temporalio/bridge/_visitor.py @@ -1,7 +1,9 @@ # This file is generated by gen_payload_visitor.py. Changes should be made there. import abc -from typing import Any, MutableSequence +from collections.abc import MutableSequence +from typing import Any +import temporalio.nexus.system from temporalio.api.common.v1.message_pb2 import Payload @@ -34,12 +36,12 @@ class PayloadVisitor: def __init__( self, *, skip_search_attributes: bool = False, skip_headers: bool = False ): - """Creates a new payload visitor.""" + """Create a new payload visitor.""" self.skip_search_attributes = skip_search_attributes self.skip_headers = skip_headers async def visit(self, fs: VisitorFunctions, root: Any) -> None: - """Visits the given root message with the given function.""" + """Visit the given root message with the given function set.""" method_name = "_visit_" + root.DESCRIPTOR.full_name.replace(".", "_") method = getattr(self, method_name, None) if method is not None: @@ -48,57 +50,69 @@ async def visit(self, fs: VisitorFunctions, root: Any) -> None: raise ValueError(f"Unknown root message type: {root.DESCRIPTOR.full_name}") async def _visit_system_nexus_payload( - self, fs, service, operation, payload + self, + fs: VisitorFunctions, + service: str, + operation: str, + payload: Payload, ) -> None: - import temporalio.nexus.system - - visitor = temporalio.nexus.system.get_payload_visitor(service, operation) - if visitor is None: + new_payload = await temporalio.nexus.system.visit_payload( + service, + operation, + payload, + fs, + self.skip_search_attributes, + ) + if new_payload is None: await self._visit_temporal_api_common_v1_Payload(fs, payload) return - async def payload_visitor(payloads): - new_payloads = list(payloads) - await fs.visit_payloads(new_payloads) - return new_payloads - - new_payload = await visitor( - payload, payload_visitor, not self.skip_search_attributes - ) if new_payload is not payload: payload.CopyFrom(new_payload) await fs.visit_system_nexus_envelope(payload) - async def _visit_temporal_api_common_v1_Payload(self, fs, o): + async def _visit_temporal_api_common_v1_Payload(self, fs: VisitorFunctions, o: Any): await fs.visit_payload(o) - async def _visit_temporal_api_common_v1_Payloads(self, fs, o): + async def _visit_temporal_api_common_v1_Payloads( + self, fs: VisitorFunctions, o: Any + ): await fs.visit_payloads(o.payloads) - async def _visit_payload_container(self, fs, o): + async def _visit_payload_container(self, fs: VisitorFunctions, o: Any): await fs.visit_payloads(o) - async def _visit_temporal_api_failure_v1_ApplicationFailureInfo(self, fs, o): + async def _visit_temporal_api_failure_v1_ApplicationFailureInfo( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("details"): await self._visit_temporal_api_common_v1_Payloads(fs, o.details) - async def _visit_temporal_api_failure_v1_TimeoutFailureInfo(self, fs, o): + async def _visit_temporal_api_failure_v1_TimeoutFailureInfo( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("last_heartbeat_details"): await self._visit_temporal_api_common_v1_Payloads( fs, o.last_heartbeat_details ) - async def _visit_temporal_api_failure_v1_CanceledFailureInfo(self, fs, o): + async def _visit_temporal_api_failure_v1_CanceledFailureInfo( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("details"): await self._visit_temporal_api_common_v1_Payloads(fs, o.details) - async def _visit_temporal_api_failure_v1_ResetWorkflowFailureInfo(self, fs, o): + async def _visit_temporal_api_failure_v1_ResetWorkflowFailureInfo( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("last_heartbeat_details"): await self._visit_temporal_api_common_v1_Payloads( fs, o.last_heartbeat_details ) - async def _visit_temporal_api_failure_v1_Failure(self, fs, o): + async def _visit_temporal_api_failure_v1_Failure( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("encoded_attributes"): await self._visit_temporal_api_common_v1_Payload(fs, o.encoded_attributes) if o.HasField("cause"): @@ -120,17 +134,21 @@ async def _visit_temporal_api_failure_v1_Failure(self, fs, o): fs, o.reset_workflow_failure_info ) - async def _visit_temporal_api_common_v1_Memo(self, fs, o): + async def _visit_temporal_api_common_v1_Memo(self, fs: VisitorFunctions, o: Any): for v in o.fields.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) - async def _visit_temporal_api_common_v1_SearchAttributes(self, fs, o): + async def _visit_temporal_api_common_v1_SearchAttributes( + self, fs: VisitorFunctions, o: Any + ): if self.skip_search_attributes: return for v in o.indexed_fields.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) - async def _visit_coresdk_workflow_activation_InitializeWorkflow(self, fs, o): + async def _visit_coresdk_workflow_activation_InitializeWorkflow( + self, fs: VisitorFunctions, o: Any + ): await self._visit_payload_container(fs, o.arguments) if not self.skip_headers: for v in o.headers.values(): @@ -148,31 +166,43 @@ async def _visit_coresdk_workflow_activation_InitializeWorkflow(self, fs, o): fs, o.search_attributes ) - async def _visit_coresdk_workflow_activation_QueryWorkflow(self, fs, o): + async def _visit_coresdk_workflow_activation_QueryWorkflow( + self, fs: VisitorFunctions, o: Any + ): await self._visit_payload_container(fs, o.arguments) if not self.skip_headers: for v in o.headers.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) - async def _visit_coresdk_workflow_activation_SignalWorkflow(self, fs, o): + async def _visit_coresdk_workflow_activation_SignalWorkflow( + self, fs: VisitorFunctions, o: Any + ): await self._visit_payload_container(fs, o.input) if not self.skip_headers: for v in o.headers.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) - async def _visit_coresdk_activity_result_Success(self, fs, o): + async def _visit_coresdk_activity_result_Success( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("result"): await self._visit_temporal_api_common_v1_Payload(fs, o.result) - async def _visit_coresdk_activity_result_Failure(self, fs, o): + async def _visit_coresdk_activity_result_Failure( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) - async def _visit_coresdk_activity_result_Cancellation(self, fs, o): + async def _visit_coresdk_activity_result_Cancellation( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) - async def _visit_coresdk_activity_result_ActivityResolution(self, fs, o): + async def _visit_coresdk_activity_result_ActivityResolution( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("completed"): await self._visit_coresdk_activity_result_Success(fs, o.completed) elif o.HasField("failed"): @@ -180,37 +210,43 @@ async def _visit_coresdk_activity_result_ActivityResolution(self, fs, o): elif o.HasField("cancelled"): await self._visit_coresdk_activity_result_Cancellation(fs, o.cancelled) - async def _visit_coresdk_workflow_activation_ResolveActivity(self, fs, o): + async def _visit_coresdk_workflow_activation_ResolveActivity( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("result"): await self._visit_coresdk_activity_result_ActivityResolution(fs, o.result) async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStartCancelled( - self, fs, o + self, fs: VisitorFunctions, o: Any ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStart( - self, fs, o + self, fs: VisitorFunctions, o: Any ): if o.HasField("cancelled"): await self._visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStartCancelled( fs, o.cancelled ) - async def _visit_coresdk_child_workflow_Success(self, fs, o): + async def _visit_coresdk_child_workflow_Success(self, fs: VisitorFunctions, o: Any): if o.HasField("result"): await self._visit_temporal_api_common_v1_Payload(fs, o.result) - async def _visit_coresdk_child_workflow_Failure(self, fs, o): + async def _visit_coresdk_child_workflow_Failure(self, fs: VisitorFunctions, o: Any): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) - async def _visit_coresdk_child_workflow_Cancellation(self, fs, o): + async def _visit_coresdk_child_workflow_Cancellation( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) - async def _visit_coresdk_child_workflow_ChildWorkflowResult(self, fs, o): + async def _visit_coresdk_child_workflow_ChildWorkflowResult( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("completed"): await self._visit_coresdk_child_workflow_Success(fs, o.completed) elif o.HasField("failed"): @@ -219,36 +255,40 @@ async def _visit_coresdk_child_workflow_ChildWorkflowResult(self, fs, o): await self._visit_coresdk_child_workflow_Cancellation(fs, o.cancelled) async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecution( - self, fs, o + self, fs: VisitorFunctions, o: Any ): if o.HasField("result"): await self._visit_coresdk_child_workflow_ChildWorkflowResult(fs, o.result) async def _visit_coresdk_workflow_activation_ResolveSignalExternalWorkflow( - self, fs, o + self, fs: VisitorFunctions, o: Any ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) async def _visit_coresdk_workflow_activation_ResolveRequestCancelExternalWorkflow( - self, fs, o + self, fs: VisitorFunctions, o: Any ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) - async def _visit_coresdk_workflow_activation_DoUpdate(self, fs, o): + async def _visit_coresdk_workflow_activation_DoUpdate( + self, fs: VisitorFunctions, o: Any + ): await self._visit_payload_container(fs, o.input) if not self.skip_headers: for v in o.headers.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) async def _visit_coresdk_workflow_activation_ResolveNexusOperationStart( - self, fs, o + self, fs: VisitorFunctions, o: Any ): if o.HasField("failed"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failed) - async def _visit_coresdk_nexus_NexusOperationResult(self, fs, o): + async def _visit_coresdk_nexus_NexusOperationResult( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("completed"): await self._visit_temporal_api_common_v1_Payload(fs, o.completed) elif o.HasField("failed"): @@ -258,11 +298,15 @@ async def _visit_coresdk_nexus_NexusOperationResult(self, fs, o): elif o.HasField("timed_out"): await self._visit_temporal_api_failure_v1_Failure(fs, o.timed_out) - async def _visit_coresdk_workflow_activation_ResolveNexusOperation(self, fs, o): + async def _visit_coresdk_workflow_activation_ResolveNexusOperation( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("result"): await self._visit_coresdk_nexus_NexusOperationResult(fs, o.result) - async def _visit_coresdk_workflow_activation_WorkflowActivationJob(self, fs, o): + async def _visit_coresdk_workflow_activation_WorkflowActivationJob( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("initialize_workflow"): await self._visit_coresdk_workflow_activation_InitializeWorkflow( fs, o.initialize_workflow @@ -306,42 +350,56 @@ async def _visit_coresdk_workflow_activation_WorkflowActivationJob(self, fs, o): fs, o.resolve_nexus_operation ) - async def _visit_coresdk_workflow_activation_WorkflowActivation(self, fs, o): + async def _visit_coresdk_workflow_activation_WorkflowActivation( + self, fs: VisitorFunctions, o: Any + ): for v in o.jobs: await self._visit_coresdk_workflow_activation_WorkflowActivationJob(fs, v) - async def _visit_temporal_api_sdk_v1_UserMetadata(self, fs, o): + async def _visit_temporal_api_sdk_v1_UserMetadata( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("summary"): await self._visit_temporal_api_common_v1_Payload(fs, o.summary) if o.HasField("details"): await self._visit_temporal_api_common_v1_Payload(fs, o.details) - async def _visit_coresdk_workflow_commands_ScheduleActivity(self, fs, o): + async def _visit_coresdk_workflow_commands_ScheduleActivity( + self, fs: VisitorFunctions, o: Any + ): if not self.skip_headers: for v in o.headers.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) await self._visit_payload_container(fs, o.arguments) - async def _visit_coresdk_workflow_commands_QuerySuccess(self, fs, o): + async def _visit_coresdk_workflow_commands_QuerySuccess( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("response"): await self._visit_temporal_api_common_v1_Payload(fs, o.response) - async def _visit_coresdk_workflow_commands_QueryResult(self, fs, o): + async def _visit_coresdk_workflow_commands_QueryResult( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("succeeded"): await self._visit_coresdk_workflow_commands_QuerySuccess(fs, o.succeeded) elif o.HasField("failed"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failed) - async def _visit_coresdk_workflow_commands_CompleteWorkflowExecution(self, fs, o): + async def _visit_coresdk_workflow_commands_CompleteWorkflowExecution( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("result"): await self._visit_temporal_api_common_v1_Payload(fs, o.result) - async def _visit_coresdk_workflow_commands_FailWorkflowExecution(self, fs, o): + async def _visit_coresdk_workflow_commands_FailWorkflowExecution( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) async def _visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution( - self, fs, o + self, fs: VisitorFunctions, o: Any ): await self._visit_payload_container(fs, o.arguments) for v in o.memo.values(): @@ -354,7 +412,9 @@ async def _visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution( fs, o.search_attributes ) - async def _visit_coresdk_workflow_commands_StartChildWorkflowExecution(self, fs, o): + async def _visit_coresdk_workflow_commands_StartChildWorkflowExecution( + self, fs: VisitorFunctions, o: Any + ): await self._visit_payload_container(fs, o.input) if not self.skip_headers: for v in o.headers.values(): @@ -367,42 +427,52 @@ async def _visit_coresdk_workflow_commands_StartChildWorkflowExecution(self, fs, ) async def _visit_coresdk_workflow_commands_SignalExternalWorkflowExecution( - self, fs, o + self, fs: VisitorFunctions, o: Any ): await self._visit_payload_container(fs, o.args) if not self.skip_headers: for v in o.headers.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) - async def _visit_coresdk_workflow_commands_ScheduleLocalActivity(self, fs, o): + async def _visit_coresdk_workflow_commands_ScheduleLocalActivity( + self, fs: VisitorFunctions, o: Any + ): if not self.skip_headers: for v in o.headers.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) await self._visit_payload_container(fs, o.arguments) async def _visit_coresdk_workflow_commands_UpsertWorkflowSearchAttributes( - self, fs, o + self, fs: VisitorFunctions, o: Any ): if o.HasField("search_attributes"): await self._visit_temporal_api_common_v1_SearchAttributes( fs, o.search_attributes ) - async def _visit_coresdk_workflow_commands_ModifyWorkflowProperties(self, fs, o): + async def _visit_coresdk_workflow_commands_ModifyWorkflowProperties( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("upserted_memo"): await self._visit_temporal_api_common_v1_Memo(fs, o.upserted_memo) - async def _visit_coresdk_workflow_commands_UpdateResponse(self, fs, o): + async def _visit_coresdk_workflow_commands_UpdateResponse( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("rejected"): await self._visit_temporal_api_failure_v1_Failure(fs, o.rejected) elif o.HasField("completed"): await self._visit_temporal_api_common_v1_Payload(fs, o.completed) - async def _visit_coresdk_workflow_commands_ScheduleNexusOperation(self, fs, o): + async def _visit_coresdk_workflow_commands_ScheduleNexusOperation( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("input"): await self._visit_system_nexus_payload(fs, o.service, o.operation, o.input) - async def _visit_coresdk_workflow_commands_WorkflowCommand(self, fs, o): + async def _visit_coresdk_workflow_commands_WorkflowCommand( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("user_metadata"): await self._visit_temporal_api_sdk_v1_UserMetadata(fs, o.user_metadata) if o.HasField("schedule_activity"): @@ -454,16 +524,20 @@ async def _visit_coresdk_workflow_commands_WorkflowCommand(self, fs, o): fs, o.schedule_nexus_operation ) - async def _visit_coresdk_workflow_completion_Success(self, fs, o): + async def _visit_coresdk_workflow_completion_Success( + self, fs: VisitorFunctions, o: Any + ): for v in o.commands: await self._visit_coresdk_workflow_commands_WorkflowCommand(fs, v) - async def _visit_coresdk_workflow_completion_Failure(self, fs, o): + async def _visit_coresdk_workflow_completion_Failure( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) async def _visit_coresdk_workflow_completion_WorkflowActivationCompletion( - self, fs, o + self, fs: VisitorFunctions, o: Any ): if o.HasField("successful"): await self._visit_coresdk_workflow_completion_Success(fs, o.successful) diff --git a/temporalio/client.py b/temporalio/client.py index 22b07b1c1..9bd6b440a 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -39,6 +39,7 @@ from google.protobuf.internal.containers import MessageMap from typing_extensions import Required, Self, TypedDict +import temporalio._workflow_requests import temporalio.activity import temporalio.api.activity.v1 import temporalio.api.common.v1 @@ -8083,15 +8084,51 @@ async def _build_signal_with_start_workflow_execution_request( workflow_id=input.id, ) ) - req = temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest( - signal_name=input.start_signal - ) - if input.start_signal_args: - req.signal_input.payloads.extend( - await data_converter.encode(input.start_signal_args) + request_memo = None + if input.memo is not None: + request_memo = temporalio.api.common.v1.Memo() + await data_converter._encode_memo_existing(input.memo, request_memo) + request_search_attributes = None + if input.search_attributes is not None: + request_search_attributes = temporalio.api.common.v1.SearchAttributes() + temporalio.converter.encode_search_attributes( + input.search_attributes, request_search_attributes ) - await self._populate_start_workflow_execution_request(req, input) - return req + request_header = None + if input.headers is not None: # type:ignore[reportUnnecessaryComparison] + request_header = temporalio.api.common.v1.Header() + await self._apply_headers(input.headers, request_header.fields) + return temporalio._workflow_requests.build_signal_with_start_workflow_execution_request( + namespace=self._client.namespace, + workflow_id=input.id, + workflow=input.workflow, + task_queue=input.task_queue, + signal_name=input.start_signal, + workflow_input_payloads=await data_converter.encode(input.args) + if input.args + else (), + signal_input_payloads=await data_converter.encode(input.start_signal_args) + if input.start_signal_args + else (), + execution_timeout=input.execution_timeout, + run_timeout=input.run_timeout, + task_timeout=input.task_timeout, + identity=self._client.identity, + request_id=str(uuid.uuid4()), + id_reuse_policy=input.id_reuse_policy, + id_conflict_policy=input.id_conflict_policy, + retry_policy=input.retry_policy, + cron_schedule=input.cron_schedule, + memo=request_memo, + search_attributes=request_search_attributes, + header=request_header, + user_metadata=await _encode_user_metadata( + data_converter, input.static_summary, input.static_details + ), + start_delay=input.start_delay, + priority=input.priority, + versioning_override=input.versioning_override, + ) async def _build_update_with_start_start_workflow_execution_request( self, input: UpdateWithStartStartWorkflowInput @@ -8114,51 +8151,48 @@ async def _populate_start_workflow_execution_request( workflow_id=input.id, ) ) - req.namespace = self._client.namespace - req.workflow_id = input.id - req.workflow_type.name = input.workflow - req.task_queue.name = input.task_queue - if input.args: - req.input.payloads.extend(await data_converter.encode(input.args)) - if input.execution_timeout is not None: - req.workflow_execution_timeout.FromTimedelta(input.execution_timeout) - if input.run_timeout is not None: - req.workflow_run_timeout.FromTimedelta(input.run_timeout) - if input.task_timeout is not None: - req.workflow_task_timeout.FromTimedelta(input.task_timeout) - req.identity = self._client.identity - req.request_id = str(uuid.uuid4()) - req.workflow_id_reuse_policy = cast( - "temporalio.api.enums.v1.WorkflowIdReusePolicy.ValueType", - int(input.id_reuse_policy), - ) - req.workflow_id_conflict_policy = cast( - "temporalio.api.enums.v1.WorkflowIdConflictPolicy.ValueType", - int(input.id_conflict_policy), - ) - - if input.retry_policy is not None: - input.retry_policy.apply_to_proto(req.retry_policy) - req.cron_schedule = input.cron_schedule + request_memo = None if input.memo is not None: - await data_converter._encode_memo_existing(input.memo, req.memo) + request_memo = temporalio.api.common.v1.Memo() + await data_converter._encode_memo_existing(input.memo, request_memo) + request_search_attributes = None if input.search_attributes is not None: + request_search_attributes = temporalio.api.common.v1.SearchAttributes() temporalio.converter.encode_search_attributes( - input.search_attributes, req.search_attributes + input.search_attributes, request_search_attributes ) - metadata = await _encode_user_metadata( - data_converter, input.static_summary, input.static_details - ) - if metadata is not None: - req.user_metadata.CopyFrom(metadata) - if input.start_delay is not None: - req.workflow_start_delay.FromTimedelta(input.start_delay) + request_header = None if input.headers is not None: # type:ignore[reportUnnecessaryComparison] - await self._apply_headers(input.headers, req.header.fields) - if input.priority is not None: # type:ignore[reportUnnecessaryComparison] - req.priority.CopyFrom(input.priority._to_proto()) - if input.versioning_override is not None: - req.versioning_override.CopyFrom(input.versioning_override._to_proto()) + request_header = temporalio.api.common.v1.Header() + await self._apply_headers(input.headers, request_header.fields) + temporalio._workflow_requests.populate_start_workflow_execution_request( + req, + namespace=self._client.namespace, + workflow_id=input.id, + workflow=input.workflow, + task_queue=input.task_queue, + input_payloads=await data_converter.encode(input.args) + if input.args + else (), + execution_timeout=input.execution_timeout, + run_timeout=input.run_timeout, + task_timeout=input.task_timeout, + identity=self._client.identity, + request_id=str(uuid.uuid4()), + id_reuse_policy=input.id_reuse_policy, + id_conflict_policy=input.id_conflict_policy, + retry_policy=input.retry_policy, + cron_schedule=input.cron_schedule, + memo=request_memo, + search_attributes=request_search_attributes, + header=request_header, + user_metadata=await _encode_user_metadata( + data_converter, input.static_summary, input.static_details + ), + start_delay=input.start_delay, + priority=input.priority, + versioning_override=input.versioning_override, + ) async def cancel_workflow(self, input: CancelWorkflowInput) -> None: await self._client.workflow_service.request_cancel_workflow_execution( diff --git a/temporalio/nexus/system/__init__.py b/temporalio/nexus/system/__init__.py index 63037d6dc..e8fbff1c4 100644 --- a/temporalio/nexus/system/__init__.py +++ b/temporalio/nexus/system/__init__.py @@ -1,87 +1,33 @@ -"""Generated system Nexus service models. +"""Generated system Nexus service models.""" -This package contains code generated from Temporal's system Nexus schemas. -Higher-level ergonomic APIs may wrap these generated types. -""" - -import dataclasses -from collections.abc import Awaitable, Callable, Mapping, Sequence +from collections.abc import Mapping, MutableSequence, Sequence from datetime import timedelta -from enum import Enum -from typing import Any, cast - -import google.protobuf.message -from google.protobuf.json_format import MessageToDict, Parse, ParseDict +from typing import Any, Protocol, cast import temporalio.api.common.v1 +import temporalio.api.sdk.v1 +import temporalio.api.workflowservice.v1 import temporalio.common import temporalio.converter -from ...converter import CompositePayloadConverter, JSONProtoPayloadConverter -from ...converter._payload_converter import value_to_type +from ... import _workflow_requests +from ...converter import BinaryProtoPayloadConverter, CompositePayloadConverter from . import _workflow_service_generated as generated -from ._workflow_service_generated import __temporal_nexus_payload_visitors__ - -TemporalNexusPayloadVisitor = Callable[ - [ - temporalio.api.common.v1.Payload, - Callable[ - [Sequence[temporalio.api.common.v1.Payload]], - Awaitable[list[temporalio.api.common.v1.Payload]], - ], - bool, - ], - Awaitable[temporalio.api.common.v1.Payload], -] - -_SYSTEM_NEXUS_PAYLOAD_CONVERTER = temporalio.converter.default().payload_converter +from ._payload_visitor import PayloadVisitor -_WORKFLOW_ID_REUSE_POLICY_TO_GENERATED = { - temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE: generated.WorkflowIDReusePolicy.WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE, - temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE_FAILED_ONLY: generated.WorkflowIDReusePolicy.WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE_FAILED_ONLY, - temporalio.common.WorkflowIDReusePolicy.REJECT_DUPLICATE: generated.WorkflowIDReusePolicy.WORKFLOW_ID_REUSE_POLICY_REJECT_DUPLICATE, - temporalio.common.WorkflowIDReusePolicy.TERMINATE_IF_RUNNING: generated.WorkflowIDReusePolicy.WORKFLOW_ID_REUSE_POLICY_TERMINATE_IF_RUNNING, -} -_WORKFLOW_ID_CONFLICT_POLICY_TO_GENERATED = { - temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED: generated.WorkflowIDConflictPolicy.WORKFLOW_ID_CONFLICT_POLICY_UNSPECIFIED, - temporalio.common.WorkflowIDConflictPolicy.FAIL: generated.WorkflowIDConflictPolicy.WORKFLOW_ID_CONFLICT_POLICY_FAIL, - temporalio.common.WorkflowIDConflictPolicy.USE_EXISTING: generated.WorkflowIDConflictPolicy.WORKFLOW_ID_CONFLICT_POLICY_USE_EXISTING, - temporalio.common.WorkflowIDConflictPolicy.TERMINATE_EXISTING: generated.WorkflowIDConflictPolicy.WORKFLOW_ID_CONFLICT_POLICY_TERMINATE_EXISTING, -} +class PayloadVisitorFunctions(Protocol): + async def visit_payload( + self, payload: temporalio.api.common.v1.Payload + ) -> None: ... + async def visit_payloads( + self, payloads: MutableSequence[temporalio.api.common.v1.Payload] + ) -> None: ... -class _SystemNexusJSONProtoPayloadConverter(JSONProtoPayloadConverter): - def to_payload(self, value: Any) -> temporalio.api.common.v1.Payload | None: - proto_type = _get_generated_proto_type(value) - if proto_type is not None: - return super().to_payload( - ParseDict( - dataclasses.asdict(value), - proto_type(), - ignore_unknown_fields=True, - ) - ) - return super().to_payload(value) - - def from_payload( - self, - payload: temporalio.api.common.v1.Payload, - type_hint: type | None = None, - ) -> Any: - proto_type = _get_generated_proto_type(type_hint) - if proto_type is not None and type_hint is not None: - proto_value = Parse( - payload.data, - proto_type(), - ignore_unknown_fields=True, - ) - return value_to_type( - type_hint, - MessageToDict(proto_value, preserving_proto_field_name=True), - [_SystemNexusStrEnumConverter()], - ) - return super().from_payload(payload, type_hint) + async def visit_system_nexus_envelope( + self, payload: temporalio.api.common.v1.Payload + ) -> None: ... class SystemNexusPayloadConverter(CompositePayloadConverter): @@ -89,153 +35,31 @@ class SystemNexusPayloadConverter(CompositePayloadConverter): def __init__(self) -> None: """Create a payload converter for system Nexus outer envelopes.""" - super().__init__(_SystemNexusJSONProtoPayloadConverter()) - - -def _get_generated_proto_type( - value_or_type: Any, -) -> type[google.protobuf.message.Message] | None: - candidate = ( - value_or_type if isinstance(value_or_type, type) else type(value_or_type) - ) - proto_type = getattr(candidate, "__temporal_nexus_proto_type__", None) - if isinstance(proto_type, type) and issubclass( - proto_type, google.protobuf.message.Message - ): - return proto_type - return None - - -class _SystemNexusStrEnumConverter(temporalio.converter.JSONTypeConverter): - # Generated enums subclass str and Enum, not StrEnum, so the default - # value_to_type enum handling does not reconstruct them. - def to_typed_value(self, hint: type, value: Any) -> Any: - if isinstance(hint, type) and issubclass(hint, Enum) and issubclass(hint, str): - if not isinstance(value, str): - raise TypeError(f"Expected value to be str, was {type(value)}") - return hint(value) - return temporalio.converter.JSONTypeConverter.Unhandled - + super().__init__(BinaryProtoPayloadConverter()) -def _payload_to_json_value( - converter: temporalio.converter.PayloadConverter, value: Any -) -> generated.Payload: - return _proto_payload_to_generated(converter.to_payload(value)) - -def _proto_payload_to_generated( - payload: temporalio.api.common.v1.Payload, -) -> generated.Payload: - value = MessageToDict(payload) - return generated.Payload( - data=cast("str | None", value.get("data")), - external_payloads=[ - generated.PayloadExternalPayloadDetails(**details) - for details in cast( - "list[dict[str, str]]", value.get("externalPayloads", []) - ) - ] - or None, - metadata=cast("dict[str, str] | None", value.get("metadata")), - ) +def _set_payload_map( + target: Any, + values: Mapping[str, Any], + payload_converter: temporalio.converter.PayloadConverter, +) -> None: + for key, value in values.items(): + target[key].CopyFrom(payload_converter.to_payload(value)) -def _payloads_to_input( - converter: temporalio.converter.PayloadConverter, values: Sequence[Any] -) -> generated.Payloads | None: - payloads = converter.to_payloads(values) if values else [] - if not payloads: +def _build_user_metadata( + payload_converter: temporalio.converter.PayloadConverter, + static_summary: str | None, + static_details: str | None, +) -> temporalio.api.sdk.v1.UserMetadata | None: + if static_summary is None and static_details is None: return None - return generated.Payloads( - payloads=[_proto_payload_to_generated(payload) for payload in payloads] - ) - - -def _search_attributes_to_json_map( - attributes: temporalio.common.TypedSearchAttributes, -) -> dict[str, generated.Payload]: - return { - pair.key.name: _proto_payload_to_generated( - temporalio.converter.encode_typed_search_attribute_value( - pair.key, pair.value - ) - ) - for pair in attributes - } - - -def _retry_policy_to_generated( - retry_policy: temporalio.common.RetryPolicy, -) -> generated.RetryPolicy: - retry_policy._validate() - return generated.RetryPolicy( - initial_interval=f"{retry_policy.initial_interval.total_seconds()}s", - backoff_coefficient=retry_policy.backoff_coefficient, - maximum_interval=f"{(retry_policy.maximum_interval or retry_policy.initial_interval * 100).total_seconds()}s", - maximum_attempts=retry_policy.maximum_attempts, - non_retryable_error_types=( - list(retry_policy.non_retryable_error_types) - if retry_policy.non_retryable_error_types - else None - ), - ) - - -def _priority_to_generated( - priority: temporalio.common.Priority, -) -> generated.Priority | None: - if ( - priority.priority_key is None - and priority.fairness_key is None - and priority.fairness_weight is None - ): - return None - return generated.Priority( - priority_key=priority.priority_key, - fairness_key=priority.fairness_key, - fairness_weight=priority.fairness_weight, - ) - - -def _workflow_id_reuse_policy_to_generated( - policy: temporalio.common.WorkflowIDReusePolicy, -) -> generated.WorkflowIDReusePolicy: - return _WORKFLOW_ID_REUSE_POLICY_TO_GENERATED[policy] - - -def _workflow_id_conflict_policy_to_generated( - policy: temporalio.common.WorkflowIDConflictPolicy, -) -> generated.WorkflowIDConflictPolicy: - return _WORKFLOW_ID_CONFLICT_POLICY_TO_GENERATED[policy] - - -def _versioning_override_to_generated( - versioning_override: temporalio.common.VersioningOverride, -) -> generated.VersioningOverride: - if isinstance(versioning_override, temporalio.common.AutoUpgradeVersioningOverride): - return generated.VersioningOverride( - auto_upgrade=True, - behavior=generated.VersioningOverrideBehavior.VERSIONING_BEHAVIOR_AUTO_UPGRADE, - ) - if isinstance(versioning_override, temporalio.common.PinnedVersioningOverride): - return generated.VersioningOverride( - behavior=generated.VersioningOverrideBehavior.VERSIONING_BEHAVIOR_PINNED, - pinned_version=versioning_override.version.to_canonical_string(), - pinned=generated.VersioningOverridePinnedOverride( - behavior=generated.VersioningOverridePinnedOverrideBehavior.PINNED_OVERRIDE_BEHAVIOR_PINNED, - version=generated.WorkerDeploymentVersion( - deployment_name=versioning_override.version.deployment_name, - build_id=versioning_override.version.build_id, - ), - ), - deployment=generated.Deployment( - series_name=versioning_override.version.deployment_name, - build_id=versioning_override.version.build_id, - ), - ) - raise TypeError( - f"Unsupported versioning override type: {type(versioning_override)!r}" - ) + metadata = temporalio.api.sdk.v1.UserMetadata() + if static_summary is not None: + metadata.summary.CopyFrom(payload_converter.to_payload(static_summary)) + if static_details is not None: + metadata.details.CopyFrom(payload_converter.to_payload(static_details)) + return metadata def build_signal_with_start_workflow_execution_input( @@ -263,89 +87,74 @@ def build_signal_with_start_workflow_execution_input( start_delay: timedelta | None = None, priority: temporalio.common.Priority = temporalio.common.Priority.default, versioning_override: temporalio.common.VersioningOverride | None = None, -) -> generated.SignalWithStartWorkflowExecutionRequest: - """Build the generated system Nexus input for signal-with-start.""" - return generated.SignalWithStartWorkflowExecutionRequest( +) -> temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest: + """Build the system Nexus signal-with-start request.""" + request_memo = None + if memo is not None: + request_memo = temporalio.api.common.v1.Memo() + _set_payload_map(request_memo.fields, memo, payload_converter) + request_search_attributes = None + if search_attributes is not None: + request_search_attributes = temporalio.api.common.v1.SearchAttributes() + temporalio.converter.encode_search_attributes( + search_attributes, request_search_attributes + ) + return _workflow_requests.build_signal_with_start_workflow_execution_request( namespace=namespace, workflow_id=workflow_id, - workflow_type=generated.WorkflowType(name=workflow), - task_queue=generated.TaskQueue(name=task_queue), - input=_payloads_to_input(payload_converter, workflow_args), - workflow_execution_timeout=( - f"{execution_timeout.total_seconds()}s" if execution_timeout else None - ), - workflow_run_timeout=f"{run_timeout.total_seconds()}s" if run_timeout else None, - workflow_task_timeout=( - f"{task_timeout.total_seconds()}s" if task_timeout else None - ), - request_id=request_id, - workflow_id_reuse_policy=_workflow_id_reuse_policy_to_generated( - id_reuse_policy - ), - workflow_id_conflict_policy=( - _workflow_id_conflict_policy_to_generated(id_conflict_policy) - if id_conflict_policy - != temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED - else None - ), - retry_policy=( - _retry_policy_to_generated(retry_policy) if retry_policy else None - ), - cron_schedule=cron_schedule or None, - memo=( - generated.Memo( - fields={ - key: _payload_to_json_value(payload_converter, value) - for key, value in memo.items() - } - ) - if memo - else None - ), - search_attributes=( - generated.SearchAttributes( - indexed_fields=_search_attributes_to_json_map(search_attributes) - ) - if search_attributes - else None - ), + workflow=workflow, + task_queue=task_queue, signal_name=signal, - signal_input=_payloads_to_input(payload_converter, signal_args), - user_metadata=( - generated.UserMetadata( - summary=_payload_to_json_value(payload_converter, static_summary) - if static_summary is not None - else None, - details=_payload_to_json_value(payload_converter, static_details) - if static_details is not None - else None, - ) - if static_summary is not None or static_details is not None - else None - ), - workflow_start_delay=( - f"{start_delay.total_seconds()}s" if start_delay else None - ), - priority=_priority_to_generated(priority), - versioning_override=( - _versioning_override_to_generated(versioning_override) - if versioning_override - else None + workflow_input_payloads=payload_converter.to_payloads(workflow_args) + if workflow_args + else (), + signal_input_payloads=payload_converter.to_payloads(signal_args) + if signal_args + else (), + execution_timeout=execution_timeout, + run_timeout=run_timeout, + task_timeout=task_timeout, + request_id=request_id, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + cron_schedule=cron_schedule, + memo=request_memo, + search_attributes=request_search_attributes, + user_metadata=_build_user_metadata( + payload_converter, static_summary, static_details ), + start_delay=start_delay, + priority=priority, + versioning_override=versioning_override, ) -def get_payload_visitor( +async def visit_payload( service: str, operation: str, -) -> TemporalNexusPayloadVisitor | None: - """Return the generated nested-payload visitor for a system Nexus operation.""" - return __temporal_nexus_payload_visitors__.get((service, operation)) + payload: temporalio.api.common.v1.Payload, + visitor_functions: PayloadVisitorFunctions, + skip_search_attributes: bool, +) -> temporalio.api.common.v1.Payload | None: + """Visit nested payloads inside a recognized system Nexus envelope.""" + operation_def = generated.__nexus_operation_registry__.get((service, operation)) + if operation_def is None: + return None + input_type = operation_def.input_type + if not isinstance(input_type, type): + return None + payload_converter = get_payload_converter() + value = payload_converter.from_payload(payload, input_type) + await PayloadVisitor(skip_search_attributes=skip_search_attributes).visit( + cast(Any, visitor_functions), value + ) + return payload_converter.to_payload(value) def is_system_operation(service: str, operation: str) -> bool: """Return whether a Nexus operation uses the generated system envelope.""" - return get_payload_visitor(service, operation) is not None + return (service, operation) in generated.__nexus_operation_registry__ def get_payload_converter() -> temporalio.converter.PayloadConverter: @@ -357,7 +166,7 @@ def get_payload_converter() -> temporalio.converter.PayloadConverter: "build_signal_with_start_workflow_execution_input", "generated", "get_payload_converter", - "get_payload_visitor", "is_system_operation", "SystemNexusPayloadConverter", + "visit_payload", ) diff --git a/temporalio/nexus/system/_payload_visitor.py b/temporalio/nexus/system/_payload_visitor.py new file mode 100644 index 000000000..6a247f643 --- /dev/null +++ b/temporalio/nexus/system/_payload_visitor.py @@ -0,0 +1,125 @@ +# This file is generated by gen_payload_visitor.py. Changes should be made there. +import abc +from collections.abc import MutableSequence +from typing import Any + +import temporalio.nexus.system +from temporalio.api.common.v1.message_pb2 import Payload + + +class VisitorFunctions(abc.ABC): + """Set of functions which can be called by the visitor. + Allows handling payloads as a sequence. + """ + + @abc.abstractmethod + async def visit_payload(self, payload: Payload) -> None: + """Called when encountering a single payload.""" + raise NotImplementedError() + + @abc.abstractmethod + async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: + """Called when encountering multiple payloads together.""" + raise NotImplementedError() + + @abc.abstractmethod + async def visit_system_nexus_envelope(self, payload: Payload) -> None: + """Called when encountering a recognized system Nexus envelope payload.""" + raise NotImplementedError() + + +class PayloadVisitor: + """A visitor for payloads. + Applies a function to every payload in a tree of messages. + """ + + def __init__( + self, *, skip_search_attributes: bool = False, skip_headers: bool = False + ): + """Create a new payload visitor.""" + self.skip_search_attributes = skip_search_attributes + self.skip_headers = skip_headers + + async def visit(self, fs: VisitorFunctions, root: Any) -> None: + """Visit the given root message with the given function set.""" + method_name = "_visit_" + root.DESCRIPTOR.full_name.replace(".", "_") + method = getattr(self, method_name, None) + if method is not None: + await method(fs, root) + else: + raise ValueError(f"Unknown root message type: {root.DESCRIPTOR.full_name}") + + async def _visit_system_nexus_payload( + self, + fs: VisitorFunctions, + service: str, + operation: str, + payload: Payload, + ) -> None: + new_payload = await temporalio.nexus.system.visit_payload( + service, + operation, + payload, + fs, + self.skip_search_attributes, + ) + if new_payload is None: + await self._visit_temporal_api_common_v1_Payload(fs, payload) + return + + if new_payload is not payload: + payload.CopyFrom(new_payload) + await fs.visit_system_nexus_envelope(payload) + + async def _visit_temporal_api_common_v1_Payload(self, fs: VisitorFunctions, o: Any): + await fs.visit_payload(o) + + async def _visit_temporal_api_common_v1_Payloads( + self, fs: VisitorFunctions, o: Any + ): + await fs.visit_payloads(o.payloads) + + async def _visit_payload_container(self, fs: VisitorFunctions, o: Any): + await fs.visit_payloads(o) + + async def _visit_temporal_api_common_v1_Memo(self, fs: VisitorFunctions, o: Any): + for v in o.fields.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) + + async def _visit_temporal_api_common_v1_SearchAttributes( + self, fs: VisitorFunctions, o: Any + ): + if self.skip_search_attributes: + return + for v in o.indexed_fields.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) + + async def _visit_temporal_api_common_v1_Header(self, fs: VisitorFunctions, o: Any): + for v in o.fields.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) + + async def _visit_temporal_api_sdk_v1_UserMetadata( + self, fs: VisitorFunctions, o: Any + ): + if o.HasField("summary"): + await self._visit_temporal_api_common_v1_Payload(fs, o.summary) + if o.HasField("details"): + await self._visit_temporal_api_common_v1_Payload(fs, o.details) + + async def _visit_temporal_api_workflowservice_v1_SignalWithStartWorkflowExecutionRequest( + self, fs: VisitorFunctions, o: Any + ): + if o.HasField("input"): + await self._visit_temporal_api_common_v1_Payloads(fs, o.input) + if o.HasField("signal_input"): + await self._visit_temporal_api_common_v1_Payloads(fs, o.signal_input) + if o.HasField("memo"): + await self._visit_temporal_api_common_v1_Memo(fs, o.memo) + if o.HasField("search_attributes"): + await self._visit_temporal_api_common_v1_SearchAttributes( + fs, o.search_attributes + ) + if o.HasField("header"): + await self._visit_temporal_api_common_v1_Header(fs, o.header) + if o.HasField("user_metadata"): + await self._visit_temporal_api_sdk_v1_UserMetadata(fs, o.user_metadata) diff --git a/temporalio/nexus/system/_workflow_service_generated.py b/temporalio/nexus/system/_workflow_service_generated.py index 22fabfa80..6d0fc9e65 100644 --- a/temporalio/nexus/system/_workflow_service_generated.py +++ b/temporalio/nexus/system/_workflow_service_generated.py @@ -3,536 +3,25 @@ from __future__ import annotations -import collections.abc -import json -from dataclasses import dataclass -from datetime import datetime -from enum import Enum -from typing import ClassVar, Dict, List, Optional +import typing -from google.protobuf.json_format import MessageToDict, ParseDict from nexusrpc import Operation, service -import temporalio.api.common.v1 import temporalio.api.workflowservice.v1 - -@dataclass -class PayloadExternalPayloadDetails: - size_bytes: Optional[str] = None - - -@dataclass -class Payload: - data: Optional[str] = None - external_payloads: Optional[List[PayloadExternalPayloadDetails]] = None - metadata: Optional[Dict[str, str]] = None - - -@dataclass -class Header: - fields: Optional[Dict[str, Payload]] = None - - -@dataclass -class Payloads: - payloads: Optional[List[Payload]] = None - - -@dataclass -class LinkActivity: - activity_id: Optional[str] = None - namespace: Optional[str] = None - run_id: Optional[str] = None - - -@dataclass -class LinkBatchJob: - job_id: Optional[str] = None - - -class EventType(str, Enum): - EVENT_TYPE_ACTIVITY_PROPERTIES_MODIFIED_EXTERNALLY = ( - "EVENT_TYPE_ACTIVITY_PROPERTIES_MODIFIED_EXTERNALLY" - ) - EVENT_TYPE_ACTIVITY_TASK_CANCELED = "EVENT_TYPE_ACTIVITY_TASK_CANCELED" - EVENT_TYPE_ACTIVITY_TASK_CANCEL_REQUESTED = ( - "EVENT_TYPE_ACTIVITY_TASK_CANCEL_REQUESTED" - ) - EVENT_TYPE_ACTIVITY_TASK_COMPLETED = "EVENT_TYPE_ACTIVITY_TASK_COMPLETED" - EVENT_TYPE_ACTIVITY_TASK_FAILED = "EVENT_TYPE_ACTIVITY_TASK_FAILED" - EVENT_TYPE_ACTIVITY_TASK_SCHEDULED = "EVENT_TYPE_ACTIVITY_TASK_SCHEDULED" - EVENT_TYPE_ACTIVITY_TASK_STARTED = "EVENT_TYPE_ACTIVITY_TASK_STARTED" - EVENT_TYPE_ACTIVITY_TASK_TIMED_OUT = "EVENT_TYPE_ACTIVITY_TASK_TIMED_OUT" - EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_CANCELED = ( - "EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_CANCELED" - ) - EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_COMPLETED = ( - "EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_COMPLETED" - ) - EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_FAILED = ( - "EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_FAILED" - ) - EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_STARTED = ( - "EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_STARTED" - ) - EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_TERMINATED = ( - "EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_TERMINATED" - ) - EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_TIMED_OUT = ( - "EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_TIMED_OUT" - ) - EVENT_TYPE_EXTERNAL_WORKFLOW_EXECUTION_CANCEL_REQUESTED = ( - "EVENT_TYPE_EXTERNAL_WORKFLOW_EXECUTION_CANCEL_REQUESTED" - ) - EVENT_TYPE_EXTERNAL_WORKFLOW_EXECUTION_SIGNALED = ( - "EVENT_TYPE_EXTERNAL_WORKFLOW_EXECUTION_SIGNALED" - ) - EVENT_TYPE_MARKER_RECORDED = "EVENT_TYPE_MARKER_RECORDED" - EVENT_TYPE_NEXUS_OPERATION_CANCELED = "EVENT_TYPE_NEXUS_OPERATION_CANCELED" - EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED = ( - "EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED" - ) - EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_COMPLETED = ( - "EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_COMPLETED" - ) - EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_FAILED = ( - "EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_FAILED" - ) - EVENT_TYPE_NEXUS_OPERATION_COMPLETED = "EVENT_TYPE_NEXUS_OPERATION_COMPLETED" - EVENT_TYPE_NEXUS_OPERATION_FAILED = "EVENT_TYPE_NEXUS_OPERATION_FAILED" - EVENT_TYPE_NEXUS_OPERATION_SCHEDULED = "EVENT_TYPE_NEXUS_OPERATION_SCHEDULED" - EVENT_TYPE_NEXUS_OPERATION_STARTED = "EVENT_TYPE_NEXUS_OPERATION_STARTED" - EVENT_TYPE_NEXUS_OPERATION_TIMED_OUT = "EVENT_TYPE_NEXUS_OPERATION_TIMED_OUT" - EVENT_TYPE_REQUEST_CANCEL_EXTERNAL_WORKFLOW_EXECUTION_FAILED = ( - "EVENT_TYPE_REQUEST_CANCEL_EXTERNAL_WORKFLOW_EXECUTION_FAILED" - ) - EVENT_TYPE_REQUEST_CANCEL_EXTERNAL_WORKFLOW_EXECUTION_INITIATED = ( - "EVENT_TYPE_REQUEST_CANCEL_EXTERNAL_WORKFLOW_EXECUTION_INITIATED" - ) - EVENT_TYPE_SIGNAL_EXTERNAL_WORKFLOW_EXECUTION_FAILED = ( - "EVENT_TYPE_SIGNAL_EXTERNAL_WORKFLOW_EXECUTION_FAILED" - ) - EVENT_TYPE_SIGNAL_EXTERNAL_WORKFLOW_EXECUTION_INITIATED = ( - "EVENT_TYPE_SIGNAL_EXTERNAL_WORKFLOW_EXECUTION_INITIATED" - ) - EVENT_TYPE_START_CHILD_WORKFLOW_EXECUTION_FAILED = ( - "EVENT_TYPE_START_CHILD_WORKFLOW_EXECUTION_FAILED" - ) - EVENT_TYPE_START_CHILD_WORKFLOW_EXECUTION_INITIATED = ( - "EVENT_TYPE_START_CHILD_WORKFLOW_EXECUTION_INITIATED" - ) - EVENT_TYPE_TIMER_CANCELED = "EVENT_TYPE_TIMER_CANCELED" - EVENT_TYPE_TIMER_FIRED = "EVENT_TYPE_TIMER_FIRED" - EVENT_TYPE_TIMER_STARTED = "EVENT_TYPE_TIMER_STARTED" - EVENT_TYPE_UNSPECIFIED = "EVENT_TYPE_UNSPECIFIED" - EVENT_TYPE_UPSERT_WORKFLOW_SEARCH_ATTRIBUTES = ( - "EVENT_TYPE_UPSERT_WORKFLOW_SEARCH_ATTRIBUTES" - ) - EVENT_TYPE_WORKFLOW_EXECUTION_CANCELED = "EVENT_TYPE_WORKFLOW_EXECUTION_CANCELED" - EVENT_TYPE_WORKFLOW_EXECUTION_CANCEL_REQUESTED = ( - "EVENT_TYPE_WORKFLOW_EXECUTION_CANCEL_REQUESTED" - ) - EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED = "EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED" - EVENT_TYPE_WORKFLOW_EXECUTION_CONTINUED_AS_NEW = ( - "EVENT_TYPE_WORKFLOW_EXECUTION_CONTINUED_AS_NEW" - ) - EVENT_TYPE_WORKFLOW_EXECUTION_FAILED = "EVENT_TYPE_WORKFLOW_EXECUTION_FAILED" - EVENT_TYPE_WORKFLOW_EXECUTION_OPTIONS_UPDATED = ( - "EVENT_TYPE_WORKFLOW_EXECUTION_OPTIONS_UPDATED" - ) - EVENT_TYPE_WORKFLOW_EXECUTION_PAUSED = "EVENT_TYPE_WORKFLOW_EXECUTION_PAUSED" - EVENT_TYPE_WORKFLOW_EXECUTION_SIGNALED = "EVENT_TYPE_WORKFLOW_EXECUTION_SIGNALED" - EVENT_TYPE_WORKFLOW_EXECUTION_STARTED = "EVENT_TYPE_WORKFLOW_EXECUTION_STARTED" - EVENT_TYPE_WORKFLOW_EXECUTION_TERMINATED = ( - "EVENT_TYPE_WORKFLOW_EXECUTION_TERMINATED" - ) - EVENT_TYPE_WORKFLOW_EXECUTION_TIMED_OUT = "EVENT_TYPE_WORKFLOW_EXECUTION_TIMED_OUT" - EVENT_TYPE_WORKFLOW_EXECUTION_TIME_SKIPPING_TRANSITIONED = ( - "EVENT_TYPE_WORKFLOW_EXECUTION_TIME_SKIPPING_TRANSITIONED" - ) - EVENT_TYPE_WORKFLOW_EXECUTION_UNPAUSED = "EVENT_TYPE_WORKFLOW_EXECUTION_UNPAUSED" - EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_ACCEPTED = ( - "EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_ACCEPTED" - ) - EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_ADMITTED = ( - "EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_ADMITTED" - ) - EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_COMPLETED = ( - "EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_COMPLETED" - ) - EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_REJECTED = ( - "EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_REJECTED" - ) - EVENT_TYPE_WORKFLOW_PROPERTIES_MODIFIED = "EVENT_TYPE_WORKFLOW_PROPERTIES_MODIFIED" - EVENT_TYPE_WORKFLOW_PROPERTIES_MODIFIED_EXTERNALLY = ( - "EVENT_TYPE_WORKFLOW_PROPERTIES_MODIFIED_EXTERNALLY" - ) - EVENT_TYPE_WORKFLOW_TASK_COMPLETED = "EVENT_TYPE_WORKFLOW_TASK_COMPLETED" - EVENT_TYPE_WORKFLOW_TASK_FAILED = "EVENT_TYPE_WORKFLOW_TASK_FAILED" - EVENT_TYPE_WORKFLOW_TASK_SCHEDULED = "EVENT_TYPE_WORKFLOW_TASK_SCHEDULED" - EVENT_TYPE_WORKFLOW_TASK_STARTED = "EVENT_TYPE_WORKFLOW_TASK_STARTED" - EVENT_TYPE_WORKFLOW_TASK_TIMED_OUT = "EVENT_TYPE_WORKFLOW_TASK_TIMED_OUT" - - -@dataclass -class WorkflowEventEventReference: - event_id: Optional[str] = None - event_type: Optional["EventType"] = None - - -@dataclass -class WorkflowEventRequestIDReference: - event_type: Optional["EventType"] = None - request_id: Optional[str] = None - - -@dataclass -class LinkWorkflowEvent: - event_ref: Optional[WorkflowEventEventReference] = None - namespace: Optional[str] = None - request_id_ref: Optional[WorkflowEventRequestIDReference] = None - run_id: Optional[str] = None - workflow_id: Optional[str] = None - - -@dataclass -class Link: - activity: Optional[LinkActivity] = None - batch_job: Optional[LinkBatchJob] = None - workflow_event: Optional[LinkWorkflowEvent] = None - - -@dataclass -class Memo: - fields: Optional[Dict[str, Payload]] = None - - -@dataclass -class Priority: - fairness_key: Optional[str] = None - fairness_weight: Optional[float] = None - priority_key: Optional[int] = None - - -@dataclass -class RetryPolicy: - backoff_coefficient: Optional[float] = None - initial_interval: Optional[str] = None - maximum_attempts: Optional[int] = None - maximum_interval: Optional[str] = None - non_retryable_error_types: Optional[List[str]] = None - - -@dataclass -class SearchAttributes: - indexed_fields: Optional[Dict[str, Payload]] = None - - -class Kind(str, Enum): - TASK_QUEUE_KIND_NORMAL = "TASK_QUEUE_KIND_NORMAL" - TASK_QUEUE_KIND_STICKY = "TASK_QUEUE_KIND_STICKY" - TASK_QUEUE_KIND_UNSPECIFIED = "TASK_QUEUE_KIND_UNSPECIFIED" - TASK_QUEUE_KIND_WORKER_COMMANDS = "TASK_QUEUE_KIND_WORKER_COMMANDS" - - -@dataclass -class TaskQueue: - kind: Optional["Kind"] = None - name: Optional[str] = None - normal_name: Optional[str] = None - - -@dataclass -class TimeSkippingConfig: - disable_propagation: Optional[bool] = None - enabled: Optional[bool] = None - max_elapsed_duration: Optional[str] = None - max_skipped_duration: Optional[str] = None - max_target_time: Optional[datetime] = None - - -@dataclass -class UserMetadata: - details: Optional[Payload] = None - summary: Optional[Payload] = None - - -class VersioningOverrideBehavior(str, Enum): - VERSIONING_BEHAVIOR_AUTO_UPGRADE = "VERSIONING_BEHAVIOR_AUTO_UPGRADE" - VERSIONING_BEHAVIOR_PINNED = "VERSIONING_BEHAVIOR_PINNED" - VERSIONING_BEHAVIOR_UNSPECIFIED = "VERSIONING_BEHAVIOR_UNSPECIFIED" - - -@dataclass -class Deployment: - build_id: Optional[str] = None - series_name: Optional[str] = None - - -class VersioningOverridePinnedOverrideBehavior(str, Enum): - PINNED_OVERRIDE_BEHAVIOR_PINNED = "PINNED_OVERRIDE_BEHAVIOR_PINNED" - PINNED_OVERRIDE_BEHAVIOR_UNSPECIFIED = "PINNED_OVERRIDE_BEHAVIOR_UNSPECIFIED" - - -@dataclass -class WorkerDeploymentVersion: - build_id: Optional[str] = None - deployment_name: Optional[str] = None - - -@dataclass -class VersioningOverridePinnedOverride: - behavior: Optional["VersioningOverridePinnedOverrideBehavior"] = None - version: Optional[WorkerDeploymentVersion] = None - - -@dataclass -class VersioningOverride: - auto_upgrade: Optional[bool] = None - behavior: Optional["VersioningOverrideBehavior"] = None - deployment: Optional[Deployment] = None - pinned: Optional[VersioningOverridePinnedOverride] = None - pinned_version: Optional[str] = None - - -class WorkflowIDConflictPolicy(str, Enum): - WORKFLOW_ID_CONFLICT_POLICY_FAIL = "WORKFLOW_ID_CONFLICT_POLICY_FAIL" - WORKFLOW_ID_CONFLICT_POLICY_TERMINATE_EXISTING = ( - "WORKFLOW_ID_CONFLICT_POLICY_TERMINATE_EXISTING" - ) - WORKFLOW_ID_CONFLICT_POLICY_UNSPECIFIED = "WORKFLOW_ID_CONFLICT_POLICY_UNSPECIFIED" - WORKFLOW_ID_CONFLICT_POLICY_USE_EXISTING = ( - "WORKFLOW_ID_CONFLICT_POLICY_USE_EXISTING" - ) - - -class WorkflowIDReusePolicy(str, Enum): - WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE = ( - "WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE" - ) - WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE_FAILED_ONLY = ( - "WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE_FAILED_ONLY" - ) - WORKFLOW_ID_REUSE_POLICY_REJECT_DUPLICATE = ( - "WORKFLOW_ID_REUSE_POLICY_REJECT_DUPLICATE" - ) - WORKFLOW_ID_REUSE_POLICY_TERMINATE_IF_RUNNING = ( - "WORKFLOW_ID_REUSE_POLICY_TERMINATE_IF_RUNNING" - ) - WORKFLOW_ID_REUSE_POLICY_UNSPECIFIED = "WORKFLOW_ID_REUSE_POLICY_UNSPECIFIED" - - -@dataclass -class WorkflowType: - name: Optional[str] = None - - -@dataclass -class SignalWithStartWorkflowExecutionRequest: - control: Optional[str] = None - cron_schedule: Optional[str] = None - header: Optional[Header] = None - identity: Optional[str] = None - input: Optional[Payloads] = None - links: Optional[List[Link]] = None - memo: Optional[Memo] = None - namespace: Optional[str] = None - priority: Optional[Priority] = None - request_id: Optional[str] = None - retry_policy: Optional[RetryPolicy] = None - search_attributes: Optional[SearchAttributes] = None - signal_input: Optional[Payloads] = None - signal_name: Optional[str] = None - task_queue: Optional[TaskQueue] = None - time_skipping_config: Optional[TimeSkippingConfig] = None - user_metadata: Optional[UserMetadata] = None - versioning_override: Optional[VersioningOverride] = None - workflow_execution_timeout: Optional[str] = None - workflow_id: Optional[str] = None - workflow_id_conflict_policy: Optional["WorkflowIDConflictPolicy"] = None - workflow_id_reuse_policy: Optional["WorkflowIDReusePolicy"] = None - workflow_run_timeout: Optional[str] = None - workflow_start_delay: Optional[str] = None - workflow_task_timeout: Optional[str] = None - workflow_type: Optional[WorkflowType] = None - - __temporal_nexus_proto_type__: ClassVar[ - type[temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest] - ] = temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest - - @property - def proto_type( - self, - ) -> type[ - temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest - ]: - return self.__temporal_nexus_proto_type__ - - -@dataclass -class SignalWithStartWorkflowExecutionResponse: - run_id: Optional[str] = None - signal_link: Optional[Link] = None - started: Optional[bool] = None - - __temporal_nexus_proto_type__: ClassVar[ - type[temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionResponse] - ] = temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionResponse - - @property - def proto_type( - self, - ) -> type[ - temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionResponse - ]: - return self.__temporal_nexus_proto_type__ +__nexus_operation_registry__: dict[ + tuple[str, str], Operation[typing.Any, typing.Any] +] = {} @service class WorkflowService: signal_with_start_workflow_execution: Operation[ - SignalWithStartWorkflowExecutionRequest, - SignalWithStartWorkflowExecutionResponse, + temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest, + temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionResponse, ] = Operation(name="SignalWithStartWorkflowExecution") -class _TemporalNexusPayloadVisitor: - def __init__( - self, - payload_visitor: collections.abc.Callable[ - [collections.abc.Sequence[temporalio.api.common.v1.Payload]], - collections.abc.Awaitable[list[temporalio.api.common.v1.Payload]], - ], - visit_search_attributes: bool = False, - ): - self._payload_visitor = payload_visitor - self._visit_search_attributes = visit_search_attributes - - async def _visit_payload_json(self, value: dict) -> dict: - payload = ParseDict(value, temporalio.api.common.v1.Payload()) - [visited_payload] = await self._payload_visitor([payload]) - return MessageToDict(visited_payload) - - async def _visit_payloads_json(self, value: dict) -> dict: - payloads = ParseDict(value, temporalio.api.common.v1.Payloads()) - visited_payloads = await self._payload_visitor(payloads.payloads) - del payloads.payloads[:] - payloads.payloads.extend(visited_payloads) - return MessageToDict(payloads) - - async def _visit_payload_map_json(self, message_type: type, value: dict) -> dict: - message = message_type() - keys = list(value.keys()) - visited_payloads = await self._payload_visitor( - [ParseDict(value[key], temporalio.api.common.v1.Payload()) for key in keys] - ) - for key, visited_payload in zip(keys, visited_payloads): - message.fields[key].CopyFrom(visited_payload) - return MessageToDict(message).get("fields", {}) - - async def _temporal_nexus_visit_header_json(self, value: dict) -> dict: - visited = dict(value) - if visited.get("fields") is not None: - visited["fields"] = await self._visit_payload_map_json( - temporalio.api.common.v1.Header, visited["fields"] - ) - return visited - - async def _temporal_nexus_visit_payload_json(self, value: dict) -> dict: - return await self._visit_payload_json(value) - - async def _temporal_nexus_visit_payloads_json(self, value: dict) -> dict: - return await self._visit_payloads_json(value) - - async def _temporal_nexus_visit_memo_json(self, value: dict) -> dict: - visited = dict(value) - if visited.get("fields") is not None: - visited["fields"] = await self._visit_payload_map_json( - temporalio.api.common.v1.Memo, visited["fields"] - ) - return visited - - async def _temporal_nexus_visit_search_attributes_json(self, value: dict) -> dict: - if not self._visit_search_attributes: - return value - visited = dict(value) - if visited.get("indexedFields") is not None: - visited["indexedFields"] = await self._visit_payload_map_json( - temporalio.api.common.v1.SearchAttributes, visited["indexedFields"] - ) - return visited - - async def _temporal_nexus_visit_user_metadata_json(self, value: dict) -> dict: - visited = dict(value) - if visited.get("details") is not None: - visited["details"] = await self._visit_payload_json(visited["details"]) - if visited.get("summary") is not None: - visited["summary"] = await self._visit_payload_json(visited["summary"]) - return visited - - async def _temporal_nexus_visit_signal_with_start_workflow_execution_request_json( - self, value: dict - ) -> dict: - visited = dict(value) - if visited.get("header") is not None: - visited["header"] = await self._temporal_nexus_visit_header_json( - visited["header"] - ) - if visited.get("input") is not None: - visited["input"] = await self._temporal_nexus_visit_payloads_json( - visited["input"] - ) - if visited.get("memo") is not None: - visited["memo"] = await self._temporal_nexus_visit_memo_json( - visited["memo"] - ) - if visited.get("searchAttributes") is not None: - visited[ - "searchAttributes" - ] = await self._temporal_nexus_visit_search_attributes_json( - visited["searchAttributes"] - ) - if visited.get("signalInput") is not None: - visited["signalInput"] = await self._temporal_nexus_visit_payloads_json( - visited["signalInput"] - ) - if visited.get("userMetadata") is not None: - visited[ - "userMetadata" - ] = await self._temporal_nexus_visit_user_metadata_json( - visited["userMetadata"] - ) - return visited - - -async def _temporal_nexus_visit_signal_with_start_workflow_execution_request( - payload: temporalio.api.common.v1.Payload, - payload_visitor: collections.abc.Callable[ - [collections.abc.Sequence[temporalio.api.common.v1.Payload]], - collections.abc.Awaitable[list[temporalio.api.common.v1.Payload]], - ], - visit_search_attributes: bool = False, -) -> temporalio.api.common.v1.Payload: - try: - value = json.loads(payload.data) - except json.JSONDecodeError: - return payload - if not isinstance(value, dict): - return payload - visitor = _TemporalNexusPayloadVisitor(payload_visitor, visit_search_attributes) - visited = await visitor._temporal_nexus_visit_signal_with_start_workflow_execution_request_json( - value - ) - return temporalio.api.common.v1.Payload( - metadata=dict(payload.metadata), - data=json.dumps(visited, separators=(",", ":"), sort_keys=True).encode(), - ) - - -__temporal_nexus_payload_visitors__ = { - ( - "WorkflowService", - "SignalWithStartWorkflowExecution", - ): _temporal_nexus_visit_signal_with_start_workflow_execution_request, -} +__nexus_operation_registry__[ + ("WorkflowService", "SignalWithStartWorkflowExecution") +] = WorkflowService.signal_with_start_workflow_execution diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index d2a90ef9a..de4233e11 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -47,6 +47,7 @@ import temporalio.api.common.v1 import temporalio.api.enums.v1 import temporalio.api.sdk.v1 +import temporalio.api.workflowservice.v1 import temporalio.bridge.proto.activity_result import temporalio.bridge.proto.child_workflow import temporalio.bridge.proto.common @@ -2061,7 +2062,7 @@ async def _outbound_signal_with_start_workflow( cancellation_type=temporalio.workflow.NexusOperationCancellationType.WAIT_COMPLETED, headers=input.headers, summary=None, - output_type=temporalio.nexus.system.generated.SignalWithStartWorkflowExecutionResponse, + output_type=temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionResponse, ) ) result = await handle diff --git a/tests/nexus/test_temporal_system_nexus.py b/tests/nexus/test_temporal_system_nexus.py index aa8e91e70..f0b8e726a 100644 --- a/tests/nexus/test_temporal_system_nexus.py +++ b/tests/nexus/test_temporal_system_nexus.py @@ -1,11 +1,10 @@ from __future__ import annotations import dataclasses -import json import uuid from collections.abc import Sequence from datetime import timedelta -from typing import Any, ClassVar, Protocol, cast +from typing import Any, cast import nexusrpc.handler import pytest @@ -13,11 +12,12 @@ from google.protobuf.message import Message import temporalio.api.common.v1 +import temporalio.api.workflowservice.v1 import temporalio.converter import temporalio.nexus.system as nexus_system from temporalio import workflow from temporalio.client import Client -from temporalio.converter import DefaultPayloadConverter, ExternalStorage, PayloadCodec +from temporalio.converter import ExternalStorage, PayloadCodec from temporalio.nexus.system import generated from temporalio.testing import WorkflowEnvironment from temporalio.worker import ( @@ -33,14 +33,9 @@ from tests.test_extstore import InMemoryTestDriver interceptor_traces: list[tuple[str, object]] = [] -received_requests: list[dict[str, Any]] = [] - - -class _AnnotatedSystemNexusMessage(Protocol): - __temporal_nexus_proto_type__: ClassVar[type[Message]] - - @property - def proto_type(self) -> type[Message]: ... +received_requests: list[ + temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest +] = [] @nexusrpc.handler.service_handler(service=generated.WorkflowService) @@ -49,13 +44,19 @@ class WorkflowServicePayloadHandler: async def signal_with_start_workflow_execution( self, _ctx: nexusrpc.handler.StartOperationContext, - request: generated.SignalWithStartWorkflowExecutionRequest, - ) -> generated.SignalWithStartWorkflowExecutionResponse: + request: temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest, + ) -> temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionResponse: assert request.workflow_id == "system-nexus-workflow-id" assert request.signal_name == "test-signal" - received_requests.append(dataclasses.asdict(request)) - return generated.SignalWithStartWorkflowExecutionResponse( - run_id=f"{request.workflow_id}-run" + received_request = ( + temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest() + ) + received_request.CopyFrom(request) + received_requests.append(received_request) + return ( + temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionResponse( + run_id=f"{request.workflow_id}-run" + ) ) @@ -86,15 +87,11 @@ async def encode( ) -> list[temporalio.api.common.v1.Payload]: encoded: list[temporalio.api.common.v1.Payload] = [] for payload in payloads: - try: - body = json.loads(payload.data) - except json.JSONDecodeError: - body = None - if isinstance(body, dict) and { - "namespace", - "workflowId", - "signalName", - }.issubset(body): + if ( + payload.metadata.get("encoding") == b"binary/protobuf" + and payload.metadata.get("messageType") + == b"temporal.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest" + ): raise RuntimeError( "outer system nexus envelope should not be codec encoded" ) @@ -112,15 +109,11 @@ async def decode( ) -> list[temporalio.api.common.v1.Payload]: decoded: list[temporalio.api.common.v1.Payload] = [] for payload in payloads: - try: - body = json.loads(payload.data) - except json.JSONDecodeError: - body = None - if isinstance(body, dict) and { - "namespace", - "workflowId", - "signalName", - }.issubset(body): + if ( + payload.metadata.get("encoding") == b"binary/protobuf" + and payload.metadata.get("messageType") + == b"temporal.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest" + ): raise RuntimeError( "outer system nexus envelope should not be codec decoded" ) @@ -128,24 +121,6 @@ async def decode( return decoded -class BadSystemNexusEnvelopePayloadConverter(DefaultPayloadConverter): - def to_payloads( - self, values: Sequence[object] - ) -> list[temporalio.api.common.v1.Payload]: - payloads: list[temporalio.api.common.v1.Payload] = [] - for value in values: - if isinstance(value, generated.SignalWithStartWorkflowExecutionRequest): - payloads.append( - temporalio.api.common.v1.Payload( - metadata={"encoding": b"json/plain"}, - data=b'{"workflow_id":"bad-envelope"}', - ) - ) - else: - payloads.extend(super().to_payloads([value])) - return payloads - - class TracingWorkflowInterceptor(Interceptor): def workflow_interceptor_class( self, input: WorkflowInterceptorClassInput @@ -166,30 +141,28 @@ async def signal_with_start_workflow( return await super().signal_with_start_workflow(input) -def _pop_received_request() -> dict[str, Any]: +def _pop_received_request() -> ( + temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest +): assert len(received_requests) == 1 return received_requests.pop() def _assert_request_payload_was_externally_stored( - request_dict: dict[str, Any], field_name: str + request: temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest, + field_name: str, ) -> None: - payloads = cast("dict[str, list[dict[str, object]]]", request_dict[field_name])[ - "payloads" - ] + payloads = getattr(request, field_name).payloads assert len(payloads) == 1 - assert payloads[0]["external_payloads"] + assert payloads[0].external_payloads def _assert_request_user_metadata_was_externally_stored( - request_dict: dict[str, Any], + request: temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest, ) -> None: - user_metadata = cast( - "dict[str, dict[str, object]] | None", request_dict.get("user_metadata") - ) - assert user_metadata is not None - assert user_metadata["summary"]["external_payloads"] - assert user_metadata["details"]["external_payloads"] + assert request.HasField("user_metadata") + assert request.user_metadata.summary.external_payloads + assert request.user_metadata.details.external_payloads def _assert_stored_payloads_include( @@ -304,40 +277,23 @@ def _proto_scalar_sample(field: FieldDescriptor, *, path: str) -> Any: raise TypeError(f"Unhandled proto scalar sample at {path}: {field!r}") -def test_generated_system_nexus_proto_roundtrip() -> None: +@pytest.mark.parametrize( + "message_type", + [ + temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest, + temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionResponse, + ], +) +def test_system_nexus_proto_roundtrip(message_type: type[Message]) -> None: payload_converter = nexus_system.get_payload_converter() - annotated_types = sorted( - ( - value - for value in vars(generated).values() - if isinstance(value, type) - and dataclasses.is_dataclass(value) - and hasattr(value, "__temporal_nexus_proto_type__") - ), - key=lambda value: value.__name__, - ) - assert annotated_types - - for annotated_type in annotated_types: - annotated_message_type = cast( - type[_AnnotatedSystemNexusMessage], annotated_type - ) - proto_type = annotated_message_type.__temporal_nexus_proto_type__ - proto_value = _build_proto_sample(proto_type) - payload = payload_converter.to_payload(proto_value) - assert payload is not None - assert ( - payload.metadata["messageType"] == proto_type.DESCRIPTOR.full_name.encode() - ) - value = payload_converter.from_payload(payload, annotated_message_type) - assert dataclasses.is_dataclass(value) - assert value.proto_type is proto_type - roundtripped_payload = payload_converter.to_payload(value) - assert roundtripped_payload is not None - roundtripped = payload_converter.from_payload( - roundtripped_payload, annotated_message_type - ) - assert roundtripped == value + proto_value = _build_proto_sample(message_type) + payload = payload_converter.to_payload(proto_value) + assert payload is not None + assert payload.metadata["encoding"] == b"binary/protobuf" + assert payload.metadata["messageType"] == message_type.DESCRIPTOR.full_name.encode() + roundtripped = payload_converter.from_payload(payload, message_type) + assert isinstance(roundtripped, message_type) + assert roundtripped == proto_value async def test_external_workflow_handle_signal_with_start_workflow_uses_system_nexus( @@ -396,10 +352,10 @@ async def test_external_workflow_handle_signal_with_start_workflow_uses_system_n ) assert result == "system-nexus-workflow-id-run" - request_dict = _pop_received_request() - _assert_request_payload_was_externally_stored(request_dict, "input") - _assert_request_payload_was_externally_stored(request_dict, "signal_input") - _assert_request_user_metadata_was_externally_stored(request_dict) + request = _pop_received_request() + _assert_request_payload_was_externally_stored(request, "input") + _assert_request_payload_was_externally_stored(request, "signal_input") + _assert_request_user_metadata_was_externally_stored(request) assert codec.encode_count >= 5 _assert_stored_payloads_include( driver, diff --git a/tests/worker/test_visitor.py b/tests/worker/test_visitor.py index 196f8e9e0..c64a48a04 100644 --- a/tests/worker/test_visitor.py +++ b/tests/worker/test_visitor.py @@ -1,6 +1,4 @@ -import base64 import dataclasses -import json from collections.abc import MutableSequence import pytest @@ -16,6 +14,9 @@ SearchAttributes, ) from temporalio.api.sdk.v1.user_metadata_pb2 import UserMetadata +from temporalio.api.workflowservice.v1.request_response_pb2 import ( + SignalWithStartWorkflowExecutionRequest, +) from temporalio.bridge._visitor import PayloadVisitor, VisitorFunctions from temporalio.bridge.proto.workflow_activation.workflow_activation_pb2 import ( InitializeWorkflow, @@ -55,6 +56,10 @@ async def visit_system_nexus_envelope(self, payload: Payload) -> None: payload.metadata["visited"] = b"True" +def _json_plain_payload(value: object) -> Payload: + return temporalio.converter.default().payload_converter.to_payload(value) + + async def test_workflow_activation_completion(): comp = WorkflowActivationCompletion( run_id="1", @@ -259,42 +264,16 @@ async def test_bridge_encoding(): async def test_visit_system_nexus_payloads_on_schedule_nexus_operation(): - envelope = nexus_system.generated.SignalWithStartWorkflowExecutionRequest( + envelope = SignalWithStartWorkflowExecutionRequest( namespace="default", workflow_id="workflow-id", signal_name="signal-name", - input=nexus_system.generated.Payloads( - payloads=[ - nexus_system.generated.Payload( - data="ImlucHV0LXZhbHVlIg==", - metadata={"encoding": "anNvbi9wbGFpbg=="}, - ) - ] - ), - signal_input=nexus_system.generated.Payloads( - payloads=[ - nexus_system.generated.Payload( - data="InNpZ25hbC12YWx1ZSI=", - metadata={"encoding": "anNvbi9wbGFpbg=="}, - ) - ] - ), - memo=nexus_system.generated.Memo( - fields={ - "memo-key": nexus_system.generated.Payload( - data="Im1lbW8tdmFsdWUi", - metadata={"encoding": "anNvbi9wbGFpbg=="}, - ) - } - ), - search_attributes=nexus_system.generated.SearchAttributes( - indexed_fields={ - "search-key": nexus_system.generated.Payload( - data="InNlYXJjaC12YWx1ZSI=", - metadata={"encoding": "anNvbi9wbGFpbg=="}, - ) - } - ), + ) + envelope.input.payloads.extend([_json_plain_payload("input-value")]) + envelope.signal_input.payloads.extend([_json_plain_payload("signal-value")]) + envelope.memo.fields["memo-key"].CopyFrom(_json_plain_payload("memo-value")) + envelope.search_attributes.indexed_fields["search-key"].CopyFrom( + _json_plain_payload("search-value") ) comp = WorkflowActivationCompletion( run_id="1", @@ -316,40 +295,26 @@ async def test_visit_system_nexus_payloads_on_schedule_nexus_operation(): input_payload = comp.successful.commands[0].schedule_nexus_operation.input assert input_payload.metadata["visited"] - rewritten = json.loads(input_payload.data) - assert ( - base64.b64decode(rewritten["input"]["payloads"][0]["metadata"]["visited"]) - == b"True" - ) - assert ( - base64.b64decode(rewritten["signalInput"]["payloads"][0]["metadata"]["visited"]) - == b"True" - ) - assert ( - base64.b64decode(rewritten["memo"]["fields"]["memo-key"]["metadata"]["visited"]) - == b"True" + rewritten = nexus_system.get_payload_converter().from_payload( + input_payload, SignalWithStartWorkflowExecutionRequest ) + assert rewritten.input.payloads[0].metadata["visited"] == b"True" + assert rewritten.signal_input.payloads[0].metadata["visited"] == b"True" + assert rewritten.memo.fields["memo-key"].metadata["visited"] == b"True" assert ( "visited" - not in rewritten["searchAttributes"]["indexedFields"]["search-key"]["metadata"] + not in rewritten.search_attributes.indexed_fields["search-key"].metadata ) async def test_bridge_encoding_checks_system_nexus_envelope_size(): - envelope = nexus_system.generated.SignalWithStartWorkflowExecutionRequest( + envelope = SignalWithStartWorkflowExecutionRequest( namespace="default", workflow_id="workflow-id", signal_name="signal-name", request_id="x" * 2048, - input=nexus_system.generated.Payloads( - payloads=[ - nexus_system.generated.Payload( - data="ImlucHV0LXZhbHVlIg==", - metadata={"encoding": "anNvbi9wbGFpbg=="}, - ) - ] - ), ) + envelope.input.payloads.extend([_json_plain_payload("input-value")]) comp = WorkflowActivationCompletion( run_id="1", successful=Success( From bca2b0fa9216f7158ba7f2064868cdb3552a9e0c Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Fri, 24 Apr 2026 08:41:00 -0700 Subject: [PATCH 16/18] Refine Python payload visitor typing --- pyproject.toml | 1 - scripts/gen_nexus_system_models.py | 131 +++++++--- scripts/gen_payload_visitor.py | 133 +++++++--- temporalio/bridge/_visitor.py | 247 ++++++++++++------ temporalio/bridge/_visitor_functions.py | 25 ++ temporalio/bridge/worker.py | 10 +- temporalio/nexus/system/__init__.py | 21 +- temporalio/nexus/system/_payload_visitor.py | 63 +++-- .../system/_workflow_service_generated.py | 9 +- temporalio/worker/_command_aware_visitor.py | 3 +- tests/worker/test_visitor.py | 37 ++- 11 files changed, 451 insertions(+), 229 deletions(-) create mode 100644 temporalio/bridge/_visitor_functions.py diff --git a/pyproject.toml b/pyproject.toml index 40bd7189e..2f5cbfef8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -217,7 +217,6 @@ exclude = [ # Exclude auto generated files "temporalio/api", "temporalio/bridge/proto", - "temporalio/bridge/_visitor.py", "tests/worker/workflow_sandbox/testmodules/proto", ] diff --git a/scripts/gen_nexus_system_models.py b/scripts/gen_nexus_system_models.py index a133965d8..11f72a90f 100644 --- a/scripts/gen_nexus_system_models.py +++ b/scripts/gen_nexus_system_models.py @@ -1,11 +1,11 @@ from __future__ import annotations -import os +import importlib +import re import subprocess import sys from pathlib import Path -NEXUS_RPC_GEN_ENV_VAR = "TEMPORAL_NEXUS_RPC_GEN_DIR" NEXUS_RPC_GEN_VERSION = "0.1.0-alpha.4" @@ -13,9 +13,6 @@ def main() -> None: repo_root = Path(__file__).resolve().parent.parent # TODO: Remove the local .nexusrpc.yaml shim once the upstream API repo # checks in the Nexus definition we can consume directly. - override_root = normalize_nexus_rpc_gen_root( - Path.cwd(), env_value=NEXUS_RPC_GEN_ENV_VAR - ) input_schema = ( repo_root / "temporalio" @@ -36,10 +33,10 @@ def main() -> None: raise RuntimeError(f"Expected Nexus schema at {input_schema}") run_nexus_rpc_gen( - override_root=override_root, output_file=output_file, input_schema=input_schema, ) + add_operation_registry(repo_root, output_file) subprocess.run( [ "uv", @@ -67,9 +64,90 @@ def main() -> None: ) -def run_nexus_rpc_gen( - *, override_root: Path | None, output_file: Path, input_schema: Path -) -> None: +def add_operation_registry(repo_root: Path, output_file: Path) -> None: + source = strip_existing_operation_registry(output_file.read_text()) + source = ensure_typing_import(source) + services = discover_services(repo_root) + if not services: + output_file.write_text(source) + return + output_file.write_text(source.rstrip() + "\n\n" + emit_operation_registry(services)) + + +def strip_existing_operation_registry(source: str) -> str: + source = re.sub( + r"\nimport typing\n(?=\n__nexus_operation_registry__)", + "\n", + source, + ) + source = re.sub( + r"\n__nexus_operation_registry__: dict\[\n" + r"(?:.*\n)*?" + r"\] = \{\}\n" + r"(?:\n__nexus_operation_registry__\[\n(?:.*\n)*?\] = .+\n)+", + "\n", + source, + flags=re.MULTILINE, + ) + return source.rstrip() + "\n" + + +def ensure_typing_import(source: str) -> str: + if "\nimport typing\n" in source: + return source + marker = "from __future__ import annotations\n\n" + if marker not in source: + raise RuntimeError("Expected future-annotations import in generated output") + return source.replace(marker, marker + "import typing\n", 1) + + +def discover_services(repo_root: Path) -> list[tuple[str, str, list[tuple[str, str]]]]: + module_name = "temporalio.nexus.system._workflow_service_generated" + sys.path.insert(0, str(repo_root)) + try: + sys.modules.pop(module_name, None) + importlib.invalidate_caches() + module = importlib.import_module(module_name) + finally: + sys.path.pop(0) + services: list[tuple[str, str, list[tuple[str, str]]]] = [] + for value in vars(module).values(): + if not isinstance(value, type): + continue + definition = getattr(value, "__nexus_service_definition__", None) + if definition is None: + continue + operations = [ + (operation_definition.method_name, operation_definition.name) + for operation_definition in definition.operation_definitions.values() + ] + services.append((value.__name__, definition.name, operations)) + return services + + +def emit_operation_registry( + services: list[tuple[str, str, list[tuple[str, str]]]], +) -> str: + lines = [ + "__nexus_operation_registry__: dict[", + " tuple[str, str], Operation[typing.Any, typing.Any]", + "] = {}", + "", + ] + for class_name, service_name, operations in services: + for attr_name, operation_name in operations: + lines.extend( + [ + "__nexus_operation_registry__[", + f" ({service_name!r}, {operation_name!r})", + f"] = {class_name}.{attr_name}", + "", + ] + ) + return "\n".join(lines).rstrip() + "\n" + + +def run_nexus_rpc_gen(*, output_file: Path, input_schema: Path) -> None: common_args = [ "--lang", "py", @@ -77,45 +155,12 @@ def run_nexus_rpc_gen( str(output_file), str(input_schema), ] - if override_root is None: - subprocess.run( - ["npx", "--yes", f"nexus-rpc-gen@{NEXUS_RPC_GEN_VERSION}", *common_args], - check=True, - ) - return - subprocess.run( - [ - "node", - "packages/nexus-rpc-gen/dist/index.js", - *common_args, - ], - cwd=override_root, + ["npx", "--yes", f"nexus-rpc-gen@{NEXUS_RPC_GEN_VERSION}", *common_args], check=True, ) -def normalize_nexus_rpc_gen_root(base_dir: Path, env_value: str) -> Path | None: - raw_root = env_get(env_value) - if raw_root is None: - return None - candidate = Path(raw_root) - if not candidate.is_absolute(): - candidate = base_dir / candidate - candidate = candidate.resolve() - if (candidate / "package.json").is_file() and (candidate / "packages").is_dir(): - return candidate - if (candidate / "src" / "package.json").is_file(): - return candidate / "src" - raise RuntimeError( - f"{NEXUS_RPC_GEN_ENV_VAR} must point to the nexus-rpc-gen repo root or its src directory" - ) - - -def env_get(name: str) -> str | None: - return os.environ.get(name) - - if __name__ == "__main__": try: main() diff --git a/scripts/gen_payload_visitor.py b/scripts/gen_payload_visitor.py index 2b66103cb..fd7e3e604 100644 --- a/scripts/gen_payload_visitor.py +++ b/scripts/gen_payload_visitor.py @@ -21,6 +21,26 @@ ) base_dir = Path(__file__).parent.parent +BASIC_IMPORTED_TYPES = { + Payload.DESCRIPTOR.full_name: "Payload", + Payloads.DESCRIPTOR.full_name: "Payloads", + SearchAttributes.DESCRIPTOR.full_name: "SearchAttributes", +} + + +def normalize_python_module(module: str) -> str: + if module.startswith("temporal.sdk.core."): + return "temporalio.bridge.proto." + module[len("temporal.sdk.core.") :] + return module + + +def nested_python_name(desc: Descriptor) -> str: + names = [desc.name] + current = desc.containing_type + while current is not None: + names.append(current.name) + current = current.containing_type + return ".".join(reversed(names)) def name_for(desc: Descriptor) -> str: @@ -112,21 +132,25 @@ def __init__(self) -> None: Payloads.DESCRIPTOR.full_name: True, } self.in_progress: set[str] = set() + self.root_type_imports: dict[str, tuple[str, str]] = {} + self.type_checking_modules: set[str] = set() self.methods: list[str] = [ """\ async def _visit_temporal_api_common_v1_Payload( - self, fs: VisitorFunctions, o: Any + self, fs: VisitorFunctions, o: Payload ): await fs.visit_payload(o) """, """\ async def _visit_temporal_api_common_v1_Payloads( - self, fs: VisitorFunctions, o: Any + self, fs: VisitorFunctions, o: Payloads ): await fs.visit_payloads(o.payloads) """, """\ - async def _visit_payload_container(self, fs: VisitorFunctions, o: Any): + async def _visit_payload_container( + self, fs: VisitorFunctions, o: PayloadSequence + ): await fs.visit_payloads(o) """, ] @@ -135,36 +159,37 @@ def generate(self, roots: list[Descriptor]) -> str: for root in roots: self.walk(root) + extra_imports = "\n".join( + f"from {module} import {class_name}" + for class_name, module in sorted( + set(self.root_type_imports.values()), + key=lambda item: (item[1], item[0]), + ) + ) + type_checking_imports = "\n".join( + f" import {module}" for module in sorted(self.type_checking_modules) + ) + if extra_imports: + extra_imports += "\n" + if type_checking_imports: + type_checking_imports = ( + "\nif TYPE_CHECKING:\n" + type_checking_imports + "\n" + ) + header = """ +from __future__ import annotations + # This file is generated by gen_payload_visitor.py. Changes should be made there. -import abc -from collections.abc import MutableSequence -from typing import Any +from typing import TYPE_CHECKING +from google.protobuf.message import Message import temporalio.nexus.system -from temporalio.api.common.v1.message_pb2 import Payload - - -class VisitorFunctions(abc.ABC): - \"\"\"Set of functions which can be called by the visitor. - Allows handling payloads as a sequence. - \"\"\" - - @abc.abstractmethod - async def visit_payload(self, payload: Payload) -> None: - \"\"\"Called when encountering a single payload.\"\"\" - raise NotImplementedError() - - @abc.abstractmethod - async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: - \"\"\"Called when encountering multiple payloads together.\"\"\" - raise NotImplementedError() - - @abc.abstractmethod - async def visit_system_nexus_envelope(self, payload: Payload) -> None: - \"\"\"Called when encountering a recognized system Nexus envelope payload.\"\"\" - raise NotImplementedError() - +from temporalio.api.common.v1.message_pb2 import Payload, Payloads, SearchAttributes +from temporalio.bridge._visitor_functions import PayloadSequence, VisitorFunctions +""" + header += extra_imports + header += type_checking_imports + header += """ class PayloadVisitor: \"\"\"A visitor for payloads. @@ -178,7 +203,7 @@ def __init__( self.skip_search_attributes = skip_search_attributes self.skip_headers = skip_headers - async def visit(self, fs: VisitorFunctions, root: Any) -> None: + async def visit(self, fs: VisitorFunctions, root: Message) -> None: \"\"\"Visit the given root message with the given function set.\"\"\" method_name = "_visit_" + root.DESCRIPTOR.full_name.replace(".", "_") method = getattr(self, method_name, None) @@ -212,6 +237,25 @@ async def _visit_system_nexus_payload( """ return header + "\n".join(self.methods) + def python_type_for_descriptor(self, desc: Descriptor) -> str: + basic_type = BASIC_IMPORTED_TYPES.get(desc.full_name) + if basic_type is not None: + return basic_type + cls = getattr(desc, "_concrete_class", None) + if cls is None: + return "Message" + module = normalize_python_module(cls.__module__) + if desc in ( + WorkflowActivation.DESCRIPTOR, + WorkflowActivationCompletion.DESCRIPTOR, + SignalWithStartWorkflowExecutionRequest.DESCRIPTOR, + SignalWithStartWorkflowExecutionResponse.DESCRIPTOR, + ): + self.root_type_imports[desc.full_name] = (cls.__name__, module) + return cls.__name__ + self.type_checking_modules.add(module) + return f"{module}.{nested_python_name(desc)}" + def check_repeated( self, child_desc: Descriptor, field: FieldDescriptor, iter_expr: str ) -> str | None: @@ -231,14 +275,10 @@ def walk(self, desc: Descriptor) -> bool: has_payload = False self.in_progress.add(key) - lines: list[str] = [ - f" async def _visit_{name_for(desc)}(" - "self, fs: VisitorFunctions, o: Any" - "):" - ] + body_lines: list[str] = [] if desc.full_name == SearchAttributes.DESCRIPTOR.full_name: - lines.append(" if self.skip_search_attributes:") - lines.append(" return") + body_lines.append(" if self.skip_search_attributes:") + body_lines.append(" return") oneof_fields: dict[int, list[FieldDescriptor]] = {} regular_fields: list[FieldDescriptor] = [] @@ -257,7 +297,7 @@ def walk(self, desc: Descriptor) -> bool: and field.name == "input" ): has_payload = True - lines.append( + body_lines.append( """\ if o.HasField("input"): await self._visit_system_nexus_payload(fs, o.service, o.operation, o.input)""" @@ -278,7 +318,7 @@ def walk(self, desc: Descriptor) -> bool: child_needed = self.walk(child_desc) if child_needed: has_payload = True - lines.append( + body_lines.append( emit_loop( field.name, f"o.{field.name}.values()", @@ -295,7 +335,7 @@ def walk(self, desc: Descriptor) -> bool: child_needed = self.walk(child_desc) if child_needed: has_payload = True - lines.append( + body_lines.append( emit_loop( field.name, f"o.{field.name}.keys()", @@ -308,13 +348,13 @@ def walk(self, desc: Descriptor) -> bool: ) if child is not None: has_payload = True - lines.append(child) + body_lines.append(child) else: child_desc = field.message_type child_has_payload = self.walk(child_desc) has_payload |= child_has_payload if child_has_payload: - lines.append( + body_lines.append( emit_singular( field.name, f"o.{field.name}", name_for(child_desc), "if" ) @@ -338,11 +378,18 @@ def walk(self, desc: Descriptor) -> bool: ) first = False if oneof_lines: - lines.extend(oneof_lines) + body_lines.extend(oneof_lines) self.generated[key] = has_payload self.in_progress.discard(key) if has_payload: + annotation = self.python_type_for_descriptor(desc) + lines = [ + f" async def _visit_{name_for(desc)}(" + f"self, fs: VisitorFunctions, o: {annotation}" + "):" + ] + lines.extend(body_lines) self.methods.append("\n".join(lines) + "\n") return has_payload diff --git a/temporalio/bridge/_visitor.py b/temporalio/bridge/_visitor.py index b6c61d730..dbbf16816 100644 --- a/temporalio/bridge/_visitor.py +++ b/temporalio/bridge/_visitor.py @@ -1,31 +1,30 @@ -# This file is generated by gen_payload_visitor.py. Changes should be made there. -import abc -from collections.abc import MutableSequence -from typing import Any - -import temporalio.nexus.system -from temporalio.api.common.v1.message_pb2 import Payload - +from __future__ import annotations -class VisitorFunctions(abc.ABC): - """Set of functions which can be called by the visitor. - Allows handling payloads as a sequence. - """ - - @abc.abstractmethod - async def visit_payload(self, payload: Payload) -> None: - """Called when encountering a single payload.""" - raise NotImplementedError() +# This file is generated by gen_payload_visitor.py. Changes should be made there. +from typing import TYPE_CHECKING - @abc.abstractmethod - async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: - """Called when encountering multiple payloads together.""" - raise NotImplementedError() +from google.protobuf.message import Message - @abc.abstractmethod - async def visit_system_nexus_envelope(self, payload: Payload) -> None: - """Called when encountering a recognized system Nexus envelope payload.""" - raise NotImplementedError() +import temporalio.nexus.system +from temporalio.api.common.v1.message_pb2 import Payload, Payloads, SearchAttributes +from temporalio.bridge._visitor_functions import PayloadSequence, VisitorFunctions +from temporalio.bridge.proto.workflow_activation.workflow_activation_pb2 import ( + WorkflowActivation, +) +from temporalio.bridge.proto.workflow_completion.workflow_completion_pb2 import ( + WorkflowActivationCompletion, +) + +if TYPE_CHECKING: + import temporalio.api.common.v1.message_pb2 + import temporalio.api.failure.v1.message_pb2 + import temporalio.api.sdk.v1.user_metadata_pb2 + import temporalio.bridge.proto.activity_result.activity_result_pb2 + import temporalio.bridge.proto.child_workflow.child_workflow_pb2 + import temporalio.bridge.proto.nexus.nexus_pb2 + import temporalio.bridge.proto.workflow_activation.workflow_activation_pb2 + import temporalio.bridge.proto.workflow_commands.workflow_commands_pb2 + import temporalio.bridge.proto.workflow_completion.workflow_completion_pb2 class PayloadVisitor: @@ -40,7 +39,7 @@ def __init__( self.skip_search_attributes = skip_search_attributes self.skip_headers = skip_headers - async def visit(self, fs: VisitorFunctions, root: Any) -> None: + async def visit(self, fs: VisitorFunctions, root: Message) -> None: """Visit the given root message with the given function set.""" method_name = "_visit_" + root.DESCRIPTOR.full_name.replace(".", "_") method = getattr(self, method_name, None) @@ -71,25 +70,31 @@ async def _visit_system_nexus_payload( payload.CopyFrom(new_payload) await fs.visit_system_nexus_envelope(payload) - async def _visit_temporal_api_common_v1_Payload(self, fs: VisitorFunctions, o: Any): + async def _visit_temporal_api_common_v1_Payload( + self, fs: VisitorFunctions, o: Payload + ): await fs.visit_payload(o) async def _visit_temporal_api_common_v1_Payloads( - self, fs: VisitorFunctions, o: Any + self, fs: VisitorFunctions, o: Payloads ): await fs.visit_payloads(o.payloads) - async def _visit_payload_container(self, fs: VisitorFunctions, o: Any): + async def _visit_payload_container(self, fs: VisitorFunctions, o: PayloadSequence): await fs.visit_payloads(o) async def _visit_temporal_api_failure_v1_ApplicationFailureInfo( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.api.failure.v1.message_pb2.ApplicationFailureInfo, ): if o.HasField("details"): await self._visit_temporal_api_common_v1_Payloads(fs, o.details) async def _visit_temporal_api_failure_v1_TimeoutFailureInfo( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.api.failure.v1.message_pb2.TimeoutFailureInfo, ): if o.HasField("last_heartbeat_details"): await self._visit_temporal_api_common_v1_Payloads( @@ -97,13 +102,17 @@ async def _visit_temporal_api_failure_v1_TimeoutFailureInfo( ) async def _visit_temporal_api_failure_v1_CanceledFailureInfo( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.api.failure.v1.message_pb2.CanceledFailureInfo, ): if o.HasField("details"): await self._visit_temporal_api_common_v1_Payloads(fs, o.details) async def _visit_temporal_api_failure_v1_ResetWorkflowFailureInfo( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.api.failure.v1.message_pb2.ResetWorkflowFailureInfo, ): if o.HasField("last_heartbeat_details"): await self._visit_temporal_api_common_v1_Payloads( @@ -111,7 +120,7 @@ async def _visit_temporal_api_failure_v1_ResetWorkflowFailureInfo( ) async def _visit_temporal_api_failure_v1_Failure( - self, fs: VisitorFunctions, o: Any + self, fs: VisitorFunctions, o: temporalio.api.failure.v1.message_pb2.Failure ): if o.HasField("encoded_attributes"): await self._visit_temporal_api_common_v1_Payload(fs, o.encoded_attributes) @@ -134,12 +143,14 @@ async def _visit_temporal_api_failure_v1_Failure( fs, o.reset_workflow_failure_info ) - async def _visit_temporal_api_common_v1_Memo(self, fs: VisitorFunctions, o: Any): + async def _visit_temporal_api_common_v1_Memo( + self, fs: VisitorFunctions, o: temporalio.api.common.v1.message_pb2.Memo + ): for v in o.fields.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) async def _visit_temporal_api_common_v1_SearchAttributes( - self, fs: VisitorFunctions, o: Any + self, fs: VisitorFunctions, o: SearchAttributes ): if self.skip_search_attributes: return @@ -147,7 +158,9 @@ async def _visit_temporal_api_common_v1_SearchAttributes( await self._visit_temporal_api_common_v1_Payload(fs, v) async def _visit_coresdk_workflow_activation_InitializeWorkflow( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_activation.workflow_activation_pb2.InitializeWorkflow, ): await self._visit_payload_container(fs, o.arguments) if not self.skip_headers: @@ -167,7 +180,9 @@ async def _visit_coresdk_workflow_activation_InitializeWorkflow( ) async def _visit_coresdk_workflow_activation_QueryWorkflow( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_activation.workflow_activation_pb2.QueryWorkflow, ): await self._visit_payload_container(fs, o.arguments) if not self.skip_headers: @@ -175,7 +190,9 @@ async def _visit_coresdk_workflow_activation_QueryWorkflow( await self._visit_temporal_api_common_v1_Payload(fs, v) async def _visit_coresdk_workflow_activation_SignalWorkflow( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_activation.workflow_activation_pb2.SignalWorkflow, ): await self._visit_payload_container(fs, o.input) if not self.skip_headers: @@ -183,25 +200,33 @@ async def _visit_coresdk_workflow_activation_SignalWorkflow( await self._visit_temporal_api_common_v1_Payload(fs, v) async def _visit_coresdk_activity_result_Success( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.activity_result.activity_result_pb2.Success, ): if o.HasField("result"): await self._visit_temporal_api_common_v1_Payload(fs, o.result) async def _visit_coresdk_activity_result_Failure( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.activity_result.activity_result_pb2.Failure, ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) async def _visit_coresdk_activity_result_Cancellation( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.activity_result.activity_result_pb2.Cancellation, ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) async def _visit_coresdk_activity_result_ActivityResolution( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.activity_result.activity_result_pb2.ActivityResolution, ): if o.HasField("completed"): await self._visit_coresdk_activity_result_Success(fs, o.completed) @@ -211,41 +236,59 @@ async def _visit_coresdk_activity_result_ActivityResolution( await self._visit_coresdk_activity_result_Cancellation(fs, o.cancelled) async def _visit_coresdk_workflow_activation_ResolveActivity( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_activation.workflow_activation_pb2.ResolveActivity, ): if o.HasField("result"): await self._visit_coresdk_activity_result_ActivityResolution(fs, o.result) async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStartCancelled( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_activation.workflow_activation_pb2.ResolveChildWorkflowExecutionStartCancelled, ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStart( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_activation.workflow_activation_pb2.ResolveChildWorkflowExecutionStart, ): if o.HasField("cancelled"): await self._visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStartCancelled( fs, o.cancelled ) - async def _visit_coresdk_child_workflow_Success(self, fs: VisitorFunctions, o: Any): + async def _visit_coresdk_child_workflow_Success( + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.child_workflow.child_workflow_pb2.Success, + ): if o.HasField("result"): await self._visit_temporal_api_common_v1_Payload(fs, o.result) - async def _visit_coresdk_child_workflow_Failure(self, fs: VisitorFunctions, o: Any): + async def _visit_coresdk_child_workflow_Failure( + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.child_workflow.child_workflow_pb2.Failure, + ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) async def _visit_coresdk_child_workflow_Cancellation( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.child_workflow.child_workflow_pb2.Cancellation, ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) async def _visit_coresdk_child_workflow_ChildWorkflowResult( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.child_workflow.child_workflow_pb2.ChildWorkflowResult, ): if o.HasField("completed"): await self._visit_coresdk_child_workflow_Success(fs, o.completed) @@ -255,25 +298,33 @@ async def _visit_coresdk_child_workflow_ChildWorkflowResult( await self._visit_coresdk_child_workflow_Cancellation(fs, o.cancelled) async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecution( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_activation.workflow_activation_pb2.ResolveChildWorkflowExecution, ): if o.HasField("result"): await self._visit_coresdk_child_workflow_ChildWorkflowResult(fs, o.result) async def _visit_coresdk_workflow_activation_ResolveSignalExternalWorkflow( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_activation.workflow_activation_pb2.ResolveSignalExternalWorkflow, ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) async def _visit_coresdk_workflow_activation_ResolveRequestCancelExternalWorkflow( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_activation.workflow_activation_pb2.ResolveRequestCancelExternalWorkflow, ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) async def _visit_coresdk_workflow_activation_DoUpdate( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_activation.workflow_activation_pb2.DoUpdate, ): await self._visit_payload_container(fs, o.input) if not self.skip_headers: @@ -281,13 +332,17 @@ async def _visit_coresdk_workflow_activation_DoUpdate( await self._visit_temporal_api_common_v1_Payload(fs, v) async def _visit_coresdk_workflow_activation_ResolveNexusOperationStart( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_activation.workflow_activation_pb2.ResolveNexusOperationStart, ): if o.HasField("failed"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failed) async def _visit_coresdk_nexus_NexusOperationResult( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.nexus.nexus_pb2.NexusOperationResult, ): if o.HasField("completed"): await self._visit_temporal_api_common_v1_Payload(fs, o.completed) @@ -299,13 +354,17 @@ async def _visit_coresdk_nexus_NexusOperationResult( await self._visit_temporal_api_failure_v1_Failure(fs, o.timed_out) async def _visit_coresdk_workflow_activation_ResolveNexusOperation( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_activation.workflow_activation_pb2.ResolveNexusOperation, ): if o.HasField("result"): await self._visit_coresdk_nexus_NexusOperationResult(fs, o.result) async def _visit_coresdk_workflow_activation_WorkflowActivationJob( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_activation.workflow_activation_pb2.WorkflowActivationJob, ): if o.HasField("initialize_workflow"): await self._visit_coresdk_workflow_activation_InitializeWorkflow( @@ -351,13 +410,15 @@ async def _visit_coresdk_workflow_activation_WorkflowActivationJob( ) async def _visit_coresdk_workflow_activation_WorkflowActivation( - self, fs: VisitorFunctions, o: Any + self, fs: VisitorFunctions, o: WorkflowActivation ): for v in o.jobs: await self._visit_coresdk_workflow_activation_WorkflowActivationJob(fs, v) async def _visit_temporal_api_sdk_v1_UserMetadata( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.api.sdk.v1.user_metadata_pb2.UserMetadata, ): if o.HasField("summary"): await self._visit_temporal_api_common_v1_Payload(fs, o.summary) @@ -365,7 +426,9 @@ async def _visit_temporal_api_sdk_v1_UserMetadata( await self._visit_temporal_api_common_v1_Payload(fs, o.details) async def _visit_coresdk_workflow_commands_ScheduleActivity( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_commands.workflow_commands_pb2.ScheduleActivity, ): if not self.skip_headers: for v in o.headers.values(): @@ -373,13 +436,17 @@ async def _visit_coresdk_workflow_commands_ScheduleActivity( await self._visit_payload_container(fs, o.arguments) async def _visit_coresdk_workflow_commands_QuerySuccess( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_commands.workflow_commands_pb2.QuerySuccess, ): if o.HasField("response"): await self._visit_temporal_api_common_v1_Payload(fs, o.response) async def _visit_coresdk_workflow_commands_QueryResult( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_commands.workflow_commands_pb2.QueryResult, ): if o.HasField("succeeded"): await self._visit_coresdk_workflow_commands_QuerySuccess(fs, o.succeeded) @@ -387,19 +454,25 @@ async def _visit_coresdk_workflow_commands_QueryResult( await self._visit_temporal_api_failure_v1_Failure(fs, o.failed) async def _visit_coresdk_workflow_commands_CompleteWorkflowExecution( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_commands.workflow_commands_pb2.CompleteWorkflowExecution, ): if o.HasField("result"): await self._visit_temporal_api_common_v1_Payload(fs, o.result) async def _visit_coresdk_workflow_commands_FailWorkflowExecution( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_commands.workflow_commands_pb2.FailWorkflowExecution, ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) async def _visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_commands.workflow_commands_pb2.ContinueAsNewWorkflowExecution, ): await self._visit_payload_container(fs, o.arguments) for v in o.memo.values(): @@ -413,7 +486,9 @@ async def _visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution( ) async def _visit_coresdk_workflow_commands_StartChildWorkflowExecution( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_commands.workflow_commands_pb2.StartChildWorkflowExecution, ): await self._visit_payload_container(fs, o.input) if not self.skip_headers: @@ -427,7 +502,9 @@ async def _visit_coresdk_workflow_commands_StartChildWorkflowExecution( ) async def _visit_coresdk_workflow_commands_SignalExternalWorkflowExecution( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_commands.workflow_commands_pb2.SignalExternalWorkflowExecution, ): await self._visit_payload_container(fs, o.args) if not self.skip_headers: @@ -435,7 +512,9 @@ async def _visit_coresdk_workflow_commands_SignalExternalWorkflowExecution( await self._visit_temporal_api_common_v1_Payload(fs, v) async def _visit_coresdk_workflow_commands_ScheduleLocalActivity( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_commands.workflow_commands_pb2.ScheduleLocalActivity, ): if not self.skip_headers: for v in o.headers.values(): @@ -443,7 +522,9 @@ async def _visit_coresdk_workflow_commands_ScheduleLocalActivity( await self._visit_payload_container(fs, o.arguments) async def _visit_coresdk_workflow_commands_UpsertWorkflowSearchAttributes( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_commands.workflow_commands_pb2.UpsertWorkflowSearchAttributes, ): if o.HasField("search_attributes"): await self._visit_temporal_api_common_v1_SearchAttributes( @@ -451,13 +532,17 @@ async def _visit_coresdk_workflow_commands_UpsertWorkflowSearchAttributes( ) async def _visit_coresdk_workflow_commands_ModifyWorkflowProperties( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_commands.workflow_commands_pb2.ModifyWorkflowProperties, ): if o.HasField("upserted_memo"): await self._visit_temporal_api_common_v1_Memo(fs, o.upserted_memo) async def _visit_coresdk_workflow_commands_UpdateResponse( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_commands.workflow_commands_pb2.UpdateResponse, ): if o.HasField("rejected"): await self._visit_temporal_api_failure_v1_Failure(fs, o.rejected) @@ -465,13 +550,17 @@ async def _visit_coresdk_workflow_commands_UpdateResponse( await self._visit_temporal_api_common_v1_Payload(fs, o.completed) async def _visit_coresdk_workflow_commands_ScheduleNexusOperation( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_commands.workflow_commands_pb2.ScheduleNexusOperation, ): if o.HasField("input"): await self._visit_system_nexus_payload(fs, o.service, o.operation, o.input) async def _visit_coresdk_workflow_commands_WorkflowCommand( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_commands.workflow_commands_pb2.WorkflowCommand, ): if o.HasField("user_metadata"): await self._visit_temporal_api_sdk_v1_UserMetadata(fs, o.user_metadata) @@ -525,19 +614,23 @@ async def _visit_coresdk_workflow_commands_WorkflowCommand( ) async def _visit_coresdk_workflow_completion_Success( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_completion.workflow_completion_pb2.Success, ): for v in o.commands: await self._visit_coresdk_workflow_commands_WorkflowCommand(fs, v) async def _visit_coresdk_workflow_completion_Failure( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_completion.workflow_completion_pb2.Failure, ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) async def _visit_coresdk_workflow_completion_WorkflowActivationCompletion( - self, fs: VisitorFunctions, o: Any + self, fs: VisitorFunctions, o: WorkflowActivationCompletion ): if o.HasField("successful"): await self._visit_coresdk_workflow_completion_Success(fs, o.successful) diff --git a/temporalio/bridge/_visitor_functions.py b/temporalio/bridge/_visitor_functions.py new file mode 100644 index 000000000..b2d5be153 --- /dev/null +++ b/temporalio/bridge/_visitor_functions.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from typing import Protocol + +from google.protobuf.internal.containers import RepeatedCompositeFieldContainer + +from temporalio.api.common.v1.message_pb2 import Payload + +PayloadSequence = list[Payload] | RepeatedCompositeFieldContainer[Payload] + + +class VisitorFunctions(Protocol): + """Functions invoked by generated payload visitors.""" + + async def visit_payload(self, payload: Payload) -> None: + """Visit a single payload.""" + ... + + async def visit_payloads(self, payloads: PayloadSequence) -> None: + """Visit a sequence of payloads together.""" + ... + + async def visit_system_nexus_envelope(self, payload: Payload) -> None: + """Visit a recognized system Nexus envelope payload.""" + ... diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index 996ed8c19..b0193e2fc 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -5,11 +5,9 @@ from __future__ import annotations -from collections.abc import Awaitable, Callable, MutableSequence, Sequence +from collections.abc import Awaitable, Callable, Sequence from dataclasses import dataclass -from typing import ( - TypeAlias, -) +from typing import TypeAlias import temporalio.bridge.client import temporalio.bridge.proto @@ -21,7 +19,7 @@ import temporalio.bridge.temporal_sdk_bridge import temporalio.converter from temporalio.api.common.v1.message_pb2 import Payload -from temporalio.bridge._visitor import VisitorFunctions +from temporalio.bridge._visitor_functions import PayloadSequence, VisitorFunctions from temporalio.bridge.temporal_sdk_bridge import ( CustomSlotSupplier as BridgeCustomSlotSupplier, ) @@ -292,7 +290,7 @@ async def visit_payload(self, payload: Payload) -> None: if new_payload is not payload: payload.CopyFrom(new_payload) - async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: + async def visit_payloads(self, payloads: PayloadSequence) -> None: if len(payloads) == 0: return new_payloads = await self._f(payloads) diff --git a/temporalio/nexus/system/__init__.py b/temporalio/nexus/system/__init__.py index e8fbff1c4..822cb3af7 100644 --- a/temporalio/nexus/system/__init__.py +++ b/temporalio/nexus/system/__init__.py @@ -1,8 +1,8 @@ """Generated system Nexus service models.""" -from collections.abc import Mapping, MutableSequence, Sequence +from collections.abc import Mapping, Sequence from datetime import timedelta -from typing import Any, Protocol, cast +from typing import Any, cast import temporalio.api.common.v1 import temporalio.api.sdk.v1 @@ -11,25 +11,12 @@ import temporalio.converter from ... import _workflow_requests +from ...bridge._visitor_functions import VisitorFunctions from ...converter import BinaryProtoPayloadConverter, CompositePayloadConverter from . import _workflow_service_generated as generated from ._payload_visitor import PayloadVisitor -class PayloadVisitorFunctions(Protocol): - async def visit_payload( - self, payload: temporalio.api.common.v1.Payload - ) -> None: ... - - async def visit_payloads( - self, payloads: MutableSequence[temporalio.api.common.v1.Payload] - ) -> None: ... - - async def visit_system_nexus_envelope( - self, payload: temporalio.api.common.v1.Payload - ) -> None: ... - - class SystemNexusPayloadConverter(CompositePayloadConverter): """Payload converter for system Nexus outer envelopes.""" @@ -134,7 +121,7 @@ async def visit_payload( service: str, operation: str, payload: temporalio.api.common.v1.Payload, - visitor_functions: PayloadVisitorFunctions, + visitor_functions: VisitorFunctions, skip_search_attributes: bool, ) -> temporalio.api.common.v1.Payload | None: """Visit nested payloads inside a recognized system Nexus envelope.""" diff --git a/temporalio/nexus/system/_payload_visitor.py b/temporalio/nexus/system/_payload_visitor.py index 6a247f643..83ad41f5c 100644 --- a/temporalio/nexus/system/_payload_visitor.py +++ b/temporalio/nexus/system/_payload_visitor.py @@ -1,31 +1,20 @@ -# This file is generated by gen_payload_visitor.py. Changes should be made there. -import abc -from collections.abc import MutableSequence -from typing import Any - -import temporalio.nexus.system -from temporalio.api.common.v1.message_pb2 import Payload +from __future__ import annotations +# This file is generated by gen_payload_visitor.py. Changes should be made there. +from typing import TYPE_CHECKING -class VisitorFunctions(abc.ABC): - """Set of functions which can be called by the visitor. - Allows handling payloads as a sequence. - """ - - @abc.abstractmethod - async def visit_payload(self, payload: Payload) -> None: - """Called when encountering a single payload.""" - raise NotImplementedError() +from google.protobuf.message import Message - @abc.abstractmethod - async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: - """Called when encountering multiple payloads together.""" - raise NotImplementedError() +import temporalio.nexus.system +from temporalio.api.common.v1.message_pb2 import Payload, Payloads, SearchAttributes +from temporalio.api.workflowservice.v1.request_response_pb2 import ( + SignalWithStartWorkflowExecutionRequest, +) +from temporalio.bridge._visitor_functions import PayloadSequence, VisitorFunctions - @abc.abstractmethod - async def visit_system_nexus_envelope(self, payload: Payload) -> None: - """Called when encountering a recognized system Nexus envelope payload.""" - raise NotImplementedError() +if TYPE_CHECKING: + import temporalio.api.common.v1.message_pb2 + import temporalio.api.sdk.v1.user_metadata_pb2 class PayloadVisitor: @@ -40,7 +29,7 @@ def __init__( self.skip_search_attributes = skip_search_attributes self.skip_headers = skip_headers - async def visit(self, fs: VisitorFunctions, root: Any) -> None: + async def visit(self, fs: VisitorFunctions, root: Message) -> None: """Visit the given root message with the given function set.""" method_name = "_visit_" + root.DESCRIPTOR.full_name.replace(".", "_") method = getattr(self, method_name, None) @@ -71,35 +60,43 @@ async def _visit_system_nexus_payload( payload.CopyFrom(new_payload) await fs.visit_system_nexus_envelope(payload) - async def _visit_temporal_api_common_v1_Payload(self, fs: VisitorFunctions, o: Any): + async def _visit_temporal_api_common_v1_Payload( + self, fs: VisitorFunctions, o: Payload + ): await fs.visit_payload(o) async def _visit_temporal_api_common_v1_Payloads( - self, fs: VisitorFunctions, o: Any + self, fs: VisitorFunctions, o: Payloads ): await fs.visit_payloads(o.payloads) - async def _visit_payload_container(self, fs: VisitorFunctions, o: Any): + async def _visit_payload_container(self, fs: VisitorFunctions, o: PayloadSequence): await fs.visit_payloads(o) - async def _visit_temporal_api_common_v1_Memo(self, fs: VisitorFunctions, o: Any): + async def _visit_temporal_api_common_v1_Memo( + self, fs: VisitorFunctions, o: temporalio.api.common.v1.message_pb2.Memo + ): for v in o.fields.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) async def _visit_temporal_api_common_v1_SearchAttributes( - self, fs: VisitorFunctions, o: Any + self, fs: VisitorFunctions, o: SearchAttributes ): if self.skip_search_attributes: return for v in o.indexed_fields.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) - async def _visit_temporal_api_common_v1_Header(self, fs: VisitorFunctions, o: Any): + async def _visit_temporal_api_common_v1_Header( + self, fs: VisitorFunctions, o: temporalio.api.common.v1.message_pb2.Header + ): for v in o.fields.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) async def _visit_temporal_api_sdk_v1_UserMetadata( - self, fs: VisitorFunctions, o: Any + self, + fs: VisitorFunctions, + o: temporalio.api.sdk.v1.user_metadata_pb2.UserMetadata, ): if o.HasField("summary"): await self._visit_temporal_api_common_v1_Payload(fs, o.summary) @@ -107,7 +104,7 @@ async def _visit_temporal_api_sdk_v1_UserMetadata( await self._visit_temporal_api_common_v1_Payload(fs, o.details) async def _visit_temporal_api_workflowservice_v1_SignalWithStartWorkflowExecutionRequest( - self, fs: VisitorFunctions, o: Any + self, fs: VisitorFunctions, o: SignalWithStartWorkflowExecutionRequest ): if o.HasField("input"): await self._visit_temporal_api_common_v1_Payloads(fs, o.input) diff --git a/temporalio/nexus/system/_workflow_service_generated.py b/temporalio/nexus/system/_workflow_service_generated.py index 6d0fc9e65..ab914993f 100644 --- a/temporalio/nexus/system/_workflow_service_generated.py +++ b/temporalio/nexus/system/_workflow_service_generated.py @@ -1,5 +1,4 @@ # Generated by nexus-rpc-gen. DO NOT EDIT! -# pyright: reportDeprecated=false from __future__ import annotations @@ -9,10 +8,6 @@ import temporalio.api.workflowservice.v1 -__nexus_operation_registry__: dict[ - tuple[str, str], Operation[typing.Any, typing.Any] -] = {} - @service class WorkflowService: @@ -22,6 +17,10 @@ class WorkflowService: ] = Operation(name="SignalWithStartWorkflowExecution") +__nexus_operation_registry__: dict[ + tuple[str, str], Operation[typing.Any, typing.Any] +] = {} + __nexus_operation_registry__[ ("WorkflowService", "SignalWithStartWorkflowExecution") ] = WorkflowService.signal_with_start_workflow_execution diff --git a/temporalio/worker/_command_aware_visitor.py b/temporalio/worker/_command_aware_visitor.py index 10aea1422..72d41891d 100644 --- a/temporalio/worker/_command_aware_visitor.py +++ b/temporalio/worker/_command_aware_visitor.py @@ -6,7 +6,8 @@ from dataclasses import dataclass from temporalio.api.enums.v1.command_type_pb2 import CommandType -from temporalio.bridge._visitor import PayloadVisitor, VisitorFunctions +from temporalio.bridge._visitor import PayloadVisitor +from temporalio.bridge._visitor_functions import VisitorFunctions from temporalio.bridge.proto.workflow_activation.workflow_activation_pb2 import ( ResolveActivity, ResolveChildWorkflowExecution, diff --git a/tests/worker/test_visitor.py b/tests/worker/test_visitor.py index c64a48a04..efe9a36aa 100644 --- a/tests/worker/test_visitor.py +++ b/tests/worker/test_visitor.py @@ -1,8 +1,9 @@ import dataclasses -from collections.abc import MutableSequence +from typing import get_type_hints import pytest from google.protobuf.duration_pb2 import Duration +from google.protobuf.message import Message import temporalio.bridge.worker import temporalio.converter @@ -17,7 +18,10 @@ from temporalio.api.workflowservice.v1.request_response_pb2 import ( SignalWithStartWorkflowExecutionRequest, ) -from temporalio.bridge._visitor import PayloadVisitor, VisitorFunctions +from temporalio.bridge._visitor import ( + PayloadVisitor, +) +from temporalio.bridge._visitor_functions import PayloadSequence, VisitorFunctions from temporalio.bridge.proto.workflow_activation.workflow_activation_pb2 import ( InitializeWorkflow, WorkflowActivation, @@ -41,6 +45,9 @@ _PayloadSizeError, _ServerPayloadErrorLimits, ) +from temporalio.nexus.system._payload_visitor import ( + PayloadVisitor as SystemPayloadVisitor, +) from tests.worker.test_workflow import SimpleCodec @@ -48,7 +55,7 @@ class Visitor(VisitorFunctions): async def visit_payload(self, payload: Payload) -> None: payload.metadata["visited"] = b"True" - async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: + async def visit_payloads(self, payloads: PayloadSequence) -> None: for payload in payloads: payload.metadata["visited"] = b"True" @@ -60,6 +67,30 @@ def _json_plain_payload(value: object) -> Payload: return temporalio.converter.default().payload_converter.to_payload(value) +def test_generated_payload_visitor_type_annotations(): + assert get_type_hints(PayloadVisitor.visit)["root"] is Message + assert ( + get_type_hints(PayloadVisitor._visit_temporal_api_common_v1_Payload)["o"] + is Payload + ) + assert ( + get_type_hints(PayloadVisitor._visit_temporal_api_common_v1_Payloads)["o"] + is Payloads + ) + assert ( + get_type_hints( + PayloadVisitor._visit_coresdk_workflow_activation_WorkflowActivation + )["o"] + is WorkflowActivation + ) + assert ( + get_type_hints( + SystemPayloadVisitor._visit_temporal_api_workflowservice_v1_SignalWithStartWorkflowExecutionRequest + )["o"] + is SignalWithStartWorkflowExecutionRequest + ) + + async def test_workflow_activation_completion(): comp = WorkflowActivationCompletion( run_id="1", From 7aaff388af3f2c907c4cfe8e0f95921899d9ec7c Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Fri, 24 Apr 2026 10:18:52 -0700 Subject: [PATCH 17/18] Refine Python system nexus request handling --- scripts/gen_nexus_system_models.py | 36 +++---------------- scripts/gen_payload_visitor.py | 29 +++++---------- temporalio/bridge/_visitor.py | 36 ++++++++----------- .../contrib/opentelemetry/_interceptor.py | 2 +- .../opentelemetry/_otel_interceptor.py | 2 +- temporalio/nexus/system/__init__.py | 20 +++++------ temporalio/nexus/system/_payload_visitor.py | 16 ++++----- .../system/_workflow_service_generated.py | 11 +++--- temporalio/worker/_interceptor.py | 2 +- temporalio/worker/_workflow_instance.py | 7 ++-- temporalio/workflow.py | 12 +++---- tests/nexus/test_temporal_system_nexus.py | 2 +- 12 files changed, 63 insertions(+), 112 deletions(-) diff --git a/scripts/gen_nexus_system_models.py b/scripts/gen_nexus_system_models.py index 11f72a90f..b3abdac52 100644 --- a/scripts/gen_nexus_system_models.py +++ b/scripts/gen_nexus_system_models.py @@ -1,7 +1,6 @@ from __future__ import annotations import importlib -import re import subprocess import sys from pathlib import Path @@ -11,8 +10,6 @@ def main() -> None: repo_root = Path(__file__).resolve().parent.parent - # TODO: Remove the local .nexusrpc.yaml shim once the upstream API repo - # checks in the Nexus definition we can consume directly. input_schema = ( repo_root / "temporalio" @@ -65,7 +62,7 @@ def main() -> None: def add_operation_registry(repo_root: Path, output_file: Path) -> None: - source = strip_existing_operation_registry(output_file.read_text()) + source = output_file.read_text() source = ensure_typing_import(source) services = discover_services(repo_root) if not services: @@ -74,24 +71,6 @@ def add_operation_registry(repo_root: Path, output_file: Path) -> None: output_file.write_text(source.rstrip() + "\n\n" + emit_operation_registry(services)) -def strip_existing_operation_registry(source: str) -> str: - source = re.sub( - r"\nimport typing\n(?=\n__nexus_operation_registry__)", - "\n", - source, - ) - source = re.sub( - r"\n__nexus_operation_registry__: dict\[\n" - r"(?:.*\n)*?" - r"\] = \{\}\n" - r"(?:\n__nexus_operation_registry__\[\n(?:.*\n)*?\] = .+\n)+", - "\n", - source, - flags=re.MULTILINE, - ) - return source.rstrip() + "\n" - - def ensure_typing_import(source: str) -> str: if "\nimport typing\n" in source: return source @@ -131,19 +110,14 @@ def emit_operation_registry( lines = [ "__nexus_operation_registry__: dict[", " tuple[str, str], Operation[typing.Any, typing.Any]", - "] = {}", - "", + "] = {", ] for class_name, service_name, operations in services: for attr_name, operation_name in operations: - lines.extend( - [ - "__nexus_operation_registry__[", - f" ({service_name!r}, {operation_name!r})", - f"] = {class_name}.{attr_name}", - "", - ] + lines.append( + f" ({service_name!r}, {operation_name!r}): {class_name}.{attr_name}," ) + lines.append("}") return "\n".join(lines).rstrip() + "\n" diff --git a/scripts/gen_payload_visitor.py b/scripts/gen_payload_visitor.py index fd7e3e604..c54c752b1 100644 --- a/scripts/gen_payload_visitor.py +++ b/scripts/gen_payload_visitor.py @@ -132,8 +132,8 @@ def __init__(self) -> None: Payloads.DESCRIPTOR.full_name: True, } self.in_progress: set[str] = set() - self.root_type_imports: dict[str, tuple[str, str]] = {} - self.type_checking_modules: set[str] = set() + self.runtime_type_imports: dict[str, tuple[str, str]] = {} + self.runtime_module_imports: set[str] = set() self.methods: list[str] = [ """\ async def _visit_temporal_api_common_v1_Payload( @@ -162,25 +162,22 @@ def generate(self, roots: list[Descriptor]) -> str: extra_imports = "\n".join( f"from {module} import {class_name}" for class_name, module in sorted( - set(self.root_type_imports.values()), + set(self.runtime_type_imports.values()), key=lambda item: (item[1], item[0]), ) ) - type_checking_imports = "\n".join( - f" import {module}" for module in sorted(self.type_checking_modules) + module_imports = "\n".join( + f"import {module}" for module in sorted(self.runtime_module_imports) ) if extra_imports: extra_imports += "\n" - if type_checking_imports: - type_checking_imports = ( - "\nif TYPE_CHECKING:\n" + type_checking_imports + "\n" - ) + if module_imports: + module_imports += "\n" header = """ from __future__ import annotations # This file is generated by gen_payload_visitor.py. Changes should be made there. -from typing import TYPE_CHECKING from google.protobuf.message import Message import temporalio.nexus.system @@ -188,7 +185,7 @@ def generate(self, roots: list[Descriptor]) -> str: from temporalio.bridge._visitor_functions import PayloadSequence, VisitorFunctions """ header += extra_imports - header += type_checking_imports + header += module_imports header += """ class PayloadVisitor: @@ -245,15 +242,7 @@ def python_type_for_descriptor(self, desc: Descriptor) -> str: if cls is None: return "Message" module = normalize_python_module(cls.__module__) - if desc in ( - WorkflowActivation.DESCRIPTOR, - WorkflowActivationCompletion.DESCRIPTOR, - SignalWithStartWorkflowExecutionRequest.DESCRIPTOR, - SignalWithStartWorkflowExecutionResponse.DESCRIPTOR, - ): - self.root_type_imports[desc.full_name] = (cls.__name__, module) - return cls.__name__ - self.type_checking_modules.add(module) + self.runtime_module_imports.add(module) return f"{module}.{nested_python_name(desc)}" def check_repeated( diff --git a/temporalio/bridge/_visitor.py b/temporalio/bridge/_visitor.py index dbbf16816..897e256e1 100644 --- a/temporalio/bridge/_visitor.py +++ b/temporalio/bridge/_visitor.py @@ -1,30 +1,20 @@ from __future__ import annotations # This file is generated by gen_payload_visitor.py. Changes should be made there. -from typing import TYPE_CHECKING - from google.protobuf.message import Message +import temporalio.api.common.v1.message_pb2 +import temporalio.api.failure.v1.message_pb2 +import temporalio.api.sdk.v1.user_metadata_pb2 +import temporalio.bridge.proto.activity_result.activity_result_pb2 +import temporalio.bridge.proto.child_workflow.child_workflow_pb2 +import temporalio.bridge.proto.nexus.nexus_pb2 +import temporalio.bridge.proto.workflow_activation.workflow_activation_pb2 +import temporalio.bridge.proto.workflow_commands.workflow_commands_pb2 +import temporalio.bridge.proto.workflow_completion.workflow_completion_pb2 import temporalio.nexus.system from temporalio.api.common.v1.message_pb2 import Payload, Payloads, SearchAttributes from temporalio.bridge._visitor_functions import PayloadSequence, VisitorFunctions -from temporalio.bridge.proto.workflow_activation.workflow_activation_pb2 import ( - WorkflowActivation, -) -from temporalio.bridge.proto.workflow_completion.workflow_completion_pb2 import ( - WorkflowActivationCompletion, -) - -if TYPE_CHECKING: - import temporalio.api.common.v1.message_pb2 - import temporalio.api.failure.v1.message_pb2 - import temporalio.api.sdk.v1.user_metadata_pb2 - import temporalio.bridge.proto.activity_result.activity_result_pb2 - import temporalio.bridge.proto.child_workflow.child_workflow_pb2 - import temporalio.bridge.proto.nexus.nexus_pb2 - import temporalio.bridge.proto.workflow_activation.workflow_activation_pb2 - import temporalio.bridge.proto.workflow_commands.workflow_commands_pb2 - import temporalio.bridge.proto.workflow_completion.workflow_completion_pb2 class PayloadVisitor: @@ -410,7 +400,9 @@ async def _visit_coresdk_workflow_activation_WorkflowActivationJob( ) async def _visit_coresdk_workflow_activation_WorkflowActivation( - self, fs: VisitorFunctions, o: WorkflowActivation + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_activation.workflow_activation_pb2.WorkflowActivation, ): for v in o.jobs: await self._visit_coresdk_workflow_activation_WorkflowActivationJob(fs, v) @@ -630,7 +622,9 @@ async def _visit_coresdk_workflow_completion_Failure( await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) async def _visit_coresdk_workflow_completion_WorkflowActivationCompletion( - self, fs: VisitorFunctions, o: WorkflowActivationCompletion + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_completion.workflow_completion_pb2.WorkflowActivationCompletion, ): if o.HasField("successful"): await self._visit_coresdk_workflow_completion_Success(fs, o.successful) diff --git a/temporalio/contrib/opentelemetry/_interceptor.py b/temporalio/contrib/opentelemetry/_interceptor.py index 502a25f16..7e19749c9 100644 --- a/temporalio/contrib/opentelemetry/_interceptor.py +++ b/temporalio/contrib/opentelemetry/_interceptor.py @@ -777,7 +777,7 @@ async def signal_with_start_workflow( ) -> temporalio.workflow.ExternalWorkflowHandle[Any]: self.root._completed_span( f"SignalWithStartWorkflow:{input.signal}", - add_to_outbound_str=input, + add_to_outbound=input, kind=opentelemetry.trace.SpanKind.CLIENT, ) return await super().signal_with_start_workflow(input) diff --git a/temporalio/contrib/opentelemetry/_otel_interceptor.py b/temporalio/contrib/opentelemetry/_otel_interceptor.py index 4b6d8b537..f60473c63 100644 --- a/temporalio/contrib/opentelemetry/_otel_interceptor.py +++ b/temporalio/contrib/opentelemetry/_otel_interceptor.py @@ -552,7 +552,7 @@ async def signal_with_start_workflow( f"SignalWithStartWorkflow:{input.signal}", kind=opentelemetry.trace.SpanKind.CLIENT, ): - input.headers = _context_to_nexus_headers(input.headers or {}) + input.headers = _context_to_headers(input.headers) return await super().signal_with_start_workflow(input) def start_activity( diff --git a/temporalio/nexus/system/__init__.py b/temporalio/nexus/system/__init__.py index 822cb3af7..cf51e9714 100644 --- a/temporalio/nexus/system/__init__.py +++ b/temporalio/nexus/system/__init__.py @@ -16,6 +16,8 @@ from . import _workflow_service_generated as generated from ._payload_visitor import PayloadVisitor +_SYSTEM_NEXUS_ENDPOINT = "__temporal_system" + class SystemNexusPayloadConverter(CompositePayloadConverter): """Payload converter for system Nexus outer envelopes.""" @@ -25,15 +27,6 @@ def __init__(self) -> None: super().__init__(BinaryProtoPayloadConverter()) -def _set_payload_map( - target: Any, - values: Mapping[str, Any], - payload_converter: temporalio.converter.PayloadConverter, -) -> None: - for key, value in values.items(): - target[key].CopyFrom(payload_converter.to_payload(value)) - - def _build_user_metadata( payload_converter: temporalio.converter.PayloadConverter, static_summary: str | None, @@ -60,6 +53,7 @@ def build_signal_with_start_workflow_execution_input( task_queue: str, request_id: str | None, payload_converter: temporalio.converter.PayloadConverter, + headers: Mapping[str, temporalio.api.common.v1.Payload], execution_timeout: timedelta | None = None, run_timeout: timedelta | None = None, task_timeout: timedelta | None = None, @@ -79,13 +73,18 @@ def build_signal_with_start_workflow_execution_input( request_memo = None if memo is not None: request_memo = temporalio.api.common.v1.Memo() - _set_payload_map(request_memo.fields, memo, payload_converter) + for key, value in memo.items(): + request_memo.fields[key].CopyFrom(payload_converter.to_payload(value)) request_search_attributes = None if search_attributes is not None: request_search_attributes = temporalio.api.common.v1.SearchAttributes() temporalio.converter.encode_search_attributes( search_attributes, request_search_attributes ) + request_header = None + if headers: + request_header = temporalio.api.common.v1.Header() + temporalio.common._apply_headers(headers, request_header.fields) return _workflow_requests.build_signal_with_start_workflow_execution_request( namespace=namespace, workflow_id=workflow_id, @@ -108,6 +107,7 @@ def build_signal_with_start_workflow_execution_input( cron_schedule=cron_schedule, memo=request_memo, search_attributes=request_search_attributes, + header=request_header, user_metadata=_build_user_metadata( payload_converter, static_summary, static_details ), diff --git a/temporalio/nexus/system/_payload_visitor.py b/temporalio/nexus/system/_payload_visitor.py index 83ad41f5c..fb7c5e173 100644 --- a/temporalio/nexus/system/_payload_visitor.py +++ b/temporalio/nexus/system/_payload_visitor.py @@ -1,21 +1,15 @@ from __future__ import annotations # This file is generated by gen_payload_visitor.py. Changes should be made there. -from typing import TYPE_CHECKING - from google.protobuf.message import Message +import temporalio.api.common.v1.message_pb2 +import temporalio.api.sdk.v1.user_metadata_pb2 +import temporalio.api.workflowservice.v1.request_response_pb2 import temporalio.nexus.system from temporalio.api.common.v1.message_pb2 import Payload, Payloads, SearchAttributes -from temporalio.api.workflowservice.v1.request_response_pb2 import ( - SignalWithStartWorkflowExecutionRequest, -) from temporalio.bridge._visitor_functions import PayloadSequence, VisitorFunctions -if TYPE_CHECKING: - import temporalio.api.common.v1.message_pb2 - import temporalio.api.sdk.v1.user_metadata_pb2 - class PayloadVisitor: """A visitor for payloads. @@ -104,7 +98,9 @@ async def _visit_temporal_api_sdk_v1_UserMetadata( await self._visit_temporal_api_common_v1_Payload(fs, o.details) async def _visit_temporal_api_workflowservice_v1_SignalWithStartWorkflowExecutionRequest( - self, fs: VisitorFunctions, o: SignalWithStartWorkflowExecutionRequest + self, + fs: VisitorFunctions, + o: temporalio.api.workflowservice.v1.request_response_pb2.SignalWithStartWorkflowExecutionRequest, ): if o.HasField("input"): await self._visit_temporal_api_common_v1_Payloads(fs, o.input) diff --git a/temporalio/nexus/system/_workflow_service_generated.py b/temporalio/nexus/system/_workflow_service_generated.py index ab914993f..732e746d2 100644 --- a/temporalio/nexus/system/_workflow_service_generated.py +++ b/temporalio/nexus/system/_workflow_service_generated.py @@ -19,8 +19,9 @@ class WorkflowService: __nexus_operation_registry__: dict[ tuple[str, str], Operation[typing.Any, typing.Any] -] = {} - -__nexus_operation_registry__[ - ("WorkflowService", "SignalWithStartWorkflowExecution") -] = WorkflowService.signal_with_start_workflow_execution +] = { + ( + "WorkflowService", + "SignalWithStartWorkflowExecution", + ): WorkflowService.signal_with_start_workflow_execution, +} diff --git a/temporalio/worker/_interceptor.py b/temporalio/worker/_interceptor.py index ec3c606b6..d05f15b2a 100644 --- a/temporalio/worker/_interceptor.py +++ b/temporalio/worker/_interceptor.py @@ -267,7 +267,7 @@ class SignalWithStartWorkflowInput: request_id: str | None priority: temporalio.common.Priority versioning_override: temporalio.common.VersioningOverride | None - headers: Mapping[str, str] | None + headers: Mapping[str, temporalio.api.common.v1.Payload] @dataclass diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index de4233e11..e19d20a67 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -1605,7 +1605,7 @@ async def workflow_signal_with_start_workflow( request_id=request_id, priority=priority, versioning_override=versioning_override, - headers=None, + headers={}, ) ) @@ -2048,11 +2048,12 @@ async def _outbound_signal_with_start_workflow( start_delay=input.start_delay, priority=input.priority, versioning_override=input.versioning_override, + headers=input.headers, ) ) handle = await self._outbound_start_nexus_operation( StartNexusOperationInput( - endpoint=temporalio.workflow._SYSTEM_NEXUS_ENDPOINT, + endpoint=temporalio.nexus.system._SYSTEM_NEXUS_ENDPOINT, service=temporalio.nexus.system.generated.WorkflowService.__name__, operation=temporalio.nexus.system.generated.WorkflowService.signal_with_start_workflow_execution.name, input=request, @@ -2060,7 +2061,7 @@ async def _outbound_signal_with_start_workflow( schedule_to_start_timeout=None, start_to_close_timeout=None, cancellation_type=temporalio.workflow.NexusOperationCancellationType.WAIT_COMPLETED, - headers=input.headers, + headers=None, summary=None, output_type=temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionResponse, ) diff --git a/temporalio/workflow.py b/temporalio/workflow.py index dfd8ea5d3..a87c93388 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -2067,10 +2067,10 @@ def _assert_dynamic_handler_args( if ( not arg_types or len(arg_types) != 2 - or arg_types[0] != str + or arg_types[0] is not str or ( - arg_types[1] != Sequence[temporalio.common.RawValue] - and arg_types[1] != typing.Sequence[temporalio.common.RawValue] # type: ignore[reportDeprecated] + arg_types[1] is not Sequence[temporalio.common.RawValue] + and arg_types[1] is not typing.Sequence[temporalio.common.RawValue] # type: ignore[reportDeprecated] ) ): raise RuntimeError( @@ -4308,10 +4308,6 @@ class ChildWorkflowConfig(TypedDict, total=False): static_details: str | None priority: temporalio.common.Priority - -_SYSTEM_NEXUS_ENDPOINT = "__temporal_system" - - # Overload for no-param workflow @overload async def start_child_workflow( @@ -4962,7 +4958,7 @@ class ContinueAsNewError(BaseException): def __init__(self, *args: object) -> None: """Direct instantiation is disabled. Use :py:func:`continue_as_new`.""" - if type(self) == ContinueAsNewError: + if type(self) is ContinueAsNewError: raise RuntimeError("Cannot instantiate ContinueAsNewError directly") super().__init__(*args) diff --git a/tests/nexus/test_temporal_system_nexus.py b/tests/nexus/test_temporal_system_nexus.py index f0b8e726a..5f19c37b5 100644 --- a/tests/nexus/test_temporal_system_nexus.py +++ b/tests/nexus/test_temporal_system_nexus.py @@ -326,7 +326,7 @@ async def test_external_workflow_handle_signal_with_start_workflow_uses_system_n caller_task_queue = str(uuid.uuid4()) handler_task_queue = str(uuid.uuid4()) endpoint_name = make_nexus_endpoint_name(handler_task_queue) - monkeypatch.setattr(workflow, "_SYSTEM_NEXUS_ENDPOINT", endpoint_name) + monkeypatch.setattr(nexus_system, "_SYSTEM_NEXUS_ENDPOINT", endpoint_name) caller_worker = Worker( caller_client, From 1d35209fedce072d0859ea1091849f7f954523a8 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Fri, 24 Apr 2026 10:29:49 -0700 Subject: [PATCH 18/18] Tighten Python type-check exclusions --- pyproject.toml | 3 +-- temporalio/workflow.py | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2f5cbfef8..425ba8091 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -142,10 +142,9 @@ environment = { PATH = "$PATH:$HOME/.cargo/bin", CARGO_NET_GIT_FETCH_WITH_CLI = ignore_missing_imports = true exclude = [ # Ignore generated code - 'build', + 'build/tool-cache', 'temporalio/api', 'temporalio/bridge/proto', - 'temporalio/nexus/system/_workflow_service_generated.py', ] [tool.pydocstyle] diff --git a/temporalio/workflow.py b/temporalio/workflow.py index a87c93388..b8106cec2 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -4308,6 +4308,7 @@ class ChildWorkflowConfig(TypedDict, total=False): static_details: str | None priority: temporalio.common.Priority + # Overload for no-param workflow @overload async def start_child_workflow(