From d735d3f8587b7f20afa0106d281c5997676267c2 Mon Sep 17 00:00:00 2001 From: Mahmoud Mabrouk Date: Fri, 19 Jun 2026 16:46:03 +0200 Subject: [PATCH] fix(agent): propagate messages session ids to runner traces --- .../sdk/middlewares/running/normalizer.py | 10 +++++-- .../unit/test_normalizer_passthrough.py | 30 +++++++++++++++++++ services/agent/src/engines/pi.ts | 5 ++-- services/agent/src/engines/rivet.ts | 6 ++-- services/agent/src/protocol.ts | 7 ++++- services/agent/test/continuation.test.ts | 13 +++++++- services/oss/src/agent/app.py | 2 ++ .../oss/tests/pytest/unit/agent/conftest.py | 2 ++ .../pytest/unit/agent/test_invoke_handler.py | 12 ++++++++ 9 files changed, 79 insertions(+), 8 deletions(-) diff --git a/sdks/python/agenta/sdk/middlewares/running/normalizer.py b/sdks/python/agenta/sdk/middlewares/running/normalizer.py index cdbe389b33..44c5b791e4 100644 --- a/sdks/python/agenta/sdk/middlewares/running/normalizer.py +++ b/sdks/python/agenta/sdk/middlewares/running/normalizer.py @@ -66,8 +66,10 @@ async def _normalize_request( 1. If parameter name is 'request': passes the entire WorkflowServiceRequest 2. If parameter name matches DATA_FIELDS (like 'inputs', 'outputs', 'parameters'): extracts that field from request.data - 3. If parameter is **kwargs: includes all unconsumed DATA_FIELDS - 4. Otherwise: looks up the parameter name in request.data.inputs dict + 3. If parameter name is a supported top-level request field like 'session_id': + extracts that field from the request envelope + 4. If parameter is **kwargs: includes all unconsumed DATA_FIELDS + 5. Otherwise: looks up the parameter name in request.data.inputs dict Args: request: The workflow service request containing inputs and data @@ -95,6 +97,10 @@ async def _normalize_request( ) consumed.add(name) + elif name == "session_id": + normalized[name] = request.session_id + consumed.add(name) + elif param.kind == inspect.Parameter.VAR_KEYWORD: if request.data: for f in self.DATA_FIELDS - consumed: diff --git a/sdks/python/oss/tests/pytest/unit/test_normalizer_passthrough.py b/sdks/python/oss/tests/pytest/unit/test_normalizer_passthrough.py index b796680685..94d99e0fdf 100644 --- a/sdks/python/oss/tests/pytest/unit/test_normalizer_passthrough.py +++ b/sdks/python/oss/tests/pytest/unit/test_normalizer_passthrough.py @@ -79,6 +79,36 @@ def handler(parameters): assert kwargs["parameters"] == {"correct_answer_key": "answer"} + @pytest.mark.asyncio + async def test_session_id_is_passed_to_explicit_handler_argument(self): + def handler(session_id): + return session_id + + request = WorkflowServiceRequest( + session_id="sess_request", + data=WorkflowRequestData(), + ) + + mw = NormalizerMiddleware() + kwargs = await mw._normalize_request(request, handler) + + assert kwargs["session_id"] == "sess_request" + + @pytest.mark.asyncio + async def test_session_id_is_not_added_to_var_kwargs(self): + def handler(**kwargs): + return kwargs + + request = WorkflowServiceRequest( + session_id="sess_request", + data=WorkflowRequestData(inputs={"prompt": "hi"}), + ) + + mw = NormalizerMiddleware() + kwargs = await mw._normalize_request(request, handler) + + assert "session_id" not in kwargs + class TestAsyncGenerator: @pytest.mark.asyncio diff --git a/services/agent/src/engines/pi.ts b/services/agent/src/engines/pi.ts index cd1ce93505..f26981ace0 100644 --- a/services/agent/src/engines/pi.ts +++ b/services/agent/src/engines/pi.ts @@ -41,6 +41,7 @@ import { type HarnessCapabilities, type ResolvedToolSpec, type ToolCallbackContext, + resolveRunSessionId, resolvePromptText, } from "../protocol.ts"; import { EMPTY_OBJECT_SCHEMA } from "../tools/callback.ts"; @@ -299,7 +300,8 @@ export async function runPi( })); // Hand the session id + model to the extension so spans carry them. - otel.config.sessionId = session.sessionId; + const sessionId = resolveRunSessionId(request, session.sessionId); + otel.config.sessionId = sessionId; otel.config.provider = model.provider; otel.config.requestModel = model.id; @@ -329,7 +331,6 @@ export async function runPi( await session.prompt(prompt); const output = streamed.trim() || extractAssistantText(session.messages); - const sessionId = session.sessionId; const stopReason = lastStopReason(session.messages); const usage = otel.usage(); diff --git a/services/agent/src/engines/rivet.ts b/services/agent/src/engines/rivet.ts index 6437d7502c..876020b9b0 100644 --- a/services/agent/src/engines/rivet.ts +++ b/services/agent/src/engines/rivet.ts @@ -65,6 +65,7 @@ import { type ToolCallbackContext, messageText, resolvePromptText, + resolveRunSessionId, } from "../protocol.ts"; const require = createRequire(import.meta.url); @@ -817,6 +818,7 @@ export async function runRivet( cwd, sessionInit: { cwd, mcpServers }, }); + const sessionId = resolveRunSessionId(request, session.id); // Resolve the model first: when the harness rejects the requested id and keeps its // own default (e.g. Claude ignores "gpt-5.5"), `model` is undefined and the chat span @@ -838,7 +840,7 @@ export async function runRivet( run.start({ prompt, - sessionId: session.id, + sessionId, messages: [...priorMessages(request), { role: "user", content: prompt }], }); @@ -922,7 +924,7 @@ export async function runRivet( // `streamingDeltas` advertises end-to-end live deltas, which is only true when a live // sink is wired. The one-shot path reports false even when the harness produces deltas. capabilities: { ...capabilities, streamingDeltas: !!emit && capabilities.streamingDeltas }, - sessionId: session.id, + sessionId, model: model ?? request.model, traceId: run.traceId(), }; diff --git a/services/agent/src/protocol.ts b/services/agent/src/protocol.ts index 086e0654b1..bee8a7a496 100644 --- a/services/agent/src/protocol.ts +++ b/services/agent/src/protocol.ts @@ -189,7 +189,7 @@ export interface AgentRunRequest { harness?: string; /** Sandbox for the rivet backend ("local" / "daytona"). */ sandbox?: string; - /** Continue a prior run by replaying its history. */ + /** External conversation id. The cold runtime still receives history in `messages`. */ sessionId?: string; /** Provider API keys as env vars ({OPENAI_API_KEY,...}), resolved from the vault. */ secrets?: Record; @@ -288,3 +288,8 @@ export function resolvePromptText(request: AgentRunRequest): string { } return ""; } + +/** Prefer the platform conversation id, falling back to the harness's ephemeral id. */ +export function resolveRunSessionId(request: AgentRunRequest, fallback: string): string { + return request.sessionId && request.sessionId.trim() ? request.sessionId : fallback; +} diff --git a/services/agent/test/continuation.test.ts b/services/agent/test/continuation.test.ts index 0d4af7948f..c9f9d4356c 100644 --- a/services/agent/test/continuation.test.ts +++ b/services/agent/test/continuation.test.ts @@ -11,7 +11,11 @@ import assert from "node:assert/strict"; import { messageTranscript, buildTurnText } from "../src/engines/rivet.ts"; -import type { AgentRunRequest, ContentBlock } from "../src/protocol.ts"; +import { + resolveRunSessionId, + type AgentRunRequest, + type ContentBlock, +} from "../src/protocol.ts"; // --- messageTranscript ------------------------------------------------------- assert.equal(messageTranscript("hello"), "hello"); @@ -29,6 +33,13 @@ assert.equal( "[send error: boom]", ); +// --- session id metadata ------------------------------------------------------ +assert.equal( + resolveRunSessionId({ sessionId: "sess_platform" }, "runner-ephemeral"), + "sess_platform", +); +assert.equal(resolveRunSessionId({}, "runner-ephemeral"), "runner-ephemeral"); + // --- buildTurnText keeps a resolved tool turn in the replay ------------------ { const req: AgentRunRequest = { diff --git a/services/oss/src/agent/app.py b/services/oss/src/agent/app.py index 6409065d55..33d1986b59 100644 --- a/services/oss/src/agent/app.py +++ b/services/oss/src/agent/app.py @@ -77,6 +77,7 @@ async def _agent( messages: Optional[List[Any]] = None, parameters: Optional[Dict] = None, stream: Optional[bool] = None, + session_id: Optional[str] = None, ): params = parameters or {} @@ -98,6 +99,7 @@ async def _agent( secrets=await resolve_harness_secrets(), permission_policy=selection.permission_policy, trace=trace_context(), + session_id=session_id, builtin_names=resources.tools.builtin_names, tool_specs=resources.tools.tool_specs, tool_callback=resources.tools.tool_callback, diff --git a/services/oss/tests/pytest/unit/agent/conftest.py b/services/oss/tests/pytest/unit/agent/conftest.py index af83b70d93..f84c7b29df 100644 --- a/services/oss/tests/pytest/unit/agent/conftest.py +++ b/services/oss/tests/pytest/unit/agent/conftest.py @@ -85,6 +85,7 @@ def __init__( self.shutdown_calls = 0 # Every harness-shaped config that reached the backend boundary, in call order. self.created_configs: list = [] + self.created_session_ids: list[Optional[str]] = [] async def setup(self) -> None: self.setup_calls += 1 @@ -99,6 +100,7 @@ async def create_session( self, sandbox, config, *, harness, secrets=None, trace=None, session_id=None ) -> _FakeSession: self.created_configs.append(config) + self.created_session_ids.append(session_id) return _FakeSession(self._result) diff --git a/services/oss/tests/pytest/unit/agent/test_invoke_handler.py b/services/oss/tests/pytest/unit/agent/test_invoke_handler.py index 59643e5754..a562eff74f 100644 --- a/services/oss/tests/pytest/unit/agent/test_invoke_handler.py +++ b/services/oss/tests/pytest/unit/agent/test_invoke_handler.py @@ -85,6 +85,18 @@ async def test_invoke_runs_backend_lifecycle(patched): assert backend.shutdown_calls == 1 # cleanup() tears the backend down +async def test_messages_session_id_reaches_session_config(patched): + backend, _ = patched + + await app._agent( + messages=[{"role": "user", "content": "hi"}], + parameters={"agent": {"harness": "pi"}}, + session_id="sess_request", + ) + + assert backend.created_session_ids == ["sess_request"] + + async def test_invoke_cross_harness_same_body_divergent_configs( monkeypatch, fake_backend ):