diff --git a/packages/agent/src/adapters/claude/conversion/sdk-to-acp.ts b/packages/agent/src/adapters/claude/conversion/sdk-to-acp.ts index ff6e55a02..847bf6f6b 100644 --- a/packages/agent/src/adapters/claude/conversion/sdk-to-acp.ts +++ b/packages/agent/src/adapters/claude/conversion/sdk-to-acp.ts @@ -618,6 +618,32 @@ export type ResultMessageHandlerResult = { }; }; +export type AgentErrorClassification = + | "upstream_stream_terminated" + | "upstream_connection_error" + | "agent_error"; + +/** + * Classify an error string surfaced by the Claude CLI via `is_error: true` + * result messages. Transient upstream-stream terminations (e.g. the fetch body + * from the LLM gateway is torn down mid-stream) are retriable; most other + * errors are not. + */ +export function classifyAgentError( + result: string | undefined, +): AgentErrorClassification { + if (!result) return "agent_error"; + const text = result.trim(); + // Anthropic SDK surfaces an undici fetch abort as "API Error: terminated". + if (/API Error:\s*terminated\b/i.test(text)) { + return "upstream_stream_terminated"; + } + if (/API Error:\s*Connection error\b/i.test(text)) { + return "upstream_connection_error"; + } + return "agent_error"; +} + export function handleResultMessage( message: SDKResultMessage, ): ResultMessageHandlerResult { @@ -636,9 +662,13 @@ export function handleResultMessage( return { shouldStop: true, stopReason: "max_tokens", usage }; } if (message.is_error) { + const classification = classifyAgentError(message.result); return { shouldStop: true, - error: RequestError.internalError(undefined, message.result), + error: RequestError.internalError( + { classification, result: message.result }, + message.result, + ), usage, }; } diff --git a/packages/agent/src/server/agent-server.ts b/packages/agent/src/server/agent-server.ts index 8698d533c..5bf8ea373 100644 --- a/packages/agent/src/server/agent-server.ts +++ b/packages/agent/src/server/agent-server.ts @@ -14,12 +14,17 @@ import { import { type ServerType, serve } from "@hono/node-server"; import { getCurrentBranch } from "@posthog/git/queries"; import { Hono } from "hono"; +import { z } from "zod"; import packageJson from "../../package.json" with { type: "json" }; import { POSTHOG_METHODS, POSTHOG_NOTIFICATIONS } from "../acp-extensions"; import { createAcpConnection, type InProcessAcpConnection, } from "../adapters/acp-connection"; +import { + type AgentErrorClassification, + classifyAgentError, +} from "../adapters/claude/conversion/sdk-to-acp"; import { selectRecentTurns } from "../adapters/claude/session/jsonl-hydration"; import type { PermissionMode } from "../execution-mode"; import { DEFAULT_CODEX_MODEL } from "../gateway-models"; @@ -51,6 +56,16 @@ import { type JwtPayload, JwtValidationError, validateJwt } from "./jwt"; import { jsonRpcRequestSchema, validateCommandParams } from "./schemas"; import type { AgentServerConfig } from "./types"; +const agentErrorClassificationSchema = z.enum([ + "upstream_stream_terminated", + "upstream_connection_error", + "agent_error", +]) satisfies z.ZodType; + +const errorWithClassificationSchema = z.object({ + data: z.object({ classification: agentErrorClassificationSchema }), +}); + type MessageCallback = (message: unknown) => void; class NdJsonTap { @@ -973,6 +988,41 @@ export class AgentServer { await this.sendInitialTaskMessage(payload, preTaskRun); } + private extractErrorClassification(error: unknown): { + classification: AgentErrorClassification; + message: string; + } { + const message = + error instanceof Error ? error.message : String(error ?? ""); + + // Prefer the structured `data` carried on RequestError if present. + const parsed = errorWithClassificationSchema.safeParse(error); + if (parsed.success) { + return { classification: parsed.data.data.classification, message }; + } + + return { classification: classifyAgentError(message), message }; + } + + private classifyAndSignalFailure( + payload: JwtPayload, + phase: "initial" | "resume", + error: unknown, + ): Promise { + const { classification, message } = this.extractErrorClassification(error); + const errorMessage = + classification === "upstream_stream_terminated" + ? "Upstream LLM stream terminated" + : classification === "upstream_connection_error" + ? "Upstream LLM connection error" + : message || "Agent error"; + this.logger.error(`send_${phase}_task_message_failed`, { + classification, + message, + }); + return this.signalTaskComplete(payload, "error", errorMessage); + } + private async sendInitialTaskMessage( payload: JwtPayload, prefetchedRun?: TaskRun | null, @@ -1087,7 +1137,7 @@ export class AgentServer { if (this.session) { await this.session.logWriter.flushAll(); } - await this.signalTaskComplete(payload, "error"); + await this.classifyAndSignalFailure(payload, "initial", error); } } @@ -1176,7 +1226,7 @@ export class AgentServer { if (this.session) { await this.session.logWriter.flushAll(); } - await this.signalTaskComplete(payload, "error"); + await this.classifyAndSignalFailure(payload, "resume", error); } } @@ -1657,6 +1707,7 @@ ${attributionInstructions} private async signalTaskComplete( payload: JwtPayload, stopReason: string, + errorMessage?: string, ): Promise { if (this.session?.payload.run_id === payload.run_id) { try { @@ -1684,7 +1735,7 @@ ${attributionInstructions} try { await this.posthogAPI.updateTaskRun(payload.task_id, payload.run_id, { status, - error_message: stopReason === "error" ? "Agent error" : undefined, + error_message: errorMessage ?? "Agent error", }); this.logger.info("Task completion signaled", { status, stopReason }); } catch (error) { diff --git a/packages/agent/src/server/question-relay.test.ts b/packages/agent/src/server/question-relay.test.ts index e3865e309..daf597ec1 100644 --- a/packages/agent/src/server/question-relay.test.ts +++ b/packages/agent/src/server/question-relay.test.ts @@ -1,5 +1,6 @@ import { type SetupServerApi, setupServer } from "msw/node"; import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import { classifyAgentError } from "../adapters/claude/conversion/sdk-to-acp"; import type { PostHogAPIClient } from "../posthog-api"; import { createTestRepo, type TestRepo } from "../test/fixtures/api"; import { createPostHogHandlers } from "../test/mocks/msw-handlers"; @@ -49,7 +50,42 @@ const QUESTION_META = { ], }; +function createTransientPromptError(): Error & { + data: { classification: string; result: string }; +} { + const error = new Error("API Error: terminated") as Error & { + data: { classification: string; result: string }; + }; + error.data = { + classification: "upstream_stream_terminated", + result: "API Error: terminated", + }; + return error; +} + +function createTransientConnectionError(): Error & { + data: { classification: string; result: string }; +} { + const error = new Error("fetch failed") as Error & { + data: { classification: string; result: string }; + }; + error.data = { + classification: "upstream_connection_error", + result: "fetch failed", + }; + return error; +} + describe("Question relay", () => { + it.each([ + ["API Error: terminated", "upstream_stream_terminated"], + ["API Error: Connection error", "upstream_connection_error"], + ["something else", "agent_error"], + [undefined, "agent_error"], + ])("classifies %p as %s", (message, expected) => { + expect(classifyAgentError(message)).toBe(expected); + }); + let repo: TestRepo; let server: TestableAgentServer; let mswServer: SetupServerApi; @@ -514,5 +550,93 @@ describe("Question relay", () => { prompt: [{ type: "text", text: "original task description" }], }); }); + + it("does not replay a transient upstream termination before any session activity", async () => { + vi.spyOn(server.posthogAPI, "getTask").mockResolvedValue({ + id: "test-task-id", + title: "t", + description: "original task description", + } as unknown as Task); + vi.spyOn(server.posthogAPI, "getTaskRun").mockResolvedValue({ + id: "test-run-id", + task: "test-task-id", + state: {}, + } as unknown as TaskRun); + + const promptSpy = vi + .fn() + .mockRejectedValueOnce(createTransientPromptError()); + const updateTaskRunSpy = vi + .spyOn(server.posthogAPI, "updateTaskRun") + .mockResolvedValue({} as TaskRun); + server.session = { + payload: TEST_PAYLOAD, + acpSessionId: "acp-session", + clientConnection: { prompt: promptSpy }, + logWriter: { + flushAll: vi.fn().mockResolvedValue(undefined), + getFullAgentResponse: vi.fn().mockReturnValue(null), + resetTurnMessages: vi.fn(), + flush: vi.fn().mockResolvedValue(undefined), + isRegistered: vi.fn().mockReturnValue(true), + }, + }; + + await server.sendInitialTaskMessage(TEST_PAYLOAD); + + expect(promptSpy).toHaveBeenCalledTimes(1); + expect(updateTaskRunSpy).toHaveBeenCalledWith( + "test-task-id", + "test-run-id", + { + status: "failed", + error_message: "Upstream LLM stream terminated", + }, + ); + }); + + it("surfaces upstream connection errors with the connection-specific message", async () => { + vi.spyOn(server.posthogAPI, "getTask").mockResolvedValue({ + id: "test-task-id", + title: "t", + description: "original task description", + } as unknown as Task); + vi.spyOn(server.posthogAPI, "getTaskRun").mockResolvedValue({ + id: "test-run-id", + task: "test-task-id", + state: {}, + } as unknown as TaskRun); + + const promptSpy = vi.fn().mockImplementationOnce(async () => { + throw createTransientConnectionError(); + }); + const updateTaskRunSpy = vi + .spyOn(server.posthogAPI, "updateTaskRun") + .mockResolvedValue({} as TaskRun); + server.session = { + payload: TEST_PAYLOAD, + acpSessionId: "acp-session", + clientConnection: { prompt: promptSpy }, + logWriter: { + flushAll: vi.fn().mockResolvedValue(undefined), + getFullAgentResponse: vi.fn().mockReturnValue(null), + resetTurnMessages: vi.fn(), + flush: vi.fn().mockResolvedValue(undefined), + isRegistered: vi.fn().mockReturnValue(true), + }, + }; + + await server.sendInitialTaskMessage(TEST_PAYLOAD); + + expect(promptSpy).toHaveBeenCalledTimes(1); + expect(updateTaskRunSpy).toHaveBeenCalledWith( + "test-task-id", + "test-run-id", + { + status: "failed", + error_message: "Upstream LLM connection error", + }, + ); + }); }); });