Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions sdks/python/agenta/sdk/middlewares/running/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,10 @@ async def _normalize_request(
1. If parameter name is 'request': passes the entire WorkflowServiceRequest
2. If parameter name matches DATA_FIELDS (like 'inputs', 'outputs', 'parameters'):
extracts that field from request.data
3. If parameter is **kwargs: includes all unconsumed DATA_FIELDS
4. Otherwise: looks up the parameter name in request.data.inputs dict
3. If parameter name is a supported top-level request field like 'session_id':
extracts that field from the request envelope
4. If parameter is **kwargs: includes all unconsumed DATA_FIELDS
5. Otherwise: looks up the parameter name in request.data.inputs dict

Args:
request: The workflow service request containing inputs and data
Expand Down Expand Up @@ -95,6 +97,10 @@ async def _normalize_request(
)
consumed.add(name)

elif name == "session_id":

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sits before the VAR_KEYWORD branch, so a handler that declares an explicit session_id argument gets the envelope value, while a **kwargs handler does not receive it. That asymmetry is intentional and pinned by the two new normalizer tests.

normalized[name] = request.session_id
consumed.add(name)

elif param.kind == inspect.Parameter.VAR_KEYWORD:
if request.data:
for f in self.DATA_FIELDS - consumed:
Expand Down
30 changes: 30 additions & 0 deletions sdks/python/oss/tests/pytest/unit/test_normalizer_passthrough.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,36 @@ def handler(parameters):

assert kwargs["parameters"] == {"correct_answer_key": "answer"}

@pytest.mark.asyncio
async def test_session_id_is_passed_to_explicit_handler_argument(self):
def handler(session_id):
return session_id

request = WorkflowServiceRequest(
session_id="sess_request",
data=WorkflowRequestData(),
)

mw = NormalizerMiddleware()
kwargs = await mw._normalize_request(request, handler)

assert kwargs["session_id"] == "sess_request"

@pytest.mark.asyncio
async def test_session_id_is_not_added_to_var_kwargs(self):
def handler(**kwargs):
return kwargs

request = WorkflowServiceRequest(
session_id="sess_request",
data=WorkflowRequestData(inputs={"prompt": "hi"}),
)

mw = NormalizerMiddleware()
kwargs = await mw._normalize_request(request, handler)

assert "session_id" not in kwargs


class TestAsyncGenerator:
@pytest.mark.asyncio
Expand Down
5 changes: 3 additions & 2 deletions services/agent/src/engines/pi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import {
type HarnessCapabilities,
type ResolvedToolSpec,
type ToolCallbackContext,
resolveRunSessionId,
resolvePromptText,
} from "../protocol.ts";
import { EMPTY_OBJECT_SCHEMA } from "../tools/callback.ts";
Expand Down Expand Up @@ -299,7 +300,8 @@ export async function runPi(
}));

// Hand the session id + model to the extension so spans carry them.
otel.config.sessionId = session.sessionId;
const sessionId = resolveRunSessionId(request, session.sessionId);

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous code stamped session.sessionId onto OTel config and returned it later. Moving the resolved sessionId up here is what actually fixes the trace: the span now carries the platform id, not Pi's ephemeral one.

otel.config.sessionId = sessionId;
otel.config.provider = model.provider;
otel.config.requestModel = model.id;

Expand Down Expand Up @@ -329,7 +331,6 @@ export async function runPi(
await session.prompt(prompt);

const output = streamed.trim() || extractAssistantText(session.messages);
const sessionId = session.sessionId;
const stopReason = lastStopReason(session.messages);
const usage = otel.usage();

Expand Down
6 changes: 4 additions & 2 deletions services/agent/src/engines/rivet.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ import {
type ToolCallbackContext,
messageText,
resolvePromptText,
resolveRunSessionId,
} from "../protocol.ts";

const require = createRequire(import.meta.url);
Expand Down Expand Up @@ -817,6 +818,7 @@ export async function runRivet(
cwd,
sessionInit: { cwd, mcpServers },
});
const sessionId = resolveRunSessionId(request, session.id);

// Resolve the model first: when the harness rejects the requested id and keeps its
// own default (e.g. Claude ignores "gpt-5.5"), `model` is undefined and the chat span
Expand All @@ -838,7 +840,7 @@ export async function runRivet(

run.start({
prompt,
sessionId: session.id,
sessionId,
messages: [...priorMessages(request), { role: "user", content: prompt }],
});

Expand Down Expand Up @@ -922,7 +924,7 @@ export async function runRivet(
// `streamingDeltas` advertises end-to-end live deltas, which is only true when a live

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirm both replaced session.id sites (here in the result and at run.start) use the same resolved sessionId, so the trace id and the returned metadata stay consistent for a continued conversation.

// sink is wired. The one-shot path reports false even when the harness produces deltas.
capabilities: { ...capabilities, streamingDeltas: !!emit && capabilities.streamingDeltas },
sessionId: session.id,
sessionId,
model: model ?? request.model,
traceId: run.traceId(),
};
Expand Down
7 changes: 6 additions & 1 deletion services/agent/src/protocol.ts
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ export interface AgentRunRequest {
harness?: string;
/** Sandbox for the rivet backend ("local" / "daytona"). */
sandbox?: string;
/** Continue a prior run by replaying its history. */
/** External conversation id. The cold runtime still receives history in `messages`. */
sessionId?: string;
/** Provider API keys as env vars ({OPENAI_API_KEY,...}), resolved from the vault. */
secrets?: Record<string, string>;
Expand Down Expand Up @@ -288,3 +288,8 @@ export function resolvePromptText(request: AgentRunRequest): string {
}
return "";
}

/** Prefer the platform conversation id, falling back to the harness's ephemeral id. */
export function resolveRunSessionId(request: AgentRunRequest, fallback: string): string {
return request.sessionId && request.sessionId.trim() ? request.sessionId : fallback;

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The empty-string guard matters: request.sessionId.trim() means a blank id from the envelope is treated as absent, so the harness fallback wins instead of stamping an empty session onto every span.

}
13 changes: 12 additions & 1 deletion services/agent/test/continuation.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
import assert from "node:assert/strict";

import { messageTranscript, buildTurnText } from "../src/engines/rivet.ts";
import type { AgentRunRequest, ContentBlock } from "../src/protocol.ts";
import {
resolveRunSessionId,
type AgentRunRequest,
type ContentBlock,
} from "../src/protocol.ts";

// --- messageTranscript -------------------------------------------------------
assert.equal(messageTranscript("hello"), "hello");
Expand All @@ -29,6 +33,13 @@ assert.equal(
"[send error: boom]",
);

// --- session id metadata ------------------------------------------------------
assert.equal(
resolveRunSessionId({ sessionId: "sess_platform" }, "runner-ephemeral"),
"sess_platform",
);
assert.equal(resolveRunSessionId({}, "runner-ephemeral"), "runner-ephemeral");

// --- buildTurnText keeps a resolved tool turn in the replay ------------------
{
const req: AgentRunRequest = {
Expand Down
2 changes: 2 additions & 0 deletions services/oss/src/agent/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ async def _agent(
messages: Optional[List[Any]] = None,
parameters: Optional[Dict] = None,
stream: Optional[bool] = None,
session_id: Optional[str] = None,
):
params = parameters or {}

Expand All @@ -98,6 +99,7 @@ async def _agent(
secrets=await resolve_harness_secrets(),
permission_policy=selection.permission_policy,
trace=trace_context(),
session_id=session_id,
builtin_names=resources.tools.builtin_names,
tool_specs=resources.tools.tool_specs,
tool_callback=resources.tools.tool_callback,
Expand Down
2 changes: 2 additions & 0 deletions services/oss/tests/pytest/unit/agent/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
self.shutdown_calls = 0
# Every harness-shaped config that reached the backend boundary, in call order.
self.created_configs: list = []
self.created_session_ids: list[Optional[str]] = []

async def setup(self) -> None:
self.setup_calls += 1
Expand All @@ -99,6 +100,7 @@ async def create_session(
self, sandbox, config, *, harness, secrets=None, trace=None, session_id=None
) -> _FakeSession:
self.created_configs.append(config)
self.created_session_ids.append(session_id)
return _FakeSession(self._result)


Expand Down
12 changes: 12 additions & 0 deletions services/oss/tests/pytest/unit/agent/test_invoke_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,18 @@ async def test_invoke_runs_backend_lifecycle(patched):
assert backend.shutdown_calls == 1 # cleanup() tears the backend down


async def test_messages_session_id_reaches_session_config(patched):
backend, _ = patched

await app._agent(
messages=[{"role": "user", "content": "hi"}],
parameters={"agent": {"harness": "pi"}},
session_id="sess_request",
)

assert backend.created_session_ids == ["sess_request"]


async def test_invoke_cross_harness_same_body_divergent_configs(
monkeypatch, fake_backend
):
Expand Down
Loading