From cb0af366a0add702d5a45820c7e9591c75853a5b Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Tue, 14 Apr 2026 14:45:48 -0700 Subject: [PATCH 01/47] add langgraph plugin --- pyproject.toml | 2 + temporalio/contrib/langgraph/README.md | 137 +++++++++++++++ temporalio/contrib/langgraph/__init__.py | 19 +++ temporalio/contrib/langgraph/activity.py | 70 ++++++++ .../contrib/langgraph/langgraph_config.py | 62 +++++++ .../contrib/langgraph/langgraph_plugin.py | 142 ++++++++++++++++ temporalio/contrib/langgraph/task_cache.py | 75 +++++++++ tests/contrib/langgraph/__init__.py | 0 tests/contrib/langgraph/conftest.py | 26 +++ .../langgraph/e2e_functional_entrypoints.py | 130 +++++++++++++++ .../langgraph/e2e_functional_workflows.py | 85 ++++++++++ .../contrib/langgraph/test_continue_as_new.py | 64 +++++++ .../langgraph/test_continue_as_new_cached.py | 123 ++++++++++++++ .../contrib/langgraph/test_e2e_functional.py | 157 ++++++++++++++++++ .../langgraph/test_e2e_functional_v2.py | 152 +++++++++++++++++ .../langgraph/test_execute_in_workflow.py | 43 +++++ tests/contrib/langgraph/test_interrupt.py | 61 +++++++ tests/contrib/langgraph/test_interrupt_v2.py | 71 ++++++++ tests/contrib/langgraph/test_streaming.py | 61 +++++++ .../langgraph/test_subgraph_activity.py | 56 +++++++ .../langgraph/test_subgraph_workflow.py | 56 +++++++ tests/contrib/langgraph/test_timeout.py | 54 ++++++ tests/contrib/langgraph/test_two_nodes.py | 58 +++++++ uv.lock | 139 +++++++++++++++- 24 files changed, 1842 insertions(+), 1 deletion(-) create mode 100644 temporalio/contrib/langgraph/README.md create mode 100644 temporalio/contrib/langgraph/__init__.py create mode 100644 temporalio/contrib/langgraph/activity.py create mode 100644 temporalio/contrib/langgraph/langgraph_config.py create mode 100644 temporalio/contrib/langgraph/langgraph_plugin.py create mode 100644 temporalio/contrib/langgraph/task_cache.py create mode 100644 tests/contrib/langgraph/__init__.py create mode 100644 tests/contrib/langgraph/conftest.py create mode 100644 tests/contrib/langgraph/e2e_functional_entrypoints.py create mode 100644 tests/contrib/langgraph/e2e_functional_workflows.py create mode 100644 tests/contrib/langgraph/test_continue_as_new.py create mode 100644 tests/contrib/langgraph/test_continue_as_new_cached.py create mode 100644 tests/contrib/langgraph/test_e2e_functional.py create mode 100644 tests/contrib/langgraph/test_e2e_functional_v2.py create mode 100644 tests/contrib/langgraph/test_execute_in_workflow.py create mode 100644 tests/contrib/langgraph/test_interrupt.py create mode 100644 tests/contrib/langgraph/test_interrupt_v2.py create mode 100644 tests/contrib/langgraph/test_streaming.py create mode 100644 tests/contrib/langgraph/test_subgraph_activity.py create mode 100644 tests/contrib/langgraph/test_subgraph_workflow.py create mode 100644 tests/contrib/langgraph/test_timeout.py create mode 100644 tests/contrib/langgraph/test_two_nodes.py diff --git a/pyproject.toml b/pyproject.toml index 6ea339047..bfa09eadf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ opentelemetry = ["opentelemetry-api>=1.11.1,<2", "opentelemetry-sdk>=1.11.1,<2"] pydantic = ["pydantic>=2.0.0,<3"] openai-agents = ["openai-agents>=0.3,<0.7", "mcp>=1.9.4, <2"] google-adk = ["google-adk>=1.27.0,<2"] +langgraph = ["langgraph>=1.1.6"] langsmith = ["langsmith>=0.7.0,<0.8"] lambda-worker-otel = [ "opentelemetry-api>=1.11.1,<2", @@ -79,6 +80,7 @@ dev = [ "pytest-rerunfailures>=16.1", "pytest-xdist>=3.6,<4", "moto[s3,server]>=5", + "langgraph>=1.1.6", "langsmith>=0.7.0,<0.8", "setuptools<82", "opentelemetry-exporter-otlp-proto-grpc>=1.11.1,<2", diff --git a/temporalio/contrib/langgraph/README.md b/temporalio/contrib/langgraph/README.md new file mode 100644 index 000000000..63911c570 --- /dev/null +++ b/temporalio/contrib/langgraph/README.md @@ -0,0 +1,137 @@ +# LangGraph Temporal Plugin + +A [Temporal](https://temporal.io) plugin that runs [LangGraph](https://www.langchain.com/langgraph) nodes and tasks as Temporal Activities, giving your AI workflows durable execution, automatic retries, and timeouts. + +## Installation + +```sh +uv add temporalio[langgraph] +``` + +or with pip: + +```sh +pip install temporalio[langgraph] +``` + +Requires `langgraph==1.1.3` and `temporalio>=1.24.0`. + +## Plugin Initialization + +### Graph API + +```python +from temporalio.contrib.langgraph import LangGraphPlugin + +plugin = LangGraphPlugin(graphs={"my-graph": graph}) +``` + +### Functional API + +```python +import datetime +from temporalio.contrib.langgraph import LangGraphPlugin + +plugin = LangGraphPlugin( + entrypoints={"my_entrypoint": my_entrypoint}, + tasks=[my_task], + activity_options={ + "my_task": { + "start_to_close_timeout": datetime.timedelta(seconds=30), + }, + }, +) +``` + +## Checkpointer + +Use `InMemorySaver` as your checkpointer. Temporal handles durability, so third-party checkpointers (like PostgreSQL or Redis) are not needed. + +```python +import langgraph.checkpoint.memory +import typing + +from temporalio.contrib.langgraph import graph +from temporalio import workflow + +@workflow.defn +class MyWorkflow: + @workflow.run + async def run(self, input: str) -> typing.Any: + g = graph("my-graph").compile( + checkpointer=langgraph.checkpoint.memory.InMemorySaver(), + ) + + ... +``` + +## Activity Options + +Options are passed through to [`workflow.execute_activity()`](https://python.temporal.io/temporalio.workflow.html#execute_activity), which supports parameters like `start_to_close_timeout`, `retry_policy`, `schedule_to_close_timeout`, `heartbeat_timeout`, and more. + +### Graph API + +Pass activity options as node `metadata` when calling `add_node`: + +```python +import datetime +from temporalio.common import RetryPolicy + +g = StateGraph(State) +g.add_node("my_node", my_node, metadata={ + "start_to_close_timeout": datetime.timedelta(seconds=30), + "retry_policy": RetryPolicy(maximum_attempts=3), +}) +``` + +### Functional API + +Pass activity options to the `Plugin` constructor, keyed by task function name: + +```python +import datetime +from temporalio.common import RetryPolicy +from temporalio.contrib.langgraph import LangGraphPlugin + +plugin = LangGraphPlugin( + entrypoints={"my_entrypoint": my_entrypoint}, + tasks=[my_task], + activity_options={ + "my_task": { + "start_to_close_timeout": datetime.timedelta(seconds=30), + "retry_policy": RetryPolicy(maximum_attempts=3), + }, + }, +) +``` + +### Running in the Workflow + +To skip the Activity wrapper and run a node or task directly in the Workflow, set `execute_in` to `"workflow"`: + +```python +# Graph API +graph.add_node("my_node", my_node, metadata={"execute_in": "workflow"}) + +# Functional API +plugin = LangGraphPlugin( + tasks=[my_task], + activity_options={"my_task": {"execute_in": "workflow"}}, +) +``` + +## Running Tests + +Install dependencies: + +```sh +uv sync +``` + +Run the test suite: + +```sh +uv run pytest +``` + +Tests start a local Temporal dev server automatically — no external server needed. diff --git a/temporalio/contrib/langgraph/__init__.py b/temporalio/contrib/langgraph/__init__.py new file mode 100644 index 000000000..df32ca5ee --- /dev/null +++ b/temporalio/contrib/langgraph/__init__.py @@ -0,0 +1,19 @@ +"""Support for using LangGraph as part of Temporal workflows. + +This module provides compatibility between +`LangGraph `_ and Temporal workflows. +""" + +from temporalio.contrib.langgraph.langgraph_plugin import ( + LangGraphPlugin, + entrypoint, + cache, + graph, +) + +__all__ = [ + "LangGraphPlugin", + "entrypoint", + "cache", + "graph", +] diff --git a/temporalio/contrib/langgraph/activity.py b/temporalio/contrib/langgraph/activity.py new file mode 100644 index 000000000..b30fdd4a1 --- /dev/null +++ b/temporalio/contrib/langgraph/activity.py @@ -0,0 +1,70 @@ +from dataclasses import dataclass +from inspect import iscoroutinefunction +from typing import Any, Callable + +from langgraph.errors import GraphInterrupt +from langgraph.types import Interrupt +from temporalio import workflow + +from temporalio.contrib.langgraph.langgraph_config import get_langgraph_config, set_langgraph_config + + +@dataclass +class ActivityInput: + args: tuple[Any, ...] + kwargs: dict[str, Any] + langgraph_config: dict[str, Any] + + +@dataclass +class ActivityOutput: + result: Any = None + langgraph_interrupts: tuple[Interrupt] | None = None + + +def wrap_activity(func: Callable) -> Callable: + async def wrapper(input: ActivityInput) -> ActivityOutput: + set_langgraph_config(input.langgraph_config) + try: + if iscoroutinefunction(func): + result = await func(*input.args, **input.kwargs) + else: + result = func(*input.args, **input.kwargs) + return ActivityOutput(result=result) + except GraphInterrupt as e: + return ActivityOutput(langgraph_interrupts=e.args[0]) + + return wrapper + + +def wrap_execute_activity( + afunc: Callable, + task_id: str = "", + **execute_activity_kwargs: dict[str, Any], +) -> Callable: + async def wrapper(*args: Any, **kwargs: dict[str, Any]) -> Any: + from temporalio.contrib.langgraph.task_cache import _cache_key, _cache_lookup, _cache_put + + # Check task result cache (for continue-as-new deduplication). + key = _cache_key(task_id, args, kwargs) if task_id else "" + if task_id: + found, cached = _cache_lookup(key) + if found: + return cached + + input = ActivityInput( + args=args, kwargs=kwargs, langgraph_config=get_langgraph_config() + ) + output: ActivityOutput = await workflow.execute_activity( + afunc, input, result_type=ActivityOutput, **execute_activity_kwargs + ) + if output.langgraph_interrupts is not None: + raise GraphInterrupt(output.langgraph_interrupts) + + # Store in cache for future continue-as-new cycles. + if task_id: + _cache_put(key, output.result) + + return output.result + + return wrapper diff --git a/temporalio/contrib/langgraph/langgraph_config.py b/temporalio/contrib/langgraph/langgraph_config.py new file mode 100644 index 000000000..07e98363f --- /dev/null +++ b/temporalio/contrib/langgraph/langgraph_config.py @@ -0,0 +1,62 @@ +from typing import Any + +from langchain_core.runnables.config import var_child_runnable_config +from langgraph._internal._constants import ( + CONFIG_KEY_CHECKPOINT_NS, + CONFIG_KEY_SCRATCHPAD, + CONFIG_KEY_SEND, +) +from langgraph.graph.state import RunnableConfig +from langgraph.pregel._algo import LazyAtomicCounter, PregelScratchpad + + +def get_langgraph_config() -> dict[str, Any]: + config = var_child_runnable_config.get() or {} + configurable = config.get("configurable") or {} + scratchpad = configurable.get(CONFIG_KEY_SCRATCHPAD) + + return { + "configurable": { + CONFIG_KEY_CHECKPOINT_NS: configurable.get(CONFIG_KEY_CHECKPOINT_NS), + CONFIG_KEY_SCRATCHPAD: { + "step": getattr(scratchpad, "step", 0), + "stop": getattr(scratchpad, "stop", 0), + "resume": list(getattr(scratchpad, "resume", [])), + "null_resume": scratchpad.get_null_resume() if scratchpad else None, + }, + } + } + + +def set_langgraph_config(config: dict[str, Any]) -> None: + configurable = config.get("configurable") or {} + scratchpad = configurable.get(CONFIG_KEY_SCRATCHPAD) or {} + null_resume_box = [scratchpad.get("null_resume")] + + def get_null_resume(consume: bool = False) -> Any: + val = null_resume_box[0] + if consume and val is not None: + null_resume_box[0] = None + return val + + var_child_runnable_config.set( + RunnableConfig( + { + "configurable": { + CONFIG_KEY_CHECKPOINT_NS: configurable.get( + CONFIG_KEY_CHECKPOINT_NS + ), + CONFIG_KEY_SCRATCHPAD: PregelScratchpad( + step=scratchpad.get("step", 0), + stop=scratchpad.get("stop", 0), + call_counter=LazyAtomicCounter(), + interrupt_counter=LazyAtomicCounter(), + get_null_resume=get_null_resume, + resume=list(scratchpad.get("resume", [])), + subgraph_counter=LazyAtomicCounter(), + ), + CONFIG_KEY_SEND: lambda _: None, + }, + } + ) + ) diff --git a/temporalio/contrib/langgraph/langgraph_plugin.py b/temporalio/contrib/langgraph/langgraph_plugin.py new file mode 100644 index 000000000..f392845b4 --- /dev/null +++ b/temporalio/contrib/langgraph/langgraph_plugin.py @@ -0,0 +1,142 @@ +from dataclasses import replace +from typing import Any, Callable + +from temporalio.contrib.langgraph.activity import wrap_activity, wrap_execute_activity +from langgraph._internal._runnable import RunnableCallable +from langgraph.graph import StateGraph +from langgraph.pregel import Pregel +from temporalio.contrib.langgraph.task_cache import _get_task_cache, _set_task_cache, _task_id + +from temporalio import activity +from temporalio.plugin import SimplePlugin +from temporalio.worker import WorkflowRunner +from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner + +# Save registered graphs/entrypoints at the module level to avoid being refreshed by the sandbox. +_graph_registry: dict[str, StateGraph] = {} +_entrypoint_registry: dict[str, Pregel] = {} + + +class LangGraphPlugin(SimplePlugin): + def __init__( + self, + # Graph API + graphs: dict[str, StateGraph] | None = None, + # Functional API + entrypoints: dict[str, Pregel] | None = None, + tasks: list | None = None, + # TODO: Remove activity_options when we have support for @task(metadata=...) + activity_options: dict[str, dict] | None = None, + # TODO: Add default_activity_options that apply to all nodes or tasks + ): + self.activities: list = [] + + # Graph API: Wrap graph nodes as Activities. + if graphs: + _graph_registry.update(graphs) + for graph in graphs.values(): + for name, node in graph.nodes.items(): + runnable = node.runnable + if ( + not isinstance(runnable, RunnableCallable) + or runnable.afunc is None + ): + raise ValueError(f"Node {name} must have an async function") + # Remove LangSmith-related callback functions that can't be serialized between the workflow and activity. + runnable.func_accepts = {} + runnable.afunc = self.execute(runnable.afunc, node.metadata) + + # Functional API: Register @entrypoint functions + if entrypoints: + _entrypoint_registry.update(entrypoints) + + # Functional API: Wrap @task functions as Activities. + if tasks: + for task in tasks: + name = task.func.__name__ + opts = (activity_options or {}).get(name, {}) + + task.func = self.execute(task.func, opts) + task.func.__name__ = name + task.func.__qualname__ = getattr(task.func, "__qualname__", name) + + def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: + if not runner: + raise ValueError("No WorkflowRunner provided to the LangGraph plugin.") + if isinstance(runner, SandboxedWorkflowRunner): + return replace( + runner, + restrictions=runner.restrictions.with_passthrough_modules( + "langchain", + "langchain_core", + "langgraph", + "langsmith", + "numpy", # LangSmith uses numpy + ), + ) + return runner + + super().__init__( + "temporalio.LangGraphPlugin", + activities=self.activities, + workflow_runner=workflow_runner, + ) + + # Prepare a [node, @task] to execute as a [Activity, Workflow]. + def execute(self, func: Callable, kwargs: dict[str, Any] | None = None) -> Callable: + execute_in = (kwargs or {}).pop("execute_in", "activity") + + if execute_in == "activity": + a = activity.defn(name=func.__name__)(wrap_activity(func)) + self.activities.append(a) + return wrap_execute_activity(a, task_id=_task_id(func), **(kwargs or {})) + elif execute_in == "workflow": + return func + else: + raise ValueError(f"Invalid execute_in value: {execute_in}") + + +def graph(name: str, cache: dict[str, Any] | None = None) -> StateGraph: + """Retrieve a registered graph by name. + + Args: + name: Graph name as registered with LangGraphPlugin. + cache: Optional task result cache from a previous cache() call. + Restores cached results so previously-completed nodes are + not re-executed after continue-as-new. + """ + _patch_event_loop() + _set_task_cache(cache or {}) + return _graph_registry[name] + + +def entrypoint(name: str, cache: dict[str, Any] | None = None) -> Pregel: + """Retrieve a registered entrypoint by name. + + Args: + name: Entrypoint name as registered with Plugin. + cache: Optional task result cache from a previous cache() call. + Restores cached results so previously-completed tasks are + not re-executed after continue-as-new. + """ + _patch_event_loop() + _set_task_cache(cache or {}) + return _entrypoint_registry[name] + + +def cache() -> dict[str, Any] | None: + """Return the task result cache as a serializable dict. + + Returns a dict suitable for passing to entrypoint(name, cache=...) to + restore cached task results across continue-as-new boundaries. + Returns None if the cache is empty. + """ + return _get_task_cache() or None + + +def _patch_event_loop(): + """Patch the event loop so LangGraph detects it as running inside Temporal's sandbox.""" + from asyncio import get_event_loop + + loop = get_event_loop() + loop.is_running = lambda: True diff --git a/temporalio/contrib/langgraph/task_cache.py b/temporalio/contrib/langgraph/task_cache.py new file mode 100644 index 000000000..78a1b5b5c --- /dev/null +++ b/temporalio/contrib/langgraph/task_cache.py @@ -0,0 +1,75 @@ +"""Task result cache for continue-as-new support. + +Caches task results by (module.qualname, args, kwargs) hash so that previously +completed tasks are not re-executed after a continue-as-new. The cache state +is a plain dict that can travel through workflow.continue_as_new(). +""" + +from __future__ import annotations + +from contextvars import ContextVar +from hashlib import sha256 +from json import dumps +from typing import Any + +_task_cache: ContextVar[dict[str, Any] | None] = ContextVar( + "_temporal_task_cache", default=None +) + + +def _set_task_cache(cache: dict[str, Any] | None) -> None: + _task_cache.set(cache) + + +def _get_task_cache() -> dict[str, Any] | None: + return _task_cache.get() + + +def _task_id(func: Any) -> str: + """Return the fully-qualified module.qualname for a function. + + Raises ValueError for functions that cannot be identified unambiguously + (lambdas, closures, __main__ functions). + """ + module = getattr(func, "__module__", None) + qualname = getattr(func, "__qualname__", None) or getattr(func, "__name__", None) + + if module is None or qualname is None: + raise ValueError( + f"Cannot identify task {func}: missing __module__ or __qualname__. " + "Tasks must be defined at module level." + ) + if module == "__main__": + raise ValueError( + f"Cannot identify task {qualname}: defined in __main__. " + "Tasks must be importable from a named module." + ) + if "" in qualname: + raise ValueError( + f"Cannot identify task {qualname}: closures/local functions are not supported. " + "Tasks must be defined at module level." + ) + return f"{module}.{qualname}" + + +def _cache_key(task_id: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> str: + """Build a cache key from the full task identifier and arguments.""" + try: + key_str = dumps([task_id, args, kwargs], sort_keys=True, default=str) + except (TypeError, ValueError): + key_str = repr([task_id, args, kwargs]) + return sha256(key_str.encode()).hexdigest()[:32] + + +def _cache_lookup(key: str) -> tuple[bool, Any]: + """Return (True, value) if cached, (False, None) otherwise.""" + cache = _task_cache.get() + if cache is not None and key in cache: + return True, cache[key] + return False, None + + +def _cache_put(key: str, value: Any) -> None: + cache = _task_cache.get() + if cache is not None: + cache[key] = value diff --git a/tests/contrib/langgraph/__init__.py b/tests/contrib/langgraph/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/contrib/langgraph/conftest.py b/tests/contrib/langgraph/conftest.py new file mode 100644 index 000000000..f56b0ad96 --- /dev/null +++ b/tests/contrib/langgraph/conftest.py @@ -0,0 +1,26 @@ +from asyncio import get_event_loop_policy +from collections.abc import AsyncGenerator + +from pytest import fixture +from pytest_asyncio import fixture as async_fixture +from temporalio.client import Client +from temporalio.testing import WorkflowEnvironment + + +@fixture(scope="session") +def event_loop(): + loop = get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + +@async_fixture(scope="session") +async def env() -> AsyncGenerator[WorkflowEnvironment, None]: + env = await WorkflowEnvironment.start_local() + yield env + await env.shutdown() + + +@async_fixture +async def client(env: WorkflowEnvironment) -> Client: + return env.client diff --git a/tests/contrib/langgraph/e2e_functional_entrypoints.py b/tests/contrib/langgraph/e2e_functional_entrypoints.py new file mode 100644 index 000000000..5498bbc9c --- /dev/null +++ b/tests/contrib/langgraph/e2e_functional_entrypoints.py @@ -0,0 +1,130 @@ +"""Functional API entrypoint definitions for E2E tests. + +These define @task and @entrypoint functions used in functional API E2E tests. +""" + +from __future__ import annotations + +from langgraph.func import entrypoint, task +from langgraph.types import interrupt + + +@task +def double_value(x: int) -> int: + return x * 2 + + +@task +def add_ten(x: int) -> int: + return x + 10 + + +@entrypoint() +async def simple_functional_entrypoint(value: int) -> dict: + doubled = await double_value(value) + result = await add_ten(doubled) + return {"result": result} + + +# Track task execution count for continue-as-new testing +_task_execution_counts: dict[str, int] = {} + + +def get_task_execution_counts() -> dict[str, int]: + return _task_execution_counts.copy() + + +def reset_task_execution_counts() -> None: + _task_execution_counts.clear() + + +@task +def expensive_task_a(x: int) -> int: + _task_execution_counts["task_a"] = _task_execution_counts.get("task_a", 0) + 1 + return x * 3 + + +@task +def expensive_task_b(x: int) -> int: + _task_execution_counts["task_b"] = _task_execution_counts.get("task_b", 0) + 1 + return x + 100 + + +@task +def expensive_task_c(x: int) -> int: + _task_execution_counts["task_c"] = _task_execution_counts.get("task_c", 0) + 1 + return x * 2 + + +@entrypoint() +async def continue_as_new_entrypoint(value: int) -> dict: + """For input 10: 10 * 3 = 30 -> 30 + 100 = 130 -> 130 * 2 = 260""" + result_a = await expensive_task_a(value) + result_b = await expensive_task_b(result_a) + result_c = await expensive_task_c(result_b) + return {"result": result_c} + + +@task +def step_1(x: int) -> int: + _task_execution_counts["step_1"] = _task_execution_counts.get("step_1", 0) + 1 + return x * 2 + + +@task +def step_2(x: int) -> int: + _task_execution_counts["step_2"] = _task_execution_counts.get("step_2", 0) + 1 + return x + 5 + + +@task +def step_3(x: int) -> int: + _task_execution_counts["step_3"] = _task_execution_counts.get("step_3", 0) + 1 + return x * 3 + + +@task +def step_4(x: int) -> int: + _task_execution_counts["step_4"] = _task_execution_counts.get("step_4", 0) + 1 + return x - 10 + + +@task +def step_5(x: int) -> int: + _task_execution_counts["step_5"] = _task_execution_counts.get("step_5", 0) + 1 + return x + 100 + + +@entrypoint() +async def partial_execution_entrypoint(input_data: dict) -> dict: + """For value=10, all 5 tasks: 10*2=20 -> +5=25 -> *3=75 -> -10=65 -> +100=165""" + value = input_data["value"] + stop_after = input_data.get("stop_after", 5) + + result = value + result = await step_1(result) + if stop_after == 1: + return {"result": result, "completed_tasks": 1} + result = await step_2(result) + if stop_after == 2: + return {"result": result, "completed_tasks": 2} + result = await step_3(result) + if stop_after == 3: + return {"result": result, "completed_tasks": 3} + result = await step_4(result) + if stop_after == 4: + return {"result": result, "completed_tasks": 4} + result = await step_5(result) + return {"result": result, "completed_tasks": 5} + + +@task +def ask_human(question: str) -> str: + return interrupt(question) + + +@entrypoint() +async def interrupt_entrypoint(value: str) -> dict: + """Entrypoint that interrupts for human input, then returns the answer.""" + answer = await ask_human("Do you approve?") + return {"input": value, "answer": answer} diff --git a/tests/contrib/langgraph/e2e_functional_workflows.py b/tests/contrib/langgraph/e2e_functional_workflows.py new file mode 100644 index 000000000..7942e11fe --- /dev/null +++ b/tests/contrib/langgraph/e2e_functional_workflows.py @@ -0,0 +1,85 @@ +"""Workflow definitions for Functional API E2E tests.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from temporalio import workflow + +from temporalio.contrib.langgraph.langgraph_plugin import entrypoint, cache + + +@workflow.defn +class SimpleFunctionalE2EWorkflow: + @workflow.run + async def run(self, input_value: int) -> dict: + return await entrypoint("e2e_simple_functional").ainvoke(input_value) + + +@dataclass +class ContinueAsNewInput: + value: int + cache: dict[str, Any] | None = None + task_a_done: bool = False + task_b_done: bool = False + + +@workflow.defn +class ContinueAsNewFunctionalWorkflow: + """Continues-as-new after each phase, passing cache for task deduplication.""" + + @workflow.run + async def run(self, input_data: ContinueAsNewInput) -> dict[str, Any]: + result = await entrypoint( + "e2e_continue_as_new_functional", cache=input_data.cache + ).ainvoke(input_data.value) + + if not input_data.task_a_done: + workflow.continue_as_new( + ContinueAsNewInput( + value=input_data.value, + cache=cache(), + task_a_done=True, + ) + ) + + if not input_data.task_b_done: + workflow.continue_as_new( + ContinueAsNewInput( + value=input_data.value, + cache=cache(), + task_a_done=True, + task_b_done=True, + ) + ) + + return result + + +@dataclass +class PartialExecutionInput: + value: int + cache: dict[str, Any] | None = None + phase: int = 1 + + +@workflow.defn +class PartialExecutionWorkflow: + """Phase 1: 3 tasks + cache. Phase 2: all 5 (1-3 cached).""" + + @workflow.run + async def run(self, input_data: PartialExecutionInput) -> dict[str, Any]: + app = entrypoint("e2e_partial_execution", cache=input_data.cache) + + if input_data.phase == 1: + await app.ainvoke({"value": input_data.value, "stop_after": 3}) + workflow.continue_as_new( + PartialExecutionInput( + value=input_data.value, + cache=cache(), + phase=2, + ) + ) + + return await app.ainvoke({"value": input_data.value, "stop_after": 5}) diff --git a/tests/contrib/langgraph/test_continue_as_new.py b/tests/contrib/langgraph/test_continue_as_new.py new file mode 100644 index 000000000..72f12ae96 --- /dev/null +++ b/tests/contrib/langgraph/test_continue_as_new.py @@ -0,0 +1,64 @@ +from datetime import timedelta +from typing import Any +from uuid import uuid4 + +from langgraph.checkpoint.memory import InMemorySaver +from langgraph.graph import START, StateGraph +from langgraph.graph.state import RunnableConfig +from temporalio import workflow +from temporalio.client import Client +from temporalio.worker import Worker + +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph + + +async def node(state: str) -> str: + return state + "a" + + +@workflow.defn +class ContinueAsNewWorkflow: + @workflow.run + async def run(self, values: str) -> Any: + g = graph("my-graph").compile( + checkpointer=InMemorySaver() + ) + config = RunnableConfig( + {"configurable": {"thread_id": "1"}} + ) + + await g.aupdate_state(config, values) + await g.ainvoke(values, config) + + if len(values) < 3: + state = await g.aget_state(config) + workflow.continue_as_new(state.values) + + return values + + +async def test_continue_as_new(client: Client): + g = StateGraph(str) + g.add_node( + "node", + node, + metadata={"start_to_close_timeout": timedelta(seconds=10)}, + ) + g.add_edge(START, "node") + + task_queue = f"my-graph-{uuid4()}" + + async with Worker( + client, + task_queue=task_queue, + workflows=[ContinueAsNewWorkflow], + plugins=[LangGraphPlugin(graphs={"my-graph": g})], + ): + result = await client.execute_workflow( + ContinueAsNewWorkflow.run, + "", + id=f"test-workflow-{uuid4()}", + task_queue=task_queue, + ) + + assert result == "aaa" diff --git a/tests/contrib/langgraph/test_continue_as_new_cached.py b/tests/contrib/langgraph/test_continue_as_new_cached.py new file mode 100644 index 000000000..048626c0a --- /dev/null +++ b/tests/contrib/langgraph/test_continue_as_new_cached.py @@ -0,0 +1,123 @@ +"""Test Graph API continue-as-new with task result caching. + +Verifies that node results are cached across continue-as-new boundaries, +so nodes don't re-execute when the graph is re-invoked with the same state. +""" + +from datetime import timedelta +from typing import Any +from uuid import uuid4 +from dataclasses import dataclass + +from langgraph.graph import START, StateGraph +from temporalio import workflow +from temporalio.client import Client +from temporalio.worker import Worker + +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, cache, graph + +# Track execution counts to verify caching +_execution_counts: dict[str, int] = {} + + +def _reset(): + _execution_counts.clear() + + +async def multiply_by_3(state: int) -> int: + _execution_counts["multiply"] = _execution_counts.get("multiply", 0) + 1 + return state * 3 + + +async def add_100(state: int) -> int: + _execution_counts["add"] = _execution_counts.get("add", 0) + 1 + return state + 100 + + +async def double(state: int) -> int: + _execution_counts["double"] = _execution_counts.get("double", 0) + 1 + return state * 2 + + +@dataclass +class GraphContinueAsNewInput: + value: int + cache: dict[str, Any] | None = None + phase: int = 1 # 1, 2, 3 — continues-as-new after phases 1 and 2 + + +@workflow.defn +class GraphContinueAsNewWorkflow: + """Runs a 3-node graph, continuing-as-new after each phase. + + Phase 1: runs graph (all 3 nodes execute), continues-as-new with cache. + Phase 2: runs graph again with same input (all 3 cached), continues-as-new. + Phase 3: runs graph again with same input (all 3 cached), returns result. + + Without caching: each node executes 3 times. + With caching: each node executes once (first run), cached for phases 2 & 3. + """ + + @workflow.run + async def run(self, input_data: GraphContinueAsNewInput) -> int: + g = graph("cached-graph", cache=input_data.cache).compile() + result = await g.ainvoke(input_data.value) + + if input_data.phase < 3: + workflow.continue_as_new( + GraphContinueAsNewInput( + value=input_data.value, + cache=cache(), + phase=input_data.phase + 1, + ) + ) + + return result + + +async def test_graph_continue_as_new_cached(client: Client): + """Each node executes once despite 3 continue-as-new cycles. + + Graph: multiply_by_3 -> add_100 -> double + Input 10: 10 * 3 = 30 -> 30 + 100 = 130 -> 130 * 2 = 260 + """ + _reset() + + timeout = {"start_to_close_timeout": timedelta(seconds=10)} + g = StateGraph(int) + g.add_node("multiply_by_3", multiply_by_3, metadata=timeout) + g.add_node("add_100", add_100, metadata=timeout) + g.add_node("double", double, metadata=timeout) + g.add_edge(START, "multiply_by_3") + g.add_edge("multiply_by_3", "add_100") + g.add_edge("add_100", "double") + + task_queue = f"graph-cached-{uuid4()}" + + async with Worker( + client, + task_queue=task_queue, + workflows=[GraphContinueAsNewWorkflow], + plugins=[LangGraphPlugin(graphs={"cached-graph": g})], + ): + result = await client.execute_workflow( + GraphContinueAsNewWorkflow.run, + GraphContinueAsNewInput(value=10), + id=f"graph-cached-{uuid4()}", + task_queue=task_queue, + execution_timeout=timedelta(seconds=60), + ) + + # 10 * 3 = 30 -> + 100 = 130 -> * 2 = 260 + assert result == 260 + + # Each node should execute exactly once — phases 2 and 3 use cached results. + assert _execution_counts.get("multiply", 0) == 1, ( + f"multiply executed {_execution_counts.get('multiply', 0)} times, expected 1" + ) + assert _execution_counts.get("add", 0) == 1, ( + f"add executed {_execution_counts.get('add', 0)} times, expected 1" + ) + assert _execution_counts.get("double", 0) == 1, ( + f"double executed {_execution_counts.get('double', 0)} times, expected 1" + ) diff --git a/tests/contrib/langgraph/test_e2e_functional.py b/tests/contrib/langgraph/test_e2e_functional.py new file mode 100644 index 000000000..09f85e90b --- /dev/null +++ b/tests/contrib/langgraph/test_e2e_functional.py @@ -0,0 +1,157 @@ +"""End-to-end tests for LangGraph Functional API integration. + +Requires a running Temporal test server (started by conftest.py). +""" + +from __future__ import annotations + +from datetime import timedelta +from uuid import uuid4 + +from temporalio.client import Client +from temporalio.worker import Worker + +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin +from tests.contrib.langgraph.e2e_functional_entrypoints import ( + add_ten, + continue_as_new_entrypoint, + double_value, + expensive_task_a, + expensive_task_b, + expensive_task_c, + get_task_execution_counts, + partial_execution_entrypoint, + reset_task_execution_counts, + simple_functional_entrypoint, + step_1, + step_2, + step_3, + step_4, + step_5, +) +from tests.contrib.langgraph.e2e_functional_workflows import ( + ContinueAsNewFunctionalWorkflow, + ContinueAsNewInput, + PartialExecutionInput, + PartialExecutionWorkflow, + SimpleFunctionalE2EWorkflow, +) + + +def _activity_opts(*task_funcs) -> dict[str, dict]: + """Build activity_options dict giving every task the same 30s timeout.""" + return { + t.func.__name__: {"start_to_close_timeout": timedelta(seconds=30)} + for t in task_funcs + } + + +class TestFunctionalAPIBasicExecution: + async def test_simple_functional_entrypoint(self, client: Client) -> None: + """input 10 -> double (20) -> add 10 (30) -> result: 30""" + tasks = [double_value, add_ten] + task_queue = f"e2e-functional-{uuid4()}" + + async with Worker( + client, + task_queue=task_queue, + workflows=[SimpleFunctionalE2EWorkflow], + plugins=[ + LangGraphPlugin( + entrypoints={"e2e_simple_functional": simple_functional_entrypoint}, + tasks=tasks, + activity_options=_activity_opts(*tasks), + ) + ], + ): + result = await client.execute_workflow( + SimpleFunctionalE2EWorkflow.run, + 10, + id=f"e2e-functional-{uuid4()}", + task_queue=task_queue, + execution_timeout=timedelta(seconds=30), + ) + + assert result["result"] == 30 + + +class TestFunctionalAPIContinueAsNew: + async def test_continue_as_new_with_checkpoint(self, client: Client) -> None: + """10 * 3 = 30 -> + 100 = 130 -> * 2 = 260. Each task executes once.""" + reset_task_execution_counts() + + tasks = [expensive_task_a, expensive_task_b, expensive_task_c] + task_queue = f"e2e-continue-as-new-{uuid4()}" + + async with Worker( + client, + task_queue=task_queue, + workflows=[ContinueAsNewFunctionalWorkflow], + plugins=[ + LangGraphPlugin( + entrypoints={ + "e2e_continue_as_new_functional": continue_as_new_entrypoint + }, + tasks=tasks, + activity_options=_activity_opts(*tasks), + ) + ], + ): + result = await client.execute_workflow( + ContinueAsNewFunctionalWorkflow.run, + ContinueAsNewInput(value=10), + id=f"e2e-continue-as-new-{uuid4()}", + task_queue=task_queue, + execution_timeout=timedelta(seconds=60), + ) + + assert result["result"] == 260 + + counts = get_task_execution_counts() + assert counts.get("task_a", 0) == 1, ( + f"task_a executed {counts.get('task_a', 0)} times, expected 1" + ) + assert counts.get("task_b", 0) == 1, ( + f"task_b executed {counts.get('task_b', 0)} times, expected 1" + ) + assert counts.get("task_c", 0) == 1, ( + f"task_c executed {counts.get('task_c', 0)} times, expected 1" + ) + + +class TestFunctionalAPIPartialExecution: + async def test_partial_execution_five_tasks(self, client: Client) -> None: + """10*2=20 -> +5=25 -> *3=75 -> -10=65 -> +100=165. Each task executes once.""" + reset_task_execution_counts() + + tasks = [step_1, step_2, step_3, step_4, step_5] + task_queue = f"e2e-partial-{uuid4()}" + + async with Worker( + client, + task_queue=task_queue, + workflows=[PartialExecutionWorkflow], + plugins=[ + LangGraphPlugin( + entrypoints={"e2e_partial_execution": partial_execution_entrypoint}, + tasks=tasks, + activity_options=_activity_opts(*tasks), + ) + ], + ): + result = await client.execute_workflow( + PartialExecutionWorkflow.run, + PartialExecutionInput(value=10), + id=f"e2e-partial-{uuid4()}", + task_queue=task_queue, + execution_timeout=timedelta(seconds=60), + ) + + assert result["result"] == 165 + assert result["completed_tasks"] == 5 + + counts = get_task_execution_counts() + for i in range(1, 6): + assert counts.get(f"step_{i}", 0) == 1, ( + f"step_{i} executed {counts.get(f'step_{i}', 0)} times, expected 1" + ) diff --git a/tests/contrib/langgraph/test_e2e_functional_v2.py b/tests/contrib/langgraph/test_e2e_functional_v2.py new file mode 100644 index 000000000..b9678bf74 --- /dev/null +++ b/tests/contrib/langgraph/test_e2e_functional_v2.py @@ -0,0 +1,152 @@ +"""Tests for LangGraph Functional API with version="v2". + +version="v2" changes ainvoke() to return a GraphOutput dataclass with +.value and .interrupts fields instead of a plain dict with __interrupt__ +mixed in. +""" + +from __future__ import annotations + +from datetime import timedelta +from typing import Any +from uuid import uuid4 + +from langgraph.checkpoint.memory import InMemorySaver +from langgraph.types import Command +from langchain_core.runnables import RunnableConfig +from langgraph.func import entrypoint as lg_entrypoint +from langgraph.func import task +from temporalio import workflow +from temporalio.client import Client +from temporalio.worker import Worker + +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, entrypoint +from tests.contrib.langgraph.e2e_functional_entrypoints import ( + ask_human, + interrupt_entrypoint, +) + +# Define separate tasks to avoid sharing mutated _TaskFunction objects with +# other tests (Plugin wraps task.func in-place). + + +@task +def triple_value(x: int) -> int: + return x * 3 + + +@task +def add_five(x: int) -> int: + return x + 5 + + +@lg_entrypoint() +async def simple_v2_entrypoint(value: int) -> dict: + tripled = await triple_value(value) + result = await add_five(tripled) + return {"result": result} + + +# -- Workflows ---------------------------------------------------------------- + + +@workflow.defn +class SimpleV2Workflow: + @workflow.run + async def run(self, input_value: int) -> dict[str, Any]: + result = await entrypoint("v2_simple").ainvoke(input_value, version="v2") + # v2 returns GraphOutput — extract .value for Temporal serialization + return result.value + + +@workflow.defn +class InterruptV2FunctionalWorkflow: + @workflow.run + async def run(self, input_value: str) -> dict[str, Any]: + app = entrypoint("v2_interrupt") + app.checkpointer = InMemorySaver() + config = RunnableConfig( + {"configurable": {"thread_id": workflow.info().workflow_id}} + ) + + # First invoke — should get an interrupt + result = await app.ainvoke(input_value, config, version="v2") + + # v2: interrupts are on result.interrupts, value is clean + assert result.value == {} + assert len(result.interrupts) == 1 + assert result.interrupts[0].value == "Do you approve?" + + # Resume with approval + resumed = await app.ainvoke( + Command(resume="approved"), config, version="v2" + ) + return resumed.value + + +# -- Tests -------------------------------------------------------------------- + + +class TestFunctionalAPIV2: + async def test_simple_v2(self, client: Client) -> None: + """version='v2' returns GraphOutput with .value containing the result.""" + tasks = [triple_value, add_five] + task_queue = f"v2-simple-{uuid4()}" + + async with Worker( + client, + task_queue=task_queue, + workflows=[SimpleV2Workflow], + plugins=[ + LangGraphPlugin( + entrypoints={"v2_simple": simple_v2_entrypoint}, + tasks=tasks, + activity_options={ + "triple_value": { + "start_to_close_timeout": timedelta(seconds=30) + }, + "add_five": {"start_to_close_timeout": timedelta(seconds=30)}, + }, + ) + ], + ): + result = await client.execute_workflow( + SimpleV2Workflow.run, + 10, + id=f"v2-simple-{uuid4()}", + task_queue=task_queue, + execution_timeout=timedelta(seconds=30), + ) + + # 10 * 3 = 30, 30 + 5 = 35 + assert result["result"] == 35 + + async def test_interrupt_v2_functional(self, client: Client) -> None: + """version='v2' separates interrupts from value in functional API.""" + tasks = [ask_human] + task_queue = f"v2-interrupt-{uuid4()}" + + async with Worker( + client, + task_queue=task_queue, + workflows=[InterruptV2FunctionalWorkflow], + plugins=[ + LangGraphPlugin( + entrypoints={"v2_interrupt": interrupt_entrypoint}, + tasks=tasks, + activity_options={ + "ask_human": {"start_to_close_timeout": timedelta(seconds=30)}, + }, + ) + ], + ): + result = await client.execute_workflow( + InterruptV2FunctionalWorkflow.run, + "hello", + id=f"v2-interrupt-{uuid4()}", + task_queue=task_queue, + execution_timeout=timedelta(seconds=30), + ) + + assert result["input"] == "hello" + assert result["answer"] == "approved" diff --git a/tests/contrib/langgraph/test_execute_in_workflow.py b/tests/contrib/langgraph/test_execute_in_workflow.py new file mode 100644 index 000000000..6d2e38ad7 --- /dev/null +++ b/tests/contrib/langgraph/test_execute_in_workflow.py @@ -0,0 +1,43 @@ +from typing import Any +from uuid import uuid4 + +from langgraph.graph import START, StateGraph +from temporalio import workflow +from temporalio.client import Client +from temporalio.worker import Worker + +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph + + +async def node(_: str) -> str: + return "done" + + +@workflow.defn +class ExecuteInWorkflowWorkflow: + @workflow.run + async def run(self, input: str) -> Any: + return await graph("my-graph").compile().ainvoke(input) + + +async def test_execute_in_workflow(client: Client): + g = StateGraph(str) + g.add_node("node", node, metadata={"execute_in": "workflow"}) + g.add_edge(START, "node") + + task_queue = f"my-graph-{uuid4()}" + + async with Worker( + client, + task_queue=task_queue, + workflows=[ExecuteInWorkflowWorkflow], + plugins=[LangGraphPlugin(graphs={"my-graph": g})], + ): + result = await client.execute_workflow( + ExecuteInWorkflowWorkflow.run, + "", + id=f"test-workflow-{uuid4()}", + task_queue=task_queue, + ) + + assert result == "done" diff --git a/tests/contrib/langgraph/test_interrupt.py b/tests/contrib/langgraph/test_interrupt.py new file mode 100644 index 000000000..d953ab5f0 --- /dev/null +++ b/tests/contrib/langgraph/test_interrupt.py @@ -0,0 +1,61 @@ +from datetime import timedelta +from typing import Any +from uuid import uuid4 + +from langgraph.checkpoint.memory import InMemorySaver +from langgraph.graph import START, StateGraph +from langgraph.graph.state import RunnableConfig +from langgraph.types import Command, interrupt +from temporalio import workflow +from temporalio.client import Client +from temporalio.worker import Worker + +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph + + +async def node(_: str) -> str: + return interrupt("Continue?") + + +@workflow.defn +class InterruptWorkflow: + @workflow.run + async def run(self, input: str) -> Any: + g = graph("my-graph").compile( + checkpointer=InMemorySaver() + ) + config = RunnableConfig( + {"configurable": {"thread_id": "1"}} + ) + + result = await g.ainvoke(input, config) + assert result["__interrupt__"][0].value == "Continue?" + + return await g.ainvoke(Command(resume="yes"), config) + + +async def test_interrupt(client: Client): + g = StateGraph(str) + g.add_node( + "node", + node, + metadata={"start_to_close_timeout": timedelta(seconds=10)}, + ) + g.add_edge(START, "node") + + task_queue = f"my-graph-{uuid4()}" + + async with Worker( + client, + task_queue=task_queue, + workflows=[InterruptWorkflow], + plugins=[LangGraphPlugin(graphs={"my-graph": g})], + ): + result = await client.execute_workflow( + InterruptWorkflow.run, + "", + id=f"test-workflow-{uuid4()}", + task_queue=task_queue, + ) + + assert result == "yes" diff --git a/tests/contrib/langgraph/test_interrupt_v2.py b/tests/contrib/langgraph/test_interrupt_v2.py new file mode 100644 index 000000000..1167971ca --- /dev/null +++ b/tests/contrib/langgraph/test_interrupt_v2.py @@ -0,0 +1,71 @@ +"""Test Graph API interrupt handling with version="v2". + +With v2, ainvoke() returns a GraphOutput dataclass with .value and .interrupts +instead of mixing __interrupt__ into the state dict. +""" + +from datetime import timedelta +from typing import Any +from uuid import uuid4 + +from langgraph.checkpoint.memory import InMemorySaver +from langgraph.graph import START, StateGraph +from langgraph.graph.state import RunnableConfig +from langgraph.types import Command, interrupt +from temporalio import workflow +from temporalio.client import Client +from temporalio.worker import Worker + +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph + + +async def node(_: str) -> str: + return interrupt("Continue?") + + +@workflow.defn +class InterruptV2Workflow: + @workflow.run + async def run(self, input: str) -> Any: + g = graph("interrupt-v2-graph").compile( + checkpointer=InMemorySaver() + ) + config = RunnableConfig( + {"configurable": {"thread_id": "1"}} + ) + + result = await g.ainvoke(input, config, version="v2") + + # v2: interrupts are on result.interrupts, not result["__interrupt__"] + assert result.value == {} + assert len(result.interrupts) == 1 + assert result.interrupts[0].value == "Continue?" + + return await g.ainvoke(Command(resume="yes"), config) + + +async def test_interrupt_v2(client: Client): + g = StateGraph(str) + g.add_node( + "node", + node, + metadata={"start_to_close_timeout": timedelta(seconds=10)}, + ) + g.add_edge(START, "node") + + task_queue = f"interrupt-v2-{uuid4()}" + + async with Worker( + client, + task_queue=task_queue, + workflows=[InterruptV2Workflow], + plugins=[LangGraphPlugin(graphs={"interrupt-v2-graph": g})], + ): + result = await client.execute_workflow( + InterruptV2Workflow.run, + "", + id=f"test-interrupt-v2-{uuid4()}", + task_queue=task_queue, + ) + + assert result == "yes" diff --git a/tests/contrib/langgraph/test_streaming.py b/tests/contrib/langgraph/test_streaming.py new file mode 100644 index 000000000..a8959aec5 --- /dev/null +++ b/tests/contrib/langgraph/test_streaming.py @@ -0,0 +1,61 @@ +from datetime import timedelta +from typing import Any +from uuid import uuid4 + +from langgraph.graph import START, StateGraph +from temporalio import workflow +from temporalio.client import Client +from temporalio.worker import Worker + +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph + + +async def node_a(state: str) -> str: + return state + "a" + + +async def node_b(state: str) -> str: + return state + "b" + + +@workflow.defn +class StreamingWorkflow: + @workflow.run + async def run(self, input: str) -> Any: + chunks = [] + async for chunk in graph("streaming").compile().astream(input): + chunks.append(chunk) + return chunks + + +async def test_streaming(client: Client): + g = StateGraph(str) + g.add_node( + "node_a", + node_a, + metadata={"start_to_close_timeout": timedelta(seconds=10)}, + ) + g.add_node( + "node_b", + node_b, + metadata={"start_to_close_timeout": timedelta(seconds=10)}, + ) + g.add_edge(START, "node_a") + g.add_edge("node_a", "node_b") + + task_queue = f"streaming-{uuid4()}" + + async with Worker( + client, + task_queue=task_queue, + workflows=[StreamingWorkflow], + plugins=[LangGraphPlugin(graphs={"streaming": g})], + ): + chunks = await client.execute_workflow( + StreamingWorkflow.run, + "", + id=f"test-streaming-{uuid4()}", + task_queue=task_queue, + ) + + assert chunks == [{"node_a": "a"}, {"node_b": "ab"}] diff --git a/tests/contrib/langgraph/test_subgraph_activity.py b/tests/contrib/langgraph/test_subgraph_activity.py new file mode 100644 index 000000000..8f603dc2b --- /dev/null +++ b/tests/contrib/langgraph/test_subgraph_activity.py @@ -0,0 +1,56 @@ +from datetime import timedelta +from typing import Any +from uuid import uuid4 + +from langgraph.graph import START, StateGraph +from temporalio import workflow +from temporalio.client import Client +from temporalio.worker import Worker + +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph + + +async def child_node(_: str) -> str: + return "child" + + +async def parent_node(state: str) -> str: + child = StateGraph(str) + child.add_node("child_node", child_node) + child.add_edge(START, "child_node") + + return await child.compile().ainvoke(state) + + +@workflow.defn +class ActivitySubgraphWorkflow: + @workflow.run + async def run(self, input: str) -> Any: + return await graph("parent").compile().ainvoke(input) + + +async def test_activity_subgraph(client: Client): + parent = StateGraph(str) + parent.add_node( + "parent_node", + parent_node, + metadata={"start_to_close_timeout": timedelta(seconds=10)}, + ) + parent.add_edge(START, "parent_node") + + task_queue = f"subgraph-{uuid4()}" + + async with Worker( + client, + task_queue=task_queue, + workflows=[ActivitySubgraphWorkflow], + plugins=[LangGraphPlugin(graphs={"parent": parent})], + ): + result = await client.execute_workflow( + ActivitySubgraphWorkflow.run, + "", + id=f"test-workflow-{uuid4()}", + task_queue=task_queue, + ) + + assert result == "child" diff --git a/tests/contrib/langgraph/test_subgraph_workflow.py b/tests/contrib/langgraph/test_subgraph_workflow.py new file mode 100644 index 000000000..5c9629d32 --- /dev/null +++ b/tests/contrib/langgraph/test_subgraph_workflow.py @@ -0,0 +1,56 @@ +from datetime import timedelta +from typing import Any +from uuid import uuid4 + +from langgraph.graph import START, StateGraph +from temporalio import workflow +from temporalio.client import Client +from temporalio.worker import Worker + +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph + + +async def child_node(_: str) -> str: + return "child" + + +async def parent_node(state: str) -> str: + return await graph("child").compile().ainvoke(state) + + +@workflow.defn +class WorkflowSubgraphWorkflow: + @workflow.run + async def run(self, input: str) -> Any: + return await graph("parent").compile().ainvoke(input) + + +async def test_workflow_subgraph(client: Client): + child = StateGraph(str) + child.add_node( + "child_node", + child_node, + metadata={"start_to_close_timeout": timedelta(seconds=10)}, + ) + child.add_edge(START, "child_node") + + parent = StateGraph(str) + parent.add_node("parent_node", parent_node, metadata={"execute_in": "workflow"}) + parent.add_edge(START, "parent_node") + + task_queue = f"subgraph-{uuid4()}" + + async with Worker( + client, + task_queue=task_queue, + workflows=[WorkflowSubgraphWorkflow], + plugins=[LangGraphPlugin(graphs={"parent": parent, "child": child})], + ): + result = await client.execute_workflow( + WorkflowSubgraphWorkflow.run, + "", + id=f"test-workflow-{uuid4()}", + task_queue=task_queue, + ) + + assert result == "child" diff --git a/tests/contrib/langgraph/test_timeout.py b/tests/contrib/langgraph/test_timeout.py new file mode 100644 index 000000000..b6b062e5f --- /dev/null +++ b/tests/contrib/langgraph/test_timeout.py @@ -0,0 +1,54 @@ +from asyncio import sleep +from datetime import timedelta +from typing import Any +from uuid import uuid4 + +from langgraph.graph import START, StateGraph +from pytest import raises +from temporalio import workflow +from temporalio.client import Client, WorkflowFailureError +from temporalio.common import RetryPolicy +from temporalio.worker import Worker + +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph + + +async def node(_: str) -> str: + await sleep(1) # 1 second + return "done" + + +@workflow.defn +class TimeoutWorkflow: + @workflow.run + async def run(self, input: str) -> Any: + return await graph("my-graph").compile().ainvoke(input) + + +async def test_timeout(client: Client): + g = StateGraph(str) + g.add_node( + "node", + node, + metadata={ + "start_to_close_timeout": timedelta(milliseconds=100), + "retry_policy": RetryPolicy(maximum_attempts=1), + }, + ) + g.add_edge(START, "node") + + task_queue = f"my-graph-{uuid4()}" + + async with Worker( + client, + task_queue=task_queue, + workflows=[TimeoutWorkflow], + plugins=[LangGraphPlugin(graphs={"my-graph": g})], + ): + with raises(WorkflowFailureError): + await client.execute_workflow( + TimeoutWorkflow.run, + "", + id=f"test-workflow-{uuid4()}", + task_queue=task_queue, + ) diff --git a/tests/contrib/langgraph/test_two_nodes.py b/tests/contrib/langgraph/test_two_nodes.py new file mode 100644 index 000000000..06cc13071 --- /dev/null +++ b/tests/contrib/langgraph/test_two_nodes.py @@ -0,0 +1,58 @@ +from datetime import timedelta +from typing import Any +from uuid import uuid4 + +from langgraph.graph import START, StateGraph +from temporalio import workflow +from temporalio.client import Client +from temporalio.worker import Worker + +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph + + +async def node_a(state: str) -> str: + return state + "a" + + +async def node_b(state: str) -> str: + return state + "b" + + +@workflow.defn +class TwoNodesWorkflow: + @workflow.run + async def run(self, input: str) -> Any: + return await graph("my-graph").compile().ainvoke(input) + + +async def test_two_nodes(client: Client): + g = StateGraph(str) + g.add_node( + "node_a", + node_a, + metadata={"start_to_close_timeout": timedelta(seconds=10)}, + ) + g.add_node( + "node_b", + node_b, + metadata={"start_to_close_timeout": timedelta(seconds=10)}, + ) + g.add_edge(START, "node_a") + g.add_edge("node_a", "node_b") + + task_queue = f"my-graph-{uuid4()}" + + async with Worker( + client, + task_queue=task_queue, + workflows=[TwoNodesWorkflow], + plugins=[LangGraphPlugin(graphs={"my-graph": g})], + ): + result = await client.execute_workflow( + TwoNodesWorkflow.run, + "", + id=f"test-workflow-{uuid4()}", + task_queue=task_queue, + ) + + assert result == "ab" diff --git a/uv.lock b/uv.lock index 6d824cf92..053eae852 100644 --- a/uv.lock +++ b/uv.lock @@ -2471,6 +2471,81 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/81/db/e655086b7f3a705df045bf0933bdd9c2f79bb3c97bfef1384598bb79a217/keyring-25.7.0-py3-none-any.whl", hash = "sha256:be4a0b195f149690c166e850609a477c532ddbfbaed96a404d4e43f8d5e2689f", size = 39160, upload-time = "2025-11-16T16:26:08.402Z" }, ] +[[package]] +name = "langchain-core" +version = "1.2.29" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jsonpatch" }, + { name = "langsmith" }, + { name = "packaging" }, + { name = "pydantic" }, + { name = "pyyaml" }, + { name = "tenacity" }, + { name = "typing-extensions" }, + { name = "uuid-utils" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a0/d8/7bdf30e4bfc5175609201806e399506a0a78a48e14367dc8b776a9b4c89c/langchain_core-1.2.29.tar.gz", hash = "sha256:cfb89c92bca81ad083eafcdfe6ec40f9803c9abf7dd166d0f8a8de1d2de03ca6", size = 846121, upload-time = "2026-04-14T20:44:58.117Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/72/37/fed31f80436b1d7bb222f1f2345300a77a88215416acf8d1cb7c8fda7388/langchain_core-1.2.29-py3-none-any.whl", hash = "sha256:11f02e57ee1c24e6e0e6577acbd35df77b205d4692a3df956b03b5389cbe44a0", size = 508733, upload-time = "2026-04-14T20:44:56.712Z" }, +] + +[[package]] +name = "langgraph" +version = "1.1.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "langchain-core" }, + { name = "langgraph-checkpoint" }, + { name = "langgraph-prebuilt" }, + { name = "langgraph-sdk" }, + { name = "pydantic" }, + { name = "xxhash" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5c/e5/d3f72ead3c7f15769d5a9c07e373628f1fbaf6cbe7735694d7085859acf6/langgraph-1.1.6.tar.gz", hash = "sha256:1783f764b08a607e9f288dbcf6da61caeb0dd40b337e5c9fb8b412341fbc0b60", size = 549634, upload-time = "2026-04-03T19:01:32.561Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/71/e6/b36ecdb3ff4ba9a290708d514bae89ebbe2f554b6abbe4642acf3fddbe51/langgraph-1.1.6-py3-none-any.whl", hash = "sha256:fdbf5f54fa5a5a4c4b09b7b5e537f1b2fa283d2f0f610d3457ddeecb479458b9", size = 169755, upload-time = "2026-04-03T19:01:30.686Z" }, +] + +[[package]] +name = "langgraph-checkpoint" +version = "4.0.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "langchain-core" }, + { name = "ormsgpack" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b1/44/a8df45d1e8b4637e29789fa8bae1db022c953cc7ac80093cfc52e923547e/langgraph_checkpoint-4.0.1.tar.gz", hash = "sha256:b433123735df11ade28829e40ce25b9be614930cd50245ff2af60629234befd9", size = 158135, upload-time = "2026-02-27T21:06:16.092Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/65/4c/09a4a0c42f5d2fc38d6c4d67884788eff7fd2cfdf367fdf7033de908b4c0/langgraph_checkpoint-4.0.1-py3-none-any.whl", hash = "sha256:e3adcd7a0e0166f3b48b8cf508ce0ea366e7420b5a73aa81289888727769b034", size = 50453, upload-time = "2026-02-27T21:06:14.293Z" }, +] + +[[package]] +name = "langgraph-prebuilt" +version = "1.0.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "langchain-core" }, + { name = "langgraph-checkpoint" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/99/4c/06dac899f4945bedb0c3a1583c19484c2cc894114ea30d9a538dd270086e/langgraph_prebuilt-1.0.9.tar.gz", hash = "sha256:93de7512e9caade4b77ead92428f6215c521fdb71b8ffda8cd55f0ad814e64de", size = 165850, upload-time = "2026-04-03T14:06:37.721Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1d/a2/8368ac187b75e7f9d938ca075d34f116683f5cfc48d924029ee79aea147b/langgraph_prebuilt-1.0.9-py3-none-any.whl", hash = "sha256:776c8e3154a5aef5ad0e5bf3f263f2dcaab3983786cc20014b7f955d99d2d1b2", size = 35958, upload-time = "2026-04-03T14:06:36.58Z" }, +] + +[[package]] +name = "langgraph-sdk" +version = "0.3.13" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "httpx" }, + { name = "orjson" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0e/db/77a45127dddcfea5e4256ba916182903e4c31dc4cfca305b8c386f0a9e53/langgraph_sdk-0.3.13.tar.gz", hash = "sha256:419ca5663eec3cec192ad194ac0647c0c826866b446073eb40f384f950986cd5", size = 196360, upload-time = "2026-04-07T20:34:18.766Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fe/ef/64d64e9f8eea47ce7b939aa6da6863b674c8d418647813c20111645fcc62/langgraph_sdk-0.3.13-py3-none-any.whl", hash = "sha256:aee09e345c90775f6de9d6f4c7b847cfc652e49055c27a2aed0d981af2af3bd0", size = 96668, upload-time = "2026-04-07T20:34:17.866Z" }, +] + [[package]] name = "langsmith" version = "0.7.26" @@ -3678,6 +3753,62 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c0/d1/facb5b5051fabb0ef9d26c6544d87ef19a939a9a001198655d0d891062dd/orjson-3.11.8-cp314-cp314-win_arm64.whl", hash = "sha256:6ccdea2c213cf9f3d9490cbd5d427693c870753df41e6cb375bd79bcbafc8817", size = 127330, upload-time = "2026-03-31T16:16:25.496Z" }, ] +[[package]] +name = "ormsgpack" +version = "1.12.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/12/0c/f1761e21486942ab9bb6feaebc610fa074f7c5e496e6962dea5873348077/ormsgpack-1.12.2.tar.gz", hash = "sha256:944a2233640273bee67521795a73cf1e959538e0dfb7ac635505010455e53b33", size = 39031, upload-time = "2026-01-18T20:55:28.023Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/93/fa/a91f70829ebccf6387c4946e0a1a109f6ba0d6a28d65f628bedfad94b890/ormsgpack-1.12.2-cp310-cp310-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:c1429217f8f4d7fcb053523bbbac6bed5e981af0b85ba616e6df7cce53c19657", size = 378262, upload-time = "2026-01-18T20:55:22.284Z" }, + { url = "https://files.pythonhosted.org/packages/5f/62/3698a9a0c487252b5c6a91926e5654e79e665708ea61f67a8bdeceb022bf/ormsgpack-1.12.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5f13034dc6c84a6280c6c33db7ac420253852ea233fc3ee27c8875f8dd651163", size = 203034, upload-time = "2026-01-18T20:55:53.324Z" }, + { url = "https://files.pythonhosted.org/packages/66/3a/f716f64edc4aec2744e817660b317e2f9bb8de372338a95a96198efa1ac1/ormsgpack-1.12.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:59f5da97000c12bc2d50e988bdc8576b21f6ab4e608489879d35b2c07a8ab51a", size = 210538, upload-time = "2026-01-18T20:55:20.097Z" }, + { url = "https://files.pythonhosted.org/packages/72/30/a436be9ce27d693d4e19fa94900028067133779f09fc45776db3f689c822/ormsgpack-1.12.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e4459c3f27066beadb2b81ea48a076a417aafffff7df1d3c11c519190ed44f2", size = 212401, upload-time = "2026-01-18T20:55:46.447Z" }, + { url = "https://files.pythonhosted.org/packages/10/c5/cde98300fd33fee84ca71de4751b19aeeca675f0cf3c0ec4b043f40f3b76/ormsgpack-1.12.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:7a1c460655d7288407ffa09065e322a7231997c0d62ce914bf3a96ad2dc6dedd", size = 387080, upload-time = "2026-01-18T20:56:00.884Z" }, + { url = "https://files.pythonhosted.org/packages/6a/31/30bf445ef827546747c10889dd254b3d84f92b591300efe4979d792f4c41/ormsgpack-1.12.2-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:458e4568be13d311ef7d8877275e7ccbe06c0e01b39baaac874caaa0f46d826c", size = 482346, upload-time = "2026-01-18T20:55:39.831Z" }, + { url = "https://files.pythonhosted.org/packages/2e/f5/e1745ddf4fa246c921b5ca253636c4c700ff768d78032f79171289159f6e/ormsgpack-1.12.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8cde5eaa6c6cbc8622db71e4a23de56828e3d876aeb6460ffbcb5b8aff91093b", size = 425178, upload-time = "2026-01-18T20:55:27.106Z" }, + { url = "https://files.pythonhosted.org/packages/8d/a2/e6532ed7716aed03dede8df2d0d0d4150710c2122647d94b474147ccd891/ormsgpack-1.12.2-cp310-cp310-win_amd64.whl", hash = "sha256:dc7a33be14c347893edbb1ceda89afbf14c467d593a5ee92c11de4f1666b4d4f", size = 117183, upload-time = "2026-01-18T20:55:55.52Z" }, + { url = "https://files.pythonhosted.org/packages/4b/08/8b68f24b18e69d92238aa8f258218e6dfeacf4381d9d07ab8df303f524a9/ormsgpack-1.12.2-cp311-cp311-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:bd5f4bf04c37888e864f08e740c5a573c4017f6fd6e99fa944c5c935fabf2dd9", size = 378266, upload-time = "2026-01-18T20:55:59.876Z" }, + { url = "https://files.pythonhosted.org/packages/0d/24/29fc13044ecb7c153523ae0a1972269fcd613650d1fa1a9cec1044c6b666/ormsgpack-1.12.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:34d5b28b3570e9fed9a5a76528fc7230c3c76333bc214798958e58e9b79cc18a", size = 203035, upload-time = "2026-01-18T20:55:30.59Z" }, + { url = "https://files.pythonhosted.org/packages/ad/c2/00169fb25dd8f9213f5e8a549dfb73e4d592009ebc85fbbcd3e1dcac575b/ormsgpack-1.12.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3708693412c28f3538fb5a65da93787b6bbab3484f6bc6e935bfb77a62400ae5", size = 210539, upload-time = "2026-01-18T20:55:48.569Z" }, + { url = "https://files.pythonhosted.org/packages/1b/33/543627f323ff3c73091f51d6a20db28a1a33531af30873ea90c5ac95a9b5/ormsgpack-1.12.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:43013a3f3e2e902e1d05e72c0f1aeb5bedbb8e09240b51e26792a3c89267e181", size = 212401, upload-time = "2026-01-18T20:56:10.101Z" }, + { url = "https://files.pythonhosted.org/packages/e8/5d/f70e2c3da414f46186659d24745483757bcc9adccb481a6eb93e2b729301/ormsgpack-1.12.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:7c8b1667a72cbba74f0ae7ecf3105a5e01304620ed14528b2cb4320679d2869b", size = 387082, upload-time = "2026-01-18T20:56:12.047Z" }, + { url = "https://files.pythonhosted.org/packages/c0/d6/06e8dc920c7903e051f30934d874d4afccc9bb1c09dcaf0bc03a7de4b343/ormsgpack-1.12.2-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:df6961442140193e517303d0b5d7bc2e20e69a879c2d774316125350c4a76b92", size = 482346, upload-time = "2026-01-18T20:56:05.152Z" }, + { url = "https://files.pythonhosted.org/packages/66/c4/f337ac0905eed9c393ef990c54565cd33644918e0a8031fe48c098c71dbf/ormsgpack-1.12.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c6a4c34ddef109647c769d69be65fa1de7a6022b02ad45546a69b3216573eb4a", size = 425181, upload-time = "2026-01-18T20:55:37.83Z" }, + { url = "https://files.pythonhosted.org/packages/78/29/6d5758fabef3babdf4bbbc453738cc7de9cd3334e4c38dd5737e27b85653/ormsgpack-1.12.2-cp311-cp311-win_amd64.whl", hash = "sha256:73670ed0375ecc303858e3613f407628dd1fca18fe6ac57b7b7ce66cc7bb006c", size = 117182, upload-time = "2026-01-18T20:55:31.472Z" }, + { url = "https://files.pythonhosted.org/packages/c4/57/17a15549233c37e7fd054c48fe9207492e06b026dbd872b826a0b5f833b6/ormsgpack-1.12.2-cp311-cp311-win_arm64.whl", hash = "sha256:c2be829954434e33601ae5da328cccce3266b098927ca7a30246a0baec2ce7bd", size = 111464, upload-time = "2026-01-18T20:55:38.811Z" }, + { url = "https://files.pythonhosted.org/packages/4c/36/16c4b1921c308a92cef3bf6663226ae283395aa0ff6e154f925c32e91ff5/ormsgpack-1.12.2-cp312-cp312-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:7a29d09b64b9694b588ff2f80e9826bdceb3a2b91523c5beae1fab27d5c940e7", size = 378618, upload-time = "2026-01-18T20:55:50.835Z" }, + { url = "https://files.pythonhosted.org/packages/c0/68/468de634079615abf66ed13bb5c34ff71da237213f29294363beeeca5306/ormsgpack-1.12.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b39e629fd2e1c5b2f46f99778450b59454d1f901bc507963168985e79f09c5d", size = 203186, upload-time = "2026-01-18T20:56:11.163Z" }, + { url = "https://files.pythonhosted.org/packages/73/a9/d756e01961442688b7939bacd87ce13bfad7d26ce24f910f6028178b2cc8/ormsgpack-1.12.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:958dcb270d30a7cb633a45ee62b9444433fa571a752d2ca484efdac07480876e", size = 210738, upload-time = "2026-01-18T20:56:09.181Z" }, + { url = "https://files.pythonhosted.org/packages/7b/ba/795b1036888542c9113269a3f5690ab53dd2258c6fb17676ac4bd44fcf94/ormsgpack-1.12.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58d379d72b6c5e964851c77cfedfb386e474adee4fd39791c2c5d9efb53505cc", size = 212569, upload-time = "2026-01-18T20:56:06.135Z" }, + { url = "https://files.pythonhosted.org/packages/6c/aa/bff73c57497b9e0cba8837c7e4bcab584b1a6dbc91a5dd5526784a5030c8/ormsgpack-1.12.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8463a3fc5f09832e67bdb0e2fda6d518dc4281b133166146a67f54c08496442e", size = 387166, upload-time = "2026-01-18T20:55:36.738Z" }, + { url = "https://files.pythonhosted.org/packages/d3/cf/f8283cba44bcb7b14f97b6274d449db276b3a86589bdb363169b51bc12de/ormsgpack-1.12.2-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:eddffb77eff0bad4e67547d67a130604e7e2dfbb7b0cde0796045be4090f35c6", size = 482498, upload-time = "2026-01-18T20:55:29.626Z" }, + { url = "https://files.pythonhosted.org/packages/05/be/71e37b852d723dfcbe952ad04178c030df60d6b78eba26bfd14c9a40575e/ormsgpack-1.12.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fcd55e5f6ba0dbce624942adf9f152062135f991a0126064889f68eb850de0dd", size = 425518, upload-time = "2026-01-18T20:55:49.556Z" }, + { url = "https://files.pythonhosted.org/packages/7a/0c/9803aa883d18c7ef197213cd2cbf73ba76472a11fe100fb7dab2884edf48/ormsgpack-1.12.2-cp312-cp312-win_amd64.whl", hash = "sha256:d024b40828f1dde5654faebd0d824f9cc29ad46891f626272dd5bfd7af2333a4", size = 117462, upload-time = "2026-01-18T20:55:47.726Z" }, + { url = "https://files.pythonhosted.org/packages/c8/9e/029e898298b2cc662f10d7a15652a53e3b525b1e7f07e21fef8536a09bb8/ormsgpack-1.12.2-cp312-cp312-win_arm64.whl", hash = "sha256:da538c542bac7d1c8f3f2a937863dba36f013108ce63e55745941dda4b75dbb6", size = 111559, upload-time = "2026-01-18T20:55:54.273Z" }, + { url = "https://files.pythonhosted.org/packages/eb/29/bb0eba3288c0449efbb013e9c6f58aea79cf5cb9ee1921f8865f04c1a9d7/ormsgpack-1.12.2-cp313-cp313-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:5ea60cb5f210b1cfbad8c002948d73447508e629ec375acb82910e3efa8ff355", size = 378661, upload-time = "2026-01-18T20:55:57.765Z" }, + { url = "https://files.pythonhosted.org/packages/6e/31/5efa31346affdac489acade2926989e019e8ca98129658a183e3add7af5e/ormsgpack-1.12.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3601f19afdbea273ed70b06495e5794606a8b690a568d6c996a90d7255e51c1", size = 203194, upload-time = "2026-01-18T20:56:08.252Z" }, + { url = "https://files.pythonhosted.org/packages/eb/56/d0087278beef833187e0167f8527235ebe6f6ffc2a143e9de12a98b1ce87/ormsgpack-1.12.2-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:29a9f17a3dac6054c0dce7925e0f4995c727f7c41859adf9b5572180f640d172", size = 210778, upload-time = "2026-01-18T20:55:17.694Z" }, + { url = "https://files.pythonhosted.org/packages/1c/a2/072343e1413d9443e5a252a8eb591c2d5b1bffbe5e7bfc78c069361b92eb/ormsgpack-1.12.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39c1bd2092880e413902910388be8715f70b9f15f20779d44e673033a6146f2d", size = 212592, upload-time = "2026-01-18T20:55:32.747Z" }, + { url = "https://files.pythonhosted.org/packages/a2/8b/a0da3b98a91d41187a63b02dda14267eefc2a74fcb43cc2701066cf1510e/ormsgpack-1.12.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:50b7249244382209877deedeee838aef1542f3d0fc28b8fe71ca9d7e1896a0d7", size = 387164, upload-time = "2026-01-18T20:55:40.853Z" }, + { url = "https://files.pythonhosted.org/packages/19/bb/6d226bc4cf9fc20d8eb1d976d027a3f7c3491e8f08289a2e76abe96a65f3/ormsgpack-1.12.2-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:5af04800d844451cf102a59c74a841324868d3f1625c296a06cc655c542a6685", size = 482516, upload-time = "2026-01-18T20:55:42.033Z" }, + { url = "https://files.pythonhosted.org/packages/fb/f1/bb2c7223398543dedb3dbf8bb93aaa737b387de61c5feaad6f908841b782/ormsgpack-1.12.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:cec70477d4371cd524534cd16472d8b9cc187e0e3043a8790545a9a9b296c258", size = 425539, upload-time = "2026-01-18T20:55:24.727Z" }, + { url = "https://files.pythonhosted.org/packages/7b/e8/0fb45f57a2ada1fed374f7494c8cd55e2f88ccd0ab0a669aa3468716bf5f/ormsgpack-1.12.2-cp313-cp313-win_amd64.whl", hash = "sha256:21f4276caca5c03a818041d637e4019bc84f9d6ca8baa5ea03e5cc8bf56140e9", size = 117459, upload-time = "2026-01-18T20:55:56.876Z" }, + { url = "https://files.pythonhosted.org/packages/7a/d4/0cfeea1e960d550a131001a7f38a5132c7ae3ebde4c82af1f364ccc5d904/ormsgpack-1.12.2-cp313-cp313-win_arm64.whl", hash = "sha256:baca4b6773d20a82e36d6fd25f341064244f9f86a13dead95dd7d7f996f51709", size = 111577, upload-time = "2026-01-18T20:55:43.605Z" }, + { url = "https://files.pythonhosted.org/packages/94/16/24d18851334be09c25e87f74307c84950f18c324a4d3c0b41dabdbf19c29/ormsgpack-1.12.2-cp314-cp314-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:bc68dd5915f4acf66ff2010ee47c8906dc1cf07399b16f4089f8c71733f6e36c", size = 378717, upload-time = "2026-01-18T20:55:26.164Z" }, + { url = "https://files.pythonhosted.org/packages/b5/a2/88b9b56f83adae8032ac6a6fa7f080c65b3baf9b6b64fd3d37bd202991d4/ormsgpack-1.12.2-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:46d084427b4132553940070ad95107266656cb646ea9da4975f85cb1a6676553", size = 203183, upload-time = "2026-01-18T20:55:18.815Z" }, + { url = "https://files.pythonhosted.org/packages/a9/80/43e4555963bf602e5bdc79cbc8debd8b6d5456c00d2504df9775e74b450b/ormsgpack-1.12.2-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c010da16235806cf1d7bc4c96bf286bfa91c686853395a299b3ddb49499a3e13", size = 210814, upload-time = "2026-01-18T20:55:33.973Z" }, + { url = "https://files.pythonhosted.org/packages/78/e1/7cfbf28de8bca6efe7e525b329c31277d1b64ce08dcba723971c241a9d60/ormsgpack-1.12.2-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18867233df592c997154ff942a6503df274b5ac1765215bceba7a231bea2745d", size = 212634, upload-time = "2026-01-18T20:55:28.634Z" }, + { url = "https://files.pythonhosted.org/packages/95/f8/30ae5716e88d792a4e879debee195653c26ddd3964c968594ddef0a3cc7e/ormsgpack-1.12.2-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:b009049086ddc6b8f80c76b3955df1aa22a5fbd7673c525cd63bf91f23122ede", size = 387139, upload-time = "2026-01-18T20:56:02.013Z" }, + { url = "https://files.pythonhosted.org/packages/dc/81/aee5b18a3e3a0e52f718b37ab4b8af6fae0d9d6a65103036a90c2a8ffb5d/ormsgpack-1.12.2-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:1dcc17d92b6390d4f18f937cf0b99054824a7815818012ddca925d6e01c2e49e", size = 482578, upload-time = "2026-01-18T20:55:35.117Z" }, + { url = "https://files.pythonhosted.org/packages/bd/17/71c9ba472d5d45f7546317f467a5fc941929cd68fb32796ca3d13dcbaec2/ormsgpack-1.12.2-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:f04b5e896d510b07c0ad733d7fce2d44b260c5e6c402d272128f8941984e4285", size = 425539, upload-time = "2026-01-18T20:56:04.009Z" }, + { url = "https://files.pythonhosted.org/packages/2e/a6/ac99cd7fe77e822fed5250ff4b86fa66dd4238937dd178d2299f10b69816/ormsgpack-1.12.2-cp314-cp314-win_amd64.whl", hash = "sha256:ae3aba7eed4ca7cb79fd3436eddd29140f17ea254b91604aa1eb19bfcedb990f", size = 117493, upload-time = "2026-01-18T20:56:07.343Z" }, + { url = "https://files.pythonhosted.org/packages/3a/67/339872846a1ae4592535385a1c1f93614138566d7af094200c9c3b45d1e5/ormsgpack-1.12.2-cp314-cp314-win_arm64.whl", hash = "sha256:118576ea6006893aea811b17429bfc561b4778fad393f5f538c84af70b01260c", size = 111579, upload-time = "2026-01-18T20:55:21.161Z" }, + { url = "https://files.pythonhosted.org/packages/49/c2/6feb972dc87285ad381749d3882d8aecbde9f6ecf908dd717d33d66df095/ormsgpack-1.12.2-cp314-cp314t-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:7121b3d355d3858781dc40dafe25a32ff8a8242b9d80c692fd548a4b1f7fd3c8", size = 378721, upload-time = "2026-01-18T20:55:52.12Z" }, + { url = "https://files.pythonhosted.org/packages/a3/9a/900a6b9b413e0f8a471cf07830f9cf65939af039a362204b36bd5b581d8b/ormsgpack-1.12.2-cp314-cp314t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ee766d2e78251b7a63daf1cddfac36a73562d3ddef68cacfb41b2af64698033", size = 203170, upload-time = "2026-01-18T20:55:44.469Z" }, + { url = "https://files.pythonhosted.org/packages/87/4c/27a95466354606b256f24fad464d7c97ab62bce6cc529dd4673e1179b8fb/ormsgpack-1.12.2-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:292410a7d23de9b40444636b9b8f1e4e4b814af7f1ef476e44887e52a123f09d", size = 212816, upload-time = "2026-01-18T20:55:23.501Z" }, + { url = "https://files.pythonhosted.org/packages/73/cd/29cee6007bddf7a834e6cd6f536754c0535fcb939d384f0f37a38b1cddb8/ormsgpack-1.12.2-cp314-cp314t-win_amd64.whl", hash = "sha256:837dd316584485b72ef451d08dd3e96c4a11d12e4963aedb40e08f89685d8ec2", size = 117232, upload-time = "2026-01-18T20:55:45.448Z" }, +] + [[package]] name = "packaging" version = "26.0" @@ -5035,6 +5166,9 @@ lambda-worker-otel = [ { name = "opentelemetry-sdk-extension-aws" }, { name = "opentelemetry-semantic-conventions" }, ] +langgraph = [ + { name = "langgraph" }, +] langsmith = [ { name = "langsmith" }, ] @@ -5057,6 +5191,7 @@ dev = [ { name = "googleapis-common-protos" }, { name = "grpcio-tools" }, { name = "httpx" }, + { name = "langgraph" }, { name = "langsmith" }, { name = "litellm" }, { name = "maturin" }, @@ -5092,6 +5227,7 @@ requires-dist = [ { name = "aioboto3", marker = "extra == 'aioboto3'", specifier = ">=10.4.0" }, { name = "google-adk", marker = "extra == 'google-adk'", specifier = ">=1.27.0,<2" }, { name = "grpcio", marker = "extra == 'grpc'", specifier = ">=1.48.2,<2" }, + { name = "langgraph", marker = "extra == 'langgraph'", specifier = ">=1.1.6" }, { name = "langsmith", marker = "extra == 'langsmith'", specifier = ">=0.7.0,<0.8" }, { name = "mcp", marker = "extra == 'openai-agents'", specifier = ">=1.9.4,<2" }, { name = "nexus-rpc", specifier = "==1.4.0" }, @@ -5110,7 +5246,7 @@ requires-dist = [ { name = "types-protobuf", specifier = ">=3.20,<7.0.0" }, { name = "typing-extensions", specifier = ">=4.2.0,<5" }, ] -provides-extras = ["grpc", "opentelemetry", "pydantic", "openai-agents", "google-adk", "langsmith", "lambda-worker-otel", "aioboto3"] +provides-extras = ["grpc", "opentelemetry", "pydantic", "openai-agents", "google-adk", "langgraph", "langsmith", "lambda-worker-otel", "aioboto3"] [package.metadata.requires-dev] dev = [ @@ -5119,6 +5255,7 @@ dev = [ { name = "googleapis-common-protos", specifier = "==1.70.0" }, { name = "grpcio-tools", specifier = ">=1.48.2,<2" }, { name = "httpx", specifier = ">=0.28.1" }, + { name = "langgraph", specifier = ">=1.1.6" }, { name = "langsmith", specifier = ">=0.7.0,<0.8" }, { name = "litellm", specifier = ">=1.83.0" }, { name = "maturin", specifier = ">=1.8.2" }, From 8f30e56c3a9a8cf8623b5114c61ee07b33a96c5d Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Tue, 14 Apr 2026 14:54:39 -0700 Subject: [PATCH 02/47] add experimental package warnings --- temporalio/contrib/langgraph/README.md | 6 ++++-- temporalio/contrib/langgraph/__init__.py | 14 +++++++++---- .../contrib/langgraph/langgraph_plugin.py | 20 +++++++++++++++++-- 3 files changed, 32 insertions(+), 8 deletions(-) diff --git a/temporalio/contrib/langgraph/README.md b/temporalio/contrib/langgraph/README.md index 63911c570..f437fe476 100644 --- a/temporalio/contrib/langgraph/README.md +++ b/temporalio/contrib/langgraph/README.md @@ -1,6 +1,8 @@ -# LangGraph Temporal Plugin +# LangGraph Plugin for Temporal Python SDK -A [Temporal](https://temporal.io) plugin that runs [LangGraph](https://www.langchain.com/langgraph) nodes and tasks as Temporal Activities, giving your AI workflows durable execution, automatic retries, and timeouts. +⚠️ **This package is currently at an experimental release stage.** ⚠️ + +This Temporal [Plugin](https://docs.temporal.io/develop/plugins-guide) allows you to run [LangGraph](https://www.langchain.com/langgraph) nodes and tasks as Temporal Activities, giving your AI workflows durable execution, automatic retries, and timeouts. It supports both the LangGraph Graph API (``StateGraph``) and Functional API (``@entrypoint`` / ``@task``). ## Installation diff --git a/temporalio/contrib/langgraph/__init__.py b/temporalio/contrib/langgraph/__init__.py index df32ca5ee..50d8ca147 100644 --- a/temporalio/contrib/langgraph/__init__.py +++ b/temporalio/contrib/langgraph/__init__.py @@ -1,13 +1,19 @@ -"""Support for using LangGraph as part of Temporal workflows. +"""LangGraph plugin for Temporal SDK. -This module provides compatibility between -`LangGraph `_ and Temporal workflows. +.. warning:: + This package is experimental and may change in future versions. + Use with caution in production environments. + +This plugin runs `LangGraph `_ nodes +and tasks as Temporal Activities, giving your AI agent workflows durable +execution, automatic retries, and timeouts. It supports both the LangGraph Graph +API (``StateGraph``) and Functional API (``@entrypoint`` / ``@task``). """ from temporalio.contrib.langgraph.langgraph_plugin import ( LangGraphPlugin, - entrypoint, cache, + entrypoint, graph, ) diff --git a/temporalio/contrib/langgraph/langgraph_plugin.py b/temporalio/contrib/langgraph/langgraph_plugin.py index f392845b4..8610631b2 100644 --- a/temporalio/contrib/langgraph/langgraph_plugin.py +++ b/temporalio/contrib/langgraph/langgraph_plugin.py @@ -1,13 +1,17 @@ from dataclasses import replace from typing import Any, Callable -from temporalio.contrib.langgraph.activity import wrap_activity, wrap_execute_activity from langgraph._internal._runnable import RunnableCallable from langgraph.graph import StateGraph from langgraph.pregel import Pregel -from temporalio.contrib.langgraph.task_cache import _get_task_cache, _set_task_cache, _task_id from temporalio import activity +from temporalio.contrib.langgraph.activity import wrap_activity, wrap_execute_activity +from temporalio.contrib.langgraph.task_cache import ( + _get_task_cache, + _set_task_cache, + _task_id, +) from temporalio.plugin import SimplePlugin from temporalio.worker import WorkflowRunner from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner @@ -18,6 +22,18 @@ class LangGraphPlugin(SimplePlugin): + """LangGraph plugin for Temporal SDK. + + .. warning:: + This package is experimental and may change in future versions. + Use with caution in production environments. + + This plugin runs `LangGraph `_ nodes + and tasks as Temporal Activities, giving your AI agent workflows durable + execution, automatic retries, and timeouts. It supports both the LangGraph Graph + API (``StateGraph``) and Functional API (``@entrypoint`` / ``@task``). + """ + def __init__( self, # Graph API From d1470c7914b437a2f6b515cfc1266c00cc644055 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Tue, 14 Apr 2026 15:09:40 -0700 Subject: [PATCH 03/47] fix ruff lint --- temporalio/contrib/langgraph/activity.py | 13 ++++++--- tests/contrib/langgraph/conftest.py | 1 + .../langgraph/e2e_functional_workflows.py | 3 +-- .../contrib/langgraph/test_continue_as_new.py | 12 +++------ .../langgraph/test_continue_as_new_cached.py | 24 ++++++++--------- .../contrib/langgraph/test_e2e_functional.py | 27 +++++++++---------- .../langgraph/test_e2e_functional_v2.py | 12 ++++----- .../langgraph/test_execute_in_workflow.py | 4 +-- tests/contrib/langgraph/test_interrupt.py | 12 +++------ tests/contrib/langgraph/test_interrupt_v2.py | 12 +++------ tests/contrib/langgraph/test_streaming.py | 4 +-- .../langgraph/test_subgraph_activity.py | 4 +-- .../langgraph/test_subgraph_workflow.py | 4 +-- tests/contrib/langgraph/test_timeout.py | 4 +-- tests/contrib/langgraph/test_two_nodes.py | 4 +-- 15 files changed, 66 insertions(+), 74 deletions(-) diff --git a/temporalio/contrib/langgraph/activity.py b/temporalio/contrib/langgraph/activity.py index b30fdd4a1..6ba828bf0 100644 --- a/temporalio/contrib/langgraph/activity.py +++ b/temporalio/contrib/langgraph/activity.py @@ -4,9 +4,12 @@ from langgraph.errors import GraphInterrupt from langgraph.types import Interrupt -from temporalio import workflow -from temporalio.contrib.langgraph.langgraph_config import get_langgraph_config, set_langgraph_config +from temporalio import workflow +from temporalio.contrib.langgraph.langgraph_config import ( + get_langgraph_config, + set_langgraph_config, +) @dataclass @@ -43,7 +46,11 @@ def wrap_execute_activity( **execute_activity_kwargs: dict[str, Any], ) -> Callable: async def wrapper(*args: Any, **kwargs: dict[str, Any]) -> Any: - from temporalio.contrib.langgraph.task_cache import _cache_key, _cache_lookup, _cache_put + from temporalio.contrib.langgraph.task_cache import ( + _cache_key, + _cache_lookup, + _cache_put, + ) # Check task result cache (for continue-as-new deduplication). key = _cache_key(task_id, args, kwargs) if task_id else "" diff --git a/tests/contrib/langgraph/conftest.py b/tests/contrib/langgraph/conftest.py index f56b0ad96..ef4469b0b 100644 --- a/tests/contrib/langgraph/conftest.py +++ b/tests/contrib/langgraph/conftest.py @@ -3,6 +3,7 @@ from pytest import fixture from pytest_asyncio import fixture as async_fixture + from temporalio.client import Client from temporalio.testing import WorkflowEnvironment diff --git a/tests/contrib/langgraph/e2e_functional_workflows.py b/tests/contrib/langgraph/e2e_functional_workflows.py index 7942e11fe..e44effcfd 100644 --- a/tests/contrib/langgraph/e2e_functional_workflows.py +++ b/tests/contrib/langgraph/e2e_functional_workflows.py @@ -6,8 +6,7 @@ from typing import Any from temporalio import workflow - -from temporalio.contrib.langgraph.langgraph_plugin import entrypoint, cache +from temporalio.contrib.langgraph.langgraph_plugin import cache, entrypoint @workflow.defn diff --git a/tests/contrib/langgraph/test_continue_as_new.py b/tests/contrib/langgraph/test_continue_as_new.py index 72f12ae96..51018b48f 100644 --- a/tests/contrib/langgraph/test_continue_as_new.py +++ b/tests/contrib/langgraph/test_continue_as_new.py @@ -5,11 +5,11 @@ from langgraph.checkpoint.memory import InMemorySaver from langgraph.graph import START, StateGraph from langgraph.graph.state import RunnableConfig + from temporalio import workflow from temporalio.client import Client -from temporalio.worker import Worker - from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph +from temporalio.worker import Worker async def node(state: str) -> str: @@ -20,12 +20,8 @@ async def node(state: str) -> str: class ContinueAsNewWorkflow: @workflow.run async def run(self, values: str) -> Any: - g = graph("my-graph").compile( - checkpointer=InMemorySaver() - ) - config = RunnableConfig( - {"configurable": {"thread_id": "1"}} - ) + g = graph("my-graph").compile(checkpointer=InMemorySaver()) + config = RunnableConfig({"configurable": {"thread_id": "1"}}) await g.aupdate_state(config, values) await g.ainvoke(values, config) diff --git a/tests/contrib/langgraph/test_continue_as_new_cached.py b/tests/contrib/langgraph/test_continue_as_new_cached.py index 048626c0a..304a3d87b 100644 --- a/tests/contrib/langgraph/test_continue_as_new_cached.py +++ b/tests/contrib/langgraph/test_continue_as_new_cached.py @@ -4,17 +4,17 @@ so nodes don't re-execute when the graph is re-invoked with the same state. """ +from dataclasses import dataclass from datetime import timedelta from typing import Any from uuid import uuid4 -from dataclasses import dataclass from langgraph.graph import START, StateGraph + from temporalio import workflow from temporalio.client import Client -from temporalio.worker import Worker - from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, cache, graph +from temporalio.worker import Worker # Track execution counts to verify caching _execution_counts: dict[str, int] = {} @@ -112,12 +112,12 @@ async def test_graph_continue_as_new_cached(client: Client): assert result == 260 # Each node should execute exactly once — phases 2 and 3 use cached results. - assert _execution_counts.get("multiply", 0) == 1, ( - f"multiply executed {_execution_counts.get('multiply', 0)} times, expected 1" - ) - assert _execution_counts.get("add", 0) == 1, ( - f"add executed {_execution_counts.get('add', 0)} times, expected 1" - ) - assert _execution_counts.get("double", 0) == 1, ( - f"double executed {_execution_counts.get('double', 0)} times, expected 1" - ) + assert ( + _execution_counts.get("multiply", 0) == 1 + ), f"multiply executed {_execution_counts.get('multiply', 0)} times, expected 1" + assert ( + _execution_counts.get("add", 0) == 1 + ), f"add executed {_execution_counts.get('add', 0)} times, expected 1" + assert ( + _execution_counts.get("double", 0) == 1 + ), f"double executed {_execution_counts.get('double', 0)} times, expected 1" diff --git a/tests/contrib/langgraph/test_e2e_functional.py b/tests/contrib/langgraph/test_e2e_functional.py index 09f85e90b..06d320f47 100644 --- a/tests/contrib/langgraph/test_e2e_functional.py +++ b/tests/contrib/langgraph/test_e2e_functional.py @@ -9,9 +9,8 @@ from uuid import uuid4 from temporalio.client import Client -from temporalio.worker import Worker - from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin +from temporalio.worker import Worker from tests.contrib.langgraph.e2e_functional_entrypoints import ( add_ten, continue_as_new_entrypoint, @@ -108,15 +107,15 @@ async def test_continue_as_new_with_checkpoint(self, client: Client) -> None: assert result["result"] == 260 counts = get_task_execution_counts() - assert counts.get("task_a", 0) == 1, ( - f"task_a executed {counts.get('task_a', 0)} times, expected 1" - ) - assert counts.get("task_b", 0) == 1, ( - f"task_b executed {counts.get('task_b', 0)} times, expected 1" - ) - assert counts.get("task_c", 0) == 1, ( - f"task_c executed {counts.get('task_c', 0)} times, expected 1" - ) + assert ( + counts.get("task_a", 0) == 1 + ), f"task_a executed {counts.get('task_a', 0)} times, expected 1" + assert ( + counts.get("task_b", 0) == 1 + ), f"task_b executed {counts.get('task_b', 0)} times, expected 1" + assert ( + counts.get("task_c", 0) == 1 + ), f"task_c executed {counts.get('task_c', 0)} times, expected 1" class TestFunctionalAPIPartialExecution: @@ -152,6 +151,6 @@ async def test_partial_execution_five_tasks(self, client: Client) -> None: counts = get_task_execution_counts() for i in range(1, 6): - assert counts.get(f"step_{i}", 0) == 1, ( - f"step_{i} executed {counts.get(f'step_{i}', 0)} times, expected 1" - ) + assert ( + counts.get(f"step_{i}", 0) == 1 + ), f"step_{i} executed {counts.get(f'step_{i}', 0)} times, expected 1" diff --git a/tests/contrib/langgraph/test_e2e_functional_v2.py b/tests/contrib/langgraph/test_e2e_functional_v2.py index b9678bf74..579a5f59c 100644 --- a/tests/contrib/langgraph/test_e2e_functional_v2.py +++ b/tests/contrib/langgraph/test_e2e_functional_v2.py @@ -11,16 +11,16 @@ from typing import Any from uuid import uuid4 -from langgraph.checkpoint.memory import InMemorySaver -from langgraph.types import Command from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.memory import InMemorySaver from langgraph.func import entrypoint as lg_entrypoint from langgraph.func import task +from langgraph.types import Command + from temporalio import workflow from temporalio.client import Client -from temporalio.worker import Worker - from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, entrypoint +from temporalio.worker import Worker from tests.contrib.langgraph.e2e_functional_entrypoints import ( ask_human, interrupt_entrypoint, @@ -78,9 +78,7 @@ async def run(self, input_value: str) -> dict[str, Any]: assert result.interrupts[0].value == "Do you approve?" # Resume with approval - resumed = await app.ainvoke( - Command(resume="approved"), config, version="v2" - ) + resumed = await app.ainvoke(Command(resume="approved"), config, version="v2") return resumed.value diff --git a/tests/contrib/langgraph/test_execute_in_workflow.py b/tests/contrib/langgraph/test_execute_in_workflow.py index 6d2e38ad7..d64ad4bef 100644 --- a/tests/contrib/langgraph/test_execute_in_workflow.py +++ b/tests/contrib/langgraph/test_execute_in_workflow.py @@ -2,11 +2,11 @@ from uuid import uuid4 from langgraph.graph import START, StateGraph + from temporalio import workflow from temporalio.client import Client -from temporalio.worker import Worker - from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph +from temporalio.worker import Worker async def node(_: str) -> str: diff --git a/tests/contrib/langgraph/test_interrupt.py b/tests/contrib/langgraph/test_interrupt.py index d953ab5f0..73442f243 100644 --- a/tests/contrib/langgraph/test_interrupt.py +++ b/tests/contrib/langgraph/test_interrupt.py @@ -6,11 +6,11 @@ from langgraph.graph import START, StateGraph from langgraph.graph.state import RunnableConfig from langgraph.types import Command, interrupt + from temporalio import workflow from temporalio.client import Client -from temporalio.worker import Worker - from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph +from temporalio.worker import Worker async def node(_: str) -> str: @@ -21,12 +21,8 @@ async def node(_: str) -> str: class InterruptWorkflow: @workflow.run async def run(self, input: str) -> Any: - g = graph("my-graph").compile( - checkpointer=InMemorySaver() - ) - config = RunnableConfig( - {"configurable": {"thread_id": "1"}} - ) + g = graph("my-graph").compile(checkpointer=InMemorySaver()) + config = RunnableConfig({"configurable": {"thread_id": "1"}}) result = await g.ainvoke(input, config) assert result["__interrupt__"][0].value == "Continue?" diff --git a/tests/contrib/langgraph/test_interrupt_v2.py b/tests/contrib/langgraph/test_interrupt_v2.py index 1167971ca..88d13d767 100644 --- a/tests/contrib/langgraph/test_interrupt_v2.py +++ b/tests/contrib/langgraph/test_interrupt_v2.py @@ -12,11 +12,11 @@ from langgraph.graph import START, StateGraph from langgraph.graph.state import RunnableConfig from langgraph.types import Command, interrupt + from temporalio import workflow from temporalio.client import Client -from temporalio.worker import Worker - from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph +from temporalio.worker import Worker async def node(_: str) -> str: @@ -27,12 +27,8 @@ async def node(_: str) -> str: class InterruptV2Workflow: @workflow.run async def run(self, input: str) -> Any: - g = graph("interrupt-v2-graph").compile( - checkpointer=InMemorySaver() - ) - config = RunnableConfig( - {"configurable": {"thread_id": "1"}} - ) + g = graph("interrupt-v2-graph").compile(checkpointer=InMemorySaver()) + config = RunnableConfig({"configurable": {"thread_id": "1"}}) result = await g.ainvoke(input, config, version="v2") diff --git a/tests/contrib/langgraph/test_streaming.py b/tests/contrib/langgraph/test_streaming.py index a8959aec5..0db41d2fa 100644 --- a/tests/contrib/langgraph/test_streaming.py +++ b/tests/contrib/langgraph/test_streaming.py @@ -3,11 +3,11 @@ from uuid import uuid4 from langgraph.graph import START, StateGraph + from temporalio import workflow from temporalio.client import Client -from temporalio.worker import Worker - from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph +from temporalio.worker import Worker async def node_a(state: str) -> str: diff --git a/tests/contrib/langgraph/test_subgraph_activity.py b/tests/contrib/langgraph/test_subgraph_activity.py index 8f603dc2b..24e250d07 100644 --- a/tests/contrib/langgraph/test_subgraph_activity.py +++ b/tests/contrib/langgraph/test_subgraph_activity.py @@ -3,11 +3,11 @@ from uuid import uuid4 from langgraph.graph import START, StateGraph + from temporalio import workflow from temporalio.client import Client -from temporalio.worker import Worker - from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph +from temporalio.worker import Worker async def child_node(_: str) -> str: diff --git a/tests/contrib/langgraph/test_subgraph_workflow.py b/tests/contrib/langgraph/test_subgraph_workflow.py index 5c9629d32..a8edfa541 100644 --- a/tests/contrib/langgraph/test_subgraph_workflow.py +++ b/tests/contrib/langgraph/test_subgraph_workflow.py @@ -3,11 +3,11 @@ from uuid import uuid4 from langgraph.graph import START, StateGraph + from temporalio import workflow from temporalio.client import Client -from temporalio.worker import Worker - from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph +from temporalio.worker import Worker async def child_node(_: str) -> str: diff --git a/tests/contrib/langgraph/test_timeout.py b/tests/contrib/langgraph/test_timeout.py index b6b062e5f..d19307820 100644 --- a/tests/contrib/langgraph/test_timeout.py +++ b/tests/contrib/langgraph/test_timeout.py @@ -5,12 +5,12 @@ from langgraph.graph import START, StateGraph from pytest import raises + from temporalio import workflow from temporalio.client import Client, WorkflowFailureError from temporalio.common import RetryPolicy -from temporalio.worker import Worker - from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph +from temporalio.worker import Worker async def node(_: str) -> str: diff --git a/tests/contrib/langgraph/test_two_nodes.py b/tests/contrib/langgraph/test_two_nodes.py index 06cc13071..47d7ff03d 100644 --- a/tests/contrib/langgraph/test_two_nodes.py +++ b/tests/contrib/langgraph/test_two_nodes.py @@ -3,11 +3,11 @@ from uuid import uuid4 from langgraph.graph import START, StateGraph + from temporalio import workflow from temporalio.client import Client -from temporalio.worker import Worker - from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph +from temporalio.worker import Worker async def node_a(state: str) -> str: From 68a0eb42d3261316f7286055d2aa05d2c74d4fe3 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Tue, 14 Apr 2026 15:29:54 -0700 Subject: [PATCH 04/47] fix pyright lint errors --- temporalio/contrib/langgraph/activity.py | 17 +++++++------ .../contrib/langgraph/test_continue_as_new.py | 19 ++++++++------ .../langgraph/test_continue_as_new_cached.py | 25 +++++++++++-------- .../langgraph/test_execute_in_workflow.py | 15 +++++++---- tests/contrib/langgraph/test_interrupt.py | 15 +++++++---- tests/contrib/langgraph/test_interrupt_v2.py | 15 +++++++---- tests/contrib/langgraph/test_streaming.py | 19 ++++++++------ .../langgraph/test_subgraph_activity.py | 19 ++++++++------ .../langgraph/test_subgraph_workflow.py | 19 ++++++++------ tests/contrib/langgraph/test_timeout.py | 13 +++++++--- tests/contrib/langgraph/test_two_nodes.py | 19 ++++++++------ 11 files changed, 124 insertions(+), 71 deletions(-) diff --git a/temporalio/contrib/langgraph/activity.py b/temporalio/contrib/langgraph/activity.py index 6ba828bf0..614940bc2 100644 --- a/temporalio/contrib/langgraph/activity.py +++ b/temporalio/contrib/langgraph/activity.py @@ -1,3 +1,4 @@ +from collections.abc import Awaitable from dataclasses import dataclass from inspect import iscoroutinefunction from typing import Any, Callable @@ -25,7 +26,9 @@ class ActivityOutput: langgraph_interrupts: tuple[Interrupt] | None = None -def wrap_activity(func: Callable) -> Callable: +def wrap_activity( + func: Callable, +) -> Callable[[ActivityInput], Awaitable[ActivityOutput]]: async def wrapper(input: ActivityInput) -> ActivityOutput: set_langgraph_config(input.langgraph_config) try: @@ -41,11 +44,11 @@ async def wrapper(input: ActivityInput) -> ActivityOutput: def wrap_execute_activity( - afunc: Callable, + afunc: Callable[[ActivityInput], Awaitable[ActivityOutput]], task_id: str = "", - **execute_activity_kwargs: dict[str, Any], -) -> Callable: - async def wrapper(*args: Any, **kwargs: dict[str, Any]) -> Any: + **execute_activity_kwargs: Any, +) -> Callable[..., Any]: + async def wrapper(*args: Any, **kwargs: Any) -> Any: from temporalio.contrib.langgraph.task_cache import ( _cache_key, _cache_lookup, @@ -62,8 +65,8 @@ async def wrapper(*args: Any, **kwargs: dict[str, Any]) -> Any: input = ActivityInput( args=args, kwargs=kwargs, langgraph_config=get_langgraph_config() ) - output: ActivityOutput = await workflow.execute_activity( - afunc, input, result_type=ActivityOutput, **execute_activity_kwargs + output = await workflow.execute_activity( + afunc, input, **execute_activity_kwargs ) if output.langgraph_interrupts is not None: raise GraphInterrupt(output.langgraph_interrupts) diff --git a/tests/contrib/langgraph/test_continue_as_new.py b/tests/contrib/langgraph/test_continue_as_new.py index 51018b48f..8fab4cc91 100644 --- a/tests/contrib/langgraph/test_continue_as_new.py +++ b/tests/contrib/langgraph/test_continue_as_new.py @@ -5,6 +5,7 @@ from langgraph.checkpoint.memory import InMemorySaver from langgraph.graph import START, StateGraph from langgraph.graph.state import RunnableConfig +from typing_extensions import TypedDict from temporalio import workflow from temporalio.client import Client @@ -12,21 +13,25 @@ from temporalio.worker import Worker -async def node(state: str) -> str: - return state + "a" +class State(TypedDict): + value: str + + +async def node(state: State) -> dict[str, str]: + return {"value": state["value"] + "a"} @workflow.defn class ContinueAsNewWorkflow: @workflow.run - async def run(self, values: str) -> Any: + async def run(self, values: dict[str, str]) -> Any: g = graph("my-graph").compile(checkpointer=InMemorySaver()) config = RunnableConfig({"configurable": {"thread_id": "1"}}) await g.aupdate_state(config, values) await g.ainvoke(values, config) - if len(values) < 3: + if len(values["value"]) < 3: state = await g.aget_state(config) workflow.continue_as_new(state.values) @@ -34,7 +39,7 @@ async def run(self, values: str) -> Any: async def test_continue_as_new(client: Client): - g = StateGraph(str) + g = StateGraph(State) g.add_node( "node", node, @@ -52,9 +57,9 @@ async def test_continue_as_new(client: Client): ): result = await client.execute_workflow( ContinueAsNewWorkflow.run, - "", + {"value": ""}, id=f"test-workflow-{uuid4()}", task_queue=task_queue, ) - assert result == "aaa" + assert result == {"value": "aaa"} diff --git a/tests/contrib/langgraph/test_continue_as_new_cached.py b/tests/contrib/langgraph/test_continue_as_new_cached.py index 304a3d87b..40bdee7a9 100644 --- a/tests/contrib/langgraph/test_continue_as_new_cached.py +++ b/tests/contrib/langgraph/test_continue_as_new_cached.py @@ -10,6 +10,7 @@ from uuid import uuid4 from langgraph.graph import START, StateGraph +from typing_extensions import TypedDict from temporalio import workflow from temporalio.client import Client @@ -24,19 +25,23 @@ def _reset(): _execution_counts.clear() -async def multiply_by_3(state: int) -> int: +class State(TypedDict): + value: int + + +async def multiply_by_3(state: State) -> dict[str, int]: _execution_counts["multiply"] = _execution_counts.get("multiply", 0) + 1 - return state * 3 + return {"value": state["value"] * 3} -async def add_100(state: int) -> int: +async def add_100(state: State) -> dict[str, int]: _execution_counts["add"] = _execution_counts.get("add", 0) + 1 - return state + 100 + return {"value": state["value"] + 100} -async def double(state: int) -> int: +async def double(state: State) -> dict[str, int]: _execution_counts["double"] = _execution_counts.get("double", 0) + 1 - return state * 2 + return {"value": state["value"] * 2} @dataclass @@ -59,9 +64,9 @@ class GraphContinueAsNewWorkflow: """ @workflow.run - async def run(self, input_data: GraphContinueAsNewInput) -> int: + async def run(self, input_data: GraphContinueAsNewInput) -> dict[str, int]: g = graph("cached-graph", cache=input_data.cache).compile() - result = await g.ainvoke(input_data.value) + result = await g.ainvoke({"value": input_data.value}) if input_data.phase < 3: workflow.continue_as_new( @@ -84,7 +89,7 @@ async def test_graph_continue_as_new_cached(client: Client): _reset() timeout = {"start_to_close_timeout": timedelta(seconds=10)} - g = StateGraph(int) + g = StateGraph(State) g.add_node("multiply_by_3", multiply_by_3, metadata=timeout) g.add_node("add_100", add_100, metadata=timeout) g.add_node("double", double, metadata=timeout) @@ -109,7 +114,7 @@ async def test_graph_continue_as_new_cached(client: Client): ) # 10 * 3 = 30 -> + 100 = 130 -> * 2 = 260 - assert result == 260 + assert result == {"value": 260} # Each node should execute exactly once — phases 2 and 3 use cached results. assert ( diff --git a/tests/contrib/langgraph/test_execute_in_workflow.py b/tests/contrib/langgraph/test_execute_in_workflow.py index d64ad4bef..42ea31dce 100644 --- a/tests/contrib/langgraph/test_execute_in_workflow.py +++ b/tests/contrib/langgraph/test_execute_in_workflow.py @@ -2,6 +2,7 @@ from uuid import uuid4 from langgraph.graph import START, StateGraph +from typing_extensions import TypedDict from temporalio import workflow from temporalio.client import Client @@ -9,19 +10,23 @@ from temporalio.worker import Worker -async def node(_: str) -> str: - return "done" +class State(TypedDict): + value: str + + +async def node(state: State) -> dict[str, str]: + return {"value": "done"} @workflow.defn class ExecuteInWorkflowWorkflow: @workflow.run async def run(self, input: str) -> Any: - return await graph("my-graph").compile().ainvoke(input) + return await graph("my-graph").compile().ainvoke({"value": input}) async def test_execute_in_workflow(client: Client): - g = StateGraph(str) + g = StateGraph(State) g.add_node("node", node, metadata={"execute_in": "workflow"}) g.add_edge(START, "node") @@ -40,4 +45,4 @@ async def test_execute_in_workflow(client: Client): task_queue=task_queue, ) - assert result == "done" + assert result == {"value": "done"} diff --git a/tests/contrib/langgraph/test_interrupt.py b/tests/contrib/langgraph/test_interrupt.py index 73442f243..e82b958f3 100644 --- a/tests/contrib/langgraph/test_interrupt.py +++ b/tests/contrib/langgraph/test_interrupt.py @@ -6,6 +6,7 @@ from langgraph.graph import START, StateGraph from langgraph.graph.state import RunnableConfig from langgraph.types import Command, interrupt +from typing_extensions import TypedDict from temporalio import workflow from temporalio.client import Client @@ -13,8 +14,12 @@ from temporalio.worker import Worker -async def node(_: str) -> str: - return interrupt("Continue?") +class State(TypedDict): + value: str + + +async def node(state: State) -> dict[str, str]: + return {"value": interrupt("Continue?")} @workflow.defn @@ -24,14 +29,14 @@ async def run(self, input: str) -> Any: g = graph("my-graph").compile(checkpointer=InMemorySaver()) config = RunnableConfig({"configurable": {"thread_id": "1"}}) - result = await g.ainvoke(input, config) + result = await g.ainvoke({"value": input}, config) assert result["__interrupt__"][0].value == "Continue?" return await g.ainvoke(Command(resume="yes"), config) async def test_interrupt(client: Client): - g = StateGraph(str) + g = StateGraph(State) g.add_node( "node", node, @@ -54,4 +59,4 @@ async def test_interrupt(client: Client): task_queue=task_queue, ) - assert result == "yes" + assert result == {"value": "yes"} diff --git a/tests/contrib/langgraph/test_interrupt_v2.py b/tests/contrib/langgraph/test_interrupt_v2.py index 88d13d767..6052d4144 100644 --- a/tests/contrib/langgraph/test_interrupt_v2.py +++ b/tests/contrib/langgraph/test_interrupt_v2.py @@ -12,6 +12,7 @@ from langgraph.graph import START, StateGraph from langgraph.graph.state import RunnableConfig from langgraph.types import Command, interrupt +from typing_extensions import TypedDict from temporalio import workflow from temporalio.client import Client @@ -19,8 +20,12 @@ from temporalio.worker import Worker -async def node(_: str) -> str: - return interrupt("Continue?") +class State(TypedDict): + value: str + + +async def node(state: State) -> dict[str, str]: + return {"value": interrupt("Continue?")} @workflow.defn @@ -30,7 +35,7 @@ async def run(self, input: str) -> Any: g = graph("interrupt-v2-graph").compile(checkpointer=InMemorySaver()) config = RunnableConfig({"configurable": {"thread_id": "1"}}) - result = await g.ainvoke(input, config, version="v2") + result = await g.ainvoke({"value": input}, config, version="v2") # v2: interrupts are on result.interrupts, not result["__interrupt__"] assert result.value == {} @@ -41,7 +46,7 @@ async def run(self, input: str) -> Any: async def test_interrupt_v2(client: Client): - g = StateGraph(str) + g = StateGraph(State) g.add_node( "node", node, @@ -64,4 +69,4 @@ async def test_interrupt_v2(client: Client): task_queue=task_queue, ) - assert result == "yes" + assert result == {"value": "yes"} diff --git a/tests/contrib/langgraph/test_streaming.py b/tests/contrib/langgraph/test_streaming.py index 0db41d2fa..6b2f66bc8 100644 --- a/tests/contrib/langgraph/test_streaming.py +++ b/tests/contrib/langgraph/test_streaming.py @@ -3,6 +3,7 @@ from uuid import uuid4 from langgraph.graph import START, StateGraph +from typing_extensions import TypedDict from temporalio import workflow from temporalio.client import Client @@ -10,12 +11,16 @@ from temporalio.worker import Worker -async def node_a(state: str) -> str: - return state + "a" +class State(TypedDict): + value: str -async def node_b(state: str) -> str: - return state + "b" +async def node_a(state: State) -> dict[str, str]: + return {"value": state["value"] + "a"} + + +async def node_b(state: State) -> dict[str, str]: + return {"value": state["value"] + "b"} @workflow.defn @@ -23,13 +28,13 @@ class StreamingWorkflow: @workflow.run async def run(self, input: str) -> Any: chunks = [] - async for chunk in graph("streaming").compile().astream(input): + async for chunk in graph("streaming").compile().astream({"value": input}): chunks.append(chunk) return chunks async def test_streaming(client: Client): - g = StateGraph(str) + g = StateGraph(State) g.add_node( "node_a", node_a, @@ -58,4 +63,4 @@ async def test_streaming(client: Client): task_queue=task_queue, ) - assert chunks == [{"node_a": "a"}, {"node_b": "ab"}] + assert chunks == [{"node_a": {"value": "a"}}, {"node_b": {"value": "ab"}}] diff --git a/tests/contrib/langgraph/test_subgraph_activity.py b/tests/contrib/langgraph/test_subgraph_activity.py index 24e250d07..1ce00bac3 100644 --- a/tests/contrib/langgraph/test_subgraph_activity.py +++ b/tests/contrib/langgraph/test_subgraph_activity.py @@ -3,6 +3,7 @@ from uuid import uuid4 from langgraph.graph import START, StateGraph +from typing_extensions import TypedDict from temporalio import workflow from temporalio.client import Client @@ -10,12 +11,16 @@ from temporalio.worker import Worker -async def child_node(_: str) -> str: - return "child" +class State(TypedDict): + value: str -async def parent_node(state: str) -> str: - child = StateGraph(str) +async def child_node(state: State) -> dict[str, str]: + return {"value": "child"} + + +async def parent_node(state: State) -> dict[str, str]: + child = StateGraph(State) child.add_node("child_node", child_node) child.add_edge(START, "child_node") @@ -26,11 +31,11 @@ async def parent_node(state: str) -> str: class ActivitySubgraphWorkflow: @workflow.run async def run(self, input: str) -> Any: - return await graph("parent").compile().ainvoke(input) + return await graph("parent").compile().ainvoke({"value": input}) async def test_activity_subgraph(client: Client): - parent = StateGraph(str) + parent = StateGraph(State) parent.add_node( "parent_node", parent_node, @@ -53,4 +58,4 @@ async def test_activity_subgraph(client: Client): task_queue=task_queue, ) - assert result == "child" + assert result == {"value": "child"} diff --git a/tests/contrib/langgraph/test_subgraph_workflow.py b/tests/contrib/langgraph/test_subgraph_workflow.py index a8edfa541..8166fd922 100644 --- a/tests/contrib/langgraph/test_subgraph_workflow.py +++ b/tests/contrib/langgraph/test_subgraph_workflow.py @@ -3,6 +3,7 @@ from uuid import uuid4 from langgraph.graph import START, StateGraph +from typing_extensions import TypedDict from temporalio import workflow from temporalio.client import Client @@ -10,11 +11,15 @@ from temporalio.worker import Worker -async def child_node(_: str) -> str: - return "child" +class State(TypedDict): + value: str -async def parent_node(state: str) -> str: +async def child_node(state: State) -> dict[str, str]: + return {"value": "child"} + + +async def parent_node(state: State) -> dict[str, str]: return await graph("child").compile().ainvoke(state) @@ -22,11 +27,11 @@ async def parent_node(state: str) -> str: class WorkflowSubgraphWorkflow: @workflow.run async def run(self, input: str) -> Any: - return await graph("parent").compile().ainvoke(input) + return await graph("parent").compile().ainvoke({"value": input}) async def test_workflow_subgraph(client: Client): - child = StateGraph(str) + child = StateGraph(State) child.add_node( "child_node", child_node, @@ -34,7 +39,7 @@ async def test_workflow_subgraph(client: Client): ) child.add_edge(START, "child_node") - parent = StateGraph(str) + parent = StateGraph(State) parent.add_node("parent_node", parent_node, metadata={"execute_in": "workflow"}) parent.add_edge(START, "parent_node") @@ -53,4 +58,4 @@ async def test_workflow_subgraph(client: Client): task_queue=task_queue, ) - assert result == "child" + assert result == {"value": "child"} diff --git a/tests/contrib/langgraph/test_timeout.py b/tests/contrib/langgraph/test_timeout.py index d19307820..141549d4d 100644 --- a/tests/contrib/langgraph/test_timeout.py +++ b/tests/contrib/langgraph/test_timeout.py @@ -5,6 +5,7 @@ from langgraph.graph import START, StateGraph from pytest import raises +from typing_extensions import TypedDict from temporalio import workflow from temporalio.client import Client, WorkflowFailureError @@ -13,20 +14,24 @@ from temporalio.worker import Worker -async def node(_: str) -> str: +class State(TypedDict): + value: str + + +async def node(state: State) -> dict[str, str]: await sleep(1) # 1 second - return "done" + return {"value": "done"} @workflow.defn class TimeoutWorkflow: @workflow.run async def run(self, input: str) -> Any: - return await graph("my-graph").compile().ainvoke(input) + return await graph("my-graph").compile().ainvoke({"value": input}) async def test_timeout(client: Client): - g = StateGraph(str) + g = StateGraph(State) g.add_node( "node", node, diff --git a/tests/contrib/langgraph/test_two_nodes.py b/tests/contrib/langgraph/test_two_nodes.py index 47d7ff03d..d40833795 100644 --- a/tests/contrib/langgraph/test_two_nodes.py +++ b/tests/contrib/langgraph/test_two_nodes.py @@ -3,6 +3,7 @@ from uuid import uuid4 from langgraph.graph import START, StateGraph +from typing_extensions import TypedDict from temporalio import workflow from temporalio.client import Client @@ -10,23 +11,27 @@ from temporalio.worker import Worker -async def node_a(state: str) -> str: - return state + "a" +class State(TypedDict): + value: str -async def node_b(state: str) -> str: - return state + "b" +async def node_a(state: State) -> dict[str, str]: + return {"value": state["value"] + "a"} + + +async def node_b(state: State) -> dict[str, str]: + return {"value": state["value"] + "b"} @workflow.defn class TwoNodesWorkflow: @workflow.run async def run(self, input: str) -> Any: - return await graph("my-graph").compile().ainvoke(input) + return await graph("my-graph").compile().ainvoke({"value": input}) async def test_two_nodes(client: Client): - g = StateGraph(str) + g = StateGraph(State) g.add_node( "node_a", node_a, @@ -55,4 +60,4 @@ async def test_two_nodes(client: Client): task_queue=task_queue, ) - assert result == "ab" + assert result == {"value": "ab"} From cbd066a4ff97d40f5576da1f303906fb4dc06834 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Tue, 14 Apr 2026 16:30:33 -0700 Subject: [PATCH 05/47] fixed some mypy lints --- temporalio/contrib/langgraph/activity.py | 17 +++++++------ .../contrib/langgraph/langgraph_config.py | 2 ++ .../contrib/langgraph/langgraph_plugin.py | 24 +++++++++++-------- temporalio/contrib/langgraph/task_cache.py | 12 +++++----- tests/contrib/langgraph/conftest.py | 17 +++++++------ .../contrib/langgraph/test_e2e_functional.py | 3 ++- .../langgraph/test_execute_in_workflow.py | 2 +- tests/contrib/langgraph/test_interrupt.py | 2 +- tests/contrib/langgraph/test_interrupt_v2.py | 4 ++-- .../langgraph/test_subgraph_activity.py | 4 ++-- .../langgraph/test_subgraph_workflow.py | 2 +- tests/contrib/langgraph/test_timeout.py | 2 +- 12 files changed, 48 insertions(+), 43 deletions(-) diff --git a/temporalio/contrib/langgraph/activity.py b/temporalio/contrib/langgraph/activity.py index 614940bc2..291441ed7 100644 --- a/temporalio/contrib/langgraph/activity.py +++ b/temporalio/contrib/langgraph/activity.py @@ -11,6 +11,11 @@ get_langgraph_config, set_langgraph_config, ) +from temporalio.contrib.langgraph.task_cache import ( + cache_key, + cache_lookup, + cache_put, +) @dataclass @@ -49,16 +54,10 @@ def wrap_execute_activity( **execute_activity_kwargs: Any, ) -> Callable[..., Any]: async def wrapper(*args: Any, **kwargs: Any) -> Any: - from temporalio.contrib.langgraph.task_cache import ( - _cache_key, - _cache_lookup, - _cache_put, - ) - # Check task result cache (for continue-as-new deduplication). - key = _cache_key(task_id, args, kwargs) if task_id else "" + key = cache_key(task_id, args, kwargs) if task_id else "" if task_id: - found, cached = _cache_lookup(key) + found, cached = cache_lookup(key) if found: return cached @@ -73,7 +72,7 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: # Store in cache for future continue-as-new cycles. if task_id: - _cache_put(key, output.result) + cache_put(key, output.result) return output.result diff --git a/temporalio/contrib/langgraph/langgraph_config.py b/temporalio/contrib/langgraph/langgraph_config.py index 07e98363f..e9ab1c5aa 100644 --- a/temporalio/contrib/langgraph/langgraph_config.py +++ b/temporalio/contrib/langgraph/langgraph_config.py @@ -1,3 +1,5 @@ +# pyright: reportMissingTypeStubs=false + from typing import Any from langchain_core.runnables.config import var_child_runnable_config diff --git a/temporalio/contrib/langgraph/langgraph_plugin.py b/temporalio/contrib/langgraph/langgraph_plugin.py index 8610631b2..ac19b9820 100644 --- a/temporalio/contrib/langgraph/langgraph_plugin.py +++ b/temporalio/contrib/langgraph/langgraph_plugin.py @@ -1,3 +1,5 @@ +# pyright: reportMissingTypeStubs=false + from dataclasses import replace from typing import Any, Callable @@ -8,16 +10,16 @@ from temporalio import activity from temporalio.contrib.langgraph.activity import wrap_activity, wrap_execute_activity from temporalio.contrib.langgraph.task_cache import ( - _get_task_cache, - _set_task_cache, - _task_id, + get_task_cache, + set_task_cache, + task_id, ) from temporalio.plugin import SimplePlugin from temporalio.worker import WorkflowRunner from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner # Save registered graphs/entrypoints at the module level to avoid being refreshed by the sandbox. -_graph_registry: dict[str, StateGraph] = {} +_graph_registry: dict[str, StateGraph[Any]] = {} _entrypoint_registry: dict[str, Pregel] = {} @@ -105,14 +107,16 @@ def execute(self, func: Callable, kwargs: dict[str, Any] | None = None) -> Calla if execute_in == "activity": a = activity.defn(name=func.__name__)(wrap_activity(func)) self.activities.append(a) - return wrap_execute_activity(a, task_id=_task_id(func), **(kwargs or {})) + return wrap_execute_activity(a, task_id=task_id(func), **(kwargs or {})) elif execute_in == "workflow": return func else: raise ValueError(f"Invalid execute_in value: {execute_in}") -def graph(name: str, cache: dict[str, Any] | None = None) -> StateGraph: +def graph( + name: str, cache: dict[str, Any] | None = None +) -> StateGraph[Any, None, Any, Any]: """Retrieve a registered graph by name. Args: @@ -122,7 +126,7 @@ def graph(name: str, cache: dict[str, Any] | None = None) -> StateGraph: not re-executed after continue-as-new. """ _patch_event_loop() - _set_task_cache(cache or {}) + set_task_cache(cache or {}) return _graph_registry[name] @@ -136,7 +140,7 @@ def entrypoint(name: str, cache: dict[str, Any] | None = None) -> Pregel: not re-executed after continue-as-new. """ _patch_event_loop() - _set_task_cache(cache or {}) + set_task_cache(cache or {}) return _entrypoint_registry[name] @@ -147,7 +151,7 @@ def cache() -> dict[str, Any] | None: restore cached task results across continue-as-new boundaries. Returns None if the cache is empty. """ - return _get_task_cache() or None + return get_task_cache() or None def _patch_event_loop(): @@ -155,4 +159,4 @@ def _patch_event_loop(): from asyncio import get_event_loop loop = get_event_loop() - loop.is_running = lambda: True + setattr(loop, "is_running", lambda: True) diff --git a/temporalio/contrib/langgraph/task_cache.py b/temporalio/contrib/langgraph/task_cache.py index 78a1b5b5c..2f5c574c4 100644 --- a/temporalio/contrib/langgraph/task_cache.py +++ b/temporalio/contrib/langgraph/task_cache.py @@ -17,15 +17,15 @@ ) -def _set_task_cache(cache: dict[str, Any] | None) -> None: +def set_task_cache(cache: dict[str, Any] | None) -> None: _task_cache.set(cache) -def _get_task_cache() -> dict[str, Any] | None: +def get_task_cache() -> dict[str, Any] | None: return _task_cache.get() -def _task_id(func: Any) -> str: +def task_id(func: Any) -> str: """Return the fully-qualified module.qualname for a function. Raises ValueError for functions that cannot be identified unambiguously @@ -52,7 +52,7 @@ def _task_id(func: Any) -> str: return f"{module}.{qualname}" -def _cache_key(task_id: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> str: +def cache_key(task_id: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> str: """Build a cache key from the full task identifier and arguments.""" try: key_str = dumps([task_id, args, kwargs], sort_keys=True, default=str) @@ -61,7 +61,7 @@ def _cache_key(task_id: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> s return sha256(key_str.encode()).hexdigest()[:32] -def _cache_lookup(key: str) -> tuple[bool, Any]: +def cache_lookup(key: str) -> tuple[bool, Any]: """Return (True, value) if cached, (False, None) otherwise.""" cache = _task_cache.get() if cache is not None and key in cache: @@ -69,7 +69,7 @@ def _cache_lookup(key: str) -> tuple[bool, Any]: return False, None -def _cache_put(key: str, value: Any) -> None: +def cache_put(key: str, value: Any) -> None: cache = _task_cache.get() if cache is not None: cache[key] = value diff --git a/tests/contrib/langgraph/conftest.py b/tests/contrib/langgraph/conftest.py index ef4469b0b..dba39e1b2 100644 --- a/tests/contrib/langgraph/conftest.py +++ b/tests/contrib/langgraph/conftest.py @@ -1,27 +1,26 @@ -from asyncio import get_event_loop_policy -from collections.abc import AsyncGenerator +import asyncio +from collections.abc import AsyncGenerator, Iterator -from pytest import fixture -from pytest_asyncio import fixture as async_fixture +import pytest from temporalio.client import Client from temporalio.testing import WorkflowEnvironment -@fixture(scope="session") -def event_loop(): - loop = get_event_loop_policy().new_event_loop() +@pytest.fixture(scope="session") +def event_loop() -> Iterator[asyncio.AbstractEventLoop]: + loop = asyncio.new_event_loop() yield loop loop.close() -@async_fixture(scope="session") +@pytest.fixture(scope="session") async def env() -> AsyncGenerator[WorkflowEnvironment, None]: env = await WorkflowEnvironment.start_local() yield env await env.shutdown() -@async_fixture +@pytest.fixture async def client(env: WorkflowEnvironment) -> Client: return env.client diff --git a/tests/contrib/langgraph/test_e2e_functional.py b/tests/contrib/langgraph/test_e2e_functional.py index 06d320f47..73128b69a 100644 --- a/tests/contrib/langgraph/test_e2e_functional.py +++ b/tests/contrib/langgraph/test_e2e_functional.py @@ -6,6 +6,7 @@ from __future__ import annotations from datetime import timedelta +from typing import Any from uuid import uuid4 from temporalio.client import Client @@ -37,7 +38,7 @@ ) -def _activity_opts(*task_funcs) -> dict[str, dict]: +def _activity_opts(*task_funcs: Any) -> dict[str, dict]: """Build activity_options dict giving every task the same 30s timeout.""" return { t.func.__name__: {"start_to_close_timeout": timedelta(seconds=30)} diff --git a/tests/contrib/langgraph/test_execute_in_workflow.py b/tests/contrib/langgraph/test_execute_in_workflow.py index 42ea31dce..0b5767680 100644 --- a/tests/contrib/langgraph/test_execute_in_workflow.py +++ b/tests/contrib/langgraph/test_execute_in_workflow.py @@ -14,7 +14,7 @@ class State(TypedDict): value: str -async def node(state: State) -> dict[str, str]: +async def node(state: State) -> dict[str, str]: # pyright: ignore[reportUnusedParameter] return {"value": "done"} diff --git a/tests/contrib/langgraph/test_interrupt.py b/tests/contrib/langgraph/test_interrupt.py index e82b958f3..90bd218bf 100644 --- a/tests/contrib/langgraph/test_interrupt.py +++ b/tests/contrib/langgraph/test_interrupt.py @@ -18,7 +18,7 @@ class State(TypedDict): value: str -async def node(state: State) -> dict[str, str]: +async def node(state: State) -> dict[str, str]: # pyright: ignore[reportUnusedParameter] return {"value": interrupt("Continue?")} diff --git a/tests/contrib/langgraph/test_interrupt_v2.py b/tests/contrib/langgraph/test_interrupt_v2.py index 6052d4144..260b1369c 100644 --- a/tests/contrib/langgraph/test_interrupt_v2.py +++ b/tests/contrib/langgraph/test_interrupt_v2.py @@ -24,7 +24,7 @@ class State(TypedDict): value: str -async def node(state: State) -> dict[str, str]: +async def node(state: State) -> dict[str, str]: # pyright: ignore[reportUnusedParameter] return {"value": interrupt("Continue?")} @@ -38,7 +38,7 @@ async def run(self, input: str) -> Any: result = await g.ainvoke({"value": input}, config, version="v2") # v2: interrupts are on result.interrupts, not result["__interrupt__"] - assert result.value == {} + assert result.value == {"value": ""} assert len(result.interrupts) == 1 assert result.interrupts[0].value == "Continue?" diff --git a/tests/contrib/langgraph/test_subgraph_activity.py b/tests/contrib/langgraph/test_subgraph_activity.py index 1ce00bac3..af3eebf88 100644 --- a/tests/contrib/langgraph/test_subgraph_activity.py +++ b/tests/contrib/langgraph/test_subgraph_activity.py @@ -15,12 +15,12 @@ class State(TypedDict): value: str -async def child_node(state: State) -> dict[str, str]: +async def child_node(state: State) -> dict[str, str]: # pyright: ignore[reportUnusedParameter] return {"value": "child"} async def parent_node(state: State) -> dict[str, str]: - child = StateGraph(State) + child: StateGraph[State, None, State, State] = StateGraph(State) child.add_node("child_node", child_node) child.add_edge(START, "child_node") diff --git a/tests/contrib/langgraph/test_subgraph_workflow.py b/tests/contrib/langgraph/test_subgraph_workflow.py index 8166fd922..67cddc40e 100644 --- a/tests/contrib/langgraph/test_subgraph_workflow.py +++ b/tests/contrib/langgraph/test_subgraph_workflow.py @@ -15,7 +15,7 @@ class State(TypedDict): value: str -async def child_node(state: State) -> dict[str, str]: +async def child_node(state: State) -> dict[str, str]: # pyright: ignore[reportUnusedParameter] return {"value": "child"} diff --git a/tests/contrib/langgraph/test_timeout.py b/tests/contrib/langgraph/test_timeout.py index 141549d4d..989d323ce 100644 --- a/tests/contrib/langgraph/test_timeout.py +++ b/tests/contrib/langgraph/test_timeout.py @@ -18,7 +18,7 @@ class State(TypedDict): value: str -async def node(state: State) -> dict[str, str]: +async def node(state: State) -> dict[str, str]: # pyright: ignore[reportUnusedParameter] await sleep(1) # 1 second return {"value": "done"} From 5d1f18278cec63f1409cccd0804c9b2ce20c4d7f Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Tue, 14 Apr 2026 16:48:55 -0700 Subject: [PATCH 06/47] fix docstring lints --- temporalio/contrib/langgraph/activity.py | 10 ++++++++++ temporalio/contrib/langgraph/langgraph_config.py | 4 ++++ temporalio/contrib/langgraph/langgraph_plugin.py | 5 ++++- temporalio/contrib/langgraph/task_cache.py | 3 +++ .../contrib/langgraph/e2e_functional_entrypoints.py | 6 +++--- tests/contrib/langgraph/test_continue_as_new.py | 6 ++++-- .../contrib/langgraph/test_continue_as_new_cached.py | 2 +- tests/contrib/langgraph/test_e2e_functional_v2.py | 6 ++++-- tests/contrib/langgraph/test_execute_in_workflow.py | 2 +- tests/contrib/langgraph/test_interrupt.py | 12 +++++++----- tests/contrib/langgraph/test_interrupt_v2.py | 12 +++++++----- tests/contrib/langgraph/test_streaming.py | 2 +- tests/contrib/langgraph/test_subgraph_activity.py | 2 +- tests/contrib/langgraph/test_subgraph_workflow.py | 2 +- tests/contrib/langgraph/test_timeout.py | 2 +- tests/contrib/langgraph/test_two_nodes.py | 2 +- 16 files changed, 53 insertions(+), 25 deletions(-) diff --git a/temporalio/contrib/langgraph/activity.py b/temporalio/contrib/langgraph/activity.py index 291441ed7..04cdd8d4a 100644 --- a/temporalio/contrib/langgraph/activity.py +++ b/temporalio/contrib/langgraph/activity.py @@ -1,3 +1,5 @@ +"""Activity wrappers for executing LangGraph nodes and tasks.""" + from collections.abc import Awaitable from dataclasses import dataclass from inspect import iscoroutinefunction @@ -20,6 +22,8 @@ @dataclass class ActivityInput: + """Input for a LangGraph activity, containing args, kwargs, and config.""" + args: tuple[Any, ...] kwargs: dict[str, Any] langgraph_config: dict[str, Any] @@ -27,6 +31,8 @@ class ActivityInput: @dataclass class ActivityOutput: + """Output from a LangGraph activity, containing result or interrupts.""" + result: Any = None langgraph_interrupts: tuple[Interrupt] | None = None @@ -34,6 +40,8 @@ class ActivityOutput: def wrap_activity( func: Callable, ) -> Callable[[ActivityInput], Awaitable[ActivityOutput]]: + """Wrap a function as a Temporal activity that handles LangGraph config and interrupts.""" + async def wrapper(input: ActivityInput) -> ActivityOutput: set_langgraph_config(input.langgraph_config) try: @@ -53,6 +61,8 @@ def wrap_execute_activity( task_id: str = "", **execute_activity_kwargs: Any, ) -> Callable[..., Any]: + """Wrap an activity function to be called via workflow.execute_activity with caching.""" + async def wrapper(*args: Any, **kwargs: Any) -> Any: # Check task result cache (for continue-as-new deduplication). key = cache_key(task_id, args, kwargs) if task_id else "" diff --git a/temporalio/contrib/langgraph/langgraph_config.py b/temporalio/contrib/langgraph/langgraph_config.py index e9ab1c5aa..9663679af 100644 --- a/temporalio/contrib/langgraph/langgraph_config.py +++ b/temporalio/contrib/langgraph/langgraph_config.py @@ -1,3 +1,5 @@ +"""LangGraph configuration management for Temporal workflows.""" + # pyright: reportMissingTypeStubs=false from typing import Any @@ -13,6 +15,7 @@ def get_langgraph_config() -> dict[str, Any]: + """Get the current LangGraph runnable config as a serializable dict.""" config = var_child_runnable_config.get() or {} configurable = config.get("configurable") or {} scratchpad = configurable.get(CONFIG_KEY_SCRATCHPAD) @@ -31,6 +34,7 @@ def get_langgraph_config() -> dict[str, Any]: def set_langgraph_config(config: dict[str, Any]) -> None: + """Restore a LangGraph runnable config from a serialized dict.""" configurable = config.get("configurable") or {} scratchpad = configurable.get(CONFIG_KEY_SCRATCHPAD) or {} null_resume_box = [scratchpad.get("null_resume")] diff --git a/temporalio/contrib/langgraph/langgraph_plugin.py b/temporalio/contrib/langgraph/langgraph_plugin.py index ac19b9820..3b7677917 100644 --- a/temporalio/contrib/langgraph/langgraph_plugin.py +++ b/temporalio/contrib/langgraph/langgraph_plugin.py @@ -1,3 +1,5 @@ +"""LangGraph plugin for running LangGraph nodes and tasks as Temporal activities.""" + # pyright: reportMissingTypeStubs=false from dataclasses import replace @@ -47,6 +49,7 @@ def __init__( activity_options: dict[str, dict] | None = None, # TODO: Add default_activity_options that apply to all nodes or tasks ): + """Initialize the LangGraph plugin with graphs, entrypoints, and tasks.""" self.activities: list = [] # Graph API: Wrap graph nodes as Activities. @@ -100,8 +103,8 @@ def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: workflow_runner=workflow_runner, ) - # Prepare a [node, @task] to execute as a [Activity, Workflow]. def execute(self, func: Callable, kwargs: dict[str, Any] | None = None) -> Callable: + """Prepare a node or task to execute as an activity or inline in the workflow.""" execute_in = (kwargs or {}).pop("execute_in", "activity") if execute_in == "activity": diff --git a/temporalio/contrib/langgraph/task_cache.py b/temporalio/contrib/langgraph/task_cache.py index 2f5c574c4..d4053c808 100644 --- a/temporalio/contrib/langgraph/task_cache.py +++ b/temporalio/contrib/langgraph/task_cache.py @@ -18,10 +18,12 @@ def set_task_cache(cache: dict[str, Any] | None) -> None: + """Set the task result cache for the current context.""" _task_cache.set(cache) def get_task_cache() -> dict[str, Any] | None: + """Get the task result cache for the current context.""" return _task_cache.get() @@ -70,6 +72,7 @@ def cache_lookup(key: str) -> tuple[bool, Any]: def cache_put(key: str, value: Any) -> None: + """Store a value in the task result cache.""" cache = _task_cache.get() if cache is not None: cache[key] = value diff --git a/tests/contrib/langgraph/e2e_functional_entrypoints.py b/tests/contrib/langgraph/e2e_functional_entrypoints.py index 5498bbc9c..01f871b8f 100644 --- a/tests/contrib/langgraph/e2e_functional_entrypoints.py +++ b/tests/contrib/langgraph/e2e_functional_entrypoints.py @@ -5,8 +5,8 @@ from __future__ import annotations -from langgraph.func import entrypoint, task -from langgraph.types import interrupt +import langgraph.types +from langgraph.func import entrypoint, task # pyright: ignore[reportMissingTypeStubs] @task @@ -120,7 +120,7 @@ async def partial_execution_entrypoint(input_data: dict) -> dict: @task def ask_human(question: str) -> str: - return interrupt(question) + return langgraph.types.interrupt(question) @entrypoint() diff --git a/tests/contrib/langgraph/test_continue_as_new.py b/tests/contrib/langgraph/test_continue_as_new.py index 8fab4cc91..55596f92d 100644 --- a/tests/contrib/langgraph/test_continue_as_new.py +++ b/tests/contrib/langgraph/test_continue_as_new.py @@ -3,8 +3,10 @@ from uuid import uuid4 from langgraph.checkpoint.memory import InMemorySaver -from langgraph.graph import START, StateGraph -from langgraph.graph.state import RunnableConfig +from langgraph.graph import START, StateGraph # pyright: ignore[reportMissingTypeStubs] +from langgraph.graph.state import ( # pyright: ignore[reportMissingTypeStubs] + RunnableConfig, +) from typing_extensions import TypedDict from temporalio import workflow diff --git a/tests/contrib/langgraph/test_continue_as_new_cached.py b/tests/contrib/langgraph/test_continue_as_new_cached.py index 40bdee7a9..b30f0c7dc 100644 --- a/tests/contrib/langgraph/test_continue_as_new_cached.py +++ b/tests/contrib/langgraph/test_continue_as_new_cached.py @@ -9,7 +9,7 @@ from typing import Any from uuid import uuid4 -from langgraph.graph import START, StateGraph +from langgraph.graph import START, StateGraph # pyright: ignore[reportMissingTypeStubs] from typing_extensions import TypedDict from temporalio import workflow diff --git a/tests/contrib/langgraph/test_e2e_functional_v2.py b/tests/contrib/langgraph/test_e2e_functional_v2.py index 579a5f59c..13679676e 100644 --- a/tests/contrib/langgraph/test_e2e_functional_v2.py +++ b/tests/contrib/langgraph/test_e2e_functional_v2.py @@ -13,8 +13,10 @@ from langchain_core.runnables import RunnableConfig from langgraph.checkpoint.memory import InMemorySaver -from langgraph.func import entrypoint as lg_entrypoint -from langgraph.func import task +from langgraph.func import ( # pyright: ignore[reportMissingTypeStubs] + entrypoint as lg_entrypoint, +) +from langgraph.func import task # pyright: ignore[reportMissingTypeStubs] from langgraph.types import Command from temporalio import workflow diff --git a/tests/contrib/langgraph/test_execute_in_workflow.py b/tests/contrib/langgraph/test_execute_in_workflow.py index 0b5767680..58e3fdba0 100644 --- a/tests/contrib/langgraph/test_execute_in_workflow.py +++ b/tests/contrib/langgraph/test_execute_in_workflow.py @@ -1,7 +1,7 @@ from typing import Any from uuid import uuid4 -from langgraph.graph import START, StateGraph +from langgraph.graph import START, StateGraph # pyright: ignore[reportMissingTypeStubs] from typing_extensions import TypedDict from temporalio import workflow diff --git a/tests/contrib/langgraph/test_interrupt.py b/tests/contrib/langgraph/test_interrupt.py index 90bd218bf..d91e6e495 100644 --- a/tests/contrib/langgraph/test_interrupt.py +++ b/tests/contrib/langgraph/test_interrupt.py @@ -2,10 +2,12 @@ from typing import Any from uuid import uuid4 +import langgraph.types from langgraph.checkpoint.memory import InMemorySaver -from langgraph.graph import START, StateGraph -from langgraph.graph.state import RunnableConfig -from langgraph.types import Command, interrupt +from langgraph.graph import START, StateGraph # pyright: ignore[reportMissingTypeStubs] +from langgraph.graph.state import ( # pyright: ignore[reportMissingTypeStubs] + RunnableConfig, +) from typing_extensions import TypedDict from temporalio import workflow @@ -19,7 +21,7 @@ class State(TypedDict): async def node(state: State) -> dict[str, str]: # pyright: ignore[reportUnusedParameter] - return {"value": interrupt("Continue?")} + return {"value": langgraph.types.interrupt("Continue?")} @workflow.defn @@ -32,7 +34,7 @@ async def run(self, input: str) -> Any: result = await g.ainvoke({"value": input}, config) assert result["__interrupt__"][0].value == "Continue?" - return await g.ainvoke(Command(resume="yes"), config) + return await g.ainvoke(langgraph.types.Command(resume="yes"), config) async def test_interrupt(client: Client): diff --git a/tests/contrib/langgraph/test_interrupt_v2.py b/tests/contrib/langgraph/test_interrupt_v2.py index 260b1369c..348466f62 100644 --- a/tests/contrib/langgraph/test_interrupt_v2.py +++ b/tests/contrib/langgraph/test_interrupt_v2.py @@ -8,10 +8,12 @@ from typing import Any from uuid import uuid4 +import langgraph.types from langgraph.checkpoint.memory import InMemorySaver -from langgraph.graph import START, StateGraph -from langgraph.graph.state import RunnableConfig -from langgraph.types import Command, interrupt +from langgraph.graph import START, StateGraph # pyright: ignore[reportMissingTypeStubs] +from langgraph.graph.state import ( # pyright: ignore[reportMissingTypeStubs] + RunnableConfig, +) from typing_extensions import TypedDict from temporalio import workflow @@ -25,7 +27,7 @@ class State(TypedDict): async def node(state: State) -> dict[str, str]: # pyright: ignore[reportUnusedParameter] - return {"value": interrupt("Continue?")} + return {"value": langgraph.types.interrupt("Continue?")} @workflow.defn @@ -42,7 +44,7 @@ async def run(self, input: str) -> Any: assert len(result.interrupts) == 1 assert result.interrupts[0].value == "Continue?" - return await g.ainvoke(Command(resume="yes"), config) + return await g.ainvoke(langgraph.types.Command(resume="yes"), config) async def test_interrupt_v2(client: Client): diff --git a/tests/contrib/langgraph/test_streaming.py b/tests/contrib/langgraph/test_streaming.py index 6b2f66bc8..c14cecec9 100644 --- a/tests/contrib/langgraph/test_streaming.py +++ b/tests/contrib/langgraph/test_streaming.py @@ -2,7 +2,7 @@ from typing import Any from uuid import uuid4 -from langgraph.graph import START, StateGraph +from langgraph.graph import START, StateGraph # pyright: ignore[reportMissingTypeStubs] from typing_extensions import TypedDict from temporalio import workflow diff --git a/tests/contrib/langgraph/test_subgraph_activity.py b/tests/contrib/langgraph/test_subgraph_activity.py index af3eebf88..fd802ff9a 100644 --- a/tests/contrib/langgraph/test_subgraph_activity.py +++ b/tests/contrib/langgraph/test_subgraph_activity.py @@ -2,7 +2,7 @@ from typing import Any from uuid import uuid4 -from langgraph.graph import START, StateGraph +from langgraph.graph import START, StateGraph # pyright: ignore[reportMissingTypeStubs] from typing_extensions import TypedDict from temporalio import workflow diff --git a/tests/contrib/langgraph/test_subgraph_workflow.py b/tests/contrib/langgraph/test_subgraph_workflow.py index 67cddc40e..c055a2a7a 100644 --- a/tests/contrib/langgraph/test_subgraph_workflow.py +++ b/tests/contrib/langgraph/test_subgraph_workflow.py @@ -2,7 +2,7 @@ from typing import Any from uuid import uuid4 -from langgraph.graph import START, StateGraph +from langgraph.graph import START, StateGraph # pyright: ignore[reportMissingTypeStubs] from typing_extensions import TypedDict from temporalio import workflow diff --git a/tests/contrib/langgraph/test_timeout.py b/tests/contrib/langgraph/test_timeout.py index 989d323ce..2521e3a00 100644 --- a/tests/contrib/langgraph/test_timeout.py +++ b/tests/contrib/langgraph/test_timeout.py @@ -3,7 +3,7 @@ from typing import Any from uuid import uuid4 -from langgraph.graph import START, StateGraph +from langgraph.graph import START, StateGraph # pyright: ignore[reportMissingTypeStubs] from pytest import raises from typing_extensions import TypedDict diff --git a/tests/contrib/langgraph/test_two_nodes.py b/tests/contrib/langgraph/test_two_nodes.py index d40833795..7a8affcac 100644 --- a/tests/contrib/langgraph/test_two_nodes.py +++ b/tests/contrib/langgraph/test_two_nodes.py @@ -2,7 +2,7 @@ from typing import Any from uuid import uuid4 -from langgraph.graph import START, StateGraph +from langgraph.graph import START, StateGraph # pyright: ignore[reportMissingTypeStubs] from typing_extensions import TypedDict from temporalio import workflow From ea24a137228c09e7b0d1f423d38566d4a2a65a54 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Tue, 14 Apr 2026 16:57:42 -0700 Subject: [PATCH 07/47] copilot code review --- temporalio/contrib/langgraph/README.md | 2 -- temporalio/contrib/langgraph/langgraph_plugin.py | 5 +++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/temporalio/contrib/langgraph/README.md b/temporalio/contrib/langgraph/README.md index f437fe476..e8a4da74b 100644 --- a/temporalio/contrib/langgraph/README.md +++ b/temporalio/contrib/langgraph/README.md @@ -16,8 +16,6 @@ or with pip: pip install temporalio[langgraph] ``` -Requires `langgraph==1.1.3` and `temporalio>=1.24.0`. - ## Plugin Initialization ### Graph API diff --git a/temporalio/contrib/langgraph/langgraph_plugin.py b/temporalio/contrib/langgraph/langgraph_plugin.py index 3b7677917..c213d2182 100644 --- a/temporalio/contrib/langgraph/langgraph_plugin.py +++ b/temporalio/contrib/langgraph/langgraph_plugin.py @@ -105,12 +105,13 @@ def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: def execute(self, func: Callable, kwargs: dict[str, Any] | None = None) -> Callable: """Prepare a node or task to execute as an activity or inline in the workflow.""" - execute_in = (kwargs or {}).pop("execute_in", "activity") + opts = kwargs or {} + execute_in = opts.pop("execute_in", "activity") if execute_in == "activity": a = activity.defn(name=func.__name__)(wrap_activity(func)) self.activities.append(a) - return wrap_execute_activity(a, task_id=task_id(func), **(kwargs or {})) + return wrap_execute_activity(a, task_id=task_id(func), **opts) elif execute_in == "workflow": return func else: From 19a5052fe43e1741c79db3ccc7f028ddf2fc995f Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Wed, 15 Apr 2026 10:08:11 -0700 Subject: [PATCH 08/47] fix mypy lint --- temporalio/contrib/langgraph/langgraph_plugin.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/temporalio/contrib/langgraph/langgraph_plugin.py b/temporalio/contrib/langgraph/langgraph_plugin.py index c213d2182..edd9a91d9 100644 --- a/temporalio/contrib/langgraph/langgraph_plugin.py +++ b/temporalio/contrib/langgraph/langgraph_plugin.py @@ -22,7 +22,7 @@ # Save registered graphs/entrypoints at the module level to avoid being refreshed by the sandbox. _graph_registry: dict[str, StateGraph[Any]] = {} -_entrypoint_registry: dict[str, Pregel] = {} +_entrypoint_registry: dict[str, Pregel[Any, Any, Any, Any]] = {} class LangGraphPlugin(SimplePlugin): @@ -43,7 +43,7 @@ def __init__( # Graph API graphs: dict[str, StateGraph] | None = None, # Functional API - entrypoints: dict[str, Pregel] | None = None, + entrypoints: dict[str, Pregel[Any, Any, Any, Any]] | None = None, tasks: list | None = None, # TODO: Remove activity_options when we have support for @task(metadata=...) activity_options: dict[str, dict] | None = None, @@ -134,7 +134,9 @@ def graph( return _graph_registry[name] -def entrypoint(name: str, cache: dict[str, Any] | None = None) -> Pregel: +def entrypoint( + name: str, cache: dict[str, Any] | None = None +) -> Pregel[Any, Any, Any, Any]: """Retrieve a registered entrypoint by name. Args: From 0201c897cb1d39571262a2c5d493ad4220ce4f85 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Wed, 15 Apr 2026 13:03:19 -0700 Subject: [PATCH 09/47] separate graphs and entrypoints by task queue to avoid concurrent write bug --- .../contrib/langgraph/langgraph_plugin.py | 51 ++++++++++++++----- 1 file changed, 38 insertions(+), 13 deletions(-) diff --git a/temporalio/contrib/langgraph/langgraph_plugin.py b/temporalio/contrib/langgraph/langgraph_plugin.py index edd9a91d9..bcb355230 100644 --- a/temporalio/contrib/langgraph/langgraph_plugin.py +++ b/temporalio/contrib/langgraph/langgraph_plugin.py @@ -9,7 +9,7 @@ from langgraph.graph import StateGraph from langgraph.pregel import Pregel -from temporalio import activity +from temporalio import activity, workflow from temporalio.contrib.langgraph.activity import wrap_activity, wrap_execute_activity from temporalio.contrib.langgraph.task_cache import ( get_task_cache, @@ -17,12 +17,13 @@ task_id, ) from temporalio.plugin import SimplePlugin -from temporalio.worker import WorkflowRunner +from temporalio.worker import WorkerConfig, WorkflowRunner from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner # Save registered graphs/entrypoints at the module level to avoid being refreshed by the sandbox. -_graph_registry: dict[str, StateGraph[Any]] = {} -_entrypoint_registry: dict[str, Pregel[Any, Any, Any, Any]] = {} +# Keyed by task queue to isolate concurrent Workers/Plugins in the same process. +_graph_registry: dict[str, dict[str, StateGraph[Any]]] = {} +_entrypoint_registry: dict[str, dict[str, Pregel[Any, Any, Any, Any]]] = {} class LangGraphPlugin(SimplePlugin): @@ -51,10 +52,11 @@ def __init__( ): """Initialize the LangGraph plugin with graphs, entrypoints, and tasks.""" self.activities: list = [] + self._graphs: dict[str, StateGraph[Any]] = graphs or {} + self._entrypoints: dict[str, Pregel[Any, Any, Any, Any]] = entrypoints or {} - # Graph API: Wrap graph nodes as Activities. + # Graph API: Wrap graph nodes as Temporal Activities. if graphs: - _graph_registry.update(graphs) for graph in graphs.values(): for name, node in graph.nodes.items(): runnable = node.runnable @@ -67,11 +69,7 @@ def __init__( runnable.func_accepts = {} runnable.afunc = self.execute(runnable.afunc, node.metadata) - # Functional API: Register @entrypoint functions - if entrypoints: - _entrypoint_registry.update(entrypoints) - - # Functional API: Wrap @task functions as Activities. + # Functional API: Wrap @task functions as Temporal Activities. if tasks: for task in tasks: name = task.func.__name__ @@ -103,6 +101,19 @@ def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: workflow_runner=workflow_runner, ) + def configure_worker(self, config: WorkerConfig) -> WorkerConfig: + """Register graphs/entrypoints scoped to the worker's task queue.""" + task_queue = config.get("task_queue") + if not task_queue: + raise ValueError( + "Worker config must include a task_queue for LangGraphPlugin" + ) + if self._graphs: + _graph_registry.setdefault(task_queue, {}).update(self._graphs) + if self._entrypoints: + _entrypoint_registry.setdefault(task_queue, {}).update(self._entrypoints) + return super().configure_worker(config) + def execute(self, func: Callable, kwargs: dict[str, Any] | None = None) -> Callable: """Prepare a node or task to execute as an activity or inline in the workflow.""" opts = kwargs or {} @@ -131,7 +142,14 @@ def graph( """ _patch_event_loop() set_task_cache(cache or {}) - return _graph_registry[name] + task_queue = workflow.info().task_queue + registry = _graph_registry.get(task_queue, {}) + if name not in registry: + raise KeyError( + f"Graph {name!r} not found for task queue {task_queue!r}. " + f"Available graphs: {list(registry.keys())}" + ) + return registry[name] def entrypoint( @@ -147,7 +165,14 @@ def entrypoint( """ _patch_event_loop() set_task_cache(cache or {}) - return _entrypoint_registry[name] + task_queue = workflow.info().task_queue + registry = _entrypoint_registry.get(task_queue, {}) + if name not in registry: + raise KeyError( + f"Entrypoint {name!r} not found for task queue {task_queue!r}. " + f"Available entrypoints: {list(registry.keys())}" + ) + return registry[name] def cache() -> dict[str, Any] | None: From e98dd9678a2141a77bc2525a7884675fb77ad383 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Wed, 15 Apr 2026 13:33:48 -0700 Subject: [PATCH 10/47] use graph.node or task_id for activity names to avoid collisions --- .../contrib/langgraph/langgraph_plugin.py | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/temporalio/contrib/langgraph/langgraph_plugin.py b/temporalio/contrib/langgraph/langgraph_plugin.py index bcb355230..0546cdf22 100644 --- a/temporalio/contrib/langgraph/langgraph_plugin.py +++ b/temporalio/contrib/langgraph/langgraph_plugin.py @@ -57,17 +57,21 @@ def __init__( # Graph API: Wrap graph nodes as Temporal Activities. if graphs: - for graph in graphs.values(): - for name, node in graph.nodes.items(): + for graph_name, graph in graphs.items(): + for node_name, node in graph.nodes.items(): runnable = node.runnable if ( not isinstance(runnable, RunnableCallable) or runnable.afunc is None ): - raise ValueError(f"Node {name} must have an async function") + raise ValueError( + f"Node {node_name} must have an async function" + ) # Remove LangSmith-related callback functions that can't be serialized between the workflow and activity. runnable.func_accepts = {} - runnable.afunc = self.execute(runnable.afunc, node.metadata) + runnable.afunc = self.execute( + f"{graph_name}.{node_name}", runnable.afunc, node.metadata + ) # Functional API: Wrap @task functions as Temporal Activities. if tasks: @@ -75,7 +79,9 @@ def __init__( name = task.func.__name__ opts = (activity_options or {}).get(name, {}) - task.func = self.execute(task.func, opts) + task.func = self.execute( + task_id(task.func), task.func, opts + ) task.func.__name__ = name task.func.__qualname__ = getattr(task.func, "__qualname__", name) @@ -114,13 +120,18 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig: _entrypoint_registry.setdefault(task_queue, {}).update(self._entrypoints) return super().configure_worker(config) - def execute(self, func: Callable, kwargs: dict[str, Any] | None = None) -> Callable: + def execute( + self, + activity_name: str, + func: Callable, + kwargs: dict[str, Any] | None = None, + ) -> Callable: """Prepare a node or task to execute as an activity or inline in the workflow.""" opts = kwargs or {} execute_in = opts.pop("execute_in", "activity") if execute_in == "activity": - a = activity.defn(name=func.__name__)(wrap_activity(func)) + a = activity.defn(name=activity_name)(wrap_activity(func)) self.activities.append(a) return wrap_execute_activity(a, task_id=task_id(func), **opts) elif execute_in == "workflow": From e1a93fcceccd871181a03ae28c9e72067c5ce498 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Wed, 15 Apr 2026 14:29:39 -0700 Subject: [PATCH 11/47] rm local conftest in favor of global and fix lint --- .../contrib/langgraph/langgraph_plugin.py | 4 +-- tests/contrib/langgraph/conftest.py | 26 ------------------- 2 files changed, 1 insertion(+), 29 deletions(-) delete mode 100644 tests/contrib/langgraph/conftest.py diff --git a/temporalio/contrib/langgraph/langgraph_plugin.py b/temporalio/contrib/langgraph/langgraph_plugin.py index 0546cdf22..c2ece6603 100644 --- a/temporalio/contrib/langgraph/langgraph_plugin.py +++ b/temporalio/contrib/langgraph/langgraph_plugin.py @@ -79,9 +79,7 @@ def __init__( name = task.func.__name__ opts = (activity_options or {}).get(name, {}) - task.func = self.execute( - task_id(task.func), task.func, opts - ) + task.func = self.execute(task_id(task.func), task.func, opts) task.func.__name__ = name task.func.__qualname__ = getattr(task.func, "__qualname__", name) diff --git a/tests/contrib/langgraph/conftest.py b/tests/contrib/langgraph/conftest.py deleted file mode 100644 index dba39e1b2..000000000 --- a/tests/contrib/langgraph/conftest.py +++ /dev/null @@ -1,26 +0,0 @@ -import asyncio -from collections.abc import AsyncGenerator, Iterator - -import pytest - -from temporalio.client import Client -from temporalio.testing import WorkflowEnvironment - - -@pytest.fixture(scope="session") -def event_loop() -> Iterator[asyncio.AbstractEventLoop]: - loop = asyncio.new_event_loop() - yield loop - loop.close() - - -@pytest.fixture(scope="session") -async def env() -> AsyncGenerator[WorkflowEnvironment, None]: - env = await WorkflowEnvironment.start_local() - yield env - await env.shutdown() - - -@pytest.fixture -async def client(env: WorkflowEnvironment) -> Client: - return env.client From 7d925b00a350e7777c165f9063f6143d6257dda6 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Thu, 16 Apr 2026 13:41:47 -0700 Subject: [PATCH 12/47] allow langgraph 1.1 --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f741a8d5a..cf20412ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ opentelemetry = ["opentelemetry-api>=1.11.1,<2", "opentelemetry-sdk>=1.11.1,<2"] pydantic = ["pydantic>=2.0.0,<3"] openai-agents = ["openai-agents>=0.14.0", "mcp>=1.9.4, <2"] google-adk = ["google-adk>=1.27.0,<2"] -langgraph = ["langgraph>=1.1.6"] +langgraph = ["langgraph>=1.1.0"] langsmith = ["langsmith>=0.7.0,<0.8"] lambda-worker-otel = [ "opentelemetry-api>=1.11.1,<2", @@ -80,7 +80,7 @@ dev = [ "pytest-rerunfailures>=16.1", "pytest-xdist>=3.6,<4", "moto[s3,server]>=5", - "langgraph>=1.1.6", + "langgraph>=1.1.0", "langsmith>=0.7.0,<0.8", "setuptools<82", "opentelemetry-exporter-otlp-proto-grpc>=1.11.1,<2", From 05bd7aababfbe9300de5c6695379aa597f6e32c6 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Thu, 16 Apr 2026 14:10:02 -0700 Subject: [PATCH 13/47] uv lock --- uv.lock | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/uv.lock b/uv.lock index f12e6bfec..a9ad989ae 100644 --- a/uv.lock +++ b/uv.lock @@ -1812,7 +1812,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/38/3f/9859f655d11901e7b2996c6e3d33e0caa9a1d4572c3bc61ed0faa64b2f4c/greenlet-3.3.2-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:9bc885b89709d901859cf95179ec9f6bb67a3d2bb1f0e88456461bd4b7f8fd0d", size = 277747, upload-time = "2026-02-20T20:16:21.325Z" }, { url = "https://files.pythonhosted.org/packages/fb/07/cb284a8b5c6498dbd7cba35d31380bb123d7dceaa7907f606c8ff5993cbf/greenlet-3.3.2-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b568183cf65b94919be4438dc28416b234b678c608cafac8874dfeeb2a9bbe13", size = 579202, upload-time = "2026-02-20T20:47:28.955Z" }, { url = "https://files.pythonhosted.org/packages/ed/45/67922992b3a152f726163b19f890a85129a992f39607a2a53155de3448b8/greenlet-3.3.2-cp310-cp310-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:527fec58dc9f90efd594b9b700662ed3fb2493c2122067ac9c740d98080a620e", size = 590620, upload-time = "2026-02-20T20:55:55.581Z" }, - { url = "https://files.pythonhosted.org/packages/03/5f/6e2a7d80c353587751ef3d44bb947f0565ec008a2e0927821c007e96d3a7/greenlet-3.3.2-cp310-cp310-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:508c7f01f1791fbc8e011bd508f6794cb95397fdb198a46cb6635eb5b78d85a7", size = 602132, upload-time = "2026-02-20T21:02:43.261Z" }, { url = "https://files.pythonhosted.org/packages/ad/55/9f1ebb5a825215fadcc0f7d5073f6e79e3007e3282b14b22d6aba7ca6cb8/greenlet-3.3.2-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ad0c8917dd42a819fe77e6bdfcb84e3379c0de956469301d9fd36427a1ca501f", size = 591729, upload-time = "2026-02-20T20:20:58.395Z" }, { url = "https://files.pythonhosted.org/packages/24/b4/21f5455773d37f94b866eb3cf5caed88d6cea6dd2c6e1f9c34f463cba3ec/greenlet-3.3.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:97245cc10e5515dbc8c3104b2928f7f02b6813002770cfaffaf9a6e0fc2b94ef", size = 1551946, upload-time = "2026-02-20T20:49:31.102Z" }, { url = "https://files.pythonhosted.org/packages/00/68/91f061a926abead128fe1a87f0b453ccf07368666bd59ffa46016627a930/greenlet-3.3.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8c1fdd7d1b309ff0da81d60a9688a8bd044ac4e18b250320a96fc68d31c209ca", size = 1618494, upload-time = "2026-02-20T20:21:06.541Z" }, @@ -1820,7 +1819,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f3/47/16400cb42d18d7a6bb46f0626852c1718612e35dcb0dffa16bbaffdf5dd2/greenlet-3.3.2-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:c56692189a7d1c7606cb794be0a8381470d95c57ce5be03fb3d0ef57c7853b86", size = 278890, upload-time = "2026-02-20T20:19:39.263Z" }, { url = "https://files.pythonhosted.org/packages/a3/90/42762b77a5b6aa96cd8c0e80612663d39211e8ae8a6cd47c7f1249a66262/greenlet-3.3.2-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1ebd458fa8285960f382841da585e02201b53a5ec2bac6b156fc623b5ce4499f", size = 581120, upload-time = "2026-02-20T20:47:30.161Z" }, { url = "https://files.pythonhosted.org/packages/bf/6f/f3d64f4fa0a9c7b5c5b3c810ff1df614540d5aa7d519261b53fba55d4df9/greenlet-3.3.2-cp311-cp311-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:a443358b33c4ec7b05b79a7c8b466f5d275025e750298be7340f8fc63dff2a55", size = 594363, upload-time = "2026-02-20T20:55:56.965Z" }, - { url = "https://files.pythonhosted.org/packages/9c/8b/1430a04657735a3f23116c2e0d5eb10220928846e4537a938a41b350bed6/greenlet-3.3.2-cp311-cp311-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:4375a58e49522698d3e70cc0b801c19433021b5c37686f7ce9c65b0d5c8677d2", size = 605046, upload-time = "2026-02-20T21:02:45.234Z" }, { url = "https://files.pythonhosted.org/packages/72/83/3e06a52aca8128bdd4dcd67e932b809e76a96ab8c232a8b025b2850264c5/greenlet-3.3.2-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8e2cd90d413acbf5e77ae41e5d3c9b3ac1d011a756d7284d7f3f2b806bbd6358", size = 594156, upload-time = "2026-02-20T20:20:59.955Z" }, { url = "https://files.pythonhosted.org/packages/70/79/0de5e62b873e08fe3cef7dbe84e5c4bc0e8ed0c7ff131bccb8405cd107c8/greenlet-3.3.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:442b6057453c8cb29b4fb36a2ac689382fc71112273726e2423f7f17dc73bf99", size = 1554649, upload-time = "2026-02-20T20:49:32.293Z" }, { url = "https://files.pythonhosted.org/packages/5a/00/32d30dee8389dc36d42170a9c66217757289e2afb0de59a3565260f38373/greenlet-3.3.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:45abe8eb6339518180d5a7fa47fa01945414d7cca5ecb745346fc6a87d2750be", size = 1619472, upload-time = "2026-02-20T20:21:07.966Z" }, @@ -1829,7 +1827,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ea/ab/1608e5a7578e62113506740b88066bf09888322a311cff602105e619bd87/greenlet-3.3.2-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:ac8d61d4343b799d1e526db579833d72f23759c71e07181c2d2944e429eb09cd", size = 280358, upload-time = "2026-02-20T20:17:43.971Z" }, { url = "https://files.pythonhosted.org/packages/a5/23/0eae412a4ade4e6623ff7626e38998cb9b11e9ff1ebacaa021e4e108ec15/greenlet-3.3.2-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3ceec72030dae6ac0c8ed7591b96b70410a8be370b6a477b1dbc072856ad02bd", size = 601217, upload-time = "2026-02-20T20:47:31.462Z" }, { url = "https://files.pythonhosted.org/packages/f8/16/5b1678a9c07098ecb9ab2dd159fafaf12e963293e61ee8d10ecb55273e5e/greenlet-3.3.2-cp312-cp312-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:a2a5be83a45ce6188c045bcc44b0ee037d6a518978de9a5d97438548b953a1ac", size = 611792, upload-time = "2026-02-20T20:55:58.423Z" }, - { url = "https://files.pythonhosted.org/packages/5c/c5/cc09412a29e43406eba18d61c70baa936e299bc27e074e2be3806ed29098/greenlet-3.3.2-cp312-cp312-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:ae9e21c84035c490506c17002f5c8ab25f980205c3e61ddb3a2a2a2e6c411fcb", size = 626250, upload-time = "2026-02-20T21:02:46.596Z" }, { url = "https://files.pythonhosted.org/packages/50/1f/5155f55bd71cabd03765a4aac9ac446be129895271f73872c36ebd4b04b6/greenlet-3.3.2-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:43e99d1749147ac21dde49b99c9abffcbc1e2d55c67501465ef0930d6e78e070", size = 613875, upload-time = "2026-02-20T20:21:01.102Z" }, { url = "https://files.pythonhosted.org/packages/fc/dd/845f249c3fcd69e32df80cdab059b4be8b766ef5830a3d0aa9d6cad55beb/greenlet-3.3.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4c956a19350e2c37f2c48b336a3afb4bff120b36076d9d7fb68cb44e05d95b79", size = 1571467, upload-time = "2026-02-20T20:49:33.495Z" }, { url = "https://files.pythonhosted.org/packages/2a/50/2649fe21fcc2b56659a452868e695634722a6655ba245d9f77f5656010bf/greenlet-3.3.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:6c6f8ba97d17a1e7d664151284cb3315fc5f8353e75221ed4324f84eb162b395", size = 1640001, upload-time = "2026-02-20T20:21:09.154Z" }, @@ -1838,7 +1835,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ac/48/f8b875fa7dea7dd9b33245e37f065af59df6a25af2f9561efa8d822fde51/greenlet-3.3.2-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:aa6ac98bdfd716a749b84d4034486863fd81c3abde9aa3cf8eff9127981a4ae4", size = 279120, upload-time = "2026-02-20T20:19:01.9Z" }, { url = "https://files.pythonhosted.org/packages/49/8d/9771d03e7a8b1ee456511961e1b97a6d77ae1dea4a34a5b98eee706689d3/greenlet-3.3.2-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ab0c7e7901a00bc0a7284907273dc165b32e0d109a6713babd04471327ff7986", size = 603238, upload-time = "2026-02-20T20:47:32.873Z" }, { url = "https://files.pythonhosted.org/packages/59/0e/4223c2bbb63cd5c97f28ffb2a8aee71bdfb30b323c35d409450f51b91e3e/greenlet-3.3.2-cp313-cp313-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:d248d8c23c67d2291ffd47af766e2a3aa9fa1c6703155c099feb11f526c63a92", size = 614219, upload-time = "2026-02-20T20:55:59.817Z" }, - { url = "https://files.pythonhosted.org/packages/94/2b/4d012a69759ac9d77210b8bfb128bc621125f5b20fc398bce3940d036b1c/greenlet-3.3.2-cp313-cp313-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:ccd21bb86944ca9be6d967cf7691e658e43417782bce90b5d2faeda0ff78a7dd", size = 628268, upload-time = "2026-02-20T21:02:48.024Z" }, { url = "https://files.pythonhosted.org/packages/7a/34/259b28ea7a2a0c904b11cd36c79b8cef8019b26ee5dbe24e73b469dea347/greenlet-3.3.2-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b6997d360a4e6a4e936c0f9625b1c20416b8a0ea18a8e19cabbefc712e7397ab", size = 616774, upload-time = "2026-02-20T20:21:02.454Z" }, { url = "https://files.pythonhosted.org/packages/0a/03/996c2d1689d486a6e199cb0f1cf9e4aa940c500e01bdf201299d7d61fa69/greenlet-3.3.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:64970c33a50551c7c50491671265d8954046cb6e8e2999aacdd60e439b70418a", size = 1571277, upload-time = "2026-02-20T20:49:34.795Z" }, { url = "https://files.pythonhosted.org/packages/d9/c4/2570fc07f34a39f2caf0bf9f24b0a1a0a47bc2e8e465b2c2424821389dfc/greenlet-3.3.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:1a9172f5bf6bd88e6ba5a84e0a68afeac9dc7b6b412b245dd64f52d83c81e55b", size = 1640455, upload-time = "2026-02-20T20:21:10.261Z" }, @@ -1847,7 +1843,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3f/ae/8bffcbd373b57a5992cd077cbe8858fff39110480a9d50697091faea6f39/greenlet-3.3.2-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:8d1658d7291f9859beed69a776c10822a0a799bc4bfe1bd4272bb60e62507dab", size = 279650, upload-time = "2026-02-20T20:18:00.783Z" }, { url = "https://files.pythonhosted.org/packages/d1/c0/45f93f348fa49abf32ac8439938726c480bd96b2a3c6f4d949ec0124b69f/greenlet-3.3.2-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:18cb1b7337bca281915b3c5d5ae19f4e76d35e1df80f4ad3c1a7be91fadf1082", size = 650295, upload-time = "2026-02-20T20:47:34.036Z" }, { url = "https://files.pythonhosted.org/packages/b3/de/dd7589b3f2b8372069ab3e4763ea5329940fc7ad9dcd3e272a37516d7c9b/greenlet-3.3.2-cp314-cp314-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c2e47408e8ce1c6f1ceea0dffcdf6ebb85cc09e55c7af407c99f1112016e45e9", size = 662163, upload-time = "2026-02-20T20:56:01.295Z" }, - { url = "https://files.pythonhosted.org/packages/cd/ac/85804f74f1ccea31ba518dcc8ee6f14c79f73fe36fa1beba38930806df09/greenlet-3.3.2-cp314-cp314-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:e3cb43ce200f59483eb82949bf1835a99cf43d7571e900d7c8d5c62cdf25d2f9", size = 675371, upload-time = "2026-02-20T21:02:49.664Z" }, { url = "https://files.pythonhosted.org/packages/d2/d8/09bfa816572a4d83bccd6750df1926f79158b1c36c5f73786e26dbe4ee38/greenlet-3.3.2-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:63d10328839d1973e5ba35e98cccbca71b232b14051fd957b6f8b6e8e80d0506", size = 664160, upload-time = "2026-02-20T20:21:04.015Z" }, { url = "https://files.pythonhosted.org/packages/48/cf/56832f0c8255d27f6c35d41b5ec91168d74ec721d85f01a12131eec6b93c/greenlet-3.3.2-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:8e4ab3cfb02993c8cc248ea73d7dae6cec0253e9afa311c9b37e603ca9fad2ce", size = 1619181, upload-time = "2026-02-20T20:49:36.052Z" }, { url = "https://files.pythonhosted.org/packages/0a/23/b90b60a4aabb4cec0796e55f25ffbfb579a907c3898cd2905c8918acaa16/greenlet-3.3.2-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:94ad81f0fd3c0c0681a018a976e5c2bd2ca2d9d94895f23e7bb1af4e8af4e2d5", size = 1687713, upload-time = "2026-02-20T20:21:11.684Z" }, @@ -1856,7 +1851,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/98/6d/8f2ef704e614bcf58ed43cfb8d87afa1c285e98194ab2cfad351bf04f81e/greenlet-3.3.2-cp314-cp314t-macosx_11_0_universal2.whl", hash = "sha256:e26e72bec7ab387ac80caa7496e0f908ff954f31065b0ffc1f8ecb1338b11b54", size = 286617, upload-time = "2026-02-20T20:19:29.856Z" }, { url = "https://files.pythonhosted.org/packages/5e/0d/93894161d307c6ea237a43988f27eba0947b360b99ac5239ad3fe09f0b47/greenlet-3.3.2-cp314-cp314t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8b466dff7a4ffda6ca975979bab80bdadde979e29fc947ac3be4451428d8b0e4", size = 655189, upload-time = "2026-02-20T20:47:35.742Z" }, { url = "https://files.pythonhosted.org/packages/f5/2c/d2d506ebd8abcb57386ec4f7ba20f4030cbe56eae541bc6fd6ef399c0b41/greenlet-3.3.2-cp314-cp314t-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:b8bddc5b73c9720bea487b3bffdb1840fe4e3656fba3bd40aa1489e9f37877ff", size = 658225, upload-time = "2026-02-20T20:56:02.527Z" }, - { url = "https://files.pythonhosted.org/packages/d1/67/8197b7e7e602150938049d8e7f30de1660cfb87e4c8ee349b42b67bdb2e1/greenlet-3.3.2-cp314-cp314t-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:59b3e2c40f6706b05a9cd299c836c6aa2378cabe25d021acd80f13abf81181cf", size = 666581, upload-time = "2026-02-20T21:02:51.526Z" }, { url = "https://files.pythonhosted.org/packages/8e/30/3a09155fbf728673a1dea713572d2d31159f824a37c22da82127056c44e4/greenlet-3.3.2-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b26b0f4428b871a751968285a1ac9648944cea09807177ac639b030bddebcea4", size = 657907, upload-time = "2026-02-20T20:21:05.259Z" }, { url = "https://files.pythonhosted.org/packages/f3/fd/d05a4b7acd0154ed758797f0a43b4c0962a843bedfe980115e842c5b2d08/greenlet-3.3.2-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:1fb39a11ee2e4d94be9a76671482be9398560955c9e568550de0224e41104727", size = 1618857, upload-time = "2026-02-20T20:49:37.309Z" }, { url = "https://files.pythonhosted.org/packages/6f/e1/50ee92a5db521de8f35075b5eff060dd43d39ebd46c2181a2042f7070385/greenlet-3.3.2-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:20154044d9085151bc309e7689d6f7ba10027f8f5a8c0676ad398b951913d89e", size = 1680010, upload-time = "2026-02-20T20:21:13.427Z" }, @@ -5231,7 +5225,7 @@ requires-dist = [ { name = "aioboto3", marker = "extra == 'aioboto3'", specifier = ">=10.4.0" }, { name = "google-adk", marker = "extra == 'google-adk'", specifier = ">=1.27.0,<2" }, { name = "grpcio", marker = "extra == 'grpc'", specifier = ">=1.48.2,<2" }, - { name = "langgraph", marker = "extra == 'langgraph'", specifier = ">=1.1.6" }, + { name = "langgraph", marker = "extra == 'langgraph'", specifier = ">=1.1.0" }, { name = "langsmith", marker = "extra == 'langsmith'", specifier = ">=0.7.0,<0.8" }, { name = "mcp", marker = "extra == 'openai-agents'", specifier = ">=1.9.4,<2" }, { name = "nexus-rpc", specifier = "==1.4.0" }, @@ -5259,7 +5253,7 @@ dev = [ { name = "googleapis-common-protos", specifier = "==1.70.0" }, { name = "grpcio-tools", specifier = ">=1.48.2,<2" }, { name = "httpx", specifier = ">=0.28.1" }, - { name = "langgraph", specifier = ">=1.1.6" }, + { name = "langgraph", specifier = ">=1.1.0" }, { name = "langsmith", specifier = ">=0.7.0,<0.8" }, { name = "litellm", specifier = ">=1.83.0" }, { name = "maturin", specifier = ">=1.8.2" }, From 2253a5fdaf7d39d23a81592233c111476c213156 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Thu, 16 Apr 2026 14:38:10 -0700 Subject: [PATCH 14/47] add default_activity_options --- .../contrib/langgraph/langgraph_plugin.py | 7 ++++--- .../contrib/langgraph/test_continue_as_new.py | 15 +++++++------ .../contrib/langgraph/test_e2e_functional.py | 15 ++++--------- .../langgraph/test_e2e_functional_v2.py | 11 ++++------ tests/contrib/langgraph/test_interrupt.py | 15 +++++++------ tests/contrib/langgraph/test_interrupt_v2.py | 15 +++++++------ tests/contrib/langgraph/test_streaming.py | 21 +++++++++---------- .../langgraph/test_subgraph_activity.py | 15 +++++++------ tests/contrib/langgraph/test_timeout.py | 19 +++++++++-------- tests/contrib/langgraph/test_two_nodes.py | 21 +++++++++---------- 10 files changed, 78 insertions(+), 76 deletions(-) diff --git a/temporalio/contrib/langgraph/langgraph_plugin.py b/temporalio/contrib/langgraph/langgraph_plugin.py index c2ece6603..8c85ac8e6 100644 --- a/temporalio/contrib/langgraph/langgraph_plugin.py +++ b/temporalio/contrib/langgraph/langgraph_plugin.py @@ -48,7 +48,7 @@ def __init__( tasks: list | None = None, # TODO: Remove activity_options when we have support for @task(metadata=...) activity_options: dict[str, dict] | None = None, - # TODO: Add default_activity_options that apply to all nodes or tasks + default_activity_options: dict[str, Any] | None = None, ): """Initialize the LangGraph plugin with graphs, entrypoints, and tasks.""" self.activities: list = [] @@ -69,15 +69,16 @@ def __init__( ) # Remove LangSmith-related callback functions that can't be serialized between the workflow and activity. runnable.func_accepts = {} + opts = {**(default_activity_options or {}), **(node.metadata or {})} runnable.afunc = self.execute( - f"{graph_name}.{node_name}", runnable.afunc, node.metadata + f"{graph_name}.{node_name}", runnable.afunc, opts ) # Functional API: Wrap @task functions as Temporal Activities. if tasks: for task in tasks: name = task.func.__name__ - opts = (activity_options or {}).get(name, {}) + opts = {**(default_activity_options or {}), **(activity_options or {}).get(name, {})} task.func = self.execute(task_id(task.func), task.func, opts) task.func.__name__ = name diff --git a/tests/contrib/langgraph/test_continue_as_new.py b/tests/contrib/langgraph/test_continue_as_new.py index 55596f92d..9f09df8fe 100644 --- a/tests/contrib/langgraph/test_continue_as_new.py +++ b/tests/contrib/langgraph/test_continue_as_new.py @@ -42,11 +42,7 @@ async def run(self, values: dict[str, str]) -> Any: async def test_continue_as_new(client: Client): g = StateGraph(State) - g.add_node( - "node", - node, - metadata={"start_to_close_timeout": timedelta(seconds=10)}, - ) + g.add_node("node", node) g.add_edge(START, "node") task_queue = f"my-graph-{uuid4()}" @@ -55,7 +51,14 @@ async def test_continue_as_new(client: Client): client, task_queue=task_queue, workflows=[ContinueAsNewWorkflow], - plugins=[LangGraphPlugin(graphs={"my-graph": g})], + plugins=[ + LangGraphPlugin( + graphs={"my-graph": g}, + default_activity_options={ + "start_to_close_timeout": timedelta(seconds=10) + }, + ) + ], ): result = await client.execute_workflow( ContinueAsNewWorkflow.run, diff --git a/tests/contrib/langgraph/test_e2e_functional.py b/tests/contrib/langgraph/test_e2e_functional.py index 73128b69a..c84b03b5d 100644 --- a/tests/contrib/langgraph/test_e2e_functional.py +++ b/tests/contrib/langgraph/test_e2e_functional.py @@ -6,7 +6,6 @@ from __future__ import annotations from datetime import timedelta -from typing import Any from uuid import uuid4 from temporalio.client import Client @@ -37,13 +36,7 @@ SimpleFunctionalE2EWorkflow, ) - -def _activity_opts(*task_funcs: Any) -> dict[str, dict]: - """Build activity_options dict giving every task the same 30s timeout.""" - return { - t.func.__name__: {"start_to_close_timeout": timedelta(seconds=30)} - for t in task_funcs - } +_DEFAULT_ACTIVITY_OPTIONS = {"start_to_close_timeout": timedelta(seconds=30)} class TestFunctionalAPIBasicExecution: @@ -60,7 +53,7 @@ async def test_simple_functional_entrypoint(self, client: Client) -> None: LangGraphPlugin( entrypoints={"e2e_simple_functional": simple_functional_entrypoint}, tasks=tasks, - activity_options=_activity_opts(*tasks), + default_activity_options=_DEFAULT_ACTIVITY_OPTIONS, ) ], ): @@ -93,7 +86,7 @@ async def test_continue_as_new_with_checkpoint(self, client: Client) -> None: "e2e_continue_as_new_functional": continue_as_new_entrypoint }, tasks=tasks, - activity_options=_activity_opts(*tasks), + default_activity_options=_DEFAULT_ACTIVITY_OPTIONS, ) ], ): @@ -135,7 +128,7 @@ async def test_partial_execution_five_tasks(self, client: Client) -> None: LangGraphPlugin( entrypoints={"e2e_partial_execution": partial_execution_entrypoint}, tasks=tasks, - activity_options=_activity_opts(*tasks), + default_activity_options=_DEFAULT_ACTIVITY_OPTIONS, ) ], ): diff --git a/tests/contrib/langgraph/test_e2e_functional_v2.py b/tests/contrib/langgraph/test_e2e_functional_v2.py index 13679676e..40f2618f6 100644 --- a/tests/contrib/langgraph/test_e2e_functional_v2.py +++ b/tests/contrib/langgraph/test_e2e_functional_v2.py @@ -101,11 +101,8 @@ async def test_simple_v2(self, client: Client) -> None: LangGraphPlugin( entrypoints={"v2_simple": simple_v2_entrypoint}, tasks=tasks, - activity_options={ - "triple_value": { - "start_to_close_timeout": timedelta(seconds=30) - }, - "add_five": {"start_to_close_timeout": timedelta(seconds=30)}, + default_activity_options={ + "start_to_close_timeout": timedelta(seconds=30) }, ) ], @@ -134,8 +131,8 @@ async def test_interrupt_v2_functional(self, client: Client) -> None: LangGraphPlugin( entrypoints={"v2_interrupt": interrupt_entrypoint}, tasks=tasks, - activity_options={ - "ask_human": {"start_to_close_timeout": timedelta(seconds=30)}, + default_activity_options={ + "start_to_close_timeout": timedelta(seconds=30) }, ) ], diff --git a/tests/contrib/langgraph/test_interrupt.py b/tests/contrib/langgraph/test_interrupt.py index d91e6e495..f4e0a5cd8 100644 --- a/tests/contrib/langgraph/test_interrupt.py +++ b/tests/contrib/langgraph/test_interrupt.py @@ -39,11 +39,7 @@ async def run(self, input: str) -> Any: async def test_interrupt(client: Client): g = StateGraph(State) - g.add_node( - "node", - node, - metadata={"start_to_close_timeout": timedelta(seconds=10)}, - ) + g.add_node("node", node) g.add_edge(START, "node") task_queue = f"my-graph-{uuid4()}" @@ -52,7 +48,14 @@ async def test_interrupt(client: Client): client, task_queue=task_queue, workflows=[InterruptWorkflow], - plugins=[LangGraphPlugin(graphs={"my-graph": g})], + plugins=[ + LangGraphPlugin( + graphs={"my-graph": g}, + default_activity_options={ + "start_to_close_timeout": timedelta(seconds=10) + }, + ) + ], ): result = await client.execute_workflow( InterruptWorkflow.run, diff --git a/tests/contrib/langgraph/test_interrupt_v2.py b/tests/contrib/langgraph/test_interrupt_v2.py index 348466f62..d7ff3957b 100644 --- a/tests/contrib/langgraph/test_interrupt_v2.py +++ b/tests/contrib/langgraph/test_interrupt_v2.py @@ -49,11 +49,7 @@ async def run(self, input: str) -> Any: async def test_interrupt_v2(client: Client): g = StateGraph(State) - g.add_node( - "node", - node, - metadata={"start_to_close_timeout": timedelta(seconds=10)}, - ) + g.add_node("node", node) g.add_edge(START, "node") task_queue = f"interrupt-v2-{uuid4()}" @@ -62,7 +58,14 @@ async def test_interrupt_v2(client: Client): client, task_queue=task_queue, workflows=[InterruptV2Workflow], - plugins=[LangGraphPlugin(graphs={"interrupt-v2-graph": g})], + plugins=[ + LangGraphPlugin( + graphs={"interrupt-v2-graph": g}, + default_activity_options={ + "start_to_close_timeout": timedelta(seconds=10) + }, + ) + ], ): result = await client.execute_workflow( InterruptV2Workflow.run, diff --git a/tests/contrib/langgraph/test_streaming.py b/tests/contrib/langgraph/test_streaming.py index c14cecec9..9617224a1 100644 --- a/tests/contrib/langgraph/test_streaming.py +++ b/tests/contrib/langgraph/test_streaming.py @@ -35,16 +35,8 @@ async def run(self, input: str) -> Any: async def test_streaming(client: Client): g = StateGraph(State) - g.add_node( - "node_a", - node_a, - metadata={"start_to_close_timeout": timedelta(seconds=10)}, - ) - g.add_node( - "node_b", - node_b, - metadata={"start_to_close_timeout": timedelta(seconds=10)}, - ) + g.add_node("node_a", node_a) + g.add_node("node_b", node_b) g.add_edge(START, "node_a") g.add_edge("node_a", "node_b") @@ -54,7 +46,14 @@ async def test_streaming(client: Client): client, task_queue=task_queue, workflows=[StreamingWorkflow], - plugins=[LangGraphPlugin(graphs={"streaming": g})], + plugins=[ + LangGraphPlugin( + graphs={"streaming": g}, + default_activity_options={ + "start_to_close_timeout": timedelta(seconds=10) + }, + ) + ], ): chunks = await client.execute_workflow( StreamingWorkflow.run, diff --git a/tests/contrib/langgraph/test_subgraph_activity.py b/tests/contrib/langgraph/test_subgraph_activity.py index fd802ff9a..e73a75865 100644 --- a/tests/contrib/langgraph/test_subgraph_activity.py +++ b/tests/contrib/langgraph/test_subgraph_activity.py @@ -36,11 +36,7 @@ async def run(self, input: str) -> Any: async def test_activity_subgraph(client: Client): parent = StateGraph(State) - parent.add_node( - "parent_node", - parent_node, - metadata={"start_to_close_timeout": timedelta(seconds=10)}, - ) + parent.add_node("parent_node", parent_node) parent.add_edge(START, "parent_node") task_queue = f"subgraph-{uuid4()}" @@ -49,7 +45,14 @@ async def test_activity_subgraph(client: Client): client, task_queue=task_queue, workflows=[ActivitySubgraphWorkflow], - plugins=[LangGraphPlugin(graphs={"parent": parent})], + plugins=[ + LangGraphPlugin( + graphs={"parent": parent}, + default_activity_options={ + "start_to_close_timeout": timedelta(seconds=10) + }, + ) + ], ): result = await client.execute_workflow( ActivitySubgraphWorkflow.run, diff --git a/tests/contrib/langgraph/test_timeout.py b/tests/contrib/langgraph/test_timeout.py index 2521e3a00..6bcdaa8ba 100644 --- a/tests/contrib/langgraph/test_timeout.py +++ b/tests/contrib/langgraph/test_timeout.py @@ -32,14 +32,7 @@ async def run(self, input: str) -> Any: async def test_timeout(client: Client): g = StateGraph(State) - g.add_node( - "node", - node, - metadata={ - "start_to_close_timeout": timedelta(milliseconds=100), - "retry_policy": RetryPolicy(maximum_attempts=1), - }, - ) + g.add_node("node", node) g.add_edge(START, "node") task_queue = f"my-graph-{uuid4()}" @@ -48,7 +41,15 @@ async def test_timeout(client: Client): client, task_queue=task_queue, workflows=[TimeoutWorkflow], - plugins=[LangGraphPlugin(graphs={"my-graph": g})], + plugins=[ + LangGraphPlugin( + graphs={"my-graph": g}, + default_activity_options={ + "start_to_close_timeout": timedelta(milliseconds=100), + "retry_policy": RetryPolicy(maximum_attempts=1), + }, + ) + ], ): with raises(WorkflowFailureError): await client.execute_workflow( diff --git a/tests/contrib/langgraph/test_two_nodes.py b/tests/contrib/langgraph/test_two_nodes.py index 7a8affcac..d5b7e06a1 100644 --- a/tests/contrib/langgraph/test_two_nodes.py +++ b/tests/contrib/langgraph/test_two_nodes.py @@ -32,16 +32,8 @@ async def run(self, input: str) -> Any: async def test_two_nodes(client: Client): g = StateGraph(State) - g.add_node( - "node_a", - node_a, - metadata={"start_to_close_timeout": timedelta(seconds=10)}, - ) - g.add_node( - "node_b", - node_b, - metadata={"start_to_close_timeout": timedelta(seconds=10)}, - ) + g.add_node("node_a", node_a) + g.add_node("node_b", node_b) g.add_edge(START, "node_a") g.add_edge("node_a", "node_b") @@ -51,7 +43,14 @@ async def test_two_nodes(client: Client): client, task_queue=task_queue, workflows=[TwoNodesWorkflow], - plugins=[LangGraphPlugin(graphs={"my-graph": g})], + plugins=[ + LangGraphPlugin( + graphs={"my-graph": g}, + default_activity_options={ + "start_to_close_timeout": timedelta(seconds=10) + }, + ) + ], ): result = await client.execute_workflow( TwoNodesWorkflow.run, From 86837df6f875660ddc13e1fd64e036fc902a123b Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Thu, 16 Apr 2026 15:54:30 -0700 Subject: [PATCH 15/47] add replay test --- .../contrib/langgraph/langgraph_plugin.py | 48 +++++++------------ 1 file changed, 16 insertions(+), 32 deletions(-) diff --git a/temporalio/contrib/langgraph/langgraph_plugin.py b/temporalio/contrib/langgraph/langgraph_plugin.py index 8c85ac8e6..bf17421cc 100644 --- a/temporalio/contrib/langgraph/langgraph_plugin.py +++ b/temporalio/contrib/langgraph/langgraph_plugin.py @@ -9,7 +9,7 @@ from langgraph.graph import StateGraph from langgraph.pregel import Pregel -from temporalio import activity, workflow +from temporalio import activity from temporalio.contrib.langgraph.activity import wrap_activity, wrap_execute_activity from temporalio.contrib.langgraph.task_cache import ( get_task_cache, @@ -17,13 +17,12 @@ task_id, ) from temporalio.plugin import SimplePlugin -from temporalio.worker import WorkerConfig, WorkflowRunner +from temporalio.worker import WorkflowRunner from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner # Save registered graphs/entrypoints at the module level to avoid being refreshed by the sandbox. -# Keyed by task queue to isolate concurrent Workers/Plugins in the same process. -_graph_registry: dict[str, dict[str, StateGraph[Any]]] = {} -_entrypoint_registry: dict[str, dict[str, Pregel[Any, Any, Any, Any]]] = {} +_graph_registry: dict[str, StateGraph[Any]] = {} +_entrypoint_registry: dict[str, Pregel[Any, Any, Any, Any]] = {} class LangGraphPlugin(SimplePlugin): @@ -52,11 +51,10 @@ def __init__( ): """Initialize the LangGraph plugin with graphs, entrypoints, and tasks.""" self.activities: list = [] - self._graphs: dict[str, StateGraph[Any]] = graphs or {} - self._entrypoints: dict[str, Pregel[Any, Any, Any, Any]] = entrypoints or {} # Graph API: Wrap graph nodes as Temporal Activities. if graphs: + _graph_registry.update(graphs) for graph_name, graph in graphs.items(): for node_name, node in graph.nodes.items(): runnable = node.runnable @@ -74,6 +72,9 @@ def __init__( f"{graph_name}.{node_name}", runnable.afunc, opts ) + if entrypoints: + _entrypoint_registry.update(entrypoints) + # Functional API: Wrap @task functions as Temporal Activities. if tasks: for task in tasks: @@ -106,19 +107,6 @@ def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: workflow_runner=workflow_runner, ) - def configure_worker(self, config: WorkerConfig) -> WorkerConfig: - """Register graphs/entrypoints scoped to the worker's task queue.""" - task_queue = config.get("task_queue") - if not task_queue: - raise ValueError( - "Worker config must include a task_queue for LangGraphPlugin" - ) - if self._graphs: - _graph_registry.setdefault(task_queue, {}).update(self._graphs) - if self._entrypoints: - _entrypoint_registry.setdefault(task_queue, {}).update(self._entrypoints) - return super().configure_worker(config) - def execute( self, activity_name: str, @@ -152,14 +140,12 @@ def graph( """ _patch_event_loop() set_task_cache(cache or {}) - task_queue = workflow.info().task_queue - registry = _graph_registry.get(task_queue, {}) - if name not in registry: + if name not in _graph_registry: raise KeyError( - f"Graph {name!r} not found for task queue {task_queue!r}. " - f"Available graphs: {list(registry.keys())}" + f"Graph {name!r} not found. " + f"Available graphs: {list(_graph_registry.keys())}" ) - return registry[name] + return _graph_registry[name] def entrypoint( @@ -175,14 +161,12 @@ def entrypoint( """ _patch_event_loop() set_task_cache(cache or {}) - task_queue = workflow.info().task_queue - registry = _entrypoint_registry.get(task_queue, {}) - if name not in registry: + if name not in _entrypoint_registry: raise KeyError( - f"Entrypoint {name!r} not found for task queue {task_queue!r}. " - f"Available entrypoints: {list(registry.keys())}" + f"Entrypoint {name!r} not found. " + f"Available entrypoints: {list(_entrypoint_registry.keys())}" ) - return registry[name] + return _entrypoint_registry[name] def cache() -> dict[str, Any] | None: From dc2ed5998aceb11cad3ffaeb04128076912b0185 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Thu, 16 Apr 2026 16:38:05 -0700 Subject: [PATCH 16/47] fix gaps in missing tests --- .../contrib/langgraph/langgraph_plugin.py | 5 +- .../langgraph/e2e_functional_entrypoints.py | 14 ++ .../langgraph/e2e_functional_workflows.py | 7 + .../contrib/langgraph/test_e2e_functional.py | 176 +++++++++++++++++- .../langgraph/test_e2e_functional_v2.py | 149 --------------- tests/contrib/langgraph/test_interrupt.py | 28 ++- tests/contrib/langgraph/test_interrupt_v2.py | 77 -------- .../langgraph/test_plugin_validation.py | 43 +++++ tests/contrib/langgraph/test_replay.py | 87 +++++++++ 9 files changed, 345 insertions(+), 241 deletions(-) delete mode 100644 tests/contrib/langgraph/test_e2e_functional_v2.py delete mode 100644 tests/contrib/langgraph/test_interrupt_v2.py create mode 100644 tests/contrib/langgraph/test_plugin_validation.py create mode 100644 tests/contrib/langgraph/test_replay.py diff --git a/temporalio/contrib/langgraph/langgraph_plugin.py b/temporalio/contrib/langgraph/langgraph_plugin.py index bf17421cc..271eb5e06 100644 --- a/temporalio/contrib/langgraph/langgraph_plugin.py +++ b/temporalio/contrib/langgraph/langgraph_plugin.py @@ -79,7 +79,10 @@ def __init__( if tasks: for task in tasks: name = task.func.__name__ - opts = {**(default_activity_options or {}), **(activity_options or {}).get(name, {})} + opts = { + **(default_activity_options or {}), + **(activity_options or {}).get(name, {}), + } task.func = self.execute(task_id(task.func), task.func, opts) task.func.__name__ = name diff --git a/tests/contrib/langgraph/e2e_functional_entrypoints.py b/tests/contrib/langgraph/e2e_functional_entrypoints.py index 01f871b8f..7f16abe8d 100644 --- a/tests/contrib/langgraph/e2e_functional_entrypoints.py +++ b/tests/contrib/langgraph/e2e_functional_entrypoints.py @@ -5,6 +5,8 @@ from __future__ import annotations +import asyncio + import langgraph.types from langgraph.func import entrypoint, task # pyright: ignore[reportMissingTypeStubs] @@ -128,3 +130,15 @@ async def interrupt_entrypoint(value: str) -> dict: """Entrypoint that interrupts for human input, then returns the answer.""" answer = await ask_human("Do you approve?") return {"input": value, "answer": answer} + + +@task +async def slow_task(x: int) -> int: + await asyncio.sleep(1) + return x + + +@entrypoint() +async def slow_entrypoint(value: int) -> dict: + result = await slow_task(value) + return {"result": result} diff --git a/tests/contrib/langgraph/e2e_functional_workflows.py b/tests/contrib/langgraph/e2e_functional_workflows.py index e44effcfd..c0494d91f 100644 --- a/tests/contrib/langgraph/e2e_functional_workflows.py +++ b/tests/contrib/langgraph/e2e_functional_workflows.py @@ -16,6 +16,13 @@ async def run(self, input_value: int) -> dict: return await entrypoint("e2e_simple_functional").ainvoke(input_value) +@workflow.defn +class SlowFunctionalWorkflow: + @workflow.run + async def run(self, input_value: int) -> dict: + return await entrypoint("e2e_slow_functional").ainvoke(input_value) + + @dataclass class ContinueAsNewInput: value: int diff --git a/tests/contrib/langgraph/test_e2e_functional.py b/tests/contrib/langgraph/test_e2e_functional.py index c84b03b5d..7c475402b 100644 --- a/tests/contrib/langgraph/test_e2e_functional.py +++ b/tests/contrib/langgraph/test_e2e_functional.py @@ -1,4 +1,4 @@ -"""End-to-end tests for LangGraph Functional API integration. +"""End-to-end tests for LangGraph Functional API integration (v1 and v2). Requires a running Temporal test server (started by conftest.py). """ @@ -6,22 +6,39 @@ from __future__ import annotations from datetime import timedelta +from typing import Any from uuid import uuid4 -from temporalio.client import Client -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin +import pytest +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.memory import InMemorySaver +from langgraph.func import ( # pyright: ignore[reportMissingTypeStubs] + entrypoint as lg_entrypoint, +) +from langgraph.func import task # pyright: ignore[reportMissingTypeStubs] +from langgraph.types import Command +from pytest import raises + +from temporalio import workflow +from temporalio.client import Client, WorkflowFailureError +from temporalio.common import RetryPolicy +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, entrypoint from temporalio.worker import Worker from tests.contrib.langgraph.e2e_functional_entrypoints import ( add_ten, + ask_human, continue_as_new_entrypoint, double_value, expensive_task_a, expensive_task_b, expensive_task_c, get_task_execution_counts, + interrupt_entrypoint, partial_execution_entrypoint, reset_task_execution_counts, simple_functional_entrypoint, + slow_entrypoint, + slow_task, step_1, step_2, step_3, @@ -34,38 +51,114 @@ PartialExecutionInput, PartialExecutionWorkflow, SimpleFunctionalE2EWorkflow, + SlowFunctionalWorkflow, ) _DEFAULT_ACTIVITY_OPTIONS = {"start_to_close_timeout": timedelta(seconds=30)} +# V2-only tasks defined here to avoid sharing mutated _TaskFunction objects +# (Plugin wraps task.func in-place). + + +@task +def triple_value(x: int) -> int: + return x * 3 + + +@task +def add_five(x: int) -> int: + return x + 5 + + +@lg_entrypoint() +async def simple_v2_entrypoint(value: int) -> dict: + tripled = await triple_value(value) + result = await add_five(tripled) + return {"result": result} + + +@workflow.defn +class SimpleV2Workflow: + @workflow.run + async def run(self, input_value: int) -> dict[str, Any]: + result = await entrypoint("v2_simple").ainvoke(input_value, version="v2") + return result.value + + +@workflow.defn +class InterruptV2FunctionalWorkflow: + @workflow.run + async def run(self, input_value: str) -> dict[str, Any]: + app = entrypoint("v2_interrupt") + app.checkpointer = InMemorySaver() + config = RunnableConfig( + {"configurable": {"thread_id": workflow.info().workflow_id}} + ) + + result = await app.ainvoke(input_value, config, version="v2") + + assert result.value == {} + assert len(result.interrupts) == 1 + assert result.interrupts[0].value == "Do you approve?" + + resumed = await app.ainvoke(Command(resume="approved"), config, version="v2") + return resumed.value + + class TestFunctionalAPIBasicExecution: - async def test_simple_functional_entrypoint(self, client: Client) -> None: - """input 10 -> double (20) -> add 10 (30) -> result: 30""" - tasks = [double_value, add_ten] + @pytest.mark.parametrize( + "workflow_cls,entrypoint_func,entrypoint_name,tasks,expected_result", + [ + ( + SimpleFunctionalE2EWorkflow, + simple_functional_entrypoint, + "e2e_simple_functional", + [double_value, add_ten], + 30, + ), + ( + SimpleV2Workflow, + simple_v2_entrypoint, + "v2_simple", + [triple_value, add_five], + 35, + ), + ], + ids=["v1", "v2"], + ) + async def test_simple_entrypoint( + self, + client: Client, + workflow_cls: Any, + entrypoint_func: Any, + entrypoint_name: str, + tasks: list, + expected_result: int, + ) -> None: task_queue = f"e2e-functional-{uuid4()}" async with Worker( client, task_queue=task_queue, - workflows=[SimpleFunctionalE2EWorkflow], + workflows=[workflow_cls], plugins=[ LangGraphPlugin( - entrypoints={"e2e_simple_functional": simple_functional_entrypoint}, + entrypoints={entrypoint_name: entrypoint_func}, tasks=tasks, default_activity_options=_DEFAULT_ACTIVITY_OPTIONS, ) ], ): result = await client.execute_workflow( - SimpleFunctionalE2EWorkflow.run, + workflow_cls.run, 10, id=f"e2e-functional-{uuid4()}", task_queue=task_queue, execution_timeout=timedelta(seconds=30), ) - assert result["result"] == 30 + assert result["result"] == expected_result class TestFunctionalAPIContinueAsNew: @@ -148,3 +241,66 @@ async def test_partial_execution_five_tasks(self, client: Client) -> None: assert ( counts.get(f"step_{i}", 0) == 1 ), f"step_{i} executed {counts.get(f'step_{i}', 0)} times, expected 1" + + +class TestFunctionalAPIInterruptV2: + async def test_interrupt_v2_functional(self, client: Client) -> None: + """version='v2' separates interrupts from value in functional API.""" + tasks = [ask_human] + task_queue = f"v2-interrupt-{uuid4()}" + + async with Worker( + client, + task_queue=task_queue, + workflows=[InterruptV2FunctionalWorkflow], + plugins=[ + LangGraphPlugin( + entrypoints={"v2_interrupt": interrupt_entrypoint}, + tasks=tasks, + default_activity_options=_DEFAULT_ACTIVITY_OPTIONS, + ) + ], + ): + result = await client.execute_workflow( + InterruptV2FunctionalWorkflow.run, + "hello", + id=f"v2-interrupt-{uuid4()}", + task_queue=task_queue, + execution_timeout=timedelta(seconds=30), + ) + + assert result["input"] == "hello" + assert result["answer"] == "approved" + + +class TestFunctionalAPIPerTaskOptions: + async def test_per_task_activity_options_override(self, client: Client) -> None: + """activity_options[task_name] overrides default_activity_options for that task.""" + task_queue = f"e2e-per-task-options-{uuid4()}" + + async with Worker( + client, + task_queue=task_queue, + workflows=[SlowFunctionalWorkflow], + plugins=[ + LangGraphPlugin( + entrypoints={"e2e_slow_functional": slow_entrypoint}, + tasks=[slow_task], + default_activity_options=_DEFAULT_ACTIVITY_OPTIONS, + activity_options={ + "slow_task": { + "start_to_close_timeout": timedelta(milliseconds=100), + "retry_policy": RetryPolicy(maximum_attempts=1), + } + }, + ) + ], + ): + with raises(WorkflowFailureError): + await client.execute_workflow( + SlowFunctionalWorkflow.run, + 1, + id=f"e2e-per-task-options-{uuid4()}", + task_queue=task_queue, + execution_timeout=timedelta(seconds=30), + ) diff --git a/tests/contrib/langgraph/test_e2e_functional_v2.py b/tests/contrib/langgraph/test_e2e_functional_v2.py deleted file mode 100644 index 40f2618f6..000000000 --- a/tests/contrib/langgraph/test_e2e_functional_v2.py +++ /dev/null @@ -1,149 +0,0 @@ -"""Tests for LangGraph Functional API with version="v2". - -version="v2" changes ainvoke() to return a GraphOutput dataclass with -.value and .interrupts fields instead of a plain dict with __interrupt__ -mixed in. -""" - -from __future__ import annotations - -from datetime import timedelta -from typing import Any -from uuid import uuid4 - -from langchain_core.runnables import RunnableConfig -from langgraph.checkpoint.memory import InMemorySaver -from langgraph.func import ( # pyright: ignore[reportMissingTypeStubs] - entrypoint as lg_entrypoint, -) -from langgraph.func import task # pyright: ignore[reportMissingTypeStubs] -from langgraph.types import Command - -from temporalio import workflow -from temporalio.client import Client -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, entrypoint -from temporalio.worker import Worker -from tests.contrib.langgraph.e2e_functional_entrypoints import ( - ask_human, - interrupt_entrypoint, -) - -# Define separate tasks to avoid sharing mutated _TaskFunction objects with -# other tests (Plugin wraps task.func in-place). - - -@task -def triple_value(x: int) -> int: - return x * 3 - - -@task -def add_five(x: int) -> int: - return x + 5 - - -@lg_entrypoint() -async def simple_v2_entrypoint(value: int) -> dict: - tripled = await triple_value(value) - result = await add_five(tripled) - return {"result": result} - - -# -- Workflows ---------------------------------------------------------------- - - -@workflow.defn -class SimpleV2Workflow: - @workflow.run - async def run(self, input_value: int) -> dict[str, Any]: - result = await entrypoint("v2_simple").ainvoke(input_value, version="v2") - # v2 returns GraphOutput — extract .value for Temporal serialization - return result.value - - -@workflow.defn -class InterruptV2FunctionalWorkflow: - @workflow.run - async def run(self, input_value: str) -> dict[str, Any]: - app = entrypoint("v2_interrupt") - app.checkpointer = InMemorySaver() - config = RunnableConfig( - {"configurable": {"thread_id": workflow.info().workflow_id}} - ) - - # First invoke — should get an interrupt - result = await app.ainvoke(input_value, config, version="v2") - - # v2: interrupts are on result.interrupts, value is clean - assert result.value == {} - assert len(result.interrupts) == 1 - assert result.interrupts[0].value == "Do you approve?" - - # Resume with approval - resumed = await app.ainvoke(Command(resume="approved"), config, version="v2") - return resumed.value - - -# -- Tests -------------------------------------------------------------------- - - -class TestFunctionalAPIV2: - async def test_simple_v2(self, client: Client) -> None: - """version='v2' returns GraphOutput with .value containing the result.""" - tasks = [triple_value, add_five] - task_queue = f"v2-simple-{uuid4()}" - - async with Worker( - client, - task_queue=task_queue, - workflows=[SimpleV2Workflow], - plugins=[ - LangGraphPlugin( - entrypoints={"v2_simple": simple_v2_entrypoint}, - tasks=tasks, - default_activity_options={ - "start_to_close_timeout": timedelta(seconds=30) - }, - ) - ], - ): - result = await client.execute_workflow( - SimpleV2Workflow.run, - 10, - id=f"v2-simple-{uuid4()}", - task_queue=task_queue, - execution_timeout=timedelta(seconds=30), - ) - - # 10 * 3 = 30, 30 + 5 = 35 - assert result["result"] == 35 - - async def test_interrupt_v2_functional(self, client: Client) -> None: - """version='v2' separates interrupts from value in functional API.""" - tasks = [ask_human] - task_queue = f"v2-interrupt-{uuid4()}" - - async with Worker( - client, - task_queue=task_queue, - workflows=[InterruptV2FunctionalWorkflow], - plugins=[ - LangGraphPlugin( - entrypoints={"v2_interrupt": interrupt_entrypoint}, - tasks=tasks, - default_activity_options={ - "start_to_close_timeout": timedelta(seconds=30) - }, - ) - ], - ): - result = await client.execute_workflow( - InterruptV2FunctionalWorkflow.run, - "hello", - id=f"v2-interrupt-{uuid4()}", - task_queue=task_queue, - execution_timeout=timedelta(seconds=30), - ) - - assert result["input"] == "hello" - assert result["answer"] == "approved" diff --git a/tests/contrib/langgraph/test_interrupt.py b/tests/contrib/langgraph/test_interrupt.py index f4e0a5cd8..91deecab5 100644 --- a/tests/contrib/langgraph/test_interrupt.py +++ b/tests/contrib/langgraph/test_interrupt.py @@ -3,6 +3,7 @@ from uuid import uuid4 import langgraph.types +import pytest from langgraph.checkpoint.memory import InMemorySaver from langgraph.graph import START, StateGraph # pyright: ignore[reportMissingTypeStubs] from langgraph.graph.state import ( # pyright: ignore[reportMissingTypeStubs] @@ -37,17 +38,36 @@ async def run(self, input: str) -> Any: return await g.ainvoke(langgraph.types.Command(resume="yes"), config) -async def test_interrupt(client: Client): +@workflow.defn +class InterruptV2Workflow: + @workflow.run + async def run(self, input: str) -> Any: + g = graph("my-graph").compile(checkpointer=InMemorySaver()) + config = RunnableConfig({"configurable": {"thread_id": "1"}}) + + result = await g.ainvoke({"value": input}, config, version="v2") + + assert result.value == {"value": ""} + assert len(result.interrupts) == 1 + assert result.interrupts[0].value == "Continue?" + + return await g.ainvoke(langgraph.types.Command(resume="yes"), config) + + +@pytest.mark.parametrize( + "workflow_cls", [InterruptWorkflow, InterruptV2Workflow], ids=["v1", "v2"] +) +async def test_interrupt(client: Client, workflow_cls: Any) -> None: g = StateGraph(State) g.add_node("node", node) g.add_edge(START, "node") - task_queue = f"my-graph-{uuid4()}" + task_queue = f"interrupt-{uuid4()}" async with Worker( client, task_queue=task_queue, - workflows=[InterruptWorkflow], + workflows=[workflow_cls], plugins=[ LangGraphPlugin( graphs={"my-graph": g}, @@ -58,7 +78,7 @@ async def test_interrupt(client: Client): ], ): result = await client.execute_workflow( - InterruptWorkflow.run, + workflow_cls.run, "", id=f"test-workflow-{uuid4()}", task_queue=task_queue, diff --git a/tests/contrib/langgraph/test_interrupt_v2.py b/tests/contrib/langgraph/test_interrupt_v2.py deleted file mode 100644 index d7ff3957b..000000000 --- a/tests/contrib/langgraph/test_interrupt_v2.py +++ /dev/null @@ -1,77 +0,0 @@ -"""Test Graph API interrupt handling with version="v2". - -With v2, ainvoke() returns a GraphOutput dataclass with .value and .interrupts -instead of mixing __interrupt__ into the state dict. -""" - -from datetime import timedelta -from typing import Any -from uuid import uuid4 - -import langgraph.types -from langgraph.checkpoint.memory import InMemorySaver -from langgraph.graph import START, StateGraph # pyright: ignore[reportMissingTypeStubs] -from langgraph.graph.state import ( # pyright: ignore[reportMissingTypeStubs] - RunnableConfig, -) -from typing_extensions import TypedDict - -from temporalio import workflow -from temporalio.client import Client -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph -from temporalio.worker import Worker - - -class State(TypedDict): - value: str - - -async def node(state: State) -> dict[str, str]: # pyright: ignore[reportUnusedParameter] - return {"value": langgraph.types.interrupt("Continue?")} - - -@workflow.defn -class InterruptV2Workflow: - @workflow.run - async def run(self, input: str) -> Any: - g = graph("interrupt-v2-graph").compile(checkpointer=InMemorySaver()) - config = RunnableConfig({"configurable": {"thread_id": "1"}}) - - result = await g.ainvoke({"value": input}, config, version="v2") - - # v2: interrupts are on result.interrupts, not result["__interrupt__"] - assert result.value == {"value": ""} - assert len(result.interrupts) == 1 - assert result.interrupts[0].value == "Continue?" - - return await g.ainvoke(langgraph.types.Command(resume="yes"), config) - - -async def test_interrupt_v2(client: Client): - g = StateGraph(State) - g.add_node("node", node) - g.add_edge(START, "node") - - task_queue = f"interrupt-v2-{uuid4()}" - - async with Worker( - client, - task_queue=task_queue, - workflows=[InterruptV2Workflow], - plugins=[ - LangGraphPlugin( - graphs={"interrupt-v2-graph": g}, - default_activity_options={ - "start_to_close_timeout": timedelta(seconds=10) - }, - ) - ], - ): - result = await client.execute_workflow( - InterruptV2Workflow.run, - "", - id=f"test-interrupt-v2-{uuid4()}", - task_queue=task_queue, - ) - - assert result == {"value": "yes"} diff --git a/tests/contrib/langgraph/test_plugin_validation.py b/tests/contrib/langgraph/test_plugin_validation.py new file mode 100644 index 000000000..4fa33c317 --- /dev/null +++ b/tests/contrib/langgraph/test_plugin_validation.py @@ -0,0 +1,43 @@ +"""Tests for LangGraphPlugin validation and registry lookup error paths.""" + +from __future__ import annotations + +from uuid import uuid4 + +from langchain_core.runnables import RunnableLambda +from langgraph.graph import START, StateGraph # pyright: ignore[reportMissingTypeStubs] +from pytest import raises +from typing_extensions import TypedDict + +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin + + +class State(TypedDict): + value: str + + +async def async_node(state: State) -> dict[str, str]: # pyright: ignore[reportUnusedParameter] + return {"value": "done"} + + +def sync_node(state: State) -> dict[str, str]: # pyright: ignore[reportUnusedParameter] + return {"value": "done"} + + +def test_non_runnable_callable_node_raises() -> None: + """Nodes whose runnable isn't a RunnableCallable can't be wrapped as activities.""" + g = StateGraph(State) + g.add_node("node", RunnableLambda(sync_node)) + g.add_edge(START, "node") + + with raises(ValueError, match="must have an async function"): + LangGraphPlugin(graphs={f"validation-{uuid4()}": g}) + + +def test_invalid_execute_in_raises() -> None: + g = StateGraph(State) + g.add_node("node", async_node, metadata={"execute_in": "bogus"}) + g.add_edge(START, "node") + + with raises(ValueError, match="Invalid execute_in value"): + LangGraphPlugin(graphs={f"validation-{uuid4()}": g}) diff --git a/tests/contrib/langgraph/test_replay.py b/tests/contrib/langgraph/test_replay.py new file mode 100644 index 000000000..f7ba60e07 --- /dev/null +++ b/tests/contrib/langgraph/test_replay.py @@ -0,0 +1,87 @@ +from datetime import timedelta +from uuid import uuid4 + +from langgraph.graph import START, StateGraph # pyright: ignore[reportMissingTypeStubs] + +from temporalio.client import Client +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin +from temporalio.worker import Replayer, Worker +from tests.contrib.langgraph.test_interrupt import ( + InterruptWorkflow, +) +from tests.contrib.langgraph.test_interrupt import ( + State as InterruptState, +) +from tests.contrib.langgraph.test_interrupt import ( + node as interrupt_node, +) +from tests.contrib.langgraph.test_two_nodes import ( + State, + TwoNodesWorkflow, + node_a, + node_b, +) + + +async def test_replay(client: Client): + g = StateGraph(State) + g.add_node("node_a", node_a) + g.add_node("node_b", node_b) + g.add_edge(START, "node_a") + g.add_edge("node_a", "node_b") + + task_queue = f"my-graph-{uuid4()}" + plugin = LangGraphPlugin( + graphs={"my-graph": g}, + default_activity_options={"start_to_close_timeout": timedelta(seconds=10)}, + ) + + async with Worker( + client, + task_queue=task_queue, + workflows=[TwoNodesWorkflow], + plugins=[plugin], + ): + handle = await client.start_workflow( + TwoNodesWorkflow.run, + "", + id=f"test-workflow-{uuid4()}", + task_queue=task_queue, + ) + await handle.result() + + await Replayer( + workflows=[TwoNodesWorkflow], + plugins=[plugin], + ).replay_workflow(await handle.fetch_history()) + + +async def test_replay_interrupt(client: Client): + g = StateGraph(InterruptState) + g.add_node("node", interrupt_node) + g.add_edge(START, "node") + + task_queue = f"interrupt-replay-{uuid4()}" + plugin = LangGraphPlugin( + graphs={"my-graph": g}, + default_activity_options={"start_to_close_timeout": timedelta(seconds=10)}, + ) + + async with Worker( + client, + task_queue=task_queue, + workflows=[InterruptWorkflow], + plugins=[plugin], + ): + handle = await client.start_workflow( + InterruptWorkflow.run, + "", + id=f"test-interrupt-replay-{uuid4()}", + task_queue=task_queue, + ) + await handle.result() + + await Replayer( + workflows=[InterruptWorkflow], + plugins=[plugin], + ).replay_workflow(await handle.fetch_history()) From c0525cc0bc4547788e376fe22ebae2e9d84019cf Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Thu, 16 Apr 2026 17:43:13 -0700 Subject: [PATCH 17/47] introduce an interceptor to patch is_running only in the workflow, then use init functions for graph compilation --- .../contrib/langgraph/langgraph_plugin.py | 12 ++---------- .../langgraph/e2e_functional_workflows.py | 16 +++++++++++----- tests/contrib/langgraph/test_continue_as_new.py | 10 ++++++---- .../langgraph/test_continue_as_new_cached.py | 4 ++-- tests/contrib/langgraph/test_e2e_functional.py | 17 ++++++++++++----- .../langgraph/test_execute_in_workflow.py | 5 ++++- tests/contrib/langgraph/test_interrupt.py | 16 ++++++++++------ .../contrib/langgraph/test_plugin_validation.py | 16 +++++++++++++++- tests/contrib/langgraph/test_streaming.py | 5 ++++- .../contrib/langgraph/test_subgraph_activity.py | 5 ++++- .../contrib/langgraph/test_subgraph_workflow.py | 5 ++++- tests/contrib/langgraph/test_timeout.py | 5 ++++- tests/contrib/langgraph/test_two_nodes.py | 5 ++++- 13 files changed, 82 insertions(+), 39 deletions(-) diff --git a/temporalio/contrib/langgraph/langgraph_plugin.py b/temporalio/contrib/langgraph/langgraph_plugin.py index 271eb5e06..6ce1ddd91 100644 --- a/temporalio/contrib/langgraph/langgraph_plugin.py +++ b/temporalio/contrib/langgraph/langgraph_plugin.py @@ -11,6 +11,7 @@ from temporalio import activity from temporalio.contrib.langgraph.activity import wrap_activity, wrap_execute_activity +from temporalio.contrib.langgraph.langgraph_interceptor import LangGraphInterceptor from temporalio.contrib.langgraph.task_cache import ( get_task_cache, set_task_cache, @@ -108,6 +109,7 @@ def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: "temporalio.LangGraphPlugin", activities=self.activities, workflow_runner=workflow_runner, + interceptors=[LangGraphInterceptor()], ) def execute( @@ -141,7 +143,6 @@ def graph( Restores cached results so previously-completed nodes are not re-executed after continue-as-new. """ - _patch_event_loop() set_task_cache(cache or {}) if name not in _graph_registry: raise KeyError( @@ -162,7 +163,6 @@ def entrypoint( Restores cached results so previously-completed tasks are not re-executed after continue-as-new. """ - _patch_event_loop() set_task_cache(cache or {}) if name not in _entrypoint_registry: raise KeyError( @@ -180,11 +180,3 @@ def cache() -> dict[str, Any] | None: Returns None if the cache is empty. """ return get_task_cache() or None - - -def _patch_event_loop(): - """Patch the event loop so LangGraph detects it as running inside Temporal's sandbox.""" - from asyncio import get_event_loop - - loop = get_event_loop() - setattr(loop, "is_running", lambda: True) diff --git a/tests/contrib/langgraph/e2e_functional_workflows.py b/tests/contrib/langgraph/e2e_functional_workflows.py index c0494d91f..f467d1758 100644 --- a/tests/contrib/langgraph/e2e_functional_workflows.py +++ b/tests/contrib/langgraph/e2e_functional_workflows.py @@ -11,16 +11,22 @@ @workflow.defn class SimpleFunctionalE2EWorkflow: + def __init__(self) -> None: + self.app = entrypoint("e2e_simple_functional") + @workflow.run async def run(self, input_value: int) -> dict: - return await entrypoint("e2e_simple_functional").ainvoke(input_value) + return await self.app.ainvoke(input_value) @workflow.defn class SlowFunctionalWorkflow: + def __init__(self) -> None: + self.app = entrypoint("e2e_slow_functional") + @workflow.run async def run(self, input_value: int) -> dict: - return await entrypoint("e2e_slow_functional").ainvoke(input_value) + return await self.app.ainvoke(input_value) @dataclass @@ -37,9 +43,9 @@ class ContinueAsNewFunctionalWorkflow: @workflow.run async def run(self, input_data: ContinueAsNewInput) -> dict[str, Any]: - result = await entrypoint( - "e2e_continue_as_new_functional", cache=input_data.cache - ).ainvoke(input_data.value) + app = entrypoint("e2e_continue_as_new_functional", cache=input_data.cache) + + result = await app.ainvoke(input_data.value) if not input_data.task_a_done: workflow.continue_as_new( diff --git a/tests/contrib/langgraph/test_continue_as_new.py b/tests/contrib/langgraph/test_continue_as_new.py index 9f09df8fe..cb73c8ca2 100644 --- a/tests/contrib/langgraph/test_continue_as_new.py +++ b/tests/contrib/langgraph/test_continue_as_new.py @@ -25,16 +25,18 @@ async def node(state: State) -> dict[str, str]: @workflow.defn class ContinueAsNewWorkflow: + def __init__(self) -> None: + self.app = graph("my-graph").compile(checkpointer=InMemorySaver()) + @workflow.run async def run(self, values: dict[str, str]) -> Any: - g = graph("my-graph").compile(checkpointer=InMemorySaver()) config = RunnableConfig({"configurable": {"thread_id": "1"}}) - await g.aupdate_state(config, values) - await g.ainvoke(values, config) + await self.app.aupdate_state(config, values) + await self.app.ainvoke(values, config) if len(values["value"]) < 3: - state = await g.aget_state(config) + state = await self.app.aget_state(config) workflow.continue_as_new(state.values) return values diff --git a/tests/contrib/langgraph/test_continue_as_new_cached.py b/tests/contrib/langgraph/test_continue_as_new_cached.py index b30f0c7dc..444fdcd4a 100644 --- a/tests/contrib/langgraph/test_continue_as_new_cached.py +++ b/tests/contrib/langgraph/test_continue_as_new_cached.py @@ -65,8 +65,8 @@ class GraphContinueAsNewWorkflow: @workflow.run async def run(self, input_data: GraphContinueAsNewInput) -> dict[str, int]: - g = graph("cached-graph", cache=input_data.cache).compile() - result = await g.ainvoke({"value": input_data.value}) + app = graph("cached-graph", cache=input_data.cache).compile() + result = await app.ainvoke({"value": input_data.value}) if input_data.phase < 3: workflow.continue_as_new( diff --git a/tests/contrib/langgraph/test_e2e_functional.py b/tests/contrib/langgraph/test_e2e_functional.py index 7c475402b..696280e8f 100644 --- a/tests/contrib/langgraph/test_e2e_functional.py +++ b/tests/contrib/langgraph/test_e2e_functional.py @@ -80,29 +80,36 @@ async def simple_v2_entrypoint(value: int) -> dict: @workflow.defn class SimpleV2Workflow: + def __init__(self) -> None: + self.app = entrypoint("v2_simple") + @workflow.run async def run(self, input_value: int) -> dict[str, Any]: - result = await entrypoint("v2_simple").ainvoke(input_value, version="v2") + result = await self.app.ainvoke(input_value, version="v2") return result.value @workflow.defn class InterruptV2FunctionalWorkflow: + def __init__(self) -> None: + self.app = entrypoint("v2_interrupt") + self.app.checkpointer = InMemorySaver() + @workflow.run async def run(self, input_value: str) -> dict[str, Any]: - app = entrypoint("v2_interrupt") - app.checkpointer = InMemorySaver() config = RunnableConfig( {"configurable": {"thread_id": workflow.info().workflow_id}} ) - result = await app.ainvoke(input_value, config, version="v2") + result = await self.app.ainvoke(input_value, config, version="v2") assert result.value == {} assert len(result.interrupts) == 1 assert result.interrupts[0].value == "Do you approve?" - resumed = await app.ainvoke(Command(resume="approved"), config, version="v2") + resumed = await self.app.ainvoke( + Command(resume="approved"), config, version="v2" + ) return resumed.value diff --git a/tests/contrib/langgraph/test_execute_in_workflow.py b/tests/contrib/langgraph/test_execute_in_workflow.py index 58e3fdba0..76037c037 100644 --- a/tests/contrib/langgraph/test_execute_in_workflow.py +++ b/tests/contrib/langgraph/test_execute_in_workflow.py @@ -20,9 +20,12 @@ async def node(state: State) -> dict[str, str]: # pyright: ignore[reportUnusedP @workflow.defn class ExecuteInWorkflowWorkflow: + def __init__(self) -> None: + self.app = graph("my-graph").compile() + @workflow.run async def run(self, input: str) -> Any: - return await graph("my-graph").compile().ainvoke({"value": input}) + return await self.app.ainvoke({"value": input}) async def test_execute_in_workflow(client: Client): diff --git a/tests/contrib/langgraph/test_interrupt.py b/tests/contrib/langgraph/test_interrupt.py index 91deecab5..cede90e53 100644 --- a/tests/contrib/langgraph/test_interrupt.py +++ b/tests/contrib/langgraph/test_interrupt.py @@ -27,31 +27,35 @@ async def node(state: State) -> dict[str, str]: # pyright: ignore[reportUnusedP @workflow.defn class InterruptWorkflow: + def __init__(self) -> None: + self.app = graph("my-graph").compile(checkpointer=InMemorySaver()) + @workflow.run async def run(self, input: str) -> Any: - g = graph("my-graph").compile(checkpointer=InMemorySaver()) config = RunnableConfig({"configurable": {"thread_id": "1"}}) - result = await g.ainvoke({"value": input}, config) + result = await self.app.ainvoke({"value": input}, config) assert result["__interrupt__"][0].value == "Continue?" - return await g.ainvoke(langgraph.types.Command(resume="yes"), config) + return await self.app.ainvoke(langgraph.types.Command(resume="yes"), config) @workflow.defn class InterruptV2Workflow: + def __init__(self) -> None: + self.app = graph("my-graph").compile(checkpointer=InMemorySaver()) + @workflow.run async def run(self, input: str) -> Any: - g = graph("my-graph").compile(checkpointer=InMemorySaver()) config = RunnableConfig({"configurable": {"thread_id": "1"}}) - result = await g.ainvoke({"value": input}, config, version="v2") + result = await self.app.ainvoke({"value": input}, config, version="v2") assert result.value == {"value": ""} assert len(result.interrupts) == 1 assert result.interrupts[0].value == "Continue?" - return await g.ainvoke(langgraph.types.Command(resume="yes"), config) + return await self.app.ainvoke(langgraph.types.Command(resume="yes"), config) @pytest.mark.parametrize( diff --git a/tests/contrib/langgraph/test_plugin_validation.py b/tests/contrib/langgraph/test_plugin_validation.py index 4fa33c317..56a2c4f15 100644 --- a/tests/contrib/langgraph/test_plugin_validation.py +++ b/tests/contrib/langgraph/test_plugin_validation.py @@ -9,7 +9,11 @@ from pytest import raises from typing_extensions import TypedDict -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin +from temporalio.contrib.langgraph.langgraph_plugin import ( + LangGraphPlugin, + entrypoint, + graph, +) class State(TypedDict): @@ -41,3 +45,13 @@ def test_invalid_execute_in_raises() -> None: with raises(ValueError, match="Invalid execute_in value"): LangGraphPlugin(graphs={f"validation-{uuid4()}": g}) + + +async def test_unknown_graph_raises() -> None: + with raises(KeyError, match="not found"): + graph(f"not-registered-{uuid4()}") + + +async def test_unknown_entrypoint_raises() -> None: + with raises(KeyError, match="not found"): + entrypoint(f"not-registered-{uuid4()}") diff --git a/tests/contrib/langgraph/test_streaming.py b/tests/contrib/langgraph/test_streaming.py index 9617224a1..1c4b19132 100644 --- a/tests/contrib/langgraph/test_streaming.py +++ b/tests/contrib/langgraph/test_streaming.py @@ -25,10 +25,13 @@ async def node_b(state: State) -> dict[str, str]: @workflow.defn class StreamingWorkflow: + def __init__(self) -> None: + self.app = graph("streaming").compile() + @workflow.run async def run(self, input: str) -> Any: chunks = [] - async for chunk in graph("streaming").compile().astream({"value": input}): + async for chunk in self.app.astream({"value": input}): chunks.append(chunk) return chunks diff --git a/tests/contrib/langgraph/test_subgraph_activity.py b/tests/contrib/langgraph/test_subgraph_activity.py index e73a75865..e752719bb 100644 --- a/tests/contrib/langgraph/test_subgraph_activity.py +++ b/tests/contrib/langgraph/test_subgraph_activity.py @@ -29,9 +29,12 @@ async def parent_node(state: State) -> dict[str, str]: @workflow.defn class ActivitySubgraphWorkflow: + def __init__(self) -> None: + self.app = graph("parent").compile() + @workflow.run async def run(self, input: str) -> Any: - return await graph("parent").compile().ainvoke({"value": input}) + return await self.app.ainvoke({"value": input}) async def test_activity_subgraph(client: Client): diff --git a/tests/contrib/langgraph/test_subgraph_workflow.py b/tests/contrib/langgraph/test_subgraph_workflow.py index c055a2a7a..d85ce25a1 100644 --- a/tests/contrib/langgraph/test_subgraph_workflow.py +++ b/tests/contrib/langgraph/test_subgraph_workflow.py @@ -25,9 +25,12 @@ async def parent_node(state: State) -> dict[str, str]: @workflow.defn class WorkflowSubgraphWorkflow: + def __init__(self) -> None: + self.app = graph("parent").compile() + @workflow.run async def run(self, input: str) -> Any: - return await graph("parent").compile().ainvoke({"value": input}) + return await self.app.ainvoke({"value": input}) async def test_workflow_subgraph(client: Client): diff --git a/tests/contrib/langgraph/test_timeout.py b/tests/contrib/langgraph/test_timeout.py index 6bcdaa8ba..22c2930bc 100644 --- a/tests/contrib/langgraph/test_timeout.py +++ b/tests/contrib/langgraph/test_timeout.py @@ -25,9 +25,12 @@ async def node(state: State) -> dict[str, str]: # pyright: ignore[reportUnusedP @workflow.defn class TimeoutWorkflow: + def __init__(self) -> None: + self.app = graph("my-graph").compile() + @workflow.run async def run(self, input: str) -> Any: - return await graph("my-graph").compile().ainvoke({"value": input}) + return await self.app.ainvoke({"value": input}) async def test_timeout(client: Client): diff --git a/tests/contrib/langgraph/test_two_nodes.py b/tests/contrib/langgraph/test_two_nodes.py index d5b7e06a1..992e30dcd 100644 --- a/tests/contrib/langgraph/test_two_nodes.py +++ b/tests/contrib/langgraph/test_two_nodes.py @@ -25,9 +25,12 @@ async def node_b(state: State) -> dict[str, str]: @workflow.defn class TwoNodesWorkflow: + def __init__(self) -> None: + self.app = graph("my-graph").compile() + @workflow.run async def run(self, input: str) -> Any: - return await graph("my-graph").compile().ainvoke({"value": input}) + return await self.app.ainvoke({"value": input}) async def test_two_nodes(client: Client): From b78d063e9fdea4ea8e1aa88ea6da2f370f992ee1 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Thu, 16 Apr 2026 17:43:54 -0700 Subject: [PATCH 18/47] add interceptor --- .../langgraph/langgraph_interceptor.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 temporalio/contrib/langgraph/langgraph_interceptor.py diff --git a/temporalio/contrib/langgraph/langgraph_interceptor.py b/temporalio/contrib/langgraph/langgraph_interceptor.py new file mode 100644 index 000000000..99d285c53 --- /dev/null +++ b/temporalio/contrib/langgraph/langgraph_interceptor.py @@ -0,0 +1,28 @@ +"""Workflow interceptor for the LangGraph plugin.""" + +import asyncio +from typing import Any + +from temporalio.worker import ( + ExecuteWorkflowInput, + Interceptor, + WorkflowInboundInterceptor, + WorkflowInterceptorClassInput, +) + + +class LangGraphInterceptor(Interceptor): + def workflow_interceptor_class( + self, input: WorkflowInterceptorClassInput + ) -> type[WorkflowInboundInterceptor]: + return _LangGraphWorkflowInboundInterceptor + + +class _LangGraphWorkflowInboundInterceptor(WorkflowInboundInterceptor): + """Patches the workflow event loop so LangGraph's `asyncio.eager_task_factory` + (which calls `loop.is_running()`) works inside Temporal's sandbox.""" + + async def execute_workflow(self, input: ExecuteWorkflowInput) -> Any: + loop = asyncio.get_event_loop() + setattr(loop, "is_running", lambda: True) + return await super().execute_workflow(input) From 8ef609f09704a5c001006ec9af5ff96a18203ad1 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Fri, 17 Apr 2026 15:41:49 -0700 Subject: [PATCH 19/47] remove graph and entrypoint functions in favor of direct graph usage --- temporalio/contrib/langgraph/README.md | 92 ++++++++++-- temporalio/contrib/langgraph/__init__.py | 6 +- .../contrib/langgraph/langgraph_plugin.py | 140 ++++++++---------- .../langgraph/e2e_functional_workflows.py | 32 ++-- .../contrib/langgraph/test_continue_as_new.py | 17 ++- .../langgraph/test_continue_as_new_cached.py | 30 ++-- .../contrib/langgraph/test_e2e_functional.py | 48 ++---- .../langgraph/test_execute_in_workflow.py | 15 +- tests/contrib/langgraph/test_interrupt.py | 17 ++- .../langgraph/test_plugin_validation.py | 28 +--- tests/contrib/langgraph/test_replay.py | 33 +---- tests/contrib/langgraph/test_streaming.py | 19 +-- .../langgraph/test_subgraph_activity.py | 15 +- .../langgraph/test_subgraph_workflow.py | 34 +++-- tests/contrib/langgraph/test_timeout.py | 15 +- tests/contrib/langgraph/test_two_nodes.py | 19 +-- 16 files changed, 282 insertions(+), 278 deletions(-) diff --git a/temporalio/contrib/langgraph/README.md b/temporalio/contrib/langgraph/README.md index e8a4da74b..d4250d304 100644 --- a/temporalio/contrib/langgraph/README.md +++ b/temporalio/contrib/langgraph/README.md @@ -16,24 +16,66 @@ or with pip: pip install temporalio[langgraph] ``` -## Plugin Initialization +## Module layout -### Graph API +Define your graphs, tasks, and entrypoints in a module **separate** from your `@workflow.defn` classes — the standard Temporal split. The plugin adds the graph/task modules to the workflow sandbox's passthrough list so its in-place rewrites are visible to the workflow. Workflow modules stay sandboxed. + +## Graph API ```python +# graphs.py +from langgraph.graph import START, StateGraph + +my_graph = StateGraph(State) +my_graph.add_node("my_node", my_node) +my_graph.add_edge(START, "my_node") + +# workflow.py +from temporalio import workflow +from myapp.graphs import my_graph + +@workflow.defn +class MyWorkflow: + @workflow.run + async def run(self, input): + return await my_graph.compile().ainvoke(input) + +# worker.py from temporalio.contrib.langgraph import LangGraphPlugin +from myapp.graphs import my_graph -plugin = LangGraphPlugin(graphs={"my-graph": graph}) +plugin = LangGraphPlugin(graphs=[my_graph]) ``` -### Functional API +## Functional API ```python +# flows.py +from langgraph.func import entrypoint, task + +@task +async def my_task(x): ... + +@entrypoint() +async def my_flow(inputs): + return await my_task(inputs) + +# workflow.py +from temporalio import workflow +from myapp.flows import my_flow + +@workflow.defn +class MyWorkflow: + @workflow.run + async def run(self, input): + return await my_flow.ainvoke(input) + +# worker.py import datetime from temporalio.contrib.langgraph import LangGraphPlugin +from myapp.flows import my_task plugin = LangGraphPlugin( - entrypoints={"my_entrypoint": my_entrypoint}, tasks=[my_task], activity_options={ "my_task": { @@ -49,19 +91,17 @@ Use `InMemorySaver` as your checkpointer. Temporal handles durability, so third- ```python import langgraph.checkpoint.memory -import typing -from temporalio.contrib.langgraph import graph from temporalio import workflow +from myapp.graphs import my_graph @workflow.defn class MyWorkflow: @workflow.run - async def run(self, input: str) -> typing.Any: - g = graph("my-graph").compile( + async def run(self, input): + app = my_graph.compile( checkpointer=langgraph.checkpoint.memory.InMemorySaver(), ) - ... ``` @@ -71,7 +111,7 @@ Options are passed through to [`workflow.execute_activity()`](https://python.tem ### Graph API -Pass activity options as node `metadata` when calling `add_node`: +Pass per-node options as node `metadata`, or plugin-wide defaults via `default_activity_options`: ```python import datetime @@ -82,11 +122,16 @@ g.add_node("my_node", my_node, metadata={ "start_to_close_timeout": datetime.timedelta(seconds=30), "retry_policy": RetryPolicy(maximum_attempts=3), }) + +plugin = LangGraphPlugin( + graphs=[g], + default_activity_options={"start_to_close_timeout": datetime.timedelta(seconds=60)}, +) ``` ### Functional API -Pass activity options to the `Plugin` constructor, keyed by task function name: +Pass activity options to the plugin, keyed by task function name: ```python import datetime @@ -94,7 +139,6 @@ from temporalio.common import RetryPolicy from temporalio.contrib.langgraph import LangGraphPlugin plugin = LangGraphPlugin( - entrypoints={"my_entrypoint": my_entrypoint}, tasks=[my_task], activity_options={ "my_task": { @@ -111,7 +155,7 @@ To skip the Activity wrapper and run a node or task directly in the Workflow, se ```python # Graph API -graph.add_node("my_node", my_node, metadata={"execute_in": "workflow"}) +g.add_node("my_node", my_node, metadata={"execute_in": "workflow"}) # Functional API plugin = LangGraphPlugin( @@ -120,6 +164,26 @@ plugin = LangGraphPlugin( ) ``` +## Continue-As-New + +To carry cached task results across a continue-as-new boundary, pass the cache to your next run and restore it with `set_cache`: + +```python +from temporalio import workflow +from temporalio.contrib.langgraph import cache, set_cache +from myapp.graphs import my_graph + +@workflow.defn +class MyWorkflow: + @workflow.run + async def run(self, input, prev_cache=None): + set_cache(prev_cache) + result = await my_graph.compile().ainvoke(input) + if should_continue(result): + workflow.continue_as_new(next_input, cache()) + return result +``` + ## Running Tests Install dependencies: diff --git a/temporalio/contrib/langgraph/__init__.py b/temporalio/contrib/langgraph/__init__.py index 50d8ca147..48a7931ed 100644 --- a/temporalio/contrib/langgraph/__init__.py +++ b/temporalio/contrib/langgraph/__init__.py @@ -13,13 +13,11 @@ from temporalio.contrib.langgraph.langgraph_plugin import ( LangGraphPlugin, cache, - entrypoint, - graph, + set_cache, ) __all__ = [ "LangGraphPlugin", - "entrypoint", "cache", - "graph", + "set_cache", ] diff --git a/temporalio/contrib/langgraph/langgraph_plugin.py b/temporalio/contrib/langgraph/langgraph_plugin.py index 6ce1ddd91..e8105bb52 100644 --- a/temporalio/contrib/langgraph/langgraph_plugin.py +++ b/temporalio/contrib/langgraph/langgraph_plugin.py @@ -2,12 +2,13 @@ # pyright: reportMissingTypeStubs=false +from __future__ import annotations + from dataclasses import replace -from typing import Any, Callable +from typing import Any from langgraph._internal._runnable import RunnableCallable from langgraph.graph import StateGraph -from langgraph.pregel import Pregel from temporalio import activity from temporalio.contrib.langgraph.activity import wrap_activity, wrap_execute_activity @@ -21,10 +22,6 @@ from temporalio.worker import WorkflowRunner from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner -# Save registered graphs/entrypoints at the module level to avoid being refreshed by the sandbox. -_graph_registry: dict[str, StateGraph[Any]] = {} -_entrypoint_registry: dict[str, Pregel[Any, Any, Any, Any]] = {} - class LangGraphPlugin(SimplePlugin): """LangGraph plugin for Temporal SDK. @@ -37,26 +34,28 @@ class LangGraphPlugin(SimplePlugin): and tasks as Temporal Activities, giving your AI agent workflows durable execution, automatic retries, and timeouts. It supports both the LangGraph Graph API (``StateGraph``) and Functional API (``@entrypoint`` / ``@task``). + + Pass your graphs and tasks to the plugin; the plugin mutates them in place so + node/task invocations dispatch to Temporal activities. The modules those + graphs and tasks are defined in are automatically added to the workflow + sandbox's passthrough list, so the mutation is visible inside the sandbox. + Keep your ``@workflow.defn`` classes in a module separate from your graphs + and tasks (the standard Temporal convention). """ def __init__( self, - # Graph API - graphs: dict[str, StateGraph] | None = None, - # Functional API - entrypoints: dict[str, Pregel[Any, Any, Any, Any]] | None = None, + graphs: list[StateGraph] | None = None, tasks: list | None = None, - # TODO: Remove activity_options when we have support for @task(metadata=...) activity_options: dict[str, dict] | None = None, default_activity_options: dict[str, Any] | None = None, ): - """Initialize the LangGraph plugin with graphs, entrypoints, and tasks.""" + """Register activities for graphs and tasks.""" self.activities: list = [] + passthrough_modules: set[str] = set() - # Graph API: Wrap graph nodes as Temporal Activities. if graphs: - _graph_registry.update(graphs) - for graph_name, graph in graphs.items(): + for graph in graphs: for node_name, node in graph.nodes.items(): runnable = node.runnable if ( @@ -66,28 +65,25 @@ def __init__( raise ValueError( f"Node {node_name} must have an async function" ) - # Remove LangSmith-related callback functions that can't be serialized between the workflow and activity. + # Remove LangSmith-related callback functions that can't be + # serialized between the workflow and activity. runnable.func_accepts = {} opts = {**(default_activity_options or {}), **(node.metadata or {})} - runnable.afunc = self.execute( - f"{graph_name}.{node_name}", runnable.afunc, opts + runnable.afunc = self._wrap( + runnable.afunc, opts, passthrough_modules ) - if entrypoints: - _entrypoint_registry.update(entrypoints) - - # Functional API: Wrap @task functions as Temporal Activities. if tasks: - for task in tasks: - name = task.func.__name__ + for t in tasks: + name = t.func.__name__ + qualname = getattr(t.func, "__qualname__", name) opts = { **(default_activity_options or {}), **(activity_options or {}).get(name, {}), } - - task.func = self.execute(task_id(task.func), task.func, opts) - task.func.__name__ = name - task.func.__qualname__ = getattr(task.func, "__qualname__", name) + t.func = self._wrap(t.func, opts, passthrough_modules) + t.func.__name__ = name + t.func.__qualname__ = qualname def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: if not runner: @@ -101,6 +97,7 @@ def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: "langgraph", "langsmith", "numpy", # LangSmith uses numpy + *passthrough_modules, ), ) return runner @@ -112,71 +109,64 @@ def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: interceptors=[LangGraphInterceptor()], ) - def execute( + def _wrap( self, - activity_name: str, - func: Callable, - kwargs: dict[str, Any] | None = None, - ) -> Callable: - """Prepare a node or task to execute as an activity or inline in the workflow.""" - opts = kwargs or {} - execute_in = opts.pop("execute_in", "activity") + func: Any, + opts: dict[str, Any], + passthrough_modules: set[str], + ) -> Any: + """Wrap a node afunc or task func as an activity. Idempotent across plugins. + + Records the activity defn on ``self.activities`` and the function's + origin module on ``passthrough_modules``. If ``func`` is already wrapped + (e.g., a second plugin sharing the same graph), reuses the cached + activity defn and module — no double-wrap. + """ + meta = getattr(func, "_temporal_meta", None) + if meta is not None: + a, module = meta + if a is not None: + self.activities.append(a) + if module: + passthrough_modules.add(module) + return func + module = getattr(func, "__module__", None) + execute_in = opts.pop("execute_in", "activity") if execute_in == "activity": + activity_name = task_id(func) a = activity.defn(name=activity_name)(wrap_activity(func)) self.activities.append(a) - return wrap_execute_activity(a, task_id=task_id(func), **opts) + wrapped = wrap_execute_activity(a, task_id=activity_name, **opts) elif execute_in == "workflow": - return func + a = None + wrapped = func else: raise ValueError(f"Invalid execute_in value: {execute_in}") + if module: + passthrough_modules.add(module) + try: + setattr(wrapped, "_temporal_meta", (a, module)) + except (AttributeError, TypeError): + pass + return wrapped -def graph( - name: str, cache: dict[str, Any] | None = None -) -> StateGraph[Any, None, Any, Any]: - """Retrieve a registered graph by name. - Args: - name: Graph name as registered with LangGraphPlugin. - cache: Optional task result cache from a previous cache() call. - Restores cached results so previously-completed nodes are - not re-executed after continue-as-new. - """ - set_task_cache(cache or {}) - if name not in _graph_registry: - raise KeyError( - f"Graph {name!r} not found. " - f"Available graphs: {list(_graph_registry.keys())}" - ) - return _graph_registry[name] +def set_cache(cache: dict[str, Any] | None) -> None: + """Restore a task result cache returned by a previous :func:`cache` call. - -def entrypoint( - name: str, cache: dict[str, Any] | None = None -) -> Pregel[Any, Any, Any, Any]: - """Retrieve a registered entrypoint by name. - - Args: - name: Entrypoint name as registered with Plugin. - cache: Optional task result cache from a previous cache() call. - Restores cached results so previously-completed tasks are - not re-executed after continue-as-new. + Use at the top of a workflow run that resumes from continue-as-new so + already-completed nodes/tasks are not re-executed. """ set_task_cache(cache or {}) - if name not in _entrypoint_registry: - raise KeyError( - f"Entrypoint {name!r} not found. " - f"Available entrypoints: {list(_entrypoint_registry.keys())}" - ) - return _entrypoint_registry[name] def cache() -> dict[str, Any] | None: """Return the task result cache as a serializable dict. - Returns a dict suitable for passing to entrypoint(name, cache=...) to - restore cached task results across continue-as-new boundaries. - Returns None if the cache is empty. + Returns a dict suitable for passing to :func:`set_cache` on the next + workflow run to restore cached task results across continue-as-new + boundaries. Returns None if the cache is empty. """ return get_task_cache() or None diff --git a/tests/contrib/langgraph/e2e_functional_workflows.py b/tests/contrib/langgraph/e2e_functional_workflows.py index f467d1758..2526a1fda 100644 --- a/tests/contrib/langgraph/e2e_functional_workflows.py +++ b/tests/contrib/langgraph/e2e_functional_workflows.py @@ -6,27 +6,27 @@ from typing import Any from temporalio import workflow -from temporalio.contrib.langgraph.langgraph_plugin import cache, entrypoint +from temporalio.contrib.langgraph.langgraph_plugin import cache, set_cache +from tests.contrib.langgraph.e2e_functional_entrypoints import ( + continue_as_new_entrypoint, + partial_execution_entrypoint, + simple_functional_entrypoint, + slow_entrypoint, +) @workflow.defn class SimpleFunctionalE2EWorkflow: - def __init__(self) -> None: - self.app = entrypoint("e2e_simple_functional") - @workflow.run async def run(self, input_value: int) -> dict: - return await self.app.ainvoke(input_value) + return await simple_functional_entrypoint.ainvoke(input_value) @workflow.defn class SlowFunctionalWorkflow: - def __init__(self) -> None: - self.app = entrypoint("e2e_slow_functional") - @workflow.run async def run(self, input_value: int) -> dict: - return await self.app.ainvoke(input_value) + return await slow_entrypoint.ainvoke(input_value) @dataclass @@ -43,9 +43,9 @@ class ContinueAsNewFunctionalWorkflow: @workflow.run async def run(self, input_data: ContinueAsNewInput) -> dict[str, Any]: - app = entrypoint("e2e_continue_as_new_functional", cache=input_data.cache) + set_cache(input_data.cache) - result = await app.ainvoke(input_data.value) + result = await continue_as_new_entrypoint.ainvoke(input_data.value) if not input_data.task_a_done: workflow.continue_as_new( @@ -82,10 +82,12 @@ class PartialExecutionWorkflow: @workflow.run async def run(self, input_data: PartialExecutionInput) -> dict[str, Any]: - app = entrypoint("e2e_partial_execution", cache=input_data.cache) + set_cache(input_data.cache) if input_data.phase == 1: - await app.ainvoke({"value": input_data.value, "stop_after": 3}) + await partial_execution_entrypoint.ainvoke( + {"value": input_data.value, "stop_after": 3} + ) workflow.continue_as_new( PartialExecutionInput( value=input_data.value, @@ -94,4 +96,6 @@ async def run(self, input_data: PartialExecutionInput) -> dict[str, Any]: ) ) - return await app.ainvoke({"value": input_data.value, "stop_after": 5}) + return await partial_execution_entrypoint.ainvoke( + {"value": input_data.value, "stop_after": 5} + ) diff --git a/tests/contrib/langgraph/test_continue_as_new.py b/tests/contrib/langgraph/test_continue_as_new.py index cb73c8ca2..e5d27db78 100644 --- a/tests/contrib/langgraph/test_continue_as_new.py +++ b/tests/contrib/langgraph/test_continue_as_new.py @@ -11,7 +11,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin from temporalio.worker import Worker @@ -23,13 +23,18 @@ async def node(state: State) -> dict[str, str]: return {"value": state["value"] + "a"} +my_graph: StateGraph[State, None, State, State] = StateGraph(State) +my_graph.add_node("node", node) +my_graph.add_edge(START, "node") + + @workflow.defn class ContinueAsNewWorkflow: def __init__(self) -> None: - self.app = graph("my-graph").compile(checkpointer=InMemorySaver()) + self.app = my_graph.compile(checkpointer=InMemorySaver()) @workflow.run - async def run(self, values: dict[str, str]) -> Any: + async def run(self, values: State) -> Any: config = RunnableConfig({"configurable": {"thread_id": "1"}}) await self.app.aupdate_state(config, values) @@ -43,10 +48,6 @@ async def run(self, values: dict[str, str]) -> Any: async def test_continue_as_new(client: Client): - g = StateGraph(State) - g.add_node("node", node) - g.add_edge(START, "node") - task_queue = f"my-graph-{uuid4()}" async with Worker( @@ -55,7 +56,7 @@ async def test_continue_as_new(client: Client): workflows=[ContinueAsNewWorkflow], plugins=[ LangGraphPlugin( - graphs={"my-graph": g}, + graphs=[my_graph], default_activity_options={ "start_to_close_timeout": timedelta(seconds=10) }, diff --git a/tests/contrib/langgraph/test_continue_as_new_cached.py b/tests/contrib/langgraph/test_continue_as_new_cached.py index 444fdcd4a..b9bfd4f8d 100644 --- a/tests/contrib/langgraph/test_continue_as_new_cached.py +++ b/tests/contrib/langgraph/test_continue_as_new_cached.py @@ -14,7 +14,11 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, cache, graph +from temporalio.contrib.langgraph.langgraph_plugin import ( + LangGraphPlugin, + cache, + set_cache, +) from temporalio.worker import Worker # Track execution counts to verify caching @@ -44,6 +48,16 @@ async def double(state: State) -> dict[str, int]: return {"value": state["value"] * 2} +_timeout = {"start_to_close_timeout": timedelta(seconds=10)} +cached_graph: StateGraph[State, None, State, State] = StateGraph(State) +cached_graph.add_node("multiply_by_3", multiply_by_3, metadata=_timeout) +cached_graph.add_node("add_100", add_100, metadata=_timeout) +cached_graph.add_node("double", double, metadata=_timeout) +cached_graph.add_edge(START, "multiply_by_3") +cached_graph.add_edge("multiply_by_3", "add_100") +cached_graph.add_edge("add_100", "double") + + @dataclass class GraphContinueAsNewInput: value: int @@ -65,7 +79,8 @@ class GraphContinueAsNewWorkflow: @workflow.run async def run(self, input_data: GraphContinueAsNewInput) -> dict[str, int]: - app = graph("cached-graph", cache=input_data.cache).compile() + set_cache(input_data.cache) + app = cached_graph.compile() result = await app.ainvoke({"value": input_data.value}) if input_data.phase < 3: @@ -88,22 +103,13 @@ async def test_graph_continue_as_new_cached(client: Client): """ _reset() - timeout = {"start_to_close_timeout": timedelta(seconds=10)} - g = StateGraph(State) - g.add_node("multiply_by_3", multiply_by_3, metadata=timeout) - g.add_node("add_100", add_100, metadata=timeout) - g.add_node("double", double, metadata=timeout) - g.add_edge(START, "multiply_by_3") - g.add_edge("multiply_by_3", "add_100") - g.add_edge("add_100", "double") - task_queue = f"graph-cached-{uuid4()}" async with Worker( client, task_queue=task_queue, workflows=[GraphContinueAsNewWorkflow], - plugins=[LangGraphPlugin(graphs={"cached-graph": g})], + plugins=[LangGraphPlugin(graphs=[cached_graph])], ): result = await client.execute_workflow( GraphContinueAsNewWorkflow.run, diff --git a/tests/contrib/langgraph/test_e2e_functional.py b/tests/contrib/langgraph/test_e2e_functional.py index 696280e8f..ceee62ac0 100644 --- a/tests/contrib/langgraph/test_e2e_functional.py +++ b/tests/contrib/langgraph/test_e2e_functional.py @@ -22,22 +22,18 @@ from temporalio import workflow from temporalio.client import Client, WorkflowFailureError from temporalio.common import RetryPolicy -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, entrypoint +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin from temporalio.worker import Worker from tests.contrib.langgraph.e2e_functional_entrypoints import ( add_ten, ask_human, - continue_as_new_entrypoint, double_value, expensive_task_a, expensive_task_b, expensive_task_c, get_task_execution_counts, interrupt_entrypoint, - partial_execution_entrypoint, reset_task_execution_counts, - simple_functional_entrypoint, - slow_entrypoint, slow_task, step_1, step_2, @@ -80,20 +76,16 @@ async def simple_v2_entrypoint(value: int) -> dict: @workflow.defn class SimpleV2Workflow: - def __init__(self) -> None: - self.app = entrypoint("v2_simple") - @workflow.run async def run(self, input_value: int) -> dict[str, Any]: - result = await self.app.ainvoke(input_value, version="v2") + result = await simple_v2_entrypoint.ainvoke(input_value, version="v2") return result.value @workflow.defn class InterruptV2FunctionalWorkflow: def __init__(self) -> None: - self.app = entrypoint("v2_interrupt") - self.app.checkpointer = InMemorySaver() + interrupt_entrypoint.checkpointer = InMemorySaver() @workflow.run async def run(self, input_value: str) -> dict[str, Any]: @@ -101,13 +93,13 @@ async def run(self, input_value: str) -> dict[str, Any]: {"configurable": {"thread_id": workflow.info().workflow_id}} ) - result = await self.app.ainvoke(input_value, config, version="v2") + result = await interrupt_entrypoint.ainvoke(input_value, config, version="v2") assert result.value == {} assert len(result.interrupts) == 1 assert result.interrupts[0].value == "Do you approve?" - resumed = await self.app.ainvoke( + resumed = await interrupt_entrypoint.ainvoke( Command(resume="approved"), config, version="v2" ) return resumed.value @@ -115,22 +107,10 @@ async def run(self, input_value: str) -> dict[str, Any]: class TestFunctionalAPIBasicExecution: @pytest.mark.parametrize( - "workflow_cls,entrypoint_func,entrypoint_name,tasks,expected_result", + "workflow_cls,tasks,expected_result", [ - ( - SimpleFunctionalE2EWorkflow, - simple_functional_entrypoint, - "e2e_simple_functional", - [double_value, add_ten], - 30, - ), - ( - SimpleV2Workflow, - simple_v2_entrypoint, - "v2_simple", - [triple_value, add_five], - 35, - ), + (SimpleFunctionalE2EWorkflow, [double_value, add_ten], 30), + (SimpleV2Workflow, [triple_value, add_five], 35), ], ids=["v1", "v2"], ) @@ -138,8 +118,6 @@ async def test_simple_entrypoint( self, client: Client, workflow_cls: Any, - entrypoint_func: Any, - entrypoint_name: str, tasks: list, expected_result: int, ) -> None: @@ -151,7 +129,6 @@ async def test_simple_entrypoint( workflows=[workflow_cls], plugins=[ LangGraphPlugin( - entrypoints={entrypoint_name: entrypoint_func}, tasks=tasks, default_activity_options=_DEFAULT_ACTIVITY_OPTIONS, ) @@ -182,9 +159,6 @@ async def test_continue_as_new_with_checkpoint(self, client: Client) -> None: workflows=[ContinueAsNewFunctionalWorkflow], plugins=[ LangGraphPlugin( - entrypoints={ - "e2e_continue_as_new_functional": continue_as_new_entrypoint - }, tasks=tasks, default_activity_options=_DEFAULT_ACTIVITY_OPTIONS, ) @@ -226,7 +200,6 @@ async def test_partial_execution_five_tasks(self, client: Client) -> None: workflows=[PartialExecutionWorkflow], plugins=[ LangGraphPlugin( - entrypoints={"e2e_partial_execution": partial_execution_entrypoint}, tasks=tasks, default_activity_options=_DEFAULT_ACTIVITY_OPTIONS, ) @@ -253,7 +226,6 @@ async def test_partial_execution_five_tasks(self, client: Client) -> None: class TestFunctionalAPIInterruptV2: async def test_interrupt_v2_functional(self, client: Client) -> None: """version='v2' separates interrupts from value in functional API.""" - tasks = [ask_human] task_queue = f"v2-interrupt-{uuid4()}" async with Worker( @@ -262,8 +234,7 @@ async def test_interrupt_v2_functional(self, client: Client) -> None: workflows=[InterruptV2FunctionalWorkflow], plugins=[ LangGraphPlugin( - entrypoints={"v2_interrupt": interrupt_entrypoint}, - tasks=tasks, + tasks=[ask_human], default_activity_options=_DEFAULT_ACTIVITY_OPTIONS, ) ], @@ -291,7 +262,6 @@ async def test_per_task_activity_options_override(self, client: Client) -> None: workflows=[SlowFunctionalWorkflow], plugins=[ LangGraphPlugin( - entrypoints={"e2e_slow_functional": slow_entrypoint}, tasks=[slow_task], default_activity_options=_DEFAULT_ACTIVITY_OPTIONS, activity_options={ diff --git a/tests/contrib/langgraph/test_execute_in_workflow.py b/tests/contrib/langgraph/test_execute_in_workflow.py index 76037c037..bcc06ae16 100644 --- a/tests/contrib/langgraph/test_execute_in_workflow.py +++ b/tests/contrib/langgraph/test_execute_in_workflow.py @@ -6,7 +6,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin from temporalio.worker import Worker @@ -18,10 +18,15 @@ async def node(state: State) -> dict[str, str]: # pyright: ignore[reportUnusedP return {"value": "done"} +inline_graph: StateGraph[State, None, State, State] = StateGraph(State) +inline_graph.add_node("node", node, metadata={"execute_in": "workflow"}) +inline_graph.add_edge(START, "node") + + @workflow.defn class ExecuteInWorkflowWorkflow: def __init__(self) -> None: - self.app = graph("my-graph").compile() + self.app = inline_graph.compile() @workflow.run async def run(self, input: str) -> Any: @@ -29,17 +34,13 @@ async def run(self, input: str) -> Any: async def test_execute_in_workflow(client: Client): - g = StateGraph(State) - g.add_node("node", node, metadata={"execute_in": "workflow"}) - g.add_edge(START, "node") - task_queue = f"my-graph-{uuid4()}" async with Worker( client, task_queue=task_queue, workflows=[ExecuteInWorkflowWorkflow], - plugins=[LangGraphPlugin(graphs={"my-graph": g})], + plugins=[LangGraphPlugin(graphs=[inline_graph])], ): result = await client.execute_workflow( ExecuteInWorkflowWorkflow.run, diff --git a/tests/contrib/langgraph/test_interrupt.py b/tests/contrib/langgraph/test_interrupt.py index cede90e53..19440ad68 100644 --- a/tests/contrib/langgraph/test_interrupt.py +++ b/tests/contrib/langgraph/test_interrupt.py @@ -13,7 +13,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin from temporalio.worker import Worker @@ -25,10 +25,15 @@ async def node(state: State) -> dict[str, str]: # pyright: ignore[reportUnusedP return {"value": langgraph.types.interrupt("Continue?")} +interrupt_graph: StateGraph[State, None, State, State] = StateGraph(State) +interrupt_graph.add_node("node", node) +interrupt_graph.add_edge(START, "node") + + @workflow.defn class InterruptWorkflow: def __init__(self) -> None: - self.app = graph("my-graph").compile(checkpointer=InMemorySaver()) + self.app = interrupt_graph.compile(checkpointer=InMemorySaver()) @workflow.run async def run(self, input: str) -> Any: @@ -43,7 +48,7 @@ async def run(self, input: str) -> Any: @workflow.defn class InterruptV2Workflow: def __init__(self) -> None: - self.app = graph("my-graph").compile(checkpointer=InMemorySaver()) + self.app = interrupt_graph.compile(checkpointer=InMemorySaver()) @workflow.run async def run(self, input: str) -> Any: @@ -62,10 +67,6 @@ async def run(self, input: str) -> Any: "workflow_cls", [InterruptWorkflow, InterruptV2Workflow], ids=["v1", "v2"] ) async def test_interrupt(client: Client, workflow_cls: Any) -> None: - g = StateGraph(State) - g.add_node("node", node) - g.add_edge(START, "node") - task_queue = f"interrupt-{uuid4()}" async with Worker( @@ -74,7 +75,7 @@ async def test_interrupt(client: Client, workflow_cls: Any) -> None: workflows=[workflow_cls], plugins=[ LangGraphPlugin( - graphs={"my-graph": g}, + graphs=[interrupt_graph], default_activity_options={ "start_to_close_timeout": timedelta(seconds=10) }, diff --git a/tests/contrib/langgraph/test_plugin_validation.py b/tests/contrib/langgraph/test_plugin_validation.py index 56a2c4f15..c693917f0 100644 --- a/tests/contrib/langgraph/test_plugin_validation.py +++ b/tests/contrib/langgraph/test_plugin_validation.py @@ -1,19 +1,13 @@ -"""Tests for LangGraphPlugin validation and registry lookup error paths.""" +"""Tests for LangGraphPlugin validation.""" from __future__ import annotations -from uuid import uuid4 - from langchain_core.runnables import RunnableLambda from langgraph.graph import START, StateGraph # pyright: ignore[reportMissingTypeStubs] from pytest import raises from typing_extensions import TypedDict -from temporalio.contrib.langgraph.langgraph_plugin import ( - LangGraphPlugin, - entrypoint, - graph, -) +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin class State(TypedDict): @@ -30,28 +24,18 @@ def sync_node(state: State) -> dict[str, str]: # pyright: ignore[reportUnusedPa def test_non_runnable_callable_node_raises() -> None: """Nodes whose runnable isn't a RunnableCallable can't be wrapped as activities.""" - g = StateGraph(State) + g: StateGraph[State, None, State, State] = StateGraph(State) g.add_node("node", RunnableLambda(sync_node)) g.add_edge(START, "node") with raises(ValueError, match="must have an async function"): - LangGraphPlugin(graphs={f"validation-{uuid4()}": g}) + LangGraphPlugin(graphs=[g]) def test_invalid_execute_in_raises() -> None: - g = StateGraph(State) + g: StateGraph[State, None, State, State] = StateGraph(State) g.add_node("node", async_node, metadata={"execute_in": "bogus"}) g.add_edge(START, "node") with raises(ValueError, match="Invalid execute_in value"): - LangGraphPlugin(graphs={f"validation-{uuid4()}": g}) - - -async def test_unknown_graph_raises() -> None: - with raises(KeyError, match="not found"): - graph(f"not-registered-{uuid4()}") - - -async def test_unknown_entrypoint_raises() -> None: - with raises(KeyError, match="not found"): - entrypoint(f"not-registered-{uuid4()}") + LangGraphPlugin(graphs=[g]) diff --git a/tests/contrib/langgraph/test_replay.py b/tests/contrib/langgraph/test_replay.py index f7ba60e07..bddcc4d72 100644 --- a/tests/contrib/langgraph/test_replay.py +++ b/tests/contrib/langgraph/test_replay.py @@ -1,40 +1,24 @@ from datetime import timedelta from uuid import uuid4 -from langgraph.graph import START, StateGraph # pyright: ignore[reportMissingTypeStubs] - from temporalio.client import Client from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin from temporalio.worker import Replayer, Worker from tests.contrib.langgraph.test_interrupt import ( InterruptWorkflow, -) -from tests.contrib.langgraph.test_interrupt import ( - State as InterruptState, -) -from tests.contrib.langgraph.test_interrupt import ( - node as interrupt_node, + interrupt_graph, ) from tests.contrib.langgraph.test_two_nodes import ( - State, TwoNodesWorkflow, - node_a, - node_b, + my_graph, ) +_DEFAULTS = {"start_to_close_timeout": timedelta(seconds=10)} -async def test_replay(client: Client): - g = StateGraph(State) - g.add_node("node_a", node_a) - g.add_node("node_b", node_b) - g.add_edge(START, "node_a") - g.add_edge("node_a", "node_b") +async def test_replay(client: Client): task_queue = f"my-graph-{uuid4()}" - plugin = LangGraphPlugin( - graphs={"my-graph": g}, - default_activity_options={"start_to_close_timeout": timedelta(seconds=10)}, - ) + plugin = LangGraphPlugin(graphs=[my_graph], default_activity_options=_DEFAULTS) async with Worker( client, @@ -57,14 +41,9 @@ async def test_replay(client: Client): async def test_replay_interrupt(client: Client): - g = StateGraph(InterruptState) - g.add_node("node", interrupt_node) - g.add_edge(START, "node") - task_queue = f"interrupt-replay-{uuid4()}" plugin = LangGraphPlugin( - graphs={"my-graph": g}, - default_activity_options={"start_to_close_timeout": timedelta(seconds=10)}, + graphs=[interrupt_graph], default_activity_options=_DEFAULTS ) async with Worker( diff --git a/tests/contrib/langgraph/test_streaming.py b/tests/contrib/langgraph/test_streaming.py index 1c4b19132..60dd833fa 100644 --- a/tests/contrib/langgraph/test_streaming.py +++ b/tests/contrib/langgraph/test_streaming.py @@ -7,7 +7,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin from temporalio.worker import Worker @@ -23,10 +23,17 @@ async def node_b(state: State) -> dict[str, str]: return {"value": state["value"] + "b"} +streaming_graph: StateGraph[State, None, State, State] = StateGraph(State) +streaming_graph.add_node("node_a", node_a) +streaming_graph.add_node("node_b", node_b) +streaming_graph.add_edge(START, "node_a") +streaming_graph.add_edge("node_a", "node_b") + + @workflow.defn class StreamingWorkflow: def __init__(self) -> None: - self.app = graph("streaming").compile() + self.app = streaming_graph.compile() @workflow.run async def run(self, input: str) -> Any: @@ -37,12 +44,6 @@ async def run(self, input: str) -> Any: async def test_streaming(client: Client): - g = StateGraph(State) - g.add_node("node_a", node_a) - g.add_node("node_b", node_b) - g.add_edge(START, "node_a") - g.add_edge("node_a", "node_b") - task_queue = f"streaming-{uuid4()}" async with Worker( @@ -51,7 +52,7 @@ async def test_streaming(client: Client): workflows=[StreamingWorkflow], plugins=[ LangGraphPlugin( - graphs={"streaming": g}, + graphs=[streaming_graph], default_activity_options={ "start_to_close_timeout": timedelta(seconds=10) }, diff --git a/tests/contrib/langgraph/test_subgraph_activity.py b/tests/contrib/langgraph/test_subgraph_activity.py index e752719bb..f7e024914 100644 --- a/tests/contrib/langgraph/test_subgraph_activity.py +++ b/tests/contrib/langgraph/test_subgraph_activity.py @@ -7,7 +7,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin from temporalio.worker import Worker @@ -27,10 +27,15 @@ async def parent_node(state: State) -> dict[str, str]: return await child.compile().ainvoke(state) +parent_graph: StateGraph[State, None, State, State] = StateGraph(State) +parent_graph.add_node("parent_node", parent_node) +parent_graph.add_edge(START, "parent_node") + + @workflow.defn class ActivitySubgraphWorkflow: def __init__(self) -> None: - self.app = graph("parent").compile() + self.app = parent_graph.compile() @workflow.run async def run(self, input: str) -> Any: @@ -38,10 +43,6 @@ async def run(self, input: str) -> Any: async def test_activity_subgraph(client: Client): - parent = StateGraph(State) - parent.add_node("parent_node", parent_node) - parent.add_edge(START, "parent_node") - task_queue = f"subgraph-{uuid4()}" async with Worker( @@ -50,7 +51,7 @@ async def test_activity_subgraph(client: Client): workflows=[ActivitySubgraphWorkflow], plugins=[ LangGraphPlugin( - graphs={"parent": parent}, + graphs=[parent_graph], default_activity_options={ "start_to_close_timeout": timedelta(seconds=10) }, diff --git a/tests/contrib/langgraph/test_subgraph_workflow.py b/tests/contrib/langgraph/test_subgraph_workflow.py index d85ce25a1..114df1a3a 100644 --- a/tests/contrib/langgraph/test_subgraph_workflow.py +++ b/tests/contrib/langgraph/test_subgraph_workflow.py @@ -7,7 +7,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin from temporalio.worker import Worker @@ -19,14 +19,28 @@ async def child_node(state: State) -> dict[str, str]: # pyright: ignore[reportU return {"value": "child"} +child_graph: StateGraph[State, None, State, State] = StateGraph(State) +child_graph.add_node( + "child_node", + child_node, + metadata={"start_to_close_timeout": timedelta(seconds=10)}, +) +child_graph.add_edge(START, "child_node") + + async def parent_node(state: State) -> dict[str, str]: - return await graph("child").compile().ainvoke(state) + return await child_graph.compile().ainvoke(state) + + +parent_graph: StateGraph[State, None, State, State] = StateGraph(State) +parent_graph.add_node("parent_node", parent_node, metadata={"execute_in": "workflow"}) +parent_graph.add_edge(START, "parent_node") @workflow.defn class WorkflowSubgraphWorkflow: def __init__(self) -> None: - self.app = graph("parent").compile() + self.app = parent_graph.compile() @workflow.run async def run(self, input: str) -> Any: @@ -34,25 +48,13 @@ async def run(self, input: str) -> Any: async def test_workflow_subgraph(client: Client): - child = StateGraph(State) - child.add_node( - "child_node", - child_node, - metadata={"start_to_close_timeout": timedelta(seconds=10)}, - ) - child.add_edge(START, "child_node") - - parent = StateGraph(State) - parent.add_node("parent_node", parent_node, metadata={"execute_in": "workflow"}) - parent.add_edge(START, "parent_node") - task_queue = f"subgraph-{uuid4()}" async with Worker( client, task_queue=task_queue, workflows=[WorkflowSubgraphWorkflow], - plugins=[LangGraphPlugin(graphs={"parent": parent, "child": child})], + plugins=[LangGraphPlugin(graphs=[parent_graph, child_graph])], ): result = await client.execute_workflow( WorkflowSubgraphWorkflow.run, diff --git a/tests/contrib/langgraph/test_timeout.py b/tests/contrib/langgraph/test_timeout.py index 22c2930bc..41a78e557 100644 --- a/tests/contrib/langgraph/test_timeout.py +++ b/tests/contrib/langgraph/test_timeout.py @@ -10,7 +10,7 @@ from temporalio import workflow from temporalio.client import Client, WorkflowFailureError from temporalio.common import RetryPolicy -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin from temporalio.worker import Worker @@ -23,10 +23,15 @@ async def node(state: State) -> dict[str, str]: # pyright: ignore[reportUnusedP return {"value": "done"} +timeout_graph: StateGraph[State, None, State, State] = StateGraph(State) +timeout_graph.add_node("node", node) +timeout_graph.add_edge(START, "node") + + @workflow.defn class TimeoutWorkflow: def __init__(self) -> None: - self.app = graph("my-graph").compile() + self.app = timeout_graph.compile() @workflow.run async def run(self, input: str) -> Any: @@ -34,10 +39,6 @@ async def run(self, input: str) -> Any: async def test_timeout(client: Client): - g = StateGraph(State) - g.add_node("node", node) - g.add_edge(START, "node") - task_queue = f"my-graph-{uuid4()}" async with Worker( @@ -46,7 +47,7 @@ async def test_timeout(client: Client): workflows=[TimeoutWorkflow], plugins=[ LangGraphPlugin( - graphs={"my-graph": g}, + graphs=[timeout_graph], default_activity_options={ "start_to_close_timeout": timedelta(milliseconds=100), "retry_policy": RetryPolicy(maximum_attempts=1), diff --git a/tests/contrib/langgraph/test_two_nodes.py b/tests/contrib/langgraph/test_two_nodes.py index 992e30dcd..1dbaf4d88 100644 --- a/tests/contrib/langgraph/test_two_nodes.py +++ b/tests/contrib/langgraph/test_two_nodes.py @@ -7,7 +7,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin from temporalio.worker import Worker @@ -23,10 +23,17 @@ async def node_b(state: State) -> dict[str, str]: return {"value": state["value"] + "b"} +my_graph: StateGraph[State, None, State, State] = StateGraph(State) +my_graph.add_node("node_a", node_a) +my_graph.add_node("node_b", node_b) +my_graph.add_edge(START, "node_a") +my_graph.add_edge("node_a", "node_b") + + @workflow.defn class TwoNodesWorkflow: def __init__(self) -> None: - self.app = graph("my-graph").compile() + self.app = my_graph.compile() @workflow.run async def run(self, input: str) -> Any: @@ -34,12 +41,6 @@ async def run(self, input: str) -> Any: async def test_two_nodes(client: Client): - g = StateGraph(State) - g.add_node("node_a", node_a) - g.add_node("node_b", node_b) - g.add_edge(START, "node_a") - g.add_edge("node_a", "node_b") - task_queue = f"my-graph-{uuid4()}" async with Worker( @@ -48,7 +49,7 @@ async def test_two_nodes(client: Client): workflows=[TwoNodesWorkflow], plugins=[ LangGraphPlugin( - graphs={"my-graph": g}, + graphs=[my_graph], default_activity_options={ "start_to_close_timeout": timedelta(seconds=10) }, From ec8244ccc9b0871e49c254d88eb4f955312a450f Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Fri, 17 Apr 2026 15:45:07 -0700 Subject: [PATCH 20/47] rename cache() to get_cache() --- temporalio/contrib/langgraph/README.md | 4 ++-- temporalio/contrib/langgraph/__init__.py | 4 ++-- temporalio/contrib/langgraph/langgraph_plugin.py | 14 +++++--------- temporalio/contrib/langgraph/task_cache.py | 10 ---------- .../contrib/langgraph/e2e_functional_workflows.py | 8 ++++---- .../langgraph/test_continue_as_new_cached.py | 4 ++-- uv.lock | 13 ++++++++++--- 7 files changed, 25 insertions(+), 32 deletions(-) diff --git a/temporalio/contrib/langgraph/README.md b/temporalio/contrib/langgraph/README.md index d4250d304..bc8efb89b 100644 --- a/temporalio/contrib/langgraph/README.md +++ b/temporalio/contrib/langgraph/README.md @@ -170,7 +170,7 @@ To carry cached task results across a continue-as-new boundary, pass the cache t ```python from temporalio import workflow -from temporalio.contrib.langgraph import cache, set_cache +from temporalio.contrib.langgraph import get_cache, set_cache from myapp.graphs import my_graph @workflow.defn @@ -180,7 +180,7 @@ class MyWorkflow: set_cache(prev_cache) result = await my_graph.compile().ainvoke(input) if should_continue(result): - workflow.continue_as_new(next_input, cache()) + workflow.continue_as_new(next_input, get_cache()) return result ``` diff --git a/temporalio/contrib/langgraph/__init__.py b/temporalio/contrib/langgraph/__init__.py index 48a7931ed..e47bd0794 100644 --- a/temporalio/contrib/langgraph/__init__.py +++ b/temporalio/contrib/langgraph/__init__.py @@ -12,12 +12,12 @@ from temporalio.contrib.langgraph.langgraph_plugin import ( LangGraphPlugin, - cache, + get_cache, set_cache, ) __all__ = [ "LangGraphPlugin", - "cache", + "get_cache", "set_cache", ] diff --git a/temporalio/contrib/langgraph/langgraph_plugin.py b/temporalio/contrib/langgraph/langgraph_plugin.py index e8105bb52..5e3089358 100644 --- a/temporalio/contrib/langgraph/langgraph_plugin.py +++ b/temporalio/contrib/langgraph/langgraph_plugin.py @@ -13,11 +13,7 @@ from temporalio import activity from temporalio.contrib.langgraph.activity import wrap_activity, wrap_execute_activity from temporalio.contrib.langgraph.langgraph_interceptor import LangGraphInterceptor -from temporalio.contrib.langgraph.task_cache import ( - get_task_cache, - set_task_cache, - task_id, -) +from temporalio.contrib.langgraph.task_cache import _task_cache, task_id from temporalio.plugin import SimplePlugin from temporalio.worker import WorkflowRunner from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner @@ -154,19 +150,19 @@ def _wrap( def set_cache(cache: dict[str, Any] | None) -> None: - """Restore a task result cache returned by a previous :func:`cache` call. + """Restore a task result cache returned by a previous :func:`get_cache` call. Use at the top of a workflow run that resumes from continue-as-new so already-completed nodes/tasks are not re-executed. """ - set_task_cache(cache or {}) + _task_cache.set(cache or {}) -def cache() -> dict[str, Any] | None: +def get_cache() -> dict[str, Any] | None: """Return the task result cache as a serializable dict. Returns a dict suitable for passing to :func:`set_cache` on the next workflow run to restore cached task results across continue-as-new boundaries. Returns None if the cache is empty. """ - return get_task_cache() or None + return _task_cache.get() or None diff --git a/temporalio/contrib/langgraph/task_cache.py b/temporalio/contrib/langgraph/task_cache.py index d4053c808..2f9d9aab0 100644 --- a/temporalio/contrib/langgraph/task_cache.py +++ b/temporalio/contrib/langgraph/task_cache.py @@ -17,16 +17,6 @@ ) -def set_task_cache(cache: dict[str, Any] | None) -> None: - """Set the task result cache for the current context.""" - _task_cache.set(cache) - - -def get_task_cache() -> dict[str, Any] | None: - """Get the task result cache for the current context.""" - return _task_cache.get() - - def task_id(func: Any) -> str: """Return the fully-qualified module.qualname for a function. diff --git a/tests/contrib/langgraph/e2e_functional_workflows.py b/tests/contrib/langgraph/e2e_functional_workflows.py index 2526a1fda..f1a927f56 100644 --- a/tests/contrib/langgraph/e2e_functional_workflows.py +++ b/tests/contrib/langgraph/e2e_functional_workflows.py @@ -6,7 +6,7 @@ from typing import Any from temporalio import workflow -from temporalio.contrib.langgraph.langgraph_plugin import cache, set_cache +from temporalio.contrib.langgraph.langgraph_plugin import get_cache, set_cache from tests.contrib.langgraph.e2e_functional_entrypoints import ( continue_as_new_entrypoint, partial_execution_entrypoint, @@ -51,7 +51,7 @@ async def run(self, input_data: ContinueAsNewInput) -> dict[str, Any]: workflow.continue_as_new( ContinueAsNewInput( value=input_data.value, - cache=cache(), + cache=get_cache(), task_a_done=True, ) ) @@ -60,7 +60,7 @@ async def run(self, input_data: ContinueAsNewInput) -> dict[str, Any]: workflow.continue_as_new( ContinueAsNewInput( value=input_data.value, - cache=cache(), + cache=get_cache(), task_a_done=True, task_b_done=True, ) @@ -91,7 +91,7 @@ async def run(self, input_data: PartialExecutionInput) -> dict[str, Any]: workflow.continue_as_new( PartialExecutionInput( value=input_data.value, - cache=cache(), + cache=get_cache(), phase=2, ) ) diff --git a/tests/contrib/langgraph/test_continue_as_new_cached.py b/tests/contrib/langgraph/test_continue_as_new_cached.py index b9bfd4f8d..78d593589 100644 --- a/tests/contrib/langgraph/test_continue_as_new_cached.py +++ b/tests/contrib/langgraph/test_continue_as_new_cached.py @@ -16,7 +16,7 @@ from temporalio.client import Client from temporalio.contrib.langgraph.langgraph_plugin import ( LangGraphPlugin, - cache, + get_cache, set_cache, ) from temporalio.worker import Worker @@ -87,7 +87,7 @@ async def run(self, input_data: GraphContinueAsNewInput) -> dict[str, int]: workflow.continue_as_new( GraphContinueAsNewInput( value=input_data.value, - cache=cache(), + cache=get_cache(), phase=input_data.phase + 1, ) ) diff --git a/uv.lock b/uv.lock index a9ad989ae..6201351c2 100644 --- a/uv.lock +++ b/uv.lock @@ -8,6 +8,13 @@ resolution-markers = [ "python_full_version < '3.11'", ] +[options] +exclude-newer = "2026-04-10T22:44:27.30202Z" +exclude-newer-span = "P1W" + +[options.exclude-newer-package] +openai-agents = false + [[package]] name = "aioboto3" version = "15.5.0" @@ -2470,7 +2477,7 @@ wheels = [ [[package]] name = "langchain-core" -version = "1.2.29" +version = "1.2.28" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jsonpatch" }, @@ -2482,9 +2489,9 @@ dependencies = [ { name = "typing-extensions" }, { name = "uuid-utils" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a0/d8/7bdf30e4bfc5175609201806e399506a0a78a48e14367dc8b776a9b4c89c/langchain_core-1.2.29.tar.gz", hash = "sha256:cfb89c92bca81ad083eafcdfe6ec40f9803c9abf7dd166d0f8a8de1d2de03ca6", size = 846121, upload-time = "2026-04-14T20:44:58.117Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f8/a4/317a1a3ac1df33a64adb3670bf88bbe3b3d5baa274db6863a979db472897/langchain_core-1.2.28.tar.gz", hash = "sha256:271a3d8bd618f795fdeba112b0753980457fc90537c46a0c11998516a74dc2cb", size = 846119, upload-time = "2026-04-08T18:19:34.867Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/72/37/fed31f80436b1d7bb222f1f2345300a77a88215416acf8d1cb7c8fda7388/langchain_core-1.2.29-py3-none-any.whl", hash = "sha256:11f02e57ee1c24e6e0e6577acbd35df77b205d4692a3df956b03b5389cbe44a0", size = 508733, upload-time = "2026-04-14T20:44:56.712Z" }, + { url = "https://files.pythonhosted.org/packages/a8/92/32f785f077c7e898da97064f113c73fbd9ad55d1e2169cf3a391b183dedb/langchain_core-1.2.28-py3-none-any.whl", hash = "sha256:80764232581eaf8057bcefa71dbf8adc1f6a28d257ebd8b95ba9b8b452e8c6ac", size = 508727, upload-time = "2026-04-08T18:19:32.823Z" }, ] [[package]] From 1da61a8d8ef7bb34c23119b206fe4e3e974023db Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Fri, 17 Apr 2026 15:53:52 -0700 Subject: [PATCH 21/47] remove interceptor --- .../langgraph/langgraph_interceptor.py | 28 ------------------- .../contrib/langgraph/langgraph_plugin.py | 2 -- 2 files changed, 30 deletions(-) delete mode 100644 temporalio/contrib/langgraph/langgraph_interceptor.py diff --git a/temporalio/contrib/langgraph/langgraph_interceptor.py b/temporalio/contrib/langgraph/langgraph_interceptor.py deleted file mode 100644 index 99d285c53..000000000 --- a/temporalio/contrib/langgraph/langgraph_interceptor.py +++ /dev/null @@ -1,28 +0,0 @@ -"""Workflow interceptor for the LangGraph plugin.""" - -import asyncio -from typing import Any - -from temporalio.worker import ( - ExecuteWorkflowInput, - Interceptor, - WorkflowInboundInterceptor, - WorkflowInterceptorClassInput, -) - - -class LangGraphInterceptor(Interceptor): - def workflow_interceptor_class( - self, input: WorkflowInterceptorClassInput - ) -> type[WorkflowInboundInterceptor]: - return _LangGraphWorkflowInboundInterceptor - - -class _LangGraphWorkflowInboundInterceptor(WorkflowInboundInterceptor): - """Patches the workflow event loop so LangGraph's `asyncio.eager_task_factory` - (which calls `loop.is_running()`) works inside Temporal's sandbox.""" - - async def execute_workflow(self, input: ExecuteWorkflowInput) -> Any: - loop = asyncio.get_event_loop() - setattr(loop, "is_running", lambda: True) - return await super().execute_workflow(input) diff --git a/temporalio/contrib/langgraph/langgraph_plugin.py b/temporalio/contrib/langgraph/langgraph_plugin.py index 5e3089358..2687be29b 100644 --- a/temporalio/contrib/langgraph/langgraph_plugin.py +++ b/temporalio/contrib/langgraph/langgraph_plugin.py @@ -12,7 +12,6 @@ from temporalio import activity from temporalio.contrib.langgraph.activity import wrap_activity, wrap_execute_activity -from temporalio.contrib.langgraph.langgraph_interceptor import LangGraphInterceptor from temporalio.contrib.langgraph.task_cache import _task_cache, task_id from temporalio.plugin import SimplePlugin from temporalio.worker import WorkflowRunner @@ -102,7 +101,6 @@ def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: "temporalio.LangGraphPlugin", activities=self.activities, workflow_runner=workflow_runner, - interceptors=[LangGraphInterceptor()], ) def _wrap( From c84c22f06f8e8d2132cd035531cfbd6578a3aa7a Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Fri, 17 Apr 2026 16:23:50 -0700 Subject: [PATCH 22/47] allow metadata to be accessed from node func and test --- temporalio/contrib/langgraph/activity.py | 10 +++ .../contrib/langgraph/langgraph_plugin.py | 36 +++++++++-- tests/contrib/langgraph/test_node_metadata.py | 62 +++++++++++++++++++ 3 files changed, 103 insertions(+), 5 deletions(-) create mode 100644 tests/contrib/langgraph/test_node_metadata.py diff --git a/temporalio/contrib/langgraph/activity.py b/temporalio/contrib/langgraph/activity.py index 04cdd8d4a..7690ee7c5 100644 --- a/temporalio/contrib/langgraph/activity.py +++ b/temporalio/contrib/langgraph/activity.py @@ -64,6 +64,16 @@ def wrap_execute_activity( """Wrap an activity function to be called via workflow.execute_activity with caching.""" async def wrapper(*args: Any, **kwargs: Any) -> Any: + # LangGraph may inject a RunnableConfig as the 'config' kwarg. Strip it + # down to a serializable subset (metadata + tags) so it can cross the + # activity boundary; callbacks, stores, etc. aren't serializable. + if "config" in kwargs: + orig = kwargs["config"] or {} + kwargs["config"] = { + "metadata": dict(orig.get("metadata") or {}), + "tags": list(orig.get("tags") or []), + } + # Check task result cache (for continue-as-new deduplication). key = cache_key(task_id, args, kwargs) if task_id else "" if task_id: diff --git a/temporalio/contrib/langgraph/langgraph_plugin.py b/temporalio/contrib/langgraph/langgraph_plugin.py index 2687be29b..ae7cc393e 100644 --- a/temporalio/contrib/langgraph/langgraph_plugin.py +++ b/temporalio/contrib/langgraph/langgraph_plugin.py @@ -4,19 +4,24 @@ from __future__ import annotations +import inspect from dataclasses import replace from typing import Any from langgraph._internal._runnable import RunnableCallable from langgraph.graph import StateGraph -from temporalio import activity +from temporalio import activity, workflow from temporalio.contrib.langgraph.activity import wrap_activity, wrap_execute_activity from temporalio.contrib.langgraph.task_cache import _task_cache, task_id from temporalio.plugin import SimplePlugin from temporalio.worker import WorkflowRunner from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner +_ACTIVITY_OPTION_KEYS: frozenset[str] = frozenset( + {"execute_in", *inspect.signature(workflow.execute_activity).parameters} +) + class LangGraphPlugin(SimplePlugin): """LangGraph plugin for Temporal SDK. @@ -60,10 +65,31 @@ def __init__( raise ValueError( f"Node {node_name} must have an async function" ) - # Remove LangSmith-related callback functions that can't be - # serialized between the workflow and activity. - runnable.func_accepts = {} - opts = {**(default_activity_options or {}), **(node.metadata or {})} + # Keep only 'config' injection so node functions can read + # metadata/tags. Drop writer/store/runtime/etc., which hold + # non-serializable objects that can't cross the activity + # boundary. The wrapper serializes config down to its + # portable subset before handing off to the activity. + runnable.func_accepts = { + k: v + for k, v in runnable.func_accepts.items() + if k == "config" + } + # Split node.metadata into activity options vs. user + # metadata. Activity-option keys (timeouts, retry policy, + # etc.) become kwargs to workflow.execute_activity; user + # keys stay on node.metadata so LangGraph exposes them to + # the node function via config["metadata"]. + node_meta = node.metadata or {} + node_opts = { + k: v for k, v in node_meta.items() if k in _ACTIVITY_OPTION_KEYS + } + node.metadata = { + k: v + for k, v in node_meta.items() + if k not in _ACTIVITY_OPTION_KEYS + } + opts = {**(default_activity_options or {}), **node_opts} runnable.afunc = self._wrap( runnable.afunc, opts, passthrough_modules ) diff --git a/tests/contrib/langgraph/test_node_metadata.py b/tests/contrib/langgraph/test_node_metadata.py new file mode 100644 index 000000000..28ee1f477 --- /dev/null +++ b/tests/contrib/langgraph/test_node_metadata.py @@ -0,0 +1,62 @@ +from datetime import timedelta +from typing import Any +from uuid import uuid4 + +from langchain_core.runnables import RunnableConfig # pyright: ignore[reportMissingTypeStubs] +from langgraph.graph import START, StateGraph # pyright: ignore[reportMissingTypeStubs] +from typing_extensions import TypedDict + +from temporalio import workflow +from temporalio.client import Client +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin +from temporalio.worker import Worker + + +class State(TypedDict): + value: str + + +async def node(state: State, config: RunnableConfig) -> dict[str, str]: + metadata = config.get("metadata") or {} + return {"value": state["value"] + str(metadata.get("my_key", "NOT_FOUND"))} + + +metadata_graph: StateGraph[State, None, State, State] = StateGraph(State) +metadata_graph.add_node( + "node", + node, + metadata={ + "start_to_close_timeout": timedelta(seconds=10), + "my_key": "my_value", + }, +) +metadata_graph.add_edge(START, "node") + + +@workflow.defn +class NodeMetadataWorkflow: + def __init__(self) -> None: + self.app = metadata_graph.compile() + + @workflow.run + async def run(self, input: str) -> Any: + return await self.app.ainvoke({"value": input}) + + +async def test_node_metadata_readable_in_node(client: Client): + task_queue = f"my-graph-{uuid4()}" + + async with Worker( + client, + task_queue=task_queue, + workflows=[NodeMetadataWorkflow], + plugins=[LangGraphPlugin(graphs=[metadata_graph])], + ): + result = await client.execute_workflow( + NodeMetadataWorkflow.run, + "prefix-", + id=f"test-workflow-{uuid4()}", + task_queue=task_queue, + ) + + assert result == {"value": "prefix-my_value"} From 077452d58ced5b0da973c0cdca22acb5b633b483 Mon Sep 17 00:00:00 2001 From: DABH Date: Sat, 18 Apr 2026 16:15:32 -0500 Subject: [PATCH 23/47] Fix import sorting in test_node_metadata.py --- tests/contrib/langgraph/test_node_metadata.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/contrib/langgraph/test_node_metadata.py b/tests/contrib/langgraph/test_node_metadata.py index 28ee1f477..a809805cf 100644 --- a/tests/contrib/langgraph/test_node_metadata.py +++ b/tests/contrib/langgraph/test_node_metadata.py @@ -2,7 +2,9 @@ from typing import Any from uuid import uuid4 -from langchain_core.runnables import RunnableConfig # pyright: ignore[reportMissingTypeStubs] +from langchain_core.runnables import ( + RunnableConfig, # pyright: ignore[reportMissingTypeStubs] +) from langgraph.graph import START, StateGraph # pyright: ignore[reportMissingTypeStubs] from typing_extensions import TypedDict From 7716e470e4caee95233dc5308cf5c7d55758f6b6 Mon Sep 17 00:00:00 2001 From: DABH Date: Sat, 18 Apr 2026 16:23:07 -0500 Subject: [PATCH 24/47] Fix formatting in langgraph_plugin.py --- temporalio/contrib/langgraph/langgraph_plugin.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/temporalio/contrib/langgraph/langgraph_plugin.py b/temporalio/contrib/langgraph/langgraph_plugin.py index ae7cc393e..07e69b448 100644 --- a/temporalio/contrib/langgraph/langgraph_plugin.py +++ b/temporalio/contrib/langgraph/langgraph_plugin.py @@ -71,9 +71,7 @@ def __init__( # boundary. The wrapper serializes config down to its # portable subset before handing off to the activity. runnable.func_accepts = { - k: v - for k, v in runnable.func_accepts.items() - if k == "config" + k: v for k, v in runnable.func_accepts.items() if k == "config" } # Split node.metadata into activity options vs. user # metadata. Activity-option keys (timeouts, retry policy, From 30094dcfc0d78d536482a94c043e3e404eec6965 Mon Sep 17 00:00:00 2001 From: DABH Date: Sat, 18 Apr 2026 16:31:45 -0500 Subject: [PATCH 25/47] Fix mypy errors: add type params to StateGraph and use State() constructor --- temporalio/contrib/langgraph/langgraph_plugin.py | 2 +- tests/contrib/langgraph/test_continue_as_new.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/temporalio/contrib/langgraph/langgraph_plugin.py b/temporalio/contrib/langgraph/langgraph_plugin.py index 07e69b448..73b157be3 100644 --- a/temporalio/contrib/langgraph/langgraph_plugin.py +++ b/temporalio/contrib/langgraph/langgraph_plugin.py @@ -45,7 +45,7 @@ class LangGraphPlugin(SimplePlugin): def __init__( self, - graphs: list[StateGraph] | None = None, + graphs: list[StateGraph[Any, Any, Any, Any]] | None = None, tasks: list | None = None, activity_options: dict[str, dict] | None = None, default_activity_options: dict[str, Any] | None = None, diff --git a/tests/contrib/langgraph/test_continue_as_new.py b/tests/contrib/langgraph/test_continue_as_new.py index e5d27db78..7d2c3d0ad 100644 --- a/tests/contrib/langgraph/test_continue_as_new.py +++ b/tests/contrib/langgraph/test_continue_as_new.py @@ -65,7 +65,7 @@ async def test_continue_as_new(client: Client): ): result = await client.execute_workflow( ContinueAsNewWorkflow.run, - {"value": ""}, + State(value=""), id=f"test-workflow-{uuid4()}", task_queue=task_queue, ) From e0d766c46e368dc4dbb4a4c64cd7a7d1e4f70ec9 Mon Sep 17 00:00:00 2001 From: DABH Date: Sat, 18 Apr 2026 18:41:22 -0500 Subject: [PATCH 26/47] Fix langsmith sandbox crash when langchain_core is installed The LangSmithPlugin passes langsmith through the workflow sandbox, but langsmith conditionally imports langchain_core when it's available. With the langgraph extra adding langchain_core as a transitive dep, this conditional import now fires inside the sandbox, where langchain_core's lazy module loading triggers a restricted access on concurrent.futures.ThreadPoolExecutor. Two fixes: - Pre-import langchain_core.runnables.config at plugin load time so it's in sys.modules before the sandbox starts - Add langchain_core to the sandbox passthrough list - Add a timeout to _poll_query in langsmith tests to prevent infinite hangs if a workflow never reaches the expected state --- temporalio/contrib/langsmith/_plugin.py | 12 +++++++++++- tests/contrib/langsmith/test_integration.py | 5 ++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/temporalio/contrib/langsmith/_plugin.py b/temporalio/contrib/langsmith/_plugin.py index 6e9fba0ee..3d4b65879 100644 --- a/temporalio/contrib/langsmith/_plugin.py +++ b/temporalio/contrib/langsmith/_plugin.py @@ -9,6 +9,15 @@ import langsmith +# langsmith conditionally imports langchain_core when it is installed. +# Pre-import the lazily-loaded submodule so it is in sys.modules before the +# workflow sandbox starts; otherwise the sandbox's __getattr__-triggered +# import hits restrictions on concurrent.futures.ThreadPoolExecutor. +try: + import langchain_core.runnables.config # noqa: F401 +except ImportError: + pass + from temporalio.contrib.langsmith._interceptor import LangSmithInterceptor from temporalio.plugin import SimplePlugin from temporalio.worker import WorkflowRunner @@ -62,7 +71,8 @@ def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: return dataclasses.replace( runner, restrictions=runner.restrictions.with_passthrough_modules( - "langsmith" + "langsmith", + "langchain_core", ), ) return runner diff --git a/tests/contrib/langsmith/test_integration.py b/tests/contrib/langsmith/test_integration.py index 78d48c71e..29cc28ce9 100644 --- a/tests/contrib/langsmith/test_integration.py +++ b/tests/contrib/langsmith/test_integration.py @@ -318,9 +318,11 @@ async def _poll_query( query: Callable[..., Any], *, expected: Any = True, + timeout: float = 45, ) -> bool: """Poll a workflow query until it returns the expected value.""" - while True: + deadline = asyncio.get_event_loop().time() + timeout + while asyncio.get_event_loop().time() < deadline: try: result = await handle.query(query) if result == expected: @@ -328,6 +330,7 @@ async def _poll_query( except (WorkflowQueryFailedError, RPCError): pass # Query not yet available (workflow hasn't started) await asyncio.sleep(1) + return False # --------------------------------------------------------------------------- From 38f2a18d7297543cbde0fa439d875a0d8c4ef9a1 Mon Sep 17 00:00:00 2001 From: DABH Date: Sat, 18 Apr 2026 20:50:48 -0500 Subject: [PATCH 27/47] Suppress basedpyright unused import warning for langchain_core preload --- temporalio/contrib/langsmith/_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/temporalio/contrib/langsmith/_plugin.py b/temporalio/contrib/langsmith/_plugin.py index 3d4b65879..789c93414 100644 --- a/temporalio/contrib/langsmith/_plugin.py +++ b/temporalio/contrib/langsmith/_plugin.py @@ -14,7 +14,7 @@ # workflow sandbox starts; otherwise the sandbox's __getattr__-triggered # import hits restrictions on concurrent.futures.ThreadPoolExecutor. try: - import langchain_core.runnables.config # noqa: F401 + import langchain_core.runnables.config # noqa: F401 # pyright: ignore[reportUnusedImport] except ImportError: pass From aba75a3cfbb63007a3bd0fe7cb033bddd77df962 Mon Sep 17 00:00:00 2001 From: DABH Date: Sat, 18 Apr 2026 22:58:46 -0500 Subject: [PATCH 28/47] Skip langgraph async tests on Python < 3.11 and warn plugin users LangGraph's Functional API (@task/@entrypoint) and interrupt() require Python >= 3.11 for async context variable propagation via asyncio.create_task(context=...). On older versions, get_config() raises "Called get_config outside of a runnable context". This is a documented LangGraph limitation: https://reference.langchain.com/python/langgraph/config/get_store/ - Skip test_e2e_functional.py, test_interrupt.py, and test_replay_interrupt on Python < 3.11 - Add a runtime warning in LangGraphPlugin.__init__ on Python < 3.11 --- temporalio/contrib/langgraph/langgraph_plugin.py | 13 +++++++++++++ tests/contrib/langgraph/test_e2e_functional.py | 8 ++++++++ tests/contrib/langgraph/test_interrupt.py | 7 +++++++ tests/contrib/langgraph/test_replay.py | 7 +++++++ 4 files changed, 35 insertions(+) diff --git a/temporalio/contrib/langgraph/langgraph_plugin.py b/temporalio/contrib/langgraph/langgraph_plugin.py index 73b157be3..66cd9c831 100644 --- a/temporalio/contrib/langgraph/langgraph_plugin.py +++ b/temporalio/contrib/langgraph/langgraph_plugin.py @@ -5,6 +5,8 @@ from __future__ import annotations import inspect +import sys +import warnings from dataclasses import replace from typing import Any @@ -51,6 +53,17 @@ def __init__( default_activity_options: dict[str, Any] | None = None, ): """Register activities for graphs and tasks.""" + if sys.version_info < (3, 11): + warnings.warn( + "LangGraphPlugin requires Python >= 3.11 for full async support. " + "On older versions, the Functional API (@task/@entrypoint) and " + "interrupt() will not work because LangGraph relies on " + "contextvars propagation through asyncio.create_task(), which is " + "only available in Python 3.11+. See " + "https://reference.langchain.com/python/langgraph/config/get_store/", + stacklevel=2, + ) + self.activities: list = [] passthrough_modules: set[str] = set() diff --git a/tests/contrib/langgraph/test_e2e_functional.py b/tests/contrib/langgraph/test_e2e_functional.py index ceee62ac0..a9e457f51 100644 --- a/tests/contrib/langgraph/test_e2e_functional.py +++ b/tests/contrib/langgraph/test_e2e_functional.py @@ -1,15 +1,23 @@ """End-to-end tests for LangGraph Functional API integration (v1 and v2). Requires a running Temporal test server (started by conftest.py). +LangGraph's Functional API requires Python >= 3.11 for async context +variable propagation (see langgraph.config.get_config). """ from __future__ import annotations +import sys from datetime import timedelta from typing import Any from uuid import uuid4 import pytest + +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 11), + reason="LangGraph Functional API requires Python >= 3.11 for async context propagation", +) from langchain_core.runnables import RunnableConfig from langgraph.checkpoint.memory import InMemorySaver from langgraph.func import ( # pyright: ignore[reportMissingTypeStubs] diff --git a/tests/contrib/langgraph/test_interrupt.py b/tests/contrib/langgraph/test_interrupt.py index 19440ad68..c3f5711c6 100644 --- a/tests/contrib/langgraph/test_interrupt.py +++ b/tests/contrib/langgraph/test_interrupt.py @@ -1,9 +1,16 @@ +import sys from datetime import timedelta from typing import Any from uuid import uuid4 import langgraph.types import pytest + +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 11), + reason="langgraph.types.interrupt() requires Python >= 3.11 for async context propagation", +) +import pytest from langgraph.checkpoint.memory import InMemorySaver from langgraph.graph import START, StateGraph # pyright: ignore[reportMissingTypeStubs] from langgraph.graph.state import ( # pyright: ignore[reportMissingTypeStubs] diff --git a/tests/contrib/langgraph/test_replay.py b/tests/contrib/langgraph/test_replay.py index bddcc4d72..a5821f910 100644 --- a/tests/contrib/langgraph/test_replay.py +++ b/tests/contrib/langgraph/test_replay.py @@ -1,6 +1,9 @@ +import sys from datetime import timedelta from uuid import uuid4 +import pytest + from temporalio.client import Client from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin from temporalio.worker import Replayer, Worker @@ -40,6 +43,10 @@ async def test_replay(client: Client): ).replay_workflow(await handle.fetch_history()) +@pytest.mark.skipif( + sys.version_info < (3, 11), + reason="langgraph.types.interrupt() requires Python >= 3.11 for async context propagation", +) async def test_replay_interrupt(client: Client): task_queue = f"interrupt-replay-{uuid4()}" plugin = LangGraphPlugin( From 6f19fa8b2f56930987821f118158f529e0c330ac Mon Sep 17 00:00:00 2001 From: DABH Date: Sat, 18 Apr 2026 23:10:01 -0500 Subject: [PATCH 29/47] Remove duplicate pytest import in test_interrupt.py --- tests/contrib/langgraph/test_interrupt.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/contrib/langgraph/test_interrupt.py b/tests/contrib/langgraph/test_interrupt.py index c3f5711c6..ecf6068e8 100644 --- a/tests/contrib/langgraph/test_interrupt.py +++ b/tests/contrib/langgraph/test_interrupt.py @@ -10,7 +10,6 @@ sys.version_info < (3, 11), reason="langgraph.types.interrupt() requires Python >= 3.11 for async context propagation", ) -import pytest from langgraph.checkpoint.memory import InMemorySaver from langgraph.graph import START, StateGraph # pyright: ignore[reportMissingTypeStubs] from langgraph.graph.state import ( # pyright: ignore[reportMissingTypeStubs] From 567f423f057acf5f7060ff4da21e1fcf2c585ab2 Mon Sep 17 00:00:00 2001 From: DABH Date: Sat, 18 Apr 2026 23:18:02 -0500 Subject: [PATCH 30/47] Fix basedpyright reportUnreachable warning on version check --- temporalio/contrib/langgraph/langgraph_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/temporalio/contrib/langgraph/langgraph_plugin.py b/temporalio/contrib/langgraph/langgraph_plugin.py index 66cd9c831..c9654f4c6 100644 --- a/temporalio/contrib/langgraph/langgraph_plugin.py +++ b/temporalio/contrib/langgraph/langgraph_plugin.py @@ -54,7 +54,7 @@ def __init__( ): """Register activities for graphs and tasks.""" if sys.version_info < (3, 11): - warnings.warn( + warnings.warn( # type: ignore[reportUnreachable] "LangGraphPlugin requires Python >= 3.11 for full async support. " "On older versions, the Functional API (@task/@entrypoint) and " "interrupt() will not work because LangGraph relies on " From 319191240c7ea40777771dbfedbc294b19fc11df Mon Sep 17 00:00:00 2001 From: DABH Date: Sun, 19 Apr 2026 13:04:18 -0500 Subject: [PATCH 31/47] Increase execution_timeout for OpenAI tests that call the real API test_hello_world_agent[False] had a 5s execution timeout and test_input_guardrail[False] had a 10s timeout, but both use a 30s activity start_to_close_timeout. The workflow times out before the OpenAI API call can complete on slower CI runners. Bump both to 60s. --- tests/contrib/openai_agents/test_openai.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index 578eb4d77..c3895a72c 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -141,7 +141,7 @@ async def test_hello_world_agent(client: Client, use_local_model: bool): "Tell me about recursion in programming.", id=f"hello-workflow-{uuid.uuid4()}", task_queue=worker.task_queue, - execution_timeout=timedelta(seconds=5), + execution_timeout=timedelta(seconds=60), ) if use_local_model: assert result == "test" @@ -1243,7 +1243,7 @@ async def test_input_guardrail(client: Client, use_local_model: bool): ], id=f"input-guardrail-{uuid.uuid4()}", task_queue=worker.task_queue, - execution_timeout=timedelta(seconds=10), + execution_timeout=timedelta(seconds=60), ) result = await workflow_handle.result() From d13d9122bd4a55ee380a6a35720399ce61be0871 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Mon, 20 Apr 2026 14:09:39 -0700 Subject: [PATCH 32/47] Revert "allow metadata to be accessed from node func and test" This reverts commit c84c22f06f8e8d2132cd035531cfbd6578a3aa7a. --- temporalio/contrib/langgraph/activity.py | 10 --- .../contrib/langgraph/langgraph_plugin.py | 34 ++-------- tests/contrib/langgraph/test_node_metadata.py | 64 ------------------- 3 files changed, 5 insertions(+), 103 deletions(-) delete mode 100644 tests/contrib/langgraph/test_node_metadata.py diff --git a/temporalio/contrib/langgraph/activity.py b/temporalio/contrib/langgraph/activity.py index 7690ee7c5..04cdd8d4a 100644 --- a/temporalio/contrib/langgraph/activity.py +++ b/temporalio/contrib/langgraph/activity.py @@ -64,16 +64,6 @@ def wrap_execute_activity( """Wrap an activity function to be called via workflow.execute_activity with caching.""" async def wrapper(*args: Any, **kwargs: Any) -> Any: - # LangGraph may inject a RunnableConfig as the 'config' kwarg. Strip it - # down to a serializable subset (metadata + tags) so it can cross the - # activity boundary; callbacks, stores, etc. aren't serializable. - if "config" in kwargs: - orig = kwargs["config"] or {} - kwargs["config"] = { - "metadata": dict(orig.get("metadata") or {}), - "tags": list(orig.get("tags") or []), - } - # Check task result cache (for continue-as-new deduplication). key = cache_key(task_id, args, kwargs) if task_id else "" if task_id: diff --git a/temporalio/contrib/langgraph/langgraph_plugin.py b/temporalio/contrib/langgraph/langgraph_plugin.py index c9654f4c6..6550c9b13 100644 --- a/temporalio/contrib/langgraph/langgraph_plugin.py +++ b/temporalio/contrib/langgraph/langgraph_plugin.py @@ -4,7 +4,6 @@ from __future__ import annotations -import inspect import sys import warnings from dataclasses import replace @@ -13,17 +12,13 @@ from langgraph._internal._runnable import RunnableCallable from langgraph.graph import StateGraph -from temporalio import activity, workflow +from temporalio import activity from temporalio.contrib.langgraph.activity import wrap_activity, wrap_execute_activity from temporalio.contrib.langgraph.task_cache import _task_cache, task_id from temporalio.plugin import SimplePlugin from temporalio.worker import WorkflowRunner from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner -_ACTIVITY_OPTION_KEYS: frozenset[str] = frozenset( - {"execute_in", *inspect.signature(workflow.execute_activity).parameters} -) - class LangGraphPlugin(SimplePlugin): """LangGraph plugin for Temporal SDK. @@ -78,29 +73,10 @@ def __init__( raise ValueError( f"Node {node_name} must have an async function" ) - # Keep only 'config' injection so node functions can read - # metadata/tags. Drop writer/store/runtime/etc., which hold - # non-serializable objects that can't cross the activity - # boundary. The wrapper serializes config down to its - # portable subset before handing off to the activity. - runnable.func_accepts = { - k: v for k, v in runnable.func_accepts.items() if k == "config" - } - # Split node.metadata into activity options vs. user - # metadata. Activity-option keys (timeouts, retry policy, - # etc.) become kwargs to workflow.execute_activity; user - # keys stay on node.metadata so LangGraph exposes them to - # the node function via config["metadata"]. - node_meta = node.metadata or {} - node_opts = { - k: v for k, v in node_meta.items() if k in _ACTIVITY_OPTION_KEYS - } - node.metadata = { - k: v - for k, v in node_meta.items() - if k not in _ACTIVITY_OPTION_KEYS - } - opts = {**(default_activity_options or {}), **node_opts} + # Remove LangSmith-related callback functions that can't be + # serialized between the workflow and activity. + runnable.func_accepts = {} + opts = {**(default_activity_options or {}), **(node.metadata or {})} runnable.afunc = self._wrap( runnable.afunc, opts, passthrough_modules ) diff --git a/tests/contrib/langgraph/test_node_metadata.py b/tests/contrib/langgraph/test_node_metadata.py deleted file mode 100644 index a809805cf..000000000 --- a/tests/contrib/langgraph/test_node_metadata.py +++ /dev/null @@ -1,64 +0,0 @@ -from datetime import timedelta -from typing import Any -from uuid import uuid4 - -from langchain_core.runnables import ( - RunnableConfig, # pyright: ignore[reportMissingTypeStubs] -) -from langgraph.graph import START, StateGraph # pyright: ignore[reportMissingTypeStubs] -from typing_extensions import TypedDict - -from temporalio import workflow -from temporalio.client import Client -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin -from temporalio.worker import Worker - - -class State(TypedDict): - value: str - - -async def node(state: State, config: RunnableConfig) -> dict[str, str]: - metadata = config.get("metadata") or {} - return {"value": state["value"] + str(metadata.get("my_key", "NOT_FOUND"))} - - -metadata_graph: StateGraph[State, None, State, State] = StateGraph(State) -metadata_graph.add_node( - "node", - node, - metadata={ - "start_to_close_timeout": timedelta(seconds=10), - "my_key": "my_value", - }, -) -metadata_graph.add_edge(START, "node") - - -@workflow.defn -class NodeMetadataWorkflow: - def __init__(self) -> None: - self.app = metadata_graph.compile() - - @workflow.run - async def run(self, input: str) -> Any: - return await self.app.ainvoke({"value": input}) - - -async def test_node_metadata_readable_in_node(client: Client): - task_queue = f"my-graph-{uuid4()}" - - async with Worker( - client, - task_queue=task_queue, - workflows=[NodeMetadataWorkflow], - plugins=[LangGraphPlugin(graphs=[metadata_graph])], - ): - result = await client.execute_workflow( - NodeMetadataWorkflow.run, - "prefix-", - id=f"test-workflow-{uuid4()}", - task_queue=task_queue, - ) - - assert result == {"value": "prefix-my_value"} From 63e3d0c7efd0115537f4ffbcf6b037b23a60a21b Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Mon, 20 Apr 2026 14:11:01 -0700 Subject: [PATCH 33/47] Revert "rename cache() to get_cache()" This reverts commit ec8244ccc9b0871e49c254d88eb4f955312a450f. --- temporalio/contrib/langgraph/README.md | 4 ++-- temporalio/contrib/langgraph/__init__.py | 4 ++-- temporalio/contrib/langgraph/langgraph_plugin.py | 14 +++++++++----- temporalio/contrib/langgraph/task_cache.py | 10 ++++++++++ .../contrib/langgraph/e2e_functional_workflows.py | 8 ++++---- .../langgraph/test_continue_as_new_cached.py | 4 ++-- uv.lock | 13 +++---------- 7 files changed, 32 insertions(+), 25 deletions(-) diff --git a/temporalio/contrib/langgraph/README.md b/temporalio/contrib/langgraph/README.md index bc8efb89b..d4250d304 100644 --- a/temporalio/contrib/langgraph/README.md +++ b/temporalio/contrib/langgraph/README.md @@ -170,7 +170,7 @@ To carry cached task results across a continue-as-new boundary, pass the cache t ```python from temporalio import workflow -from temporalio.contrib.langgraph import get_cache, set_cache +from temporalio.contrib.langgraph import cache, set_cache from myapp.graphs import my_graph @workflow.defn @@ -180,7 +180,7 @@ class MyWorkflow: set_cache(prev_cache) result = await my_graph.compile().ainvoke(input) if should_continue(result): - workflow.continue_as_new(next_input, get_cache()) + workflow.continue_as_new(next_input, cache()) return result ``` diff --git a/temporalio/contrib/langgraph/__init__.py b/temporalio/contrib/langgraph/__init__.py index e47bd0794..48a7931ed 100644 --- a/temporalio/contrib/langgraph/__init__.py +++ b/temporalio/contrib/langgraph/__init__.py @@ -12,12 +12,12 @@ from temporalio.contrib.langgraph.langgraph_plugin import ( LangGraphPlugin, - get_cache, + cache, set_cache, ) __all__ = [ "LangGraphPlugin", - "get_cache", + "cache", "set_cache", ] diff --git a/temporalio/contrib/langgraph/langgraph_plugin.py b/temporalio/contrib/langgraph/langgraph_plugin.py index 6550c9b13..2a2e54904 100644 --- a/temporalio/contrib/langgraph/langgraph_plugin.py +++ b/temporalio/contrib/langgraph/langgraph_plugin.py @@ -14,7 +14,11 @@ from temporalio import activity from temporalio.contrib.langgraph.activity import wrap_activity, wrap_execute_activity -from temporalio.contrib.langgraph.task_cache import _task_cache, task_id +from temporalio.contrib.langgraph.task_cache import ( + get_task_cache, + set_task_cache, + task_id, +) from temporalio.plugin import SimplePlugin from temporalio.worker import WorkflowRunner from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner @@ -161,19 +165,19 @@ def _wrap( def set_cache(cache: dict[str, Any] | None) -> None: - """Restore a task result cache returned by a previous :func:`get_cache` call. + """Restore a task result cache returned by a previous :func:`cache` call. Use at the top of a workflow run that resumes from continue-as-new so already-completed nodes/tasks are not re-executed. """ - _task_cache.set(cache or {}) + set_task_cache(cache or {}) -def get_cache() -> dict[str, Any] | None: +def cache() -> dict[str, Any] | None: """Return the task result cache as a serializable dict. Returns a dict suitable for passing to :func:`set_cache` on the next workflow run to restore cached task results across continue-as-new boundaries. Returns None if the cache is empty. """ - return _task_cache.get() or None + return get_task_cache() or None diff --git a/temporalio/contrib/langgraph/task_cache.py b/temporalio/contrib/langgraph/task_cache.py index 2f9d9aab0..d4053c808 100644 --- a/temporalio/contrib/langgraph/task_cache.py +++ b/temporalio/contrib/langgraph/task_cache.py @@ -17,6 +17,16 @@ ) +def set_task_cache(cache: dict[str, Any] | None) -> None: + """Set the task result cache for the current context.""" + _task_cache.set(cache) + + +def get_task_cache() -> dict[str, Any] | None: + """Get the task result cache for the current context.""" + return _task_cache.get() + + def task_id(func: Any) -> str: """Return the fully-qualified module.qualname for a function. diff --git a/tests/contrib/langgraph/e2e_functional_workflows.py b/tests/contrib/langgraph/e2e_functional_workflows.py index f1a927f56..2526a1fda 100644 --- a/tests/contrib/langgraph/e2e_functional_workflows.py +++ b/tests/contrib/langgraph/e2e_functional_workflows.py @@ -6,7 +6,7 @@ from typing import Any from temporalio import workflow -from temporalio.contrib.langgraph.langgraph_plugin import get_cache, set_cache +from temporalio.contrib.langgraph.langgraph_plugin import cache, set_cache from tests.contrib.langgraph.e2e_functional_entrypoints import ( continue_as_new_entrypoint, partial_execution_entrypoint, @@ -51,7 +51,7 @@ async def run(self, input_data: ContinueAsNewInput) -> dict[str, Any]: workflow.continue_as_new( ContinueAsNewInput( value=input_data.value, - cache=get_cache(), + cache=cache(), task_a_done=True, ) ) @@ -60,7 +60,7 @@ async def run(self, input_data: ContinueAsNewInput) -> dict[str, Any]: workflow.continue_as_new( ContinueAsNewInput( value=input_data.value, - cache=get_cache(), + cache=cache(), task_a_done=True, task_b_done=True, ) @@ -91,7 +91,7 @@ async def run(self, input_data: PartialExecutionInput) -> dict[str, Any]: workflow.continue_as_new( PartialExecutionInput( value=input_data.value, - cache=get_cache(), + cache=cache(), phase=2, ) ) diff --git a/tests/contrib/langgraph/test_continue_as_new_cached.py b/tests/contrib/langgraph/test_continue_as_new_cached.py index 78d593589..b9bfd4f8d 100644 --- a/tests/contrib/langgraph/test_continue_as_new_cached.py +++ b/tests/contrib/langgraph/test_continue_as_new_cached.py @@ -16,7 +16,7 @@ from temporalio.client import Client from temporalio.contrib.langgraph.langgraph_plugin import ( LangGraphPlugin, - get_cache, + cache, set_cache, ) from temporalio.worker import Worker @@ -87,7 +87,7 @@ async def run(self, input_data: GraphContinueAsNewInput) -> dict[str, int]: workflow.continue_as_new( GraphContinueAsNewInput( value=input_data.value, - cache=get_cache(), + cache=cache(), phase=input_data.phase + 1, ) ) diff --git a/uv.lock b/uv.lock index 6201351c2..a9ad989ae 100644 --- a/uv.lock +++ b/uv.lock @@ -8,13 +8,6 @@ resolution-markers = [ "python_full_version < '3.11'", ] -[options] -exclude-newer = "2026-04-10T22:44:27.30202Z" -exclude-newer-span = "P1W" - -[options.exclude-newer-package] -openai-agents = false - [[package]] name = "aioboto3" version = "15.5.0" @@ -2477,7 +2470,7 @@ wheels = [ [[package]] name = "langchain-core" -version = "1.2.28" +version = "1.2.29" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jsonpatch" }, @@ -2489,9 +2482,9 @@ dependencies = [ { name = "typing-extensions" }, { name = "uuid-utils" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f8/a4/317a1a3ac1df33a64adb3670bf88bbe3b3d5baa274db6863a979db472897/langchain_core-1.2.28.tar.gz", hash = "sha256:271a3d8bd618f795fdeba112b0753980457fc90537c46a0c11998516a74dc2cb", size = 846119, upload-time = "2026-04-08T18:19:34.867Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a0/d8/7bdf30e4bfc5175609201806e399506a0a78a48e14367dc8b776a9b4c89c/langchain_core-1.2.29.tar.gz", hash = "sha256:cfb89c92bca81ad083eafcdfe6ec40f9803c9abf7dd166d0f8a8de1d2de03ca6", size = 846121, upload-time = "2026-04-14T20:44:58.117Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a8/92/32f785f077c7e898da97064f113c73fbd9ad55d1e2169cf3a391b183dedb/langchain_core-1.2.28-py3-none-any.whl", hash = "sha256:80764232581eaf8057bcefa71dbf8adc1f6a28d257ebd8b95ba9b8b452e8c6ac", size = 508727, upload-time = "2026-04-08T18:19:32.823Z" }, + { url = "https://files.pythonhosted.org/packages/72/37/fed31f80436b1d7bb222f1f2345300a77a88215416acf8d1cb7c8fda7388/langchain_core-1.2.29-py3-none-any.whl", hash = "sha256:11f02e57ee1c24e6e0e6577acbd35df77b205d4692a3df956b03b5389cbe44a0", size = 508733, upload-time = "2026-04-14T20:44:56.712Z" }, ] [[package]] From 21acfea50981bf75a9e816b324d050af834ff053 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Mon, 20 Apr 2026 14:14:55 -0700 Subject: [PATCH 34/47] Revert "remove graph and entrypoint functions in favor of direct graph usage" This reverts commit 8ef609f09704a5c001006ec9af5ff96a18203ad1. --- temporalio/contrib/langgraph/README.md | 92 ++---------- temporalio/contrib/langgraph/__init__.py | 6 +- .../contrib/langgraph/langgraph_plugin.py | 138 ++++++++++-------- .../langgraph/e2e_functional_workflows.py | 32 ++-- .../contrib/langgraph/test_continue_as_new.py | 17 +-- .../langgraph/test_continue_as_new_cached.py | 30 ++-- .../contrib/langgraph/test_e2e_functional.py | 48 ++++-- .../langgraph/test_execute_in_workflow.py | 15 +- tests/contrib/langgraph/test_interrupt.py | 17 +-- .../langgraph/test_plugin_validation.py | 28 +++- tests/contrib/langgraph/test_replay.py | 32 +++- tests/contrib/langgraph/test_streaming.py | 19 ++- .../langgraph/test_subgraph_activity.py | 15 +- .../langgraph/test_subgraph_workflow.py | 34 ++--- tests/contrib/langgraph/test_timeout.py | 15 +- tests/contrib/langgraph/test_two_nodes.py | 19 ++- 16 files changed, 277 insertions(+), 280 deletions(-) diff --git a/temporalio/contrib/langgraph/README.md b/temporalio/contrib/langgraph/README.md index d4250d304..e8a4da74b 100644 --- a/temporalio/contrib/langgraph/README.md +++ b/temporalio/contrib/langgraph/README.md @@ -16,66 +16,24 @@ or with pip: pip install temporalio[langgraph] ``` -## Module layout +## Plugin Initialization -Define your graphs, tasks, and entrypoints in a module **separate** from your `@workflow.defn` classes — the standard Temporal split. The plugin adds the graph/task modules to the workflow sandbox's passthrough list so its in-place rewrites are visible to the workflow. Workflow modules stay sandboxed. - -## Graph API +### Graph API ```python -# graphs.py -from langgraph.graph import START, StateGraph - -my_graph = StateGraph(State) -my_graph.add_node("my_node", my_node) -my_graph.add_edge(START, "my_node") - -# workflow.py -from temporalio import workflow -from myapp.graphs import my_graph - -@workflow.defn -class MyWorkflow: - @workflow.run - async def run(self, input): - return await my_graph.compile().ainvoke(input) - -# worker.py from temporalio.contrib.langgraph import LangGraphPlugin -from myapp.graphs import my_graph -plugin = LangGraphPlugin(graphs=[my_graph]) +plugin = LangGraphPlugin(graphs={"my-graph": graph}) ``` -## Functional API +### Functional API ```python -# flows.py -from langgraph.func import entrypoint, task - -@task -async def my_task(x): ... - -@entrypoint() -async def my_flow(inputs): - return await my_task(inputs) - -# workflow.py -from temporalio import workflow -from myapp.flows import my_flow - -@workflow.defn -class MyWorkflow: - @workflow.run - async def run(self, input): - return await my_flow.ainvoke(input) - -# worker.py import datetime from temporalio.contrib.langgraph import LangGraphPlugin -from myapp.flows import my_task plugin = LangGraphPlugin( + entrypoints={"my_entrypoint": my_entrypoint}, tasks=[my_task], activity_options={ "my_task": { @@ -91,17 +49,19 @@ Use `InMemorySaver` as your checkpointer. Temporal handles durability, so third- ```python import langgraph.checkpoint.memory +import typing +from temporalio.contrib.langgraph import graph from temporalio import workflow -from myapp.graphs import my_graph @workflow.defn class MyWorkflow: @workflow.run - async def run(self, input): - app = my_graph.compile( + async def run(self, input: str) -> typing.Any: + g = graph("my-graph").compile( checkpointer=langgraph.checkpoint.memory.InMemorySaver(), ) + ... ``` @@ -111,7 +71,7 @@ Options are passed through to [`workflow.execute_activity()`](https://python.tem ### Graph API -Pass per-node options as node `metadata`, or plugin-wide defaults via `default_activity_options`: +Pass activity options as node `metadata` when calling `add_node`: ```python import datetime @@ -122,16 +82,11 @@ g.add_node("my_node", my_node, metadata={ "start_to_close_timeout": datetime.timedelta(seconds=30), "retry_policy": RetryPolicy(maximum_attempts=3), }) - -plugin = LangGraphPlugin( - graphs=[g], - default_activity_options={"start_to_close_timeout": datetime.timedelta(seconds=60)}, -) ``` ### Functional API -Pass activity options to the plugin, keyed by task function name: +Pass activity options to the `Plugin` constructor, keyed by task function name: ```python import datetime @@ -139,6 +94,7 @@ from temporalio.common import RetryPolicy from temporalio.contrib.langgraph import LangGraphPlugin plugin = LangGraphPlugin( + entrypoints={"my_entrypoint": my_entrypoint}, tasks=[my_task], activity_options={ "my_task": { @@ -155,7 +111,7 @@ To skip the Activity wrapper and run a node or task directly in the Workflow, se ```python # Graph API -g.add_node("my_node", my_node, metadata={"execute_in": "workflow"}) +graph.add_node("my_node", my_node, metadata={"execute_in": "workflow"}) # Functional API plugin = LangGraphPlugin( @@ -164,26 +120,6 @@ plugin = LangGraphPlugin( ) ``` -## Continue-As-New - -To carry cached task results across a continue-as-new boundary, pass the cache to your next run and restore it with `set_cache`: - -```python -from temporalio import workflow -from temporalio.contrib.langgraph import cache, set_cache -from myapp.graphs import my_graph - -@workflow.defn -class MyWorkflow: - @workflow.run - async def run(self, input, prev_cache=None): - set_cache(prev_cache) - result = await my_graph.compile().ainvoke(input) - if should_continue(result): - workflow.continue_as_new(next_input, cache()) - return result -``` - ## Running Tests Install dependencies: diff --git a/temporalio/contrib/langgraph/__init__.py b/temporalio/contrib/langgraph/__init__.py index 48a7931ed..50d8ca147 100644 --- a/temporalio/contrib/langgraph/__init__.py +++ b/temporalio/contrib/langgraph/__init__.py @@ -13,11 +13,13 @@ from temporalio.contrib.langgraph.langgraph_plugin import ( LangGraphPlugin, cache, - set_cache, + entrypoint, + graph, ) __all__ = [ "LangGraphPlugin", + "entrypoint", "cache", - "set_cache", + "graph", ] diff --git a/temporalio/contrib/langgraph/langgraph_plugin.py b/temporalio/contrib/langgraph/langgraph_plugin.py index 2a2e54904..432da67eb 100644 --- a/temporalio/contrib/langgraph/langgraph_plugin.py +++ b/temporalio/contrib/langgraph/langgraph_plugin.py @@ -7,10 +7,11 @@ import sys import warnings from dataclasses import replace -from typing import Any +from typing import Any, Callable from langgraph._internal._runnable import RunnableCallable from langgraph.graph import StateGraph +from langgraph.pregel import Pregel from temporalio import activity from temporalio.contrib.langgraph.activity import wrap_activity, wrap_execute_activity @@ -23,6 +24,10 @@ from temporalio.worker import WorkflowRunner from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner +# Save registered graphs/entrypoints at the module level to avoid being refreshed by the sandbox. +_graph_registry: dict[str, StateGraph[Any]] = {} +_entrypoint_registry: dict[str, Pregel[Any, Any, Any, Any]] = {} + class LangGraphPlugin(SimplePlugin): """LangGraph plugin for Temporal SDK. @@ -35,23 +40,20 @@ class LangGraphPlugin(SimplePlugin): and tasks as Temporal Activities, giving your AI agent workflows durable execution, automatic retries, and timeouts. It supports both the LangGraph Graph API (``StateGraph``) and Functional API (``@entrypoint`` / ``@task``). - - Pass your graphs and tasks to the plugin; the plugin mutates them in place so - node/task invocations dispatch to Temporal activities. The modules those - graphs and tasks are defined in are automatically added to the workflow - sandbox's passthrough list, so the mutation is visible inside the sandbox. - Keep your ``@workflow.defn`` classes in a module separate from your graphs - and tasks (the standard Temporal convention). """ def __init__( self, - graphs: list[StateGraph[Any, Any, Any, Any]] | None = None, + # Graph API + graphs: dict[str, StateGraph[Any, Any, Any, Any]] | None = None, + # Functional API + entrypoints: dict[str, Pregel[Any, Any, Any, Any]] | None = None, tasks: list | None = None, + # TODO: Remove activity_options when we have support for @task(metadata=...) activity_options: dict[str, dict] | None = None, default_activity_options: dict[str, Any] | None = None, ): - """Register activities for graphs and tasks.""" + """Initialize the LangGraph plugin with graphs, entrypoints, and tasks.""" if sys.version_info < (3, 11): warnings.warn( # type: ignore[reportUnreachable] "LangGraphPlugin requires Python >= 3.11 for full async support. " @@ -64,10 +66,11 @@ def __init__( ) self.activities: list = [] - passthrough_modules: set[str] = set() + # Graph API: Wrap graph nodes as Temporal Activities. if graphs: - for graph in graphs: + _graph_registry.update(graphs) + for graph_name, graph in graphs.items(): for node_name, node in graph.nodes.items(): runnable = node.runnable if ( @@ -77,25 +80,28 @@ def __init__( raise ValueError( f"Node {node_name} must have an async function" ) - # Remove LangSmith-related callback functions that can't be - # serialized between the workflow and activity. + # Remove LangSmith-related callback functions that can't be serialized between the workflow and activity. runnable.func_accepts = {} opts = {**(default_activity_options or {}), **(node.metadata or {})} - runnable.afunc = self._wrap( - runnable.afunc, opts, passthrough_modules + runnable.afunc = self.execute( + f"{graph_name}.{node_name}", runnable.afunc, opts ) + if entrypoints: + _entrypoint_registry.update(entrypoints) + + # Functional API: Wrap @task functions as Temporal Activities. if tasks: - for t in tasks: - name = t.func.__name__ - qualname = getattr(t.func, "__qualname__", name) + for task in tasks: + name = task.func.__name__ opts = { **(default_activity_options or {}), **(activity_options or {}).get(name, {}), } - t.func = self._wrap(t.func, opts, passthrough_modules) - t.func.__name__ = name - t.func.__qualname__ = qualname + + task.func = self.execute(task_id(task.func), task.func, opts) + task.func.__name__ = name + task.func.__qualname__ = getattr(task.func, "__qualname__", name) def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: if not runner: @@ -109,7 +115,6 @@ def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: "langgraph", "langsmith", "numpy", # LangSmith uses numpy - *passthrough_modules, ), ) return runner @@ -120,64 +125,71 @@ def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: workflow_runner=workflow_runner, ) - def _wrap( + def execute( self, - func: Any, - opts: dict[str, Any], - passthrough_modules: set[str], - ) -> Any: - """Wrap a node afunc or task func as an activity. Idempotent across plugins. - - Records the activity defn on ``self.activities`` and the function's - origin module on ``passthrough_modules``. If ``func`` is already wrapped - (e.g., a second plugin sharing the same graph), reuses the cached - activity defn and module — no double-wrap. - """ - meta = getattr(func, "_temporal_meta", None) - if meta is not None: - a, module = meta - if a is not None: - self.activities.append(a) - if module: - passthrough_modules.add(module) - return func - - module = getattr(func, "__module__", None) + activity_name: str, + func: Callable, + kwargs: dict[str, Any] | None = None, + ) -> Callable: + """Prepare a node or task to execute as an activity or inline in the workflow.""" + opts = kwargs or {} execute_in = opts.pop("execute_in", "activity") + if execute_in == "activity": - activity_name = task_id(func) a = activity.defn(name=activity_name)(wrap_activity(func)) self.activities.append(a) - wrapped = wrap_execute_activity(a, task_id=activity_name, **opts) + return wrap_execute_activity(a, task_id=task_id(func), **opts) elif execute_in == "workflow": - a = None - wrapped = func + return func else: raise ValueError(f"Invalid execute_in value: {execute_in}") - if module: - passthrough_modules.add(module) - try: - setattr(wrapped, "_temporal_meta", (a, module)) - except (AttributeError, TypeError): - pass - return wrapped +def graph( + name: str, cache: dict[str, Any] | None = None +) -> StateGraph[Any, None, Any, Any]: + """Retrieve a registered graph by name. -def set_cache(cache: dict[str, Any] | None) -> None: - """Restore a task result cache returned by a previous :func:`cache` call. + Args: + name: Graph name as registered with LangGraphPlugin. + cache: Optional task result cache from a previous cache() call. + Restores cached results so previously-completed nodes are + not re-executed after continue-as-new. + """ + set_task_cache(cache or {}) + if name not in _graph_registry: + raise KeyError( + f"Graph {name!r} not found. " + f"Available graphs: {list(_graph_registry.keys())}" + ) + return _graph_registry[name] - Use at the top of a workflow run that resumes from continue-as-new so - already-completed nodes/tasks are not re-executed. + +def entrypoint( + name: str, cache: dict[str, Any] | None = None +) -> Pregel[Any, Any, Any, Any]: + """Retrieve a registered entrypoint by name. + + Args: + name: Entrypoint name as registered with Plugin. + cache: Optional task result cache from a previous cache() call. + Restores cached results so previously-completed tasks are + not re-executed after continue-as-new. """ set_task_cache(cache or {}) + if name not in _entrypoint_registry: + raise KeyError( + f"Entrypoint {name!r} not found. " + f"Available entrypoints: {list(_entrypoint_registry.keys())}" + ) + return _entrypoint_registry[name] def cache() -> dict[str, Any] | None: """Return the task result cache as a serializable dict. - Returns a dict suitable for passing to :func:`set_cache` on the next - workflow run to restore cached task results across continue-as-new - boundaries. Returns None if the cache is empty. + Returns a dict suitable for passing to entrypoint(name, cache=...) to + restore cached task results across continue-as-new boundaries. + Returns None if the cache is empty. """ return get_task_cache() or None diff --git a/tests/contrib/langgraph/e2e_functional_workflows.py b/tests/contrib/langgraph/e2e_functional_workflows.py index 2526a1fda..f467d1758 100644 --- a/tests/contrib/langgraph/e2e_functional_workflows.py +++ b/tests/contrib/langgraph/e2e_functional_workflows.py @@ -6,27 +6,27 @@ from typing import Any from temporalio import workflow -from temporalio.contrib.langgraph.langgraph_plugin import cache, set_cache -from tests.contrib.langgraph.e2e_functional_entrypoints import ( - continue_as_new_entrypoint, - partial_execution_entrypoint, - simple_functional_entrypoint, - slow_entrypoint, -) +from temporalio.contrib.langgraph.langgraph_plugin import cache, entrypoint @workflow.defn class SimpleFunctionalE2EWorkflow: + def __init__(self) -> None: + self.app = entrypoint("e2e_simple_functional") + @workflow.run async def run(self, input_value: int) -> dict: - return await simple_functional_entrypoint.ainvoke(input_value) + return await self.app.ainvoke(input_value) @workflow.defn class SlowFunctionalWorkflow: + def __init__(self) -> None: + self.app = entrypoint("e2e_slow_functional") + @workflow.run async def run(self, input_value: int) -> dict: - return await slow_entrypoint.ainvoke(input_value) + return await self.app.ainvoke(input_value) @dataclass @@ -43,9 +43,9 @@ class ContinueAsNewFunctionalWorkflow: @workflow.run async def run(self, input_data: ContinueAsNewInput) -> dict[str, Any]: - set_cache(input_data.cache) + app = entrypoint("e2e_continue_as_new_functional", cache=input_data.cache) - result = await continue_as_new_entrypoint.ainvoke(input_data.value) + result = await app.ainvoke(input_data.value) if not input_data.task_a_done: workflow.continue_as_new( @@ -82,12 +82,10 @@ class PartialExecutionWorkflow: @workflow.run async def run(self, input_data: PartialExecutionInput) -> dict[str, Any]: - set_cache(input_data.cache) + app = entrypoint("e2e_partial_execution", cache=input_data.cache) if input_data.phase == 1: - await partial_execution_entrypoint.ainvoke( - {"value": input_data.value, "stop_after": 3} - ) + await app.ainvoke({"value": input_data.value, "stop_after": 3}) workflow.continue_as_new( PartialExecutionInput( value=input_data.value, @@ -96,6 +94,4 @@ async def run(self, input_data: PartialExecutionInput) -> dict[str, Any]: ) ) - return await partial_execution_entrypoint.ainvoke( - {"value": input_data.value, "stop_after": 5} - ) + return await app.ainvoke({"value": input_data.value, "stop_after": 5}) diff --git a/tests/contrib/langgraph/test_continue_as_new.py b/tests/contrib/langgraph/test_continue_as_new.py index 7d2c3d0ad..2458c07a5 100644 --- a/tests/contrib/langgraph/test_continue_as_new.py +++ b/tests/contrib/langgraph/test_continue_as_new.py @@ -11,7 +11,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph from temporalio.worker import Worker @@ -23,18 +23,13 @@ async def node(state: State) -> dict[str, str]: return {"value": state["value"] + "a"} -my_graph: StateGraph[State, None, State, State] = StateGraph(State) -my_graph.add_node("node", node) -my_graph.add_edge(START, "node") - - @workflow.defn class ContinueAsNewWorkflow: def __init__(self) -> None: - self.app = my_graph.compile(checkpointer=InMemorySaver()) + self.app = graph("my-graph").compile(checkpointer=InMemorySaver()) @workflow.run - async def run(self, values: State) -> Any: + async def run(self, values: dict[str, str]) -> Any: config = RunnableConfig({"configurable": {"thread_id": "1"}}) await self.app.aupdate_state(config, values) @@ -48,6 +43,10 @@ async def run(self, values: State) -> Any: async def test_continue_as_new(client: Client): + g = StateGraph(State) + g.add_node("node", node) + g.add_edge(START, "node") + task_queue = f"my-graph-{uuid4()}" async with Worker( @@ -56,7 +55,7 @@ async def test_continue_as_new(client: Client): workflows=[ContinueAsNewWorkflow], plugins=[ LangGraphPlugin( - graphs=[my_graph], + graphs={"my-graph": g}, default_activity_options={ "start_to_close_timeout": timedelta(seconds=10) }, diff --git a/tests/contrib/langgraph/test_continue_as_new_cached.py b/tests/contrib/langgraph/test_continue_as_new_cached.py index b9bfd4f8d..444fdcd4a 100644 --- a/tests/contrib/langgraph/test_continue_as_new_cached.py +++ b/tests/contrib/langgraph/test_continue_as_new_cached.py @@ -14,11 +14,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.langgraph_plugin import ( - LangGraphPlugin, - cache, - set_cache, -) +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, cache, graph from temporalio.worker import Worker # Track execution counts to verify caching @@ -48,16 +44,6 @@ async def double(state: State) -> dict[str, int]: return {"value": state["value"] * 2} -_timeout = {"start_to_close_timeout": timedelta(seconds=10)} -cached_graph: StateGraph[State, None, State, State] = StateGraph(State) -cached_graph.add_node("multiply_by_3", multiply_by_3, metadata=_timeout) -cached_graph.add_node("add_100", add_100, metadata=_timeout) -cached_graph.add_node("double", double, metadata=_timeout) -cached_graph.add_edge(START, "multiply_by_3") -cached_graph.add_edge("multiply_by_3", "add_100") -cached_graph.add_edge("add_100", "double") - - @dataclass class GraphContinueAsNewInput: value: int @@ -79,8 +65,7 @@ class GraphContinueAsNewWorkflow: @workflow.run async def run(self, input_data: GraphContinueAsNewInput) -> dict[str, int]: - set_cache(input_data.cache) - app = cached_graph.compile() + app = graph("cached-graph", cache=input_data.cache).compile() result = await app.ainvoke({"value": input_data.value}) if input_data.phase < 3: @@ -103,13 +88,22 @@ async def test_graph_continue_as_new_cached(client: Client): """ _reset() + timeout = {"start_to_close_timeout": timedelta(seconds=10)} + g = StateGraph(State) + g.add_node("multiply_by_3", multiply_by_3, metadata=timeout) + g.add_node("add_100", add_100, metadata=timeout) + g.add_node("double", double, metadata=timeout) + g.add_edge(START, "multiply_by_3") + g.add_edge("multiply_by_3", "add_100") + g.add_edge("add_100", "double") + task_queue = f"graph-cached-{uuid4()}" async with Worker( client, task_queue=task_queue, workflows=[GraphContinueAsNewWorkflow], - plugins=[LangGraphPlugin(graphs=[cached_graph])], + plugins=[LangGraphPlugin(graphs={"cached-graph": g})], ): result = await client.execute_workflow( GraphContinueAsNewWorkflow.run, diff --git a/tests/contrib/langgraph/test_e2e_functional.py b/tests/contrib/langgraph/test_e2e_functional.py index a9e457f51..8edbfd922 100644 --- a/tests/contrib/langgraph/test_e2e_functional.py +++ b/tests/contrib/langgraph/test_e2e_functional.py @@ -30,18 +30,22 @@ from temporalio import workflow from temporalio.client import Client, WorkflowFailureError from temporalio.common import RetryPolicy -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, entrypoint from temporalio.worker import Worker from tests.contrib.langgraph.e2e_functional_entrypoints import ( add_ten, ask_human, + continue_as_new_entrypoint, double_value, expensive_task_a, expensive_task_b, expensive_task_c, get_task_execution_counts, interrupt_entrypoint, + partial_execution_entrypoint, reset_task_execution_counts, + simple_functional_entrypoint, + slow_entrypoint, slow_task, step_1, step_2, @@ -84,16 +88,20 @@ async def simple_v2_entrypoint(value: int) -> dict: @workflow.defn class SimpleV2Workflow: + def __init__(self) -> None: + self.app = entrypoint("v2_simple") + @workflow.run async def run(self, input_value: int) -> dict[str, Any]: - result = await simple_v2_entrypoint.ainvoke(input_value, version="v2") + result = await self.app.ainvoke(input_value, version="v2") return result.value @workflow.defn class InterruptV2FunctionalWorkflow: def __init__(self) -> None: - interrupt_entrypoint.checkpointer = InMemorySaver() + self.app = entrypoint("v2_interrupt") + self.app.checkpointer = InMemorySaver() @workflow.run async def run(self, input_value: str) -> dict[str, Any]: @@ -101,13 +109,13 @@ async def run(self, input_value: str) -> dict[str, Any]: {"configurable": {"thread_id": workflow.info().workflow_id}} ) - result = await interrupt_entrypoint.ainvoke(input_value, config, version="v2") + result = await self.app.ainvoke(input_value, config, version="v2") assert result.value == {} assert len(result.interrupts) == 1 assert result.interrupts[0].value == "Do you approve?" - resumed = await interrupt_entrypoint.ainvoke( + resumed = await self.app.ainvoke( Command(resume="approved"), config, version="v2" ) return resumed.value @@ -115,10 +123,22 @@ async def run(self, input_value: str) -> dict[str, Any]: class TestFunctionalAPIBasicExecution: @pytest.mark.parametrize( - "workflow_cls,tasks,expected_result", + "workflow_cls,entrypoint_func,entrypoint_name,tasks,expected_result", [ - (SimpleFunctionalE2EWorkflow, [double_value, add_ten], 30), - (SimpleV2Workflow, [triple_value, add_five], 35), + ( + SimpleFunctionalE2EWorkflow, + simple_functional_entrypoint, + "e2e_simple_functional", + [double_value, add_ten], + 30, + ), + ( + SimpleV2Workflow, + simple_v2_entrypoint, + "v2_simple", + [triple_value, add_five], + 35, + ), ], ids=["v1", "v2"], ) @@ -126,6 +146,8 @@ async def test_simple_entrypoint( self, client: Client, workflow_cls: Any, + entrypoint_func: Any, + entrypoint_name: str, tasks: list, expected_result: int, ) -> None: @@ -137,6 +159,7 @@ async def test_simple_entrypoint( workflows=[workflow_cls], plugins=[ LangGraphPlugin( + entrypoints={entrypoint_name: entrypoint_func}, tasks=tasks, default_activity_options=_DEFAULT_ACTIVITY_OPTIONS, ) @@ -167,6 +190,9 @@ async def test_continue_as_new_with_checkpoint(self, client: Client) -> None: workflows=[ContinueAsNewFunctionalWorkflow], plugins=[ LangGraphPlugin( + entrypoints={ + "e2e_continue_as_new_functional": continue_as_new_entrypoint + }, tasks=tasks, default_activity_options=_DEFAULT_ACTIVITY_OPTIONS, ) @@ -208,6 +234,7 @@ async def test_partial_execution_five_tasks(self, client: Client) -> None: workflows=[PartialExecutionWorkflow], plugins=[ LangGraphPlugin( + entrypoints={"e2e_partial_execution": partial_execution_entrypoint}, tasks=tasks, default_activity_options=_DEFAULT_ACTIVITY_OPTIONS, ) @@ -234,6 +261,7 @@ async def test_partial_execution_five_tasks(self, client: Client) -> None: class TestFunctionalAPIInterruptV2: async def test_interrupt_v2_functional(self, client: Client) -> None: """version='v2' separates interrupts from value in functional API.""" + tasks = [ask_human] task_queue = f"v2-interrupt-{uuid4()}" async with Worker( @@ -242,7 +270,8 @@ async def test_interrupt_v2_functional(self, client: Client) -> None: workflows=[InterruptV2FunctionalWorkflow], plugins=[ LangGraphPlugin( - tasks=[ask_human], + entrypoints={"v2_interrupt": interrupt_entrypoint}, + tasks=tasks, default_activity_options=_DEFAULT_ACTIVITY_OPTIONS, ) ], @@ -270,6 +299,7 @@ async def test_per_task_activity_options_override(self, client: Client) -> None: workflows=[SlowFunctionalWorkflow], plugins=[ LangGraphPlugin( + entrypoints={"e2e_slow_functional": slow_entrypoint}, tasks=[slow_task], default_activity_options=_DEFAULT_ACTIVITY_OPTIONS, activity_options={ diff --git a/tests/contrib/langgraph/test_execute_in_workflow.py b/tests/contrib/langgraph/test_execute_in_workflow.py index bcc06ae16..76037c037 100644 --- a/tests/contrib/langgraph/test_execute_in_workflow.py +++ b/tests/contrib/langgraph/test_execute_in_workflow.py @@ -6,7 +6,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph from temporalio.worker import Worker @@ -18,15 +18,10 @@ async def node(state: State) -> dict[str, str]: # pyright: ignore[reportUnusedP return {"value": "done"} -inline_graph: StateGraph[State, None, State, State] = StateGraph(State) -inline_graph.add_node("node", node, metadata={"execute_in": "workflow"}) -inline_graph.add_edge(START, "node") - - @workflow.defn class ExecuteInWorkflowWorkflow: def __init__(self) -> None: - self.app = inline_graph.compile() + self.app = graph("my-graph").compile() @workflow.run async def run(self, input: str) -> Any: @@ -34,13 +29,17 @@ async def run(self, input: str) -> Any: async def test_execute_in_workflow(client: Client): + g = StateGraph(State) + g.add_node("node", node, metadata={"execute_in": "workflow"}) + g.add_edge(START, "node") + task_queue = f"my-graph-{uuid4()}" async with Worker( client, task_queue=task_queue, workflows=[ExecuteInWorkflowWorkflow], - plugins=[LangGraphPlugin(graphs=[inline_graph])], + plugins=[LangGraphPlugin(graphs={"my-graph": g})], ): result = await client.execute_workflow( ExecuteInWorkflowWorkflow.run, diff --git a/tests/contrib/langgraph/test_interrupt.py b/tests/contrib/langgraph/test_interrupt.py index ecf6068e8..838539853 100644 --- a/tests/contrib/langgraph/test_interrupt.py +++ b/tests/contrib/langgraph/test_interrupt.py @@ -19,7 +19,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph from temporalio.worker import Worker @@ -31,15 +31,10 @@ async def node(state: State) -> dict[str, str]: # pyright: ignore[reportUnusedP return {"value": langgraph.types.interrupt("Continue?")} -interrupt_graph: StateGraph[State, None, State, State] = StateGraph(State) -interrupt_graph.add_node("node", node) -interrupt_graph.add_edge(START, "node") - - @workflow.defn class InterruptWorkflow: def __init__(self) -> None: - self.app = interrupt_graph.compile(checkpointer=InMemorySaver()) + self.app = graph("my-graph").compile(checkpointer=InMemorySaver()) @workflow.run async def run(self, input: str) -> Any: @@ -54,7 +49,7 @@ async def run(self, input: str) -> Any: @workflow.defn class InterruptV2Workflow: def __init__(self) -> None: - self.app = interrupt_graph.compile(checkpointer=InMemorySaver()) + self.app = graph("my-graph").compile(checkpointer=InMemorySaver()) @workflow.run async def run(self, input: str) -> Any: @@ -73,6 +68,10 @@ async def run(self, input: str) -> Any: "workflow_cls", [InterruptWorkflow, InterruptV2Workflow], ids=["v1", "v2"] ) async def test_interrupt(client: Client, workflow_cls: Any) -> None: + g = StateGraph(State) + g.add_node("node", node) + g.add_edge(START, "node") + task_queue = f"interrupt-{uuid4()}" async with Worker( @@ -81,7 +80,7 @@ async def test_interrupt(client: Client, workflow_cls: Any) -> None: workflows=[workflow_cls], plugins=[ LangGraphPlugin( - graphs=[interrupt_graph], + graphs={"my-graph": g}, default_activity_options={ "start_to_close_timeout": timedelta(seconds=10) }, diff --git a/tests/contrib/langgraph/test_plugin_validation.py b/tests/contrib/langgraph/test_plugin_validation.py index c693917f0..56a2c4f15 100644 --- a/tests/contrib/langgraph/test_plugin_validation.py +++ b/tests/contrib/langgraph/test_plugin_validation.py @@ -1,13 +1,19 @@ -"""Tests for LangGraphPlugin validation.""" +"""Tests for LangGraphPlugin validation and registry lookup error paths.""" from __future__ import annotations +from uuid import uuid4 + from langchain_core.runnables import RunnableLambda from langgraph.graph import START, StateGraph # pyright: ignore[reportMissingTypeStubs] from pytest import raises from typing_extensions import TypedDict -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin +from temporalio.contrib.langgraph.langgraph_plugin import ( + LangGraphPlugin, + entrypoint, + graph, +) class State(TypedDict): @@ -24,18 +30,28 @@ def sync_node(state: State) -> dict[str, str]: # pyright: ignore[reportUnusedPa def test_non_runnable_callable_node_raises() -> None: """Nodes whose runnable isn't a RunnableCallable can't be wrapped as activities.""" - g: StateGraph[State, None, State, State] = StateGraph(State) + g = StateGraph(State) g.add_node("node", RunnableLambda(sync_node)) g.add_edge(START, "node") with raises(ValueError, match="must have an async function"): - LangGraphPlugin(graphs=[g]) + LangGraphPlugin(graphs={f"validation-{uuid4()}": g}) def test_invalid_execute_in_raises() -> None: - g: StateGraph[State, None, State, State] = StateGraph(State) + g = StateGraph(State) g.add_node("node", async_node, metadata={"execute_in": "bogus"}) g.add_edge(START, "node") with raises(ValueError, match="Invalid execute_in value"): - LangGraphPlugin(graphs=[g]) + LangGraphPlugin(graphs={f"validation-{uuid4()}": g}) + + +async def test_unknown_graph_raises() -> None: + with raises(KeyError, match="not found"): + graph(f"not-registered-{uuid4()}") + + +async def test_unknown_entrypoint_raises() -> None: + with raises(KeyError, match="not found"): + entrypoint(f"not-registered-{uuid4()}") diff --git a/tests/contrib/langgraph/test_replay.py b/tests/contrib/langgraph/test_replay.py index a5821f910..d246fe3a0 100644 --- a/tests/contrib/langgraph/test_replay.py +++ b/tests/contrib/langgraph/test_replay.py @@ -3,25 +3,40 @@ from uuid import uuid4 import pytest +from langgraph.graph import START, StateGraph # pyright: ignore[reportMissingTypeStubs] from temporalio.client import Client from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin from temporalio.worker import Replayer, Worker from tests.contrib.langgraph.test_interrupt import ( InterruptWorkflow, - interrupt_graph, +) +from tests.contrib.langgraph.test_interrupt import ( + State as InterruptState, +) +from tests.contrib.langgraph.test_interrupt import ( + node as interrupt_node, ) from tests.contrib.langgraph.test_two_nodes import ( + State, TwoNodesWorkflow, - my_graph, + node_a, + node_b, ) -_DEFAULTS = {"start_to_close_timeout": timedelta(seconds=10)} - async def test_replay(client: Client): + g = StateGraph(State) + g.add_node("node_a", node_a) + g.add_node("node_b", node_b) + g.add_edge(START, "node_a") + g.add_edge("node_a", "node_b") + task_queue = f"my-graph-{uuid4()}" - plugin = LangGraphPlugin(graphs=[my_graph], default_activity_options=_DEFAULTS) + plugin = LangGraphPlugin( + graphs={"my-graph": g}, + default_activity_options={"start_to_close_timeout": timedelta(seconds=10)}, + ) async with Worker( client, @@ -48,9 +63,14 @@ async def test_replay(client: Client): reason="langgraph.types.interrupt() requires Python >= 3.11 for async context propagation", ) async def test_replay_interrupt(client: Client): + g = StateGraph(InterruptState) + g.add_node("node", interrupt_node) + g.add_edge(START, "node") + task_queue = f"interrupt-replay-{uuid4()}" plugin = LangGraphPlugin( - graphs=[interrupt_graph], default_activity_options=_DEFAULTS + graphs={"my-graph": g}, + default_activity_options={"start_to_close_timeout": timedelta(seconds=10)}, ) async with Worker( diff --git a/tests/contrib/langgraph/test_streaming.py b/tests/contrib/langgraph/test_streaming.py index 60dd833fa..1c4b19132 100644 --- a/tests/contrib/langgraph/test_streaming.py +++ b/tests/contrib/langgraph/test_streaming.py @@ -7,7 +7,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph from temporalio.worker import Worker @@ -23,17 +23,10 @@ async def node_b(state: State) -> dict[str, str]: return {"value": state["value"] + "b"} -streaming_graph: StateGraph[State, None, State, State] = StateGraph(State) -streaming_graph.add_node("node_a", node_a) -streaming_graph.add_node("node_b", node_b) -streaming_graph.add_edge(START, "node_a") -streaming_graph.add_edge("node_a", "node_b") - - @workflow.defn class StreamingWorkflow: def __init__(self) -> None: - self.app = streaming_graph.compile() + self.app = graph("streaming").compile() @workflow.run async def run(self, input: str) -> Any: @@ -44,6 +37,12 @@ async def run(self, input: str) -> Any: async def test_streaming(client: Client): + g = StateGraph(State) + g.add_node("node_a", node_a) + g.add_node("node_b", node_b) + g.add_edge(START, "node_a") + g.add_edge("node_a", "node_b") + task_queue = f"streaming-{uuid4()}" async with Worker( @@ -52,7 +51,7 @@ async def test_streaming(client: Client): workflows=[StreamingWorkflow], plugins=[ LangGraphPlugin( - graphs=[streaming_graph], + graphs={"streaming": g}, default_activity_options={ "start_to_close_timeout": timedelta(seconds=10) }, diff --git a/tests/contrib/langgraph/test_subgraph_activity.py b/tests/contrib/langgraph/test_subgraph_activity.py index f7e024914..e752719bb 100644 --- a/tests/contrib/langgraph/test_subgraph_activity.py +++ b/tests/contrib/langgraph/test_subgraph_activity.py @@ -7,7 +7,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph from temporalio.worker import Worker @@ -27,15 +27,10 @@ async def parent_node(state: State) -> dict[str, str]: return await child.compile().ainvoke(state) -parent_graph: StateGraph[State, None, State, State] = StateGraph(State) -parent_graph.add_node("parent_node", parent_node) -parent_graph.add_edge(START, "parent_node") - - @workflow.defn class ActivitySubgraphWorkflow: def __init__(self) -> None: - self.app = parent_graph.compile() + self.app = graph("parent").compile() @workflow.run async def run(self, input: str) -> Any: @@ -43,6 +38,10 @@ async def run(self, input: str) -> Any: async def test_activity_subgraph(client: Client): + parent = StateGraph(State) + parent.add_node("parent_node", parent_node) + parent.add_edge(START, "parent_node") + task_queue = f"subgraph-{uuid4()}" async with Worker( @@ -51,7 +50,7 @@ async def test_activity_subgraph(client: Client): workflows=[ActivitySubgraphWorkflow], plugins=[ LangGraphPlugin( - graphs=[parent_graph], + graphs={"parent": parent}, default_activity_options={ "start_to_close_timeout": timedelta(seconds=10) }, diff --git a/tests/contrib/langgraph/test_subgraph_workflow.py b/tests/contrib/langgraph/test_subgraph_workflow.py index 114df1a3a..d85ce25a1 100644 --- a/tests/contrib/langgraph/test_subgraph_workflow.py +++ b/tests/contrib/langgraph/test_subgraph_workflow.py @@ -7,7 +7,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph from temporalio.worker import Worker @@ -19,28 +19,14 @@ async def child_node(state: State) -> dict[str, str]: # pyright: ignore[reportU return {"value": "child"} -child_graph: StateGraph[State, None, State, State] = StateGraph(State) -child_graph.add_node( - "child_node", - child_node, - metadata={"start_to_close_timeout": timedelta(seconds=10)}, -) -child_graph.add_edge(START, "child_node") - - async def parent_node(state: State) -> dict[str, str]: - return await child_graph.compile().ainvoke(state) - - -parent_graph: StateGraph[State, None, State, State] = StateGraph(State) -parent_graph.add_node("parent_node", parent_node, metadata={"execute_in": "workflow"}) -parent_graph.add_edge(START, "parent_node") + return await graph("child").compile().ainvoke(state) @workflow.defn class WorkflowSubgraphWorkflow: def __init__(self) -> None: - self.app = parent_graph.compile() + self.app = graph("parent").compile() @workflow.run async def run(self, input: str) -> Any: @@ -48,13 +34,25 @@ async def run(self, input: str) -> Any: async def test_workflow_subgraph(client: Client): + child = StateGraph(State) + child.add_node( + "child_node", + child_node, + metadata={"start_to_close_timeout": timedelta(seconds=10)}, + ) + child.add_edge(START, "child_node") + + parent = StateGraph(State) + parent.add_node("parent_node", parent_node, metadata={"execute_in": "workflow"}) + parent.add_edge(START, "parent_node") + task_queue = f"subgraph-{uuid4()}" async with Worker( client, task_queue=task_queue, workflows=[WorkflowSubgraphWorkflow], - plugins=[LangGraphPlugin(graphs=[parent_graph, child_graph])], + plugins=[LangGraphPlugin(graphs={"parent": parent, "child": child})], ): result = await client.execute_workflow( WorkflowSubgraphWorkflow.run, diff --git a/tests/contrib/langgraph/test_timeout.py b/tests/contrib/langgraph/test_timeout.py index 41a78e557..22c2930bc 100644 --- a/tests/contrib/langgraph/test_timeout.py +++ b/tests/contrib/langgraph/test_timeout.py @@ -10,7 +10,7 @@ from temporalio import workflow from temporalio.client import Client, WorkflowFailureError from temporalio.common import RetryPolicy -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph from temporalio.worker import Worker @@ -23,15 +23,10 @@ async def node(state: State) -> dict[str, str]: # pyright: ignore[reportUnusedP return {"value": "done"} -timeout_graph: StateGraph[State, None, State, State] = StateGraph(State) -timeout_graph.add_node("node", node) -timeout_graph.add_edge(START, "node") - - @workflow.defn class TimeoutWorkflow: def __init__(self) -> None: - self.app = timeout_graph.compile() + self.app = graph("my-graph").compile() @workflow.run async def run(self, input: str) -> Any: @@ -39,6 +34,10 @@ async def run(self, input: str) -> Any: async def test_timeout(client: Client): + g = StateGraph(State) + g.add_node("node", node) + g.add_edge(START, "node") + task_queue = f"my-graph-{uuid4()}" async with Worker( @@ -47,7 +46,7 @@ async def test_timeout(client: Client): workflows=[TimeoutWorkflow], plugins=[ LangGraphPlugin( - graphs=[timeout_graph], + graphs={"my-graph": g}, default_activity_options={ "start_to_close_timeout": timedelta(milliseconds=100), "retry_policy": RetryPolicy(maximum_attempts=1), diff --git a/tests/contrib/langgraph/test_two_nodes.py b/tests/contrib/langgraph/test_two_nodes.py index 1dbaf4d88..992e30dcd 100644 --- a/tests/contrib/langgraph/test_two_nodes.py +++ b/tests/contrib/langgraph/test_two_nodes.py @@ -7,7 +7,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph from temporalio.worker import Worker @@ -23,17 +23,10 @@ async def node_b(state: State) -> dict[str, str]: return {"value": state["value"] + "b"} -my_graph: StateGraph[State, None, State, State] = StateGraph(State) -my_graph.add_node("node_a", node_a) -my_graph.add_node("node_b", node_b) -my_graph.add_edge(START, "node_a") -my_graph.add_edge("node_a", "node_b") - - @workflow.defn class TwoNodesWorkflow: def __init__(self) -> None: - self.app = my_graph.compile() + self.app = graph("my-graph").compile() @workflow.run async def run(self, input: str) -> Any: @@ -41,6 +34,12 @@ async def run(self, input: str) -> Any: async def test_two_nodes(client: Client): + g = StateGraph(State) + g.add_node("node_a", node_a) + g.add_node("node_b", node_b) + g.add_edge(START, "node_a") + g.add_edge("node_a", "node_b") + task_queue = f"my-graph-{uuid4()}" async with Worker( @@ -49,7 +48,7 @@ async def test_two_nodes(client: Client): workflows=[TwoNodesWorkflow], plugins=[ LangGraphPlugin( - graphs=[my_graph], + graphs={"my-graph": g}, default_activity_options={ "start_to_close_timeout": timedelta(seconds=10) }, From 4eb67e5a27817e8a87b5a7add071b0b16b294255 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Mon, 20 Apr 2026 14:50:25 -0700 Subject: [PATCH 35/47] reimplement node metadata fixes --- temporalio/contrib/langgraph/activity.py | 10 +++ .../contrib/langgraph/langgraph_plugin.py | 35 +++++++++-- tests/contrib/langgraph/test_node_metadata.py | 62 +++++++++++++++++++ uv.lock | 13 +++- 4 files changed, 113 insertions(+), 7 deletions(-) create mode 100644 tests/contrib/langgraph/test_node_metadata.py diff --git a/temporalio/contrib/langgraph/activity.py b/temporalio/contrib/langgraph/activity.py index 04cdd8d4a..7690ee7c5 100644 --- a/temporalio/contrib/langgraph/activity.py +++ b/temporalio/contrib/langgraph/activity.py @@ -64,6 +64,16 @@ def wrap_execute_activity( """Wrap an activity function to be called via workflow.execute_activity with caching.""" async def wrapper(*args: Any, **kwargs: Any) -> Any: + # LangGraph may inject a RunnableConfig as the 'config' kwarg. Strip it + # down to a serializable subset (metadata + tags) so it can cross the + # activity boundary; callbacks, stores, etc. aren't serializable. + if "config" in kwargs: + orig = kwargs["config"] or {} + kwargs["config"] = { + "metadata": dict(orig.get("metadata") or {}), + "tags": list(orig.get("tags") or []), + } + # Check task result cache (for continue-as-new deduplication). key = cache_key(task_id, args, kwargs) if task_id else "" if task_id: diff --git a/temporalio/contrib/langgraph/langgraph_plugin.py b/temporalio/contrib/langgraph/langgraph_plugin.py index 432da67eb..d97c613bc 100644 --- a/temporalio/contrib/langgraph/langgraph_plugin.py +++ b/temporalio/contrib/langgraph/langgraph_plugin.py @@ -4,6 +4,7 @@ from __future__ import annotations +import inspect import sys import warnings from dataclasses import replace @@ -13,7 +14,7 @@ from langgraph.graph import StateGraph from langgraph.pregel import Pregel -from temporalio import activity +from temporalio import activity, workflow from temporalio.contrib.langgraph.activity import wrap_activity, wrap_execute_activity from temporalio.contrib.langgraph.task_cache import ( get_task_cache, @@ -24,6 +25,10 @@ from temporalio.worker import WorkflowRunner from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner +_ACTIVITY_OPTION_KEYS: frozenset[str] = frozenset( + {"execute_in", *inspect.signature(workflow.execute_activity).parameters} +) + # Save registered graphs/entrypoints at the module level to avoid being refreshed by the sandbox. _graph_registry: dict[str, StateGraph[Any]] = {} _entrypoint_registry: dict[str, Pregel[Any, Any, Any, Any]] = {} @@ -80,9 +85,31 @@ def __init__( raise ValueError( f"Node {node_name} must have an async function" ) - # Remove LangSmith-related callback functions that can't be serialized between the workflow and activity. - runnable.func_accepts = {} - opts = {**(default_activity_options or {}), **(node.metadata or {})} + # Keep only 'config' injection so node functions can read + # metadata/tags. Drop writer/store/runtime/etc., which hold + # non-serializable objects that can't cross the activity + # boundary. The wrapper serializes config down to its + # portable subset before handing off to the activity. + runnable.func_accepts = { + k: v + for k, v in runnable.func_accepts.items() + if k == "config" + } + # Split node.metadata into activity options vs. user + # metadata. Activity-option keys (timeouts, retry policy, + # etc.) become kwargs to workflow.execute_activity; user + # keys stay on node.metadata so LangGraph exposes them to + # the node function via config["metadata"]. + node_meta = node.metadata or {} + node_opts = { + k: v for k, v in node_meta.items() if k in _ACTIVITY_OPTION_KEYS + } + node.metadata = { + k: v + for k, v in node_meta.items() + if k not in _ACTIVITY_OPTION_KEYS + } + opts = {**(default_activity_options or {}), **node_opts} runnable.afunc = self.execute( f"{graph_name}.{node_name}", runnable.afunc, opts ) diff --git a/tests/contrib/langgraph/test_node_metadata.py b/tests/contrib/langgraph/test_node_metadata.py new file mode 100644 index 000000000..974976718 --- /dev/null +++ b/tests/contrib/langgraph/test_node_metadata.py @@ -0,0 +1,62 @@ +from datetime import timedelta +from typing import Any +from uuid import uuid4 + +from langchain_core.runnables import RunnableConfig # pyright: ignore[reportMissingTypeStubs] +from langgraph.graph import START, StateGraph # pyright: ignore[reportMissingTypeStubs] +from typing_extensions import TypedDict + +from temporalio import workflow +from temporalio.client import Client +from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin +from temporalio.worker import Worker + + +class State(TypedDict): + value: str + + +async def node(state: State, config: RunnableConfig) -> dict[str, str]: + metadata = config.get("metadata") or {} + return {"value": state["value"] + str(metadata.get("my_key", "NOT_FOUND"))} + + +metadata_graph: StateGraph[State, None, State, State] = StateGraph(State) +metadata_graph.add_node( + "node", + node, + metadata={ + "start_to_close_timeout": timedelta(seconds=10), + "my_key": "my_value", + }, +) +metadata_graph.add_edge(START, "node") + + +@workflow.defn +class NodeMetadataWorkflow: + def __init__(self) -> None: + self.app = metadata_graph.compile() + + @workflow.run + async def run(self, input: str) -> Any: + return await self.app.ainvoke({"value": input}) + + +async def test_node_metadata_readable_in_node(client: Client): + task_queue = f"my-graph-{uuid4()}" + + async with Worker( + client, + task_queue=task_queue, + workflows=[NodeMetadataWorkflow], + plugins=[LangGraphPlugin(graphs={"my-graph": metadata_graph})], + ): + result = await client.execute_workflow( + NodeMetadataWorkflow.run, + "prefix-", + id=f"test-workflow-{uuid4()}", + task_queue=task_queue, + ) + + assert result == {"value": "prefix-my_value"} diff --git a/uv.lock b/uv.lock index a9ad989ae..4fba27fc0 100644 --- a/uv.lock +++ b/uv.lock @@ -8,6 +8,13 @@ resolution-markers = [ "python_full_version < '3.11'", ] +[options] +exclude-newer = "2026-04-13T21:30:54.856039Z" +exclude-newer-span = "P1W" + +[options.exclude-newer-package] +openai-agents = false + [[package]] name = "aioboto3" version = "15.5.0" @@ -2470,7 +2477,7 @@ wheels = [ [[package]] name = "langchain-core" -version = "1.2.29" +version = "1.2.28" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jsonpatch" }, @@ -2482,9 +2489,9 @@ dependencies = [ { name = "typing-extensions" }, { name = "uuid-utils" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a0/d8/7bdf30e4bfc5175609201806e399506a0a78a48e14367dc8b776a9b4c89c/langchain_core-1.2.29.tar.gz", hash = "sha256:cfb89c92bca81ad083eafcdfe6ec40f9803c9abf7dd166d0f8a8de1d2de03ca6", size = 846121, upload-time = "2026-04-14T20:44:58.117Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f8/a4/317a1a3ac1df33a64adb3670bf88bbe3b3d5baa274db6863a979db472897/langchain_core-1.2.28.tar.gz", hash = "sha256:271a3d8bd618f795fdeba112b0753980457fc90537c46a0c11998516a74dc2cb", size = 846119, upload-time = "2026-04-08T18:19:34.867Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/72/37/fed31f80436b1d7bb222f1f2345300a77a88215416acf8d1cb7c8fda7388/langchain_core-1.2.29-py3-none-any.whl", hash = "sha256:11f02e57ee1c24e6e0e6577acbd35df77b205d4692a3df956b03b5389cbe44a0", size = 508733, upload-time = "2026-04-14T20:44:56.712Z" }, + { url = "https://files.pythonhosted.org/packages/a8/92/32f785f077c7e898da97064f113c73fbd9ad55d1e2169cf3a391b183dedb/langchain_core-1.2.28-py3-none-any.whl", hash = "sha256:80764232581eaf8057bcefa71dbf8adc1f6a28d257ebd8b95ba9b8b452e8c6ac", size = 508727, upload-time = "2026-04-08T18:19:32.823Z" }, ] [[package]] From f62bc09724d1d4a5d3540227a9d59e9e5e0ab024 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Mon, 20 Apr 2026 15:32:52 -0700 Subject: [PATCH 36/47] scope graphs and entrypoints to workflow, rename files --- temporalio/contrib/langgraph/__init__.py | 2 +- temporalio/contrib/langgraph/interceptor.py | 55 +++++++++++++++++++ .../{langgraph_plugin.py => plugin.py} | 42 +++++++------- .../langgraph/e2e_functional_workflows.py | 2 +- .../contrib/langgraph/test_continue_as_new.py | 2 +- .../langgraph/test_continue_as_new_cached.py | 2 +- .../contrib/langgraph/test_e2e_functional.py | 2 +- .../langgraph/test_execute_in_workflow.py | 2 +- tests/contrib/langgraph/test_interrupt.py | 2 +- tests/contrib/langgraph/test_node_metadata.py | 2 +- .../langgraph/test_plugin_validation.py | 18 +----- tests/contrib/langgraph/test_replay.py | 2 +- tests/contrib/langgraph/test_streaming.py | 2 +- .../langgraph/test_subgraph_activity.py | 2 +- .../langgraph/test_subgraph_workflow.py | 2 +- tests/contrib/langgraph/test_timeout.py | 2 +- tests/contrib/langgraph/test_two_nodes.py | 2 +- 17 files changed, 93 insertions(+), 50 deletions(-) create mode 100644 temporalio/contrib/langgraph/interceptor.py rename temporalio/contrib/langgraph/{langgraph_plugin.py => plugin.py} (88%) diff --git a/temporalio/contrib/langgraph/__init__.py b/temporalio/contrib/langgraph/__init__.py index 50d8ca147..56ae647b9 100644 --- a/temporalio/contrib/langgraph/__init__.py +++ b/temporalio/contrib/langgraph/__init__.py @@ -10,7 +10,7 @@ API (``StateGraph``) and Functional API (``@entrypoint`` / ``@task``). """ -from temporalio.contrib.langgraph.langgraph_plugin import ( +from temporalio.contrib.langgraph.plugin import ( LangGraphPlugin, cache, entrypoint, diff --git a/temporalio/contrib/langgraph/interceptor.py b/temporalio/contrib/langgraph/interceptor.py new file mode 100644 index 000000000..94108e507 --- /dev/null +++ b/temporalio/contrib/langgraph/interceptor.py @@ -0,0 +1,55 @@ +"""Workflow interceptor that scopes LangGraph graphs/entrypoints to the workflow run.""" + +# pyright: reportMissingTypeStubs=false + +from __future__ import annotations + +from typing import Any + +from langgraph.graph import StateGraph +from langgraph.pregel import Pregel + +from temporalio import workflow +from temporalio.worker import ( + ExecuteWorkflowInput, + Interceptor, + WorkflowInboundInterceptor, + WorkflowInterceptorClassInput, + WorkflowOutboundInterceptor, +) + +_workflow_graphs: dict[str, dict[str, StateGraph[Any, Any, Any, Any]]] = {} +_workflow_entrypoints: dict[str, dict[str, Pregel[Any, Any, Any, Any]]] = {} + + +class LangGraphInterceptor(Interceptor): + def __init__( + self, + graphs: dict[str, StateGraph[Any, Any, Any, Any]], + entrypoints: dict[str, Pregel[Any, Any, Any, Any]], + ) -> None: + self._graphs = graphs + self._entrypoints = entrypoints + + def workflow_interceptor_class( + self, input: WorkflowInterceptorClassInput + ) -> type[WorkflowInboundInterceptor]: + graphs = self._graphs + entrypoints = self._entrypoints + + class Inbound(WorkflowInboundInterceptor): + def init(self, outbound: WorkflowOutboundInterceptor) -> None: + run_id = outbound.info().run_id + _workflow_graphs[run_id] = graphs + _workflow_entrypoints[run_id] = entrypoints + super().init(outbound) + + async def execute_workflow(self, input: ExecuteWorkflowInput) -> Any: + try: + return await self.next.execute_workflow(input) + finally: + run_id = workflow.info().run_id + _workflow_graphs.pop(run_id, None) + _workflow_entrypoints.pop(run_id, None) + + return Inbound diff --git a/temporalio/contrib/langgraph/langgraph_plugin.py b/temporalio/contrib/langgraph/plugin.py similarity index 88% rename from temporalio/contrib/langgraph/langgraph_plugin.py rename to temporalio/contrib/langgraph/plugin.py index d97c613bc..d438a393c 100644 --- a/temporalio/contrib/langgraph/langgraph_plugin.py +++ b/temporalio/contrib/langgraph/plugin.py @@ -16,6 +16,11 @@ from temporalio import activity, workflow from temporalio.contrib.langgraph.activity import wrap_activity, wrap_execute_activity +from temporalio.contrib.langgraph.interceptor import ( + LangGraphInterceptor, + _workflow_entrypoints, + _workflow_graphs, +) from temporalio.contrib.langgraph.task_cache import ( get_task_cache, set_task_cache, @@ -29,10 +34,6 @@ {"execute_in", *inspect.signature(workflow.execute_activity).parameters} ) -# Save registered graphs/entrypoints at the module level to avoid being refreshed by the sandbox. -_graph_registry: dict[str, StateGraph[Any]] = {} -_entrypoint_registry: dict[str, Pregel[Any, Any, Any, Any]] = {} - class LangGraphPlugin(SimplePlugin): """LangGraph plugin for Temporal SDK. @@ -74,7 +75,6 @@ def __init__( # Graph API: Wrap graph nodes as Temporal Activities. if graphs: - _graph_registry.update(graphs) for graph_name, graph in graphs.items(): for node_name, node in graph.nodes.items(): runnable = node.runnable @@ -91,9 +91,7 @@ def __init__( # boundary. The wrapper serializes config down to its # portable subset before handing off to the activity. runnable.func_accepts = { - k: v - for k, v in runnable.func_accepts.items() - if k == "config" + k: v for k, v in runnable.func_accepts.items() if k == "config" } # Split node.metadata into activity options vs. user # metadata. Activity-option keys (timeouts, retry policy, @@ -114,9 +112,6 @@ def __init__( f"{graph_name}.{node_name}", runnable.afunc, opts ) - if entrypoints: - _entrypoint_registry.update(entrypoints) - # Functional API: Wrap @task functions as Temporal Activities. if tasks: for task in tasks: @@ -150,6 +145,7 @@ def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: "temporalio.LangGraphPlugin", activities=self.activities, workflow_runner=workflow_runner, + interceptors=[LangGraphInterceptor(graphs or {}, entrypoints or {})], ) def execute( @@ -184,12 +180,14 @@ def graph( not re-executed after continue-as-new. """ set_task_cache(cache or {}) - if name not in _graph_registry: - raise KeyError( - f"Graph {name!r} not found. " - f"Available graphs: {list(_graph_registry.keys())}" + graphs = _workflow_graphs.get(workflow.info().run_id) + if graphs is None: + raise RuntimeError( + "graph() must be called from inside a workflow running under LangGraphPlugin" ) - return _graph_registry[name] + if name not in graphs: + raise KeyError(f"Graph {name!r} not found. Available graphs: {list(graphs)}") + return graphs[name] def entrypoint( @@ -204,12 +202,16 @@ def entrypoint( not re-executed after continue-as-new. """ set_task_cache(cache or {}) - if name not in _entrypoint_registry: + entrypoints = _workflow_entrypoints.get(workflow.info().run_id) + if entrypoints is None: + raise RuntimeError( + "entrypoint() must be called from inside a workflow running under LangGraphPlugin" + ) + if name not in entrypoints: raise KeyError( - f"Entrypoint {name!r} not found. " - f"Available entrypoints: {list(_entrypoint_registry.keys())}" + f"Entrypoint {name!r} not found. Available entrypoints: {list(entrypoints)}" ) - return _entrypoint_registry[name] + return entrypoints[name] def cache() -> dict[str, Any] | None: diff --git a/tests/contrib/langgraph/e2e_functional_workflows.py b/tests/contrib/langgraph/e2e_functional_workflows.py index f467d1758..852384c90 100644 --- a/tests/contrib/langgraph/e2e_functional_workflows.py +++ b/tests/contrib/langgraph/e2e_functional_workflows.py @@ -6,7 +6,7 @@ from typing import Any from temporalio import workflow -from temporalio.contrib.langgraph.langgraph_plugin import cache, entrypoint +from temporalio.contrib.langgraph.plugin import cache, entrypoint @workflow.defn diff --git a/tests/contrib/langgraph/test_continue_as_new.py b/tests/contrib/langgraph/test_continue_as_new.py index 2458c07a5..f75d4556b 100644 --- a/tests/contrib/langgraph/test_continue_as_new.py +++ b/tests/contrib/langgraph/test_continue_as_new.py @@ -11,7 +11,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph +from temporalio.contrib.langgraph.plugin import LangGraphPlugin, graph from temporalio.worker import Worker diff --git a/tests/contrib/langgraph/test_continue_as_new_cached.py b/tests/contrib/langgraph/test_continue_as_new_cached.py index 444fdcd4a..c24d6f544 100644 --- a/tests/contrib/langgraph/test_continue_as_new_cached.py +++ b/tests/contrib/langgraph/test_continue_as_new_cached.py @@ -14,7 +14,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, cache, graph +from temporalio.contrib.langgraph.plugin import LangGraphPlugin, cache, graph from temporalio.worker import Worker # Track execution counts to verify caching diff --git a/tests/contrib/langgraph/test_e2e_functional.py b/tests/contrib/langgraph/test_e2e_functional.py index 8edbfd922..f9e8654c8 100644 --- a/tests/contrib/langgraph/test_e2e_functional.py +++ b/tests/contrib/langgraph/test_e2e_functional.py @@ -30,7 +30,7 @@ from temporalio import workflow from temporalio.client import Client, WorkflowFailureError from temporalio.common import RetryPolicy -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, entrypoint +from temporalio.contrib.langgraph.plugin import LangGraphPlugin, entrypoint from temporalio.worker import Worker from tests.contrib.langgraph.e2e_functional_entrypoints import ( add_ten, diff --git a/tests/contrib/langgraph/test_execute_in_workflow.py b/tests/contrib/langgraph/test_execute_in_workflow.py index 76037c037..11c22fa6b 100644 --- a/tests/contrib/langgraph/test_execute_in_workflow.py +++ b/tests/contrib/langgraph/test_execute_in_workflow.py @@ -6,7 +6,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph +from temporalio.contrib.langgraph.plugin import LangGraphPlugin, graph from temporalio.worker import Worker diff --git a/tests/contrib/langgraph/test_interrupt.py b/tests/contrib/langgraph/test_interrupt.py index 838539853..f95590bff 100644 --- a/tests/contrib/langgraph/test_interrupt.py +++ b/tests/contrib/langgraph/test_interrupt.py @@ -19,7 +19,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph +from temporalio.contrib.langgraph.plugin import LangGraphPlugin, graph from temporalio.worker import Worker diff --git a/tests/contrib/langgraph/test_node_metadata.py b/tests/contrib/langgraph/test_node_metadata.py index 974976718..1e4dbe9a2 100644 --- a/tests/contrib/langgraph/test_node_metadata.py +++ b/tests/contrib/langgraph/test_node_metadata.py @@ -8,7 +8,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin +from temporalio.contrib.langgraph.plugin import LangGraphPlugin from temporalio.worker import Worker diff --git a/tests/contrib/langgraph/test_plugin_validation.py b/tests/contrib/langgraph/test_plugin_validation.py index 56a2c4f15..40f391a0d 100644 --- a/tests/contrib/langgraph/test_plugin_validation.py +++ b/tests/contrib/langgraph/test_plugin_validation.py @@ -1,4 +1,4 @@ -"""Tests for LangGraphPlugin validation and registry lookup error paths.""" +"""Tests for LangGraphPlugin validation.""" from __future__ import annotations @@ -9,11 +9,7 @@ from pytest import raises from typing_extensions import TypedDict -from temporalio.contrib.langgraph.langgraph_plugin import ( - LangGraphPlugin, - entrypoint, - graph, -) +from temporalio.contrib.langgraph.plugin import LangGraphPlugin class State(TypedDict): @@ -45,13 +41,3 @@ def test_invalid_execute_in_raises() -> None: with raises(ValueError, match="Invalid execute_in value"): LangGraphPlugin(graphs={f"validation-{uuid4()}": g}) - - -async def test_unknown_graph_raises() -> None: - with raises(KeyError, match="not found"): - graph(f"not-registered-{uuid4()}") - - -async def test_unknown_entrypoint_raises() -> None: - with raises(KeyError, match="not found"): - entrypoint(f"not-registered-{uuid4()}") diff --git a/tests/contrib/langgraph/test_replay.py b/tests/contrib/langgraph/test_replay.py index d246fe3a0..e83bb1a16 100644 --- a/tests/contrib/langgraph/test_replay.py +++ b/tests/contrib/langgraph/test_replay.py @@ -6,7 +6,7 @@ from langgraph.graph import START, StateGraph # pyright: ignore[reportMissingTypeStubs] from temporalio.client import Client -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin +from temporalio.contrib.langgraph.plugin import LangGraphPlugin from temporalio.worker import Replayer, Worker from tests.contrib.langgraph.test_interrupt import ( InterruptWorkflow, diff --git a/tests/contrib/langgraph/test_streaming.py b/tests/contrib/langgraph/test_streaming.py index 1c4b19132..ba0283479 100644 --- a/tests/contrib/langgraph/test_streaming.py +++ b/tests/contrib/langgraph/test_streaming.py @@ -7,7 +7,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph +from temporalio.contrib.langgraph.plugin import LangGraphPlugin, graph from temporalio.worker import Worker diff --git a/tests/contrib/langgraph/test_subgraph_activity.py b/tests/contrib/langgraph/test_subgraph_activity.py index e752719bb..956867e0b 100644 --- a/tests/contrib/langgraph/test_subgraph_activity.py +++ b/tests/contrib/langgraph/test_subgraph_activity.py @@ -7,7 +7,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph +from temporalio.contrib.langgraph.plugin import LangGraphPlugin, graph from temporalio.worker import Worker diff --git a/tests/contrib/langgraph/test_subgraph_workflow.py b/tests/contrib/langgraph/test_subgraph_workflow.py index d85ce25a1..bafb70b82 100644 --- a/tests/contrib/langgraph/test_subgraph_workflow.py +++ b/tests/contrib/langgraph/test_subgraph_workflow.py @@ -7,7 +7,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph +from temporalio.contrib.langgraph.plugin import LangGraphPlugin, graph from temporalio.worker import Worker diff --git a/tests/contrib/langgraph/test_timeout.py b/tests/contrib/langgraph/test_timeout.py index 22c2930bc..36331f65c 100644 --- a/tests/contrib/langgraph/test_timeout.py +++ b/tests/contrib/langgraph/test_timeout.py @@ -10,7 +10,7 @@ from temporalio import workflow from temporalio.client import Client, WorkflowFailureError from temporalio.common import RetryPolicy -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph +from temporalio.contrib.langgraph.plugin import LangGraphPlugin, graph from temporalio.worker import Worker diff --git a/tests/contrib/langgraph/test_two_nodes.py b/tests/contrib/langgraph/test_two_nodes.py index 992e30dcd..b5cdd62bb 100644 --- a/tests/contrib/langgraph/test_two_nodes.py +++ b/tests/contrib/langgraph/test_two_nodes.py @@ -7,7 +7,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.langgraph_plugin import LangGraphPlugin, graph +from temporalio.contrib.langgraph.plugin import LangGraphPlugin, graph from temporalio.worker import Worker From aa8489289b64cecb332e4dc0ac16f93541216730 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Tue, 21 Apr 2026 15:08:15 -0700 Subject: [PATCH 37/47] test sync nodes and tasks, send --- temporalio/contrib/langgraph/plugin.py | 18 ++--- .../langgraph/test_plugin_validation.py | 2 +- tests/contrib/langgraph/test_send.py | 72 +++++++++++++++++++ tests/contrib/langgraph/test_sync_node.py | 59 +++++++++++++++ tests/contrib/langgraph/test_sync_task.py | 68 ++++++++++++++++++ 5 files changed, 210 insertions(+), 9 deletions(-) create mode 100644 tests/contrib/langgraph/test_send.py create mode 100644 tests/contrib/langgraph/test_sync_node.py create mode 100644 tests/contrib/langgraph/test_sync_task.py diff --git a/temporalio/contrib/langgraph/plugin.py b/temporalio/contrib/langgraph/plugin.py index d438a393c..e6f6547f4 100644 --- a/temporalio/contrib/langgraph/plugin.py +++ b/temporalio/contrib/langgraph/plugin.py @@ -78,13 +78,11 @@ def __init__( for graph_name, graph in graphs.items(): for node_name, node in graph.nodes.items(): runnable = node.runnable - if ( - not isinstance(runnable, RunnableCallable) - or runnable.afunc is None - ): - raise ValueError( - f"Node {node_name} must have an async function" - ) + if not isinstance(runnable, RunnableCallable): + raise ValueError(f"Node {node_name} must be a RunnableCallable") + user_func = runnable.afunc or runnable.func + if user_func is None: + raise ValueError(f"Node {node_name} must have a function") # Keep only 'config' injection so node functions can read # metadata/tags. Drop writer/store/runtime/etc., which hold # non-serializable objects that can't cross the activity @@ -108,9 +106,13 @@ def __init__( if k not in _ACTIVITY_OPTION_KEYS } opts = {**(default_activity_options or {}), **node_opts} + # Route all LangGraph node calls through afunc so the async + # activity wrapper is always used. wrap_activity handles + # sync vs. async user functions inside the activity itself. runnable.afunc = self.execute( - f"{graph_name}.{node_name}", runnable.afunc, opts + f"{graph_name}.{node_name}", user_func, opts ) + runnable.func = None # Functional API: Wrap @task functions as Temporal Activities. if tasks: diff --git a/tests/contrib/langgraph/test_plugin_validation.py b/tests/contrib/langgraph/test_plugin_validation.py index 40f391a0d..0b73557e6 100644 --- a/tests/contrib/langgraph/test_plugin_validation.py +++ b/tests/contrib/langgraph/test_plugin_validation.py @@ -30,7 +30,7 @@ def test_non_runnable_callable_node_raises() -> None: g.add_node("node", RunnableLambda(sync_node)) g.add_edge(START, "node") - with raises(ValueError, match="must have an async function"): + with raises(ValueError, match="must be a RunnableCallable"): LangGraphPlugin(graphs={f"validation-{uuid4()}": g}) diff --git a/tests/contrib/langgraph/test_send.py b/tests/contrib/langgraph/test_send.py new file mode 100644 index 000000000..7b54fb1ca --- /dev/null +++ b/tests/contrib/langgraph/test_send.py @@ -0,0 +1,72 @@ +import operator +from datetime import timedelta +from typing import Annotated, Any +from uuid import uuid4 + +from langgraph.graph import END, START, StateGraph # pyright: ignore[reportMissingTypeStubs] +from langgraph.types import Send +from typing_extensions import TypedDict + +from temporalio import workflow +from temporalio.client import Client +from temporalio.contrib.langgraph.plugin import LangGraphPlugin, graph +from temporalio.worker import Worker + + +class State(TypedDict): + items: list[str] + results: Annotated[list[str], operator.add] + + +class WorkerState(TypedDict): + item: str + + +def worker(state: WorkerState) -> dict[str, list[str]]: + return {"results": [state["item"].upper()]} + + +def fan_out(state: State) -> list[Send]: + return [Send("worker", {"item": item}) for item in state["items"]] + + +@workflow.defn +class SendWorkflow: + def __init__(self) -> None: + self.app = graph("my-graph").compile() + + @workflow.run + async def run(self, items: list[str]) -> Any: + return await self.app.ainvoke({"items": items, "results": []}) + + +async def test_send(client: Client): + g = StateGraph(State) + g.add_node("worker", worker) + g.add_conditional_edges(START, fan_out, ["worker"]) + g.add_edge("worker", END) + + task_queue = f"send-{uuid4()}" + + async with Worker( + client, + task_queue=task_queue, + workflows=[SendWorkflow], + plugins=[ + LangGraphPlugin( + graphs={"my-graph": g}, + default_activity_options={ + "start_to_close_timeout": timedelta(seconds=10) + }, + ) + ], + ): + result = await client.execute_workflow( + SendWorkflow.run, + ["a", "b", "c"], + id=f"test-send-{uuid4()}", + task_queue=task_queue, + ) + + assert result["items"] == ["a", "b", "c"] + assert sorted(result["results"]) == ["A", "B", "C"] diff --git a/tests/contrib/langgraph/test_sync_node.py b/tests/contrib/langgraph/test_sync_node.py new file mode 100644 index 000000000..bb99f2858 --- /dev/null +++ b/tests/contrib/langgraph/test_sync_node.py @@ -0,0 +1,59 @@ +from datetime import timedelta +from typing import Any +from uuid import uuid4 + +from langgraph.graph import START, StateGraph # pyright: ignore[reportMissingTypeStubs] +from typing_extensions import TypedDict + +from temporalio import workflow +from temporalio.client import Client +from temporalio.contrib.langgraph.plugin import LangGraphPlugin, graph +from temporalio.worker import Worker + + +class State(TypedDict): + value: str + + +def sync_node(state: State) -> dict[str, str]: + return {"value": state["value"] + "!"} + + +@workflow.defn +class SyncNodeWorkflow: + def __init__(self) -> None: + self.app = graph("my-graph").compile() + + @workflow.run + async def run(self, input: str) -> Any: + return await self.app.ainvoke({"value": input}) + + +async def test_sync_node(client: Client): + g = StateGraph(State) + g.add_node("sync_node", sync_node) + g.add_edge(START, "sync_node") + + task_queue = f"sync-node-{uuid4()}" + + async with Worker( + client, + task_queue=task_queue, + workflows=[SyncNodeWorkflow], + plugins=[ + LangGraphPlugin( + graphs={"my-graph": g}, + default_activity_options={ + "start_to_close_timeout": timedelta(seconds=10) + }, + ) + ], + ): + result = await client.execute_workflow( + SyncNodeWorkflow.run, + "hello", + id=f"test-sync-node-{uuid4()}", + task_queue=task_queue, + ) + + assert result == {"value": "hello!"} diff --git a/tests/contrib/langgraph/test_sync_task.py b/tests/contrib/langgraph/test_sync_task.py new file mode 100644 index 000000000..cf8dfe2f8 --- /dev/null +++ b/tests/contrib/langgraph/test_sync_task.py @@ -0,0 +1,68 @@ +import sys +from datetime import timedelta +from typing import Any +from uuid import uuid4 + +import pytest + +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 11), + reason="LangGraph Functional API requires Python >= 3.11 for async context propagation", +) +from langgraph.func import ( # pyright: ignore[reportMissingTypeStubs] + entrypoint as lg_entrypoint, +) +from langgraph.func import task # pyright: ignore[reportMissingTypeStubs] + +from temporalio import workflow +from temporalio.client import Client +from temporalio.contrib.langgraph.plugin import LangGraphPlugin, entrypoint +from temporalio.worker import Worker + + +@task +def sync_task(x: int) -> int: + return x + 1 + + +@lg_entrypoint() +async def sync_task_entrypoint(value: int) -> dict[str, int]: + result = await sync_task(value) + return {"result": result} + + +@workflow.defn +class SyncTaskWorkflow: + def __init__(self) -> None: + self.app = entrypoint("sync-task") + + @workflow.run + async def run(self, input: int) -> Any: + return await self.app.ainvoke(input) + + +async def test_sync_task(client: Client): + task_queue = f"sync-task-{uuid4()}" + + async with Worker( + client, + task_queue=task_queue, + workflows=[SyncTaskWorkflow], + plugins=[ + LangGraphPlugin( + entrypoints={"sync-task": sync_task_entrypoint}, + tasks=[sync_task], + default_activity_options={ + "start_to_close_timeout": timedelta(seconds=10) + }, + ) + ], + ): + result = await client.execute_workflow( + SyncTaskWorkflow.run, + 41, + id=f"test-sync-task-{uuid4()}", + task_queue=task_queue, + ) + + assert result == {"result": 42} From 4b915b10feecbfd3dfe96b69626d440ad1061dce Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Tue, 21 Apr 2026 15:29:47 -0700 Subject: [PATCH 38/47] support command goto/update --- temporalio/contrib/langgraph/activity.py | 21 +++++++++++++++++---- tests/contrib/langgraph/test_send.py | 2 +- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/temporalio/contrib/langgraph/activity.py b/temporalio/contrib/langgraph/activity.py index 7690ee7c5..50d93c4be 100644 --- a/temporalio/contrib/langgraph/activity.py +++ b/temporalio/contrib/langgraph/activity.py @@ -6,7 +6,7 @@ from typing import Any, Callable from langgraph.errors import GraphInterrupt -from langgraph.types import Interrupt +from langgraph.types import Command, Interrupt from temporalio import workflow from temporalio.contrib.langgraph.langgraph_config import ( @@ -31,9 +31,10 @@ class ActivityInput: @dataclass class ActivityOutput: - """Output from a LangGraph activity, containing result or interrupts.""" + """Output from an Activity, containing result, command, or interrupts.""" result: Any = None + langgraph_command: Any = None langgraph_interrupts: tuple[Interrupt] | None = None @@ -49,6 +50,8 @@ async def wrapper(input: ActivityInput) -> ActivityOutput: result = await func(*input.args, **input.kwargs) else: result = func(*input.args, **input.kwargs) + if isinstance(result, Command): + return ActivityOutput(langgraph_command=result) return ActivityOutput(result=result) except GraphInterrupt as e: return ActivityOutput(langgraph_interrupts=e.args[0]) @@ -90,10 +93,20 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: if output.langgraph_interrupts is not None: raise GraphInterrupt(output.langgraph_interrupts) + result = output.result + if output.langgraph_command is not None: + cmd = output.langgraph_command + result = Command( + graph=cmd["graph"], + update=cmd["update"], + resume=cmd["resume"], + goto=cmd["goto"], + ) + # Store in cache for future continue-as-new cycles. if task_id: - cache_put(key, output.result) + cache_put(key, result) - return output.result + return result return wrapper diff --git a/tests/contrib/langgraph/test_send.py b/tests/contrib/langgraph/test_send.py index 7b54fb1ca..570f8517b 100644 --- a/tests/contrib/langgraph/test_send.py +++ b/tests/contrib/langgraph/test_send.py @@ -26,7 +26,7 @@ def worker(state: WorkerState) -> dict[str, list[str]]: return {"results": [state["item"].upper()]} -def fan_out(state: State) -> list[Send]: +async def fan_out(state: State) -> list[Send]: return [Send("worker", {"item": item}) for item in state["items"]] From c36423b320bb5490d832e00f0bb158da73a509ff Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Tue, 21 Apr 2026 15:30:15 -0700 Subject: [PATCH 39/47] add test for command --- tests/contrib/langgraph/test_command.py | 65 +++++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 tests/contrib/langgraph/test_command.py diff --git a/tests/contrib/langgraph/test_command.py b/tests/contrib/langgraph/test_command.py new file mode 100644 index 000000000..c983e6728 --- /dev/null +++ b/tests/contrib/langgraph/test_command.py @@ -0,0 +1,65 @@ +from datetime import timedelta +from typing import Any, Literal +from uuid import uuid4 + +from langgraph.graph import END, START, StateGraph # pyright: ignore[reportMissingTypeStubs] +from langgraph.types import Command +from typing_extensions import TypedDict + +from temporalio import workflow +from temporalio.client import Client +from temporalio.contrib.langgraph.plugin import LangGraphPlugin, graph +from temporalio.worker import Worker + + +class State(TypedDict): + value: str + + +def node_a(state: State) -> Command[Literal["node_b"]]: + return Command(update={"value": state["value"] + "a"}, goto="node_b") + + +def node_b(state: State) -> Command[Literal["__end__"]]: + return Command(update={"value": state["value"] + "b"}, goto=END) + + +@workflow.defn +class CommandWorkflow: + def __init__(self) -> None: + self.app = graph("my-graph").compile() + + @workflow.run + async def run(self, input: str) -> Any: + return await self.app.ainvoke({"value": input}) + + +async def test_command_goto_and_update(client: Client): + g = StateGraph(State) + g.add_node("node_a", node_a) + g.add_node("node_b", node_b) + g.add_edge(START, "node_a") + + task_queue = f"command-{uuid4()}" + + async with Worker( + client, + task_queue=task_queue, + workflows=[CommandWorkflow], + plugins=[ + LangGraphPlugin( + graphs={"my-graph": g}, + default_activity_options={ + "start_to_close_timeout": timedelta(seconds=10) + }, + ) + ], + ): + result = await client.execute_workflow( + CommandWorkflow.run, + "", + id=f"test-command-{uuid4()}", + task_queue=task_queue, + ) + + assert result == {"value": "ab"} From 04b56e7f8667f39482f6728a14fc1f02ac39898f Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Tue, 21 Apr 2026 15:45:16 -0700 Subject: [PATCH 40/47] raise error if node or task has a retry policy --- temporalio/contrib/langgraph/plugin.py | 15 ++++++++++++++ .../langgraph/test_plugin_validation.py | 20 +++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/temporalio/contrib/langgraph/plugin.py b/temporalio/contrib/langgraph/plugin.py index e6f6547f4..009647eaf 100644 --- a/temporalio/contrib/langgraph/plugin.py +++ b/temporalio/contrib/langgraph/plugin.py @@ -77,6 +77,14 @@ def __init__( if graphs: for graph_name, graph in graphs.items(): for node_name, node in graph.nodes.items(): + if node.retry_policy: + raise ValueError( + f"Node {graph_name}.{node_name} has a LangGraph " + f"retry_policy set. Use Temporal activity options " + f"instead, e.g. pass retry_policy=RetryPolicy(...) " + f"via default_activity_options or in the node's " + f"metadata dict." + ) runnable = node.runnable if not isinstance(runnable, RunnableCallable): raise ValueError(f"Node {node_name} must be a RunnableCallable") @@ -118,6 +126,13 @@ def __init__( if tasks: for task in tasks: name = task.func.__name__ + if task.retry_policy: + raise ValueError( + f"Task {name} has a LangGraph retry_policy set. " + f"Use Temporal activity options instead, e.g. pass " + f"retry_policy=RetryPolicy(...) via " + f"default_activity_options or activity_options[{name!r}]." + ) opts = { **(default_activity_options or {}), **(activity_options or {}).get(name, {}), diff --git a/tests/contrib/langgraph/test_plugin_validation.py b/tests/contrib/langgraph/test_plugin_validation.py index 0b73557e6..35be0e3d5 100644 --- a/tests/contrib/langgraph/test_plugin_validation.py +++ b/tests/contrib/langgraph/test_plugin_validation.py @@ -5,7 +5,9 @@ from uuid import uuid4 from langchain_core.runnables import RunnableLambda +from langgraph.func import task # pyright: ignore[reportMissingTypeStubs] from langgraph.graph import START, StateGraph # pyright: ignore[reportMissingTypeStubs] +from langgraph.types import RetryPolicy # pyright: ignore[reportMissingTypeStubs] from pytest import raises from typing_extensions import TypedDict @@ -41,3 +43,21 @@ def test_invalid_execute_in_raises() -> None: with raises(ValueError, match="Invalid execute_in value"): LangGraphPlugin(graphs={f"validation-{uuid4()}": g}) + + +def test_node_retry_policy_raises() -> None: + g = StateGraph(State) + g.add_node("node", async_node, retry_policy=RetryPolicy(max_attempts=3)) + g.add_edge(START, "node") + + with raises(ValueError, match="retry_policy"): + LangGraphPlugin(graphs={f"validation-{uuid4()}": g}) + + +def test_task_retry_policy_raises() -> None: + @task(retry_policy=RetryPolicy(max_attempts=3)) + def my_task() -> str: + return "done" + + with raises(ValueError, match="retry_policy"): + LangGraphPlugin(tasks=[my_task]) From ef96ba41157a9f7cd88be5351ea8551fdf7bd6b4 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Tue, 21 Apr 2026 16:29:53 -0700 Subject: [PATCH 41/47] support runtime context --- temporalio/contrib/langgraph/README.md | 31 +++++++++- temporalio/contrib/langgraph/activity.py | 56 +++++++++++++++++-- temporalio/contrib/langgraph/interceptor.py | 2 + .../contrib/langgraph/langgraph_config.py | 21 ++++++- temporalio/contrib/langgraph/plugin.py | 15 ++--- temporalio/contrib/langgraph/task_cache.py | 13 +++-- 6 files changed, 117 insertions(+), 21 deletions(-) diff --git a/temporalio/contrib/langgraph/README.md b/temporalio/contrib/langgraph/README.md index e8a4da74b..56ebcc0aa 100644 --- a/temporalio/contrib/langgraph/README.md +++ b/temporalio/contrib/langgraph/README.md @@ -105,7 +105,36 @@ plugin = LangGraphPlugin( ) ``` -### Running in the Workflow +### Runtime Context + +LangGraph's run-scoped context (`context_schema`) is reconstructed on the activity side, so nodes and tasks can read `runtime.context` (or call `get_runtime()`) without changing anything at the call site: + +```python +from langgraph.runtime import Runtime +from typing_extensions import TypedDict + +from temporalio.contrib.langgraph import graph + +class Context(TypedDict): + user_id: str + +async def my_node(state: State, runtime: Runtime[Context]) -> dict: + return {"user": runtime.context["user_id"]} + +# In the workflow: +g = graph("my-graph").compile() +await g.ainvoke({...}, context=Context(user_id="alice")) +``` + +Your `context` object must be serializable by the configured Temporal payload converter, since it crosses the activity boundary. + +## Stores are not supported + +LangGraph's `Store` (e.g. `InMemoryStore` passed via `graph.compile(store=...)` or `@entrypoint(store=...)`) isn't accessible inside activity-wrapped nodes: the Store holds live state that can't cross the activity boundary, and activities may run on a different worker than the workflow. If you pass a store, the plugin logs a warning on first use and `runtime.store` is `None` inside nodes. + +Use workflow state for per-run memory, or a backend-backed store (Postgres/Redis/etc.) configured on each worker if you need shared memory across runs. + +## Running in the Workflow To skip the Activity wrapper and run a node or task directly in the Workflow, set `execute_in` to `"workflow"`: diff --git a/temporalio/contrib/langgraph/activity.py b/temporalio/contrib/langgraph/activity.py index 50d93c4be..92bbf3639 100644 --- a/temporalio/contrib/langgraph/activity.py +++ b/temporalio/contrib/langgraph/activity.py @@ -2,7 +2,7 @@ from collections.abc import Awaitable from dataclasses import dataclass -from inspect import iscoroutinefunction +from inspect import iscoroutinefunction, signature from typing import Any, Callable from langgraph.errors import GraphInterrupt @@ -20,6 +20,17 @@ ) +# Per-run dedupe so we only warn once when a user passes a Store via +# graph.compile(store=...) / @entrypoint(store=...). Cleared by +# LangGraphInterceptor.execute_workflow on workflow exit. +_warned_store_runs: set[str] = set() + + +def clear_store_warning(run_id: str) -> None: + """Drop the store-warning dedupe entry for a workflow run.""" + _warned_store_runs.discard(run_id) + + @dataclass class ActivityInput: """Input for a LangGraph activity, containing args, kwargs, and config.""" @@ -42,14 +53,21 @@ def wrap_activity( func: Callable, ) -> Callable[[ActivityInput], Awaitable[ActivityOutput]]: """Wrap a function as a Temporal activity that handles LangGraph config and interrupts.""" + # Graph nodes declare `runtime: Runtime[Ctx]` in their signature; tasks + # don't and instead reach for Runtime via get_runtime(). We re-inject the + # reconstructed Runtime only when the user function asks. + accepts_runtime = "runtime" in signature(func).parameters async def wrapper(input: ActivityInput) -> ActivityOutput: - set_langgraph_config(input.langgraph_config) + runtime = set_langgraph_config(input.langgraph_config) + kwargs = dict(input.kwargs) + if accepts_runtime: + kwargs["runtime"] = runtime try: if iscoroutinefunction(func): - result = await func(*input.args, **input.kwargs) + result = await func(*input.args, **kwargs) else: - result = func(*input.args, **input.kwargs) + result = func(*input.args, **kwargs) if isinstance(result, Command): return ActivityOutput(langgraph_command=result) return ActivityOutput(result=result) @@ -77,15 +95,41 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: "tags": list(orig.get("tags") or []), } + # LangGraph may inject a Runtime as the 'runtime' kwarg. It's + # reconstructed on the activity side from the serialized langgraph + # config, so drop the live Runtime from the kwargs that cross the + # activity boundary (it holds non-serializable stream_writer, store). + runtime = kwargs.pop("runtime", None) + run_id = workflow.info().run_id + if ( + getattr(runtime, "store", None) is not None + and run_id not in _warned_store_runs + ): + _warned_store_runs.add(run_id) + workflow.logger.warning( + "LangGraph Store passed via compile(store=...) / @entrypoint(store=...) " + "is not accessible inside activity-wrapped nodes and tasks: the Store " + "object isn't serializable across the activity boundary, and activities " + "may run on a different worker than the workflow. Use a backend-backed " + "store (Postgres/Redis) configured on each worker if you need shared " + "memory, or use workflow state for per-run memory." + ) + + langgraph_config = get_langgraph_config() + # Check task result cache (for continue-as-new deduplication). - key = cache_key(task_id, args, kwargs) if task_id else "" + key = ( + cache_key(task_id, args, kwargs, langgraph_config.get("context")) + if task_id + else "" + ) if task_id: found, cached = cache_lookup(key) if found: return cached input = ActivityInput( - args=args, kwargs=kwargs, langgraph_config=get_langgraph_config() + args=args, kwargs=kwargs, langgraph_config=langgraph_config ) output = await workflow.execute_activity( afunc, input, **execute_activity_kwargs diff --git a/temporalio/contrib/langgraph/interceptor.py b/temporalio/contrib/langgraph/interceptor.py index 94108e507..46fe0c315 100644 --- a/temporalio/contrib/langgraph/interceptor.py +++ b/temporalio/contrib/langgraph/interceptor.py @@ -10,6 +10,7 @@ from langgraph.pregel import Pregel from temporalio import workflow +from temporalio.contrib.langgraph.activity import clear_store_warning from temporalio.worker import ( ExecuteWorkflowInput, Interceptor, @@ -51,5 +52,6 @@ async def execute_workflow(self, input: ExecuteWorkflowInput) -> Any: run_id = workflow.info().run_id _workflow_graphs.pop(run_id, None) _workflow_entrypoints.pop(run_id, None) + clear_store_warning(run_id) return Inbound diff --git a/temporalio/contrib/langgraph/langgraph_config.py b/temporalio/contrib/langgraph/langgraph_config.py index 9663679af..23a2cd9ae 100644 --- a/temporalio/contrib/langgraph/langgraph_config.py +++ b/temporalio/contrib/langgraph/langgraph_config.py @@ -7,11 +7,13 @@ from langchain_core.runnables.config import var_child_runnable_config from langgraph._internal._constants import ( CONFIG_KEY_CHECKPOINT_NS, + CONFIG_KEY_RUNTIME, CONFIG_KEY_SCRATCHPAD, CONFIG_KEY_SEND, ) from langgraph.graph.state import RunnableConfig from langgraph.pregel._algo import LazyAtomicCounter, PregelScratchpad +from langgraph.runtime import Runtime def get_langgraph_config() -> dict[str, Any]: @@ -19,6 +21,7 @@ def get_langgraph_config() -> dict[str, Any]: config = var_child_runnable_config.get() or {} configurable = config.get("configurable") or {} scratchpad = configurable.get(CONFIG_KEY_SCRATCHPAD) + runtime = configurable.get(CONFIG_KEY_RUNTIME) return { "configurable": { @@ -29,12 +32,17 @@ def get_langgraph_config() -> dict[str, Any]: "resume": list(getattr(scratchpad, "resume", [])), "null_resume": scratchpad.get_null_resume() if scratchpad else None, }, - } + }, + "context": getattr(runtime, "context", None), } -def set_langgraph_config(config: dict[str, Any]) -> None: - """Restore a LangGraph runnable config from a serialized dict.""" +def set_langgraph_config(config: dict[str, Any]) -> Runtime: + """Restore a LangGraph runnable config from a serialized dict. + + Returns the reconstructed Runtime so callers can re-inject it into the + user function's kwargs without needing to know the configurable layout. + """ configurable = config.get("configurable") or {} scratchpad = configurable.get(CONFIG_KEY_SCRATCHPAD) or {} null_resume_box = [scratchpad.get("null_resume")] @@ -45,6 +53,11 @@ def get_null_resume(consume: bool = False) -> Any: null_resume_box[0] = None return val + runtime = Runtime( + context=config.get("context"), + stream_writer=lambda _: None, + ) + var_child_runnable_config.set( RunnableConfig( { @@ -62,7 +75,9 @@ def get_null_resume(consume: bool = False) -> Any: subgraph_counter=LazyAtomicCounter(), ), CONFIG_KEY_SEND: lambda _: None, + CONFIG_KEY_RUNTIME: runtime, }, } ) ) + return runtime diff --git a/temporalio/contrib/langgraph/plugin.py b/temporalio/contrib/langgraph/plugin.py index 009647eaf..3c2b69b4a 100644 --- a/temporalio/contrib/langgraph/plugin.py +++ b/temporalio/contrib/langgraph/plugin.py @@ -91,13 +91,14 @@ def __init__( user_func = runnable.afunc or runnable.func if user_func is None: raise ValueError(f"Node {node_name} must have a function") - # Keep only 'config' injection so node functions can read - # metadata/tags. Drop writer/store/runtime/etc., which hold - # non-serializable objects that can't cross the activity - # boundary. The wrapper serializes config down to its - # portable subset before handing off to the activity. + # Keep 'config' (for metadata/tags) and 'runtime' (for + # context + store — reconstructed on the activity side). + # Drop writer/etc., which hold non-serializable objects + # that can't cross the activity boundary. runnable.func_accepts = { - k: v for k, v in runnable.func_accepts.items() if k == "config" + k: v + for k, v in runnable.func_accepts.items() + if k in ("config", "runtime") } # Split node.metadata into activity options vs. user # metadata. Activity-option keys (timeouts, retry policy, @@ -187,7 +188,7 @@ def execute( def graph( name: str, cache: dict[str, Any] | None = None -) -> StateGraph[Any, None, Any, Any]: +) -> StateGraph[Any, Any, Any, Any]: """Retrieve a registered graph by name. Args: diff --git a/temporalio/contrib/langgraph/task_cache.py b/temporalio/contrib/langgraph/task_cache.py index d4053c808..ab3e683d7 100644 --- a/temporalio/contrib/langgraph/task_cache.py +++ b/temporalio/contrib/langgraph/task_cache.py @@ -54,12 +54,17 @@ def task_id(func: Any) -> str: return f"{module}.{qualname}" -def cache_key(task_id: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> str: - """Build a cache key from the full task identifier and arguments.""" +def cache_key( + task_id: str, + args: tuple[Any, ...], + kwargs: dict[str, Any], + context: Any = None, +) -> str: + """Build a cache key from the full task identifier, arguments, and runtime context.""" try: - key_str = dumps([task_id, args, kwargs], sort_keys=True, default=str) + key_str = dumps([task_id, args, kwargs, context], sort_keys=True, default=str) except (TypeError, ValueError): - key_str = repr([task_id, args, kwargs]) + key_str = repr([task_id, args, kwargs, context]) return sha256(key_str.encode()).hexdigest()[:32] From b742951b3245e39fc7aeebc015de3fedc4e0f496 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Wed, 22 Apr 2026 12:25:08 -0700 Subject: [PATCH 42/47] fix lint --- temporalio/contrib/langgraph/activity.py | 1 - temporalio/contrib/langgraph/interceptor.py | 4 ++++ tests/contrib/langgraph/test_command.py | 7 +++++-- tests/contrib/langgraph/test_continue_as_new.py | 2 +- tests/contrib/langgraph/test_node_metadata.py | 4 +++- tests/contrib/langgraph/test_plugin_validation.py | 9 ++++++--- tests/contrib/langgraph/test_send.py | 6 +++++- 7 files changed, 24 insertions(+), 9 deletions(-) diff --git a/temporalio/contrib/langgraph/activity.py b/temporalio/contrib/langgraph/activity.py index 92bbf3639..ba938d445 100644 --- a/temporalio/contrib/langgraph/activity.py +++ b/temporalio/contrib/langgraph/activity.py @@ -19,7 +19,6 @@ cache_put, ) - # Per-run dedupe so we only warn once when a user passes a Store via # graph.compile(store=...) / @entrypoint(store=...). Cleared by # LangGraphInterceptor.execute_workflow on workflow exit. diff --git a/temporalio/contrib/langgraph/interceptor.py b/temporalio/contrib/langgraph/interceptor.py index 46fe0c315..b3d28e1bf 100644 --- a/temporalio/contrib/langgraph/interceptor.py +++ b/temporalio/contrib/langgraph/interceptor.py @@ -24,17 +24,21 @@ class LangGraphInterceptor(Interceptor): + """Interceptor that registers a workflow's graphs and entrypoints for the run.""" + def __init__( self, graphs: dict[str, StateGraph[Any, Any, Any, Any]], entrypoints: dict[str, Pregel[Any, Any, Any, Any]], ) -> None: + """Initialize with the graphs and entrypoints to scope to each workflow run.""" self._graphs = graphs self._entrypoints = entrypoints def workflow_interceptor_class( self, input: WorkflowInterceptorClassInput ) -> type[WorkflowInboundInterceptor]: + """Return the inbound interceptor class used to scope graphs per run.""" graphs = self._graphs entrypoints = self._entrypoints diff --git a/tests/contrib/langgraph/test_command.py b/tests/contrib/langgraph/test_command.py index c983e6728..21d73df38 100644 --- a/tests/contrib/langgraph/test_command.py +++ b/tests/contrib/langgraph/test_command.py @@ -2,7 +2,10 @@ from typing import Any, Literal from uuid import uuid4 -from langgraph.graph import END, START, StateGraph # pyright: ignore[reportMissingTypeStubs] +from langgraph.graph import ( # pyright: ignore[reportMissingTypeStubs] + START, + StateGraph, +) from langgraph.types import Command from typing_extensions import TypedDict @@ -21,7 +24,7 @@ def node_a(state: State) -> Command[Literal["node_b"]]: def node_b(state: State) -> Command[Literal["__end__"]]: - return Command(update={"value": state["value"] + "b"}, goto=END) + return Command(update={"value": state["value"] + "b"}, goto="__end__") @workflow.defn diff --git a/tests/contrib/langgraph/test_continue_as_new.py b/tests/contrib/langgraph/test_continue_as_new.py index f75d4556b..2cf955ac3 100644 --- a/tests/contrib/langgraph/test_continue_as_new.py +++ b/tests/contrib/langgraph/test_continue_as_new.py @@ -29,7 +29,7 @@ def __init__(self) -> None: self.app = graph("my-graph").compile(checkpointer=InMemorySaver()) @workflow.run - async def run(self, values: dict[str, str]) -> Any: + async def run(self, values: State) -> Any: config = RunnableConfig({"configurable": {"thread_id": "1"}}) await self.app.aupdate_state(config, values) diff --git a/tests/contrib/langgraph/test_node_metadata.py b/tests/contrib/langgraph/test_node_metadata.py index 1e4dbe9a2..e682e4bbc 100644 --- a/tests/contrib/langgraph/test_node_metadata.py +++ b/tests/contrib/langgraph/test_node_metadata.py @@ -2,7 +2,9 @@ from typing import Any from uuid import uuid4 -from langchain_core.runnables import RunnableConfig # pyright: ignore[reportMissingTypeStubs] +from langchain_core.runnables import ( + RunnableConfig, # pyright: ignore[reportMissingTypeStubs] +) from langgraph.graph import START, StateGraph # pyright: ignore[reportMissingTypeStubs] from typing_extensions import TypedDict diff --git a/tests/contrib/langgraph/test_plugin_validation.py b/tests/contrib/langgraph/test_plugin_validation.py index 35be0e3d5..dbb1613ea 100644 --- a/tests/contrib/langgraph/test_plugin_validation.py +++ b/tests/contrib/langgraph/test_plugin_validation.py @@ -2,6 +2,7 @@ from __future__ import annotations +from typing import Any from uuid import uuid4 from langchain_core.runnables import RunnableLambda @@ -55,9 +56,11 @@ def test_node_retry_policy_raises() -> None: def test_task_retry_policy_raises() -> None: - @task(retry_policy=RetryPolicy(max_attempts=3)) - def my_task() -> str: - return "done" + decorator: Any = task(retry_policy=RetryPolicy(max_attempts=3)) + + @decorator + def my_task(x: int) -> int: + return x + 1 with raises(ValueError, match="retry_policy"): LangGraphPlugin(tasks=[my_task]) diff --git a/tests/contrib/langgraph/test_send.py b/tests/contrib/langgraph/test_send.py index 570f8517b..a03247fbc 100644 --- a/tests/contrib/langgraph/test_send.py +++ b/tests/contrib/langgraph/test_send.py @@ -3,7 +3,11 @@ from typing import Annotated, Any from uuid import uuid4 -from langgraph.graph import END, START, StateGraph # pyright: ignore[reportMissingTypeStubs] +from langgraph.graph import ( # pyright: ignore[reportMissingTypeStubs] + END, + START, + StateGraph, +) from langgraph.types import Send from typing_extensions import TypedDict From 6135fe09e4d9a7198ddda12bba200c76833154ee Mon Sep 17 00:00:00 2001 From: DABH Date: Thu, 23 Apr 2026 13:35:13 -0500 Subject: [PATCH 43/47] Revert changes to langsmith test_integration.py --- tests/contrib/langsmith/test_integration.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/contrib/langsmith/test_integration.py b/tests/contrib/langsmith/test_integration.py index 29cc28ce9..78d48c71e 100644 --- a/tests/contrib/langsmith/test_integration.py +++ b/tests/contrib/langsmith/test_integration.py @@ -318,11 +318,9 @@ async def _poll_query( query: Callable[..., Any], *, expected: Any = True, - timeout: float = 45, ) -> bool: """Poll a workflow query until it returns the expected value.""" - deadline = asyncio.get_event_loop().time() + timeout - while asyncio.get_event_loop().time() < deadline: + while True: try: result = await handle.query(query) if result == expected: @@ -330,7 +328,6 @@ async def _poll_query( except (WorkflowQueryFailedError, RPCError): pass # Query not yet available (workflow hasn't started) await asyncio.sleep(1) - return False # --------------------------------------------------------------------------- From 65036a0a1526085a115ef4aa3d622e3de4790e7f Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Thu, 23 Apr 2026 11:56:13 -0700 Subject: [PATCH 44/47] code review --- temporalio/contrib/langgraph/README.md | 42 ++++++++++---------------- temporalio/contrib/langgraph/plugin.py | 4 +-- 2 files changed, 18 insertions(+), 28 deletions(-) diff --git a/temporalio/contrib/langgraph/README.md b/temporalio/contrib/langgraph/README.md index 56ebcc0aa..0189092a6 100644 --- a/temporalio/contrib/langgraph/README.md +++ b/temporalio/contrib/langgraph/README.md @@ -10,12 +10,6 @@ This Temporal [Plugin](https://docs.temporal.io/develop/plugins-guide) allows yo uv add temporalio[langgraph] ``` -or with pip: - -```sh -pip install temporalio[langgraph] -``` - ## Plugin Initialization ### Graph API @@ -35,17 +29,13 @@ from temporalio.contrib.langgraph import LangGraphPlugin plugin = LangGraphPlugin( entrypoints={"my_entrypoint": my_entrypoint}, tasks=[my_task], - activity_options={ - "my_task": { - "start_to_close_timeout": datetime.timedelta(seconds=30), - }, - }, ) ``` ## Checkpointer -Use `InMemorySaver` as your checkpointer. Temporal handles durability, so third-party checkpointers (like PostgreSQL or Redis) are not needed. +If your LangGraph code requires a checkpointer (for example, if you're using interrupts), use `InMemorySaver`. +Temporal handles durability, so third-party checkpointers (like PostgreSQL or Redis) are not needed. ```python import langgraph.checkpoint.memory @@ -71,25 +61,25 @@ Options are passed through to [`workflow.execute_activity()`](https://python.tem ### Graph API -Pass activity options as node `metadata` when calling `add_node`: +Pass Activity options as node `metadata` when calling `add_node`: ```python -import datetime +from datetime import timedelta from temporalio.common import RetryPolicy g = StateGraph(State) g.add_node("my_node", my_node, metadata={ - "start_to_close_timeout": datetime.timedelta(seconds=30), + "start_to_close_timeout": timedelta(seconds=30), "retry_policy": RetryPolicy(maximum_attempts=3), }) ``` ### Functional API -Pass activity options to the `Plugin` constructor, keyed by task function name: +Pass Activity options to the `LangGraphPlugin` constructor, keyed by task function name: ```python -import datetime +from datetime import timedelta from temporalio.common import RetryPolicy from temporalio.contrib.langgraph import LangGraphPlugin @@ -98,7 +88,7 @@ plugin = LangGraphPlugin( tasks=[my_task], activity_options={ "my_task": { - "start_to_close_timeout": datetime.timedelta(seconds=30), + "start_to_close_timeout": timedelta(seconds=30), "retry_policy": RetryPolicy(maximum_attempts=3), }, }, @@ -107,7 +97,7 @@ plugin = LangGraphPlugin( ### Runtime Context -LangGraph's run-scoped context (`context_schema`) is reconstructed on the activity side, so nodes and tasks can read `runtime.context` (or call `get_runtime()`) without changing anything at the call site: +LangGraph's run-scoped context (`context_schema`) is reconstructed on the Activity side, so nodes and tasks can read from and write to `runtime.context`: ```python from langgraph.runtime import Runtime @@ -121,22 +111,22 @@ class Context(TypedDict): async def my_node(state: State, runtime: Runtime[Context]) -> dict: return {"user": runtime.context["user_id"]} -# In the workflow: +# In the Workflow: g = graph("my-graph").compile() await g.ainvoke({...}, context=Context(user_id="alice")) ``` -Your `context` object must be serializable by the configured Temporal payload converter, since it crosses the activity boundary. +Your `context` object must be serializable by the configured Temporal payload converter, since it crosses the Activity boundary. ## Stores are not supported -LangGraph's `Store` (e.g. `InMemoryStore` passed via `graph.compile(store=...)` or `@entrypoint(store=...)`) isn't accessible inside activity-wrapped nodes: the Store holds live state that can't cross the activity boundary, and activities may run on a different worker than the workflow. If you pass a store, the plugin logs a warning on first use and `runtime.store` is `None` inside nodes. +LangGraph's `Store` (e.g. `InMemoryStore` passed via `graph.compile(store=...)` or `@entrypoint(store=...)`) isn't accessible inside Activity-wrapped nodes: the Store holds live state that can't cross the Activity boundary, and Activities may run on a different worker than the Workflow. If you pass a store, the plugin logs a warning on first use and `runtime.store` is `None` inside nodes. -Use workflow state for per-run memory, or a backend-backed store (Postgres/Redis/etc.) configured on each worker if you need shared memory across runs. +Use Workflow state for per-run memory, or an external database (Postgres/Redis/etc.) configured on each worker if you need shared memory across runs. ## Running in the Workflow -To skip the Activity wrapper and run a node or task directly in the Workflow, set `execute_in` to `"workflow"`: +To run a node or task directly in the Workflow, set `execute_in` to `"workflow"`: ```python # Graph API @@ -154,13 +144,13 @@ plugin = LangGraphPlugin( Install dependencies: ```sh -uv sync +uv sync --all-extras ``` Run the test suite: ```sh -uv run pytest +uv run pytest tests/contrib/langgraph ``` Tests start a local Temporal dev server automatically — no external server needed. diff --git a/temporalio/contrib/langgraph/plugin.py b/temporalio/contrib/langgraph/plugin.py index 3c2b69b4a..598718790 100644 --- a/temporalio/contrib/langgraph/plugin.py +++ b/temporalio/contrib/langgraph/plugin.py @@ -56,7 +56,7 @@ def __init__( entrypoints: dict[str, Pregel[Any, Any, Any, Any]] | None = None, tasks: list | None = None, # TODO: Remove activity_options when we have support for @task(metadata=...) - activity_options: dict[str, dict] | None = None, + activity_options: dict[str, dict[str, Any]] | None = None, default_activity_options: dict[str, Any] | None = None, ): """Initialize the LangGraph plugin with graphs, entrypoints, and tasks.""" @@ -160,7 +160,7 @@ def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: return runner super().__init__( - "temporalio.LangGraphPlugin", + "langchain.LangGraphPlugin", activities=self.activities, workflow_runner=workflow_runner, interceptors=[LangGraphInterceptor(graphs or {}, entrypoints or {})], From 44c7b885bb12344c8e2e1a8296f8c7874d985fa3 Mon Sep 17 00:00:00 2001 From: DABH Date: Thu, 23 Apr 2026 14:36:37 -0500 Subject: [PATCH 45/47] Remove langchain_core from LangSmith plugin sandbox passthroughs (CI experiment) --- temporalio/contrib/langsmith/_plugin.py | 1 - 1 file changed, 1 deletion(-) diff --git a/temporalio/contrib/langsmith/_plugin.py b/temporalio/contrib/langsmith/_plugin.py index 789c93414..e7923c9c1 100644 --- a/temporalio/contrib/langsmith/_plugin.py +++ b/temporalio/contrib/langsmith/_plugin.py @@ -72,7 +72,6 @@ def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: runner, restrictions=runner.restrictions.with_passthrough_modules( "langsmith", - "langchain_core", ), ) return runner From d9187c874f04a971f24e99cbfa8ae1e0fb413220 Mon Sep 17 00:00:00 2001 From: DABH Date: Thu, 23 Apr 2026 15:04:01 -0500 Subject: [PATCH 46/47] Restore langchain_core to LangSmith plugin sandbox passthroughs --- temporalio/contrib/langsmith/_plugin.py | 1 + 1 file changed, 1 insertion(+) diff --git a/temporalio/contrib/langsmith/_plugin.py b/temporalio/contrib/langsmith/_plugin.py index e7923c9c1..789c93414 100644 --- a/temporalio/contrib/langsmith/_plugin.py +++ b/temporalio/contrib/langsmith/_plugin.py @@ -72,6 +72,7 @@ def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: runner, restrictions=runner.restrictions.with_passthrough_modules( "langsmith", + "langchain_core", ), ) return runner From 872adf121cb81e22d54663dbb61b8a9c6851daa9 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Thu, 23 Apr 2026 15:58:34 -0700 Subject: [PATCH 47/47] underscore py files in langgraph plugin dir --- temporalio/contrib/langgraph/__init__.py | 2 +- temporalio/contrib/langgraph/{activity.py => _activity.py} | 4 ++-- .../contrib/langgraph/{interceptor.py => _interceptor.py} | 2 +- .../langgraph/{langgraph_config.py => _langgraph_config.py} | 0 temporalio/contrib/langgraph/{plugin.py => _plugin.py} | 6 +++--- .../contrib/langgraph/{task_cache.py => _task_cache.py} | 0 tests/contrib/langgraph/e2e_functional_workflows.py | 2 +- tests/contrib/langgraph/test_command.py | 2 +- tests/contrib/langgraph/test_continue_as_new.py | 2 +- tests/contrib/langgraph/test_continue_as_new_cached.py | 2 +- tests/contrib/langgraph/test_e2e_functional.py | 2 +- tests/contrib/langgraph/test_execute_in_workflow.py | 2 +- tests/contrib/langgraph/test_interrupt.py | 2 +- tests/contrib/langgraph/test_node_metadata.py | 2 +- tests/contrib/langgraph/test_plugin_validation.py | 2 +- tests/contrib/langgraph/test_replay.py | 2 +- tests/contrib/langgraph/test_send.py | 2 +- tests/contrib/langgraph/test_streaming.py | 2 +- tests/contrib/langgraph/test_subgraph_activity.py | 2 +- tests/contrib/langgraph/test_subgraph_workflow.py | 2 +- tests/contrib/langgraph/test_sync_node.py | 2 +- tests/contrib/langgraph/test_sync_task.py | 2 +- tests/contrib/langgraph/test_timeout.py | 2 +- tests/contrib/langgraph/test_two_nodes.py | 2 +- 24 files changed, 25 insertions(+), 25 deletions(-) rename temporalio/contrib/langgraph/{activity.py => _activity.py} (97%) rename temporalio/contrib/langgraph/{interceptor.py => _interceptor.py} (96%) rename temporalio/contrib/langgraph/{langgraph_config.py => _langgraph_config.py} (100%) rename temporalio/contrib/langgraph/{plugin.py => _plugin.py} (98%) rename temporalio/contrib/langgraph/{task_cache.py => _task_cache.py} (100%) diff --git a/temporalio/contrib/langgraph/__init__.py b/temporalio/contrib/langgraph/__init__.py index 56ae647b9..c12d459a6 100644 --- a/temporalio/contrib/langgraph/__init__.py +++ b/temporalio/contrib/langgraph/__init__.py @@ -10,7 +10,7 @@ API (``StateGraph``) and Functional API (``@entrypoint`` / ``@task``). """ -from temporalio.contrib.langgraph.plugin import ( +from temporalio.contrib.langgraph._plugin import ( LangGraphPlugin, cache, entrypoint, diff --git a/temporalio/contrib/langgraph/activity.py b/temporalio/contrib/langgraph/_activity.py similarity index 97% rename from temporalio/contrib/langgraph/activity.py rename to temporalio/contrib/langgraph/_activity.py index ba938d445..c9be05849 100644 --- a/temporalio/contrib/langgraph/activity.py +++ b/temporalio/contrib/langgraph/_activity.py @@ -9,11 +9,11 @@ from langgraph.types import Command, Interrupt from temporalio import workflow -from temporalio.contrib.langgraph.langgraph_config import ( +from temporalio.contrib.langgraph._langgraph_config import ( get_langgraph_config, set_langgraph_config, ) -from temporalio.contrib.langgraph.task_cache import ( +from temporalio.contrib.langgraph._task_cache import ( cache_key, cache_lookup, cache_put, diff --git a/temporalio/contrib/langgraph/interceptor.py b/temporalio/contrib/langgraph/_interceptor.py similarity index 96% rename from temporalio/contrib/langgraph/interceptor.py rename to temporalio/contrib/langgraph/_interceptor.py index b3d28e1bf..fd583c052 100644 --- a/temporalio/contrib/langgraph/interceptor.py +++ b/temporalio/contrib/langgraph/_interceptor.py @@ -10,7 +10,7 @@ from langgraph.pregel import Pregel from temporalio import workflow -from temporalio.contrib.langgraph.activity import clear_store_warning +from temporalio.contrib.langgraph._activity import clear_store_warning from temporalio.worker import ( ExecuteWorkflowInput, Interceptor, diff --git a/temporalio/contrib/langgraph/langgraph_config.py b/temporalio/contrib/langgraph/_langgraph_config.py similarity index 100% rename from temporalio/contrib/langgraph/langgraph_config.py rename to temporalio/contrib/langgraph/_langgraph_config.py diff --git a/temporalio/contrib/langgraph/plugin.py b/temporalio/contrib/langgraph/_plugin.py similarity index 98% rename from temporalio/contrib/langgraph/plugin.py rename to temporalio/contrib/langgraph/_plugin.py index 598718790..e7cde6e56 100644 --- a/temporalio/contrib/langgraph/plugin.py +++ b/temporalio/contrib/langgraph/_plugin.py @@ -15,13 +15,13 @@ from langgraph.pregel import Pregel from temporalio import activity, workflow -from temporalio.contrib.langgraph.activity import wrap_activity, wrap_execute_activity -from temporalio.contrib.langgraph.interceptor import ( +from temporalio.contrib.langgraph._activity import wrap_activity, wrap_execute_activity +from temporalio.contrib.langgraph._interceptor import ( LangGraphInterceptor, _workflow_entrypoints, _workflow_graphs, ) -from temporalio.contrib.langgraph.task_cache import ( +from temporalio.contrib.langgraph._task_cache import ( get_task_cache, set_task_cache, task_id, diff --git a/temporalio/contrib/langgraph/task_cache.py b/temporalio/contrib/langgraph/_task_cache.py similarity index 100% rename from temporalio/contrib/langgraph/task_cache.py rename to temporalio/contrib/langgraph/_task_cache.py diff --git a/tests/contrib/langgraph/e2e_functional_workflows.py b/tests/contrib/langgraph/e2e_functional_workflows.py index 852384c90..d355bdb28 100644 --- a/tests/contrib/langgraph/e2e_functional_workflows.py +++ b/tests/contrib/langgraph/e2e_functional_workflows.py @@ -6,7 +6,7 @@ from typing import Any from temporalio import workflow -from temporalio.contrib.langgraph.plugin import cache, entrypoint +from temporalio.contrib.langgraph import cache, entrypoint @workflow.defn diff --git a/tests/contrib/langgraph/test_command.py b/tests/contrib/langgraph/test_command.py index 21d73df38..2bbb5bb0c 100644 --- a/tests/contrib/langgraph/test_command.py +++ b/tests/contrib/langgraph/test_command.py @@ -11,7 +11,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.plugin import LangGraphPlugin, graph +from temporalio.contrib.langgraph import LangGraphPlugin, graph from temporalio.worker import Worker diff --git a/tests/contrib/langgraph/test_continue_as_new.py b/tests/contrib/langgraph/test_continue_as_new.py index 2cf955ac3..a2eec3834 100644 --- a/tests/contrib/langgraph/test_continue_as_new.py +++ b/tests/contrib/langgraph/test_continue_as_new.py @@ -11,7 +11,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.plugin import LangGraphPlugin, graph +from temporalio.contrib.langgraph import LangGraphPlugin, graph from temporalio.worker import Worker diff --git a/tests/contrib/langgraph/test_continue_as_new_cached.py b/tests/contrib/langgraph/test_continue_as_new_cached.py index c24d6f544..a35574bb4 100644 --- a/tests/contrib/langgraph/test_continue_as_new_cached.py +++ b/tests/contrib/langgraph/test_continue_as_new_cached.py @@ -14,7 +14,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.plugin import LangGraphPlugin, cache, graph +from temporalio.contrib.langgraph import LangGraphPlugin, cache, graph from temporalio.worker import Worker # Track execution counts to verify caching diff --git a/tests/contrib/langgraph/test_e2e_functional.py b/tests/contrib/langgraph/test_e2e_functional.py index f9e8654c8..fc55fc81c 100644 --- a/tests/contrib/langgraph/test_e2e_functional.py +++ b/tests/contrib/langgraph/test_e2e_functional.py @@ -30,7 +30,7 @@ from temporalio import workflow from temporalio.client import Client, WorkflowFailureError from temporalio.common import RetryPolicy -from temporalio.contrib.langgraph.plugin import LangGraphPlugin, entrypoint +from temporalio.contrib.langgraph import LangGraphPlugin, entrypoint from temporalio.worker import Worker from tests.contrib.langgraph.e2e_functional_entrypoints import ( add_ten, diff --git a/tests/contrib/langgraph/test_execute_in_workflow.py b/tests/contrib/langgraph/test_execute_in_workflow.py index 11c22fa6b..15b44f5a9 100644 --- a/tests/contrib/langgraph/test_execute_in_workflow.py +++ b/tests/contrib/langgraph/test_execute_in_workflow.py @@ -6,7 +6,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.plugin import LangGraphPlugin, graph +from temporalio.contrib.langgraph import LangGraphPlugin, graph from temporalio.worker import Worker diff --git a/tests/contrib/langgraph/test_interrupt.py b/tests/contrib/langgraph/test_interrupt.py index f95590bff..eeff1e846 100644 --- a/tests/contrib/langgraph/test_interrupt.py +++ b/tests/contrib/langgraph/test_interrupt.py @@ -19,7 +19,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.plugin import LangGraphPlugin, graph +from temporalio.contrib.langgraph import LangGraphPlugin, graph from temporalio.worker import Worker diff --git a/tests/contrib/langgraph/test_node_metadata.py b/tests/contrib/langgraph/test_node_metadata.py index e682e4bbc..9d9537355 100644 --- a/tests/contrib/langgraph/test_node_metadata.py +++ b/tests/contrib/langgraph/test_node_metadata.py @@ -10,7 +10,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.plugin import LangGraphPlugin +from temporalio.contrib.langgraph import LangGraphPlugin from temporalio.worker import Worker diff --git a/tests/contrib/langgraph/test_plugin_validation.py b/tests/contrib/langgraph/test_plugin_validation.py index dbb1613ea..d80413b6c 100644 --- a/tests/contrib/langgraph/test_plugin_validation.py +++ b/tests/contrib/langgraph/test_plugin_validation.py @@ -12,7 +12,7 @@ from pytest import raises from typing_extensions import TypedDict -from temporalio.contrib.langgraph.plugin import LangGraphPlugin +from temporalio.contrib.langgraph import LangGraphPlugin class State(TypedDict): diff --git a/tests/contrib/langgraph/test_replay.py b/tests/contrib/langgraph/test_replay.py index e83bb1a16..04bba96bf 100644 --- a/tests/contrib/langgraph/test_replay.py +++ b/tests/contrib/langgraph/test_replay.py @@ -6,7 +6,7 @@ from langgraph.graph import START, StateGraph # pyright: ignore[reportMissingTypeStubs] from temporalio.client import Client -from temporalio.contrib.langgraph.plugin import LangGraphPlugin +from temporalio.contrib.langgraph import LangGraphPlugin from temporalio.worker import Replayer, Worker from tests.contrib.langgraph.test_interrupt import ( InterruptWorkflow, diff --git a/tests/contrib/langgraph/test_send.py b/tests/contrib/langgraph/test_send.py index a03247fbc..445d20648 100644 --- a/tests/contrib/langgraph/test_send.py +++ b/tests/contrib/langgraph/test_send.py @@ -13,7 +13,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.plugin import LangGraphPlugin, graph +from temporalio.contrib.langgraph import LangGraphPlugin, graph from temporalio.worker import Worker diff --git a/tests/contrib/langgraph/test_streaming.py b/tests/contrib/langgraph/test_streaming.py index ba0283479..474a853c8 100644 --- a/tests/contrib/langgraph/test_streaming.py +++ b/tests/contrib/langgraph/test_streaming.py @@ -7,7 +7,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.plugin import LangGraphPlugin, graph +from temporalio.contrib.langgraph import LangGraphPlugin, graph from temporalio.worker import Worker diff --git a/tests/contrib/langgraph/test_subgraph_activity.py b/tests/contrib/langgraph/test_subgraph_activity.py index 956867e0b..1cd6b5e96 100644 --- a/tests/contrib/langgraph/test_subgraph_activity.py +++ b/tests/contrib/langgraph/test_subgraph_activity.py @@ -7,7 +7,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.plugin import LangGraphPlugin, graph +from temporalio.contrib.langgraph import LangGraphPlugin, graph from temporalio.worker import Worker diff --git a/tests/contrib/langgraph/test_subgraph_workflow.py b/tests/contrib/langgraph/test_subgraph_workflow.py index bafb70b82..5e77f6dd8 100644 --- a/tests/contrib/langgraph/test_subgraph_workflow.py +++ b/tests/contrib/langgraph/test_subgraph_workflow.py @@ -7,7 +7,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.plugin import LangGraphPlugin, graph +from temporalio.contrib.langgraph import LangGraphPlugin, graph from temporalio.worker import Worker diff --git a/tests/contrib/langgraph/test_sync_node.py b/tests/contrib/langgraph/test_sync_node.py index bb99f2858..edc79200b 100644 --- a/tests/contrib/langgraph/test_sync_node.py +++ b/tests/contrib/langgraph/test_sync_node.py @@ -7,7 +7,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.plugin import LangGraphPlugin, graph +from temporalio.contrib.langgraph import LangGraphPlugin, graph from temporalio.worker import Worker diff --git a/tests/contrib/langgraph/test_sync_task.py b/tests/contrib/langgraph/test_sync_task.py index cf8dfe2f8..a6e135b21 100644 --- a/tests/contrib/langgraph/test_sync_task.py +++ b/tests/contrib/langgraph/test_sync_task.py @@ -16,7 +16,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.plugin import LangGraphPlugin, entrypoint +from temporalio.contrib.langgraph import LangGraphPlugin, entrypoint from temporalio.worker import Worker diff --git a/tests/contrib/langgraph/test_timeout.py b/tests/contrib/langgraph/test_timeout.py index 36331f65c..23e65caa5 100644 --- a/tests/contrib/langgraph/test_timeout.py +++ b/tests/contrib/langgraph/test_timeout.py @@ -10,7 +10,7 @@ from temporalio import workflow from temporalio.client import Client, WorkflowFailureError from temporalio.common import RetryPolicy -from temporalio.contrib.langgraph.plugin import LangGraphPlugin, graph +from temporalio.contrib.langgraph import LangGraphPlugin, graph from temporalio.worker import Worker diff --git a/tests/contrib/langgraph/test_two_nodes.py b/tests/contrib/langgraph/test_two_nodes.py index b5cdd62bb..5d9978ffe 100644 --- a/tests/contrib/langgraph/test_two_nodes.py +++ b/tests/contrib/langgraph/test_two_nodes.py @@ -7,7 +7,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph.plugin import LangGraphPlugin, graph +from temporalio.contrib.langgraph import LangGraphPlugin, graph from temporalio.worker import Worker