diff --git a/CHANGELOG.md b/CHANGELOG.md index 03a1e921..b32a1916 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,24 @@ All notable changes to this project will be documented in this file. # Changelog +# Changelog + +# Changelog + +## [0.25.1](https://github.com/databricks/appkit/compare/v0.25.0...v0.25.1) (2026-04-27) + +### appkit + +* **appkit:** check isRetryable before retrying in interceptor ([#276](https://github.com/databricks/appkit/issues/276)) ([1c994a6](https://github.com/databricks/appkit/commit/1c994a6d99f397b56e90f1b53df06a61f02b9e82)) + + +## [0.25.0](https://github.com/databricks/appkit/compare/v0.24.0...v0.25.0) (2026-04-23) + +### files + +* **files:** per-volume in-app policy enforcement ([#197](https://github.com/databricks/appkit/issues/197)) ([f54dca5](https://github.com/databricks/appkit/commit/f54dca5da5af5368c7bcb18745715b54a99d47e9)) + + ## [0.24.0](https://github.com/databricks/appkit/compare/v0.23.0...v0.24.0) (2026-04-20) * add AST extraction to serving type generator and move types to shared/ ([#279](https://github.com/databricks/appkit/issues/279)) ([422afb3](https://github.com/databricks/appkit/commit/422afb38aa73f8adb94e091225dc3381bd92cfcd)) diff --git a/apps/dev-playground/server/index.ts b/apps/dev-playground/server/index.ts index 8a77b76c..94f1cc12 100644 --- a/apps/dev-playground/server/index.ts +++ b/apps/dev-playground/server/index.ts @@ -50,7 +50,7 @@ const adminOnly: FilePolicy = (action, _resource, user) => { createApp({ plugins: [ - server({ autoStart: false }), + server(), reconnect(), telemetryExamples(), analytics({}), @@ -95,9 +95,8 @@ createApp({ // }), ], ...(process.env.APPKIT_E2E_TEST && { client: createMockClient() }), -}).then((appkit) => { - appkit.server - .extend((app) => { + onPluginsReady(appkit) { + appkit.server.extend((app) => { app.get("/sp", (_req, res) => { appkit.analytics .query("SELECT * FROM samples.nyctaxi.trips;") @@ -195,9 +194,9 @@ createApp({ results, }); }); - }) - .start(); -}); + }); + }, +}).catch(console.error); type ProbeResult = { volume: string; diff --git a/apps/dev-playground/shared/appkit-types/analytics.d.ts b/apps/dev-playground/shared/appkit-types/analytics.d.ts index 0e0ae0b0..43666dd0 100644 --- a/apps/dev-playground/shared/appkit-types/analytics.d.ts +++ b/apps/dev-playground/shared/appkit-types/analytics.d.ts @@ -119,10 +119,10 @@ declare module "@databricks/appkit-ui/react" { result: Array<{ /** @sqlType STRING */ string_value: string; - /** @sqlType STRING */ - number_value: string; - /** @sqlType STRING */ - boolean_value: string; + /** @sqlType INT */ + number_value: number; + /** @sqlType BOOLEAN */ + boolean_value: boolean; /** @sqlType STRING */ date_value: string; /** @sqlType STRING */ diff --git a/docs/docs/api/appkit/Class.ServerError.md b/docs/docs/api/appkit/Class.ServerError.md index d3dce68e..cce86cad 100644 --- a/docs/docs/api/appkit/Class.ServerError.md +++ b/docs/docs/api/appkit/Class.ServerError.md @@ -6,7 +6,6 @@ Use for server start/stop issues, configuration conflicts, etc. ## Example ```typescript -throw new ServerError("Cannot get server when autoStart is true"); throw new ServerError("Server not started"); ``` @@ -151,26 +150,6 @@ Create a human-readable string representation *** -### autoStartConflict() - -```ts -static autoStartConflict(operation: string): ServerError; -``` - -Create a server error for autoStart conflict - -#### Parameters - -| Parameter | Type | -| ------ | ------ | -| `operation` | `string` | - -#### Returns - -`ServerError` - -*** - ### clientDirectoryNotFound() ```ts diff --git a/docs/docs/api/appkit/Function.createApp.md b/docs/docs/api/appkit/Function.createApp.md index cb703386..6a0b7cb2 100644 --- a/docs/docs/api/appkit/Function.createApp.md +++ b/docs/docs/api/appkit/Function.createApp.md @@ -4,6 +4,7 @@ function createApp(config: { cache?: CacheConfig; client?: WorkspaceClient; + onPluginsReady?: (appkit: PluginMap) => void | Promise; plugins?: T; telemetry?: TelemetryConfig; }): Promise>; @@ -13,6 +14,9 @@ Bootstraps AppKit with the provided configuration. Initializes telemetry, cache, and service context, then registers plugins in phase order (core, normal, deferred) and awaits their setup. +If a `onPluginsReady` callback is provided it runs after plugin setup but +before the server starts, giving you access to the full appkit handle +for registering custom routes or performing async setup. The returned object maps each plugin name to its `exports()` API, with an `asUser(req)` method for user-scoped execution. @@ -26,9 +30,10 @@ with an `asUser(req)` method for user-scoped execution. | Parameter | Type | | ------ | ------ | -| `config` | \{ `cache?`: [`CacheConfig`](Interface.CacheConfig.md); `client?`: `WorkspaceClient`; `plugins?`: `T`; `telemetry?`: [`TelemetryConfig`](Interface.TelemetryConfig.md); \} | +| `config` | \{ `cache?`: [`CacheConfig`](Interface.CacheConfig.md); `client?`: `WorkspaceClient`; `onPluginsReady?`: (`appkit`: `PluginMap`\<`T`\>) => `void` \| `Promise`\<`void`\>; `plugins?`: `T`; `telemetry?`: [`TelemetryConfig`](Interface.TelemetryConfig.md); \} | | `config.cache?` | [`CacheConfig`](Interface.CacheConfig.md) | | `config.client?` | `WorkspaceClient` | +| `config.onPluginsReady?` | (`appkit`: `PluginMap`\<`T`\>) => `void` \| `Promise`\<`void`\> | | `config.plugins?` | `T` | | `config.telemetry?` | [`TelemetryConfig`](Interface.TelemetryConfig.md) | @@ -51,12 +56,12 @@ await createApp({ ```ts import { createApp, server, analytics } from "@databricks/appkit"; -const appkit = await createApp({ - plugins: [server({ autoStart: false }), analytics({})], -}); - -appkit.server.extend((app) => { - app.get("/custom", (_req, res) => res.json({ ok: true })); +await createApp({ + plugins: [server(), analytics({})], + onPluginsReady(appkit) { + appkit.server.extend((app) => { + app.get("/custom", (_req, res) => res.json({ ok: true })); + }); + }, }); -await appkit.server.start(); ``` diff --git a/docs/docs/faq.md b/docs/docs/faq.md index b3fd50e1..41667cae 100644 --- a/docs/docs/faq.md +++ b/docs/docs/faq.md @@ -16,7 +16,7 @@ AppKit provides built-in integrations with the following Databricks services via | [Lakebase](./plugins/lakebase) | Lakebase Autoscaling (PostgreSQL) | Relational database access via standard pg.Pool with automatic OAuth token refresh | | [Genie](./plugins/genie) | AI/BI Genie Spaces | Natural language data queries with conversation management and streaming | | [Files](./plugins/files) | Unity Catalog Volumes | Multi-volume file operations (list, read, upload, download, delete, preview) | -| [Serving](./plugins/serving) | Model Serving | Authenticated proxy to Model Serving endpoints with invoke and streaming support | +| [Model Serving](./plugins/model-serving) | Model Serving | Authenticated proxy to Model Serving endpoints with invoke and streaming support | | [Server](./plugins/server) | N/A | Express HTTP server with static file serving, Vite dev mode, and plugin route injection | Stay tuned for new plugins as we constantly expand integrations! diff --git a/docs/docs/plugins/serving.md b/docs/docs/plugins/model-serving.md similarity index 99% rename from docs/docs/plugins/serving.md rename to docs/docs/plugins/model-serving.md index e2ee052b..00f60245 100644 --- a/docs/docs/plugins/serving.md +++ b/docs/docs/plugins/model-serving.md @@ -2,7 +2,7 @@ sidebar_position: 7 --- -# Serving plugin +# Model Serving plugin Provides an authenticated proxy to [Databricks Model Serving](https://docs.databricks.com/aws/en/machine-learning/model-serving) endpoints, with invoke and streaming support. diff --git a/docs/docs/plugins/server.md b/docs/docs/plugins/server.md index 389828dc..6cfaa7b7 100644 --- a/docs/docs/plugins/server.md +++ b/docs/docs/plugins/server.md @@ -36,22 +36,38 @@ await createApp({ }); ``` -## Manual server start example +## Custom routes example -When you need to extend Express with custom routes: +Use the `onPluginsReady` callback to extend Express with custom routes before the server starts: ```ts import { createApp, server } from "@databricks/appkit"; -const appkit = await createApp({ - plugins: [server({ autoStart: false })], +await createApp({ + plugins: [server()], + onPluginsReady(appkit) { + appkit.server.extend((app) => { + app.get("/custom", (_req, res) => res.json({ ok: true })); + }); + }, }); +``` -appkit.server.extend((app) => { - app.get("/custom", (_req, res) => res.json({ ok: true })); -}); +The `onPluginsReady` callback also supports async operations: -await appkit.server.start(); +```ts +await createApp({ + plugins: [server()], + async onPluginsReady(appkit) { + const pool = await initializeDatabase(); + appkit.server.extend((app) => { + app.get("/data", async (_req, res) => { + const result = await pool.query("SELECT 1"); + res.json(result); + }); + }); + }, +}); ``` ## Configuration options @@ -64,7 +80,6 @@ await createApp({ server({ port: 8000, // default: Number(process.env.DATABRICKS_APP_PORT) || 8000 host: "0.0.0.0", // default: process.env.FLASK_RUN_HOST || "0.0.0.0" - autoStart: true, // default: true staticPath: "dist", // optional: force a specific static directory }), ], diff --git a/knip.json b/knip.json index b777d8c2..878dd3f5 100644 --- a/knip.json +++ b/knip.json @@ -7,7 +7,9 @@ "docs" ], "workspaces": { - "packages/appkit": {}, + "packages/appkit": { + "ignoreDependencies": ["@langchain/core", "ai"] + }, "packages/appkit-ui": { "ignoreDependencies": ["tailwindcss", "tw-animate-css"] } @@ -17,6 +19,12 @@ "**/*.example.tsx", "**/*.css", "packages/appkit/src/plugins/vector-search/**", + "packages/appkit/src/plugin/index.ts", + "packages/appkit/src/plugin/to-plugin.ts", + "packages/appkit/src/plugins/agents/index.ts", + "packages/appkit/src/plugins/agents/tools/index.ts", + "packages/appkit/src/plugins/agents/from-plugin.ts", + "packages/appkit/src/plugins/agents/load-agents.ts", "template/**", "tools/**", "docs/**" diff --git a/packages/appkit-ui/package.json b/packages/appkit-ui/package.json index fa2953b1..beaca8b9 100644 --- a/packages/appkit-ui/package.json +++ b/packages/appkit-ui/package.json @@ -1,7 +1,7 @@ { "name": "@databricks/appkit-ui", "type": "module", - "version": "0.24.0", + "version": "0.25.1", "license": "Apache-2.0", "repository": { "type": "git", diff --git a/packages/appkit/package.json b/packages/appkit/package.json index 146be5a9..83e62814 100644 --- a/packages/appkit/package.json +++ b/packages/appkit/package.json @@ -1,7 +1,7 @@ { "name": "@databricks/appkit", "type": "module", - "version": "0.24.0", + "version": "0.25.1", "main": "./dist/index.js", "types": "./dist/index.d.ts", "packageManager": "pnpm@10.21.0", @@ -29,6 +29,18 @@ "development": "./src/index.ts", "default": "./dist/index.js" }, + "./agents/vercel-ai": { + "development": "./src/agents/vercel-ai.ts", + "default": "./dist/agents/vercel-ai.js" + }, + "./agents/langchain": { + "development": "./src/agents/langchain.ts", + "default": "./dist/agents/langchain.js" + }, + "./agents/databricks": { + "development": "./src/agents/databricks.ts", + "default": "./dist/agents/databricks.js" + }, "./type-generator": { "types": "./dist/type-generator/index.d.ts", "development": "./src/type-generator/index.ts", @@ -71,20 +83,38 @@ "@types/semver": "7.7.1", "dotenv": "16.6.1", "express": "4.22.0", + "js-yaml": "^4.1.1", "obug": "2.1.1", "pg": "8.18.0", "picocolors": "1.1.1", "semver": "7.7.3", "shared": "workspace:*", "vite": "npm:rolldown-vite@7.1.14", - "ws": "8.18.3" + "ws": "8.18.3", + "zod": "^4.0.0" + }, + "peerDependencies": { + "@langchain/core": ">=0.3.0", + "ai": ">=4.0.0" + }, + "peerDependenciesMeta": { + "ai": { + "optional": true + }, + "@langchain/core": { + "optional": true + } }, "devDependencies": { + "@ai-sdk/openai": "4.0.0-beta.27", + "@langchain/core": "^1.1.39", "@types/express": "4.17.25", + "@types/js-yaml": "^4.0.9", "@types/json-schema": "7.0.15", "@types/pg": "8.16.0", "@types/ws": "8.18.1", - "@vitejs/plugin-react": "5.1.1" + "@vitejs/plugin-react": "5.1.1", + "ai": "7.0.0-beta.76" }, "overrides": { "vite": "npm:rolldown-vite@7.1.14" @@ -93,6 +123,9 @@ "publishConfig": { "exports": { ".": "./dist/index.js", + "./agents/vercel-ai": "./dist/agents/vercel-ai.js", + "./agents/langchain": "./dist/agents/langchain.js", + "./agents/databricks": "./dist/agents/databricks.js", "./dist/shared/src/plugin": "./dist/shared/src/plugin.d.ts", "./type-generator": "./dist/type-generator/index.js", "./package.json": "./package.json" diff --git a/packages/appkit/src/agents/databricks.ts b/packages/appkit/src/agents/databricks.ts new file mode 100644 index 00000000..6cc98ca4 --- /dev/null +++ b/packages/appkit/src/agents/databricks.ts @@ -0,0 +1,775 @@ +import type { + AgentAdapter, + AgentEvent, + AgentInput, + AgentRunContext, + AgentToolDefinition, +} from "shared"; +import { stream as servingStream } from "../connectors/serving/client"; + +/** + * Transport shim: given an OpenAI-compatible request body, returns the raw + * SSE byte stream from the serving endpoint. Injected at construction time so + * callers can swap in the workspace SDK (factory paths), a bare `fetch` + * (the raw constructor), or a test fake. + */ +type StreamBody = ( + body: Record, + signal?: AbortSignal, +) => Promise>; + +/** + * Escape-hatch options: provide an `endpointUrl` + `authenticate()` and the + * adapter uses a bare `fetch()` to call it. Useful for tests and for pointing + * the adapter at non-workspace endpoints (reverse proxies, mocks). + */ +interface RawFetchAdapterOptions { + endpointUrl: string; + authenticate: () => Promise>; + maxSteps?: number; + maxTokens?: number; +} + +/** + * Preferred options: caller provides the transport function directly. + * The `fromServingEndpoint` / `fromModelServing` factories use this to route + * through `connectors/serving/stream`, which centralises URL encoding, auth + * via the SDK's `apiClient.request`, and any future retries/telemetry. + */ +interface StreamBodyAdapterOptions { + streamBody: StreamBody; + maxSteps?: number; + maxTokens?: number; +} + +type DatabricksAdapterOptions = + | RawFetchAdapterOptions + | StreamBodyAdapterOptions; + +function isStreamBodyOptions( + o: DatabricksAdapterOptions, +): o is StreamBodyAdapterOptions { + return "streamBody" in o; +} + +/** + * Minimal structural shape consumed by `connectors/serving/stream`. We avoid + * importing the concrete `WorkspaceClient` type to keep the adapter free of a + * compile-time dependency on the SDK. + */ +interface WorkspaceClientLike { + apiClient: { + request(options: Record): Promise; + }; +} + +interface ServingEndpointOptions { + workspaceClient: WorkspaceClientLike; + endpointName: string; + maxSteps?: number; + maxTokens?: number; +} + +interface ModelServingOptions { + maxSteps?: number; + maxTokens?: number; + workspaceClient?: WorkspaceClientLike; +} + +/** + * Structural shape for {@link createDatabricksModel}. The Vercel AI helper + * builds its own `fetch` override and so needs the workspace config surface + * (host, authenticate, ensureResolved) rather than the `apiClient` used by + * the adapter factories. + */ +interface WorkspaceConfig { + host?: string; + authenticate(headers: Headers): Promise; + ensureResolved(): Promise; +} + +interface VercelDatabricksModelOptions { + workspaceClient: { config: WorkspaceConfig }; + endpointName: string; +} + +interface OpenAIMessage { + role: "system" | "user" | "assistant" | "tool"; + content: string | null; + tool_calls?: OpenAIToolCall[]; + tool_call_id?: string; +} + +interface OpenAIToolCall { + id: string; + type: "function"; + function: { name: string; arguments: string }; +} + +interface OpenAITool { + type: "function"; + function: { + name: string; + description: string; + parameters: unknown; + }; +} + +interface DeltaToolCall { + index: number; + id?: string; + type?: string; + function?: { name?: string; arguments?: string }; +} + +/** + * Adapter that talks directly to Databricks Model Serving `/invocations` endpoint. + * + * No dependency on the Vercel AI SDK or LangChain. Uses raw `fetch()` to POST + * OpenAI-compatible payloads and parses the SSE stream itself. Calls + * `authenticate()` per-request so tokens are always fresh. + * + * Handles both structured `tool_calls` responses and text-based tool call + * fallback parsing for models that output tool calls as text. + * + * @example Using the factory (recommended) + * ```ts + * import { createApp, createAgent, agents } from "@databricks/appkit"; + * import { DatabricksAdapter } from "@databricks/appkit/agents/databricks"; + * import { WorkspaceClient } from "@databricks/sdk-experimental"; + * + * const adapter = DatabricksAdapter.fromServingEndpoint({ + * workspaceClient: new WorkspaceClient({}), + * endpointName: "my-endpoint", + * }); + * + * await createApp({ + * plugins: [ + * agents({ + * agents: { + * assistant: createAgent({ + * instructions: "You are a helpful assistant.", + * model: adapter, + * }), + * }, + * }), + * ], + * }); + * ``` + * + * @example Using the raw constructor + * ```ts + * const adapter = new DatabricksAdapter({ + * endpointUrl: "https://host/serving-endpoints/my-endpoint/invocations", + * authenticate: async () => ({ Authorization: `Bearer ${token}` }), + * }); + * ``` + */ +export class DatabricksAdapter implements AgentAdapter { + private streamBody: StreamBody; + private maxSteps: number; + private maxTokens: number; + + constructor(options: DatabricksAdapterOptions) { + this.maxSteps = options.maxSteps ?? 10; + this.maxTokens = options.maxTokens ?? 4096; + + if (isStreamBodyOptions(options)) { + this.streamBody = options.streamBody; + } else { + const { endpointUrl, authenticate } = options; + this.streamBody = async (body, signal) => { + const authHeaders = await authenticate(); + const response = await fetch(endpointUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + ...authHeaders, + }, + body: JSON.stringify(body), + signal, + }); + if (!response.ok) { + const errorText = await response.text().catch(() => "Unknown error"); + throw new Error( + `Databricks API error (${response.status}): ${errorText}`, + ); + } + if (!response.body) throw new Error("No response body"); + return response.body; + }; + } + } + + /** + * Creates a DatabricksAdapter for a Databricks Model Serving endpoint. + * + * Routes through the shared `connectors/serving/stream` helper, which + * delegates to the SDK's `apiClient.request({ raw: true })`. That gives the + * adapter centralised URL encoding + authentication with the rest of the + * serving surface — no bespoke `fetch()` + `authenticate()` plumbing. + */ + static async fromServingEndpoint( + options: ServingEndpointOptions, + ): Promise { + const { workspaceClient, endpointName, maxSteps, maxTokens } = options; + return new DatabricksAdapter({ + streamBody: (body) => + // Cast through the structural shape: the connector types + // `workspaceClient` as the SDK's concrete `WorkspaceClient`, but we + // only need `apiClient.request`. + servingStream( + workspaceClient as unknown as Parameters[0], + endpointName, + body, + ), + maxSteps, + maxTokens, + }); + } + + /** + * Creates a DatabricksAdapter from a Model Serving endpoint name. + * Auto-creates a WorkspaceClient internally. Reads the endpoint name + * from the argument or the `DATABRICKS_AGENT_ENDPOINT` env var. + * + * @example + * ```ts + * // Reads endpoint from DATABRICKS_AGENT_ENDPOINT env var + * const adapter = await DatabricksAdapter.fromModelServing(); + * + * // Explicit endpoint + * const adapter = await DatabricksAdapter.fromModelServing("my-endpoint"); + * + * // With options + * const adapter = await DatabricksAdapter.fromModelServing("my-endpoint", { + * maxSteps: 5, + * maxTokens: 2048, + * }); + * ``` + */ + static async fromModelServing( + endpointName?: string, + options?: ModelServingOptions, + ): Promise { + const resolvedEndpoint = + endpointName ?? process.env.DATABRICKS_AGENT_ENDPOINT; + + if (!resolvedEndpoint) { + throw new Error( + "No endpoint name provided and DATABRICKS_AGENT_ENDPOINT env var is not set. " + + "Pass an endpoint name or set the environment variable.", + ); + } + + let workspaceClient: WorkspaceClientLike | undefined = + options?.workspaceClient; + if (!workspaceClient) { + const sdk = await import("@databricks/sdk-experimental"); + workspaceClient = new sdk.WorkspaceClient( + {}, + ) as unknown as WorkspaceClientLike; + } + + return DatabricksAdapter.fromServingEndpoint({ + workspaceClient, + endpointName: resolvedEndpoint, + maxSteps: options?.maxSteps, + maxTokens: options?.maxTokens, + }); + } + + async *run( + input: AgentInput, + context: AgentRunContext, + ): AsyncGenerator { + // Databricks API requires tool names to match [a-zA-Z0-9_-]. + // Our tool names use dots (e.g. "analytics.query"), so we swap dots + // for double-underscores in the wire format and map back on receipt. + const nameToWire = new Map(); + const wireToName = new Map(); + for (const tool of input.tools) { + const wire = tool.name.replace(/\./g, "__"); + nameToWire.set(tool.name, wire); + wireToName.set(wire, tool.name); + } + + const tools = this.buildTools(input.tools, nameToWire); + const messages = this.buildMessages(input.messages); + + yield { type: "status", status: "running" }; + + for (let step = 0; step < this.maxSteps; step++) { + if (context.signal?.aborted) break; + + const { text, toolCalls } = yield* this.streamCompletion( + messages, + tools, + context, + ); + + if (toolCalls.length === 0) { + const parsed = parseTextToolCalls(text); + if (parsed.length > 0) { + yield* this.executeToolCalls(parsed, messages, context); + continue; + } + break; + } + + messages.push({ + role: "assistant", + content: text || null, + tool_calls: toolCalls, + }); + + for (const tc of toolCalls) { + const wireName = tc.function.name; + const originalName = wireToName.get(wireName) ?? wireName; + let args: unknown; + try { + args = JSON.parse(tc.function.arguments); + } catch { + args = {}; + } + + yield { type: "tool_call", callId: tc.id, name: originalName, args }; + + try { + const result = await context.executeTool(originalName, args); + const resultStr = + typeof result === "string" ? result : JSON.stringify(result); + + yield { type: "tool_result", callId: tc.id, result }; + + messages.push({ + role: "tool", + content: resultStr, + tool_call_id: tc.id, + }); + } catch (error) { + const errMsg = + error instanceof Error ? error.message : "Tool execution failed"; + + yield { + type: "tool_result", + callId: tc.id, + result: null, + error: errMsg, + }; + + messages.push({ + role: "tool", + content: JSON.stringify({ error: errMsg }), + tool_call_id: tc.id, + }); + } + } + } + } + + private async *streamCompletion( + messages: OpenAIMessage[], + tools: OpenAITool[], + context: AgentRunContext, + ): AsyncGenerator< + AgentEvent, + { text: string; toolCalls: OpenAIToolCall[] }, + unknown + > { + const body: Record = { + messages, + stream: true, + max_tokens: this.maxTokens, + }; + + if (tools.length > 0) { + body.tools = tools; + } + + const responseBody = await this.streamBody(body, context.signal); + const reader = responseBody.getReader(); + + const decoder = new TextDecoder(); + let buffer = ""; + let fullText = ""; + const toolCallAccumulator = new Map< + number, + { id: string; name: string; arguments: string } + >(); + + try { + while (true) { + if (context.signal?.aborted) break; + + const { done, value } = await reader.read(); + if (done) break; + + buffer += decoder.decode(value, { stream: true }); + const lines = buffer.split("\n"); + buffer = lines.pop() ?? ""; + + for (const line of lines) { + const trimmed = line.trim(); + if (!trimmed.startsWith("data: ")) continue; + const data = trimmed.slice(6); + if (data === "[DONE]") continue; + + let parsed: any; + try { + parsed = JSON.parse(data); + } catch { + continue; + } + + const delta = parsed.choices?.[0]?.delta; + if (!delta) continue; + + if (delta.content) { + fullText += delta.content; + yield { type: "message_delta" as const, content: delta.content }; + } + + if (delta.tool_calls) { + for (const tc of delta.tool_calls as DeltaToolCall[]) { + const existing = toolCallAccumulator.get(tc.index); + if (existing) { + if (tc.function?.arguments) { + existing.arguments += tc.function.arguments; + } + } else { + toolCallAccumulator.set(tc.index, { + id: tc.id ?? `call_${tc.index}`, + name: tc.function?.name ?? "", + arguments: tc.function?.arguments ?? "", + }); + } + } + } + } + } + } finally { + reader.releaseLock(); + } + + const toolCalls: OpenAIToolCall[] = Array.from( + toolCallAccumulator.values(), + ).map((tc) => ({ + id: tc.id, + type: "function" as const, + function: { name: tc.name, arguments: tc.arguments || "{}" }, + })); + + return { text: fullText, toolCalls }; + } + + private async *executeToolCalls( + calls: Array<{ name: string; args: unknown }>, + messages: OpenAIMessage[], + context: AgentRunContext, + ): AsyncGenerator { + const toolCallObjs: OpenAIToolCall[] = calls.map((c, i) => ({ + id: `text_call_${i}`, + type: "function" as const, + function: { + name: c.name, + arguments: JSON.stringify(c.args), + }, + })); + + messages.push({ + role: "assistant", + content: null, + tool_calls: toolCallObjs, + }); + + for (const tc of toolCallObjs) { + const name = tc.function.name; + let args: unknown; + try { + args = JSON.parse(tc.function.arguments); + } catch { + args = {}; + } + + yield { type: "tool_call", callId: tc.id, name, args }; + + try { + const result = await context.executeTool(name, args); + const resultStr = + typeof result === "string" ? result : JSON.stringify(result); + + yield { type: "tool_result", callId: tc.id, result }; + + messages.push({ + role: "tool", + content: resultStr, + tool_call_id: tc.id, + }); + } catch (error) { + const errMsg = + error instanceof Error ? error.message : "Tool execution failed"; + + yield { + type: "tool_result", + callId: tc.id, + result: null, + error: errMsg, + }; + + messages.push({ + role: "tool", + content: JSON.stringify({ error: errMsg }), + tool_call_id: tc.id, + }); + } + } + } + + private buildMessages(messages: AgentInput["messages"]): OpenAIMessage[] { + return messages.map((m) => ({ + role: m.role as OpenAIMessage["role"], + content: m.content, + })); + } + + private buildTools( + definitions: AgentToolDefinition[], + nameToWire: Map, + ): OpenAITool[] { + return definitions.map((def) => ({ + type: "function" as const, + function: { + name: nameToWire.get(def.name) ?? def.name, + description: def.description, + parameters: def.parameters, + }, + })); + } +} + +// --------------------------------------------------------------------------- +// Vercel AI SDK helper +// --------------------------------------------------------------------------- + +/** + * Creates a Vercel AI-compatible model backed by a Databricks Model Serving endpoint. + * + * Use with `VercelAIAdapter` to get the Vercel AI SDK ecosystem (useChat, etc.) + * while targeting a Databricks `/invocations` endpoint. + * + * Handles URL rewriting (`/chat/completions` -> `/invocations`), per-request + * auth refresh, and tool name sanitization (dots -> double-underscores). + * + * Requires the `ai` and `@ai-sdk/openai` packages as peer dependencies. + * + * @example + * ```ts + * import { createApp, createAgent, agents } from "@databricks/appkit"; + * import { createDatabricksModel } from "@databricks/appkit/agents/databricks"; + * import { VercelAIAdapter } from "@databricks/appkit/agents/vercel-ai"; + * import { WorkspaceClient } from "@databricks/sdk-experimental"; + * + * const model = await createDatabricksModel({ + * workspaceClient: new WorkspaceClient({}), + * endpointName: "my-endpoint", + * }); + * + * await createApp({ + * plugins: [ + * agents({ + * agents: { + * assistant: createAgent({ + * instructions: "You are a helpful assistant.", + * model: new VercelAIAdapter({ model }), + * }), + * }, + * }), + * ], + * }); + * ``` + */ +export async function createDatabricksModel( + options: VercelDatabricksModelOptions, +): Promise { + let createOpenAI: any; + try { + const mod = await import("@ai-sdk/openai"); + createOpenAI = mod.createOpenAI; + } catch { + throw new Error( + "createDatabricksModel requires '@ai-sdk/openai' as a dependency. Install it with: npm install @ai-sdk/openai ai", + ); + } + + const config = options.workspaceClient.config; + await config.ensureResolved(); + + const baseURL = `${config.host}/serving-endpoints/${options.endpointName}`; + + const provider = createOpenAI({ + baseURL, + apiKey: "databricks", + fetch: async (url: string | URL | Request, init?: RequestInit) => { + const rewritten = String(url).replace( + "/chat/completions", + "/invocations", + ); + + const headers = new Headers(init?.headers); + await config.authenticate(headers); + + let body = init?.body; + if (typeof body === "string") { + body = rewriteToolNamesOutbound(body); + } + + const response = await globalThis.fetch(rewritten, { + ...init, + headers, + body, + }); + + if ( + !response.body || + !response.headers.get("content-type")?.includes("text/event-stream") + ) { + return response; + } + + const transformed = response.body.pipeThrough( + createToolNameRewriteStream(), + ); + + return new Response(transformed, { + status: response.status, + statusText: response.statusText, + headers: response.headers, + }); + }, + }); + + return provider(options.endpointName); +} + +/** + * Rewrites tool names in outbound request body (dots -> double-underscores). + */ +function rewriteToolNamesOutbound(body: string): string { + try { + const parsed = JSON.parse(body); + if (parsed.tools) { + for (const tool of parsed.tools) { + if (tool.function?.name) { + tool.function.name = tool.function.name.replace(/\./g, "__"); + } + } + } + return JSON.stringify(parsed); + } catch { + return body; + } +} + +/** + * Creates a TransformStream that rewrites tool names in SSE response chunks + * (double-underscores -> dots). + */ +function createToolNameRewriteStream(): TransformStream< + Uint8Array, + Uint8Array +> { + const decoder = new TextDecoder(); + const encoder = new TextEncoder(); + + return new TransformStream({ + transform(chunk, controller) { + const text = decoder.decode(chunk, { stream: true }); + const rewritten = text.replace( + /"name"\s*:\s*"([a-zA-Z0-9_-]+)"/g, + (match, name: string) => { + if (name.includes("__")) { + return match.replace(name, name.replace(/__/g, ".")); + } + return match; + }, + ); + controller.enqueue(encoder.encode(rewritten)); + }, + }); +} + +// --------------------------------------------------------------------------- +// Text-based tool call parsing (fallback) +// --------------------------------------------------------------------------- + +/** + * Parses text-based tool calls from model output. + * + * Handles two formats: + * 1. Llama native: `[{"name": "tool_name", "parameters": {"arg": "val"}}]` + * 2. Python-style: `[tool_name(arg1='val1', arg2='val2')]` + */ +export function parseTextToolCalls( + text: string, +): Array<{ name: string; args: unknown }> { + const trimmed = text.trim(); + + const jsonResult = tryParseLlamaJsonToolCalls(trimmed); + if (jsonResult.length > 0) return jsonResult; + + const pyResult = tryParsePythonStyleToolCalls(trimmed); + if (pyResult.length > 0) return pyResult; + + return []; +} + +function tryParseLlamaJsonToolCalls( + text: string, +): Array<{ name: string; args: unknown }> { + const match = text.match(/\[\s*\{[\s\S]*\}\s*\]/); + if (!match) return []; + + try { + const parsed = JSON.parse(match[0]); + if (!Array.isArray(parsed)) return []; + + return parsed + .filter( + (item: any) => + typeof item === "object" && + item !== null && + typeof item.name === "string", + ) + .map((item: any) => ({ + name: item.name, + args: item.parameters ?? item.arguments ?? item.args ?? {}, + })); + } catch { + return []; + } +} + +function tryParsePythonStyleToolCalls( + text: string, +): Array<{ name: string; args: unknown }> { + const pattern = /\[?([a-zA-Z_][\w.]*)\(([^)]*)\)\]?/g; + const results: Array<{ name: string; args: unknown }> = []; + + for (const match of text.matchAll(pattern)) { + const name = match[1]; + const argsStr = match[2]; + + const args: Record = {}; + const argPattern = /(\w+)\s*=\s*(?:'([^']*)'|"([^"]*)"|(\S+))/g; + for (const argMatch of argsStr.matchAll(argPattern)) { + const key = argMatch[1]; + const value = argMatch[2] ?? argMatch[3] ?? argMatch[4]; + args[key] = value; + } + + results.push({ name, args }); + } + + return results; +} diff --git a/packages/appkit/src/agents/langchain.ts b/packages/appkit/src/agents/langchain.ts new file mode 100644 index 00000000..77961bcf --- /dev/null +++ b/packages/appkit/src/agents/langchain.ts @@ -0,0 +1,292 @@ +import type { + AgentAdapter, + AgentEvent, + AgentInput, + AgentRunContext, + AgentToolDefinition, +} from "shared"; + +/** + * Adapter bridging LangChain/LangGraph to the AppKit agent protocol. + * + * Accepts any LangChain `Runnable` (e.g. AgentExecutor, compiled LangGraph) + * and maps `streamEvents` v2 to `AgentEvent`. + * + * Requires `@langchain/core` as an optional peer dependency. + * + * @example + * ```ts + * import { createApp, createAgent, agents } from "@databricks/appkit"; + * import { LangChainAdapter } from "@databricks/appkit/agents/langchain"; + * import { ChatOpenAI } from "@langchain/openai"; + * import { createReactAgent } from "@langchain/langgraph/prebuilt"; + * + * const model = new ChatOpenAI({ model: "gpt-4o" }); + * const agentExecutor = createReactAgent({ llm: model, tools: [] }); + * + * await createApp({ + * plugins: [ + * agents({ + * agents: { + * assistant: createAgent({ + * instructions: "You are a helpful assistant.", + * model: new LangChainAdapter({ runnable: agentExecutor }), + * }), + * }, + * }), + * ], + * }); + * ``` + */ +export class LangChainAdapter implements AgentAdapter { + private runnable: any; + + constructor(options: { runnable: any }) { + this.runnable = options.runnable; + } + + async *run( + input: AgentInput, + context: AgentRunContext, + ): AsyncGenerator { + const lcTools = await import("@langchain/core/tools"); + const DynamicStructuredTool = lcTools.DynamicStructuredTool; + const zodModule: any = await import("zod"); + const z = zodModule.z; + + const tools = this.buildTools( + input.tools, + context, + DynamicStructuredTool, + z, + ); + + const messages = input.messages.map((m) => ({ + role: m.role, + content: m.content, + })); + + yield { type: "status", status: "running" }; + + const runnableWithTools = + tools.length > 0 && typeof this.runnable.bindTools === "function" + ? this.runnable.bindTools(tools) + : this.runnable; + + const stream = await runnableWithTools.streamEvents( + { messages }, + { + version: "v2", + signal: input.signal, + }, + ); + + // Tool-call chunks from `on_chat_model_stream` come in fragments keyed by + // the model's `index`. We accumulate them and flush on `on_tool_start`. + const toolCallAccumulator = new Map< + number, + { id: string; name: string; arguments: string } + >(); + // LangChain's `on_tool_end` reports the tool via `event.run_id` (its own + // internal identifier), not the model-provided tool_call id. To keep the + // `call_id` on `tool_call` and `tool_result` matching (so clients can + // correlate them via the Responses API `call_id` field), we record the + // mapping from `run_id` to the original model `tc.id` at `on_tool_start` + // and look it up at `on_tool_end`. + const runIdToCallId = new Map(); + // Counter for fallback callIds when the model does not provide `tc.id`. + let fallbackIdx = 0; + + for await (const event of stream) { + if (context.signal?.aborted) break; + + switch (event.event) { + case "on_chat_model_stream": { + const chunk = event.data?.chunk; + if (chunk?.content && typeof chunk.content === "string") { + yield { type: "message_delta", content: chunk.content }; + } + if (chunk?.tool_call_chunks) { + for (const tc of chunk.tool_call_chunks) { + const idx = tc.index ?? 0; + const existing = toolCallAccumulator.get(idx); + if (existing) { + if (tc.args) existing.arguments += tc.args; + // Later chunks for the same tool call may carry the id/name + // that the first chunk lacked. + if (tc.id && !existing.id.startsWith("lc_")) + existing.id = tc.id; + if (tc.name && !existing.name) existing.name = tc.name; + } else if (tc.name || tc.id) { + toolCallAccumulator.set(idx, { + // Use a deterministic fallback that cannot collide if the + // same tool is called twice in one turn without a model id. + id: + tc.id ?? `lc_${tc.name ?? "tool"}_${idx}_${++fallbackIdx}`, + name: tc.name ?? "", + arguments: tc.args ?? "", + }); + } + } + } + break; + } + + case "on_tool_start": { + // Find the accumulated tool_call that matches this tool invocation + // by name so we can record the run_id → callId mapping and yield + // the `tool_call` event with a callId that will match the + // subsequent `tool_result`. + const toolName = event.name; + let matched: { id: string; name: string; arguments: string } | null = + null; + let matchedKey: number | null = null; + for (const [key, tc] of toolCallAccumulator) { + if (tc.name === toolName) { + matched = tc; + matchedKey = key; + break; + } + } + + if (matched) { + const runId = event.run_id; + if (typeof runId === "string" && runId.length > 0) { + runIdToCallId.set(runId, matched.id); + } + let args: unknown; + try { + args = JSON.parse(matched.arguments || "{}"); + } catch { + args = {}; + } + yield { + type: "tool_call" as const, + callId: matched.id, + name: matched.name, + args, + }; + if (matchedKey !== null) toolCallAccumulator.delete(matchedKey); + } else { + // Fallback: no accumulated tool_call matched this name. Emit a + // tool_call anyway with run_id as the correlating key so the + // client at least sees a call/result pair. + const runId = event.run_id ?? `lc_${toolName}_${++fallbackIdx}`; + runIdToCallId.set(runId, runId); + yield { + type: "tool_call" as const, + callId: runId, + name: toolName ?? "", + args: event.data?.input ?? {}, + }; + } + break; + } + + case "on_tool_end": { + const output = event.data?.output; + const runId = event.run_id; + const callId = + (typeof runId === "string" && runIdToCallId.get(runId)) || runId; + if (typeof runId === "string") runIdToCallId.delete(runId); + yield { + type: "tool_result", + callId, + result: output?.content ?? output, + }; + break; + } + + case "on_chain_end": { + const output = event.data?.output; + if (output?.content && typeof output.content === "string") { + yield { type: "message", content: output.content }; + } + break; + } + } + } + } + + /** + * Converts AgentToolDefinitions into LangChain DynamicStructuredTool instances. + * + * JSON Schema properties are mapped to Zod schemas using a lightweight + * recursive converter for the subset of JSON Schema types that tools use. + */ + private buildTools( + definitions: AgentToolDefinition[], + context: AgentRunContext, + DynamicStructuredTool: any, + z: any, + ): any[] { + return definitions.map( + (def) => + new DynamicStructuredTool({ + name: def.name, + description: def.description, + schema: jsonSchemaToZod(def.parameters, z), + func: async (args: unknown) => { + try { + const result = await context.executeTool(def.name, args); + return typeof result === "string" + ? result + : JSON.stringify(result); + } catch (error) { + return `Error: ${error instanceof Error ? error.message : "Tool execution failed"}`; + } + }, + }), + ); + } +} + +/** + * Lightweight JSON Schema (subset) to Zod converter. + * Handles the types commonly used in tool parameters. + */ +function jsonSchemaToZod(schema: any, z: any): any { + if (!schema) return z.object({}); + + switch (schema.type) { + case "object": { + const shape: Record = {}; + const properties = schema.properties ?? {}; + const required = new Set(schema.required ?? []); + + for (const [key, prop] of Object.entries(properties)) { + let field = jsonSchemaToZod(prop, z); + if (!required.has(key)) { + field = field.optional(); + } + if ((prop as any).description) { + field = field.describe((prop as any).description); + } + shape[key] = field; + } + return z.object(shape); + } + + case "array": + return z.array(jsonSchemaToZod(schema.items ?? {}, z)); + + case "string": { + let s = z.string(); + if (schema.enum) s = z.enum(schema.enum); + return s; + } + + case "number": + case "integer": + return z.number(); + + case "boolean": + return z.boolean(); + + case "null": + return z.null(); + + default: + return z.any(); + } +} diff --git a/packages/appkit/src/agents/tests/databricks.test.ts b/packages/appkit/src/agents/tests/databricks.test.ts new file mode 100644 index 00000000..8a835094 --- /dev/null +++ b/packages/appkit/src/agents/tests/databricks.test.ts @@ -0,0 +1,486 @@ +import type { AgentEvent, AgentToolDefinition, Message } from "shared"; +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { DatabricksAdapter, parseTextToolCalls } from "../databricks"; + +const mockAuthenticate = vi + .fn() + .mockResolvedValue({ Authorization: "Bearer test-token" }); + +function sseChunk(data: string): string { + return `data: ${data}\n\n`; +} + +function textDelta(content: string): string { + return sseChunk( + JSON.stringify({ + choices: [{ delta: { content } }], + }), + ); +} + +function toolCallDelta( + index: number, + id: string | undefined, + name: string | undefined, + args: string, +): string { + return sseChunk( + JSON.stringify({ + choices: [ + { + delta: { + tool_calls: [ + { + index, + ...(id && { id }), + ...(name && { type: "function" }), + function: { + ...(name && { name }), + arguments: args, + }, + }, + ], + }, + }, + ], + }), + ); +} + +function createReadableStream(chunks: string[]): ReadableStream { + const encoder = new TextEncoder(); + let i = 0; + return new ReadableStream({ + pull(controller) { + if (i < chunks.length) { + controller.enqueue(encoder.encode(chunks[i])); + i++; + } else { + controller.close(); + } + }, + }); +} + +function mockFetch(chunks: string[]): typeof globalThis.fetch { + return vi.fn().mockResolvedValue({ + ok: true, + body: createReadableStream(chunks), + text: () => Promise.resolve(""), + }); +} + +function createTestMessages(): Message[] { + return [{ id: "1", role: "user", content: "Hello", createdAt: new Date() }]; +} + +function createTestTools(): AgentToolDefinition[] { + return [ + { + name: "analytics.query", + description: "Run SQL", + parameters: { + type: "object", + properties: { query: { type: "string" } }, + required: ["query"], + }, + }, + ]; +} + +function createAdapter(overrides?: { + endpointUrl?: string; + authenticate?: () => Promise>; + maxSteps?: number; + maxTokens?: number; +}) { + return new DatabricksAdapter({ + endpointUrl: + "https://test.databricks.com/serving-endpoints/my-endpoint/invocations", + authenticate: mockAuthenticate, + ...overrides, + }); +} + +describe("DatabricksAdapter", () => { + const originalFetch = globalThis.fetch; + + afterEach(() => { + globalThis.fetch = originalFetch; + mockAuthenticate.mockClear(); + }); + + test("streams text deltas from the model", async () => { + globalThis.fetch = mockFetch([ + textDelta("Hello"), + textDelta(" world"), + sseChunk("[DONE]"), + ]); + + const adapter = createAdapter(); + const events: AgentEvent[] = []; + + for await (const event of adapter.run( + { messages: createTestMessages(), tools: [], threadId: "t1" }, + { executeTool: vi.fn() }, + )) { + events.push(event); + } + + expect(events[0]).toEqual({ type: "status", status: "running" }); + expect(events[1]).toEqual({ type: "message_delta", content: "Hello" }); + expect(events[2]).toEqual({ type: "message_delta", content: " world" }); + }); + + test("calls authenticate() per request for fresh headers", async () => { + globalThis.fetch = mockFetch([textDelta("Hi"), sseChunk("[DONE]")]); + + const adapter = createAdapter(); + + for await (const _ of adapter.run( + { messages: createTestMessages(), tools: [], threadId: "t1" }, + { executeTool: vi.fn() }, + )) { + // drain + } + + expect(mockAuthenticate).toHaveBeenCalledTimes(1); + + const [, init] = (globalThis.fetch as any).mock.calls[0]; + expect(init.headers.Authorization).toBe("Bearer test-token"); + }); + + test("handles structured tool calls and executes them", async () => { + const executeTool = vi.fn().mockResolvedValue([{ trip_id: 1 }]); + + let callCount = 0; + globalThis.fetch = vi.fn().mockImplementation(() => { + callCount++; + if (callCount === 1) { + return Promise.resolve({ + ok: true, + body: createReadableStream([ + toolCallDelta(0, "call_1", "analytics__query", ""), + toolCallDelta(0, undefined, undefined, '{"query":'), + toolCallDelta(0, undefined, undefined, '"SELECT 1"}'), + sseChunk("[DONE]"), + ]), + }); + } + return Promise.resolve({ + ok: true, + body: createReadableStream([ + textDelta("Here are the results"), + sseChunk("[DONE]"), + ]), + }); + }); + + const adapter = createAdapter(); + const events: AgentEvent[] = []; + + for await (const event of adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { executeTool }, + )) { + events.push(event); + } + + expect(events).toContainEqual({ + type: "tool_call", + callId: "call_1", + name: "analytics.query", + args: { query: "SELECT 1" }, + }); + + expect(executeTool).toHaveBeenCalledWith("analytics.query", { + query: "SELECT 1", + }); + + expect(events).toContainEqual( + expect.objectContaining({ + type: "tool_result", + callId: "call_1", + result: [{ trip_id: 1 }], + }), + ); + + expect(events).toContainEqual({ + type: "message_delta", + content: "Here are the results", + }); + + // authenticate() called once per streamCompletion + expect(mockAuthenticate).toHaveBeenCalledTimes(2); + }); + + test("respects maxSteps limit", async () => { + globalThis.fetch = vi.fn().mockImplementation(() => + Promise.resolve({ + ok: true, + body: createReadableStream([ + toolCallDelta( + 0, + "call_loop", + "analytics__query", + '{"query":"SELECT 1"}', + ), + sseChunk("[DONE]"), + ]), + }), + ); + + const adapter = createAdapter({ maxSteps: 2 }); + const events: AgentEvent[] = []; + + for await (const event of adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { executeTool: vi.fn().mockResolvedValue("ok") }, + )) { + events.push(event); + } + + expect(globalThis.fetch).toHaveBeenCalledTimes(2); + }); + + test("sends correct request to endpoint URL", async () => { + globalThis.fetch = mockFetch([textDelta("Hi"), sseChunk("[DONE]")]); + + const adapter = createAdapter(); + + for await (const _ of adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { executeTool: vi.fn() }, + )) { + // drain + } + + const [url, init] = (globalThis.fetch as any).mock.calls[0]; + expect(url).toBe( + "https://test.databricks.com/serving-endpoints/my-endpoint/invocations", + ); + + const body = JSON.parse(init.body); + expect(body.stream).toBe(true); + expect(body.tools).toHaveLength(1); + expect(body.tools[0].function.name).toBe("analytics__query"); + expect(body.messages[0]).toEqual({ + role: "user", + content: "Hello", + }); + }); + + test("throws on non-ok response", async () => { + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: false, + status: 401, + text: () => Promise.resolve("Unauthorized"), + }); + + const adapter = createAdapter(); + + await expect(async () => { + for await (const _ of adapter.run( + { messages: createTestMessages(), tools: [], threadId: "t1" }, + { executeTool: vi.fn() }, + )) { + // drain + } + }).rejects.toThrow("Databricks API error (401): Unauthorized"); + }); +}); + +describe("DatabricksAdapter.fromServingEndpoint", () => { + test("routes tool-free chat through apiClient.request with a streaming payload", async () => { + const apiClient = { + request: vi.fn().mockResolvedValue({ + contents: createReadableStream([textDelta("Hi"), sseChunk("[DONE]")]), + }), + }; + + const adapter = await DatabricksAdapter.fromServingEndpoint({ + workspaceClient: { apiClient }, + endpointName: "my-model", + }); + + for await (const _ of adapter.run( + { messages: createTestMessages(), tools: [], threadId: "t1" }, + { executeTool: vi.fn() }, + )) { + // drain + } + + expect(apiClient.request).toHaveBeenCalledTimes(1); + const [requestArgs] = apiClient.request.mock.calls[0]; + expect(requestArgs.path).toBe("/serving-endpoints/my-model/invocations"); + expect(requestArgs.method).toBe("POST"); + expect(requestArgs.raw).toBe(true); + expect(requestArgs.payload.stream).toBe(true); + // Auth + url encoding are the connector's (and the SDK's) concerns — the + // adapter no longer reaches into the workspace config. + }); + + test("URL-encodes endpoint names with special characters", async () => { + const apiClient = { + request: vi.fn().mockResolvedValue({ + contents: createReadableStream([textDelta("Hi"), sseChunk("[DONE]")]), + }), + }; + + const adapter = await DatabricksAdapter.fromServingEndpoint({ + workspaceClient: { apiClient }, + endpointName: "my model/with spaces", + }); + + for await (const _ of adapter.run( + { messages: createTestMessages(), tools: [], threadId: "t1" }, + { executeTool: vi.fn() }, + )) { + // drain + } + + const [requestArgs] = apiClient.request.mock.calls[0]; + expect(requestArgs.path).toBe( + "/serving-endpoints/my%20model%2Fwith%20spaces/invocations", + ); + }); +}); + +describe("DatabricksAdapter.fromModelServing", () => { + const originalEnv = process.env; + + beforeEach(() => { + process.env = { ...originalEnv }; + }); + + afterEach(() => { + process.env = originalEnv; + }); + + test("reads endpoint from DATABRICKS_AGENT_ENDPOINT env var", async () => { + process.env.DATABRICKS_AGENT_ENDPOINT = "my-model"; + + vi.mock("@databricks/sdk-experimental", () => ({ + WorkspaceClient: vi.fn().mockImplementation(() => ({ + apiClient: { request: vi.fn() }, + })), + })); + + const adapter = await DatabricksAdapter.fromModelServing(); + expect(adapter).toBeInstanceOf(DatabricksAdapter); + }); + + test("throws when no endpoint name and no env var", async () => { + delete process.env.DATABRICKS_AGENT_ENDPOINT; + + await expect(DatabricksAdapter.fromModelServing()).rejects.toThrow( + "No endpoint name provided", + ); + }); + + test("explicit endpoint name takes precedence over env var", async () => { + process.env.DATABRICKS_AGENT_ENDPOINT = "env-model"; + + const apiClient = { + request: vi.fn().mockResolvedValue({ + contents: createReadableStream([textDelta("Hi"), sseChunk("[DONE]")]), + }), + }; + + const adapter = await DatabricksAdapter.fromModelServing("explicit-model", { + workspaceClient: { apiClient }, + }); + + expect(adapter).toBeInstanceOf(DatabricksAdapter); + + for await (const _ of adapter.run( + { messages: createTestMessages(), tools: [], threadId: "t1" }, + { executeTool: vi.fn() }, + )) { + // drain + } + + const [requestArgs] = apiClient.request.mock.calls[0]; + expect(requestArgs.path).toBe( + "/serving-endpoints/explicit-model/invocations", + ); + }); +}); + +describe("parseTextToolCalls", () => { + test("parses Llama JSON format", () => { + const text = + '[{"name": "analytics.query", "parameters": {"query": "SELECT 1"}}]'; + const result = parseTextToolCalls(text); + + expect(result).toEqual([ + { name: "analytics.query", args: { query: "SELECT 1" } }, + ]); + }); + + test("parses multiple Llama JSON tool calls", () => { + const text = + '[{"name": "analytics.query", "parameters": {"query": "SELECT 1"}}, {"name": "files.uploads.list", "parameters": {}}]'; + const result = parseTextToolCalls(text); + + expect(result).toHaveLength(2); + expect(result[0].name).toBe("analytics.query"); + expect(result[1].name).toBe("files.uploads.list"); + }); + + test("parses Python-style tool calls", () => { + const text = + "[analytics.query(query='SELECT * FROM trips ORDER BY date DESC LIMIT 10')]"; + const result = parseTextToolCalls(text); + + expect(result).toEqual([ + { + name: "analytics.query", + args: { + query: "SELECT * FROM trips ORDER BY date DESC LIMIT 10", + }, + }, + ]); + }); + + test("parses Python-style with multiple args", () => { + const text = + "[files.uploads.read(path='/data/file.csv', encoding='utf-8')]"; + const result = parseTextToolCalls(text); + + expect(result).toEqual([ + { + name: "files.uploads.read", + args: { path: "/data/file.csv", encoding: "utf-8" }, + }, + ]); + }); + + test("returns empty array for plain text", () => { + expect(parseTextToolCalls("Hello, how can I help?")).toEqual([]); + expect(parseTextToolCalls("")).toEqual([]); + expect(parseTextToolCalls("The answer is 42")).toEqual([]); + }); + + test("handles Llama format with 'arguments' key", () => { + const text = + '[{"name": "lakebase.query", "arguments": {"text": "SELECT 1"}}]'; + const result = parseTextToolCalls(text); + + expect(result).toEqual([ + { name: "lakebase.query", args: { text: "SELECT 1" } }, + ]); + }); +}); diff --git a/packages/appkit/src/agents/tests/langchain.test.ts b/packages/appkit/src/agents/tests/langchain.test.ts new file mode 100644 index 00000000..3bf1a471 --- /dev/null +++ b/packages/appkit/src/agents/tests/langchain.test.ts @@ -0,0 +1,366 @@ +import type { AgentEvent, AgentToolDefinition, Message } from "shared"; +import { describe, expect, test, vi } from "vitest"; +import { LangChainAdapter } from "../langchain"; + +vi.mock("@langchain/core/tools", () => ({ + DynamicStructuredTool: vi.fn().mockImplementation((config: any) => ({ + name: config.name, + description: config.description, + schema: config.schema, + func: config.func, + })), +})); + +vi.mock("zod", () => { + const createChainable = (base: Record = {}): any => { + const obj: any = { ...base }; + obj.optional = () => createChainable({ ...obj, _optional: true }); + obj.describe = (d: string) => createChainable({ ...obj, _description: d }); + return obj; + }; + + return { + z: { + object: (shape: any) => createChainable({ type: "object", shape }), + string: () => createChainable({ type: "string" }), + number: () => createChainable({ type: "number" }), + boolean: () => createChainable({ type: "boolean" }), + array: (item: any) => createChainable({ type: "array", item }), + enum: (vals: any) => createChainable({ type: "enum", values: vals }), + any: () => createChainable({ type: "any" }), + null: () => createChainable({ type: "null" }), + }, + }; +}); + +function createTestMessages(): Message[] { + return [{ id: "1", role: "user", content: "Hello", createdAt: new Date() }]; +} + +function createTestTools(): AgentToolDefinition[] { + return [ + { + name: "lakebase.query", + description: "Run SQL", + parameters: { + type: "object", + properties: { + text: { type: "string", description: "SQL query" }, + values: { type: "array", items: {} }, + }, + required: ["text"], + }, + }, + ]; +} + +describe("LangChainAdapter", () => { + test("yields status running on start and maps chat_model_stream", async () => { + async function* mockStreamEvents() { + yield { + event: "on_chat_model_stream", + data: { chunk: { content: "Hello" } }, + }; + yield { + event: "on_chat_model_stream", + data: { chunk: { content: " world" } }, + }; + } + + const mockRunnable = { + bindTools: vi.fn().mockReturnValue({ + streamEvents: vi.fn().mockResolvedValue(mockStreamEvents()), + }), + }; + + const adapter = new LangChainAdapter({ runnable: mockRunnable }); + const events: AgentEvent[] = []; + + for await (const event of adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { executeTool: vi.fn() }, + )) { + events.push(event); + } + + expect(events[0]).toEqual({ type: "status", status: "running" }); + expect(events[1]).toEqual({ type: "message_delta", content: "Hello" }); + expect(events[2]).toEqual({ type: "message_delta", content: " world" }); + }); + + test("maps on_tool_end events to tool_result", async () => { + async function* mockStreamEvents() { + yield { + event: "on_tool_end", + run_id: "run-1", + data: { output: { content: "42 rows" } }, + }; + } + + const mockRunnable = { + bindTools: vi.fn().mockReturnValue({ + streamEvents: vi.fn().mockResolvedValue(mockStreamEvents()), + }), + }; + + const adapter = new LangChainAdapter({ runnable: mockRunnable }); + const events: AgentEvent[] = []; + + for await (const event of adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { executeTool: vi.fn() }, + )) { + events.push(event); + } + + expect(events).toContainEqual({ + type: "tool_result", + callId: "run-1", + result: "42 rows", + }); + }); + + test("calls bindTools when tools are provided", async () => { + const streamEvents = vi.fn().mockResolvedValue((async function* () {})()); + const bindTools = vi.fn().mockReturnValue({ streamEvents }); + + const adapter = new LangChainAdapter({ + runnable: { bindTools }, + }); + + for await (const _ of adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { executeTool: vi.fn() }, + )) { + // drain + } + + expect(bindTools).toHaveBeenCalledTimes(1); + expect(bindTools.mock.calls[0][0]).toHaveLength(1); + expect(bindTools.mock.calls[0][0][0].name).toBe("lakebase.query"); + }); + + test("does not call bindTools when no tools provided", async () => { + const streamEvents = vi.fn().mockResolvedValue((async function* () {})()); + const bindTools = vi.fn().mockReturnValue({ streamEvents }); + + const adapter = new LangChainAdapter({ + runnable: { bindTools, streamEvents }, + }); + + for await (const _ of adapter.run( + { + messages: createTestMessages(), + tools: [], + threadId: "t1", + }, + { executeTool: vi.fn() }, + )) { + // drain + } + + expect(bindTools).not.toHaveBeenCalled(); + }); + + test("callId on tool_call and tool_result match across on_tool_start / on_tool_end", async () => { + // Simulates the realistic LangChain stream: first the chat-model emits a + // tool_call_chunk carrying the model-provided id and name; then + // on_tool_start fires with LangChain's own run_id; then on_tool_end + // fires with the same run_id. The adapter must yield tool_call with + // callId = model id, then tool_result with the SAME callId. + async function* mockStreamEvents() { + yield { + event: "on_chat_model_stream", + data: { + chunk: { + content: "", + tool_call_chunks: [ + { + index: 0, + id: "call_abc123", + name: "lakebase.query", + args: '{"text":"SELECT 1"}', + }, + ], + }, + }, + }; + yield { + event: "on_tool_start", + name: "lakebase.query", + run_id: "run-uuid-xyz", + data: { input: { text: "SELECT 1" } }, + }; + yield { + event: "on_tool_end", + name: "lakebase.query", + run_id: "run-uuid-xyz", + data: { output: { content: "42 rows" } }, + }; + } + + const mockRunnable = { + bindTools: vi.fn().mockReturnValue({ + streamEvents: vi.fn().mockResolvedValue(mockStreamEvents()), + }), + }; + + const adapter = new LangChainAdapter({ runnable: mockRunnable }); + const events: AgentEvent[] = []; + + for await (const event of adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { executeTool: vi.fn() }, + )) { + events.push(event); + } + + const toolCall = events.find((e) => e.type === "tool_call"); + const toolResult = events.find((e) => e.type === "tool_result"); + expect(toolCall).toBeDefined(); + expect(toolResult).toBeDefined(); + // Critical invariant: same callId on the pair, using the model-provided id + // (not LangChain's internal run_id). + expect(toolCall && "callId" in toolCall ? toolCall.callId : undefined).toBe( + "call_abc123", + ); + expect( + toolResult && "callId" in toolResult ? toolResult.callId : undefined, + ).toBe("call_abc123"); + }); + + test("same tool invoked twice in one turn gets unique callIds when model omits id", async () => { + async function* mockStreamEvents() { + yield { + event: "on_chat_model_stream", + data: { + chunk: { + tool_call_chunks: [ + // Two calls to the same tool, no model-provided id on either. + { index: 0, name: "search", args: '{"q":"a"}' }, + { index: 1, name: "search", args: '{"q":"b"}' }, + ], + }, + }, + }; + yield { + event: "on_tool_start", + name: "search", + run_id: "run-1", + data: { input: { q: "a" } }, + }; + yield { + event: "on_tool_end", + name: "search", + run_id: "run-1", + data: { output: { content: "A-result" } }, + }; + yield { + event: "on_tool_start", + name: "search", + run_id: "run-2", + data: { input: { q: "b" } }, + }; + yield { + event: "on_tool_end", + name: "search", + run_id: "run-2", + data: { output: { content: "B-result" } }, + }; + } + + const mockRunnable = { + bindTools: vi.fn().mockReturnValue({ + streamEvents: vi.fn().mockResolvedValue(mockStreamEvents()), + }), + }; + + const adapter = new LangChainAdapter({ runnable: mockRunnable }); + const events: AgentEvent[] = []; + for await (const ev of adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { executeTool: vi.fn() }, + )) { + events.push(ev); + } + + const calls = events.filter( + (e): e is AgentEvent & { type: "tool_call"; callId: string } => + e.type === "tool_call", + ); + const results = events.filter( + (e): e is AgentEvent & { type: "tool_result"; callId: string } => + e.type === "tool_result", + ); + expect(calls).toHaveLength(2); + expect(results).toHaveLength(2); + expect(calls[0].callId).not.toBe(calls[1].callId); + // Each result correlates with its call. + expect(results[0].callId).toBe(calls[0].callId); + expect(results[1].callId).toBe(calls[1].callId); + }); + + test("falls back to run_id as callId when on_tool_start has no accumulated match", async () => { + async function* mockStreamEvents() { + yield { + event: "on_tool_start", + name: "orphan_tool", + run_id: "run-orphan", + data: { input: { x: 1 } }, + }; + yield { + event: "on_tool_end", + run_id: "run-orphan", + data: { output: { content: "ok" } }, + }; + } + + const mockRunnable = { + bindTools: vi.fn().mockReturnValue({ + streamEvents: vi.fn().mockResolvedValue(mockStreamEvents()), + }), + }; + + const adapter = new LangChainAdapter({ runnable: mockRunnable }); + const events: AgentEvent[] = []; + for await (const ev of adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { executeTool: vi.fn() }, + )) { + events.push(ev); + } + + const toolCall = events.find((e) => e.type === "tool_call"); + const toolResult = events.find((e) => e.type === "tool_result"); + expect(toolCall && "callId" in toolCall ? toolCall.callId : undefined).toBe( + "run-orphan", + ); + expect( + toolResult && "callId" in toolResult ? toolResult.callId : undefined, + ).toBe("run-orphan"); + }); +}); diff --git a/packages/appkit/src/agents/tests/vercel-ai.test.ts b/packages/appkit/src/agents/tests/vercel-ai.test.ts new file mode 100644 index 00000000..7280c9aa --- /dev/null +++ b/packages/appkit/src/agents/tests/vercel-ai.test.ts @@ -0,0 +1,190 @@ +import type { AgentEvent, AgentToolDefinition, Message } from "shared"; +import { describe, expect, test, vi } from "vitest"; +import { VercelAIAdapter } from "../vercel-ai"; + +vi.mock("ai", () => ({ + streamText: vi.fn(), + jsonSchema: vi.fn((schema: any) => schema), +})); + +function createTestMessages(): Message[] { + return [ + { + id: "1", + role: "user", + content: "Hello", + createdAt: new Date(), + }, + ]; +} + +function createTestTools(): AgentToolDefinition[] { + return [ + { + name: "analytics.query", + description: "Run SQL", + parameters: { + type: "object", + properties: { + query: { type: "string" }, + }, + required: ["query"], + }, + }, + ]; +} + +describe("VercelAIAdapter", () => { + test("yields status running on start", async () => { + const { streamText } = await import("ai"); + + async function* mockStream() { + yield { type: "text-delta", textDelta: "Hi" }; + } + + (streamText as any).mockReturnValue({ + fullStream: mockStream(), + }); + + const adapter = new VercelAIAdapter({ model: {} }); + const events: AgentEvent[] = []; + + const stream = adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { + executeTool: vi.fn(), + }, + ); + + for await (const event of stream) { + events.push(event); + } + + expect(events[0]).toEqual({ type: "status", status: "running" }); + expect(events[1]).toEqual({ type: "message_delta", content: "Hi" }); + }); + + test("maps tool-call and tool-result events", async () => { + const { streamText } = await import("ai"); + + async function* mockStream() { + yield { + type: "tool-call", + toolCallId: "c1", + toolName: "analytics.query", + args: { query: "SELECT 1" }, + }; + yield { + type: "tool-result", + toolCallId: "c1", + result: [{ value: 1 }], + }; + } + + (streamText as any).mockReturnValue({ + fullStream: mockStream(), + }); + + const adapter = new VercelAIAdapter({ model: {} }); + const events: AgentEvent[] = []; + + for await (const event of adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { executeTool: vi.fn() }, + )) { + events.push(event); + } + + expect(events).toContainEqual({ + type: "tool_call", + callId: "c1", + name: "analytics.query", + args: { query: "SELECT 1" }, + }); + + expect(events).toContainEqual({ + type: "tool_result", + callId: "c1", + result: [{ value: 1 }], + }); + }); + + test("maps error events", async () => { + const { streamText } = await import("ai"); + + async function* mockStream() { + yield { type: "error", error: "API rate limited" }; + } + + (streamText as any).mockReturnValue({ + fullStream: mockStream(), + }); + + const adapter = new VercelAIAdapter({ model: {} }); + const events: AgentEvent[] = []; + + for await (const event of adapter.run( + { + messages: createTestMessages(), + tools: [], + threadId: "t1", + }, + { executeTool: vi.fn() }, + )) { + events.push(event); + } + + expect(events).toContainEqual({ + type: "status", + status: "error", + error: "API rate limited", + }); + }); + + test("builds tools with execute functions that delegate to executeTool", async () => { + const { streamText } = await import("ai"); + + let capturedTools: Record = {}; + + (streamText as any).mockImplementation((opts: any) => { + capturedTools = opts.tools; + return { + fullStream: (async function* () {})(), + }; + }); + + const executeTool = vi.fn().mockResolvedValue({ count: 42 }); + const adapter = new VercelAIAdapter({ model: {} }); + + // Consume the stream to trigger streamText + for await (const _ of adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { executeTool }, + )) { + // drain + } + + expect(capturedTools["analytics.query"]).toBeDefined(); + expect(capturedTools["analytics.query"].description).toBe("Run SQL"); + + const result = await capturedTools["analytics.query"].execute({ + query: "SELECT 1", + }); + expect(executeTool).toHaveBeenCalledWith("analytics.query", { + query: "SELECT 1", + }); + expect(result).toEqual({ count: 42 }); + }); +}); diff --git a/packages/appkit/src/agents/vercel-ai.ts b/packages/appkit/src/agents/vercel-ai.ts new file mode 100644 index 00000000..ea77771a --- /dev/null +++ b/packages/appkit/src/agents/vercel-ai.ts @@ -0,0 +1,138 @@ +import type { + AgentAdapter, + AgentEvent, + AgentInput, + AgentRunContext, + AgentToolDefinition, +} from "shared"; + +/** + * Adapter bridging the Vercel AI SDK (`ai` package) to the AppKit agent protocol. + * + * Converts `AgentToolDefinition[]` to Vercel AI tool format and maps + * `streamText().fullStream` events to `AgentEvent`. + * + * Requires `ai` as an optional peer dependency. + * + * @example + * ```ts + * import { createApp, createAgent, agents } from "@databricks/appkit"; + * import { VercelAIAdapter } from "@databricks/appkit/agents/vercel-ai"; + * import { openai } from "@ai-sdk/openai"; + * + * await createApp({ + * plugins: [ + * agents({ + * agents: { + * assistant: createAgent({ + * instructions: "You are a helpful assistant.", + * model: new VercelAIAdapter({ model: openai("gpt-4o") }), + * }), + * }, + * }), + * ], + * }); + * ``` + */ +export class VercelAIAdapter implements AgentAdapter { + private model: any; + + constructor(options: { model: any }) { + this.model = options.model; + } + + async *run( + input: AgentInput, + context: AgentRunContext, + ): AsyncGenerator { + const { streamText } = await import("ai"); + const { jsonSchema } = await import("ai"); + + const tools = this.buildTools(input.tools, context, jsonSchema); + + const messages = input.messages.map((m) => ({ + role: m.role as "user" | "assistant" | "system", + content: m.content, + })); + + yield { type: "status", status: "running" }; + + const result = streamText({ + model: this.model, + messages, + tools, + maxSteps: 10 as any, + abortSignal: input.signal, + } as any); + + for await (const part of (result as any).fullStream) { + if (context.signal?.aborted) break; + + switch (part.type) { + case "text-delta": + yield { type: "message_delta", content: part.textDelta }; + break; + + case "tool-call": + yield { + type: "tool_call", + callId: part.toolCallId, + name: part.toolName, + args: part.args, + }; + break; + + case "tool-result": + yield { + type: "tool_result", + callId: part.toolCallId, + result: part.result, + }; + break; + + case "reasoning": + if (part.textDelta) { + yield { type: "thinking", content: part.textDelta }; + } + break; + + case "error": + yield { + type: "status", + status: "error", + error: String(part.error), + }; + break; + } + } + } + + private buildTools( + definitions: AgentToolDefinition[], + context: AgentRunContext, + jsonSchema: any, + ): Record { + const tools: Record = {}; + + for (const def of definitions) { + tools[def.name] = { + description: def.description, + parameters: jsonSchema(def.parameters), + execute: async (args: unknown) => { + try { + return await context.executeTool(def.name, args); + } catch (error) { + return { + error: + error instanceof Error + ? error.message + : "Tool execution failed", + }; + } + }, + }; + } + + return tools; + } +} diff --git a/packages/appkit/src/connectors/index.ts b/packages/appkit/src/connectors/index.ts index 54a24fa4..daae9439 100644 --- a/packages/appkit/src/connectors/index.ts +++ b/packages/appkit/src/connectors/index.ts @@ -2,5 +2,6 @@ export * from "./files"; export * from "./genie"; export * from "./lakebase"; export * from "./lakebase-v1"; +export * from "./mcp"; export * from "./sql-warehouse"; export * from "./vector-search"; diff --git a/packages/appkit/src/connectors/mcp/client.ts b/packages/appkit/src/connectors/mcp/client.ts new file mode 100644 index 00000000..4c8d058b --- /dev/null +++ b/packages/appkit/src/connectors/mcp/client.ts @@ -0,0 +1,404 @@ +/** + * Custom MCP over HTTP (Streamable) — not `@modelcontextprotocol/sdk` + * + * This module implements a tiny JSON-RPC 2.0 client on `fetch` for the subset + * of MCP we need: `initialize`, `notifications/initialized`, `tools/list`, + * `tools/call` over a single JSON request/response. We do not use the official + * SDK because: + * + * - **Policy and auth are the product** — every outbound URL is checked with + * {@link McpHostPolicy} (allowlist, DNS, private/blocked IP ranges) before + * the first byte is sent, and workspace tokens are only forwarded when + * `forwardWorkspaceAuth` is true for that destination. A generic transport + * from the SDK would still need the same hooks; re-wrapping it would be + * about as much code, with a larger third-party surface to audit. + * - **Narrow scope** — we only target Databricks-hosted MCP over Streamable + * HTTP, not stdio, full SSE sessions, or the rest of the protocol. A + * hand-rolled path keeps the call graph obvious in code review. + * - **Zero extra runtime dependency** for this path, consistent with other + * small, security-sensitive AppKit pieces. + * + * Revisit if we add more transports, or if the SDK ships a first-class way to + * inject our host policy and per-URL auth without fighting the default + * transport. + */ +import type { AgentToolDefinition } from "shared"; +import { createLogger } from "../../logging/logger"; +import { + assertResolvedHostSafe, + checkMcpUrl, + type DnsLookup, + type McpHostPolicy, +} from "./host-policy"; +import type { McpEndpointConfig } from "./types"; + +const logger = createLogger("connector:mcp"); + +interface JsonRpcRequest { + jsonrpc: "2.0"; + id: number; + method: string; + params?: Record; +} + +interface JsonRpcResponse { + jsonrpc: "2.0"; + id: number; + result?: unknown; + error?: { code: number; message: string; data?: unknown }; +} + +interface McpToolSchema { + name: string; + description?: string; + inputSchema?: Record; +} + +interface McpToolCallResult { + content: Array<{ type: string; text?: string }>; + isError?: boolean; +} + +interface McpServerConnection { + config: McpEndpointConfig; + resolvedUrl: string; + /** + * Whether workspace auth (SP / OBO) may be forwarded to this endpoint's URL. + * Decided at `connect()` time via {@link McpHostPolicy} and cached for the + * lifetime of the connection. + */ + forwardWorkspaceAuth: boolean; + tools: Map; +} + +/** + * Lightweight MCP client for Databricks-hosted MCP servers. + * + * Uses raw fetch() with JSON-RPC 2.0 over HTTP — no @modelcontextprotocol/sdk + * or LangChain dependency. Supports the Streamable HTTP transport only + * (POST with JSON-RPC request, single JSON-RPC response). Implements exactly + * four methods: `initialize`, `notifications/initialized`, `tools/list`, + * `tools/call`. No prompts/resources/completion/sampling. + * + * All outbound URLs are gated by an {@link McpHostPolicy}: unallowlisted hosts + * are rejected before the first byte is sent, and workspace credentials are + * only forwarded to the same-origin workspace. See `mcp-host-policy.ts`. + * + * Rationale for hand-rolling JSON-RPC instead of `@modelcontextprotocol/sdk`: + * see the file-level comment at the top of this module. + */ +export class AppKitMcpClient { + private connections = new Map(); + private sessionIds = new Map(); + private requestId = 0; + private closed = false; + + constructor( + private workspaceHost: string, + private authenticate: () => Promise>, + private policy: McpHostPolicy, + private options: { dnsLookup?: DnsLookup; fetchImpl?: typeof fetch } = {}, + ) {} + + async connectAll(endpoints: McpEndpointConfig[]): Promise { + const results = await Promise.allSettled( + endpoints.map((ep) => this.connect(ep)), + ); + for (let i = 0; i < results.length; i++) { + if (results[i].status === "rejected") { + logger.error( + "Failed to connect MCP server %s: %O", + endpoints[i].name, + (results[i] as PromiseRejectedResult).reason, + ); + } + } + } + + private resolveUrl(endpoint: McpEndpointConfig): string { + if ( + endpoint.url.startsWith("http://") || + endpoint.url.startsWith("https://") + ) { + return endpoint.url; + } + return `${this.workspaceHost}${endpoint.url}`; + } + + async connect(endpoint: McpEndpointConfig): Promise { + const resolvedUrl = this.resolveUrl(endpoint); + const check = checkMcpUrl(resolvedUrl, this.policy); + if (!check.ok) { + throw new Error( + `MCP endpoint '${endpoint.name}' refused at connect: ${check.reason}`, + ); + } + await assertResolvedHostSafe( + check.url.hostname, + this.policy, + this.options.dnsLookup, + ); + + logger.info( + "Connecting to MCP server: %s at %s (forwardWorkspaceAuth=%s)", + endpoint.name, + resolvedUrl, + check.forwardWorkspaceAuth, + ); + + const initResponse = await this.sendRpc( + resolvedUrl, + "initialize", + { + protocolVersion: "2025-03-26", + capabilities: {}, + clientInfo: { name: "appkit-agent", version: "0.1.0" }, + }, + { forwardWorkspaceAuth: check.forwardWorkspaceAuth }, + ); + + if (initResponse.sessionId) { + this.sessionIds.set(endpoint.name, initResponse.sessionId); + } + const sessionId = this.sessionIds.get(endpoint.name); + + await this.sendNotification(resolvedUrl, "notifications/initialized", { + sessionId, + forwardWorkspaceAuth: check.forwardWorkspaceAuth, + }); + + const listResponse = await this.sendRpc( + resolvedUrl, + "tools/list", + {}, + { sessionId, forwardWorkspaceAuth: check.forwardWorkspaceAuth }, + ); + const toolList = + (listResponse.result as { tools?: McpToolSchema[] })?.tools ?? []; + + const tools = new Map(); + for (const tool of toolList) { + tools.set(tool.name, tool); + } + + this.connections.set(endpoint.name, { + config: endpoint, + resolvedUrl, + forwardWorkspaceAuth: check.forwardWorkspaceAuth, + tools, + }); + logger.info( + "Connected to MCP server %s: %d tools available", + endpoint.name, + tools.size, + ); + } + + getAllToolDefinitions(): AgentToolDefinition[] { + const defs: AgentToolDefinition[] = []; + for (const [serverName, conn] of this.connections) { + for (const [toolName, schema] of conn.tools) { + defs.push({ + name: `mcp.${serverName}.${toolName}`, + description: schema.description ?? toolName, + parameters: + (schema.inputSchema as AgentToolDefinition["parameters"]) ?? { + type: "object", + properties: {}, + }, + }); + } + } + return defs; + } + + /** + * Whether the named MCP server may receive workspace-scoped auth headers + * (e.g., an OBO bearer token from an end-user request). Callers should gate + * auth-forwarding decisions on this to prevent credential exfiltration to + * non-workspace hosts. + */ + canForwardWorkspaceAuth(serverName: string): boolean { + return this.connections.get(serverName)?.forwardWorkspaceAuth ?? false; + } + + async callTool( + qualifiedName: string, + args: unknown, + authHeaders?: Record, + callerSignal?: AbortSignal, + ): Promise { + const parts = qualifiedName.split("."); + if (parts.length < 3 || parts[0] !== "mcp") { + throw new Error(`Invalid MCP tool name: ${qualifiedName}`); + } + const serverName = parts[1]; + const toolName = parts.slice(2).join("."); + + const conn = this.connections.get(serverName); + if (!conn) { + throw new Error(`MCP server not connected: ${serverName}`); + } + + const sessionId = this.sessionIds.get(serverName); + // authHeaders are caller-supplied credentials (typically the OBO token). + // Only honor them if the destination URL was admitted with + // forwardWorkspaceAuth=true at connect time. + const scopedAuthOverride = conn.forwardWorkspaceAuth + ? authHeaders + : undefined; + + const rpcResult = await this.sendRpc( + conn.resolvedUrl, + "tools/call", + { name: toolName, arguments: args }, + { + authOverride: scopedAuthOverride, + sessionId, + forwardWorkspaceAuth: conn.forwardWorkspaceAuth, + callerSignal, + }, + ); + const result = rpcResult.result as McpToolCallResult; + + if (result.isError) { + const errText = (result.content ?? []) + .filter((c) => c.type === "text") + .map((c) => c.text) + .join("\n"); + throw new Error(errText || "MCP tool call failed"); + } + + return (result.content ?? []) + .filter((c) => c.type === "text") + .map((c) => c.text) + .join("\n"); + } + + async close(): Promise { + this.closed = true; + this.connections.clear(); + this.sessionIds.clear(); + } + + private async sendRpc( + url: string, + method: string, + params?: Record, + options?: { + authOverride?: Record; + sessionId?: string; + forwardWorkspaceAuth?: boolean; + /** + * Optional external abort signal (typically the agent's stream signal). + * Composed with the built-in 30 s timeout so `/cancel` or agent-run + * shutdown immediately propagates to the MCP fetch rather than waiting + * for the remote server to respond. + */ + callerSignal?: AbortSignal; + }, + ): Promise<{ result: unknown; sessionId?: string }> { + if (this.closed) throw new Error("MCP client is closed"); + + const request: JsonRpcRequest = { + jsonrpc: "2.0", + id: ++this.requestId, + method, + ...(params && { params }), + }; + + const authHeaders = await this.resolveAuthHeaders(options); + const headers: Record = { + "Content-Type": "application/json", + Accept: "application/json, text/event-stream", + ...authHeaders, + }; + if (options?.sessionId) { + headers["Mcp-Session-Id"] = options.sessionId; + } + + const fetchImpl = this.options.fetchImpl ?? fetch; + const signals: AbortSignal[] = [AbortSignal.timeout(30_000)]; + if (options?.callerSignal) signals.push(options.callerSignal); + const response = await fetchImpl(url, { + method: "POST", + headers, + body: JSON.stringify(request), + signal: signals.length > 1 ? AbortSignal.any(signals) : signals[0], + }); + + if (!response.ok) { + throw new Error( + `MCP request to ${method} failed: ${response.status} ${response.statusText}`, + ); + } + + const contentType = response.headers.get("content-type") ?? ""; + let json: JsonRpcResponse; + + if (contentType.includes("text/event-stream")) { + const text = await response.text(); + const lastData = text + .split("\n") + .filter((line) => line.startsWith("data: ")) + .map((line) => line.slice(6)) + .pop(); + if (!lastData) { + throw new Error(`MCP SSE response for ${method} contained no data`); + } + json = JSON.parse(lastData) as JsonRpcResponse; + } else { + json = (await response.json()) as JsonRpcResponse; + } + + if (json.error) { + throw new Error(`MCP error (${json.error.code}): ${json.error.message}`); + } + + const sid = response.headers.get("mcp-session-id") ?? undefined; + return { result: json.result, sessionId: sid }; + } + + private async sendNotification( + url: string, + method: string, + options?: { + sessionId?: string; + forwardWorkspaceAuth?: boolean; + }, + ): Promise { + if (this.closed) return; + + const authHeaders = await this.resolveAuthHeaders(options); + const headers: Record = { + "Content-Type": "application/json", + Accept: "application/json, text/event-stream", + ...authHeaders, + }; + if (options?.sessionId) { + headers["Mcp-Session-Id"] = options.sessionId; + } + + const fetchImpl = this.options.fetchImpl ?? fetch; + await fetchImpl(url, { + method: "POST", + headers, + body: JSON.stringify({ jsonrpc: "2.0", method }), + signal: AbortSignal.timeout(30_000), + }); + } + + /** + * Return the auth headers to send on an outbound request. Workspace auth + * (SP or OBO) is only resolved when `forwardWorkspaceAuth` is true; for + * non-workspace hosts no bearer token is attached. + */ + private async resolveAuthHeaders(options?: { + authOverride?: Record; + forwardWorkspaceAuth?: boolean; + }): Promise> { + if (!options?.forwardWorkspaceAuth) return {}; + if (options.authOverride) return options.authOverride; + return this.authenticate(); + } +} diff --git a/packages/appkit/src/connectors/mcp/host-policy.ts b/packages/appkit/src/connectors/mcp/host-policy.ts new file mode 100644 index 00000000..d970c83a --- /dev/null +++ b/packages/appkit/src/connectors/mcp/host-policy.ts @@ -0,0 +1,299 @@ +import { lookup as defaultLookup } from "node:dns/promises"; +import { isIP, isIPv4 } from "node:net"; + +/** + * DNS lookup function compatible with `dns/promises.lookup(host, { all: true })`. + * Exposed as an injection point so callers (tests, custom DNS resolvers) can + * override the default resolver. + */ +export type DnsLookup = ( + hostname: string, + options: { all: true }, +) => Promise>; + +/** + * Policy that decides whether a given MCP endpoint URL is allowed and whether + * Databricks workspace credentials (SP or OBO) may be forwarded to it. + * + * The default posture is zero-trust: only same-origin workspace URLs receive + * workspace credentials, and all other destinations must be explicitly + * allowlisted by the application developer. Private / link-local IP ranges + * are blocked outright to prevent SSRF into cloud metadata services. + */ +export interface McpHostPolicy { + /** Lowercased hostname of the Databricks workspace (same-origin target). */ + readonly workspaceHostname: string; + /** Additional allowlisted hostnames (lowercased). Workspace auth is NEVER forwarded to these. */ + readonly trustedHosts: ReadonlySet; + /** Permit `http://localhost`, `127.0.0.1`, `::1` URLs. Typically true only in development. */ + readonly allowLocalhost: boolean; +} + +/** + * Config shape accepted by {@link buildMcpHostPolicy}, matching the + * `mcp` field on `AgentsPluginConfig`. + */ +export interface McpHostPolicyConfig { + /** + * Additional hostnames that may host custom MCP servers beyond the same-origin + * workspace. Compared case-insensitively; bare hostnames only (no scheme or + * path). Workspace credentials (SP / OBO) are never forwarded to these hosts — + * they must handle authentication themselves. + */ + trustedHosts?: string[]; + /** + * Allow `http://localhost`, `127.0.0.1`, and `::1` MCP URLs for local + * development. Defaults to `true` when `NODE_ENV !== "production"`, + * otherwise `false`. Workspace credentials are never forwarded to localhost. + */ + allowLocalhost?: boolean; +} + +/** Build an {@link McpHostPolicy} from user config + the resolved workspace URL. */ +export function buildMcpHostPolicy( + config: McpHostPolicyConfig | undefined, + workspaceHost: string, +): McpHostPolicy { + const workspaceHostname = safeHostname(workspaceHost); + if (!workspaceHostname) { + throw new Error( + `Invalid workspace host for MCP policy: ${JSON.stringify(workspaceHost)}`, + ); + } + const trustedHosts = new Set( + (config?.trustedHosts ?? []).map((h) => h.trim().toLowerCase()), + ); + const allowLocalhost = + config?.allowLocalhost ?? process.env.NODE_ENV !== "production"; + return { workspaceHostname, trustedHosts, allowLocalhost }; +} + +type McpUrlCheck = + | { + readonly ok: true; + /** Whether it is safe to forward workspace-scoped credentials (SP/OBO) to this URL. */ + readonly forwardWorkspaceAuth: boolean; + /** Parsed URL for reuse by the caller. */ + readonly url: URL; + } + | { readonly ok: false; readonly reason: string }; + +/** + * Synchronously decide whether an MCP URL is allowed under the given policy + * and whether workspace credentials may be forwarded to it. + * + * Hard rejections: + * - Non-`http(s)` schemes. + * - `http://` unless the host is localhost AND `allowLocalhost` is true. + * - Hosts that are neither same-origin workspace, localhost (if allowed), + * nor in `trustedHosts`. + */ +export function checkMcpUrl( + rawUrl: string, + policy: McpHostPolicy, +): McpUrlCheck { + let url: URL; + try { + url = new URL(rawUrl); + } catch { + return { + ok: false, + reason: `MCP URL is not a valid absolute URL: ${rawUrl}`, + }; + } + + if (url.protocol !== "http:" && url.protocol !== "https:") { + return { + ok: false, + reason: `MCP URL scheme '${url.protocol}' is not allowed (http(s) only): ${rawUrl}`, + }; + } + + const host = url.hostname.toLowerCase(); + const isLoopback = isLoopbackHost(host); + + if (url.protocol === "http:" && !(isLoopback && policy.allowLocalhost)) { + return { + ok: false, + reason: `MCP URL uses plaintext http:// which forwards bearer tokens in cleartext: ${rawUrl}. Use https:// or enable allowLocalhost for a localhost dev server.`, + }; + } + + if (host === policy.workspaceHostname) { + return { ok: true, forwardWorkspaceAuth: true, url }; + } + + if (isLoopback) { + if (!policy.allowLocalhost) { + return { + ok: false, + reason: `MCP URL points to localhost but allowLocalhost is disabled: ${rawUrl}`, + }; + } + return { ok: true, forwardWorkspaceAuth: false, url }; + } + + if (policy.trustedHosts.has(host)) { + return { ok: true, forwardWorkspaceAuth: false, url }; + } + + return { + ok: false, + reason: `MCP host '${host}' is not allowed. Either use a same-origin workspace URL (${policy.workspaceHostname}) or add it to agents({ mcp: { trustedHosts: ['${host}'] } }).`, + }; +} + +/** + * Resolve `hostname` via DNS and assert that none of its addresses fall in a + * blocked IP range (loopback, RFC1918, link-local, CGNAT, cloud metadata). + * + * Throws with a descriptive error if any resolved address is blocked. Pass + * `allowLocalhost: true` to permit `127.0.0.1` / `::1` specifically. + * + * Note: this only guards against hosts that statically resolve to private + * ranges. Full SSRF protection requires socket-level IP pinning after + * resolution (DNS rebinding defense), which is out of scope here. + */ +export async function assertResolvedHostSafe( + hostname: string, + policy: McpHostPolicy, + lookup: DnsLookup = defaultLookup, +): Promise { + const lowered = hostname.toLowerCase(); + + if (isIP(lowered)) { + if (isBlockedIp(lowered, policy.allowLocalhost)) { + throw new Error(`MCP host ${lowered} is in a blocked IP range`); + } + return; + } + + if (lowered === "localhost") { + if (!policy.allowLocalhost) { + throw new Error( + `MCP host localhost is not allowed under the current policy`, + ); + } + return; + } + + let resolved: Array<{ address: string }>; + try { + resolved = await lookup(hostname, { all: true }); + } catch (cause) { + throw new Error( + `MCP host ${hostname} could not be resolved via DNS: ${cause instanceof Error ? cause.message : String(cause)}`, + ); + } + + if (resolved.length === 0) { + throw new Error(`MCP host ${hostname} returned no DNS addresses`); + } + + for (const { address } of resolved) { + if (isBlockedIp(address, policy.allowLocalhost)) { + throw new Error( + `MCP host ${hostname} resolved to blocked address ${address} (private / link-local ranges are not allowed)`, + ); + } + } +} + +/** Whether a raw hostname literal is one of the recognised loopback aliases. */ +export function isLoopbackHost(host: string): boolean { + const lowered = host.toLowerCase(); + return ( + lowered === "localhost" || + lowered === "127.0.0.1" || + lowered === "::1" || + lowered === "[::1]" || + lowered === "0:0:0:0:0:0:0:1" + ); +} + +/** + * Check whether a resolved IP address is in a range that should never receive + * workspace credentials. `allowLocalhost` carves out 127.0.0.0/8 and ::1. + */ +export function isBlockedIp(address: string, allowLocalhost: boolean): boolean { + if (isIPv4(address)) { + return isBlockedIpv4(address, allowLocalhost); + } + if (isIP(address) === 6) { + return isBlockedIpv6(address, allowLocalhost); + } + // Not a recognisable IP literal — fail-closed. + return true; +} + +function isBlockedIpv4(addr: string, allowLocalhost: boolean): boolean { + const parts = addr.split(".").map((p) => Number.parseInt(p, 10)); + if (parts.length !== 4 || parts.some((n) => !Number.isFinite(n))) { + return true; + } + const [a, b] = parts; + if (a === 0) return true; + if (a === 127) return !allowLocalhost; + if (a === 10) return true; + if (a === 172 && b >= 16 && b <= 31) return true; + if (a === 192 && b === 168) return true; + if (a === 169 && b === 254) return true; + if (a === 100 && b >= 64 && b <= 127) return true; + if (a >= 224) return true; + return false; +} + +function isBlockedIpv6(addr: string, allowLocalhost: boolean): boolean { + const lowered = addr.toLowerCase().replace(/^\[|\]$/g, ""); + + if (lowered === "::") return true; + if (lowered === "::1" || lowered === "0:0:0:0:0:0:0:1") + return !allowLocalhost; + + // IPv4-mapped IPv6: `::ffff:` may be written in dotted form + // (`::ffff:169.254.169.254`) or colon-hex form (`::ffff:a9fe:a9fe`). Both + // route to the same destination, so we must normalise before delegating + // to the IPv4 blocklist. + if (lowered.startsWith("::ffff:")) { + const tail = lowered.slice("::ffff:".length); + if (isIPv4(tail)) return isBlockedIpv4(tail, allowLocalhost); + const hexV4 = hexPairToDottedIpv4(tail); + if (hexV4) return isBlockedIpv4(hexV4, allowLocalhost); + } + + // Unique Local Addresses (fc00::/7) — `fc` and `fd` only. + if (/^f[cd][0-9a-f]{2}:/.test(lowered)) return true; + // Link-local fe80::/10 — the first 10 bits are 1111111010, i.e. the + // second hex nibble must be 8-b. Matches fe80:..–febf:.. + if (/^fe[89ab][0-9a-f]:/.test(lowered)) return true; + // Multicast ff00::/8. + if (lowered.startsWith("ff")) return true; + return false; +} + +/** + * Parse the trailing two hex groups of an IPv4-mapped IPv6 address written + * in colon-hex form (e.g. `a9fe:a9fe`) into the equivalent dotted-quad IPv4 + * representation (`169.254.169.254`). Returns null for anything else. + */ +function hexPairToDottedIpv4(tail: string): string | null { + const match = tail.match(/^([0-9a-f]{1,4}):([0-9a-f]{1,4})$/); + if (!match) return null; + const hi = Number.parseInt(match[1], 16); + const lo = Number.parseInt(match[2], 16); + if (!Number.isFinite(hi) || !Number.isFinite(lo)) return null; + if (hi < 0 || hi > 0xffff || lo < 0 || lo > 0xffff) return null; + const a = (hi >> 8) & 0xff; + const b = hi & 0xff; + const c = (lo >> 8) & 0xff; + const d = lo & 0xff; + return `${a}.${b}.${c}.${d}`; +} + +function safeHostname(rawUrl: string): string | null { + try { + return new URL(rawUrl).hostname.toLowerCase(); + } catch { + return null; + } +} diff --git a/packages/appkit/src/connectors/mcp/index.ts b/packages/appkit/src/connectors/mcp/index.ts new file mode 100644 index 00000000..f9f32a41 --- /dev/null +++ b/packages/appkit/src/connectors/mcp/index.ts @@ -0,0 +1,6 @@ +export { AppKitMcpClient } from "./client"; +export { + buildMcpHostPolicy, + type McpHostPolicyConfig, +} from "./host-policy"; +export type { McpEndpointConfig } from "./types"; diff --git a/packages/appkit/src/connectors/mcp/tests/client.test.ts b/packages/appkit/src/connectors/mcp/tests/client.test.ts new file mode 100644 index 00000000..0cdffa29 --- /dev/null +++ b/packages/appkit/src/connectors/mcp/tests/client.test.ts @@ -0,0 +1,402 @@ +import { beforeEach, describe, expect, test, vi } from "vitest"; +import { AppKitMcpClient } from "../client"; +import type { DnsLookup, McpHostPolicy } from "../host-policy"; + +const WORKSPACE = "https://test-workspace.cloud.databricks.com"; + +const workspacePolicy: McpHostPolicy = { + workspaceHostname: "test-workspace.cloud.databricks.com", + trustedHosts: new Set(), + allowLocalhost: false, +}; + +const trustedExternalPolicy: McpHostPolicy = { + workspaceHostname: "test-workspace.cloud.databricks.com", + trustedHosts: new Set(["mcp.example.com"]), + allowLocalhost: false, +}; + +const publicDnsLookup: DnsLookup = async () => [ + { address: "203.0.113.42", family: 4 }, +]; + +const workspaceAuth = async (): Promise> => ({ + Authorization: "Bearer SP-TOKEN", +}); + +type FetchCall = { + url: string; + init: RequestInit; +}; + +function recordingFetch( + responders: Array<(call: FetchCall) => Response | Promise>, +) { + const calls: FetchCall[] = []; + let n = 0; + const fetchImpl: typeof fetch = async (input, init) => { + const url = typeof input === "string" ? input : (input as URL).toString(); + const call: FetchCall = { url, init: init ?? {} }; + calls.push(call); + const responder = responders[n++] ?? responders[responders.length - 1]; + return Promise.resolve(responder(call)); + }; + return { fetchImpl, calls }; +} + +function jsonResponse(body: unknown, headers: Record = {}) { + return new Response(JSON.stringify(body), { + status: 200, + headers: { "content-type": "application/json", ...headers }, + }); +} + +describe("AppKitMcpClient — host allowlist", () => { + let authSpy: ReturnType; + + beforeEach(() => { + authSpy = vi.fn(workspaceAuth); + }); + + test("connect rejects a URL whose host is not allowlisted without making any fetch", async () => { + const { fetchImpl, calls } = recordingFetch([() => jsonResponse({})]); + const client = new AppKitMcpClient(WORKSPACE, authSpy, workspacePolicy, { + fetchImpl, + dnsLookup: publicDnsLookup, + }); + await expect( + client.connect({ name: "evil", url: "https://attacker.example.com/mcp" }), + ).rejects.toThrow(/attacker\.example\.com/); + expect(calls).toHaveLength(0); + expect(authSpy).not.toHaveBeenCalled(); + }); + + test("connect rejects plaintext http:// for remote hosts", async () => { + const { fetchImpl, calls } = recordingFetch([() => jsonResponse({})]); + const client = new AppKitMcpClient( + WORKSPACE, + authSpy, + trustedExternalPolicy, + { fetchImpl, dnsLookup: publicDnsLookup }, + ); + await expect( + client.connect({ name: "plain", url: "http://mcp.example.com/mcp" }), + ).rejects.toThrow(/plaintext http/); + expect(calls).toHaveLength(0); + expect(authSpy).not.toHaveBeenCalled(); + }); + + test("connect rejects a URL whose DNS resolves to a blocked IP and never sends SP token", async () => { + const ssrfLookup: DnsLookup = async () => [ + { address: "169.254.169.254", family: 4 }, + ]; + const policy: McpHostPolicy = { + workspaceHostname: "test-workspace.cloud.databricks.com", + trustedHosts: new Set(["evil.example.com"]), + allowLocalhost: false, + }; + const { fetchImpl, calls } = recordingFetch([() => jsonResponse({})]); + const client = new AppKitMcpClient(WORKSPACE, authSpy, policy, { + fetchImpl, + dnsLookup: ssrfLookup, + }); + await expect( + client.connect({ name: "evil", url: "https://evil.example.com/mcp" }), + ).rejects.toThrow(/169\.254\.169\.254/); + expect(calls).toHaveLength(0); + expect(authSpy).not.toHaveBeenCalled(); + }); + + test("connect to same-origin workspace forwards SP token on initialize + tools/list", async () => { + const { fetchImpl, calls } = recordingFetch([ + () => + jsonResponse( + { jsonrpc: "2.0", id: 1, result: {} }, + { + "mcp-session-id": "sess-1", + }, + ), + () => jsonResponse({ jsonrpc: "2.0", result: null }), + () => + jsonResponse({ + jsonrpc: "2.0", + id: 3, + result: { tools: [{ name: "echo", description: "Echo" }] }, + }), + ]); + const client = new AppKitMcpClient(WORKSPACE, authSpy, workspacePolicy, { + fetchImpl, + dnsLookup: publicDnsLookup, + }); + + await client.connect({ + name: "genie-1", + url: `${WORKSPACE}/api/2.0/mcp/genie/abc`, + }); + + // initialize + notifications/initialized + tools/list all carry SP token + expect(calls.map((c) => c.url)).toEqual([ + `${WORKSPACE}/api/2.0/mcp/genie/abc`, + `${WORKSPACE}/api/2.0/mcp/genie/abc`, + `${WORKSPACE}/api/2.0/mcp/genie/abc`, + ]); + for (const call of calls) { + const headers = call.init.headers as Record; + expect(headers.Authorization).toBe("Bearer SP-TOKEN"); + } + expect(client.canForwardWorkspaceAuth("genie-1")).toBe(true); + }); + + test("connect to trusted external host does NOT forward SP token on any RPC", async () => { + const { fetchImpl, calls } = recordingFetch([ + () => + jsonResponse( + { jsonrpc: "2.0", id: 1, result: {} }, + { + "mcp-session-id": "sess-1", + }, + ), + () => jsonResponse({ jsonrpc: "2.0", result: null }), + () => + jsonResponse({ + jsonrpc: "2.0", + id: 3, + result: { tools: [{ name: "help" }] }, + }), + ]); + const client = new AppKitMcpClient( + WORKSPACE, + authSpy, + trustedExternalPolicy, + { fetchImpl, dnsLookup: publicDnsLookup }, + ); + + await client.connect({ name: "ext", url: "https://mcp.example.com/mcp" }); + + for (const call of calls) { + const headers = call.init.headers as Record; + expect(headers.Authorization).toBeUndefined(); + } + expect(authSpy).not.toHaveBeenCalled(); + expect(client.canForwardWorkspaceAuth("ext")).toBe(false); + }); +}); + +describe("AppKitMcpClient — callTool auth scoping", () => { + test("drops caller-supplied OBO token when destination is not workspace-origin", async () => { + const connectResponders = [ + () => + jsonResponse( + { jsonrpc: "2.0", id: 1, result: {} }, + { + "mcp-session-id": "sess-1", + }, + ), + () => jsonResponse({ jsonrpc: "2.0", result: null }), + () => + jsonResponse({ + jsonrpc: "2.0", + id: 3, + result: { tools: [{ name: "do" }] }, + }), + ]; + const callResponder = () => + jsonResponse({ + jsonrpc: "2.0", + id: 4, + result: { content: [{ type: "text", text: "ok" }] }, + }); + const { fetchImpl, calls } = recordingFetch([ + ...connectResponders, + callResponder, + ]); + const client = new AppKitMcpClient( + WORKSPACE, + workspaceAuth, + trustedExternalPolicy, + { fetchImpl, dnsLookup: publicDnsLookup }, + ); + await client.connect({ name: "ext", url: "https://mcp.example.com/mcp" }); + + const output = await client.callTool( + "mcp.ext.do", + { x: 1 }, + { + Authorization: "Bearer OBO-USER-TOKEN", + }, + ); + expect(output).toBe("ok"); + + const toolCall = calls[calls.length - 1]; + const headers = toolCall.init.headers as Record; + expect(headers.Authorization).toBeUndefined(); + }); + + test("forwards caller-supplied OBO token when destination is workspace-origin", async () => { + const connectResponders = [ + () => + jsonResponse( + { jsonrpc: "2.0", id: 1, result: {} }, + { + "mcp-session-id": "sess-1", + }, + ), + () => jsonResponse({ jsonrpc: "2.0", result: null }), + () => + jsonResponse({ + jsonrpc: "2.0", + id: 3, + result: { tools: [{ name: "do" }] }, + }), + ]; + const callResponder = () => + jsonResponse({ + jsonrpc: "2.0", + id: 4, + result: { content: [{ type: "text", text: "ok" }] }, + }); + const { fetchImpl, calls } = recordingFetch([ + ...connectResponders, + callResponder, + ]); + const client = new AppKitMcpClient( + WORKSPACE, + workspaceAuth, + workspacePolicy, + { + fetchImpl, + dnsLookup: publicDnsLookup, + }, + ); + await client.connect({ + name: "genie-1", + url: `${WORKSPACE}/api/2.0/mcp/genie/abc`, + }); + + await client.callTool( + "mcp.genie-1.do", + {}, + { + Authorization: "Bearer OBO-USER-TOKEN", + }, + ); + + const toolCall = calls[calls.length - 1]; + const headers = toolCall.init.headers as Record; + expect(headers.Authorization).toBe("Bearer OBO-USER-TOKEN"); + }); + + test("falls back to SP auth when no OBO override is provided and destination is workspace", async () => { + const authSpy = vi.fn(workspaceAuth); + const connectResponders = [ + () => + jsonResponse( + { jsonrpc: "2.0", id: 1, result: {} }, + { + "mcp-session-id": "sess-1", + }, + ), + () => jsonResponse({ jsonrpc: "2.0", result: null }), + () => + jsonResponse({ + jsonrpc: "2.0", + id: 3, + result: { tools: [{ name: "do" }] }, + }), + ]; + const callResponder = () => + jsonResponse({ + jsonrpc: "2.0", + id: 4, + result: { content: [{ type: "text", text: "ok" }] }, + }); + const { fetchImpl, calls } = recordingFetch([ + ...connectResponders, + callResponder, + ]); + const client = new AppKitMcpClient(WORKSPACE, authSpy, workspacePolicy, { + fetchImpl, + dnsLookup: publicDnsLookup, + }); + await client.connect({ + name: "genie-1", + url: `${WORKSPACE}/api/2.0/mcp/genie/abc`, + }); + + await client.callTool("mcp.genie-1.do", {}, undefined); + + const toolCall = calls[calls.length - 1]; + const headers = toolCall.init.headers as Record; + expect(headers.Authorization).toBe("Bearer SP-TOKEN"); + }); +}); + +describe("AppKitMcpClient — caller abort signal composition", () => { + test("callTool's fetch aborts when the caller signal fires", async () => { + const connectResponders = [ + () => + jsonResponse( + { jsonrpc: "2.0", id: 1, result: {} }, + { "mcp-session-id": "sess-1" }, + ), + () => jsonResponse({ jsonrpc: "2.0", result: null }), + () => + jsonResponse({ + jsonrpc: "2.0", + id: 3, + result: { tools: [{ name: "slow" }] }, + }), + ]; + const callResponder = (call: FetchCall): Promise => { + const signal = call.init.signal as AbortSignal | undefined; + return new Promise((_, reject) => { + if (signal?.aborted) { + reject( + new DOMException( + signal.reason?.toString() ?? "aborted", + "AbortError", + ), + ); + return; + } + signal?.addEventListener( + "abort", + () => { + reject( + new DOMException( + signal.reason?.toString() ?? "aborted", + "AbortError", + ), + ); + }, + { once: true }, + ); + }); + }; + const { fetchImpl } = recordingFetch([...connectResponders, callResponder]); + const client = new AppKitMcpClient( + WORKSPACE, + workspaceAuth, + workspacePolicy, + { + fetchImpl, + dnsLookup: publicDnsLookup, + }, + ); + await client.connect({ + name: "genie-1", + url: `${WORKSPACE}/api/2.0/mcp/genie/abc`, + }); + + const controller = new AbortController(); + const pending = client + .callTool("mcp.genie-1.slow", {}, undefined, controller.signal) + .catch((e) => e); + // Let the fetch start + register its abort listener before we abort. + await new Promise((r) => setTimeout(r, 10)); + controller.abort(new Error("user cancelled")); + const error = (await pending) as Error; + expect(error).toBeInstanceOf(Error); + expect(error.name).toBe("AbortError"); + }); +}); diff --git a/packages/appkit/src/connectors/mcp/tests/host-policy.test.ts b/packages/appkit/src/connectors/mcp/tests/host-policy.test.ts new file mode 100644 index 00000000..451536ed --- /dev/null +++ b/packages/appkit/src/connectors/mcp/tests/host-policy.test.ts @@ -0,0 +1,354 @@ +import { describe, expect, test, vi } from "vitest"; +import { + assertResolvedHostSafe, + buildMcpHostPolicy, + checkMcpUrl, + type DnsLookup, + isBlockedIp, + isLoopbackHost, + type McpHostPolicy, + type McpHostPolicyConfig, +} from "../host-policy"; + +function stubLookup( + addresses: Array<{ address: string; family?: number }>, +): DnsLookup { + return vi + .fn() + .mockResolvedValue(addresses.map((a) => ({ family: 4, ...a }))); +} + +function failingLookup(message: string): DnsLookup { + return vi.fn().mockRejectedValue(new Error(message)); +} + +const WORKSPACE = "https://test-workspace.cloud.databricks.com"; + +function policy(overrides: Partial = {}): McpHostPolicy { + return { + workspaceHostname: "test-workspace.cloud.databricks.com", + trustedHosts: new Set(), + allowLocalhost: false, + ...overrides, + }; +} + +describe("buildMcpHostPolicy", () => { + test("extracts hostname from workspace URL", () => { + const p = buildMcpHostPolicy(undefined, WORKSPACE); + expect(p.workspaceHostname).toBe("test-workspace.cloud.databricks.com"); + }); + + test("lowercases and trims trustedHosts", () => { + const p = buildMcpHostPolicy( + { trustedHosts: ["Example.COM", " corp.internal ", "mcp.example.com"] }, + WORKSPACE, + ); + expect(p.trustedHosts).toEqual( + new Set(["example.com", "corp.internal", "mcp.example.com"]), + ); + }); + + test("allowLocalhost defaults to false in production", () => { + const prev = process.env.NODE_ENV; + process.env.NODE_ENV = "production"; + try { + const p = buildMcpHostPolicy(undefined, WORKSPACE); + expect(p.allowLocalhost).toBe(false); + } finally { + process.env.NODE_ENV = prev; + } + }); + + test("allowLocalhost defaults to true outside production", () => { + const prev = process.env.NODE_ENV; + process.env.NODE_ENV = "development"; + try { + const p = buildMcpHostPolicy(undefined, WORKSPACE); + expect(p.allowLocalhost).toBe(true); + } finally { + process.env.NODE_ENV = prev; + } + }); + + test("allowLocalhost respects explicit override", () => { + const prev = process.env.NODE_ENV; + process.env.NODE_ENV = "production"; + try { + const cfg: McpHostPolicyConfig = { allowLocalhost: true }; + const p = buildMcpHostPolicy(cfg, WORKSPACE); + expect(p.allowLocalhost).toBe(true); + } finally { + process.env.NODE_ENV = prev; + } + }); + + test("throws on invalid workspace host", () => { + expect(() => buildMcpHostPolicy(undefined, "not-a-url")).toThrow( + /Invalid workspace host/, + ); + }); +}); + +describe("checkMcpUrl", () => { + test("admits same-origin workspace https URL and forwards auth", () => { + const result = checkMcpUrl(`${WORKSPACE}/api/2.0/mcp/genie/abc`, policy()); + expect(result.ok).toBe(true); + if (result.ok) expect(result.forwardWorkspaceAuth).toBe(true); + }); + + test("admits trusted host but does NOT forward workspace auth", () => { + const p = policy({ trustedHosts: new Set(["mcp.example.com"]) }); + const result = checkMcpUrl("https://mcp.example.com/mcp", p); + expect(result.ok).toBe(true); + if (result.ok) expect(result.forwardWorkspaceAuth).toBe(false); + }); + + test("rejects host that is neither workspace nor trusted", () => { + const result = checkMcpUrl("https://attacker.example.com/mcp", policy()); + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.reason).toMatch(/attacker\.example\.com/); + expect(result.reason).toMatch(/trustedHosts/); + } + }); + + test("rejects plaintext http:// for remote hosts even when trusted", () => { + const p = policy({ trustedHosts: new Set(["mcp.example.com"]) }); + const result = checkMcpUrl("http://mcp.example.com/mcp", p); + expect(result.ok).toBe(false); + if (!result.ok) expect(result.reason).toMatch(/plaintext http/); + }); + + test("rejects plaintext http://localhost when allowLocalhost is false", () => { + const result = checkMcpUrl("http://localhost:4000/mcp", policy()); + expect(result.ok).toBe(false); + }); + + test("admits http://localhost when allowLocalhost is true, no workspace auth", () => { + const p = policy({ allowLocalhost: true }); + const result = checkMcpUrl("http://localhost:4000/mcp", p); + expect(result.ok).toBe(true); + if (result.ok) expect(result.forwardWorkspaceAuth).toBe(false); + }); + + test("admits http://127.0.0.1 when allowLocalhost is true", () => { + const p = policy({ allowLocalhost: true }); + const result = checkMcpUrl("http://127.0.0.1:4000/mcp", p); + expect(result.ok).toBe(true); + if (result.ok) expect(result.forwardWorkspaceAuth).toBe(false); + }); + + test("rejects non-http(s) schemes", () => { + for (const url of [ + "file:///etc/passwd", + "ftp://host/x", + "gopher://host/x", + "javascript:alert(1)", + ]) { + const result = checkMcpUrl(url, policy()); + expect(result.ok).toBe(false); + } + }); + + test("rejects obviously invalid URLs", () => { + const result = checkMcpUrl("not-a-url", policy()); + expect(result.ok).toBe(false); + }); + + test("hostname comparison is case-insensitive", () => { + const result = checkMcpUrl( + "https://TEST-Workspace.CLOUD.Databricks.com/mcp", + policy(), + ); + expect(result.ok).toBe(true); + if (result.ok) expect(result.forwardWorkspaceAuth).toBe(true); + }); + + test("rejects same hostname on different scheme (http) even for workspace", () => { + const result = checkMcpUrl( + "http://test-workspace.cloud.databricks.com/mcp", + policy(), + ); + expect(result.ok).toBe(false); + }); +}); + +describe("isBlockedIp", () => { + test("blocks RFC1918 IPv4 ranges", () => { + for (const addr of [ + "10.0.0.1", + "10.255.255.255", + "172.16.0.1", + "172.31.255.255", + "192.168.0.1", + "192.168.255.255", + ]) { + expect(isBlockedIp(addr, true)).toBe(true); + } + }); + + test("blocks link-local 169.254.0.0/16 (covers cloud metadata 169.254.169.254)", () => { + expect(isBlockedIp("169.254.169.254", true)).toBe(true); + expect(isBlockedIp("169.254.0.1", true)).toBe(true); + }); + + test("blocks CGNAT 100.64.0.0/10", () => { + expect(isBlockedIp("100.64.0.1", true)).toBe(true); + expect(isBlockedIp("100.127.255.255", true)).toBe(true); + }); + + test("blocks 0.0.0.0/8 and multicast/reserved (>= 224.0.0.0)", () => { + expect(isBlockedIp("0.0.0.0", true)).toBe(true); + expect(isBlockedIp("0.1.2.3", true)).toBe(true); + expect(isBlockedIp("224.0.0.1", true)).toBe(true); + expect(isBlockedIp("255.255.255.255", true)).toBe(true); + }); + + test("blocks loopback when allowLocalhost is false", () => { + expect(isBlockedIp("127.0.0.1", false)).toBe(true); + expect(isBlockedIp("127.1.2.3", false)).toBe(true); + expect(isBlockedIp("::1", false)).toBe(true); + }); + + test("permits loopback when allowLocalhost is true", () => { + expect(isBlockedIp("127.0.0.1", true)).toBe(false); + expect(isBlockedIp("::1", true)).toBe(false); + }); + + test("blocks ULA (fc00::/7) and link-local (fe80::/10) IPv6", () => { + expect(isBlockedIp("fc00::1", true)).toBe(true); + expect(isBlockedIp("fd00::1", true)).toBe(true); + expect(isBlockedIp("fe80::1", true)).toBe(true); + }); + + test("blocks the full link-local /10 range fe80::–febf:: (regression: fea0/feb0)", () => { + // fe80::/10 spans 1111 1110 10.. — first hex pair `fe` + second nibble 8..b. + for (const addr of [ + "fe80::1", + "fe90::1", + "fea0::1", // regression: was passing the filter before + "feaf::1", // regression + "feb0::1", // regression + "febf::1", // regression + ]) { + expect(isBlockedIp(addr, true)).toBe(true); + } + // Outside /10 must not be blocked by this rule (belongs to routable-ish + // experimental ranges; nothing else in the module should match either). + expect(isBlockedIp("fec0::1", true)).toBe(false); + }); + + test("blocks IPv4-mapped IPv6 addresses in blocked ranges (dotted form)", () => { + expect(isBlockedIp("::ffff:169.254.169.254", true)).toBe(true); + expect(isBlockedIp("::ffff:10.0.0.1", true)).toBe(true); + }); + + test("blocks IPv4-mapped IPv6 addresses in colon-hex form (regression)", () => { + // ::ffff:a9fe:a9fe is the same destination as ::ffff:169.254.169.254. + // Before the fix this form slipped past the IPv4-mapped branch because + // isIPv4("a9fe:a9fe") is false and no other v6 rule matched. + expect(isBlockedIp("::ffff:a9fe:a9fe", true)).toBe(true); // 169.254.169.254 + expect(isBlockedIp("::ffff:0a00:0001", true)).toBe(true); // 10.0.0.1 + expect(isBlockedIp("::ffff:c0a8:0001", true)).toBe(true); // 192.168.0.1 + // A public IPv4 mapped to colon-hex must still pass through: 8.8.8.8 = 0808:0808 + expect(isBlockedIp("::ffff:0808:0808", true)).toBe(false); + }); + + test("allows public IPv4 and IPv6 addresses", () => { + expect(isBlockedIp("8.8.8.8", false)).toBe(false); + expect(isBlockedIp("1.1.1.1", false)).toBe(false); + expect(isBlockedIp("2001:4860:4860::8888", false)).toBe(false); + }); + + test("treats malformed IP strings as blocked (fail-closed)", () => { + expect(isBlockedIp("10.0.0", true)).toBe(true); + expect(isBlockedIp("abc.def.ghi.jkl", true)).toBe(true); + }); +}); + +describe("isLoopbackHost", () => { + test.each([ + "localhost", + "LOCALHOST", + "127.0.0.1", + "::1", + "[::1]", + "0:0:0:0:0:0:0:1", + ])("recognises %s as loopback", (host) => { + expect(isLoopbackHost(host)).toBe(true); + }); + + test("does not match other hosts", () => { + expect(isLoopbackHost("example.com")).toBe(false); + expect(isLoopbackHost("10.0.0.1")).toBe(false); + }); +}); + +describe("assertResolvedHostSafe", () => { + test("passes workspace hostname when resolved address is public", async () => { + const lookup = stubLookup([{ address: "203.0.113.42" }]); + await expect( + assertResolvedHostSafe( + "test-workspace.cloud.databricks.com", + policy(), + lookup, + ), + ).resolves.toBeUndefined(); + expect(lookup).toHaveBeenCalledWith("test-workspace.cloud.databricks.com", { + all: true, + }); + }); + + test("rejects hostname that resolves to link-local cloud metadata IP", async () => { + const lookup = stubLookup([{ address: "169.254.169.254" }]); + await expect( + assertResolvedHostSafe("evil.example.com", policy(), lookup), + ).rejects.toThrow(/169\.254\.169\.254/); + }); + + test("rejects hostname that resolves to RFC1918 IP", async () => { + const lookup = stubLookup([{ address: "10.0.0.1" }]); + await expect( + assertResolvedHostSafe("internal.example.com", policy(), lookup), + ).rejects.toThrow(/10\.0\.0\.1/); + }); + + test("rejects IP literal in blocked range without DNS lookup", async () => { + const lookup = stubLookup([{ address: "8.8.8.8" }]); + await expect( + assertResolvedHostSafe("169.254.169.254", policy(), lookup), + ).rejects.toThrow(/blocked IP range/); + expect(lookup).not.toHaveBeenCalled(); + }); + + test("rejects plain 'localhost' when allowLocalhost is false", async () => { + await expect(assertResolvedHostSafe("localhost", policy())).rejects.toThrow( + /localhost is not allowed/, + ); + }); + + test("surfaces DNS resolution failures", async () => { + const lookup = failingLookup("ENOTFOUND"); + await expect( + assertResolvedHostSafe("nonexistent.example.com", policy(), lookup), + ).rejects.toThrow(/could not be resolved/); + }); + + test("rejects if any resolved address is blocked (defense against split DNS)", async () => { + const lookup = stubLookup([ + { address: "8.8.8.8" }, + { address: "169.254.169.254" }, + ]); + await expect( + assertResolvedHostSafe("mixed.example.com", policy(), lookup), + ).rejects.toThrow(/169\.254\.169\.254/); + }); + + test("rejects hostname that resolves to empty DNS result", async () => { + const lookup = stubLookup([]); + await expect( + assertResolvedHostSafe("empty.example.com", policy(), lookup), + ).rejects.toThrow(/no DNS addresses/); + }); +}); diff --git a/packages/appkit/src/connectors/mcp/types.ts b/packages/appkit/src/connectors/mcp/types.ts new file mode 100644 index 00000000..d74f0a46 --- /dev/null +++ b/packages/appkit/src/connectors/mcp/types.ts @@ -0,0 +1,12 @@ +/** + * Input shape consumed by {@link AppKitMcpClient.connect}. Produced by the + * agents plugin from user-facing `HostedTool` declarations (see + * `plugins/agents/tools/hosted-tools.ts`) and accepted directly by the + * connector to keep its surface free of agent-layer concepts. + */ +export interface McpEndpointConfig { + /** Stable logical name used as the `mcp..*` tool prefix and in logs. */ + name: string; + /** Absolute URL (`https://…`) or workspace-relative path (`/api/2.0/mcp/…`). */ + url: string; +} diff --git a/packages/appkit/src/core/appkit.ts b/packages/appkit/src/core/appkit.ts index a2cba994..5d1dd455 100644 --- a/packages/appkit/src/core/appkit.ts +++ b/packages/appkit/src/core/appkit.ts @@ -10,17 +10,24 @@ import type { } from "shared"; import { CacheManager } from "../cache"; import { ServiceContext } from "../context"; +import { createLogger } from "../logging/logger"; import { ResourceRegistry, ResourceType } from "../registry"; import type { TelemetryConfig } from "../telemetry"; import { TelemetryManager } from "../telemetry"; +import { isToolProvider, PluginContext } from "./plugin-context"; + +const logger = createLogger("appkit"); export class AppKit { #pluginInstances: Record = {}; #setupPromises: Promise[] = []; + #context: PluginContext; private constructor(config: { plugins: TPlugins }) { const { plugins, ...globalConfig } = config; + this.#context = new PluginContext(); + const pluginEntries = Object.entries(plugins); const corePlugins = pluginEntries.filter(([_, p]) => { @@ -35,20 +42,24 @@ export class AppKit { for (const [name, pluginData] of corePlugins) { if (pluginData) { - this.createAndRegisterPlugin(globalConfig, name, pluginData); + this.createAndRegisterPlugin(globalConfig, name, pluginData, { + context: this.#context, + }); } } for (const [name, pluginData] of normalPlugins) { if (pluginData) { - this.createAndRegisterPlugin(globalConfig, name, pluginData); + this.createAndRegisterPlugin(globalConfig, name, pluginData, { + context: this.#context, + }); } } for (const [name, pluginData] of deferredPlugins) { if (pluginData) { this.createAndRegisterPlugin(globalConfig, name, pluginData, { - plugins: this.#pluginInstances, + context: this.#context, }); } } @@ -70,8 +81,20 @@ export class AppKit { }; const pluginInstance = new Plugin(baseConfig); + if (typeof pluginInstance.attachContext === "function") { + pluginInstance.attachContext({ + context: this.#context, + telemetryConfig: baseConfig.telemetry, + }); + } + this.#pluginInstances[name] = pluginInstance; + this.#context.registerPlugin(name, pluginInstance); + if (isToolProvider(pluginInstance)) { + this.#context.registerToolProvider(name, pluginInstance); + } + this.#setupPromises.push(pluginInstance.setup()); const self = this; @@ -167,6 +190,7 @@ export class AppKit { telemetry?: TelemetryConfig; cache?: CacheConfig; client?: WorkspaceClient; + onPluginsReady?: (appkit: PluginMap) => void | Promise; } = {}, ): Promise> { // Initialize core services @@ -199,8 +223,22 @@ export class AppKit { const instance = new AppKit(mergedConfig); await Promise.all(instance.#setupPromises); + await instance.#context.emitLifecycle("setup:complete"); + + const handle = instance as unknown as PluginMap; - return instance as unknown as PluginMap; + if (config.onPluginsReady) { + logger.debug("Running onPluginsReady hook"); + await config.onPluginsReady(handle); + logger.debug("onPluginsReady hook completed"); + } + + const serverPlugin = instance.#pluginInstances.server; + if (serverPlugin && typeof (serverPlugin as any).start === "function") { + await (serverPlugin as any).start(); + } + + return handle; } private static preparePlugins( @@ -222,6 +260,9 @@ export class AppKit { * * Initializes telemetry, cache, and service context, then registers plugins * in phase order (core, normal, deferred) and awaits their setup. + * If a `onPluginsReady` callback is provided it runs after plugin setup but + * before the server starts, giving you access to the full appkit handle + * for registering custom routes or performing async setup. * The returned object maps each plugin name to its `exports()` API, * with an `asUser(req)` method for user-scoped execution. * @@ -236,18 +277,18 @@ export class AppKit { * }); * ``` * - * @example Extended Server with analytics and custom endpoint + * @example Server with custom routes via onPluginsReady * ```ts * import { createApp, server, analytics } from "@databricks/appkit"; * - * const appkit = await createApp({ - * plugins: [server({ autoStart: false }), analytics({})], - * }); - * - * appkit.server.extend((app) => { - * app.get("/custom", (_req, res) => res.json({ ok: true })); + * await createApp({ + * plugins: [server(), analytics({})], + * onPluginsReady(appkit) { + * appkit.server.extend((app) => { + * app.get("/custom", (_req, res) => res.json({ ok: true })); + * }); + * }, * }); - * await appkit.server.start(); * ``` */ export async function createApp< @@ -258,6 +299,7 @@ export async function createApp< telemetry?: TelemetryConfig; cache?: CacheConfig; client?: WorkspaceClient; + onPluginsReady?: (appkit: PluginMap) => void | Promise; } = {}, ): Promise> { return AppKit._createApp(config); diff --git a/packages/appkit/src/core/create-agent-def.ts b/packages/appkit/src/core/create-agent-def.ts new file mode 100644 index 00000000..3e93371d --- /dev/null +++ b/packages/appkit/src/core/create-agent-def.ts @@ -0,0 +1,53 @@ +import { ConfigurationError } from "../errors"; +import type { AgentDefinition } from "../plugins/agents/types"; + +/** + * Pure factory for agent definitions. Returns the passed-in definition after + * cycle-detecting the sub-agent graph. Accepts the full `AgentDefinition` shape + * and is safe to call at module top-level. + * + * The returned value is a plain `AgentDefinition` — no adapter construction, + * no side effects. Register it with `agents({ agents: { name: def } })` or run + * it standalone via `runAgent(def, input)`. + * + * @example + * ```ts + * const support = createAgent({ + * instructions: "You help customers.", + * model: "databricks-claude-sonnet-4-5", + * tools: { + * get_weather: tool({ ... }), + * }, + * }); + * ``` + */ +export function createAgent(def: AgentDefinition): AgentDefinition { + detectCycles(def); + return def; +} + +/** + * Walks the `agents: { ... }` sub-agent tree via DFS and throws if a cycle is + * found. Cycles would cause infinite recursion at tool-invocation time. + */ +function detectCycles(def: AgentDefinition): void { + const visiting = new Set(); + const visited = new Set(); + + const walk = (current: AgentDefinition, path: string[]): void => { + if (visited.has(current)) return; + if (visiting.has(current)) { + throw new ConfigurationError( + `Agent sub-agent cycle detected: ${path.join(" -> ")}`, + ); + } + visiting.add(current); + for (const [childKey, child] of Object.entries(current.agents ?? {})) { + walk(child, [...path, childKey]); + } + visiting.delete(current); + visited.add(current); + }; + + walk(def, [def.name ?? "(root)"]); +} diff --git a/packages/appkit/src/core/plugin-context.ts b/packages/appkit/src/core/plugin-context.ts new file mode 100644 index 00000000..c2801585 --- /dev/null +++ b/packages/appkit/src/core/plugin-context.ts @@ -0,0 +1,287 @@ +import type express from "express"; +import type { BasePlugin, ToolProvider } from "shared"; +import { createLogger } from "../logging/logger"; +import { TelemetryManager } from "../telemetry"; + +const logger = createLogger("plugin-context"); + +interface BufferedRoute { + method: string; + path: string; + handlers: express.RequestHandler[]; +} + +interface RouteTarget { + addExtension(fn: (app: express.Application) => void): void; +} + +interface ToolProviderEntry { + plugin: BasePlugin & ToolProvider; + name: string; +} + +type LifecycleEvent = "setup:complete" | "server:ready" | "shutdown"; + +/** + * Mediator for inter-plugin communication. + * + * Created by AppKit core and passed to every plugin. Plugins request + * capabilities from the context instead of holding direct references + * to sibling plugin instances. + * + * Capabilities: + * - Route mounting with buffering (order-independent) + * - Typed ToolProvider registry (live, not snapshot-based) + * - User-scoped tool execution with automatic telemetry + * - Lifecycle hooks for plugin coordination + */ +export class PluginContext { + private routeBuffer: BufferedRoute[] = []; + private routeTarget: RouteTarget | null = null; + private toolProviders = new Map(); + private plugins = new Map(); + private lifecycleHooks = new Map< + LifecycleEvent, + Set<() => void | Promise> + >(); + private telemetry = TelemetryManager.getProvider("plugin-context"); + + /** + * Register a route on the root Express application. + * + * If a route target (server plugin) has registered, the route is applied + * immediately. Otherwise it is buffered and flushed when a route target + * becomes available. + */ + addRoute( + method: string, + path: string, + ...handlers: express.RequestHandler[] + ): void { + if (this.routeTarget) { + this.applyRoute({ method, path, handlers }); + } else { + this.routeBuffer.push({ method, path, handlers }); + } + } + + /** + * Register middleware on the root Express application. + * + * Same buffering semantics as `addRoute`. + */ + addMiddleware(path: string, ...handlers: express.RequestHandler[]): void { + if (this.routeTarget) { + this.applyMiddleware(path, handlers); + } else { + this.routeBuffer.push({ method: "use", path, handlers }); + } + } + + /** + * Called by the server plugin to opt in as the route target. + * Flushes all buffered routes via the server's `addExtension`. + */ + registerAsRouteTarget(target: RouteTarget): void { + this.routeTarget = target; + + for (const route of this.routeBuffer) { + if (route.method === "use") { + this.applyMiddleware(route.path, route.handlers); + } else { + this.applyRoute(route); + } + } + this.routeBuffer = []; + } + + /** + * Register a plugin that implements the ToolProvider interface. + * Called by AppKit core after constructing each plugin. + */ + registerToolProvider(name: string, plugin: BasePlugin & ToolProvider): void { + this.toolProviders.set(name, { plugin, name }); + } + + /** + * Register a plugin instance. + * Called by AppKit core after constructing each plugin. + */ + registerPlugin(name: string, instance: BasePlugin): void { + this.plugins.set(name, instance); + } + + /** + * Returns all registered plugin instances keyed by name. + * Used by the server plugin for route injection, client config, + * and shutdown coordination. + */ + getPlugins(): Map { + return this.plugins; + } + + /** + * Returns all registered ToolProvider plugins. + * Always returns the current set — not a frozen snapshot. + */ + getToolProviders(): Array<{ name: string; provider: ToolProvider }> { + return Array.from(this.toolProviders.values()).map((entry) => ({ + name: entry.name, + provider: entry.plugin, + })); + } + + /** + * Execute a tool on a ToolProvider plugin with automatic user scoping + * and telemetry. + * + * The context: + * 1. Resolves the plugin by name + * 2. Calls `asUser(req)` for user-scoped execution + * 3. Wraps the call in a telemetry span with a 30s timeout + */ + async executeTool( + req: express.Request, + pluginName: string, + toolName: string, + args: unknown, + signal?: AbortSignal, + ): Promise { + const entry = this.toolProviders.get(pluginName); + if (!entry) { + throw new Error( + `PluginContext: unknown plugin "${pluginName}". Available: ${Array.from(this.toolProviders.keys()).join(", ")}`, + ); + } + + const tracer = this.telemetry.getTracer(); + const operationName = `executeTool:${pluginName}.${toolName}`; + + return tracer.startActiveSpan(operationName, async (span) => { + const timeout = 30_000; + const timeoutSignal = AbortSignal.timeout(timeout); + const combinedSignal = signal + ? AbortSignal.any([signal, timeoutSignal]) + : timeoutSignal; + + try { + const userPlugin = (entry.plugin as any).asUser(req); + const result = await (userPlugin as ToolProvider).executeAgentTool( + toolName, + args, + combinedSignal, + ); + span.setStatus({ code: 0 }); + return result; + } catch (error) { + span.setStatus({ + code: 2, + message: + error instanceof Error ? error.message : "Tool execution failed", + }); + span.recordException( + error instanceof Error ? error : new Error(String(error)), + ); + throw error; + } finally { + span.end(); + } + }); + } + + /** + * Register a lifecycle hook callback. + */ + onLifecycle(event: LifecycleEvent, fn: () => void | Promise): void { + let hooks = this.lifecycleHooks.get(event); + if (!hooks) { + hooks = new Set(); + this.lifecycleHooks.set(event, hooks); + } + hooks.add(fn); + } + + /** + * Emit a lifecycle event, calling all registered callbacks. + * Errors in individual callbacks are logged but do not prevent + * other callbacks from running. + * + * @internal Called by AppKit core only. + */ + async emitLifecycle(event: LifecycleEvent): Promise { + const hooks = this.lifecycleHooks.get(event); + if (!hooks) return; + + if ( + event === "setup:complete" && + this.routeBuffer.length > 0 && + !this.routeTarget + ) { + logger.warn( + "%d buffered routes were never applied — no server plugin registered as route target", + this.routeBuffer.length, + ); + } + + for (const fn of hooks) { + try { + await fn(); + } catch (error) { + logger.error("Lifecycle hook '%s' failed: %O", event, error); + } + } + } + + /** + * Returns all registered plugin names. + */ + getPluginNames(): string[] { + return Array.from(this.plugins.keys()); + } + + /** + * Check if a plugin with the given name is registered. + */ + hasPlugin(name: string): boolean { + return this.plugins.has(name); + } + + private applyRoute(route: BufferedRoute): void { + if (!this.routeTarget) return; + this.routeTarget.addExtension((app) => { + const method = route.method.toLowerCase() as keyof express.Application; + if (typeof app[method] === "function") { + (app[method] as (...a: unknown[]) => void)( + route.path, + ...route.handlers, + ); + } + }); + } + + private applyMiddleware( + path: string, + handlers: express.RequestHandler[], + ): void { + if (!this.routeTarget) return; + this.routeTarget.addExtension((app) => { + app.use(path, ...handlers); + }); + } +} + +/** + * Type guard: checks whether a plugin implements the ToolProvider interface. + */ +export function isToolProvider( + plugin: unknown, +): plugin is BasePlugin & ToolProvider { + return ( + typeof plugin === "object" && + plugin !== null && + "getAgentTools" in plugin && + typeof (plugin as ToolProvider).getAgentTools === "function" && + "executeAgentTool" in plugin && + typeof (plugin as ToolProvider).executeAgentTool === "function" + ); +} diff --git a/packages/appkit/src/core/run-agent.ts b/packages/appkit/src/core/run-agent.ts new file mode 100644 index 00000000..e83c2c9c --- /dev/null +++ b/packages/appkit/src/core/run-agent.ts @@ -0,0 +1,226 @@ +import { randomUUID } from "node:crypto"; +import type { + AgentAdapter, + AgentEvent, + AgentToolDefinition, + Message, +} from "shared"; +import { + type FunctionTool, + functionToolToDefinition, + isFunctionTool, +} from "../plugins/agents/tools/function-tool"; +import { isHostedTool } from "../plugins/agents/tools/hosted-tools"; +import type { + AgentDefinition, + AgentTool, + ToolkitEntry, +} from "../plugins/agents/types"; +import { isToolkitEntry } from "../plugins/agents/types"; + +export interface RunAgentInput { + /** Seed messages for the run. Either a single user string or a full message list. */ + messages: string | Message[]; + /** Abort signal for cancellation. */ + signal?: AbortSignal; +} + +export interface RunAgentResult { + /** Aggregated text output from all `message_delta` events. */ + text: string; + /** Every event the adapter yielded, in order. Useful for inspection/tests. */ + events: AgentEvent[]; +} + +/** + * Standalone agent execution without `createApp`. Resolves the adapter, binds + * inline tools, and drives the adapter's `run()` loop to completion. + * + * Limitations vs. running through the agents() plugin: + * - No OBO: there is no HTTP request, so plugin tools run as the service + * principal (when they work at all). + * - Plugin tools (`ToolkitEntry`) are not supported — they require a live + * `PluginContext` that only exists when registered in a `createApp` + * instance. This function throws a clear error if encountered. + * - Sub-agents (`agents: { ... }` on the def) are executed as nested + * `runAgent` calls with no shared thread state. + */ +export async function runAgent( + def: AgentDefinition, + input: RunAgentInput, +): Promise { + const adapter = await resolveAdapter(def); + const messages = normalizeMessages(input.messages, def.instructions); + const toolIndex = buildStandaloneToolIndex(def); + const tools = Array.from(toolIndex.values()).map((e) => e.def); + + const signal = input.signal; + + const executeTool = async (name: string, args: unknown): Promise => { + const entry = toolIndex.get(name); + if (!entry) throw new Error(`Unknown tool: ${name}`); + if (entry.kind === "function") { + return entry.tool.execute(args as Record); + } + if (entry.kind === "subagent") { + const subInput: RunAgentInput = { + messages: + typeof args === "object" && + args !== null && + typeof (args as { input?: unknown }).input === "string" + ? (args as { input: string }).input + : JSON.stringify(args), + signal, + }; + const res = await runAgent(entry.agentDef, subInput); + return res.text; + } + throw new Error( + `runAgent: tool "${name}" is a ${entry.kind} tool. ` + + "Plugin toolkits and MCP tools are only usable via createApp({ plugins: [..., agents(...)] }).", + ); + }; + + const events: AgentEvent[] = []; + let text = ""; + + const stream = adapter.run( + { + messages, + tools, + threadId: randomUUID(), + signal, + }, + { executeTool, signal }, + ); + + for await (const event of stream) { + if (signal?.aborted) break; + events.push(event); + if (event.type === "message_delta") { + text += event.content; + } else if (event.type === "message") { + text = event.content; + } + } + + return { text, events }; +} + +async function resolveAdapter(def: AgentDefinition): Promise { + const { model } = def; + if (!model) { + const { DatabricksAdapter } = await import("../agents/databricks"); + return DatabricksAdapter.fromModelServing(); + } + if (typeof model === "string") { + const { DatabricksAdapter } = await import("../agents/databricks"); + return DatabricksAdapter.fromModelServing(model); + } + return await model; +} + +function normalizeMessages( + input: string | Message[], + instructions: string, +): Message[] { + const systemMessage: Message = { + id: "system", + role: "system", + content: instructions, + createdAt: new Date(), + }; + if (typeof input === "string") { + return [ + systemMessage, + { + id: randomUUID(), + role: "user", + content: input, + createdAt: new Date(), + }, + ]; + } + return [systemMessage, ...input]; +} + +type StandaloneEntry = + | { + kind: "function"; + def: AgentToolDefinition; + tool: FunctionTool; + } + | { + kind: "subagent"; + def: AgentToolDefinition; + agentDef: AgentDefinition; + } + | { + kind: "toolkit"; + def: AgentToolDefinition; + entry: ToolkitEntry; + } + | { + kind: "hosted"; + def: AgentToolDefinition; + }; + +function buildStandaloneToolIndex( + def: AgentDefinition, +): Map { + const index = new Map(); + + for (const [key, tool] of Object.entries(def.tools ?? {})) { + index.set(key, classifyTool(key, tool)); + } + + for (const [childKey, child] of Object.entries(def.agents ?? {})) { + const toolName = `agent-${childKey}`; + index.set(toolName, { + kind: "subagent", + agentDef: { ...child, name: child.name ?? childKey }, + def: { + name: toolName, + description: + child.instructions.slice(0, 120) || + `Delegate to the ${childKey} sub-agent`, + parameters: { + type: "object", + properties: { + input: { + type: "string", + description: "Message to send to the sub-agent.", + }, + }, + required: ["input"], + }, + }, + }); + } + + return index; +} + +function classifyTool(key: string, tool: AgentTool): StandaloneEntry { + if (isToolkitEntry(tool)) { + return { kind: "toolkit", def: { ...tool.def, name: key }, entry: tool }; + } + if (isFunctionTool(tool)) { + return { + kind: "function", + tool, + def: { ...functionToolToDefinition(tool), name: key }, + }; + } + if (isHostedTool(tool)) { + return { + kind: "hosted", + def: { + name: key, + description: `Hosted tool: ${tool.type}`, + parameters: { type: "object", properties: {} }, + }, + }; + } + throw new Error(`runAgent: unrecognized tool shape at key "${key}"`); +} diff --git a/packages/appkit/src/core/tests/databricks.test.ts b/packages/appkit/src/core/tests/databricks.test.ts index c05345a6..9d3fe5f8 100644 --- a/packages/appkit/src/core/tests/databricks.test.ts +++ b/packages/appkit/src/core/tests/databricks.test.ts @@ -109,11 +109,11 @@ class DeferredTestPlugin implements BasePlugin { name = "deferredTest"; setupCalled = false; injectedConfig: any; - injectedPlugins: any; + injectedContext: any; constructor(config: any) { this.injectedConfig = config; - this.injectedPlugins = config.plugins; + this.injectedContext = config.context; } async setup() { @@ -130,7 +130,7 @@ class DeferredTestPlugin implements BasePlugin { return { setupCalled: this.setupCalled, injectedConfig: this.injectedConfig, - injectedPlugins: this.injectedPlugins, + injectedContext: this.injectedContext, }; } } @@ -276,7 +276,7 @@ describe("AppKit", () => { expect(setupOrder).toEqual(["core", "normal", "deferred"]); }); - test("should provide plugin instances to deferred plugins", async () => { + test("should provide PluginContext to deferred plugins", async () => { const pluginData = [ { plugin: CoreTestPlugin, config: {}, name: "coreTest" }, { plugin: DeferredTestPlugin, config: {}, name: "deferredTest" }, @@ -284,10 +284,9 @@ describe("AppKit", () => { const instance = (await createApp({ plugins: pluginData })) as any; - // Deferred plugins receive plugin instances (not SDKs) for internal use - expect(instance.deferredTest.injectedPlugins).toBeDefined(); - expect(instance.deferredTest.injectedPlugins.coreTest).toBeInstanceOf( - CoreTestPlugin, + expect(instance.deferredTest.injectedContext).toBeDefined(); + expect(instance.deferredTest.injectedContext.hasPlugin("coreTest")).toBe( + true, ); }); diff --git a/packages/appkit/src/core/tests/plugin-context.test.ts b/packages/appkit/src/core/tests/plugin-context.test.ts new file mode 100644 index 00000000..276c5502 --- /dev/null +++ b/packages/appkit/src/core/tests/plugin-context.test.ts @@ -0,0 +1,325 @@ +import type { AgentToolDefinition } from "shared"; +import { beforeEach, describe, expect, test, vi } from "vitest"; +import { isToolProvider, PluginContext } from "../plugin-context"; + +vi.mock("../../telemetry", () => ({ + TelemetryManager: { + getProvider: () => ({ + getTracer: () => ({ + startActiveSpan: (_name: string, fn: (span: any) => any) => { + const span = { + setStatus: vi.fn(), + recordException: vi.fn(), + end: vi.fn(), + }; + return fn(span); + }, + }), + }), + }, +})); + +vi.mock("../../logging/logger", () => ({ + createLogger: () => ({ + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + }), +})); + +function createMockToolProvider(tools: AgentToolDefinition[] = []) { + const mock = { + name: "mock-plugin", + setup: vi.fn().mockResolvedValue(undefined), + injectRoutes: vi.fn(), + getEndpoints: vi.fn().mockReturnValue({}), + getAgentTools: vi.fn().mockReturnValue(tools), + executeAgentTool: vi.fn().mockResolvedValue("tool-result"), + asUser: vi.fn().mockReturnThis(), + }; + return mock as any; +} + +describe("PluginContext", () => { + let ctx: PluginContext; + + beforeEach(() => { + ctx = new PluginContext(); + }); + + describe("route buffering", () => { + test("addRoute buffers when no route target exists", () => { + const handler = vi.fn(); + ctx.addRoute("post", "/invocations", handler); + + expect(ctx.getPluginNames()).toEqual([]); + }); + + test("flushRoutes applies buffered routes via addExtension", () => { + const handler = vi.fn(); + ctx.addRoute("post", "/invocations", handler); + + const addExtension = vi.fn(); + ctx.registerAsRouteTarget({ addExtension }); + + expect(addExtension).toHaveBeenCalledTimes(1); + const extensionFn = addExtension.mock.calls[0][0]; + + const mockApp = { post: vi.fn() }; + extensionFn(mockApp); + expect(mockApp.post).toHaveBeenCalledWith("/invocations", handler); + }); + + test("addRoute called after registerAsRouteTarget applies immediately", () => { + const addExtension = vi.fn(); + ctx.registerAsRouteTarget({ addExtension }); + + const handler = vi.fn(); + ctx.addRoute("get", "/health", handler); + + expect(addExtension).toHaveBeenCalledTimes(1); + const extensionFn = addExtension.mock.calls[0][0]; + + const mockApp = { get: vi.fn() }; + extensionFn(mockApp); + expect(mockApp.get).toHaveBeenCalledWith("/health", handler); + }); + + test("addRoute supports middleware chains", () => { + const auth = vi.fn(); + const handler = vi.fn(); + + const addExtension = vi.fn(); + ctx.registerAsRouteTarget({ addExtension }); + + ctx.addRoute("post", "/api", auth, handler); + + const extensionFn = addExtension.mock.calls[0][0]; + const mockApp = { post: vi.fn() }; + extensionFn(mockApp); + expect(mockApp.post).toHaveBeenCalledWith("/api", auth, handler); + }); + + test("addMiddleware buffers and applies via use()", () => { + const handler = vi.fn(); + ctx.addMiddleware("/api", handler); + + const addExtension = vi.fn(); + ctx.registerAsRouteTarget({ addExtension }); + + expect(addExtension).toHaveBeenCalledTimes(1); + const extensionFn = addExtension.mock.calls[0][0]; + + const mockApp = { use: vi.fn() }; + extensionFn(mockApp); + expect(mockApp.use).toHaveBeenCalledWith("/api", handler); + }); + + test("multiple buffered routes are all applied on registration", () => { + const h1 = vi.fn(); + const h2 = vi.fn(); + ctx.addRoute("post", "/a", h1); + ctx.addRoute("get", "/b", h2); + + const addExtension = vi.fn(); + ctx.registerAsRouteTarget({ addExtension }); + + expect(addExtension).toHaveBeenCalledTimes(2); + }); + }); + + describe("ToolProvider registry", () => { + test("registerToolProvider makes provider visible via getToolProviders", () => { + const provider = createMockToolProvider([ + { + name: "query", + description: "Run query", + parameters: { type: "object" }, + }, + ]); + + ctx.registerToolProvider("analytics", provider); + + const providers = ctx.getToolProviders(); + expect(providers).toHaveLength(1); + expect(providers[0].name).toBe("analytics"); + expect(providers[0].provider.getAgentTools()).toHaveLength(1); + }); + + test("getToolProviders returns all registered providers", () => { + ctx.registerToolProvider("analytics", createMockToolProvider()); + ctx.registerToolProvider("files", createMockToolProvider()); + ctx.registerToolProvider("genie", createMockToolProvider()); + + expect(ctx.getToolProviders()).toHaveLength(3); + }); + + test("getToolProviders returns current set, not snapshot", () => { + const before = ctx.getToolProviders(); + expect(before).toHaveLength(0); + + ctx.registerToolProvider("analytics", createMockToolProvider()); + + const after = ctx.getToolProviders(); + expect(after).toHaveLength(1); + }); + }); + + describe("executeTool", () => { + test("calls asUser(req).executeAgentTool on the correct plugin", async () => { + const provider = createMockToolProvider(); + ctx.registerToolProvider("analytics", provider); + + const mockReq = { headers: {} } as any; + await ctx.executeTool(mockReq, "analytics", "query", { sql: "SELECT 1" }); + + expect(provider.asUser).toHaveBeenCalledWith(mockReq); + expect(provider.executeAgentTool).toHaveBeenCalledWith( + "query", + { sql: "SELECT 1" }, + expect.any(Object), + ); + }); + + test("throws for unknown plugin name", async () => { + const mockReq = { headers: {} } as any; + + await expect( + ctx.executeTool(mockReq, "nonexistent", "query", {}), + ).rejects.toThrow('unknown plugin "nonexistent"'); + }); + + test("propagates tool execution errors", async () => { + const provider = createMockToolProvider(); + (provider.executeAgentTool as any).mockRejectedValue( + new Error("Query failed"), + ); + ctx.registerToolProvider("analytics", provider); + + const mockReq = { headers: {} } as any; + + await expect( + ctx.executeTool(mockReq, "analytics", "query", {}), + ).rejects.toThrow("Query failed"); + }); + + test("passes abort signal to executeAgentTool", async () => { + const provider = createMockToolProvider(); + ctx.registerToolProvider("analytics", provider); + + const controller = new AbortController(); + const mockReq = { headers: {} } as any; + + await ctx.executeTool( + mockReq, + "analytics", + "query", + {}, + controller.signal, + ); + + const callArgs = (provider.executeAgentTool as any).mock.calls[0]; + expect(callArgs[2]).toBeDefined(); + }); + }); + + describe("lifecycle hooks", () => { + test("onLifecycle registers callback, emitLifecycle invokes it", async () => { + const fn = vi.fn(); + ctx.onLifecycle("setup:complete", fn); + + await ctx.emitLifecycle("setup:complete"); + + expect(fn).toHaveBeenCalledTimes(1); + }); + + test("multiple callbacks for the same event all fire", async () => { + const fn1 = vi.fn(); + const fn2 = vi.fn(); + ctx.onLifecycle("setup:complete", fn1); + ctx.onLifecycle("setup:complete", fn2); + + await ctx.emitLifecycle("setup:complete"); + + expect(fn1).toHaveBeenCalledTimes(1); + expect(fn2).toHaveBeenCalledTimes(1); + }); + + test("callback error does not prevent other callbacks from running", async () => { + const fn1 = vi.fn().mockRejectedValue(new Error("fail")); + const fn2 = vi.fn(); + ctx.onLifecycle("shutdown", fn1); + ctx.onLifecycle("shutdown", fn2); + + await ctx.emitLifecycle("shutdown"); + + expect(fn1).toHaveBeenCalled(); + expect(fn2).toHaveBeenCalled(); + }); + + test("emitLifecycle with no registered hooks does nothing", async () => { + await expect(ctx.emitLifecycle("server:ready")).resolves.toBeUndefined(); + }); + }); + + describe("plugin metadata", () => { + const stubPlugin = { name: "stub" } as any; + + test("getPluginNames returns all registered names", () => { + ctx.registerPlugin("analytics", stubPlugin); + ctx.registerPlugin("server", stubPlugin); + ctx.registerPlugin("agent", stubPlugin); + + const names = ctx.getPluginNames(); + expect(names).toContain("analytics"); + expect(names).toContain("server"); + expect(names).toContain("agent"); + expect(names).toHaveLength(3); + }); + + test("hasPlugin returns true for registered plugins", () => { + ctx.registerPlugin("analytics", stubPlugin); + + expect(ctx.hasPlugin("analytics")).toBe(true); + expect(ctx.hasPlugin("nonexistent")).toBe(false); + }); + + test("getPlugins returns all registered instances", () => { + const p1 = { name: "analytics" } as any; + const p2 = { name: "server" } as any; + ctx.registerPlugin("analytics", p1); + ctx.registerPlugin("server", p2); + + const plugins = ctx.getPlugins(); + expect(plugins.size).toBe(2); + expect(plugins.get("analytics")).toBe(p1); + expect(plugins.get("server")).toBe(p2); + }); + }); +}); + +describe("isToolProvider", () => { + test("returns true for objects with getAgentTools and executeAgentTool", () => { + const provider = createMockToolProvider(); + expect(isToolProvider(provider)).toBe(true); + }); + + test("returns false for null", () => { + expect(isToolProvider(null)).toBe(false); + }); + + test("returns false for objects missing executeAgentTool", () => { + expect(isToolProvider({ getAgentTools: vi.fn() })).toBe(false); + }); + + test("returns false for objects missing getAgentTools", () => { + expect(isToolProvider({ executeAgentTool: vi.fn() })).toBe(false); + }); + + test("returns false for non-objects", () => { + expect(isToolProvider("string")).toBe(false); + expect(isToolProvider(42)).toBe(false); + expect(isToolProvider(undefined)).toBe(false); + }); +}); diff --git a/packages/appkit/src/errors/server.ts b/packages/appkit/src/errors/server.ts index 6af5b59f..d45148d8 100644 --- a/packages/appkit/src/errors/server.ts +++ b/packages/appkit/src/errors/server.ts @@ -6,7 +6,6 @@ import { AppKitError } from "./base"; * * @example * ```typescript - * throw new ServerError("Cannot get server when autoStart is true"); * throw new ServerError("Server not started"); * ``` */ @@ -15,15 +14,6 @@ export class ServerError extends AppKitError { readonly statusCode = 500; readonly isRetryable = false; - /** - * Create a server error for autoStart conflict - */ - static autoStartConflict(operation: string): ServerError { - return new ServerError(`Cannot ${operation} when autoStart is true`, { - context: { operation }, - }); - } - /** * Create a server error for server not started */ diff --git a/packages/appkit/src/errors/tests/errors.test.ts b/packages/appkit/src/errors/tests/errors.test.ts index c404a18f..347ce1c0 100644 --- a/packages/appkit/src/errors/tests/errors.test.ts +++ b/packages/appkit/src/errors/tests/errors.test.ts @@ -348,12 +348,6 @@ describe("ServerError", () => { expect(error.isRetryable).toBe(false); }); - test("autoStartConflict should create proper error", () => { - const error = ServerError.autoStartConflict("get server"); - expect(error.message).toBe("Cannot get server when autoStart is true"); - expect(error.context?.operation).toBe("get server"); - }); - test("notStarted should create proper error", () => { const error = ServerError.notStarted(); expect(error.message).toContain("Server not started"); diff --git a/packages/appkit/src/index.ts b/packages/appkit/src/index.ts index a4666a49..ba4110b3 100644 --- a/packages/appkit/src/index.ts +++ b/packages/appkit/src/index.ts @@ -7,11 +7,20 @@ // Types from shared export type { + AgentAdapter, + AgentEvent, + AgentInput, + AgentRunContext, + AgentToolDefinition, BasePluginConfig, CacheConfig, IAppRouter, + Message, PluginData, StreamExecutionSettings, + Thread, + ThreadStore, + ToolProvider, } from "shared"; export { isSQLTypeMarker, sql } from "shared"; export { CacheManager } from "./cache"; @@ -34,6 +43,12 @@ export { } from "./connectors/lakebase"; export { getExecutionContext } from "./context"; export { createApp } from "./core"; +export { createAgent } from "./core/create-agent-def"; +export { + type RunAgentInput, + type RunAgentResult, + runAgent, +} from "./core/run-agent"; // Errors export { AppKitError, @@ -54,6 +69,29 @@ export { toPlugin, } from "./plugin"; export { analytics, files, genie, lakebase, server, serving } from "./plugins"; +export { + type AgentDefinition, + type AgentsPluginConfig, + type AgentTool, + agentIdFromMarkdownPath, + agents, + type BaseSystemPromptOption, + isToolkitEntry, + loadAgentFromFile, + loadAgentsFromDir, + type PromptContext, + type ToolkitEntry, + type ToolkitOptions, +} from "./plugins/agents"; +export { + type FunctionTool, + type HostedTool, + isFunctionTool, + isHostedTool, + mcpServer, + type ToolConfig, + tool, +} from "./plugins/agents/tools"; // Files plugin types (for custom policy authoring) export type { FileAction, diff --git a/packages/appkit/src/plugin/index.ts b/packages/appkit/src/plugin/index.ts index 93765219..46a4eb94 100644 --- a/packages/appkit/src/plugin/index.ts +++ b/packages/appkit/src/plugin/index.ts @@ -1,4 +1,4 @@ export type { ToPlugin } from "shared"; export type { ExecutionResult } from "./execution-result"; export { Plugin } from "./plugin"; -export { toPlugin } from "./to-plugin"; +export { type NamedPluginFactory, toPlugin } from "./to-plugin"; diff --git a/packages/appkit/src/plugin/interceptors/retry.ts b/packages/appkit/src/plugin/interceptors/retry.ts index 435e0fde..aeddf3d5 100644 --- a/packages/appkit/src/plugin/interceptors/retry.ts +++ b/packages/appkit/src/plugin/interceptors/retry.ts @@ -1,10 +1,38 @@ import type { RetryConfig } from "shared"; +import { AppKitError } from "../../errors/base"; import { createLogger } from "../../logging/logger"; import type { ExecutionInterceptor, InterceptorContext } from "./types"; const logger = createLogger("interceptors:retry"); -// interceptor to handle retry logic +/** + * Determines whether an error is safe to retry. + * + * Priority: + * 1. AppKitError — reads the `isRetryable` boolean property. + * 2. Databricks SDK ApiError (duck-typed) — calls `isRetryable()` method, + * or falls back to status-code heuristic (5xx / 429 → retryable). + * 3. Unknown errors — treated as retryable to preserve backward compatibility. + */ +function isRetryableError(error: unknown): boolean { + if (error instanceof AppKitError) { + return error.isRetryable; + } + + if (error instanceof Error && "statusCode" in error) { + const record = error as Record; + if (typeof record.statusCode !== "number") { + return true; + } + if (typeof record.isRetryable === "function") { + return (record.isRetryable as () => boolean)(); + } + return record.statusCode >= 500 || record.statusCode === 429; + } + + return true; +} + export class RetryInterceptor implements ExecutionInterceptor { private attempts: number; private initialDelay: number; @@ -36,7 +64,6 @@ export class RetryInterceptor implements ExecutionInterceptor { } catch (error) { lastError = error; - // last attempt, rethrow the error if (attempt === this.attempts) { logger.event()?.setExecution({ retry_attempts: attempt - 1, @@ -44,17 +71,19 @@ export class RetryInterceptor implements ExecutionInterceptor { throw error; } - // don't retry if was already aborted if (context.signal?.aborted) { throw error; } + if (!isRetryableError(error)) { + throw error; + } + const delay = this.calculateDelay(attempt); await this.sleep(delay); } } - // type guard throw lastError; } diff --git a/packages/appkit/src/plugin/plugin.ts b/packages/appkit/src/plugin/plugin.ts index 5173cb61..4c9a0e64 100644 --- a/packages/appkit/src/plugin/plugin.ts +++ b/packages/appkit/src/plugin/plugin.ts @@ -19,6 +19,7 @@ import { ServiceContext, type UserContext, } from "../context"; +import type { PluginContext } from "../core/plugin-context"; import { AppKitError, AuthenticationError } from "../errors"; import { createLogger } from "../logging/logger"; import { StreamManager } from "../stream"; @@ -163,11 +164,12 @@ export abstract class Plugin< > implements BasePlugin { protected isReady = false; - protected cache: CacheManager; + protected cache!: CacheManager; protected app: AppManager; protected devFileReader: DevFileReader; protected streamManager: StreamManager; - protected telemetry: ITelemetry; + protected telemetry!: ITelemetry; + protected context?: PluginContext; /** Registered endpoints for this plugin */ private registeredEndpoints: PluginEndpointMap = {}; @@ -193,12 +195,58 @@ export abstract class Plugin< config.name ?? (this.constructor as { manifest?: { name: string } }).manifest?.name ?? "plugin"; - this.telemetry = TelemetryManager.getProvider(this.name, config.telemetry); this.streamManager = new StreamManager(); - this.cache = CacheManager.getInstanceSync(); this.app = new AppManager(); this.devFileReader = DevFileReader.getInstance(); + this.context = (config as Record).context as + | PluginContext + | undefined; + + // Eagerly bind telemetry + cache if the core services have already been + // initialized (normal createApp path, or tests that mock CacheManager). + // If they haven't, we leave these undefined and rely on `attachContext` + // being called later — this lets factories eagerly construct plugin + // instances at module top-level before `createApp` has run. + this.tryAttachContext(); + } + + private tryAttachContext(): void { + try { + this.cache = CacheManager.getInstanceSync(); + } catch { + return; + } + this.telemetry = TelemetryManager.getProvider( + this.name, + this.config.telemetry, + ); + this.isReady = true; + } + /** + * Binds runtime dependencies (telemetry provider, cache, plugin context) to + * this plugin. Called by `AppKit._createApp` after construction and before + * `setup()`. Idempotent: safe to call if the constructor already bound them + * eagerly. Kept separate so factories can eagerly construct plugin instances + * without running this before `TelemetryManager.initialize()` / + * `CacheManager.getInstance()` have run. + */ + attachContext( + deps: { + context?: unknown; + telemetryConfig?: BasePluginConfig["telemetry"]; + } = {}, + ): void { + if (!this.cache) { + this.cache = CacheManager.getInstanceSync(); + } + this.telemetry = TelemetryManager.getProvider( + this.name, + deps.telemetryConfig ?? this.config.telemetry, + ); + if (deps.context !== undefined) { + this.context = deps.context as PluginContext; + } this.isReady = true; } diff --git a/packages/appkit/src/plugin/tests/retry.test.ts b/packages/appkit/src/plugin/tests/retry.test.ts index b897c585..f6976f11 100644 --- a/packages/appkit/src/plugin/tests/retry.test.ts +++ b/packages/appkit/src/plugin/tests/retry.test.ts @@ -1,5 +1,9 @@ import type { RetryConfig } from "shared"; import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { AuthenticationError } from "../../errors/authentication"; +import { ConnectionError } from "../../errors/connection"; +import { ExecutionError } from "../../errors/execution"; +import { ValidationError } from "../../errors/validation"; import { RetryInterceptor } from "../interceptors/retry"; import type { InterceptorContext } from "../interceptors/types"; @@ -241,4 +245,191 @@ describe("RetryInterceptor", () => { vi.spyOn(Math, "random").mockRestore(); }); + + test("should not retry AuthenticationError (isRetryable=false)", async () => { + const config: RetryConfig = { + enabled: true, + attempts: 3, + initialDelay: 1000, + }; + const interceptor = new RetryInterceptor(config); + const fn = vi.fn().mockRejectedValue(AuthenticationError.missingToken()); + + await expect(interceptor.intercept(fn, context)).rejects.toThrow( + AuthenticationError, + ); + expect(fn).toHaveBeenCalledTimes(1); + }); + + test("should not retry ValidationError (isRetryable=false)", async () => { + const config: RetryConfig = { + enabled: true, + attempts: 3, + initialDelay: 1000, + }; + const interceptor = new RetryInterceptor(config); + const fn = vi.fn().mockRejectedValue(ValidationError.missingField("name")); + + await expect(interceptor.intercept(fn, context)).rejects.toThrow( + ValidationError, + ); + expect(fn).toHaveBeenCalledTimes(1); + }); + + test("should not retry ExecutionError (isRetryable=false)", async () => { + const config: RetryConfig = { + enabled: true, + attempts: 3, + initialDelay: 1000, + }; + const interceptor = new RetryInterceptor(config); + const fn = vi + .fn() + .mockRejectedValue(ExecutionError.statementFailed("syntax error")); + + await expect(interceptor.intercept(fn, context)).rejects.toThrow( + ExecutionError, + ); + expect(fn).toHaveBeenCalledTimes(1); + }); + + test("should retry ConnectionError (isRetryable=true)", async () => { + const config: RetryConfig = { + enabled: true, + attempts: 3, + initialDelay: 1000, + }; + const interceptor = new RetryInterceptor(config); + const fn = vi + .fn() + .mockRejectedValueOnce(ConnectionError.queryFailed()) + .mockResolvedValue("recovered"); + + const promise = interceptor.intercept(fn, context); + await vi.runAllTimersAsync(); + + expect(await promise).toBe("recovered"); + expect(fn).toHaveBeenCalledTimes(2); + }); + + test("should not retry errors with 4xx statusCode", async () => { + const config: RetryConfig = { + enabled: true, + attempts: 3, + initialDelay: 1000, + }; + const interceptor = new RetryInterceptor(config); + const error = Object.assign(new Error("bad request"), { statusCode: 400 }); + const fn = vi.fn().mockRejectedValue(error); + + await expect(interceptor.intercept(fn, context)).rejects.toThrow( + "bad request", + ); + expect(fn).toHaveBeenCalledTimes(1); + }); + + test("should retry errors with 5xx statusCode", async () => { + const config: RetryConfig = { + enabled: true, + attempts: 3, + initialDelay: 1000, + }; + const interceptor = new RetryInterceptor(config); + const fn = vi + .fn() + .mockRejectedValueOnce( + Object.assign(new Error("internal"), { statusCode: 500 }), + ) + .mockResolvedValue("recovered"); + + const promise = interceptor.intercept(fn, context); + await vi.runAllTimersAsync(); + + expect(await promise).toBe("recovered"); + expect(fn).toHaveBeenCalledTimes(2); + }); + + test("should retry errors with 429 statusCode (rate limit)", async () => { + const config: RetryConfig = { + enabled: true, + attempts: 3, + initialDelay: 1000, + }; + const interceptor = new RetryInterceptor(config); + const fn = vi + .fn() + .mockRejectedValueOnce( + Object.assign(new Error("rate limited"), { statusCode: 429 }), + ) + .mockResolvedValue("recovered"); + + const promise = interceptor.intercept(fn, context); + await vi.runAllTimersAsync(); + + expect(await promise).toBe("recovered"); + expect(fn).toHaveBeenCalledTimes(2); + }); + + test("should use isRetryable() method when available on error", async () => { + const config: RetryConfig = { + enabled: true, + attempts: 3, + initialDelay: 1000, + }; + const interceptor = new RetryInterceptor(config); + + const nonRetryable = Object.assign(new Error("not found"), { + statusCode: 404, + isRetryable: () => false, + }); + const fn = vi.fn().mockRejectedValue(nonRetryable); + + await expect(interceptor.intercept(fn, context)).rejects.toThrow( + "not found", + ); + expect(fn).toHaveBeenCalledTimes(1); + }); + + test("should respect isRetryable() returning true even for 4xx", async () => { + const config: RetryConfig = { + enabled: true, + attempts: 3, + initialDelay: 1000, + }; + const interceptor = new RetryInterceptor(config); + + const retryableClientError = Object.assign(new Error("conflict"), { + statusCode: 409, + isRetryable: () => true, + }); + const fn = vi + .fn() + .mockRejectedValueOnce(retryableClientError) + .mockResolvedValue("ok"); + + const promise = interceptor.intercept(fn, context); + await vi.runAllTimersAsync(); + + expect(await promise).toBe("ok"); + expect(fn).toHaveBeenCalledTimes(2); + }); + + test("should still retry plain Error (backward compatibility)", async () => { + const config: RetryConfig = { + enabled: true, + attempts: 3, + initialDelay: 1000, + }; + const interceptor = new RetryInterceptor(config); + const fn = vi + .fn() + .mockRejectedValueOnce(new Error("transient")) + .mockResolvedValue("ok"); + + const promise = interceptor.intercept(fn, context); + await vi.runAllTimersAsync(); + + expect(await promise).toBe("ok"); + expect(fn).toHaveBeenCalledTimes(2); + }); }); diff --git a/packages/appkit/src/plugin/to-plugin.ts b/packages/appkit/src/plugin/to-plugin.ts index 77725027..c882f300 100644 --- a/packages/appkit/src/plugin/to-plugin.ts +++ b/packages/appkit/src/plugin/to-plugin.ts @@ -1,19 +1,41 @@ import type { PluginConstructor, PluginData, ToPlugin } from "shared"; /** - * Wraps a plugin class so it can be passed to createApp with optional config. - * Infers config type from the constructor and plugin name from the static `name` property. + * Factory function produced by {@link toPlugin}. Carries a static + * `pluginName` field so tooling (e.g. `fromPlugin`) can identify which + * plugin a factory references without constructing an instance. + */ +export type NamedPluginFactory = { + readonly pluginName: Name; +}; + +/** + * Wraps a plugin class so it can be passed to `createApp` with optional + * config. Infers the config type from the constructor and the plugin name + * from the static `manifest.name` property, and stamps `pluginName` onto + * the returned factory function so `fromPlugin` can identify the plugin + * without needing to construct it. * * @internal */ export function toPlugin( plugin: T, -): ToPlugin[0], T["manifest"]["name"]> { +): ToPlugin[0], T["manifest"]["name"]> & + NamedPluginFactory { type Config = ConstructorParameters[0]; type Name = T["manifest"]["name"]; - return (config: Config = {} as Config): PluginData => ({ + const pluginName = plugin.manifest.name as Name; + const factory = ( + config: Config = {} as Config, + ): PluginData => ({ plugin: plugin as T, config: config as Config, - name: plugin.manifest.name as Name, + name: pluginName, + }); + Object.defineProperty(factory, "pluginName", { + value: pluginName, + writable: false, + enumerable: true, }); + return factory as ToPlugin & NamedPluginFactory; } diff --git a/packages/appkit/src/plugins/agents/agents.ts b/packages/appkit/src/plugins/agents/agents.ts new file mode 100644 index 00000000..ceed66e6 --- /dev/null +++ b/packages/appkit/src/plugins/agents/agents.ts @@ -0,0 +1,1246 @@ +import { randomUUID } from "node:crypto"; +import path from "node:path"; +import type express from "express"; +import pc from "picocolors"; +import type { + AgentAdapter, + AgentRunContext, + AgentToolDefinition, + IAppRouter, + Message, + PluginPhase, + ResponseStreamEvent, + Thread, + ToolAnnotations, + ToolProvider, +} from "shared"; +import { AppKitMcpClient, buildMcpHostPolicy } from "../../connectors/mcp"; +import { getWorkspaceClient } from "../../context"; +import { createLogger } from "../../logging/logger"; +import { Plugin, toPlugin } from "../../plugin"; +import type { PluginManifest } from "../../registry"; +import { consumeAdapterStream } from "./consume-adapter-stream"; +import { agentStreamDefaults } from "./defaults"; +import { EventChannel } from "./event-channel"; +import { AgentEventTranslator } from "./event-translator"; +import { loadAgentsFromDir } from "./load-agents"; +import manifest from "./manifest.json"; +import { normalizeToolResult } from "./normalize-result"; +import { + approvalRequestSchema, + chatRequestSchema, + invocationsRequestSchema, +} from "./schemas"; +import { buildBaseSystemPrompt, composeSystemPrompt } from "./system-prompt"; +import { InMemoryThreadStore } from "./thread-store"; +import { ToolApprovalGate } from "./tool-approval-gate"; +import { dispatchToolCall } from "./tool-dispatch"; +import { + functionToolToDefinition, + isFunctionTool, + isHostedTool, + resolveHostedTools, +} from "./tools"; +import type { + AgentDefinition, + AgentsPluginConfig, + BaseSystemPromptOption, + PromptContext, + RegisteredAgent, + ResolvedToolEntry, +} from "./types"; +import { isToolkitEntry } from "./types"; + +const logger = createLogger("agents"); + +const DEFAULT_AGENTS_DIR = "./config/agents"; + +/** + * Context flag recorded on the in-memory AgentDefinition to indicate whether + * it came from markdown (file) or from user code. Drives the asymmetric + * `autoInheritTools` default. + */ +interface AgentSource { + origin: "file" | "code"; +} + +export class AgentsPlugin extends Plugin implements ToolProvider { + static manifest = manifest as PluginManifest; + static phase: PluginPhase = "deferred"; + + protected declare config: AgentsPluginConfig; + + private agents = new Map(); + private defaultAgentName: string | null = null; + private activeStreams = new Map< + string, + { controller: AbortController; userId: string } + >(); + private mcpClient: AppKitMcpClient | null = null; + private threadStore; + private approvalGate = new ToolApprovalGate(); + + constructor(config: AgentsPluginConfig) { + super(config); + this.config = config; + if (config.threadStore) { + this.threadStore = config.threadStore; + } else { + this.threadStore = new InMemoryThreadStore(); + if (process.env.NODE_ENV === "production") { + logger.warn( + "InMemoryThreadStore is in use in a production build (NODE_ENV=production). " + + "Thread history is unbounded and lost on restart. " + + "Pass agents({ threadStore: }) for real deployments.", + ); + } else { + logger.info( + "Using default InMemoryThreadStore (dev-only — threads are lost on restart and grow without bound).", + ); + } + } + } + + /** Effective approval policy with defaults applied. */ + private get resolvedApprovalPolicy(): { + requireForDestructive: boolean; + timeoutMs: number; + } { + const cfg = this.config.approval ?? {}; + return { + requireForDestructive: cfg.requireForDestructive ?? true, + timeoutMs: cfg.timeoutMs ?? 60_000, + }; + } + + /** Effective DoS limits with defaults applied. */ + private get resolvedLimits(): { + maxConcurrentStreamsPerUser: number; + maxToolCalls: number; + maxSubAgentDepth: number; + } { + const cfg = this.config.limits ?? {}; + return { + maxConcurrentStreamsPerUser: cfg.maxConcurrentStreamsPerUser ?? 5, + maxToolCalls: cfg.maxToolCalls ?? 50, + maxSubAgentDepth: cfg.maxSubAgentDepth ?? 3, + }; + } + + /** Count active streams owned by a given user. */ + private countUserStreams(userId: string): number { + let n = 0; + for (const entry of this.activeStreams.values()) { + if (entry.userId === userId) n++; + } + return n; + } + + async setup() { + await this.loadAgents(); + this.mountInvocationsRoute(); + this.printRegistry(); + } + + /** + * Reload agents from the configured directory, preserving code-defined + * agents. Swaps the registry atomically at the end. + */ + async reload(): Promise { + this.agents.clear(); + this.defaultAgentName = null; + if (this.mcpClient) { + await this.mcpClient.close(); + this.mcpClient = null; + } + await this.loadAgents(); + } + + private async loadAgents() { + const { defs: fileDefs, defaultAgent: fileDefault } = + await this.loadFileDefinitions(); + + const codeDefs = this.config.agents ?? {}; + + for (const name of Object.keys(fileDefs)) { + if (codeDefs[name]) { + logger.warn( + "Agent '%s' defined in both code and a markdown file. Code definition takes precedence.", + name, + ); + } + } + + const merged: Record = + {}; + for (const [name, def] of Object.entries(fileDefs)) { + merged[name] = { def, src: { origin: "file" } }; + } + for (const [name, def] of Object.entries(codeDefs)) { + merged[name] = { def, src: { origin: "code" } }; + } + + if (Object.keys(merged).length === 0) { + logger.info( + "No agents registered (no files in %s, no code-defined agents)", + this.resolvedAgentsDir() ?? "", + ); + return; + } + + for (const [name, { def, src }] of Object.entries(merged)) { + try { + const registered = await this.buildRegisteredAgent(name, def, src); + this.agents.set(name, registered); + if (!this.defaultAgentName) this.defaultAgentName = name; + } catch (err) { + throw new Error( + `Failed to register agent '${name}' (${src.origin}): ${ + err instanceof Error ? err.message : String(err) + }`, + { cause: err instanceof Error ? err : undefined }, + ); + } + } + + if (this.config.defaultAgent) { + if (!this.agents.has(this.config.defaultAgent)) { + throw new Error( + `defaultAgent '${this.config.defaultAgent}' is not registered. Available: ${Array.from(this.agents.keys()).join(", ")}`, + ); + } + this.defaultAgentName = this.config.defaultAgent; + } else if (fileDefault && this.agents.has(fileDefault)) { + this.defaultAgentName = fileDefault; + } + } + + private resolvedAgentsDir(): string | null { + if (this.config.dir === false) return null; + const dir = this.config.dir ?? DEFAULT_AGENTS_DIR; + return path.isAbsolute(dir) ? dir : path.resolve(process.cwd(), dir); + } + + private async loadFileDefinitions(): Promise<{ + defs: Record; + defaultAgent: string | null; + }> { + const dir = this.resolvedAgentsDir(); + if (!dir) return { defs: {}, defaultAgent: null }; + + const pluginToolProviders = this.pluginProviderIndex(); + const ambient = this.config.tools ?? {}; + + const result = await loadAgentsFromDir(dir, { + defaultModel: this.config.defaultModel, + availableTools: ambient, + plugins: pluginToolProviders, + codeAgents: this.config.agents, + }); + + return result; + } + + /** + * Builds the map of plugin-name → toolkit that the markdown loader consults + * when resolving `toolkits:` frontmatter entries. + */ + private pluginProviderIndex(): Map< + string, + { toolkit: (opts?: unknown) => Record } + > { + const out = new Map(); + if (!this.context) return out; + for (const { name, provider } of this.context.getToolProviders()) { + const withToolkit = provider as ToolProvider & { + toolkit?: (opts?: unknown) => Record; + }; + if (typeof withToolkit.toolkit === "function") { + out.set(name, { + toolkit: withToolkit.toolkit.bind(withToolkit), + }); + } + } + return out; + } + + private async buildRegisteredAgent( + name: string, + def: AgentDefinition, + src: AgentSource, + ): Promise { + const adapter = await this.resolveAdapter(def, name); + const toolIndex = await this.buildToolIndex(name, def, src); + + return { + name, + instructions: def.instructions, + adapter, + toolIndex, + baseSystemPrompt: def.baseSystemPrompt, + maxSteps: def.maxSteps, + maxTokens: def.maxTokens, + ephemeral: def.ephemeral, + }; + } + + private async resolveAdapter( + def: AgentDefinition, + name: string, + ): Promise { + const source = def.model ?? this.config.defaultModel; + // Per-agent adapter knobs from `AgentDefinition` / markdown frontmatter. + // Only applied when AppKit builds the adapter itself (string or omitted + // model). Users who pass a pre-built `AgentAdapter` own these settings. + const adapterOptions: { maxSteps?: number; maxTokens?: number } = {}; + if (def.maxSteps !== undefined) adapterOptions.maxSteps = def.maxSteps; + if (def.maxTokens !== undefined) adapterOptions.maxTokens = def.maxTokens; + + if (!source) { + const { DatabricksAdapter } = await import("../../agents/databricks"); + try { + return await DatabricksAdapter.fromModelServing( + undefined, + adapterOptions, + ); + } catch (err) { + throw new Error( + `Agent '${name}' has no model configured and no DATABRICKS_AGENT_ENDPOINT default available`, + { cause: err instanceof Error ? err : undefined }, + ); + } + } + if (typeof source === "string") { + const { DatabricksAdapter } = await import("../../agents/databricks"); + return DatabricksAdapter.fromModelServing(source, adapterOptions); + } + return await source; + } + + /** + * Resolves an agent's tool record into a per-agent dispatch index. Connects + * hosted tools via MCP client. Applies `autoInheritTools` defaults when the + * definition has no declared tools/agents. + */ + private async buildToolIndex( + agentName: string, + def: AgentDefinition, + src: AgentSource, + ): Promise> { + const index = new Map(); + const hasExplicitTools = def.tools && Object.keys(def.tools).length > 0; + const hasExplicitSubAgents = + def.agents && Object.keys(def.agents).length > 0; + + const inheritDefaults = normalizeAutoInherit(this.config.autoInheritTools); + const shouldInherit = + !hasExplicitTools && + !hasExplicitSubAgents && + (src.origin === "file" ? inheritDefaults.file : inheritDefaults.code); + + if (shouldInherit) { + await this.applyAutoInherit(agentName, index); + } + + // 1. Sub-agents → agent- + for (const [childKey, childDef] of Object.entries(def.agents ?? {})) { + const toolName = `agent-${childKey}`; + index.set(toolName, { + source: "subagent", + agentName: childDef.name ?? childKey, + def: { + name: toolName, + description: + childDef.instructions.slice(0, 120) || + `Delegate to the ${childKey} sub-agent`, + parameters: { + type: "object", + properties: { + input: { + type: "string", + description: "Message to send to the sub-agent.", + }, + }, + required: ["input"], + }, + }, + }); + } + + // 2. Explicit tools (toolkit entries, function tools, hosted tools) + const hostedToCollect: import("./tools/hosted-tools").HostedTool[] = []; + for (const [key, tool] of Object.entries(def.tools ?? {})) { + if (isToolkitEntry(tool)) { + index.set(key, { + source: "toolkit", + pluginName: tool.pluginName, + localName: tool.localName, + def: { ...tool.def, name: key }, + }); + continue; + } + if (isFunctionTool(tool)) { + index.set(key, { + source: "function", + functionTool: tool, + def: { ...functionToolToDefinition(tool), name: key }, + }); + continue; + } + if (isHostedTool(tool)) { + hostedToCollect.push(tool); + continue; + } + throw new Error( + `Agent '${agentName}' tool '${key}' has an unrecognized shape`, + ); + } + + if (hostedToCollect.length > 0) { + await this.connectHostedTools(hostedToCollect, index); + } + + return index; + } + + private async applyAutoInherit( + agentName: string, + index: Map, + ): Promise { + if (!this.context) return; + const inherited: string[] = []; + const skippedByPlugin = new Map(); + const recordSkip = (pluginName: string, localName: string) => { + const list = skippedByPlugin.get(pluginName) ?? []; + list.push(localName); + skippedByPlugin.set(pluginName, list); + }; + + for (const { + name: pluginName, + provider, + } of this.context.getToolProviders()) { + if (pluginName === this.name) continue; + const withToolkit = provider as ToolProvider & { + toolkit?: (opts?: unknown) => Record; + }; + if (typeof withToolkit.toolkit === "function") { + const entries = withToolkit.toolkit() as Record; + for (const [key, maybeEntry] of Object.entries(entries)) { + if (!isToolkitEntry(maybeEntry)) continue; + if (maybeEntry.autoInheritable !== true) { + recordSkip(maybeEntry.pluginName, maybeEntry.localName); + continue; + } + index.set(key, { + source: "toolkit", + pluginName: maybeEntry.pluginName, + localName: maybeEntry.localName, + def: { ...maybeEntry.def, name: key }, + }); + inherited.push(key); + } + continue; + } + // Fallback: providers without a toolkit() still expose getAgentTools(). + // These cannot be selectively opted in per tool, so we conservatively + // skip them during auto-inherit and require explicit `tools:` wiring. + for (const tool of provider.getAgentTools()) { + recordSkip(pluginName, tool.name); + } + } + + if (inherited.length > 0) { + logger.info( + "[agent %s] auto-inherited %d tool(s): %s", + agentName, + inherited.length, + inherited.join(", "), + ); + } + if (skippedByPlugin.size > 0) { + const summary = Array.from(skippedByPlugin.entries()) + .map(([p, tools]) => `${p}(${tools.length})`) + .join(", "); + logger.info( + "[agent %s] auto-inherit skipped %d tool(s) not marked autoInheritable: %s. Wire them explicitly via `tools:` if needed.", + agentName, + Array.from(skippedByPlugin.values()).reduce( + (n, list) => n + list.length, + 0, + ), + summary, + ); + } + } + + private async connectHostedTools( + hostedTools: import("./tools/hosted-tools").HostedTool[], + index: Map, + ): Promise { + const wsClient = await this.resolveWorkspaceClient(); + await wsClient.config.ensureResolved(); + const host = wsClient.config.host; + + if (!host) { + logger.warn( + "No Databricks host available — skipping %d hosted tool(s). " + + "Set DATABRICKS_HOST or configure a profile in ~/.databrickscfg.", + hostedTools.length, + ); + return; + } + + const authenticate = async (): Promise> => { + const headers = new Headers(); + await wsClient.config.authenticate(headers); + return Object.fromEntries(headers.entries()); + }; + + if (!this.mcpClient) { + const policy = buildMcpHostPolicy(this.config.mcp, host); + this.mcpClient = new AppKitMcpClient(host, authenticate, policy); + } + + const endpoints = resolveHostedTools(hostedTools); + await this.mcpClient.connectAll(endpoints); + + for (const def of this.mcpClient.getAllToolDefinitions()) { + index.set(def.name, { + source: "mcp", + mcpToolName: def.name, + def, + }); + } + } + + /** + * Return the ambient workspace client from {@link getWorkspaceClient} when + * `ServiceContext` is initialized (the normal `createApp` path). Fall back + * to a fresh `WorkspaceClient()` that walks the SDK's credential chain — + * `DATABRICKS_HOST` / `DATABRICKS_TOKEN`, `~/.databrickscfg` profiles, + * DAB auth, OAuth, metadata service — for test rigs and manual embeds + * that never ran through `createApp`. + */ + private async resolveWorkspaceClient() { + try { + return getWorkspaceClient(); + } catch { + const { WorkspaceClient } = await import("@databricks/sdk-experimental"); + return new WorkspaceClient({}); + } + } + + // ----------------- ToolProvider (no tools of our own) -------------------- + + getAgentTools(): AgentToolDefinition[] { + return []; + } + + async executeAgentTool(): Promise { + throw new Error("AgentsPlugin does not expose executeAgentTool directly"); + } + + // ----------------- Route mounting and handlers --------------------------- + + private mountInvocationsRoute() { + if (!this.context) return; + this.context.addRoute( + "post", + "/invocations", + (req: express.Request, res: express.Response) => { + this._handleInvocations(req, res); + }, + ); + } + + injectRoutes(router: IAppRouter) { + this.route(router, { + name: "chat", + method: "post", + path: "/chat", + handler: async (req, res) => this._handleChat(req, res), + }); + this.route(router, { + name: "cancel", + method: "post", + path: "/cancel", + handler: async (req, res) => this._handleCancel(req, res), + }); + this.route(router, { + name: "approve", + method: "post", + path: "/approve", + handler: async (req, res) => this._handleApprove(req, res), + }); + this.route(router, { + name: "threads", + method: "get", + path: "/threads", + handler: async (req, res) => this._handleListThreads(req, res), + }); + this.route(router, { + name: "thread", + method: "get", + path: "/threads/:threadId", + handler: async (req, res) => this._handleGetThread(req, res), + }); + this.route(router, { + name: "deleteThread", + method: "delete", + path: "/threads/:threadId", + handler: async (req, res) => this._handleDeleteThread(req, res), + }); + this.route(router, { + name: "info", + method: "get", + path: "/info", + handler: async (_req, res) => { + res.json({ + agents: Array.from(this.agents.keys()), + defaultAgent: this.defaultAgentName, + }); + }, + }); + } + + clientConfig(): Record { + return { + agents: Array.from(this.agents.keys()), + defaultAgent: this.defaultAgentName, + }; + } + + private async _handleChat(req: express.Request, res: express.Response) { + const parsed = chatRequestSchema.safeParse(req.body); + if (!parsed.success) { + res.status(400).json({ + error: "Invalid request", + details: parsed.error.flatten().fieldErrors, + }); + return; + } + const { message, threadId, agent: agentName } = parsed.data; + + const registered = this.resolveAgent(agentName); + if (!registered) { + res.status(400).json({ + error: agentName + ? `Agent "${agentName}" not found` + : "No agent registered", + }); + return; + } + + const userId = this.resolveUserId(req); + + // Reject early (before allocating a thread) when the user is already at + // their concurrent-stream limit. Prevents a misbehaving client from + // churning thread rows while being denied elsewhere. + const limits = this.resolvedLimits; + if (this.countUserStreams(userId) >= limits.maxConcurrentStreamsPerUser) { + res.setHeader("Retry-After", "5"); + res.status(429).json({ + error: `Too many concurrent streams for this user (limit ${limits.maxConcurrentStreamsPerUser}). Wait for an existing stream to complete before starting another.`, + }); + return; + } + + let thread = threadId ? await this.threadStore.get(threadId, userId) : null; + if (threadId && !thread) { + res.status(404).json({ error: `Thread ${threadId} not found` }); + return; + } + if (!thread) { + thread = await this.threadStore.create(userId); + } + + const userMessage: Message = { + id: randomUUID(), + role: "user", + content: message, + createdAt: new Date(), + }; + await this.threadStore.addMessage(thread.id, userId, userMessage); + return this._streamAgent(req, res, registered, thread, userId); + } + + private async _handleInvocations( + req: express.Request, + res: express.Response, + ) { + const parsed = invocationsRequestSchema.safeParse(req.body); + if (!parsed.success) { + res.status(400).json({ + error: "Invalid request", + details: parsed.error.flatten().fieldErrors, + }); + return; + } + const { input } = parsed.data; + const registered = this.resolveAgent(); + if (!registered) { + res.status(400).json({ error: "No agent registered" }); + return; + } + const userId = this.resolveUserId(req); + const thread = await this.threadStore.create(userId); + + if (typeof input === "string") { + await this.threadStore.addMessage(thread.id, userId, { + id: randomUUID(), + role: "user", + content: input, + createdAt: new Date(), + }); + } else { + for (const item of input) { + const role = (item.role ?? "user") as Message["role"]; + const content = + typeof item.content === "string" + ? item.content + : JSON.stringify(item.content ?? ""); + if (!content) continue; + await this.threadStore.addMessage(thread.id, userId, { + id: randomUUID(), + role, + content, + createdAt: new Date(), + }); + } + } + + return this._streamAgent(req, res, registered, thread, userId); + } + + private async _streamAgent( + req: express.Request, + res: express.Response, + registered: RegisteredAgent, + thread: Thread, + userId: string, + ): Promise { + const abortController = new AbortController(); + const signal = abortController.signal; + const requestId = randomUUID(); + this.activeStreams.set(requestId, { controller: abortController, userId }); + + const tools = Array.from(registered.toolIndex.values()).map((e) => e.def); + const approvalPolicy = this.resolvedApprovalPolicy; + const limits = this.resolvedLimits; + const outboundEvents = new EventChannel(); + const translator = new AgentEventTranslator(); + // Per-run tool-call budget (shared across the top-level adapter and any + // sub-agents it delegates to). Counted pre-dispatch so a prompt-injected + // agent cannot drain the budget silently via denied calls. + let toolCallsUsed = 0; + + const executeTool = async ( + name: string, + args: unknown, + ): Promise => { + if (toolCallsUsed >= limits.maxToolCalls) { + abortController.abort( + new Error( + `Tool-call budget exhausted (limit ${limits.maxToolCalls}).`, + ), + ); + throw new Error( + `Tool-call budget exhausted (limit ${limits.maxToolCalls}). Raise agents({ limits: { maxToolCalls } }) or review the agent's tool-selection logic.`, + ); + } + toolCallsUsed++; + + const entry = registered.toolIndex.get(name); + if (!entry) throw new Error(`Unknown tool: ${name}`); + + if ( + approvalPolicy.requireForDestructive && + isDestructiveToolEntry(entry) + ) { + const approvalId = randomUUID(); + for (const ev of translator.translate({ + type: "approval_pending", + approvalId, + streamId: requestId, + toolName: name, + args, + annotations: combinedToolAnnotations(entry), + })) { + outboundEvents.push(ev); + } + const decision = await this.approvalGate.wait({ + approvalId, + streamId: requestId, + userId, + timeoutMs: approvalPolicy.timeoutMs, + }); + if (decision === "deny") { + return `Tool execution denied by user approval gate (tool: ${name}).`; + } + } + + const raw = await dispatchToolCall(entry, args, { + req, + signal, + pluginContext: this.context, + mcpClient: this.mcpClient, + runSubAgent: (agentName, subArgs) => { + const childAgent = this.agents.get(agentName); + if (!childAgent) throw new Error(`Sub-agent not found: ${agentName}`); + return this.runSubAgent(req, childAgent, subArgs, signal, 1); + }, + }); + return normalizeToolResult(raw); + }; + + // Drive the adapter and the approval-event side-channel concurrently. + // Outbound events from both sources flow through `outboundEvents`; the + // generator below drains the channel in order. executeTool pushes + // approval-pending events into the same channel before awaiting the gate. + const driver = (async () => { + try { + for (const evt of translator.translate({ + type: "metadata", + data: { threadId: thread.id }, + })) { + outboundEvents.push(evt); + } + + const pluginNames = this.context + ? this.context + .getPluginNames() + .filter((n) => n !== this.name && n !== "server") + : []; + const fullPrompt = composePromptForAgent( + registered, + this.config.baseSystemPrompt, + { + agentName: registered.name, + pluginNames, + toolNames: tools.map((t) => t.name), + }, + ); + + const messagesWithSystem: Message[] = [ + { + id: "system", + role: "system", + content: fullPrompt, + createdAt: new Date(), + }, + ...thread.messages, + ]; + + const stream = registered.adapter.run( + { + messages: messagesWithSystem, + tools, + threadId: thread.id, + signal, + }, + { executeTool, signal }, + ); + + const fullContent = await consumeAdapterStream(stream, { + signal, + onEvent: (event) => { + for (const translated of translator.translate(event)) { + outboundEvents.push(translated); + } + }, + }); + + if (fullContent) { + await this.threadStore.addMessage(thread.id, userId, { + id: randomUUID(), + role: "assistant", + content: fullContent, + createdAt: new Date(), + }); + } + + for (const evt of translator.finalize()) outboundEvents.push(evt); + } catch (error) { + if (signal.aborted) { + outboundEvents.close(); + return; + } + logger.error("Agent chat error: %O", error); + outboundEvents.close(error); + return; + } finally { + // Any pending approval gates for this stream are auto-denied so the + // adapter can unwind if it was still waiting. + this.approvalGate.abortStream(requestId); + this.activeStreams.delete(requestId); + // Stateless agents (e.g. autocomplete) don't persist history; drop + // the thread so `InMemoryThreadStore` doesn't accumulate one record + // per request. Swallow delete errors — the stream has already + // finished and the client has the response. + if (registered.ephemeral) { + try { + await this.threadStore.delete(thread.id, userId); + } catch (err) { + logger.warn( + "Failed to delete ephemeral thread %s: %O", + thread.id, + err, + ); + } + } + } + outboundEvents.close(); + })(); + + await this.executeStream( + res, + async function* () { + try { + for await (const ev of outboundEvents) { + yield ev; + } + } finally { + await driver.catch(() => undefined); + } + }, + { + ...agentStreamDefaults, + stream: { ...agentStreamDefaults.stream, streamId: requestId }, + }, + ); + } + + /** + * Runs a sub-agent in response to an `agent-` tool call. Returns the + * concatenated text output to hand back to the parent adapter as the tool + * result. + * + * `depth` starts at 1 for a top-level sub-agent invocation (i.e. the + * outer `_streamAgent` calls `runSubAgent(..., 1)`) and increments on + * each nested `runSubAgent` call. Depths exceeding + * `limits.maxSubAgentDepth` are rejected before any adapter work. + */ + private async runSubAgent( + req: express.Request, + child: RegisteredAgent, + args: unknown, + signal: AbortSignal, + depth: number, + ): Promise { + const limits = this.resolvedLimits; + if (depth > limits.maxSubAgentDepth) { + throw new Error( + `Sub-agent depth exceeded (limit ${limits.maxSubAgentDepth}). ` + + `Raise agents({ limits: { maxSubAgentDepth } }) or break the delegation cycle.`, + ); + } + + const input = + typeof args === "object" && + args !== null && + typeof (args as { input?: unknown }).input === "string" + ? (args as { input: string }).input + : JSON.stringify(args); + const childTools = Array.from(child.toolIndex.values()).map((e) => e.def); + + const childExecute = async ( + name: string, + childArgs: unknown, + ): Promise => { + const entry = child.toolIndex.get(name); + if (!entry) throw new Error(`Unknown tool in sub-agent: ${name}`); + return dispatchToolCall(entry, childArgs, { + req, + signal, + pluginContext: this.context, + mcpClient: this.mcpClient, + runSubAgent: (agentName, args) => { + const grandchild = this.agents.get(agentName); + if (!grandchild) throw new Error(`Sub-agent not found: ${agentName}`); + return this.runSubAgent(req, grandchild, args, signal, depth + 1); + }, + }); + }; + + const runContext: AgentRunContext = { executeTool: childExecute, signal }; + + const pluginNames = this.context + ? this.context + .getPluginNames() + .filter((n) => n !== this.name && n !== "server") + : []; + const systemPrompt = composePromptForAgent( + child, + this.config.baseSystemPrompt, + { + agentName: child.name, + pluginNames, + toolNames: childTools.map((t) => t.name), + }, + ); + + const messages: Message[] = [ + { + id: "system", + role: "system", + content: systemPrompt, + createdAt: new Date(), + }, + { + id: randomUUID(), + role: "user", + content: input, + createdAt: new Date(), + }, + ]; + + return consumeAdapterStream( + child.adapter.run( + { messages, tools: childTools, threadId: randomUUID(), signal }, + runContext, + ), + { signal }, + ); + } + + private async _handleCancel(req: express.Request, res: express.Response) { + const { streamId } = req.body as { streamId?: string }; + if (!streamId) { + res.status(400).json({ error: "streamId is required" }); + return; + } + const entry = this.activeStreams.get(streamId); + if (!entry) { + // Stream is unknown or already completed — idempotent no-op. + res.json({ cancelled: true }); + return; + } + const userId = this.resolveUserId(req); + if (entry.userId !== userId) { + res.status(403).json({ error: "Forbidden" }); + return; + } + entry.controller.abort("Cancelled by user"); + this.activeStreams.delete(streamId); + this.approvalGate.abortStream(streamId); + res.json({ cancelled: true }); + } + + private async _handleApprove(req: express.Request, res: express.Response) { + const parsed = approvalRequestSchema.safeParse(req.body); + if (!parsed.success) { + res.status(400).json({ + error: "Invalid request", + details: parsed.error.flatten().fieldErrors, + }); + return; + } + const { streamId, approvalId, decision } = parsed.data; + + const streamEntry = this.activeStreams.get(streamId); + if (!streamEntry) { + // Stream has already completed or never existed. Return 404 so the UI + // knows the approval token is no longer valid (the waiter, if any, has + // already been timed out or aborted). + res.status(404).json({ error: "Stream not found or already completed" }); + return; + } + + const userId = this.resolveUserId(req); + if (streamEntry.userId !== userId) { + res.status(403).json({ error: "Forbidden" }); + return; + } + + const result = this.approvalGate.submit({ approvalId, userId, decision }); + if (!result.ok) { + if (result.reason === "forbidden") { + res.status(403).json({ error: "Forbidden" }); + return; + } + res.status(404).json({ error: "Approval not found or already settled" }); + return; + } + + res.json({ decision }); + } + + private async _handleListThreads( + req: express.Request, + res: express.Response, + ) { + const userId = this.resolveUserId(req); + const threads = await this.threadStore.list(userId); + res.json({ threads }); + } + + private async _handleGetThread(req: express.Request, res: express.Response) { + const userId = this.resolveUserId(req); + const thread = await this.threadStore.get(req.params.threadId, userId); + if (!thread) { + res.status(404).json({ error: "Thread not found" }); + return; + } + res.json(thread); + } + + private async _handleDeleteThread( + req: express.Request, + res: express.Response, + ) { + const userId = this.resolveUserId(req); + const deleted = await this.threadStore.delete(req.params.threadId, userId); + if (!deleted) { + res.status(404).json({ error: "Thread not found" }); + return; + } + res.json({ deleted: true }); + } + + private resolveAgent(name?: string): RegisteredAgent | null { + if (name) return this.agents.get(name) ?? null; + if (this.defaultAgentName) { + return this.agents.get(this.defaultAgentName) ?? null; + } + const first = this.agents.values().next(); + return first.done ? null : first.value; + } + + private printRegistry(): void { + if (this.agents.size === 0) return; + console.log(""); + console.log(` ${pc.bold("Agents")} ${pc.dim(`(${this.agents.size})`)}`); + console.log(` ${pc.dim("─".repeat(60))}`); + for (const [name, reg] of this.agents) { + const tools = reg.toolIndex.size; + const marker = name === this.defaultAgentName ? pc.green("●") : " "; + console.log( + ` ${marker} ${pc.bold(name.padEnd(24))} ${pc.dim(`${tools} tools`)}`, + ); + } + console.log(` ${pc.dim("─".repeat(60))}`); + console.log(""); + } + + async shutdown(): Promise { + this.approvalGate.abortAll(); + if (this.mcpClient) { + await this.mcpClient.close(); + this.mcpClient = null; + } + } + + exports() { + return { + register: (name: string, def: AgentDefinition) => + this.registerCodeAgent(name, def), + list: () => Array.from(this.agents.keys()), + get: (name: string) => this.agents.get(name) ?? null, + reload: () => this.reload(), + getDefault: () => this.defaultAgentName, + getThreads: (userId: string) => this.threadStore.list(userId), + }; + } + + private async registerCodeAgent( + name: string, + def: AgentDefinition, + ): Promise { + const registered = await this.buildRegisteredAgent(name, def, { + origin: "code", + }); + this.agents.set(name, registered); + if (!this.defaultAgentName) this.defaultAgentName = name; + } +} + +/** + * True when the tool should go through the approval gate. Historically + * scoped to `destructive: true` — hence the name — but now also fires for + * the semantic `effect` enum on {@link ToolAnnotations}. Any effect that + * mutates the world (`write` | `update` | `destructive`) gates; `read` and + * unannotated tools do not. `def.annotations` is the normal path; for + * `function` tools we also read `functionTool.annotations` so a mismatch + * between the spread def and the original {@link FunctionTool} cannot drop + * the hint. + */ +function isDestructiveToolEntry(entry: ResolvedToolEntry): boolean { + const defAnn = entry.def.annotations; + const fnAnn = + entry.source === "function" ? entry.functionTool.annotations : undefined; + + const effect = defAnn?.effect ?? fnAnn?.effect; + if (effect === "write" || effect === "update" || effect === "destructive") { + return true; + } + if (defAnn?.destructive === true) return true; + if (fnAnn?.destructive === true) return true; + return false; +} + +/** Merged annotations for the approval SSE payload (client UI + debugging). */ +function combinedToolAnnotations( + entry: ResolvedToolEntry, +): ToolAnnotations | undefined { + if (entry.source === "function") { + const merged: ToolAnnotations = { + ...entry.functionTool.annotations, + ...entry.def.annotations, + }; + return Object.keys(merged).length > 0 ? merged : undefined; + } + return entry.def.annotations; +} + +function normalizeAutoInherit(value: AgentsPluginConfig["autoInheritTools"]): { + file: boolean; + code: boolean; +} { + // Default is opt-out for both origins. A markdown agent or code-defined + // agent with no declared `tools:` gets an empty tool index unless the + // developer explicitly flips `autoInheritTools` on. Even then, only tools + // whose plugin author marked `autoInheritable: true` are spread — see + // `applyAutoInherit` for the filter. + if (value === undefined) return { file: false, code: false }; + if (typeof value === "boolean") return { file: value, code: value }; + return { file: value.file ?? false, code: value.code ?? false }; +} + +function composePromptForAgent( + registered: RegisteredAgent, + pluginLevel: BaseSystemPromptOption | undefined, + ctx: PromptContext, +): string { + const perAgent = registered.baseSystemPrompt; + const resolved = perAgent !== undefined ? perAgent : pluginLevel; + + let base = ""; + if (resolved === false) { + base = ""; + } else if (typeof resolved === "string") { + base = resolved; + } else if (typeof resolved === "function") { + base = resolved(ctx); + } else { + base = buildBaseSystemPrompt(ctx); + } + + return composeSystemPrompt(base, registered.instructions); +} + +/** + * Plugin factory for the agents plugin. Reads `config/agents//agent.md` by default, + * resolves toolkits/tools from registered plugins, exposes `appkit.agents.*` + * runtime API and mounts `/invocations`. + * + * @example + * ```ts + * import { agents, analytics, createApp, server } from "@databricks/appkit"; + * + * await createApp({ + * plugins: [server(), analytics(), agents()], + * }); + * ``` + */ +export const agents = toPlugin(AgentsPlugin); diff --git a/packages/appkit/src/plugins/agents/build-toolkit.ts b/packages/appkit/src/plugins/agents/build-toolkit.ts new file mode 100644 index 00000000..0140425d --- /dev/null +++ b/packages/appkit/src/plugins/agents/build-toolkit.ts @@ -0,0 +1,63 @@ +import type { AgentToolDefinition } from "shared"; +import type { ToolRegistry } from "./tools/define-tool"; +import { toToolJSONSchema } from "./tools/json-schema"; +import type { ToolkitEntry, ToolkitOptions } from "./types"; + +/** + * Converts a plugin's internal `ToolRegistry` into a keyed record of + * `ToolkitEntry` markers suitable for spreading into an `AgentDefinition.tools` + * record. + * + * The `opts` record controls shape and filtering: + * - `prefix` — overrides the default `${pluginName}.` prefix; `""` drops it. + * - `only` — allowlist of local tool names to include (post-prefix). + * - `except` — denylist of local names. + * - `rename` — per-tool key remapping (applied after prefix/filter). + * + * Each entry carries `pluginName` + `localName` so the agents plugin can + * dispatch back through `PluginContext.executeTool` for OBO + telemetry. + */ +export function buildToolkitEntries( + pluginName: string, + registry: ToolRegistry, + opts: ToolkitOptions = {}, +): Record { + const prefix = opts.prefix ?? `${pluginName}.`; + const only = opts.only ? new Set(opts.only) : null; + const except = opts.except ? new Set(opts.except) : null; + const rename = opts.rename ?? {}; + + const out: Record = {}; + + for (const [localName, entry] of Object.entries(registry)) { + if (only && !only.has(localName)) continue; + if (except?.has(localName)) continue; + + const keyAfterPrefix = `${prefix}${localName}`; + const key = rename[localName] ?? keyAfterPrefix; + + const parameters = toToolJSONSchema( + entry.schema, + ) as unknown as AgentToolDefinition["parameters"]; + + const def: AgentToolDefinition = { + name: key, + description: entry.description, + parameters, + }; + if (entry.annotations) { + def.annotations = entry.annotations; + } + + out[key] = { + __toolkitRef: true, + pluginName, + localName, + def, + annotations: entry.annotations, + autoInheritable: entry.autoInheritable, + }; + } + + return out; +} diff --git a/packages/appkit/src/plugins/agents/consume-adapter-stream.ts b/packages/appkit/src/plugins/agents/consume-adapter-stream.ts new file mode 100644 index 00000000..c4f3d07e --- /dev/null +++ b/packages/appkit/src/plugins/agents/consume-adapter-stream.ts @@ -0,0 +1,52 @@ +import type { AgentEvent } from "shared"; + +interface ConsumeAdapterStreamOptions { + /** + * Optional abort signal. When aborted, the loop stops consuming (the caller + * is expected to have forwarded the same signal to `adapter.run` to stop + * upstream work). `undefined` is valid — standalone `runAgent` runs without + * a signal. + */ + signal?: AbortSignal; + /** + * Side-effect callback invoked once per adapter event, after the content + * accumulator has been updated. Use to fan events out to SSE translators, + * collect a raw event list for tests, or emit telemetry. + */ + onEvent?: (event: AgentEvent) => void; +} + +/** + * Consume an adapter's event stream and aggregate the assistant's final text. + * + * Accumulation rule (shared across all agent-execution paths in AppKit): + * + * - `message_delta` events append their `content` to the running text. + * - A `message` event *replaces* the running text with its `content`. + * + * The two branches coexist because different adapters emit different shapes: + * streaming adapters (Databricks, Vercel AI) emit deltas chunk-by-chunk, + * while `LangChain`'s `on_chain_end` path emits a single final `message`. + * Without the replace branch, LangChain conversations silently dropped the + * assistant turn from thread history. + * + * Kept pure (no I/O, no mutable external state beyond the caller's `onEvent` + * side effect) so each execution path — HTTP streaming, sub-agents, and the + * standalone `runAgent` — can share one loop. + */ +export async function consumeAdapterStream( + stream: AsyncIterable, + opts: ConsumeAdapterStreamOptions = {}, +): Promise { + let text = ""; + for await (const event of stream) { + if (opts.signal?.aborted) break; + if (event.type === "message_delta") { + text += event.content; + } else if (event.type === "message") { + text = event.content; + } + opts.onEvent?.(event); + } + return text; +} diff --git a/packages/appkit/src/plugins/agents/defaults.ts b/packages/appkit/src/plugins/agents/defaults.ts new file mode 100644 index 00000000..4da11bef --- /dev/null +++ b/packages/appkit/src/plugins/agents/defaults.ts @@ -0,0 +1,12 @@ +import type { StreamExecutionSettings } from "shared"; + +export const agentStreamDefaults: StreamExecutionSettings = { + default: { + cache: { enabled: false }, + retry: { enabled: false }, + timeout: 300_000, + }, + stream: { + bufferSize: 200, + }, +}; diff --git a/packages/appkit/src/plugins/agents/event-channel.ts b/packages/appkit/src/plugins/agents/event-channel.ts new file mode 100644 index 00000000..c5b60463 --- /dev/null +++ b/packages/appkit/src/plugins/agents/event-channel.ts @@ -0,0 +1,70 @@ +/** + * Single-producer/single-consumer async queue used by the agents plugin to + * merge streams of SSE events from two concurrent sources: the adapter's + * `run()` generator, and out-of-band events emitted by `executeTool` (e.g. + * human-approval requests). + * + * The consumer drains the channel as an async iterable; the producer pushes + * events synchronously and closes the channel when the source has completed + * or errored. + */ +interface Waiter { + resolve: (value: IteratorResult) => void; + reject: (error: unknown) => void; +} + +export class EventChannel { + private queue: T[] = []; + private waiters: Array> = []; + private closed = false; + private error: unknown = undefined; + + /** Synchronously enqueue an event. Safe to call from non-async contexts. */ + push(value: T): void { + if (this.closed) return; + const waiter = this.waiters.shift(); + if (waiter) { + waiter.resolve({ value, done: false }); + } else { + this.queue.push(value); + } + } + + /** + * Close the channel. Any pending `next()` calls resolve with `done: true`. + * If `error` is supplied, pending `next()` calls reject with it and future + * calls do the same. + */ + close(error?: unknown): void { + if (this.closed) return; + this.closed = true; + this.error = error; + while (this.waiters.length > 0) { + const waiter = this.waiters.shift(); + if (!waiter) break; + if (error) { + waiter.reject(error); + } else { + waiter.resolve({ value: undefined as never, done: true }); + } + } + } + + [Symbol.asyncIterator](): AsyncIterator { + return { + next: (): Promise> => { + if (this.queue.length > 0) { + const value = this.queue.shift() as T; + return Promise.resolve({ value, done: false }); + } + if (this.closed) { + if (this.error) return Promise.reject(this.error); + return Promise.resolve({ value: undefined as never, done: true }); + } + return new Promise((resolve, reject) => { + this.waiters.push({ resolve, reject }); + }); + }, + }; + } +} diff --git a/packages/appkit/src/plugins/agents/event-translator.ts b/packages/appkit/src/plugins/agents/event-translator.ts new file mode 100644 index 00000000..54d749fb --- /dev/null +++ b/packages/appkit/src/plugins/agents/event-translator.ts @@ -0,0 +1,291 @@ +import { randomUUID } from "node:crypto"; +import type { + AgentEvent, + ResponseFunctionCallOutput, + ResponseFunctionToolCall, + ResponseOutputMessage, + ResponseStreamEvent, +} from "shared"; + +/** + * Translates internal `AgentEvent` stream into Responses API SSE events. + * + * Stateful: one instance per streaming request. Tracks sequence numbers and + * allocates `output_index` strictly monotonically — each emitted output + * item (message, function call, function call output) claims the next + * available index and, once claimed, never reuses an earlier one. This is a + * Responses-API contract that OpenAI's own SDK parsers enforce. + * + * A message is opened lazily on the first `message_delta` or `message` + * event. If a `tool_call` or `tool_result` arrives while a message is open + * (common ReAct flow: partial text → tool call → more text), the open + * message is closed (`response.output_item.done`) BEFORE the tool item is + * added, so subsequent text resumes as a new message item at a strictly + * later index. + */ +export class AgentEventTranslator { + private seqNum = 0; + private nextOutputIndex = 0; + private currentMessage: { + id: string; + text: string; + outputIndex: number; + } | null = null; + private finalized = false; + + translate(event: AgentEvent): ResponseStreamEvent[] { + switch (event.type) { + case "message_delta": + return this.handleMessageDelta(event.content); + case "message": + return this.handleFullMessage(event.content); + case "tool_call": + return this.handleToolCall(event.callId, event.name, event.args); + case "tool_result": + return this.handleToolResult(event.callId, event.result, event.error); + case "thinking": + return [ + { + type: "appkit.thinking", + content: event.content, + sequence_number: this.seqNum++, + }, + ]; + case "metadata": + return [ + { + type: "appkit.metadata", + data: event.data, + sequence_number: this.seqNum++, + }, + ]; + case "approval_pending": + return [ + { + type: "appkit.approval_pending", + approval_id: event.approvalId, + stream_id: event.streamId, + tool_name: event.toolName, + args: event.args, + annotations: event.annotations, + sequence_number: this.seqNum++, + }, + ]; + case "status": + return this.handleStatus(event.status, event.error); + } + } + + finalize(): ResponseStreamEvent[] { + if (this.finalized) return []; + this.finalized = true; + + const events: ResponseStreamEvent[] = []; + const closeEvent = this.closeCurrentMessage(); + if (closeEvent) events.push(closeEvent); + + events.push({ + type: "response.completed", + sequence_number: this.seqNum++, + response: {}, + }); + + return events; + } + + private handleMessageDelta(content: string): ResponseStreamEvent[] { + const events: ResponseStreamEvent[] = []; + + if (!this.currentMessage) { + const id = `msg_${randomUUID()}`; + const outputIndex = this.nextOutputIndex++; + this.currentMessage = { id, text: content, outputIndex }; + const item: ResponseOutputMessage = { + type: "message", + id, + status: "in_progress", + role: "assistant", + content: [], + }; + events.push({ + type: "response.output_item.added", + output_index: outputIndex, + item, + sequence_number: this.seqNum++, + }); + } else { + this.currentMessage.text += content; + } + + events.push({ + type: "response.output_text.delta", + item_id: this.currentMessage.id, + output_index: this.currentMessage.outputIndex, + content_index: 0, + delta: content, + sequence_number: this.seqNum++, + }); + + return events; + } + + private handleFullMessage(content: string): ResponseStreamEvent[] { + const events: ResponseStreamEvent[] = []; + + if (!this.currentMessage) { + // No prior deltas — open and immediately close. + const id = `msg_${randomUUID()}`; + const outputIndex = this.nextOutputIndex++; + this.currentMessage = { id, text: content, outputIndex }; + const addedItem: ResponseOutputMessage = { + type: "message", + id, + status: "in_progress", + role: "assistant", + content: [], + }; + events.push({ + type: "response.output_item.added", + output_index: outputIndex, + item: addedItem, + sequence_number: this.seqNum++, + }); + } else { + // Deltas already opened the item; `message` overrides the accumulated + // text (per adapter contract) and closes it. + this.currentMessage.text = content; + } + + const closeEvent = this.closeCurrentMessage(); + if (closeEvent) events.push(closeEvent); + return events; + } + + private handleToolCall( + callId: string, + name: string, + args: unknown, + ): ResponseStreamEvent[] { + const events: ResponseStreamEvent[] = []; + const closeEvent = this.closeCurrentMessage(); + if (closeEvent) events.push(closeEvent); + + const outputIndex = this.nextOutputIndex++; + const item: ResponseFunctionToolCall = { + type: "function_call", + id: `fc_${randomUUID()}`, + call_id: callId, + name, + arguments: typeof args === "string" ? args : JSON.stringify(args), + }; + + events.push( + { + type: "response.output_item.added", + output_index: outputIndex, + item, + sequence_number: this.seqNum++, + }, + { + type: "response.output_item.done", + output_index: outputIndex, + item, + sequence_number: this.seqNum++, + }, + ); + return events; + } + + private handleToolResult( + callId: string, + result: unknown, + error?: string, + ): ResponseStreamEvent[] { + const events: ResponseStreamEvent[] = []; + const closeEvent = this.closeCurrentMessage(); + if (closeEvent) events.push(closeEvent); + + const outputIndex = this.nextOutputIndex++; + // Coalesce `undefined` → "" so the wire shape is always a string (the + // Responses API contract). Non-string results are JSON-serialised. + let output: string; + if (error !== undefined) { + output = error; + } else if (typeof result === "string") { + output = result; + } else if (result === undefined) { + output = ""; + } else { + output = JSON.stringify(result); + } + const item: ResponseFunctionCallOutput = { + type: "function_call_output", + id: `fc_output_${randomUUID()}`, + call_id: callId, + output, + }; + + events.push( + { + type: "response.output_item.added", + output_index: outputIndex, + item, + sequence_number: this.seqNum++, + }, + { + type: "response.output_item.done", + output_index: outputIndex, + item, + sequence_number: this.seqNum++, + }, + ); + return events; + } + + /** + * Emit an `response.output_item.done` for the currently-open message, if + * any, and clear the state. Returns the event to the caller so it can be + * pushed at the right moment in the sequence. Returns `null` when there + * is no open message. + */ + private closeCurrentMessage(): ResponseStreamEvent | null { + if (!this.currentMessage) return null; + const { id, text, outputIndex } = this.currentMessage; + this.currentMessage = null; + const doneItem: ResponseOutputMessage = { + type: "message", + id, + status: "completed", + role: "assistant", + content: [{ type: "output_text", text }], + }; + return { + type: "response.output_item.done", + output_index: outputIndex, + item: doneItem, + sequence_number: this.seqNum++, + }; + } + + private handleStatus(status: string, error?: string): ResponseStreamEvent[] { + if (status === "error") { + return [ + { + type: "error", + error: error ?? "Unknown error", + sequence_number: this.seqNum++, + }, + { + type: "response.failed", + sequence_number: this.seqNum++, + }, + ]; + } + + if (status === "complete") { + return this.finalize(); + } + + return []; + } +} diff --git a/packages/appkit/src/plugins/agents/index.ts b/packages/appkit/src/plugins/agents/index.ts new file mode 100644 index 00000000..377a8776 --- /dev/null +++ b/packages/appkit/src/plugins/agents/index.ts @@ -0,0 +1,23 @@ +export { AgentsPlugin, agents } from "./agents"; +export { buildToolkitEntries } from "./build-toolkit"; +export { + agentIdFromMarkdownPath, + type LoadContext, + type LoadResult, + loadAgentFromFile, + loadAgentsFromDir, + parseFrontmatter, +} from "./load-agents"; +export { + type AgentDefinition, + type AgentsPluginConfig, + type AgentTool, + type AutoInheritToolsConfig, + type BaseSystemPromptOption, + isToolkitEntry, + type PromptContext, + type RegisteredAgent, + type ResolvedToolEntry, + type ToolkitEntry, + type ToolkitOptions, +} from "./types"; diff --git a/packages/appkit/src/plugins/agents/load-agents.ts b/packages/appkit/src/plugins/agents/load-agents.ts new file mode 100644 index 00000000..5b2999ca --- /dev/null +++ b/packages/appkit/src/plugins/agents/load-agents.ts @@ -0,0 +1,407 @@ +import fs from "node:fs"; +import path from "node:path"; +import yaml from "js-yaml"; +import type { AgentAdapter } from "shared"; +import { createLogger } from "../../logging/logger"; +import type { + AgentDefinition, + AgentTool, + BaseSystemPromptOption, + ToolkitEntry, + ToolkitOptions, +} from "./types"; +import { isToolkitEntry } from "./types"; + +const logger = createLogger("agents:loader"); + +interface ToolkitProvider { + toolkit: (opts?: ToolkitOptions) => Record; +} + +export interface LoadContext { + /** Default model when frontmatter has no `endpoint` and the def has no `model`. */ + defaultModel?: AgentAdapter | Promise | string; + /** Ambient tool library referenced by frontmatter `tools: [key1, key2]`. */ + availableTools?: Record; + /** Registered plugin toolkits referenced by frontmatter `toolkits: [...]`. */ + plugins?: Map; + /** + * Code-defined agents contributed by `agents({ agents: { ... } })`. The + * directory loader resolves `agents:` frontmatter references against + * these alongside sibling markdown files, so a markdown parent can + * delegate to a code-defined child. Code-defined names win on collision + * with markdown names, matching the plugin's top-level merge precedence. + */ + codeAgents?: Record; +} + +export interface LoadResult { + /** Agent definitions keyed by agent id (directory name under `dir`). */ + defs: Record; + /** First agent with `default: true` frontmatter (sorted id order), or `null`. */ + defaultAgent: string | null; +} + +interface Frontmatter { + endpoint?: string; + model?: string; + toolkits?: ToolkitSpec[]; + tools?: string[]; + /** + * Other agent ids to expose as sub-agents. Each becomes an `agent-` + * tool at runtime. Resolution happens at directory-load time in + * {@link loadAgentsFromDir}; the single-file {@link loadAgentFromFile} path + * rejects non-empty values since there are no siblings to resolve against. + */ + agents?: string[]; + maxSteps?: number; + maxTokens?: number; + default?: boolean; + baseSystemPrompt?: false | string; + ephemeral?: boolean; +} + +type ToolkitSpec = string | { [pluginName: string]: ToolkitOptions | string[] }; + +/** + * Derives the logical agent id from a markdown path. When the file is named + * `agent.md`, the id is the parent directory name (folder-based layout); + * otherwise the id is the file stem (e.g. legacy single-file paths). + */ +export function agentIdFromMarkdownPath(filePath: string): string { + const normalized = path.normalize(filePath); + const base = path.basename(normalized); + const parent = path.basename(path.dirname(normalized)); + if (base === "agent.md" && parent && parent !== "." && parent !== "..") { + return parent; + } + return path.basename(normalized, ".md"); +} + +const ALLOWED_KEYS = new Set([ + "endpoint", + "model", + "toolkits", + "tools", + "agents", + "maxSteps", + "maxTokens", + "default", + "baseSystemPrompt", + "ephemeral", +]); + +/** + * Loads a single markdown agent file and resolves its frontmatter against + * registered plugin toolkits + ambient tool library. + * + * Rejects non-empty `agents:` frontmatter because single-file loads have + * no siblings to resolve sub-agent references against — callers must use + * {@link loadAgentsFromDir} when markdown agents delegate to one another. + */ +export async function loadAgentFromFile( + filePath: string, + ctx: LoadContext, +): Promise { + const raw = fs.readFileSync(filePath, "utf-8"); + const name = agentIdFromMarkdownPath(filePath); + const { data } = parseFrontmatter(raw, filePath); + if (Array.isArray(data?.agents) && data.agents.length > 0) { + throw new Error( + `Agent '${name}' (${filePath}) declares 'agents:' in frontmatter, ` + + `which requires loadAgentsFromDir to resolve sibling references. ` + + `Use loadAgentsFromDir, or wire sub-agents in code via createAgent({ agents: { ... } }).`, + ); + } + return buildDefinition(name, raw, filePath, ctx); +} + +/** + * Scans a directory for one subdirectory per agent, each containing + * `agent.md` (frontmatter + body). Produces an `AgentDefinition` record keyed + * by agent id (folder name). Throws on frontmatter errors or unresolved + * references. Returns an empty map if the directory does not exist. + * + * Legacy top-level `*.md` files are rejected with an error — migrate each to + * `/agent.md` under a sibling folder named for the agent id. + * + * Runs in two passes so sub-agent references in frontmatter (`agents: [...]`) + * can be resolved regardless of directory iteration order: + * + * 1. Build every agent's definition from its own `agent.md`. + * 2. Walk `agents:` references and wire `def.agents = { child: childDef }` + * by looking them up in the complete map. Dangling names and + * self-references fail loudly; mutual delegation is allowed and bounded + * at runtime by `limits.maxSubAgentDepth`. + */ +export async function loadAgentsFromDir( + dir: string, + ctx: LoadContext, +): Promise { + if (!fs.existsSync(dir)) { + return { defs: {}, defaultAgent: null }; + } + + const entries = fs.readdirSync(dir, { withFileTypes: true }); + const orphanMd = entries + .filter((e) => e.isFile() && e.name.endsWith(".md")) + .map((e) => e.name) + .sort(); + + if (orphanMd.length > 0) { + const hint = orphanMd + .map((f) => `${path.basename(f, ".md")}/agent.md`) + .join(", "); + throw new Error( + `Agents directory contains unsupported top-level markdown file(s): ${orphanMd.join(", ")}. ` + + `Use one folder per agent with a fixed entry file, e.g. ${hint}.`, + ); + } + + /** Reserved folder name until per-agent skills land; not an agent package. */ + const RESERVED_DIRS = new Set(["skills"]); + + const agentIds = entries + .filter((e) => e.isDirectory()) + .map((e) => e.name) + .filter((name) => !RESERVED_DIRS.has(name)) + .sort(); + + const defs: Record = {}; + const subAgentRefs: Record = {}; + let defaultAgent: string | null = null; + + // Pass 1: build every agent's definition; collect sub-agent refs. + for (const id of agentIds) { + const agentPath = path.join(dir, id, "agent.md"); + if (!fs.existsSync(agentPath)) { + throw new Error( + `Agents subdirectory '${path.join(dir, id)}' must contain agent.md.`, + ); + } + const raw = fs.readFileSync(agentPath, "utf-8"); + defs[id] = buildDefinition(id, raw, agentPath, ctx); + const { data } = parseFrontmatter(raw, agentPath); + if (data?.agents !== undefined) { + subAgentRefs[id] = normalizeAgentsFrontmatter(data.agents, id, agentPath); + } + if (data?.default === true && !defaultAgent) { + defaultAgent = id; + } + } + + // Pass 2: resolve sub-agent references against the complete defs map. + // Code-defined agents (ctx.codeAgents) take precedence over markdown ones + // with the same name, matching the plugin's top-level merge behaviour. + for (const [name, refs] of Object.entries(subAgentRefs)) { + if (refs.length === 0) continue; + const children: Record = {}; + const missing: string[] = []; + for (const ref of refs) { + if (ref === name) { + throw new Error( + `Agent '${name}' (${path.join(dir, name, "agent.md")}) cannot reference itself in 'agents:'.`, + ); + } + const sibling = ctx.codeAgents?.[ref] ?? defs[ref]; + if (!sibling) { + missing.push(ref); + continue; + } + children[ref] = sibling; + } + if (missing.length > 0) { + const available = + [...Object.keys(ctx.codeAgents ?? {}), ...Object.keys(defs)] + .sort() + .join(", ") || ""; + throw new Error( + `Agent '${name}' references sub-agent(s) '${missing.join(", ")}' in 'agents:', ` + + `but no markdown or code agent(s) with those names exist. ` + + `Available: ${available}.`, + ); + } + defs[name].agents = children; + } + + return { defs, defaultAgent }; +} + +/** + * Validates that `agents:` frontmatter is an array of non-empty strings and + * returns it with duplicates removed. Throws with a clear per-file message + * on malformed input rather than silently ignoring. + */ +function normalizeAgentsFrontmatter( + value: unknown, + agentName: string, + filePath: string, +): string[] { + if (!Array.isArray(value)) { + throw new Error( + `Agent '${agentName}' (${filePath}) has invalid 'agents:' frontmatter: ` + + `expected an array of sibling agent ids, got ${typeof value}.`, + ); + } + const out: string[] = []; + const seen = new Set(); + for (const item of value) { + if (typeof item !== "string" || item.trim() === "") { + throw new Error( + `Agent '${agentName}' (${filePath}) has invalid 'agents:' entry: ` + + `expected non-empty string, got ${JSON.stringify(item)}.`, + ); + } + if (seen.has(item)) continue; + seen.add(item); + out.push(item); + } + return out; +} + +/** Exposed for tests. Parses `--- yaml ---\nbody` and validates frontmatter keys. */ +export function parseFrontmatter( + raw: string, + sourcePath?: string, +): { data: Frontmatter | null; content: string } { + const match = raw.match(/^---\r?\n([\s\S]*?)\r?\n---\r?\n?([\s\S]*)$/); + if (!match) { + return { data: null, content: raw.trim() }; + } + let parsed: unknown; + try { + parsed = yaml.load(match[1]); + } catch (err) { + const src = sourcePath ? ` (${sourcePath})` : ""; + throw new Error( + `Invalid YAML frontmatter${src}: ${err instanceof Error ? err.message : String(err)}`, + ); + } + if (parsed === null || parsed === undefined) { + return { data: {}, content: match[2].trim() }; + } + if (typeof parsed !== "object" || Array.isArray(parsed)) { + const src = sourcePath ? ` (${sourcePath})` : ""; + throw new Error(`Frontmatter must be a YAML object${src}`); + } + const data = parsed as Record; + for (const key of Object.keys(data)) { + if (!ALLOWED_KEYS.has(key)) { + logger.warn( + "Ignoring unknown frontmatter key '%s' in %s", + key, + sourcePath ?? "", + ); + } + } + return { data: data as Frontmatter, content: match[2].trim() }; +} + +function buildDefinition( + name: string, + raw: string, + filePath: string, + ctx: LoadContext, +): AgentDefinition { + const { data, content } = parseFrontmatter(raw, filePath); + const fm: Frontmatter = data ?? {}; + + const tools = resolveFrontmatterTools(name, fm, filePath, ctx); + const model = fm.model ?? fm.endpoint ?? ctx.defaultModel; + + let baseSystemPrompt: BaseSystemPromptOption | undefined; + if (fm.baseSystemPrompt === false) baseSystemPrompt = false; + else if (typeof fm.baseSystemPrompt === "string") + baseSystemPrompt = fm.baseSystemPrompt; + + return { + name, + instructions: content, + model, + tools: Object.keys(tools).length > 0 ? tools : undefined, + maxSteps: typeof fm.maxSteps === "number" ? fm.maxSteps : undefined, + maxTokens: typeof fm.maxTokens === "number" ? fm.maxTokens : undefined, + baseSystemPrompt, + ephemeral: typeof fm.ephemeral === "boolean" ? fm.ephemeral : undefined, + }; +} + +function resolveFrontmatterTools( + agentName: string, + fm: Frontmatter, + filePath: string, + ctx: LoadContext, +): Record { + const out: Record = {}; + const pluginIdx = ctx.plugins ?? new Map(); + + for (const spec of fm.toolkits ?? []) { + const [pluginName, opts] = parseToolkitSpec(spec, filePath, agentName); + const provider = pluginIdx.get(pluginName); + if (!provider) { + throw new Error( + `Agent '${agentName}' (${filePath}) references toolkit '${pluginName}', but plugin '${pluginName}' is not registered. Available: ${ + pluginIdx.size > 0 + ? Array.from(pluginIdx.keys()).join(", ") + : "" + }`, + ); + } + const entries = provider.toolkit(opts) as Record; + for (const [key, entry] of Object.entries(entries)) { + if (!isToolkitEntry(entry)) { + throw new Error( + `Plugin '${pluginName}'.toolkit() returned a value at key '${key}' that is not a ToolkitEntry`, + ); + } + out[key] = entry as ToolkitEntry; + } + } + + for (const key of fm.tools ?? []) { + const tool = ctx.availableTools?.[key]; + if (!tool) { + const available = ctx.availableTools + ? Object.keys(ctx.availableTools).join(", ") + : ""; + throw new Error( + `Agent '${agentName}' (${filePath}) references tool '${key}', which is not in the agents() plugin's tools field. Available: ${available}`, + ); + } + out[key] = tool; + } + + return out; +} + +function parseToolkitSpec( + spec: ToolkitSpec, + filePath: string, + agentName: string, +): [string, ToolkitOptions | undefined] { + if (typeof spec === "string") { + return [spec, undefined]; + } + if (typeof spec !== "object" || spec === null) { + throw new Error( + `Agent '${agentName}' (${filePath}) has invalid toolkit entry: ${JSON.stringify(spec)}`, + ); + } + const keys = Object.keys(spec); + if (keys.length !== 1) { + throw new Error( + `Agent '${agentName}' (${filePath}) toolkit entry must have exactly one key, got: ${keys.join(", ")}`, + ); + } + const pluginName = keys[0]; + const value = spec[pluginName]; + if (Array.isArray(value)) { + return [pluginName, { only: value }]; + } + if (typeof value === "object" && value !== null) { + return [pluginName, value as ToolkitOptions]; + } + throw new Error( + `Agent '${agentName}' (${filePath}) toolkit '${pluginName}' options must be an array of tool names or an options object`, + ); +} diff --git a/packages/appkit/src/plugins/agents/manifest.json b/packages/appkit/src/plugins/agents/manifest.json new file mode 100644 index 00000000..f3986c83 --- /dev/null +++ b/packages/appkit/src/plugins/agents/manifest.json @@ -0,0 +1,24 @@ +{ + "$schema": "https://databricks.github.io/appkit/schemas/plugin-manifest.schema.json", + "name": "agents", + "displayName": "Agents Plugin", + "description": "AI agents driven by markdown configs or code, with auto-tool-discovery from registered plugins", + "resources": { + "required": [], + "optional": [ + { + "type": "serving_endpoint", + "alias": "Model Serving (agents)", + "resourceKey": "agents-serving-endpoint", + "description": "Databricks Model Serving endpoint for agents using workspace-hosted models (`DatabricksAdapter.fromModelServing`). Wire the same endpoint name AppKit reads from `DATABRICKS_AGENT_ENDPOINT` when no per-agent model is configured. Omit when agents use only external adapters.", + "permission": "CAN_QUERY", + "fields": { + "name": { + "env": "DATABRICKS_AGENT_ENDPOINT", + "description": "Endpoint name passed to Model Serving when agents default to `DatabricksAdapter.fromModelServing()`" + } + } + } + ] + } +} diff --git a/packages/appkit/src/plugins/agents/normalize-result.ts b/packages/appkit/src/plugins/agents/normalize-result.ts new file mode 100644 index 00000000..6fe2362c --- /dev/null +++ b/packages/appkit/src/plugins/agents/normalize-result.ts @@ -0,0 +1,33 @@ +/** + * Maximum serialized length of a tool result before we truncate with a + * human-readable marker. 50k chars is roughly ~12k tokens — enough for + * reasonable SQL result sets and JSON blobs, well short of the per-call + * context limits on current frontier models. + */ +export const MAX_TOOL_RESULT_CHARS = 50_000; + +/** + * Normalise a raw tool-execution result for the LLM: + * + * - `undefined` → empty string. A `void` return is a legitimate outcome for + * side-effecting tools ("send notification"); surfacing `undefined` to the + * adapter would otherwise read as "execution failed". + * - strings are returned as-is. + * - everything else is JSON-stringified. + * - results longer than {@link MAX_TOOL_RESULT_CHARS} are truncated and + * annotated so the model sees the cut rather than silent data loss. + * + * Pure function; safe to unit-test in isolation. + */ +export function normalizeToolResult( + result: unknown, + maxChars: number = MAX_TOOL_RESULT_CHARS, +): unknown { + if (result === undefined) return ""; + const serialized = + typeof result === "string" ? result : JSON.stringify(result); + if (serialized.length > maxChars) { + return `${serialized.slice(0, maxChars)}\n\n[Result truncated: ${serialized.length} chars exceeds ${maxChars} limit]`; + } + return result; +} diff --git a/packages/appkit/src/plugins/agents/schemas.ts b/packages/appkit/src/plugins/agents/schemas.ts new file mode 100644 index 00000000..cea6c6d6 --- /dev/null +++ b/packages/appkit/src/plugins/agents/schemas.ts @@ -0,0 +1,69 @@ +import { z } from "zod"; + +/** + * Static body cap for the `message` field on `POST /chat`. 64 000 characters + * is well above any legitimate chat turn (~16k tokens at 4 chars/token) and + * bounds the per-request cost of appending to `InMemoryThreadStore` without + * requiring per-deployment configuration. + */ +const MAX_MESSAGE_CHARS = 64_000; + +/** Cap applied to `/invocations` when `input` is a raw string. */ +const MAX_INVOCATIONS_INPUT_CHARS = 64_000; + +/** + * Cap on the number of items accepted in an `/invocations` `input` array + * (one element per seeded message). Protects against a single request + * seeding hundreds of messages into the thread store. + */ +const MAX_INVOCATIONS_INPUT_ITEMS = 100; + +/** Per-message `content` size cap (string form). */ +const MAX_INVOCATIONS_ITEM_CHARS = 64_000; + +/** Per-message `content` size cap (array form). */ +const MAX_INVOCATIONS_ITEM_ARRAY_ITEMS = 100; + +export const chatRequestSchema = z.object({ + message: z + .string() + .min(1, "message must not be empty") + .max( + MAX_MESSAGE_CHARS, + `message exceeds the ${MAX_MESSAGE_CHARS}-character limit`, + ), + threadId: z.string().optional(), + agent: z.string().optional(), +}); + +const messageItemSchema = z.object({ + role: z.enum(["user", "assistant", "system"]).optional(), + content: z + .union([ + z.string().max(MAX_INVOCATIONS_ITEM_CHARS), + z.array(z.any()).max(MAX_INVOCATIONS_ITEM_ARRAY_ITEMS), + ]) + .optional(), + type: z.string().optional(), +}); + +export const invocationsRequestSchema = z.object({ + input: z.union([ + z.string().min(1).max(MAX_INVOCATIONS_INPUT_CHARS), + z + .array(messageItemSchema) + .min(1) + .max( + MAX_INVOCATIONS_INPUT_ITEMS, + `input array exceeds the ${MAX_INVOCATIONS_INPUT_ITEMS}-item limit`, + ), + ]), + stream: z.boolean().optional().default(true), + model: z.string().optional(), +}); + +export const approvalRequestSchema = z.object({ + streamId: z.string().min(1, "streamId is required"), + approvalId: z.string().min(1, "approvalId is required"), + decision: z.enum(["approve", "deny"]), +}); diff --git a/packages/appkit/src/plugins/agents/system-prompt.ts b/packages/appkit/src/plugins/agents/system-prompt.ts new file mode 100644 index 00000000..01f3fe9b --- /dev/null +++ b/packages/appkit/src/plugins/agents/system-prompt.ts @@ -0,0 +1,52 @@ +import type { PromptContext } from "./types"; + +/** + * Default base system prompt: product identity, active AppKit plugins, and + * tool-agnostic behavior hints. + * + * Individual tool definitions and JSON Schemas are still sent through the + * model's `tools` / function-calling channel — this string is not a second + * copy of that list. `ctx.toolNames` is available for custom + * `baseSystemPrompt` callbacks; the default text stays short and does not + * enumerate tools to avoid drift and token bloat. + */ +export function buildBaseSystemPrompt(ctx: PromptContext): string { + const { pluginNames } = ctx; + const lines: string[] = [ + "You are an AI assistant running on Databricks AppKit.", + ]; + + if (pluginNames.length > 0) { + lines.push(""); + lines.push(`Active AppKit plugins: ${pluginNames.join(", ")}`); + } + + lines.push(""); + lines.push("Guidelines:"); + lines.push( + "- Be concise: for large or noisy tool output, summarize what matters and how to go deeper instead of pasting everything.", + ); + lines.push( + "- Use each tool as defined: pass required arguments and use the syntax, dialect, or path rules the target system expects (see each tool’s description and schema).", + ); + lines.push( + "- If a tool call fails, explain the error in plain language and suggest a fix or next step.", + ); + lines.push( + "- Respect tool metadata and app policy: read-only vs destructive tools, user/identity context, and any approval or safety flows the app provides.", + ); + + return lines.join("\n"); +} + +/** + * Compose the full system prompt from the base prompt and an optional + * per-agent user prompt. + */ +export function composeSystemPrompt( + basePrompt: string, + agentPrompt?: string, +): string { + if (!agentPrompt) return basePrompt; + return `${basePrompt}\n\n${agentPrompt}`; +} diff --git a/packages/appkit/src/plugins/agents/tests/agents-plugin.test.ts b/packages/appkit/src/plugins/agents/tests/agents-plugin.test.ts new file mode 100644 index 00000000..747ada48 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/agents-plugin.test.ts @@ -0,0 +1,367 @@ +import fs from "node:fs"; +import os from "node:os"; +import path from "node:path"; +import type { + AgentAdapter, + AgentInput, + AgentRunContext, + AgentToolDefinition, + ToolProvider, +} from "shared"; +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { z } from "zod"; +import { CacheManager } from "../../../cache"; +// Import the class directly so we can construct it without a createApp +import { AgentsPlugin } from "../agents"; +import { buildToolkitEntries } from "../build-toolkit"; +import { defineTool, type ToolRegistry } from "../tools/define-tool"; +import type { AgentsPluginConfig, ToolkitEntry } from "../types"; +import { isToolkitEntry } from "../types"; + +interface FakeContext { + providers: Array<{ name: string; provider: ToolProvider }>; + getToolProviders(): Array<{ name: string; provider: ToolProvider }>; + getPluginNames(): string[]; + addRoute(): void; + executeTool: ( + req: unknown, + pluginName: string, + localName: string, + args: unknown, + ) => Promise; +} + +function fakeContext( + providers: Array<{ name: string; provider: ToolProvider }>, +): FakeContext { + return { + providers, + getToolProviders: () => providers, + getPluginNames: () => providers.map((p) => p.name), + addRoute: vi.fn(), + executeTool: vi.fn(async (_req, p, n, args) => ({ + plugin: p, + tool: n, + args, + })), + }; +} + +function stubAdapter(): AgentAdapter { + return { + async *run(_input: AgentInput, _ctx: AgentRunContext) { + yield { type: "message_delta", content: "" }; + }, + }; +} + +function makeToolProvider( + pluginName: string, + registry: ToolRegistry, +): ToolProvider & { + toolkit: (opts?: unknown) => Record; +} { + return { + getAgentTools(): AgentToolDefinition[] { + return Object.entries(registry).map(([name, entry]) => ({ + name, + description: entry.description, + parameters: { type: "object", properties: {} }, + })); + }, + async executeAgentTool(name, args) { + return { callFrom: pluginName, name, args }; + }, + toolkit: (opts) => buildToolkitEntries(pluginName, registry, opts as never), + }; +} + +let tmpDir: string; + +beforeEach(async () => { + tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), "agents-plugin-")); + const storage = { + get: vi.fn(), + set: vi.fn(), + delete: vi.fn(), + keys: vi.fn(), + healthCheck: vi.fn(async () => true), + close: vi.fn(async () => {}), + }; + // biome-ignore lint/suspicious/noExplicitAny: test-only CacheManager wiring + await CacheManager.getInstance({ storage: storage as any }); +}); + +afterEach(() => { + fs.rmSync(tmpDir, { recursive: true, force: true }); +}); + +function instantiate(config: AgentsPluginConfig, ctx?: FakeContext) { + const plugin = new AgentsPlugin({ ...config, name: "agent" }); + plugin.attachContext({ context: ctx as unknown as object }); + return plugin; +} + +function writeMarkdownAgent(dir: string, id: string, content: string) { + const folder = path.join(dir, id); + fs.mkdirSync(folder, { recursive: true }); + fs.writeFileSync(path.join(folder, "agent.md"), content, "utf-8"); +} + +describe("AgentsPlugin", () => { + test("registers code-defined agents and exposes them via exports", async () => { + const plugin = instantiate({ + dir: false, + agents: { + support: { + instructions: "You help customers.", + model: stubAdapter(), + }, + }, + }); + await plugin.setup(); + + const api = plugin.exports() as { + list: () => string[]; + getDefault: () => string | null; + }; + expect(api.list()).toEqual(["support"]); + expect(api.getDefault()).toBe("support"); + }); + + test("loads markdown agents from a directory", async () => { + writeMarkdownAgent( + tmpDir, + "assistant", + "---\ndefault: true\n---\nYou are helpful.", + ); + const plugin = instantiate({ + dir: tmpDir, + defaultModel: stubAdapter(), + }); + await plugin.setup(); + + const api = plugin.exports() as { + list: () => string[]; + getDefault: () => string | null; + }; + expect(api.list()).toEqual(["assistant"]); + expect(api.getDefault()).toBe("assistant"); + }); + + test("code definitions override markdown on key collision", async () => { + writeMarkdownAgent(tmpDir, "support", "---\n---\nFrom markdown."); + const plugin = instantiate({ + dir: tmpDir, + defaultModel: stubAdapter(), + agents: { + support: { + instructions: "From code", + model: stubAdapter(), + }, + }, + }); + await plugin.setup(); + + const api = plugin.exports() as { + get: (name: string) => { instructions: string } | null; + }; + expect(api.get("support")?.instructions).toBe("From code"); + }); + + test("auto-inherit default is safe (both file and code get nothing without an explicit opt-in)", async () => { + const registry: ToolRegistry = { + query: defineTool({ + description: "q", + schema: z.object({ sql: z.string() }), + autoInheritable: true, // even with autoInheritable, no spread without opt-in + handler: () => "ok", + }), + }; + const provider = makeToolProvider("analytics", registry); + const ctx = fakeContext([{ name: "analytics", provider }]); + + writeMarkdownAgent(tmpDir, "assistant", "---\n---\nYou are helpful."); + + const plugin = instantiate( + { + dir: tmpDir, + defaultModel: stubAdapter(), + agents: { + manual: { + instructions: "Manual agent", + model: stubAdapter(), + }, + }, + }, + ctx, + ); + await plugin.setup(); + + const api = plugin.exports() as { + get: (name: string) => { toolIndex: Map } | null; + }; + const fileAgent = api.get("assistant"); + const codeAgent = api.get("manual"); + + expect(fileAgent?.toolIndex.size).toBe(0); + expect(codeAgent?.toolIndex.size).toBe(0); + }); + + test("opting in with autoInheritTools: { file: true } spreads only autoInheritable tools", async () => { + const registry: ToolRegistry = { + query: defineTool({ + description: "read-only query", + schema: z.object({ sql: z.string() }), + autoInheritable: true, + handler: () => "ok", + }), + destructive: defineTool({ + description: "mutation", + schema: z.object({}), + // autoInheritable left unset → skipped even when opted in + handler: () => "ok", + }), + }; + const provider = makeToolProvider("analytics", registry); + const ctx = fakeContext([{ name: "analytics", provider }]); + + writeMarkdownAgent(tmpDir, "assistant", "---\n---\nYou are helpful."); + + const plugin = instantiate( + { + dir: tmpDir, + defaultModel: stubAdapter(), + autoInheritTools: { file: true }, + }, + ctx, + ); + await plugin.setup(); + + const api = plugin.exports() as { + get: (name: string) => { toolIndex: Map } | null; + }; + const fileAgent = api.get("assistant"); + const keys = Array.from(fileAgent?.toolIndex.keys() ?? []); + expect(keys).toEqual(["analytics.query"]); + }); + + test("autoInheritTools: true enables both origins but still filters by autoInheritable", async () => { + const registry: ToolRegistry = { + safe: defineTool({ + description: "safe", + schema: z.object({}), + autoInheritable: true, + handler: () => "ok", + }), + unsafe: defineTool({ + description: "unsafe", + schema: z.object({}), + handler: () => "ok", + }), + }; + const provider = makeToolProvider("p", registry); + const ctx = fakeContext([{ name: "p", provider }]); + + const plugin = instantiate( + { + dir: false, + defaultModel: stubAdapter(), + autoInheritTools: true, + agents: { + code1: { + instructions: "code agent", + model: stubAdapter(), + }, + }, + }, + ctx, + ); + await plugin.setup(); + + const api = plugin.exports() as { + get: (name: string) => { toolIndex: Map } | null; + }; + const codeAgent = api.get("code1"); + const keys = Array.from(codeAgent?.toolIndex.keys() ?? []); + expect(keys).toEqual(["p.safe"]); + }); + + test("file-loaded agent respects explicit toolkits (skips auto-inherit)", async () => { + const registry: ToolRegistry = { + query: defineTool({ + description: "q", + schema: z.object({ sql: z.string() }), + handler: () => "ok", + }), + }; + const registry2: ToolRegistry = { + list: defineTool({ + description: "l", + schema: z.object({}), + handler: () => [], + }), + }; + const ctx = fakeContext([ + { name: "analytics", provider: makeToolProvider("analytics", registry) }, + { name: "files", provider: makeToolProvider("files", registry2) }, + ]); + + writeMarkdownAgent( + tmpDir, + "analyst", + "---\ntoolkits: [analytics]\n---\nAnalyst.", + ); + + const plugin = instantiate( + { dir: tmpDir, defaultModel: stubAdapter() }, + ctx, + ); + await plugin.setup(); + + const api = plugin.exports() as { + get: (name: string) => { toolIndex: Map } | null; + }; + const agent = api.get("analyst"); + const toolNames = Array.from(agent?.toolIndex.keys() ?? []); + expect(toolNames.some((n) => n.startsWith("analytics."))).toBe(true); + expect(toolNames.some((n) => n.startsWith("files."))).toBe(false); + }); + + test("registers sub-agents as agent- tools", async () => { + const plugin = instantiate({ + dir: false, + agents: { + supervisor: { + instructions: "Supervise", + model: stubAdapter(), + agents: { + worker: { + instructions: "Work", + model: stubAdapter(), + }, + }, + }, + }, + }); + await plugin.setup(); + + const api = plugin.exports() as { + get: (name: string) => { toolIndex: Map } | null; + }; + const sup = api.get("supervisor"); + expect(sup?.toolIndex.has("agent-worker")).toBe(true); + }); + + test("isToolkitEntry type guard recognizes toolkit entries", () => { + const entry: ToolkitEntry = { + __toolkitRef: true, + pluginName: "x", + localName: "y", + def: { name: "x.y", description: "", parameters: { type: "object" } }, + }; + expect(isToolkitEntry(entry)).toBe(true); + expect(isToolkitEntry({ foo: 1 })).toBe(false); + expect(isToolkitEntry(null)).toBe(false); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/approval-route.test.ts b/packages/appkit/src/plugins/agents/tests/approval-route.test.ts new file mode 100644 index 00000000..6e090bd2 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/approval-route.test.ts @@ -0,0 +1,292 @@ +import type express from "express"; +import { beforeEach, describe, expect, test, vi } from "vitest"; +import { CacheManager } from "../../../cache"; +import { AgentsPlugin } from "../agents"; + +/** + * Focused tests for the `POST /approve` route and the associated + * ownership / error paths on `_handleApprove`. Covers: + * + * - Schema validation of the request body. + * - Ownership check: the user submitting the decision must be the same + * user who initiated the underlying chat stream. + * - 404 for unknown stream (already completed or never existed). + * - 404 for unknown approvalId even when the stream is active. + * - Happy-path resolution of a pending gate with `approve` and `deny`. + * - Cancel of an active stream denies every pending gate on that stream. + */ + +function mockReq(body: unknown, userId?: string): express.Request { + const headers: Record = {}; + if (userId) { + headers["x-forwarded-user"] = userId; + headers["x-forwarded-access-token"] = "fake-token"; + } + return { + body, + headers, + header: (name: string) => headers[name.toLowerCase()], + } as unknown as express.Request; +} + +function mockRes() { + const json = vi.fn(); + const end = vi.fn(); + let statusCode = 200; + const status = vi.fn((code: number) => { + statusCode = code; + return { json, end }; + }); + return { + res: { status, json, end } as unknown as express.Response, + get statusCode() { + return statusCode; + }, + json, + }; +} + +beforeEach(() => { + CacheManager.getInstanceSync = vi.fn(() => ({ + get: vi.fn(), + set: vi.fn(), + delete: vi.fn(), + getOrExecute: vi.fn(async (_k: unknown[], fn: () => Promise) => + fn(), + ), + generateKey: vi.fn(() => "test-key"), + })) as any; + process.env.NODE_ENV = "development"; +}); + +describe("POST /approve route handler", () => { + test("rejects invalid body shape with 400", async () => { + const plugin = new AgentsPlugin({ dir: false }); + const { res, json } = mockRes(); + await (plugin as any)._handleApprove(mockReq({}, "alice"), res); + expect(res.status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith( + expect.objectContaining({ error: "Invalid request" }), + ); + }); + + test("returns 404 when the streamId is unknown", async () => { + const plugin = new AgentsPlugin({ dir: false }); + const { res, json } = mockRes(); + await ( + plugin as unknown as { + _handleApprove: ( + r: express.Request, + w: express.Response, + ) => Promise; + } + )._handleApprove( + mockReq( + { streamId: "ghost", approvalId: "a1", decision: "approve" }, + "alice", + ), + res, + ); + expect(res.status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith( + expect.objectContaining({ error: expect.stringMatching(/not found/i) }), + ); + }); + + test("returns 403 when submitter is different from stream owner", async () => { + const plugin = new AgentsPlugin({ dir: false }); + (plugin as any).activeStreams.set("stream-x", { + controller: new AbortController(), + userId: "alice", + }); + const gate = (plugin as any).approvalGate; + const waiter = gate.wait({ + approvalId: "a1", + streamId: "stream-x", + userId: "alice", + timeoutMs: 60_000, + }); + + const { res, json } = mockRes(); + await ( + plugin as unknown as { + _handleApprove: ( + r: express.Request, + w: express.Response, + ) => Promise; + } + )._handleApprove( + mockReq( + { streamId: "stream-x", approvalId: "a1", decision: "approve" }, + "bob", + ), + res, + ); + expect(res.status).toHaveBeenCalledWith(403); + expect(json).toHaveBeenCalledWith( + expect.objectContaining({ error: "Forbidden" }), + ); + + // Settle the waiter to clean up. + gate.submit({ approvalId: "a1", userId: "alice", decision: "deny" }); + await expect(waiter).resolves.toBe("deny"); + }); + + test("returns 404 when approvalId is unknown on an active stream", async () => { + const plugin = new AgentsPlugin({ dir: false }); + (plugin as any).activeStreams.set("stream-y", { + controller: new AbortController(), + userId: "alice", + }); + const { res, json } = mockRes(); + await ( + plugin as unknown as { + _handleApprove: ( + r: express.Request, + w: express.Response, + ) => Promise; + } + )._handleApprove( + mockReq( + { streamId: "stream-y", approvalId: "unknown-a", decision: "approve" }, + "alice", + ), + res, + ); + expect(res.status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith( + expect.objectContaining({ + error: expect.stringMatching(/not found|already settled/i), + }), + ); + }); + + test("happy path: approve resolves pending gate with 'approve'", async () => { + const plugin = new AgentsPlugin({ dir: false }); + (plugin as any).activeStreams.set("stream-z", { + controller: new AbortController(), + userId: "alice", + }); + const gate = (plugin as any).approvalGate; + const waiter = gate.wait({ + approvalId: "a42", + streamId: "stream-z", + userId: "alice", + timeoutMs: 60_000, + }); + + const { res, json } = mockRes(); + await ( + plugin as unknown as { + _handleApprove: ( + r: express.Request, + w: express.Response, + ) => Promise; + } + )._handleApprove( + mockReq( + { streamId: "stream-z", approvalId: "a42", decision: "approve" }, + "alice", + ), + res, + ); + expect(res.status).not.toHaveBeenCalled(); + expect(json).toHaveBeenCalledWith({ decision: "approve" }); + await expect(waiter).resolves.toBe("approve"); + }); + + test("happy path: deny resolves pending gate with 'deny'", async () => { + const plugin = new AgentsPlugin({ dir: false }); + (plugin as any).activeStreams.set("stream-z", { + controller: new AbortController(), + userId: "alice", + }); + const gate = (plugin as any).approvalGate; + const waiter = gate.wait({ + approvalId: "a43", + streamId: "stream-z", + userId: "alice", + timeoutMs: 60_000, + }); + + const { res, json } = mockRes(); + await ( + plugin as unknown as { + _handleApprove: ( + r: express.Request, + w: express.Response, + ) => Promise; + } + )._handleApprove( + mockReq( + { streamId: "stream-z", approvalId: "a43", decision: "deny" }, + "alice", + ), + res, + ); + expect(json).toHaveBeenCalledWith({ decision: "deny" }); + await expect(waiter).resolves.toBe("deny"); + }); +}); + +describe("POST /cancel ownership + gate cleanup", () => { + test("cancelling a stream denies every pending approval on that stream", async () => { + const plugin = new AgentsPlugin({ dir: false }); + const controller = new AbortController(); + (plugin as any).activeStreams.set("stream-c", { + controller, + userId: "alice", + }); + const gate = (plugin as any).approvalGate; + const a = gate.wait({ + approvalId: "ca1", + streamId: "stream-c", + userId: "alice", + timeoutMs: 60_000, + }); + const b = gate.wait({ + approvalId: "ca2", + streamId: "stream-c", + userId: "alice", + timeoutMs: 60_000, + }); + + const { res, json } = mockRes(); + await ( + plugin as unknown as { + _handleCancel: ( + r: express.Request, + w: express.Response, + ) => Promise; + } + )._handleCancel(mockReq({ streamId: "stream-c" }, "alice"), res); + + expect(controller.signal.aborted).toBe(true); + expect(json).toHaveBeenCalledWith({ cancelled: true }); + await expect(a).resolves.toBe("deny"); + await expect(b).resolves.toBe("deny"); + }); + + test("cancel from a different user is refused with 403", async () => { + const plugin = new AgentsPlugin({ dir: false }); + const controller = new AbortController(); + (plugin as any).activeStreams.set("stream-d", { + controller, + userId: "alice", + }); + const { res, json } = mockRes(); + await ( + plugin as unknown as { + _handleCancel: ( + r: express.Request, + w: express.Response, + ) => Promise; + } + )._handleCancel(mockReq({ streamId: "stream-d" }, "bob"), res); + expect(res.status).toHaveBeenCalledWith(403); + expect(controller.signal.aborted).toBe(false); + expect(json).toHaveBeenCalledWith( + expect.objectContaining({ error: "Forbidden" }), + ); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/build-toolkit.test.ts b/packages/appkit/src/plugins/agents/tests/build-toolkit.test.ts new file mode 100644 index 00000000..08f71da9 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/build-toolkit.test.ts @@ -0,0 +1,101 @@ +import { describe, expect, test } from "vitest"; +import { z } from "zod"; +import { buildToolkitEntries } from "../build-toolkit"; +import { defineTool, type ToolRegistry } from "../tools/define-tool"; +import { isToolkitEntry } from "../types"; + +const registry: ToolRegistry = { + query: defineTool({ + description: "Run a query", + schema: z.object({ sql: z.string() }), + handler: () => "ok", + }), + history: defineTool({ + description: "Get query history", + schema: z.object({}), + handler: () => [], + }), +}; + +describe("buildToolkitEntries", () => { + test("produces ToolkitEntry per registry item with default dotted prefix", () => { + const entries = buildToolkitEntries("analytics", registry); + expect(Object.keys(entries).sort()).toEqual([ + "analytics.history", + "analytics.query", + ]); + for (const entry of Object.values(entries)) { + expect(isToolkitEntry(entry)).toBe(true); + expect(entry.pluginName).toBe("analytics"); + } + }); + + test("respects prefix option (empty drops the namespace)", () => { + const entries = buildToolkitEntries("analytics", registry, { prefix: "" }); + expect(Object.keys(entries).sort()).toEqual(["history", "query"]); + }); + + test("respects custom prefix", () => { + const entries = buildToolkitEntries("analytics", registry, { + prefix: "db.", + }); + expect(Object.keys(entries).sort()).toEqual(["db.history", "db.query"]); + }); + + test("only filter keeps the listed local names", () => { + const entries = buildToolkitEntries("analytics", registry, { + only: ["query"], + }); + expect(Object.keys(entries)).toEqual(["analytics.query"]); + }); + + test("except filter drops the listed local names", () => { + const entries = buildToolkitEntries("analytics", registry, { + except: ["history"], + }); + expect(Object.keys(entries)).toEqual(["analytics.query"]); + }); + + test("rename remaps specific local names (overrides the prefix key)", () => { + const entries = buildToolkitEntries("analytics", registry, { + rename: { query: "sql" }, + }); + expect(Object.keys(entries).sort()).toEqual(["analytics.history", "sql"]); + }); + + test("exposes the original plugin+local name so dispatch can route", () => { + const entries = buildToolkitEntries("analytics", registry, { + prefix: "db.", + }); + const qEntry = entries["db.query"]; + expect(qEntry.pluginName).toBe("analytics"); + expect(qEntry.localName).toBe("query"); + expect(qEntry.def.name).toBe("db.query"); + }); + + test("propagates autoInheritable from the source registry", () => { + const mixed: ToolRegistry = { + readIt: defineTool({ + description: "safe read", + schema: z.object({}), + autoInheritable: true, + handler: () => "ok", + }), + writeIt: defineTool({ + description: "unsafe write", + schema: z.object({}), + autoInheritable: false, + handler: () => "ok", + }), + unmarked: defineTool({ + description: "default: not auto-inheritable", + schema: z.object({}), + handler: () => "ok", + }), + }; + const entries = buildToolkitEntries("p", mixed); + expect(entries["p.readIt"].autoInheritable).toBe(true); + expect(entries["p.writeIt"].autoInheritable).toBe(false); + expect(entries["p.unmarked"].autoInheritable).toBeUndefined(); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/consume-adapter-stream.test.ts b/packages/appkit/src/plugins/agents/tests/consume-adapter-stream.test.ts new file mode 100644 index 00000000..98863a62 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/consume-adapter-stream.test.ts @@ -0,0 +1,86 @@ +import type { AgentEvent } from "shared"; +import { describe, expect, test } from "vitest"; +import { consumeAdapterStream } from "../consume-adapter-stream"; + +async function* streamOf( + events: AgentEvent[], +): AsyncGenerator { + for (const event of events) { + yield event; + } +} + +describe("consumeAdapterStream", () => { + test("concatenates message_delta events into the final text", async () => { + const text = await consumeAdapterStream( + streamOf([ + { type: "message_delta", content: "Hello " }, + { type: "message_delta", content: "world" }, + ]), + ); + expect(text).toBe("Hello world"); + }); + + test("a `message` event replaces whatever deltas arrived so far", async () => { + const text = await consumeAdapterStream( + streamOf([ + { type: "message_delta", content: "partial" }, + { type: "message", content: "final answer" }, + ]), + ); + expect(text).toBe("final answer"); + }); + + test("invokes onEvent once per event, in order, with the raw event", async () => { + const seen: AgentEvent[] = []; + await consumeAdapterStream( + streamOf([ + { type: "message_delta", content: "a" }, + { type: "thinking", content: "…" }, + { type: "message_delta", content: "b" }, + ]), + { onEvent: (ev) => seen.push(ev) }, + ); + expect(seen.map((e) => e.type)).toEqual([ + "message_delta", + "thinking", + "message_delta", + ]); + }); + + test("stops iterating once the signal aborts", async () => { + const controller = new AbortController(); + const emitted: string[] = []; + await consumeAdapterStream( + (async function* () { + yield { type: "message_delta", content: "first" } as AgentEvent; + controller.abort(); + yield { type: "message_delta", content: "second" } as AgentEvent; + })(), + { + signal: controller.signal, + onEvent: (ev) => { + if (ev.type === "message_delta") emitted.push(ev.content); + }, + }, + ); + expect(emitted).toEqual(["first"]); + }); + + test("returns an empty string for a stream with no content events", async () => { + const text = await consumeAdapterStream( + streamOf([{ type: "thinking", content: "…" }]), + ); + expect(text).toBe(""); + }); + + test("works without a signal (standalone runAgent path)", async () => { + const text = await consumeAdapterStream( + streamOf([ + { type: "message_delta", content: "x" }, + { type: "message_delta", content: "y" }, + ]), + ); + expect(text).toBe("xy"); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/create-agent.test.ts b/packages/appkit/src/plugins/agents/tests/create-agent.test.ts new file mode 100644 index 00000000..3822897f --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/create-agent.test.ts @@ -0,0 +1,75 @@ +import { describe, expect, test } from "vitest"; +import { z } from "zod"; +import { createAgent } from "../../../core/create-agent-def"; +import { tool } from "../tools/tool"; +import type { AgentDefinition } from "../types"; + +describe("createAgent", () => { + test("returns the definition unchanged for a simple agent", () => { + const def: AgentDefinition = { + name: "support", + instructions: "You help customers.", + model: "endpoint-x", + }; + const result = createAgent(def); + expect(result).toBe(def); + }); + + test("accepts tools as a keyed record", () => { + const get_weather = tool({ + name: "get_weather", + description: "Get the weather", + schema: z.object({ city: z.string() }), + execute: async ({ city }) => `Sunny in ${city}`, + }); + + const def = createAgent({ + instructions: "...", + tools: { get_weather }, + }); + + expect(def.tools?.get_weather).toBe(get_weather); + }); + + test("accepts sub-agents in a keyed record", () => { + const researcher = createAgent({ instructions: "Research." }); + const supervisor = createAgent({ + instructions: "Supervise.", + agents: { researcher }, + }); + expect(supervisor.agents?.researcher).toBe(researcher); + }); + + test("throws on a direct self-cycle", () => { + const a: AgentDefinition = { instructions: "a" }; + // biome-ignore lint/suspicious/noExplicitAny: intentional cycle setup for test + (a as any).agents = { self: a }; + expect(() => createAgent(a)).toThrow(/cycle/i); + }); + + test("throws on an indirect cycle", () => { + const a: AgentDefinition = { instructions: "a" }; + const b: AgentDefinition = { instructions: "b" }; + a.agents = { b }; + b.agents = { a }; + expect(() => createAgent(a)).toThrow(/cycle/i); + }); + + test("accepts a DAG of sub-agents without throwing", () => { + const leaf: AgentDefinition = { instructions: "leaf" }; + const branchA: AgentDefinition = { + instructions: "a", + agents: { leaf }, + }; + const branchB: AgentDefinition = { + instructions: "b", + agents: { leaf }, + }; + const root = createAgent({ + instructions: "root", + agents: { branchA, branchB }, + }); + expect(root.agents?.branchA.agents?.leaf).toBe(leaf); + expect(root.agents?.branchB.agents?.leaf).toBe(leaf); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/define-tool.test.ts b/packages/appkit/src/plugins/agents/tests/define-tool.test.ts new file mode 100644 index 00000000..ef61e8c4 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/define-tool.test.ts @@ -0,0 +1,133 @@ +import { describe, expect, test, vi } from "vitest"; +import { z } from "zod"; +import { + defineTool, + executeFromRegistry, + type ToolRegistry, + toolsFromRegistry, +} from "../tools/define-tool"; + +describe("defineTool()", () => { + test("returns an entry matching the input config", () => { + const entry = defineTool({ + description: "echo", + schema: z.object({ msg: z.string() }), + annotations: { readOnly: true }, + handler: ({ msg }) => msg, + }); + + expect(entry.description).toBe("echo"); + expect(entry.annotations).toEqual({ readOnly: true }); + expect(typeof entry.handler).toBe("function"); + }); +}); + +describe("executeFromRegistry", () => { + const registry: ToolRegistry = { + echo: defineTool({ + description: "echo", + schema: z.object({ msg: z.string() }), + handler: ({ msg }) => `got ${msg}`, + }), + }; + + test("validates args and calls handler on success", async () => { + const result = await executeFromRegistry(registry, "echo", { msg: "hi" }); + expect(result).toBe("got hi"); + }); + + test("returns formatted error string on validation failure", async () => { + const result = await executeFromRegistry(registry, "echo", {}); + expect(typeof result).toBe("string"); + expect(result).toContain("Invalid arguments for echo"); + expect(result).toContain("msg"); + }); + + test("throws for unknown tool names", async () => { + await expect(executeFromRegistry(registry, "missing", {})).rejects.toThrow( + /Unknown tool: missing/, + ); + }); + + test("forwards AbortSignal to the handler", async () => { + const handler = vi.fn(async (_args: { x: string }, signal?: AbortSignal) => + signal?.aborted ? "aborted" : "ok", + ); + const reg: ToolRegistry = { + t: defineTool({ + description: "t", + schema: z.object({ x: z.string() }), + handler, + }), + }; + + const controller = new AbortController(); + controller.abort(); + await executeFromRegistry(reg, "t", { x: "hi" }, controller.signal); + + expect(handler).toHaveBeenCalledTimes(1); + expect(handler.mock.calls[0][1]).toBe(controller.signal); + }); +}); + +describe("toolsFromRegistry", () => { + test("produces AgentToolDefinition[] with JSON Schema parameters", () => { + const registry: ToolRegistry = { + query: defineTool({ + description: "Execute a SQL query", + schema: z.object({ + query: z.string().describe("SQL query"), + }), + annotations: { readOnly: true, requiresUserContext: true }, + handler: () => "ok", + }), + }; + + const defs = toolsFromRegistry(registry); + expect(defs).toHaveLength(1); + expect(defs[0].name).toBe("query"); + expect(defs[0].description).toBe("Execute a SQL query"); + expect(defs[0].parameters).toMatchObject({ + type: "object", + properties: { + query: { type: "string", description: "SQL query" }, + }, + required: ["query"], + }); + expect(defs[0].annotations).toEqual({ + readOnly: true, + requiresUserContext: true, + }); + }); + + test("preserves dotted names like uploads.list from the registry keys", () => { + const registry: ToolRegistry = { + "uploads.list": defineTool({ + description: "list uploads", + schema: z.object({}), + handler: () => [], + }), + "documents.list": defineTool({ + description: "list documents", + schema: z.object({}), + handler: () => [], + }), + }; + + const names = toolsFromRegistry(registry).map((d) => d.name); + expect(names).toContain("uploads.list"); + expect(names).toContain("documents.list"); + }); + + test("omits annotations when none are provided", () => { + const registry: ToolRegistry = { + plain: defineTool({ + description: "plain", + schema: z.object({}), + handler: () => "ok", + }), + }; + const [def] = toolsFromRegistry(registry); + expect(def.annotations).toBeUndefined(); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/dos-limits.test.ts b/packages/appkit/src/plugins/agents/tests/dos-limits.test.ts new file mode 100644 index 00000000..935fa240 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/dos-limits.test.ts @@ -0,0 +1,299 @@ +import type express from "express"; +import { beforeEach, describe, expect, test, vi } from "vitest"; +import { CacheManager } from "../../../cache"; +import { AgentsPlugin } from "../agents"; +import { chatRequestSchema, invocationsRequestSchema } from "../schemas"; + +/** + * Exercises the four DoS caps landed for MVP: + * + * - `chatRequestSchema.message.max(64_000)` — body cap on `POST /chat`. + * - Per-user `maxConcurrentStreamsPerUser` — 429 with Retry-After. + * - Per-run `maxToolCalls` — aborts stream and throws in `executeTool`. + * - Per-delegation `maxSubAgentDepth` — rejects in `runSubAgent`. + * + * Route-level tests exercise the schemas + `_handleChat` directly via the + * mocked req/res pattern already used by approval-route.test.ts. + */ + +function mockReq(body: unknown, userId?: string): express.Request { + const headers: Record = {}; + if (userId) { + headers["x-forwarded-user"] = userId; + headers["x-forwarded-access-token"] = "fake-token"; + } + return { + body, + headers, + header: (name: string) => headers[name.toLowerCase()], + } as unknown as express.Request; +} + +function mockRes() { + const json = vi.fn(); + const setHeader = vi.fn(); + let statusCode = 200; + const status = vi.fn((code: number) => { + statusCode = code; + return { json }; + }); + return { + res: { status, json, setHeader } as unknown as express.Response, + get statusCode() { + return statusCode; + }, + json, + setHeader, + }; +} + +beforeEach(() => { + CacheManager.getInstanceSync = vi.fn(() => ({ + get: vi.fn(), + set: vi.fn(), + delete: vi.fn(), + getOrExecute: vi.fn(async (_k: unknown[], fn: () => Promise) => + fn(), + ), + generateKey: vi.fn(() => "test-key"), + // biome-ignore lint/suspicious/noExplicitAny: test mock + })) as any; + process.env.NODE_ENV = "development"; +}); + +describe("chatRequestSchema — body cap", () => { + test("accepts messages up to 64_000 characters", () => { + const result = chatRequestSchema.safeParse({ + message: "a".repeat(64_000), + }); + expect(result.success).toBe(true); + }); + + test("rejects messages over 64_000 characters", () => { + const result = chatRequestSchema.safeParse({ + message: "a".repeat(64_001), + }); + expect(result.success).toBe(false); + if (!result.success) { + expect(JSON.stringify(result.error.flatten())).toMatch(/64000/); + } + }); + + test("rejects empty message (existing contract)", () => { + expect(chatRequestSchema.safeParse({ message: "" }).success).toBe(false); + }); +}); + +describe("invocationsRequestSchema — input caps", () => { + test("accepts string input up to 64_000 characters", () => { + const result = invocationsRequestSchema.safeParse({ + input: "a".repeat(64_000), + }); + expect(result.success).toBe(true); + }); + + test("rejects string input over 64_000 characters", () => { + const result = invocationsRequestSchema.safeParse({ + input: "a".repeat(64_001), + }); + expect(result.success).toBe(false); + }); + + test("accepts array input up to 100 items", () => { + const items = Array.from({ length: 100 }, (_, i) => ({ + role: "user" as const, + content: `m${i}`, + })); + expect(invocationsRequestSchema.safeParse({ input: items }).success).toBe( + true, + ); + }); + + test("rejects array input over 100 items", () => { + const items = Array.from({ length: 101 }, (_, i) => ({ + role: "user" as const, + content: `m${i}`, + })); + const result = invocationsRequestSchema.safeParse({ input: items }); + expect(result.success).toBe(false); + }); + + test("rejects per-item content over 64_000 characters", () => { + const result = invocationsRequestSchema.safeParse({ + input: [{ role: "user", content: "a".repeat(64_001) }], + }); + expect(result.success).toBe(false); + }); +}); + +describe("POST /chat — per-user concurrent-stream limit", () => { + function seedPlugin( + overrides: ConstructorParameters[0] = { dir: false }, + ): AgentsPlugin { + const plugin = new AgentsPlugin(overrides); + // Seed the agents map directly so _handleChat can resolve "hello" + // without running setup() (which would require a live model). + // biome-ignore lint/suspicious/noExplicitAny: seeding private state + (plugin as any).agents.set("hello", { + name: "hello", + instructions: "hi", + adapter: { async *run() {} }, + toolIndex: new Map(), + }); + // biome-ignore lint/suspicious/noExplicitAny: seeding private state + (plugin as any).defaultAgentName = "hello"; + return plugin; + } + + test("rejects with 429 + Retry-After when user is at-limit (default 5)", async () => { + const plugin = seedPlugin(); + for (let i = 0; i < 5; i++) { + // biome-ignore lint/suspicious/noExplicitAny: seeding + (plugin as any).activeStreams.set(`s${i}`, { + controller: new AbortController(), + userId: "alice", + }); + } + + const { res, setHeader, json } = mockRes(); + await ( + plugin as unknown as { + _handleChat: (r: express.Request, w: express.Response) => Promise; + } + )._handleChat(mockReq({ message: "hi" }, "alice"), res); + + expect(res.status).toHaveBeenCalledWith(429); + expect(setHeader).toHaveBeenCalledWith("Retry-After", "5"); + expect(json).toHaveBeenCalledWith( + expect.objectContaining({ + error: expect.stringMatching(/Too many concurrent streams/), + }), + ); + }); + + test("does not reject when another user is at-limit (per-user, not global)", async () => { + const plugin = seedPlugin(); + for (let i = 0; i < 5; i++) { + // biome-ignore lint/suspicious/noExplicitAny: seeding + (plugin as any).activeStreams.set(`s${i}`, { + controller: new AbortController(), + userId: "alice", + }); + } + + // Carol's request must not see a 429 even though alice is at-limit. + // Don't bother running the full stream — we assert only that 429 is + // not the response status. + const { res } = mockRes(); + // biome-ignore lint/suspicious/noExplicitAny: stub _streamAgent to avoid needing a real adapter + (plugin as any)._streamAgent = vi.fn(async () => undefined); + + await ( + plugin as unknown as { + _handleChat: (r: express.Request, w: express.Response) => Promise; + } + )._handleChat(mockReq({ message: "hi" }, "carol"), res); + + expect(res.status).not.toHaveBeenCalledWith(429); + }); + + test("honours agents({ limits: { maxConcurrentStreamsPerUser } })", async () => { + const plugin = seedPlugin({ + dir: false, + limits: { maxConcurrentStreamsPerUser: 2 }, + }); + for (let i = 0; i < 2; i++) { + // biome-ignore lint/suspicious/noExplicitAny: seeding + (plugin as any).activeStreams.set(`s${i}`, { + controller: new AbortController(), + userId: "alice", + }); + } + + const { res } = mockRes(); + await ( + plugin as unknown as { + _handleChat: (r: express.Request, w: express.Response) => Promise; + } + )._handleChat(mockReq({ message: "hi" }, "alice"), res); + + expect(res.status).toHaveBeenCalledWith(429); + }); +}); + +describe("resolvedLimits — default values", () => { + test("exposes the documented MVP defaults when unconfigured", () => { + const plugin = new AgentsPlugin({ dir: false }); + // biome-ignore lint/suspicious/noExplicitAny: read private getter + const limits = (plugin as any).resolvedLimits; + expect(limits).toEqual({ + maxConcurrentStreamsPerUser: 5, + maxToolCalls: 50, + maxSubAgentDepth: 3, + }); + }); + + test("lets callers override any subset", () => { + const plugin = new AgentsPlugin({ + dir: false, + limits: { maxToolCalls: 100 }, + }); + // biome-ignore lint/suspicious/noExplicitAny: read private + const limits = (plugin as any).resolvedLimits; + expect(limits.maxToolCalls).toBe(100); + expect(limits.maxConcurrentStreamsPerUser).toBe(5); + expect(limits.maxSubAgentDepth).toBe(3); + }); +}); + +describe("runSubAgent — depth guard", () => { + test("rejects when depth exceeds the configured maximum", async () => { + const plugin = new AgentsPlugin({ + dir: false, + limits: { maxSubAgentDepth: 2 }, + }); + // biome-ignore lint/suspicious/noExplicitAny: call private method directly + await expect( + (plugin as any).runSubAgent( + mockReq({}, "alice"), + { name: "child", toolIndex: new Map() }, + {}, + new AbortController().signal, + 3, // exceeds limit 2 + ), + ).rejects.toThrow(/Sub-agent depth exceeded \(limit 2\)/); + }); + + test("accepts at the boundary (depth === limit)", async () => { + // Use a stub adapter so we don't need a real model. + const plugin = new AgentsPlugin({ + dir: false, + limits: { maxSubAgentDepth: 3 }, + agents: {}, + }); + + const stubAdapter = { + // biome-ignore lint/suspicious/noExplicitAny: adapter shape not under test + async *run(): any { + yield { type: "message", content: "hello from depth-3" }; + }, + }; + const child = { + name: "child", + instructions: "test", + // biome-ignore lint/suspicious/noExplicitAny: stub shape + adapter: stubAdapter as any, + toolIndex: new Map(), + }; + + // biome-ignore lint/suspicious/noExplicitAny: call private + const result = await (plugin as any).runSubAgent( + mockReq({}, "alice"), + child, + { input: "test" }, + new AbortController().signal, + 3, // at the limit, not over + ); + expect(result).toBe("hello from depth-3"); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/event-channel.test.ts b/packages/appkit/src/plugins/agents/tests/event-channel.test.ts new file mode 100644 index 00000000..d80d788d --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/event-channel.test.ts @@ -0,0 +1,78 @@ +import { describe, expect, test } from "vitest"; +import { EventChannel } from "../event-channel"; + +async function collect(ch: EventChannel): Promise { + const out: T[] = []; + for await (const v of ch) out.push(v); + return out; +} + +describe("EventChannel", () => { + test("yields pushed values in order", async () => { + const ch = new EventChannel(); + const p = collect(ch); + ch.push(1); + ch.push(2); + ch.push(3); + ch.close(); + await expect(p).resolves.toEqual([1, 2, 3]); + }); + + test("pushes before iteration start are buffered", async () => { + const ch = new EventChannel(); + ch.push("a"); + ch.push("b"); + ch.close(); + await expect(collect(ch)).resolves.toEqual(["a", "b"]); + }); + + test("waiting iterator is unblocked by subsequent push", async () => { + const ch = new EventChannel(); + const promise = collect(ch); + await new Promise((r) => setTimeout(r, 5)); + ch.push(42); + ch.close(); + await expect(promise).resolves.toEqual([42]); + }); + + test("close with no pending values terminates iteration", async () => { + const ch = new EventChannel(); + const p = collect(ch); + ch.close(); + await expect(p).resolves.toEqual([]); + }); + + test("push after close is a no-op (channel is closed)", async () => { + const ch = new EventChannel(); + ch.close(); + ch.push(1); + await expect(collect(ch)).resolves.toEqual([]); + }); + + test("close with error rejects the waiting iterator", async () => { + const ch = new EventChannel(); + const promise = collect(ch); + await new Promise((r) => setTimeout(r, 5)); + ch.close(new Error("boom")); + await expect(promise).rejects.toThrow(/boom/); + }); + + test("interleaved pushes and reads stream through", async () => { + const ch = new EventChannel(); + const received: number[] = []; + const reader = (async () => { + for await (const v of ch) { + received.push(v); + if (received.length === 3) break; + } + })(); + ch.push(1); + await new Promise((r) => setTimeout(r, 0)); + ch.push(2); + await new Promise((r) => setTimeout(r, 0)); + ch.push(3); + await reader; + expect(received).toEqual([1, 2, 3]); + ch.close(); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/event-translator.test.ts b/packages/appkit/src/plugins/agents/tests/event-translator.test.ts new file mode 100644 index 00000000..050af001 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/event-translator.test.ts @@ -0,0 +1,332 @@ +import type { ResponseStreamEvent } from "shared"; +import { describe, expect, test } from "vitest"; +import { AgentEventTranslator } from "../event-translator"; + +describe("AgentEventTranslator", () => { + test("translates message_delta to output_item.added + output_text.delta on first delta", () => { + const translator = new AgentEventTranslator(); + const events = translator.translate({ + type: "message_delta", + content: "Hello", + }); + + expect(events).toHaveLength(2); + expect(events[0].type).toBe("response.output_item.added"); + expect(events[1].type).toBe("response.output_text.delta"); + + if (events[1].type === "response.output_text.delta") { + expect(events[1].delta).toBe("Hello"); + } + }); + + test("subsequent message_delta only produces output_text.delta", () => { + const translator = new AgentEventTranslator(); + translator.translate({ type: "message_delta", content: "Hello" }); + const events = translator.translate({ + type: "message_delta", + content: " world", + }); + + expect(events).toHaveLength(1); + expect(events[0].type).toBe("response.output_text.delta"); + }); + + test("sequence_number is monotonically increasing", () => { + const translator = new AgentEventTranslator(); + const e1 = translator.translate({ type: "message_delta", content: "a" }); + const e2 = translator.translate({ type: "message_delta", content: "b" }); + const e3 = translator.finalize(); + + const allSeqs = [...e1, ...e2, ...e3].map((e) => + "sequence_number" in e ? e.sequence_number : -1, + ); + + for (let i = 1; i < allSeqs.length; i++) { + expect(allSeqs[i]).toBeGreaterThan(allSeqs[i - 1]); + } + }); + + test("translates tool_call to paired output_item.added + output_item.done", () => { + const translator = new AgentEventTranslator(); + const events = translator.translate({ + type: "tool_call", + callId: "call_1", + name: "analytics.query", + args: { sql: "SELECT 1" }, + }); + + expect(events).toHaveLength(2); + expect(events[0].type).toBe("response.output_item.added"); + expect(events[1].type).toBe("response.output_item.done"); + + if (events[0].type === "response.output_item.added") { + expect(events[0].item.type).toBe("function_call"); + if (events[0].item.type === "function_call") { + expect(events[0].item.name).toBe("analytics.query"); + expect(events[0].item.call_id).toBe("call_1"); + } + } + }); + + test("translates tool_result to paired output_item events", () => { + const translator = new AgentEventTranslator(); + const events = translator.translate({ + type: "tool_result", + callId: "call_1", + result: { rows: 42 }, + }); + + expect(events).toHaveLength(2); + expect(events[0].type).toBe("response.output_item.added"); + + if (events[0].type === "response.output_item.added") { + expect(events[0].item.type).toBe("function_call_output"); + } + }); + + test("translates tool_result error", () => { + const translator = new AgentEventTranslator(); + const events = translator.translate({ + type: "tool_result", + callId: "call_1", + result: null, + error: "Query failed", + }); + + if ( + events[0].type === "response.output_item.added" && + events[0].item.type === "function_call_output" + ) { + expect(events[0].item.output).toBe("Query failed"); + } + }); + + test("translates thinking to appkit.thinking extension event", () => { + const translator = new AgentEventTranslator(); + const events = translator.translate({ + type: "thinking", + content: "Let me think about this...", + }); + + expect(events).toHaveLength(1); + expect(events[0].type).toBe("appkit.thinking"); + if (events[0].type === "appkit.thinking") { + expect(events[0].content).toBe("Let me think about this..."); + } + }); + + test("translates metadata to appkit.metadata extension event", () => { + const translator = new AgentEventTranslator(); + const events = translator.translate({ + type: "metadata", + data: { threadId: "t-123" }, + }); + + expect(events).toHaveLength(1); + expect(events[0].type).toBe("appkit.metadata"); + if (events[0].type === "appkit.metadata") { + expect(events[0].data.threadId).toBe("t-123"); + } + }); + + test("status:complete triggers finalize with response.completed", () => { + const translator = new AgentEventTranslator(); + translator.translate({ type: "message_delta", content: "Hi" }); + const events = translator.translate({ type: "status", status: "complete" }); + + const types = events.map((e) => e.type); + expect(types).toContain("response.output_item.done"); + expect(types).toContain("response.completed"); + }); + + test("status:error emits error + response.failed", () => { + const translator = new AgentEventTranslator(); + const events = translator.translate({ + type: "status", + status: "error", + error: "Something broke", + }); + + expect(events).toHaveLength(2); + expect(events[0].type).toBe("error"); + expect(events[1].type).toBe("response.failed"); + + if (events[0].type === "error") { + expect(events[0].error).toBe("Something broke"); + } + }); + + test("finalize produces response.completed", () => { + const translator = new AgentEventTranslator(); + const events = translator.finalize(); + + expect(events.some((e) => e.type === "response.completed")).toBe(true); + }); + + test("finalize with accumulated message text produces output_item.done", () => { + const translator = new AgentEventTranslator(); + translator.translate({ type: "message_delta", content: "Hello " }); + translator.translate({ type: "message_delta", content: "world" }); + const events = translator.finalize(); + + const doneEvent = events.find( + (e) => e.type === "response.output_item.done", + ); + expect(doneEvent).toBeDefined(); + if ( + doneEvent?.type === "response.output_item.done" && + doneEvent.item.type === "message" + ) { + expect(doneEvent.item.content[0].text).toBe("Hello world"); + } + }); + + test("output_index increments for tool calls", () => { + const translator = new AgentEventTranslator(); + const e1 = translator.translate({ + type: "tool_call", + callId: "c1", + name: "tool1", + args: {}, + }); + const e2 = translator.translate({ + type: "tool_result", + callId: "c1", + result: "ok", + }); + + if ( + e1[0].type === "response.output_item.added" && + e2[0].type === "response.output_item.added" + ) { + expect(e2[0].output_index).toBeGreaterThan(e1[0].output_index); + } + }); +}); + +describe("AgentEventTranslator — monotonic output_index", () => { + /** + * Helper: every emitted `response.output_item.added`/`output_item.done` + * event's `output_index` must be >= every prior add/done `output_index`. + * This is the strict contract Responses-API clients (OpenAI's own SDK + * parser) enforce. + */ + function assertMonotonic(events: ResponseStreamEvent[]) { + let last = -1; + for (const ev of events) { + if ( + ev.type === "response.output_item.added" || + ev.type === "response.output_item.done" + ) { + expect(ev.output_index).toBeGreaterThanOrEqual(last); + last = ev.output_index; + } + } + } + + test("tool_call followed by message_delta emits monotonic indices (regression)", () => { + // Before the fix this produced: tool_call at index 1, then + // message_delta.added at 0 — monotonicity violated. + const t = new AgentEventTranslator(); + const all: ResponseStreamEvent[] = []; + all.push( + ...t.translate({ + type: "tool_call", + callId: "c1", + name: "lookup", + args: { q: "x" }, + }), + ); + all.push( + ...t.translate({ type: "tool_result", callId: "c1", result: "ok" }), + ); + all.push(...t.translate({ type: "message_delta", content: "Result: " })); + all.push(...t.translate({ type: "message_delta", content: "ok." })); + all.push(...t.finalize()); + + assertMonotonic(all); + + const added = all.filter((e) => e.type === "response.output_item.added"); + // Three items: tool_call, tool_result, message. Indices 0/1/2. + expect(added.map((e) => e.output_index)).toEqual([0, 1, 2]); + }); + + test("message interrupted by tool_call is closed before the tool_call opens", () => { + const t = new AgentEventTranslator(); + const all: ResponseStreamEvent[] = []; + all.push(...t.translate({ type: "message_delta", content: "thinking..." })); + all.push( + ...t.translate({ + type: "tool_call", + callId: "c1", + name: "lookup", + args: {}, + }), + ); + all.push( + ...t.translate({ type: "tool_result", callId: "c1", result: "ok" }), + ); + all.push(...t.translate({ type: "message_delta", content: "final" })); + all.push(...t.finalize()); + + assertMonotonic(all); + + // Structure: msg0.added, msg0.delta, msg0.done (closed before tool), + // tool_call.added/done, tool_result.added/done, msg1.added, msg1.delta, + // msg1.done (from finalize), response.completed. + const addedDone = all.filter( + (e) => + e.type === "response.output_item.added" || + e.type === "response.output_item.done", + ); + expect(addedDone.map((e) => `${e.type}@${e.output_index}`)).toEqual([ + "response.output_item.added@0", + "response.output_item.done@0", + "response.output_item.added@1", + "response.output_item.done@1", + "response.output_item.added@2", + "response.output_item.done@2", + "response.output_item.added@3", + "response.output_item.done@3", + ]); + }); + + test("full `message` event after deltas does not double-emit output_item.added", () => { + const t = new AgentEventTranslator(); + const all: ResponseStreamEvent[] = []; + all.push(...t.translate({ type: "message_delta", content: "partial" })); + all.push( + ...t.translate({ type: "message", content: "full final content" }), + ); + all.push(...t.finalize()); + + const added = all.filter((e) => e.type === "response.output_item.added"); + const done = all.filter((e) => e.type === "response.output_item.done"); + // Exactly one added (from the first delta) and one done (from the full + // message). finalize() must not emit a second done for the same item. + expect(added).toHaveLength(1); + expect(done).toHaveLength(1); + if (done[0].type === "response.output_item.done") { + const item = done[0].item; + if (item.type === "message") { + expect(item.content[0].text).toBe("full final content"); + } + } + }); + + test("tool_result coerces undefined result to empty-string output", () => { + const t = new AgentEventTranslator(); + const events = t.translate({ + type: "tool_result", + callId: "c1", + result: undefined, + }); + const done = events.find((e) => e.type === "response.output_item.done"); + if (done?.type === "response.output_item.done") { + const item = done.item; + if (item.type === "function_call_output") { + expect(item.output).toBe(""); + } + } + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/function-tool.test.ts b/packages/appkit/src/plugins/agents/tests/function-tool.test.ts new file mode 100644 index 00000000..8e668d69 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/function-tool.test.ts @@ -0,0 +1,110 @@ +import { describe, expect, test } from "vitest"; +import { + functionToolToDefinition, + isFunctionTool, +} from "../tools/function-tool"; + +describe("isFunctionTool", () => { + test("returns true for valid FunctionTool", () => { + expect( + isFunctionTool({ + type: "function", + name: "greet", + execute: async () => "hello", + }), + ).toBe(true); + }); + + test("returns true for minimal FunctionTool", () => { + expect( + isFunctionTool({ + type: "function", + name: "x", + execute: () => "y", + }), + ).toBe(true); + }); + + test("returns false for null", () => { + expect(isFunctionTool(null)).toBe(false); + }); + + test("returns false for non-object", () => { + expect(isFunctionTool("function")).toBe(false); + }); + + test("returns false for wrong type", () => { + expect( + isFunctionTool({ + type: "genie-space", + name: "x", + execute: () => "y", + }), + ).toBe(false); + }); + + test("returns false when execute is missing", () => { + expect(isFunctionTool({ type: "function", name: "x" })).toBe(false); + }); + + test("returns false when name is missing", () => { + expect(isFunctionTool({ type: "function", execute: () => "y" })).toBe( + false, + ); + }); +}); + +describe("functionToolToDefinition", () => { + test("converts a FunctionTool with all fields", () => { + const def = functionToolToDefinition({ + type: "function", + name: "getWeather", + description: "Get current weather", + parameters: { + type: "object", + properties: { city: { type: "string" } }, + required: ["city"], + }, + execute: async () => "sunny", + }); + + expect(def.name).toBe("getWeather"); + expect(def.description).toBe("Get current weather"); + expect(def.parameters).toEqual({ + type: "object", + properties: { city: { type: "string" } }, + required: ["city"], + }); + }); + + test("uses name as fallback description", () => { + const def = functionToolToDefinition({ + type: "function", + name: "myTool", + execute: async () => "result", + }); + + expect(def.description).toBe("myTool"); + }); + + test("uses empty object schema when parameters are null", () => { + const def = functionToolToDefinition({ + type: "function", + name: "noParams", + parameters: null, + execute: async () => "ok", + }); + + expect(def.parameters).toEqual({ type: "object", properties: {} }); + }); + + test("uses empty object schema when parameters are omitted", () => { + const def = functionToolToDefinition({ + type: "function", + name: "noParams", + execute: async () => "ok", + }); + + expect(def.parameters).toEqual({ type: "object", properties: {} }); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/hosted-tools.test.ts b/packages/appkit/src/plugins/agents/tests/hosted-tools.test.ts new file mode 100644 index 00000000..d62b266b --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/hosted-tools.test.ts @@ -0,0 +1,131 @@ +import { describe, expect, test } from "vitest"; +import { isHostedTool, resolveHostedTools } from "../tools/hosted-tools"; + +describe("isHostedTool", () => { + test("returns true for genie-space", () => { + expect( + isHostedTool({ type: "genie-space", genie_space: { id: "abc" } }), + ).toBe(true); + }); + + test("returns true for vector_search_index", () => { + expect( + isHostedTool({ + type: "vector_search_index", + vector_search_index: { name: "cat.schema.idx" }, + }), + ).toBe(true); + }); + + test("returns true for custom_mcp_server", () => { + expect( + isHostedTool({ + type: "custom_mcp_server", + custom_mcp_server: { app_name: "my-app", app_url: "my-app-url" }, + }), + ).toBe(true); + }); + + test("returns true for external_mcp_server", () => { + expect( + isHostedTool({ + type: "external_mcp_server", + external_mcp_server: { connection_name: "conn1" }, + }), + ).toBe(true); + }); + + test("returns false for FunctionTool", () => { + expect( + isHostedTool({ type: "function", name: "x", execute: () => "y" }), + ).toBe(false); + }); + + test("returns false for null", () => { + expect(isHostedTool(null)).toBe(false); + }); + + test("returns false for unknown type", () => { + expect(isHostedTool({ type: "unknown" })).toBe(false); + }); + + test("returns false for non-object", () => { + expect(isHostedTool(42)).toBe(false); + }); +}); + +describe("resolveHostedTools", () => { + test("resolves genie-space to correct MCP endpoint", () => { + const configs = resolveHostedTools([ + { type: "genie-space", genie_space: { id: "space123" } }, + ]); + + expect(configs).toHaveLength(1); + expect(configs[0].name).toBe("genie-space123"); + expect(configs[0].url).toBe("/api/2.0/mcp/genie/space123"); + }); + + test("resolves vector_search_index with 3-part name", () => { + const configs = resolveHostedTools([ + { + type: "vector_search_index", + vector_search_index: { name: "catalog.schema.my_index" }, + }, + ]); + + expect(configs).toHaveLength(1); + expect(configs[0].name).toBe("vs-catalog-schema-my_index"); + expect(configs[0].url).toBe( + "/api/2.0/mcp/vector-search/catalog/schema/my_index", + ); + }); + + test("throws for invalid vector_search_index name", () => { + expect(() => + resolveHostedTools([ + { + type: "vector_search_index", + vector_search_index: { name: "bad.name" }, + }, + ]), + ).toThrow("3-part dotted"); + }); + + test("resolves custom_mcp_server", () => { + const configs = resolveHostedTools([ + { + type: "custom_mcp_server", + custom_mcp_server: { app_name: "my-app", app_url: "my-app-endpoint" }, + }, + ]); + + expect(configs[0].name).toBe("my-app"); + expect(configs[0].url).toBe("my-app-endpoint"); + }); + + test("resolves external_mcp_server", () => { + const configs = resolveHostedTools([ + { + type: "external_mcp_server", + external_mcp_server: { connection_name: "conn1" }, + }, + ]); + + expect(configs[0].name).toBe("conn1"); + expect(configs[0].url).toBe("/api/2.0/mcp/external/conn1"); + }); + + test("resolves multiple tools preserving order", () => { + const configs = resolveHostedTools([ + { type: "genie-space", genie_space: { id: "g1" } }, + { + type: "external_mcp_server", + external_mcp_server: { connection_name: "e1" }, + }, + ]); + + expect(configs).toHaveLength(2); + expect(configs[0].name).toBe("genie-g1"); + expect(configs[1].name).toBe("e1"); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/load-agents.test.ts b/packages/appkit/src/plugins/agents/tests/load-agents.test.ts new file mode 100644 index 00000000..1a5b9523 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/load-agents.test.ts @@ -0,0 +1,360 @@ +import fs from "node:fs"; +import os from "node:os"; +import path from "node:path"; +import { afterEach, beforeEach, describe, expect, test } from "vitest"; +import { z } from "zod"; +import { buildToolkitEntries } from "../build-toolkit"; +import { + agentIdFromMarkdownPath, + loadAgentFromFile, + loadAgentsFromDir, + parseFrontmatter, +} from "../load-agents"; +import { defineTool, type ToolRegistry } from "../tools/define-tool"; +import { tool } from "../tools/tool"; +import type { AgentDefinition } from "../types"; + +let workDir: string; + +beforeEach(() => { + workDir = fs.mkdtempSync(path.join(os.tmpdir(), "agents-test-")); +}); + +afterEach(() => { + fs.rmSync(workDir, { recursive: true, force: true }); +}); + +/** Flat file under workDir (for legacy loadAgentFromFile tests). */ +function writeRoot(name: string, content: string) { + fs.writeFileSync(path.join(workDir, name), content, "utf-8"); + return path.join(workDir, name); +} + +/** Folder layout: `/agent.md`. */ +function writeAgent(id: string, content: string) { + const dir = path.join(workDir, id); + fs.mkdirSync(dir, { recursive: true }); + const p = path.join(dir, "agent.md"); + fs.writeFileSync(p, content, "utf-8"); + return p; +} + +describe("agentIdFromMarkdownPath", () => { + test("uses parent folder name when file is agent.md", () => { + expect(agentIdFromMarkdownPath("/foo/bar/assistant/agent.md")).toBe( + "assistant", + ); + }); + + test("uses file stem for other .md names", () => { + expect(agentIdFromMarkdownPath("/tmp/assistant.md")).toBe("assistant"); + }); +}); + +describe("parseFrontmatter", () => { + test("parses a simple object", () => { + const { data, content } = parseFrontmatter( + "---\nendpoint: foo\ndefault: true\n---\nHello body", + ); + expect(data).toEqual({ endpoint: "foo", default: true }); + expect(content).toBe("Hello body"); + }); + + test("parses nested arrays", () => { + const { data } = parseFrontmatter( + "---\ntoolkits:\n - analytics\n - files: [uploads.list]\n---\nbody", + ); + expect(data?.toolkits).toEqual(["analytics", { files: ["uploads.list"] }]); + }); + + test("returns null data when no frontmatter", () => { + const { data, content } = parseFrontmatter("No frontmatter here"); + expect(data).toBeNull(); + expect(content).toBe("No frontmatter here"); + }); + + test("throws on invalid YAML", () => { + expect(() => parseFrontmatter("---\nkey: : : bad\n---\n")).toThrow(/YAML/); + }); +}); + +describe("loadAgentFromFile", () => { + test("returns AgentDefinition with body as instructions", async () => { + const p = writeRoot( + "assistant.md", + "---\nendpoint: e-1\n---\nYou are helpful.", + ); + const def = await loadAgentFromFile(p, {}); + expect(def.name).toBe("assistant"); + expect(def.instructions).toBe("You are helpful."); + expect(def.model).toBe("e-1"); + }); + + test("derives agent id from folder when path ends with agent.md", async () => { + const p = writeAgent("router", "---\nendpoint: e-1\n---\nRoute traffic."); + const def = await loadAgentFromFile(p, {}); + expect(def.name).toBe("router"); + expect(def.instructions).toBe("Route traffic."); + }); +}); + +describe("loadAgentsFromDir", () => { + test("returns empty map when dir doesn't exist", async () => { + const res = await loadAgentsFromDir("/nonexistent-for-tests", {}); + expect(res.defs).toEqual({}); + expect(res.defaultAgent).toBeNull(); + }); + + test("loads each subdirectory with agent.md keyed by folder name", async () => { + writeAgent("support", "---\nendpoint: e-1\n---\nSupport prompt."); + writeAgent("sales", "---\nendpoint: e-2\n---\nSales prompt."); + const res = await loadAgentsFromDir(workDir, {}); + expect(Object.keys(res.defs).sort()).toEqual(["sales", "support"]); + }); + + test("throws when legacy top-level .md exists", async () => { + writeRoot("assistant.md", "---\nendpoint: e\n---\nLegacy flat file."); + await expect(loadAgentsFromDir(workDir, {})).rejects.toThrow( + /unsupported top-level markdown file\(s\): assistant\.md.*assistant\/agent\.md/s, + ); + }); + + test("throws when a subdirectory lacks agent.md", async () => { + fs.mkdirSync(path.join(workDir, "broken"), { recursive: true }); + await expect(loadAgentsFromDir(workDir, {})).rejects.toThrow( + /must contain agent\.md/, + ); + }); + + test("ignores reserved skills directory without agent.md", async () => { + fs.mkdirSync(path.join(workDir, "skills"), { recursive: true }); + writeAgent("solo", "---\nendpoint: e\n---\nOnly real agent."); + const res = await loadAgentsFromDir(workDir, {}); + expect(Object.keys(res.defs)).toEqual(["solo"]); + }); + + test("picks up default: true from frontmatter (deterministic sorted ids)", async () => { + writeAgent("one", "---\nendpoint: a\n---\nOne."); + writeAgent("two", "---\nendpoint: b\ndefault: true\n---\nTwo."); + const res = await loadAgentsFromDir(workDir, {}); + expect(res.defaultAgent).toBe("two"); + }); + + test("throws when frontmatter references an unregistered plugin", async () => { + writeAgent( + "broken", + "---\nendpoint: e\ntoolkits: [missing]\n---\nBroken agent.", + ); + await expect(loadAgentsFromDir(workDir, {})).rejects.toThrow( + /references toolkit 'missing'/, + ); + }); + + test("throws when frontmatter references an unknown ambient tool", async () => { + writeAgent( + "broken", + "---\nendpoint: e\ntools: [unknown_tool]\n---\nBroken.", + ); + await expect(loadAgentsFromDir(workDir, {})).rejects.toThrow( + /references tool 'unknown_tool'/, + ); + }); + + test("resolves toolkits + ambient tools when provided", async () => { + const registry: ToolRegistry = { + query: defineTool({ + description: "q", + schema: z.object({ sql: z.string() }), + handler: () => "ok", + }), + }; + const plugins = new Map< + string, + { toolkit: (opts?: unknown) => Record } + >([ + [ + "analytics", + { + toolkit: (opts) => + buildToolkitEntries("analytics", registry, opts as never), + }, + ], + ]); + + const weather = tool({ + name: "get_weather", + description: "Weather", + schema: z.object({ city: z.string() }), + execute: async () => "sunny", + }); + + writeAgent( + "analyst", + "---\nendpoint: e\ntoolkits:\n - analytics\ntools:\n - get_weather\n---\nBody.", + ); + const res = await loadAgentsFromDir(workDir, { + plugins, + availableTools: { get_weather: weather }, + }); + expect(res.defs.analyst.tools).toBeDefined(); + expect(Object.keys(res.defs.analyst.tools ?? {}).sort()).toEqual([ + "analytics.query", + "get_weather", + ]); + }); + + describe("agents: sibling sub-agent references", () => { + test("resolves sibling references into def.agents regardless of folder order", async () => { + writeAgent( + "dispatcher", + "---\nendpoint: e\nagents:\n - analyst\n - writer\n---\nRoute work.", + ); + writeAgent("analyst", "---\nendpoint: e\n---\nAnalyst."); + writeAgent("writer", "---\nendpoint: e\n---\nWriter."); + + const res = await loadAgentsFromDir(workDir, {}); + expect(Object.keys(res.defs.dispatcher.agents ?? {}).sort()).toEqual([ + "analyst", + "writer", + ]); + expect(res.defs.dispatcher.agents?.analyst).toBe(res.defs.analyst); + expect(res.defs.dispatcher.agents?.writer).toBe(res.defs.writer); + expect(res.defs.analyst.agents).toBeUndefined(); + expect(res.defs.writer.agents).toBeUndefined(); + }); + + test("mutual delegation is allowed (runtime depth cap handles cycles)", async () => { + writeAgent("a", "---\nendpoint: e\nagents:\n - b\n---\nA."); + writeAgent("b", "---\nendpoint: e\nagents:\n - a\n---\nB."); + + const res = await loadAgentsFromDir(workDir, {}); + expect(res.defs.a.agents?.b).toBe(res.defs.b); + expect(res.defs.b.agents?.a).toBe(res.defs.a); + }); + + test("throws with available list when a sibling is missing", async () => { + writeAgent("dispatcher", "---\nendpoint: e\nagents:\n - ghost\n---\nD."); + writeAgent("analyst", "---\nendpoint: e\n---\nAnalyst."); + await expect(loadAgentsFromDir(workDir, {})).rejects.toThrow( + /references sub-agent\(s\) 'ghost'.*Available: analyst, dispatcher/s, + ); + }); + + test("reports every missing sibling in one error, not just the first", async () => { + writeAgent( + "dispatcher", + "---\nendpoint: e\nagents:\n - ghost1\n - ghost2\n---\nD.", + ); + await expect(loadAgentsFromDir(workDir, {})).rejects.toThrow( + /ghost1, ghost2/, + ); + }); + + test("throws on self-reference", async () => { + writeAgent("solo", "---\nendpoint: e\nagents:\n - solo\n---\nSolo."); + await expect(loadAgentsFromDir(workDir, {})).rejects.toThrow( + /'solo'.*cannot reference itself/s, + ); + }); + + test("throws on non-array 'agents:' value", async () => { + writeAgent("bad", "---\nendpoint: e\nagents: analyst\n---\nBad."); + writeAgent("analyst", "---\nendpoint: e\n---\nAnalyst."); + await expect(loadAgentsFromDir(workDir, {})).rejects.toThrow( + /invalid 'agents:' frontmatter/, + ); + }); + + test("throws on non-string entries in 'agents:'", async () => { + writeAgent("bad", "---\nendpoint: e\nagents:\n - 42\n---\nBad."); + await expect(loadAgentsFromDir(workDir, {})).rejects.toThrow( + /invalid 'agents:' entry/, + ); + }); + + test("deduplicates repeated entries silently", async () => { + writeAgent( + "dispatcher", + "---\nendpoint: e\nagents:\n - analyst\n - analyst\n---\nD.", + ); + writeAgent("analyst", "---\nendpoint: e\n---\nAnalyst."); + const res = await loadAgentsFromDir(workDir, {}); + expect(Object.keys(res.defs.dispatcher.agents ?? {})).toEqual([ + "analyst", + ]); + }); + + test("empty array yields no sub-agents (no-op)", async () => { + writeAgent("dispatcher", "---\nendpoint: e\nagents: []\n---\nD."); + const res = await loadAgentsFromDir(workDir, {}); + expect(res.defs.dispatcher.agents).toBeUndefined(); + }); + + test("resolves 'agents:' references against codeAgents when provided", async () => { + writeAgent( + "dispatcher", + "---\nendpoint: e\nagents:\n - support\n---\nD.", + ); + const support: AgentDefinition = { + name: "support", + instructions: "Code-defined support.", + }; + const res = await loadAgentsFromDir(workDir, { + codeAgents: { support }, + }); + expect(res.defs.dispatcher.agents?.support).toBe(support); + }); + + test("codeAgents takes precedence over markdown sibling with the same name", async () => { + writeAgent( + "dispatcher", + "---\nendpoint: e\nagents:\n - support\n---\nD.", + ); + writeAgent("support", "---\nendpoint: e\n---\nMarkdown support."); + const codeSupport: AgentDefinition = { + name: "support", + instructions: "Code support.", + }; + const res = await loadAgentsFromDir(workDir, { + codeAgents: { support: codeSupport }, + }); + expect(res.defs.dispatcher.agents?.support).toBe(codeSupport); + expect(res.defs.dispatcher.agents?.support.instructions).toBe( + "Code support.", + ); + }); + + test("missing-sibling error lists both markdown and code agent names", async () => { + writeAgent("dispatcher", "---\nendpoint: e\nagents:\n - ghost\n---\nD."); + writeAgent("analyst", "---\nendpoint: e\n---\nAnalyst."); + const codeAgent: AgentDefinition = { + name: "writer", + instructions: "Writer.", + }; + await expect( + loadAgentsFromDir(workDir, { codeAgents: { writer: codeAgent } }), + ).rejects.toThrow(/Available: analyst, dispatcher, writer/); + }); + }); +}); + +describe("loadAgentFromFile — sub-agent refs rejected", () => { + test("throws when 'agents:' is non-empty in a single-file load", async () => { + const p = writeRoot( + "lonely.md", + "---\nendpoint: e\nagents:\n - ghost\n---\nLonely.", + ); + await expect(loadAgentFromFile(p, {})).rejects.toThrow( + /requires loadAgentsFromDir/, + ); + }); + + test("ignores empty 'agents:' array (treated as absent)", async () => { + const p = writeRoot( + "lonely.md", + "---\nendpoint: e\nagents: []\n---\nLonely.", + ); + const def = await loadAgentFromFile(p, {}); + expect(def.agents).toBeUndefined(); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/mcp-server-helper.test.ts b/packages/appkit/src/plugins/agents/tests/mcp-server-helper.test.ts new file mode 100644 index 00000000..96ad8e38 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/mcp-server-helper.test.ts @@ -0,0 +1,34 @@ +import { describe, expect, test } from "vitest"; +import { + isHostedTool, + mcpServer, + resolveHostedTools, +} from "../tools/hosted-tools"; + +describe("mcpServer()", () => { + test("returns a CustomMcpServerTool with correct shape", () => { + const result = mcpServer("my-app", "https://example.com/mcp"); + + expect(result).toEqual({ + type: "custom_mcp_server", + custom_mcp_server: { + app_name: "my-app", + app_url: "https://example.com/mcp", + }, + }); + }); + + test("isHostedTool recognizes mcpServer() output", () => { + expect(isHostedTool(mcpServer("x", "y"))).toBe(true); + }); + + test("resolveHostedTools resolves mcpServer() output to an endpoint config", () => { + const configs = resolveHostedTools([ + mcpServer("vector-search", "https://host/mcp/vs"), + ]); + + expect(configs).toHaveLength(1); + expect(configs[0].name).toBe("vector-search"); + expect(configs[0].url).toBe("https://host/mcp/vs"); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/normalize-result.test.ts b/packages/appkit/src/plugins/agents/tests/normalize-result.test.ts new file mode 100644 index 00000000..a0545d09 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/normalize-result.test.ts @@ -0,0 +1,63 @@ +import { describe, expect, test } from "vitest"; +import { + MAX_TOOL_RESULT_CHARS, + normalizeToolResult, +} from "../normalize-result"; + +describe("normalizeToolResult", () => { + test("maps undefined to empty string so void tools don't surface as errors", () => { + expect(normalizeToolResult(undefined)).toBe(""); + }); + + test("returns strings unchanged", () => { + expect(normalizeToolResult("hello")).toBe("hello"); + }); + + test("leaves non-string results intact (caller serialises)", () => { + const result = normalizeToolResult({ rows: 2, ok: true }); + expect(result).toEqual({ rows: 2, ok: true }); + }); + + test("returns an empty string input as an empty string (not undefined)", () => { + expect(normalizeToolResult("")).toBe(""); + }); + + test("preserves null without converting to empty string", () => { + expect(normalizeToolResult(null)).toBeNull(); + }); + + test("truncates long strings and appends a marker with the original length", () => { + const big = "x".repeat(MAX_TOOL_RESULT_CHARS + 1000); + const result = normalizeToolResult(big); + expect(typeof result).toBe("string"); + const s = result as string; + // Content portion is bounded to MAX_TOOL_RESULT_CHARS (plus the marker). + expect(s.slice(0, MAX_TOOL_RESULT_CHARS)).toBe( + "x".repeat(MAX_TOOL_RESULT_CHARS), + ); + expect(s).toMatch( + new RegExp( + `\\[Result truncated: ${big.length} chars exceeds ${MAX_TOOL_RESULT_CHARS} limit\\]`, + ), + ); + }); + + test("truncates long serialised objects the same way", () => { + const big = { blob: "x".repeat(MAX_TOOL_RESULT_CHARS + 10) }; + const result = normalizeToolResult(big); + expect(typeof result).toBe("string"); + expect(result as string).toMatch(/\[Result truncated:/); + }); + + test("honours a custom maxChars parameter", () => { + const result = normalizeToolResult("hello world", 5); + expect(result).toBe( + "hello\n\n[Result truncated: 11 chars exceeds 5 limit]", + ); + }); + + test("does not truncate at the boundary (exact length is fine)", () => { + const s = "x".repeat(MAX_TOOL_RESULT_CHARS); + expect(normalizeToolResult(s)).toBe(s); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/run-agent.test.ts b/packages/appkit/src/plugins/agents/tests/run-agent.test.ts new file mode 100644 index 00000000..1a974811 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/run-agent.test.ts @@ -0,0 +1,120 @@ +import type { + AgentAdapter, + AgentEvent, + AgentInput, + AgentRunContext, +} from "shared"; +import { describe, expect, test, vi } from "vitest"; +import { z } from "zod"; +import { createAgent } from "../../../core/create-agent-def"; +import { runAgent } from "../../../core/run-agent"; +import { tool } from "../tools/tool"; +import type { ToolkitEntry } from "../types"; + +function scriptedAdapter(events: AgentEvent[]): AgentAdapter { + return { + async *run(_input: AgentInput, _context: AgentRunContext) { + for (const event of events) { + yield event; + } + }, + }; +} + +describe("runAgent", () => { + test("drives the adapter and returns aggregated text", async () => { + const events: AgentEvent[] = [ + { type: "message_delta", content: "Hello " }, + { type: "message_delta", content: "world" }, + { type: "status", status: "complete" }, + ]; + const def = createAgent({ + instructions: "Say hello", + model: scriptedAdapter(events), + }); + + const result = await runAgent(def, { messages: "hi" }); + expect(result.text).toBe("Hello world"); + expect(result.events).toHaveLength(3); + }); + + test("prefers terminal 'message' event over deltas when present", async () => { + const events: AgentEvent[] = [ + { type: "message_delta", content: "partial" }, + { type: "message", content: "final answer" }, + ]; + const def = createAgent({ + instructions: "x", + model: scriptedAdapter(events), + }); + const result = await runAgent(def, { messages: "hi" }); + expect(result.text).toBe("final answer"); + }); + + test("invokes inline tools via executeTool callback", async () => { + const weatherFn = vi.fn(async () => "Sunny in NYC"); + const weather = tool({ + name: "get_weather", + description: "Weather", + schema: z.object({ city: z.string() }), + execute: weatherFn, + }); + + let capturedCtx: AgentRunContext | null = null; + const adapter: AgentAdapter = { + async *run(_input, context) { + capturedCtx = context; + yield { type: "message_delta", content: "" }; + }, + }; + + const def = createAgent({ + instructions: "x", + model: adapter, + tools: { get_weather: weather }, + }); + + await runAgent(def, { messages: "hi" }); + expect(capturedCtx).not.toBeNull(); + // biome-ignore lint/style/noNonNullAssertion: asserted above + const result = await capturedCtx!.executeTool("get_weather", { + city: "NYC", + }); + expect(result).toBe("Sunny in NYC"); + expect(weatherFn).toHaveBeenCalledWith({ city: "NYC" }); + }); + + test("throws a clear error when a ToolkitEntry is invoked", async () => { + const toolkitEntry: ToolkitEntry = { + __toolkitRef: true, + pluginName: "analytics", + localName: "query", + def: { + name: "analytics.query", + description: "SQL", + parameters: { type: "object", properties: {} }, + }, + }; + + let capturedCtx: AgentRunContext | null = null; + const adapter: AgentAdapter = { + async *run(_input, context) { + capturedCtx = context; + yield { type: "message_delta", content: "" }; + }, + }; + + const def = createAgent({ + instructions: "x", + model: adapter, + tools: { "analytics.query": toolkitEntry }, + }); + + await runAgent(def, { messages: "hi" }); + expect(capturedCtx).not.toBeNull(); + await expect( + // biome-ignore lint/style/noNonNullAssertion: asserted above + capturedCtx!.executeTool("analytics.query", {}), + ).rejects.toThrow(/only usable via createApp/); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/sql-policy.test.ts b/packages/appkit/src/plugins/agents/tests/sql-policy.test.ts new file mode 100644 index 00000000..fb81493e --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/sql-policy.test.ts @@ -0,0 +1,227 @@ +import { describe, expect, test } from "vitest"; +import { + assertReadOnlySql, + classifyReadOnly, + ReadOnlySqlViolation, +} from "../tools/sql-policy"; + +function ok(sql: string) { + const result = classifyReadOnly(sql); + if (!result.readOnly) { + throw new Error( + `Expected readOnly=true for ${JSON.stringify(sql)}, got reason: ${result.reason}`, + ); + } + return result; +} + +function rejected(sql: string) { + const result = classifyReadOnly(sql); + if (result.readOnly) { + throw new Error( + `Expected readOnly=false for ${JSON.stringify(sql)}, got readOnly=true`, + ); + } + return result; +} + +describe("classifyReadOnly: plain reads are admitted", () => { + test.each([ + "SELECT 1", + "select 1", + "SELECT * FROM users", + "SELECT * FROM main.sales.orders WHERE created_at > now() - interval '7 days'", + "SELECT COUNT(*) FROM main.sales.orders", + "WITH a AS (SELECT 1) SELECT * FROM a", + "WITH RECURSIVE t AS (SELECT 1) SELECT * FROM t", + "SHOW TABLES", + "SHOW TABLES IN main.sales", + "DESCRIBE EXTENDED main.sales.orders", + "DESC main.sales.orders", + "EXPLAIN SELECT 1", + "EXPLAIN ANALYZE SELECT 1", + ])("admits %s", (sql) => { + expect(ok(sql).statements).toBe(1); + }); +}); + +describe("classifyReadOnly: writes are rejected", () => { + test.each([ + ["DROP TABLE users", "DROP"], + ["UPDATE users SET email = 'x@y.com'", "UPDATE"], + ["DELETE FROM orders WHERE id = 1", "DELETE"], + ["INSERT INTO x VALUES (1)", "INSERT"], + ["CREATE TABLE x (id INT)", "CREATE"], + ["ALTER TABLE x ADD COLUMN y INT", "ALTER"], + ["TRUNCATE TABLE orders", "TRUNCATE"], + ["GRANT SELECT ON t TO u", "GRANT"], + ["REVOKE ALL ON t FROM u", "REVOKE"], + ["CALL sp_do_thing()", "CALL"], + ["COPY t FROM '/tmp/x'", "COPY"], + ["MERGE INTO t USING s", "MERGE"], + ["REFRESH TABLE t", "REFRESH"], + ["VACUUM t", "VACUUM"], + ])("rejects %s", (sql, keyword) => { + const result = rejected(sql); + expect(result.reason).toContain(keyword); + }); +}); + +describe("classifyReadOnly: stacked statements", () => { + test("rejects SELECT followed by DROP", () => { + const result = rejected("SELECT 1; DROP TABLE x"); + expect(result.reason).toMatch(/DROP/); + }); + + test("rejects DROP followed by SELECT (write comes first)", () => { + const result = rejected("DROP TABLE x; SELECT 1"); + expect(result.reason).toMatch(/DROP/); + }); + + test("admits multiple SELECTs", () => { + expect(ok("SELECT 1; SELECT 2").statements).toBe(2); + }); + + test("admits trailing semicolon on single statement", () => { + expect(ok("SELECT 1;").statements).toBe(1); + }); + + test("admits SELECT, SHOW, DESCRIBE batch", () => { + const result = ok("SELECT 1; SHOW TABLES; DESCRIBE x;"); + expect(result.statements).toBe(3); + }); +}); + +describe("classifyReadOnly: comment handling", () => { + test("admits SELECT with line comment hiding a write keyword", () => { + ok("SELECT 1 -- DROP TABLE x\n"); + }); + + test("admits SELECT preceded by line comment with write keyword", () => { + ok("-- DROP TABLE x\nSELECT 1"); + }); + + test("admits SELECT with block comment containing stacked write", () => { + ok("SELECT 1 /* ; DROP TABLE x */"); + }); + + test("handles nested block comments (PostgreSQL style)", () => { + ok("SELECT 1 /* outer /* inner */ still inside */"); + }); + + test("rejects when write is outside the comment", () => { + const result = rejected("/* SELECT 1 */ DROP TABLE x"); + expect(result.reason).toMatch(/DROP/); + }); + + test("empty after stripping comments is rejected", () => { + rejected("-- only a comment"); + rejected("/* nothing */"); + }); +}); + +describe("classifyReadOnly: string literal handling", () => { + test("admits SELECT with write keyword inside single-quoted string", () => { + ok("SELECT 'DROP TABLE x' AS msg"); + }); + + test("admits SELECT with semicolon inside single-quoted string", () => { + ok("SELECT 'value; DROP TABLE x' AS msg"); + }); + + test("admits SELECT with doubled-quote escape", () => { + ok("SELECT 'it''s ok; DROP' AS msg"); + }); + + test("admits SELECT with backslash escape inside string", () => { + ok("SELECT E'line\\'s end; DROP' AS msg"); + }); + + test("admits SELECT with dollar-quoted string hiding a write", () => { + ok("SELECT $body$ arbitrary ; DROP TABLE x $body$ AS msg"); + }); + + test("admits SELECT with untagged dollar quote", () => { + ok("SELECT $$hello; DROP$$ AS msg"); + }); + + test("admits SELECT with ANSI double-quoted identifier named drop", () => { + ok('SELECT * FROM "drop"'); + }); + + test("admits SELECT with doubled-quote inside ANSI identifier", () => { + ok('SELECT * FROM "weird""name"'); + }); + + test("admits SELECT with backtick identifier (Databricks)", () => { + ok("SELECT * FROM `my table`"); + }); +}); + +describe("classifyReadOnly: degenerate input", () => { + test("rejects empty string", () => { + rejected(""); + }); + + test("rejects whitespace-only", () => { + rejected(" \n\t "); + }); + + test("rejects semicolons only", () => { + rejected(";;;"); + }); + + test("rejects non-SQL garbage", () => { + rejected("-- this is just a comment\n-- nothing else"); + rejected("random garbage text"); + }); + + test("rejects a single empty statement between two selects", () => { + // "SELECT 1;; SELECT 2" — the middle empty statement is dropped by + // splitter; the surviving two statements are both SELECT, admitted. + ok("SELECT 1;; SELECT 2"); + }); +}); + +describe("classifyReadOnly: evasion-resistance", () => { + test("cannot hide DROP after a comment-ended newline", () => { + const result = rejected("-- intent\nDROP TABLE x"); + expect(result.reason).toMatch(/DROP/); + }); + + test("cannot hide DROP via concatenated strings (strings end cleanly)", () => { + rejected("'SELECT 1'; DROP TABLE x"); + }); + + test("bare DROP after unclosed string is still considered part of the string (defensive)", () => { + // An unclosed single quote eats the rest of the input — classifier + // sees the whole thing as one stripped, empty-ish statement and rejects. + rejected("SELECT 'unterminated ; DROP TABLE x"); + }); + + test("dollar-quoted literal with malicious tag is handled", () => { + ok("SELECT $tag$ DROP $tag$ AS harmless"); + }); + + test("mismatched dollar-quote tag is treated as unterminated", () => { + rejected("SELECT $a$ DROP TABLE x $b$"); + }); +}); + +describe("assertReadOnlySql", () => { + test("returns void on read-only SQL", () => { + expect(() => assertReadOnlySql("SELECT 1")).not.toThrow(); + }); + + test("throws ReadOnlySqlViolation with descriptive message on writes", () => { + expect(() => assertReadOnlySql("DROP TABLE x")).toThrow( + ReadOnlySqlViolation, + ); + try { + assertReadOnlySql("DROP TABLE x"); + } catch (e) { + expect((e as Error).message).toMatch(/SQL read-only policy violation/); + expect((e as Error).message).toMatch(/DROP/); + } + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/system-prompt.test.ts b/packages/appkit/src/plugins/agents/tests/system-prompt.test.ts new file mode 100644 index 00000000..25724259 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/system-prompt.test.ts @@ -0,0 +1,59 @@ +import { describe, expect, test } from "vitest"; +import { buildBaseSystemPrompt, composeSystemPrompt } from "../system-prompt"; + +const emptyCtx = { + agentName: "a", + pluginNames: [] as string[], + toolNames: [] as string[], +}; + +describe("buildBaseSystemPrompt", () => { + test("includes plugin names", () => { + const prompt = buildBaseSystemPrompt({ + agentName: "assistant", + pluginNames: ["analytics", "files", "genie"], + toolNames: [], + }); + expect(prompt).toContain("Active AppKit plugins: analytics, files, genie"); + }); + + test("includes guidelines", () => { + const prompt = buildBaseSystemPrompt(emptyCtx); + expect(prompt).toContain("Guidelines:"); + expect(prompt).toContain("syntax, dialect, or path rules"); + expect(prompt).toContain("summarize what matters"); + }); + + test("works with no plugins", () => { + const prompt = buildBaseSystemPrompt(emptyCtx); + expect(prompt).toContain("AI assistant running on Databricks AppKit"); + expect(prompt).not.toContain("Active AppKit plugins:"); + }); + + test("does NOT include individual tool names", () => { + const prompt = buildBaseSystemPrompt({ + agentName: "a", + pluginNames: ["analytics"], + toolNames: ["analytics.query"], + }); + expect(prompt).not.toContain("analytics.query"); + expect(prompt).not.toContain("Available tools:"); + }); +}); + +describe("composeSystemPrompt", () => { + test("concatenates base + agent prompt with double newline", () => { + const composed = composeSystemPrompt("Base prompt.", "Agent prompt."); + expect(composed).toBe("Base prompt.\n\nAgent prompt."); + }); + + test("returns base prompt alone when no agent prompt", () => { + const composed = composeSystemPrompt("Base prompt."); + expect(composed).toBe("Base prompt."); + }); + + test("returns base prompt when agent prompt is empty string", () => { + const composed = composeSystemPrompt("Base prompt.", ""); + expect(composed).toBe("Base prompt."); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/thread-store.test.ts b/packages/appkit/src/plugins/agents/tests/thread-store.test.ts new file mode 100644 index 00000000..ed4f70ba --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/thread-store.test.ts @@ -0,0 +1,138 @@ +import { describe, expect, test } from "vitest"; +import { InMemoryThreadStore } from "../thread-store"; + +describe("InMemoryThreadStore", () => { + test("create() returns a new thread with the given userId", async () => { + const store = new InMemoryThreadStore(); + const thread = await store.create("user-1"); + + expect(thread.id).toBeDefined(); + expect(thread.userId).toBe("user-1"); + expect(thread.messages).toEqual([]); + expect(thread.createdAt).toBeInstanceOf(Date); + expect(thread.updatedAt).toBeInstanceOf(Date); + }); + + test("get() returns the thread for the correct user", async () => { + const store = new InMemoryThreadStore(); + const thread = await store.create("user-1"); + + const retrieved = await store.get(thread.id, "user-1"); + expect(retrieved).toEqual(thread); + }); + + test("get() returns null for wrong user", async () => { + const store = new InMemoryThreadStore(); + const thread = await store.create("user-1"); + + const retrieved = await store.get(thread.id, "user-2"); + expect(retrieved).toBeNull(); + }); + + test("get() returns null for non-existent thread", async () => { + const store = new InMemoryThreadStore(); + const retrieved = await store.get("non-existent", "user-1"); + expect(retrieved).toBeNull(); + }); + + test("list() returns threads sorted by updatedAt desc", async () => { + const store = new InMemoryThreadStore(); + const t1 = await store.create("user-1"); + const t2 = await store.create("user-1"); + + // Make t1 more recently updated + await store.addMessage(t1.id, "user-1", { + id: "msg-1", + role: "user", + content: "hello", + createdAt: new Date(), + }); + + const threads = await store.list("user-1"); + expect(threads).toHaveLength(2); + expect(threads[0].id).toBe(t1.id); + expect(threads[1].id).toBe(t2.id); + }); + + test("list() returns empty for unknown user", async () => { + const store = new InMemoryThreadStore(); + await store.create("user-1"); + + const threads = await store.list("user-2"); + expect(threads).toEqual([]); + }); + + test("addMessage() appends to thread and updates timestamp", async () => { + const store = new InMemoryThreadStore(); + const thread = await store.create("user-1"); + const originalUpdatedAt = thread.updatedAt; + + // Small delay to ensure timestamp differs + await new Promise((r) => setTimeout(r, 5)); + + await store.addMessage(thread.id, "user-1", { + id: "msg-1", + role: "user", + content: "hello", + createdAt: new Date(), + }); + + const updated = await store.get(thread.id, "user-1"); + expect(updated?.messages).toHaveLength(1); + expect(updated?.messages[0].content).toBe("hello"); + expect(updated?.updatedAt.getTime()).toBeGreaterThanOrEqual( + originalUpdatedAt.getTime(), + ); + }); + + test("addMessage() throws for non-existent thread", async () => { + const store = new InMemoryThreadStore(); + + await expect( + store.addMessage("non-existent", "user-1", { + id: "msg-1", + role: "user", + content: "hello", + createdAt: new Date(), + }), + ).rejects.toThrow("Thread non-existent not found"); + }); + + test("delete() removes a thread and returns true", async () => { + const store = new InMemoryThreadStore(); + const thread = await store.create("user-1"); + + const deleted = await store.delete(thread.id, "user-1"); + expect(deleted).toBe(true); + + const retrieved = await store.get(thread.id, "user-1"); + expect(retrieved).toBeNull(); + }); + + test("delete() returns false for non-existent thread", async () => { + const store = new InMemoryThreadStore(); + const deleted = await store.delete("non-existent", "user-1"); + expect(deleted).toBe(false); + }); + + test("delete() returns false for wrong user", async () => { + const store = new InMemoryThreadStore(); + const thread = await store.create("user-1"); + + const deleted = await store.delete(thread.id, "user-2"); + expect(deleted).toBe(false); + }); + + test("threads are isolated per user", async () => { + const store = new InMemoryThreadStore(); + await store.create("user-1"); + await store.create("user-1"); + await store.create("user-2"); + + const user1Threads = await store.list("user-1"); + const user2Threads = await store.list("user-2"); + + expect(user1Threads).toHaveLength(2); + expect(user2Threads).toHaveLength(1); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/tool-approval-gate.test.ts b/packages/appkit/src/plugins/agents/tests/tool-approval-gate.test.ts new file mode 100644 index 00000000..1e17ddf6 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/tool-approval-gate.test.ts @@ -0,0 +1,156 @@ +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { ToolApprovalGate } from "../tool-approval-gate"; + +describe("ToolApprovalGate", () => { + let gate: ToolApprovalGate; + + beforeEach(() => { + vi.useFakeTimers(); + gate = new ToolApprovalGate(); + }); + + afterEach(() => { + vi.useRealTimers(); + }); + + test("resolves with 'approve' when a matching submit arrives", async () => { + const waiter = gate.wait({ + approvalId: "a1", + streamId: "s1", + userId: "alice", + timeoutMs: 60_000, + }); + expect(gate.size).toBe(1); + + const result = gate.submit({ + approvalId: "a1", + userId: "alice", + decision: "approve", + }); + + expect(result).toEqual({ ok: true }); + await expect(waiter).resolves.toBe("approve"); + expect(gate.size).toBe(0); + }); + + test("resolves with 'deny' on explicit deny", async () => { + const waiter = gate.wait({ + approvalId: "a2", + streamId: "s1", + userId: "alice", + timeoutMs: 60_000, + }); + gate.submit({ + approvalId: "a2", + userId: "alice", + decision: "deny", + }); + await expect(waiter).resolves.toBe("deny"); + }); + + test("auto-denies after timeoutMs with no submit", async () => { + const waiter = gate.wait({ + approvalId: "a3", + streamId: "s1", + userId: "alice", + timeoutMs: 1000, + }); + vi.advanceTimersByTime(1000); + await expect(waiter).resolves.toBe("deny"); + expect(gate.size).toBe(0); + }); + + test("refuses a submit from a different user (ownership check)", async () => { + const waiter = gate.wait({ + approvalId: "a4", + streamId: "s1", + userId: "alice", + timeoutMs: 60_000, + }); + const result = gate.submit({ + approvalId: "a4", + userId: "bob", + decision: "approve", + }); + expect(result).toEqual({ ok: false, reason: "forbidden" }); + expect(gate.size).toBe(1); + // Waiter is still pending; cleanup to let fake timers drain. + gate.submit({ + approvalId: "a4", + userId: "alice", + decision: "deny", + }); + await expect(waiter).resolves.toBe("deny"); + }); + + test("returns 'unknown' reason when approvalId is not registered", () => { + expect( + gate.submit({ approvalId: "nope", userId: "x", decision: "approve" }), + ).toEqual({ ok: false, reason: "unknown" }); + }); + + test("abortStream denies every pending gate for that stream", async () => { + const a = gate.wait({ + approvalId: "a5", + streamId: "s1", + userId: "alice", + timeoutMs: 60_000, + }); + const b = gate.wait({ + approvalId: "a6", + streamId: "s1", + userId: "alice", + timeoutMs: 60_000, + }); + const c = gate.wait({ + approvalId: "a7", + streamId: "s2", + userId: "alice", + timeoutMs: 60_000, + }); + gate.abortStream("s1"); + await expect(a).resolves.toBe("deny"); + await expect(b).resolves.toBe("deny"); + expect(gate.size).toBe(1); + // s2's waiter is still pending; settle it to clean up timers. + gate.submit({ approvalId: "a7", userId: "alice", decision: "deny" }); + await expect(c).resolves.toBe("deny"); + }); + + test("abortAll denies every pending gate", async () => { + const a = gate.wait({ + approvalId: "a8", + streamId: "s1", + userId: "alice", + timeoutMs: 60_000, + }); + const b = gate.wait({ + approvalId: "a9", + streamId: "s2", + userId: "bob", + timeoutMs: 60_000, + }); + gate.abortAll(); + await expect(a).resolves.toBe("deny"); + await expect(b).resolves.toBe("deny"); + expect(gate.size).toBe(0); + }); + + test("a timed-out approval cannot be resolved by a late submit", async () => { + const waiter = gate.wait({ + approvalId: "a10", + streamId: "s1", + userId: "alice", + timeoutMs: 500, + }); + vi.advanceTimersByTime(500); + await expect(waiter).resolves.toBe("deny"); + + const late = gate.submit({ + approvalId: "a10", + userId: "alice", + decision: "approve", + }); + expect(late).toEqual({ ok: false, reason: "unknown" }); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/tool.test.ts b/packages/appkit/src/plugins/agents/tests/tool.test.ts new file mode 100644 index 00000000..3d47f3a9 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/tool.test.ts @@ -0,0 +1,110 @@ +import { describe, expect, test } from "vitest"; +import { z } from "zod"; +import { formatZodError, tool } from "../tools/tool"; + +describe("tool()", () => { + test("produces a FunctionTool with JSON Schema parameters from the Zod schema", () => { + const weather = tool({ + name: "get_weather", + description: "Get the current weather for a city", + schema: z.object({ + city: z.string().describe("City name"), + }), + execute: async ({ city }) => `Sunny in ${city}`, + }); + + expect(weather.type).toBe("function"); + expect(weather.name).toBe("get_weather"); + expect(weather.description).toBe("Get the current weather for a city"); + expect(weather.parameters).toMatchObject({ + type: "object", + properties: { + city: { type: "string", description: "City name" }, + }, + required: ["city"], + }); + }); + + test("execute receives typed args on valid input", async () => { + const echo = tool({ + name: "echo", + schema: z.object({ message: z.string() }), + execute: async ({ message }) => { + const _typed: string = message; + return `got ${_typed}`; + }, + }); + + const result = await echo.execute({ message: "hi" }); + expect(result).toBe("got hi"); + }); + + test("returns formatted error string (does not throw) when args are invalid", async () => { + const weather = tool({ + name: "get_weather", + schema: z.object({ city: z.string() }), + execute: async ({ city }) => `Sunny in ${city}`, + }); + + const result = await weather.execute({}); + expect(typeof result).toBe("string"); + expect(result).toContain("Invalid arguments for get_weather"); + expect(result).toContain("city"); + }); + + test("joins multiple validation errors with '; '", async () => { + const t = tool({ + name: "multi", + schema: z.object({ a: z.string(), b: z.number() }), + execute: async () => "ok", + }); + + const result = await t.execute({}); + expect(result).toContain("a:"); + expect(result).toContain("b:"); + expect(result).toContain(";"); + }); + + test("optional fields validate when absent", async () => { + const t = tool({ + name: "opt", + schema: z.object({ note: z.string().optional() }), + execute: async ({ note }) => note ?? "(no note)", + }); + + expect(await t.execute({})).toBe("(no note)"); + expect(await t.execute({ note: "hello" })).toBe("hello"); + }); + + test("description falls back to the tool name when omitted", () => { + const t = tool({ + name: "my_tool", + schema: z.object({}), + execute: async () => "ok", + }); + + expect(t.description).toBe("my_tool"); + expect(t.parameters).toBeDefined(); + }); +}); + +describe("formatZodError", () => { + test("formats a single issue with the tool name", () => { + const schema = z.object({ city: z.string() }); + const result = schema.safeParse({}); + if (result.success) throw new Error("expected failure"); + + const msg = formatZodError(result.error, "get_weather"); + expect(msg).toMatch(/^Invalid arguments for get_weather: /); + expect(msg).toContain("city:"); + }); + + test("joins multiple issues with '; '", () => { + const schema = z.object({ a: z.string(), b: z.number() }); + const result = schema.safeParse({}); + if (result.success) throw new Error("expected failure"); + + const msg = formatZodError(result.error, "t"); + expect(msg.split(";").length).toBeGreaterThanOrEqual(2); + }); +}); diff --git a/packages/appkit/src/plugins/agents/thread-store.ts b/packages/appkit/src/plugins/agents/thread-store.ts new file mode 100644 index 00000000..7c4622cd --- /dev/null +++ b/packages/appkit/src/plugins/agents/thread-store.ts @@ -0,0 +1,66 @@ +import { randomUUID } from "node:crypto"; +import type { Message, Thread, ThreadStore } from "shared"; + +/** + * In-memory thread store backed by a nested Map. + * + * Outer key: userId, inner key: threadId. Thread history is retained for the + * lifetime of the process with no eviction, caps, or TTL — a chatty user will + * grow the in-memory footprint monotonically, and the server loses every + * thread on restart. **This implementation is intended for local development + * and single-process demos only.** + * + * For any real deployment, pass a persistent `ThreadStore` to `agents({ ... })` + * (e.g. a Lakebase- or Postgres-backed implementation). A bounded + * `InMemoryThreadStore` with eviction policies is tracked as a follow-up. + */ +export class InMemoryThreadStore implements ThreadStore { + private store = new Map>(); + + async create(userId: string): Promise { + const now = new Date(); + const thread: Thread = { + id: randomUUID(), + userId, + messages: [], + createdAt: now, + updatedAt: now, + }; + this.userMap(userId).set(thread.id, thread); + return thread; + } + + async get(threadId: string, userId: string): Promise { + return this.userMap(userId).get(threadId) ?? null; + } + + async list(userId: string): Promise { + return Array.from(this.userMap(userId).values()).sort( + (a, b) => b.updatedAt.getTime() - a.updatedAt.getTime(), + ); + } + + async addMessage( + threadId: string, + userId: string, + message: Message, + ): Promise { + const thread = this.userMap(userId).get(threadId); + if (!thread) throw new Error(`Thread ${threadId} not found`); + thread.messages.push(message); + thread.updatedAt = new Date(); + } + + async delete(threadId: string, userId: string): Promise { + return this.userMap(userId).delete(threadId); + } + + private userMap(userId: string): Map { + let map = this.store.get(userId); + if (!map) { + map = new Map(); + this.store.set(userId, map); + } + return map; + } +} diff --git a/packages/appkit/src/plugins/agents/tool-approval-gate.ts b/packages/appkit/src/plugins/agents/tool-approval-gate.ts new file mode 100644 index 00000000..669f30a9 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tool-approval-gate.ts @@ -0,0 +1,122 @@ +/** + * Server-side state for the human-in-the-loop approval gate on + * `destructive: true` agent tool calls. + * + * Lifecycle: + * + * 1. `wait(...)` is called from inside `executeTool` when a destructive tool + * is about to execute. A `Pending` record is registered and a timer is + * scheduled for auto-deny. The returned promise is what blocks the + * adapter until the decision arrives. + * 2. The client receives an `appkit.approval_pending` SSE event carrying the + * `approvalId` + `streamId` and posts a decision to `POST /chat/approve`. + * The route calls {@link ToolApprovalGate.submit} which resolves the + * pending promise and clears the timer. + * 3. If no submit arrives within `timeoutMs`, the timer fires and the + * promise resolves with `"deny"`. + * + * Security invariants: + * + * - `submit` verifies that the decider's user id matches the user who + * initiated the stream (set by `wait`). Mismatches are rejected without + * resolving the pending promise — this prevents a second user from + * approving (or denying) another user's destructive action. + * - `abort(streamId)` cancels every pending gate for a stream and denies + * each one. Used when the enclosing stream is cancelled or the plugin is + * shutting down. + */ +type ApprovalDecision = "approve" | "deny"; + +interface Pending { + resolve: (decision: ApprovalDecision) => void; + userId: string; + streamId: string; + timeout: ReturnType; +} + +type ApprovalSubmitResult = + | { ok: true } + | { ok: false; reason: "unknown" | "forbidden" }; + +export class ToolApprovalGate { + private pending = new Map(); + + /** + * Register a pending approval and return a promise that resolves with the + * user's decision or with `"deny"` when the timeout elapses. The returned + * promise never rejects. + */ + wait(args: { + approvalId: string; + streamId: string; + userId: string; + timeoutMs: number; + }): Promise { + const { approvalId, streamId, userId, timeoutMs } = args; + return new Promise((resolve) => { + const timeout = setTimeout(() => { + if (this.pending.delete(approvalId)) { + resolve("deny"); + } + }, timeoutMs); + this.pending.set(approvalId, { + resolve, + userId, + streamId, + timeout, + }); + }); + } + + /** + * Settle an approval with a user decision. Returns: + * - `{ ok: true }` if the pending record existed, the userId matched, and + * the promise was resolved. + * - `{ ok: false, reason: "unknown" }` if no pending record matches the id. + * - `{ ok: false, reason: "forbidden" }` if the userId does not match the + * user who initiated the stream. + */ + submit(args: { + approvalId: string; + userId: string; + decision: ApprovalDecision; + }): ApprovalSubmitResult { + const { approvalId, userId, decision } = args; + const p = this.pending.get(approvalId); + if (!p) return { ok: false, reason: "unknown" }; + if (p.userId !== userId) return { ok: false, reason: "forbidden" }; + clearTimeout(p.timeout); + this.pending.delete(approvalId); + p.resolve(decision); + return { ok: true }; + } + + /** + * Cancel all pending gates for a specific stream (e.g., when the user + * cancels the stream). Each gate resolves with `"deny"` so the adapter + * unwinds cleanly. + */ + abortStream(streamId: string): void { + for (const [id, p] of this.pending) { + if (p.streamId === streamId) { + clearTimeout(p.timeout); + this.pending.delete(id); + p.resolve("deny"); + } + } + } + + /** Cancel every pending gate. Used at plugin shutdown. */ + abortAll(): void { + for (const [id, p] of this.pending) { + clearTimeout(p.timeout); + this.pending.delete(id); + p.resolve("deny"); + } + } + + /** Number of pending approvals (test/diagnostic helper). */ + get size(): number { + return this.pending.size; + } +} diff --git a/packages/appkit/src/plugins/agents/tool-dispatch.ts b/packages/appkit/src/plugins/agents/tool-dispatch.ts new file mode 100644 index 00000000..a3e220bb --- /dev/null +++ b/packages/appkit/src/plugins/agents/tool-dispatch.ts @@ -0,0 +1,97 @@ +import type express from "express"; +import type { AppKitMcpClient } from "../../connectors/mcp"; +import type { PluginContext } from "../../core/plugin-context"; +import type { ResolvedToolEntry } from "./types"; + +interface ToolDispatchContext { + /** + * The originating HTTP request. Used by `toolkit` entries to scope execution + * to the caller's user context (`asUser(req)`) and by `mcp` entries to pick + * up the OBO bearer token from `x-forwarded-access-token`. + */ + req: express.Request; + /** Cancellation signal, forwarded to the tool implementation. */ + signal: AbortSignal; + /** + * PluginContext mediator — required to dispatch `toolkit` entries. Absent in + * unit tests that construct `AgentsPlugin` directly; callers may pass + * `null` / `undefined`, in which case toolkit calls throw a clear error. + */ + pluginContext?: PluginContext | null; + /** Live MCP client. Required for `mcp` entries. */ + mcpClient?: AppKitMcpClient | null; + /** + * Delegates a sub-agent invocation. The closure owns the recursion depth so + * the dispatcher itself remains depth-agnostic — the top-level caller + * passes `depth = 1`, and a sub-agent's inner dispatcher passes `depth + 1`. + */ + runSubAgent: (agentName: string, args: unknown) => Promise; +} + +/** + * Fan-out a resolved tool entry to the correct executor. One place to add a + * fifth `source` variant; `never`-typed default forces every caller to + * update in lockstep. + * + * This only handles dispatch — result normalisation (`normalizeToolResult`), + * budget counting, and approval gating remain at the call site, where each + * stream has different policies. + */ +export async function dispatchToolCall( + entry: ResolvedToolEntry, + args: unknown, + ctx: ToolDispatchContext, +): Promise { + switch (entry.source) { + case "toolkit": { + if (!ctx.pluginContext) { + throw new Error( + "Plugin tool execution requires PluginContext; " + + "this should never happen through createApp.", + ); + } + return ctx.pluginContext.executeTool( + ctx.req, + entry.pluginName, + entry.localName, + args, + ctx.signal, + ); + } + case "function": + return entry.functionTool.execute(args as Record); + case "mcp": { + if (!ctx.mcpClient) throw new Error("MCP client not connected"); + return ctx.mcpClient.callTool( + entry.mcpToolName, + args, + extractOboMcpAuth(ctx.req), + ); + } + case "subagent": + return ctx.runSubAgent(entry.agentName, args); + default: { + // Exhaustiveness guard: adding a new `source` to ResolvedToolEntry + // without teaching this switch breaks the build here. + const _exhaustive: never = entry; + throw new Error( + `Unsupported tool source: ${(_exhaustive as ResolvedToolEntry).source}`, + ); + } + } +} + +/** + * Extracts the caller's OBO bearer token from the standard Databricks Apps + * forwarded-auth header. MCP destinations that `forwardWorkspaceAuth` admits + * as same-origin will receive this header; non-workspace destinations drop + * it inside {@link AppKitMcpClient.callTool}. + */ +function extractOboMcpAuth( + req: express.Request, +): Record | undefined { + const oboToken = req.headers["x-forwarded-access-token"]; + return typeof oboToken === "string" + ? { Authorization: `Bearer ${oboToken}` } + : undefined; +} diff --git a/packages/appkit/src/plugins/agents/tools/define-tool.ts b/packages/appkit/src/plugins/agents/tools/define-tool.ts new file mode 100644 index 00000000..dc269ba6 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tools/define-tool.ts @@ -0,0 +1,94 @@ +import type { AgentToolDefinition, ToolAnnotations } from "shared"; +import type { z } from "zod"; +import { toToolJSONSchema } from "./json-schema"; +import { formatZodError } from "./tool"; + +/** + * Single-tool entry for a plugin's internal tool registry. + * + * Plugins collect these into a `Record` keyed by the tool's + * public name and dispatch via `executeFromRegistry`. + */ +export interface ToolEntry { + description: string; + schema: S; + annotations?: ToolAnnotations; + /** + * Whether this tool is eligible for auto-inheritance into markdown or + * code-defined agents that enable `autoInheritTools`. Defaults to `false` + * (safe-by-default) — plugin authors must explicitly opt a tool in if they + * consider it safe enough to appear in every agent's tool record without an + * explicit `tools:` declaration. Destructive or privilege-sensitive tools + * should leave this unset so that they only reach agents that wire them + * explicitly (via `tools:`, `toolkits:`, or `fromPlugin({ only: [...] })`). + */ + autoInheritable?: boolean; + handler: ( + args: z.infer, + signal?: AbortSignal, + ) => unknown | Promise; +} + +export type ToolRegistry = Record; + +/** + * Defines a single tool entry for a plugin's internal registry. + * + * The generic `S` flows from `schema` through to the `handler` callback so + * `args` is fully typed from the Zod schema. Names are assigned by the + * registry key, so they are not repeated inside the entry. + */ +export function defineTool( + config: ToolEntry, +): ToolEntry { + return config; +} + +/** + * Validates tool-call arguments against the entry's schema and invokes its + * handler. On validation failure, returns an LLM-friendly error string + * (matching the behavior of `tool()`) rather than throwing, so the model + * can self-correct on its next turn. + */ +export async function executeFromRegistry( + registry: ToolRegistry, + name: string, + args: unknown, + signal?: AbortSignal, +): Promise { + const entry = registry[name]; + if (!entry) { + throw new Error(`Unknown tool: ${name}`); + } + const parsed = entry.schema.safeParse(args); + if (!parsed.success) { + return formatZodError(parsed.error, name); + } + return entry.handler(parsed.data, signal); +} + +/** + * Produces the `AgentToolDefinition[]` a ToolProvider exposes to the LLM, + * deriving `parameters` JSON Schema from each entry's Zod schema. + * + * Tool names come from registry keys (supports dotted names like + * `uploads.list` for dynamic plugins). + */ +export function toolsFromRegistry( + registry: ToolRegistry, +): AgentToolDefinition[] { + return Object.entries(registry).map(([name, entry]) => { + const parameters = toToolJSONSchema( + entry.schema, + ) as unknown as AgentToolDefinition["parameters"]; + const def: AgentToolDefinition = { + name, + description: entry.description, + parameters, + }; + if (entry.annotations) { + def.annotations = entry.annotations; + } + return def; + }); +} diff --git a/packages/appkit/src/plugins/agents/tools/function-tool.ts b/packages/appkit/src/plugins/agents/tools/function-tool.ts new file mode 100644 index 00000000..19820f8f --- /dev/null +++ b/packages/appkit/src/plugins/agents/tools/function-tool.ts @@ -0,0 +1,44 @@ +import type { AgentToolDefinition, ToolAnnotations } from "shared"; + +export interface FunctionTool { + type: "function"; + name: string; + description?: string | null; + parameters?: Record | null; + strict?: boolean | null; + /** + * Behavioural hints that drive the agents plugin's approval gate and the + * client's approval-card styling. Prefer setting `effect` (one of + * `"read" | "write" | "update" | "destructive"`) — any mutating value + * forces HITL approval before `execute()` runs. Legacy `destructive: true` + * is still honoured. Must be preserved through {@link + * functionToolToDefinition} so the plugin sees them when building agent + * tool indexes. + */ + annotations?: ToolAnnotations; + execute: (args: Record) => Promise | string; +} + +export function isFunctionTool(value: unknown): value is FunctionTool { + if (typeof value !== "object" || value === null) return false; + const obj = value as Record; + return ( + obj.type === "function" && + typeof obj.name === "string" && + typeof obj.execute === "function" + ); +} + +export function functionToolToDefinition( + tool: FunctionTool, +): AgentToolDefinition { + return { + name: tool.name, + description: tool.description ?? tool.name, + parameters: (tool.parameters as AgentToolDefinition["parameters"]) ?? { + type: "object", + properties: {}, + }, + ...(tool.annotations ? { annotations: tool.annotations } : {}), + }; +} diff --git a/packages/appkit/src/plugins/agents/tools/hosted-tools.ts b/packages/appkit/src/plugins/agents/tools/hosted-tools.ts new file mode 100644 index 00000000..c1f06767 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tools/hosted-tools.ts @@ -0,0 +1,98 @@ +import type { McpEndpointConfig } from "../../../connectors/mcp"; + +export interface GenieTool { + type: "genie-space"; + genie_space: { id: string }; +} + +export interface VectorSearchIndexTool { + type: "vector_search_index"; + vector_search_index: { name: string }; +} + +export interface CustomMcpServerTool { + type: "custom_mcp_server"; + custom_mcp_server: { app_name: string; app_url: string }; +} + +export interface ExternalMcpServerTool { + type: "external_mcp_server"; + external_mcp_server: { connection_name: string }; +} + +export type HostedTool = + | GenieTool + | VectorSearchIndexTool + | CustomMcpServerTool + | ExternalMcpServerTool; + +const HOSTED_TOOL_TYPES = new Set([ + "genie-space", + "vector_search_index", + "custom_mcp_server", + "external_mcp_server", +]); + +export function isHostedTool(value: unknown): value is HostedTool { + if (typeof value !== "object" || value === null) return false; + const obj = value as Record; + return typeof obj.type === "string" && HOSTED_TOOL_TYPES.has(obj.type); +} + +/** + * Resolves HostedTool configs into MCP endpoint configurations + * that the MCP client can connect to. + */ +function resolveHostedTool(tool: HostedTool): McpEndpointConfig { + switch (tool.type) { + case "genie-space": + return { + name: `genie-${tool.genie_space.id}`, + url: `/api/2.0/mcp/genie/${tool.genie_space.id}`, + }; + case "vector_search_index": { + const parts = tool.vector_search_index.name.split("."); + if (parts.length !== 3) { + throw new Error( + `vector_search_index name must be 3-part dotted (catalog.schema.index), got: ${tool.vector_search_index.name}`, + ); + } + return { + name: `vs-${parts.join("-")}`, + url: `/api/2.0/mcp/vector-search/${parts[0]}/${parts[1]}/${parts[2]}`, + }; + } + case "custom_mcp_server": + return { + name: tool.custom_mcp_server.app_name, + url: tool.custom_mcp_server.app_url, + }; + case "external_mcp_server": + return { + name: tool.external_mcp_server.connection_name, + url: `/api/2.0/mcp/external/${tool.external_mcp_server.connection_name}`, + }; + } +} + +export function resolveHostedTools(tools: HostedTool[]): McpEndpointConfig[] { + return tools.map(resolveHostedTool); +} + +/** + * Factory for declaring a custom MCP server tool. + * + * Replaces the verbose `{ type: "custom_mcp_server", custom_mcp_server: { app_name, app_url } }` + * wrapper with a concise positional call. + * + * Example: + * ```ts + * mcpServer("my-app", "https://my-app.databricksapps.com/mcp") + * ``` + */ +export function mcpServer(name: string, url: string): CustomMcpServerTool { + return { + type: "custom_mcp_server", + custom_mcp_server: { app_name: name, app_url: url }, + }; +} diff --git a/packages/appkit/src/plugins/agents/tools/index.ts b/packages/appkit/src/plugins/agents/tools/index.ts new file mode 100644 index 00000000..004c96b5 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tools/index.ts @@ -0,0 +1,19 @@ +export { + defineTool, + executeFromRegistry, + type ToolEntry, + type ToolRegistry, + toolsFromRegistry, +} from "./define-tool"; +export { + type FunctionTool, + functionToolToDefinition, + isFunctionTool, +} from "./function-tool"; +export { + type HostedTool, + isHostedTool, + mcpServer, + resolveHostedTools, +} from "./hosted-tools"; +export { type ToolConfig, tool } from "./tool"; diff --git a/packages/appkit/src/plugins/agents/tools/json-schema.ts b/packages/appkit/src/plugins/agents/tools/json-schema.ts new file mode 100644 index 00000000..c5c10dbf --- /dev/null +++ b/packages/appkit/src/plugins/agents/tools/json-schema.ts @@ -0,0 +1,20 @@ +import { toJSONSchema, type z } from "zod"; + +/** + * Converts a Zod schema to JSON Schema suitable for an LLM tool-call + * `parameters` field. + * + * Wraps `zod`'s `toJSONSchema()` and strips the top-level `$schema` annotation + * that Zod v4 emits by default (e.g. `"https://json-schema.org/draft/..."`). + * The Databricks Mosaic serving endpoint forwards tool schemas to Google's + * Gemini `function_declarations` format, which rejects any top-level key it + * doesn't explicitly recognize — including `$schema` — with a 400 + * `Invalid JSON payload received. Unknown name "$schema"` error. Other LLM + * providers either ignore the field or also trip on it, so stripping here is + * safe across backends. + */ +export function toToolJSONSchema(schema: z.ZodType): Record { + const raw = toJSONSchema(schema) as Record; + const { $schema: _ignored, ...rest } = raw; + return rest; +} diff --git a/packages/appkit/src/plugins/agents/tools/sql-policy.ts b/packages/appkit/src/plugins/agents/tools/sql-policy.ts new file mode 100644 index 00000000..6f889d44 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tools/sql-policy.ts @@ -0,0 +1,317 @@ +/** + * Conservative SQL classifier used by agent-facing query tools to enforce + * `readOnly: true` annotations at execution time. + * + * Why a hand-rolled tokenizer rather than `node-sql-parser` or `pgsql-parser`: + * + * - `node-sql-parser`'s Hive/Spark dialect coverage rejects common Databricks + * SQL patterns (three-part `catalog.schema.table` names, `SHOW TABLES IN`, + * `DESCRIBE EXTENDED`, `EXPLAIN`) that must be allowed by a read-only + * classifier. Its PostgreSQL grammar rejects `SHOW`/`DESCRIBE` too. + * - `pgsql-parser` (libpg_query) is a native binding and fails to install + * cleanly on every Databricks App runtime we care about. + * + * We don't need to fully parse SQL — we only need to decide whether every + * statement in the batch starts with a read-only keyword. A small tokenizer + * that correctly strips strings, identifiers, and comments is enough and + * costs no extra dependencies. + * + * What this classifier guarantees (when it returns `readOnly: true`): + * + * - Every semicolon-separated statement outside a string, identifier, or + * comment begins with `SELECT`, `WITH`, `SHOW`, `EXPLAIN`, `DESCRIBE`, or + * `DESC`. + * - `SELECT 1; DROP TABLE x` is rejected (stacked write detected). + * - `SELECT 'value; DROP TABLE x'` passes (literal inside a string). + * - `-- DROP TABLE x\nSELECT 1` passes (comment stripped). + * - `SELECT 1 ` passes (comment stripped). + * + * What this classifier does NOT guarantee: + * + * - A `SELECT` statement may still have side effects via function calls + * (`SELECT pg_advisory_lock(...)`, `SELECT lo_import('/etc/passwd')`, CTEs + * with DML in Postgres 9.1+). Callers that need stronger guarantees should + * combine this check with a runtime mechanism: for PostgreSQL, execute the + * statement inside a dedicated client's `BEGIN READ ONLY … ROLLBACK` + * transaction (see `LakebasePlugin.runReadOnlyStatement`). A batched + * `pool.query("BEGIN READ ONLY; ; ROLLBACK")` cannot be used because + * the Postgres Extended Query protocol rejects multi-statement prepared + * queries, which silently breaks parameterized SQL. + */ + +const READ_ONLY_KEYWORDS = new Set([ + "SELECT", + "WITH", + "SHOW", + "EXPLAIN", + "DESCRIBE", + "DESC", +]); + +type SqlReadOnlyResult = + | { readOnly: true; statements: number } + | { readOnly: false; reason: string }; + +/** + * Classify a SQL string as read-only or not. See module docstring for the + * precise guarantee this offers. + */ +export function classifyReadOnly(sql: string): SqlReadOnlyResult { + const strip = stripCommentsAndQuoted(sql); + if (strip.unterminated) { + return { + readOnly: false, + reason: `SQL has an unterminated ${strip.unterminated} literal`, + }; + } + const statements = splitStatements(strip.cleaned); + + if (statements.length === 0) { + return { + readOnly: false, + reason: "SQL is empty or contains only comments", + }; + } + + for (let i = 0; i < statements.length; i++) { + const stmt = statements[i]; + const firstWord = firstKeyword(stmt); + if (!firstWord) { + return { + readOnly: false, + reason: `statement ${i + 1} of ${statements.length} is empty`, + }; + } + if (!READ_ONLY_KEYWORDS.has(firstWord.toUpperCase())) { + return { + readOnly: false, + reason: `statement starts with '${firstWord}'; only SELECT, WITH, SHOW, EXPLAIN, DESCRIBE, DESC are allowed in read-only mode`, + }; + } + } + + return { readOnly: true, statements: statements.length }; +} + +/** + * Assert `sql` is read-only or throw {@link ReadOnlySqlViolation}. Suitable + * for calling from agent-tool handlers where the thrown string surfaces back + * to the LLM as the tool's error output. + */ +export function assertReadOnlySql(sql: string): void { + const result = classifyReadOnly(sql); + if (!result.readOnly) { + throw new ReadOnlySqlViolation(result.reason); + } +} + +export class ReadOnlySqlViolation extends Error { + constructor(reason: string) { + super(`SQL read-only policy violation: ${reason}`); + this.name = "ReadOnlySqlViolation"; + } +} + +// --------------------------------------------------------------------------- +// Tokenizer helpers +// --------------------------------------------------------------------------- + +/** + * Walk `sql` character-by-character and replace every string literal, + * identifier quote, and comment body with a single space of equivalent + * length. Leaves structural tokens (semicolons, whitespace, identifiers, + * operators) in place. + * + * Handles: + * - `-- line comments` through end-of-line + * - SQL block comments (slash-star ... star-slash) with correct nesting (PostgreSQL) + * - `'single-quoted strings'` with `''` escape + * - `"double-quoted identifiers"` with `""` escape (ANSI) + * - `` `backtick identifiers` `` (Databricks) + * - `$tag$dollar quoted$tag$` strings (PostgreSQL) + * - `E'escape-style'` strings (PostgreSQL) + */ +type StripResult = { + cleaned: string; + /** Non-null if tokenization ended inside an unterminated literal or comment. */ + unterminated: + | null + | "string" + | "identifier" + | "block comment" + | "dollar-quoted string"; +}; + +function stripCommentsAndQuoted(sql: string): StripResult { + const out: string[] = []; + let i = 0; + const n = sql.length; + let unterminated: StripResult["unterminated"] = null; + + while (i < n) { + const ch = sql[i]; + const next = i + 1 < n ? sql[i + 1] : ""; + + if (ch === "-" && next === "-") { + out.push(" "); + i += 2; + while (i < n && sql[i] !== "\n") { + out.push(" "); + i++; + } + continue; + } + + if (ch === "/" && next === "*") { + out.push(" "); + i += 2; + let depth = 1; + while (i < n && depth > 0) { + if (sql[i] === "/" && sql[i + 1] === "*") { + out.push(" "); + i += 2; + depth++; + continue; + } + if (sql[i] === "*" && sql[i + 1] === "/") { + out.push(" "); + i += 2; + depth--; + continue; + } + out.push(sql[i] === "\n" ? "\n" : " "); + i++; + } + if (depth > 0) { + unterminated = "block comment"; + } + continue; + } + + if ( + ch === "'" || + (ch === "E" && next === "'") || + (ch === "e" && next === "'") + ) { + if (ch === "E" || ch === "e") { + out.push(" "); + i++; + } + out.push(" "); + i++; + let closed = false; + while (i < n) { + if (sql[i] === "'" && sql[i + 1] === "'") { + out.push(" "); + i += 2; + continue; + } + if (sql[i] === "\\" && sql[i + 1]) { + out.push(" "); + i += 2; + continue; + } + if (sql[i] === "'") { + out.push(" "); + i++; + closed = true; + break; + } + out.push(sql[i] === "\n" ? "\n" : " "); + i++; + } + if (!closed) unterminated = "string"; + continue; + } + + if (ch === '"') { + out.push(" "); + i++; + let closed = false; + while (i < n) { + if (sql[i] === '"' && sql[i + 1] === '"') { + out.push(" "); + i += 2; + continue; + } + if (sql[i] === '"') { + out.push(" "); + i++; + closed = true; + break; + } + out.push(sql[i] === "\n" ? "\n" : " "); + i++; + } + if (!closed) unterminated = "identifier"; + continue; + } + + if (ch === "`") { + out.push(" "); + i++; + let closed = false; + while (i < n) { + if (sql[i] === "`" && sql[i + 1] === "`") { + out.push(" "); + i += 2; + continue; + } + if (sql[i] === "`") { + out.push(" "); + i++; + closed = true; + break; + } + out.push(sql[i] === "\n" ? "\n" : " "); + i++; + } + if (!closed) unterminated = "identifier"; + continue; + } + + if (ch === "$") { + const tagMatch = sql.slice(i).match(/^\$([A-Za-z_][A-Za-z0-9_]*)?\$/); + if (tagMatch) { + const tag = tagMatch[0]; + out.push(" ".repeat(tag.length)); + i += tag.length; + const closeIdx = sql.indexOf(tag, i); + if (closeIdx === -1) { + while (i < n) { + out.push(sql[i] === "\n" ? "\n" : " "); + i++; + } + unterminated = "dollar-quoted string"; + } else { + while (i < closeIdx) { + out.push(sql[i] === "\n" ? "\n" : " "); + i++; + } + out.push(" ".repeat(tag.length)); + i += tag.length; + } + continue; + } + } + + out.push(ch); + i++; + } + + return { cleaned: out.join(""), unterminated }; +} + +/** Split on unquoted `;`, trim, drop empty segments. */ +function splitStatements(cleanedSql: string): string[] { + return cleanedSql + .split(";") + .map((s) => s.trim()) + .filter((s) => s.length > 0); +} + +/** Return the first bareword keyword of a statement, or null if empty. */ +function firstKeyword(stmt: string): string | null { + const match = stmt.match(/^\s*([A-Za-z_][A-Za-z0-9_]*)/); + return match ? match[1] : null; +} diff --git a/packages/appkit/src/plugins/agents/tools/tool.ts b/packages/appkit/src/plugins/agents/tools/tool.ts new file mode 100644 index 00000000..53305c23 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tools/tool.ts @@ -0,0 +1,64 @@ +import type { ToolAnnotations } from "shared"; +import type { z } from "zod"; +import type { FunctionTool } from "./function-tool"; +import { toToolJSONSchema } from "./json-schema"; + +export interface ToolConfig { + name: string; + description?: string; + schema: S; + /** + * Behavioural hints forwarded to the resolved tool definition. Prefer + * `effect` (`"read" | "write" | "update" | "destructive"`) — any mutating + * value forces the agents-plugin approval gate before `execute()` runs + * and the client's approval card will colour itself accordingly. Legacy + * `destructive: true` still gates. Dropped silently before the fix that + * added this field. + */ + annotations?: ToolAnnotations; + execute: (args: z.infer) => Promise | string; +} + +/** + * Factory for defining function tools with Zod schemas. + * + * - Generates JSON Schema (for the LLM) from the Zod schema via `z.toJSONSchema()`. + * - Infers the `execute` argument type from the schema. + * - Validates tool call arguments at runtime. On validation failure, returns + * a formatted error string to the LLM instead of throwing, so the model + * can self-correct on its next turn. + */ +export function tool(config: ToolConfig): FunctionTool { + const parameters = toToolJSONSchema(config.schema) as unknown as Record< + string, + unknown + >; + + return { + type: "function", + name: config.name, + description: config.description ?? config.name, + parameters, + ...(config.annotations ? { annotations: config.annotations } : {}), + execute: async (args: Record) => { + const parsed = config.schema.safeParse(args); + if (!parsed.success) { + return formatZodError(parsed.error, config.name); + } + return config.execute(parsed.data as z.infer); + }, + }; +} + +/** + * Formats a Zod validation error into an LLM-friendly string. + * + * Example: `Invalid arguments for get_weather: city: Invalid input: expected string, received undefined` + */ +export function formatZodError(error: z.ZodError, toolName: string): string { + const parts = error.issues.map((issue) => { + const field = issue.path.length > 0 ? issue.path.join(".") : "(root)"; + return `${field}: ${issue.message}`; + }); + return `Invalid arguments for ${toolName}: ${parts.join("; ")}`; +} diff --git a/packages/appkit/src/plugins/agents/types.ts b/packages/appkit/src/plugins/agents/types.ts new file mode 100644 index 00000000..14366e9a --- /dev/null +++ b/packages/appkit/src/plugins/agents/types.ts @@ -0,0 +1,225 @@ +import type { + AgentAdapter, + AgentToolDefinition, + BasePluginConfig, + ThreadStore, + ToolAnnotations, +} from "shared"; +import type { McpHostPolicyConfig } from "../../connectors/mcp"; +import type { FunctionTool } from "./tools/function-tool"; +import type { HostedTool } from "./tools/hosted-tools"; + +/** + * A tool reference produced by a plugin's `.toolkit()` call. The agents plugin + * recognizes the `__toolkitRef` brand and dispatches tool invocations through + * `PluginContext.executeTool(req, pluginName, localName, ...)`, preserving + * OBO (asUser) and telemetry spans. + */ +export interface ToolkitEntry { + readonly __toolkitRef: true; + pluginName: string; + localName: string; + def: AgentToolDefinition; + annotations?: ToolAnnotations; + /** + * Whether this tool is eligible for `autoInheritTools` spreading. Mirrors + * {@link ToolEntry.autoInheritable} from the source registry so the agents + * plugin can filter auto-inherited tools without re-walking the provider's + * internal registry. + */ + autoInheritable?: boolean; +} + +/** + * Any tool an agent can invoke: inline function tools (`tool()`), hosted MCP + * tools (`mcpServer()` / raw hosted), or toolkit references from plugins + * (`analytics().toolkit()`). + */ +export type AgentTool = FunctionTool | HostedTool | ToolkitEntry; + +export interface ToolkitOptions { + /** Key prefix to prepend to each tool's local name. Defaults to `${pluginName}.`. */ + prefix?: string; + /** Only include tools whose local name matches one of these. */ + only?: string[]; + /** Exclude tools whose local name matches one of these. */ + except?: string[]; + /** Remap specific local names to different keys (applied after prefix). */ + rename?: Record; +} + +/** + * Context passed to `baseSystemPrompt` callbacks. + */ +export interface PromptContext { + agentName: string; + pluginNames: string[]; + toolNames: string[]; +} + +export type BaseSystemPromptOption = + | false + | string + | ((ctx: PromptContext) => string); + +export interface AgentDefinition { + /** Filled in from the enclosing key when used in `agents: { foo: def }`. */ + name?: string; + /** System prompt body. For markdown-loaded agents this is the file body. */ + instructions: string; + /** + * Model adapter (or endpoint-name string sugar for + * `DatabricksAdapter.fromServingEndpoint({ endpointName })`). Optional — + * falls back to the plugin's `defaultModel`. + */ + model?: AgentAdapter | Promise | string; + /** Per-agent tool record. Key is the LLM-visible tool-call name. */ + tools?: Record; + /** Sub-agents, exposed as `agent-` tools on this agent. */ + agents?: Record; + /** Override the plugin's baseSystemPrompt for this agent only. */ + baseSystemPrompt?: BaseSystemPromptOption; + maxSteps?: number; + maxTokens?: number; + /** + * When true, the thread used for a chat request against this agent is + * deleted from `ThreadStore` after the stream completes (success or + * failure). Use for stateless one-shot agents — e.g. autocomplete, where + * each request is independent and retaining history would both poison + * future calls and accumulate unbounded state in the default + * `InMemoryThreadStore`. Defaults to `false`. + */ + ephemeral?: boolean; +} + +/** + * Auto-inherit configuration. When enabled for a given agent origin, agents + * with no explicit `tools:` declaration receive every registered ToolProvider + * plugin tool whose author marked `autoInheritable: true`. Tools without that + * flag — destructive, state-mutating, or privilege-sensitive — never spread + * automatically and must be wired via `tools:`, `toolkits:`, or `fromPlugin`. + * + * Defaults are `false` for both origins (safe-by-default): developers must + * consciously opt an origin in to any auto-inherit behaviour. + */ +export interface AutoInheritToolsConfig { + /** Default for agents loaded from markdown files. Default: `false`. */ + file?: boolean; + /** Default for code-defined agents (via `agents: { foo: createAgent(...) }`). Default: `false`. */ + code?: boolean; +} + +export interface AgentsPluginConfig extends BasePluginConfig { + /** Directory of agent packages (`/agent.md` each). Default `./config/agents`. Set to `false` to disable. */ + dir?: string | false; + /** Code-defined agents, merged with file-loaded ones (code wins on key collision). */ + agents?: Record; + /** Agent used when clients don't specify one. Defaults to the first-registered agent or the file with `default: true` frontmatter. */ + defaultAgent?: string; + /** Default model for agents that don't specify their own (in code or frontmatter). */ + defaultModel?: AgentAdapter | Promise | string; + /** Ambient tool library. Keys may be referenced by markdown frontmatter via `tools: [key1, key2]`. */ + tools?: Record; + /** Whether to auto-inherit every ToolProvider plugin's toolkit. Accepts a boolean shorthand. */ + autoInheritTools?: boolean | AutoInheritToolsConfig; + /** Persistent thread store. Default: in-memory. */ + threadStore?: ThreadStore; + /** Customize or disable the AppKit base system prompt. */ + baseSystemPrompt?: BaseSystemPromptOption; + /** + * MCP server host policy. By default only same-origin Databricks workspace + * URLs may be used as MCP endpoints; custom hosts must be explicitly + * allowlisted here. Workspace credentials (SP / OBO) are never forwarded + * to non-workspace hosts. + */ + mcp?: McpHostPolicyConfig; + /** + * Human-in-the-loop approval gate for destructive tool calls. When enabled + * (the default), the agents plugin emits an `appkit.approval_pending` SSE + * event before executing any tool annotated `destructive: true` and waits + * for a `POST /chat/approve` decision from the same user who initiated the + * stream. A missing decision after `timeoutMs` auto-denies the call. + */ + approval?: { + /** Require human approval for tools annotated `destructive: true`. Default: `true`. */ + requireForDestructive?: boolean; + /** Milliseconds to wait before auto-denying. Default: 60_000. */ + timeoutMs?: number; + }; + /** + * Runtime resource limits applied during agent execution. Defaults are + * tuned to protect a single-instance deployment from a misbehaving user or + * a runaway prompt injection; tighten or relax as appropriate for the + * deployment's scale and trust model. Request-body caps (chat message + * size, invocations input size / length) are enforced statically by the + * Zod schemas and are not configurable here. + */ + limits?: { + /** + * Max concurrent chat streams a single user may have open. Subsequent + * `POST /chat` requests from that user while at-limit are rejected with + * HTTP 429. Default: `5`. + */ + maxConcurrentStreamsPerUser?: number; + /** + * Max tool invocations per agent run (across the full tool-call graph, + * including sub-agent invocations). A run that exceeds the budget is + * aborted with a terminal error event. Default: `50`. + */ + maxToolCalls?: number; + /** + * Max sub-agent recursion depth. Protects against a prompt-injected + * agent that delegates to a sub-agent which in turn delegates back to + * itself (directly or transitively). Default: `3`. + */ + maxSubAgentDepth?: number; + }; +} + +/** Internal tool-index entry after a tool record has been resolved to a dispatchable form. */ +export type ResolvedToolEntry = + | { + source: "toolkit"; + pluginName: string; + localName: string; + def: AgentToolDefinition; + } + | { + source: "function"; + functionTool: FunctionTool; + def: AgentToolDefinition; + } + | { + source: "mcp"; + mcpToolName: string; + def: AgentToolDefinition; + } + | { + source: "subagent"; + agentName: string; + def: AgentToolDefinition; + }; + +export interface RegisteredAgent { + name: string; + instructions: string; + adapter: AgentAdapter; + toolIndex: Map; + baseSystemPrompt?: BaseSystemPromptOption; + maxSteps?: number; + maxTokens?: number; + /** Mirrors `AgentDefinition.ephemeral` — skip thread persistence. */ + ephemeral?: boolean; +} + +/** + * Type guard for `ToolkitEntry` — used by the agents plugin to differentiate + * toolkit references from inline tools in a mixed `tools` record. + */ +export function isToolkitEntry(value: unknown): value is ToolkitEntry { + return ( + typeof value === "object" && + value !== null && + (value as { __toolkitRef?: unknown }).__toolkitRef === true + ); +} diff --git a/packages/appkit/src/plugins/analytics/analytics.ts b/packages/appkit/src/plugins/analytics/analytics.ts index a9c688da..78d11d4d 100644 --- a/packages/appkit/src/plugins/analytics/analytics.ts +++ b/packages/appkit/src/plugins/analytics/analytics.ts @@ -1,16 +1,26 @@ import type { WorkspaceClient } from "@databricks/sdk-experimental"; import type express from "express"; import type { + AgentToolDefinition, IAppRouter, PluginExecuteConfig, SQLTypeMarker, StreamExecutionSettings, + ToolProvider, } from "shared"; +import { z } from "zod"; import { SQLWarehouseConnector } from "../../connectors"; import { getWarehouseId, getWorkspaceClient } from "../../context"; import { createLogger } from "../../logging/logger"; import { Plugin, toPlugin } from "../../plugin"; import type { PluginManifest } from "../../registry"; +import { buildToolkitEntries } from "../agents/build-toolkit"; +import { + defineTool, + executeFromRegistry, + toolsFromRegistry, +} from "../agents/tools/define-tool"; +import { assertReadOnlySql } from "../agents/tools/sql-policy"; import { queryDefaults } from "./defaults"; import manifest from "./manifest.json"; import { QueryProcessor } from "./query"; @@ -22,7 +32,7 @@ import type { const logger = createLogger("analytics"); -export class AnalyticsPlugin extends Plugin { +export class AnalyticsPlugin extends Plugin implements ToolProvider { /** Plugin manifest declaring metadata and resource requirements */ static manifest = manifest as PluginManifest<"analytics">; @@ -262,6 +272,52 @@ export class AnalyticsPlugin extends Plugin { this.streamManager.abortAll(); } + private tools = { + query: defineTool({ + description: + "Execute a read-only SQL query against the Databricks SQL warehouse. Only SELECT, WITH, SHOW, EXPLAIN, and DESCRIBE statements are accepted; writes are rejected. Returns the query results as JSON.", + schema: z.object({ + query: z + .string() + .describe( + "The SQL query to execute. Must be a SELECT, WITH, SHOW, EXPLAIN, or DESCRIBE statement.", + ), + }), + annotations: { + readOnly: true, + requiresUserContext: true, + }, + autoInheritable: true, + handler: (args, signal) => { + assertReadOnlySql(args.query); + return this.query(args.query, undefined, undefined, signal); + }, + }), + }; + + getAgentTools(): AgentToolDefinition[] { + return toolsFromRegistry(this.tools); + } + + async executeAgentTool( + name: string, + args: unknown, + signal?: AbortSignal, + ): Promise { + return executeFromRegistry(this.tools, name, args, signal); + } + + /** + * Returns the plugin's tools as a keyed record of `ToolkitEntry` markers. + * Called by the agents plugin (via `resolveToolkitFromProvider`) to spread + * a filtered, renamed view of the plugin's tools into an agent's tool + * index. Most callers should go through `fromPlugin(analytics, opts)` at + * module scope instead of reaching for this directly. + */ + toolkit(opts?: import("../agents/types").ToolkitOptions) { + return buildToolkitEntries(this.name, this.tools, opts); + } + /** * Returns the public exports for the analytics plugin. * Note: `asUser()` is automatically added by AppKit. diff --git a/packages/appkit/src/plugins/analytics/tests/analytics.integration.test.ts b/packages/appkit/src/plugins/analytics/tests/analytics.integration.test.ts index cb73394a..0cec2298 100644 --- a/packages/appkit/src/plugins/analytics/tests/analytics.integration.test.ts +++ b/packages/appkit/src/plugins/analytics/tests/analytics.integration.test.ts @@ -46,13 +46,11 @@ describe("Analytics Plugin Integration", () => { serverPlugin({ port: TEST_PORT, host: "127.0.0.1", - autoStart: false, }), analytics({}), ], }); - await app.server.start(); server = app.server.getServer(); baseUrl = `http://127.0.0.1:${TEST_PORT}`; }); diff --git a/packages/appkit/src/plugins/analytics/tests/analytics.readonly.test.ts b/packages/appkit/src/plugins/analytics/tests/analytics.readonly.test.ts new file mode 100644 index 00000000..42c9b516 --- /dev/null +++ b/packages/appkit/src/plugins/analytics/tests/analytics.readonly.test.ts @@ -0,0 +1,133 @@ +import { describe, expect, test, vi } from "vitest"; + +vi.mock("../../../cache", () => ({ + CacheManager: { + getInstanceSync: vi.fn(() => ({ + get: vi.fn(), + set: vi.fn(), + delete: vi.fn(), + getOrExecute: vi.fn(async (_k: unknown[], fn: () => Promise) => + fn(), + ), + generateKey: vi.fn(() => "test-key"), + })), + }, +})); + +import { AnalyticsPlugin } from "../analytics"; + +/** + * Tests the read-only SQL enforcement on the analytics agent tool. + * + * The tool is annotated `{ readOnly: true, requiresUserContext: true }`; this + * suite verifies that the annotation is enforced at execution time — not just + * exposed as metadata to the LLM — by the `assertReadOnlySql` guard in the + * tool's handler. + */ + +function makePlugin(): AnalyticsPlugin { + return new AnalyticsPlugin({}); +} + +describe("AnalyticsPlugin.query agent tool — readOnly annotation", () => { + test("is advertised with readOnly:true and requiresUserContext:true", () => { + const plugin = makePlugin(); + const defs = plugin.getAgentTools(); + const query = defs.find((d) => d.name === "query"); + expect(query).toBeDefined(); + expect(query?.annotations).toEqual({ + readOnly: true, + requiresUserContext: true, + }); + }); +}); + +describe("AnalyticsPlugin.query agent tool — runtime enforcement", () => { + test("rejects a DROP statement before it reaches this.query", async () => { + const plugin = makePlugin(); + const spy = vi + .spyOn(plugin, "query") + // biome-ignore lint/suspicious/noExplicitAny: mocked return + .mockResolvedValue({ rows: [] } as any); + await expect( + plugin.executeAgentTool("query", { query: "DROP TABLE users" }), + ).rejects.toThrow(/read-only policy violation/i); + expect(spy).not.toHaveBeenCalled(); + }); + + test("rejects UPDATE, DELETE, INSERT, TRUNCATE, GRANT", async () => { + const plugin = makePlugin(); + const spy = vi + .spyOn(plugin, "query") + // biome-ignore lint/suspicious/noExplicitAny: mocked return + .mockResolvedValue({ rows: [] } as any); + for (const q of [ + "UPDATE users SET email='x'", + "DELETE FROM orders", + "INSERT INTO x VALUES (1)", + "TRUNCATE TABLE orders", + "GRANT SELECT ON t TO u", + ]) { + await expect( + plugin.executeAgentTool("query", { query: q }), + ).rejects.toThrow(/read-only policy violation/i); + } + expect(spy).not.toHaveBeenCalled(); + }); + + test("rejects a stacked SELECT + DROP", async () => { + const plugin = makePlugin(); + const spy = vi + .spyOn(plugin, "query") + // biome-ignore lint/suspicious/noExplicitAny: mocked return + .mockResolvedValue({ rows: [] } as any); + await expect( + plugin.executeAgentTool("query", { + query: "SELECT 1; DROP TABLE users", + }), + ).rejects.toThrow(/DROP/); + expect(spy).not.toHaveBeenCalled(); + }); + + test("passes a plain SELECT through to this.query", async () => { + const plugin = makePlugin(); + const spy = vi + .spyOn(plugin, "query") + // biome-ignore lint/suspicious/noExplicitAny: mocked return + .mockResolvedValue({ rows: [{ id: 1 }] } as any); + const result = await plugin.executeAgentTool("query", { + query: "SELECT * FROM main.sales.orders", + }); + expect(result).toEqual({ rows: [{ id: 1 }] }); + expect(spy).toHaveBeenCalledWith( + "SELECT * FROM main.sales.orders", + undefined, + undefined, + undefined, + ); + }); + + test("passes WITH … SELECT through", async () => { + const plugin = makePlugin(); + const spy = vi + .spyOn(plugin, "query") + // biome-ignore lint/suspicious/noExplicitAny: mocked return + .mockResolvedValue({ rows: [] } as any); + await plugin.executeAgentTool("query", { + query: "WITH a AS (SELECT 1) SELECT * FROM a", + }); + expect(spy).toHaveBeenCalledOnce(); + }); + + test("passes SHOW TABLES through", async () => { + const plugin = makePlugin(); + const spy = vi + .spyOn(plugin, "query") + // biome-ignore lint/suspicious/noExplicitAny: mocked return + .mockResolvedValue({ rows: [] } as any); + await plugin.executeAgentTool("query", { + query: "SHOW TABLES IN main.sales", + }); + expect(spy).toHaveBeenCalledOnce(); + }); +}); diff --git a/packages/appkit/src/plugins/analytics/tests/analytics.test.ts b/packages/appkit/src/plugins/analytics/tests/analytics.test.ts index 9a30440e..29157fff 100644 --- a/packages/appkit/src/plugins/analytics/tests/analytics.test.ts +++ b/packages/appkit/src/plugins/analytics/tests/analytics.test.ts @@ -608,4 +608,22 @@ describe("Analytics Plugin", () => { }); }); }); + + describe("toolkit()", () => { + test("produces ToolkitEntry records keyed by the plugin name", () => { + const plugin = new AnalyticsPlugin({ name: "analytics" }); + const entries = plugin.toolkit(); + expect(Object.keys(entries)).toContain("analytics.query"); + const entry = entries["analytics.query"]; + expect(entry.__toolkitRef).toBe(true); + expect(entry.pluginName).toBe("analytics"); + expect(entry.localName).toBe("query"); + }); + + test("respects prefix and only options", () => { + const plugin = new AnalyticsPlugin({ name: "analytics" }); + const entries = plugin.toolkit({ prefix: "", only: ["query"] }); + expect(Object.keys(entries)).toEqual(["query"]); + }); + }); }); diff --git a/packages/appkit/src/plugins/files/plugin.ts b/packages/appkit/src/plugins/files/plugin.ts index 75f2e14d..1a80e868 100644 --- a/packages/appkit/src/plugins/files/plugin.ts +++ b/packages/appkit/src/plugins/files/plugin.ts @@ -2,19 +2,37 @@ import { STATUS_CODES } from "node:http"; import { Readable } from "node:stream"; import { ApiError } from "@databricks/sdk-experimental"; import type express from "express"; -import type { IAppRouter, PluginExecutionSettings } from "shared"; +import type { + AgentToolDefinition, + IAppRouter, + PluginExecutionSettings, + ToolProvider, +} from "shared"; +import { z } from "zod"; import { contentTypeFromPath, FilesConnector, isSafeInlineContentType, validateCustomContentTypes, } from "../../connectors/files"; -import { getCurrentUserId, getWorkspaceClient } from "../../context"; +import { + getCurrentUserId, + getExecutionContext, + getWorkspaceClient, +} from "../../context"; +import { isUserContext } from "../../context/user-context"; import { AuthenticationError } from "../../errors"; import { createLogger } from "../../logging/logger"; import { Plugin, toPlugin } from "../../plugin"; import type { PluginManifest, ResourceRequirement } from "../../registry"; import { ResourceType } from "../../registry"; +import { buildToolkitEntries } from "../agents/build-toolkit"; +import { + defineTool, + executeFromRegistry, + type ToolRegistry, + toolsFromRegistry, +} from "../agents/tools/define-tool"; import { FILES_DOWNLOAD_DEFAULTS, FILES_MAX_UPLOAD_SIZE, @@ -41,7 +59,7 @@ import type { const logger = createLogger("files"); -export class FilesPlugin extends Plugin { +export class FilesPlugin extends Plugin implements ToolProvider { name = "files"; /** Plugin manifest declaring metadata and resource requirements. */ @@ -52,6 +70,7 @@ export class FilesPlugin extends Plugin { private volumeConnectors: Record = {}; private volumeConfigs: Record = {}; private volumeKeys: string[] = []; + private tools: ToolRegistry = {}; /** * Scans `process.env` for `DATABRICKS_VOLUME_*` keys and merges them with @@ -224,11 +243,11 @@ export class FilesPlugin extends Plugin { telemetry: config.telemetry, customContentTypes: mergedConfig.customContentTypes, }); - } - // Warn at startup for volumes without an explicit policy - for (const key of this.volumeKeys) { - if (!volumes[key].policy) { + Object.assign(this.tools, this._defineVolumeTools(key)); + + // Warn at startup for volumes without an explicit policy + if (!volumeCfg.policy) { logger.warn( 'Volume "%s" has no explicit policy — defaulting to publicRead(). ' + "Set a policy in files({ volumes: { %s: { policy: ... } } }) to silence this warning.", @@ -1019,6 +1038,91 @@ export class FilesPlugin extends Plugin { }; } + /** + * Builds the agent-tool registry entries for a single volume. One set of + * tools per configured volume, keyed by `${volumeKey}.${method}`. + * + * Each handler resolves the caller's identity from the current execution + * context (OBO user when the agent run is wrapped in `asUser(req)`, service + * principal otherwise in local dev) and dispatches through + * `createVolumeAPI(volumeKey, user)` so the volume's policy is enforced + * uniformly for agent and HTTP callers. + */ + private _defineVolumeTools(volumeKey: string): ToolRegistry { + const buildUser = (): FilePolicyUser => { + const ctx = getExecutionContext(); + return isUserContext(ctx) + ? { id: ctx.userId } + : { id: ctx.serviceUserId, isServicePrincipal: true }; + }; + const api = () => this.createVolumeAPI(volumeKey, buildUser()); + return { + [`${volumeKey}.list`]: defineTool({ + description: `List files and directories in the "${volumeKey}" volume`, + schema: z.object({ + path: z + .string() + .optional() + .describe("Directory path to list (optional, defaults to root)"), + }), + annotations: { readOnly: true, requiresUserContext: true }, + autoInheritable: true, + handler: (args) => api().list(args.path), + }), + [`${volumeKey}.read`]: defineTool({ + description: `Read a text file from the "${volumeKey}" volume`, + schema: z.object({ + path: z.string().describe("File path to read"), + }), + annotations: { readOnly: true, requiresUserContext: true }, + autoInheritable: true, + handler: (args) => api().read(args.path), + }), + [`${volumeKey}.exists`]: defineTool({ + description: `Check if a file or directory exists in the "${volumeKey}" volume`, + schema: z.object({ + path: z.string().describe("Path to check"), + }), + annotations: { readOnly: true, requiresUserContext: true }, + autoInheritable: true, + handler: (args) => api().exists(args.path), + }), + [`${volumeKey}.metadata`]: defineTool({ + description: `Get metadata (size, type, last modified) for a file in the "${volumeKey}" volume`, + schema: z.object({ + path: z.string().describe("File path"), + }), + annotations: { readOnly: true, requiresUserContext: true }, + autoInheritable: true, + handler: (args) => api().metadata(args.path), + }), + [`${volumeKey}.upload`]: defineTool({ + description: `Upload a text file to the "${volumeKey}" volume`, + schema: z.object({ + path: z.string().describe("Destination file path"), + contents: z.string().describe("File contents as a string"), + overwrite: z + .boolean() + .optional() + .describe("Whether to overwrite existing file"), + }), + annotations: { destructive: true, requiresUserContext: true }, + handler: (args) => + api().upload(args.path, args.contents, { + overwrite: args.overwrite, + }), + }), + [`${volumeKey}.delete`]: defineTool({ + description: `Delete a file from the "${volumeKey}" volume`, + schema: z.object({ + path: z.string().describe("File path to delete"), + }), + annotations: { destructive: true, requiresUserContext: true }, + handler: (args) => api().delete(args.path), + }), + }; + } + private inflightWrites = 0; private trackWrite(fn: () => Promise): Promise { @@ -1047,6 +1151,22 @@ export class FilesPlugin extends Plugin { this.streamManager.abortAll(); } + getAgentTools(): AgentToolDefinition[] { + return toolsFromRegistry(this.tools); + } + + async executeAgentTool( + name: string, + args: unknown, + signal?: AbortSignal, + ): Promise { + return executeFromRegistry(this.tools, name, args, signal); + } + + toolkit(opts?: import("../agents/types").ToolkitOptions) { + return buildToolkitEntries(this.name, this.tools, opts); + } + /** * Returns the programmatic API for the Files plugin. * Callable with a volume key to get a volume-scoped handle. diff --git a/packages/appkit/src/plugins/files/tests/plugin.integration.test.ts b/packages/appkit/src/plugins/files/tests/plugin.integration.test.ts index da90760d..3c8ff74e 100644 --- a/packages/appkit/src/plugins/files/tests/plugin.integration.test.ts +++ b/packages/appkit/src/plugins/files/tests/plugin.integration.test.ts @@ -87,13 +87,11 @@ describe("Files Plugin Integration", () => { serverPlugin({ port: TEST_PORT, host: "127.0.0.1", - autoStart: false, }), files(), ], }); - await appkit.server.start(); server = appkit.server.getServer(); baseUrl = `http://127.0.0.1:${TEST_PORT}`; }); diff --git a/packages/appkit/src/plugins/files/tests/plugin.test.ts b/packages/appkit/src/plugins/files/tests/plugin.test.ts index a4b9bea2..bbaa1b98 100644 --- a/packages/appkit/src/plugins/files/tests/plugin.test.ts +++ b/packages/appkit/src/plugins/files/tests/plugin.test.ts @@ -205,6 +205,62 @@ describe("FilesPlugin", () => { }); }); + describe("getAgentTools / executeAgentTool", () => { + test("produces independent tool entries per volume", () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const tools = plugin.getAgentTools(); + const names = tools.map((t) => t.name); + + expect(names).toContain("uploads.list"); + expect(names).toContain("uploads.read"); + expect(names).toContain("uploads.exists"); + expect(names).toContain("uploads.metadata"); + expect(names).toContain("uploads.upload"); + expect(names).toContain("uploads.delete"); + + expect(names).toContain("exports.list"); + expect(names).toContain("exports.read"); + expect(names).toContain("exports.delete"); + + expect(tools).toHaveLength(12); + }); + + test("dispatches to the correct volume API based on the tool name", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const asyncIterable = (items: { path: string }[]) => ({ + [Symbol.asyncIterator]: async function* () { + for (const item of items) yield item; + }, + }); + mockClient.files.listDirectoryContents.mockReturnValueOnce( + asyncIterable([{ path: "uploads-file" }]), + ); + mockClient.files.listDirectoryContents.mockReturnValueOnce( + asyncIterable([{ path: "exports-file" }]), + ); + + const uploadsResult = (await plugin.executeAgentTool( + "uploads.list", + {}, + )) as { path: string }[]; + const exportsResult = (await plugin.executeAgentTool( + "exports.list", + {}, + )) as { path: string }[]; + + expect(uploadsResult[0].path).toBe("uploads-file"); + expect(exportsResult[0].path).toBe("exports-file"); + }); + + test("returns LLM-friendly error string for invalid tool args", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const result = await plugin.executeAgentTool("uploads.read", {}); + expect(typeof result).toBe("string"); + expect(result).toContain("Invalid arguments for uploads.read"); + expect(result).toContain("path"); + }); + }); + describe("exports()", () => { test("returns a callable function with a .volume alias", () => { const plugin = new FilesPlugin(VOLUMES_CONFIG); diff --git a/packages/appkit/src/plugins/genie/genie.ts b/packages/appkit/src/plugins/genie/genie.ts index 712aadbf..3167794e 100644 --- a/packages/appkit/src/plugins/genie/genie.ts +++ b/packages/appkit/src/plugins/genie/genie.ts @@ -1,11 +1,24 @@ import { randomUUID } from "node:crypto"; import type express from "express"; -import type { IAppRouter, StreamExecutionSettings } from "shared"; +import type { + AgentToolDefinition, + IAppRouter, + StreamExecutionSettings, + ToolProvider, +} from "shared"; +import { z } from "zod"; import { GenieConnector } from "../../connectors"; import { getWorkspaceClient } from "../../context"; import { createLogger } from "../../logging"; import { Plugin, toPlugin } from "../../plugin"; import type { PluginManifest } from "../../registry"; +import { buildToolkitEntries } from "../agents/build-toolkit"; +import { + defineTool, + executeFromRegistry, + type ToolRegistry, + toolsFromRegistry, +} from "../agents/tools/define-tool"; import { genieStreamDefaults } from "./defaults"; import manifest from "./manifest.json"; import type { @@ -17,7 +30,7 @@ import type { const logger = createLogger("genie"); -export class GeniePlugin extends Plugin { +export class GeniePlugin extends Plugin implements ToolProvider { static manifest = manifest as PluginManifest<"genie">; protected static description = @@ -25,6 +38,7 @@ export class GeniePlugin extends Plugin { protected declare config: IGenieConfig; private readonly genieConnector: GenieConnector; + private tools: ToolRegistry = {}; constructor(config: IGenieConfig) { super(config); @@ -36,6 +50,54 @@ export class GeniePlugin extends Plugin { timeout: this.config.timeout, maxMessages: 200, }); + + for (const alias of Object.keys(this.config.spaces ?? {})) { + Object.assign(this.tools, this._defineSpaceTools(alias)); + } + } + + /** + * Builds the registry entries for a single Genie space alias. + * One set of tools per configured space, keyed by `${alias}.${method}`. + */ + private _defineSpaceTools(alias: string): ToolRegistry { + return { + [`${alias}.sendMessage`]: defineTool({ + description: `Send a natural language question to the Genie space "${alias}" and get data analysis results`, + schema: z.object({ + content: z.string().describe("The natural language question to ask"), + conversationId: z + .string() + .optional() + .describe( + "Optional conversation ID to continue an existing conversation", + ), + }), + annotations: { requiresUserContext: true }, + handler: async (args) => { + const events: GenieStreamEvent[] = []; + for await (const event of this.sendMessage( + alias, + args.content, + args.conversationId, + )) { + events.push(event); + } + return events; + }, + }), + [`${alias}.getConversation`]: defineTool({ + description: `Retrieve the conversation history from the Genie space "${alias}"`, + schema: z.object({ + conversationId: z + .string() + .describe("The conversation ID to retrieve"), + }), + annotations: { readOnly: true, requiresUserContext: true }, + autoInheritable: true, + handler: (args) => this.getConversation(alias, args.conversationId), + }), + }; } private defaultSpaces(): Record { @@ -287,6 +349,22 @@ export class GeniePlugin extends Plugin { this.streamManager.abortAll(); } + getAgentTools(): AgentToolDefinition[] { + return toolsFromRegistry(this.tools); + } + + async executeAgentTool( + name: string, + args: unknown, + signal?: AbortSignal, + ): Promise { + return executeFromRegistry(this.tools, name, args, signal); + } + + toolkit(opts?: import("../agents/types").ToolkitOptions) { + return buildToolkitEntries(this.name, this.tools, opts); + } + exports() { return { sendMessage: this.sendMessage, diff --git a/packages/appkit/src/plugins/genie/tests/genie.test.ts b/packages/appkit/src/plugins/genie/tests/genie.test.ts index 3cf0784d..672e6242 100644 --- a/packages/appkit/src/plugins/genie/tests/genie.test.ts +++ b/packages/appkit/src/plugins/genie/tests/genie.test.ts @@ -187,6 +187,30 @@ describe("Genie Plugin", () => { }); }); + describe("getAgentTools / executeAgentTool", () => { + test("produces independent tool entries per configured space", () => { + const plugin = new GeniePlugin(config); + const names = plugin.getAgentTools().map((t) => t.name); + + expect(names).toContain("myspace.sendMessage"); + expect(names).toContain("myspace.getConversation"); + expect(names).toContain("salesbot.sendMessage"); + expect(names).toContain("salesbot.getConversation"); + expect(names).toHaveLength(4); + }); + + test("returns LLM-friendly error string for invalid tool args", async () => { + const plugin = new GeniePlugin(config); + const result = await plugin.executeAgentTool( + "myspace.getConversation", + {}, + ); + expect(typeof result).toBe("string"); + expect(result).toContain("Invalid arguments for myspace.getConversation"); + expect(result).toContain("conversationId"); + }); + }); + describe("space alias resolution", () => { test("should return 404 for unknown alias", async () => { const plugin = new GeniePlugin(config); diff --git a/packages/appkit/src/plugins/lakebase/lakebase.ts b/packages/appkit/src/plugins/lakebase/lakebase.ts index 3071d539..aaf61b51 100644 --- a/packages/appkit/src/plugins/lakebase/lakebase.ts +++ b/packages/appkit/src/plugins/lakebase/lakebase.ts @@ -1,4 +1,6 @@ import type { Pool, QueryResult, QueryResultRow } from "pg"; +import type { AgentToolDefinition, ToolProvider } from "shared"; +import { z } from "zod"; import { createLakebasePool, getLakebaseOrmConfig, @@ -8,6 +10,13 @@ import { import { createLogger } from "../../logging/logger"; import { Plugin, toPlugin } from "../../plugin"; import type { PluginManifest } from "../../registry"; +import { buildToolkitEntries } from "../agents/build-toolkit"; +import { + defineTool, + executeFromRegistry, + toolsFromRegistry, +} from "../agents/tools/define-tool"; +import { assertReadOnlySql } from "../agents/tools/sql-policy"; import manifest from "./manifest.json"; import type { ILakebaseConfig } from "./types"; @@ -30,18 +39,13 @@ const logger = createLogger("lakebase"); * const result = await AppKit.lakebase.query("SELECT * FROM users WHERE id = $1", [userId]); * ``` */ -class LakebasePlugin extends Plugin { +export class LakebasePlugin extends Plugin implements ToolProvider { /** Plugin manifest declaring metadata and resource requirements */ static manifest = manifest as PluginManifest<"lakebase">; protected declare config: ILakebaseConfig; private pool: Pool | null = null; - constructor(config: ILakebaseConfig) { - super(config); - this.config = config; - } - /** * Initializes the Lakebase connection pool. * Called automatically by AppKit during the plugin setup phase. @@ -79,6 +83,39 @@ class LakebasePlugin extends Plugin { return this.pool!.query(text, values); } + /** + * Execute a single statement inside a `BEGIN READ ONLY … ROLLBACK` + * transaction on a dedicated client. + * + * The three commands MUST share a connection — a naive + * `pool.query("BEGIN READ ONLY; ; ROLLBACK")` batch cannot accept + * parameter values (PostgreSQL's Extended Query protocol rejects multi- + * statement prepared queries), which would silently break every + * parameterized query the agent tool issues. + * + * Returns the raw `rows` array for the user's statement. Side effects the + * statement may attempt (writes, writable-function side effects) are + * rejected by PostgreSQL under the read-only transaction posture. + */ + private async runReadOnlyStatement( + text: string, + values?: unknown[], + ): Promise { + // biome-ignore lint/style/noNonNullAssertion: pool is guaranteed non-null after setup() + const client = await this.pool!.connect(); + try { + await client.query("BEGIN READ ONLY"); + const result = await client.query(text, values); + return result.rows; + } finally { + try { + await client.query("ROLLBACK"); + } finally { + client.release(); + } + } + } + /** * Gracefully drains and closes the connection pool. * Called automatically by AppKit during shutdown. @@ -102,6 +139,82 @@ class LakebasePlugin extends Plugin { * - `getOrmConfig()` — Returns a config object compatible with Drizzle, TypeORM, Sequelize, etc. * - `getPgConfig()` — Returns a `pg.PoolConfig` object for manual pool construction */ + + /** + * Agent tool registry. Empty by default — the Lakebase plugin does NOT + * expose its SQL connection to LLM agents unless the developer explicitly + * opts in via `config.exposeAsAgentTool`. See {@link buildQueryTool}. + */ + private tools: Record> = {}; + + constructor(config: ILakebaseConfig) { + super(config); + this.config = config; + if (config.exposeAsAgentTool) { + if (config.exposeAsAgentTool.iUnderstandRunsAsServicePrincipal !== true) { + throw new Error( + "lakebase.exposeAsAgentTool requires iUnderstandRunsAsServicePrincipal: true — this acknowledges that SQL statements authored by the LLM run with the application's service-principal credentials regardless of which end user initiated the request.", + ); + } + this.tools = { query: this.buildQueryTool(config.exposeAsAgentTool) }; + logger.warn( + "Lakebase agent tool is enabled (readOnly=%s). Every agent with access to this plugin can execute SQL against the Lakebase database as the service principal.", + config.exposeAsAgentTool.readOnly !== false, + ); + } + } + + private buildQueryTool( + opt: NonNullable, + ) { + const readOnly = opt.readOnly !== false; + return defineTool({ + description: readOnly + ? "Execute a read-only SQL query against the Lakebase PostgreSQL database. Only SELECT, WITH, SHOW, EXPLAIN, and DESCRIBE statements are accepted. Use $1, $2, etc. as placeholders and pass values separately. Runs as the application's service principal." + : "Execute a parameterized SQL statement against the Lakebase PostgreSQL database. Use $1, $2, etc. as placeholders and pass values separately. Runs as the application's service principal. This tool can modify data; every invocation requires explicit human approval.", + schema: z.object({ + text: z + .string() + .describe( + "SQL statement with $1, $2, ... placeholders for parameters", + ), + values: z + .array(z.unknown()) + .optional() + .describe("Parameter values corresponding to placeholders"), + }), + annotations: { + readOnly, + destructive: !readOnly, + idempotent: false, + }, + handler: async (args) => { + if (readOnly) { + assertReadOnlySql(args.text); + return this.runReadOnlyStatement(args.text, args.values); + } + const result = await this.query(args.text, args.values); + return result.rows; + }, + }); + } + + getAgentTools(): AgentToolDefinition[] { + return toolsFromRegistry(this.tools); + } + + async executeAgentTool( + name: string, + args: unknown, + signal?: AbortSignal, + ): Promise { + return executeFromRegistry(this.tools, name, args, signal); + } + + toolkit(opts?: import("../agents/types").ToolkitOptions) { + return buildToolkitEntries(this.name, this.tools, opts); + } + exports() { return { // biome-ignore lint/style/noNonNullAssertion: pool is guaranteed non-null after setup(), which AppKit always awaits before exposing the plugin API diff --git a/packages/appkit/src/plugins/lakebase/tests/lakebase-agent-tool.test.ts b/packages/appkit/src/plugins/lakebase/tests/lakebase-agent-tool.test.ts new file mode 100644 index 00000000..8e59fb32 --- /dev/null +++ b/packages/appkit/src/plugins/lakebase/tests/lakebase-agent-tool.test.ts @@ -0,0 +1,238 @@ +import { beforeEach, describe, expect, test, vi } from "vitest"; + +/** + * Tests the agent-tool surface of the Lakebase plugin. + * + * The plugin defaults to **not** exposing an agent tool at all. Enabling the + * tool is an explicit opt-in (`exposeAsAgentTool` with an acknowledgement + * flag) because every invocation runs with the application's service- + * principal credentials regardless of which end user initiated the request. + */ + +vi.mock("../../../cache", () => ({ + CacheManager: { + getInstanceSync: vi.fn(() => ({ + get: vi.fn(), + set: vi.fn(), + delete: vi.fn(), + getOrExecute: vi.fn(async (_k: unknown[], fn: () => Promise) => + fn(), + ), + generateKey: vi.fn(() => "test-key"), + })), + }, +})); + +// Client calls recorded by the read-only-statement test. The `connect()` +// mock returns a fresh client whose `query` pushes to this array so tests +// can assert the exact sequence of statements emitted on the dedicated +// connection. +const clientQueries: Array<{ text: string; values?: unknown[] }> = []; +const clientReleases: number[] = []; + +vi.mock("../../../connectors/lakebase", () => ({ + createLakebasePool: vi.fn(() => ({ + query: vi.fn(), + connect: vi.fn(async () => { + let releaseCalls = 0; + return { + query: vi.fn(async (text: string, values?: unknown[]) => { + clientQueries.push({ text, values }); + return { rows: [{ n: 1 }] }; + }), + release: vi.fn(() => { + releaseCalls += 1; + clientReleases.push(releaseCalls); + }), + }; + }), + end: vi.fn(), + })), + getLakebaseOrmConfig: vi.fn(() => ({})), + getLakebasePgConfig: vi.fn(() => ({})), + getUsernameWithApiLookup: vi.fn(async () => "test-user"), +})); + +import { LakebasePlugin } from "../lakebase"; + +function makePlugin( + config: ConstructorParameters[0], +): LakebasePlugin { + return new LakebasePlugin(config); +} + +describe("LakebasePlugin — agent tool opt-in", () => { + test("does not register an agent tool by default", () => { + const plugin = makePlugin({}); + expect(plugin.getAgentTools()).toEqual([]); + }); + + test("does not register a tool when `pool` is set but `exposeAsAgentTool` is absent", () => { + const plugin = makePlugin({ pool: {} }); + expect(plugin.getAgentTools()).toEqual([]); + }); + + test("throws when exposeAsAgentTool is set without the acknowledgement flag", () => { + expect(() => + makePlugin({ + exposeAsAgentTool: + // biome-ignore lint/suspicious/noExplicitAny: intentionally bypass the required flag for the negative case + {} as any, + }), + ).toThrow(/iUnderstandRunsAsServicePrincipal/); + }); + + test("registers a read-only tool when opted in with defaults", () => { + const plugin = makePlugin({ + exposeAsAgentTool: { iUnderstandRunsAsServicePrincipal: true }, + }); + const defs = plugin.getAgentTools(); + expect(defs).toHaveLength(1); + expect(defs[0].name).toBe("query"); + expect(defs[0].annotations).toEqual({ + readOnly: true, + destructive: false, + idempotent: false, + }); + }); + + test("registers a destructive tool when readOnly: false is explicit", () => { + const plugin = makePlugin({ + exposeAsAgentTool: { + iUnderstandRunsAsServicePrincipal: true, + readOnly: false, + }, + }); + const defs = plugin.getAgentTools(); + expect(defs[0].annotations).toEqual({ + readOnly: false, + destructive: true, + idempotent: false, + }); + }); +}); + +describe("LakebasePlugin — readOnly enforcement", () => { + let plugin: LakebasePlugin; + + beforeEach(async () => { + clientQueries.length = 0; + clientReleases.length = 0; + plugin = makePlugin({ + exposeAsAgentTool: { iUnderstandRunsAsServicePrincipal: true }, + }); + await plugin.setup(); + }); + + test("rejects DROP before acquiring a client", async () => { + await expect( + plugin.executeAgentTool("query", { text: "DROP TABLE users" }), + ).rejects.toThrow(/read-only policy violation/i); + expect(clientQueries).toHaveLength(0); + }); + + test("rejects UPDATE, DELETE, INSERT", async () => { + for (const text of [ + "UPDATE users SET email='x'", + "DELETE FROM orders", + "INSERT INTO x VALUES (1)", + ]) { + await expect(plugin.executeAgentTool("query", { text })).rejects.toThrow( + /read-only policy violation/i, + ); + } + expect(clientQueries).toHaveLength(0); + }); + + test("runs SELECT inside BEGIN READ ONLY / ROLLBACK on a dedicated client", async () => { + const rows = await plugin.executeAgentTool("query", { + text: "SELECT * FROM users", + }); + expect(rows).toEqual([{ n: 1 }]); + expect(clientQueries.map((c) => c.text)).toEqual([ + "BEGIN READ ONLY", + "SELECT * FROM users", + "ROLLBACK", + ]); + // Client must be released exactly once, regardless of outcome. + expect(clientReleases).toHaveLength(1); + }); + + test("forwards parameter values to the user statement only (the regression fix)", async () => { + // Prior to the fix this would have failed with "cannot insert multiple + // commands into a prepared statement" because pg's Extended Query + // protocol rejects multi-statement batches when values are supplied. + await plugin.executeAgentTool("query", { + text: "SELECT * FROM users WHERE id = $1", + values: [42], + }); + expect(clientQueries).toEqual([ + { text: "BEGIN READ ONLY", values: undefined }, + { text: "SELECT * FROM users WHERE id = $1", values: [42] }, + { text: "ROLLBACK", values: undefined }, + ]); + }); + + test("releases the client even when the user statement throws", async () => { + // Poison the client so the middle query throws (simulates a Postgres + // error like "cannot execute UPDATE in a read-only transaction"). + const { createLakebasePool } = await import("../../../connectors/lakebase"); + const connect = vi.fn(async () => ({ + query: vi + .fn() + .mockResolvedValueOnce({ rows: [] }) + .mockRejectedValueOnce(new Error("read-only violation")) + .mockResolvedValueOnce({ rows: [] }), + release: vi.fn(() => { + clientReleases.push(clientReleases.length + 1); + }), + })); + // biome-ignore lint/suspicious/noExplicitAny: test override + ( + createLakebasePool as unknown as { mockReturnValueOnce: any } + ).mockReturnValueOnce({ query: vi.fn(), connect, end: vi.fn() }); + + clientQueries.length = 0; + clientReleases.length = 0; + const leakyPlugin = makePlugin({ + exposeAsAgentTool: { iUnderstandRunsAsServicePrincipal: true }, + }); + await leakyPlugin.setup(); + + await expect( + leakyPlugin.executeAgentTool("query", { + text: "SELECT * FROM users", + }), + ).rejects.toThrow(/read-only violation/); + expect(clientReleases).toHaveLength(1); + }); +}); + +describe("LakebasePlugin — destructive mode", () => { + test("does NOT wrap in read-only transaction when readOnly: false", async () => { + const queryMock = vi.fn((_text: string, _values?: unknown[]) => + Promise.resolve({ rows: [] }), + ); + const plugin = makePlugin({ + exposeAsAgentTool: { + iUnderstandRunsAsServicePrincipal: true, + readOnly: false, + }, + }); + await plugin.setup(); + vi.spyOn(plugin, "query").mockImplementation(async (text, values) => { + queryMock(text, values); + return { rows: [] } as never; + }); + + await plugin.executeAgentTool("query", { + text: "UPDATE t SET x=1 WHERE id=$1", + values: [42], + }); + + expect(queryMock).toHaveBeenCalledWith( + "UPDATE t SET x=1 WHERE id=$1", + [42], + ); + }); +}); diff --git a/packages/appkit/src/plugins/lakebase/types.ts b/packages/appkit/src/plugins/lakebase/types.ts index ac6997c6..6703e425 100644 --- a/packages/appkit/src/plugins/lakebase/types.ts +++ b/packages/appkit/src/plugins/lakebase/types.ts @@ -1,6 +1,42 @@ import type { BasePluginConfig } from "shared"; import type { LakebasePoolConfig } from "../../connectors/lakebase"; +/** + * Opt-in configuration for exposing Lakebase as an agent-callable SQL tool. + * + * This tool executes LLM-authored SQL against the Lakebase pool. The pool is + * **always bound to the application's service-principal credentials**, so any + * agent that can call this tool effectively has full SP access to the database + * regardless of which end user initiated the request. Exposing it is a + * deliberate decision the developer must make explicitly — hence the required + * acknowledgement flag. + * + * When `readOnly: true` (default when opted in), every statement is: + * 1. Classified by {@link @databricks/appkit's sql-policy classifier}; anything + * that isn't a pure `SELECT`/`WITH`/`SHOW`/`EXPLAIN`/`DESCRIBE` is rejected. + * 2. Executed inside a `BEGIN READ ONLY … ROLLBACK` transaction so the + * PostgreSQL server rejects writes that slip past the classifier (e.g., a + * `SELECT` over a function with side effects). + * + * When `readOnly: false`, the tool is annotated `destructive: true` and the + * agents plugin will require human approval for every invocation (see + * `AgentsPluginConfig.approval`). + */ +export interface LakebaseExposeAsAgentTool { + /** + * Required acknowledgement that tool invocations run as the service principal + * and share that privilege across end users. Must be set to `true` to opt in. + */ + iUnderstandRunsAsServicePrincipal: true; + /** + * Enforce read-only execution. Defaults to `true`. Set to `false` to allow + * destructive statements — highly discouraged outside of tightly controlled + * single-user deployments. Combined with the `destructive: true` annotation, + * the agents plugin will require explicit human approval for each call. + */ + readOnly?: boolean; +} + /** * Configuration for the Lakebase plugin. * @@ -17,4 +53,11 @@ export interface ILakebaseConfig extends BasePluginConfig { * Common overrides: `max` (pool size), `connectionTimeoutMillis`, `idleTimeoutMillis`. */ pool?: Partial; + /** + * Opt-in to expose Lakebase as an agent-callable SQL tool. By default no + * agent tool is registered — the Lakebase plugin only exposes its API to + * application code. See {@link LakebaseExposeAsAgentTool} for the privilege + * implications of enabling this. + */ + exposeAsAgentTool?: LakebaseExposeAsAgentTool; } diff --git a/packages/appkit/src/plugins/server/index.ts b/packages/appkit/src/plugins/server/index.ts index e7b9b31a..cc58cc0d 100644 --- a/packages/appkit/src/plugins/server/index.ts +++ b/packages/appkit/src/plugins/server/index.ts @@ -27,17 +27,23 @@ const logger = createLogger("server"); * This plugin is responsible for starting the server and serving the static files. * It also handles the remote tunneling for development purposes. * + * The server is started automatically by `createApp` after all plugins are set up + * and the optional `onPluginsReady` callback has run. + * * @example * ```ts * createApp({ - * plugins: [server(), telemetryExamples(), analytics({})], + * plugins: [server(), analytics({})], + * onPluginsReady(appkit) { + * appkit.server.extend((app) => { + * app.get("/custom", (_req, res) => res.json({ ok: true })); + * }); + * }, * }); * ``` - * */ export class ServerPlugin extends Plugin { public static DEFAULT_CONFIG = { - autoStart: true, host: process.env.FLASK_RUN_HOST || "0.0.0.0", port: Number(process.env.DATABRICKS_APP_PORT) || 8000, }; @@ -54,23 +60,30 @@ export class ServerPlugin extends Plugin { static phase: PluginPhase = "deferred"; constructor(config: ServerConfig) { + if ("autoStart" in config) { + throw new ServerError( + "server({ autoStart }) has been removed. " + + "The server is now started automatically by createApp.\n\n" + + "Run `npx appkit codemod on-plugins-ready --write` to auto-migrate.", + ); + } super(config); this.config = config; this.serverApplication = express(); this.server = null; this.serverExtensions = []; + } + + attachContext(deps: Parameters[0] = {}): void { + super.attachContext(deps); this.telemetry.registerInstrumentations([ instrumentations.http, instrumentations.express, ]); + this.context?.registerAsRouteTarget(this); } - /** Setup the server plugin. */ - async setup() { - if (this.shouldAutoStart()) { - await this.start(); - } - } + async setup() {} /** Get the server configuration. */ getConfig() { @@ -79,11 +92,6 @@ export class ServerPlugin extends Plugin { return config; } - /** Check if the server should auto start. */ - shouldAutoStart() { - return this.config.autoStart; - } - /** * Start the server. * @@ -148,14 +156,10 @@ export class ServerPlugin extends Plugin { * * Only use this method if you need to access the server instance for advanced usage like a custom websocket server, etc. * - * @throws {Error} If the server is not started or autoStart is true. + * @throws {Error} If the server has not started yet. * @returns {HTTPServer} The server instance. */ getServer(): HTTPServer { - if (this.shouldAutoStart()) { - throw ServerError.autoStartConflict("get server"); - } - if (!this.server) { throw ServerError.notStarted(); } @@ -166,19 +170,27 @@ export class ServerPlugin extends Plugin { /** * Extend the server with custom routes or middleware. * + * Call this inside the `onPluginsReady` callback of `createApp` to register + * custom Express routes or middleware before the server starts listening. + * * @param fn - A function that receives the express application. * @returns The server plugin instance for chaining. - * @throws {Error} If autoStart is true. */ extend(fn: (app: express.Application) => void) { - if (this.shouldAutoStart()) { - throw ServerError.autoStartConflict("extend server"); - } - this.serverExtensions.push(fn); return this; } + /** + * Register a server extension from another plugin during setup. + * Unlike extend(), this does not guard on autoStart — it's designed + * for internal plugin-to-plugin coordination where extensions are + * registered before the server starts listening. + */ + addExtension(fn: (app: express.Application) => void) { + this.serverExtensions.push(fn); + } + /** * Setup the routes with the plugins. * @@ -193,14 +205,15 @@ export class ServerPlugin extends Plugin { const endpoints: PluginEndpoints = {}; const pluginConfigs: PluginClientConfigs = {}; - if (!this.config.plugins) return { endpoints, pluginConfigs }; + const plugins = this.context?.getPlugins(); + if (!plugins || plugins.size === 0) return { endpoints, pluginConfigs }; this.serverApplication.get("/health", (_, res) => { res.status(200).json({ status: "ok" }); }); this.registerEndpoint("health", "/health"); - for (const plugin of Object.values(this.config.plugins)) { + for (const plugin of plugins.values()) { if (EXCLUDED_PLUGINS.includes(plugin.name)) continue; if (plugin?.injectRoutes && typeof plugin.injectRoutes === "function") { @@ -349,8 +362,9 @@ export class ServerPlugin extends Plugin { } // 1. abort active operations from plugins - if (this.config.plugins) { - for (const plugin of Object.values(this.config.plugins)) { + const shutdownPlugins = this.context?.getPlugins(); + if (shutdownPlugins) { + for (const plugin of shutdownPlugins.values()) { if (plugin.abortActiveOperations) { try { plugin.abortActiveOperations(); @@ -389,8 +403,6 @@ export class ServerPlugin extends Plugin { exports() { const self = this; return { - /** Start the server */ - start: this.start, /** Extend the server with custom routes or middleware */ extend(fn: (app: express.Application) => void) { self.extend(fn); @@ -400,6 +412,19 @@ export class ServerPlugin extends Plugin { getServer: this.getServer, /** Get the server configuration */ getConfig: this.getConfig, + /** @deprecated Server is now started automatically by createApp. */ + start() { + throw new ServerError( + "server.start() has been removed. Use the onPluginsReady callback instead:\n\n" + + " createApp({\n" + + " plugins: [server(), ...],\n" + + " onPluginsReady(appkit) {\n" + + " appkit.server.extend(...);\n" + + " },\n" + + " });\n\n" + + "Run `npx appkit codemod on-plugins-ready --write` to auto-migrate.", + ); + }, }; } } diff --git a/packages/appkit/src/plugins/server/manifest.json b/packages/appkit/src/plugins/server/manifest.json index 11822beb..1112fbf5 100644 --- a/packages/appkit/src/plugins/server/manifest.json +++ b/packages/appkit/src/plugins/server/manifest.json @@ -11,11 +11,6 @@ "schema": { "type": "object", "properties": { - "autoStart": { - "type": "boolean", - "default": true, - "description": "Automatically start the server on plugin setup" - }, "host": { "type": "string", "default": "0.0.0.0", diff --git a/packages/appkit/src/plugins/server/tests/server.integration.test.ts b/packages/appkit/src/plugins/server/tests/server.integration.test.ts index c3a646ea..0b67e4c0 100644 --- a/packages/appkit/src/plugins/server/tests/server.integration.test.ts +++ b/packages/appkit/src/plugins/server/tests/server.integration.test.ts @@ -29,13 +29,10 @@ describe("ServerPlugin Integration", () => { serverPlugin({ port: TEST_PORT, host: "127.0.0.1", - autoStart: false, }), ], }); - // Start server manually - await app.server.start(); server = app.server.getServer(); baseUrl = `http://127.0.0.1:${TEST_PORT}`; @@ -124,13 +121,11 @@ describe("ServerPlugin with custom plugin", () => { serverPlugin({ port: TEST_PORT, host: "127.0.0.1", - autoStart: false, }), testPlugin({}), ], }); - await app.server.start(); server = app.server.getServer(); baseUrl = `http://127.0.0.1:${TEST_PORT}`; @@ -172,7 +167,7 @@ describe("ServerPlugin with custom plugin", () => { }); }); -describe("ServerPlugin with extend()", () => { +describe("ServerPlugin with extend() via onPluginsReady", () => { let server: Server; let baseUrl: string; let serviceContextMock: Awaited>; @@ -188,19 +183,73 @@ describe("ServerPlugin with extend()", () => { serverPlugin({ port: TEST_PORT, host: "127.0.0.1", - autoStart: false, }), ], + onPluginsReady(appkit) { + appkit.server.extend((expressApp) => { + expressApp.get("/custom", (_req, res) => { + res.json({ custom: true }); + }); + }); + }, }); - // Add custom route via extend() - app.server.extend((expressApp) => { - expressApp.get("/custom", (_req, res) => { - res.json({ custom: true }); + server = app.server.getServer(); + baseUrl = `http://127.0.0.1:${TEST_PORT}`; + + await new Promise((resolve) => setTimeout(resolve, 100)); + }); + + afterAll(async () => { + serviceContextMock?.restore(); + if (server) { + await new Promise((resolve, reject) => { + server.close((err) => { + if (err) reject(err); + else resolve(); + }); }); + } + }); + + test("custom route via extend() in onPluginsReady callback works", async () => { + const response = await fetch(`${baseUrl}/custom`); + + expect(response.status).toBe(200); + + const data = await response.json(); + expect(data).toEqual({ custom: true }); + }); +}); + +describe("createApp with async onPluginsReady callback", () => { + let server: Server; + let baseUrl: string; + let serviceContextMock: Awaited>; + const TEST_PORT = 9885; + + beforeAll(async () => { + setupDatabricksEnv(); + ServiceContext.reset(); + serviceContextMock = await mockServiceContext(); + + const app = await createApp({ + plugins: [ + serverPlugin({ + port: TEST_PORT, + host: "127.0.0.1", + }), + ], + async onPluginsReady(appkit) { + await new Promise((resolve) => setTimeout(resolve, 10)); + appkit.server.extend((expressApp) => { + expressApp.get("/async-custom", (_req, res) => { + res.json({ asyncSetup: true }); + }); + }); + }, }); - await app.server.start(); server = app.server.getServer(); baseUrl = `http://127.0.0.1:${TEST_PORT}`; @@ -219,12 +268,38 @@ describe("ServerPlugin with extend()", () => { } }); - test("custom route via extend() works", async () => { - const response = await fetch(`${baseUrl}/custom`); + test("async onPluginsReady callback runs before server starts", async () => { + const response = await fetch(`${baseUrl}/async-custom`); expect(response.status).toBe(200); const data = await response.json(); - expect(data).toEqual({ custom: true }); + expect(data).toEqual({ asyncSetup: true }); + }); +}); + +describe("createApp without server plugin", () => { + let serviceContextMock: Awaited>; + let onPluginsReadyWasCalled = false; + + beforeAll(async () => { + setupDatabricksEnv(); + ServiceContext.reset(); + serviceContextMock = await mockServiceContext(); + + await createApp({ + plugins: [], + onPluginsReady() { + onPluginsReadyWasCalled = true; + }, + }); + }); + + afterAll(async () => { + serviceContextMock?.restore(); + }); + + test("onPluginsReady callback is still called without server plugin", () => { + expect(onPluginsReadyWasCalled).toBe(true); }); }); diff --git a/packages/appkit/src/plugins/server/tests/server.test.ts b/packages/appkit/src/plugins/server/tests/server.test.ts index 22f18129..76088348 100644 --- a/packages/appkit/src/plugins/server/tests/server.test.ts +++ b/packages/appkit/src/plugins/server/tests/server.test.ts @@ -1,4 +1,6 @@ +import type { BasePlugin } from "shared"; import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { PluginContext } from "../../../core/plugin-context"; // Use vi.hoisted for mocks that need to be available before module loading const { @@ -171,6 +173,14 @@ import { RemoteTunnelController } from "../remote-tunnel/remote-tunnel-controlle import { StaticServer } from "../static-server"; import { ViteDevServer } from "../vite-dev-server"; +function createContextWithPlugins(plugins: Record): PluginContext { + const ctx = new PluginContext(); + for (const [name, instance] of Object.entries(plugins)) { + ctx.registerPlugin(name, instance as BasePlugin); + } + return ctx; +} + describe("ServerPlugin", () => { let originalEnv: NodeJS.ProcessEnv; @@ -197,19 +207,22 @@ describe("ServerPlugin", () => { const plugin = new ServerPlugin({ port: 3000, host: "127.0.0.1", - autoStart: false, }); const config = plugin.getConfig(); expect(config.port).toBe(3000); expect(config.host).toBe("127.0.0.1"); - expect(config.autoStart).toBe(false); + }); + + test("should throw when autoStart is passed", () => { + expect(() => new ServerPlugin({ autoStart: false } as any)).toThrow( + "server({ autoStart }) has been removed", + ); }); }); describe("DEFAULT_CONFIG", () => { test("should have correct default values", () => { - expect(ServerPlugin.DEFAULT_CONFIG.autoStart).toBe(true); expect(ServerPlugin.DEFAULT_CONFIG.host).toBe("0.0.0.0"); expect(ServerPlugin.DEFAULT_CONFIG.port).toBe(8000); }); @@ -220,30 +233,9 @@ describe("ServerPlugin", () => { }); }); - describe("shouldAutoStart", () => { - test("should return true when autoStart is true", () => { - const plugin = new ServerPlugin({ autoStart: true }); - expect(plugin.shouldAutoStart()).toBe(true); - }); - - test("should return false when autoStart is false", () => { - const plugin = new ServerPlugin({ autoStart: false }); - expect(plugin.shouldAutoStart()).toBe(false); - }); - }); - describe("setup", () => { - test("should call start when autoStart is true", async () => { - const plugin = new ServerPlugin({ autoStart: true }); - const startSpy = vi.spyOn(plugin, "start").mockResolvedValue({} as any); - - await plugin.setup(); - - expect(startSpy).toHaveBeenCalled(); - }); - - test("should not call start when autoStart is false", async () => { - const plugin = new ServerPlugin({ autoStart: false }); + test("should be a no-op (server start is orchestrated by createApp)", async () => { + const plugin = new ServerPlugin({}); const startSpy = vi.spyOn(plugin, "start").mockResolvedValue({} as any); await plugin.setup(); @@ -254,7 +246,7 @@ describe("ServerPlugin", () => { describe("start", () => { test("should call listen on express app", async () => { - const plugin = new ServerPlugin({ autoStart: false, port: 3000 }); + const plugin = new ServerPlugin({ port: 3000 }); await plugin.start(); @@ -267,7 +259,7 @@ describe("ServerPlugin", () => { test("should setup ViteDevServer in development mode", async () => { process.env.NODE_ENV = "development"; - const plugin = new ServerPlugin({ autoStart: false }); + const plugin = new ServerPlugin({}); await plugin.start(); @@ -277,7 +269,7 @@ describe("ServerPlugin", () => { }); test("should register RemoteTunnelController middleware and set server", async () => { - const plugin = new ServerPlugin({ autoStart: false }); + const plugin = new ServerPlugin({}); await plugin.start(); @@ -304,7 +296,7 @@ describe("ServerPlugin", () => { }, }; - const plugin = new ServerPlugin({ autoStart: false, plugins }); + const plugin = new ServerPlugin({ plugins }); await plugin.start(); // Get the type function passed to express.json @@ -340,7 +332,7 @@ describe("ServerPlugin", () => { process.env.NODE_ENV = "production"; const injectRoutes = vi.fn(); - const plugins: any = { + const testPlugins: any = { "test-plugin": { name: "test-plugin", injectRoutes, @@ -348,7 +340,9 @@ describe("ServerPlugin", () => { }, }; - const plugin = new ServerPlugin({ autoStart: false, plugins }); + const plugin = new ServerPlugin({ + context: createContextWithPlugins(testPlugins), + } as any); await plugin.start(); const routerFn = (express as any).Router as ReturnType; @@ -386,7 +380,9 @@ describe("ServerPlugin", () => { }, }; - const plugin = new ServerPlugin({ autoStart: false, plugins }); + const plugin = new ServerPlugin({ + context: createContextWithPlugins(plugins), + } as any); await plugin.start(); expect(plugins["plugin-a"].clientConfig).toHaveBeenCalled(); @@ -413,7 +409,9 @@ describe("ServerPlugin", () => { }, }; - const plugin = new ServerPlugin({ autoStart: false, plugins }); + const plugin = new ServerPlugin({ + context: createContextWithPlugins(plugins), + } as any); await plugin.start(); expect(plugins["plugin-null"].clientConfig).toHaveBeenCalled(); @@ -444,7 +442,9 @@ describe("ServerPlugin", () => { }, }; - const plugin = new ServerPlugin({ autoStart: false, plugins }); + const plugin = new ServerPlugin({ + context: createContextWithPlugins(plugins), + } as any); await expect(plugin.start()).resolves.toBeDefined(); expect(mockLoggerError).toHaveBeenCalledWith( "Plugin '%s' clientConfig() failed, skipping its config: %O", @@ -457,7 +457,7 @@ describe("ServerPlugin", () => { process.env.NODE_ENV = "production"; vi.mocked(fs.existsSync).mockReturnValue(true); - const plugin = new ServerPlugin({ autoStart: false }); + const plugin = new ServerPlugin({}); await plugin.start(); @@ -470,7 +470,7 @@ describe("ServerPlugin", () => { process.env.NODE_ENV = "production"; vi.mocked(fs.existsSync).mockReturnValue(false); - const plugin = new ServerPlugin({ autoStart: false }); + const plugin = new ServerPlugin({}); await plugin.start(); @@ -479,8 +479,8 @@ describe("ServerPlugin", () => { }); describe("extend", () => { - test("should add extension function when autoStart is false", () => { - const plugin = new ServerPlugin({ autoStart: false }); + test("should add extension function and return plugin for chaining", () => { + const plugin = new ServerPlugin({}); const extensionFn = vi.fn(); const result = plugin.extend(extensionFn); @@ -488,17 +488,8 @@ describe("ServerPlugin", () => { expect(result).toBe(plugin); }); - test("should throw when autoStart is true", () => { - const plugin = new ServerPlugin({ autoStart: true }); - const extensionFn = vi.fn(); - - expect(() => plugin.extend(extensionFn)).toThrow( - "Cannot extend server when autoStart is true", - ); - }); - test("should call extension functions during start", async () => { - const plugin = new ServerPlugin({ autoStart: false }); + const plugin = new ServerPlugin({}); const extensionFn = vi.fn(); plugin.extend(extensionFn); @@ -508,17 +499,18 @@ describe("ServerPlugin", () => { }); }); - describe("getServer", () => { - test("should throw when autoStart is true", () => { - const plugin = new ServerPlugin({ autoStart: true }); + describe("exports().start() trap", () => { + test("should throw migration error when start() is called via exports", () => { + const plugin = new ServerPlugin({}); + const exported = plugin.exports(); - expect(() => plugin.getServer()).toThrow( - "Cannot get server when autoStart is true", - ); + expect(() => exported.start()).toThrow("server.start() has been removed"); }); + }); + describe("getServer", () => { test("should throw when server not started", () => { - const plugin = new ServerPlugin({ autoStart: false }); + const plugin = new ServerPlugin({}); expect(() => plugin.getServer()).toThrow( "Server not started. Please start the server first by calling the start() method", @@ -526,7 +518,7 @@ describe("ServerPlugin", () => { }); test("should return server after start", async () => { - const plugin = new ServerPlugin({ autoStart: false }); + const plugin = new ServerPlugin({}); await plugin.start(); const server = plugin.getServer(); @@ -553,7 +545,7 @@ describe("ServerPlugin", () => { describe("logStartupInfo", () => { test("logs remote tunnel controller disabled when missing", () => { mockLoggerDebug.mockClear(); - const plugin = new ServerPlugin({ autoStart: false }); + const plugin = new ServerPlugin({}); (plugin as any).remoteTunnelController = undefined; (plugin as any).logStartupInfo(); @@ -565,7 +557,7 @@ describe("ServerPlugin", () => { test("logs remote tunnel allowed/active when controller present", () => { mockLoggerDebug.mockClear(); - const plugin = new ServerPlugin({ autoStart: false }); + const plugin = new ServerPlugin({}); (plugin as any).remoteTunnelController = { isAllowedByEnv: () => true, isActive: () => true, @@ -607,20 +599,19 @@ describe("ServerPlugin", () => { .mockImplementation(((_code?: number) => undefined) as any); const plugin = new ServerPlugin({ - autoStart: false, - plugins: { + context: createContextWithPlugins({ ok: { name: "ok", abortActiveOperations: vi.fn(), - } as any, + }, bad: { name: "bad", abortActiveOperations: vi.fn(() => { throw new Error("boom"); }), - } as any, - }, - }); + }, + }), + } as any); // pretend started (plugin as any).server = mockHttpServer; diff --git a/packages/appkit/src/plugins/server/types.ts b/packages/appkit/src/plugins/server/types.ts index e187cacc..f9f6ebce 100644 --- a/packages/appkit/src/plugins/server/types.ts +++ b/packages/appkit/src/plugins/server/types.ts @@ -5,6 +5,5 @@ export interface ServerConfig extends BasePluginConfig { port?: number; plugins?: Record; staticPath?: string; - autoStart?: boolean; host?: string; } diff --git a/packages/appkit/tsdown.config.ts b/packages/appkit/tsdown.config.ts index 97698714..0e6a4b6b 100644 --- a/packages/appkit/tsdown.config.ts +++ b/packages/appkit/tsdown.config.ts @@ -4,7 +4,12 @@ export default defineConfig([ { publint: true, name: "@databricks/appkit", - entry: "src/index.ts", + entry: [ + "src/index.ts", + "src/agents/vercel-ai.ts", + "src/agents/langchain.ts", + "src/agents/databricks.ts", + ], outDir: "dist", hash: false, format: "esm", diff --git a/packages/shared/src/agent.ts b/packages/shared/src/agent.ts new file mode 100644 index 00000000..5e22126b --- /dev/null +++ b/packages/shared/src/agent.ts @@ -0,0 +1,276 @@ +import type { JSONSchema7 } from "json-schema"; + +// --------------------------------------------------------------------------- +// Tool definitions +// --------------------------------------------------------------------------- + +/** + * Semantic hint for what the tool does to the world. Drives both the + * agents-plugin approval gate and the client's approval-card styling. + * + * - `read` — observes only; never needs approval. + * - `write` — creates or appends new state (e.g. saving a new view). Approval + * required by default. Rendered as a low-severity "writes" card. + * - `update` — mutates existing state in place (e.g. renaming, toggling). + * Approval required. Rendered as a medium-severity "updates" card. + * - `destructive` — deletes or irreversibly mutates (e.g. dropping a view). + * Approval required. Rendered as a high-severity "destructive" card. + * + * Prefer this over the legacy `readOnly`/`destructive` booleans: it lets the + * UI distinguish "captured a screenshot" from "deleted a dashboard", both of + * which today are lumped under a single red "destructive" label. + */ +export type ToolEffect = "read" | "write" | "update" | "destructive"; + +export interface ToolAnnotations { + /** + * Preferred semantic label. When set, drives both the approval gate (fires + * for `write`/`update`/`destructive`) and the approval-card styling. + */ + effect?: ToolEffect; + /** + * @deprecated Prefer {@link effect}. Retained for backward compatibility + * with tools authored against the original flags and for MCP interop. + */ + readOnly?: boolean; + /** + * @deprecated Prefer {@link effect} with value `"destructive"`. Retained + * so existing annotations continue to force the approval gate, and so + * MCP-style consumers that only read `destructive` still see the hint. + */ + destructive?: boolean; + idempotent?: boolean; + requiresUserContext?: boolean; +} + +export interface AgentToolDefinition { + name: string; + description: string; + parameters: JSONSchema7; + annotations?: ToolAnnotations; +} + +export interface ToolProvider { + getAgentTools(): AgentToolDefinition[]; + executeAgentTool( + name: string, + args: unknown, + signal?: AbortSignal, + ): Promise; +} + +// --------------------------------------------------------------------------- +// Messages & threads +// --------------------------------------------------------------------------- + +export interface Message { + id: string; + role: "user" | "assistant" | "system" | "tool"; + content: string; + toolCallId?: string; + toolCalls?: ToolCall[]; + createdAt: Date; +} + +export interface ToolCall { + id: string; + name: string; + args: unknown; +} + +export interface Thread { + id: string; + userId: string; + messages: Message[]; + createdAt: Date; + updatedAt: Date; +} + +// --------------------------------------------------------------------------- +// Thread store +// --------------------------------------------------------------------------- + +export interface ThreadStore { + create(userId: string): Promise; + get(threadId: string, userId: string): Promise; + list(userId: string): Promise; + addMessage(threadId: string, userId: string, message: Message): Promise; + delete(threadId: string, userId: string): Promise; +} + +// --------------------------------------------------------------------------- +// Agent events (SSE protocol) +// --------------------------------------------------------------------------- + +export type AgentEvent = + | { type: "message_delta"; content: string } + | { type: "message"; content: string } + | { type: "tool_call"; callId: string; name: string; args: unknown } + | { + type: "tool_result"; + callId: string; + result: unknown; + error?: string; + } + | { type: "thinking"; content: string } + | { + type: "status"; + status: "running" | "waiting" | "complete" | "error"; + error?: string; + } + | { type: "metadata"; data: Record } + | { + /** + * Emitted by the agents plugin (not adapters) when a tool call annotated + * `destructive: true` is awaiting human approval. Clients should render + * an approval prompt and POST to `/chat/approve` with the matching + * `approvalId` and a `decision` of `approve` or `deny`. + */ + type: "approval_pending"; + approvalId: string; + streamId: string; + toolName: string; + args: unknown; + annotations?: ToolAnnotations; + }; + +// --------------------------------------------------------------------------- +// Responses API types (OpenAI-compatible wire format for HTTP boundary) +// Self-contained — no openai package dependency. +// --------------------------------------------------------------------------- + +export interface OutputTextContent { + type: "output_text"; + text: string; +} + +export interface ResponseOutputMessage { + type: "message"; + id: string; + status: "in_progress" | "completed"; + role: "assistant"; + content: OutputTextContent[]; +} + +export interface ResponseFunctionToolCall { + type: "function_call"; + id: string; + call_id: string; + name: string; + arguments: string; +} + +export interface ResponseFunctionCallOutput { + type: "function_call_output"; + id: string; + call_id: string; + output: string; +} + +export type ResponseOutputItem = + | ResponseOutputMessage + | ResponseFunctionToolCall + | ResponseFunctionCallOutput; + +export interface ResponseOutputItemAddedEvent { + type: "response.output_item.added"; + output_index: number; + item: ResponseOutputItem; + sequence_number: number; +} + +export interface ResponseOutputItemDoneEvent { + type: "response.output_item.done"; + output_index: number; + item: ResponseOutputItem; + sequence_number: number; +} + +export interface ResponseTextDeltaEvent { + type: "response.output_text.delta"; + item_id: string; + output_index: number; + content_index: number; + delta: string; + sequence_number: number; +} + +export interface ResponseCompletedEvent { + type: "response.completed"; + sequence_number: number; + response: Record; +} + +export interface ResponseErrorEvent { + type: "error"; + error: string; + sequence_number: number; +} + +export interface ResponseFailedEvent { + type: "response.failed"; + sequence_number: number; +} + +export interface AppKitThinkingEvent { + type: "appkit.thinking"; + content: string; + sequence_number: number; +} + +export interface AppKitMetadataEvent { + type: "appkit.metadata"; + data: Record; + sequence_number: number; +} + +/** + * Emitted when a destructive tool call is awaiting human approval. The client + * should render an approval UI and POST the decision to `/chat/approve` with + * `{ streamId, approvalId, decision: "approve" | "deny" }`. If no decision + * arrives before the server-side timeout, the call is auto-denied and the + * agent receives a denial string as the tool output. + */ +export interface AppKitApprovalPendingEvent { + type: "appkit.approval_pending"; + approval_id: string; + stream_id: string; + tool_name: string; + args: unknown; + annotations?: ToolAnnotations; + sequence_number: number; +} + +export type ResponseStreamEvent = + | ResponseOutputItemAddedEvent + | ResponseOutputItemDoneEvent + | ResponseTextDeltaEvent + | ResponseCompletedEvent + | ResponseErrorEvent + | ResponseFailedEvent + | AppKitThinkingEvent + | AppKitMetadataEvent + | AppKitApprovalPendingEvent; + +// --------------------------------------------------------------------------- +// Adapter contract +// --------------------------------------------------------------------------- + +export interface AgentInput { + messages: Message[]; + tools: AgentToolDefinition[]; + threadId: string; + signal?: AbortSignal; +} + +export interface AgentRunContext { + executeTool: (name: string, args: unknown) => Promise; + signal?: AbortSignal; +} + +export interface AgentAdapter { + run( + input: AgentInput, + context: AgentRunContext, + ): AsyncGenerator; +} diff --git a/packages/shared/src/cli/commands/codemod/index.ts b/packages/shared/src/cli/commands/codemod/index.ts new file mode 100644 index 00000000..2f9c160d --- /dev/null +++ b/packages/shared/src/cli/commands/codemod/index.ts @@ -0,0 +1,17 @@ +import { Command } from "commander"; +import { onPluginsReadyCommand } from "./on-plugins-ready"; + +/** + * Parent command for codemod operations. + * Subcommands: + * - on-plugins-ready: Migrate from autoStart/extend/start to onPluginsReady callback + */ +export const codemodCommand = new Command("codemod") + .description("Run codemods to migrate to newer AppKit APIs") + .addCommand(onPluginsReadyCommand) + .addHelpText( + "after", + ` +Examples: + $ appkit codemod on-plugins-ready --write`, + ); diff --git a/packages/shared/src/cli/commands/codemod/on-plugins-ready.ts b/packages/shared/src/cli/commands/codemod/on-plugins-ready.ts new file mode 100644 index 00000000..37faeefe --- /dev/null +++ b/packages/shared/src/cli/commands/codemod/on-plugins-ready.ts @@ -0,0 +1,484 @@ +import fs from "node:fs"; +import path from "node:path"; +import { Lang, parse } from "@ast-grep/napi"; +import { Command } from "commander"; + +const SEARCH_DIRS = ["server", "src", "."]; +const CANDIDATE_NAMES = ["server.ts", "index.ts"]; +const SKIP_DIRS = new Set(["node_modules", "dist", "build", ".git"]); + +function findServerEntryFiles(rootDir: string): string[] { + const results: string[] = []; + + for (const dir of SEARCH_DIRS) { + const absDir = path.resolve(rootDir, dir); + if (!fs.existsSync(absDir)) continue; + + const files = + dir === "." + ? CANDIDATE_NAMES.map((n) => path.join(absDir, n)).filter(fs.existsSync) + : findTsFiles(absDir); + + for (const file of files) { + const content = fs.readFileSync(file, "utf-8"); + if ( + content.includes("createApp") && + content.includes("@databricks/appkit") + ) { + results.push(file); + } + } + } + + return [...new Set(results)]; +} + +function findTsFiles(dir: string, files: string[] = []): string[] { + let entries: fs.Dirent[]; + try { + entries = fs.readdirSync(dir, { withFileTypes: true }); + } catch { + return files; + } + + for (const entry of entries) { + const fullPath = path.join(dir, entry.name); + if (entry.isDirectory()) { + if (SKIP_DIRS.has(entry.name)) continue; + findTsFiles(fullPath, files); + } else if (entry.isFile() && entry.name.endsWith(".ts")) { + files.push(fullPath); + } + } + + return files; +} + +function isAlreadyMigrated(content: string): boolean { + const ast = parse(Lang.TypeScript, content); + const root = ast.root(); + return root.findAll("createApp({ $$$PROPS })").some((match) => { + const text = match.text(); + return /\bonPluginsReady\s*[(:]/.test(text); + }); +} + +/** + * Find the index of the matching closing delimiter for an opening one. + * Supports (), {}, and []. + */ +function findMatchingClose(content: string, openIdx: number): number { + const open = content[openIdx]; + const closeMap: Record = { + "(": ")", + "{": "}", + "[": "]", + }; + const close = closeMap[open]; + if (!close) return -1; + + let depth = 1; + let i = openIdx + 1; + while (i < content.length && depth > 0) { + const ch = content[i]; + if (ch === open) depth++; + else if (ch === close) depth--; + + // skip string literals + if (ch === '"' || ch === "'" || ch === "`") { + i = skipString(content, i); + continue; + } + i++; + } + return depth === 0 ? i - 1 : -1; +} + +function skipString(content: string, startIdx: number): number { + const quote = content[startIdx]; + let i = startIdx + 1; + while (i < content.length) { + if (content[i] === "\\") { + i += 2; + continue; + } + if (content[i] === quote) return i + 1; + i++; + } + return i; +} + +function stripAutoStartFromServerCalls(content: string): string { + return content.replace( + /server\(\{([^}]*)\}\)/g, + (_fullMatch, propsStr: string) => { + const cleaned = propsStr + .replace(/autoStart\s*:\s*(true|false)\s*,?\s*/g, "") + .replace(/,\s*$/, "") + .trim(); + if (!cleaned) return "server()"; + return `server({ ${cleaned} })`; + }, + ); +} + +interface MigrationResult { + migrated: boolean; + content: string; + warnings: string[]; +} + +function migratePatternA(content: string): MigrationResult { + const warnings: string[] = []; + + // Find createApp(...).then( + const createAppIdx = content.indexOf("createApp("); + if (createAppIdx === -1) return { migrated: false, content, warnings }; + + // Find the opening paren of createApp( + const configOpenParen = content.indexOf("(", createAppIdx); + const configCloseParen = findMatchingClose(content, configOpenParen); + if (configCloseParen === -1) return { migrated: false, content, warnings }; + + // Check for .then( after the closing paren + const afterCreateApp = content.slice(configCloseParen + 1); + const thenMatch = afterCreateApp.match(/^\s*\.then\s*\(/); + if (!thenMatch) return { migrated: false, content, warnings }; + + const thenStart = configCloseParen + 1 + afterCreateApp.indexOf(".then"); + const thenOpenParen = content.indexOf("(", thenStart + 4); + const thenCloseParen = findMatchingClose(content, thenOpenParen); + if (thenCloseParen === -1) return { migrated: false, content, warnings }; + + // Extract the callback inside .then(...) + const thenRaw = content.slice(thenOpenParen + 1, thenCloseParen); + const thenInner = thenRaw.trim(); + + // Parse callback: (param) => { body } or async (param) => { body } + const callbackMatch = thenInner.match( + /^(?:async\s+)?\(\s*(\w+)\s*\)\s*=>\s*\{/, + ); + if (!callbackMatch) return { migrated: false, content, warnings }; + + const paramName = callbackMatch[1]; + const bodyOpenBrace = thenOpenParen + 1 + thenRaw.indexOf("{"); + const bodyCloseBrace = findMatchingClose(content, bodyOpenBrace); + if (bodyCloseBrace === -1) return { migrated: false, content, warnings }; + + let callbackBody = content.slice(bodyOpenBrace + 1, bodyCloseBrace).trim(); + + // Remove entire statements that are just .start() calls (e.g. `await appkit.server.start();`) + callbackBody = callbackBody + .replace(/^\s*(?:await\s+)?\w+\.server\s*\.\s*start\(\s*\)\s*;?\s*$/gm, "") + .replace(/\n\s*\.start\(\s*\)\s*;?/g, ";") + .replace(/\.start\(\s*\)/g, "") + .replace(/\n\s*\n\s*\n/g, "\n\n") + .trim(); + + // Clean up trailing semicolons + if (callbackBody.endsWith(";")) { + // fine + } else if (!callbackBody.endsWith("}")) { + callbackBody += ";"; + } + + // Detect if the callback was async + const isAsync = /^async\s/.test(thenInner.trim()); + + // Check for .catch() after .then(...) using brace-aware parsing + const afterThenClose = content.slice(thenCloseParen + 1); + const catchPatternMatch = afterThenClose.match(/^\s*(?:\)\s*)?\.catch\s*\(/); + + let catchSuffix: string; + let consumeAfterThen: number; + + if (catchPatternMatch) { + const catchOpenParen = thenCloseParen + 1 + catchPatternMatch[0].length - 1; + const catchCloseParen = findMatchingClose(content, catchOpenParen); + if (catchCloseParen !== -1) { + const catchArg = content + .slice(catchOpenParen + 1, catchCloseParen) + .trim(); + catchSuffix = `.catch(${catchArg})`; + consumeAfterThen = catchCloseParen + 1 - (thenCloseParen + 1); + } else { + catchSuffix = ".catch(console.error)"; + consumeAfterThen = 0; + } + } else { + catchSuffix = ".catch(console.error)"; + consumeAfterThen = 0; + } + + // Build the onPluginsReady property + const configStr = content.slice(configOpenParen + 1, configCloseParen); + const lastBraceIdx = configStr.lastIndexOf("}"); + if (lastBraceIdx === -1) return { migrated: false, content, warnings }; + + const beforeLastBrace = configStr.slice(0, lastBraceIdx).trimEnd(); + const needsComma = beforeLastBrace.endsWith(",") ? "" : ","; + + // Indent the body properly + const bodyLines = callbackBody.split("\n"); + const indentedBody = bodyLines + .map((line) => ` ${line.trimStart()}`) + .join("\n"); + + const asyncPrefix = isAsync ? "async " : ""; + const onPluginsReadyProp = `${needsComma}\n ${asyncPrefix}onPluginsReady(${paramName}) {\n${indentedBody}\n },`; + const newConfig = `${beforeLastBrace}${onPluginsReadyProp}\n}`; + + // Build the replacement + const endIdx = thenCloseParen + 1 + consumeAfterThen; + // Consume trailing ) ; and whitespace + let finalEnd = endIdx; + const trailing = content.slice(finalEnd).match(/^\s*\)?\s*;?\s*/); + if (trailing) finalEnd += trailing[0].length; + + const newContent = + content.slice(0, createAppIdx) + + `createApp(${newConfig})${catchSuffix};` + + content.slice(finalEnd); + + return { migrated: true, content: newContent, warnings }; +} + +function migratePatternB(content: string): MigrationResult { + const warnings: string[] = []; + + // Match: const/let varName = await createApp({...}); + const awaitPattern = /(?:const|let)\s+(\w+)\s*=\s*await\s+createApp\s*\(/; + + const match = content.match(awaitPattern); + if (!match) return { migrated: false, content, warnings }; + + const varName = match[1]; + const matchIdx = content.indexOf(match[0]); + + // Find the createApp(...) closing paren + const configOpenParen = matchIdx + match[0].length - 1; + const configCloseParen = findMatchingClose(content, configOpenParen); + if (configCloseParen === -1) return { migrated: false, content, warnings }; + + // Find the semicolon after the createApp call + const afterCall = content.slice(configCloseParen + 1); + const semiMatch = afterCall.match(/^\s*;/); + const createAppEnd = + configCloseParen + 1 + (semiMatch ? semiMatch[0].length : 0); + + // Find all uses of varName after the createApp call + const afterCreateApp = content.slice(createAppEnd); + const varUsagePattern = new RegExp(`\\b${varName}\\.(\\w+)`, "g"); + + const usages: { plugin: string; index: number }[] = []; + for (const usageMatch of afterCreateApp.matchAll(varUsagePattern)) { + usages.push({ plugin: usageMatch[1], index: usageMatch.index }); + } + + // Check for non-server usage + const nonServerUsage = usages.filter((u) => u.plugin !== "server"); + if (nonServerUsage.length > 0) { + warnings.push( + `Found additional usage of '${varName}' handle outside server.extend/start. Please migrate manually.`, + ); + return { migrated: false, content, warnings }; + } + + // Find the extend call(s) and start call in the after-createApp region + const extendPattern = new RegExp( + `\\b${varName}\\.server\\.extend\\s*\\(`, + "g", + ); + const startPattern = new RegExp( + `(?:await\\s+)?${varName}\\.server\\.start\\s*\\(\\s*\\)\\s*;`, + ); + + const extendMatches = [...afterCreateApp.matchAll(extendPattern)]; + if (extendMatches.length > 1) { + warnings.push( + `Found ${extendMatches.length} server.extend() calls. Please migrate manually.`, + ); + return { migrated: false, content, warnings }; + } + + const extendExec = extendMatches[0] ?? null; + const startExec = startPattern.exec(afterCreateApp); + + if (!startExec) return { migrated: false, content, warnings }; + + // Extract the extend call's argument + let extendArg = ""; + let extendFullStatement = ""; + if (extendExec) { + const extendOpenParen = + createAppEnd + extendExec.index + extendExec[0].length - 1; + const extendCloseParen = findMatchingClose(content, extendOpenParen); + if (extendCloseParen !== -1) { + extendArg = content.slice(extendOpenParen + 1, extendCloseParen).trim(); + // Find the full statement including trailing semicolon + const stmtStart = createAppEnd + extendExec.index; + let stmtEnd = extendCloseParen + 1; + const afterExtend = content.slice(stmtEnd); + const trailingSemi = afterExtend.match(/^\s*;/); + if (trailingSemi) stmtEnd += trailingSemi[0].length; + extendFullStatement = content.slice(stmtStart, stmtEnd); + } + } + + const startFullStatement = startExec[0]; + + // Build the onPluginsReady callback + const configStr = content.slice(configOpenParen + 1, configCloseParen); + const lastBraceIdx = configStr.lastIndexOf("}"); + if (lastBraceIdx === -1) return { migrated: false, content, warnings }; + + const beforeLastBrace = configStr.slice(0, lastBraceIdx).trimEnd(); + const needsComma = beforeLastBrace.endsWith(",") ? "" : ","; + + let onPluginsReadyProp: string; + if (extendArg) { + onPluginsReadyProp = + `${needsComma}\n onPluginsReady(${varName}) {\n` + + ` ${varName}.server.extend(${extendArg});\n` + + " },"; + } else { + onPluginsReadyProp = ""; + } + + const newConfig = `${beforeLastBrace}${onPluginsReadyProp}\n}`; + const newCreateApp = `await createApp(${newConfig});`; + + // Replace: remove const declaration, replace with plain await, remove extend + start + let result = content.slice(0, matchIdx) + newCreateApp; + let remaining = afterCreateApp; + + if (extendFullStatement) { + remaining = remaining.replace(extendFullStatement, ""); + } + remaining = remaining.replace(startFullStatement, ""); + + // Clean up consecutive blank lines + remaining = remaining.replace(/\n\s*\n\s*\n/g, "\n\n"); + + result += remaining; + + return { migrated: true, content: result, warnings }; +} + +export function migrateFile(filePath: string): MigrationResult { + const original = fs.readFileSync(filePath, "utf-8"); + + if (isAlreadyMigrated(original)) { + return { + migrated: false, + content: original, + warnings: ["Already migrated -- no changes needed."], + }; + } + + const content = stripAutoStartFromServerCalls(original); + const allWarnings: string[] = []; + + // Try Pattern A first + const patternA = migratePatternA(content); + if (patternA.migrated) { + allWarnings.push(...patternA.warnings); + return { + migrated: true, + content: patternA.content, + warnings: allWarnings, + }; + } + allWarnings.push(...patternA.warnings); + + // Try Pattern B + const patternB = migratePatternB(content); + if (patternB.migrated) { + allWarnings.push(...patternB.warnings); + return { + migrated: true, + content: patternB.content, + warnings: allWarnings, + }; + } + allWarnings.push(...patternB.warnings); + + // Check if autoStart was stripped (content changed but no pattern matched) + if (content !== original) { + return { migrated: true, content, warnings: allWarnings }; + } + + return { migrated: false, content: original, warnings: allWarnings }; +} + +function runCodemod(options: { path?: string; write?: boolean }) { + const rootDir = process.cwd(); + const write = options.write ?? false; + + let files: string[]; + if (options.path) { + const absPath = path.resolve(rootDir, options.path); + if (!fs.existsSync(absPath)) { + console.error(`File not found: ${absPath}`); + process.exit(1); + } + files = [absPath]; + } else { + files = findServerEntryFiles(rootDir); + } + + if (files.length === 0) { + console.log("No files found importing createApp from @databricks/appkit."); + console.log("Use --path to specify a file explicitly."); + process.exit(0); + } + + let hasChanges = false; + + for (const file of files) { + const relPath = path.relative(rootDir, file); + const result = migrateFile(file); + + for (const warning of result.warnings) { + console.log(` ${relPath}: ${warning}`); + } + + if (!result.migrated) { + if (result.warnings.length === 0) { + console.log(` ${relPath}: No migration needed.`); + } + continue; + } + + hasChanges = true; + + if (write) { + fs.writeFileSync(file, result.content, "utf-8"); + console.log(` ${relPath}: Migrated successfully.`); + } else { + console.log(`\n--- ${relPath} (dry run) ---`); + console.log(result.content); + console.log("---"); + } + } + + if (hasChanges && !write) { + console.log("\nDry run complete. Run with --write to apply changes."); + } +} + +export const onPluginsReadyCommand = new Command("on-plugins-ready") + .description( + "Migrate createApp usage from autoStart/extend/start pattern to onPluginsReady callback", + ) + .option("--path ", "Path to the server entry file to migrate") + .option("--write", "Apply changes (default: dry-run)", false) + .addHelpText( + "after", + ` +Examples: + $ appkit codemod on-plugins-ready # dry-run, auto-detect files + $ appkit codemod on-plugins-ready --write # apply changes + $ appkit codemod on-plugins-ready --path server.ts # migrate a specific file`, + ) + .action(runCodemod); diff --git a/packages/shared/src/cli/commands/codemod/tests/fixtures/already-migrated.input.ts b/packages/shared/src/cli/commands/codemod/tests/fixtures/already-migrated.input.ts new file mode 100644 index 00000000..ab1cf6d0 --- /dev/null +++ b/packages/shared/src/cli/commands/codemod/tests/fixtures/already-migrated.input.ts @@ -0,0 +1,10 @@ +import { analytics, createApp, server } from "@databricks/appkit"; + +createApp({ + plugins: [server(), analytics({})], + onPluginsReady(appkit) { + appkit.server.extend((app) => { + app.get("/custom", (_req, res) => res.json({ ok: true })); + }); + }, +}).catch(console.error); diff --git a/packages/shared/src/cli/commands/codemod/tests/fixtures/autostart-true-with-port.input.ts b/packages/shared/src/cli/commands/codemod/tests/fixtures/autostart-true-with-port.input.ts new file mode 100644 index 00000000..96b70a4f --- /dev/null +++ b/packages/shared/src/cli/commands/codemod/tests/fixtures/autostart-true-with-port.input.ts @@ -0,0 +1,5 @@ +import { createApp, server } from "@databricks/appkit"; + +createApp({ + plugins: [server({ autoStart: true, port: 3000 })], +}).catch(console.error); diff --git a/packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-a-arrow-catch.input.ts b/packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-a-arrow-catch.input.ts new file mode 100644 index 00000000..b6ae8c8a --- /dev/null +++ b/packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-a-arrow-catch.input.ts @@ -0,0 +1,15 @@ +import { analytics, createApp, server } from "@databricks/appkit"; + +createApp({ + plugins: [server({ autoStart: false }), analytics({})], +}) + .then((appkit) => { + appkit.server + .extend((app) => { + app.get("/custom", (_req, res) => { + res.json({ ok: true }); + }); + }) + .start(); + }) + .catch((err) => console.error(err)); diff --git a/packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-a-with-catch.input.ts b/packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-a-with-catch.input.ts new file mode 100644 index 00000000..faa04d5e --- /dev/null +++ b/packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-a-with-catch.input.ts @@ -0,0 +1,15 @@ +import { analytics, createApp, server } from "@databricks/appkit"; + +createApp({ + plugins: [server({ autoStart: false }), analytics({})], +}) + .then((appkit) => { + appkit.server + .extend((app) => { + app.get("/custom", (_req, res) => { + res.json({ ok: true }); + }); + }) + .start(); + }) + .catch(console.error); diff --git a/packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-a.input.ts b/packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-a.input.ts new file mode 100644 index 00000000..73523d6a --- /dev/null +++ b/packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-a.input.ts @@ -0,0 +1,13 @@ +import { analytics, createApp, server } from "@databricks/appkit"; + +createApp({ + plugins: [server({ autoStart: false }), analytics({})], +}).then((appkit) => { + appkit.server + .extend((app) => { + app.get("/custom", (_req, res) => { + res.json({ ok: true }); + }); + }) + .start(); +}); diff --git a/packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-b-complex.input.ts b/packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-b-complex.input.ts new file mode 100644 index 00000000..c1fb25fa --- /dev/null +++ b/packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-b-complex.input.ts @@ -0,0 +1,15 @@ +import { analytics, createApp, server } from "@databricks/appkit"; + +const appkit = await createApp({ + plugins: [server({ autoStart: false }), analytics({})], +}); + +appkit.server.extend((app) => { + app.get("/custom", (_req, res) => { + res.json({ ok: true }); + }); +}); + +appkit.analytics.query("SELECT 1"); + +await appkit.server.start(); diff --git a/packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-b-multi-extend.input.ts b/packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-b-multi-extend.input.ts new file mode 100644 index 00000000..dded09f2 --- /dev/null +++ b/packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-b-multi-extend.input.ts @@ -0,0 +1,15 @@ +import { analytics, createApp, server } from "@databricks/appkit"; + +const appkit = await createApp({ + plugins: [server({ autoStart: false }), analytics({})], +}); + +appkit.server.extend((app) => { + app.get("/one", (_req, res) => res.json({ route: 1 })); +}); + +appkit.server.extend((app) => { + app.get("/two", (_req, res) => res.json({ route: 2 })); +}); + +await appkit.server.start(); diff --git a/packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-b.input.ts b/packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-b.input.ts new file mode 100644 index 00000000..b56c0048 --- /dev/null +++ b/packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-b.input.ts @@ -0,0 +1,13 @@ +import { analytics, createApp, server } from "@databricks/appkit"; + +const appkit = await createApp({ + plugins: [server({ autoStart: false }), analytics({})], +}); + +appkit.server.extend((app) => { + app.get("/custom", (_req, res) => { + res.json({ ok: true }); + }); +}); + +await appkit.server.start(); diff --git a/packages/shared/src/cli/commands/codemod/tests/on-plugins-ready.test.ts b/packages/shared/src/cli/commands/codemod/tests/on-plugins-ready.test.ts new file mode 100644 index 00000000..299e8f1d --- /dev/null +++ b/packages/shared/src/cli/commands/codemod/tests/on-plugins-ready.test.ts @@ -0,0 +1,129 @@ +import fs from "node:fs"; +import path from "node:path"; +import { describe, expect, test } from "vitest"; +import { migrateFile } from "../on-plugins-ready"; + +const fixturesDir = path.join(__dirname, "fixtures"); + +function readFixture(name: string): string { + return fs.readFileSync(path.join(fixturesDir, name), "utf-8"); +} + +describe("onPluginsReady-callback codemod", () => { + describe("Pattern A: .then() chain", () => { + test("migrates .then chain without .catch, adds .catch(console.error)", () => { + const fixturePath = path.join(fixturesDir, "pattern-a.input.ts"); + const result = migrateFile(fixturePath); + + expect(result.migrated).toBe(true); + expect(result.content).toContain("onPluginsReady(appkit)"); + expect(result.content).not.toContain(".then("); + expect(result.content).not.toContain(".start()"); + expect(result.content).not.toContain("autoStart"); + expect(result.content).toContain(".catch(console.error)"); + expect(result.content).toContain("server()"); + }); + + test("migrates .then chain with existing .catch, preserves it", () => { + const fixturePath = path.join( + fixturesDir, + "pattern-a-with-catch.input.ts", + ); + const result = migrateFile(fixturePath); + + expect(result.migrated).toBe(true); + expect(result.content).toContain("onPluginsReady(appkit)"); + expect(result.content).not.toContain(".then("); + expect(result.content).not.toContain(".start()"); + expect(result.content).toContain(".catch(console.error)"); + expect(result.content).toContain("server()"); + }); + + test("preserves the extend callback content", () => { + const fixturePath = path.join(fixturesDir, "pattern-a.input.ts"); + const result = migrateFile(fixturePath); + + expect(result.content).toContain('app.get("/custom"'); + expect(result.content).toContain("res.json({ ok: true })"); + }); + + test("preserves arrow function .catch handler with parens", () => { + const fixturePath = path.join( + fixturesDir, + "pattern-a-arrow-catch.input.ts", + ); + const result = migrateFile(fixturePath); + + expect(result.migrated).toBe(true); + expect(result.content).toContain(".catch((err) => console.error(err))"); + expect(result.content).not.toContain(".then("); + expect(result.content).not.toContain(".start()"); + }); + }); + + describe("Pattern B: await + imperative", () => { + test("migrates await pattern with extend + start", () => { + const fixturePath = path.join(fixturesDir, "pattern-b.input.ts"); + const result = migrateFile(fixturePath); + + expect(result.migrated).toBe(true); + expect(result.content).toContain("onPluginsReady(appkit)"); + expect(result.content).not.toContain("appkit.server.start()"); + expect(result.content).not.toContain("autoStart"); + expect(result.content).toContain("server()"); + }); + + test("bails out when non-server usage of appkit handle exists", () => { + const fixturePath = path.join(fixturesDir, "pattern-b-complex.input.ts"); + const result = migrateFile(fixturePath); + + expect(result.warnings.some((w) => w.includes("migrate manually"))).toBe( + true, + ); + expect(result.content).toContain("server()"); + expect(result.content).not.toContain("autoStart"); + }); + + test("bails out when multiple .extend() calls exist", () => { + const fixturePath = path.join( + fixturesDir, + "pattern-b-multi-extend.input.ts", + ); + const result = migrateFile(fixturePath); + + expect(result.warnings.some((w) => w.includes("migrate manually"))).toBe( + true, + ); + expect(result.content).toContain("server()"); + expect(result.content).not.toContain("autoStart"); + }); + }); + + describe("autoStart stripping", () => { + test("strips autoStart: true and preserves other config", () => { + const fixturePath = path.join( + fixturesDir, + "autostart-true-with-port.input.ts", + ); + const result = migrateFile(fixturePath); + + expect(result.migrated).toBe(true); + expect(result.content).not.toContain("autoStart"); + expect(result.content).toContain("port: 3000"); + expect(result.content).toContain("server({"); + }); + }); + + describe("idempotency", () => { + test("no-ops on already migrated file", () => { + const fixturePath = path.join(fixturesDir, "already-migrated.input.ts"); + const result = migrateFile(fixturePath); + + expect(result.migrated).toBe(false); + expect(result.warnings.some((w) => w.includes("Already migrated"))).toBe( + true, + ); + expect(result.content).toBe(readFixture("already-migrated.input.ts")); + }); + }); +}); diff --git a/packages/shared/src/cli/index.ts b/packages/shared/src/cli/index.ts index 71f09e6f..4d0ed65b 100644 --- a/packages/shared/src/cli/index.ts +++ b/packages/shared/src/cli/index.ts @@ -4,6 +4,7 @@ import { readFileSync } from "node:fs"; import { dirname, join } from "node:path"; import { fileURLToPath } from "node:url"; import { Command } from "commander"; +import { codemodCommand } from "./commands/codemod/index.js"; import { docsCommand } from "./commands/docs.js"; import { generateTypesCommand } from "./commands/generate-types.js"; import { lintCommand } from "./commands/lint.js"; @@ -26,5 +27,6 @@ cmd.addCommand(generateTypesCommand); cmd.addCommand(lintCommand); cmd.addCommand(docsCommand); cmd.addCommand(pluginCommand); +cmd.addCommand(codemodCommand); cmd.parse(); diff --git a/packages/shared/src/index.ts b/packages/shared/src/index.ts index 627d70d6..9829729a 100644 --- a/packages/shared/src/index.ts +++ b/packages/shared/src/index.ts @@ -1,3 +1,4 @@ +export * from "./agent"; export * from "./cache"; export * from "./execute"; export * from "./genie"; diff --git a/packages/shared/src/plugin.ts b/packages/shared/src/plugin.ts index 9fa8066c..651840c7 100644 --- a/packages/shared/src/plugin.ts +++ b/packages/shared/src/plugin.ts @@ -26,6 +26,15 @@ export interface BasePlugin { exports?(): unknown; clientConfig?(): Record; + + /** + * Binds runtime dependencies (telemetry, cache, plugin context) after the + * plugin has been constructed. Called by the AppKit core before `setup()`. + */ + attachContext?(deps: { + context?: unknown; + telemetryConfig?: TelemetryOptions; + }): void; } /** Base configuration interface for AppKit plugins */ diff --git a/packages/shared/tsconfig.json b/packages/shared/tsconfig.json index 4a6e68b3..5e195c3b 100644 --- a/packages/shared/tsconfig.json +++ b/packages/shared/tsconfig.json @@ -8,5 +8,5 @@ } }, "include": ["src/**/*"], - "exclude": ["node_modules", "dist"] + "exclude": ["node_modules", "dist", "src/**/fixtures"] } diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 9ca11b81..307f44cf 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -305,6 +305,9 @@ importers: express: specifier: 4.22.0 version: 4.22.0 + js-yaml: + specifier: ^4.1.1 + version: 4.1.1 obug: specifier: 2.1.1 version: 2.1.1 @@ -326,10 +329,22 @@ importers: ws: specifier: 8.18.3 version: 8.18.3(bufferutil@4.0.9) + zod: + specifier: ^4.0.0 + version: 4.1.13 devDependencies: + '@ai-sdk/openai': + specifier: 4.0.0-beta.27 + version: 4.0.0-beta.27(zod@4.1.13) + '@langchain/core': + specifier: ^1.1.39 + version: 1.1.39(@opentelemetry/api@1.9.0)(@opentelemetry/exporter-trace-otlp-proto@0.208.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.6.0(@opentelemetry/api@1.9.0))(ws@8.18.3(bufferutil@4.0.9)) '@types/express': specifier: 4.17.25 version: 4.17.25 + '@types/js-yaml': + specifier: ^4.0.9 + version: 4.0.9 '@types/json-schema': specifier: 7.0.15 version: 7.0.15 @@ -342,6 +357,9 @@ importers: '@vitejs/plugin-react': specifier: 5.1.1 version: 5.1.1(rolldown-vite@7.1.14(@types/node@25.2.3)(esbuild@0.25.10)(jiti@2.6.1)(terser@5.44.1)(tsx@4.20.6)(yaml@2.8.2)) + ai: + specifier: 7.0.0-beta.76 + version: 7.0.0-beta.76(zod@4.1.13) packages/appkit-ui: dependencies: @@ -561,16 +579,38 @@ packages: peerDependencies: zod: ^3.25.76 || ^4.1.8 + '@ai-sdk/gateway@4.0.0-beta.43': + resolution: {integrity: sha512-EGQe4If6jt1ZhENmwZn8UAeHbEc7DRiK7ff7dwgfNthwso2hdzLbgXzuTO+W/op+oDFQK1pKiAz5RrPsVQWiew==} + engines: {node: '>=18'} + peerDependencies: + zod: ^3.25.76 || ^4.1.8 + + '@ai-sdk/openai@4.0.0-beta.27': + resolution: {integrity: sha512-7DpXCE4pcc4pVzuEc0whMrQN6Whi14Qsqjx97mLPGjpS6Lff48Zcn2322IFpWuhVJ10hIM1kEZNxUYvXt1O/yg==} + engines: {node: '>=18'} + peerDependencies: + zod: ^3.25.76 || ^4.1.8 + '@ai-sdk/provider-utils@3.0.19': resolution: {integrity: sha512-W41Wc9/jbUVXVwCN/7bWa4IKe8MtxO3EyA0Hfhx6grnmiYlCvpI8neSYWFE0zScXJkgA/YK3BRybzgyiXuu6JA==} engines: {node: '>=18'} peerDependencies: zod: ^3.25.76 || ^4.1.8 + '@ai-sdk/provider-utils@5.0.0-beta.16': + resolution: {integrity: sha512-CyMV5go6libw5WaZ4m7nO0uRLTENxbIODiDrTXJNwxYIBR8p5aCGaxt9oj3prbvNkTt0Srh/Gyw+n2pR9hQ5Pg==} + engines: {node: '>=18'} + peerDependencies: + zod: ^3.25.76 || ^4.1.8 + '@ai-sdk/provider@2.0.0': resolution: {integrity: sha512-6o7Y2SeO9vFKB8lArHXehNuusnpddKPk7xqL7T2/b+OvXMRIXUO1rR4wcv1hAFUAT9avGZshty3Wlua/XA7TvA==} engines: {node: '>=18'} + '@ai-sdk/provider@4.0.0-beta.10': + resolution: {integrity: sha512-E2O/LCWjqOxAUfpykQR4xLmcGXySIu6L+wYJjav2xiHu38otPq0qIexgH9ZKulBvBWkrtJ3fxz0kzHDlCBkwng==} + engines: {node: '>=18'} + '@ai-sdk/react@2.0.115': resolution: {integrity: sha512-Etu7gWSEi2dmXss1PoR5CAZGwGShXsF9+Pon1eRO6EmatjYaBMhq1CfHPyYhGzWrint8jJIK2VaAhiMef29qZw==} engines: {node: '>=18'} @@ -1520,6 +1560,9 @@ packages: resolution: {integrity: sha512-hAs5PPKPCQ3/Nha+1fo4A4/gL85fIfxZwHPehsjCJ+BhQH2/yw6/xReuaPA/RfNQr6iz1PcD7BZcE3ctyyl3EA==} cpu: [x64] + '@cfworker/json-schema@4.1.1': + resolution: {integrity: sha512-gAmrUZSGtKc3AiBL71iNWxDsyUC5uMaKKGdvzYsBoTW/xi42JQHl7eKV2OYzCUqvc+D2RCcf7EXY2iCyFIk6og==} + '@chevrotain/cst-dts-gen@11.0.3': resolution: {integrity: sha512-BvIKpRLeS/8UbfxXxgC33xOumsacaeCKAjAeLyOn7Pcp95HiRbrpl14S+9vaZLolnbssPIUuiUd8IvgkRyt6NQ==} @@ -2646,6 +2689,10 @@ packages: '@kwsites/file-exists@1.1.1': resolution: {integrity: sha512-m9/5YGR18lIwxSFDwfE3oA7bWuq9kdau6ugN4H2rJeyhFQZcG9AgSHkQtSD15a8WvTgfz9aikZMrKPHvbpqFiw==} + '@langchain/core@1.1.39': + resolution: {integrity: sha512-DP9c7TREy6iA7HnywstmUAsNyJNYTFpRg2yBfQ+6H0l1HnvQzei9GsQ36GeOLxgRaD3vm9K8urCcawSC7yQpCw==} + engines: {node: '>=20'} + '@leichtgewicht/ip-codec@2.0.5': resolution: {integrity: sha512-Vo+PSpZG2/fmgmiNzYK9qWRh8h/CHrwD0mo1h1DzL4yzHNSfWYujGTYsWGreD000gcgmZ7K4Ys6Tx9TxtsKdDw==} @@ -4948,6 +4995,9 @@ packages: '@types/istanbul-reports@3.0.4': resolution: {integrity: sha512-pk2B1NWalF9toCRu6gjBzR69syFjP4Od8WRAX+0mmf9lAjCRicLOWc+ZrxZHx/0XRjotgkF9t6iaMJ+aXcOdZQ==} + '@types/js-yaml@4.0.9': + resolution: {integrity: sha512-k4MGaQl5TGo/iipqb2UDG2UwjXziSWkh0uysQelTlJpX1qGlpUZYm8PnO4DxG1qBomtJUdYJ6qR6xdIah10JLg==} + '@types/jsesc@2.5.1': resolution: {integrity: sha512-9VN+6yxLOPLOav+7PwjZbxiID2bVaeq0ED4qSQmdQTdjnXJSaCVKTR58t15oqH1H5t8Ng2ZX1SabJVoN9Q34bw==} @@ -5166,6 +5216,10 @@ packages: resolution: {integrity: sha512-fnYhv671l+eTTp48gB4zEsTW/YtRgRPnkI2nT7x6qw5rkI1Lq2hTmQIpHPgyThI0znLK+vX2n9XxKdXZ7BUbbw==} engines: {node: '>= 20'} + '@vercel/oidc@3.2.0': + resolution: {integrity: sha512-UycprH3T6n3jH0k44NHMa7pnFHGu/N05MjojYr+Mc6I7obkoLIJujSWwin1pCvdy/eOxrI/l3uDLQsmcrOb4ug==} + engines: {node: '>= 20'} + '@vitejs/plugin-react@5.0.4': resolution: {integrity: sha512-La0KD0vGkVkSk6K+piWDKRUyg8Rl5iAIKRMH0vMJI0Eg47bq1eOxmoObAaQG37WMW9MSyk7Cs8EIWwJC1PtzKA==} engines: {node: ^20.19.0 || >=22.12.0} @@ -5321,6 +5375,12 @@ packages: peerDependencies: zod: ^3.25.76 || ^4.1.8 + ai@7.0.0-beta.76: + resolution: {integrity: sha512-yJMCqsnfUi8jnFOvxmXhjMZd0YVSCLk1E5PZpqmGWynvo3uADt1XADYYYRcj0I9Q2wsL4HbCLAKe01I8aswzJg==} + engines: {node: '>=18'} + peerDependencies: + zod: ^3.25.76 || ^4.1.8 + ajv-formats@2.1.1: resolution: {integrity: sha512-Wx0Kx52hxE7C18hkMEggYlEifqWZtYaRgouJor+WMdPnQyEK13vgEWyVNup7SoeeoLMsr4kf5h6dOW11I15MUA==} peerDependencies: @@ -6453,6 +6513,10 @@ packages: supports-color: optional: true + decamelize@1.2.0: + resolution: {integrity: sha512-z2S+W9X73hAUUki+N+9Za2lBlun89zigOyGrsax+KUQ6wKW4ZoWpEYBkGhQjwAjjDCkWxhY0VKEhk8wzY7F5cA==} + engines: {node: '>=0.10.0'} + decimal.js-light@2.5.1: resolution: {integrity: sha512-qIMFpTMZmny+MMIitAB6D7iVPEorVw6YQRWkvarTkT4tBeSLLiHzcwj6q0MmYSFCiVpiqPJTJEYIrpcPzVEIvg==} @@ -8109,6 +8173,9 @@ packages: joi@17.13.3: resolution: {integrity: sha512-otDA4ldcIx+ZXsKHWmp0YizCweVRZG96J10b0FevjfuncLO1oX59THoAmHkNubYJ+9gWsYsp5k8v4ib6oDv1fA==} + js-tiktoken@1.0.21: + resolution: {integrity: sha512-biOj/6M5qdgx5TKjDnFT1ymSpM5tbd3ylwDtrQvFQSu0Z7bBYko2dF+W/aUkXUPuk6IVpRxk/3Q2sHOzGlS36g==} + js-tokens@4.0.0: resolution: {integrity: sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==} @@ -8238,6 +8305,26 @@ packages: resolution: {integrity: sha512-QJv/h939gDpvT+9SiLVlY7tZC3xB2qK57v0J04Sh9wpMb6MP1q8gB21L3WIo8T5P1MSMg3Ep14L7KkDCFG3y4w==} engines: {node: '>=16.0.0'} + langsmith@0.5.18: + resolution: {integrity: sha512-3zuZUWffTHQ+73EAwnodADtf534VNEZUpXr9jC12qyG8/IQuJET7PRsCpTb9wX2lmBspakwLUpqpj3tNm/0bVA==} + peerDependencies: + '@opentelemetry/api': '*' + '@opentelemetry/exporter-trace-otlp-proto': '*' + '@opentelemetry/sdk-trace-base': '*' + openai: '*' + ws: '>=7' + peerDependenciesMeta: + '@opentelemetry/api': + optional: true + '@opentelemetry/exporter-trace-otlp-proto': + optional: true + '@opentelemetry/sdk-trace-base': + optional: true + openai: + optional: true + ws: + optional: true + latest-version@7.0.0: resolution: {integrity: sha512-KvNT4XqAMzdcL6ka6Tl3i2lYeFDgXNCuIX+xNx6ZMVR1dFq+idXd9FLKNMOIx0t9mJ9/HudyX4oZWXZQ0UJHeg==} engines: {node: '>=14.16'} @@ -8910,6 +8997,10 @@ packages: resolution: {integrity: sha512-2eznPJP8z2BFLX50tf0LuODrpINqP1RVIm/CObbTcBRITQgmC/TjcREF1NeTBzIcR5XO/ukWo+YHOjBbFwIupg==} hasBin: true + mustache@4.2.0: + resolution: {integrity: sha512-71ippSywq5Yb7/tVYyGbkBggbU8H3u5Rz56fH60jGFgr8uHwxs+aSKeqmluIVzM0m0kB7xQjKS6qPfd0b2ZoqQ==} + hasBin: true + mute-stream@2.0.0: resolution: {integrity: sha512-WWdIxpyjEn+FhQJQQv9aQAYlHoNVdzIzUySNV1gHUPDSdZJ3yZn7pAAbQcV7B56Mvu881q9FZV+0Vx2xC44VWA==} engines: {node: ^18.17.0 || >=20.5.0} @@ -11463,6 +11554,10 @@ packages: resolution: {integrity: sha512-pMZTvIkT1d+TFGvDOqodOclx0QWkkgi6Tdoa8gC8ffGAAqz9pzPTZWAybbsHHoED/ztMtkv/VoYTYyShUn81hA==} engines: {node: '>= 0.4.0'} + uuid@10.0.0: + resolution: {integrity: sha512-8XkAphELsDnEGrDxUOHB3RGvXz6TeuYSGEZBOjtTtPm2lwhGBjLgOzLHB63IUWfBpNucQjND6d3AOudO+H3RWQ==} + hasBin: true + uuid@11.1.0: resolution: {integrity: sha512-0/A9rDy9P7cJ+8w1c9WD9V//9Wj15Ce2MPz8Ri6032usz+NfePxx5AcN3bN+r6ZL6jEo066/yNYB3tn4pQEx+A==} hasBin: true @@ -11937,6 +12032,19 @@ snapshots: '@vercel/oidc': 3.0.5 zod: 4.1.13 + '@ai-sdk/gateway@4.0.0-beta.43(zod@4.1.13)': + dependencies: + '@ai-sdk/provider': 4.0.0-beta.10 + '@ai-sdk/provider-utils': 5.0.0-beta.16(zod@4.1.13) + '@vercel/oidc': 3.2.0 + zod: 4.1.13 + + '@ai-sdk/openai@4.0.0-beta.27(zod@4.1.13)': + dependencies: + '@ai-sdk/provider': 4.0.0-beta.10 + '@ai-sdk/provider-utils': 5.0.0-beta.16(zod@4.1.13) + zod: 4.1.13 + '@ai-sdk/provider-utils@3.0.19(zod@4.1.13)': dependencies: '@ai-sdk/provider': 2.0.0 @@ -11944,10 +12052,21 @@ snapshots: eventsource-parser: 3.0.6 zod: 4.1.13 + '@ai-sdk/provider-utils@5.0.0-beta.16(zod@4.1.13)': + dependencies: + '@ai-sdk/provider': 4.0.0-beta.10 + '@standard-schema/spec': 1.1.0 + eventsource-parser: 3.0.6 + zod: 4.1.13 + '@ai-sdk/provider@2.0.0': dependencies: json-schema: 0.4.0 + '@ai-sdk/provider@4.0.0-beta.10': + dependencies: + json-schema: 0.4.0 + '@ai-sdk/react@2.0.115(react@19.2.0)(zod@4.1.13)': dependencies: '@ai-sdk/provider-utils': 3.0.19(zod@4.1.13) @@ -13078,6 +13197,8 @@ snapshots: '@cdxgen/cdxgen-plugins-bin@2.0.2': optional: true + '@cfworker/json-schema@4.1.1': {} + '@chevrotain/cst-dts-gen@11.0.3': dependencies: '@chevrotain/gast': 11.0.3 @@ -14858,6 +14979,26 @@ snapshots: transitivePeerDependencies: - supports-color + '@langchain/core@1.1.39(@opentelemetry/api@1.9.0)(@opentelemetry/exporter-trace-otlp-proto@0.208.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.6.0(@opentelemetry/api@1.9.0))(ws@8.18.3(bufferutil@4.0.9))': + dependencies: + '@cfworker/json-schema': 4.1.1 + '@standard-schema/spec': 1.1.0 + ansi-styles: 5.2.0 + camelcase: 6.3.0 + decamelize: 1.2.0 + js-tiktoken: 1.0.21 + langsmith: 0.5.18(@opentelemetry/api@1.9.0)(@opentelemetry/exporter-trace-otlp-proto@0.208.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.6.0(@opentelemetry/api@1.9.0))(ws@8.18.3(bufferutil@4.0.9)) + mustache: 4.2.0 + p-queue: 6.6.2 + uuid: 11.1.0 + zod: 4.1.13 + transitivePeerDependencies: + - '@opentelemetry/api' + - '@opentelemetry/exporter-trace-otlp-proto' + - '@opentelemetry/sdk-trace-base' + - openai + - ws + '@leichtgewicht/ip-codec@2.0.5': {} '@mdx-js/mdx@3.1.1': @@ -17289,6 +17430,8 @@ snapshots: dependencies: '@types/istanbul-lib-report': 3.0.3 + '@types/js-yaml@4.0.9': {} + '@types/jsesc@2.5.1': {} '@types/json-schema@7.0.15': {} @@ -17555,6 +17698,8 @@ snapshots: '@vercel/oidc@3.0.5': {} + '@vercel/oidc@3.2.0': {} + '@vitejs/plugin-react@5.0.4(vite@7.2.4(@types/node@24.7.2)(jiti@2.6.1)(lightningcss@1.30.2)(terser@5.44.1)(tsx@4.20.6)(yaml@2.8.2))': dependencies: '@babel/core': 7.28.5 @@ -17779,6 +17924,13 @@ snapshots: '@opentelemetry/api': 1.9.0 zod: 4.1.13 + ai@7.0.0-beta.76(zod@4.1.13): + dependencies: + '@ai-sdk/gateway': 4.0.0-beta.43(zod@4.1.13) + '@ai-sdk/provider': 4.0.0-beta.10 + '@ai-sdk/provider-utils': 5.0.0-beta.16(zod@4.1.13) + zod: 4.1.13 + ajv-formats@2.1.1(ajv@8.17.1): optionalDependencies: ajv: 8.17.1 @@ -19053,6 +19205,8 @@ snapshots: dependencies: ms: 2.1.3 + decamelize@1.2.0: {} + decimal.js-light@2.5.1: {} decimal.js@10.6.0: {} @@ -20873,6 +21027,10 @@ snapshots: '@sideway/formula': 3.0.1 '@sideway/pinpoint': 2.0.0 + js-tiktoken@1.0.21: + dependencies: + base64-js: 1.5.1 + js-tokens@4.0.0: {} js-tokens@9.0.1: {} @@ -21027,6 +21185,16 @@ snapshots: vscode-languageserver-textdocument: 1.0.12 vscode-uri: 3.0.8 + langsmith@0.5.18(@opentelemetry/api@1.9.0)(@opentelemetry/exporter-trace-otlp-proto@0.208.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.6.0(@opentelemetry/api@1.9.0))(ws@8.18.3(bufferutil@4.0.9)): + dependencies: + p-queue: 6.6.2 + uuid: 10.0.0 + optionalDependencies: + '@opentelemetry/api': 1.9.0 + '@opentelemetry/exporter-trace-otlp-proto': 0.208.0(@opentelemetry/api@1.9.0) + '@opentelemetry/sdk-trace-base': 2.6.0(@opentelemetry/api@1.9.0) + ws: 8.18.3(bufferutil@4.0.9) + latest-version@7.0.0: dependencies: package-json: 8.1.1 @@ -21964,6 +22132,8 @@ snapshots: dns-packet: 5.6.1 thunky: 1.1.0 + mustache@4.2.0: {} + mute-stream@2.0.0: {} nanoid@3.3.11: {} @@ -24753,6 +24923,8 @@ snapshots: utils-merge@1.0.1: {} + uuid@10.0.0: {} + uuid@11.1.0: {} uuid@13.0.0: {} diff --git a/template/appkit.plugins.json b/template/appkit.plugins.json index d1420d2e..c8589f9a 100644 --- a/template/appkit.plugins.json +++ b/template/appkit.plugins.json @@ -2,6 +2,30 @@ "$schema": "https://databricks.github.io/appkit/schemas/template-plugins.schema.json", "version": "1.0", "plugins": { + "agents": { + "name": "agents", + "displayName": "Agents Plugin", + "description": "AI agents driven by markdown configs or code, with auto-tool-discovery from registered plugins", + "package": "@databricks/appkit", + "resources": { + "required": [], + "optional": [ + { + "type": "serving_endpoint", + "alias": "Model Serving (agents)", + "resourceKey": "agents-serving-endpoint", + "description": "Databricks Model Serving endpoint for agents using workspace-hosted models (`DatabricksAdapter.fromModelServing`). Wire the same endpoint name AppKit reads from `DATABRICKS_AGENT_ENDPOINT` when no per-agent model is configured. Omit when agents use only external adapters.", + "permission": "CAN_QUERY", + "fields": { + "name": { + "env": "DATABRICKS_AGENT_ENDPOINT", + "description": "Endpoint name passed to Model Serving when agents default to `DatabricksAdapter.fromModelServing()`" + } + } + } + ] + } + }, "analytics": { "name": "analytics", "displayName": "Analytics Plugin", diff --git a/template/package-lock.json b/template/package-lock.json index 6cee76d7..a79f1eeb 100644 --- a/template/package-lock.json +++ b/template/package-lock.json @@ -10,8 +10,8 @@ "hasInstallScript": true, "license": "Unlicensed", "dependencies": { - "@databricks/appkit": "0.24.0", - "@databricks/appkit-ui": "0.24.0", + "@databricks/appkit": "0.25.1", + "@databricks/appkit-ui": "0.25.1", "@databricks/sdk-experimental": "0.14.2", "clsx": "2.1.1", "embla-carousel-react": "8.6.0", @@ -558,9 +558,9 @@ } }, "node_modules/@databricks/appkit": { - "version": "0.24.0", - "resolved": "https://registry.npmjs.org/@databricks/appkit/-/appkit-0.24.0.tgz", - "integrity": "sha512-GvQFUbp6FPo0CVNHnHTyZOyAhnRbwkw1oTIB82TczGH4XawIT06TS66QezdEq/Lt+3d3XBCmBdQwa3Zp/93+BA==", + "version": "0.25.1", + "resolved": "https://registry.npmjs.org/@databricks/appkit/-/appkit-0.25.1.tgz", + "integrity": "sha512-jYxHtl03bQ1F/0xusJix+0U/YjwcXxlvHO5vZUWcdsArc108pqcblFQZj+EfWu0OEZQVduOy2eqt5D3I3yhUFg==", "hasInstallScript": true, "license": "Apache-2.0", "dependencies": { @@ -601,9 +601,9 @@ } }, "node_modules/@databricks/appkit-ui": { - "version": "0.24.0", - "resolved": "https://registry.npmjs.org/@databricks/appkit-ui/-/appkit-ui-0.24.0.tgz", - "integrity": "sha512-UovTw4LF2n/+WLRxecSzZsI34Z6/iLOMIu51IUFEfJzLYOaON+OTFUMfxyWOlwWNjZsJIUdzld9rTRgaTckXhg==", + "version": "0.25.1", + "resolved": "https://registry.npmjs.org/@databricks/appkit-ui/-/appkit-ui-0.25.1.tgz", + "integrity": "sha512-brYU4SF97iFh27XhPN/rPZXBIoACngmkFM25Db3Ijnu96b1S22bkoeG0uBWkxoTgnrE21lfMTIv11LKrFO7AHQ==", "hasInstallScript": true, "license": "Apache-2.0", "dependencies": { diff --git a/template/package.json b/template/package.json index 2ca08581..b3488c02 100644 --- a/template/package.json +++ b/template/package.json @@ -32,8 +32,8 @@ "license": "Unlicensed", "description": "{{.appDescription}}", "dependencies": { - "@databricks/appkit": "0.24.0", - "@databricks/appkit-ui": "0.24.0", + "@databricks/appkit": "0.25.1", + "@databricks/appkit-ui": "0.25.1", "@databricks/sdk-experimental": "0.14.2", "clsx": "2.1.1", "embla-carousel-react": "8.6.0", diff --git a/template/server/server.ts b/template/server/server.ts index 214ac1ce..e28f3ef4 100644 --- a/template/server/server.ts +++ b/template/server/server.ts @@ -5,24 +5,13 @@ import { setupSampleLakebaseRoutes } from './routes/lakebase/todo-routes'; createApp({ plugins: [ -{{- if .plugins.lakebase}} - server({ autoStart: false }), -{{- range $name, $_ := .plugins}} -{{- if ne $name "server"}} - {{$name}}(), -{{- end}} -{{- end}} -{{- else}} {{- range $name, $_ := .plugins}} {{$name}}(), -{{- end}} {{- end}} ], -}) {{- if .plugins.lakebase}} - .then(async (appkit) => { + async onPluginsReady(appkit) { await setupSampleLakebaseRoutes(appkit); - await appkit.server.start(); - }) + }, {{- end}} - .catch(console.error); +}).catch(console.error);