diff --git a/README.md b/README.md index 6e42a2019..736f04e8a 100644 --- a/README.md +++ b/README.md @@ -1933,6 +1933,7 @@ To build the SDK from source for use as a dependency, the following prerequisite * [uv](https://docs.astral.sh/uv/) * [Rust](https://www.rust-lang.org/) * [Protobuf Compiler](https://protobuf.dev/) +* [Node.js](https://nodejs.org/) Use `uv` to install `poe`: @@ -2074,6 +2075,12 @@ back from this downgrade, restore both of those files and run `uv sync --all-ext run for protobuf version 3 by setting the `TEMPORAL_TEST_PROTO3` env var to `1` prior to running tests. +The local build and lint flows also regenerate Temporal system Nexus models. By default this pulls +in `nexus-rpc-gen@0.1.0-alpha.4` via `npx`. To use an existing checkout instead, set +`TEMPORAL_NEXUS_RPC_GEN_DIR` to the `nexus-rpc-gen` repo root or its `src` directory before +running `poe build-develop`, `poe lint`, or `poe gen-protos`. The local checkout override path +also requires [`pnpm`](https://pnpm.io/) to be installed. + ### Style * Mostly [Google Style Guide](https://google.github.io/styleguide/pyguide.html). Notable exceptions: diff --git a/pyproject.toml b/pyproject.toml index 4bcd3f03e..425ba8091 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,12 +79,14 @@ gen-protos = [ { cmd = "uv run scripts/gen_protos.py" }, { cmd = "uv run scripts/gen_payload_visitor.py" }, { cmd = "uv run scripts/gen_bridge_client.py" }, + { ref = "gen-nexus-system-models" }, { ref = "format" }, ] gen-protos-docker = [ { cmd = "uv run scripts/gen_protos_docker.py" }, { cmd = "uv run scripts/gen_payload_visitor.py" }, { cmd = "uv run scripts/gen_bridge_client.py" }, + { ref = "gen-nexus-system-models" }, { ref = "format" }, ] lint = [ @@ -102,6 +104,7 @@ lint-types = [ { cmd = "uv run mypy --namespace-packages --check-untyped-defs ." }, { cmd = "uv run basedpyright" }, ] +gen-nexus-system-models = "uv run scripts/gen_nexus_system_models.py" run-bench = "uv run python scripts/run_bench.py" test = "uv run pytest" @@ -139,6 +142,7 @@ environment = { PATH = "$PATH:$HOME/.cargo/bin", CARGO_NET_GIT_FETCH_WITH_CLI = ignore_missing_imports = true exclude = [ # Ignore generated code + 'build/tool-cache', 'temporalio/api', 'temporalio/bridge/proto', ] @@ -146,7 +150,8 @@ exclude = [ [tool.pydocstyle] convention = "google" # https://github.com/PyCQA/pydocstyle/issues/363#issuecomment-625563088 -match_dir = "^(?!(docs|scripts|tests|api|proto|\\.)).*" +match_dir = "^(?!(build|docs|scripts|tests|api|proto|\\.)).*" +match = "^(?!_workflow_service_generated\\.py$).*\\.py" add_ignore = [ # We like to wrap at a certain number of chars, even long summary sentences. # https://github.com/PyCQA/pydocstyle/issues/184 @@ -211,7 +216,6 @@ exclude = [ # Exclude auto generated files "temporalio/api", "temporalio/bridge/proto", - "temporalio/bridge/_visitor.py", "tests/worker/workflow_sandbox/testmodules/proto", ] diff --git a/scripts/gen_nexus_system_models.py b/scripts/gen_nexus_system_models.py new file mode 100644 index 000000000..b3abdac52 --- /dev/null +++ b/scripts/gen_nexus_system_models.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +import importlib +import subprocess +import sys +from pathlib import Path + +NEXUS_RPC_GEN_VERSION = "0.1.0-alpha.4" + + +def main() -> None: + repo_root = Path(__file__).resolve().parent.parent + input_schema = ( + repo_root + / "temporalio" + / "bridge" + / "sdk-core" + / "crates" + / "common" + / "protos" + / "api_upstream" + / "nexus" + / "temporal-proto-models-nexusrpc.yaml" + ) + output_file = ( + repo_root / "temporalio" / "nexus" / "system" / "_workflow_service_generated.py" + ) + + if not input_schema.is_file(): + raise RuntimeError(f"Expected Nexus schema at {input_schema}") + + run_nexus_rpc_gen( + output_file=output_file, + input_schema=input_schema, + ) + add_operation_registry(repo_root, output_file) + subprocess.run( + [ + "uv", + "run", + "ruff", + "check", + "--select", + "I", + "--fix", + str(output_file), + ], + cwd=repo_root, + check=True, + ) + subprocess.run( + [ + "uv", + "run", + "ruff", + "format", + str(output_file), + ], + cwd=repo_root, + check=True, + ) + + +def add_operation_registry(repo_root: Path, output_file: Path) -> None: + source = output_file.read_text() + source = ensure_typing_import(source) + services = discover_services(repo_root) + if not services: + output_file.write_text(source) + return + output_file.write_text(source.rstrip() + "\n\n" + emit_operation_registry(services)) + + +def ensure_typing_import(source: str) -> str: + if "\nimport typing\n" in source: + return source + marker = "from __future__ import annotations\n\n" + if marker not in source: + raise RuntimeError("Expected future-annotations import in generated output") + return source.replace(marker, marker + "import typing\n", 1) + + +def discover_services(repo_root: Path) -> list[tuple[str, str, list[tuple[str, str]]]]: + module_name = "temporalio.nexus.system._workflow_service_generated" + sys.path.insert(0, str(repo_root)) + try: + sys.modules.pop(module_name, None) + importlib.invalidate_caches() + module = importlib.import_module(module_name) + finally: + sys.path.pop(0) + services: list[tuple[str, str, list[tuple[str, str]]]] = [] + for value in vars(module).values(): + if not isinstance(value, type): + continue + definition = getattr(value, "__nexus_service_definition__", None) + if definition is None: + continue + operations = [ + (operation_definition.method_name, operation_definition.name) + for operation_definition in definition.operation_definitions.values() + ] + services.append((value.__name__, definition.name, operations)) + return services + + +def emit_operation_registry( + services: list[tuple[str, str, list[tuple[str, str]]]], +) -> str: + lines = [ + "__nexus_operation_registry__: dict[", + " tuple[str, str], Operation[typing.Any, typing.Any]", + "] = {", + ] + for class_name, service_name, operations in services: + for attr_name, operation_name in operations: + lines.append( + f" ({service_name!r}, {operation_name!r}): {class_name}.{attr_name}," + ) + lines.append("}") + return "\n".join(lines).rstrip() + "\n" + + +def run_nexus_rpc_gen(*, output_file: Path, input_schema: Path) -> None: + common_args = [ + "--lang", + "py", + "--out-file", + str(output_file), + str(input_schema), + ] + subprocess.run( + ["npx", "--yes", f"nexus-rpc-gen@{NEXUS_RPC_GEN_VERSION}", *common_args], + check=True, + ) + + +if __name__ == "__main__": + try: + main() + except Exception as err: + print(f"Failed to generate Nexus system models: {err}", file=sys.stderr) + raise diff --git a/scripts/gen_payload_visitor.py b/scripts/gen_payload_visitor.py index eabfd9e6a..c54c752b1 100644 --- a/scripts/gen_payload_visitor.py +++ b/scripts/gen_payload_visitor.py @@ -1,11 +1,18 @@ import subprocess import sys +from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path -from typing import Optional +from types import ModuleType +from typing import cast +import google.protobuf.message from google.protobuf.descriptor import Descriptor, FieldDescriptor from temporalio.api.common.v1.message_pb2 import Payload, Payloads, SearchAttributes +from temporalio.api.workflowservice.v1.request_response_pb2 import ( + SignalWithStartWorkflowExecutionRequest, + SignalWithStartWorkflowExecutionResponse, +) from temporalio.bridge.proto.workflow_activation.workflow_activation_pb2 import ( WorkflowActivation, ) @@ -14,31 +21,86 @@ ) base_dir = Path(__file__).parent.parent +BASIC_IMPORTED_TYPES = { + Payload.DESCRIPTOR.full_name: "Payload", + Payloads.DESCRIPTOR.full_name: "Payloads", + SearchAttributes.DESCRIPTOR.full_name: "SearchAttributes", +} + + +def normalize_python_module(module: str) -> str: + if module.startswith("temporal.sdk.core."): + return "temporalio.bridge.proto." + module[len("temporal.sdk.core.") :] + return module + + +def nested_python_name(desc: Descriptor) -> str: + names = [desc.name] + current = desc.containing_type + while current is not None: + names.append(current.name) + current = current.containing_type + return ".".join(reversed(names)) def name_for(desc: Descriptor) -> str: - # Use fully-qualified name to avoid collisions; replace dots with underscores return desc.full_name.replace(".", "_") -def emit_loop( - field_name: str, - iter_expr: str, - child_method: str, -) -> str: - # Helper to emit a for-loop over a collection with optional headers guard +def python_type_for(desc: Descriptor) -> str: + module = desc.file.package + if module.startswith("temporal.api."): + module = "temporalio.api." + module[len("temporal.api.") :] + return f"{module}.{desc.name}" + return "object" + + +def load_generated_system_module() -> ModuleType: + module_path = ( + base_dir / "temporalio" / "nexus" / "system" / "_workflow_service_generated.py" + ) + module_name = "temporalio_nexus_system_generated" + spec = spec_from_file_location(module_name, module_path) + if spec is None or spec.loader is None: + raise RuntimeError(f"Cannot load generated system module from {module_path}") + module = module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def discover_system_nexus_roots() -> list[Descriptor]: + module = load_generated_system_module() + roots: list[Descriptor] = [] + for operation in getattr(module, "__nexus_operation_registry__", {}).values(): + for proto_type in (operation.input_type, operation.output_type): + if ( + isinstance(proto_type, type) + and issubclass(proto_type, google.protobuf.message.Message) + and proto_type.DESCRIPTOR is not None + ): + roots.append(cast(Descriptor, proto_type.DESCRIPTOR)) + deduped: list[Descriptor] = [] + seen: set[str] = set() + for root in roots: + if root.full_name not in seen: + seen.add(root.full_name) + deduped.append(root) + return deduped + + +def emit_loop(field_name: str, iter_expr: str, child_method: str) -> str: if field_name == "headers": return f"""\ if not self.skip_headers: for v in {iter_expr}: await self._visit_{child_method}(fs, v)""" - elif field_name == "search_attributes": + if field_name == "search_attributes": return f"""\ if not self.skip_search_attributes: for v in {iter_expr}: await self._visit_{child_method}(fs, v)""" - else: - return f"""\ + return f"""\ for v in {iter_expr}: await self._visit_{child_method}(fs, v)""" @@ -46,79 +108,100 @@ def emit_loop( def emit_singular( field_name: str, access_expr: str, child_method: str, presence_word: str | None ) -> str: - # Helper to emit a singular field visit with presence check and optional headers guard if presence_word: if field_name == "headers": return f"""\ if not self.skip_headers: {presence_word} o.HasField("{field_name}"): await self._visit_{child_method}(fs, {access_expr})""" - else: - return f"""\ + return f"""\ {presence_word} o.HasField("{field_name}"): await self._visit_{child_method}(fs, {access_expr})""" - else: - if field_name == "headers": - return f"""\ + if field_name == "headers": + return f"""\ if not self.skip_headers: await self._visit_{child_method}(fs, {access_expr})""" - else: - return f"""\ + return f"""\ await self._visit_{child_method}(fs, {access_expr})""" -class VisitorGenerator: - def generate(self, roots: list[Descriptor]) -> str: - """ - Generate Python source code that, given a function f(Payload) -> Payload, - applies it to every Payload contained within a WorkflowActivation tree. - - The generated code defines async visitor functions for each reachable - protobuf message type starting from WorkflowActivation, including support - for repeated fields and map entries, and a convenience entrypoint - function `visit`. - """ +class PayloadVisitorGenerator: + def __init__(self) -> None: + self.generated: dict[str, bool] = { + Payload.DESCRIPTOR.full_name: True, + Payloads.DESCRIPTOR.full_name: True, + } + self.in_progress: set[str] = set() + self.runtime_type_imports: dict[str, tuple[str, str]] = {} + self.runtime_module_imports: set[str] = set() + self.methods: list[str] = [ + """\ + async def _visit_temporal_api_common_v1_Payload( + self, fs: VisitorFunctions, o: Payload + ): + await fs.visit_payload(o) + """, + """\ + async def _visit_temporal_api_common_v1_Payloads( + self, fs: VisitorFunctions, o: Payloads + ): + await fs.visit_payloads(o.payloads) + """, + """\ + async def _visit_payload_container( + self, fs: VisitorFunctions, o: PayloadSequence + ): + await fs.visit_payloads(o) + """, + ] - for r in roots: - self.walk(r) + def generate(self, roots: list[Descriptor]) -> str: + for root in roots: + self.walk(root) + + extra_imports = "\n".join( + f"from {module} import {class_name}" + for class_name, module in sorted( + set(self.runtime_type_imports.values()), + key=lambda item: (item[1], item[0]), + ) + ) + module_imports = "\n".join( + f"import {module}" for module in sorted(self.runtime_module_imports) + ) + if extra_imports: + extra_imports += "\n" + if module_imports: + module_imports += "\n" header = """ -# This file is generated by gen_payload_visitor.py. Changes should be made there. -import abc -from typing import Any, MutableSequence - -from temporalio.api.common.v1.message_pb2 import Payload +from __future__ import annotations +# This file is generated by gen_payload_visitor.py. Changes should be made there. +from google.protobuf.message import Message -class VisitorFunctions(abc.ABC): - \"\"\"Set of functions which can be called by the visitor. - Allows handling payloads as a sequence. - \"\"\" - @abc.abstractmethod - async def visit_payload(self, payload: Payload) -> None: - \"\"\"Called when encountering a single payload.\"\"\" - raise NotImplementedError() - - @abc.abstractmethod - async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: - \"\"\"Called when encountering multiple payloads together.\"\"\" - raise NotImplementedError() +import temporalio.nexus.system +from temporalio.api.common.v1.message_pb2 import Payload, Payloads, SearchAttributes +from temporalio.bridge._visitor_functions import PayloadSequence, VisitorFunctions +""" + header += extra_imports + header += module_imports + header += """ class PayloadVisitor: - \"\"\"A visitor for payloads. + \"\"\"A visitor for payloads. Applies a function to every payload in a tree of messages. \"\"\" + def __init__( self, *, skip_search_attributes: bool = False, skip_headers: bool = False ): - \"\"\"Creates a new payload visitor.\"\"\" + \"\"\"Create a new payload visitor.\"\"\" self.skip_search_attributes = skip_search_attributes self.skip_headers = skip_headers - async def visit( - self, fs: VisitorFunctions, root: Any - ) -> None: - \"\"\"Visits the given root message with the given function.\"\"\" + async def visit(self, fs: VisitorFunctions, root: Message) -> None: + \"\"\"Visit the given root message with the given function set.\"\"\" method_name = "_visit_" + root.DESCRIPTOR.full_name.replace(".", "_") method = getattr(self, method_name, None) if method is not None: @@ -126,83 +209,90 @@ async def visit( else: raise ValueError(f"Unknown root message type: {root.DESCRIPTOR.full_name}") -""" + async def _visit_system_nexus_payload( + self, + fs: VisitorFunctions, + service: str, + operation: str, + payload: Payload, + ) -> None: + new_payload = await temporalio.nexus.system.visit_payload( + service, + operation, + payload, + fs, + self.skip_search_attributes, + ) + if new_payload is None: + await self._visit_temporal_api_common_v1_Payload(fs, payload) + return + + if new_payload is not payload: + payload.CopyFrom(new_payload) + await fs.visit_system_nexus_envelope(payload) +""" return header + "\n".join(self.methods) - def __init__(self): - # Track which message descriptors have visitor methods generated - self.generated: dict[str, bool] = { - Payload.DESCRIPTOR.full_name: True, - Payloads.DESCRIPTOR.full_name: True, - } - self.in_progress: set[str] = set() - self.methods: list[str] = [ - """\ - async def _visit_temporal_api_common_v1_Payload(self, fs, o): - await fs.visit_payload(o) - """, - """\ - async def _visit_temporal_api_common_v1_Payloads(self, fs, o): - await fs.visit_payloads(o.payloads) - """, - """\ - async def _visit_payload_container(self, fs, o): - await fs.visit_payloads(o) - """, - ] - - def check_repeated(self, child_desc, field, iter_expr) -> str | None: - # Special case for repeated payloads, handle them directly + def python_type_for_descriptor(self, desc: Descriptor) -> str: + basic_type = BASIC_IMPORTED_TYPES.get(desc.full_name) + if basic_type is not None: + return basic_type + cls = getattr(desc, "_concrete_class", None) + if cls is None: + return "Message" + module = normalize_python_module(cls.__module__) + self.runtime_module_imports.add(module) + return f"{module}.{nested_python_name(desc)}" + + def check_repeated( + self, child_desc: Descriptor, field: FieldDescriptor, iter_expr: str + ) -> str | None: if child_desc.full_name == Payload.DESCRIPTOR.full_name: return emit_singular(field.name, iter_expr, "payload_container", None) - else: - child_needed = self.walk(child_desc) - if child_needed: - return emit_loop( - field.name, - iter_expr, - name_for(child_desc), - ) - else: - return None + child_needed = self.walk(child_desc) + if child_needed: + return emit_loop(field.name, iter_expr, name_for(child_desc)) + return None def walk(self, desc: Descriptor) -> bool: key = desc.full_name if key in self.generated: return self.generated[key] if key in self.in_progress: - # Break cycles; Assume the child will be needed (Used by Failure -> Cause) return True has_payload = False self.in_progress.add(key) - lines: list[str] = [f" async def _visit_{name_for(desc)}(self, fs, o):"] - # If this is the SearchAttributes message, allow skipping + body_lines: list[str] = [] if desc.full_name == SearchAttributes.DESCRIPTOR.full_name: - lines.append(" if self.skip_search_attributes:") - lines.append(" return") + body_lines.append(" if self.skip_search_attributes:") + body_lines.append(" return") - # Group fields by oneof to generate if/elif chains oneof_fields: dict[int, list[FieldDescriptor]] = {} regular_fields: list[FieldDescriptor] = [] for field in desc.fields: if field.type != FieldDescriptor.TYPE_MESSAGE: continue - - # Skip synthetic oneofs (proto3 optional fields) if field.containing_oneof is not None: - oneof_idx = field.containing_oneof.index - if oneof_idx not in oneof_fields: - oneof_fields[oneof_idx] = [] - oneof_fields[oneof_idx].append(field) + oneof_fields.setdefault(field.containing_oneof.index, []).append(field) else: regular_fields.append(field) - # Process regular fields first for field in regular_fields: - # Repeated fields (including maps which are represented as repeated messages) + if ( + desc.full_name == "coresdk.workflow_commands.ScheduleNexusOperation" + and field.name == "input" + ): + has_payload = True + body_lines.append( + """\ + if o.HasField("input"): + await self._visit_system_nexus_payload(fs, o.service, o.operation, o.input)""" + ) + continue + if field.label == FieldDescriptor.LABEL_REPEATED: if ( field.message_type is not None @@ -217,7 +307,7 @@ def walk(self, desc: Descriptor) -> bool: child_needed = self.walk(child_desc) if child_needed: has_payload = True - lines.append( + body_lines.append( emit_loop( field.name, f"o.{field.name}.values()", @@ -234,7 +324,7 @@ def walk(self, desc: Descriptor) -> bool: child_needed = self.walk(child_desc) if child_needed: has_payload = True - lines.append( + body_lines.append( emit_loop( field.name, f"o.{field.name}.keys()", @@ -247,20 +337,19 @@ def walk(self, desc: Descriptor) -> bool: ) if child is not None: has_payload = True - lines.append(child) + body_lines.append(child) else: child_desc = field.message_type child_has_payload = self.walk(child_desc) has_payload |= child_has_payload if child_has_payload: - lines.append( + body_lines.append( emit_singular( field.name, f"o.{field.name}", name_for(child_desc), "if" ) ) - # Process oneof fields as if/elif chains - for oneof_idx, fields in oneof_fields.items(): + for fields in oneof_fields.values(): oneof_lines = [] first = True for field in fields: @@ -268,38 +357,76 @@ def walk(self, desc: Descriptor) -> bool: child_has_payload = self.walk(child_desc) has_payload |= child_has_payload if child_has_payload: - if_word = "if" if first else "elif" - first = False - line = emit_singular( - field.name, f"o.{field.name}", name_for(child_desc), if_word + oneof_lines.append( + emit_singular( + field.name, + f"o.{field.name}", + name_for(child_desc), + "if" if first else "elif", + ) ) - oneof_lines.append(line) + first = False if oneof_lines: - lines.extend(oneof_lines) + body_lines.extend(oneof_lines) self.generated[key] = has_payload self.in_progress.discard(key) if has_payload: + annotation = self.python_type_for_descriptor(desc) + lines = [ + f" async def _visit_{name_for(desc)}(" + f"self, fs: VisitorFunctions, o: {annotation}" + "):" + ] + lines.extend(body_lines) self.methods.append("\n".join(lines) + "\n") return has_payload -def write_generated_visitors_into_visitor_generated_py() -> None: - """Write the generated visitor code into _visitor.py.""" +def write_bridge_visitors() -> None: out_path = base_dir / "temporalio" / "bridge" / "_visitor.py" - - # Build root descriptors: WorkflowActivation, WorkflowActivationCompletion, - # and all messages from selected API modules - roots: list[Descriptor] = [ + roots = [ WorkflowActivation.DESCRIPTOR, WorkflowActivationCompletion.DESCRIPTOR, ] + out_path.write_text(PayloadVisitorGenerator().generate(roots)) + - code = VisitorGenerator().generate(roots) - out_path.write_text(code) +def write_system_nexus_payload_visitors() -> None: + out_path = base_dir / "temporalio" / "nexus" / "system" / "_payload_visitor.py" + roots = discover_system_nexus_roots() + out_path.write_text(PayloadVisitorGenerator().generate(roots)) if __name__ == "__main__": print("Generating temporalio/bridge/_visitor.py...", file=sys.stderr) - write_generated_visitors_into_visitor_generated_py() - subprocess.run(["uv", "run", "ruff", "format", "temporalio/bridge/_visitor.py"]) + write_bridge_visitors() + print("Generating temporalio/nexus/system/_payload_visitor.py...", file=sys.stderr) + write_system_nexus_payload_visitors() + subprocess.run( + [ + "uv", + "run", + "ruff", + "check", + "--select", + "I", + "--fix", + "temporalio/bridge/_visitor.py", + "temporalio/nexus/system/_payload_visitor.py", + ], + cwd=base_dir, + check=True, + ) + subprocess.run( + [ + "uv", + "run", + "ruff", + "format", + "temporalio/bridge/_visitor.py", + "temporalio/nexus/system/_payload_visitor.py", + ], + cwd=base_dir, + check=True, + ) diff --git a/temporalio/_workflow_requests.py b/temporalio/_workflow_requests.py new file mode 100644 index 000000000..eb2b72606 --- /dev/null +++ b/temporalio/_workflow_requests.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +from collections.abc import Sequence +from datetime import timedelta +from typing import cast + +import temporalio.api.common.v1 +import temporalio.api.enums.v1 +import temporalio.api.sdk.v1 +import temporalio.api.workflowservice.v1 +import temporalio.common + + +def populate_start_workflow_execution_request( + req: ( + temporalio.api.workflowservice.v1.StartWorkflowExecutionRequest + | temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest + ), + *, + namespace: str, + workflow_id: str, + workflow: str, + task_queue: str, + input_payloads: Sequence[temporalio.api.common.v1.Payload] = (), + execution_timeout: timedelta | None = None, + run_timeout: timedelta | None = None, + task_timeout: timedelta | None = None, + identity: str | None = None, + request_id: str | None = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: temporalio.common.RetryPolicy | None = None, + cron_schedule: str = "", + memo: temporalio.api.common.v1.Memo | None = None, + search_attributes: temporalio.api.common.v1.SearchAttributes | None = None, + header: temporalio.api.common.v1.Header | None = None, + user_metadata: temporalio.api.sdk.v1.UserMetadata | None = None, + start_delay: timedelta | None = None, + priority: temporalio.common.Priority | None = None, + versioning_override: temporalio.common.VersioningOverride | None = None, +) -> None: + """Populate a workflow-service start-style request from pre-encoded pieces.""" + req.namespace = namespace + req.workflow_id = workflow_id + req.workflow_type.name = workflow + req.task_queue.name = task_queue + if input_payloads: + req.input.payloads.extend(input_payloads) + if execution_timeout is not None: + req.workflow_execution_timeout.FromTimedelta(execution_timeout) + if run_timeout is not None: + req.workflow_run_timeout.FromTimedelta(run_timeout) + if task_timeout is not None: + req.workflow_task_timeout.FromTimedelta(task_timeout) + if identity is not None: + req.identity = identity + if request_id is not None: + req.request_id = request_id + req.workflow_id_reuse_policy = cast( + "temporalio.api.enums.v1.WorkflowIdReusePolicy.ValueType", + int(id_reuse_policy), + ) + req.workflow_id_conflict_policy = cast( + "temporalio.api.enums.v1.WorkflowIdConflictPolicy.ValueType", + int(id_conflict_policy), + ) + if retry_policy is not None: + retry_policy.apply_to_proto(req.retry_policy) + req.cron_schedule = cron_schedule + if memo is not None: + req.memo.CopyFrom(memo) + if search_attributes is not None: + req.search_attributes.CopyFrom(search_attributes) + if header is not None: + req.header.CopyFrom(header) + if user_metadata is not None: + req.user_metadata.CopyFrom(user_metadata) + if start_delay is not None: + req.workflow_start_delay.FromTimedelta(start_delay) + if priority is not None: + req.priority.CopyFrom(priority._to_proto()) + if versioning_override is not None: + req.versioning_override.CopyFrom(versioning_override._to_proto()) + + +def build_signal_with_start_workflow_execution_request( + *, + namespace: str, + workflow_id: str, + workflow: str, + task_queue: str, + signal_name: str, + workflow_input_payloads: Sequence[temporalio.api.common.v1.Payload] = (), + signal_input_payloads: Sequence[temporalio.api.common.v1.Payload] = (), + execution_timeout: timedelta | None = None, + run_timeout: timedelta | None = None, + task_timeout: timedelta | None = None, + identity: str | None = None, + request_id: str | None = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: temporalio.common.RetryPolicy | None = None, + cron_schedule: str = "", + memo: temporalio.api.common.v1.Memo | None = None, + search_attributes: temporalio.api.common.v1.SearchAttributes | None = None, + header: temporalio.api.common.v1.Header | None = None, + user_metadata: temporalio.api.sdk.v1.UserMetadata | None = None, + start_delay: timedelta | None = None, + priority: temporalio.common.Priority | None = None, + versioning_override: temporalio.common.VersioningOverride | None = None, +) -> temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest: + """Build a signal-with-start workflow-service request from pre-encoded pieces.""" + req = temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest( + signal_name=signal_name + ) + if signal_input_payloads: + req.signal_input.payloads.extend(signal_input_payloads) + populate_start_workflow_execution_request( + req, + namespace=namespace, + workflow_id=workflow_id, + workflow=workflow, + task_queue=task_queue, + input_payloads=workflow_input_payloads, + execution_timeout=execution_timeout, + run_timeout=run_timeout, + task_timeout=task_timeout, + identity=identity, + request_id=request_id, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + cron_schedule=cron_schedule, + memo=memo, + search_attributes=search_attributes, + header=header, + user_metadata=user_metadata, + start_delay=start_delay, + priority=priority, + versioning_override=versioning_override, + ) + return req diff --git a/temporalio/bridge/_visitor.py b/temporalio/bridge/_visitor.py index 16876fb59..897e256e1 100644 --- a/temporalio/bridge/_visitor.py +++ b/temporalio/bridge/_visitor.py @@ -1,24 +1,20 @@ -# This file is generated by gen_payload_visitor.py. Changes should be made there. -import abc -from typing import Any, MutableSequence - -from temporalio.api.common.v1.message_pb2 import Payload +from __future__ import annotations - -class VisitorFunctions(abc.ABC): - """Set of functions which can be called by the visitor. - Allows handling payloads as a sequence. - """ - - @abc.abstractmethod - async def visit_payload(self, payload: Payload) -> None: - """Called when encountering a single payload.""" - raise NotImplementedError() - - @abc.abstractmethod - async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: - """Called when encountering multiple payloads together.""" - raise NotImplementedError() +# This file is generated by gen_payload_visitor.py. Changes should be made there. +from google.protobuf.message import Message + +import temporalio.api.common.v1.message_pb2 +import temporalio.api.failure.v1.message_pb2 +import temporalio.api.sdk.v1.user_metadata_pb2 +import temporalio.bridge.proto.activity_result.activity_result_pb2 +import temporalio.bridge.proto.child_workflow.child_workflow_pb2 +import temporalio.bridge.proto.nexus.nexus_pb2 +import temporalio.bridge.proto.workflow_activation.workflow_activation_pb2 +import temporalio.bridge.proto.workflow_commands.workflow_commands_pb2 +import temporalio.bridge.proto.workflow_completion.workflow_completion_pb2 +import temporalio.nexus.system +from temporalio.api.common.v1.message_pb2 import Payload, Payloads, SearchAttributes +from temporalio.bridge._visitor_functions import PayloadSequence, VisitorFunctions class PayloadVisitor: @@ -29,12 +25,12 @@ class PayloadVisitor: def __init__( self, *, skip_search_attributes: bool = False, skip_headers: bool = False ): - """Creates a new payload visitor.""" + """Create a new payload visitor.""" self.skip_search_attributes = skip_search_attributes self.skip_headers = skip_headers - async def visit(self, fs: VisitorFunctions, root: Any) -> None: - """Visits the given root message with the given function.""" + async def visit(self, fs: VisitorFunctions, root: Message) -> None: + """Visit the given root message with the given function set.""" method_name = "_visit_" + root.DESCRIPTOR.full_name.replace(".", "_") method = getattr(self, method_name, None) if method is not None: @@ -42,36 +38,80 @@ async def visit(self, fs: VisitorFunctions, root: Any) -> None: else: raise ValueError(f"Unknown root message type: {root.DESCRIPTOR.full_name}") - async def _visit_temporal_api_common_v1_Payload(self, fs, o): + async def _visit_system_nexus_payload( + self, + fs: VisitorFunctions, + service: str, + operation: str, + payload: Payload, + ) -> None: + new_payload = await temporalio.nexus.system.visit_payload( + service, + operation, + payload, + fs, + self.skip_search_attributes, + ) + if new_payload is None: + await self._visit_temporal_api_common_v1_Payload(fs, payload) + return + + if new_payload is not payload: + payload.CopyFrom(new_payload) + await fs.visit_system_nexus_envelope(payload) + + async def _visit_temporal_api_common_v1_Payload( + self, fs: VisitorFunctions, o: Payload + ): await fs.visit_payload(o) - async def _visit_temporal_api_common_v1_Payloads(self, fs, o): + async def _visit_temporal_api_common_v1_Payloads( + self, fs: VisitorFunctions, o: Payloads + ): await fs.visit_payloads(o.payloads) - async def _visit_payload_container(self, fs, o): + async def _visit_payload_container(self, fs: VisitorFunctions, o: PayloadSequence): await fs.visit_payloads(o) - async def _visit_temporal_api_failure_v1_ApplicationFailureInfo(self, fs, o): + async def _visit_temporal_api_failure_v1_ApplicationFailureInfo( + self, + fs: VisitorFunctions, + o: temporalio.api.failure.v1.message_pb2.ApplicationFailureInfo, + ): if o.HasField("details"): await self._visit_temporal_api_common_v1_Payloads(fs, o.details) - async def _visit_temporal_api_failure_v1_TimeoutFailureInfo(self, fs, o): + async def _visit_temporal_api_failure_v1_TimeoutFailureInfo( + self, + fs: VisitorFunctions, + o: temporalio.api.failure.v1.message_pb2.TimeoutFailureInfo, + ): if o.HasField("last_heartbeat_details"): await self._visit_temporal_api_common_v1_Payloads( fs, o.last_heartbeat_details ) - async def _visit_temporal_api_failure_v1_CanceledFailureInfo(self, fs, o): + async def _visit_temporal_api_failure_v1_CanceledFailureInfo( + self, + fs: VisitorFunctions, + o: temporalio.api.failure.v1.message_pb2.CanceledFailureInfo, + ): if o.HasField("details"): await self._visit_temporal_api_common_v1_Payloads(fs, o.details) - async def _visit_temporal_api_failure_v1_ResetWorkflowFailureInfo(self, fs, o): + async def _visit_temporal_api_failure_v1_ResetWorkflowFailureInfo( + self, + fs: VisitorFunctions, + o: temporalio.api.failure.v1.message_pb2.ResetWorkflowFailureInfo, + ): if o.HasField("last_heartbeat_details"): await self._visit_temporal_api_common_v1_Payloads( fs, o.last_heartbeat_details ) - async def _visit_temporal_api_failure_v1_Failure(self, fs, o): + async def _visit_temporal_api_failure_v1_Failure( + self, fs: VisitorFunctions, o: temporalio.api.failure.v1.message_pb2.Failure + ): if o.HasField("encoded_attributes"): await self._visit_temporal_api_common_v1_Payload(fs, o.encoded_attributes) if o.HasField("cause"): @@ -93,17 +133,25 @@ async def _visit_temporal_api_failure_v1_Failure(self, fs, o): fs, o.reset_workflow_failure_info ) - async def _visit_temporal_api_common_v1_Memo(self, fs, o): + async def _visit_temporal_api_common_v1_Memo( + self, fs: VisitorFunctions, o: temporalio.api.common.v1.message_pb2.Memo + ): for v in o.fields.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) - async def _visit_temporal_api_common_v1_SearchAttributes(self, fs, o): + async def _visit_temporal_api_common_v1_SearchAttributes( + self, fs: VisitorFunctions, o: SearchAttributes + ): if self.skip_search_attributes: return for v in o.indexed_fields.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) - async def _visit_coresdk_workflow_activation_InitializeWorkflow(self, fs, o): + async def _visit_coresdk_workflow_activation_InitializeWorkflow( + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_activation.workflow_activation_pb2.InitializeWorkflow, + ): await self._visit_payload_container(fs, o.arguments) if not self.skip_headers: for v in o.headers.values(): @@ -121,31 +169,55 @@ async def _visit_coresdk_workflow_activation_InitializeWorkflow(self, fs, o): fs, o.search_attributes ) - async def _visit_coresdk_workflow_activation_QueryWorkflow(self, fs, o): + async def _visit_coresdk_workflow_activation_QueryWorkflow( + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_activation.workflow_activation_pb2.QueryWorkflow, + ): await self._visit_payload_container(fs, o.arguments) if not self.skip_headers: for v in o.headers.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) - async def _visit_coresdk_workflow_activation_SignalWorkflow(self, fs, o): + async def _visit_coresdk_workflow_activation_SignalWorkflow( + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_activation.workflow_activation_pb2.SignalWorkflow, + ): await self._visit_payload_container(fs, o.input) if not self.skip_headers: for v in o.headers.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) - async def _visit_coresdk_activity_result_Success(self, fs, o): + async def _visit_coresdk_activity_result_Success( + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.activity_result.activity_result_pb2.Success, + ): if o.HasField("result"): await self._visit_temporal_api_common_v1_Payload(fs, o.result) - async def _visit_coresdk_activity_result_Failure(self, fs, o): + async def _visit_coresdk_activity_result_Failure( + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.activity_result.activity_result_pb2.Failure, + ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) - async def _visit_coresdk_activity_result_Cancellation(self, fs, o): + async def _visit_coresdk_activity_result_Cancellation( + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.activity_result.activity_result_pb2.Cancellation, + ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) - async def _visit_coresdk_activity_result_ActivityResolution(self, fs, o): + async def _visit_coresdk_activity_result_ActivityResolution( + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.activity_result.activity_result_pb2.ActivityResolution, + ): if o.HasField("completed"): await self._visit_coresdk_activity_result_Success(fs, o.completed) elif o.HasField("failed"): @@ -153,37 +225,61 @@ async def _visit_coresdk_activity_result_ActivityResolution(self, fs, o): elif o.HasField("cancelled"): await self._visit_coresdk_activity_result_Cancellation(fs, o.cancelled) - async def _visit_coresdk_workflow_activation_ResolveActivity(self, fs, o): + async def _visit_coresdk_workflow_activation_ResolveActivity( + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_activation.workflow_activation_pb2.ResolveActivity, + ): if o.HasField("result"): await self._visit_coresdk_activity_result_ActivityResolution(fs, o.result) async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStartCancelled( - self, fs, o + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_activation.workflow_activation_pb2.ResolveChildWorkflowExecutionStartCancelled, ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStart( - self, fs, o + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_activation.workflow_activation_pb2.ResolveChildWorkflowExecutionStart, ): if o.HasField("cancelled"): await self._visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStartCancelled( fs, o.cancelled ) - async def _visit_coresdk_child_workflow_Success(self, fs, o): + async def _visit_coresdk_child_workflow_Success( + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.child_workflow.child_workflow_pb2.Success, + ): if o.HasField("result"): await self._visit_temporal_api_common_v1_Payload(fs, o.result) - async def _visit_coresdk_child_workflow_Failure(self, fs, o): + async def _visit_coresdk_child_workflow_Failure( + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.child_workflow.child_workflow_pb2.Failure, + ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) - async def _visit_coresdk_child_workflow_Cancellation(self, fs, o): + async def _visit_coresdk_child_workflow_Cancellation( + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.child_workflow.child_workflow_pb2.Cancellation, + ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) - async def _visit_coresdk_child_workflow_ChildWorkflowResult(self, fs, o): + async def _visit_coresdk_child_workflow_ChildWorkflowResult( + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.child_workflow.child_workflow_pb2.ChildWorkflowResult, + ): if o.HasField("completed"): await self._visit_coresdk_child_workflow_Success(fs, o.completed) elif o.HasField("failed"): @@ -192,36 +288,52 @@ async def _visit_coresdk_child_workflow_ChildWorkflowResult(self, fs, o): await self._visit_coresdk_child_workflow_Cancellation(fs, o.cancelled) async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecution( - self, fs, o + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_activation.workflow_activation_pb2.ResolveChildWorkflowExecution, ): if o.HasField("result"): await self._visit_coresdk_child_workflow_ChildWorkflowResult(fs, o.result) async def _visit_coresdk_workflow_activation_ResolveSignalExternalWorkflow( - self, fs, o + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_activation.workflow_activation_pb2.ResolveSignalExternalWorkflow, ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) async def _visit_coresdk_workflow_activation_ResolveRequestCancelExternalWorkflow( - self, fs, o + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_activation.workflow_activation_pb2.ResolveRequestCancelExternalWorkflow, ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) - async def _visit_coresdk_workflow_activation_DoUpdate(self, fs, o): + async def _visit_coresdk_workflow_activation_DoUpdate( + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_activation.workflow_activation_pb2.DoUpdate, + ): await self._visit_payload_container(fs, o.input) if not self.skip_headers: for v in o.headers.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) async def _visit_coresdk_workflow_activation_ResolveNexusOperationStart( - self, fs, o + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_activation.workflow_activation_pb2.ResolveNexusOperationStart, ): if o.HasField("failed"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failed) - async def _visit_coresdk_nexus_NexusOperationResult(self, fs, o): + async def _visit_coresdk_nexus_NexusOperationResult( + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.nexus.nexus_pb2.NexusOperationResult, + ): if o.HasField("completed"): await self._visit_temporal_api_common_v1_Payload(fs, o.completed) elif o.HasField("failed"): @@ -231,11 +343,19 @@ async def _visit_coresdk_nexus_NexusOperationResult(self, fs, o): elif o.HasField("timed_out"): await self._visit_temporal_api_failure_v1_Failure(fs, o.timed_out) - async def _visit_coresdk_workflow_activation_ResolveNexusOperation(self, fs, o): + async def _visit_coresdk_workflow_activation_ResolveNexusOperation( + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_activation.workflow_activation_pb2.ResolveNexusOperation, + ): if o.HasField("result"): await self._visit_coresdk_nexus_NexusOperationResult(fs, o.result) - async def _visit_coresdk_workflow_activation_WorkflowActivationJob(self, fs, o): + async def _visit_coresdk_workflow_activation_WorkflowActivationJob( + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_activation.workflow_activation_pb2.WorkflowActivationJob, + ): if o.HasField("initialize_workflow"): await self._visit_coresdk_workflow_activation_InitializeWorkflow( fs, o.initialize_workflow @@ -279,42 +399,72 @@ async def _visit_coresdk_workflow_activation_WorkflowActivationJob(self, fs, o): fs, o.resolve_nexus_operation ) - async def _visit_coresdk_workflow_activation_WorkflowActivation(self, fs, o): + async def _visit_coresdk_workflow_activation_WorkflowActivation( + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_activation.workflow_activation_pb2.WorkflowActivation, + ): for v in o.jobs: await self._visit_coresdk_workflow_activation_WorkflowActivationJob(fs, v) - async def _visit_temporal_api_sdk_v1_UserMetadata(self, fs, o): + async def _visit_temporal_api_sdk_v1_UserMetadata( + self, + fs: VisitorFunctions, + o: temporalio.api.sdk.v1.user_metadata_pb2.UserMetadata, + ): if o.HasField("summary"): await self._visit_temporal_api_common_v1_Payload(fs, o.summary) if o.HasField("details"): await self._visit_temporal_api_common_v1_Payload(fs, o.details) - async def _visit_coresdk_workflow_commands_ScheduleActivity(self, fs, o): + async def _visit_coresdk_workflow_commands_ScheduleActivity( + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_commands.workflow_commands_pb2.ScheduleActivity, + ): if not self.skip_headers: for v in o.headers.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) await self._visit_payload_container(fs, o.arguments) - async def _visit_coresdk_workflow_commands_QuerySuccess(self, fs, o): + async def _visit_coresdk_workflow_commands_QuerySuccess( + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_commands.workflow_commands_pb2.QuerySuccess, + ): if o.HasField("response"): await self._visit_temporal_api_common_v1_Payload(fs, o.response) - async def _visit_coresdk_workflow_commands_QueryResult(self, fs, o): + async def _visit_coresdk_workflow_commands_QueryResult( + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_commands.workflow_commands_pb2.QueryResult, + ): if o.HasField("succeeded"): await self._visit_coresdk_workflow_commands_QuerySuccess(fs, o.succeeded) elif o.HasField("failed"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failed) - async def _visit_coresdk_workflow_commands_CompleteWorkflowExecution(self, fs, o): + async def _visit_coresdk_workflow_commands_CompleteWorkflowExecution( + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_commands.workflow_commands_pb2.CompleteWorkflowExecution, + ): if o.HasField("result"): await self._visit_temporal_api_common_v1_Payload(fs, o.result) - async def _visit_coresdk_workflow_commands_FailWorkflowExecution(self, fs, o): + async def _visit_coresdk_workflow_commands_FailWorkflowExecution( + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_commands.workflow_commands_pb2.FailWorkflowExecution, + ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) async def _visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution( - self, fs, o + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_commands.workflow_commands_pb2.ContinueAsNewWorkflowExecution, ): await self._visit_payload_container(fs, o.arguments) for v in o.memo.values(): @@ -327,7 +477,11 @@ async def _visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution( fs, o.search_attributes ) - async def _visit_coresdk_workflow_commands_StartChildWorkflowExecution(self, fs, o): + async def _visit_coresdk_workflow_commands_StartChildWorkflowExecution( + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_commands.workflow_commands_pb2.StartChildWorkflowExecution, + ): await self._visit_payload_container(fs, o.input) if not self.skip_headers: for v in o.headers.values(): @@ -340,42 +494,66 @@ async def _visit_coresdk_workflow_commands_StartChildWorkflowExecution(self, fs, ) async def _visit_coresdk_workflow_commands_SignalExternalWorkflowExecution( - self, fs, o + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_commands.workflow_commands_pb2.SignalExternalWorkflowExecution, ): await self._visit_payload_container(fs, o.args) if not self.skip_headers: for v in o.headers.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) - async def _visit_coresdk_workflow_commands_ScheduleLocalActivity(self, fs, o): + async def _visit_coresdk_workflow_commands_ScheduleLocalActivity( + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_commands.workflow_commands_pb2.ScheduleLocalActivity, + ): if not self.skip_headers: for v in o.headers.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) await self._visit_payload_container(fs, o.arguments) async def _visit_coresdk_workflow_commands_UpsertWorkflowSearchAttributes( - self, fs, o + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_commands.workflow_commands_pb2.UpsertWorkflowSearchAttributes, ): if o.HasField("search_attributes"): await self._visit_temporal_api_common_v1_SearchAttributes( fs, o.search_attributes ) - async def _visit_coresdk_workflow_commands_ModifyWorkflowProperties(self, fs, o): + async def _visit_coresdk_workflow_commands_ModifyWorkflowProperties( + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_commands.workflow_commands_pb2.ModifyWorkflowProperties, + ): if o.HasField("upserted_memo"): await self._visit_temporal_api_common_v1_Memo(fs, o.upserted_memo) - async def _visit_coresdk_workflow_commands_UpdateResponse(self, fs, o): + async def _visit_coresdk_workflow_commands_UpdateResponse( + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_commands.workflow_commands_pb2.UpdateResponse, + ): if o.HasField("rejected"): await self._visit_temporal_api_failure_v1_Failure(fs, o.rejected) elif o.HasField("completed"): await self._visit_temporal_api_common_v1_Payload(fs, o.completed) - async def _visit_coresdk_workflow_commands_ScheduleNexusOperation(self, fs, o): + async def _visit_coresdk_workflow_commands_ScheduleNexusOperation( + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_commands.workflow_commands_pb2.ScheduleNexusOperation, + ): if o.HasField("input"): - await self._visit_temporal_api_common_v1_Payload(fs, o.input) + await self._visit_system_nexus_payload(fs, o.service, o.operation, o.input) - async def _visit_coresdk_workflow_commands_WorkflowCommand(self, fs, o): + async def _visit_coresdk_workflow_commands_WorkflowCommand( + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_commands.workflow_commands_pb2.WorkflowCommand, + ): if o.HasField("user_metadata"): await self._visit_temporal_api_sdk_v1_UserMetadata(fs, o.user_metadata) if o.HasField("schedule_activity"): @@ -427,16 +605,26 @@ async def _visit_coresdk_workflow_commands_WorkflowCommand(self, fs, o): fs, o.schedule_nexus_operation ) - async def _visit_coresdk_workflow_completion_Success(self, fs, o): + async def _visit_coresdk_workflow_completion_Success( + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_completion.workflow_completion_pb2.Success, + ): for v in o.commands: await self._visit_coresdk_workflow_commands_WorkflowCommand(fs, v) - async def _visit_coresdk_workflow_completion_Failure(self, fs, o): + async def _visit_coresdk_workflow_completion_Failure( + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_completion.workflow_completion_pb2.Failure, + ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) async def _visit_coresdk_workflow_completion_WorkflowActivationCompletion( - self, fs, o + self, + fs: VisitorFunctions, + o: temporalio.bridge.proto.workflow_completion.workflow_completion_pb2.WorkflowActivationCompletion, ): if o.HasField("successful"): await self._visit_coresdk_workflow_completion_Success(fs, o.successful) diff --git a/temporalio/bridge/_visitor_functions.py b/temporalio/bridge/_visitor_functions.py new file mode 100644 index 000000000..b2d5be153 --- /dev/null +++ b/temporalio/bridge/_visitor_functions.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from typing import Protocol + +from google.protobuf.internal.containers import RepeatedCompositeFieldContainer + +from temporalio.api.common.v1.message_pb2 import Payload + +PayloadSequence = list[Payload] | RepeatedCompositeFieldContainer[Payload] + + +class VisitorFunctions(Protocol): + """Functions invoked by generated payload visitors.""" + + async def visit_payload(self, payload: Payload) -> None: + """Visit a single payload.""" + ... + + async def visit_payloads(self, payloads: PayloadSequence) -> None: + """Visit a sequence of payloads together.""" + ... + + async def visit_system_nexus_envelope(self, payload: Payload) -> None: + """Visit a recognized system Nexus envelope payload.""" + ... diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index c2e426d28..b0193e2fc 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -5,11 +5,9 @@ from __future__ import annotations -from collections.abc import Awaitable, Callable, MutableSequence, Sequence +from collections.abc import Awaitable, Callable, Sequence from dataclasses import dataclass -from typing import ( - TypeAlias, -) +from typing import TypeAlias import temporalio.bridge.client import temporalio.bridge.proto @@ -21,7 +19,7 @@ import temporalio.bridge.temporal_sdk_bridge import temporalio.converter from temporalio.api.common.v1.message_pb2 import Payload -from temporalio.bridge._visitor import VisitorFunctions +from temporalio.bridge._visitor_functions import PayloadSequence, VisitorFunctions from temporalio.bridge.temporal_sdk_bridge import ( CustomSlotSupplier as BridgeCustomSlotSupplier, ) @@ -279,15 +277,20 @@ async def finalize_shutdown(self) -> None: class _Visitor(VisitorFunctions): - def __init__(self, f: Callable[[Sequence[Payload]], Awaitable[list[Payload]]]): + def __init__( + self, + f: Callable[[Sequence[Payload]], Awaitable[list[Payload]]], + visit_system_nexus_envelope: Callable[[Payload], Awaitable[None]] | None = None, + ): self._f = f + self._visit_system_nexus_envelope = visit_system_nexus_envelope async def visit_payload(self, payload: Payload) -> None: new_payload = (await self._f([payload]))[0] if new_payload is not payload: payload.CopyFrom(new_payload) - async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: + async def visit_payloads(self, payloads: PayloadSequence) -> None: if len(payloads) == 0: return new_payloads = await self._f(payloads) @@ -296,6 +299,10 @@ async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: del payloads[:] payloads.extend(new_payloads) + async def visit_system_nexus_envelope(self, payload: Payload) -> None: + if self._visit_system_nexus_envelope is not None: + await self._visit_system_nexus_envelope(payload) + async def decode_activation( activation: temporalio.bridge.proto.workflow_activation.WorkflowActivation, @@ -314,6 +321,16 @@ async def encode_completion( encode_headers: bool, ) -> None: """Encode all payloads in the completion.""" + + async def visit_system_nexus_envelope(payload: Payload) -> None: + data_converter._validate_payload_limits([payload]) + await CommandAwarePayloadVisitor( skip_search_attributes=True, skip_headers=not encode_headers - ).visit(_Visitor(data_converter._encode_payload_sequence), completion) + ).visit( + _Visitor( + data_converter._encode_payload_sequence, + visit_system_nexus_envelope=visit_system_nexus_envelope, + ), + completion, + ) diff --git a/temporalio/client.py b/temporalio/client.py index 22b07b1c1..9bd6b440a 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -39,6 +39,7 @@ from google.protobuf.internal.containers import MessageMap from typing_extensions import Required, Self, TypedDict +import temporalio._workflow_requests import temporalio.activity import temporalio.api.activity.v1 import temporalio.api.common.v1 @@ -8083,15 +8084,51 @@ async def _build_signal_with_start_workflow_execution_request( workflow_id=input.id, ) ) - req = temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest( - signal_name=input.start_signal - ) - if input.start_signal_args: - req.signal_input.payloads.extend( - await data_converter.encode(input.start_signal_args) + request_memo = None + if input.memo is not None: + request_memo = temporalio.api.common.v1.Memo() + await data_converter._encode_memo_existing(input.memo, request_memo) + request_search_attributes = None + if input.search_attributes is not None: + request_search_attributes = temporalio.api.common.v1.SearchAttributes() + temporalio.converter.encode_search_attributes( + input.search_attributes, request_search_attributes ) - await self._populate_start_workflow_execution_request(req, input) - return req + request_header = None + if input.headers is not None: # type:ignore[reportUnnecessaryComparison] + request_header = temporalio.api.common.v1.Header() + await self._apply_headers(input.headers, request_header.fields) + return temporalio._workflow_requests.build_signal_with_start_workflow_execution_request( + namespace=self._client.namespace, + workflow_id=input.id, + workflow=input.workflow, + task_queue=input.task_queue, + signal_name=input.start_signal, + workflow_input_payloads=await data_converter.encode(input.args) + if input.args + else (), + signal_input_payloads=await data_converter.encode(input.start_signal_args) + if input.start_signal_args + else (), + execution_timeout=input.execution_timeout, + run_timeout=input.run_timeout, + task_timeout=input.task_timeout, + identity=self._client.identity, + request_id=str(uuid.uuid4()), + id_reuse_policy=input.id_reuse_policy, + id_conflict_policy=input.id_conflict_policy, + retry_policy=input.retry_policy, + cron_schedule=input.cron_schedule, + memo=request_memo, + search_attributes=request_search_attributes, + header=request_header, + user_metadata=await _encode_user_metadata( + data_converter, input.static_summary, input.static_details + ), + start_delay=input.start_delay, + priority=input.priority, + versioning_override=input.versioning_override, + ) async def _build_update_with_start_start_workflow_execution_request( self, input: UpdateWithStartStartWorkflowInput @@ -8114,51 +8151,48 @@ async def _populate_start_workflow_execution_request( workflow_id=input.id, ) ) - req.namespace = self._client.namespace - req.workflow_id = input.id - req.workflow_type.name = input.workflow - req.task_queue.name = input.task_queue - if input.args: - req.input.payloads.extend(await data_converter.encode(input.args)) - if input.execution_timeout is not None: - req.workflow_execution_timeout.FromTimedelta(input.execution_timeout) - if input.run_timeout is not None: - req.workflow_run_timeout.FromTimedelta(input.run_timeout) - if input.task_timeout is not None: - req.workflow_task_timeout.FromTimedelta(input.task_timeout) - req.identity = self._client.identity - req.request_id = str(uuid.uuid4()) - req.workflow_id_reuse_policy = cast( - "temporalio.api.enums.v1.WorkflowIdReusePolicy.ValueType", - int(input.id_reuse_policy), - ) - req.workflow_id_conflict_policy = cast( - "temporalio.api.enums.v1.WorkflowIdConflictPolicy.ValueType", - int(input.id_conflict_policy), - ) - - if input.retry_policy is not None: - input.retry_policy.apply_to_proto(req.retry_policy) - req.cron_schedule = input.cron_schedule + request_memo = None if input.memo is not None: - await data_converter._encode_memo_existing(input.memo, req.memo) + request_memo = temporalio.api.common.v1.Memo() + await data_converter._encode_memo_existing(input.memo, request_memo) + request_search_attributes = None if input.search_attributes is not None: + request_search_attributes = temporalio.api.common.v1.SearchAttributes() temporalio.converter.encode_search_attributes( - input.search_attributes, req.search_attributes + input.search_attributes, request_search_attributes ) - metadata = await _encode_user_metadata( - data_converter, input.static_summary, input.static_details - ) - if metadata is not None: - req.user_metadata.CopyFrom(metadata) - if input.start_delay is not None: - req.workflow_start_delay.FromTimedelta(input.start_delay) + request_header = None if input.headers is not None: # type:ignore[reportUnnecessaryComparison] - await self._apply_headers(input.headers, req.header.fields) - if input.priority is not None: # type:ignore[reportUnnecessaryComparison] - req.priority.CopyFrom(input.priority._to_proto()) - if input.versioning_override is not None: - req.versioning_override.CopyFrom(input.versioning_override._to_proto()) + request_header = temporalio.api.common.v1.Header() + await self._apply_headers(input.headers, request_header.fields) + temporalio._workflow_requests.populate_start_workflow_execution_request( + req, + namespace=self._client.namespace, + workflow_id=input.id, + workflow=input.workflow, + task_queue=input.task_queue, + input_payloads=await data_converter.encode(input.args) + if input.args + else (), + execution_timeout=input.execution_timeout, + run_timeout=input.run_timeout, + task_timeout=input.task_timeout, + identity=self._client.identity, + request_id=str(uuid.uuid4()), + id_reuse_policy=input.id_reuse_policy, + id_conflict_policy=input.id_conflict_policy, + retry_policy=input.retry_policy, + cron_schedule=input.cron_schedule, + memo=request_memo, + search_attributes=request_search_attributes, + header=request_header, + user_metadata=await _encode_user_metadata( + data_converter, input.static_summary, input.static_details + ), + start_delay=input.start_delay, + priority=input.priority, + versioning_override=input.versioning_override, + ) async def cancel_workflow(self, input: CancelWorkflowInput) -> None: await self._client.workflow_service.request_cancel_workflow_execution( diff --git a/temporalio/contrib/opentelemetry/_interceptor.py b/temporalio/contrib/opentelemetry/_interceptor.py index 69a2cfb0c..7e19749c9 100644 --- a/temporalio/contrib/opentelemetry/_interceptor.py +++ b/temporalio/contrib/opentelemetry/_interceptor.py @@ -772,6 +772,16 @@ async def signal_external_workflow( ) await super().signal_external_workflow(input) + async def signal_with_start_workflow( + self, input: temporalio.worker.SignalWithStartWorkflowInput + ) -> temporalio.workflow.ExternalWorkflowHandle[Any]: + self.root._completed_span( + f"SignalWithStartWorkflow:{input.signal}", + add_to_outbound=input, + kind=opentelemetry.trace.SpanKind.CLIENT, + ) + return await super().signal_with_start_workflow(input) + def start_activity( self, input: temporalio.worker.StartActivityInput ) -> temporalio.workflow.ActivityHandle: diff --git a/temporalio/contrib/opentelemetry/_otel_interceptor.py b/temporalio/contrib/opentelemetry/_otel_interceptor.py index 1756f93e1..f60473c63 100644 --- a/temporalio/contrib/opentelemetry/_otel_interceptor.py +++ b/temporalio/contrib/opentelemetry/_otel_interceptor.py @@ -545,6 +545,16 @@ async def signal_external_workflow( input.headers = _context_to_headers(input.headers) await super().signal_external_workflow(input) + async def signal_with_start_workflow( + self, input: temporalio.worker.SignalWithStartWorkflowInput + ) -> temporalio.workflow.ExternalWorkflowHandle[Any]: + with self._workflow_maybe_span( + f"SignalWithStartWorkflow:{input.signal}", + kind=opentelemetry.trace.SpanKind.CLIENT, + ): + input.headers = _context_to_headers(input.headers) + return await super().signal_with_start_workflow(input) + def start_activity( self, input: temporalio.worker.StartActivityInput ) -> temporalio.workflow.ActivityHandle: diff --git a/temporalio/nexus/system/__init__.py b/temporalio/nexus/system/__init__.py new file mode 100644 index 000000000..cf51e9714 --- /dev/null +++ b/temporalio/nexus/system/__init__.py @@ -0,0 +1,159 @@ +"""Generated system Nexus service models.""" + +from collections.abc import Mapping, Sequence +from datetime import timedelta +from typing import Any, cast + +import temporalio.api.common.v1 +import temporalio.api.sdk.v1 +import temporalio.api.workflowservice.v1 +import temporalio.common +import temporalio.converter + +from ... import _workflow_requests +from ...bridge._visitor_functions import VisitorFunctions +from ...converter import BinaryProtoPayloadConverter, CompositePayloadConverter +from . import _workflow_service_generated as generated +from ._payload_visitor import PayloadVisitor + +_SYSTEM_NEXUS_ENDPOINT = "__temporal_system" + + +class SystemNexusPayloadConverter(CompositePayloadConverter): + """Payload converter for system Nexus outer envelopes.""" + + def __init__(self) -> None: + """Create a payload converter for system Nexus outer envelopes.""" + super().__init__(BinaryProtoPayloadConverter()) + + +def _build_user_metadata( + payload_converter: temporalio.converter.PayloadConverter, + static_summary: str | None, + static_details: str | None, +) -> temporalio.api.sdk.v1.UserMetadata | None: + if static_summary is None and static_details is None: + return None + metadata = temporalio.api.sdk.v1.UserMetadata() + if static_summary is not None: + metadata.summary.CopyFrom(payload_converter.to_payload(static_summary)) + if static_details is not None: + metadata.details.CopyFrom(payload_converter.to_payload(static_details)) + return metadata + + +def build_signal_with_start_workflow_execution_input( + *, + namespace: str, + workflow_id: str, + workflow: str, + workflow_args: Sequence[Any], + signal: str, + signal_args: Sequence[Any], + task_queue: str, + request_id: str | None, + payload_converter: temporalio.converter.PayloadConverter, + headers: Mapping[str, temporalio.api.common.v1.Payload], + execution_timeout: timedelta | None = None, + run_timeout: timedelta | None = None, + task_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: temporalio.common.RetryPolicy | None = None, + cron_schedule: str = "", + memo: Mapping[str, Any] | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + static_summary: str | None = None, + static_details: str | None = None, + start_delay: timedelta | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: temporalio.common.VersioningOverride | None = None, +) -> temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest: + """Build the system Nexus signal-with-start request.""" + request_memo = None + if memo is not None: + request_memo = temporalio.api.common.v1.Memo() + for key, value in memo.items(): + request_memo.fields[key].CopyFrom(payload_converter.to_payload(value)) + request_search_attributes = None + if search_attributes is not None: + request_search_attributes = temporalio.api.common.v1.SearchAttributes() + temporalio.converter.encode_search_attributes( + search_attributes, request_search_attributes + ) + request_header = None + if headers: + request_header = temporalio.api.common.v1.Header() + temporalio.common._apply_headers(headers, request_header.fields) + return _workflow_requests.build_signal_with_start_workflow_execution_request( + namespace=namespace, + workflow_id=workflow_id, + workflow=workflow, + task_queue=task_queue, + signal_name=signal, + workflow_input_payloads=payload_converter.to_payloads(workflow_args) + if workflow_args + else (), + signal_input_payloads=payload_converter.to_payloads(signal_args) + if signal_args + else (), + execution_timeout=execution_timeout, + run_timeout=run_timeout, + task_timeout=task_timeout, + request_id=request_id, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + cron_schedule=cron_schedule, + memo=request_memo, + search_attributes=request_search_attributes, + header=request_header, + user_metadata=_build_user_metadata( + payload_converter, static_summary, static_details + ), + start_delay=start_delay, + priority=priority, + versioning_override=versioning_override, + ) + + +async def visit_payload( + service: str, + operation: str, + payload: temporalio.api.common.v1.Payload, + visitor_functions: VisitorFunctions, + skip_search_attributes: bool, +) -> temporalio.api.common.v1.Payload | None: + """Visit nested payloads inside a recognized system Nexus envelope.""" + operation_def = generated.__nexus_operation_registry__.get((service, operation)) + if operation_def is None: + return None + input_type = operation_def.input_type + if not isinstance(input_type, type): + return None + payload_converter = get_payload_converter() + value = payload_converter.from_payload(payload, input_type) + await PayloadVisitor(skip_search_attributes=skip_search_attributes).visit( + cast(Any, visitor_functions), value + ) + return payload_converter.to_payload(value) + + +def is_system_operation(service: str, operation: str) -> bool: + """Return whether a Nexus operation uses the generated system envelope.""" + return (service, operation) in generated.__nexus_operation_registry__ + + +def get_payload_converter() -> temporalio.converter.PayloadConverter: + """Return the fixed payload converter for system Nexus outer envelopes.""" + return SystemNexusPayloadConverter() + + +__all__ = ( + "build_signal_with_start_workflow_execution_input", + "generated", + "get_payload_converter", + "is_system_operation", + "SystemNexusPayloadConverter", + "visit_payload", +) diff --git a/temporalio/nexus/system/_payload_visitor.py b/temporalio/nexus/system/_payload_visitor.py new file mode 100644 index 000000000..fb7c5e173 --- /dev/null +++ b/temporalio/nexus/system/_payload_visitor.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +# This file is generated by gen_payload_visitor.py. Changes should be made there. +from google.protobuf.message import Message + +import temporalio.api.common.v1.message_pb2 +import temporalio.api.sdk.v1.user_metadata_pb2 +import temporalio.api.workflowservice.v1.request_response_pb2 +import temporalio.nexus.system +from temporalio.api.common.v1.message_pb2 import Payload, Payloads, SearchAttributes +from temporalio.bridge._visitor_functions import PayloadSequence, VisitorFunctions + + +class PayloadVisitor: + """A visitor for payloads. + Applies a function to every payload in a tree of messages. + """ + + def __init__( + self, *, skip_search_attributes: bool = False, skip_headers: bool = False + ): + """Create a new payload visitor.""" + self.skip_search_attributes = skip_search_attributes + self.skip_headers = skip_headers + + async def visit(self, fs: VisitorFunctions, root: Message) -> None: + """Visit the given root message with the given function set.""" + method_name = "_visit_" + root.DESCRIPTOR.full_name.replace(".", "_") + method = getattr(self, method_name, None) + if method is not None: + await method(fs, root) + else: + raise ValueError(f"Unknown root message type: {root.DESCRIPTOR.full_name}") + + async def _visit_system_nexus_payload( + self, + fs: VisitorFunctions, + service: str, + operation: str, + payload: Payload, + ) -> None: + new_payload = await temporalio.nexus.system.visit_payload( + service, + operation, + payload, + fs, + self.skip_search_attributes, + ) + if new_payload is None: + await self._visit_temporal_api_common_v1_Payload(fs, payload) + return + + if new_payload is not payload: + payload.CopyFrom(new_payload) + await fs.visit_system_nexus_envelope(payload) + + async def _visit_temporal_api_common_v1_Payload( + self, fs: VisitorFunctions, o: Payload + ): + await fs.visit_payload(o) + + async def _visit_temporal_api_common_v1_Payloads( + self, fs: VisitorFunctions, o: Payloads + ): + await fs.visit_payloads(o.payloads) + + async def _visit_payload_container(self, fs: VisitorFunctions, o: PayloadSequence): + await fs.visit_payloads(o) + + async def _visit_temporal_api_common_v1_Memo( + self, fs: VisitorFunctions, o: temporalio.api.common.v1.message_pb2.Memo + ): + for v in o.fields.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) + + async def _visit_temporal_api_common_v1_SearchAttributes( + self, fs: VisitorFunctions, o: SearchAttributes + ): + if self.skip_search_attributes: + return + for v in o.indexed_fields.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) + + async def _visit_temporal_api_common_v1_Header( + self, fs: VisitorFunctions, o: temporalio.api.common.v1.message_pb2.Header + ): + for v in o.fields.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) + + async def _visit_temporal_api_sdk_v1_UserMetadata( + self, + fs: VisitorFunctions, + o: temporalio.api.sdk.v1.user_metadata_pb2.UserMetadata, + ): + if o.HasField("summary"): + await self._visit_temporal_api_common_v1_Payload(fs, o.summary) + if o.HasField("details"): + await self._visit_temporal_api_common_v1_Payload(fs, o.details) + + async def _visit_temporal_api_workflowservice_v1_SignalWithStartWorkflowExecutionRequest( + self, + fs: VisitorFunctions, + o: temporalio.api.workflowservice.v1.request_response_pb2.SignalWithStartWorkflowExecutionRequest, + ): + if o.HasField("input"): + await self._visit_temporal_api_common_v1_Payloads(fs, o.input) + if o.HasField("signal_input"): + await self._visit_temporal_api_common_v1_Payloads(fs, o.signal_input) + if o.HasField("memo"): + await self._visit_temporal_api_common_v1_Memo(fs, o.memo) + if o.HasField("search_attributes"): + await self._visit_temporal_api_common_v1_SearchAttributes( + fs, o.search_attributes + ) + if o.HasField("header"): + await self._visit_temporal_api_common_v1_Header(fs, o.header) + if o.HasField("user_metadata"): + await self._visit_temporal_api_sdk_v1_UserMetadata(fs, o.user_metadata) diff --git a/temporalio/nexus/system/_workflow_service_generated.py b/temporalio/nexus/system/_workflow_service_generated.py new file mode 100644 index 000000000..732e746d2 --- /dev/null +++ b/temporalio/nexus/system/_workflow_service_generated.py @@ -0,0 +1,27 @@ +# Generated by nexus-rpc-gen. DO NOT EDIT! + +from __future__ import annotations + +import typing + +from nexusrpc import Operation, service + +import temporalio.api.workflowservice.v1 + + +@service +class WorkflowService: + signal_with_start_workflow_execution: Operation[ + temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest, + temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionResponse, + ] = Operation(name="SignalWithStartWorkflowExecution") + + +__nexus_operation_registry__: dict[ + tuple[str, str], Operation[typing.Any, typing.Any] +] = { + ( + "WorkflowService", + "SignalWithStartWorkflowExecution", + ): WorkflowService.signal_with_start_workflow_execution, +} diff --git a/temporalio/worker/__init__.py b/temporalio/worker/__init__.py index 55966b35d..fb099f73b 100644 --- a/temporalio/worker/__init__.py +++ b/temporalio/worker/__init__.py @@ -17,6 +17,7 @@ NexusOperationInboundInterceptor, SignalChildWorkflowInput, SignalExternalWorkflowInput, + SignalWithStartWorkflowInput, StartActivityInput, StartChildWorkflowInput, StartLocalActivityInput, @@ -94,6 +95,7 @@ "HandleUpdateInput", "SignalChildWorkflowInput", "SignalExternalWorkflowInput", + "SignalWithStartWorkflowInput", "StartActivityInput", "StartChildWorkflowInput", "StartLocalActivityInput", diff --git a/temporalio/worker/_command_aware_visitor.py b/temporalio/worker/_command_aware_visitor.py index 2d7f3990b..72d41891d 100644 --- a/temporalio/worker/_command_aware_visitor.py +++ b/temporalio/worker/_command_aware_visitor.py @@ -6,7 +6,8 @@ from dataclasses import dataclass from temporalio.api.enums.v1.command_type_pb2 import CommandType -from temporalio.bridge._visitor import PayloadVisitor, VisitorFunctions +from temporalio.bridge._visitor import PayloadVisitor +from temporalio.bridge._visitor_functions import VisitorFunctions from temporalio.bridge.proto.workflow_activation.workflow_activation_pb2 import ( ResolveActivity, ResolveChildWorkflowExecution, @@ -81,7 +82,10 @@ async def _visit_coresdk_workflow_commands_SignalExternalWorkflowExecution( async def _visit_coresdk_workflow_commands_ScheduleNexusOperation( self, fs: VisitorFunctions, o: ScheduleNexusOperation ) -> None: - with current_command(CommandType.COMMAND_TYPE_SCHEDULE_NEXUS_OPERATION, o.seq): + with current_command( + CommandType.COMMAND_TYPE_SCHEDULE_NEXUS_OPERATION, + o.seq, + ): await super()._visit_coresdk_workflow_commands_ScheduleNexusOperation(fs, o) # Workflow activation jobs with payloads @@ -150,11 +154,15 @@ async def _visit_coresdk_workflow_activation_ResolveNexusOperation( @contextmanager def current_command( - command_type: CommandType.ValueType, command_seq: int + command_type: CommandType.ValueType, + command_seq: int, ) -> Iterator[None]: """Context manager for setting command info.""" token = current_command_info.set( - CommandInfo(command_type=command_type, command_seq=command_seq) + CommandInfo( + command_type=command_type, + command_seq=command_seq, + ) ) try: yield diff --git a/temporalio/worker/_interceptor.py b/temporalio/worker/_interceptor.py index f0d616f2c..d05f15b2a 100644 --- a/temporalio/worker/_interceptor.py +++ b/temporalio/worker/_interceptor.py @@ -241,6 +241,35 @@ class SignalExternalWorkflowInput: headers: Mapping[str, temporalio.api.common.v1.Payload] +@dataclass +class SignalWithStartWorkflowInput: + """Input for :py:meth:`WorkflowOutboundInterceptor.signal_with_start_workflow`.""" + + signal: str + signal_args: Sequence[Any] + namespace: str + workflow_id: str + workflow: str + workflow_args: Sequence[Any] + task_queue: str + execution_timeout: timedelta | None + run_timeout: timedelta | None + task_timeout: timedelta | None + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy + retry_policy: temporalio.common.RetryPolicy | None + cron_schedule: str + memo: Mapping[str, Any] | None + search_attributes: temporalio.common.TypedSearchAttributes | None + static_summary: str | None + static_details: str | None + start_delay: timedelta | None + request_id: str | None + priority: temporalio.common.Priority + versioning_override: temporalio.common.VersioningOverride | None + headers: Mapping[str, temporalio.api.common.v1.Payload] + + @dataclass class StartActivityInput: """Input for :py:meth:`WorkflowOutboundInterceptor.start_activity`.""" @@ -450,6 +479,15 @@ async def signal_external_workflow( """ return await self.next.signal_external_workflow(input) + async def signal_with_start_workflow( + self, input: SignalWithStartWorkflowInput + ) -> temporalio.workflow.ExternalWorkflowHandle[Any]: + """Called for every + :py:func:`temporalio.workflow.signal_with_start_workflow` + call. + """ + return await self.next.signal_with_start_workflow(input) + def start_activity( self, input: StartActivityInput ) -> temporalio.workflow.ActivityHandle[Any]: diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 1bfa77c3c..e19d20a67 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -47,6 +47,7 @@ import temporalio.api.common.v1 import temporalio.api.enums.v1 import temporalio.api.sdk.v1 +import temporalio.api.workflowservice.v1 import temporalio.bridge.proto.activity_result import temporalio.bridge.proto.child_workflow import temporalio.bridge.proto.common @@ -57,10 +58,12 @@ import temporalio.common import temporalio.converter import temporalio.exceptions +import temporalio.nexus.system import temporalio.workflow from temporalio.service import __version__ from ..api.failure.v1.message_pb2 import Failure +from ..converter import PayloadConverter from . import _command_aware_visitor from ._interceptor import ( ContinueAsNewInput, @@ -70,6 +73,7 @@ HandleUpdateInput, SignalChildWorkflowInput, SignalExternalWorkflowInput, + SignalWithStartWorkflowInput, StartActivityInput, StartChildWorkflowInput, StartLocalActivityInput, @@ -1546,6 +1550,65 @@ async def workflow_start_child_workflow( ) ) + async def workflow_signal_with_start_workflow( + self, + workflow_id: str, + signal: str | Callable, + workflow: Any, + *, + signal_args: Sequence[Any], + workflow_args: Sequence[Any], + task_queue: str, + execution_timeout: timedelta | None, + run_timeout: timedelta | None, + task_timeout: timedelta | None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy, + retry_policy: temporalio.common.RetryPolicy | None, + cron_schedule: str, + memo: Mapping[str, Any] | None, + search_attributes: temporalio.common.TypedSearchAttributes | None, + static_summary: str | None, + static_details: str | None, + start_delay: timedelta | None, + request_id: str | None, + priority: temporalio.common.Priority, + versioning_override: temporalio.common.VersioningOverride | None, + ) -> temporalio.workflow.ExternalWorkflowHandle[Any]: + self._assert_not_read_only("signal with start workflow") + workflow_name, _ = temporalio.workflow._Definition.get_name_and_result_type( + workflow + ) + return await self._outbound.signal_with_start_workflow( + SignalWithStartWorkflowInput( + signal=temporalio.workflow._SignalDefinition.must_name_from_fn_or_str( + signal + ), + signal_args=signal_args, + namespace=self._info.namespace, + workflow_id=workflow_id, + workflow=workflow_name, + workflow_args=workflow_args, + task_queue=task_queue, + execution_timeout=execution_timeout, + run_timeout=run_timeout, + task_timeout=task_timeout, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + cron_schedule=cron_schedule, + memo=memo, + search_attributes=search_attributes, + static_summary=static_summary, + static_details=static_details, + start_delay=start_delay, + request_id=request_id, + priority=priority, + versioning_override=versioning_override, + headers={}, + ) + ) + def workflow_start_local_activity( self, activity: Any, @@ -1951,6 +2014,63 @@ async def _outbound_signal_external_workflow( temporalio.common._apply_headers(input.headers, v.headers) await self._signal_external_workflow(command) + async def _outbound_signal_with_start_workflow( + self, input: SignalWithStartWorkflowInput + ) -> temporalio.workflow.ExternalWorkflowHandle[Any]: + payload_converter = self._payload_converter_with_context( + temporalio.converter.WorkflowSerializationContext( + namespace=input.namespace, + workflow_id=input.workflow_id, + ) + ) + request = ( + temporalio.nexus.system.build_signal_with_start_workflow_execution_input( + namespace=input.namespace, + workflow_id=input.workflow_id, + workflow=input.workflow, + workflow_args=input.workflow_args, + signal=input.signal, + signal_args=input.signal_args, + task_queue=input.task_queue, + request_id=input.request_id, + payload_converter=payload_converter, + execution_timeout=input.execution_timeout, + run_timeout=input.run_timeout, + task_timeout=input.task_timeout, + id_reuse_policy=input.id_reuse_policy, + id_conflict_policy=input.id_conflict_policy, + retry_policy=input.retry_policy, + cron_schedule=input.cron_schedule, + memo=input.memo, + search_attributes=input.search_attributes, + static_summary=input.static_summary, + static_details=input.static_details, + start_delay=input.start_delay, + priority=input.priority, + versioning_override=input.versioning_override, + headers=input.headers, + ) + ) + handle = await self._outbound_start_nexus_operation( + StartNexusOperationInput( + endpoint=temporalio.nexus.system._SYSTEM_NEXUS_ENDPOINT, + service=temporalio.nexus.system.generated.WorkflowService.__name__, + operation=temporalio.nexus.system.generated.WorkflowService.signal_with_start_workflow_execution.name, + input=request, + schedule_to_close_timeout=None, + schedule_to_start_timeout=None, + start_to_close_timeout=None, + cancellation_type=temporalio.workflow.NexusOperationCancellationType.WAIT_COMPLETED, + headers=None, + summary=None, + output_type=temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionResponse, + ) + ) + result = await handle + return self.workflow_get_external_workflow_handle( + input.workflow_id, run_id=result.run_id + ) + async def _outbound_start_child_workflow( self, input: StartChildWorkflowInput ) -> _ChildWorkflowHandle: @@ -2029,8 +2149,20 @@ async def operation_handle_fn() -> OutputT: cancel_command = self._add_command() handle._apply_cancel_command(cancel_command) + is_system_operation = temporalio.nexus.system.is_system_operation( + input.service, input.operation_name + ) + payload_converter = ( + temporalio.nexus.system.get_payload_converter() + if is_system_operation + else self._context_free_payload_converter + ) handle = _NexusOperationHandle( - self, self._next_seq("nexus_operation"), input, operation_handle_fn() + self, + self._next_seq("nexus_operation"), + input, + operation_handle_fn(), + payload_converter, ) handle._apply_schedule_command() self._pending_nexus_operations[handle._seq] = handle @@ -2877,6 +3009,11 @@ async def signal_external_workflow( ) -> None: await self._instance._outbound_signal_external_workflow(input) + async def signal_with_start_workflow( + self, input: SignalWithStartWorkflowInput + ) -> temporalio.workflow.ExternalWorkflowHandle[Any]: + return await self._instance._outbound_signal_with_start_workflow(input) + def start_activity( self, input: StartActivityInput ) -> temporalio.workflow.ActivityHandle[Any]: @@ -3304,6 +3441,7 @@ def __init__( seq: int, input: StartNexusOperationInput[Any, OutputT], fn: Coroutine[Any, Any, OutputT], + payload_converter: PayloadConverter, ): self._instance = instance self._seq = seq @@ -3311,7 +3449,7 @@ def __init__( self._task = asyncio.Task(fn) self._start_fut: asyncio.Future[str | None] = instance.create_future() self._result_fut: asyncio.Future[OutputT | None] = instance.create_future() - self._payload_converter = self._instance._context_free_payload_converter + self._payload_converter = payload_converter self._failure_converter = self._instance._context_free_failure_converter @property @@ -3345,14 +3483,13 @@ def _resolve_failure(self, err: BaseException) -> None: self._result_fut.set_result(None) def _apply_schedule_command(self) -> None: - payload = self._payload_converter.to_payload(self._input.input) command = self._instance._add_command() v = command.schedule_nexus_operation v.seq = self._seq v.endpoint = self._input.endpoint v.service = self._input.service v.operation = self._input.operation_name - v.input.CopyFrom(payload) + v.input.CopyFrom(self._payload_converter.to_payload(self._input.input)) if self._input.schedule_to_close_timeout is not None: v.schedule_to_close_timeout.FromTimedelta( self._input.schedule_to_close_timeout diff --git a/temporalio/workflow.py b/temporalio/workflow.py index dd8565f78..b8106cec2 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -847,6 +847,33 @@ async def workflow_start_child_workflow( priority: temporalio.common.Priority = temporalio.common.Priority.default, ) -> ChildWorkflowHandle[Any, Any]: ... + @abstractmethod + async def workflow_signal_with_start_workflow( + self, + workflow_id: str, + signal: str | Callable, + workflow: Any, + *, + signal_args: Sequence[Any], + workflow_args: Sequence[Any], + task_queue: str, + execution_timeout: timedelta | None, + run_timeout: timedelta | None, + task_timeout: timedelta | None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy, + retry_policy: temporalio.common.RetryPolicy | None, + cron_schedule: str, + memo: Mapping[str, Any] | None, + search_attributes: temporalio.common.TypedSearchAttributes | None, + static_summary: str | None, + static_details: str | None, + start_delay: timedelta | None, + request_id: str | None, + priority: temporalio.common.Priority, + versioning_override: temporalio.common.VersioningOverride | None, + ) -> ExternalWorkflowHandle[Any]: ... + @abstractmethod def workflow_start_local_activity( self, @@ -2040,10 +2067,10 @@ def _assert_dynamic_handler_args( if ( not arg_types or len(arg_types) != 2 - or arg_types[0] != str + or arg_types[0] is not str or ( - arg_types[1] != Sequence[temporalio.common.RawValue] - and arg_types[1] != typing.Sequence[temporalio.common.RawValue] # type: ignore[reportDeprecated] + arg_types[1] is not Sequence[temporalio.common.RawValue] + and arg_types[1] is not typing.Sequence[temporalio.common.RawValue] # type: ignore[reportDeprecated] ) ): raise RuntimeError( @@ -4655,6 +4682,147 @@ async def execute_child_workflow( return await handle +@overload +async def signal_with_start_workflow( + workflow_id: str, + signal: str | Callable, + workflow: MethodAsyncNoParam[SelfType, Any] + | MethodAsyncSingleParam[SelfType, Any, Any] + | Callable[Concatenate[SelfType, MultiParamSpec], Awaitable[Any]], + *, + signal_args: Sequence[Any] = [], + workflow_args: Sequence[Any] = [], + task_queue: str, + execution_timeout: timedelta | None = None, + run_timeout: timedelta | None = None, + task_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: temporalio.common.RetryPolicy | None = None, + cron_schedule: str = "", + memo: Mapping[str, Any] | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + static_summary: str | None = None, + static_details: str | None = None, + start_delay: timedelta | None = None, + request_id: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: temporalio.common.VersioningOverride | None = None, +) -> ExternalWorkflowHandle[SelfType]: ... + + +@overload +async def signal_with_start_workflow( + workflow_id: str, + signal: str | Callable, + workflow: str, + *, + signal_args: Sequence[Any] = [], + workflow_args: Sequence[Any] = [], + task_queue: str, + execution_timeout: timedelta | None = None, + run_timeout: timedelta | None = None, + task_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: temporalio.common.RetryPolicy | None = None, + cron_schedule: str = "", + memo: Mapping[str, Any] | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + static_summary: str | None = None, + static_details: str | None = None, + start_delay: timedelta | None = None, + request_id: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: temporalio.common.VersioningOverride | None = None, +) -> ExternalWorkflowHandle[Any]: ... + + +async def signal_with_start_workflow( + workflow_id: str, + signal: str | Callable, + workflow: str | Callable[..., Awaitable[Any]], + *, + signal_args: Sequence[Any] = [], + workflow_args: Sequence[Any] = [], + task_queue: str, + execution_timeout: timedelta | None = None, + run_timeout: timedelta | None = None, + task_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: temporalio.common.RetryPolicy | None = None, + cron_schedule: str = "", + memo: Mapping[str, Any] | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + static_summary: str | None = None, + static_details: str | None = None, + start_delay: timedelta | None = None, + request_id: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: temporalio.common.VersioningOverride | None = None, +) -> ExternalWorkflowHandle[Any]: + """Signal a workflow, or start it and signal it if it is not running. + + This uses the system Nexus ``SignalWithStartWorkflowExecution`` operation + under the hood. + + Args: + workflow_id: Workflow ID to signal or start. + signal: Name or method reference for the signal. + workflow: String name or class method decorated with ``@workflow.run`` + for the workflow to start. + signal_args: Arguments to the signal. + workflow_args: Arguments to the workflow. + task_queue: Task queue to run the workflow on if it is started. + execution_timeout: Total workflow execution timeout including + retries and continue as new. + run_timeout: Timeout of a single workflow run. + task_timeout: Timeout of a single workflow task. + id_reuse_policy: How already-existing IDs are treated. + id_conflict_policy: How already-running IDs are treated. + retry_policy: Retry policy for the workflow. + cron_schedule: See https://docs.temporal.io/docs/content/what-is-a-temporal-cron-job/ + memo: Memo for the workflow. + search_attributes: Typed search attributes for the workflow. + static_summary: A single-line fixed summary for this workflow + execution that may appear in the UI/CLI. + static_details: General fixed details for this workflow execution + that may appear in UI/CLI. + start_delay: Time to wait before dispatching the first workflow task. + request_id: Optional idempotency request ID for the start request. + priority: Priority to use for this workflow. + versioning_override: Versioning override to apply if the workflow is + started. + + Returns: + A handle for the resulting workflow run. + """ + return await _Runtime.current().workflow_signal_with_start_workflow( + workflow_id, + signal, + workflow, + signal_args=signal_args, + workflow_args=workflow_args, + task_queue=task_queue, + execution_timeout=execution_timeout, + run_timeout=run_timeout, + task_timeout=task_timeout, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + cron_schedule=cron_schedule, + memo=memo, + search_attributes=search_attributes, + static_summary=static_summary, + static_details=static_details, + start_delay=start_delay, + request_id=request_id, + priority=priority, + versioning_override=versioning_override, + ) + + class NexusOperationHandle(Generic[OutputT]): """Handle for interacting with a Nexus operation.""" @@ -4791,7 +4959,7 @@ class ContinueAsNewError(BaseException): def __init__(self, *args: object) -> None: """Direct instantiation is disabled. Use :py:func:`continue_as_new`.""" - if type(self) == ContinueAsNewError: + if type(self) is ContinueAsNewError: raise RuntimeError("Cannot instantiate ContinueAsNewError directly") super().__init__(*args) diff --git a/tests/nexus/test_temporal_system_nexus.py b/tests/nexus/test_temporal_system_nexus.py new file mode 100644 index 000000000..5f19c37b5 --- /dev/null +++ b/tests/nexus/test_temporal_system_nexus.py @@ -0,0 +1,370 @@ +from __future__ import annotations + +import dataclasses +import uuid +from collections.abc import Sequence +from datetime import timedelta +from typing import Any, cast + +import nexusrpc.handler +import pytest +from google.protobuf.descriptor import FieldDescriptor +from google.protobuf.message import Message + +import temporalio.api.common.v1 +import temporalio.api.workflowservice.v1 +import temporalio.converter +import temporalio.nexus.system as nexus_system +from temporalio import workflow +from temporalio.client import Client +from temporalio.converter import ExternalStorage, PayloadCodec +from temporalio.nexus.system import generated +from temporalio.testing import WorkflowEnvironment +from temporalio.worker import ( + Interceptor, + SignalWithStartWorkflowInput, + Worker, + WorkflowInboundInterceptor, + WorkflowInterceptorClassInput, + WorkflowOutboundInterceptor, +) +from temporalio.worker._workflow_instance import UnsandboxedWorkflowRunner +from tests.helpers.nexus import make_nexus_endpoint_name +from tests.test_extstore import InMemoryTestDriver + +interceptor_traces: list[tuple[str, object]] = [] +received_requests: list[ + temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest +] = [] + + +@nexusrpc.handler.service_handler(service=generated.WorkflowService) +class WorkflowServicePayloadHandler: + @nexusrpc.handler.sync_operation + async def signal_with_start_workflow_execution( + self, + _ctx: nexusrpc.handler.StartOperationContext, + request: temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest, + ) -> temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionResponse: + assert request.workflow_id == "system-nexus-workflow-id" + assert request.signal_name == "test-signal" + received_request = ( + temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest() + ) + received_request.CopyFrom(request) + received_requests.append(received_request) + return ( + temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionResponse( + run_id=f"{request.workflow_id}-run" + ) + ) + + +@workflow.defn +class ExternalHandleSignalWithStartWorkflowCaller: + @workflow.run + async def run(self, task_queue: str) -> str: + started_handle = await workflow.signal_with_start_workflow( + "system-nexus-workflow-id", + "test-signal", + "test-workflow", + signal_args=["signal-input"], + workflow_args=["workflow-input"], + task_queue=task_queue, + memo={"memo-key": "memo-value"}, + static_summary="summary-value", + static_details="details-value", + ) + return cast(str, started_handle.run_id) + + +class RejectOuterSystemNexusCodec(PayloadCodec): + def __init__(self) -> None: + self.encode_count = 0 + + async def encode( + self, payloads: Sequence[temporalio.api.common.v1.Payload] + ) -> list[temporalio.api.common.v1.Payload]: + encoded: list[temporalio.api.common.v1.Payload] = [] + for payload in payloads: + if ( + payload.metadata.get("encoding") == b"binary/protobuf" + and payload.metadata.get("messageType") + == b"temporal.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest" + ): + raise RuntimeError( + "outer system nexus envelope should not be codec encoded" + ) + self.encode_count += 1 + encoded.append( + temporalio.api.common.v1.Payload( + metadata={**payload.metadata, "test-codec": b"true"}, + data=payload.data, + ) + ) + return encoded + + async def decode( + self, payloads: Sequence[temporalio.api.common.v1.Payload] + ) -> list[temporalio.api.common.v1.Payload]: + decoded: list[temporalio.api.common.v1.Payload] = [] + for payload in payloads: + if ( + payload.metadata.get("encoding") == b"binary/protobuf" + and payload.metadata.get("messageType") + == b"temporal.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest" + ): + raise RuntimeError( + "outer system nexus envelope should not be codec decoded" + ) + decoded.append(payload) + return decoded + + +class TracingWorkflowInterceptor(Interceptor): + def workflow_interceptor_class( + self, input: WorkflowInterceptorClassInput + ) -> type[WorkflowInboundInterceptor] | None: + return _TracingWorkflowInboundInterceptor + + +class _TracingWorkflowInboundInterceptor(WorkflowInboundInterceptor): + def init(self, outbound: WorkflowOutboundInterceptor) -> None: + super().init(_TracingWorkflowOutboundInterceptor(outbound)) + + +class _TracingWorkflowOutboundInterceptor(WorkflowOutboundInterceptor): + async def signal_with_start_workflow( + self, input: SignalWithStartWorkflowInput + ) -> workflow.ExternalWorkflowHandle[object]: + interceptor_traces.append(("workflow.signal_with_start_workflow", input)) + return await super().signal_with_start_workflow(input) + + +def _pop_received_request() -> ( + temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest +): + assert len(received_requests) == 1 + return received_requests.pop() + + +def _assert_request_payload_was_externally_stored( + request: temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest, + field_name: str, +) -> None: + payloads = getattr(request, field_name).payloads + assert len(payloads) == 1 + assert payloads[0].external_payloads + + +def _assert_request_user_metadata_was_externally_stored( + request: temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest, +) -> None: + assert request.HasField("user_metadata") + assert request.user_metadata.summary.external_payloads + assert request.user_metadata.details.external_payloads + + +def _assert_stored_payloads_include( + driver: InMemoryTestDriver, expected_payload_data: set[bytes] +) -> None: + stored_payload_data: set[bytes] = set() + for stored_payload_bytes in driver._storage.values(): + stored_payload = temporalio.api.common.v1.Payload() + stored_payload.ParseFromString(stored_payload_bytes) + assert stored_payload.metadata["test-codec"] == b"true" + stored_payload_data.add(stored_payload.data) + assert expected_payload_data.issubset(stored_payload_data) + + +def _assert_signal_with_start_interceptor_trace() -> None: + assert len(interceptor_traces) == 1 + trace_name, trace_value = interceptor_traces.pop() + assert trace_name == "workflow.signal_with_start_workflow" + trace_input = cast(SignalWithStartWorkflowInput, trace_value) + assert trace_input.workflow_id == "system-nexus-workflow-id" + assert trace_input.signal == "test-signal" + assert trace_input.workflow == "test-workflow" + + +def _build_proto_sample(message_type: type[Message]) -> Message: + message = message_type() + _populate_proto_sample(message) + return message + + +def _populate_proto_sample(message: Message, *, path: str = "value") -> None: + seen_oneofs: set[str] = set() + for field in message.DESCRIPTOR.fields: + if field.containing_oneof is not None: + if field.containing_oneof.name in seen_oneofs: + continue + seen_oneofs.add(field.containing_oneof.name) + if field.label == FieldDescriptor.LABEL_REPEATED: + if ( + field.message_type is not None + and field.message_type.GetOptions().map_entry + ): + _populate_proto_map_entry(message, field, path=path) + elif field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE: + _populate_proto_sample( + getattr(message, field.name).add(), + path=f"{path}.{field.name}[0]", + ) + else: + getattr(message, field.name).append( + _proto_scalar_sample(field, path=f"{path}.{field.name}[0]") + ) + elif field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE: + _populate_proto_sample( + getattr(message, field.name), + path=f"{path}.{field.name}", + ) + else: + setattr( + message, + field.name, + _proto_scalar_sample(field, path=f"{path}.{field.name}"), + ) + + +def _populate_proto_map_entry( + message: Message, + field: FieldDescriptor, + *, + path: str, +) -> None: + key_field = field.message_type.fields_by_name["key"] + value_field = field.message_type.fields_by_name["value"] + key = _proto_scalar_sample(key_field, path=f"{path}.{field.name}.key") + container = getattr(message, field.name) + if value_field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE: + _populate_proto_sample( + container[key], + path=f"{path}.{field.name}[{key!r}]", + ) + else: + container[key] = _proto_scalar_sample( + value_field, + path=f"{path}.{field.name}[{key!r}]", + ) + + +def _proto_scalar_sample(field: FieldDescriptor, *, path: str) -> Any: + if field.type == FieldDescriptor.TYPE_BYTES: + return b"test" + if field.cpp_type == FieldDescriptor.CPPTYPE_STRING: + return f"{path}-value" + if field.cpp_type == FieldDescriptor.CPPTYPE_BOOL: + return True + if field.cpp_type in ( + FieldDescriptor.CPPTYPE_INT32, + FieldDescriptor.CPPTYPE_INT64, + FieldDescriptor.CPPTYPE_UINT32, + FieldDescriptor.CPPTYPE_UINT64, + ): + return 1 + if field.cpp_type in ( + FieldDescriptor.CPPTYPE_FLOAT, + FieldDescriptor.CPPTYPE_DOUBLE, + ): + return 1.5 + if field.cpp_type == FieldDescriptor.CPPTYPE_ENUM: + for enum_value in field.enum_type.values: + if enum_value.number != 0: + return enum_value.number + return field.enum_type.values[0].number + raise TypeError(f"Unhandled proto scalar sample at {path}: {field!r}") + + +@pytest.mark.parametrize( + "message_type", + [ + temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest, + temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionResponse, + ], +) +def test_system_nexus_proto_roundtrip(message_type: type[Message]) -> None: + payload_converter = nexus_system.get_payload_converter() + proto_value = _build_proto_sample(message_type) + payload = payload_converter.to_payload(proto_value) + assert payload is not None + assert payload.metadata["encoding"] == b"binary/protobuf" + assert payload.metadata["messageType"] == message_type.DESCRIPTOR.full_name.encode() + roundtripped = payload_converter.from_payload(payload, message_type) + assert isinstance(roundtripped, message_type) + assert roundtripped == proto_value + + +async def test_external_workflow_handle_signal_with_start_workflow_uses_system_nexus( + env: WorkflowEnvironment, + monkeypatch: pytest.MonkeyPatch, +): + if env.supports_time_skipping: + pytest.skip("Nexus tests don't work with the Java test server") + + codec = RejectOuterSystemNexusCodec() + interceptor_traces.clear() + received_requests.clear() + driver = InMemoryTestDriver() + caller_config = env.client.config() + caller_config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), + payload_codec=codec, + external_storage=ExternalStorage( + drivers=[driver], + payload_size_threshold=1, + ), + ) + caller_client = Client(**caller_config) + handler_config = env.client.config() + handler_config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), + payload_converter_class=nexus_system.SystemNexusPayloadConverter, + ) + handler_client = Client(**handler_config) + caller_task_queue = str(uuid.uuid4()) + handler_task_queue = str(uuid.uuid4()) + endpoint_name = make_nexus_endpoint_name(handler_task_queue) + monkeypatch.setattr(nexus_system, "_SYSTEM_NEXUS_ENDPOINT", endpoint_name) + + caller_worker = Worker( + caller_client, + task_queue=caller_task_queue, + workflows=[ExternalHandleSignalWithStartWorkflowCaller], + workflow_runner=UnsandboxedWorkflowRunner(), + interceptors=[TracingWorkflowInterceptor()], + ) + handler_worker = Worker( + handler_client, + task_queue=handler_task_queue, + nexus_service_handlers=[WorkflowServicePayloadHandler()], + ) + + async with caller_worker, handler_worker: + await env.create_nexus_endpoint(endpoint_name, handler_task_queue) + result = await caller_client.execute_workflow( + ExternalHandleSignalWithStartWorkflowCaller.run, + args=[handler_task_queue], + id=str(uuid.uuid4()), + task_queue=caller_task_queue, + execution_timeout=timedelta(seconds=5), + ) + + assert result == "system-nexus-workflow-id-run" + request = _pop_received_request() + _assert_request_payload_was_externally_stored(request, "input") + _assert_request_payload_was_externally_stored(request, "signal_input") + _assert_request_user_metadata_was_externally_stored(request) + assert codec.encode_count >= 5 + _assert_stored_payloads_include( + driver, + { + b'"workflow-input"', + b'"signal-input"', + b'"memo-value"', + b'"summary-value"', + b'"details-value"', + }, + ) + _assert_signal_with_start_interceptor_trace() diff --git a/tests/worker/test_visitor.py b/tests/worker/test_visitor.py index 5604b8542..efe9a36aa 100644 --- a/tests/worker/test_visitor.py +++ b/tests/worker/test_visitor.py @@ -1,10 +1,13 @@ import dataclasses -from collections.abc import MutableSequence +from typing import get_type_hints +import pytest from google.protobuf.duration_pb2 import Duration +from google.protobuf.message import Message import temporalio.bridge.worker import temporalio.converter +import temporalio.nexus.system as nexus_system from temporalio.api.common.v1.message_pb2 import ( Payload, Payloads, @@ -12,7 +15,13 @@ SearchAttributes, ) from temporalio.api.sdk.v1.user_metadata_pb2 import UserMetadata -from temporalio.bridge._visitor import PayloadVisitor, VisitorFunctions +from temporalio.api.workflowservice.v1.request_response_pb2 import ( + SignalWithStartWorkflowExecutionRequest, +) +from temporalio.bridge._visitor import ( + PayloadVisitor, +) +from temporalio.bridge._visitor_functions import PayloadSequence, VisitorFunctions from temporalio.bridge.proto.workflow_activation.workflow_activation_pb2 import ( InitializeWorkflow, WorkflowActivation, @@ -22,6 +31,7 @@ ContinueAsNewWorkflowExecution, ScheduleActivity, ScheduleLocalActivity, + ScheduleNexusOperation, SignalExternalWorkflowExecution, StartChildWorkflowExecution, UpdateResponse, @@ -31,6 +41,13 @@ Success, WorkflowActivationCompletion, ) +from temporalio.converter._payload_limits import ( + _PayloadSizeError, + _ServerPayloadErrorLimits, +) +from temporalio.nexus.system._payload_visitor import ( + PayloadVisitor as SystemPayloadVisitor, +) from tests.worker.test_workflow import SimpleCodec @@ -38,10 +55,41 @@ class Visitor(VisitorFunctions): async def visit_payload(self, payload: Payload) -> None: payload.metadata["visited"] = b"True" - async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: + async def visit_payloads(self, payloads: PayloadSequence) -> None: for payload in payloads: payload.metadata["visited"] = b"True" + async def visit_system_nexus_envelope(self, payload: Payload) -> None: + payload.metadata["visited"] = b"True" + + +def _json_plain_payload(value: object) -> Payload: + return temporalio.converter.default().payload_converter.to_payload(value) + + +def test_generated_payload_visitor_type_annotations(): + assert get_type_hints(PayloadVisitor.visit)["root"] is Message + assert ( + get_type_hints(PayloadVisitor._visit_temporal_api_common_v1_Payload)["o"] + is Payload + ) + assert ( + get_type_hints(PayloadVisitor._visit_temporal_api_common_v1_Payloads)["o"] + is Payloads + ) + assert ( + get_type_hints( + PayloadVisitor._visit_coresdk_workflow_activation_WorkflowActivation + )["o"] + is WorkflowActivation + ) + assert ( + get_type_hints( + SystemPayloadVisitor._visit_temporal_api_workflowservice_v1_SignalWithStartWorkflowExecutionRequest + )["o"] + is SignalWithStartWorkflowExecutionRequest + ) + async def test_workflow_activation_completion(): comp = WorkflowActivationCompletion( @@ -244,3 +292,82 @@ async def test_bridge_encoding(): assert sa.arguments[0].metadata["simple-codec"] assert cmd.user_metadata.summary.metadata["simple-codec"] + + +async def test_visit_system_nexus_payloads_on_schedule_nexus_operation(): + envelope = SignalWithStartWorkflowExecutionRequest( + namespace="default", + workflow_id="workflow-id", + signal_name="signal-name", + ) + envelope.input.payloads.extend([_json_plain_payload("input-value")]) + envelope.signal_input.payloads.extend([_json_plain_payload("signal-value")]) + envelope.memo.fields["memo-key"].CopyFrom(_json_plain_payload("memo-value")) + envelope.search_attributes.indexed_fields["search-key"].CopyFrom( + _json_plain_payload("search-value") + ) + comp = WorkflowActivationCompletion( + run_id="1", + successful=Success( + commands=[ + WorkflowCommand( + schedule_nexus_operation=ScheduleNexusOperation( + seq=1, + service="WorkflowService", + operation="SignalWithStartWorkflowExecution", + input=nexus_system.get_payload_converter().to_payload(envelope), + ) + ) + ], + ), + ) + + await PayloadVisitor(skip_search_attributes=True).visit(Visitor(), comp) + + input_payload = comp.successful.commands[0].schedule_nexus_operation.input + assert input_payload.metadata["visited"] + rewritten = nexus_system.get_payload_converter().from_payload( + input_payload, SignalWithStartWorkflowExecutionRequest + ) + assert rewritten.input.payloads[0].metadata["visited"] == b"True" + assert rewritten.signal_input.payloads[0].metadata["visited"] == b"True" + assert rewritten.memo.fields["memo-key"].metadata["visited"] == b"True" + assert ( + "visited" + not in rewritten.search_attributes.indexed_fields["search-key"].metadata + ) + + +async def test_bridge_encoding_checks_system_nexus_envelope_size(): + envelope = SignalWithStartWorkflowExecutionRequest( + namespace="default", + workflow_id="workflow-id", + signal_name="signal-name", + request_id="x" * 2048, + ) + envelope.input.payloads.extend([_json_plain_payload("input-value")]) + comp = WorkflowActivationCompletion( + run_id="1", + successful=Success( + commands=[ + WorkflowCommand( + schedule_nexus_operation=ScheduleNexusOperation( + seq=1, + service="WorkflowService", + operation="SignalWithStartWorkflowExecution", + input=nexus_system.get_payload_converter().to_payload(envelope), + ) + ) + ], + ), + ) + + data_converter = temporalio.converter.default()._with_payload_error_limits( + _ServerPayloadErrorLimits( + memo_size_error=1024 * 1024, + payload_size_error=512, + ) + ) + + with pytest.raises(_PayloadSizeError, match="payloads with size that exceeded"): + await temporalio.bridge.worker.encode_completion(comp, data_converter, True)