Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions sdks/python/agenta/sdk/agents/adapters/vercel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@
vercel_ui_messages_to_messages,
)
from .routing import (
VERCEL_MESSAGE_PROTOCOL,
VERCEL_MESSAGE_PROTOCOL_HEADERS,
VERCEL_MESSAGE_PROTOCOL_VERSION,
inject_stream_session_id,
register_agent_message_routes,
resolve_session_id,
set_vercel_message_protocol_headers,
)
from .sse import VERCEL_UI_MESSAGE_STREAM_HEADERS, vercel_sse_stream
from .stream import agent_run_to_vercel_parts, ui_message_stream
Expand All @@ -27,6 +31,10 @@
"vercel_sse_stream",
"resolve_session_id",
"inject_stream_session_id",
"VERCEL_MESSAGE_PROTOCOL",
"VERCEL_MESSAGE_PROTOCOL_VERSION",
"VERCEL_MESSAGE_PROTOCOL_HEADERS",
"set_vercel_message_protocol_headers",
"register_agent_message_routes",
# Former flat-module names.
"from_ui_messages",
Expand Down
48 changes: 36 additions & 12 deletions sdks/python/agenta/sdk/agents/adapters/vercel/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,20 @@
# An opaque, project-scoped session id (RFC §4.1): bounded length, restricted charset.
_SESSION_ID_RE = re.compile(r"^[A-Za-z0-9._:-]{1,128}$")

VERCEL_MESSAGE_PROTOCOL = "vercel"
VERCEL_MESSAGE_PROTOCOL_VERSION = "v1"
VERCEL_MESSAGE_PROTOCOL_HEADERS = {
"x-ag-messages-format": VERCEL_MESSAGE_PROTOCOL,
"x-ag-messages-version": VERCEL_MESSAGE_PROTOCOL_VERSION,
}


def set_vercel_message_protocol_headers(response: Response) -> Response:
"""Stamp the default agent ``/messages`` protocol identity on an HTTP response."""
for key, value in VERCEL_MESSAGE_PROTOCOL_HEADERS.items():
response.headers.setdefault(key, value)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

setdefault is the load-bearing choice here: it lets a future path stamp a different format/version (or a non-Vercel adapter reuse this helper) without this overwriting it. Worth a one-line comment so nobody 'fixes' it to a plain assignment.

return response


def resolve_session_id(session_id: Optional[str]) -> Optional[str]:
"""Mint a new id when absent, echo a valid one, or return ``None`` when invalid."""
Expand Down Expand Up @@ -69,9 +83,13 @@ async def messages_endpoint(req: Request, request: WorkflowInvokeRequest):

session_id = resolve_session_id(request.session_id)
if session_id is None:
return JSONResponse(
status_code=400,
content={"detail": "session_id violates the allowed charset/length"},
return set_vercel_message_protocol_headers(
JSONResponse(
status_code=400,
content={
"detail": "session_id violates the allowed charset/length"
},
)
)

try:
Expand Down Expand Up @@ -104,22 +122,28 @@ async def messages_endpoint(req: Request, request: WorkflowInvokeRequest):
and response.status.code is not None
and response.status.code >= 400
):
return make_json_response(response)
return set_vercel_message_protocol_headers(make_json_response(response))

if want_stream:
if not isinstance(response, WorkflowStreamingResponse):
return make_not_acceptable_response(str(requested), response)
return set_vercel_message_protocol_headers(
make_not_acceptable_response(str(requested), response)
)
inject_stream_session_id(response, session_id)
return make_stream_response(response, "vercel")
return set_vercel_message_protocol_headers(
make_stream_response(response, "vercel")
)

if not isinstance(response, WorkflowBatchResponse):
return make_not_acceptable_response(
requested or "application/json", response
return set_vercel_message_protocol_headers(
make_not_acceptable_response(
requested or "application/json", response
)
)
return make_json_response(response)
return set_vercel_message_protocol_headers(make_json_response(response))

except Exception as exception:
return await handle_failure(exception)
return set_vercel_message_protocol_headers(await handle_failure(exception))

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The exception path is the easiest exit to miss: handle_failure already builds the response, so wrapping it here is the only way these headers reach an error reply. A new early-return added above this except would silently ship without the protocol identity.


return messages_endpoint

Expand All @@ -140,8 +164,8 @@ async def load_session_endpoint(req: Request, request: LoadSessionRequest):
for idx, message in enumerate(messages, start=1)
],
)
return JSONResponse(
content=response.model_dump(mode="json"),
return set_vercel_message_protocol_headers(
JSONResponse(content=response.model_dump(mode="json"))
)

return load_session_endpoint
Expand Down
32 changes: 30 additions & 2 deletions sdks/python/oss/tests/pytest/utils/test_messages_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

from agenta.sdk.agents import Message
from agenta.sdk.agents.adapters.vercel.routing import (
VERCEL_MESSAGE_PROTOCOL,
VERCEL_MESSAGE_PROTOCOL_VERSION,
inject_stream_session_id,
make_load_session_endpoint,
resolve_session_id,
Expand Down Expand Up @@ -67,6 +69,11 @@ async def base():
_UI_MESSAGE = {"role": "user", "parts": [{"type": "text", "text": "hello"}]}


def _assert_vercel_message_protocol(response):
assert response.headers["x-ag-messages-format"] == VERCEL_MESSAGE_PROTOCOL
assert response.headers["x-ag-messages-version"] == VERCEL_MESSAGE_PROTOCOL_VERSION


def _build_client() -> TestClient:
app = FastAPI()

Expand All @@ -78,7 +85,13 @@ async def _fake_auth(request, call_next):
return await call_next(request)

@route("/", app=app, flags={"is_agent": True})
async def agent(messages=None, inputs=None, parameters=None, stream=None):
async def agent(
messages=None,
inputs=None,
parameters=None,
stream=None,
session_id=None,
):
if stream:

async def gen():
Expand All @@ -89,7 +102,12 @@ async def gen():
yield {"type": "finish"}

return gen()
return {"role": "assistant", "content": "hi", "echoed": messages}
return {
"role": "assistant",
"content": "hi",
"echoed": messages,
"session_id": session_id,
}

return TestClient(app)

Expand Down Expand Up @@ -145,9 +163,11 @@ def client():
def test_messages_json_mints_session_and_folds_conversation(client):
res = client.post("/messages", json={"data": {"messages": [_UI_MESSAGE]}})
assert res.status_code == 200
_assert_vercel_message_protocol(res)
body = res.json()
assert body["session_id"].startswith("sess_")
assert body["data"]["outputs"]["content"] == "hi"
assert body["data"]["outputs"]["session_id"] == body["session_id"]
# The Vercel UIMessage was folded to a neutral {role, content} message for the handler.
assert body["data"]["outputs"]["echoed"] == [{"role": "user", "content": "hello"}]

Expand All @@ -158,7 +178,9 @@ def test_messages_echoes_supplied_session_id(client):
json={"session_id": "sess_keep", "data": {"messages": [_UI_MESSAGE]}},
)
assert res.status_code == 200
_assert_vercel_message_protocol(res)
assert res.json()["session_id"] == "sess_keep"
assert res.json()["data"]["outputs"]["session_id"] == "sess_keep"


def test_messages_sse_streams_with_done_and_session_in_start(client):
Expand All @@ -168,6 +190,7 @@ def test_messages_sse_streams_with_done_and_session_in_start(client):
json={"session_id": "sess_abc", "data": {"messages": [_UI_MESSAGE]}},
)
assert res.status_code == 200
_assert_vercel_message_protocol(res)
assert res.headers["x-vercel-ai-ui-message-stream"] == "v1"
text = res.text
assert '"sessionId": "sess_abc"' in text # stamped onto the start part
Expand Down Expand Up @@ -208,6 +231,7 @@ def test_messages_sse_preserves_json_error_before_stream():
)

assert response.status_code == 500
_assert_vercel_message_protocol(response)
assert response.headers["content-type"].startswith("application/json")
assert "x-vercel-ai-ui-message-stream" not in response.headers
body = response.json()
Expand All @@ -222,11 +246,13 @@ def test_messages_rejects_invalid_session_id(client):
"/messages", json={"session_id": "bad id!", "data": {"messages": []}}
)
assert res.status_code == 400
_assert_vercel_message_protocol(res)


def test_load_session_returns_stub_history(client):
res = client.post("/load-session", json={"session_id": "sess_abc"})
assert res.status_code == 200
_assert_vercel_message_protocol(res)
assert res.json() == {"session_id": "sess_abc", "messages": []}


Expand All @@ -244,6 +270,8 @@ async def save_turn(self, session_id, *, messages, result=None):
response = await endpoint(None, LoadSessionRequest(session_id="sess_abc"))

assert response.status_code == 200
assert response.headers["x-ag-messages-format"] == VERCEL_MESSAGE_PROTOCOL

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good that the direct load-session test asserts the literal header strings rather than the constants. That pins the wire values, so a rename of the constant can't silently change the contract a client depends on.

assert response.headers["x-ag-messages-version"] == VERCEL_MESSAGE_PROTOCOL_VERSION
assert json.loads(response.body) == {
"session_id": "sess_abc",
"messages": [
Expand Down
Loading