-
Notifications
You must be signed in to change notification settings - Fork 552
feat(sdk): advertise Vercel messages protocol headers #4769
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,6 +25,20 @@ | |
| # An opaque, project-scoped session id (RFC §4.1): bounded length, restricted charset. | ||
| _SESSION_ID_RE = re.compile(r"^[A-Za-z0-9._:-]{1,128}$") | ||
|
|
||
| VERCEL_MESSAGE_PROTOCOL = "vercel" | ||
| VERCEL_MESSAGE_PROTOCOL_VERSION = "v1" | ||
| VERCEL_MESSAGE_PROTOCOL_HEADERS = { | ||
| "x-ag-messages-format": VERCEL_MESSAGE_PROTOCOL, | ||
| "x-ag-messages-version": VERCEL_MESSAGE_PROTOCOL_VERSION, | ||
| } | ||
|
|
||
|
|
||
| def set_vercel_message_protocol_headers(response: Response) -> Response: | ||
| """Stamp the default agent ``/messages`` protocol identity on an HTTP response.""" | ||
| for key, value in VERCEL_MESSAGE_PROTOCOL_HEADERS.items(): | ||
| response.headers.setdefault(key, value) | ||
| return response | ||
|
|
||
|
|
||
| def resolve_session_id(session_id: Optional[str]) -> Optional[str]: | ||
| """Mint a new id when absent, echo a valid one, or return ``None`` when invalid.""" | ||
|
|
@@ -69,9 +83,13 @@ async def messages_endpoint(req: Request, request: WorkflowInvokeRequest): | |
|
|
||
| session_id = resolve_session_id(request.session_id) | ||
| if session_id is None: | ||
| return JSONResponse( | ||
| status_code=400, | ||
| content={"detail": "session_id violates the allowed charset/length"}, | ||
| return set_vercel_message_protocol_headers( | ||
| JSONResponse( | ||
| status_code=400, | ||
| content={ | ||
| "detail": "session_id violates the allowed charset/length" | ||
| }, | ||
| ) | ||
| ) | ||
|
|
||
| try: | ||
|
|
@@ -104,22 +122,28 @@ async def messages_endpoint(req: Request, request: WorkflowInvokeRequest): | |
| and response.status.code is not None | ||
| and response.status.code >= 400 | ||
| ): | ||
| return make_json_response(response) | ||
| return set_vercel_message_protocol_headers(make_json_response(response)) | ||
|
|
||
| if want_stream: | ||
| if not isinstance(response, WorkflowStreamingResponse): | ||
| return make_not_acceptable_response(str(requested), response) | ||
| return set_vercel_message_protocol_headers( | ||
| make_not_acceptable_response(str(requested), response) | ||
| ) | ||
| inject_stream_session_id(response, session_id) | ||
| return make_stream_response(response, "vercel") | ||
| return set_vercel_message_protocol_headers( | ||
| make_stream_response(response, "vercel") | ||
| ) | ||
|
|
||
| if not isinstance(response, WorkflowBatchResponse): | ||
| return make_not_acceptable_response( | ||
| requested or "application/json", response | ||
| return set_vercel_message_protocol_headers( | ||
| make_not_acceptable_response( | ||
| requested or "application/json", response | ||
| ) | ||
| ) | ||
| return make_json_response(response) | ||
| return set_vercel_message_protocol_headers(make_json_response(response)) | ||
|
|
||
| except Exception as exception: | ||
| return await handle_failure(exception) | ||
| return set_vercel_message_protocol_headers(await handle_failure(exception)) | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The exception path is the easiest exit to miss: |
||
|
|
||
| return messages_endpoint | ||
|
|
||
|
|
@@ -140,8 +164,8 @@ async def load_session_endpoint(req: Request, request: LoadSessionRequest): | |
| for idx, message in enumerate(messages, start=1) | ||
| ], | ||
| ) | ||
| return JSONResponse( | ||
| content=response.model_dump(mode="json"), | ||
| return set_vercel_message_protocol_headers( | ||
| JSONResponse(content=response.model_dump(mode="json")) | ||
| ) | ||
|
|
||
| return load_session_endpoint | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,8 @@ | |
|
|
||
| from agenta.sdk.agents import Message | ||
| from agenta.sdk.agents.adapters.vercel.routing import ( | ||
| VERCEL_MESSAGE_PROTOCOL, | ||
| VERCEL_MESSAGE_PROTOCOL_VERSION, | ||
| inject_stream_session_id, | ||
| make_load_session_endpoint, | ||
| resolve_session_id, | ||
|
|
@@ -67,6 +69,11 @@ async def base(): | |
| _UI_MESSAGE = {"role": "user", "parts": [{"type": "text", "text": "hello"}]} | ||
|
|
||
|
|
||
| def _assert_vercel_message_protocol(response): | ||
| assert response.headers["x-ag-messages-format"] == VERCEL_MESSAGE_PROTOCOL | ||
| assert response.headers["x-ag-messages-version"] == VERCEL_MESSAGE_PROTOCOL_VERSION | ||
|
|
||
|
|
||
| def _build_client() -> TestClient: | ||
| app = FastAPI() | ||
|
|
||
|
|
@@ -78,7 +85,13 @@ async def _fake_auth(request, call_next): | |
| return await call_next(request) | ||
|
|
||
| @route("/", app=app, flags={"is_agent": True}) | ||
| async def agent(messages=None, inputs=None, parameters=None, stream=None): | ||
| async def agent( | ||
| messages=None, | ||
| inputs=None, | ||
| parameters=None, | ||
| stream=None, | ||
| session_id=None, | ||
| ): | ||
| if stream: | ||
|
|
||
| async def gen(): | ||
|
|
@@ -89,7 +102,12 @@ async def gen(): | |
| yield {"type": "finish"} | ||
|
|
||
| return gen() | ||
| return {"role": "assistant", "content": "hi", "echoed": messages} | ||
| return { | ||
| "role": "assistant", | ||
| "content": "hi", | ||
| "echoed": messages, | ||
| "session_id": session_id, | ||
| } | ||
|
|
||
| return TestClient(app) | ||
|
|
||
|
|
@@ -145,9 +163,11 @@ def client(): | |
| def test_messages_json_mints_session_and_folds_conversation(client): | ||
| res = client.post("/messages", json={"data": {"messages": [_UI_MESSAGE]}}) | ||
| assert res.status_code == 200 | ||
| _assert_vercel_message_protocol(res) | ||
| body = res.json() | ||
| assert body["session_id"].startswith("sess_") | ||
| assert body["data"]["outputs"]["content"] == "hi" | ||
| assert body["data"]["outputs"]["session_id"] == body["session_id"] | ||
| # The Vercel UIMessage was folded to a neutral {role, content} message for the handler. | ||
| assert body["data"]["outputs"]["echoed"] == [{"role": "user", "content": "hello"}] | ||
|
|
||
|
|
@@ -158,7 +178,9 @@ def test_messages_echoes_supplied_session_id(client): | |
| json={"session_id": "sess_keep", "data": {"messages": [_UI_MESSAGE]}}, | ||
| ) | ||
| assert res.status_code == 200 | ||
| _assert_vercel_message_protocol(res) | ||
| assert res.json()["session_id"] == "sess_keep" | ||
| assert res.json()["data"]["outputs"]["session_id"] == "sess_keep" | ||
|
|
||
|
|
||
| def test_messages_sse_streams_with_done_and_session_in_start(client): | ||
|
|
@@ -168,6 +190,7 @@ def test_messages_sse_streams_with_done_and_session_in_start(client): | |
| json={"session_id": "sess_abc", "data": {"messages": [_UI_MESSAGE]}}, | ||
| ) | ||
| assert res.status_code == 200 | ||
| _assert_vercel_message_protocol(res) | ||
| assert res.headers["x-vercel-ai-ui-message-stream"] == "v1" | ||
| text = res.text | ||
| assert '"sessionId": "sess_abc"' in text # stamped onto the start part | ||
|
|
@@ -208,6 +231,7 @@ def test_messages_sse_preserves_json_error_before_stream(): | |
| ) | ||
|
|
||
| assert response.status_code == 500 | ||
| _assert_vercel_message_protocol(response) | ||
| assert response.headers["content-type"].startswith("application/json") | ||
| assert "x-vercel-ai-ui-message-stream" not in response.headers | ||
| body = response.json() | ||
|
|
@@ -222,11 +246,13 @@ def test_messages_rejects_invalid_session_id(client): | |
| "/messages", json={"session_id": "bad id!", "data": {"messages": []}} | ||
| ) | ||
| assert res.status_code == 400 | ||
| _assert_vercel_message_protocol(res) | ||
|
|
||
|
|
||
| def test_load_session_returns_stub_history(client): | ||
| res = client.post("/load-session", json={"session_id": "sess_abc"}) | ||
| assert res.status_code == 200 | ||
| _assert_vercel_message_protocol(res) | ||
| assert res.json() == {"session_id": "sess_abc", "messages": []} | ||
|
|
||
|
|
||
|
|
@@ -244,6 +270,8 @@ async def save_turn(self, session_id, *, messages, result=None): | |
| response = await endpoint(None, LoadSessionRequest(session_id="sess_abc")) | ||
|
|
||
| assert response.status_code == 200 | ||
| assert response.headers["x-ag-messages-format"] == VERCEL_MESSAGE_PROTOCOL | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good that the direct load-session test asserts the literal header strings rather than the constants. That pins the wire values, so a rename of the constant can't silently change the contract a client depends on. |
||
| assert response.headers["x-ag-messages-version"] == VERCEL_MESSAGE_PROTOCOL_VERSION | ||
| assert json.loads(response.body) == { | ||
| "session_id": "sess_abc", | ||
| "messages": [ | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
setdefaultis the load-bearing choice here: it lets a future path stamp a different format/version (or a non-Vercel adapter reuse this helper) without this overwriting it. Worth a one-line comment so nobody 'fixes' it to a plain assignment.