From 1c858c067bb2d38222745b910b050a2b0dae5289 Mon Sep 17 00:00:00 2001 From: Pawel Kosiec Date: Thu, 23 Apr 2026 12:00:29 +0200 Subject: [PATCH 01/23] docs: fix Model Serving plugin naming for consistency (#309) Rename serving.md to model-serving.md and update the heading to "Model Serving plugin" to match the manifest displayName and the official Databricks product name. Signed-off-by: Pawel Kosiec --- docs/docs/faq.md | 2 +- docs/docs/plugins/{serving.md => model-serving.md} | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) rename docs/docs/plugins/{serving.md => model-serving.md} (99%) diff --git a/docs/docs/faq.md b/docs/docs/faq.md index b3fd50e1..41667cae 100644 --- a/docs/docs/faq.md +++ b/docs/docs/faq.md @@ -16,7 +16,7 @@ AppKit provides built-in integrations with the following Databricks services via | [Lakebase](./plugins/lakebase) | Lakebase Autoscaling (PostgreSQL) | Relational database access via standard pg.Pool with automatic OAuth token refresh | | [Genie](./plugins/genie) | AI/BI Genie Spaces | Natural language data queries with conversation management and streaming | | [Files](./plugins/files) | Unity Catalog Volumes | Multi-volume file operations (list, read, upload, download, delete, preview) | -| [Serving](./plugins/serving) | Model Serving | Authenticated proxy to Model Serving endpoints with invoke and streaming support | +| [Model Serving](./plugins/model-serving) | Model Serving | Authenticated proxy to Model Serving endpoints with invoke and streaming support | | [Server](./plugins/server) | N/A | Express HTTP server with static file serving, Vite dev mode, and plugin route injection | Stay tuned for new plugins as we constantly expand integrations! diff --git a/docs/docs/plugins/serving.md b/docs/docs/plugins/model-serving.md similarity index 99% rename from docs/docs/plugins/serving.md rename to docs/docs/plugins/model-serving.md index e2ee052b..00f60245 100644 --- a/docs/docs/plugins/serving.md +++ b/docs/docs/plugins/model-serving.md @@ -2,7 +2,7 @@ sidebar_position: 7 --- -# Serving plugin +# Model Serving plugin Provides an authenticated proxy to [Databricks Model Serving](https://docs.databricks.com/aws/en/machine-learning/model-serving) endpoints, with invoke and streaming support. From a9e98a66b21c184b8089cff45e0ab4eb64b0fa8a Mon Sep 17 00:00:00 2001 From: "databricks-appkit[bot]" Date: Thu, 23 Apr 2026 13:18:05 +0000 Subject: [PATCH 02/23] chore: release v0.25.0 [skip ci] Signed-off-by: databricks-appkit[bot] --- CHANGELOG.md | 9 +++++++++ packages/appkit-ui/package.json | 2 +- packages/appkit/package.json | 2 +- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 03a1e921..3ea18946 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,15 @@ All notable changes to this project will be documented in this file. # Changelog +# Changelog + +## [0.25.0](https://github.com/databricks/appkit/compare/v0.24.0...v0.25.0) (2026-04-23) + +### files + +* **files:** per-volume in-app policy enforcement ([#197](https://github.com/databricks/appkit/issues/197)) ([f54dca5](https://github.com/databricks/appkit/commit/f54dca5da5af5368c7bcb18745715b54a99d47e9)) + + ## [0.24.0](https://github.com/databricks/appkit/compare/v0.23.0...v0.24.0) (2026-04-20) * add AST extraction to serving type generator and move types to shared/ ([#279](https://github.com/databricks/appkit/issues/279)) ([422afb3](https://github.com/databricks/appkit/commit/422afb38aa73f8adb94e091225dc3381bd92cfcd)) diff --git a/packages/appkit-ui/package.json b/packages/appkit-ui/package.json index fa2953b1..a3e3f279 100644 --- a/packages/appkit-ui/package.json +++ b/packages/appkit-ui/package.json @@ -1,7 +1,7 @@ { "name": "@databricks/appkit-ui", "type": "module", - "version": "0.24.0", + "version": "0.25.0", "license": "Apache-2.0", "repository": { "type": "git", diff --git a/packages/appkit/package.json b/packages/appkit/package.json index 146be5a9..4b86106b 100644 --- a/packages/appkit/package.json +++ b/packages/appkit/package.json @@ -1,7 +1,7 @@ { "name": "@databricks/appkit", "type": "module", - "version": "0.24.0", + "version": "0.25.0", "main": "./dist/index.js", "types": "./dist/index.d.ts", "packageManager": "pnpm@10.21.0", From 2bc3e70e3f06ce02a12d183ffdc7d26863b8bc95 Mon Sep 17 00:00:00 2001 From: "databricks-appkit[bot]" Date: Thu, 23 Apr 2026 13:19:59 +0000 Subject: [PATCH 03/23] chore: sync template to v0.25.0 [skip ci] Signed-off-by: databricks-appkit[bot] --- template/package-lock.json | 16 ++++++++-------- template/package.json | 4 ++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/template/package-lock.json b/template/package-lock.json index 6cee76d7..77af65bc 100644 --- a/template/package-lock.json +++ b/template/package-lock.json @@ -10,8 +10,8 @@ "hasInstallScript": true, "license": "Unlicensed", "dependencies": { - "@databricks/appkit": "0.24.0", - "@databricks/appkit-ui": "0.24.0", + "@databricks/appkit": "0.25.0", + "@databricks/appkit-ui": "0.25.0", "@databricks/sdk-experimental": "0.14.2", "clsx": "2.1.1", "embla-carousel-react": "8.6.0", @@ -558,9 +558,9 @@ } }, "node_modules/@databricks/appkit": { - "version": "0.24.0", - "resolved": "https://registry.npmjs.org/@databricks/appkit/-/appkit-0.24.0.tgz", - "integrity": "sha512-GvQFUbp6FPo0CVNHnHTyZOyAhnRbwkw1oTIB82TczGH4XawIT06TS66QezdEq/Lt+3d3XBCmBdQwa3Zp/93+BA==", + "version": "0.25.0", + "resolved": "https://registry.npmjs.org/@databricks/appkit/-/appkit-0.25.0.tgz", + "integrity": "sha512-JbcZbUddMGrQiHlMqMdbpNdtW/riBs3QKi4HKduzswGgjwKD/xZsqyImSUlJq3zvsPzxscqC+EIFiIBYv2VxfA==", "hasInstallScript": true, "license": "Apache-2.0", "dependencies": { @@ -601,9 +601,9 @@ } }, "node_modules/@databricks/appkit-ui": { - "version": "0.24.0", - "resolved": "https://registry.npmjs.org/@databricks/appkit-ui/-/appkit-ui-0.24.0.tgz", - "integrity": "sha512-UovTw4LF2n/+WLRxecSzZsI34Z6/iLOMIu51IUFEfJzLYOaON+OTFUMfxyWOlwWNjZsJIUdzld9rTRgaTckXhg==", + "version": "0.25.0", + "resolved": "https://registry.npmjs.org/@databricks/appkit-ui/-/appkit-ui-0.25.0.tgz", + "integrity": "sha512-gW/OGjmpnE0nq/nx4M+W39RLlFid4soptqbhTmkRlp8RPdXQJJ08sHUGwGJXYSTzYelbzbK5hwHemFJdQa4Yng==", "hasInstallScript": true, "license": "Apache-2.0", "dependencies": { diff --git a/template/package.json b/template/package.json index 2ca08581..f9c32275 100644 --- a/template/package.json +++ b/template/package.json @@ -32,8 +32,8 @@ "license": "Unlicensed", "description": "{{.appDescription}}", "dependencies": { - "@databricks/appkit": "0.24.0", - "@databricks/appkit-ui": "0.24.0", + "@databricks/appkit": "0.25.0", + "@databricks/appkit-ui": "0.25.0", "@databricks/sdk-experimental": "0.14.2", "clsx": "2.1.1", "embla-carousel-react": "8.6.0", From 1c994a6d99f397b56e90f1b53df06a61f02b9e82 Mon Sep 17 00:00:00 2001 From: Mario Cadenas <17888484+MarioCadenas@users.noreply.github.com> Date: Mon, 27 Apr 2026 12:56:49 +0200 Subject: [PATCH 04/23] fix(appkit): check isRetryable before retrying in interceptor (#276) The RetryInterceptor retried all errors including AuthenticationError and ValidationError which have isRetryable=false. Now checks error.isRetryable before scheduling a retry attempt. Also handles Databricks SDK ApiError via duck-typed isRetryable() method and status-code heuristic (4xx not retried, 5xx/429 retried). Signed-off-by: MarioCadenas Co-authored-by: MarioCadenas --- .../appkit/src/plugin/interceptors/retry.ts | 37 +++- .../appkit/src/plugin/tests/retry.test.ts | 191 ++++++++++++++++++ 2 files changed, 224 insertions(+), 4 deletions(-) diff --git a/packages/appkit/src/plugin/interceptors/retry.ts b/packages/appkit/src/plugin/interceptors/retry.ts index 435e0fde..aeddf3d5 100644 --- a/packages/appkit/src/plugin/interceptors/retry.ts +++ b/packages/appkit/src/plugin/interceptors/retry.ts @@ -1,10 +1,38 @@ import type { RetryConfig } from "shared"; +import { AppKitError } from "../../errors/base"; import { createLogger } from "../../logging/logger"; import type { ExecutionInterceptor, InterceptorContext } from "./types"; const logger = createLogger("interceptors:retry"); -// interceptor to handle retry logic +/** + * Determines whether an error is safe to retry. + * + * Priority: + * 1. AppKitError — reads the `isRetryable` boolean property. + * 2. Databricks SDK ApiError (duck-typed) — calls `isRetryable()` method, + * or falls back to status-code heuristic (5xx / 429 → retryable). + * 3. Unknown errors — treated as retryable to preserve backward compatibility. + */ +function isRetryableError(error: unknown): boolean { + if (error instanceof AppKitError) { + return error.isRetryable; + } + + if (error instanceof Error && "statusCode" in error) { + const record = error as Record; + if (typeof record.statusCode !== "number") { + return true; + } + if (typeof record.isRetryable === "function") { + return (record.isRetryable as () => boolean)(); + } + return record.statusCode >= 500 || record.statusCode === 429; + } + + return true; +} + export class RetryInterceptor implements ExecutionInterceptor { private attempts: number; private initialDelay: number; @@ -36,7 +64,6 @@ export class RetryInterceptor implements ExecutionInterceptor { } catch (error) { lastError = error; - // last attempt, rethrow the error if (attempt === this.attempts) { logger.event()?.setExecution({ retry_attempts: attempt - 1, @@ -44,17 +71,19 @@ export class RetryInterceptor implements ExecutionInterceptor { throw error; } - // don't retry if was already aborted if (context.signal?.aborted) { throw error; } + if (!isRetryableError(error)) { + throw error; + } + const delay = this.calculateDelay(attempt); await this.sleep(delay); } } - // type guard throw lastError; } diff --git a/packages/appkit/src/plugin/tests/retry.test.ts b/packages/appkit/src/plugin/tests/retry.test.ts index b897c585..f6976f11 100644 --- a/packages/appkit/src/plugin/tests/retry.test.ts +++ b/packages/appkit/src/plugin/tests/retry.test.ts @@ -1,5 +1,9 @@ import type { RetryConfig } from "shared"; import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { AuthenticationError } from "../../errors/authentication"; +import { ConnectionError } from "../../errors/connection"; +import { ExecutionError } from "../../errors/execution"; +import { ValidationError } from "../../errors/validation"; import { RetryInterceptor } from "../interceptors/retry"; import type { InterceptorContext } from "../interceptors/types"; @@ -241,4 +245,191 @@ describe("RetryInterceptor", () => { vi.spyOn(Math, "random").mockRestore(); }); + + test("should not retry AuthenticationError (isRetryable=false)", async () => { + const config: RetryConfig = { + enabled: true, + attempts: 3, + initialDelay: 1000, + }; + const interceptor = new RetryInterceptor(config); + const fn = vi.fn().mockRejectedValue(AuthenticationError.missingToken()); + + await expect(interceptor.intercept(fn, context)).rejects.toThrow( + AuthenticationError, + ); + expect(fn).toHaveBeenCalledTimes(1); + }); + + test("should not retry ValidationError (isRetryable=false)", async () => { + const config: RetryConfig = { + enabled: true, + attempts: 3, + initialDelay: 1000, + }; + const interceptor = new RetryInterceptor(config); + const fn = vi.fn().mockRejectedValue(ValidationError.missingField("name")); + + await expect(interceptor.intercept(fn, context)).rejects.toThrow( + ValidationError, + ); + expect(fn).toHaveBeenCalledTimes(1); + }); + + test("should not retry ExecutionError (isRetryable=false)", async () => { + const config: RetryConfig = { + enabled: true, + attempts: 3, + initialDelay: 1000, + }; + const interceptor = new RetryInterceptor(config); + const fn = vi + .fn() + .mockRejectedValue(ExecutionError.statementFailed("syntax error")); + + await expect(interceptor.intercept(fn, context)).rejects.toThrow( + ExecutionError, + ); + expect(fn).toHaveBeenCalledTimes(1); + }); + + test("should retry ConnectionError (isRetryable=true)", async () => { + const config: RetryConfig = { + enabled: true, + attempts: 3, + initialDelay: 1000, + }; + const interceptor = new RetryInterceptor(config); + const fn = vi + .fn() + .mockRejectedValueOnce(ConnectionError.queryFailed()) + .mockResolvedValue("recovered"); + + const promise = interceptor.intercept(fn, context); + await vi.runAllTimersAsync(); + + expect(await promise).toBe("recovered"); + expect(fn).toHaveBeenCalledTimes(2); + }); + + test("should not retry errors with 4xx statusCode", async () => { + const config: RetryConfig = { + enabled: true, + attempts: 3, + initialDelay: 1000, + }; + const interceptor = new RetryInterceptor(config); + const error = Object.assign(new Error("bad request"), { statusCode: 400 }); + const fn = vi.fn().mockRejectedValue(error); + + await expect(interceptor.intercept(fn, context)).rejects.toThrow( + "bad request", + ); + expect(fn).toHaveBeenCalledTimes(1); + }); + + test("should retry errors with 5xx statusCode", async () => { + const config: RetryConfig = { + enabled: true, + attempts: 3, + initialDelay: 1000, + }; + const interceptor = new RetryInterceptor(config); + const fn = vi + .fn() + .mockRejectedValueOnce( + Object.assign(new Error("internal"), { statusCode: 500 }), + ) + .mockResolvedValue("recovered"); + + const promise = interceptor.intercept(fn, context); + await vi.runAllTimersAsync(); + + expect(await promise).toBe("recovered"); + expect(fn).toHaveBeenCalledTimes(2); + }); + + test("should retry errors with 429 statusCode (rate limit)", async () => { + const config: RetryConfig = { + enabled: true, + attempts: 3, + initialDelay: 1000, + }; + const interceptor = new RetryInterceptor(config); + const fn = vi + .fn() + .mockRejectedValueOnce( + Object.assign(new Error("rate limited"), { statusCode: 429 }), + ) + .mockResolvedValue("recovered"); + + const promise = interceptor.intercept(fn, context); + await vi.runAllTimersAsync(); + + expect(await promise).toBe("recovered"); + expect(fn).toHaveBeenCalledTimes(2); + }); + + test("should use isRetryable() method when available on error", async () => { + const config: RetryConfig = { + enabled: true, + attempts: 3, + initialDelay: 1000, + }; + const interceptor = new RetryInterceptor(config); + + const nonRetryable = Object.assign(new Error("not found"), { + statusCode: 404, + isRetryable: () => false, + }); + const fn = vi.fn().mockRejectedValue(nonRetryable); + + await expect(interceptor.intercept(fn, context)).rejects.toThrow( + "not found", + ); + expect(fn).toHaveBeenCalledTimes(1); + }); + + test("should respect isRetryable() returning true even for 4xx", async () => { + const config: RetryConfig = { + enabled: true, + attempts: 3, + initialDelay: 1000, + }; + const interceptor = new RetryInterceptor(config); + + const retryableClientError = Object.assign(new Error("conflict"), { + statusCode: 409, + isRetryable: () => true, + }); + const fn = vi + .fn() + .mockRejectedValueOnce(retryableClientError) + .mockResolvedValue("ok"); + + const promise = interceptor.intercept(fn, context); + await vi.runAllTimersAsync(); + + expect(await promise).toBe("ok"); + expect(fn).toHaveBeenCalledTimes(2); + }); + + test("should still retry plain Error (backward compatibility)", async () => { + const config: RetryConfig = { + enabled: true, + attempts: 3, + initialDelay: 1000, + }; + const interceptor = new RetryInterceptor(config); + const fn = vi + .fn() + .mockRejectedValueOnce(new Error("transient")) + .mockResolvedValue("ok"); + + const promise = interceptor.intercept(fn, context); + await vi.runAllTimersAsync(); + + expect(await promise).toBe("ok"); + expect(fn).toHaveBeenCalledTimes(2); + }); }); From a6ad84d451ada84668c9b7b2deb26538cc95787e Mon Sep 17 00:00:00 2001 From: "databricks-appkit[bot]" Date: Mon, 27 Apr 2026 11:44:01 +0000 Subject: [PATCH 05/23] chore: release v0.25.1 [skip ci] Signed-off-by: databricks-appkit[bot] --- CHANGELOG.md | 9 +++++++++ packages/appkit-ui/package.json | 2 +- packages/appkit/package.json | 2 +- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3ea18946..b32a1916 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,15 @@ All notable changes to this project will be documented in this file. # Changelog +# Changelog + +## [0.25.1](https://github.com/databricks/appkit/compare/v0.25.0...v0.25.1) (2026-04-27) + +### appkit + +* **appkit:** check isRetryable before retrying in interceptor ([#276](https://github.com/databricks/appkit/issues/276)) ([1c994a6](https://github.com/databricks/appkit/commit/1c994a6d99f397b56e90f1b53df06a61f02b9e82)) + + ## [0.25.0](https://github.com/databricks/appkit/compare/v0.24.0...v0.25.0) (2026-04-23) ### files diff --git a/packages/appkit-ui/package.json b/packages/appkit-ui/package.json index a3e3f279..beaca8b9 100644 --- a/packages/appkit-ui/package.json +++ b/packages/appkit-ui/package.json @@ -1,7 +1,7 @@ { "name": "@databricks/appkit-ui", "type": "module", - "version": "0.25.0", + "version": "0.25.1", "license": "Apache-2.0", "repository": { "type": "git", diff --git a/packages/appkit/package.json b/packages/appkit/package.json index 4b86106b..0e0af1ba 100644 --- a/packages/appkit/package.json +++ b/packages/appkit/package.json @@ -1,7 +1,7 @@ { "name": "@databricks/appkit", "type": "module", - "version": "0.25.0", + "version": "0.25.1", "main": "./dist/index.js", "types": "./dist/index.d.ts", "packageManager": "pnpm@10.21.0", From 5334308f6d423396d44e27ca48785d4035702624 Mon Sep 17 00:00:00 2001 From: "databricks-appkit[bot]" Date: Mon, 27 Apr 2026 11:45:52 +0000 Subject: [PATCH 06/23] chore: sync template to v0.25.1 [skip ci] Signed-off-by: databricks-appkit[bot] --- template/package-lock.json | 16 ++++++++-------- template/package.json | 4 ++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/template/package-lock.json b/template/package-lock.json index 77af65bc..a79f1eeb 100644 --- a/template/package-lock.json +++ b/template/package-lock.json @@ -10,8 +10,8 @@ "hasInstallScript": true, "license": "Unlicensed", "dependencies": { - "@databricks/appkit": "0.25.0", - "@databricks/appkit-ui": "0.25.0", + "@databricks/appkit": "0.25.1", + "@databricks/appkit-ui": "0.25.1", "@databricks/sdk-experimental": "0.14.2", "clsx": "2.1.1", "embla-carousel-react": "8.6.0", @@ -558,9 +558,9 @@ } }, "node_modules/@databricks/appkit": { - "version": "0.25.0", - "resolved": "https://registry.npmjs.org/@databricks/appkit/-/appkit-0.25.0.tgz", - "integrity": "sha512-JbcZbUddMGrQiHlMqMdbpNdtW/riBs3QKi4HKduzswGgjwKD/xZsqyImSUlJq3zvsPzxscqC+EIFiIBYv2VxfA==", + "version": "0.25.1", + "resolved": "https://registry.npmjs.org/@databricks/appkit/-/appkit-0.25.1.tgz", + "integrity": "sha512-jYxHtl03bQ1F/0xusJix+0U/YjwcXxlvHO5vZUWcdsArc108pqcblFQZj+EfWu0OEZQVduOy2eqt5D3I3yhUFg==", "hasInstallScript": true, "license": "Apache-2.0", "dependencies": { @@ -601,9 +601,9 @@ } }, "node_modules/@databricks/appkit-ui": { - "version": "0.25.0", - "resolved": "https://registry.npmjs.org/@databricks/appkit-ui/-/appkit-ui-0.25.0.tgz", - "integrity": "sha512-gW/OGjmpnE0nq/nx4M+W39RLlFid4soptqbhTmkRlp8RPdXQJJ08sHUGwGJXYSTzYelbzbK5hwHemFJdQa4Yng==", + "version": "0.25.1", + "resolved": "https://registry.npmjs.org/@databricks/appkit-ui/-/appkit-ui-0.25.1.tgz", + "integrity": "sha512-brYU4SF97iFh27XhPN/rPZXBIoACngmkFM25Db3Ijnu96b1S22bkoeG0uBWkxoTgnrE21lfMTIv11LKrFO7AHQ==", "hasInstallScript": true, "license": "Apache-2.0", "dependencies": { diff --git a/template/package.json b/template/package.json index f9c32275..b3488c02 100644 --- a/template/package.json +++ b/template/package.json @@ -32,8 +32,8 @@ "license": "Unlicensed", "description": "{{.appDescription}}", "dependencies": { - "@databricks/appkit": "0.25.0", - "@databricks/appkit-ui": "0.25.0", + "@databricks/appkit": "0.25.1", + "@databricks/appkit-ui": "0.25.1", "@databricks/sdk-experimental": "0.14.2", "clsx": "2.1.1", "embla-carousel-react": "8.6.0", From a06b6e392de2bab2004e7ea2ccf8bb080dfea545 Mon Sep 17 00:00:00 2001 From: Mario Cadenas <17888484+MarioCadenas@users.noreply.github.com> Date: Mon, 27 Apr 2026 14:49:44 +0200 Subject: [PATCH 07/23] feat: add onPluginsReady callback to createApp, remove autoStart (#280) * feat: add customize callback to createApp, remove autoStart Replace the post-await extend/start ceremony with a declarative `customize` callback on createApp config. The callback runs after plugin setup but before the server starts, giving access to the full appkit handle for registering custom routes or async setup. - Add `customize` option to createApp config - Server start is now orchestrated by createApp (lookup by name) - Remove `autoStart` from public API, ServerConfig, and manifest - Remove `start()` from server plugin exports - Remove autoStart guards from extend() and getServer() - Remove ServerError.autoStartConflict() - Migrate dev-playground, template, and all tests Signed-off-by: MarioCadenas * feat: rename customize to onPluginsReady, add codemod CLI and runtime detection Rename the lifecycle hook from `customize` to `onPluginsReady` to clearly communicate when it fires (after plugins are ready, before server starts). Add `appkit codemod customize-callback` CLI command that auto-migrates old autoStart/extend/start patterns to the new onPluginsReady callback. Supports both .then() chain (Pattern A) and await + imperative (Pattern B, with bail-out for complex cases). Add runtime detection that throws helpful errors when users pass autoStart to server() or call server.start() after upgrading, directing them to run the codemod. Signed-off-by: MarioCadenas * fix: exclude codemod fixture files from typecheck The test fixture .ts files import @databricks/appkit which doesn't exist in the shared package, causing tsc to fail in CI. Exclude the fixtures directory from the shared tsconfig. Signed-off-by: MarioCadenas * refactor: split codemod into separate PR Remove the codemod CLI from this PR to keep the review focused on the core lifecycle change. The codemod will land as a follow-up with bug fixes from review. Runtime detection (constructor autoStart throw + exports().start() trap) stays since it's part of the migration story. Signed-off-by: MarioCadenas * fix: add debug logging for onPluginsReady lifecycle hook Log when the onPluginsReady hook starts and completes to aid debugging in development mode. Signed-off-by: MarioCadenas * fix: rename codemod reference to on-plugins-ready Update runtime detection error messages to point users to `npx appkit codemod on-plugins-ready` to match the hook name. Signed-off-by: MarioCadenas --------- Signed-off-by: MarioCadenas Co-authored-by: MarioCadenas --- apps/dev-playground/server/index.ts | 13 +-- .../shared/appkit-types/analytics.d.ts | 8 +- docs/docs/api/appkit/Class.ServerError.md | 21 ---- docs/docs/api/appkit/Function.createApp.md | 21 ++-- docs/docs/plugins/server.md | 33 ++++-- packages/appkit/src/core/appkit.ts | 39 +++++-- packages/appkit/src/errors/server.ts | 10 -- .../appkit/src/errors/tests/errors.test.ts | 6 - .../tests/analytics.integration.test.ts | 2 - .../files/tests/plugin.integration.test.ts | 2 - packages/appkit/src/plugins/server/index.ts | 60 +++++----- .../appkit/src/plugins/server/manifest.json | 5 - .../server/tests/server.integration.test.ts | 105 +++++++++++++++--- .../src/plugins/server/tests/server.test.ts | 91 ++++++--------- packages/appkit/src/plugins/server/types.ts | 1 - template/server/server.ts | 17 +-- 16 files changed, 236 insertions(+), 198 deletions(-) diff --git a/apps/dev-playground/server/index.ts b/apps/dev-playground/server/index.ts index 8a77b76c..94f1cc12 100644 --- a/apps/dev-playground/server/index.ts +++ b/apps/dev-playground/server/index.ts @@ -50,7 +50,7 @@ const adminOnly: FilePolicy = (action, _resource, user) => { createApp({ plugins: [ - server({ autoStart: false }), + server(), reconnect(), telemetryExamples(), analytics({}), @@ -95,9 +95,8 @@ createApp({ // }), ], ...(process.env.APPKIT_E2E_TEST && { client: createMockClient() }), -}).then((appkit) => { - appkit.server - .extend((app) => { + onPluginsReady(appkit) { + appkit.server.extend((app) => { app.get("/sp", (_req, res) => { appkit.analytics .query("SELECT * FROM samples.nyctaxi.trips;") @@ -195,9 +194,9 @@ createApp({ results, }); }); - }) - .start(); -}); + }); + }, +}).catch(console.error); type ProbeResult = { volume: string; diff --git a/apps/dev-playground/shared/appkit-types/analytics.d.ts b/apps/dev-playground/shared/appkit-types/analytics.d.ts index 0e0ae0b0..43666dd0 100644 --- a/apps/dev-playground/shared/appkit-types/analytics.d.ts +++ b/apps/dev-playground/shared/appkit-types/analytics.d.ts @@ -119,10 +119,10 @@ declare module "@databricks/appkit-ui/react" { result: Array<{ /** @sqlType STRING */ string_value: string; - /** @sqlType STRING */ - number_value: string; - /** @sqlType STRING */ - boolean_value: string; + /** @sqlType INT */ + number_value: number; + /** @sqlType BOOLEAN */ + boolean_value: boolean; /** @sqlType STRING */ date_value: string; /** @sqlType STRING */ diff --git a/docs/docs/api/appkit/Class.ServerError.md b/docs/docs/api/appkit/Class.ServerError.md index d3dce68e..cce86cad 100644 --- a/docs/docs/api/appkit/Class.ServerError.md +++ b/docs/docs/api/appkit/Class.ServerError.md @@ -6,7 +6,6 @@ Use for server start/stop issues, configuration conflicts, etc. ## Example ```typescript -throw new ServerError("Cannot get server when autoStart is true"); throw new ServerError("Server not started"); ``` @@ -151,26 +150,6 @@ Create a human-readable string representation *** -### autoStartConflict() - -```ts -static autoStartConflict(operation: string): ServerError; -``` - -Create a server error for autoStart conflict - -#### Parameters - -| Parameter | Type | -| ------ | ------ | -| `operation` | `string` | - -#### Returns - -`ServerError` - -*** - ### clientDirectoryNotFound() ```ts diff --git a/docs/docs/api/appkit/Function.createApp.md b/docs/docs/api/appkit/Function.createApp.md index cb703386..6a0b7cb2 100644 --- a/docs/docs/api/appkit/Function.createApp.md +++ b/docs/docs/api/appkit/Function.createApp.md @@ -4,6 +4,7 @@ function createApp(config: { cache?: CacheConfig; client?: WorkspaceClient; + onPluginsReady?: (appkit: PluginMap) => void | Promise; plugins?: T; telemetry?: TelemetryConfig; }): Promise>; @@ -13,6 +14,9 @@ Bootstraps AppKit with the provided configuration. Initializes telemetry, cache, and service context, then registers plugins in phase order (core, normal, deferred) and awaits their setup. +If a `onPluginsReady` callback is provided it runs after plugin setup but +before the server starts, giving you access to the full appkit handle +for registering custom routes or performing async setup. The returned object maps each plugin name to its `exports()` API, with an `asUser(req)` method for user-scoped execution. @@ -26,9 +30,10 @@ with an `asUser(req)` method for user-scoped execution. | Parameter | Type | | ------ | ------ | -| `config` | \{ `cache?`: [`CacheConfig`](Interface.CacheConfig.md); `client?`: `WorkspaceClient`; `plugins?`: `T`; `telemetry?`: [`TelemetryConfig`](Interface.TelemetryConfig.md); \} | +| `config` | \{ `cache?`: [`CacheConfig`](Interface.CacheConfig.md); `client?`: `WorkspaceClient`; `onPluginsReady?`: (`appkit`: `PluginMap`\<`T`\>) => `void` \| `Promise`\<`void`\>; `plugins?`: `T`; `telemetry?`: [`TelemetryConfig`](Interface.TelemetryConfig.md); \} | | `config.cache?` | [`CacheConfig`](Interface.CacheConfig.md) | | `config.client?` | `WorkspaceClient` | +| `config.onPluginsReady?` | (`appkit`: `PluginMap`\<`T`\>) => `void` \| `Promise`\<`void`\> | | `config.plugins?` | `T` | | `config.telemetry?` | [`TelemetryConfig`](Interface.TelemetryConfig.md) | @@ -51,12 +56,12 @@ await createApp({ ```ts import { createApp, server, analytics } from "@databricks/appkit"; -const appkit = await createApp({ - plugins: [server({ autoStart: false }), analytics({})], -}); - -appkit.server.extend((app) => { - app.get("/custom", (_req, res) => res.json({ ok: true })); +await createApp({ + plugins: [server(), analytics({})], + onPluginsReady(appkit) { + appkit.server.extend((app) => { + app.get("/custom", (_req, res) => res.json({ ok: true })); + }); + }, }); -await appkit.server.start(); ``` diff --git a/docs/docs/plugins/server.md b/docs/docs/plugins/server.md index 389828dc..6cfaa7b7 100644 --- a/docs/docs/plugins/server.md +++ b/docs/docs/plugins/server.md @@ -36,22 +36,38 @@ await createApp({ }); ``` -## Manual server start example +## Custom routes example -When you need to extend Express with custom routes: +Use the `onPluginsReady` callback to extend Express with custom routes before the server starts: ```ts import { createApp, server } from "@databricks/appkit"; -const appkit = await createApp({ - plugins: [server({ autoStart: false })], +await createApp({ + plugins: [server()], + onPluginsReady(appkit) { + appkit.server.extend((app) => { + app.get("/custom", (_req, res) => res.json({ ok: true })); + }); + }, }); +``` -appkit.server.extend((app) => { - app.get("/custom", (_req, res) => res.json({ ok: true })); -}); +The `onPluginsReady` callback also supports async operations: -await appkit.server.start(); +```ts +await createApp({ + plugins: [server()], + async onPluginsReady(appkit) { + const pool = await initializeDatabase(); + appkit.server.extend((app) => { + app.get("/data", async (_req, res) => { + const result = await pool.query("SELECT 1"); + res.json(result); + }); + }); + }, +}); ``` ## Configuration options @@ -64,7 +80,6 @@ await createApp({ server({ port: 8000, // default: Number(process.env.DATABRICKS_APP_PORT) || 8000 host: "0.0.0.0", // default: process.env.FLASK_RUN_HOST || "0.0.0.0" - autoStart: true, // default: true staticPath: "dist", // optional: force a specific static directory }), ], diff --git a/packages/appkit/src/core/appkit.ts b/packages/appkit/src/core/appkit.ts index a2cba994..607a1552 100644 --- a/packages/appkit/src/core/appkit.ts +++ b/packages/appkit/src/core/appkit.ts @@ -10,10 +10,13 @@ import type { } from "shared"; import { CacheManager } from "../cache"; import { ServiceContext } from "../context"; +import { createLogger } from "../logging/logger"; import { ResourceRegistry, ResourceType } from "../registry"; import type { TelemetryConfig } from "../telemetry"; import { TelemetryManager } from "../telemetry"; +const logger = createLogger("appkit"); + export class AppKit { #pluginInstances: Record = {}; #setupPromises: Promise[] = []; @@ -167,6 +170,7 @@ export class AppKit { telemetry?: TelemetryConfig; cache?: CacheConfig; client?: WorkspaceClient; + onPluginsReady?: (appkit: PluginMap) => void | Promise; } = {}, ): Promise> { // Initialize core services @@ -200,7 +204,20 @@ export class AppKit { await Promise.all(instance.#setupPromises); - return instance as unknown as PluginMap; + const handle = instance as unknown as PluginMap; + + if (config.onPluginsReady) { + logger.debug("Running onPluginsReady hook"); + await config.onPluginsReady(handle); + logger.debug("onPluginsReady hook completed"); + } + + const serverPlugin = instance.#pluginInstances.server; + if (serverPlugin && typeof (serverPlugin as any).start === "function") { + await (serverPlugin as any).start(); + } + + return handle; } private static preparePlugins( @@ -222,6 +239,9 @@ export class AppKit { * * Initializes telemetry, cache, and service context, then registers plugins * in phase order (core, normal, deferred) and awaits their setup. + * If a `onPluginsReady` callback is provided it runs after plugin setup but + * before the server starts, giving you access to the full appkit handle + * for registering custom routes or performing async setup. * The returned object maps each plugin name to its `exports()` API, * with an `asUser(req)` method for user-scoped execution. * @@ -236,18 +256,18 @@ export class AppKit { * }); * ``` * - * @example Extended Server with analytics and custom endpoint + * @example Server with custom routes via onPluginsReady * ```ts * import { createApp, server, analytics } from "@databricks/appkit"; * - * const appkit = await createApp({ - * plugins: [server({ autoStart: false }), analytics({})], - * }); - * - * appkit.server.extend((app) => { - * app.get("/custom", (_req, res) => res.json({ ok: true })); + * await createApp({ + * plugins: [server(), analytics({})], + * onPluginsReady(appkit) { + * appkit.server.extend((app) => { + * app.get("/custom", (_req, res) => res.json({ ok: true })); + * }); + * }, * }); - * await appkit.server.start(); * ``` */ export async function createApp< @@ -258,6 +278,7 @@ export async function createApp< telemetry?: TelemetryConfig; cache?: CacheConfig; client?: WorkspaceClient; + onPluginsReady?: (appkit: PluginMap) => void | Promise; } = {}, ): Promise> { return AppKit._createApp(config); diff --git a/packages/appkit/src/errors/server.ts b/packages/appkit/src/errors/server.ts index 6af5b59f..d45148d8 100644 --- a/packages/appkit/src/errors/server.ts +++ b/packages/appkit/src/errors/server.ts @@ -6,7 +6,6 @@ import { AppKitError } from "./base"; * * @example * ```typescript - * throw new ServerError("Cannot get server when autoStart is true"); * throw new ServerError("Server not started"); * ``` */ @@ -15,15 +14,6 @@ export class ServerError extends AppKitError { readonly statusCode = 500; readonly isRetryable = false; - /** - * Create a server error for autoStart conflict - */ - static autoStartConflict(operation: string): ServerError { - return new ServerError(`Cannot ${operation} when autoStart is true`, { - context: { operation }, - }); - } - /** * Create a server error for server not started */ diff --git a/packages/appkit/src/errors/tests/errors.test.ts b/packages/appkit/src/errors/tests/errors.test.ts index c404a18f..347ce1c0 100644 --- a/packages/appkit/src/errors/tests/errors.test.ts +++ b/packages/appkit/src/errors/tests/errors.test.ts @@ -348,12 +348,6 @@ describe("ServerError", () => { expect(error.isRetryable).toBe(false); }); - test("autoStartConflict should create proper error", () => { - const error = ServerError.autoStartConflict("get server"); - expect(error.message).toBe("Cannot get server when autoStart is true"); - expect(error.context?.operation).toBe("get server"); - }); - test("notStarted should create proper error", () => { const error = ServerError.notStarted(); expect(error.message).toContain("Server not started"); diff --git a/packages/appkit/src/plugins/analytics/tests/analytics.integration.test.ts b/packages/appkit/src/plugins/analytics/tests/analytics.integration.test.ts index cb73394a..0cec2298 100644 --- a/packages/appkit/src/plugins/analytics/tests/analytics.integration.test.ts +++ b/packages/appkit/src/plugins/analytics/tests/analytics.integration.test.ts @@ -46,13 +46,11 @@ describe("Analytics Plugin Integration", () => { serverPlugin({ port: TEST_PORT, host: "127.0.0.1", - autoStart: false, }), analytics({}), ], }); - await app.server.start(); server = app.server.getServer(); baseUrl = `http://127.0.0.1:${TEST_PORT}`; }); diff --git a/packages/appkit/src/plugins/files/tests/plugin.integration.test.ts b/packages/appkit/src/plugins/files/tests/plugin.integration.test.ts index da90760d..3c8ff74e 100644 --- a/packages/appkit/src/plugins/files/tests/plugin.integration.test.ts +++ b/packages/appkit/src/plugins/files/tests/plugin.integration.test.ts @@ -87,13 +87,11 @@ describe("Files Plugin Integration", () => { serverPlugin({ port: TEST_PORT, host: "127.0.0.1", - autoStart: false, }), files(), ], }); - await appkit.server.start(); server = appkit.server.getServer(); baseUrl = `http://127.0.0.1:${TEST_PORT}`; }); diff --git a/packages/appkit/src/plugins/server/index.ts b/packages/appkit/src/plugins/server/index.ts index e7b9b31a..8ed13cea 100644 --- a/packages/appkit/src/plugins/server/index.ts +++ b/packages/appkit/src/plugins/server/index.ts @@ -27,17 +27,23 @@ const logger = createLogger("server"); * This plugin is responsible for starting the server and serving the static files. * It also handles the remote tunneling for development purposes. * + * The server is started automatically by `createApp` after all plugins are set up + * and the optional `onPluginsReady` callback has run. + * * @example * ```ts * createApp({ - * plugins: [server(), telemetryExamples(), analytics({})], + * plugins: [server(), analytics({})], + * onPluginsReady(appkit) { + * appkit.server.extend((app) => { + * app.get("/custom", (_req, res) => res.json({ ok: true })); + * }); + * }, * }); * ``` - * */ export class ServerPlugin extends Plugin { public static DEFAULT_CONFIG = { - autoStart: true, host: process.env.FLASK_RUN_HOST || "0.0.0.0", port: Number(process.env.DATABRICKS_APP_PORT) || 8000, }; @@ -54,6 +60,13 @@ export class ServerPlugin extends Plugin { static phase: PluginPhase = "deferred"; constructor(config: ServerConfig) { + if ("autoStart" in config) { + throw new ServerError( + "server({ autoStart }) has been removed. " + + "The server is now started automatically by createApp.\n\n" + + "Run `npx appkit codemod on-plugins-ready --write` to auto-migrate.", + ); + } super(config); this.config = config; this.serverApplication = express(); @@ -65,12 +78,7 @@ export class ServerPlugin extends Plugin { ]); } - /** Setup the server plugin. */ - async setup() { - if (this.shouldAutoStart()) { - await this.start(); - } - } + async setup() {} /** Get the server configuration. */ getConfig() { @@ -79,11 +87,6 @@ export class ServerPlugin extends Plugin { return config; } - /** Check if the server should auto start. */ - shouldAutoStart() { - return this.config.autoStart; - } - /** * Start the server. * @@ -148,14 +151,10 @@ export class ServerPlugin extends Plugin { * * Only use this method if you need to access the server instance for advanced usage like a custom websocket server, etc. * - * @throws {Error} If the server is not started or autoStart is true. + * @throws {Error} If the server has not started yet. * @returns {HTTPServer} The server instance. */ getServer(): HTTPServer { - if (this.shouldAutoStart()) { - throw ServerError.autoStartConflict("get server"); - } - if (!this.server) { throw ServerError.notStarted(); } @@ -166,15 +165,13 @@ export class ServerPlugin extends Plugin { /** * Extend the server with custom routes or middleware. * + * Call this inside the `onPluginsReady` callback of `createApp` to register + * custom Express routes or middleware before the server starts listening. + * * @param fn - A function that receives the express application. * @returns The server plugin instance for chaining. - * @throws {Error} If autoStart is true. */ extend(fn: (app: express.Application) => void) { - if (this.shouldAutoStart()) { - throw ServerError.autoStartConflict("extend server"); - } - this.serverExtensions.push(fn); return this; } @@ -389,8 +386,6 @@ export class ServerPlugin extends Plugin { exports() { const self = this; return { - /** Start the server */ - start: this.start, /** Extend the server with custom routes or middleware */ extend(fn: (app: express.Application) => void) { self.extend(fn); @@ -400,6 +395,19 @@ export class ServerPlugin extends Plugin { getServer: this.getServer, /** Get the server configuration */ getConfig: this.getConfig, + /** @deprecated Server is now started automatically by createApp. */ + start() { + throw new ServerError( + "server.start() has been removed. Use the onPluginsReady callback instead:\n\n" + + " createApp({\n" + + " plugins: [server(), ...],\n" + + " onPluginsReady(appkit) {\n" + + " appkit.server.extend(...);\n" + + " },\n" + + " });\n\n" + + "Run `npx appkit codemod on-plugins-ready --write` to auto-migrate.", + ); + }, }; } } diff --git a/packages/appkit/src/plugins/server/manifest.json b/packages/appkit/src/plugins/server/manifest.json index 11822beb..1112fbf5 100644 --- a/packages/appkit/src/plugins/server/manifest.json +++ b/packages/appkit/src/plugins/server/manifest.json @@ -11,11 +11,6 @@ "schema": { "type": "object", "properties": { - "autoStart": { - "type": "boolean", - "default": true, - "description": "Automatically start the server on plugin setup" - }, "host": { "type": "string", "default": "0.0.0.0", diff --git a/packages/appkit/src/plugins/server/tests/server.integration.test.ts b/packages/appkit/src/plugins/server/tests/server.integration.test.ts index c3a646ea..0b67e4c0 100644 --- a/packages/appkit/src/plugins/server/tests/server.integration.test.ts +++ b/packages/appkit/src/plugins/server/tests/server.integration.test.ts @@ -29,13 +29,10 @@ describe("ServerPlugin Integration", () => { serverPlugin({ port: TEST_PORT, host: "127.0.0.1", - autoStart: false, }), ], }); - // Start server manually - await app.server.start(); server = app.server.getServer(); baseUrl = `http://127.0.0.1:${TEST_PORT}`; @@ -124,13 +121,11 @@ describe("ServerPlugin with custom plugin", () => { serverPlugin({ port: TEST_PORT, host: "127.0.0.1", - autoStart: false, }), testPlugin({}), ], }); - await app.server.start(); server = app.server.getServer(); baseUrl = `http://127.0.0.1:${TEST_PORT}`; @@ -172,7 +167,7 @@ describe("ServerPlugin with custom plugin", () => { }); }); -describe("ServerPlugin with extend()", () => { +describe("ServerPlugin with extend() via onPluginsReady", () => { let server: Server; let baseUrl: string; let serviceContextMock: Awaited>; @@ -188,19 +183,73 @@ describe("ServerPlugin with extend()", () => { serverPlugin({ port: TEST_PORT, host: "127.0.0.1", - autoStart: false, }), ], + onPluginsReady(appkit) { + appkit.server.extend((expressApp) => { + expressApp.get("/custom", (_req, res) => { + res.json({ custom: true }); + }); + }); + }, }); - // Add custom route via extend() - app.server.extend((expressApp) => { - expressApp.get("/custom", (_req, res) => { - res.json({ custom: true }); + server = app.server.getServer(); + baseUrl = `http://127.0.0.1:${TEST_PORT}`; + + await new Promise((resolve) => setTimeout(resolve, 100)); + }); + + afterAll(async () => { + serviceContextMock?.restore(); + if (server) { + await new Promise((resolve, reject) => { + server.close((err) => { + if (err) reject(err); + else resolve(); + }); }); + } + }); + + test("custom route via extend() in onPluginsReady callback works", async () => { + const response = await fetch(`${baseUrl}/custom`); + + expect(response.status).toBe(200); + + const data = await response.json(); + expect(data).toEqual({ custom: true }); + }); +}); + +describe("createApp with async onPluginsReady callback", () => { + let server: Server; + let baseUrl: string; + let serviceContextMock: Awaited>; + const TEST_PORT = 9885; + + beforeAll(async () => { + setupDatabricksEnv(); + ServiceContext.reset(); + serviceContextMock = await mockServiceContext(); + + const app = await createApp({ + plugins: [ + serverPlugin({ + port: TEST_PORT, + host: "127.0.0.1", + }), + ], + async onPluginsReady(appkit) { + await new Promise((resolve) => setTimeout(resolve, 10)); + appkit.server.extend((expressApp) => { + expressApp.get("/async-custom", (_req, res) => { + res.json({ asyncSetup: true }); + }); + }); + }, }); - await app.server.start(); server = app.server.getServer(); baseUrl = `http://127.0.0.1:${TEST_PORT}`; @@ -219,12 +268,38 @@ describe("ServerPlugin with extend()", () => { } }); - test("custom route via extend() works", async () => { - const response = await fetch(`${baseUrl}/custom`); + test("async onPluginsReady callback runs before server starts", async () => { + const response = await fetch(`${baseUrl}/async-custom`); expect(response.status).toBe(200); const data = await response.json(); - expect(data).toEqual({ custom: true }); + expect(data).toEqual({ asyncSetup: true }); + }); +}); + +describe("createApp without server plugin", () => { + let serviceContextMock: Awaited>; + let onPluginsReadyWasCalled = false; + + beforeAll(async () => { + setupDatabricksEnv(); + ServiceContext.reset(); + serviceContextMock = await mockServiceContext(); + + await createApp({ + plugins: [], + onPluginsReady() { + onPluginsReadyWasCalled = true; + }, + }); + }); + + afterAll(async () => { + serviceContextMock?.restore(); + }); + + test("onPluginsReady callback is still called without server plugin", () => { + expect(onPluginsReadyWasCalled).toBe(true); }); }); diff --git a/packages/appkit/src/plugins/server/tests/server.test.ts b/packages/appkit/src/plugins/server/tests/server.test.ts index 22f18129..fae11fb5 100644 --- a/packages/appkit/src/plugins/server/tests/server.test.ts +++ b/packages/appkit/src/plugins/server/tests/server.test.ts @@ -197,19 +197,22 @@ describe("ServerPlugin", () => { const plugin = new ServerPlugin({ port: 3000, host: "127.0.0.1", - autoStart: false, }); const config = plugin.getConfig(); expect(config.port).toBe(3000); expect(config.host).toBe("127.0.0.1"); - expect(config.autoStart).toBe(false); + }); + + test("should throw when autoStart is passed", () => { + expect(() => new ServerPlugin({ autoStart: false } as any)).toThrow( + "server({ autoStart }) has been removed", + ); }); }); describe("DEFAULT_CONFIG", () => { test("should have correct default values", () => { - expect(ServerPlugin.DEFAULT_CONFIG.autoStart).toBe(true); expect(ServerPlugin.DEFAULT_CONFIG.host).toBe("0.0.0.0"); expect(ServerPlugin.DEFAULT_CONFIG.port).toBe(8000); }); @@ -220,30 +223,9 @@ describe("ServerPlugin", () => { }); }); - describe("shouldAutoStart", () => { - test("should return true when autoStart is true", () => { - const plugin = new ServerPlugin({ autoStart: true }); - expect(plugin.shouldAutoStart()).toBe(true); - }); - - test("should return false when autoStart is false", () => { - const plugin = new ServerPlugin({ autoStart: false }); - expect(plugin.shouldAutoStart()).toBe(false); - }); - }); - describe("setup", () => { - test("should call start when autoStart is true", async () => { - const plugin = new ServerPlugin({ autoStart: true }); - const startSpy = vi.spyOn(plugin, "start").mockResolvedValue({} as any); - - await plugin.setup(); - - expect(startSpy).toHaveBeenCalled(); - }); - - test("should not call start when autoStart is false", async () => { - const plugin = new ServerPlugin({ autoStart: false }); + test("should be a no-op (server start is orchestrated by createApp)", async () => { + const plugin = new ServerPlugin({}); const startSpy = vi.spyOn(plugin, "start").mockResolvedValue({} as any); await plugin.setup(); @@ -254,7 +236,7 @@ describe("ServerPlugin", () => { describe("start", () => { test("should call listen on express app", async () => { - const plugin = new ServerPlugin({ autoStart: false, port: 3000 }); + const plugin = new ServerPlugin({ port: 3000 }); await plugin.start(); @@ -267,7 +249,7 @@ describe("ServerPlugin", () => { test("should setup ViteDevServer in development mode", async () => { process.env.NODE_ENV = "development"; - const plugin = new ServerPlugin({ autoStart: false }); + const plugin = new ServerPlugin({}); await plugin.start(); @@ -277,7 +259,7 @@ describe("ServerPlugin", () => { }); test("should register RemoteTunnelController middleware and set server", async () => { - const plugin = new ServerPlugin({ autoStart: false }); + const plugin = new ServerPlugin({}); await plugin.start(); @@ -304,7 +286,7 @@ describe("ServerPlugin", () => { }, }; - const plugin = new ServerPlugin({ autoStart: false, plugins }); + const plugin = new ServerPlugin({ plugins }); await plugin.start(); // Get the type function passed to express.json @@ -348,7 +330,7 @@ describe("ServerPlugin", () => { }, }; - const plugin = new ServerPlugin({ autoStart: false, plugins }); + const plugin = new ServerPlugin({ plugins }); await plugin.start(); const routerFn = (express as any).Router as ReturnType; @@ -386,7 +368,7 @@ describe("ServerPlugin", () => { }, }; - const plugin = new ServerPlugin({ autoStart: false, plugins }); + const plugin = new ServerPlugin({ plugins }); await plugin.start(); expect(plugins["plugin-a"].clientConfig).toHaveBeenCalled(); @@ -413,7 +395,7 @@ describe("ServerPlugin", () => { }, }; - const plugin = new ServerPlugin({ autoStart: false, plugins }); + const plugin = new ServerPlugin({ plugins }); await plugin.start(); expect(plugins["plugin-null"].clientConfig).toHaveBeenCalled(); @@ -444,7 +426,7 @@ describe("ServerPlugin", () => { }, }; - const plugin = new ServerPlugin({ autoStart: false, plugins }); + const plugin = new ServerPlugin({ plugins }); await expect(plugin.start()).resolves.toBeDefined(); expect(mockLoggerError).toHaveBeenCalledWith( "Plugin '%s' clientConfig() failed, skipping its config: %O", @@ -457,7 +439,7 @@ describe("ServerPlugin", () => { process.env.NODE_ENV = "production"; vi.mocked(fs.existsSync).mockReturnValue(true); - const plugin = new ServerPlugin({ autoStart: false }); + const plugin = new ServerPlugin({}); await plugin.start(); @@ -470,7 +452,7 @@ describe("ServerPlugin", () => { process.env.NODE_ENV = "production"; vi.mocked(fs.existsSync).mockReturnValue(false); - const plugin = new ServerPlugin({ autoStart: false }); + const plugin = new ServerPlugin({}); await plugin.start(); @@ -479,8 +461,8 @@ describe("ServerPlugin", () => { }); describe("extend", () => { - test("should add extension function when autoStart is false", () => { - const plugin = new ServerPlugin({ autoStart: false }); + test("should add extension function and return plugin for chaining", () => { + const plugin = new ServerPlugin({}); const extensionFn = vi.fn(); const result = plugin.extend(extensionFn); @@ -488,17 +470,8 @@ describe("ServerPlugin", () => { expect(result).toBe(plugin); }); - test("should throw when autoStart is true", () => { - const plugin = new ServerPlugin({ autoStart: true }); - const extensionFn = vi.fn(); - - expect(() => plugin.extend(extensionFn)).toThrow( - "Cannot extend server when autoStart is true", - ); - }); - test("should call extension functions during start", async () => { - const plugin = new ServerPlugin({ autoStart: false }); + const plugin = new ServerPlugin({}); const extensionFn = vi.fn(); plugin.extend(extensionFn); @@ -508,17 +481,18 @@ describe("ServerPlugin", () => { }); }); - describe("getServer", () => { - test("should throw when autoStart is true", () => { - const plugin = new ServerPlugin({ autoStart: true }); + describe("exports().start() trap", () => { + test("should throw migration error when start() is called via exports", () => { + const plugin = new ServerPlugin({}); + const exported = plugin.exports(); - expect(() => plugin.getServer()).toThrow( - "Cannot get server when autoStart is true", - ); + expect(() => exported.start()).toThrow("server.start() has been removed"); }); + }); + describe("getServer", () => { test("should throw when server not started", () => { - const plugin = new ServerPlugin({ autoStart: false }); + const plugin = new ServerPlugin({}); expect(() => plugin.getServer()).toThrow( "Server not started. Please start the server first by calling the start() method", @@ -526,7 +500,7 @@ describe("ServerPlugin", () => { }); test("should return server after start", async () => { - const plugin = new ServerPlugin({ autoStart: false }); + const plugin = new ServerPlugin({}); await plugin.start(); const server = plugin.getServer(); @@ -553,7 +527,7 @@ describe("ServerPlugin", () => { describe("logStartupInfo", () => { test("logs remote tunnel controller disabled when missing", () => { mockLoggerDebug.mockClear(); - const plugin = new ServerPlugin({ autoStart: false }); + const plugin = new ServerPlugin({}); (plugin as any).remoteTunnelController = undefined; (plugin as any).logStartupInfo(); @@ -565,7 +539,7 @@ describe("ServerPlugin", () => { test("logs remote tunnel allowed/active when controller present", () => { mockLoggerDebug.mockClear(); - const plugin = new ServerPlugin({ autoStart: false }); + const plugin = new ServerPlugin({}); (plugin as any).remoteTunnelController = { isAllowedByEnv: () => true, isActive: () => true, @@ -607,7 +581,6 @@ describe("ServerPlugin", () => { .mockImplementation(((_code?: number) => undefined) as any); const plugin = new ServerPlugin({ - autoStart: false, plugins: { ok: { name: "ok", diff --git a/packages/appkit/src/plugins/server/types.ts b/packages/appkit/src/plugins/server/types.ts index e187cacc..f9f6ebce 100644 --- a/packages/appkit/src/plugins/server/types.ts +++ b/packages/appkit/src/plugins/server/types.ts @@ -5,6 +5,5 @@ export interface ServerConfig extends BasePluginConfig { port?: number; plugins?: Record; staticPath?: string; - autoStart?: boolean; host?: string; } diff --git a/template/server/server.ts b/template/server/server.ts index 214ac1ce..e28f3ef4 100644 --- a/template/server/server.ts +++ b/template/server/server.ts @@ -5,24 +5,13 @@ import { setupSampleLakebaseRoutes } from './routes/lakebase/todo-routes'; createApp({ plugins: [ -{{- if .plugins.lakebase}} - server({ autoStart: false }), -{{- range $name, $_ := .plugins}} -{{- if ne $name "server"}} - {{$name}}(), -{{- end}} -{{- end}} -{{- else}} {{- range $name, $_ := .plugins}} {{$name}}(), -{{- end}} {{- end}} ], -}) {{- if .plugins.lakebase}} - .then(async (appkit) => { + async onPluginsReady(appkit) { await setupSampleLakebaseRoutes(appkit); - await appkit.server.start(); - }) + }, {{- end}} - .catch(console.error); +}).catch(console.error); From f2f05044f3216b1eafcd118bd995960c35392020 Mon Sep 17 00:00:00 2001 From: Mario Cadenas <17888484+MarioCadenas@users.noreply.github.com> Date: Mon, 27 Apr 2026 14:56:46 +0200 Subject: [PATCH 08/23] feat: add appkit codemod on-plugins-ready CLI (#291) * feat: add customize callback to createApp, remove autoStart Replace the post-await extend/start ceremony with a declarative `customize` callback on createApp config. The callback runs after plugin setup but before the server starts, giving access to the full appkit handle for registering custom routes or async setup. - Add `customize` option to createApp config - Server start is now orchestrated by createApp (lookup by name) - Remove `autoStart` from public API, ServerConfig, and manifest - Remove `start()` from server plugin exports - Remove autoStart guards from extend() and getServer() - Remove ServerError.autoStartConflict() - Migrate dev-playground, template, and all tests Signed-off-by: MarioCadenas * feat: rename customize to onPluginsReady, add codemod CLI and runtime detection Rename the lifecycle hook from `customize` to `onPluginsReady` to clearly communicate when it fires (after plugins are ready, before server starts). Add `appkit codemod customize-callback` CLI command that auto-migrates old autoStart/extend/start patterns to the new onPluginsReady callback. Supports both .then() chain (Pattern A) and await + imperative (Pattern B, with bail-out for complex cases). Add runtime detection that throws helpful errors when users pass autoStart to server() or call server.start() after upgrading, directing them to run the codemod. Signed-off-by: MarioCadenas * fix: exclude codemod fixture files from typecheck The test fixture .ts files import @databricks/appkit which doesn't exist in the shared package, causing tsc to fail in CI. Exclude the fixtures directory from the shared tsconfig. Signed-off-by: MarioCadenas * refactor: split codemod into separate PR Remove the codemod CLI from this PR to keep the review focused on the core lifecycle change. The codemod will land as a follow-up with bug fixes from review. Runtime detection (constructor autoStart throw + exports().start() trap) stays since it's part of the migration story. Signed-off-by: MarioCadenas * fix: add debug logging for onPluginsReady lifecycle hook Log when the onPluginsReady hook starts and completes to aid debugging in development mode. Signed-off-by: MarioCadenas * fix: rename codemod reference to on-plugins-ready Update runtime detection error messages to point users to `npx appkit codemod on-plugins-ready` to match the hook name. Signed-off-by: MarioCadenas * feat: add appkit codemod on-plugins-ready CLI Add `appkit codemod on-plugins-ready` command that auto-migrates old autoStart/extend/start patterns to the new onPluginsReady callback. Handles Pattern A (.then chain) and Pattern B (await + imperative). Bails with a warning for complex cases (non-server usage of appkit handle, multiple extend calls). Includes fixes from review: - Use raw slice offset for brace matching (not trimmed) - Use brace-aware parsing for .catch() handlers with arrow functions - Bail out when multiple .extend() calls detected in Pattern B Signed-off-by: MarioCadenas * fix: handle async .then callbacks and full start() statement removal - Remove entire `await appkit.server.start();` statements instead of just stripping `.start()` (which left dangling `await appkit.server;`) - Detect async callbacks in .then() and emit `async onPluginsReady` so await expressions inside the body remain valid Signed-off-by: MarioCadenas --------- Signed-off-by: MarioCadenas Co-authored-by: MarioCadenas --- .../shared/src/cli/commands/codemod/index.ts | 17 + .../cli/commands/codemod/on-plugins-ready.ts | 484 ++++++++++++++++++ .../tests/fixtures/already-migrated.input.ts | 10 + .../autostart-true-with-port.input.ts | 5 + .../fixtures/pattern-a-arrow-catch.input.ts | 15 + .../fixtures/pattern-a-with-catch.input.ts | 15 + .../codemod/tests/fixtures/pattern-a.input.ts | 13 + .../tests/fixtures/pattern-b-complex.input.ts | 15 + .../fixtures/pattern-b-multi-extend.input.ts | 15 + .../codemod/tests/fixtures/pattern-b.input.ts | 13 + .../codemod/tests/on-plugins-ready.test.ts | 129 +++++ packages/shared/src/cli/index.ts | 2 + packages/shared/tsconfig.json | 2 +- 13 files changed, 734 insertions(+), 1 deletion(-) create mode 100644 packages/shared/src/cli/commands/codemod/index.ts create mode 100644 packages/shared/src/cli/commands/codemod/on-plugins-ready.ts create mode 100644 packages/shared/src/cli/commands/codemod/tests/fixtures/already-migrated.input.ts create mode 100644 packages/shared/src/cli/commands/codemod/tests/fixtures/autostart-true-with-port.input.ts create mode 100644 packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-a-arrow-catch.input.ts create mode 100644 packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-a-with-catch.input.ts create mode 100644 packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-a.input.ts create mode 100644 packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-b-complex.input.ts create mode 100644 packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-b-multi-extend.input.ts create mode 100644 packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-b.input.ts create mode 100644 packages/shared/src/cli/commands/codemod/tests/on-plugins-ready.test.ts diff --git a/packages/shared/src/cli/commands/codemod/index.ts b/packages/shared/src/cli/commands/codemod/index.ts new file mode 100644 index 00000000..2f9c160d --- /dev/null +++ b/packages/shared/src/cli/commands/codemod/index.ts @@ -0,0 +1,17 @@ +import { Command } from "commander"; +import { onPluginsReadyCommand } from "./on-plugins-ready"; + +/** + * Parent command for codemod operations. + * Subcommands: + * - on-plugins-ready: Migrate from autoStart/extend/start to onPluginsReady callback + */ +export const codemodCommand = new Command("codemod") + .description("Run codemods to migrate to newer AppKit APIs") + .addCommand(onPluginsReadyCommand) + .addHelpText( + "after", + ` +Examples: + $ appkit codemod on-plugins-ready --write`, + ); diff --git a/packages/shared/src/cli/commands/codemod/on-plugins-ready.ts b/packages/shared/src/cli/commands/codemod/on-plugins-ready.ts new file mode 100644 index 00000000..37faeefe --- /dev/null +++ b/packages/shared/src/cli/commands/codemod/on-plugins-ready.ts @@ -0,0 +1,484 @@ +import fs from "node:fs"; +import path from "node:path"; +import { Lang, parse } from "@ast-grep/napi"; +import { Command } from "commander"; + +const SEARCH_DIRS = ["server", "src", "."]; +const CANDIDATE_NAMES = ["server.ts", "index.ts"]; +const SKIP_DIRS = new Set(["node_modules", "dist", "build", ".git"]); + +function findServerEntryFiles(rootDir: string): string[] { + const results: string[] = []; + + for (const dir of SEARCH_DIRS) { + const absDir = path.resolve(rootDir, dir); + if (!fs.existsSync(absDir)) continue; + + const files = + dir === "." + ? CANDIDATE_NAMES.map((n) => path.join(absDir, n)).filter(fs.existsSync) + : findTsFiles(absDir); + + for (const file of files) { + const content = fs.readFileSync(file, "utf-8"); + if ( + content.includes("createApp") && + content.includes("@databricks/appkit") + ) { + results.push(file); + } + } + } + + return [...new Set(results)]; +} + +function findTsFiles(dir: string, files: string[] = []): string[] { + let entries: fs.Dirent[]; + try { + entries = fs.readdirSync(dir, { withFileTypes: true }); + } catch { + return files; + } + + for (const entry of entries) { + const fullPath = path.join(dir, entry.name); + if (entry.isDirectory()) { + if (SKIP_DIRS.has(entry.name)) continue; + findTsFiles(fullPath, files); + } else if (entry.isFile() && entry.name.endsWith(".ts")) { + files.push(fullPath); + } + } + + return files; +} + +function isAlreadyMigrated(content: string): boolean { + const ast = parse(Lang.TypeScript, content); + const root = ast.root(); + return root.findAll("createApp({ $$$PROPS })").some((match) => { + const text = match.text(); + return /\bonPluginsReady\s*[(:]/.test(text); + }); +} + +/** + * Find the index of the matching closing delimiter for an opening one. + * Supports (), {}, and []. + */ +function findMatchingClose(content: string, openIdx: number): number { + const open = content[openIdx]; + const closeMap: Record = { + "(": ")", + "{": "}", + "[": "]", + }; + const close = closeMap[open]; + if (!close) return -1; + + let depth = 1; + let i = openIdx + 1; + while (i < content.length && depth > 0) { + const ch = content[i]; + if (ch === open) depth++; + else if (ch === close) depth--; + + // skip string literals + if (ch === '"' || ch === "'" || ch === "`") { + i = skipString(content, i); + continue; + } + i++; + } + return depth === 0 ? i - 1 : -1; +} + +function skipString(content: string, startIdx: number): number { + const quote = content[startIdx]; + let i = startIdx + 1; + while (i < content.length) { + if (content[i] === "\\") { + i += 2; + continue; + } + if (content[i] === quote) return i + 1; + i++; + } + return i; +} + +function stripAutoStartFromServerCalls(content: string): string { + return content.replace( + /server\(\{([^}]*)\}\)/g, + (_fullMatch, propsStr: string) => { + const cleaned = propsStr + .replace(/autoStart\s*:\s*(true|false)\s*,?\s*/g, "") + .replace(/,\s*$/, "") + .trim(); + if (!cleaned) return "server()"; + return `server({ ${cleaned} })`; + }, + ); +} + +interface MigrationResult { + migrated: boolean; + content: string; + warnings: string[]; +} + +function migratePatternA(content: string): MigrationResult { + const warnings: string[] = []; + + // Find createApp(...).then( + const createAppIdx = content.indexOf("createApp("); + if (createAppIdx === -1) return { migrated: false, content, warnings }; + + // Find the opening paren of createApp( + const configOpenParen = content.indexOf("(", createAppIdx); + const configCloseParen = findMatchingClose(content, configOpenParen); + if (configCloseParen === -1) return { migrated: false, content, warnings }; + + // Check for .then( after the closing paren + const afterCreateApp = content.slice(configCloseParen + 1); + const thenMatch = afterCreateApp.match(/^\s*\.then\s*\(/); + if (!thenMatch) return { migrated: false, content, warnings }; + + const thenStart = configCloseParen + 1 + afterCreateApp.indexOf(".then"); + const thenOpenParen = content.indexOf("(", thenStart + 4); + const thenCloseParen = findMatchingClose(content, thenOpenParen); + if (thenCloseParen === -1) return { migrated: false, content, warnings }; + + // Extract the callback inside .then(...) + const thenRaw = content.slice(thenOpenParen + 1, thenCloseParen); + const thenInner = thenRaw.trim(); + + // Parse callback: (param) => { body } or async (param) => { body } + const callbackMatch = thenInner.match( + /^(?:async\s+)?\(\s*(\w+)\s*\)\s*=>\s*\{/, + ); + if (!callbackMatch) return { migrated: false, content, warnings }; + + const paramName = callbackMatch[1]; + const bodyOpenBrace = thenOpenParen + 1 + thenRaw.indexOf("{"); + const bodyCloseBrace = findMatchingClose(content, bodyOpenBrace); + if (bodyCloseBrace === -1) return { migrated: false, content, warnings }; + + let callbackBody = content.slice(bodyOpenBrace + 1, bodyCloseBrace).trim(); + + // Remove entire statements that are just .start() calls (e.g. `await appkit.server.start();`) + callbackBody = callbackBody + .replace(/^\s*(?:await\s+)?\w+\.server\s*\.\s*start\(\s*\)\s*;?\s*$/gm, "") + .replace(/\n\s*\.start\(\s*\)\s*;?/g, ";") + .replace(/\.start\(\s*\)/g, "") + .replace(/\n\s*\n\s*\n/g, "\n\n") + .trim(); + + // Clean up trailing semicolons + if (callbackBody.endsWith(";")) { + // fine + } else if (!callbackBody.endsWith("}")) { + callbackBody += ";"; + } + + // Detect if the callback was async + const isAsync = /^async\s/.test(thenInner.trim()); + + // Check for .catch() after .then(...) using brace-aware parsing + const afterThenClose = content.slice(thenCloseParen + 1); + const catchPatternMatch = afterThenClose.match(/^\s*(?:\)\s*)?\.catch\s*\(/); + + let catchSuffix: string; + let consumeAfterThen: number; + + if (catchPatternMatch) { + const catchOpenParen = thenCloseParen + 1 + catchPatternMatch[0].length - 1; + const catchCloseParen = findMatchingClose(content, catchOpenParen); + if (catchCloseParen !== -1) { + const catchArg = content + .slice(catchOpenParen + 1, catchCloseParen) + .trim(); + catchSuffix = `.catch(${catchArg})`; + consumeAfterThen = catchCloseParen + 1 - (thenCloseParen + 1); + } else { + catchSuffix = ".catch(console.error)"; + consumeAfterThen = 0; + } + } else { + catchSuffix = ".catch(console.error)"; + consumeAfterThen = 0; + } + + // Build the onPluginsReady property + const configStr = content.slice(configOpenParen + 1, configCloseParen); + const lastBraceIdx = configStr.lastIndexOf("}"); + if (lastBraceIdx === -1) return { migrated: false, content, warnings }; + + const beforeLastBrace = configStr.slice(0, lastBraceIdx).trimEnd(); + const needsComma = beforeLastBrace.endsWith(",") ? "" : ","; + + // Indent the body properly + const bodyLines = callbackBody.split("\n"); + const indentedBody = bodyLines + .map((line) => ` ${line.trimStart()}`) + .join("\n"); + + const asyncPrefix = isAsync ? "async " : ""; + const onPluginsReadyProp = `${needsComma}\n ${asyncPrefix}onPluginsReady(${paramName}) {\n${indentedBody}\n },`; + const newConfig = `${beforeLastBrace}${onPluginsReadyProp}\n}`; + + // Build the replacement + const endIdx = thenCloseParen + 1 + consumeAfterThen; + // Consume trailing ) ; and whitespace + let finalEnd = endIdx; + const trailing = content.slice(finalEnd).match(/^\s*\)?\s*;?\s*/); + if (trailing) finalEnd += trailing[0].length; + + const newContent = + content.slice(0, createAppIdx) + + `createApp(${newConfig})${catchSuffix};` + + content.slice(finalEnd); + + return { migrated: true, content: newContent, warnings }; +} + +function migratePatternB(content: string): MigrationResult { + const warnings: string[] = []; + + // Match: const/let varName = await createApp({...}); + const awaitPattern = /(?:const|let)\s+(\w+)\s*=\s*await\s+createApp\s*\(/; + + const match = content.match(awaitPattern); + if (!match) return { migrated: false, content, warnings }; + + const varName = match[1]; + const matchIdx = content.indexOf(match[0]); + + // Find the createApp(...) closing paren + const configOpenParen = matchIdx + match[0].length - 1; + const configCloseParen = findMatchingClose(content, configOpenParen); + if (configCloseParen === -1) return { migrated: false, content, warnings }; + + // Find the semicolon after the createApp call + const afterCall = content.slice(configCloseParen + 1); + const semiMatch = afterCall.match(/^\s*;/); + const createAppEnd = + configCloseParen + 1 + (semiMatch ? semiMatch[0].length : 0); + + // Find all uses of varName after the createApp call + const afterCreateApp = content.slice(createAppEnd); + const varUsagePattern = new RegExp(`\\b${varName}\\.(\\w+)`, "g"); + + const usages: { plugin: string; index: number }[] = []; + for (const usageMatch of afterCreateApp.matchAll(varUsagePattern)) { + usages.push({ plugin: usageMatch[1], index: usageMatch.index }); + } + + // Check for non-server usage + const nonServerUsage = usages.filter((u) => u.plugin !== "server"); + if (nonServerUsage.length > 0) { + warnings.push( + `Found additional usage of '${varName}' handle outside server.extend/start. Please migrate manually.`, + ); + return { migrated: false, content, warnings }; + } + + // Find the extend call(s) and start call in the after-createApp region + const extendPattern = new RegExp( + `\\b${varName}\\.server\\.extend\\s*\\(`, + "g", + ); + const startPattern = new RegExp( + `(?:await\\s+)?${varName}\\.server\\.start\\s*\\(\\s*\\)\\s*;`, + ); + + const extendMatches = [...afterCreateApp.matchAll(extendPattern)]; + if (extendMatches.length > 1) { + warnings.push( + `Found ${extendMatches.length} server.extend() calls. Please migrate manually.`, + ); + return { migrated: false, content, warnings }; + } + + const extendExec = extendMatches[0] ?? null; + const startExec = startPattern.exec(afterCreateApp); + + if (!startExec) return { migrated: false, content, warnings }; + + // Extract the extend call's argument + let extendArg = ""; + let extendFullStatement = ""; + if (extendExec) { + const extendOpenParen = + createAppEnd + extendExec.index + extendExec[0].length - 1; + const extendCloseParen = findMatchingClose(content, extendOpenParen); + if (extendCloseParen !== -1) { + extendArg = content.slice(extendOpenParen + 1, extendCloseParen).trim(); + // Find the full statement including trailing semicolon + const stmtStart = createAppEnd + extendExec.index; + let stmtEnd = extendCloseParen + 1; + const afterExtend = content.slice(stmtEnd); + const trailingSemi = afterExtend.match(/^\s*;/); + if (trailingSemi) stmtEnd += trailingSemi[0].length; + extendFullStatement = content.slice(stmtStart, stmtEnd); + } + } + + const startFullStatement = startExec[0]; + + // Build the onPluginsReady callback + const configStr = content.slice(configOpenParen + 1, configCloseParen); + const lastBraceIdx = configStr.lastIndexOf("}"); + if (lastBraceIdx === -1) return { migrated: false, content, warnings }; + + const beforeLastBrace = configStr.slice(0, lastBraceIdx).trimEnd(); + const needsComma = beforeLastBrace.endsWith(",") ? "" : ","; + + let onPluginsReadyProp: string; + if (extendArg) { + onPluginsReadyProp = + `${needsComma}\n onPluginsReady(${varName}) {\n` + + ` ${varName}.server.extend(${extendArg});\n` + + " },"; + } else { + onPluginsReadyProp = ""; + } + + const newConfig = `${beforeLastBrace}${onPluginsReadyProp}\n}`; + const newCreateApp = `await createApp(${newConfig});`; + + // Replace: remove const declaration, replace with plain await, remove extend + start + let result = content.slice(0, matchIdx) + newCreateApp; + let remaining = afterCreateApp; + + if (extendFullStatement) { + remaining = remaining.replace(extendFullStatement, ""); + } + remaining = remaining.replace(startFullStatement, ""); + + // Clean up consecutive blank lines + remaining = remaining.replace(/\n\s*\n\s*\n/g, "\n\n"); + + result += remaining; + + return { migrated: true, content: result, warnings }; +} + +export function migrateFile(filePath: string): MigrationResult { + const original = fs.readFileSync(filePath, "utf-8"); + + if (isAlreadyMigrated(original)) { + return { + migrated: false, + content: original, + warnings: ["Already migrated -- no changes needed."], + }; + } + + const content = stripAutoStartFromServerCalls(original); + const allWarnings: string[] = []; + + // Try Pattern A first + const patternA = migratePatternA(content); + if (patternA.migrated) { + allWarnings.push(...patternA.warnings); + return { + migrated: true, + content: patternA.content, + warnings: allWarnings, + }; + } + allWarnings.push(...patternA.warnings); + + // Try Pattern B + const patternB = migratePatternB(content); + if (patternB.migrated) { + allWarnings.push(...patternB.warnings); + return { + migrated: true, + content: patternB.content, + warnings: allWarnings, + }; + } + allWarnings.push(...patternB.warnings); + + // Check if autoStart was stripped (content changed but no pattern matched) + if (content !== original) { + return { migrated: true, content, warnings: allWarnings }; + } + + return { migrated: false, content: original, warnings: allWarnings }; +} + +function runCodemod(options: { path?: string; write?: boolean }) { + const rootDir = process.cwd(); + const write = options.write ?? false; + + let files: string[]; + if (options.path) { + const absPath = path.resolve(rootDir, options.path); + if (!fs.existsSync(absPath)) { + console.error(`File not found: ${absPath}`); + process.exit(1); + } + files = [absPath]; + } else { + files = findServerEntryFiles(rootDir); + } + + if (files.length === 0) { + console.log("No files found importing createApp from @databricks/appkit."); + console.log("Use --path to specify a file explicitly."); + process.exit(0); + } + + let hasChanges = false; + + for (const file of files) { + const relPath = path.relative(rootDir, file); + const result = migrateFile(file); + + for (const warning of result.warnings) { + console.log(` ${relPath}: ${warning}`); + } + + if (!result.migrated) { + if (result.warnings.length === 0) { + console.log(` ${relPath}: No migration needed.`); + } + continue; + } + + hasChanges = true; + + if (write) { + fs.writeFileSync(file, result.content, "utf-8"); + console.log(` ${relPath}: Migrated successfully.`); + } else { + console.log(`\n--- ${relPath} (dry run) ---`); + console.log(result.content); + console.log("---"); + } + } + + if (hasChanges && !write) { + console.log("\nDry run complete. Run with --write to apply changes."); + } +} + +export const onPluginsReadyCommand = new Command("on-plugins-ready") + .description( + "Migrate createApp usage from autoStart/extend/start pattern to onPluginsReady callback", + ) + .option("--path ", "Path to the server entry file to migrate") + .option("--write", "Apply changes (default: dry-run)", false) + .addHelpText( + "after", + ` +Examples: + $ appkit codemod on-plugins-ready # dry-run, auto-detect files + $ appkit codemod on-plugins-ready --write # apply changes + $ appkit codemod on-plugins-ready --path server.ts # migrate a specific file`, + ) + .action(runCodemod); diff --git a/packages/shared/src/cli/commands/codemod/tests/fixtures/already-migrated.input.ts b/packages/shared/src/cli/commands/codemod/tests/fixtures/already-migrated.input.ts new file mode 100644 index 00000000..ab1cf6d0 --- /dev/null +++ b/packages/shared/src/cli/commands/codemod/tests/fixtures/already-migrated.input.ts @@ -0,0 +1,10 @@ +import { analytics, createApp, server } from "@databricks/appkit"; + +createApp({ + plugins: [server(), analytics({})], + onPluginsReady(appkit) { + appkit.server.extend((app) => { + app.get("/custom", (_req, res) => res.json({ ok: true })); + }); + }, +}).catch(console.error); diff --git a/packages/shared/src/cli/commands/codemod/tests/fixtures/autostart-true-with-port.input.ts b/packages/shared/src/cli/commands/codemod/tests/fixtures/autostart-true-with-port.input.ts new file mode 100644 index 00000000..96b70a4f --- /dev/null +++ b/packages/shared/src/cli/commands/codemod/tests/fixtures/autostart-true-with-port.input.ts @@ -0,0 +1,5 @@ +import { createApp, server } from "@databricks/appkit"; + +createApp({ + plugins: [server({ autoStart: true, port: 3000 })], +}).catch(console.error); diff --git a/packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-a-arrow-catch.input.ts b/packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-a-arrow-catch.input.ts new file mode 100644 index 00000000..b6ae8c8a --- /dev/null +++ b/packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-a-arrow-catch.input.ts @@ -0,0 +1,15 @@ +import { analytics, createApp, server } from "@databricks/appkit"; + +createApp({ + plugins: [server({ autoStart: false }), analytics({})], +}) + .then((appkit) => { + appkit.server + .extend((app) => { + app.get("/custom", (_req, res) => { + res.json({ ok: true }); + }); + }) + .start(); + }) + .catch((err) => console.error(err)); diff --git a/packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-a-with-catch.input.ts b/packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-a-with-catch.input.ts new file mode 100644 index 00000000..faa04d5e --- /dev/null +++ b/packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-a-with-catch.input.ts @@ -0,0 +1,15 @@ +import { analytics, createApp, server } from "@databricks/appkit"; + +createApp({ + plugins: [server({ autoStart: false }), analytics({})], +}) + .then((appkit) => { + appkit.server + .extend((app) => { + app.get("/custom", (_req, res) => { + res.json({ ok: true }); + }); + }) + .start(); + }) + .catch(console.error); diff --git a/packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-a.input.ts b/packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-a.input.ts new file mode 100644 index 00000000..73523d6a --- /dev/null +++ b/packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-a.input.ts @@ -0,0 +1,13 @@ +import { analytics, createApp, server } from "@databricks/appkit"; + +createApp({ + plugins: [server({ autoStart: false }), analytics({})], +}).then((appkit) => { + appkit.server + .extend((app) => { + app.get("/custom", (_req, res) => { + res.json({ ok: true }); + }); + }) + .start(); +}); diff --git a/packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-b-complex.input.ts b/packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-b-complex.input.ts new file mode 100644 index 00000000..c1fb25fa --- /dev/null +++ b/packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-b-complex.input.ts @@ -0,0 +1,15 @@ +import { analytics, createApp, server } from "@databricks/appkit"; + +const appkit = await createApp({ + plugins: [server({ autoStart: false }), analytics({})], +}); + +appkit.server.extend((app) => { + app.get("/custom", (_req, res) => { + res.json({ ok: true }); + }); +}); + +appkit.analytics.query("SELECT 1"); + +await appkit.server.start(); diff --git a/packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-b-multi-extend.input.ts b/packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-b-multi-extend.input.ts new file mode 100644 index 00000000..dded09f2 --- /dev/null +++ b/packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-b-multi-extend.input.ts @@ -0,0 +1,15 @@ +import { analytics, createApp, server } from "@databricks/appkit"; + +const appkit = await createApp({ + plugins: [server({ autoStart: false }), analytics({})], +}); + +appkit.server.extend((app) => { + app.get("/one", (_req, res) => res.json({ route: 1 })); +}); + +appkit.server.extend((app) => { + app.get("/two", (_req, res) => res.json({ route: 2 })); +}); + +await appkit.server.start(); diff --git a/packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-b.input.ts b/packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-b.input.ts new file mode 100644 index 00000000..b56c0048 --- /dev/null +++ b/packages/shared/src/cli/commands/codemod/tests/fixtures/pattern-b.input.ts @@ -0,0 +1,13 @@ +import { analytics, createApp, server } from "@databricks/appkit"; + +const appkit = await createApp({ + plugins: [server({ autoStart: false }), analytics({})], +}); + +appkit.server.extend((app) => { + app.get("/custom", (_req, res) => { + res.json({ ok: true }); + }); +}); + +await appkit.server.start(); diff --git a/packages/shared/src/cli/commands/codemod/tests/on-plugins-ready.test.ts b/packages/shared/src/cli/commands/codemod/tests/on-plugins-ready.test.ts new file mode 100644 index 00000000..299e8f1d --- /dev/null +++ b/packages/shared/src/cli/commands/codemod/tests/on-plugins-ready.test.ts @@ -0,0 +1,129 @@ +import fs from "node:fs"; +import path from "node:path"; +import { describe, expect, test } from "vitest"; +import { migrateFile } from "../on-plugins-ready"; + +const fixturesDir = path.join(__dirname, "fixtures"); + +function readFixture(name: string): string { + return fs.readFileSync(path.join(fixturesDir, name), "utf-8"); +} + +describe("onPluginsReady-callback codemod", () => { + describe("Pattern A: .then() chain", () => { + test("migrates .then chain without .catch, adds .catch(console.error)", () => { + const fixturePath = path.join(fixturesDir, "pattern-a.input.ts"); + const result = migrateFile(fixturePath); + + expect(result.migrated).toBe(true); + expect(result.content).toContain("onPluginsReady(appkit)"); + expect(result.content).not.toContain(".then("); + expect(result.content).not.toContain(".start()"); + expect(result.content).not.toContain("autoStart"); + expect(result.content).toContain(".catch(console.error)"); + expect(result.content).toContain("server()"); + }); + + test("migrates .then chain with existing .catch, preserves it", () => { + const fixturePath = path.join( + fixturesDir, + "pattern-a-with-catch.input.ts", + ); + const result = migrateFile(fixturePath); + + expect(result.migrated).toBe(true); + expect(result.content).toContain("onPluginsReady(appkit)"); + expect(result.content).not.toContain(".then("); + expect(result.content).not.toContain(".start()"); + expect(result.content).toContain(".catch(console.error)"); + expect(result.content).toContain("server()"); + }); + + test("preserves the extend callback content", () => { + const fixturePath = path.join(fixturesDir, "pattern-a.input.ts"); + const result = migrateFile(fixturePath); + + expect(result.content).toContain('app.get("/custom"'); + expect(result.content).toContain("res.json({ ok: true })"); + }); + + test("preserves arrow function .catch handler with parens", () => { + const fixturePath = path.join( + fixturesDir, + "pattern-a-arrow-catch.input.ts", + ); + const result = migrateFile(fixturePath); + + expect(result.migrated).toBe(true); + expect(result.content).toContain(".catch((err) => console.error(err))"); + expect(result.content).not.toContain(".then("); + expect(result.content).not.toContain(".start()"); + }); + }); + + describe("Pattern B: await + imperative", () => { + test("migrates await pattern with extend + start", () => { + const fixturePath = path.join(fixturesDir, "pattern-b.input.ts"); + const result = migrateFile(fixturePath); + + expect(result.migrated).toBe(true); + expect(result.content).toContain("onPluginsReady(appkit)"); + expect(result.content).not.toContain("appkit.server.start()"); + expect(result.content).not.toContain("autoStart"); + expect(result.content).toContain("server()"); + }); + + test("bails out when non-server usage of appkit handle exists", () => { + const fixturePath = path.join(fixturesDir, "pattern-b-complex.input.ts"); + const result = migrateFile(fixturePath); + + expect(result.warnings.some((w) => w.includes("migrate manually"))).toBe( + true, + ); + expect(result.content).toContain("server()"); + expect(result.content).not.toContain("autoStart"); + }); + + test("bails out when multiple .extend() calls exist", () => { + const fixturePath = path.join( + fixturesDir, + "pattern-b-multi-extend.input.ts", + ); + const result = migrateFile(fixturePath); + + expect(result.warnings.some((w) => w.includes("migrate manually"))).toBe( + true, + ); + expect(result.content).toContain("server()"); + expect(result.content).not.toContain("autoStart"); + }); + }); + + describe("autoStart stripping", () => { + test("strips autoStart: true and preserves other config", () => { + const fixturePath = path.join( + fixturesDir, + "autostart-true-with-port.input.ts", + ); + const result = migrateFile(fixturePath); + + expect(result.migrated).toBe(true); + expect(result.content).not.toContain("autoStart"); + expect(result.content).toContain("port: 3000"); + expect(result.content).toContain("server({"); + }); + }); + + describe("idempotency", () => { + test("no-ops on already migrated file", () => { + const fixturePath = path.join(fixturesDir, "already-migrated.input.ts"); + const result = migrateFile(fixturePath); + + expect(result.migrated).toBe(false); + expect(result.warnings.some((w) => w.includes("Already migrated"))).toBe( + true, + ); + expect(result.content).toBe(readFixture("already-migrated.input.ts")); + }); + }); +}); diff --git a/packages/shared/src/cli/index.ts b/packages/shared/src/cli/index.ts index 71f09e6f..4d0ed65b 100644 --- a/packages/shared/src/cli/index.ts +++ b/packages/shared/src/cli/index.ts @@ -4,6 +4,7 @@ import { readFileSync } from "node:fs"; import { dirname, join } from "node:path"; import { fileURLToPath } from "node:url"; import { Command } from "commander"; +import { codemodCommand } from "./commands/codemod/index.js"; import { docsCommand } from "./commands/docs.js"; import { generateTypesCommand } from "./commands/generate-types.js"; import { lintCommand } from "./commands/lint.js"; @@ -26,5 +27,6 @@ cmd.addCommand(generateTypesCommand); cmd.addCommand(lintCommand); cmd.addCommand(docsCommand); cmd.addCommand(pluginCommand); +cmd.addCommand(codemodCommand); cmd.parse(); diff --git a/packages/shared/tsconfig.json b/packages/shared/tsconfig.json index 4a6e68b3..5e195c3b 100644 --- a/packages/shared/tsconfig.json +++ b/packages/shared/tsconfig.json @@ -8,5 +8,5 @@ } }, "include": ["src/**/*"], - "exclude": ["node_modules", "dist"] + "exclude": ["node_modules", "dist", "src/**/fixtures"] } From 323c8cc56334c6f5a248759020be61ba706758e5 Mon Sep 17 00:00:00 2001 From: MarioCadenas Date: Tue, 21 Apr 2026 19:42:39 +0200 Subject: [PATCH 09/23] feat(appkit): shared agent types and LLM adapter implementations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Foundation layer for the agents feature. Adds the portable type surface that every downstream layer builds on, plus three LLM adapter implementations so the agents plugin (later PR) can target whatever the user has. ### Shared agent types `packages/shared/src/agent.ts` — no behavior, just the type vocabulary: `AgentAdapter`, `AgentEvent`, `AgentInput`, `AgentRunContext`, `AgentToolDefinition`, `Message`, `Thread`, `ThreadStore`, `ToolAnnotations`, `ToolCall`, `ToolProvider`, `ResponseStreamEvent`. Exported from the shared barrel. ### Adapters - `packages/appkit/src/agents/databricks.ts` — `DatabricksAdapter`: streams OpenAI-compatible completions against a Databricks Model Serving endpoint (raw fetch + SSE, no vendor SDKs). - `packages/appkit/src/agents/vercel-ai.ts` — `VercelAIAdapter`: wraps any Vercel AI SDK `streamText` call. Maps Vercel SDK events to AppKit `AgentEvent`s and tool calls. - `packages/appkit/src/agents/langchain.ts` — `LangChainAdapter`: wraps any LangChain `Runnable` (AgentExecutor, compiled LangGraph, etc.). Subscribes to `streamEvents(v2)` and maps to `AgentEvent`s. Each adapter is self-contained and independently testable. ### Package plumbing - Subpath exports `@databricks/appkit/agents/{databricks,vercel-ai,langchain}` so consumers pick only the adapter they want. - `@langchain/core` and `ai` declared as optional peer dependencies. - `@ai-sdk/openai`, `@langchain/core`, `ai` added as devDeps for tests. - `tsdown.config.ts` emits the three adapter entry points alongside the main bundle. ### Test plan - 24 adapter tests (Databricks: 16, Vercel AI: 4, LangChain: 4) passing - Full appkit vitest suite: 1154 tests passing - Typecheck clean - Build clean, publint clean Signed-off-by: MarioCadenas ### MVP polish - **LangChain adapter `callId` correlation fix.** The previous implementation emitted `tool_call` with the LLM-provided `tc.id ?? tc.name` and `tool_result` with LangChain's internal `event.run_id` — these never matched, breaking every Responses-API client that pairs calls by `call_id`. The adapter now records a `run_id → callId` mapping at `on_tool_start` (matching the accumulated tool_call by name) and resolves it at `on_tool_end`. A deterministic `lc___` fallback id prevents collisions when the same tool is called multiple times in one turn without a model-provided id. Adds three regression tests covering happy-path correlation, duplicate-name disambiguation, and the no-accumulator-match fallback. - **Adapter docstring cleanup.** The four `@example` blocks in `databricks.ts`, `langchain.ts`, and `vercel-ai.ts` referenced a fictional `appkit.agent.registerAgent("assistant", adapter)` API that has never existed. Replaced with real usage via `createApp({ plugins: [agents({ agents: { assistant: createAgent( { model: adapter }) } })] })`. Signed-off-by: MarioCadenas --- knip.json | 9 +- packages/appkit/package.json | 35 +- packages/appkit/src/agents/databricks.ts | 775 ++++++++++++++++++ packages/appkit/src/agents/langchain.ts | 292 +++++++ .../src/agents/tests/databricks.test.ts | 486 +++++++++++ .../appkit/src/agents/tests/langchain.test.ts | 366 +++++++++ .../appkit/src/agents/tests/vercel-ai.test.ts | 190 +++++ packages/appkit/src/agents/vercel-ai.ts | 138 ++++ packages/appkit/tsdown.config.ts | 7 +- packages/shared/src/agent.ts | 212 +++++ packages/shared/src/index.ts | 1 + pnpm-lock.yaml | 161 ++++ 12 files changed, 2668 insertions(+), 4 deletions(-) create mode 100644 packages/appkit/src/agents/databricks.ts create mode 100644 packages/appkit/src/agents/langchain.ts create mode 100644 packages/appkit/src/agents/tests/databricks.test.ts create mode 100644 packages/appkit/src/agents/tests/langchain.test.ts create mode 100644 packages/appkit/src/agents/tests/vercel-ai.test.ts create mode 100644 packages/appkit/src/agents/vercel-ai.ts create mode 100644 packages/shared/src/agent.ts diff --git a/knip.json b/knip.json index b777d8c2..13a43187 100644 --- a/knip.json +++ b/knip.json @@ -7,7 +7,9 @@ "docs" ], "workspaces": { - "packages/appkit": {}, + "packages/appkit": { + "ignoreDependencies": ["@langchain/core", "ai"] + }, "packages/appkit-ui": { "ignoreDependencies": ["tailwindcss", "tw-animate-css"] } @@ -17,6 +19,11 @@ "**/*.example.tsx", "**/*.css", "packages/appkit/src/plugins/vector-search/**", + "packages/appkit/src/plugin/index.ts", + "packages/appkit/src/plugins/agents/index.ts", + "packages/appkit/src/plugins/agents/tools/index.ts", + "packages/appkit/src/plugins/agents/from-plugin.ts", + "packages/appkit/src/plugins/agents/load-agents.ts", "template/**", "tools/**", "docs/**" diff --git a/packages/appkit/package.json b/packages/appkit/package.json index 0e0af1ba..3d5b1ddf 100644 --- a/packages/appkit/package.json +++ b/packages/appkit/package.json @@ -29,6 +29,18 @@ "development": "./src/index.ts", "default": "./dist/index.js" }, + "./agents/vercel-ai": { + "development": "./src/agents/vercel-ai.ts", + "default": "./dist/agents/vercel-ai.js" + }, + "./agents/langchain": { + "development": "./src/agents/langchain.ts", + "default": "./dist/agents/langchain.js" + }, + "./agents/databricks": { + "development": "./src/agents/databricks.ts", + "default": "./dist/agents/databricks.js" + }, "./type-generator": { "types": "./dist/type-generator/index.d.ts", "development": "./src/type-generator/index.ts", @@ -77,14 +89,30 @@ "semver": "7.7.3", "shared": "workspace:*", "vite": "npm:rolldown-vite@7.1.14", - "ws": "8.18.3" + "ws": "8.18.3", + "zod": "^4.0.0" + }, + "peerDependencies": { + "@langchain/core": ">=0.3.0", + "ai": ">=4.0.0" + }, + "peerDependenciesMeta": { + "ai": { + "optional": true + }, + "@langchain/core": { + "optional": true + } }, "devDependencies": { + "@ai-sdk/openai": "4.0.0-beta.27", + "@langchain/core": "^1.1.39", "@types/express": "4.17.25", "@types/json-schema": "7.0.15", "@types/pg": "8.16.0", "@types/ws": "8.18.1", - "@vitejs/plugin-react": "5.1.1" + "@vitejs/plugin-react": "5.1.1", + "ai": "7.0.0-beta.76" }, "overrides": { "vite": "npm:rolldown-vite@7.1.14" @@ -93,6 +121,9 @@ "publishConfig": { "exports": { ".": "./dist/index.js", + "./agents/vercel-ai": "./dist/agents/vercel-ai.js", + "./agents/langchain": "./dist/agents/langchain.js", + "./agents/databricks": "./dist/agents/databricks.js", "./dist/shared/src/plugin": "./dist/shared/src/plugin.d.ts", "./type-generator": "./dist/type-generator/index.js", "./package.json": "./package.json" diff --git a/packages/appkit/src/agents/databricks.ts b/packages/appkit/src/agents/databricks.ts new file mode 100644 index 00000000..6cc98ca4 --- /dev/null +++ b/packages/appkit/src/agents/databricks.ts @@ -0,0 +1,775 @@ +import type { + AgentAdapter, + AgentEvent, + AgentInput, + AgentRunContext, + AgentToolDefinition, +} from "shared"; +import { stream as servingStream } from "../connectors/serving/client"; + +/** + * Transport shim: given an OpenAI-compatible request body, returns the raw + * SSE byte stream from the serving endpoint. Injected at construction time so + * callers can swap in the workspace SDK (factory paths), a bare `fetch` + * (the raw constructor), or a test fake. + */ +type StreamBody = ( + body: Record, + signal?: AbortSignal, +) => Promise>; + +/** + * Escape-hatch options: provide an `endpointUrl` + `authenticate()` and the + * adapter uses a bare `fetch()` to call it. Useful for tests and for pointing + * the adapter at non-workspace endpoints (reverse proxies, mocks). + */ +interface RawFetchAdapterOptions { + endpointUrl: string; + authenticate: () => Promise>; + maxSteps?: number; + maxTokens?: number; +} + +/** + * Preferred options: caller provides the transport function directly. + * The `fromServingEndpoint` / `fromModelServing` factories use this to route + * through `connectors/serving/stream`, which centralises URL encoding, auth + * via the SDK's `apiClient.request`, and any future retries/telemetry. + */ +interface StreamBodyAdapterOptions { + streamBody: StreamBody; + maxSteps?: number; + maxTokens?: number; +} + +type DatabricksAdapterOptions = + | RawFetchAdapterOptions + | StreamBodyAdapterOptions; + +function isStreamBodyOptions( + o: DatabricksAdapterOptions, +): o is StreamBodyAdapterOptions { + return "streamBody" in o; +} + +/** + * Minimal structural shape consumed by `connectors/serving/stream`. We avoid + * importing the concrete `WorkspaceClient` type to keep the adapter free of a + * compile-time dependency on the SDK. + */ +interface WorkspaceClientLike { + apiClient: { + request(options: Record): Promise; + }; +} + +interface ServingEndpointOptions { + workspaceClient: WorkspaceClientLike; + endpointName: string; + maxSteps?: number; + maxTokens?: number; +} + +interface ModelServingOptions { + maxSteps?: number; + maxTokens?: number; + workspaceClient?: WorkspaceClientLike; +} + +/** + * Structural shape for {@link createDatabricksModel}. The Vercel AI helper + * builds its own `fetch` override and so needs the workspace config surface + * (host, authenticate, ensureResolved) rather than the `apiClient` used by + * the adapter factories. + */ +interface WorkspaceConfig { + host?: string; + authenticate(headers: Headers): Promise; + ensureResolved(): Promise; +} + +interface VercelDatabricksModelOptions { + workspaceClient: { config: WorkspaceConfig }; + endpointName: string; +} + +interface OpenAIMessage { + role: "system" | "user" | "assistant" | "tool"; + content: string | null; + tool_calls?: OpenAIToolCall[]; + tool_call_id?: string; +} + +interface OpenAIToolCall { + id: string; + type: "function"; + function: { name: string; arguments: string }; +} + +interface OpenAITool { + type: "function"; + function: { + name: string; + description: string; + parameters: unknown; + }; +} + +interface DeltaToolCall { + index: number; + id?: string; + type?: string; + function?: { name?: string; arguments?: string }; +} + +/** + * Adapter that talks directly to Databricks Model Serving `/invocations` endpoint. + * + * No dependency on the Vercel AI SDK or LangChain. Uses raw `fetch()` to POST + * OpenAI-compatible payloads and parses the SSE stream itself. Calls + * `authenticate()` per-request so tokens are always fresh. + * + * Handles both structured `tool_calls` responses and text-based tool call + * fallback parsing for models that output tool calls as text. + * + * @example Using the factory (recommended) + * ```ts + * import { createApp, createAgent, agents } from "@databricks/appkit"; + * import { DatabricksAdapter } from "@databricks/appkit/agents/databricks"; + * import { WorkspaceClient } from "@databricks/sdk-experimental"; + * + * const adapter = DatabricksAdapter.fromServingEndpoint({ + * workspaceClient: new WorkspaceClient({}), + * endpointName: "my-endpoint", + * }); + * + * await createApp({ + * plugins: [ + * agents({ + * agents: { + * assistant: createAgent({ + * instructions: "You are a helpful assistant.", + * model: adapter, + * }), + * }, + * }), + * ], + * }); + * ``` + * + * @example Using the raw constructor + * ```ts + * const adapter = new DatabricksAdapter({ + * endpointUrl: "https://host/serving-endpoints/my-endpoint/invocations", + * authenticate: async () => ({ Authorization: `Bearer ${token}` }), + * }); + * ``` + */ +export class DatabricksAdapter implements AgentAdapter { + private streamBody: StreamBody; + private maxSteps: number; + private maxTokens: number; + + constructor(options: DatabricksAdapterOptions) { + this.maxSteps = options.maxSteps ?? 10; + this.maxTokens = options.maxTokens ?? 4096; + + if (isStreamBodyOptions(options)) { + this.streamBody = options.streamBody; + } else { + const { endpointUrl, authenticate } = options; + this.streamBody = async (body, signal) => { + const authHeaders = await authenticate(); + const response = await fetch(endpointUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + ...authHeaders, + }, + body: JSON.stringify(body), + signal, + }); + if (!response.ok) { + const errorText = await response.text().catch(() => "Unknown error"); + throw new Error( + `Databricks API error (${response.status}): ${errorText}`, + ); + } + if (!response.body) throw new Error("No response body"); + return response.body; + }; + } + } + + /** + * Creates a DatabricksAdapter for a Databricks Model Serving endpoint. + * + * Routes through the shared `connectors/serving/stream` helper, which + * delegates to the SDK's `apiClient.request({ raw: true })`. That gives the + * adapter centralised URL encoding + authentication with the rest of the + * serving surface — no bespoke `fetch()` + `authenticate()` plumbing. + */ + static async fromServingEndpoint( + options: ServingEndpointOptions, + ): Promise { + const { workspaceClient, endpointName, maxSteps, maxTokens } = options; + return new DatabricksAdapter({ + streamBody: (body) => + // Cast through the structural shape: the connector types + // `workspaceClient` as the SDK's concrete `WorkspaceClient`, but we + // only need `apiClient.request`. + servingStream( + workspaceClient as unknown as Parameters[0], + endpointName, + body, + ), + maxSteps, + maxTokens, + }); + } + + /** + * Creates a DatabricksAdapter from a Model Serving endpoint name. + * Auto-creates a WorkspaceClient internally. Reads the endpoint name + * from the argument or the `DATABRICKS_AGENT_ENDPOINT` env var. + * + * @example + * ```ts + * // Reads endpoint from DATABRICKS_AGENT_ENDPOINT env var + * const adapter = await DatabricksAdapter.fromModelServing(); + * + * // Explicit endpoint + * const adapter = await DatabricksAdapter.fromModelServing("my-endpoint"); + * + * // With options + * const adapter = await DatabricksAdapter.fromModelServing("my-endpoint", { + * maxSteps: 5, + * maxTokens: 2048, + * }); + * ``` + */ + static async fromModelServing( + endpointName?: string, + options?: ModelServingOptions, + ): Promise { + const resolvedEndpoint = + endpointName ?? process.env.DATABRICKS_AGENT_ENDPOINT; + + if (!resolvedEndpoint) { + throw new Error( + "No endpoint name provided and DATABRICKS_AGENT_ENDPOINT env var is not set. " + + "Pass an endpoint name or set the environment variable.", + ); + } + + let workspaceClient: WorkspaceClientLike | undefined = + options?.workspaceClient; + if (!workspaceClient) { + const sdk = await import("@databricks/sdk-experimental"); + workspaceClient = new sdk.WorkspaceClient( + {}, + ) as unknown as WorkspaceClientLike; + } + + return DatabricksAdapter.fromServingEndpoint({ + workspaceClient, + endpointName: resolvedEndpoint, + maxSteps: options?.maxSteps, + maxTokens: options?.maxTokens, + }); + } + + async *run( + input: AgentInput, + context: AgentRunContext, + ): AsyncGenerator { + // Databricks API requires tool names to match [a-zA-Z0-9_-]. + // Our tool names use dots (e.g. "analytics.query"), so we swap dots + // for double-underscores in the wire format and map back on receipt. + const nameToWire = new Map(); + const wireToName = new Map(); + for (const tool of input.tools) { + const wire = tool.name.replace(/\./g, "__"); + nameToWire.set(tool.name, wire); + wireToName.set(wire, tool.name); + } + + const tools = this.buildTools(input.tools, nameToWire); + const messages = this.buildMessages(input.messages); + + yield { type: "status", status: "running" }; + + for (let step = 0; step < this.maxSteps; step++) { + if (context.signal?.aborted) break; + + const { text, toolCalls } = yield* this.streamCompletion( + messages, + tools, + context, + ); + + if (toolCalls.length === 0) { + const parsed = parseTextToolCalls(text); + if (parsed.length > 0) { + yield* this.executeToolCalls(parsed, messages, context); + continue; + } + break; + } + + messages.push({ + role: "assistant", + content: text || null, + tool_calls: toolCalls, + }); + + for (const tc of toolCalls) { + const wireName = tc.function.name; + const originalName = wireToName.get(wireName) ?? wireName; + let args: unknown; + try { + args = JSON.parse(tc.function.arguments); + } catch { + args = {}; + } + + yield { type: "tool_call", callId: tc.id, name: originalName, args }; + + try { + const result = await context.executeTool(originalName, args); + const resultStr = + typeof result === "string" ? result : JSON.stringify(result); + + yield { type: "tool_result", callId: tc.id, result }; + + messages.push({ + role: "tool", + content: resultStr, + tool_call_id: tc.id, + }); + } catch (error) { + const errMsg = + error instanceof Error ? error.message : "Tool execution failed"; + + yield { + type: "tool_result", + callId: tc.id, + result: null, + error: errMsg, + }; + + messages.push({ + role: "tool", + content: JSON.stringify({ error: errMsg }), + tool_call_id: tc.id, + }); + } + } + } + } + + private async *streamCompletion( + messages: OpenAIMessage[], + tools: OpenAITool[], + context: AgentRunContext, + ): AsyncGenerator< + AgentEvent, + { text: string; toolCalls: OpenAIToolCall[] }, + unknown + > { + const body: Record = { + messages, + stream: true, + max_tokens: this.maxTokens, + }; + + if (tools.length > 0) { + body.tools = tools; + } + + const responseBody = await this.streamBody(body, context.signal); + const reader = responseBody.getReader(); + + const decoder = new TextDecoder(); + let buffer = ""; + let fullText = ""; + const toolCallAccumulator = new Map< + number, + { id: string; name: string; arguments: string } + >(); + + try { + while (true) { + if (context.signal?.aborted) break; + + const { done, value } = await reader.read(); + if (done) break; + + buffer += decoder.decode(value, { stream: true }); + const lines = buffer.split("\n"); + buffer = lines.pop() ?? ""; + + for (const line of lines) { + const trimmed = line.trim(); + if (!trimmed.startsWith("data: ")) continue; + const data = trimmed.slice(6); + if (data === "[DONE]") continue; + + let parsed: any; + try { + parsed = JSON.parse(data); + } catch { + continue; + } + + const delta = parsed.choices?.[0]?.delta; + if (!delta) continue; + + if (delta.content) { + fullText += delta.content; + yield { type: "message_delta" as const, content: delta.content }; + } + + if (delta.tool_calls) { + for (const tc of delta.tool_calls as DeltaToolCall[]) { + const existing = toolCallAccumulator.get(tc.index); + if (existing) { + if (tc.function?.arguments) { + existing.arguments += tc.function.arguments; + } + } else { + toolCallAccumulator.set(tc.index, { + id: tc.id ?? `call_${tc.index}`, + name: tc.function?.name ?? "", + arguments: tc.function?.arguments ?? "", + }); + } + } + } + } + } + } finally { + reader.releaseLock(); + } + + const toolCalls: OpenAIToolCall[] = Array.from( + toolCallAccumulator.values(), + ).map((tc) => ({ + id: tc.id, + type: "function" as const, + function: { name: tc.name, arguments: tc.arguments || "{}" }, + })); + + return { text: fullText, toolCalls }; + } + + private async *executeToolCalls( + calls: Array<{ name: string; args: unknown }>, + messages: OpenAIMessage[], + context: AgentRunContext, + ): AsyncGenerator { + const toolCallObjs: OpenAIToolCall[] = calls.map((c, i) => ({ + id: `text_call_${i}`, + type: "function" as const, + function: { + name: c.name, + arguments: JSON.stringify(c.args), + }, + })); + + messages.push({ + role: "assistant", + content: null, + tool_calls: toolCallObjs, + }); + + for (const tc of toolCallObjs) { + const name = tc.function.name; + let args: unknown; + try { + args = JSON.parse(tc.function.arguments); + } catch { + args = {}; + } + + yield { type: "tool_call", callId: tc.id, name, args }; + + try { + const result = await context.executeTool(name, args); + const resultStr = + typeof result === "string" ? result : JSON.stringify(result); + + yield { type: "tool_result", callId: tc.id, result }; + + messages.push({ + role: "tool", + content: resultStr, + tool_call_id: tc.id, + }); + } catch (error) { + const errMsg = + error instanceof Error ? error.message : "Tool execution failed"; + + yield { + type: "tool_result", + callId: tc.id, + result: null, + error: errMsg, + }; + + messages.push({ + role: "tool", + content: JSON.stringify({ error: errMsg }), + tool_call_id: tc.id, + }); + } + } + } + + private buildMessages(messages: AgentInput["messages"]): OpenAIMessage[] { + return messages.map((m) => ({ + role: m.role as OpenAIMessage["role"], + content: m.content, + })); + } + + private buildTools( + definitions: AgentToolDefinition[], + nameToWire: Map, + ): OpenAITool[] { + return definitions.map((def) => ({ + type: "function" as const, + function: { + name: nameToWire.get(def.name) ?? def.name, + description: def.description, + parameters: def.parameters, + }, + })); + } +} + +// --------------------------------------------------------------------------- +// Vercel AI SDK helper +// --------------------------------------------------------------------------- + +/** + * Creates a Vercel AI-compatible model backed by a Databricks Model Serving endpoint. + * + * Use with `VercelAIAdapter` to get the Vercel AI SDK ecosystem (useChat, etc.) + * while targeting a Databricks `/invocations` endpoint. + * + * Handles URL rewriting (`/chat/completions` -> `/invocations`), per-request + * auth refresh, and tool name sanitization (dots -> double-underscores). + * + * Requires the `ai` and `@ai-sdk/openai` packages as peer dependencies. + * + * @example + * ```ts + * import { createApp, createAgent, agents } from "@databricks/appkit"; + * import { createDatabricksModel } from "@databricks/appkit/agents/databricks"; + * import { VercelAIAdapter } from "@databricks/appkit/agents/vercel-ai"; + * import { WorkspaceClient } from "@databricks/sdk-experimental"; + * + * const model = await createDatabricksModel({ + * workspaceClient: new WorkspaceClient({}), + * endpointName: "my-endpoint", + * }); + * + * await createApp({ + * plugins: [ + * agents({ + * agents: { + * assistant: createAgent({ + * instructions: "You are a helpful assistant.", + * model: new VercelAIAdapter({ model }), + * }), + * }, + * }), + * ], + * }); + * ``` + */ +export async function createDatabricksModel( + options: VercelDatabricksModelOptions, +): Promise { + let createOpenAI: any; + try { + const mod = await import("@ai-sdk/openai"); + createOpenAI = mod.createOpenAI; + } catch { + throw new Error( + "createDatabricksModel requires '@ai-sdk/openai' as a dependency. Install it with: npm install @ai-sdk/openai ai", + ); + } + + const config = options.workspaceClient.config; + await config.ensureResolved(); + + const baseURL = `${config.host}/serving-endpoints/${options.endpointName}`; + + const provider = createOpenAI({ + baseURL, + apiKey: "databricks", + fetch: async (url: string | URL | Request, init?: RequestInit) => { + const rewritten = String(url).replace( + "/chat/completions", + "/invocations", + ); + + const headers = new Headers(init?.headers); + await config.authenticate(headers); + + let body = init?.body; + if (typeof body === "string") { + body = rewriteToolNamesOutbound(body); + } + + const response = await globalThis.fetch(rewritten, { + ...init, + headers, + body, + }); + + if ( + !response.body || + !response.headers.get("content-type")?.includes("text/event-stream") + ) { + return response; + } + + const transformed = response.body.pipeThrough( + createToolNameRewriteStream(), + ); + + return new Response(transformed, { + status: response.status, + statusText: response.statusText, + headers: response.headers, + }); + }, + }); + + return provider(options.endpointName); +} + +/** + * Rewrites tool names in outbound request body (dots -> double-underscores). + */ +function rewriteToolNamesOutbound(body: string): string { + try { + const parsed = JSON.parse(body); + if (parsed.tools) { + for (const tool of parsed.tools) { + if (tool.function?.name) { + tool.function.name = tool.function.name.replace(/\./g, "__"); + } + } + } + return JSON.stringify(parsed); + } catch { + return body; + } +} + +/** + * Creates a TransformStream that rewrites tool names in SSE response chunks + * (double-underscores -> dots). + */ +function createToolNameRewriteStream(): TransformStream< + Uint8Array, + Uint8Array +> { + const decoder = new TextDecoder(); + const encoder = new TextEncoder(); + + return new TransformStream({ + transform(chunk, controller) { + const text = decoder.decode(chunk, { stream: true }); + const rewritten = text.replace( + /"name"\s*:\s*"([a-zA-Z0-9_-]+)"/g, + (match, name: string) => { + if (name.includes("__")) { + return match.replace(name, name.replace(/__/g, ".")); + } + return match; + }, + ); + controller.enqueue(encoder.encode(rewritten)); + }, + }); +} + +// --------------------------------------------------------------------------- +// Text-based tool call parsing (fallback) +// --------------------------------------------------------------------------- + +/** + * Parses text-based tool calls from model output. + * + * Handles two formats: + * 1. Llama native: `[{"name": "tool_name", "parameters": {"arg": "val"}}]` + * 2. Python-style: `[tool_name(arg1='val1', arg2='val2')]` + */ +export function parseTextToolCalls( + text: string, +): Array<{ name: string; args: unknown }> { + const trimmed = text.trim(); + + const jsonResult = tryParseLlamaJsonToolCalls(trimmed); + if (jsonResult.length > 0) return jsonResult; + + const pyResult = tryParsePythonStyleToolCalls(trimmed); + if (pyResult.length > 0) return pyResult; + + return []; +} + +function tryParseLlamaJsonToolCalls( + text: string, +): Array<{ name: string; args: unknown }> { + const match = text.match(/\[\s*\{[\s\S]*\}\s*\]/); + if (!match) return []; + + try { + const parsed = JSON.parse(match[0]); + if (!Array.isArray(parsed)) return []; + + return parsed + .filter( + (item: any) => + typeof item === "object" && + item !== null && + typeof item.name === "string", + ) + .map((item: any) => ({ + name: item.name, + args: item.parameters ?? item.arguments ?? item.args ?? {}, + })); + } catch { + return []; + } +} + +function tryParsePythonStyleToolCalls( + text: string, +): Array<{ name: string; args: unknown }> { + const pattern = /\[?([a-zA-Z_][\w.]*)\(([^)]*)\)\]?/g; + const results: Array<{ name: string; args: unknown }> = []; + + for (const match of text.matchAll(pattern)) { + const name = match[1]; + const argsStr = match[2]; + + const args: Record = {}; + const argPattern = /(\w+)\s*=\s*(?:'([^']*)'|"([^"]*)"|(\S+))/g; + for (const argMatch of argsStr.matchAll(argPattern)) { + const key = argMatch[1]; + const value = argMatch[2] ?? argMatch[3] ?? argMatch[4]; + args[key] = value; + } + + results.push({ name, args }); + } + + return results; +} diff --git a/packages/appkit/src/agents/langchain.ts b/packages/appkit/src/agents/langchain.ts new file mode 100644 index 00000000..77961bcf --- /dev/null +++ b/packages/appkit/src/agents/langchain.ts @@ -0,0 +1,292 @@ +import type { + AgentAdapter, + AgentEvent, + AgentInput, + AgentRunContext, + AgentToolDefinition, +} from "shared"; + +/** + * Adapter bridging LangChain/LangGraph to the AppKit agent protocol. + * + * Accepts any LangChain `Runnable` (e.g. AgentExecutor, compiled LangGraph) + * and maps `streamEvents` v2 to `AgentEvent`. + * + * Requires `@langchain/core` as an optional peer dependency. + * + * @example + * ```ts + * import { createApp, createAgent, agents } from "@databricks/appkit"; + * import { LangChainAdapter } from "@databricks/appkit/agents/langchain"; + * import { ChatOpenAI } from "@langchain/openai"; + * import { createReactAgent } from "@langchain/langgraph/prebuilt"; + * + * const model = new ChatOpenAI({ model: "gpt-4o" }); + * const agentExecutor = createReactAgent({ llm: model, tools: [] }); + * + * await createApp({ + * plugins: [ + * agents({ + * agents: { + * assistant: createAgent({ + * instructions: "You are a helpful assistant.", + * model: new LangChainAdapter({ runnable: agentExecutor }), + * }), + * }, + * }), + * ], + * }); + * ``` + */ +export class LangChainAdapter implements AgentAdapter { + private runnable: any; + + constructor(options: { runnable: any }) { + this.runnable = options.runnable; + } + + async *run( + input: AgentInput, + context: AgentRunContext, + ): AsyncGenerator { + const lcTools = await import("@langchain/core/tools"); + const DynamicStructuredTool = lcTools.DynamicStructuredTool; + const zodModule: any = await import("zod"); + const z = zodModule.z; + + const tools = this.buildTools( + input.tools, + context, + DynamicStructuredTool, + z, + ); + + const messages = input.messages.map((m) => ({ + role: m.role, + content: m.content, + })); + + yield { type: "status", status: "running" }; + + const runnableWithTools = + tools.length > 0 && typeof this.runnable.bindTools === "function" + ? this.runnable.bindTools(tools) + : this.runnable; + + const stream = await runnableWithTools.streamEvents( + { messages }, + { + version: "v2", + signal: input.signal, + }, + ); + + // Tool-call chunks from `on_chat_model_stream` come in fragments keyed by + // the model's `index`. We accumulate them and flush on `on_tool_start`. + const toolCallAccumulator = new Map< + number, + { id: string; name: string; arguments: string } + >(); + // LangChain's `on_tool_end` reports the tool via `event.run_id` (its own + // internal identifier), not the model-provided tool_call id. To keep the + // `call_id` on `tool_call` and `tool_result` matching (so clients can + // correlate them via the Responses API `call_id` field), we record the + // mapping from `run_id` to the original model `tc.id` at `on_tool_start` + // and look it up at `on_tool_end`. + const runIdToCallId = new Map(); + // Counter for fallback callIds when the model does not provide `tc.id`. + let fallbackIdx = 0; + + for await (const event of stream) { + if (context.signal?.aborted) break; + + switch (event.event) { + case "on_chat_model_stream": { + const chunk = event.data?.chunk; + if (chunk?.content && typeof chunk.content === "string") { + yield { type: "message_delta", content: chunk.content }; + } + if (chunk?.tool_call_chunks) { + for (const tc of chunk.tool_call_chunks) { + const idx = tc.index ?? 0; + const existing = toolCallAccumulator.get(idx); + if (existing) { + if (tc.args) existing.arguments += tc.args; + // Later chunks for the same tool call may carry the id/name + // that the first chunk lacked. + if (tc.id && !existing.id.startsWith("lc_")) + existing.id = tc.id; + if (tc.name && !existing.name) existing.name = tc.name; + } else if (tc.name || tc.id) { + toolCallAccumulator.set(idx, { + // Use a deterministic fallback that cannot collide if the + // same tool is called twice in one turn without a model id. + id: + tc.id ?? `lc_${tc.name ?? "tool"}_${idx}_${++fallbackIdx}`, + name: tc.name ?? "", + arguments: tc.args ?? "", + }); + } + } + } + break; + } + + case "on_tool_start": { + // Find the accumulated tool_call that matches this tool invocation + // by name so we can record the run_id → callId mapping and yield + // the `tool_call` event with a callId that will match the + // subsequent `tool_result`. + const toolName = event.name; + let matched: { id: string; name: string; arguments: string } | null = + null; + let matchedKey: number | null = null; + for (const [key, tc] of toolCallAccumulator) { + if (tc.name === toolName) { + matched = tc; + matchedKey = key; + break; + } + } + + if (matched) { + const runId = event.run_id; + if (typeof runId === "string" && runId.length > 0) { + runIdToCallId.set(runId, matched.id); + } + let args: unknown; + try { + args = JSON.parse(matched.arguments || "{}"); + } catch { + args = {}; + } + yield { + type: "tool_call" as const, + callId: matched.id, + name: matched.name, + args, + }; + if (matchedKey !== null) toolCallAccumulator.delete(matchedKey); + } else { + // Fallback: no accumulated tool_call matched this name. Emit a + // tool_call anyway with run_id as the correlating key so the + // client at least sees a call/result pair. + const runId = event.run_id ?? `lc_${toolName}_${++fallbackIdx}`; + runIdToCallId.set(runId, runId); + yield { + type: "tool_call" as const, + callId: runId, + name: toolName ?? "", + args: event.data?.input ?? {}, + }; + } + break; + } + + case "on_tool_end": { + const output = event.data?.output; + const runId = event.run_id; + const callId = + (typeof runId === "string" && runIdToCallId.get(runId)) || runId; + if (typeof runId === "string") runIdToCallId.delete(runId); + yield { + type: "tool_result", + callId, + result: output?.content ?? output, + }; + break; + } + + case "on_chain_end": { + const output = event.data?.output; + if (output?.content && typeof output.content === "string") { + yield { type: "message", content: output.content }; + } + break; + } + } + } + } + + /** + * Converts AgentToolDefinitions into LangChain DynamicStructuredTool instances. + * + * JSON Schema properties are mapped to Zod schemas using a lightweight + * recursive converter for the subset of JSON Schema types that tools use. + */ + private buildTools( + definitions: AgentToolDefinition[], + context: AgentRunContext, + DynamicStructuredTool: any, + z: any, + ): any[] { + return definitions.map( + (def) => + new DynamicStructuredTool({ + name: def.name, + description: def.description, + schema: jsonSchemaToZod(def.parameters, z), + func: async (args: unknown) => { + try { + const result = await context.executeTool(def.name, args); + return typeof result === "string" + ? result + : JSON.stringify(result); + } catch (error) { + return `Error: ${error instanceof Error ? error.message : "Tool execution failed"}`; + } + }, + }), + ); + } +} + +/** + * Lightweight JSON Schema (subset) to Zod converter. + * Handles the types commonly used in tool parameters. + */ +function jsonSchemaToZod(schema: any, z: any): any { + if (!schema) return z.object({}); + + switch (schema.type) { + case "object": { + const shape: Record = {}; + const properties = schema.properties ?? {}; + const required = new Set(schema.required ?? []); + + for (const [key, prop] of Object.entries(properties)) { + let field = jsonSchemaToZod(prop, z); + if (!required.has(key)) { + field = field.optional(); + } + if ((prop as any).description) { + field = field.describe((prop as any).description); + } + shape[key] = field; + } + return z.object(shape); + } + + case "array": + return z.array(jsonSchemaToZod(schema.items ?? {}, z)); + + case "string": { + let s = z.string(); + if (schema.enum) s = z.enum(schema.enum); + return s; + } + + case "number": + case "integer": + return z.number(); + + case "boolean": + return z.boolean(); + + case "null": + return z.null(); + + default: + return z.any(); + } +} diff --git a/packages/appkit/src/agents/tests/databricks.test.ts b/packages/appkit/src/agents/tests/databricks.test.ts new file mode 100644 index 00000000..8a835094 --- /dev/null +++ b/packages/appkit/src/agents/tests/databricks.test.ts @@ -0,0 +1,486 @@ +import type { AgentEvent, AgentToolDefinition, Message } from "shared"; +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { DatabricksAdapter, parseTextToolCalls } from "../databricks"; + +const mockAuthenticate = vi + .fn() + .mockResolvedValue({ Authorization: "Bearer test-token" }); + +function sseChunk(data: string): string { + return `data: ${data}\n\n`; +} + +function textDelta(content: string): string { + return sseChunk( + JSON.stringify({ + choices: [{ delta: { content } }], + }), + ); +} + +function toolCallDelta( + index: number, + id: string | undefined, + name: string | undefined, + args: string, +): string { + return sseChunk( + JSON.stringify({ + choices: [ + { + delta: { + tool_calls: [ + { + index, + ...(id && { id }), + ...(name && { type: "function" }), + function: { + ...(name && { name }), + arguments: args, + }, + }, + ], + }, + }, + ], + }), + ); +} + +function createReadableStream(chunks: string[]): ReadableStream { + const encoder = new TextEncoder(); + let i = 0; + return new ReadableStream({ + pull(controller) { + if (i < chunks.length) { + controller.enqueue(encoder.encode(chunks[i])); + i++; + } else { + controller.close(); + } + }, + }); +} + +function mockFetch(chunks: string[]): typeof globalThis.fetch { + return vi.fn().mockResolvedValue({ + ok: true, + body: createReadableStream(chunks), + text: () => Promise.resolve(""), + }); +} + +function createTestMessages(): Message[] { + return [{ id: "1", role: "user", content: "Hello", createdAt: new Date() }]; +} + +function createTestTools(): AgentToolDefinition[] { + return [ + { + name: "analytics.query", + description: "Run SQL", + parameters: { + type: "object", + properties: { query: { type: "string" } }, + required: ["query"], + }, + }, + ]; +} + +function createAdapter(overrides?: { + endpointUrl?: string; + authenticate?: () => Promise>; + maxSteps?: number; + maxTokens?: number; +}) { + return new DatabricksAdapter({ + endpointUrl: + "https://test.databricks.com/serving-endpoints/my-endpoint/invocations", + authenticate: mockAuthenticate, + ...overrides, + }); +} + +describe("DatabricksAdapter", () => { + const originalFetch = globalThis.fetch; + + afterEach(() => { + globalThis.fetch = originalFetch; + mockAuthenticate.mockClear(); + }); + + test("streams text deltas from the model", async () => { + globalThis.fetch = mockFetch([ + textDelta("Hello"), + textDelta(" world"), + sseChunk("[DONE]"), + ]); + + const adapter = createAdapter(); + const events: AgentEvent[] = []; + + for await (const event of adapter.run( + { messages: createTestMessages(), tools: [], threadId: "t1" }, + { executeTool: vi.fn() }, + )) { + events.push(event); + } + + expect(events[0]).toEqual({ type: "status", status: "running" }); + expect(events[1]).toEqual({ type: "message_delta", content: "Hello" }); + expect(events[2]).toEqual({ type: "message_delta", content: " world" }); + }); + + test("calls authenticate() per request for fresh headers", async () => { + globalThis.fetch = mockFetch([textDelta("Hi"), sseChunk("[DONE]")]); + + const adapter = createAdapter(); + + for await (const _ of adapter.run( + { messages: createTestMessages(), tools: [], threadId: "t1" }, + { executeTool: vi.fn() }, + )) { + // drain + } + + expect(mockAuthenticate).toHaveBeenCalledTimes(1); + + const [, init] = (globalThis.fetch as any).mock.calls[0]; + expect(init.headers.Authorization).toBe("Bearer test-token"); + }); + + test("handles structured tool calls and executes them", async () => { + const executeTool = vi.fn().mockResolvedValue([{ trip_id: 1 }]); + + let callCount = 0; + globalThis.fetch = vi.fn().mockImplementation(() => { + callCount++; + if (callCount === 1) { + return Promise.resolve({ + ok: true, + body: createReadableStream([ + toolCallDelta(0, "call_1", "analytics__query", ""), + toolCallDelta(0, undefined, undefined, '{"query":'), + toolCallDelta(0, undefined, undefined, '"SELECT 1"}'), + sseChunk("[DONE]"), + ]), + }); + } + return Promise.resolve({ + ok: true, + body: createReadableStream([ + textDelta("Here are the results"), + sseChunk("[DONE]"), + ]), + }); + }); + + const adapter = createAdapter(); + const events: AgentEvent[] = []; + + for await (const event of adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { executeTool }, + )) { + events.push(event); + } + + expect(events).toContainEqual({ + type: "tool_call", + callId: "call_1", + name: "analytics.query", + args: { query: "SELECT 1" }, + }); + + expect(executeTool).toHaveBeenCalledWith("analytics.query", { + query: "SELECT 1", + }); + + expect(events).toContainEqual( + expect.objectContaining({ + type: "tool_result", + callId: "call_1", + result: [{ trip_id: 1 }], + }), + ); + + expect(events).toContainEqual({ + type: "message_delta", + content: "Here are the results", + }); + + // authenticate() called once per streamCompletion + expect(mockAuthenticate).toHaveBeenCalledTimes(2); + }); + + test("respects maxSteps limit", async () => { + globalThis.fetch = vi.fn().mockImplementation(() => + Promise.resolve({ + ok: true, + body: createReadableStream([ + toolCallDelta( + 0, + "call_loop", + "analytics__query", + '{"query":"SELECT 1"}', + ), + sseChunk("[DONE]"), + ]), + }), + ); + + const adapter = createAdapter({ maxSteps: 2 }); + const events: AgentEvent[] = []; + + for await (const event of adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { executeTool: vi.fn().mockResolvedValue("ok") }, + )) { + events.push(event); + } + + expect(globalThis.fetch).toHaveBeenCalledTimes(2); + }); + + test("sends correct request to endpoint URL", async () => { + globalThis.fetch = mockFetch([textDelta("Hi"), sseChunk("[DONE]")]); + + const adapter = createAdapter(); + + for await (const _ of adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { executeTool: vi.fn() }, + )) { + // drain + } + + const [url, init] = (globalThis.fetch as any).mock.calls[0]; + expect(url).toBe( + "https://test.databricks.com/serving-endpoints/my-endpoint/invocations", + ); + + const body = JSON.parse(init.body); + expect(body.stream).toBe(true); + expect(body.tools).toHaveLength(1); + expect(body.tools[0].function.name).toBe("analytics__query"); + expect(body.messages[0]).toEqual({ + role: "user", + content: "Hello", + }); + }); + + test("throws on non-ok response", async () => { + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: false, + status: 401, + text: () => Promise.resolve("Unauthorized"), + }); + + const adapter = createAdapter(); + + await expect(async () => { + for await (const _ of adapter.run( + { messages: createTestMessages(), tools: [], threadId: "t1" }, + { executeTool: vi.fn() }, + )) { + // drain + } + }).rejects.toThrow("Databricks API error (401): Unauthorized"); + }); +}); + +describe("DatabricksAdapter.fromServingEndpoint", () => { + test("routes tool-free chat through apiClient.request with a streaming payload", async () => { + const apiClient = { + request: vi.fn().mockResolvedValue({ + contents: createReadableStream([textDelta("Hi"), sseChunk("[DONE]")]), + }), + }; + + const adapter = await DatabricksAdapter.fromServingEndpoint({ + workspaceClient: { apiClient }, + endpointName: "my-model", + }); + + for await (const _ of adapter.run( + { messages: createTestMessages(), tools: [], threadId: "t1" }, + { executeTool: vi.fn() }, + )) { + // drain + } + + expect(apiClient.request).toHaveBeenCalledTimes(1); + const [requestArgs] = apiClient.request.mock.calls[0]; + expect(requestArgs.path).toBe("/serving-endpoints/my-model/invocations"); + expect(requestArgs.method).toBe("POST"); + expect(requestArgs.raw).toBe(true); + expect(requestArgs.payload.stream).toBe(true); + // Auth + url encoding are the connector's (and the SDK's) concerns — the + // adapter no longer reaches into the workspace config. + }); + + test("URL-encodes endpoint names with special characters", async () => { + const apiClient = { + request: vi.fn().mockResolvedValue({ + contents: createReadableStream([textDelta("Hi"), sseChunk("[DONE]")]), + }), + }; + + const adapter = await DatabricksAdapter.fromServingEndpoint({ + workspaceClient: { apiClient }, + endpointName: "my model/with spaces", + }); + + for await (const _ of adapter.run( + { messages: createTestMessages(), tools: [], threadId: "t1" }, + { executeTool: vi.fn() }, + )) { + // drain + } + + const [requestArgs] = apiClient.request.mock.calls[0]; + expect(requestArgs.path).toBe( + "/serving-endpoints/my%20model%2Fwith%20spaces/invocations", + ); + }); +}); + +describe("DatabricksAdapter.fromModelServing", () => { + const originalEnv = process.env; + + beforeEach(() => { + process.env = { ...originalEnv }; + }); + + afterEach(() => { + process.env = originalEnv; + }); + + test("reads endpoint from DATABRICKS_AGENT_ENDPOINT env var", async () => { + process.env.DATABRICKS_AGENT_ENDPOINT = "my-model"; + + vi.mock("@databricks/sdk-experimental", () => ({ + WorkspaceClient: vi.fn().mockImplementation(() => ({ + apiClient: { request: vi.fn() }, + })), + })); + + const adapter = await DatabricksAdapter.fromModelServing(); + expect(adapter).toBeInstanceOf(DatabricksAdapter); + }); + + test("throws when no endpoint name and no env var", async () => { + delete process.env.DATABRICKS_AGENT_ENDPOINT; + + await expect(DatabricksAdapter.fromModelServing()).rejects.toThrow( + "No endpoint name provided", + ); + }); + + test("explicit endpoint name takes precedence over env var", async () => { + process.env.DATABRICKS_AGENT_ENDPOINT = "env-model"; + + const apiClient = { + request: vi.fn().mockResolvedValue({ + contents: createReadableStream([textDelta("Hi"), sseChunk("[DONE]")]), + }), + }; + + const adapter = await DatabricksAdapter.fromModelServing("explicit-model", { + workspaceClient: { apiClient }, + }); + + expect(adapter).toBeInstanceOf(DatabricksAdapter); + + for await (const _ of adapter.run( + { messages: createTestMessages(), tools: [], threadId: "t1" }, + { executeTool: vi.fn() }, + )) { + // drain + } + + const [requestArgs] = apiClient.request.mock.calls[0]; + expect(requestArgs.path).toBe( + "/serving-endpoints/explicit-model/invocations", + ); + }); +}); + +describe("parseTextToolCalls", () => { + test("parses Llama JSON format", () => { + const text = + '[{"name": "analytics.query", "parameters": {"query": "SELECT 1"}}]'; + const result = parseTextToolCalls(text); + + expect(result).toEqual([ + { name: "analytics.query", args: { query: "SELECT 1" } }, + ]); + }); + + test("parses multiple Llama JSON tool calls", () => { + const text = + '[{"name": "analytics.query", "parameters": {"query": "SELECT 1"}}, {"name": "files.uploads.list", "parameters": {}}]'; + const result = parseTextToolCalls(text); + + expect(result).toHaveLength(2); + expect(result[0].name).toBe("analytics.query"); + expect(result[1].name).toBe("files.uploads.list"); + }); + + test("parses Python-style tool calls", () => { + const text = + "[analytics.query(query='SELECT * FROM trips ORDER BY date DESC LIMIT 10')]"; + const result = parseTextToolCalls(text); + + expect(result).toEqual([ + { + name: "analytics.query", + args: { + query: "SELECT * FROM trips ORDER BY date DESC LIMIT 10", + }, + }, + ]); + }); + + test("parses Python-style with multiple args", () => { + const text = + "[files.uploads.read(path='/data/file.csv', encoding='utf-8')]"; + const result = parseTextToolCalls(text); + + expect(result).toEqual([ + { + name: "files.uploads.read", + args: { path: "/data/file.csv", encoding: "utf-8" }, + }, + ]); + }); + + test("returns empty array for plain text", () => { + expect(parseTextToolCalls("Hello, how can I help?")).toEqual([]); + expect(parseTextToolCalls("")).toEqual([]); + expect(parseTextToolCalls("The answer is 42")).toEqual([]); + }); + + test("handles Llama format with 'arguments' key", () => { + const text = + '[{"name": "lakebase.query", "arguments": {"text": "SELECT 1"}}]'; + const result = parseTextToolCalls(text); + + expect(result).toEqual([ + { name: "lakebase.query", args: { text: "SELECT 1" } }, + ]); + }); +}); diff --git a/packages/appkit/src/agents/tests/langchain.test.ts b/packages/appkit/src/agents/tests/langchain.test.ts new file mode 100644 index 00000000..3bf1a471 --- /dev/null +++ b/packages/appkit/src/agents/tests/langchain.test.ts @@ -0,0 +1,366 @@ +import type { AgentEvent, AgentToolDefinition, Message } from "shared"; +import { describe, expect, test, vi } from "vitest"; +import { LangChainAdapter } from "../langchain"; + +vi.mock("@langchain/core/tools", () => ({ + DynamicStructuredTool: vi.fn().mockImplementation((config: any) => ({ + name: config.name, + description: config.description, + schema: config.schema, + func: config.func, + })), +})); + +vi.mock("zod", () => { + const createChainable = (base: Record = {}): any => { + const obj: any = { ...base }; + obj.optional = () => createChainable({ ...obj, _optional: true }); + obj.describe = (d: string) => createChainable({ ...obj, _description: d }); + return obj; + }; + + return { + z: { + object: (shape: any) => createChainable({ type: "object", shape }), + string: () => createChainable({ type: "string" }), + number: () => createChainable({ type: "number" }), + boolean: () => createChainable({ type: "boolean" }), + array: (item: any) => createChainable({ type: "array", item }), + enum: (vals: any) => createChainable({ type: "enum", values: vals }), + any: () => createChainable({ type: "any" }), + null: () => createChainable({ type: "null" }), + }, + }; +}); + +function createTestMessages(): Message[] { + return [{ id: "1", role: "user", content: "Hello", createdAt: new Date() }]; +} + +function createTestTools(): AgentToolDefinition[] { + return [ + { + name: "lakebase.query", + description: "Run SQL", + parameters: { + type: "object", + properties: { + text: { type: "string", description: "SQL query" }, + values: { type: "array", items: {} }, + }, + required: ["text"], + }, + }, + ]; +} + +describe("LangChainAdapter", () => { + test("yields status running on start and maps chat_model_stream", async () => { + async function* mockStreamEvents() { + yield { + event: "on_chat_model_stream", + data: { chunk: { content: "Hello" } }, + }; + yield { + event: "on_chat_model_stream", + data: { chunk: { content: " world" } }, + }; + } + + const mockRunnable = { + bindTools: vi.fn().mockReturnValue({ + streamEvents: vi.fn().mockResolvedValue(mockStreamEvents()), + }), + }; + + const adapter = new LangChainAdapter({ runnable: mockRunnable }); + const events: AgentEvent[] = []; + + for await (const event of adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { executeTool: vi.fn() }, + )) { + events.push(event); + } + + expect(events[0]).toEqual({ type: "status", status: "running" }); + expect(events[1]).toEqual({ type: "message_delta", content: "Hello" }); + expect(events[2]).toEqual({ type: "message_delta", content: " world" }); + }); + + test("maps on_tool_end events to tool_result", async () => { + async function* mockStreamEvents() { + yield { + event: "on_tool_end", + run_id: "run-1", + data: { output: { content: "42 rows" } }, + }; + } + + const mockRunnable = { + bindTools: vi.fn().mockReturnValue({ + streamEvents: vi.fn().mockResolvedValue(mockStreamEvents()), + }), + }; + + const adapter = new LangChainAdapter({ runnable: mockRunnable }); + const events: AgentEvent[] = []; + + for await (const event of adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { executeTool: vi.fn() }, + )) { + events.push(event); + } + + expect(events).toContainEqual({ + type: "tool_result", + callId: "run-1", + result: "42 rows", + }); + }); + + test("calls bindTools when tools are provided", async () => { + const streamEvents = vi.fn().mockResolvedValue((async function* () {})()); + const bindTools = vi.fn().mockReturnValue({ streamEvents }); + + const adapter = new LangChainAdapter({ + runnable: { bindTools }, + }); + + for await (const _ of adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { executeTool: vi.fn() }, + )) { + // drain + } + + expect(bindTools).toHaveBeenCalledTimes(1); + expect(bindTools.mock.calls[0][0]).toHaveLength(1); + expect(bindTools.mock.calls[0][0][0].name).toBe("lakebase.query"); + }); + + test("does not call bindTools when no tools provided", async () => { + const streamEvents = vi.fn().mockResolvedValue((async function* () {})()); + const bindTools = vi.fn().mockReturnValue({ streamEvents }); + + const adapter = new LangChainAdapter({ + runnable: { bindTools, streamEvents }, + }); + + for await (const _ of adapter.run( + { + messages: createTestMessages(), + tools: [], + threadId: "t1", + }, + { executeTool: vi.fn() }, + )) { + // drain + } + + expect(bindTools).not.toHaveBeenCalled(); + }); + + test("callId on tool_call and tool_result match across on_tool_start / on_tool_end", async () => { + // Simulates the realistic LangChain stream: first the chat-model emits a + // tool_call_chunk carrying the model-provided id and name; then + // on_tool_start fires with LangChain's own run_id; then on_tool_end + // fires with the same run_id. The adapter must yield tool_call with + // callId = model id, then tool_result with the SAME callId. + async function* mockStreamEvents() { + yield { + event: "on_chat_model_stream", + data: { + chunk: { + content: "", + tool_call_chunks: [ + { + index: 0, + id: "call_abc123", + name: "lakebase.query", + args: '{"text":"SELECT 1"}', + }, + ], + }, + }, + }; + yield { + event: "on_tool_start", + name: "lakebase.query", + run_id: "run-uuid-xyz", + data: { input: { text: "SELECT 1" } }, + }; + yield { + event: "on_tool_end", + name: "lakebase.query", + run_id: "run-uuid-xyz", + data: { output: { content: "42 rows" } }, + }; + } + + const mockRunnable = { + bindTools: vi.fn().mockReturnValue({ + streamEvents: vi.fn().mockResolvedValue(mockStreamEvents()), + }), + }; + + const adapter = new LangChainAdapter({ runnable: mockRunnable }); + const events: AgentEvent[] = []; + + for await (const event of adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { executeTool: vi.fn() }, + )) { + events.push(event); + } + + const toolCall = events.find((e) => e.type === "tool_call"); + const toolResult = events.find((e) => e.type === "tool_result"); + expect(toolCall).toBeDefined(); + expect(toolResult).toBeDefined(); + // Critical invariant: same callId on the pair, using the model-provided id + // (not LangChain's internal run_id). + expect(toolCall && "callId" in toolCall ? toolCall.callId : undefined).toBe( + "call_abc123", + ); + expect( + toolResult && "callId" in toolResult ? toolResult.callId : undefined, + ).toBe("call_abc123"); + }); + + test("same tool invoked twice in one turn gets unique callIds when model omits id", async () => { + async function* mockStreamEvents() { + yield { + event: "on_chat_model_stream", + data: { + chunk: { + tool_call_chunks: [ + // Two calls to the same tool, no model-provided id on either. + { index: 0, name: "search", args: '{"q":"a"}' }, + { index: 1, name: "search", args: '{"q":"b"}' }, + ], + }, + }, + }; + yield { + event: "on_tool_start", + name: "search", + run_id: "run-1", + data: { input: { q: "a" } }, + }; + yield { + event: "on_tool_end", + name: "search", + run_id: "run-1", + data: { output: { content: "A-result" } }, + }; + yield { + event: "on_tool_start", + name: "search", + run_id: "run-2", + data: { input: { q: "b" } }, + }; + yield { + event: "on_tool_end", + name: "search", + run_id: "run-2", + data: { output: { content: "B-result" } }, + }; + } + + const mockRunnable = { + bindTools: vi.fn().mockReturnValue({ + streamEvents: vi.fn().mockResolvedValue(mockStreamEvents()), + }), + }; + + const adapter = new LangChainAdapter({ runnable: mockRunnable }); + const events: AgentEvent[] = []; + for await (const ev of adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { executeTool: vi.fn() }, + )) { + events.push(ev); + } + + const calls = events.filter( + (e): e is AgentEvent & { type: "tool_call"; callId: string } => + e.type === "tool_call", + ); + const results = events.filter( + (e): e is AgentEvent & { type: "tool_result"; callId: string } => + e.type === "tool_result", + ); + expect(calls).toHaveLength(2); + expect(results).toHaveLength(2); + expect(calls[0].callId).not.toBe(calls[1].callId); + // Each result correlates with its call. + expect(results[0].callId).toBe(calls[0].callId); + expect(results[1].callId).toBe(calls[1].callId); + }); + + test("falls back to run_id as callId when on_tool_start has no accumulated match", async () => { + async function* mockStreamEvents() { + yield { + event: "on_tool_start", + name: "orphan_tool", + run_id: "run-orphan", + data: { input: { x: 1 } }, + }; + yield { + event: "on_tool_end", + run_id: "run-orphan", + data: { output: { content: "ok" } }, + }; + } + + const mockRunnable = { + bindTools: vi.fn().mockReturnValue({ + streamEvents: vi.fn().mockResolvedValue(mockStreamEvents()), + }), + }; + + const adapter = new LangChainAdapter({ runnable: mockRunnable }); + const events: AgentEvent[] = []; + for await (const ev of adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { executeTool: vi.fn() }, + )) { + events.push(ev); + } + + const toolCall = events.find((e) => e.type === "tool_call"); + const toolResult = events.find((e) => e.type === "tool_result"); + expect(toolCall && "callId" in toolCall ? toolCall.callId : undefined).toBe( + "run-orphan", + ); + expect( + toolResult && "callId" in toolResult ? toolResult.callId : undefined, + ).toBe("run-orphan"); + }); +}); diff --git a/packages/appkit/src/agents/tests/vercel-ai.test.ts b/packages/appkit/src/agents/tests/vercel-ai.test.ts new file mode 100644 index 00000000..7280c9aa --- /dev/null +++ b/packages/appkit/src/agents/tests/vercel-ai.test.ts @@ -0,0 +1,190 @@ +import type { AgentEvent, AgentToolDefinition, Message } from "shared"; +import { describe, expect, test, vi } from "vitest"; +import { VercelAIAdapter } from "../vercel-ai"; + +vi.mock("ai", () => ({ + streamText: vi.fn(), + jsonSchema: vi.fn((schema: any) => schema), +})); + +function createTestMessages(): Message[] { + return [ + { + id: "1", + role: "user", + content: "Hello", + createdAt: new Date(), + }, + ]; +} + +function createTestTools(): AgentToolDefinition[] { + return [ + { + name: "analytics.query", + description: "Run SQL", + parameters: { + type: "object", + properties: { + query: { type: "string" }, + }, + required: ["query"], + }, + }, + ]; +} + +describe("VercelAIAdapter", () => { + test("yields status running on start", async () => { + const { streamText } = await import("ai"); + + async function* mockStream() { + yield { type: "text-delta", textDelta: "Hi" }; + } + + (streamText as any).mockReturnValue({ + fullStream: mockStream(), + }); + + const adapter = new VercelAIAdapter({ model: {} }); + const events: AgentEvent[] = []; + + const stream = adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { + executeTool: vi.fn(), + }, + ); + + for await (const event of stream) { + events.push(event); + } + + expect(events[0]).toEqual({ type: "status", status: "running" }); + expect(events[1]).toEqual({ type: "message_delta", content: "Hi" }); + }); + + test("maps tool-call and tool-result events", async () => { + const { streamText } = await import("ai"); + + async function* mockStream() { + yield { + type: "tool-call", + toolCallId: "c1", + toolName: "analytics.query", + args: { query: "SELECT 1" }, + }; + yield { + type: "tool-result", + toolCallId: "c1", + result: [{ value: 1 }], + }; + } + + (streamText as any).mockReturnValue({ + fullStream: mockStream(), + }); + + const adapter = new VercelAIAdapter({ model: {} }); + const events: AgentEvent[] = []; + + for await (const event of adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { executeTool: vi.fn() }, + )) { + events.push(event); + } + + expect(events).toContainEqual({ + type: "tool_call", + callId: "c1", + name: "analytics.query", + args: { query: "SELECT 1" }, + }); + + expect(events).toContainEqual({ + type: "tool_result", + callId: "c1", + result: [{ value: 1 }], + }); + }); + + test("maps error events", async () => { + const { streamText } = await import("ai"); + + async function* mockStream() { + yield { type: "error", error: "API rate limited" }; + } + + (streamText as any).mockReturnValue({ + fullStream: mockStream(), + }); + + const adapter = new VercelAIAdapter({ model: {} }); + const events: AgentEvent[] = []; + + for await (const event of adapter.run( + { + messages: createTestMessages(), + tools: [], + threadId: "t1", + }, + { executeTool: vi.fn() }, + )) { + events.push(event); + } + + expect(events).toContainEqual({ + type: "status", + status: "error", + error: "API rate limited", + }); + }); + + test("builds tools with execute functions that delegate to executeTool", async () => { + const { streamText } = await import("ai"); + + let capturedTools: Record = {}; + + (streamText as any).mockImplementation((opts: any) => { + capturedTools = opts.tools; + return { + fullStream: (async function* () {})(), + }; + }); + + const executeTool = vi.fn().mockResolvedValue({ count: 42 }); + const adapter = new VercelAIAdapter({ model: {} }); + + // Consume the stream to trigger streamText + for await (const _ of adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { executeTool }, + )) { + // drain + } + + expect(capturedTools["analytics.query"]).toBeDefined(); + expect(capturedTools["analytics.query"].description).toBe("Run SQL"); + + const result = await capturedTools["analytics.query"].execute({ + query: "SELECT 1", + }); + expect(executeTool).toHaveBeenCalledWith("analytics.query", { + query: "SELECT 1", + }); + expect(result).toEqual({ count: 42 }); + }); +}); diff --git a/packages/appkit/src/agents/vercel-ai.ts b/packages/appkit/src/agents/vercel-ai.ts new file mode 100644 index 00000000..ea77771a --- /dev/null +++ b/packages/appkit/src/agents/vercel-ai.ts @@ -0,0 +1,138 @@ +import type { + AgentAdapter, + AgentEvent, + AgentInput, + AgentRunContext, + AgentToolDefinition, +} from "shared"; + +/** + * Adapter bridging the Vercel AI SDK (`ai` package) to the AppKit agent protocol. + * + * Converts `AgentToolDefinition[]` to Vercel AI tool format and maps + * `streamText().fullStream` events to `AgentEvent`. + * + * Requires `ai` as an optional peer dependency. + * + * @example + * ```ts + * import { createApp, createAgent, agents } from "@databricks/appkit"; + * import { VercelAIAdapter } from "@databricks/appkit/agents/vercel-ai"; + * import { openai } from "@ai-sdk/openai"; + * + * await createApp({ + * plugins: [ + * agents({ + * agents: { + * assistant: createAgent({ + * instructions: "You are a helpful assistant.", + * model: new VercelAIAdapter({ model: openai("gpt-4o") }), + * }), + * }, + * }), + * ], + * }); + * ``` + */ +export class VercelAIAdapter implements AgentAdapter { + private model: any; + + constructor(options: { model: any }) { + this.model = options.model; + } + + async *run( + input: AgentInput, + context: AgentRunContext, + ): AsyncGenerator { + const { streamText } = await import("ai"); + const { jsonSchema } = await import("ai"); + + const tools = this.buildTools(input.tools, context, jsonSchema); + + const messages = input.messages.map((m) => ({ + role: m.role as "user" | "assistant" | "system", + content: m.content, + })); + + yield { type: "status", status: "running" }; + + const result = streamText({ + model: this.model, + messages, + tools, + maxSteps: 10 as any, + abortSignal: input.signal, + } as any); + + for await (const part of (result as any).fullStream) { + if (context.signal?.aborted) break; + + switch (part.type) { + case "text-delta": + yield { type: "message_delta", content: part.textDelta }; + break; + + case "tool-call": + yield { + type: "tool_call", + callId: part.toolCallId, + name: part.toolName, + args: part.args, + }; + break; + + case "tool-result": + yield { + type: "tool_result", + callId: part.toolCallId, + result: part.result, + }; + break; + + case "reasoning": + if (part.textDelta) { + yield { type: "thinking", content: part.textDelta }; + } + break; + + case "error": + yield { + type: "status", + status: "error", + error: String(part.error), + }; + break; + } + } + } + + private buildTools( + definitions: AgentToolDefinition[], + context: AgentRunContext, + jsonSchema: any, + ): Record { + const tools: Record = {}; + + for (const def of definitions) { + tools[def.name] = { + description: def.description, + parameters: jsonSchema(def.parameters), + execute: async (args: unknown) => { + try { + return await context.executeTool(def.name, args); + } catch (error) { + return { + error: + error instanceof Error + ? error.message + : "Tool execution failed", + }; + } + }, + }; + } + + return tools; + } +} diff --git a/packages/appkit/tsdown.config.ts b/packages/appkit/tsdown.config.ts index 97698714..0e6a4b6b 100644 --- a/packages/appkit/tsdown.config.ts +++ b/packages/appkit/tsdown.config.ts @@ -4,7 +4,12 @@ export default defineConfig([ { publint: true, name: "@databricks/appkit", - entry: "src/index.ts", + entry: [ + "src/index.ts", + "src/agents/vercel-ai.ts", + "src/agents/langchain.ts", + "src/agents/databricks.ts", + ], outDir: "dist", hash: false, format: "esm", diff --git a/packages/shared/src/agent.ts b/packages/shared/src/agent.ts new file mode 100644 index 00000000..c4f76b29 --- /dev/null +++ b/packages/shared/src/agent.ts @@ -0,0 +1,212 @@ +import type { JSONSchema7 } from "json-schema"; + +// --------------------------------------------------------------------------- +// Tool definitions +// --------------------------------------------------------------------------- + +export interface ToolAnnotations { + readOnly?: boolean; + destructive?: boolean; + idempotent?: boolean; + requiresUserContext?: boolean; +} + +export interface AgentToolDefinition { + name: string; + description: string; + parameters: JSONSchema7; + annotations?: ToolAnnotations; +} + +export interface ToolProvider { + getAgentTools(): AgentToolDefinition[]; + executeAgentTool( + name: string, + args: unknown, + signal?: AbortSignal, + ): Promise; +} + +// --------------------------------------------------------------------------- +// Messages & threads +// --------------------------------------------------------------------------- + +export interface Message { + id: string; + role: "user" | "assistant" | "system" | "tool"; + content: string; + toolCallId?: string; + toolCalls?: ToolCall[]; + createdAt: Date; +} + +export interface ToolCall { + id: string; + name: string; + args: unknown; +} + +export interface Thread { + id: string; + userId: string; + messages: Message[]; + createdAt: Date; + updatedAt: Date; +} + +// --------------------------------------------------------------------------- +// Thread store +// --------------------------------------------------------------------------- + +export interface ThreadStore { + create(userId: string): Promise; + get(threadId: string, userId: string): Promise; + list(userId: string): Promise; + addMessage(threadId: string, userId: string, message: Message): Promise; + delete(threadId: string, userId: string): Promise; +} + +// --------------------------------------------------------------------------- +// Agent events (SSE protocol) +// --------------------------------------------------------------------------- + +export type AgentEvent = + | { type: "message_delta"; content: string } + | { type: "message"; content: string } + | { type: "tool_call"; callId: string; name: string; args: unknown } + | { + type: "tool_result"; + callId: string; + result: unknown; + error?: string; + } + | { type: "thinking"; content: string } + | { + type: "status"; + status: "running" | "waiting" | "complete" | "error"; + error?: string; + } + | { type: "metadata"; data: Record }; + +// --------------------------------------------------------------------------- +// Responses API types (OpenAI-compatible wire format for HTTP boundary) +// Self-contained — no openai package dependency. +// --------------------------------------------------------------------------- + +export interface OutputTextContent { + type: "output_text"; + text: string; +} + +export interface ResponseOutputMessage { + type: "message"; + id: string; + status: "in_progress" | "completed"; + role: "assistant"; + content: OutputTextContent[]; +} + +export interface ResponseFunctionToolCall { + type: "function_call"; + id: string; + call_id: string; + name: string; + arguments: string; +} + +export interface ResponseFunctionCallOutput { + type: "function_call_output"; + id: string; + call_id: string; + output: string; +} + +export type ResponseOutputItem = + | ResponseOutputMessage + | ResponseFunctionToolCall + | ResponseFunctionCallOutput; + +export interface ResponseOutputItemAddedEvent { + type: "response.output_item.added"; + output_index: number; + item: ResponseOutputItem; + sequence_number: number; +} + +export interface ResponseOutputItemDoneEvent { + type: "response.output_item.done"; + output_index: number; + item: ResponseOutputItem; + sequence_number: number; +} + +export interface ResponseTextDeltaEvent { + type: "response.output_text.delta"; + item_id: string; + output_index: number; + content_index: number; + delta: string; + sequence_number: number; +} + +export interface ResponseCompletedEvent { + type: "response.completed"; + sequence_number: number; + response: Record; +} + +export interface ResponseErrorEvent { + type: "error"; + error: string; + sequence_number: number; +} + +export interface ResponseFailedEvent { + type: "response.failed"; + sequence_number: number; +} + +export interface AppKitThinkingEvent { + type: "appkit.thinking"; + content: string; + sequence_number: number; +} + +export interface AppKitMetadataEvent { + type: "appkit.metadata"; + data: Record; + sequence_number: number; +} + +export type ResponseStreamEvent = + | ResponseOutputItemAddedEvent + | ResponseOutputItemDoneEvent + | ResponseTextDeltaEvent + | ResponseCompletedEvent + | ResponseErrorEvent + | ResponseFailedEvent + | AppKitThinkingEvent + | AppKitMetadataEvent; + +// --------------------------------------------------------------------------- +// Adapter contract +// --------------------------------------------------------------------------- + +export interface AgentInput { + messages: Message[]; + tools: AgentToolDefinition[]; + threadId: string; + signal?: AbortSignal; +} + +export interface AgentRunContext { + executeTool: (name: string, args: unknown) => Promise; + signal?: AbortSignal; +} + +export interface AgentAdapter { + run( + input: AgentInput, + context: AgentRunContext, + ): AsyncGenerator; +} diff --git a/packages/shared/src/index.ts b/packages/shared/src/index.ts index 627d70d6..9829729a 100644 --- a/packages/shared/src/index.ts +++ b/packages/shared/src/index.ts @@ -1,3 +1,4 @@ +export * from "./agent"; export * from "./cache"; export * from "./execute"; export * from "./genie"; diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 9ca11b81..16079b1d 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -326,7 +326,16 @@ importers: ws: specifier: 8.18.3 version: 8.18.3(bufferutil@4.0.9) + zod: + specifier: ^4.0.0 + version: 4.1.13 devDependencies: + '@ai-sdk/openai': + specifier: 4.0.0-beta.27 + version: 4.0.0-beta.27(zod@4.1.13) + '@langchain/core': + specifier: ^1.1.39 + version: 1.1.39(@opentelemetry/api@1.9.0)(@opentelemetry/exporter-trace-otlp-proto@0.208.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.6.0(@opentelemetry/api@1.9.0))(ws@8.18.3(bufferutil@4.0.9)) '@types/express': specifier: 4.17.25 version: 4.17.25 @@ -342,6 +351,9 @@ importers: '@vitejs/plugin-react': specifier: 5.1.1 version: 5.1.1(rolldown-vite@7.1.14(@types/node@25.2.3)(esbuild@0.25.10)(jiti@2.6.1)(terser@5.44.1)(tsx@4.20.6)(yaml@2.8.2)) + ai: + specifier: 7.0.0-beta.76 + version: 7.0.0-beta.76(zod@4.1.13) packages/appkit-ui: dependencies: @@ -561,16 +573,38 @@ packages: peerDependencies: zod: ^3.25.76 || ^4.1.8 + '@ai-sdk/gateway@4.0.0-beta.43': + resolution: {integrity: sha512-EGQe4If6jt1ZhENmwZn8UAeHbEc7DRiK7ff7dwgfNthwso2hdzLbgXzuTO+W/op+oDFQK1pKiAz5RrPsVQWiew==} + engines: {node: '>=18'} + peerDependencies: + zod: ^3.25.76 || ^4.1.8 + + '@ai-sdk/openai@4.0.0-beta.27': + resolution: {integrity: sha512-7DpXCE4pcc4pVzuEc0whMrQN6Whi14Qsqjx97mLPGjpS6Lff48Zcn2322IFpWuhVJ10hIM1kEZNxUYvXt1O/yg==} + engines: {node: '>=18'} + peerDependencies: + zod: ^3.25.76 || ^4.1.8 + '@ai-sdk/provider-utils@3.0.19': resolution: {integrity: sha512-W41Wc9/jbUVXVwCN/7bWa4IKe8MtxO3EyA0Hfhx6grnmiYlCvpI8neSYWFE0zScXJkgA/YK3BRybzgyiXuu6JA==} engines: {node: '>=18'} peerDependencies: zod: ^3.25.76 || ^4.1.8 + '@ai-sdk/provider-utils@5.0.0-beta.16': + resolution: {integrity: sha512-CyMV5go6libw5WaZ4m7nO0uRLTENxbIODiDrTXJNwxYIBR8p5aCGaxt9oj3prbvNkTt0Srh/Gyw+n2pR9hQ5Pg==} + engines: {node: '>=18'} + peerDependencies: + zod: ^3.25.76 || ^4.1.8 + '@ai-sdk/provider@2.0.0': resolution: {integrity: sha512-6o7Y2SeO9vFKB8lArHXehNuusnpddKPk7xqL7T2/b+OvXMRIXUO1rR4wcv1hAFUAT9avGZshty3Wlua/XA7TvA==} engines: {node: '>=18'} + '@ai-sdk/provider@4.0.0-beta.10': + resolution: {integrity: sha512-E2O/LCWjqOxAUfpykQR4xLmcGXySIu6L+wYJjav2xiHu38otPq0qIexgH9ZKulBvBWkrtJ3fxz0kzHDlCBkwng==} + engines: {node: '>=18'} + '@ai-sdk/react@2.0.115': resolution: {integrity: sha512-Etu7gWSEi2dmXss1PoR5CAZGwGShXsF9+Pon1eRO6EmatjYaBMhq1CfHPyYhGzWrint8jJIK2VaAhiMef29qZw==} engines: {node: '>=18'} @@ -1520,6 +1554,9 @@ packages: resolution: {integrity: sha512-hAs5PPKPCQ3/Nha+1fo4A4/gL85fIfxZwHPehsjCJ+BhQH2/yw6/xReuaPA/RfNQr6iz1PcD7BZcE3ctyyl3EA==} cpu: [x64] + '@cfworker/json-schema@4.1.1': + resolution: {integrity: sha512-gAmrUZSGtKc3AiBL71iNWxDsyUC5uMaKKGdvzYsBoTW/xi42JQHl7eKV2OYzCUqvc+D2RCcf7EXY2iCyFIk6og==} + '@chevrotain/cst-dts-gen@11.0.3': resolution: {integrity: sha512-BvIKpRLeS/8UbfxXxgC33xOumsacaeCKAjAeLyOn7Pcp95HiRbrpl14S+9vaZLolnbssPIUuiUd8IvgkRyt6NQ==} @@ -2646,6 +2683,10 @@ packages: '@kwsites/file-exists@1.1.1': resolution: {integrity: sha512-m9/5YGR18lIwxSFDwfE3oA7bWuq9kdau6ugN4H2rJeyhFQZcG9AgSHkQtSD15a8WvTgfz9aikZMrKPHvbpqFiw==} + '@langchain/core@1.1.39': + resolution: {integrity: sha512-DP9c7TREy6iA7HnywstmUAsNyJNYTFpRg2yBfQ+6H0l1HnvQzei9GsQ36GeOLxgRaD3vm9K8urCcawSC7yQpCw==} + engines: {node: '>=20'} + '@leichtgewicht/ip-codec@2.0.5': resolution: {integrity: sha512-Vo+PSpZG2/fmgmiNzYK9qWRh8h/CHrwD0mo1h1DzL4yzHNSfWYujGTYsWGreD000gcgmZ7K4Ys6Tx9TxtsKdDw==} @@ -5166,6 +5207,10 @@ packages: resolution: {integrity: sha512-fnYhv671l+eTTp48gB4zEsTW/YtRgRPnkI2nT7x6qw5rkI1Lq2hTmQIpHPgyThI0znLK+vX2n9XxKdXZ7BUbbw==} engines: {node: '>= 20'} + '@vercel/oidc@3.2.0': + resolution: {integrity: sha512-UycprH3T6n3jH0k44NHMa7pnFHGu/N05MjojYr+Mc6I7obkoLIJujSWwin1pCvdy/eOxrI/l3uDLQsmcrOb4ug==} + engines: {node: '>= 20'} + '@vitejs/plugin-react@5.0.4': resolution: {integrity: sha512-La0KD0vGkVkSk6K+piWDKRUyg8Rl5iAIKRMH0vMJI0Eg47bq1eOxmoObAaQG37WMW9MSyk7Cs8EIWwJC1PtzKA==} engines: {node: ^20.19.0 || >=22.12.0} @@ -5321,6 +5366,12 @@ packages: peerDependencies: zod: ^3.25.76 || ^4.1.8 + ai@7.0.0-beta.76: + resolution: {integrity: sha512-yJMCqsnfUi8jnFOvxmXhjMZd0YVSCLk1E5PZpqmGWynvo3uADt1XADYYYRcj0I9Q2wsL4HbCLAKe01I8aswzJg==} + engines: {node: '>=18'} + peerDependencies: + zod: ^3.25.76 || ^4.1.8 + ajv-formats@2.1.1: resolution: {integrity: sha512-Wx0Kx52hxE7C18hkMEggYlEifqWZtYaRgouJor+WMdPnQyEK13vgEWyVNup7SoeeoLMsr4kf5h6dOW11I15MUA==} peerDependencies: @@ -6453,6 +6504,10 @@ packages: supports-color: optional: true + decamelize@1.2.0: + resolution: {integrity: sha512-z2S+W9X73hAUUki+N+9Za2lBlun89zigOyGrsax+KUQ6wKW4ZoWpEYBkGhQjwAjjDCkWxhY0VKEhk8wzY7F5cA==} + engines: {node: '>=0.10.0'} + decimal.js-light@2.5.1: resolution: {integrity: sha512-qIMFpTMZmny+MMIitAB6D7iVPEorVw6YQRWkvarTkT4tBeSLLiHzcwj6q0MmYSFCiVpiqPJTJEYIrpcPzVEIvg==} @@ -8109,6 +8164,9 @@ packages: joi@17.13.3: resolution: {integrity: sha512-otDA4ldcIx+ZXsKHWmp0YizCweVRZG96J10b0FevjfuncLO1oX59THoAmHkNubYJ+9gWsYsp5k8v4ib6oDv1fA==} + js-tiktoken@1.0.21: + resolution: {integrity: sha512-biOj/6M5qdgx5TKjDnFT1ymSpM5tbd3ylwDtrQvFQSu0Z7bBYko2dF+W/aUkXUPuk6IVpRxk/3Q2sHOzGlS36g==} + js-tokens@4.0.0: resolution: {integrity: sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==} @@ -8238,6 +8296,26 @@ packages: resolution: {integrity: sha512-QJv/h939gDpvT+9SiLVlY7tZC3xB2qK57v0J04Sh9wpMb6MP1q8gB21L3WIo8T5P1MSMg3Ep14L7KkDCFG3y4w==} engines: {node: '>=16.0.0'} + langsmith@0.5.18: + resolution: {integrity: sha512-3zuZUWffTHQ+73EAwnodADtf534VNEZUpXr9jC12qyG8/IQuJET7PRsCpTb9wX2lmBspakwLUpqpj3tNm/0bVA==} + peerDependencies: + '@opentelemetry/api': '*' + '@opentelemetry/exporter-trace-otlp-proto': '*' + '@opentelemetry/sdk-trace-base': '*' + openai: '*' + ws: '>=7' + peerDependenciesMeta: + '@opentelemetry/api': + optional: true + '@opentelemetry/exporter-trace-otlp-proto': + optional: true + '@opentelemetry/sdk-trace-base': + optional: true + openai: + optional: true + ws: + optional: true + latest-version@7.0.0: resolution: {integrity: sha512-KvNT4XqAMzdcL6ka6Tl3i2lYeFDgXNCuIX+xNx6ZMVR1dFq+idXd9FLKNMOIx0t9mJ9/HudyX4oZWXZQ0UJHeg==} engines: {node: '>=14.16'} @@ -8910,6 +8988,10 @@ packages: resolution: {integrity: sha512-2eznPJP8z2BFLX50tf0LuODrpINqP1RVIm/CObbTcBRITQgmC/TjcREF1NeTBzIcR5XO/ukWo+YHOjBbFwIupg==} hasBin: true + mustache@4.2.0: + resolution: {integrity: sha512-71ippSywq5Yb7/tVYyGbkBggbU8H3u5Rz56fH60jGFgr8uHwxs+aSKeqmluIVzM0m0kB7xQjKS6qPfd0b2ZoqQ==} + hasBin: true + mute-stream@2.0.0: resolution: {integrity: sha512-WWdIxpyjEn+FhQJQQv9aQAYlHoNVdzIzUySNV1gHUPDSdZJ3yZn7pAAbQcV7B56Mvu881q9FZV+0Vx2xC44VWA==} engines: {node: ^18.17.0 || >=20.5.0} @@ -11463,6 +11545,10 @@ packages: resolution: {integrity: sha512-pMZTvIkT1d+TFGvDOqodOclx0QWkkgi6Tdoa8gC8ffGAAqz9pzPTZWAybbsHHoED/ztMtkv/VoYTYyShUn81hA==} engines: {node: '>= 0.4.0'} + uuid@10.0.0: + resolution: {integrity: sha512-8XkAphELsDnEGrDxUOHB3RGvXz6TeuYSGEZBOjtTtPm2lwhGBjLgOzLHB63IUWfBpNucQjND6d3AOudO+H3RWQ==} + hasBin: true + uuid@11.1.0: resolution: {integrity: sha512-0/A9rDy9P7cJ+8w1c9WD9V//9Wj15Ce2MPz8Ri6032usz+NfePxx5AcN3bN+r6ZL6jEo066/yNYB3tn4pQEx+A==} hasBin: true @@ -11937,6 +12023,19 @@ snapshots: '@vercel/oidc': 3.0.5 zod: 4.1.13 + '@ai-sdk/gateway@4.0.0-beta.43(zod@4.1.13)': + dependencies: + '@ai-sdk/provider': 4.0.0-beta.10 + '@ai-sdk/provider-utils': 5.0.0-beta.16(zod@4.1.13) + '@vercel/oidc': 3.2.0 + zod: 4.1.13 + + '@ai-sdk/openai@4.0.0-beta.27(zod@4.1.13)': + dependencies: + '@ai-sdk/provider': 4.0.0-beta.10 + '@ai-sdk/provider-utils': 5.0.0-beta.16(zod@4.1.13) + zod: 4.1.13 + '@ai-sdk/provider-utils@3.0.19(zod@4.1.13)': dependencies: '@ai-sdk/provider': 2.0.0 @@ -11944,10 +12043,21 @@ snapshots: eventsource-parser: 3.0.6 zod: 4.1.13 + '@ai-sdk/provider-utils@5.0.0-beta.16(zod@4.1.13)': + dependencies: + '@ai-sdk/provider': 4.0.0-beta.10 + '@standard-schema/spec': 1.1.0 + eventsource-parser: 3.0.6 + zod: 4.1.13 + '@ai-sdk/provider@2.0.0': dependencies: json-schema: 0.4.0 + '@ai-sdk/provider@4.0.0-beta.10': + dependencies: + json-schema: 0.4.0 + '@ai-sdk/react@2.0.115(react@19.2.0)(zod@4.1.13)': dependencies: '@ai-sdk/provider-utils': 3.0.19(zod@4.1.13) @@ -13078,6 +13188,8 @@ snapshots: '@cdxgen/cdxgen-plugins-bin@2.0.2': optional: true + '@cfworker/json-schema@4.1.1': {} + '@chevrotain/cst-dts-gen@11.0.3': dependencies: '@chevrotain/gast': 11.0.3 @@ -14858,6 +14970,26 @@ snapshots: transitivePeerDependencies: - supports-color + '@langchain/core@1.1.39(@opentelemetry/api@1.9.0)(@opentelemetry/exporter-trace-otlp-proto@0.208.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.6.0(@opentelemetry/api@1.9.0))(ws@8.18.3(bufferutil@4.0.9))': + dependencies: + '@cfworker/json-schema': 4.1.1 + '@standard-schema/spec': 1.1.0 + ansi-styles: 5.2.0 + camelcase: 6.3.0 + decamelize: 1.2.0 + js-tiktoken: 1.0.21 + langsmith: 0.5.18(@opentelemetry/api@1.9.0)(@opentelemetry/exporter-trace-otlp-proto@0.208.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.6.0(@opentelemetry/api@1.9.0))(ws@8.18.3(bufferutil@4.0.9)) + mustache: 4.2.0 + p-queue: 6.6.2 + uuid: 11.1.0 + zod: 4.1.13 + transitivePeerDependencies: + - '@opentelemetry/api' + - '@opentelemetry/exporter-trace-otlp-proto' + - '@opentelemetry/sdk-trace-base' + - openai + - ws + '@leichtgewicht/ip-codec@2.0.5': {} '@mdx-js/mdx@3.1.1': @@ -17555,6 +17687,8 @@ snapshots: '@vercel/oidc@3.0.5': {} + '@vercel/oidc@3.2.0': {} + '@vitejs/plugin-react@5.0.4(vite@7.2.4(@types/node@24.7.2)(jiti@2.6.1)(lightningcss@1.30.2)(terser@5.44.1)(tsx@4.20.6)(yaml@2.8.2))': dependencies: '@babel/core': 7.28.5 @@ -17779,6 +17913,13 @@ snapshots: '@opentelemetry/api': 1.9.0 zod: 4.1.13 + ai@7.0.0-beta.76(zod@4.1.13): + dependencies: + '@ai-sdk/gateway': 4.0.0-beta.43(zod@4.1.13) + '@ai-sdk/provider': 4.0.0-beta.10 + '@ai-sdk/provider-utils': 5.0.0-beta.16(zod@4.1.13) + zod: 4.1.13 + ajv-formats@2.1.1(ajv@8.17.1): optionalDependencies: ajv: 8.17.1 @@ -19053,6 +19194,8 @@ snapshots: dependencies: ms: 2.1.3 + decamelize@1.2.0: {} + decimal.js-light@2.5.1: {} decimal.js@10.6.0: {} @@ -20873,6 +21016,10 @@ snapshots: '@sideway/formula': 3.0.1 '@sideway/pinpoint': 2.0.0 + js-tiktoken@1.0.21: + dependencies: + base64-js: 1.5.1 + js-tokens@4.0.0: {} js-tokens@9.0.1: {} @@ -21027,6 +21174,16 @@ snapshots: vscode-languageserver-textdocument: 1.0.12 vscode-uri: 3.0.8 + langsmith@0.5.18(@opentelemetry/api@1.9.0)(@opentelemetry/exporter-trace-otlp-proto@0.208.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.6.0(@opentelemetry/api@1.9.0))(ws@8.18.3(bufferutil@4.0.9)): + dependencies: + p-queue: 6.6.2 + uuid: 10.0.0 + optionalDependencies: + '@opentelemetry/api': 1.9.0 + '@opentelemetry/exporter-trace-otlp-proto': 0.208.0(@opentelemetry/api@1.9.0) + '@opentelemetry/sdk-trace-base': 2.6.0(@opentelemetry/api@1.9.0) + ws: 8.18.3(bufferutil@4.0.9) + latest-version@7.0.0: dependencies: package-json: 8.1.1 @@ -21964,6 +22121,8 @@ snapshots: dns-packet: 5.6.1 thunky: 1.1.0 + mustache@4.2.0: {} + mute-stream@2.0.0: {} nanoid@3.3.11: {} @@ -24753,6 +24912,8 @@ snapshots: utils-merge@1.0.1: {} + uuid@10.0.0: {} + uuid@11.1.0: {} uuid@13.0.0: {} From b45e3383e4f610ea58b92e8557f989a419cec251 Mon Sep 17 00:00:00 2001 From: MarioCadenas Date: Tue, 21 Apr 2026 19:46:14 +0200 Subject: [PATCH 10/23] feat(appkit): tool primitives and ToolProvider surfaces on core plugins MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Second layer of the agents feature. Adds the primitives for defining agent tools and implements them on every core ToolProvider plugin. - `tool(config)` — inline function tools backed by a Zod schema. Auto- generates JSON Schema for the LLM via `z.toJSONSchema()` (stripping the top-level `$schema` annotation that Gemini rejects), runtime- validates tool-call arguments, returns an LLM-friendly error string on validation failure so the model can self-correct. - `mcpServer(name, url)` — tiny factory for hosted custom MCP server configs. Replaces the verbose `{ type: "custom_mcp_server", custom_mcp_server: { app_name, app_url } }` wrapper. - `FunctionTool` / `HostedTool` types + `isFunctionTool` / `isHostedTool` type guards. `HostedTool` is a union of Genie, VectorSearch, custom MCP, and external-connection configs. - `ToolkitEntry` + `ToolkitOptions` types + `isToolkitEntry` guard. `AgentTool = FunctionTool | HostedTool | ToolkitEntry` is the canonical union later PRs spread into agent definitions. - `defineTool(config)` + `ToolRegistry` — plugin authors' internal shape for declaring a keyed set of tools with Zod-typed handlers. - `toolsFromRegistry()` — produces the `AgentToolDefinition[]` exposed via `ToolProvider.getAgentTools()`. - `executeFromRegistry()` — validates args then dispatches to the handler. Returns LLM-friendly errors on bad args. - `toToolJSONSchema()` — shared helper at `packages/appkit/src/plugins/agents/tools/json-schema.ts` that wraps `toJSONSchema()` and strips `$schema`. Used by `tool()`, `toolsFromRegistry()`, and `buildToolkitEntries()`. - `buildToolkitEntries(pluginName, registry, opts?)` — converts a plugin's internal `ToolRegistry` into a keyed record of `ToolkitEntry` markers, honoring `prefix` / `only` / `except` / `rename`. - `AppKitMcpClient` — minimal JSON-RPC 2.0 client over SSE, zero deps. Handles auth refresh, per-server connection pooling, and tool definition aggregation. - `resolveHostedTools()` — maps `HostedTool` configs to Databricks MCP endpoint URLs. - **analytics** — `query` tool (Zod-typed, asUser dispatch) - **files** — per-volume tool family: `${volumeKey}.{list,read,exists,metadata,upload,delete}` (dynamically named from the plugin's volume config) - **genie** — per-space tool family: `${alias}.{sendMessage,getConversation}` (dynamically named from the plugin's spaces config) - **lakebase** — `query` tool Each plugin gains `getAgentTools()` + `executeAgentTool()` satisfying the `ToolProvider` interface, plus a `.toolkit(opts?)` method that returns a record of `ToolkitEntry` markers for later spread into agent definitions. - 58 new tests across tool primitives + plugin ToolProvider surfaces - Full appkit vitest suite: 1212 tests passing - Typecheck clean - Build clean, publint clean Signed-off-by: MarioCadenas New `mcp-host-policy.ts` module enforces an allowlist on every MCP URL before the first byte is sent. Same-origin Databricks workspace URLs are admitted by default; any other host must be explicitly trusted via the new `AgentsPluginConfig.mcp.trustedHosts` field (added in a subsequent stack layer). - Rejects non-`http(s)` schemes and plaintext `http://` outside of localhost-in-dev. - Blocks link-local (`169.254/16` — cloud metadata), RFC1918, CGNAT, loopback (unless `allowLocalhost`), ULA, multicast, and IPv4-mapped IPv6 equivalents at DNS-resolve time. IP-literal URLs in these ranges are rejected without a DNS lookup. Malformed IPs fail-closed. - `AppKitMcpClient` constructor now takes the policy as a third arg. Workspace credentials (SP on `initialize`/`tools/list`; caller- supplied OBO on `tools/call`) are never attached to non-workspace hosts — `callTool` drops caller OBO overrides for external destinations, and `sendRpc`/`sendNotification` never invoke `authenticate()` when `forwardWorkspaceAuth` is false. - Constructor accepts optional `{ dnsLookup, fetchImpl }` for test DI. New tests: - `mcp-host-policy.test.ts` (42 tests): config builder, URL check, IP blocklist, DNS-backed resolution with split-DNS defense. - `mcp-client.test.ts` (8 tests): integrated client with recording fetch — verifies no fetch + no `authenticate()` call when URL is rejected, and that auth headers are scoped correctly per-destination. - `json-schema.ts`: biome formatting fix (pre-existing drift). - `packages/appkit/src/index.ts`: biome organizeImports fix (pre-existing sort order drift). Full appkit vitest suite: 1262 tests passing (+50 from security). Signed-off-by: MarioCadenas New `sql-policy.ts` module provides `classifyReadOnly(sql)` and `assertReadOnlySql(sql)` — a dependency-free tokenizer-based classifier that rejects any statement outside `SELECT | WITH | SHOW | EXPLAIN | DESCRIBE` at execution time. Also exports `wrapInReadOnlyTransaction(stmt)` which produces a `BEGIN READ ONLY … ROLLBACK` envelope for belt-and-suspenders enforcement on PostgreSQL. Why a hand-rolled tokenizer rather than `node-sql-parser` or `pgsql-parser`: - `node-sql-parser`'s Hive/Spark dialect coverage rejects common Databricks SQL patterns (three-part names, `SHOW TABLES IN`, `DESCRIBE EXTENDED`, `EXPLAIN`); its PostgreSQL grammar rejects the same meta-commands. - `pgsql-parser` (libpg_query) is a native binding and fails to install cleanly on Databricks App runtimes. - We only need statement-type classification, not full parsing. The tokenizer handles line/block comments (nested), single- and double-quoted literals, ANSI/backtick identifiers, PostgreSQL dollar-quoted strings, `E'..'` escape strings, and reports unterminated literals as fail-closed. 62 tests exercise evasion vectors (stacked writes, quoted keywords, comment-hidden writes, mismatched dollar-quote tags, unterminated strings). `analytics.query` was annotated `{ readOnly: true, requiresUserContext: true }` but the annotation was a claim only. A prompt-injected LLM could send `UPDATE`, `DELETE`, or `DROP` and the warehouse would run it subject to the end user's SQL grants. The tool now calls `assertReadOnlySql` before reaching `this.query()`. A rejection surfaces an LLM-friendly error the model can self-correct on; tests verify writes never reach the warehouse. Public `AppKit.analytics.query(...)` continues to accept arbitrary SQL — app authors use it intentionally; LLMs do not. `lakebase.query` previously shipped as an always-on agent tool with `{ readOnly: false, destructive: false, idempotent: false }` (`destructive: false` was an outright lie) and executed arbitrary LLM SQL against the SP-scoped pool, auto-inherited by every markdown agent. The plugin now registers **no** agent tool by default. Opt-in via: ```ts lakebase({ exposeAsAgentTool: { iUnderstandRunsAsServicePrincipal: true, readOnly: true, // default }, }); ``` The acknowledgement flag is required because the pool is bound to the service principal regardless of which end user invokes the tool — enabling the tool is a deliberate privilege grant. When opted in with `readOnly: true` (default): - Statement classified by `classifyReadOnly` (rejects non-SELECT with an LLM-friendly error). - Remaining statement executed inside `BEGIN READ ONLY; …; ROLLBACK` so PostgreSQL enforces server-side even if a side-effecting function slips past the classifier. - Annotations: `{ readOnly: true, destructive: false, idempotent: false }`. When opted in with `readOnly: false`: - Statement passed through unchanged. - Annotations: `{ readOnly: false, destructive: true, idempotent: false }`. The `destructive: true` signal will be honored by the agents plugin's HITL approval gate in PR #304. `LakebasePlugin` is now `export class` so tests can construct it directly. New test file `lakebase-agent-tool.test.ts` (9 tests) verifies defaults, opt-in, acknowledgement enforcement, readOnly rejection + wrap, and destructive pass-through. Full appkit vitest suite: 1340 tests passing (+78 from S-2 Layer 1+2). Signed-off-by: MarioCadenas Groundwork for flipping the unsafe `autoInheritTools: { file: true }` default into opt-in auto-inherit gated by per-tool consent. Adds an `autoInheritable?: boolean` field to `ToolEntry` (defined via `defineTool`) and propagates it through `buildToolkitEntries` onto the resulting `ToolkitEntry`. The agents plugin consumes this flag in a subsequent stack layer to filter auto-inherited tools — any tool whose plugin author has not explicitly marked `autoInheritable: true` never spreads into agents that enable auto-inherit. Default is `false` for defense-in-depth. Plugin authors must consciously mark tools that are genuinely safe for unscoped exposure. - `analytics.query`: `autoInheritable: true`. The tool is OBO-scoped and `readOnly: true` is enforced at runtime (S2). - `files.list` / `files.read` / `files.exists` / `files.metadata`: `autoInheritable: true`. Pure read operations, OBO-scoped. - `files.upload` / `files.delete`: NOT auto-inheritable. Destructive; must be explicitly wired by the agent author. - `genie.getConversation`: `autoInheritable: true` (read-only history). - `genie.sendMessage`: NOT auto-inheritable. State-mutating Genie conversation; user wires explicitly if desired. - `lakebase.query`: NOT auto-inheritable. The tool is already gated by the explicit `exposeAsAgentTool` acknowledgement (S2); auto-inherit remains closed as defense-in-depth. - `build-toolkit.test.ts` gains a case verifying propagation: explicit `true`, explicit `false`, and omitted (undefined) all flow through to the `ToolkitEntry` unchanged. Signed-off-by: MarioCadenas - **MCP caller abort signal composition.** `callTool` now accepts an optional `callerSignal` and composes it with the existing 30 s timeout via `AbortSignal.any([...])`. The agents plugin threads its stream signal through in a subsequent stack layer so that a `POST /cancel` or agent-run shutdown immediately propagates to the MCP fetch, rather than leaving in-flight MCP calls running on the remote server until they complete. Adds a regression test that aborts the caller signal mid-fetch and asserts the fetch rejects with `AbortError`. Signed-off-by: MarioCadenas Three issues flagged by the re-review pass (one correctness HIGH, two security MEDIUMs). - **Lakebase read-only path no longer emits a multi-statement batch.** `wrapInReadOnlyTransaction` returned `"BEGIN READ ONLY;\n;\ nROLLBACK;"` as one string passed to `pool.query(text, values)`. As soon as the agent supplied `values`, pg switched to the Extended Query protocol and PostgreSQL rejected the batch with `cannot insert multiple commands into a prepared statement`, silently breaking every parameterized read-only tool call in production. The mocked lakebase test concealed this. The helper is removed; `LakebasePlugin.runReadOnlyStatement` now acquires a dedicated client from the pool and runs three separate `client.query` calls on the same connection (`BEGIN READ ONLY`, user statement with values, `ROLLBACK`), with a `finally` that rolls back and releases the client even when the user statement throws. Four tests cover the new flow: dispatch-time rejection, statement ordering, parameter forwarding, and release-on-error. - **`isBlockedIpv6` link-local `fe80::/10` now covers the full range.** Previous regex `startsWith("fe80:") || startsWith("fe9")` only matched `fe80`–`fe9f`, leaving `fea0`–`febf` (valid link-local) passable. Replaced with `/^fe[89ab][0-9a-f]:/.test(lowered)` so the second hex nibble is checked against `8..b`. - **`::ffff::` IPv4-mapped IPv6 is now normalised.** The colon-hex form of an IPv4-mapped address (`::ffff:a9fe:a9fe` = 169.254.169.254) previously bypassed the IPv4 blocklist because `isIPv4("a9fe:a9fe")` is false. `hexPairToDottedIpv4` reassembles the trailing two hex groups into dotted form and delegates to `isBlockedIpv4`. Regression tests cover the metadata, 10/8, and 192.168/16 equivalents; a public IPv4 mapped to colon-hex still passes through. Signed-off-by: MarioCadenas --- packages/appkit/src/index.ts | 24 ++ .../src/plugins/agents/build-toolkit.ts | 63 +++ .../agents/tests/build-toolkit.test.ts | 101 +++++ .../plugins/agents/tests/define-tool.test.ts | 133 ++++++ .../agents/tests/function-tool.test.ts | 110 +++++ .../plugins/agents/tests/hosted-tools.test.ts | 131 ++++++ .../plugins/agents/tests/mcp-client.test.ts | 402 ++++++++++++++++++ .../agents/tests/mcp-host-policy.test.ts | 354 +++++++++++++++ .../agents/tests/mcp-server-helper.test.ts | 34 ++ .../plugins/agents/tests/sql-policy.test.ts | 227 ++++++++++ .../src/plugins/agents/tests/tool.test.ts | 110 +++++ .../src/plugins/agents/tools/define-tool.ts | 94 ++++ .../src/plugins/agents/tools/function-tool.ts | 33 ++ .../src/plugins/agents/tools/hosted-tools.ts | 102 +++++ .../appkit/src/plugins/agents/tools/index.ts | 20 + .../src/plugins/agents/tools/json-schema.ts | 20 + .../src/plugins/agents/tools/mcp-client.ts | 394 +++++++++++++++++ .../plugins/agents/tools/mcp-host-policy.ts | 299 +++++++++++++ .../src/plugins/agents/tools/sql-policy.ts | 317 ++++++++++++++ .../appkit/src/plugins/agents/tools/tool.ts | 53 +++ packages/appkit/src/plugins/agents/types.ts | 54 +++ .../appkit/src/plugins/analytics/analytics.ts | 58 ++- .../tests/analytics.readonly.test.ts | 133 ++++++ .../plugins/analytics/tests/analytics.test.ts | 18 + packages/appkit/src/plugins/files/plugin.ts | 130 +++++- .../src/plugins/files/tests/plugin.test.ts | 56 +++ packages/appkit/src/plugins/genie/genie.ts | 82 +++- .../src/plugins/genie/tests/genie.test.ts | 24 ++ .../appkit/src/plugins/lakebase/lakebase.ts | 125 +++++- .../tests/lakebase-agent-tool.test.ts | 238 +++++++++++ packages/appkit/src/plugins/lakebase/types.ts | 43 ++ 31 files changed, 3970 insertions(+), 12 deletions(-) create mode 100644 packages/appkit/src/plugins/agents/build-toolkit.ts create mode 100644 packages/appkit/src/plugins/agents/tests/build-toolkit.test.ts create mode 100644 packages/appkit/src/plugins/agents/tests/define-tool.test.ts create mode 100644 packages/appkit/src/plugins/agents/tests/function-tool.test.ts create mode 100644 packages/appkit/src/plugins/agents/tests/hosted-tools.test.ts create mode 100644 packages/appkit/src/plugins/agents/tests/mcp-client.test.ts create mode 100644 packages/appkit/src/plugins/agents/tests/mcp-host-policy.test.ts create mode 100644 packages/appkit/src/plugins/agents/tests/mcp-server-helper.test.ts create mode 100644 packages/appkit/src/plugins/agents/tests/sql-policy.test.ts create mode 100644 packages/appkit/src/plugins/agents/tests/tool.test.ts create mode 100644 packages/appkit/src/plugins/agents/tools/define-tool.ts create mode 100644 packages/appkit/src/plugins/agents/tools/function-tool.ts create mode 100644 packages/appkit/src/plugins/agents/tools/hosted-tools.ts create mode 100644 packages/appkit/src/plugins/agents/tools/index.ts create mode 100644 packages/appkit/src/plugins/agents/tools/json-schema.ts create mode 100644 packages/appkit/src/plugins/agents/tools/mcp-client.ts create mode 100644 packages/appkit/src/plugins/agents/tools/mcp-host-policy.ts create mode 100644 packages/appkit/src/plugins/agents/tools/sql-policy.ts create mode 100644 packages/appkit/src/plugins/agents/tools/tool.ts create mode 100644 packages/appkit/src/plugins/agents/types.ts create mode 100644 packages/appkit/src/plugins/analytics/tests/analytics.readonly.test.ts create mode 100644 packages/appkit/src/plugins/lakebase/tests/lakebase-agent-tool.test.ts diff --git a/packages/appkit/src/index.ts b/packages/appkit/src/index.ts index a4666a49..6c6c6f5b 100644 --- a/packages/appkit/src/index.ts +++ b/packages/appkit/src/index.ts @@ -7,11 +7,20 @@ // Types from shared export type { + AgentAdapter, + AgentEvent, + AgentInput, + AgentRunContext, + AgentToolDefinition, BasePluginConfig, CacheConfig, IAppRouter, + Message, PluginData, StreamExecutionSettings, + Thread, + ThreadStore, + ToolProvider, } from "shared"; export { isSQLTypeMarker, sql } from "shared"; export { CacheManager } from "./cache"; @@ -54,6 +63,21 @@ export { toPlugin, } from "./plugin"; export { analytics, files, genie, lakebase, server, serving } from "./plugins"; +export { + type FunctionTool, + type HostedTool, + isFunctionTool, + isHostedTool, + mcpServer, + type ToolConfig, + tool, +} from "./plugins/agents/tools"; +export { + type AgentTool, + isToolkitEntry, + type ToolkitEntry, + type ToolkitOptions, +} from "./plugins/agents/types"; // Files plugin types (for custom policy authoring) export type { FileAction, diff --git a/packages/appkit/src/plugins/agents/build-toolkit.ts b/packages/appkit/src/plugins/agents/build-toolkit.ts new file mode 100644 index 00000000..0140425d --- /dev/null +++ b/packages/appkit/src/plugins/agents/build-toolkit.ts @@ -0,0 +1,63 @@ +import type { AgentToolDefinition } from "shared"; +import type { ToolRegistry } from "./tools/define-tool"; +import { toToolJSONSchema } from "./tools/json-schema"; +import type { ToolkitEntry, ToolkitOptions } from "./types"; + +/** + * Converts a plugin's internal `ToolRegistry` into a keyed record of + * `ToolkitEntry` markers suitable for spreading into an `AgentDefinition.tools` + * record. + * + * The `opts` record controls shape and filtering: + * - `prefix` — overrides the default `${pluginName}.` prefix; `""` drops it. + * - `only` — allowlist of local tool names to include (post-prefix). + * - `except` — denylist of local names. + * - `rename` — per-tool key remapping (applied after prefix/filter). + * + * Each entry carries `pluginName` + `localName` so the agents plugin can + * dispatch back through `PluginContext.executeTool` for OBO + telemetry. + */ +export function buildToolkitEntries( + pluginName: string, + registry: ToolRegistry, + opts: ToolkitOptions = {}, +): Record { + const prefix = opts.prefix ?? `${pluginName}.`; + const only = opts.only ? new Set(opts.only) : null; + const except = opts.except ? new Set(opts.except) : null; + const rename = opts.rename ?? {}; + + const out: Record = {}; + + for (const [localName, entry] of Object.entries(registry)) { + if (only && !only.has(localName)) continue; + if (except?.has(localName)) continue; + + const keyAfterPrefix = `${prefix}${localName}`; + const key = rename[localName] ?? keyAfterPrefix; + + const parameters = toToolJSONSchema( + entry.schema, + ) as unknown as AgentToolDefinition["parameters"]; + + const def: AgentToolDefinition = { + name: key, + description: entry.description, + parameters, + }; + if (entry.annotations) { + def.annotations = entry.annotations; + } + + out[key] = { + __toolkitRef: true, + pluginName, + localName, + def, + annotations: entry.annotations, + autoInheritable: entry.autoInheritable, + }; + } + + return out; +} diff --git a/packages/appkit/src/plugins/agents/tests/build-toolkit.test.ts b/packages/appkit/src/plugins/agents/tests/build-toolkit.test.ts new file mode 100644 index 00000000..08f71da9 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/build-toolkit.test.ts @@ -0,0 +1,101 @@ +import { describe, expect, test } from "vitest"; +import { z } from "zod"; +import { buildToolkitEntries } from "../build-toolkit"; +import { defineTool, type ToolRegistry } from "../tools/define-tool"; +import { isToolkitEntry } from "../types"; + +const registry: ToolRegistry = { + query: defineTool({ + description: "Run a query", + schema: z.object({ sql: z.string() }), + handler: () => "ok", + }), + history: defineTool({ + description: "Get query history", + schema: z.object({}), + handler: () => [], + }), +}; + +describe("buildToolkitEntries", () => { + test("produces ToolkitEntry per registry item with default dotted prefix", () => { + const entries = buildToolkitEntries("analytics", registry); + expect(Object.keys(entries).sort()).toEqual([ + "analytics.history", + "analytics.query", + ]); + for (const entry of Object.values(entries)) { + expect(isToolkitEntry(entry)).toBe(true); + expect(entry.pluginName).toBe("analytics"); + } + }); + + test("respects prefix option (empty drops the namespace)", () => { + const entries = buildToolkitEntries("analytics", registry, { prefix: "" }); + expect(Object.keys(entries).sort()).toEqual(["history", "query"]); + }); + + test("respects custom prefix", () => { + const entries = buildToolkitEntries("analytics", registry, { + prefix: "db.", + }); + expect(Object.keys(entries).sort()).toEqual(["db.history", "db.query"]); + }); + + test("only filter keeps the listed local names", () => { + const entries = buildToolkitEntries("analytics", registry, { + only: ["query"], + }); + expect(Object.keys(entries)).toEqual(["analytics.query"]); + }); + + test("except filter drops the listed local names", () => { + const entries = buildToolkitEntries("analytics", registry, { + except: ["history"], + }); + expect(Object.keys(entries)).toEqual(["analytics.query"]); + }); + + test("rename remaps specific local names (overrides the prefix key)", () => { + const entries = buildToolkitEntries("analytics", registry, { + rename: { query: "sql" }, + }); + expect(Object.keys(entries).sort()).toEqual(["analytics.history", "sql"]); + }); + + test("exposes the original plugin+local name so dispatch can route", () => { + const entries = buildToolkitEntries("analytics", registry, { + prefix: "db.", + }); + const qEntry = entries["db.query"]; + expect(qEntry.pluginName).toBe("analytics"); + expect(qEntry.localName).toBe("query"); + expect(qEntry.def.name).toBe("db.query"); + }); + + test("propagates autoInheritable from the source registry", () => { + const mixed: ToolRegistry = { + readIt: defineTool({ + description: "safe read", + schema: z.object({}), + autoInheritable: true, + handler: () => "ok", + }), + writeIt: defineTool({ + description: "unsafe write", + schema: z.object({}), + autoInheritable: false, + handler: () => "ok", + }), + unmarked: defineTool({ + description: "default: not auto-inheritable", + schema: z.object({}), + handler: () => "ok", + }), + }; + const entries = buildToolkitEntries("p", mixed); + expect(entries["p.readIt"].autoInheritable).toBe(true); + expect(entries["p.writeIt"].autoInheritable).toBe(false); + expect(entries["p.unmarked"].autoInheritable).toBeUndefined(); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/define-tool.test.ts b/packages/appkit/src/plugins/agents/tests/define-tool.test.ts new file mode 100644 index 00000000..ef61e8c4 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/define-tool.test.ts @@ -0,0 +1,133 @@ +import { describe, expect, test, vi } from "vitest"; +import { z } from "zod"; +import { + defineTool, + executeFromRegistry, + type ToolRegistry, + toolsFromRegistry, +} from "../tools/define-tool"; + +describe("defineTool()", () => { + test("returns an entry matching the input config", () => { + const entry = defineTool({ + description: "echo", + schema: z.object({ msg: z.string() }), + annotations: { readOnly: true }, + handler: ({ msg }) => msg, + }); + + expect(entry.description).toBe("echo"); + expect(entry.annotations).toEqual({ readOnly: true }); + expect(typeof entry.handler).toBe("function"); + }); +}); + +describe("executeFromRegistry", () => { + const registry: ToolRegistry = { + echo: defineTool({ + description: "echo", + schema: z.object({ msg: z.string() }), + handler: ({ msg }) => `got ${msg}`, + }), + }; + + test("validates args and calls handler on success", async () => { + const result = await executeFromRegistry(registry, "echo", { msg: "hi" }); + expect(result).toBe("got hi"); + }); + + test("returns formatted error string on validation failure", async () => { + const result = await executeFromRegistry(registry, "echo", {}); + expect(typeof result).toBe("string"); + expect(result).toContain("Invalid arguments for echo"); + expect(result).toContain("msg"); + }); + + test("throws for unknown tool names", async () => { + await expect(executeFromRegistry(registry, "missing", {})).rejects.toThrow( + /Unknown tool: missing/, + ); + }); + + test("forwards AbortSignal to the handler", async () => { + const handler = vi.fn(async (_args: { x: string }, signal?: AbortSignal) => + signal?.aborted ? "aborted" : "ok", + ); + const reg: ToolRegistry = { + t: defineTool({ + description: "t", + schema: z.object({ x: z.string() }), + handler, + }), + }; + + const controller = new AbortController(); + controller.abort(); + await executeFromRegistry(reg, "t", { x: "hi" }, controller.signal); + + expect(handler).toHaveBeenCalledTimes(1); + expect(handler.mock.calls[0][1]).toBe(controller.signal); + }); +}); + +describe("toolsFromRegistry", () => { + test("produces AgentToolDefinition[] with JSON Schema parameters", () => { + const registry: ToolRegistry = { + query: defineTool({ + description: "Execute a SQL query", + schema: z.object({ + query: z.string().describe("SQL query"), + }), + annotations: { readOnly: true, requiresUserContext: true }, + handler: () => "ok", + }), + }; + + const defs = toolsFromRegistry(registry); + expect(defs).toHaveLength(1); + expect(defs[0].name).toBe("query"); + expect(defs[0].description).toBe("Execute a SQL query"); + expect(defs[0].parameters).toMatchObject({ + type: "object", + properties: { + query: { type: "string", description: "SQL query" }, + }, + required: ["query"], + }); + expect(defs[0].annotations).toEqual({ + readOnly: true, + requiresUserContext: true, + }); + }); + + test("preserves dotted names like uploads.list from the registry keys", () => { + const registry: ToolRegistry = { + "uploads.list": defineTool({ + description: "list uploads", + schema: z.object({}), + handler: () => [], + }), + "documents.list": defineTool({ + description: "list documents", + schema: z.object({}), + handler: () => [], + }), + }; + + const names = toolsFromRegistry(registry).map((d) => d.name); + expect(names).toContain("uploads.list"); + expect(names).toContain("documents.list"); + }); + + test("omits annotations when none are provided", () => { + const registry: ToolRegistry = { + plain: defineTool({ + description: "plain", + schema: z.object({}), + handler: () => "ok", + }), + }; + const [def] = toolsFromRegistry(registry); + expect(def.annotations).toBeUndefined(); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/function-tool.test.ts b/packages/appkit/src/plugins/agents/tests/function-tool.test.ts new file mode 100644 index 00000000..8e668d69 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/function-tool.test.ts @@ -0,0 +1,110 @@ +import { describe, expect, test } from "vitest"; +import { + functionToolToDefinition, + isFunctionTool, +} from "../tools/function-tool"; + +describe("isFunctionTool", () => { + test("returns true for valid FunctionTool", () => { + expect( + isFunctionTool({ + type: "function", + name: "greet", + execute: async () => "hello", + }), + ).toBe(true); + }); + + test("returns true for minimal FunctionTool", () => { + expect( + isFunctionTool({ + type: "function", + name: "x", + execute: () => "y", + }), + ).toBe(true); + }); + + test("returns false for null", () => { + expect(isFunctionTool(null)).toBe(false); + }); + + test("returns false for non-object", () => { + expect(isFunctionTool("function")).toBe(false); + }); + + test("returns false for wrong type", () => { + expect( + isFunctionTool({ + type: "genie-space", + name: "x", + execute: () => "y", + }), + ).toBe(false); + }); + + test("returns false when execute is missing", () => { + expect(isFunctionTool({ type: "function", name: "x" })).toBe(false); + }); + + test("returns false when name is missing", () => { + expect(isFunctionTool({ type: "function", execute: () => "y" })).toBe( + false, + ); + }); +}); + +describe("functionToolToDefinition", () => { + test("converts a FunctionTool with all fields", () => { + const def = functionToolToDefinition({ + type: "function", + name: "getWeather", + description: "Get current weather", + parameters: { + type: "object", + properties: { city: { type: "string" } }, + required: ["city"], + }, + execute: async () => "sunny", + }); + + expect(def.name).toBe("getWeather"); + expect(def.description).toBe("Get current weather"); + expect(def.parameters).toEqual({ + type: "object", + properties: { city: { type: "string" } }, + required: ["city"], + }); + }); + + test("uses name as fallback description", () => { + const def = functionToolToDefinition({ + type: "function", + name: "myTool", + execute: async () => "result", + }); + + expect(def.description).toBe("myTool"); + }); + + test("uses empty object schema when parameters are null", () => { + const def = functionToolToDefinition({ + type: "function", + name: "noParams", + parameters: null, + execute: async () => "ok", + }); + + expect(def.parameters).toEqual({ type: "object", properties: {} }); + }); + + test("uses empty object schema when parameters are omitted", () => { + const def = functionToolToDefinition({ + type: "function", + name: "noParams", + execute: async () => "ok", + }); + + expect(def.parameters).toEqual({ type: "object", properties: {} }); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/hosted-tools.test.ts b/packages/appkit/src/plugins/agents/tests/hosted-tools.test.ts new file mode 100644 index 00000000..d62b266b --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/hosted-tools.test.ts @@ -0,0 +1,131 @@ +import { describe, expect, test } from "vitest"; +import { isHostedTool, resolveHostedTools } from "../tools/hosted-tools"; + +describe("isHostedTool", () => { + test("returns true for genie-space", () => { + expect( + isHostedTool({ type: "genie-space", genie_space: { id: "abc" } }), + ).toBe(true); + }); + + test("returns true for vector_search_index", () => { + expect( + isHostedTool({ + type: "vector_search_index", + vector_search_index: { name: "cat.schema.idx" }, + }), + ).toBe(true); + }); + + test("returns true for custom_mcp_server", () => { + expect( + isHostedTool({ + type: "custom_mcp_server", + custom_mcp_server: { app_name: "my-app", app_url: "my-app-url" }, + }), + ).toBe(true); + }); + + test("returns true for external_mcp_server", () => { + expect( + isHostedTool({ + type: "external_mcp_server", + external_mcp_server: { connection_name: "conn1" }, + }), + ).toBe(true); + }); + + test("returns false for FunctionTool", () => { + expect( + isHostedTool({ type: "function", name: "x", execute: () => "y" }), + ).toBe(false); + }); + + test("returns false for null", () => { + expect(isHostedTool(null)).toBe(false); + }); + + test("returns false for unknown type", () => { + expect(isHostedTool({ type: "unknown" })).toBe(false); + }); + + test("returns false for non-object", () => { + expect(isHostedTool(42)).toBe(false); + }); +}); + +describe("resolveHostedTools", () => { + test("resolves genie-space to correct MCP endpoint", () => { + const configs = resolveHostedTools([ + { type: "genie-space", genie_space: { id: "space123" } }, + ]); + + expect(configs).toHaveLength(1); + expect(configs[0].name).toBe("genie-space123"); + expect(configs[0].url).toBe("/api/2.0/mcp/genie/space123"); + }); + + test("resolves vector_search_index with 3-part name", () => { + const configs = resolveHostedTools([ + { + type: "vector_search_index", + vector_search_index: { name: "catalog.schema.my_index" }, + }, + ]); + + expect(configs).toHaveLength(1); + expect(configs[0].name).toBe("vs-catalog-schema-my_index"); + expect(configs[0].url).toBe( + "/api/2.0/mcp/vector-search/catalog/schema/my_index", + ); + }); + + test("throws for invalid vector_search_index name", () => { + expect(() => + resolveHostedTools([ + { + type: "vector_search_index", + vector_search_index: { name: "bad.name" }, + }, + ]), + ).toThrow("3-part dotted"); + }); + + test("resolves custom_mcp_server", () => { + const configs = resolveHostedTools([ + { + type: "custom_mcp_server", + custom_mcp_server: { app_name: "my-app", app_url: "my-app-endpoint" }, + }, + ]); + + expect(configs[0].name).toBe("my-app"); + expect(configs[0].url).toBe("my-app-endpoint"); + }); + + test("resolves external_mcp_server", () => { + const configs = resolveHostedTools([ + { + type: "external_mcp_server", + external_mcp_server: { connection_name: "conn1" }, + }, + ]); + + expect(configs[0].name).toBe("conn1"); + expect(configs[0].url).toBe("/api/2.0/mcp/external/conn1"); + }); + + test("resolves multiple tools preserving order", () => { + const configs = resolveHostedTools([ + { type: "genie-space", genie_space: { id: "g1" } }, + { + type: "external_mcp_server", + external_mcp_server: { connection_name: "e1" }, + }, + ]); + + expect(configs).toHaveLength(2); + expect(configs[0].name).toBe("genie-g1"); + expect(configs[1].name).toBe("e1"); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/mcp-client.test.ts b/packages/appkit/src/plugins/agents/tests/mcp-client.test.ts new file mode 100644 index 00000000..483fb5f4 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/mcp-client.test.ts @@ -0,0 +1,402 @@ +import { beforeEach, describe, expect, test, vi } from "vitest"; +import { AppKitMcpClient } from "../tools/mcp-client"; +import type { DnsLookup, McpHostPolicy } from "../tools/mcp-host-policy"; + +const WORKSPACE = "https://test-workspace.cloud.databricks.com"; + +const workspacePolicy: McpHostPolicy = { + workspaceHostname: "test-workspace.cloud.databricks.com", + trustedHosts: new Set(), + allowLocalhost: false, +}; + +const trustedExternalPolicy: McpHostPolicy = { + workspaceHostname: "test-workspace.cloud.databricks.com", + trustedHosts: new Set(["mcp.example.com"]), + allowLocalhost: false, +}; + +const publicDnsLookup: DnsLookup = async () => [ + { address: "203.0.113.42", family: 4 }, +]; + +const workspaceAuth = async (): Promise> => ({ + Authorization: "Bearer SP-TOKEN", +}); + +type FetchCall = { + url: string; + init: RequestInit; +}; + +function recordingFetch( + responders: Array<(call: FetchCall) => Response | Promise>, +) { + const calls: FetchCall[] = []; + let n = 0; + const fetchImpl: typeof fetch = async (input, init) => { + const url = typeof input === "string" ? input : (input as URL).toString(); + const call: FetchCall = { url, init: init ?? {} }; + calls.push(call); + const responder = responders[n++] ?? responders[responders.length - 1]; + return Promise.resolve(responder(call)); + }; + return { fetchImpl, calls }; +} + +function jsonResponse(body: unknown, headers: Record = {}) { + return new Response(JSON.stringify(body), { + status: 200, + headers: { "content-type": "application/json", ...headers }, + }); +} + +describe("AppKitMcpClient — host allowlist", () => { + let authSpy: ReturnType; + + beforeEach(() => { + authSpy = vi.fn(workspaceAuth); + }); + + test("connect rejects a URL whose host is not allowlisted without making any fetch", async () => { + const { fetchImpl, calls } = recordingFetch([() => jsonResponse({})]); + const client = new AppKitMcpClient(WORKSPACE, authSpy, workspacePolicy, { + fetchImpl, + dnsLookup: publicDnsLookup, + }); + await expect( + client.connect({ name: "evil", url: "https://attacker.example.com/mcp" }), + ).rejects.toThrow(/attacker\.example\.com/); + expect(calls).toHaveLength(0); + expect(authSpy).not.toHaveBeenCalled(); + }); + + test("connect rejects plaintext http:// for remote hosts", async () => { + const { fetchImpl, calls } = recordingFetch([() => jsonResponse({})]); + const client = new AppKitMcpClient( + WORKSPACE, + authSpy, + trustedExternalPolicy, + { fetchImpl, dnsLookup: publicDnsLookup }, + ); + await expect( + client.connect({ name: "plain", url: "http://mcp.example.com/mcp" }), + ).rejects.toThrow(/plaintext http/); + expect(calls).toHaveLength(0); + expect(authSpy).not.toHaveBeenCalled(); + }); + + test("connect rejects a URL whose DNS resolves to a blocked IP and never sends SP token", async () => { + const ssrfLookup: DnsLookup = async () => [ + { address: "169.254.169.254", family: 4 }, + ]; + const policy: McpHostPolicy = { + workspaceHostname: "test-workspace.cloud.databricks.com", + trustedHosts: new Set(["evil.example.com"]), + allowLocalhost: false, + }; + const { fetchImpl, calls } = recordingFetch([() => jsonResponse({})]); + const client = new AppKitMcpClient(WORKSPACE, authSpy, policy, { + fetchImpl, + dnsLookup: ssrfLookup, + }); + await expect( + client.connect({ name: "evil", url: "https://evil.example.com/mcp" }), + ).rejects.toThrow(/169\.254\.169\.254/); + expect(calls).toHaveLength(0); + expect(authSpy).not.toHaveBeenCalled(); + }); + + test("connect to same-origin workspace forwards SP token on initialize + tools/list", async () => { + const { fetchImpl, calls } = recordingFetch([ + () => + jsonResponse( + { jsonrpc: "2.0", id: 1, result: {} }, + { + "mcp-session-id": "sess-1", + }, + ), + () => jsonResponse({ jsonrpc: "2.0", result: null }), + () => + jsonResponse({ + jsonrpc: "2.0", + id: 3, + result: { tools: [{ name: "echo", description: "Echo" }] }, + }), + ]); + const client = new AppKitMcpClient(WORKSPACE, authSpy, workspacePolicy, { + fetchImpl, + dnsLookup: publicDnsLookup, + }); + + await client.connect({ + name: "genie-1", + url: `${WORKSPACE}/api/2.0/mcp/genie/abc`, + }); + + // initialize + notifications/initialized + tools/list all carry SP token + expect(calls.map((c) => c.url)).toEqual([ + `${WORKSPACE}/api/2.0/mcp/genie/abc`, + `${WORKSPACE}/api/2.0/mcp/genie/abc`, + `${WORKSPACE}/api/2.0/mcp/genie/abc`, + ]); + for (const call of calls) { + const headers = call.init.headers as Record; + expect(headers.Authorization).toBe("Bearer SP-TOKEN"); + } + expect(client.canForwardWorkspaceAuth("genie-1")).toBe(true); + }); + + test("connect to trusted external host does NOT forward SP token on any RPC", async () => { + const { fetchImpl, calls } = recordingFetch([ + () => + jsonResponse( + { jsonrpc: "2.0", id: 1, result: {} }, + { + "mcp-session-id": "sess-1", + }, + ), + () => jsonResponse({ jsonrpc: "2.0", result: null }), + () => + jsonResponse({ + jsonrpc: "2.0", + id: 3, + result: { tools: [{ name: "help" }] }, + }), + ]); + const client = new AppKitMcpClient( + WORKSPACE, + authSpy, + trustedExternalPolicy, + { fetchImpl, dnsLookup: publicDnsLookup }, + ); + + await client.connect({ name: "ext", url: "https://mcp.example.com/mcp" }); + + for (const call of calls) { + const headers = call.init.headers as Record; + expect(headers.Authorization).toBeUndefined(); + } + expect(authSpy).not.toHaveBeenCalled(); + expect(client.canForwardWorkspaceAuth("ext")).toBe(false); + }); +}); + +describe("AppKitMcpClient — callTool auth scoping", () => { + test("drops caller-supplied OBO token when destination is not workspace-origin", async () => { + const connectResponders = [ + () => + jsonResponse( + { jsonrpc: "2.0", id: 1, result: {} }, + { + "mcp-session-id": "sess-1", + }, + ), + () => jsonResponse({ jsonrpc: "2.0", result: null }), + () => + jsonResponse({ + jsonrpc: "2.0", + id: 3, + result: { tools: [{ name: "do" }] }, + }), + ]; + const callResponder = () => + jsonResponse({ + jsonrpc: "2.0", + id: 4, + result: { content: [{ type: "text", text: "ok" }] }, + }); + const { fetchImpl, calls } = recordingFetch([ + ...connectResponders, + callResponder, + ]); + const client = new AppKitMcpClient( + WORKSPACE, + workspaceAuth, + trustedExternalPolicy, + { fetchImpl, dnsLookup: publicDnsLookup }, + ); + await client.connect({ name: "ext", url: "https://mcp.example.com/mcp" }); + + const output = await client.callTool( + "mcp.ext.do", + { x: 1 }, + { + Authorization: "Bearer OBO-USER-TOKEN", + }, + ); + expect(output).toBe("ok"); + + const toolCall = calls[calls.length - 1]; + const headers = toolCall.init.headers as Record; + expect(headers.Authorization).toBeUndefined(); + }); + + test("forwards caller-supplied OBO token when destination is workspace-origin", async () => { + const connectResponders = [ + () => + jsonResponse( + { jsonrpc: "2.0", id: 1, result: {} }, + { + "mcp-session-id": "sess-1", + }, + ), + () => jsonResponse({ jsonrpc: "2.0", result: null }), + () => + jsonResponse({ + jsonrpc: "2.0", + id: 3, + result: { tools: [{ name: "do" }] }, + }), + ]; + const callResponder = () => + jsonResponse({ + jsonrpc: "2.0", + id: 4, + result: { content: [{ type: "text", text: "ok" }] }, + }); + const { fetchImpl, calls } = recordingFetch([ + ...connectResponders, + callResponder, + ]); + const client = new AppKitMcpClient( + WORKSPACE, + workspaceAuth, + workspacePolicy, + { + fetchImpl, + dnsLookup: publicDnsLookup, + }, + ); + await client.connect({ + name: "genie-1", + url: `${WORKSPACE}/api/2.0/mcp/genie/abc`, + }); + + await client.callTool( + "mcp.genie-1.do", + {}, + { + Authorization: "Bearer OBO-USER-TOKEN", + }, + ); + + const toolCall = calls[calls.length - 1]; + const headers = toolCall.init.headers as Record; + expect(headers.Authorization).toBe("Bearer OBO-USER-TOKEN"); + }); + + test("falls back to SP auth when no OBO override is provided and destination is workspace", async () => { + const authSpy = vi.fn(workspaceAuth); + const connectResponders = [ + () => + jsonResponse( + { jsonrpc: "2.0", id: 1, result: {} }, + { + "mcp-session-id": "sess-1", + }, + ), + () => jsonResponse({ jsonrpc: "2.0", result: null }), + () => + jsonResponse({ + jsonrpc: "2.0", + id: 3, + result: { tools: [{ name: "do" }] }, + }), + ]; + const callResponder = () => + jsonResponse({ + jsonrpc: "2.0", + id: 4, + result: { content: [{ type: "text", text: "ok" }] }, + }); + const { fetchImpl, calls } = recordingFetch([ + ...connectResponders, + callResponder, + ]); + const client = new AppKitMcpClient(WORKSPACE, authSpy, workspacePolicy, { + fetchImpl, + dnsLookup: publicDnsLookup, + }); + await client.connect({ + name: "genie-1", + url: `${WORKSPACE}/api/2.0/mcp/genie/abc`, + }); + + await client.callTool("mcp.genie-1.do", {}, undefined); + + const toolCall = calls[calls.length - 1]; + const headers = toolCall.init.headers as Record; + expect(headers.Authorization).toBe("Bearer SP-TOKEN"); + }); +}); + +describe("AppKitMcpClient — caller abort signal composition", () => { + test("callTool's fetch aborts when the caller signal fires", async () => { + const connectResponders = [ + () => + jsonResponse( + { jsonrpc: "2.0", id: 1, result: {} }, + { "mcp-session-id": "sess-1" }, + ), + () => jsonResponse({ jsonrpc: "2.0", result: null }), + () => + jsonResponse({ + jsonrpc: "2.0", + id: 3, + result: { tools: [{ name: "slow" }] }, + }), + ]; + const callResponder = (call: FetchCall): Promise => { + const signal = call.init.signal as AbortSignal | undefined; + return new Promise((_, reject) => { + if (signal?.aborted) { + reject( + new DOMException( + signal.reason?.toString() ?? "aborted", + "AbortError", + ), + ); + return; + } + signal?.addEventListener( + "abort", + () => { + reject( + new DOMException( + signal.reason?.toString() ?? "aborted", + "AbortError", + ), + ); + }, + { once: true }, + ); + }); + }; + const { fetchImpl } = recordingFetch([...connectResponders, callResponder]); + const client = new AppKitMcpClient( + WORKSPACE, + workspaceAuth, + workspacePolicy, + { + fetchImpl, + dnsLookup: publicDnsLookup, + }, + ); + await client.connect({ + name: "genie-1", + url: `${WORKSPACE}/api/2.0/mcp/genie/abc`, + }); + + const controller = new AbortController(); + const pending = client + .callTool("mcp.genie-1.slow", {}, undefined, controller.signal) + .catch((e) => e); + // Let the fetch start + register its abort listener before we abort. + await new Promise((r) => setTimeout(r, 10)); + controller.abort(new Error("user cancelled")); + const error = (await pending) as Error; + expect(error).toBeInstanceOf(Error); + expect(error.name).toBe("AbortError"); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/mcp-host-policy.test.ts b/packages/appkit/src/plugins/agents/tests/mcp-host-policy.test.ts new file mode 100644 index 00000000..06d98627 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/mcp-host-policy.test.ts @@ -0,0 +1,354 @@ +import { describe, expect, test, vi } from "vitest"; +import { + assertResolvedHostSafe, + buildMcpHostPolicy, + checkMcpUrl, + type DnsLookup, + isBlockedIp, + isLoopbackHost, + type McpHostPolicy, + type McpHostPolicyConfig, +} from "../tools/mcp-host-policy"; + +function stubLookup( + addresses: Array<{ address: string; family?: number }>, +): DnsLookup { + return vi + .fn() + .mockResolvedValue(addresses.map((a) => ({ family: 4, ...a }))); +} + +function failingLookup(message: string): DnsLookup { + return vi.fn().mockRejectedValue(new Error(message)); +} + +const WORKSPACE = "https://test-workspace.cloud.databricks.com"; + +function policy(overrides: Partial = {}): McpHostPolicy { + return { + workspaceHostname: "test-workspace.cloud.databricks.com", + trustedHosts: new Set(), + allowLocalhost: false, + ...overrides, + }; +} + +describe("buildMcpHostPolicy", () => { + test("extracts hostname from workspace URL", () => { + const p = buildMcpHostPolicy(undefined, WORKSPACE); + expect(p.workspaceHostname).toBe("test-workspace.cloud.databricks.com"); + }); + + test("lowercases and trims trustedHosts", () => { + const p = buildMcpHostPolicy( + { trustedHosts: ["Example.COM", " corp.internal ", "mcp.example.com"] }, + WORKSPACE, + ); + expect(p.trustedHosts).toEqual( + new Set(["example.com", "corp.internal", "mcp.example.com"]), + ); + }); + + test("allowLocalhost defaults to false in production", () => { + const prev = process.env.NODE_ENV; + process.env.NODE_ENV = "production"; + try { + const p = buildMcpHostPolicy(undefined, WORKSPACE); + expect(p.allowLocalhost).toBe(false); + } finally { + process.env.NODE_ENV = prev; + } + }); + + test("allowLocalhost defaults to true outside production", () => { + const prev = process.env.NODE_ENV; + process.env.NODE_ENV = "development"; + try { + const p = buildMcpHostPolicy(undefined, WORKSPACE); + expect(p.allowLocalhost).toBe(true); + } finally { + process.env.NODE_ENV = prev; + } + }); + + test("allowLocalhost respects explicit override", () => { + const prev = process.env.NODE_ENV; + process.env.NODE_ENV = "production"; + try { + const cfg: McpHostPolicyConfig = { allowLocalhost: true }; + const p = buildMcpHostPolicy(cfg, WORKSPACE); + expect(p.allowLocalhost).toBe(true); + } finally { + process.env.NODE_ENV = prev; + } + }); + + test("throws on invalid workspace host", () => { + expect(() => buildMcpHostPolicy(undefined, "not-a-url")).toThrow( + /Invalid workspace host/, + ); + }); +}); + +describe("checkMcpUrl", () => { + test("admits same-origin workspace https URL and forwards auth", () => { + const result = checkMcpUrl(`${WORKSPACE}/api/2.0/mcp/genie/abc`, policy()); + expect(result.ok).toBe(true); + if (result.ok) expect(result.forwardWorkspaceAuth).toBe(true); + }); + + test("admits trusted host but does NOT forward workspace auth", () => { + const p = policy({ trustedHosts: new Set(["mcp.example.com"]) }); + const result = checkMcpUrl("https://mcp.example.com/mcp", p); + expect(result.ok).toBe(true); + if (result.ok) expect(result.forwardWorkspaceAuth).toBe(false); + }); + + test("rejects host that is neither workspace nor trusted", () => { + const result = checkMcpUrl("https://attacker.example.com/mcp", policy()); + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.reason).toMatch(/attacker\.example\.com/); + expect(result.reason).toMatch(/trustedHosts/); + } + }); + + test("rejects plaintext http:// for remote hosts even when trusted", () => { + const p = policy({ trustedHosts: new Set(["mcp.example.com"]) }); + const result = checkMcpUrl("http://mcp.example.com/mcp", p); + expect(result.ok).toBe(false); + if (!result.ok) expect(result.reason).toMatch(/plaintext http/); + }); + + test("rejects plaintext http://localhost when allowLocalhost is false", () => { + const result = checkMcpUrl("http://localhost:4000/mcp", policy()); + expect(result.ok).toBe(false); + }); + + test("admits http://localhost when allowLocalhost is true, no workspace auth", () => { + const p = policy({ allowLocalhost: true }); + const result = checkMcpUrl("http://localhost:4000/mcp", p); + expect(result.ok).toBe(true); + if (result.ok) expect(result.forwardWorkspaceAuth).toBe(false); + }); + + test("admits http://127.0.0.1 when allowLocalhost is true", () => { + const p = policy({ allowLocalhost: true }); + const result = checkMcpUrl("http://127.0.0.1:4000/mcp", p); + expect(result.ok).toBe(true); + if (result.ok) expect(result.forwardWorkspaceAuth).toBe(false); + }); + + test("rejects non-http(s) schemes", () => { + for (const url of [ + "file:///etc/passwd", + "ftp://host/x", + "gopher://host/x", + "javascript:alert(1)", + ]) { + const result = checkMcpUrl(url, policy()); + expect(result.ok).toBe(false); + } + }); + + test("rejects obviously invalid URLs", () => { + const result = checkMcpUrl("not-a-url", policy()); + expect(result.ok).toBe(false); + }); + + test("hostname comparison is case-insensitive", () => { + const result = checkMcpUrl( + "https://TEST-Workspace.CLOUD.Databricks.com/mcp", + policy(), + ); + expect(result.ok).toBe(true); + if (result.ok) expect(result.forwardWorkspaceAuth).toBe(true); + }); + + test("rejects same hostname on different scheme (http) even for workspace", () => { + const result = checkMcpUrl( + "http://test-workspace.cloud.databricks.com/mcp", + policy(), + ); + expect(result.ok).toBe(false); + }); +}); + +describe("isBlockedIp", () => { + test("blocks RFC1918 IPv4 ranges", () => { + for (const addr of [ + "10.0.0.1", + "10.255.255.255", + "172.16.0.1", + "172.31.255.255", + "192.168.0.1", + "192.168.255.255", + ]) { + expect(isBlockedIp(addr, true)).toBe(true); + } + }); + + test("blocks link-local 169.254.0.0/16 (covers cloud metadata 169.254.169.254)", () => { + expect(isBlockedIp("169.254.169.254", true)).toBe(true); + expect(isBlockedIp("169.254.0.1", true)).toBe(true); + }); + + test("blocks CGNAT 100.64.0.0/10", () => { + expect(isBlockedIp("100.64.0.1", true)).toBe(true); + expect(isBlockedIp("100.127.255.255", true)).toBe(true); + }); + + test("blocks 0.0.0.0/8 and multicast/reserved (>= 224.0.0.0)", () => { + expect(isBlockedIp("0.0.0.0", true)).toBe(true); + expect(isBlockedIp("0.1.2.3", true)).toBe(true); + expect(isBlockedIp("224.0.0.1", true)).toBe(true); + expect(isBlockedIp("255.255.255.255", true)).toBe(true); + }); + + test("blocks loopback when allowLocalhost is false", () => { + expect(isBlockedIp("127.0.0.1", false)).toBe(true); + expect(isBlockedIp("127.1.2.3", false)).toBe(true); + expect(isBlockedIp("::1", false)).toBe(true); + }); + + test("permits loopback when allowLocalhost is true", () => { + expect(isBlockedIp("127.0.0.1", true)).toBe(false); + expect(isBlockedIp("::1", true)).toBe(false); + }); + + test("blocks ULA (fc00::/7) and link-local (fe80::/10) IPv6", () => { + expect(isBlockedIp("fc00::1", true)).toBe(true); + expect(isBlockedIp("fd00::1", true)).toBe(true); + expect(isBlockedIp("fe80::1", true)).toBe(true); + }); + + test("blocks the full link-local /10 range fe80::–febf:: (regression: fea0/feb0)", () => { + // fe80::/10 spans 1111 1110 10.. — first hex pair `fe` + second nibble 8..b. + for (const addr of [ + "fe80::1", + "fe90::1", + "fea0::1", // regression: was passing the filter before + "feaf::1", // regression + "feb0::1", // regression + "febf::1", // regression + ]) { + expect(isBlockedIp(addr, true)).toBe(true); + } + // Outside /10 must not be blocked by this rule (belongs to routable-ish + // experimental ranges; nothing else in the module should match either). + expect(isBlockedIp("fec0::1", true)).toBe(false); + }); + + test("blocks IPv4-mapped IPv6 addresses in blocked ranges (dotted form)", () => { + expect(isBlockedIp("::ffff:169.254.169.254", true)).toBe(true); + expect(isBlockedIp("::ffff:10.0.0.1", true)).toBe(true); + }); + + test("blocks IPv4-mapped IPv6 addresses in colon-hex form (regression)", () => { + // ::ffff:a9fe:a9fe is the same destination as ::ffff:169.254.169.254. + // Before the fix this form slipped past the IPv4-mapped branch because + // isIPv4("a9fe:a9fe") is false and no other v6 rule matched. + expect(isBlockedIp("::ffff:a9fe:a9fe", true)).toBe(true); // 169.254.169.254 + expect(isBlockedIp("::ffff:0a00:0001", true)).toBe(true); // 10.0.0.1 + expect(isBlockedIp("::ffff:c0a8:0001", true)).toBe(true); // 192.168.0.1 + // A public IPv4 mapped to colon-hex must still pass through: 8.8.8.8 = 0808:0808 + expect(isBlockedIp("::ffff:0808:0808", true)).toBe(false); + }); + + test("allows public IPv4 and IPv6 addresses", () => { + expect(isBlockedIp("8.8.8.8", false)).toBe(false); + expect(isBlockedIp("1.1.1.1", false)).toBe(false); + expect(isBlockedIp("2001:4860:4860::8888", false)).toBe(false); + }); + + test("treats malformed IP strings as blocked (fail-closed)", () => { + expect(isBlockedIp("10.0.0", true)).toBe(true); + expect(isBlockedIp("abc.def.ghi.jkl", true)).toBe(true); + }); +}); + +describe("isLoopbackHost", () => { + test.each([ + "localhost", + "LOCALHOST", + "127.0.0.1", + "::1", + "[::1]", + "0:0:0:0:0:0:0:1", + ])("recognises %s as loopback", (host) => { + expect(isLoopbackHost(host)).toBe(true); + }); + + test("does not match other hosts", () => { + expect(isLoopbackHost("example.com")).toBe(false); + expect(isLoopbackHost("10.0.0.1")).toBe(false); + }); +}); + +describe("assertResolvedHostSafe", () => { + test("passes workspace hostname when resolved address is public", async () => { + const lookup = stubLookup([{ address: "203.0.113.42" }]); + await expect( + assertResolvedHostSafe( + "test-workspace.cloud.databricks.com", + policy(), + lookup, + ), + ).resolves.toBeUndefined(); + expect(lookup).toHaveBeenCalledWith("test-workspace.cloud.databricks.com", { + all: true, + }); + }); + + test("rejects hostname that resolves to link-local cloud metadata IP", async () => { + const lookup = stubLookup([{ address: "169.254.169.254" }]); + await expect( + assertResolvedHostSafe("evil.example.com", policy(), lookup), + ).rejects.toThrow(/169\.254\.169\.254/); + }); + + test("rejects hostname that resolves to RFC1918 IP", async () => { + const lookup = stubLookup([{ address: "10.0.0.1" }]); + await expect( + assertResolvedHostSafe("internal.example.com", policy(), lookup), + ).rejects.toThrow(/10\.0\.0\.1/); + }); + + test("rejects IP literal in blocked range without DNS lookup", async () => { + const lookup = stubLookup([{ address: "8.8.8.8" }]); + await expect( + assertResolvedHostSafe("169.254.169.254", policy(), lookup), + ).rejects.toThrow(/blocked IP range/); + expect(lookup).not.toHaveBeenCalled(); + }); + + test("rejects plain 'localhost' when allowLocalhost is false", async () => { + await expect(assertResolvedHostSafe("localhost", policy())).rejects.toThrow( + /localhost is not allowed/, + ); + }); + + test("surfaces DNS resolution failures", async () => { + const lookup = failingLookup("ENOTFOUND"); + await expect( + assertResolvedHostSafe("nonexistent.example.com", policy(), lookup), + ).rejects.toThrow(/could not be resolved/); + }); + + test("rejects if any resolved address is blocked (defense against split DNS)", async () => { + const lookup = stubLookup([ + { address: "8.8.8.8" }, + { address: "169.254.169.254" }, + ]); + await expect( + assertResolvedHostSafe("mixed.example.com", policy(), lookup), + ).rejects.toThrow(/169\.254\.169\.254/); + }); + + test("rejects hostname that resolves to empty DNS result", async () => { + const lookup = stubLookup([]); + await expect( + assertResolvedHostSafe("empty.example.com", policy(), lookup), + ).rejects.toThrow(/no DNS addresses/); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/mcp-server-helper.test.ts b/packages/appkit/src/plugins/agents/tests/mcp-server-helper.test.ts new file mode 100644 index 00000000..96ad8e38 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/mcp-server-helper.test.ts @@ -0,0 +1,34 @@ +import { describe, expect, test } from "vitest"; +import { + isHostedTool, + mcpServer, + resolveHostedTools, +} from "../tools/hosted-tools"; + +describe("mcpServer()", () => { + test("returns a CustomMcpServerTool with correct shape", () => { + const result = mcpServer("my-app", "https://example.com/mcp"); + + expect(result).toEqual({ + type: "custom_mcp_server", + custom_mcp_server: { + app_name: "my-app", + app_url: "https://example.com/mcp", + }, + }); + }); + + test("isHostedTool recognizes mcpServer() output", () => { + expect(isHostedTool(mcpServer("x", "y"))).toBe(true); + }); + + test("resolveHostedTools resolves mcpServer() output to an endpoint config", () => { + const configs = resolveHostedTools([ + mcpServer("vector-search", "https://host/mcp/vs"), + ]); + + expect(configs).toHaveLength(1); + expect(configs[0].name).toBe("vector-search"); + expect(configs[0].url).toBe("https://host/mcp/vs"); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/sql-policy.test.ts b/packages/appkit/src/plugins/agents/tests/sql-policy.test.ts new file mode 100644 index 00000000..fb81493e --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/sql-policy.test.ts @@ -0,0 +1,227 @@ +import { describe, expect, test } from "vitest"; +import { + assertReadOnlySql, + classifyReadOnly, + ReadOnlySqlViolation, +} from "../tools/sql-policy"; + +function ok(sql: string) { + const result = classifyReadOnly(sql); + if (!result.readOnly) { + throw new Error( + `Expected readOnly=true for ${JSON.stringify(sql)}, got reason: ${result.reason}`, + ); + } + return result; +} + +function rejected(sql: string) { + const result = classifyReadOnly(sql); + if (result.readOnly) { + throw new Error( + `Expected readOnly=false for ${JSON.stringify(sql)}, got readOnly=true`, + ); + } + return result; +} + +describe("classifyReadOnly: plain reads are admitted", () => { + test.each([ + "SELECT 1", + "select 1", + "SELECT * FROM users", + "SELECT * FROM main.sales.orders WHERE created_at > now() - interval '7 days'", + "SELECT COUNT(*) FROM main.sales.orders", + "WITH a AS (SELECT 1) SELECT * FROM a", + "WITH RECURSIVE t AS (SELECT 1) SELECT * FROM t", + "SHOW TABLES", + "SHOW TABLES IN main.sales", + "DESCRIBE EXTENDED main.sales.orders", + "DESC main.sales.orders", + "EXPLAIN SELECT 1", + "EXPLAIN ANALYZE SELECT 1", + ])("admits %s", (sql) => { + expect(ok(sql).statements).toBe(1); + }); +}); + +describe("classifyReadOnly: writes are rejected", () => { + test.each([ + ["DROP TABLE users", "DROP"], + ["UPDATE users SET email = 'x@y.com'", "UPDATE"], + ["DELETE FROM orders WHERE id = 1", "DELETE"], + ["INSERT INTO x VALUES (1)", "INSERT"], + ["CREATE TABLE x (id INT)", "CREATE"], + ["ALTER TABLE x ADD COLUMN y INT", "ALTER"], + ["TRUNCATE TABLE orders", "TRUNCATE"], + ["GRANT SELECT ON t TO u", "GRANT"], + ["REVOKE ALL ON t FROM u", "REVOKE"], + ["CALL sp_do_thing()", "CALL"], + ["COPY t FROM '/tmp/x'", "COPY"], + ["MERGE INTO t USING s", "MERGE"], + ["REFRESH TABLE t", "REFRESH"], + ["VACUUM t", "VACUUM"], + ])("rejects %s", (sql, keyword) => { + const result = rejected(sql); + expect(result.reason).toContain(keyword); + }); +}); + +describe("classifyReadOnly: stacked statements", () => { + test("rejects SELECT followed by DROP", () => { + const result = rejected("SELECT 1; DROP TABLE x"); + expect(result.reason).toMatch(/DROP/); + }); + + test("rejects DROP followed by SELECT (write comes first)", () => { + const result = rejected("DROP TABLE x; SELECT 1"); + expect(result.reason).toMatch(/DROP/); + }); + + test("admits multiple SELECTs", () => { + expect(ok("SELECT 1; SELECT 2").statements).toBe(2); + }); + + test("admits trailing semicolon on single statement", () => { + expect(ok("SELECT 1;").statements).toBe(1); + }); + + test("admits SELECT, SHOW, DESCRIBE batch", () => { + const result = ok("SELECT 1; SHOW TABLES; DESCRIBE x;"); + expect(result.statements).toBe(3); + }); +}); + +describe("classifyReadOnly: comment handling", () => { + test("admits SELECT with line comment hiding a write keyword", () => { + ok("SELECT 1 -- DROP TABLE x\n"); + }); + + test("admits SELECT preceded by line comment with write keyword", () => { + ok("-- DROP TABLE x\nSELECT 1"); + }); + + test("admits SELECT with block comment containing stacked write", () => { + ok("SELECT 1 /* ; DROP TABLE x */"); + }); + + test("handles nested block comments (PostgreSQL style)", () => { + ok("SELECT 1 /* outer /* inner */ still inside */"); + }); + + test("rejects when write is outside the comment", () => { + const result = rejected("/* SELECT 1 */ DROP TABLE x"); + expect(result.reason).toMatch(/DROP/); + }); + + test("empty after stripping comments is rejected", () => { + rejected("-- only a comment"); + rejected("/* nothing */"); + }); +}); + +describe("classifyReadOnly: string literal handling", () => { + test("admits SELECT with write keyword inside single-quoted string", () => { + ok("SELECT 'DROP TABLE x' AS msg"); + }); + + test("admits SELECT with semicolon inside single-quoted string", () => { + ok("SELECT 'value; DROP TABLE x' AS msg"); + }); + + test("admits SELECT with doubled-quote escape", () => { + ok("SELECT 'it''s ok; DROP' AS msg"); + }); + + test("admits SELECT with backslash escape inside string", () => { + ok("SELECT E'line\\'s end; DROP' AS msg"); + }); + + test("admits SELECT with dollar-quoted string hiding a write", () => { + ok("SELECT $body$ arbitrary ; DROP TABLE x $body$ AS msg"); + }); + + test("admits SELECT with untagged dollar quote", () => { + ok("SELECT $$hello; DROP$$ AS msg"); + }); + + test("admits SELECT with ANSI double-quoted identifier named drop", () => { + ok('SELECT * FROM "drop"'); + }); + + test("admits SELECT with doubled-quote inside ANSI identifier", () => { + ok('SELECT * FROM "weird""name"'); + }); + + test("admits SELECT with backtick identifier (Databricks)", () => { + ok("SELECT * FROM `my table`"); + }); +}); + +describe("classifyReadOnly: degenerate input", () => { + test("rejects empty string", () => { + rejected(""); + }); + + test("rejects whitespace-only", () => { + rejected(" \n\t "); + }); + + test("rejects semicolons only", () => { + rejected(";;;"); + }); + + test("rejects non-SQL garbage", () => { + rejected("-- this is just a comment\n-- nothing else"); + rejected("random garbage text"); + }); + + test("rejects a single empty statement between two selects", () => { + // "SELECT 1;; SELECT 2" — the middle empty statement is dropped by + // splitter; the surviving two statements are both SELECT, admitted. + ok("SELECT 1;; SELECT 2"); + }); +}); + +describe("classifyReadOnly: evasion-resistance", () => { + test("cannot hide DROP after a comment-ended newline", () => { + const result = rejected("-- intent\nDROP TABLE x"); + expect(result.reason).toMatch(/DROP/); + }); + + test("cannot hide DROP via concatenated strings (strings end cleanly)", () => { + rejected("'SELECT 1'; DROP TABLE x"); + }); + + test("bare DROP after unclosed string is still considered part of the string (defensive)", () => { + // An unclosed single quote eats the rest of the input — classifier + // sees the whole thing as one stripped, empty-ish statement and rejects. + rejected("SELECT 'unterminated ; DROP TABLE x"); + }); + + test("dollar-quoted literal with malicious tag is handled", () => { + ok("SELECT $tag$ DROP $tag$ AS harmless"); + }); + + test("mismatched dollar-quote tag is treated as unterminated", () => { + rejected("SELECT $a$ DROP TABLE x $b$"); + }); +}); + +describe("assertReadOnlySql", () => { + test("returns void on read-only SQL", () => { + expect(() => assertReadOnlySql("SELECT 1")).not.toThrow(); + }); + + test("throws ReadOnlySqlViolation with descriptive message on writes", () => { + expect(() => assertReadOnlySql("DROP TABLE x")).toThrow( + ReadOnlySqlViolation, + ); + try { + assertReadOnlySql("DROP TABLE x"); + } catch (e) { + expect((e as Error).message).toMatch(/SQL read-only policy violation/); + expect((e as Error).message).toMatch(/DROP/); + } + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/tool.test.ts b/packages/appkit/src/plugins/agents/tests/tool.test.ts new file mode 100644 index 00000000..3d47f3a9 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/tool.test.ts @@ -0,0 +1,110 @@ +import { describe, expect, test } from "vitest"; +import { z } from "zod"; +import { formatZodError, tool } from "../tools/tool"; + +describe("tool()", () => { + test("produces a FunctionTool with JSON Schema parameters from the Zod schema", () => { + const weather = tool({ + name: "get_weather", + description: "Get the current weather for a city", + schema: z.object({ + city: z.string().describe("City name"), + }), + execute: async ({ city }) => `Sunny in ${city}`, + }); + + expect(weather.type).toBe("function"); + expect(weather.name).toBe("get_weather"); + expect(weather.description).toBe("Get the current weather for a city"); + expect(weather.parameters).toMatchObject({ + type: "object", + properties: { + city: { type: "string", description: "City name" }, + }, + required: ["city"], + }); + }); + + test("execute receives typed args on valid input", async () => { + const echo = tool({ + name: "echo", + schema: z.object({ message: z.string() }), + execute: async ({ message }) => { + const _typed: string = message; + return `got ${_typed}`; + }, + }); + + const result = await echo.execute({ message: "hi" }); + expect(result).toBe("got hi"); + }); + + test("returns formatted error string (does not throw) when args are invalid", async () => { + const weather = tool({ + name: "get_weather", + schema: z.object({ city: z.string() }), + execute: async ({ city }) => `Sunny in ${city}`, + }); + + const result = await weather.execute({}); + expect(typeof result).toBe("string"); + expect(result).toContain("Invalid arguments for get_weather"); + expect(result).toContain("city"); + }); + + test("joins multiple validation errors with '; '", async () => { + const t = tool({ + name: "multi", + schema: z.object({ a: z.string(), b: z.number() }), + execute: async () => "ok", + }); + + const result = await t.execute({}); + expect(result).toContain("a:"); + expect(result).toContain("b:"); + expect(result).toContain(";"); + }); + + test("optional fields validate when absent", async () => { + const t = tool({ + name: "opt", + schema: z.object({ note: z.string().optional() }), + execute: async ({ note }) => note ?? "(no note)", + }); + + expect(await t.execute({})).toBe("(no note)"); + expect(await t.execute({ note: "hello" })).toBe("hello"); + }); + + test("description falls back to the tool name when omitted", () => { + const t = tool({ + name: "my_tool", + schema: z.object({}), + execute: async () => "ok", + }); + + expect(t.description).toBe("my_tool"); + expect(t.parameters).toBeDefined(); + }); +}); + +describe("formatZodError", () => { + test("formats a single issue with the tool name", () => { + const schema = z.object({ city: z.string() }); + const result = schema.safeParse({}); + if (result.success) throw new Error("expected failure"); + + const msg = formatZodError(result.error, "get_weather"); + expect(msg).toMatch(/^Invalid arguments for get_weather: /); + expect(msg).toContain("city:"); + }); + + test("joins multiple issues with '; '", () => { + const schema = z.object({ a: z.string(), b: z.number() }); + const result = schema.safeParse({}); + if (result.success) throw new Error("expected failure"); + + const msg = formatZodError(result.error, "t"); + expect(msg.split(";").length).toBeGreaterThanOrEqual(2); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tools/define-tool.ts b/packages/appkit/src/plugins/agents/tools/define-tool.ts new file mode 100644 index 00000000..dc269ba6 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tools/define-tool.ts @@ -0,0 +1,94 @@ +import type { AgentToolDefinition, ToolAnnotations } from "shared"; +import type { z } from "zod"; +import { toToolJSONSchema } from "./json-schema"; +import { formatZodError } from "./tool"; + +/** + * Single-tool entry for a plugin's internal tool registry. + * + * Plugins collect these into a `Record` keyed by the tool's + * public name and dispatch via `executeFromRegistry`. + */ +export interface ToolEntry { + description: string; + schema: S; + annotations?: ToolAnnotations; + /** + * Whether this tool is eligible for auto-inheritance into markdown or + * code-defined agents that enable `autoInheritTools`. Defaults to `false` + * (safe-by-default) — plugin authors must explicitly opt a tool in if they + * consider it safe enough to appear in every agent's tool record without an + * explicit `tools:` declaration. Destructive or privilege-sensitive tools + * should leave this unset so that they only reach agents that wire them + * explicitly (via `tools:`, `toolkits:`, or `fromPlugin({ only: [...] })`). + */ + autoInheritable?: boolean; + handler: ( + args: z.infer, + signal?: AbortSignal, + ) => unknown | Promise; +} + +export type ToolRegistry = Record; + +/** + * Defines a single tool entry for a plugin's internal registry. + * + * The generic `S` flows from `schema` through to the `handler` callback so + * `args` is fully typed from the Zod schema. Names are assigned by the + * registry key, so they are not repeated inside the entry. + */ +export function defineTool( + config: ToolEntry, +): ToolEntry { + return config; +} + +/** + * Validates tool-call arguments against the entry's schema and invokes its + * handler. On validation failure, returns an LLM-friendly error string + * (matching the behavior of `tool()`) rather than throwing, so the model + * can self-correct on its next turn. + */ +export async function executeFromRegistry( + registry: ToolRegistry, + name: string, + args: unknown, + signal?: AbortSignal, +): Promise { + const entry = registry[name]; + if (!entry) { + throw new Error(`Unknown tool: ${name}`); + } + const parsed = entry.schema.safeParse(args); + if (!parsed.success) { + return formatZodError(parsed.error, name); + } + return entry.handler(parsed.data, signal); +} + +/** + * Produces the `AgentToolDefinition[]` a ToolProvider exposes to the LLM, + * deriving `parameters` JSON Schema from each entry's Zod schema. + * + * Tool names come from registry keys (supports dotted names like + * `uploads.list` for dynamic plugins). + */ +export function toolsFromRegistry( + registry: ToolRegistry, +): AgentToolDefinition[] { + return Object.entries(registry).map(([name, entry]) => { + const parameters = toToolJSONSchema( + entry.schema, + ) as unknown as AgentToolDefinition["parameters"]; + const def: AgentToolDefinition = { + name, + description: entry.description, + parameters, + }; + if (entry.annotations) { + def.annotations = entry.annotations; + } + return def; + }); +} diff --git a/packages/appkit/src/plugins/agents/tools/function-tool.ts b/packages/appkit/src/plugins/agents/tools/function-tool.ts new file mode 100644 index 00000000..8ce634e0 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tools/function-tool.ts @@ -0,0 +1,33 @@ +import type { AgentToolDefinition } from "shared"; + +export interface FunctionTool { + type: "function"; + name: string; + description?: string | null; + parameters?: Record | null; + strict?: boolean | null; + execute: (args: Record) => Promise | string; +} + +export function isFunctionTool(value: unknown): value is FunctionTool { + if (typeof value !== "object" || value === null) return false; + const obj = value as Record; + return ( + obj.type === "function" && + typeof obj.name === "string" && + typeof obj.execute === "function" + ); +} + +export function functionToolToDefinition( + tool: FunctionTool, +): AgentToolDefinition { + return { + name: tool.name, + description: tool.description ?? tool.name, + parameters: (tool.parameters as AgentToolDefinition["parameters"]) ?? { + type: "object", + properties: {}, + }, + }; +} diff --git a/packages/appkit/src/plugins/agents/tools/hosted-tools.ts b/packages/appkit/src/plugins/agents/tools/hosted-tools.ts new file mode 100644 index 00000000..bce70c4f --- /dev/null +++ b/packages/appkit/src/plugins/agents/tools/hosted-tools.ts @@ -0,0 +1,102 @@ +export interface GenieTool { + type: "genie-space"; + genie_space: { id: string }; +} + +export interface VectorSearchIndexTool { + type: "vector_search_index"; + vector_search_index: { name: string }; +} + +export interface CustomMcpServerTool { + type: "custom_mcp_server"; + custom_mcp_server: { app_name: string; app_url: string }; +} + +export interface ExternalMcpServerTool { + type: "external_mcp_server"; + external_mcp_server: { connection_name: string }; +} + +export type HostedTool = + | GenieTool + | VectorSearchIndexTool + | CustomMcpServerTool + | ExternalMcpServerTool; + +const HOSTED_TOOL_TYPES = new Set([ + "genie-space", + "vector_search_index", + "custom_mcp_server", + "external_mcp_server", +]); + +export function isHostedTool(value: unknown): value is HostedTool { + if (typeof value !== "object" || value === null) return false; + const obj = value as Record; + return typeof obj.type === "string" && HOSTED_TOOL_TYPES.has(obj.type); +} + +export interface McpEndpointConfig { + name: string; + /** Absolute URL or path relative to workspace host */ + url: string; +} + +/** + * Resolves HostedTool configs into MCP endpoint configurations + * that the MCP client can connect to. + */ +function resolveHostedTool(tool: HostedTool): McpEndpointConfig { + switch (tool.type) { + case "genie-space": + return { + name: `genie-${tool.genie_space.id}`, + url: `/api/2.0/mcp/genie/${tool.genie_space.id}`, + }; + case "vector_search_index": { + const parts = tool.vector_search_index.name.split("."); + if (parts.length !== 3) { + throw new Error( + `vector_search_index name must be 3-part dotted (catalog.schema.index), got: ${tool.vector_search_index.name}`, + ); + } + return { + name: `vs-${parts.join("-")}`, + url: `/api/2.0/mcp/vector-search/${parts[0]}/${parts[1]}/${parts[2]}`, + }; + } + case "custom_mcp_server": + return { + name: tool.custom_mcp_server.app_name, + url: tool.custom_mcp_server.app_url, + }; + case "external_mcp_server": + return { + name: tool.external_mcp_server.connection_name, + url: `/api/2.0/mcp/external/${tool.external_mcp_server.connection_name}`, + }; + } +} + +export function resolveHostedTools(tools: HostedTool[]): McpEndpointConfig[] { + return tools.map(resolveHostedTool); +} + +/** + * Factory for declaring a custom MCP server tool. + * + * Replaces the verbose `{ type: "custom_mcp_server", custom_mcp_server: { app_name, app_url } }` + * wrapper with a concise positional call. + * + * Example: + * ```ts + * mcpServer("my-app", "https://my-app.databricksapps.com/mcp") + * ``` + */ +export function mcpServer(name: string, url: string): CustomMcpServerTool { + return { + type: "custom_mcp_server", + custom_mcp_server: { app_name: name, app_url: url }, + }; +} diff --git a/packages/appkit/src/plugins/agents/tools/index.ts b/packages/appkit/src/plugins/agents/tools/index.ts new file mode 100644 index 00000000..7b779d1c --- /dev/null +++ b/packages/appkit/src/plugins/agents/tools/index.ts @@ -0,0 +1,20 @@ +export { + defineTool, + executeFromRegistry, + type ToolEntry, + type ToolRegistry, + toolsFromRegistry, +} from "./define-tool"; +export { + type FunctionTool, + functionToolToDefinition, + isFunctionTool, +} from "./function-tool"; +export { + type HostedTool, + isHostedTool, + mcpServer, + resolveHostedTools, +} from "./hosted-tools"; +export { AppKitMcpClient } from "./mcp-client"; +export { type ToolConfig, tool } from "./tool"; diff --git a/packages/appkit/src/plugins/agents/tools/json-schema.ts b/packages/appkit/src/plugins/agents/tools/json-schema.ts new file mode 100644 index 00000000..c5c10dbf --- /dev/null +++ b/packages/appkit/src/plugins/agents/tools/json-schema.ts @@ -0,0 +1,20 @@ +import { toJSONSchema, type z } from "zod"; + +/** + * Converts a Zod schema to JSON Schema suitable for an LLM tool-call + * `parameters` field. + * + * Wraps `zod`'s `toJSONSchema()` and strips the top-level `$schema` annotation + * that Zod v4 emits by default (e.g. `"https://json-schema.org/draft/..."`). + * The Databricks Mosaic serving endpoint forwards tool schemas to Google's + * Gemini `function_declarations` format, which rejects any top-level key it + * doesn't explicitly recognize — including `$schema` — with a 400 + * `Invalid JSON payload received. Unknown name "$schema"` error. Other LLM + * providers either ignore the field or also trip on it, so stripping here is + * safe across backends. + */ +export function toToolJSONSchema(schema: z.ZodType): Record { + const raw = toJSONSchema(schema) as Record; + const { $schema: _ignored, ...rest } = raw; + return rest; +} diff --git a/packages/appkit/src/plugins/agents/tools/mcp-client.ts b/packages/appkit/src/plugins/agents/tools/mcp-client.ts new file mode 100644 index 00000000..fcb58d59 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tools/mcp-client.ts @@ -0,0 +1,394 @@ +import type { AgentToolDefinition } from "shared"; +import { createLogger } from "../../../logging/logger"; +import type { McpEndpointConfig } from "./hosted-tools"; +import { + assertResolvedHostSafe, + checkMcpUrl, + type DnsLookup, + type McpHostPolicy, +} from "./mcp-host-policy"; + +const logger = createLogger("agent:mcp"); + +interface JsonRpcRequest { + jsonrpc: "2.0"; + id: number; + method: string; + params?: Record; +} + +interface JsonRpcResponse { + jsonrpc: "2.0"; + id: number; + result?: unknown; + error?: { code: number; message: string; data?: unknown }; +} + +interface McpToolSchema { + name: string; + description?: string; + inputSchema?: Record; +} + +interface McpToolCallResult { + content: Array<{ type: string; text?: string }>; + isError?: boolean; +} + +interface McpServerConnection { + config: McpEndpointConfig; + resolvedUrl: string; + /** + * Whether workspace auth (SP / OBO) may be forwarded to this endpoint's URL. + * Decided at `connect()` time via {@link McpHostPolicy} and cached for the + * lifetime of the connection. + */ + forwardWorkspaceAuth: boolean; + tools: Map; +} + +/** + * Lightweight MCP client for Databricks-hosted MCP servers. + * + * Uses raw fetch() with JSON-RPC 2.0 over HTTP — no @modelcontextprotocol/sdk + * or LangChain dependency. Supports the Streamable HTTP transport only + * (POST with JSON-RPC request, single JSON-RPC response). Implements exactly + * four methods: `initialize`, `notifications/initialized`, `tools/list`, + * `tools/call`. No prompts/resources/completion/sampling. + * + * All outbound URLs are gated by an {@link McpHostPolicy}: unallowlisted hosts + * are rejected before the first byte is sent, and workspace credentials are + * only forwarded to the same-origin workspace. See `mcp-host-policy.ts`. + * + * ### Why not `@modelcontextprotocol/sdk`? + * + * Our security model — zero-trust host policy + per-destination auth scoping + * via `forwardWorkspaceAuth` — is the load-bearing anchor of this module, not + * the wire protocol. Wrapping the SDK to enforce the same invariants (custom + * transport + auth plumbing + policy checks before every request) ends up + * the same size as writing the narrow JSON-RPC subset directly, but with an + * opaque dependency that's harder to audit. Our entire target surface is + * Databricks-hosted MCP over Streamable HTTP; we never need stdio, SSE-only + * transport, or the rest of the protocol the SDK covers. + * + * Zero runtime deps matches the rest of AppKit's philosophy + * (`DatabricksAdapter`, `sql-policy`). Revisit the call if MCP adoption + * grows beyond Databricks-hosted servers with the transports above, or if + * the SDK gains transport-level auth hooks that cleanly match our + * per-destination auth-forwarding model. + */ +export class AppKitMcpClient { + private connections = new Map(); + private sessionIds = new Map(); + private requestId = 0; + private closed = false; + + constructor( + private workspaceHost: string, + private authenticate: () => Promise>, + private policy: McpHostPolicy, + private options: { dnsLookup?: DnsLookup; fetchImpl?: typeof fetch } = {}, + ) {} + + async connectAll(endpoints: McpEndpointConfig[]): Promise { + const results = await Promise.allSettled( + endpoints.map((ep) => this.connect(ep)), + ); + for (let i = 0; i < results.length; i++) { + if (results[i].status === "rejected") { + logger.error( + "Failed to connect MCP server %s: %O", + endpoints[i].name, + (results[i] as PromiseRejectedResult).reason, + ); + } + } + } + + private resolveUrl(endpoint: McpEndpointConfig): string { + if ( + endpoint.url.startsWith("http://") || + endpoint.url.startsWith("https://") + ) { + return endpoint.url; + } + return `${this.workspaceHost}${endpoint.url}`; + } + + async connect(endpoint: McpEndpointConfig): Promise { + const resolvedUrl = this.resolveUrl(endpoint); + const check = checkMcpUrl(resolvedUrl, this.policy); + if (!check.ok) { + throw new Error( + `MCP endpoint '${endpoint.name}' refused at connect: ${check.reason}`, + ); + } + await assertResolvedHostSafe( + check.url.hostname, + this.policy, + this.options.dnsLookup, + ); + + logger.info( + "Connecting to MCP server: %s at %s (forwardWorkspaceAuth=%s)", + endpoint.name, + resolvedUrl, + check.forwardWorkspaceAuth, + ); + + const initResponse = await this.sendRpc( + resolvedUrl, + "initialize", + { + protocolVersion: "2025-03-26", + capabilities: {}, + clientInfo: { name: "appkit-agent", version: "0.1.0" }, + }, + { forwardWorkspaceAuth: check.forwardWorkspaceAuth }, + ); + + if (initResponse.sessionId) { + this.sessionIds.set(endpoint.name, initResponse.sessionId); + } + const sessionId = this.sessionIds.get(endpoint.name); + + await this.sendNotification(resolvedUrl, "notifications/initialized", { + sessionId, + forwardWorkspaceAuth: check.forwardWorkspaceAuth, + }); + + const listResponse = await this.sendRpc( + resolvedUrl, + "tools/list", + {}, + { sessionId, forwardWorkspaceAuth: check.forwardWorkspaceAuth }, + ); + const toolList = + (listResponse.result as { tools?: McpToolSchema[] })?.tools ?? []; + + const tools = new Map(); + for (const tool of toolList) { + tools.set(tool.name, tool); + } + + this.connections.set(endpoint.name, { + config: endpoint, + resolvedUrl, + forwardWorkspaceAuth: check.forwardWorkspaceAuth, + tools, + }); + logger.info( + "Connected to MCP server %s: %d tools available", + endpoint.name, + tools.size, + ); + } + + getAllToolDefinitions(): AgentToolDefinition[] { + const defs: AgentToolDefinition[] = []; + for (const [serverName, conn] of this.connections) { + for (const [toolName, schema] of conn.tools) { + defs.push({ + name: `mcp.${serverName}.${toolName}`, + description: schema.description ?? toolName, + parameters: + (schema.inputSchema as AgentToolDefinition["parameters"]) ?? { + type: "object", + properties: {}, + }, + }); + } + } + return defs; + } + + /** + * Whether the named MCP server may receive workspace-scoped auth headers + * (e.g., an OBO bearer token from an end-user request). Callers should gate + * auth-forwarding decisions on this to prevent credential exfiltration to + * non-workspace hosts. + */ + canForwardWorkspaceAuth(serverName: string): boolean { + return this.connections.get(serverName)?.forwardWorkspaceAuth ?? false; + } + + async callTool( + qualifiedName: string, + args: unknown, + authHeaders?: Record, + callerSignal?: AbortSignal, + ): Promise { + const parts = qualifiedName.split("."); + if (parts.length < 3 || parts[0] !== "mcp") { + throw new Error(`Invalid MCP tool name: ${qualifiedName}`); + } + const serverName = parts[1]; + const toolName = parts.slice(2).join("."); + + const conn = this.connections.get(serverName); + if (!conn) { + throw new Error(`MCP server not connected: ${serverName}`); + } + + const sessionId = this.sessionIds.get(serverName); + // authHeaders are caller-supplied credentials (typically the OBO token). + // Only honor them if the destination URL was admitted with + // forwardWorkspaceAuth=true at connect time. + const scopedAuthOverride = conn.forwardWorkspaceAuth + ? authHeaders + : undefined; + + const rpcResult = await this.sendRpc( + conn.resolvedUrl, + "tools/call", + { name: toolName, arguments: args }, + { + authOverride: scopedAuthOverride, + sessionId, + forwardWorkspaceAuth: conn.forwardWorkspaceAuth, + callerSignal, + }, + ); + const result = rpcResult.result as McpToolCallResult; + + if (result.isError) { + const errText = (result.content ?? []) + .filter((c) => c.type === "text") + .map((c) => c.text) + .join("\n"); + throw new Error(errText || "MCP tool call failed"); + } + + return (result.content ?? []) + .filter((c) => c.type === "text") + .map((c) => c.text) + .join("\n"); + } + + async close(): Promise { + this.closed = true; + this.connections.clear(); + this.sessionIds.clear(); + } + + private async sendRpc( + url: string, + method: string, + params?: Record, + options?: { + authOverride?: Record; + sessionId?: string; + forwardWorkspaceAuth?: boolean; + /** + * Optional external abort signal (typically the agent's stream signal). + * Composed with the built-in 30 s timeout so `/cancel` or agent-run + * shutdown immediately propagates to the MCP fetch rather than waiting + * for the remote server to respond. + */ + callerSignal?: AbortSignal; + }, + ): Promise<{ result: unknown; sessionId?: string }> { + if (this.closed) throw new Error("MCP client is closed"); + + const request: JsonRpcRequest = { + jsonrpc: "2.0", + id: ++this.requestId, + method, + ...(params && { params }), + }; + + const authHeaders = await this.resolveAuthHeaders(options); + const headers: Record = { + "Content-Type": "application/json", + Accept: "application/json, text/event-stream", + ...authHeaders, + }; + if (options?.sessionId) { + headers["Mcp-Session-Id"] = options.sessionId; + } + + const fetchImpl = this.options.fetchImpl ?? fetch; + const signals: AbortSignal[] = [AbortSignal.timeout(30_000)]; + if (options?.callerSignal) signals.push(options.callerSignal); + const response = await fetchImpl(url, { + method: "POST", + headers, + body: JSON.stringify(request), + signal: signals.length > 1 ? AbortSignal.any(signals) : signals[0], + }); + + if (!response.ok) { + throw new Error( + `MCP request to ${method} failed: ${response.status} ${response.statusText}`, + ); + } + + const contentType = response.headers.get("content-type") ?? ""; + let json: JsonRpcResponse; + + if (contentType.includes("text/event-stream")) { + const text = await response.text(); + const lastData = text + .split("\n") + .filter((line) => line.startsWith("data: ")) + .map((line) => line.slice(6)) + .pop(); + if (!lastData) { + throw new Error(`MCP SSE response for ${method} contained no data`); + } + json = JSON.parse(lastData) as JsonRpcResponse; + } else { + json = (await response.json()) as JsonRpcResponse; + } + + if (json.error) { + throw new Error(`MCP error (${json.error.code}): ${json.error.message}`); + } + + const sid = response.headers.get("mcp-session-id") ?? undefined; + return { result: json.result, sessionId: sid }; + } + + private async sendNotification( + url: string, + method: string, + options?: { + sessionId?: string; + forwardWorkspaceAuth?: boolean; + }, + ): Promise { + if (this.closed) return; + + const authHeaders = await this.resolveAuthHeaders(options); + const headers: Record = { + "Content-Type": "application/json", + Accept: "application/json, text/event-stream", + ...authHeaders, + }; + if (options?.sessionId) { + headers["Mcp-Session-Id"] = options.sessionId; + } + + const fetchImpl = this.options.fetchImpl ?? fetch; + await fetchImpl(url, { + method: "POST", + headers, + body: JSON.stringify({ jsonrpc: "2.0", method }), + signal: AbortSignal.timeout(30_000), + }); + } + + /** + * Return the auth headers to send on an outbound request. Workspace auth + * (SP or OBO) is only resolved when `forwardWorkspaceAuth` is true; for + * non-workspace hosts no bearer token is attached. + */ + private async resolveAuthHeaders(options?: { + authOverride?: Record; + forwardWorkspaceAuth?: boolean; + }): Promise> { + if (!options?.forwardWorkspaceAuth) return {}; + if (options.authOverride) return options.authOverride; + return this.authenticate(); + } +} diff --git a/packages/appkit/src/plugins/agents/tools/mcp-host-policy.ts b/packages/appkit/src/plugins/agents/tools/mcp-host-policy.ts new file mode 100644 index 00000000..d970c83a --- /dev/null +++ b/packages/appkit/src/plugins/agents/tools/mcp-host-policy.ts @@ -0,0 +1,299 @@ +import { lookup as defaultLookup } from "node:dns/promises"; +import { isIP, isIPv4 } from "node:net"; + +/** + * DNS lookup function compatible with `dns/promises.lookup(host, { all: true })`. + * Exposed as an injection point so callers (tests, custom DNS resolvers) can + * override the default resolver. + */ +export type DnsLookup = ( + hostname: string, + options: { all: true }, +) => Promise>; + +/** + * Policy that decides whether a given MCP endpoint URL is allowed and whether + * Databricks workspace credentials (SP or OBO) may be forwarded to it. + * + * The default posture is zero-trust: only same-origin workspace URLs receive + * workspace credentials, and all other destinations must be explicitly + * allowlisted by the application developer. Private / link-local IP ranges + * are blocked outright to prevent SSRF into cloud metadata services. + */ +export interface McpHostPolicy { + /** Lowercased hostname of the Databricks workspace (same-origin target). */ + readonly workspaceHostname: string; + /** Additional allowlisted hostnames (lowercased). Workspace auth is NEVER forwarded to these. */ + readonly trustedHosts: ReadonlySet; + /** Permit `http://localhost`, `127.0.0.1`, `::1` URLs. Typically true only in development. */ + readonly allowLocalhost: boolean; +} + +/** + * Config shape accepted by {@link buildMcpHostPolicy}, matching the + * `mcp` field on `AgentsPluginConfig`. + */ +export interface McpHostPolicyConfig { + /** + * Additional hostnames that may host custom MCP servers beyond the same-origin + * workspace. Compared case-insensitively; bare hostnames only (no scheme or + * path). Workspace credentials (SP / OBO) are never forwarded to these hosts — + * they must handle authentication themselves. + */ + trustedHosts?: string[]; + /** + * Allow `http://localhost`, `127.0.0.1`, and `::1` MCP URLs for local + * development. Defaults to `true` when `NODE_ENV !== "production"`, + * otherwise `false`. Workspace credentials are never forwarded to localhost. + */ + allowLocalhost?: boolean; +} + +/** Build an {@link McpHostPolicy} from user config + the resolved workspace URL. */ +export function buildMcpHostPolicy( + config: McpHostPolicyConfig | undefined, + workspaceHost: string, +): McpHostPolicy { + const workspaceHostname = safeHostname(workspaceHost); + if (!workspaceHostname) { + throw new Error( + `Invalid workspace host for MCP policy: ${JSON.stringify(workspaceHost)}`, + ); + } + const trustedHosts = new Set( + (config?.trustedHosts ?? []).map((h) => h.trim().toLowerCase()), + ); + const allowLocalhost = + config?.allowLocalhost ?? process.env.NODE_ENV !== "production"; + return { workspaceHostname, trustedHosts, allowLocalhost }; +} + +type McpUrlCheck = + | { + readonly ok: true; + /** Whether it is safe to forward workspace-scoped credentials (SP/OBO) to this URL. */ + readonly forwardWorkspaceAuth: boolean; + /** Parsed URL for reuse by the caller. */ + readonly url: URL; + } + | { readonly ok: false; readonly reason: string }; + +/** + * Synchronously decide whether an MCP URL is allowed under the given policy + * and whether workspace credentials may be forwarded to it. + * + * Hard rejections: + * - Non-`http(s)` schemes. + * - `http://` unless the host is localhost AND `allowLocalhost` is true. + * - Hosts that are neither same-origin workspace, localhost (if allowed), + * nor in `trustedHosts`. + */ +export function checkMcpUrl( + rawUrl: string, + policy: McpHostPolicy, +): McpUrlCheck { + let url: URL; + try { + url = new URL(rawUrl); + } catch { + return { + ok: false, + reason: `MCP URL is not a valid absolute URL: ${rawUrl}`, + }; + } + + if (url.protocol !== "http:" && url.protocol !== "https:") { + return { + ok: false, + reason: `MCP URL scheme '${url.protocol}' is not allowed (http(s) only): ${rawUrl}`, + }; + } + + const host = url.hostname.toLowerCase(); + const isLoopback = isLoopbackHost(host); + + if (url.protocol === "http:" && !(isLoopback && policy.allowLocalhost)) { + return { + ok: false, + reason: `MCP URL uses plaintext http:// which forwards bearer tokens in cleartext: ${rawUrl}. Use https:// or enable allowLocalhost for a localhost dev server.`, + }; + } + + if (host === policy.workspaceHostname) { + return { ok: true, forwardWorkspaceAuth: true, url }; + } + + if (isLoopback) { + if (!policy.allowLocalhost) { + return { + ok: false, + reason: `MCP URL points to localhost but allowLocalhost is disabled: ${rawUrl}`, + }; + } + return { ok: true, forwardWorkspaceAuth: false, url }; + } + + if (policy.trustedHosts.has(host)) { + return { ok: true, forwardWorkspaceAuth: false, url }; + } + + return { + ok: false, + reason: `MCP host '${host}' is not allowed. Either use a same-origin workspace URL (${policy.workspaceHostname}) or add it to agents({ mcp: { trustedHosts: ['${host}'] } }).`, + }; +} + +/** + * Resolve `hostname` via DNS and assert that none of its addresses fall in a + * blocked IP range (loopback, RFC1918, link-local, CGNAT, cloud metadata). + * + * Throws with a descriptive error if any resolved address is blocked. Pass + * `allowLocalhost: true` to permit `127.0.0.1` / `::1` specifically. + * + * Note: this only guards against hosts that statically resolve to private + * ranges. Full SSRF protection requires socket-level IP pinning after + * resolution (DNS rebinding defense), which is out of scope here. + */ +export async function assertResolvedHostSafe( + hostname: string, + policy: McpHostPolicy, + lookup: DnsLookup = defaultLookup, +): Promise { + const lowered = hostname.toLowerCase(); + + if (isIP(lowered)) { + if (isBlockedIp(lowered, policy.allowLocalhost)) { + throw new Error(`MCP host ${lowered} is in a blocked IP range`); + } + return; + } + + if (lowered === "localhost") { + if (!policy.allowLocalhost) { + throw new Error( + `MCP host localhost is not allowed under the current policy`, + ); + } + return; + } + + let resolved: Array<{ address: string }>; + try { + resolved = await lookup(hostname, { all: true }); + } catch (cause) { + throw new Error( + `MCP host ${hostname} could not be resolved via DNS: ${cause instanceof Error ? cause.message : String(cause)}`, + ); + } + + if (resolved.length === 0) { + throw new Error(`MCP host ${hostname} returned no DNS addresses`); + } + + for (const { address } of resolved) { + if (isBlockedIp(address, policy.allowLocalhost)) { + throw new Error( + `MCP host ${hostname} resolved to blocked address ${address} (private / link-local ranges are not allowed)`, + ); + } + } +} + +/** Whether a raw hostname literal is one of the recognised loopback aliases. */ +export function isLoopbackHost(host: string): boolean { + const lowered = host.toLowerCase(); + return ( + lowered === "localhost" || + lowered === "127.0.0.1" || + lowered === "::1" || + lowered === "[::1]" || + lowered === "0:0:0:0:0:0:0:1" + ); +} + +/** + * Check whether a resolved IP address is in a range that should never receive + * workspace credentials. `allowLocalhost` carves out 127.0.0.0/8 and ::1. + */ +export function isBlockedIp(address: string, allowLocalhost: boolean): boolean { + if (isIPv4(address)) { + return isBlockedIpv4(address, allowLocalhost); + } + if (isIP(address) === 6) { + return isBlockedIpv6(address, allowLocalhost); + } + // Not a recognisable IP literal — fail-closed. + return true; +} + +function isBlockedIpv4(addr: string, allowLocalhost: boolean): boolean { + const parts = addr.split(".").map((p) => Number.parseInt(p, 10)); + if (parts.length !== 4 || parts.some((n) => !Number.isFinite(n))) { + return true; + } + const [a, b] = parts; + if (a === 0) return true; + if (a === 127) return !allowLocalhost; + if (a === 10) return true; + if (a === 172 && b >= 16 && b <= 31) return true; + if (a === 192 && b === 168) return true; + if (a === 169 && b === 254) return true; + if (a === 100 && b >= 64 && b <= 127) return true; + if (a >= 224) return true; + return false; +} + +function isBlockedIpv6(addr: string, allowLocalhost: boolean): boolean { + const lowered = addr.toLowerCase().replace(/^\[|\]$/g, ""); + + if (lowered === "::") return true; + if (lowered === "::1" || lowered === "0:0:0:0:0:0:0:1") + return !allowLocalhost; + + // IPv4-mapped IPv6: `::ffff:` may be written in dotted form + // (`::ffff:169.254.169.254`) or colon-hex form (`::ffff:a9fe:a9fe`). Both + // route to the same destination, so we must normalise before delegating + // to the IPv4 blocklist. + if (lowered.startsWith("::ffff:")) { + const tail = lowered.slice("::ffff:".length); + if (isIPv4(tail)) return isBlockedIpv4(tail, allowLocalhost); + const hexV4 = hexPairToDottedIpv4(tail); + if (hexV4) return isBlockedIpv4(hexV4, allowLocalhost); + } + + // Unique Local Addresses (fc00::/7) — `fc` and `fd` only. + if (/^f[cd][0-9a-f]{2}:/.test(lowered)) return true; + // Link-local fe80::/10 — the first 10 bits are 1111111010, i.e. the + // second hex nibble must be 8-b. Matches fe80:..–febf:.. + if (/^fe[89ab][0-9a-f]:/.test(lowered)) return true; + // Multicast ff00::/8. + if (lowered.startsWith("ff")) return true; + return false; +} + +/** + * Parse the trailing two hex groups of an IPv4-mapped IPv6 address written + * in colon-hex form (e.g. `a9fe:a9fe`) into the equivalent dotted-quad IPv4 + * representation (`169.254.169.254`). Returns null for anything else. + */ +function hexPairToDottedIpv4(tail: string): string | null { + const match = tail.match(/^([0-9a-f]{1,4}):([0-9a-f]{1,4})$/); + if (!match) return null; + const hi = Number.parseInt(match[1], 16); + const lo = Number.parseInt(match[2], 16); + if (!Number.isFinite(hi) || !Number.isFinite(lo)) return null; + if (hi < 0 || hi > 0xffff || lo < 0 || lo > 0xffff) return null; + const a = (hi >> 8) & 0xff; + const b = hi & 0xff; + const c = (lo >> 8) & 0xff; + const d = lo & 0xff; + return `${a}.${b}.${c}.${d}`; +} + +function safeHostname(rawUrl: string): string | null { + try { + return new URL(rawUrl).hostname.toLowerCase(); + } catch { + return null; + } +} diff --git a/packages/appkit/src/plugins/agents/tools/sql-policy.ts b/packages/appkit/src/plugins/agents/tools/sql-policy.ts new file mode 100644 index 00000000..6f889d44 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tools/sql-policy.ts @@ -0,0 +1,317 @@ +/** + * Conservative SQL classifier used by agent-facing query tools to enforce + * `readOnly: true` annotations at execution time. + * + * Why a hand-rolled tokenizer rather than `node-sql-parser` or `pgsql-parser`: + * + * - `node-sql-parser`'s Hive/Spark dialect coverage rejects common Databricks + * SQL patterns (three-part `catalog.schema.table` names, `SHOW TABLES IN`, + * `DESCRIBE EXTENDED`, `EXPLAIN`) that must be allowed by a read-only + * classifier. Its PostgreSQL grammar rejects `SHOW`/`DESCRIBE` too. + * - `pgsql-parser` (libpg_query) is a native binding and fails to install + * cleanly on every Databricks App runtime we care about. + * + * We don't need to fully parse SQL — we only need to decide whether every + * statement in the batch starts with a read-only keyword. A small tokenizer + * that correctly strips strings, identifiers, and comments is enough and + * costs no extra dependencies. + * + * What this classifier guarantees (when it returns `readOnly: true`): + * + * - Every semicolon-separated statement outside a string, identifier, or + * comment begins with `SELECT`, `WITH`, `SHOW`, `EXPLAIN`, `DESCRIBE`, or + * `DESC`. + * - `SELECT 1; DROP TABLE x` is rejected (stacked write detected). + * - `SELECT 'value; DROP TABLE x'` passes (literal inside a string). + * - `-- DROP TABLE x\nSELECT 1` passes (comment stripped). + * - `SELECT 1 ` passes (comment stripped). + * + * What this classifier does NOT guarantee: + * + * - A `SELECT` statement may still have side effects via function calls + * (`SELECT pg_advisory_lock(...)`, `SELECT lo_import('/etc/passwd')`, CTEs + * with DML in Postgres 9.1+). Callers that need stronger guarantees should + * combine this check with a runtime mechanism: for PostgreSQL, execute the + * statement inside a dedicated client's `BEGIN READ ONLY … ROLLBACK` + * transaction (see `LakebasePlugin.runReadOnlyStatement`). A batched + * `pool.query("BEGIN READ ONLY; ; ROLLBACK")` cannot be used because + * the Postgres Extended Query protocol rejects multi-statement prepared + * queries, which silently breaks parameterized SQL. + */ + +const READ_ONLY_KEYWORDS = new Set([ + "SELECT", + "WITH", + "SHOW", + "EXPLAIN", + "DESCRIBE", + "DESC", +]); + +type SqlReadOnlyResult = + | { readOnly: true; statements: number } + | { readOnly: false; reason: string }; + +/** + * Classify a SQL string as read-only or not. See module docstring for the + * precise guarantee this offers. + */ +export function classifyReadOnly(sql: string): SqlReadOnlyResult { + const strip = stripCommentsAndQuoted(sql); + if (strip.unterminated) { + return { + readOnly: false, + reason: `SQL has an unterminated ${strip.unterminated} literal`, + }; + } + const statements = splitStatements(strip.cleaned); + + if (statements.length === 0) { + return { + readOnly: false, + reason: "SQL is empty or contains only comments", + }; + } + + for (let i = 0; i < statements.length; i++) { + const stmt = statements[i]; + const firstWord = firstKeyword(stmt); + if (!firstWord) { + return { + readOnly: false, + reason: `statement ${i + 1} of ${statements.length} is empty`, + }; + } + if (!READ_ONLY_KEYWORDS.has(firstWord.toUpperCase())) { + return { + readOnly: false, + reason: `statement starts with '${firstWord}'; only SELECT, WITH, SHOW, EXPLAIN, DESCRIBE, DESC are allowed in read-only mode`, + }; + } + } + + return { readOnly: true, statements: statements.length }; +} + +/** + * Assert `sql` is read-only or throw {@link ReadOnlySqlViolation}. Suitable + * for calling from agent-tool handlers where the thrown string surfaces back + * to the LLM as the tool's error output. + */ +export function assertReadOnlySql(sql: string): void { + const result = classifyReadOnly(sql); + if (!result.readOnly) { + throw new ReadOnlySqlViolation(result.reason); + } +} + +export class ReadOnlySqlViolation extends Error { + constructor(reason: string) { + super(`SQL read-only policy violation: ${reason}`); + this.name = "ReadOnlySqlViolation"; + } +} + +// --------------------------------------------------------------------------- +// Tokenizer helpers +// --------------------------------------------------------------------------- + +/** + * Walk `sql` character-by-character and replace every string literal, + * identifier quote, and comment body with a single space of equivalent + * length. Leaves structural tokens (semicolons, whitespace, identifiers, + * operators) in place. + * + * Handles: + * - `-- line comments` through end-of-line + * - SQL block comments (slash-star ... star-slash) with correct nesting (PostgreSQL) + * - `'single-quoted strings'` with `''` escape + * - `"double-quoted identifiers"` with `""` escape (ANSI) + * - `` `backtick identifiers` `` (Databricks) + * - `$tag$dollar quoted$tag$` strings (PostgreSQL) + * - `E'escape-style'` strings (PostgreSQL) + */ +type StripResult = { + cleaned: string; + /** Non-null if tokenization ended inside an unterminated literal or comment. */ + unterminated: + | null + | "string" + | "identifier" + | "block comment" + | "dollar-quoted string"; +}; + +function stripCommentsAndQuoted(sql: string): StripResult { + const out: string[] = []; + let i = 0; + const n = sql.length; + let unterminated: StripResult["unterminated"] = null; + + while (i < n) { + const ch = sql[i]; + const next = i + 1 < n ? sql[i + 1] : ""; + + if (ch === "-" && next === "-") { + out.push(" "); + i += 2; + while (i < n && sql[i] !== "\n") { + out.push(" "); + i++; + } + continue; + } + + if (ch === "/" && next === "*") { + out.push(" "); + i += 2; + let depth = 1; + while (i < n && depth > 0) { + if (sql[i] === "/" && sql[i + 1] === "*") { + out.push(" "); + i += 2; + depth++; + continue; + } + if (sql[i] === "*" && sql[i + 1] === "/") { + out.push(" "); + i += 2; + depth--; + continue; + } + out.push(sql[i] === "\n" ? "\n" : " "); + i++; + } + if (depth > 0) { + unterminated = "block comment"; + } + continue; + } + + if ( + ch === "'" || + (ch === "E" && next === "'") || + (ch === "e" && next === "'") + ) { + if (ch === "E" || ch === "e") { + out.push(" "); + i++; + } + out.push(" "); + i++; + let closed = false; + while (i < n) { + if (sql[i] === "'" && sql[i + 1] === "'") { + out.push(" "); + i += 2; + continue; + } + if (sql[i] === "\\" && sql[i + 1]) { + out.push(" "); + i += 2; + continue; + } + if (sql[i] === "'") { + out.push(" "); + i++; + closed = true; + break; + } + out.push(sql[i] === "\n" ? "\n" : " "); + i++; + } + if (!closed) unterminated = "string"; + continue; + } + + if (ch === '"') { + out.push(" "); + i++; + let closed = false; + while (i < n) { + if (sql[i] === '"' && sql[i + 1] === '"') { + out.push(" "); + i += 2; + continue; + } + if (sql[i] === '"') { + out.push(" "); + i++; + closed = true; + break; + } + out.push(sql[i] === "\n" ? "\n" : " "); + i++; + } + if (!closed) unterminated = "identifier"; + continue; + } + + if (ch === "`") { + out.push(" "); + i++; + let closed = false; + while (i < n) { + if (sql[i] === "`" && sql[i + 1] === "`") { + out.push(" "); + i += 2; + continue; + } + if (sql[i] === "`") { + out.push(" "); + i++; + closed = true; + break; + } + out.push(sql[i] === "\n" ? "\n" : " "); + i++; + } + if (!closed) unterminated = "identifier"; + continue; + } + + if (ch === "$") { + const tagMatch = sql.slice(i).match(/^\$([A-Za-z_][A-Za-z0-9_]*)?\$/); + if (tagMatch) { + const tag = tagMatch[0]; + out.push(" ".repeat(tag.length)); + i += tag.length; + const closeIdx = sql.indexOf(tag, i); + if (closeIdx === -1) { + while (i < n) { + out.push(sql[i] === "\n" ? "\n" : " "); + i++; + } + unterminated = "dollar-quoted string"; + } else { + while (i < closeIdx) { + out.push(sql[i] === "\n" ? "\n" : " "); + i++; + } + out.push(" ".repeat(tag.length)); + i += tag.length; + } + continue; + } + } + + out.push(ch); + i++; + } + + return { cleaned: out.join(""), unterminated }; +} + +/** Split on unquoted `;`, trim, drop empty segments. */ +function splitStatements(cleanedSql: string): string[] { + return cleanedSql + .split(";") + .map((s) => s.trim()) + .filter((s) => s.length > 0); +} + +/** Return the first bareword keyword of a statement, or null if empty. */ +function firstKeyword(stmt: string): string | null { + const match = stmt.match(/^\s*([A-Za-z_][A-Za-z0-9_]*)/); + return match ? match[1] : null; +} diff --git a/packages/appkit/src/plugins/agents/tools/tool.ts b/packages/appkit/src/plugins/agents/tools/tool.ts new file mode 100644 index 00000000..b5d4db65 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tools/tool.ts @@ -0,0 +1,53 @@ +import type { z } from "zod"; +import type { FunctionTool } from "./function-tool"; +import { toToolJSONSchema } from "./json-schema"; + +export interface ToolConfig { + name: string; + description?: string; + schema: S; + execute: (args: z.infer) => Promise | string; +} + +/** + * Factory for defining function tools with Zod schemas. + * + * - Generates JSON Schema (for the LLM) from the Zod schema via `z.toJSONSchema()`. + * - Infers the `execute` argument type from the schema. + * - Validates tool call arguments at runtime. On validation failure, returns + * a formatted error string to the LLM instead of throwing, so the model + * can self-correct on its next turn. + */ +export function tool(config: ToolConfig): FunctionTool { + const parameters = toToolJSONSchema(config.schema) as unknown as Record< + string, + unknown + >; + + return { + type: "function", + name: config.name, + description: config.description ?? config.name, + parameters, + execute: async (args: Record) => { + const parsed = config.schema.safeParse(args); + if (!parsed.success) { + return formatZodError(parsed.error, config.name); + } + return config.execute(parsed.data as z.infer); + }, + }; +} + +/** + * Formats a Zod validation error into an LLM-friendly string. + * + * Example: `Invalid arguments for get_weather: city: Invalid input: expected string, received undefined` + */ +export function formatZodError(error: z.ZodError, toolName: string): string { + const parts = error.issues.map((issue) => { + const field = issue.path.length > 0 ? issue.path.join(".") : "(root)"; + return `${field}: ${issue.message}`; + }); + return `Invalid arguments for ${toolName}: ${parts.join("; ")}`; +} diff --git a/packages/appkit/src/plugins/agents/types.ts b/packages/appkit/src/plugins/agents/types.ts new file mode 100644 index 00000000..086a0426 --- /dev/null +++ b/packages/appkit/src/plugins/agents/types.ts @@ -0,0 +1,54 @@ +import type { AgentToolDefinition, ToolAnnotations } from "shared"; +import type { FunctionTool } from "./tools/function-tool"; +import type { HostedTool } from "./tools/hosted-tools"; + +/** + * A tool reference produced by a plugin's `.toolkit()` call. The agents plugin + * recognizes the `__toolkitRef` brand and dispatches tool invocations through + * `PluginContext.executeTool(req, pluginName, localName, ...)`, preserving + * OBO (asUser) and telemetry spans. + */ +export interface ToolkitEntry { + readonly __toolkitRef: true; + pluginName: string; + localName: string; + def: AgentToolDefinition; + annotations?: ToolAnnotations; + /** + * Whether this tool is eligible for `autoInheritTools` spreading. Mirrors + * {@link ToolEntry.autoInheritable} from the source registry so the agents + * plugin can filter auto-inherited tools without re-walking the provider's + * internal registry. + */ + autoInheritable?: boolean; +} + +/** + * Any tool an agent can invoke: inline function tools (`tool()`), hosted MCP + * tools (`mcpServer()` / raw hosted), or toolkit references from plugins + * (`analytics().toolkit()`). + */ +export type AgentTool = FunctionTool | HostedTool | ToolkitEntry; + +export interface ToolkitOptions { + /** Key prefix to prepend to each tool's local name. Defaults to `${pluginName}.`. */ + prefix?: string; + /** Only include tools whose local name matches one of these. */ + only?: string[]; + /** Exclude tools whose local name matches one of these. */ + except?: string[]; + /** Remap specific local names to different keys (applied after prefix). */ + rename?: Record; +} + +/** + * Type guard for `ToolkitEntry` — used to differentiate toolkit references + * from inline tools in a mixed `tools` record. + */ +export function isToolkitEntry(value: unknown): value is ToolkitEntry { + return ( + typeof value === "object" && + value !== null && + (value as { __toolkitRef?: unknown }).__toolkitRef === true + ); +} diff --git a/packages/appkit/src/plugins/analytics/analytics.ts b/packages/appkit/src/plugins/analytics/analytics.ts index a9c688da..78d11d4d 100644 --- a/packages/appkit/src/plugins/analytics/analytics.ts +++ b/packages/appkit/src/plugins/analytics/analytics.ts @@ -1,16 +1,26 @@ import type { WorkspaceClient } from "@databricks/sdk-experimental"; import type express from "express"; import type { + AgentToolDefinition, IAppRouter, PluginExecuteConfig, SQLTypeMarker, StreamExecutionSettings, + ToolProvider, } from "shared"; +import { z } from "zod"; import { SQLWarehouseConnector } from "../../connectors"; import { getWarehouseId, getWorkspaceClient } from "../../context"; import { createLogger } from "../../logging/logger"; import { Plugin, toPlugin } from "../../plugin"; import type { PluginManifest } from "../../registry"; +import { buildToolkitEntries } from "../agents/build-toolkit"; +import { + defineTool, + executeFromRegistry, + toolsFromRegistry, +} from "../agents/tools/define-tool"; +import { assertReadOnlySql } from "../agents/tools/sql-policy"; import { queryDefaults } from "./defaults"; import manifest from "./manifest.json"; import { QueryProcessor } from "./query"; @@ -22,7 +32,7 @@ import type { const logger = createLogger("analytics"); -export class AnalyticsPlugin extends Plugin { +export class AnalyticsPlugin extends Plugin implements ToolProvider { /** Plugin manifest declaring metadata and resource requirements */ static manifest = manifest as PluginManifest<"analytics">; @@ -262,6 +272,52 @@ export class AnalyticsPlugin extends Plugin { this.streamManager.abortAll(); } + private tools = { + query: defineTool({ + description: + "Execute a read-only SQL query against the Databricks SQL warehouse. Only SELECT, WITH, SHOW, EXPLAIN, and DESCRIBE statements are accepted; writes are rejected. Returns the query results as JSON.", + schema: z.object({ + query: z + .string() + .describe( + "The SQL query to execute. Must be a SELECT, WITH, SHOW, EXPLAIN, or DESCRIBE statement.", + ), + }), + annotations: { + readOnly: true, + requiresUserContext: true, + }, + autoInheritable: true, + handler: (args, signal) => { + assertReadOnlySql(args.query); + return this.query(args.query, undefined, undefined, signal); + }, + }), + }; + + getAgentTools(): AgentToolDefinition[] { + return toolsFromRegistry(this.tools); + } + + async executeAgentTool( + name: string, + args: unknown, + signal?: AbortSignal, + ): Promise { + return executeFromRegistry(this.tools, name, args, signal); + } + + /** + * Returns the plugin's tools as a keyed record of `ToolkitEntry` markers. + * Called by the agents plugin (via `resolveToolkitFromProvider`) to spread + * a filtered, renamed view of the plugin's tools into an agent's tool + * index. Most callers should go through `fromPlugin(analytics, opts)` at + * module scope instead of reaching for this directly. + */ + toolkit(opts?: import("../agents/types").ToolkitOptions) { + return buildToolkitEntries(this.name, this.tools, opts); + } + /** * Returns the public exports for the analytics plugin. * Note: `asUser()` is automatically added by AppKit. diff --git a/packages/appkit/src/plugins/analytics/tests/analytics.readonly.test.ts b/packages/appkit/src/plugins/analytics/tests/analytics.readonly.test.ts new file mode 100644 index 00000000..42c9b516 --- /dev/null +++ b/packages/appkit/src/plugins/analytics/tests/analytics.readonly.test.ts @@ -0,0 +1,133 @@ +import { describe, expect, test, vi } from "vitest"; + +vi.mock("../../../cache", () => ({ + CacheManager: { + getInstanceSync: vi.fn(() => ({ + get: vi.fn(), + set: vi.fn(), + delete: vi.fn(), + getOrExecute: vi.fn(async (_k: unknown[], fn: () => Promise) => + fn(), + ), + generateKey: vi.fn(() => "test-key"), + })), + }, +})); + +import { AnalyticsPlugin } from "../analytics"; + +/** + * Tests the read-only SQL enforcement on the analytics agent tool. + * + * The tool is annotated `{ readOnly: true, requiresUserContext: true }`; this + * suite verifies that the annotation is enforced at execution time — not just + * exposed as metadata to the LLM — by the `assertReadOnlySql` guard in the + * tool's handler. + */ + +function makePlugin(): AnalyticsPlugin { + return new AnalyticsPlugin({}); +} + +describe("AnalyticsPlugin.query agent tool — readOnly annotation", () => { + test("is advertised with readOnly:true and requiresUserContext:true", () => { + const plugin = makePlugin(); + const defs = plugin.getAgentTools(); + const query = defs.find((d) => d.name === "query"); + expect(query).toBeDefined(); + expect(query?.annotations).toEqual({ + readOnly: true, + requiresUserContext: true, + }); + }); +}); + +describe("AnalyticsPlugin.query agent tool — runtime enforcement", () => { + test("rejects a DROP statement before it reaches this.query", async () => { + const plugin = makePlugin(); + const spy = vi + .spyOn(plugin, "query") + // biome-ignore lint/suspicious/noExplicitAny: mocked return + .mockResolvedValue({ rows: [] } as any); + await expect( + plugin.executeAgentTool("query", { query: "DROP TABLE users" }), + ).rejects.toThrow(/read-only policy violation/i); + expect(spy).not.toHaveBeenCalled(); + }); + + test("rejects UPDATE, DELETE, INSERT, TRUNCATE, GRANT", async () => { + const plugin = makePlugin(); + const spy = vi + .spyOn(plugin, "query") + // biome-ignore lint/suspicious/noExplicitAny: mocked return + .mockResolvedValue({ rows: [] } as any); + for (const q of [ + "UPDATE users SET email='x'", + "DELETE FROM orders", + "INSERT INTO x VALUES (1)", + "TRUNCATE TABLE orders", + "GRANT SELECT ON t TO u", + ]) { + await expect( + plugin.executeAgentTool("query", { query: q }), + ).rejects.toThrow(/read-only policy violation/i); + } + expect(spy).not.toHaveBeenCalled(); + }); + + test("rejects a stacked SELECT + DROP", async () => { + const plugin = makePlugin(); + const spy = vi + .spyOn(plugin, "query") + // biome-ignore lint/suspicious/noExplicitAny: mocked return + .mockResolvedValue({ rows: [] } as any); + await expect( + plugin.executeAgentTool("query", { + query: "SELECT 1; DROP TABLE users", + }), + ).rejects.toThrow(/DROP/); + expect(spy).not.toHaveBeenCalled(); + }); + + test("passes a plain SELECT through to this.query", async () => { + const plugin = makePlugin(); + const spy = vi + .spyOn(plugin, "query") + // biome-ignore lint/suspicious/noExplicitAny: mocked return + .mockResolvedValue({ rows: [{ id: 1 }] } as any); + const result = await plugin.executeAgentTool("query", { + query: "SELECT * FROM main.sales.orders", + }); + expect(result).toEqual({ rows: [{ id: 1 }] }); + expect(spy).toHaveBeenCalledWith( + "SELECT * FROM main.sales.orders", + undefined, + undefined, + undefined, + ); + }); + + test("passes WITH … SELECT through", async () => { + const plugin = makePlugin(); + const spy = vi + .spyOn(plugin, "query") + // biome-ignore lint/suspicious/noExplicitAny: mocked return + .mockResolvedValue({ rows: [] } as any); + await plugin.executeAgentTool("query", { + query: "WITH a AS (SELECT 1) SELECT * FROM a", + }); + expect(spy).toHaveBeenCalledOnce(); + }); + + test("passes SHOW TABLES through", async () => { + const plugin = makePlugin(); + const spy = vi + .spyOn(plugin, "query") + // biome-ignore lint/suspicious/noExplicitAny: mocked return + .mockResolvedValue({ rows: [] } as any); + await plugin.executeAgentTool("query", { + query: "SHOW TABLES IN main.sales", + }); + expect(spy).toHaveBeenCalledOnce(); + }); +}); diff --git a/packages/appkit/src/plugins/analytics/tests/analytics.test.ts b/packages/appkit/src/plugins/analytics/tests/analytics.test.ts index 9a30440e..29157fff 100644 --- a/packages/appkit/src/plugins/analytics/tests/analytics.test.ts +++ b/packages/appkit/src/plugins/analytics/tests/analytics.test.ts @@ -608,4 +608,22 @@ describe("Analytics Plugin", () => { }); }); }); + + describe("toolkit()", () => { + test("produces ToolkitEntry records keyed by the plugin name", () => { + const plugin = new AnalyticsPlugin({ name: "analytics" }); + const entries = plugin.toolkit(); + expect(Object.keys(entries)).toContain("analytics.query"); + const entry = entries["analytics.query"]; + expect(entry.__toolkitRef).toBe(true); + expect(entry.pluginName).toBe("analytics"); + expect(entry.localName).toBe("query"); + }); + + test("respects prefix and only options", () => { + const plugin = new AnalyticsPlugin({ name: "analytics" }); + const entries = plugin.toolkit({ prefix: "", only: ["query"] }); + expect(Object.keys(entries)).toEqual(["query"]); + }); + }); }); diff --git a/packages/appkit/src/plugins/files/plugin.ts b/packages/appkit/src/plugins/files/plugin.ts index 75f2e14d..fc768c5a 100644 --- a/packages/appkit/src/plugins/files/plugin.ts +++ b/packages/appkit/src/plugins/files/plugin.ts @@ -2,19 +2,37 @@ import { STATUS_CODES } from "node:http"; import { Readable } from "node:stream"; import { ApiError } from "@databricks/sdk-experimental"; import type express from "express"; -import type { IAppRouter, PluginExecutionSettings } from "shared"; +import type { + AgentToolDefinition, + IAppRouter, + PluginExecutionSettings, + ToolProvider, +} from "shared"; +import { z } from "zod"; import { contentTypeFromPath, FilesConnector, isSafeInlineContentType, validateCustomContentTypes, } from "../../connectors/files"; -import { getCurrentUserId, getWorkspaceClient } from "../../context"; +import { + getCurrentUserId, + getExecutionContext, + getWorkspaceClient, +} from "../../context"; +import { isUserContext } from "../../context/user-context"; import { AuthenticationError } from "../../errors"; import { createLogger } from "../../logging/logger"; import { Plugin, toPlugin } from "../../plugin"; import type { PluginManifest, ResourceRequirement } from "../../registry"; import { ResourceType } from "../../registry"; +import { buildToolkitEntries } from "../agents/build-toolkit"; +import { + defineTool, + executeFromRegistry, + type ToolRegistry, + toolsFromRegistry, +} from "../agents/tools/define-tool"; import { FILES_DOWNLOAD_DEFAULTS, FILES_MAX_UPLOAD_SIZE, @@ -41,7 +59,7 @@ import type { const logger = createLogger("files"); -export class FilesPlugin extends Plugin { +export class FilesPlugin extends Plugin implements ToolProvider { name = "files"; /** Plugin manifest declaring metadata and resource requirements. */ @@ -52,6 +70,7 @@ export class FilesPlugin extends Plugin { private volumeConnectors: Record = {}; private volumeConfigs: Record = {}; private volumeKeys: string[] = []; + private tools: ToolRegistry = {}; /** * Scans `process.env` for `DATABRICKS_VOLUME_*` keys and merges them with @@ -226,6 +245,10 @@ export class FilesPlugin extends Plugin { }); } + for (const volumeKey of this.volumeKeys) { + Object.assign(this.tools, this._defineVolumeTools(volumeKey)); + } + // Warn at startup for volumes without an explicit policy for (const key of this.volumeKeys) { if (!volumes[key].policy) { @@ -1019,6 +1042,91 @@ export class FilesPlugin extends Plugin { }; } + /** + * Builds the agent-tool registry entries for a single volume. One set of + * tools per configured volume, keyed by `${volumeKey}.${method}`. + * + * Each handler resolves the caller's identity from the current execution + * context (OBO user when the agent run is wrapped in `asUser(req)`, service + * principal otherwise in local dev) and dispatches through + * `createVolumeAPI(volumeKey, user)` so the volume's policy is enforced + * uniformly for agent and HTTP callers. + */ + private _defineVolumeTools(volumeKey: string): ToolRegistry { + const buildUser = (): FilePolicyUser => { + const ctx = getExecutionContext(); + return isUserContext(ctx) + ? { id: ctx.userId } + : { id: ctx.serviceUserId, isServicePrincipal: true }; + }; + const api = () => this.createVolumeAPI(volumeKey, buildUser()); + return { + [`${volumeKey}.list`]: defineTool({ + description: `List files and directories in the "${volumeKey}" volume`, + schema: z.object({ + path: z + .string() + .optional() + .describe("Directory path to list (optional, defaults to root)"), + }), + annotations: { readOnly: true, requiresUserContext: true }, + autoInheritable: true, + handler: (args) => api().list(args.path), + }), + [`${volumeKey}.read`]: defineTool({ + description: `Read a text file from the "${volumeKey}" volume`, + schema: z.object({ + path: z.string().describe("File path to read"), + }), + annotations: { readOnly: true, requiresUserContext: true }, + autoInheritable: true, + handler: (args) => api().read(args.path), + }), + [`${volumeKey}.exists`]: defineTool({ + description: `Check if a file or directory exists in the "${volumeKey}" volume`, + schema: z.object({ + path: z.string().describe("Path to check"), + }), + annotations: { readOnly: true, requiresUserContext: true }, + autoInheritable: true, + handler: (args) => api().exists(args.path), + }), + [`${volumeKey}.metadata`]: defineTool({ + description: `Get metadata (size, type, last modified) for a file in the "${volumeKey}" volume`, + schema: z.object({ + path: z.string().describe("File path"), + }), + annotations: { readOnly: true, requiresUserContext: true }, + autoInheritable: true, + handler: (args) => api().metadata(args.path), + }), + [`${volumeKey}.upload`]: defineTool({ + description: `Upload a text file to the "${volumeKey}" volume`, + schema: z.object({ + path: z.string().describe("Destination file path"), + contents: z.string().describe("File contents as a string"), + overwrite: z + .boolean() + .optional() + .describe("Whether to overwrite existing file"), + }), + annotations: { destructive: true, requiresUserContext: true }, + handler: (args) => + api().upload(args.path, args.contents, { + overwrite: args.overwrite, + }), + }), + [`${volumeKey}.delete`]: defineTool({ + description: `Delete a file from the "${volumeKey}" volume`, + schema: z.object({ + path: z.string().describe("File path to delete"), + }), + annotations: { destructive: true, requiresUserContext: true }, + handler: (args) => api().delete(args.path), + }), + }; + } + private inflightWrites = 0; private trackWrite(fn: () => Promise): Promise { @@ -1047,6 +1155,22 @@ export class FilesPlugin extends Plugin { this.streamManager.abortAll(); } + getAgentTools(): AgentToolDefinition[] { + return toolsFromRegistry(this.tools); + } + + async executeAgentTool( + name: string, + args: unknown, + signal?: AbortSignal, + ): Promise { + return executeFromRegistry(this.tools, name, args, signal); + } + + toolkit(opts?: import("../agents/types").ToolkitOptions) { + return buildToolkitEntries(this.name, this.tools, opts); + } + /** * Returns the programmatic API for the Files plugin. * Callable with a volume key to get a volume-scoped handle. diff --git a/packages/appkit/src/plugins/files/tests/plugin.test.ts b/packages/appkit/src/plugins/files/tests/plugin.test.ts index a4b9bea2..bbaa1b98 100644 --- a/packages/appkit/src/plugins/files/tests/plugin.test.ts +++ b/packages/appkit/src/plugins/files/tests/plugin.test.ts @@ -205,6 +205,62 @@ describe("FilesPlugin", () => { }); }); + describe("getAgentTools / executeAgentTool", () => { + test("produces independent tool entries per volume", () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const tools = plugin.getAgentTools(); + const names = tools.map((t) => t.name); + + expect(names).toContain("uploads.list"); + expect(names).toContain("uploads.read"); + expect(names).toContain("uploads.exists"); + expect(names).toContain("uploads.metadata"); + expect(names).toContain("uploads.upload"); + expect(names).toContain("uploads.delete"); + + expect(names).toContain("exports.list"); + expect(names).toContain("exports.read"); + expect(names).toContain("exports.delete"); + + expect(tools).toHaveLength(12); + }); + + test("dispatches to the correct volume API based on the tool name", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const asyncIterable = (items: { path: string }[]) => ({ + [Symbol.asyncIterator]: async function* () { + for (const item of items) yield item; + }, + }); + mockClient.files.listDirectoryContents.mockReturnValueOnce( + asyncIterable([{ path: "uploads-file" }]), + ); + mockClient.files.listDirectoryContents.mockReturnValueOnce( + asyncIterable([{ path: "exports-file" }]), + ); + + const uploadsResult = (await plugin.executeAgentTool( + "uploads.list", + {}, + )) as { path: string }[]; + const exportsResult = (await plugin.executeAgentTool( + "exports.list", + {}, + )) as { path: string }[]; + + expect(uploadsResult[0].path).toBe("uploads-file"); + expect(exportsResult[0].path).toBe("exports-file"); + }); + + test("returns LLM-friendly error string for invalid tool args", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const result = await plugin.executeAgentTool("uploads.read", {}); + expect(typeof result).toBe("string"); + expect(result).toContain("Invalid arguments for uploads.read"); + expect(result).toContain("path"); + }); + }); + describe("exports()", () => { test("returns a callable function with a .volume alias", () => { const plugin = new FilesPlugin(VOLUMES_CONFIG); diff --git a/packages/appkit/src/plugins/genie/genie.ts b/packages/appkit/src/plugins/genie/genie.ts index 712aadbf..3167794e 100644 --- a/packages/appkit/src/plugins/genie/genie.ts +++ b/packages/appkit/src/plugins/genie/genie.ts @@ -1,11 +1,24 @@ import { randomUUID } from "node:crypto"; import type express from "express"; -import type { IAppRouter, StreamExecutionSettings } from "shared"; +import type { + AgentToolDefinition, + IAppRouter, + StreamExecutionSettings, + ToolProvider, +} from "shared"; +import { z } from "zod"; import { GenieConnector } from "../../connectors"; import { getWorkspaceClient } from "../../context"; import { createLogger } from "../../logging"; import { Plugin, toPlugin } from "../../plugin"; import type { PluginManifest } from "../../registry"; +import { buildToolkitEntries } from "../agents/build-toolkit"; +import { + defineTool, + executeFromRegistry, + type ToolRegistry, + toolsFromRegistry, +} from "../agents/tools/define-tool"; import { genieStreamDefaults } from "./defaults"; import manifest from "./manifest.json"; import type { @@ -17,7 +30,7 @@ import type { const logger = createLogger("genie"); -export class GeniePlugin extends Plugin { +export class GeniePlugin extends Plugin implements ToolProvider { static manifest = manifest as PluginManifest<"genie">; protected static description = @@ -25,6 +38,7 @@ export class GeniePlugin extends Plugin { protected declare config: IGenieConfig; private readonly genieConnector: GenieConnector; + private tools: ToolRegistry = {}; constructor(config: IGenieConfig) { super(config); @@ -36,6 +50,54 @@ export class GeniePlugin extends Plugin { timeout: this.config.timeout, maxMessages: 200, }); + + for (const alias of Object.keys(this.config.spaces ?? {})) { + Object.assign(this.tools, this._defineSpaceTools(alias)); + } + } + + /** + * Builds the registry entries for a single Genie space alias. + * One set of tools per configured space, keyed by `${alias}.${method}`. + */ + private _defineSpaceTools(alias: string): ToolRegistry { + return { + [`${alias}.sendMessage`]: defineTool({ + description: `Send a natural language question to the Genie space "${alias}" and get data analysis results`, + schema: z.object({ + content: z.string().describe("The natural language question to ask"), + conversationId: z + .string() + .optional() + .describe( + "Optional conversation ID to continue an existing conversation", + ), + }), + annotations: { requiresUserContext: true }, + handler: async (args) => { + const events: GenieStreamEvent[] = []; + for await (const event of this.sendMessage( + alias, + args.content, + args.conversationId, + )) { + events.push(event); + } + return events; + }, + }), + [`${alias}.getConversation`]: defineTool({ + description: `Retrieve the conversation history from the Genie space "${alias}"`, + schema: z.object({ + conversationId: z + .string() + .describe("The conversation ID to retrieve"), + }), + annotations: { readOnly: true, requiresUserContext: true }, + autoInheritable: true, + handler: (args) => this.getConversation(alias, args.conversationId), + }), + }; } private defaultSpaces(): Record { @@ -287,6 +349,22 @@ export class GeniePlugin extends Plugin { this.streamManager.abortAll(); } + getAgentTools(): AgentToolDefinition[] { + return toolsFromRegistry(this.tools); + } + + async executeAgentTool( + name: string, + args: unknown, + signal?: AbortSignal, + ): Promise { + return executeFromRegistry(this.tools, name, args, signal); + } + + toolkit(opts?: import("../agents/types").ToolkitOptions) { + return buildToolkitEntries(this.name, this.tools, opts); + } + exports() { return { sendMessage: this.sendMessage, diff --git a/packages/appkit/src/plugins/genie/tests/genie.test.ts b/packages/appkit/src/plugins/genie/tests/genie.test.ts index 3cf0784d..672e6242 100644 --- a/packages/appkit/src/plugins/genie/tests/genie.test.ts +++ b/packages/appkit/src/plugins/genie/tests/genie.test.ts @@ -187,6 +187,30 @@ describe("Genie Plugin", () => { }); }); + describe("getAgentTools / executeAgentTool", () => { + test("produces independent tool entries per configured space", () => { + const plugin = new GeniePlugin(config); + const names = plugin.getAgentTools().map((t) => t.name); + + expect(names).toContain("myspace.sendMessage"); + expect(names).toContain("myspace.getConversation"); + expect(names).toContain("salesbot.sendMessage"); + expect(names).toContain("salesbot.getConversation"); + expect(names).toHaveLength(4); + }); + + test("returns LLM-friendly error string for invalid tool args", async () => { + const plugin = new GeniePlugin(config); + const result = await plugin.executeAgentTool( + "myspace.getConversation", + {}, + ); + expect(typeof result).toBe("string"); + expect(result).toContain("Invalid arguments for myspace.getConversation"); + expect(result).toContain("conversationId"); + }); + }); + describe("space alias resolution", () => { test("should return 404 for unknown alias", async () => { const plugin = new GeniePlugin(config); diff --git a/packages/appkit/src/plugins/lakebase/lakebase.ts b/packages/appkit/src/plugins/lakebase/lakebase.ts index 3071d539..aaf61b51 100644 --- a/packages/appkit/src/plugins/lakebase/lakebase.ts +++ b/packages/appkit/src/plugins/lakebase/lakebase.ts @@ -1,4 +1,6 @@ import type { Pool, QueryResult, QueryResultRow } from "pg"; +import type { AgentToolDefinition, ToolProvider } from "shared"; +import { z } from "zod"; import { createLakebasePool, getLakebaseOrmConfig, @@ -8,6 +10,13 @@ import { import { createLogger } from "../../logging/logger"; import { Plugin, toPlugin } from "../../plugin"; import type { PluginManifest } from "../../registry"; +import { buildToolkitEntries } from "../agents/build-toolkit"; +import { + defineTool, + executeFromRegistry, + toolsFromRegistry, +} from "../agents/tools/define-tool"; +import { assertReadOnlySql } from "../agents/tools/sql-policy"; import manifest from "./manifest.json"; import type { ILakebaseConfig } from "./types"; @@ -30,18 +39,13 @@ const logger = createLogger("lakebase"); * const result = await AppKit.lakebase.query("SELECT * FROM users WHERE id = $1", [userId]); * ``` */ -class LakebasePlugin extends Plugin { +export class LakebasePlugin extends Plugin implements ToolProvider { /** Plugin manifest declaring metadata and resource requirements */ static manifest = manifest as PluginManifest<"lakebase">; protected declare config: ILakebaseConfig; private pool: Pool | null = null; - constructor(config: ILakebaseConfig) { - super(config); - this.config = config; - } - /** * Initializes the Lakebase connection pool. * Called automatically by AppKit during the plugin setup phase. @@ -79,6 +83,39 @@ class LakebasePlugin extends Plugin { return this.pool!.query(text, values); } + /** + * Execute a single statement inside a `BEGIN READ ONLY … ROLLBACK` + * transaction on a dedicated client. + * + * The three commands MUST share a connection — a naive + * `pool.query("BEGIN READ ONLY; ; ROLLBACK")` batch cannot accept + * parameter values (PostgreSQL's Extended Query protocol rejects multi- + * statement prepared queries), which would silently break every + * parameterized query the agent tool issues. + * + * Returns the raw `rows` array for the user's statement. Side effects the + * statement may attempt (writes, writable-function side effects) are + * rejected by PostgreSQL under the read-only transaction posture. + */ + private async runReadOnlyStatement( + text: string, + values?: unknown[], + ): Promise { + // biome-ignore lint/style/noNonNullAssertion: pool is guaranteed non-null after setup() + const client = await this.pool!.connect(); + try { + await client.query("BEGIN READ ONLY"); + const result = await client.query(text, values); + return result.rows; + } finally { + try { + await client.query("ROLLBACK"); + } finally { + client.release(); + } + } + } + /** * Gracefully drains and closes the connection pool. * Called automatically by AppKit during shutdown. @@ -102,6 +139,82 @@ class LakebasePlugin extends Plugin { * - `getOrmConfig()` — Returns a config object compatible with Drizzle, TypeORM, Sequelize, etc. * - `getPgConfig()` — Returns a `pg.PoolConfig` object for manual pool construction */ + + /** + * Agent tool registry. Empty by default — the Lakebase plugin does NOT + * expose its SQL connection to LLM agents unless the developer explicitly + * opts in via `config.exposeAsAgentTool`. See {@link buildQueryTool}. + */ + private tools: Record> = {}; + + constructor(config: ILakebaseConfig) { + super(config); + this.config = config; + if (config.exposeAsAgentTool) { + if (config.exposeAsAgentTool.iUnderstandRunsAsServicePrincipal !== true) { + throw new Error( + "lakebase.exposeAsAgentTool requires iUnderstandRunsAsServicePrincipal: true — this acknowledges that SQL statements authored by the LLM run with the application's service-principal credentials regardless of which end user initiated the request.", + ); + } + this.tools = { query: this.buildQueryTool(config.exposeAsAgentTool) }; + logger.warn( + "Lakebase agent tool is enabled (readOnly=%s). Every agent with access to this plugin can execute SQL against the Lakebase database as the service principal.", + config.exposeAsAgentTool.readOnly !== false, + ); + } + } + + private buildQueryTool( + opt: NonNullable, + ) { + const readOnly = opt.readOnly !== false; + return defineTool({ + description: readOnly + ? "Execute a read-only SQL query against the Lakebase PostgreSQL database. Only SELECT, WITH, SHOW, EXPLAIN, and DESCRIBE statements are accepted. Use $1, $2, etc. as placeholders and pass values separately. Runs as the application's service principal." + : "Execute a parameterized SQL statement against the Lakebase PostgreSQL database. Use $1, $2, etc. as placeholders and pass values separately. Runs as the application's service principal. This tool can modify data; every invocation requires explicit human approval.", + schema: z.object({ + text: z + .string() + .describe( + "SQL statement with $1, $2, ... placeholders for parameters", + ), + values: z + .array(z.unknown()) + .optional() + .describe("Parameter values corresponding to placeholders"), + }), + annotations: { + readOnly, + destructive: !readOnly, + idempotent: false, + }, + handler: async (args) => { + if (readOnly) { + assertReadOnlySql(args.text); + return this.runReadOnlyStatement(args.text, args.values); + } + const result = await this.query(args.text, args.values); + return result.rows; + }, + }); + } + + getAgentTools(): AgentToolDefinition[] { + return toolsFromRegistry(this.tools); + } + + async executeAgentTool( + name: string, + args: unknown, + signal?: AbortSignal, + ): Promise { + return executeFromRegistry(this.tools, name, args, signal); + } + + toolkit(opts?: import("../agents/types").ToolkitOptions) { + return buildToolkitEntries(this.name, this.tools, opts); + } + exports() { return { // biome-ignore lint/style/noNonNullAssertion: pool is guaranteed non-null after setup(), which AppKit always awaits before exposing the plugin API diff --git a/packages/appkit/src/plugins/lakebase/tests/lakebase-agent-tool.test.ts b/packages/appkit/src/plugins/lakebase/tests/lakebase-agent-tool.test.ts new file mode 100644 index 00000000..8e59fb32 --- /dev/null +++ b/packages/appkit/src/plugins/lakebase/tests/lakebase-agent-tool.test.ts @@ -0,0 +1,238 @@ +import { beforeEach, describe, expect, test, vi } from "vitest"; + +/** + * Tests the agent-tool surface of the Lakebase plugin. + * + * The plugin defaults to **not** exposing an agent tool at all. Enabling the + * tool is an explicit opt-in (`exposeAsAgentTool` with an acknowledgement + * flag) because every invocation runs with the application's service- + * principal credentials regardless of which end user initiated the request. + */ + +vi.mock("../../../cache", () => ({ + CacheManager: { + getInstanceSync: vi.fn(() => ({ + get: vi.fn(), + set: vi.fn(), + delete: vi.fn(), + getOrExecute: vi.fn(async (_k: unknown[], fn: () => Promise) => + fn(), + ), + generateKey: vi.fn(() => "test-key"), + })), + }, +})); + +// Client calls recorded by the read-only-statement test. The `connect()` +// mock returns a fresh client whose `query` pushes to this array so tests +// can assert the exact sequence of statements emitted on the dedicated +// connection. +const clientQueries: Array<{ text: string; values?: unknown[] }> = []; +const clientReleases: number[] = []; + +vi.mock("../../../connectors/lakebase", () => ({ + createLakebasePool: vi.fn(() => ({ + query: vi.fn(), + connect: vi.fn(async () => { + let releaseCalls = 0; + return { + query: vi.fn(async (text: string, values?: unknown[]) => { + clientQueries.push({ text, values }); + return { rows: [{ n: 1 }] }; + }), + release: vi.fn(() => { + releaseCalls += 1; + clientReleases.push(releaseCalls); + }), + }; + }), + end: vi.fn(), + })), + getLakebaseOrmConfig: vi.fn(() => ({})), + getLakebasePgConfig: vi.fn(() => ({})), + getUsernameWithApiLookup: vi.fn(async () => "test-user"), +})); + +import { LakebasePlugin } from "../lakebase"; + +function makePlugin( + config: ConstructorParameters[0], +): LakebasePlugin { + return new LakebasePlugin(config); +} + +describe("LakebasePlugin — agent tool opt-in", () => { + test("does not register an agent tool by default", () => { + const plugin = makePlugin({}); + expect(plugin.getAgentTools()).toEqual([]); + }); + + test("does not register a tool when `pool` is set but `exposeAsAgentTool` is absent", () => { + const plugin = makePlugin({ pool: {} }); + expect(plugin.getAgentTools()).toEqual([]); + }); + + test("throws when exposeAsAgentTool is set without the acknowledgement flag", () => { + expect(() => + makePlugin({ + exposeAsAgentTool: + // biome-ignore lint/suspicious/noExplicitAny: intentionally bypass the required flag for the negative case + {} as any, + }), + ).toThrow(/iUnderstandRunsAsServicePrincipal/); + }); + + test("registers a read-only tool when opted in with defaults", () => { + const plugin = makePlugin({ + exposeAsAgentTool: { iUnderstandRunsAsServicePrincipal: true }, + }); + const defs = plugin.getAgentTools(); + expect(defs).toHaveLength(1); + expect(defs[0].name).toBe("query"); + expect(defs[0].annotations).toEqual({ + readOnly: true, + destructive: false, + idempotent: false, + }); + }); + + test("registers a destructive tool when readOnly: false is explicit", () => { + const plugin = makePlugin({ + exposeAsAgentTool: { + iUnderstandRunsAsServicePrincipal: true, + readOnly: false, + }, + }); + const defs = plugin.getAgentTools(); + expect(defs[0].annotations).toEqual({ + readOnly: false, + destructive: true, + idempotent: false, + }); + }); +}); + +describe("LakebasePlugin — readOnly enforcement", () => { + let plugin: LakebasePlugin; + + beforeEach(async () => { + clientQueries.length = 0; + clientReleases.length = 0; + plugin = makePlugin({ + exposeAsAgentTool: { iUnderstandRunsAsServicePrincipal: true }, + }); + await plugin.setup(); + }); + + test("rejects DROP before acquiring a client", async () => { + await expect( + plugin.executeAgentTool("query", { text: "DROP TABLE users" }), + ).rejects.toThrow(/read-only policy violation/i); + expect(clientQueries).toHaveLength(0); + }); + + test("rejects UPDATE, DELETE, INSERT", async () => { + for (const text of [ + "UPDATE users SET email='x'", + "DELETE FROM orders", + "INSERT INTO x VALUES (1)", + ]) { + await expect(plugin.executeAgentTool("query", { text })).rejects.toThrow( + /read-only policy violation/i, + ); + } + expect(clientQueries).toHaveLength(0); + }); + + test("runs SELECT inside BEGIN READ ONLY / ROLLBACK on a dedicated client", async () => { + const rows = await plugin.executeAgentTool("query", { + text: "SELECT * FROM users", + }); + expect(rows).toEqual([{ n: 1 }]); + expect(clientQueries.map((c) => c.text)).toEqual([ + "BEGIN READ ONLY", + "SELECT * FROM users", + "ROLLBACK", + ]); + // Client must be released exactly once, regardless of outcome. + expect(clientReleases).toHaveLength(1); + }); + + test("forwards parameter values to the user statement only (the regression fix)", async () => { + // Prior to the fix this would have failed with "cannot insert multiple + // commands into a prepared statement" because pg's Extended Query + // protocol rejects multi-statement batches when values are supplied. + await plugin.executeAgentTool("query", { + text: "SELECT * FROM users WHERE id = $1", + values: [42], + }); + expect(clientQueries).toEqual([ + { text: "BEGIN READ ONLY", values: undefined }, + { text: "SELECT * FROM users WHERE id = $1", values: [42] }, + { text: "ROLLBACK", values: undefined }, + ]); + }); + + test("releases the client even when the user statement throws", async () => { + // Poison the client so the middle query throws (simulates a Postgres + // error like "cannot execute UPDATE in a read-only transaction"). + const { createLakebasePool } = await import("../../../connectors/lakebase"); + const connect = vi.fn(async () => ({ + query: vi + .fn() + .mockResolvedValueOnce({ rows: [] }) + .mockRejectedValueOnce(new Error("read-only violation")) + .mockResolvedValueOnce({ rows: [] }), + release: vi.fn(() => { + clientReleases.push(clientReleases.length + 1); + }), + })); + // biome-ignore lint/suspicious/noExplicitAny: test override + ( + createLakebasePool as unknown as { mockReturnValueOnce: any } + ).mockReturnValueOnce({ query: vi.fn(), connect, end: vi.fn() }); + + clientQueries.length = 0; + clientReleases.length = 0; + const leakyPlugin = makePlugin({ + exposeAsAgentTool: { iUnderstandRunsAsServicePrincipal: true }, + }); + await leakyPlugin.setup(); + + await expect( + leakyPlugin.executeAgentTool("query", { + text: "SELECT * FROM users", + }), + ).rejects.toThrow(/read-only violation/); + expect(clientReleases).toHaveLength(1); + }); +}); + +describe("LakebasePlugin — destructive mode", () => { + test("does NOT wrap in read-only transaction when readOnly: false", async () => { + const queryMock = vi.fn((_text: string, _values?: unknown[]) => + Promise.resolve({ rows: [] }), + ); + const plugin = makePlugin({ + exposeAsAgentTool: { + iUnderstandRunsAsServicePrincipal: true, + readOnly: false, + }, + }); + await plugin.setup(); + vi.spyOn(plugin, "query").mockImplementation(async (text, values) => { + queryMock(text, values); + return { rows: [] } as never; + }); + + await plugin.executeAgentTool("query", { + text: "UPDATE t SET x=1 WHERE id=$1", + values: [42], + }); + + expect(queryMock).toHaveBeenCalledWith( + "UPDATE t SET x=1 WHERE id=$1", + [42], + ); + }); +}); diff --git a/packages/appkit/src/plugins/lakebase/types.ts b/packages/appkit/src/plugins/lakebase/types.ts index ac6997c6..6703e425 100644 --- a/packages/appkit/src/plugins/lakebase/types.ts +++ b/packages/appkit/src/plugins/lakebase/types.ts @@ -1,6 +1,42 @@ import type { BasePluginConfig } from "shared"; import type { LakebasePoolConfig } from "../../connectors/lakebase"; +/** + * Opt-in configuration for exposing Lakebase as an agent-callable SQL tool. + * + * This tool executes LLM-authored SQL against the Lakebase pool. The pool is + * **always bound to the application's service-principal credentials**, so any + * agent that can call this tool effectively has full SP access to the database + * regardless of which end user initiated the request. Exposing it is a + * deliberate decision the developer must make explicitly — hence the required + * acknowledgement flag. + * + * When `readOnly: true` (default when opted in), every statement is: + * 1. Classified by {@link @databricks/appkit's sql-policy classifier}; anything + * that isn't a pure `SELECT`/`WITH`/`SHOW`/`EXPLAIN`/`DESCRIBE` is rejected. + * 2. Executed inside a `BEGIN READ ONLY … ROLLBACK` transaction so the + * PostgreSQL server rejects writes that slip past the classifier (e.g., a + * `SELECT` over a function with side effects). + * + * When `readOnly: false`, the tool is annotated `destructive: true` and the + * agents plugin will require human approval for every invocation (see + * `AgentsPluginConfig.approval`). + */ +export interface LakebaseExposeAsAgentTool { + /** + * Required acknowledgement that tool invocations run as the service principal + * and share that privilege across end users. Must be set to `true` to opt in. + */ + iUnderstandRunsAsServicePrincipal: true; + /** + * Enforce read-only execution. Defaults to `true`. Set to `false` to allow + * destructive statements — highly discouraged outside of tightly controlled + * single-user deployments. Combined with the `destructive: true` annotation, + * the agents plugin will require explicit human approval for each call. + */ + readOnly?: boolean; +} + /** * Configuration for the Lakebase plugin. * @@ -17,4 +53,11 @@ export interface ILakebaseConfig extends BasePluginConfig { * Common overrides: `max` (pool size), `connectionTimeoutMillis`, `idleTimeoutMillis`. */ pool?: Partial; + /** + * Opt-in to expose Lakebase as an agent-callable SQL tool. By default no + * agent tool is registered — the Lakebase plugin only exposes its API to + * application code. See {@link LakebaseExposeAsAgentTool} for the privilege + * implications of enabling this. + */ + exposeAsAgentTool?: LakebaseExposeAsAgentTool; } From 73a696a4a37d35ab3243069938a94e3d03eeb89b Mon Sep 17 00:00:00 2001 From: MarioCadenas Date: Thu, 23 Apr 2026 17:29:56 +0200 Subject: [PATCH 11/23] docs(appkit): explain hand-rolled AppKitMcpClient vs official MCP SDK Add a file-level rationale (policy/auth, narrow scope, zero extra deps) and point the class JSDoc at it to avoid duplicating the same story in two places. Signed-off-by: MarioCadenas --- .../src/plugins/agents/tools/mcp-client.ts | 42 ++++++++++++------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/packages/appkit/src/plugins/agents/tools/mcp-client.ts b/packages/appkit/src/plugins/agents/tools/mcp-client.ts index fcb58d59..49db7882 100644 --- a/packages/appkit/src/plugins/agents/tools/mcp-client.ts +++ b/packages/appkit/src/plugins/agents/tools/mcp-client.ts @@ -1,3 +1,27 @@ +/** + * Custom MCP over HTTP (Streamable) — not `@modelcontextprotocol/sdk` + * + * This module implements a tiny JSON-RPC 2.0 client on `fetch` for the subset + * of MCP we need: `initialize`, `notifications/initialized`, `tools/list`, + * `tools/call` over a single JSON request/response. We do not use the official + * SDK because: + * + * - **Policy and auth are the product** — every outbound URL is checked with + * {@link McpHostPolicy} (allowlist, DNS, private/blocked IP ranges) before + * the first byte is sent, and workspace tokens are only forwarded when + * `forwardWorkspaceAuth` is true for that destination. A generic transport + * from the SDK would still need the same hooks; re-wrapping it would be + * about as much code, with a larger third-party surface to audit. + * - **Narrow scope** — we only target Databricks-hosted MCP over Streamable + * HTTP, not stdio, full SSE sessions, or the rest of the protocol. A + * hand-rolled path keeps the call graph obvious in code review. + * - **Zero extra runtime dependency** for this path, consistent with other + * small, security-sensitive AppKit pieces. + * + * Revisit if we add more transports, or if the SDK ships a first-class way to + * inject our host policy and per-URL auth without fighting the default + * transport. + */ import type { AgentToolDefinition } from "shared"; import { createLogger } from "../../../logging/logger"; import type { McpEndpointConfig } from "./hosted-tools"; @@ -60,22 +84,8 @@ interface McpServerConnection { * are rejected before the first byte is sent, and workspace credentials are * only forwarded to the same-origin workspace. See `mcp-host-policy.ts`. * - * ### Why not `@modelcontextprotocol/sdk`? - * - * Our security model — zero-trust host policy + per-destination auth scoping - * via `forwardWorkspaceAuth` — is the load-bearing anchor of this module, not - * the wire protocol. Wrapping the SDK to enforce the same invariants (custom - * transport + auth plumbing + policy checks before every request) ends up - * the same size as writing the narrow JSON-RPC subset directly, but with an - * opaque dependency that's harder to audit. Our entire target surface is - * Databricks-hosted MCP over Streamable HTTP; we never need stdio, SSE-only - * transport, or the rest of the protocol the SDK covers. - * - * Zero runtime deps matches the rest of AppKit's philosophy - * (`DatabricksAdapter`, `sql-policy`). Revisit the call if MCP adoption - * grows beyond Databricks-hosted servers with the transports above, or if - * the SDK gains transport-level auth hooks that cleanly match our - * per-destination auth-forwarding model. + * Rationale for hand-rolling JSON-RPC instead of `@modelcontextprotocol/sdk`: + * see the file-level comment at the top of this module. */ export class AppKitMcpClient { private connections = new Map(); From 1055f0fe1f4702cbb9adffe940141430c9920432 Mon Sep 17 00:00:00 2001 From: MarioCadenas Date: Thu, 23 Apr 2026 17:35:10 +0200 Subject: [PATCH 12/23] refactor(appkit): merge FilesPlugin ctor volume loops Single pass over volumes: connectors, toolkit tools, and policy warnings. Signed-off-by: MarioCadenas --- packages/appkit/src/plugins/files/plugin.ts | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/packages/appkit/src/plugins/files/plugin.ts b/packages/appkit/src/plugins/files/plugin.ts index fc768c5a..1a80e868 100644 --- a/packages/appkit/src/plugins/files/plugin.ts +++ b/packages/appkit/src/plugins/files/plugin.ts @@ -243,15 +243,11 @@ export class FilesPlugin extends Plugin implements ToolProvider { telemetry: config.telemetry, customContentTypes: mergedConfig.customContentTypes, }); - } - for (const volumeKey of this.volumeKeys) { - Object.assign(this.tools, this._defineVolumeTools(volumeKey)); - } + Object.assign(this.tools, this._defineVolumeTools(key)); - // Warn at startup for volumes without an explicit policy - for (const key of this.volumeKeys) { - if (!volumes[key].policy) { + // Warn at startup for volumes without an explicit policy + if (!volumeCfg.policy) { logger.warn( 'Volume "%s" has no explicit policy — defaulting to publicRead(). ' + "Set a policy in files({ volumes: { %s: { policy: ... } } }) to silence this warning.", From 9183ce66a894dba46387d06e0b7d6cff9de7ac5e Mon Sep 17 00:00:00 2001 From: MarioCadenas Date: Tue, 21 Apr 2026 19:48:00 +0200 Subject: [PATCH 13/23] =?UTF-8?q?feat(appkit):=20plugin=20infrastructure?= =?UTF-8?q?=20=E2=80=94=20attachContext=20lifecycle=20+=20PluginContext=20?= =?UTF-8?q?mediator?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Third layer: the substrate every downstream PR relies on. No user- facing API changes here; the surface for this PR is the mediator pattern, lifecycle semantics, and factory stamping. `Plugin` constructors become pure — no `CacheManager.getInstanceSync()`, no `TelemetryManager.getProvider()`, no `PluginContext` wiring inside `constructor()`. That work moves to a new lifecycle method: ```ts interface BasePlugin { attachContext?(deps: { context?: unknown; telemetryConfig?: TelemetryOptions; }): void; } ``` `createApp` calls `attachContext()` on every plugin after all constructors have run, before `setup()`. This lets factories return `PluginData` tuples at module scope without pulling core services into the import graph — a prerequisite for later PRs that construct agent definitions before `createApp`. `packages/appkit/src/core/plugin-context.ts` — new class that mediates all inter-plugin communication: - **Route buffering**: `addRoute()` / `addMiddleware()` buffer until the server plugin calls `registerAsRouteTarget()`, then flush via `addExtension()`. Eliminates plugin-ordering fragility. - **ToolProvider registry**: `registerToolProvider(name, plugin)` + live `getToolProviders()`. Typed discovery of tool-exposing plugins. - **User-scoped tool execution**: `executeTool(req, pluginName, localName, args, signal?)` resolves the provider, wraps in `asUser(req)` for OBO, opens a telemetry span, applies a 30s timeout, dispatches, returns. - **Lifecycle hooks**: `onLifecycle('setup:complete' | 'server:ready' | 'shutdown', cb)` + `emitLifecycle(event)`. Callback errors don't block siblings. `packages/appkit/src/plugin/to-plugin.ts` — the factory now attaches a read-only `pluginName` property to the returned function. Later PRs' `fromPlugin(factory)` reads it to identify which plugin a factory refers to without needing to construct an instance. `NamedPluginFactory` type exported for consumers who want to type-constrain factories. `ServerPlugin.setup()` no longer calls `extendRoutes()` synchronously. It subscribes to the `setup:complete` lifecycle event via `PluginContext` and starts the HTTP server there. This ensures that any deferred-phase plugin (agents plugin in a later PR) has had a chance to register routes via `PluginContext.addRoute()` before the server binds. Removes the `plugins` field from `ServerConfig` (routes are now discovered via the context, not a config snapshot). - 25 new PluginContext tests (route buffering, tool provider registry, executeTool paths, lifecycle hooks, plugin metadata) - Updated AppKit lifecycle tests to inject `context` instead of `plugins` - Full appkit vitest suite: 1237 tests passing - Typecheck clean across all 8 workspace projects Signed-off-by: MarioCadenas --- packages/appkit/src/core/appkit.ts | 27 +- packages/appkit/src/core/plugin-context.ts | 287 ++++++++++++++++ .../appkit/src/core/tests/databricks.test.ts | 15 +- .../src/core/tests/plugin-context.test.ts | 325 ++++++++++++++++++ packages/appkit/src/plugin/index.ts | 2 +- packages/appkit/src/plugin/plugin.ts | 56 ++- packages/appkit/src/plugin/to-plugin.ts | 32 +- packages/appkit/src/plugins/server/index.ts | 25 +- .../src/plugins/server/tests/server.test.ts | 38 +- packages/shared/src/plugin.ts | 9 + 10 files changed, 781 insertions(+), 35 deletions(-) create mode 100644 packages/appkit/src/core/plugin-context.ts create mode 100644 packages/appkit/src/core/tests/plugin-context.test.ts diff --git a/packages/appkit/src/core/appkit.ts b/packages/appkit/src/core/appkit.ts index 607a1552..5d1dd455 100644 --- a/packages/appkit/src/core/appkit.ts +++ b/packages/appkit/src/core/appkit.ts @@ -14,16 +14,20 @@ import { createLogger } from "../logging/logger"; import { ResourceRegistry, ResourceType } from "../registry"; import type { TelemetryConfig } from "../telemetry"; import { TelemetryManager } from "../telemetry"; +import { isToolProvider, PluginContext } from "./plugin-context"; const logger = createLogger("appkit"); export class AppKit { #pluginInstances: Record = {}; #setupPromises: Promise[] = []; + #context: PluginContext; private constructor(config: { plugins: TPlugins }) { const { plugins, ...globalConfig } = config; + this.#context = new PluginContext(); + const pluginEntries = Object.entries(plugins); const corePlugins = pluginEntries.filter(([_, p]) => { @@ -38,20 +42,24 @@ export class AppKit { for (const [name, pluginData] of corePlugins) { if (pluginData) { - this.createAndRegisterPlugin(globalConfig, name, pluginData); + this.createAndRegisterPlugin(globalConfig, name, pluginData, { + context: this.#context, + }); } } for (const [name, pluginData] of normalPlugins) { if (pluginData) { - this.createAndRegisterPlugin(globalConfig, name, pluginData); + this.createAndRegisterPlugin(globalConfig, name, pluginData, { + context: this.#context, + }); } } for (const [name, pluginData] of deferredPlugins) { if (pluginData) { this.createAndRegisterPlugin(globalConfig, name, pluginData, { - plugins: this.#pluginInstances, + context: this.#context, }); } } @@ -73,8 +81,20 @@ export class AppKit { }; const pluginInstance = new Plugin(baseConfig); + if (typeof pluginInstance.attachContext === "function") { + pluginInstance.attachContext({ + context: this.#context, + telemetryConfig: baseConfig.telemetry, + }); + } + this.#pluginInstances[name] = pluginInstance; + this.#context.registerPlugin(name, pluginInstance); + if (isToolProvider(pluginInstance)) { + this.#context.registerToolProvider(name, pluginInstance); + } + this.#setupPromises.push(pluginInstance.setup()); const self = this; @@ -203,6 +223,7 @@ export class AppKit { const instance = new AppKit(mergedConfig); await Promise.all(instance.#setupPromises); + await instance.#context.emitLifecycle("setup:complete"); const handle = instance as unknown as PluginMap; diff --git a/packages/appkit/src/core/plugin-context.ts b/packages/appkit/src/core/plugin-context.ts new file mode 100644 index 00000000..c2801585 --- /dev/null +++ b/packages/appkit/src/core/plugin-context.ts @@ -0,0 +1,287 @@ +import type express from "express"; +import type { BasePlugin, ToolProvider } from "shared"; +import { createLogger } from "../logging/logger"; +import { TelemetryManager } from "../telemetry"; + +const logger = createLogger("plugin-context"); + +interface BufferedRoute { + method: string; + path: string; + handlers: express.RequestHandler[]; +} + +interface RouteTarget { + addExtension(fn: (app: express.Application) => void): void; +} + +interface ToolProviderEntry { + plugin: BasePlugin & ToolProvider; + name: string; +} + +type LifecycleEvent = "setup:complete" | "server:ready" | "shutdown"; + +/** + * Mediator for inter-plugin communication. + * + * Created by AppKit core and passed to every plugin. Plugins request + * capabilities from the context instead of holding direct references + * to sibling plugin instances. + * + * Capabilities: + * - Route mounting with buffering (order-independent) + * - Typed ToolProvider registry (live, not snapshot-based) + * - User-scoped tool execution with automatic telemetry + * - Lifecycle hooks for plugin coordination + */ +export class PluginContext { + private routeBuffer: BufferedRoute[] = []; + private routeTarget: RouteTarget | null = null; + private toolProviders = new Map(); + private plugins = new Map(); + private lifecycleHooks = new Map< + LifecycleEvent, + Set<() => void | Promise> + >(); + private telemetry = TelemetryManager.getProvider("plugin-context"); + + /** + * Register a route on the root Express application. + * + * If a route target (server plugin) has registered, the route is applied + * immediately. Otherwise it is buffered and flushed when a route target + * becomes available. + */ + addRoute( + method: string, + path: string, + ...handlers: express.RequestHandler[] + ): void { + if (this.routeTarget) { + this.applyRoute({ method, path, handlers }); + } else { + this.routeBuffer.push({ method, path, handlers }); + } + } + + /** + * Register middleware on the root Express application. + * + * Same buffering semantics as `addRoute`. + */ + addMiddleware(path: string, ...handlers: express.RequestHandler[]): void { + if (this.routeTarget) { + this.applyMiddleware(path, handlers); + } else { + this.routeBuffer.push({ method: "use", path, handlers }); + } + } + + /** + * Called by the server plugin to opt in as the route target. + * Flushes all buffered routes via the server's `addExtension`. + */ + registerAsRouteTarget(target: RouteTarget): void { + this.routeTarget = target; + + for (const route of this.routeBuffer) { + if (route.method === "use") { + this.applyMiddleware(route.path, route.handlers); + } else { + this.applyRoute(route); + } + } + this.routeBuffer = []; + } + + /** + * Register a plugin that implements the ToolProvider interface. + * Called by AppKit core after constructing each plugin. + */ + registerToolProvider(name: string, plugin: BasePlugin & ToolProvider): void { + this.toolProviders.set(name, { plugin, name }); + } + + /** + * Register a plugin instance. + * Called by AppKit core after constructing each plugin. + */ + registerPlugin(name: string, instance: BasePlugin): void { + this.plugins.set(name, instance); + } + + /** + * Returns all registered plugin instances keyed by name. + * Used by the server plugin for route injection, client config, + * and shutdown coordination. + */ + getPlugins(): Map { + return this.plugins; + } + + /** + * Returns all registered ToolProvider plugins. + * Always returns the current set — not a frozen snapshot. + */ + getToolProviders(): Array<{ name: string; provider: ToolProvider }> { + return Array.from(this.toolProviders.values()).map((entry) => ({ + name: entry.name, + provider: entry.plugin, + })); + } + + /** + * Execute a tool on a ToolProvider plugin with automatic user scoping + * and telemetry. + * + * The context: + * 1. Resolves the plugin by name + * 2. Calls `asUser(req)` for user-scoped execution + * 3. Wraps the call in a telemetry span with a 30s timeout + */ + async executeTool( + req: express.Request, + pluginName: string, + toolName: string, + args: unknown, + signal?: AbortSignal, + ): Promise { + const entry = this.toolProviders.get(pluginName); + if (!entry) { + throw new Error( + `PluginContext: unknown plugin "${pluginName}". Available: ${Array.from(this.toolProviders.keys()).join(", ")}`, + ); + } + + const tracer = this.telemetry.getTracer(); + const operationName = `executeTool:${pluginName}.${toolName}`; + + return tracer.startActiveSpan(operationName, async (span) => { + const timeout = 30_000; + const timeoutSignal = AbortSignal.timeout(timeout); + const combinedSignal = signal + ? AbortSignal.any([signal, timeoutSignal]) + : timeoutSignal; + + try { + const userPlugin = (entry.plugin as any).asUser(req); + const result = await (userPlugin as ToolProvider).executeAgentTool( + toolName, + args, + combinedSignal, + ); + span.setStatus({ code: 0 }); + return result; + } catch (error) { + span.setStatus({ + code: 2, + message: + error instanceof Error ? error.message : "Tool execution failed", + }); + span.recordException( + error instanceof Error ? error : new Error(String(error)), + ); + throw error; + } finally { + span.end(); + } + }); + } + + /** + * Register a lifecycle hook callback. + */ + onLifecycle(event: LifecycleEvent, fn: () => void | Promise): void { + let hooks = this.lifecycleHooks.get(event); + if (!hooks) { + hooks = new Set(); + this.lifecycleHooks.set(event, hooks); + } + hooks.add(fn); + } + + /** + * Emit a lifecycle event, calling all registered callbacks. + * Errors in individual callbacks are logged but do not prevent + * other callbacks from running. + * + * @internal Called by AppKit core only. + */ + async emitLifecycle(event: LifecycleEvent): Promise { + const hooks = this.lifecycleHooks.get(event); + if (!hooks) return; + + if ( + event === "setup:complete" && + this.routeBuffer.length > 0 && + !this.routeTarget + ) { + logger.warn( + "%d buffered routes were never applied — no server plugin registered as route target", + this.routeBuffer.length, + ); + } + + for (const fn of hooks) { + try { + await fn(); + } catch (error) { + logger.error("Lifecycle hook '%s' failed: %O", event, error); + } + } + } + + /** + * Returns all registered plugin names. + */ + getPluginNames(): string[] { + return Array.from(this.plugins.keys()); + } + + /** + * Check if a plugin with the given name is registered. + */ + hasPlugin(name: string): boolean { + return this.plugins.has(name); + } + + private applyRoute(route: BufferedRoute): void { + if (!this.routeTarget) return; + this.routeTarget.addExtension((app) => { + const method = route.method.toLowerCase() as keyof express.Application; + if (typeof app[method] === "function") { + (app[method] as (...a: unknown[]) => void)( + route.path, + ...route.handlers, + ); + } + }); + } + + private applyMiddleware( + path: string, + handlers: express.RequestHandler[], + ): void { + if (!this.routeTarget) return; + this.routeTarget.addExtension((app) => { + app.use(path, ...handlers); + }); + } +} + +/** + * Type guard: checks whether a plugin implements the ToolProvider interface. + */ +export function isToolProvider( + plugin: unknown, +): plugin is BasePlugin & ToolProvider { + return ( + typeof plugin === "object" && + plugin !== null && + "getAgentTools" in plugin && + typeof (plugin as ToolProvider).getAgentTools === "function" && + "executeAgentTool" in plugin && + typeof (plugin as ToolProvider).executeAgentTool === "function" + ); +} diff --git a/packages/appkit/src/core/tests/databricks.test.ts b/packages/appkit/src/core/tests/databricks.test.ts index c05345a6..9d3fe5f8 100644 --- a/packages/appkit/src/core/tests/databricks.test.ts +++ b/packages/appkit/src/core/tests/databricks.test.ts @@ -109,11 +109,11 @@ class DeferredTestPlugin implements BasePlugin { name = "deferredTest"; setupCalled = false; injectedConfig: any; - injectedPlugins: any; + injectedContext: any; constructor(config: any) { this.injectedConfig = config; - this.injectedPlugins = config.plugins; + this.injectedContext = config.context; } async setup() { @@ -130,7 +130,7 @@ class DeferredTestPlugin implements BasePlugin { return { setupCalled: this.setupCalled, injectedConfig: this.injectedConfig, - injectedPlugins: this.injectedPlugins, + injectedContext: this.injectedContext, }; } } @@ -276,7 +276,7 @@ describe("AppKit", () => { expect(setupOrder).toEqual(["core", "normal", "deferred"]); }); - test("should provide plugin instances to deferred plugins", async () => { + test("should provide PluginContext to deferred plugins", async () => { const pluginData = [ { plugin: CoreTestPlugin, config: {}, name: "coreTest" }, { plugin: DeferredTestPlugin, config: {}, name: "deferredTest" }, @@ -284,10 +284,9 @@ describe("AppKit", () => { const instance = (await createApp({ plugins: pluginData })) as any; - // Deferred plugins receive plugin instances (not SDKs) for internal use - expect(instance.deferredTest.injectedPlugins).toBeDefined(); - expect(instance.deferredTest.injectedPlugins.coreTest).toBeInstanceOf( - CoreTestPlugin, + expect(instance.deferredTest.injectedContext).toBeDefined(); + expect(instance.deferredTest.injectedContext.hasPlugin("coreTest")).toBe( + true, ); }); diff --git a/packages/appkit/src/core/tests/plugin-context.test.ts b/packages/appkit/src/core/tests/plugin-context.test.ts new file mode 100644 index 00000000..276c5502 --- /dev/null +++ b/packages/appkit/src/core/tests/plugin-context.test.ts @@ -0,0 +1,325 @@ +import type { AgentToolDefinition } from "shared"; +import { beforeEach, describe, expect, test, vi } from "vitest"; +import { isToolProvider, PluginContext } from "../plugin-context"; + +vi.mock("../../telemetry", () => ({ + TelemetryManager: { + getProvider: () => ({ + getTracer: () => ({ + startActiveSpan: (_name: string, fn: (span: any) => any) => { + const span = { + setStatus: vi.fn(), + recordException: vi.fn(), + end: vi.fn(), + }; + return fn(span); + }, + }), + }), + }, +})); + +vi.mock("../../logging/logger", () => ({ + createLogger: () => ({ + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + }), +})); + +function createMockToolProvider(tools: AgentToolDefinition[] = []) { + const mock = { + name: "mock-plugin", + setup: vi.fn().mockResolvedValue(undefined), + injectRoutes: vi.fn(), + getEndpoints: vi.fn().mockReturnValue({}), + getAgentTools: vi.fn().mockReturnValue(tools), + executeAgentTool: vi.fn().mockResolvedValue("tool-result"), + asUser: vi.fn().mockReturnThis(), + }; + return mock as any; +} + +describe("PluginContext", () => { + let ctx: PluginContext; + + beforeEach(() => { + ctx = new PluginContext(); + }); + + describe("route buffering", () => { + test("addRoute buffers when no route target exists", () => { + const handler = vi.fn(); + ctx.addRoute("post", "/invocations", handler); + + expect(ctx.getPluginNames()).toEqual([]); + }); + + test("flushRoutes applies buffered routes via addExtension", () => { + const handler = vi.fn(); + ctx.addRoute("post", "/invocations", handler); + + const addExtension = vi.fn(); + ctx.registerAsRouteTarget({ addExtension }); + + expect(addExtension).toHaveBeenCalledTimes(1); + const extensionFn = addExtension.mock.calls[0][0]; + + const mockApp = { post: vi.fn() }; + extensionFn(mockApp); + expect(mockApp.post).toHaveBeenCalledWith("/invocations", handler); + }); + + test("addRoute called after registerAsRouteTarget applies immediately", () => { + const addExtension = vi.fn(); + ctx.registerAsRouteTarget({ addExtension }); + + const handler = vi.fn(); + ctx.addRoute("get", "/health", handler); + + expect(addExtension).toHaveBeenCalledTimes(1); + const extensionFn = addExtension.mock.calls[0][0]; + + const mockApp = { get: vi.fn() }; + extensionFn(mockApp); + expect(mockApp.get).toHaveBeenCalledWith("/health", handler); + }); + + test("addRoute supports middleware chains", () => { + const auth = vi.fn(); + const handler = vi.fn(); + + const addExtension = vi.fn(); + ctx.registerAsRouteTarget({ addExtension }); + + ctx.addRoute("post", "/api", auth, handler); + + const extensionFn = addExtension.mock.calls[0][0]; + const mockApp = { post: vi.fn() }; + extensionFn(mockApp); + expect(mockApp.post).toHaveBeenCalledWith("/api", auth, handler); + }); + + test("addMiddleware buffers and applies via use()", () => { + const handler = vi.fn(); + ctx.addMiddleware("/api", handler); + + const addExtension = vi.fn(); + ctx.registerAsRouteTarget({ addExtension }); + + expect(addExtension).toHaveBeenCalledTimes(1); + const extensionFn = addExtension.mock.calls[0][0]; + + const mockApp = { use: vi.fn() }; + extensionFn(mockApp); + expect(mockApp.use).toHaveBeenCalledWith("/api", handler); + }); + + test("multiple buffered routes are all applied on registration", () => { + const h1 = vi.fn(); + const h2 = vi.fn(); + ctx.addRoute("post", "/a", h1); + ctx.addRoute("get", "/b", h2); + + const addExtension = vi.fn(); + ctx.registerAsRouteTarget({ addExtension }); + + expect(addExtension).toHaveBeenCalledTimes(2); + }); + }); + + describe("ToolProvider registry", () => { + test("registerToolProvider makes provider visible via getToolProviders", () => { + const provider = createMockToolProvider([ + { + name: "query", + description: "Run query", + parameters: { type: "object" }, + }, + ]); + + ctx.registerToolProvider("analytics", provider); + + const providers = ctx.getToolProviders(); + expect(providers).toHaveLength(1); + expect(providers[0].name).toBe("analytics"); + expect(providers[0].provider.getAgentTools()).toHaveLength(1); + }); + + test("getToolProviders returns all registered providers", () => { + ctx.registerToolProvider("analytics", createMockToolProvider()); + ctx.registerToolProvider("files", createMockToolProvider()); + ctx.registerToolProvider("genie", createMockToolProvider()); + + expect(ctx.getToolProviders()).toHaveLength(3); + }); + + test("getToolProviders returns current set, not snapshot", () => { + const before = ctx.getToolProviders(); + expect(before).toHaveLength(0); + + ctx.registerToolProvider("analytics", createMockToolProvider()); + + const after = ctx.getToolProviders(); + expect(after).toHaveLength(1); + }); + }); + + describe("executeTool", () => { + test("calls asUser(req).executeAgentTool on the correct plugin", async () => { + const provider = createMockToolProvider(); + ctx.registerToolProvider("analytics", provider); + + const mockReq = { headers: {} } as any; + await ctx.executeTool(mockReq, "analytics", "query", { sql: "SELECT 1" }); + + expect(provider.asUser).toHaveBeenCalledWith(mockReq); + expect(provider.executeAgentTool).toHaveBeenCalledWith( + "query", + { sql: "SELECT 1" }, + expect.any(Object), + ); + }); + + test("throws for unknown plugin name", async () => { + const mockReq = { headers: {} } as any; + + await expect( + ctx.executeTool(mockReq, "nonexistent", "query", {}), + ).rejects.toThrow('unknown plugin "nonexistent"'); + }); + + test("propagates tool execution errors", async () => { + const provider = createMockToolProvider(); + (provider.executeAgentTool as any).mockRejectedValue( + new Error("Query failed"), + ); + ctx.registerToolProvider("analytics", provider); + + const mockReq = { headers: {} } as any; + + await expect( + ctx.executeTool(mockReq, "analytics", "query", {}), + ).rejects.toThrow("Query failed"); + }); + + test("passes abort signal to executeAgentTool", async () => { + const provider = createMockToolProvider(); + ctx.registerToolProvider("analytics", provider); + + const controller = new AbortController(); + const mockReq = { headers: {} } as any; + + await ctx.executeTool( + mockReq, + "analytics", + "query", + {}, + controller.signal, + ); + + const callArgs = (provider.executeAgentTool as any).mock.calls[0]; + expect(callArgs[2]).toBeDefined(); + }); + }); + + describe("lifecycle hooks", () => { + test("onLifecycle registers callback, emitLifecycle invokes it", async () => { + const fn = vi.fn(); + ctx.onLifecycle("setup:complete", fn); + + await ctx.emitLifecycle("setup:complete"); + + expect(fn).toHaveBeenCalledTimes(1); + }); + + test("multiple callbacks for the same event all fire", async () => { + const fn1 = vi.fn(); + const fn2 = vi.fn(); + ctx.onLifecycle("setup:complete", fn1); + ctx.onLifecycle("setup:complete", fn2); + + await ctx.emitLifecycle("setup:complete"); + + expect(fn1).toHaveBeenCalledTimes(1); + expect(fn2).toHaveBeenCalledTimes(1); + }); + + test("callback error does not prevent other callbacks from running", async () => { + const fn1 = vi.fn().mockRejectedValue(new Error("fail")); + const fn2 = vi.fn(); + ctx.onLifecycle("shutdown", fn1); + ctx.onLifecycle("shutdown", fn2); + + await ctx.emitLifecycle("shutdown"); + + expect(fn1).toHaveBeenCalled(); + expect(fn2).toHaveBeenCalled(); + }); + + test("emitLifecycle with no registered hooks does nothing", async () => { + await expect(ctx.emitLifecycle("server:ready")).resolves.toBeUndefined(); + }); + }); + + describe("plugin metadata", () => { + const stubPlugin = { name: "stub" } as any; + + test("getPluginNames returns all registered names", () => { + ctx.registerPlugin("analytics", stubPlugin); + ctx.registerPlugin("server", stubPlugin); + ctx.registerPlugin("agent", stubPlugin); + + const names = ctx.getPluginNames(); + expect(names).toContain("analytics"); + expect(names).toContain("server"); + expect(names).toContain("agent"); + expect(names).toHaveLength(3); + }); + + test("hasPlugin returns true for registered plugins", () => { + ctx.registerPlugin("analytics", stubPlugin); + + expect(ctx.hasPlugin("analytics")).toBe(true); + expect(ctx.hasPlugin("nonexistent")).toBe(false); + }); + + test("getPlugins returns all registered instances", () => { + const p1 = { name: "analytics" } as any; + const p2 = { name: "server" } as any; + ctx.registerPlugin("analytics", p1); + ctx.registerPlugin("server", p2); + + const plugins = ctx.getPlugins(); + expect(plugins.size).toBe(2); + expect(plugins.get("analytics")).toBe(p1); + expect(plugins.get("server")).toBe(p2); + }); + }); +}); + +describe("isToolProvider", () => { + test("returns true for objects with getAgentTools and executeAgentTool", () => { + const provider = createMockToolProvider(); + expect(isToolProvider(provider)).toBe(true); + }); + + test("returns false for null", () => { + expect(isToolProvider(null)).toBe(false); + }); + + test("returns false for objects missing executeAgentTool", () => { + expect(isToolProvider({ getAgentTools: vi.fn() })).toBe(false); + }); + + test("returns false for objects missing getAgentTools", () => { + expect(isToolProvider({ executeAgentTool: vi.fn() })).toBe(false); + }); + + test("returns false for non-objects", () => { + expect(isToolProvider("string")).toBe(false); + expect(isToolProvider(42)).toBe(false); + expect(isToolProvider(undefined)).toBe(false); + }); +}); diff --git a/packages/appkit/src/plugin/index.ts b/packages/appkit/src/plugin/index.ts index 93765219..46a4eb94 100644 --- a/packages/appkit/src/plugin/index.ts +++ b/packages/appkit/src/plugin/index.ts @@ -1,4 +1,4 @@ export type { ToPlugin } from "shared"; export type { ExecutionResult } from "./execution-result"; export { Plugin } from "./plugin"; -export { toPlugin } from "./to-plugin"; +export { type NamedPluginFactory, toPlugin } from "./to-plugin"; diff --git a/packages/appkit/src/plugin/plugin.ts b/packages/appkit/src/plugin/plugin.ts index 5173cb61..4c9a0e64 100644 --- a/packages/appkit/src/plugin/plugin.ts +++ b/packages/appkit/src/plugin/plugin.ts @@ -19,6 +19,7 @@ import { ServiceContext, type UserContext, } from "../context"; +import type { PluginContext } from "../core/plugin-context"; import { AppKitError, AuthenticationError } from "../errors"; import { createLogger } from "../logging/logger"; import { StreamManager } from "../stream"; @@ -163,11 +164,12 @@ export abstract class Plugin< > implements BasePlugin { protected isReady = false; - protected cache: CacheManager; + protected cache!: CacheManager; protected app: AppManager; protected devFileReader: DevFileReader; protected streamManager: StreamManager; - protected telemetry: ITelemetry; + protected telemetry!: ITelemetry; + protected context?: PluginContext; /** Registered endpoints for this plugin */ private registeredEndpoints: PluginEndpointMap = {}; @@ -193,12 +195,58 @@ export abstract class Plugin< config.name ?? (this.constructor as { manifest?: { name: string } }).manifest?.name ?? "plugin"; - this.telemetry = TelemetryManager.getProvider(this.name, config.telemetry); this.streamManager = new StreamManager(); - this.cache = CacheManager.getInstanceSync(); this.app = new AppManager(); this.devFileReader = DevFileReader.getInstance(); + this.context = (config as Record).context as + | PluginContext + | undefined; + + // Eagerly bind telemetry + cache if the core services have already been + // initialized (normal createApp path, or tests that mock CacheManager). + // If they haven't, we leave these undefined and rely on `attachContext` + // being called later — this lets factories eagerly construct plugin + // instances at module top-level before `createApp` has run. + this.tryAttachContext(); + } + + private tryAttachContext(): void { + try { + this.cache = CacheManager.getInstanceSync(); + } catch { + return; + } + this.telemetry = TelemetryManager.getProvider( + this.name, + this.config.telemetry, + ); + this.isReady = true; + } + /** + * Binds runtime dependencies (telemetry provider, cache, plugin context) to + * this plugin. Called by `AppKit._createApp` after construction and before + * `setup()`. Idempotent: safe to call if the constructor already bound them + * eagerly. Kept separate so factories can eagerly construct plugin instances + * without running this before `TelemetryManager.initialize()` / + * `CacheManager.getInstance()` have run. + */ + attachContext( + deps: { + context?: unknown; + telemetryConfig?: BasePluginConfig["telemetry"]; + } = {}, + ): void { + if (!this.cache) { + this.cache = CacheManager.getInstanceSync(); + } + this.telemetry = TelemetryManager.getProvider( + this.name, + deps.telemetryConfig ?? this.config.telemetry, + ); + if (deps.context !== undefined) { + this.context = deps.context as PluginContext; + } this.isReady = true; } diff --git a/packages/appkit/src/plugin/to-plugin.ts b/packages/appkit/src/plugin/to-plugin.ts index 77725027..c882f300 100644 --- a/packages/appkit/src/plugin/to-plugin.ts +++ b/packages/appkit/src/plugin/to-plugin.ts @@ -1,19 +1,41 @@ import type { PluginConstructor, PluginData, ToPlugin } from "shared"; /** - * Wraps a plugin class so it can be passed to createApp with optional config. - * Infers config type from the constructor and plugin name from the static `name` property. + * Factory function produced by {@link toPlugin}. Carries a static + * `pluginName` field so tooling (e.g. `fromPlugin`) can identify which + * plugin a factory references without constructing an instance. + */ +export type NamedPluginFactory = { + readonly pluginName: Name; +}; + +/** + * Wraps a plugin class so it can be passed to `createApp` with optional + * config. Infers the config type from the constructor and the plugin name + * from the static `manifest.name` property, and stamps `pluginName` onto + * the returned factory function so `fromPlugin` can identify the plugin + * without needing to construct it. * * @internal */ export function toPlugin( plugin: T, -): ToPlugin[0], T["manifest"]["name"]> { +): ToPlugin[0], T["manifest"]["name"]> & + NamedPluginFactory { type Config = ConstructorParameters[0]; type Name = T["manifest"]["name"]; - return (config: Config = {} as Config): PluginData => ({ + const pluginName = plugin.manifest.name as Name; + const factory = ( + config: Config = {} as Config, + ): PluginData => ({ plugin: plugin as T, config: config as Config, - name: plugin.manifest.name as Name, + name: pluginName, + }); + Object.defineProperty(factory, "pluginName", { + value: pluginName, + writable: false, + enumerable: true, }); + return factory as ToPlugin & NamedPluginFactory; } diff --git a/packages/appkit/src/plugins/server/index.ts b/packages/appkit/src/plugins/server/index.ts index 8ed13cea..cc58cc0d 100644 --- a/packages/appkit/src/plugins/server/index.ts +++ b/packages/appkit/src/plugins/server/index.ts @@ -72,10 +72,15 @@ export class ServerPlugin extends Plugin { this.serverApplication = express(); this.server = null; this.serverExtensions = []; + } + + attachContext(deps: Parameters[0] = {}): void { + super.attachContext(deps); this.telemetry.registerInstrumentations([ instrumentations.http, instrumentations.express, ]); + this.context?.registerAsRouteTarget(this); } async setup() {} @@ -176,6 +181,16 @@ export class ServerPlugin extends Plugin { return this; } + /** + * Register a server extension from another plugin during setup. + * Unlike extend(), this does not guard on autoStart — it's designed + * for internal plugin-to-plugin coordination where extensions are + * registered before the server starts listening. + */ + addExtension(fn: (app: express.Application) => void) { + this.serverExtensions.push(fn); + } + /** * Setup the routes with the plugins. * @@ -190,14 +205,15 @@ export class ServerPlugin extends Plugin { const endpoints: PluginEndpoints = {}; const pluginConfigs: PluginClientConfigs = {}; - if (!this.config.plugins) return { endpoints, pluginConfigs }; + const plugins = this.context?.getPlugins(); + if (!plugins || plugins.size === 0) return { endpoints, pluginConfigs }; this.serverApplication.get("/health", (_, res) => { res.status(200).json({ status: "ok" }); }); this.registerEndpoint("health", "/health"); - for (const plugin of Object.values(this.config.plugins)) { + for (const plugin of plugins.values()) { if (EXCLUDED_PLUGINS.includes(plugin.name)) continue; if (plugin?.injectRoutes && typeof plugin.injectRoutes === "function") { @@ -346,8 +362,9 @@ export class ServerPlugin extends Plugin { } // 1. abort active operations from plugins - if (this.config.plugins) { - for (const plugin of Object.values(this.config.plugins)) { + const shutdownPlugins = this.context?.getPlugins(); + if (shutdownPlugins) { + for (const plugin of shutdownPlugins.values()) { if (plugin.abortActiveOperations) { try { plugin.abortActiveOperations(); diff --git a/packages/appkit/src/plugins/server/tests/server.test.ts b/packages/appkit/src/plugins/server/tests/server.test.ts index fae11fb5..76088348 100644 --- a/packages/appkit/src/plugins/server/tests/server.test.ts +++ b/packages/appkit/src/plugins/server/tests/server.test.ts @@ -1,4 +1,6 @@ +import type { BasePlugin } from "shared"; import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { PluginContext } from "../../../core/plugin-context"; // Use vi.hoisted for mocks that need to be available before module loading const { @@ -171,6 +173,14 @@ import { RemoteTunnelController } from "../remote-tunnel/remote-tunnel-controlle import { StaticServer } from "../static-server"; import { ViteDevServer } from "../vite-dev-server"; +function createContextWithPlugins(plugins: Record): PluginContext { + const ctx = new PluginContext(); + for (const [name, instance] of Object.entries(plugins)) { + ctx.registerPlugin(name, instance as BasePlugin); + } + return ctx; +} + describe("ServerPlugin", () => { let originalEnv: NodeJS.ProcessEnv; @@ -322,7 +332,7 @@ describe("ServerPlugin", () => { process.env.NODE_ENV = "production"; const injectRoutes = vi.fn(); - const plugins: any = { + const testPlugins: any = { "test-plugin": { name: "test-plugin", injectRoutes, @@ -330,7 +340,9 @@ describe("ServerPlugin", () => { }, }; - const plugin = new ServerPlugin({ plugins }); + const plugin = new ServerPlugin({ + context: createContextWithPlugins(testPlugins), + } as any); await plugin.start(); const routerFn = (express as any).Router as ReturnType; @@ -368,7 +380,9 @@ describe("ServerPlugin", () => { }, }; - const plugin = new ServerPlugin({ plugins }); + const plugin = new ServerPlugin({ + context: createContextWithPlugins(plugins), + } as any); await plugin.start(); expect(plugins["plugin-a"].clientConfig).toHaveBeenCalled(); @@ -395,7 +409,9 @@ describe("ServerPlugin", () => { }, }; - const plugin = new ServerPlugin({ plugins }); + const plugin = new ServerPlugin({ + context: createContextWithPlugins(plugins), + } as any); await plugin.start(); expect(plugins["plugin-null"].clientConfig).toHaveBeenCalled(); @@ -426,7 +442,9 @@ describe("ServerPlugin", () => { }, }; - const plugin = new ServerPlugin({ plugins }); + const plugin = new ServerPlugin({ + context: createContextWithPlugins(plugins), + } as any); await expect(plugin.start()).resolves.toBeDefined(); expect(mockLoggerError).toHaveBeenCalledWith( "Plugin '%s' clientConfig() failed, skipping its config: %O", @@ -581,19 +599,19 @@ describe("ServerPlugin", () => { .mockImplementation(((_code?: number) => undefined) as any); const plugin = new ServerPlugin({ - plugins: { + context: createContextWithPlugins({ ok: { name: "ok", abortActiveOperations: vi.fn(), - } as any, + }, bad: { name: "bad", abortActiveOperations: vi.fn(() => { throw new Error("boom"); }), - } as any, - }, - }); + }, + }), + } as any); // pretend started (plugin as any).server = mockHttpServer; diff --git a/packages/shared/src/plugin.ts b/packages/shared/src/plugin.ts index 9fa8066c..651840c7 100644 --- a/packages/shared/src/plugin.ts +++ b/packages/shared/src/plugin.ts @@ -26,6 +26,15 @@ export interface BasePlugin { exports?(): unknown; clientConfig?(): Record; + + /** + * Binds runtime dependencies (telemetry, cache, plugin context) after the + * plugin has been constructed. Called by the AppKit core before `setup()`. + */ + attachContext?(deps: { + context?: unknown; + telemetryConfig?: TelemetryOptions; + }): void; } /** Base configuration interface for AppKit plugins */ From aafb5117359a2fb28c762b4757a0a0b7ae9202ce Mon Sep 17 00:00:00 2001 From: MarioCadenas Date: Tue, 21 Apr 2026 19:51:21 +0200 Subject: [PATCH 14/23] feat(appkit): agents() plugin, createAgent(def), and markdown-driven agents MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The main product layer. Turns an AppKit app into an AI-agent host with markdown-driven agent discovery, code-defined agents, sub-agents, and a standalone run-without-HTTP executor. `packages/appkit/src/core/create-agent-def.ts`. Returns the passed-in definition after cycle-detecting the sub-agent graph. No adapter construction, no side effects — safe at module top-level. The returned `AgentDefinition` is plain data, consumable by either `agents({ agents })` or `runAgent(def, input)`. `packages/appkit/src/plugins/agents/agents.ts`. `AgentsPlugin` class: - Loads markdown agents from `config/agents/*.md` (configurable dir) via real YAML frontmatter parsing (`js-yaml`). Frontmatter schema: `endpoint`, `model`, `toolkits`, `tools`, `default`, `maxSteps`, `maxTokens`, `baseSystemPrompt`. Unknown keys logged, invalid YAML throws at boot. - Merges code-defined agents passed via `agents({ agents: { name: def } })`. Code wins on key collision. - For each agent, builds a per-agent tool index from: 1. Sub-agents (`agents: {...}`) — synthesized as `agent-` tools on the parent. 2. Explicit tool record entries — `ToolkitEntry`s, inline `FunctionTool`s, or `HostedTool`s. 3. Auto-inherit (if nothing explicit) — pulls every registered `ToolProvider` plugin's tools. Asymmetric default: markdown agents inherit (`file: true`), code-defined agents don't (`code: false`). - Mounts `POST /invocations` (OpenAI Responses compatible) + `POST /chat`, `POST /cancel`, `GET /threads/:id`, `DELETE /threads/:id`, `GET /info`. - SSE streaming via `executeStream`. Tool calls dispatch through `PluginContext.executeTool(req, pluginName, localName, args, signal)` for OBO, telemetry, and timeout. - Exposes `appkit.agent.{register, list, get, reload, getDefault, getThreads}` runtime helpers. `packages/appkit/src/core/run-agent.ts`. Runs an `AgentDefinition` without `createApp` or HTTP. Drives the adapter's event stream to completion, executing inline tools + sub-agents along the way. Aggregates events into `{ text, events }`. Useful for tests, CLI scripts, and offline pipelines. Hosted/MCP tools and plugin toolkits require the agents plugin and throw clear errors with guidance. - `AgentEventTranslator` — stateful converter from internal `AgentEvent`s to OpenAI Responses API `ResponseStreamEvent`s with sequence numbers and output indices. - `InMemoryThreadStore` — per-user conversation persistence. Nested `Map>`. Implements `ThreadStore` from shared types. - `buildBaseSystemPrompt` + `composeSystemPrompt` — formats the AppKit base prompt (with plugin names and tool names) and layers the agent's instructions on top. `load-agents.ts` — reads `*.md` files, parses YAML frontmatter with `js-yaml`, resolves `toolkits: [...]` entries against the plugin provider index at load time, wraps ambient tools (from `agents({ tools: {...} })`) for `tools: [...]` frontmatter references. - Adds `js-yaml` + `@types/js-yaml` deps. - Manifest mounts routes at `/api/agent/*` (singular — matches `appkit.agent.*` runtime handle). - Exports from the main barrel: `agents`, `createAgent`, `runAgent`, `AgentDefinition`, `AgentsPluginConfig`, `AgentTool`, `ToolkitEntry`, `ToolkitOptions`, `BaseSystemPromptOption`, `PromptContext`, `isToolkitEntry`, `loadAgentFromFile`, `loadAgentsFromDir`. - 60 new tests: agents plugin lifecycle, markdown loading, code-agent registration, auto-inherit asymmetry, sub-agent tool synthesis, cycle detection, event translator, thread store, system prompt composition, standalone `runAgent`. - Full appkit vitest suite: 1297 tests passing. - Typecheck clean across all 8 workspace projects. Signed-off-by: MarioCadenas `connectHostedTools` now builds an `McpHostPolicy` from the new `config.mcp` field (`trustedHosts`, `allowLocalhost`) and passes it to `AppKitMcpClient`. Same-origin workspace URLs are admitted with workspace auth; all other hosts must be explicitly trusted, and workspace credentials are never forwarded to them. - `AgentsPluginConfig.mcp?: McpHostPolicyConfig` added to the public config surface. See PR #302 for the policy definition and tests. - No behavioural change for the default case (same-origin Databricks workspace URLs): those continue to receive SP on setup and caller OBO on `tools/call`. - `knip.json`: ignore `packages/appkit/src/plugin/to-plugin.ts`. The `NamedPluginFactory` export introduced by this layer is consumed in a later stack layer; knip flags it as unused in the intermediate state. Signed-off-by: MarioCadenas Adds a secure-by-default HITL approval gate for any tool annotated `destructive: true`. Before executing such a tool the agents plugin: 1. Emits a new `appkit.approval_pending` SSE event carrying the `approval_id`, `stream_id`, `tool_name`, `args`, and `annotations`. 2. Awaits a matching `POST /chat/approve` from the same user who initiated the stream. 3. Auto-denies after the configurable `timeoutMs` (default 60 s). A denial returns a short denial string to the adapter as the tool output so the LLM can apologise / replan instead of crashing. ```ts agents({ approval: { requireForDestructive?: boolean, // default: true timeoutMs?: number, // default: 60_000 }, }); ``` - `event-channel.ts`: single-producer / single-consumer async queue used to merge adapter events with out-of-band events emitted by `executeTool` (same SSE stream, single `executeStream` sink). - `tool-approval-gate.ts`: state machine keyed by `approvalId`. Owns the pending promise + timeout, enforces ownership on submit, exposes `abortStream(streamId)` + `abortAll()` for clean teardown. - `AgentEvent` gains an `approval_pending` variant. - `ResponseStreamEvent` gains `AppKitApprovalPendingEvent`. - `AgentEventTranslator.translate()` handles both. - `POST /approve` route with zod validation, ownership check, and 404 / 403 / 200 semantics. - `POST /cancel` now enforces the same ownership invariant (`resolveUserId(req) === stream.userId`) and aborts any pending approval gates on the stream. - `event-channel.test.ts` (7): ordering, buffered-before-iter, close semantics, close-with-error rejection, interleave. - `tool-approval-gate.test.ts` (8): approve / deny / timeout / ownership / abortStream / abortAll / late-submit no-op. - `approval-route.test.ts` (8): schema validation, unknown stream, ownership refusal, unknown approvalId, approve happy path, deny happy path, cancel clears pending gates, cancel ownership refusal. Full appkit suite: 1448 tests passing (+23 from Layer 3). Signed-off-by: MarioCadenas `autoInheritTools` now defaults to `{ file: false, code: false }`. Markdown and code-defined agents with no explicit `tools:` declaration receive an empty tool index unless the developer explicitly opts in. When opted in (`autoInheritTools: { file: true }` or the boolean shorthand), `applyAutoInherit` now filters the spread strictly by each `ToolkitEntry.autoInheritable` flag (set on the source `ToolEntry` in PR #302). Any tool not marked `autoInheritable: true` is skipped and logged so the operator can see exactly what the safe default omits. Providers exposing tools only via `getAgentTools()` (no `toolkit()`) cannot be filtered per tool, so their entries are conservatively skipped during auto-inherit and must be wired explicitly via `tools:`. This removes the silent privilege-amplification path where registering a plugin implicitly granted its entire tool surface to every markdown agent. - New: safe default produces an empty tool index for both file and code agents even when an `autoInheritable: true` tool exists. - New: `autoInheritTools: { file: true }` spreads only the tool marked `autoInheritable: true`; an adjacent unmarked tool is skipped. - New: `autoInheritTools: true` (boolean shorthand) enables both origins and still filters by `autoInheritable`. - Updated: the prior "asymmetric default" test now validates the new safe-default semantics (empty index on both origins). Full appkit suite: 1451 tests passing (+3 from S-3 Layer 2). Signed-off-by: MarioCadenas Five small correctness + DX fixes rounding out the MVP blocker list. - **Plugin name `agent` → `agents`.** The manifest name now matches the public runtime key (`appkit.agents.*`) and the factory export. Previously the cast in the reference app masked a real typing gap. - **`maxSteps` / `maxTokens` frontmatter / AgentDefinition fields now flow into the adapter.** Previously `resolveAdapter` built `DatabricksAdapter.fromModelServing(source)` without passing either value, so the documented knobs were silent no-ops. Now threaded through as `adapterOptions` when AppKit constructs the adapter itself (string model or omitted model); user-supplied `AgentAdapter` instances own their own settings as before. - **Single-`message` assistant turns are now persisted to the thread store.** The stream accumulator previously only handled `message_delta`; any adapter that yields a single final `message` (notably LangChain's `on_chain_end` path) silently dropped the assistant turn from thread history, so multi-turn LangChain conversations lost context. The loop now accumulates both kinds, with `message` replacing previously-accumulated deltas. - **Void-tool return no longer coerced into a fake error.** A tool handler that legitimately returns `undefined` (side-effecting fire-and-forget tools) was being reported to the LLM as `Error: Tool "x" execution failed`. Now returns `""` so the model sees a successful-but-empty result. - **Default `InMemoryThreadStore` is now loud about being dev-only.** Constructor logs an INFO in development and a WARN in production when no `threadStore` is supplied. Docstring rewritten to state unambiguously that the default is intended for local development / demos only, and points at a follow-up for a capped variant. Real caps + a persistent implementation are tracked as follow-ups. Signed-off-by: MarioCadenas New `ephemeral?: boolean` field on `AgentDefinition` and the `ephemeral` markdown frontmatter key. When set, the thread created for a chat request against that agent is deleted from `ThreadStore` in the stream's `finally`. Intended for stateless one-shot agents (e.g. autocomplete) where each invocation is independent and retaining history would both poison future calls and accumulate unbounded state in the default `InMemoryThreadStore`. This closes the "autocomplete agent creates orphan thread per keystroke" regression flagged in the performance re-review (R1), which otherwise would have put an in-tree memory-leak demonstrator against the one perf finding S-2/S-3 consciously deferred. Cleanup errors in the finally block are logged at warn level so a late delete never masks the real response. `RegisteredAgent` mirrors the flag. `load-agents.ts` adds `ephemeral` to `ALLOWED_KEYS`. Signed-off-by: MarioCadenas Rewrote `event-translator.ts` to allocate the message's `output_index` lazily on the first `message_delta` or `message` and to close any open message before emitting a subsequent `tool_call` / `tool_result` item. The previous implementation hardcoded `output_index: 0` for messages and incremented a counter starting at 1 for tool items, so any ReAct-style flow (tool call before text) produced `output_item. added` at index 1 followed by `output_item.added` at index 0 — monotonicity violation that OpenAI's own Responses-API SDK parsers enforce. Also fixed the companion bug from the original review: `message` after preceding `message_delta`s no longer double-emits `output_item.added` (it just emits the `done`), and `handleToolResult` coalesces `undefined` to `""` at the translator layer so the wire shape is always a string for every adapter (not just the ones that funnel through `agents.ts` executeTool). Four new regression tests pin the invariants: tool_call → text ordering, message-interrupted-by-tool, no duplicate added on full- message, undefined-tool-result → empty-string output. The one remaining HIGH security item from the prior review is now closed. Minimal, static caps at the schema layer; configurable per-deployment caps at runtime. Schemas: - `chatRequestSchema.message`: `.max(64_000)` — ~16k tokens, well above any legitimate chat turn. - `invocationsRequestSchema.input`: string `.max(64_000)`, array `.max(100)` items, per-item `content` string `.max(64_000)` or array `.max(100)` items. Runtime limits (new `AgentsPluginConfig.limits`): - `maxConcurrentStreamsPerUser` (default 5): `_handleChat` counts the user's active streams before admitting and returns HTTP 429 + `Retry-After: 5` when at-limit. Per-user, not global. - `maxToolCalls` (default 50): per-run budget tracked in the `executeTool` closure across the top-level adapter and any sub-agent delegations. Exceeding aborts the stream. - `maxSubAgentDepth` (default 3): `runSubAgent` rejects before any adapter work when the recursion depth exceeds the limit. Protects against a prompt-injected agent that delegates to itself transitively. 15 new tests exercise body caps (6), per-user limit with and without override (3), defaults and overrides on `resolvedLimits` (2), sub-agent depth boundary + violation (2), plus the remaining schema checks (2). Full appkit vitest suite: 1475 tests passing (+19 from this pass). Signed-off-by: MarioCadenas --- knip.json | 1 + packages/appkit/package.json | 2 + packages/appkit/src/core/create-agent-def.ts | 53 + packages/appkit/src/core/run-agent.ts | 226 +++ packages/appkit/src/index.ts | 25 +- packages/appkit/src/plugins/agents/agents.ts | 1269 +++++++++++++++++ .../appkit/src/plugins/agents/defaults.ts | 12 + .../src/plugins/agents/event-channel.ts | 70 + .../src/plugins/agents/event-translator.ts | 291 ++++ packages/appkit/src/plugins/agents/index.ts | 22 + .../appkit/src/plugins/agents/load-agents.ts | 370 +++++ .../appkit/src/plugins/agents/manifest.json | 10 + packages/appkit/src/plugins/agents/schemas.ts | 69 + .../src/plugins/agents/system-prompt.ts | 40 + .../agents/tests/agents-plugin.test.ts | 373 +++++ .../agents/tests/approval-route.test.ts | 292 ++++ .../plugins/agents/tests/create-agent.test.ts | 75 + .../plugins/agents/tests/dos-limits.test.ts | 299 ++++ .../agents/tests/event-channel.test.ts | 78 + .../agents/tests/event-translator.test.ts | 332 +++++ .../plugins/agents/tests/load-agents.test.ts | 302 ++++ .../plugins/agents/tests/run-agent.test.ts | 120 ++ .../agents/tests/system-prompt.test.ts | 45 + .../plugins/agents/tests/thread-store.test.ts | 138 ++ .../agents/tests/tool-approval-gate.test.ts | 156 ++ .../appkit/src/plugins/agents/thread-store.ts | 66 + .../src/plugins/agents/tool-approval-gate.ts | 122 ++ packages/appkit/src/plugins/agents/types.ts | 177 ++- packages/shared/src/agent.ts | 36 +- pnpm-lock.yaml | 11 + 30 files changed, 5071 insertions(+), 11 deletions(-) create mode 100644 packages/appkit/src/core/create-agent-def.ts create mode 100644 packages/appkit/src/core/run-agent.ts create mode 100644 packages/appkit/src/plugins/agents/agents.ts create mode 100644 packages/appkit/src/plugins/agents/defaults.ts create mode 100644 packages/appkit/src/plugins/agents/event-channel.ts create mode 100644 packages/appkit/src/plugins/agents/event-translator.ts create mode 100644 packages/appkit/src/plugins/agents/index.ts create mode 100644 packages/appkit/src/plugins/agents/load-agents.ts create mode 100644 packages/appkit/src/plugins/agents/manifest.json create mode 100644 packages/appkit/src/plugins/agents/schemas.ts create mode 100644 packages/appkit/src/plugins/agents/system-prompt.ts create mode 100644 packages/appkit/src/plugins/agents/tests/agents-plugin.test.ts create mode 100644 packages/appkit/src/plugins/agents/tests/approval-route.test.ts create mode 100644 packages/appkit/src/plugins/agents/tests/create-agent.test.ts create mode 100644 packages/appkit/src/plugins/agents/tests/dos-limits.test.ts create mode 100644 packages/appkit/src/plugins/agents/tests/event-channel.test.ts create mode 100644 packages/appkit/src/plugins/agents/tests/event-translator.test.ts create mode 100644 packages/appkit/src/plugins/agents/tests/load-agents.test.ts create mode 100644 packages/appkit/src/plugins/agents/tests/run-agent.test.ts create mode 100644 packages/appkit/src/plugins/agents/tests/system-prompt.test.ts create mode 100644 packages/appkit/src/plugins/agents/tests/thread-store.test.ts create mode 100644 packages/appkit/src/plugins/agents/tests/tool-approval-gate.test.ts create mode 100644 packages/appkit/src/plugins/agents/thread-store.ts create mode 100644 packages/appkit/src/plugins/agents/tool-approval-gate.ts diff --git a/knip.json b/knip.json index 13a43187..878dd3f5 100644 --- a/knip.json +++ b/knip.json @@ -20,6 +20,7 @@ "**/*.css", "packages/appkit/src/plugins/vector-search/**", "packages/appkit/src/plugin/index.ts", + "packages/appkit/src/plugin/to-plugin.ts", "packages/appkit/src/plugins/agents/index.ts", "packages/appkit/src/plugins/agents/tools/index.ts", "packages/appkit/src/plugins/agents/from-plugin.ts", diff --git a/packages/appkit/package.json b/packages/appkit/package.json index 3d5b1ddf..83e62814 100644 --- a/packages/appkit/package.json +++ b/packages/appkit/package.json @@ -83,6 +83,7 @@ "@types/semver": "7.7.1", "dotenv": "16.6.1", "express": "4.22.0", + "js-yaml": "^4.1.1", "obug": "2.1.1", "pg": "8.18.0", "picocolors": "1.1.1", @@ -108,6 +109,7 @@ "@ai-sdk/openai": "4.0.0-beta.27", "@langchain/core": "^1.1.39", "@types/express": "4.17.25", + "@types/js-yaml": "^4.0.9", "@types/json-schema": "7.0.15", "@types/pg": "8.16.0", "@types/ws": "8.18.1", diff --git a/packages/appkit/src/core/create-agent-def.ts b/packages/appkit/src/core/create-agent-def.ts new file mode 100644 index 00000000..3e93371d --- /dev/null +++ b/packages/appkit/src/core/create-agent-def.ts @@ -0,0 +1,53 @@ +import { ConfigurationError } from "../errors"; +import type { AgentDefinition } from "../plugins/agents/types"; + +/** + * Pure factory for agent definitions. Returns the passed-in definition after + * cycle-detecting the sub-agent graph. Accepts the full `AgentDefinition` shape + * and is safe to call at module top-level. + * + * The returned value is a plain `AgentDefinition` — no adapter construction, + * no side effects. Register it with `agents({ agents: { name: def } })` or run + * it standalone via `runAgent(def, input)`. + * + * @example + * ```ts + * const support = createAgent({ + * instructions: "You help customers.", + * model: "databricks-claude-sonnet-4-5", + * tools: { + * get_weather: tool({ ... }), + * }, + * }); + * ``` + */ +export function createAgent(def: AgentDefinition): AgentDefinition { + detectCycles(def); + return def; +} + +/** + * Walks the `agents: { ... }` sub-agent tree via DFS and throws if a cycle is + * found. Cycles would cause infinite recursion at tool-invocation time. + */ +function detectCycles(def: AgentDefinition): void { + const visiting = new Set(); + const visited = new Set(); + + const walk = (current: AgentDefinition, path: string[]): void => { + if (visited.has(current)) return; + if (visiting.has(current)) { + throw new ConfigurationError( + `Agent sub-agent cycle detected: ${path.join(" -> ")}`, + ); + } + visiting.add(current); + for (const [childKey, child] of Object.entries(current.agents ?? {})) { + walk(child, [...path, childKey]); + } + visiting.delete(current); + visited.add(current); + }; + + walk(def, [def.name ?? "(root)"]); +} diff --git a/packages/appkit/src/core/run-agent.ts b/packages/appkit/src/core/run-agent.ts new file mode 100644 index 00000000..e83c2c9c --- /dev/null +++ b/packages/appkit/src/core/run-agent.ts @@ -0,0 +1,226 @@ +import { randomUUID } from "node:crypto"; +import type { + AgentAdapter, + AgentEvent, + AgentToolDefinition, + Message, +} from "shared"; +import { + type FunctionTool, + functionToolToDefinition, + isFunctionTool, +} from "../plugins/agents/tools/function-tool"; +import { isHostedTool } from "../plugins/agents/tools/hosted-tools"; +import type { + AgentDefinition, + AgentTool, + ToolkitEntry, +} from "../plugins/agents/types"; +import { isToolkitEntry } from "../plugins/agents/types"; + +export interface RunAgentInput { + /** Seed messages for the run. Either a single user string or a full message list. */ + messages: string | Message[]; + /** Abort signal for cancellation. */ + signal?: AbortSignal; +} + +export interface RunAgentResult { + /** Aggregated text output from all `message_delta` events. */ + text: string; + /** Every event the adapter yielded, in order. Useful for inspection/tests. */ + events: AgentEvent[]; +} + +/** + * Standalone agent execution without `createApp`. Resolves the adapter, binds + * inline tools, and drives the adapter's `run()` loop to completion. + * + * Limitations vs. running through the agents() plugin: + * - No OBO: there is no HTTP request, so plugin tools run as the service + * principal (when they work at all). + * - Plugin tools (`ToolkitEntry`) are not supported — they require a live + * `PluginContext` that only exists when registered in a `createApp` + * instance. This function throws a clear error if encountered. + * - Sub-agents (`agents: { ... }` on the def) are executed as nested + * `runAgent` calls with no shared thread state. + */ +export async function runAgent( + def: AgentDefinition, + input: RunAgentInput, +): Promise { + const adapter = await resolveAdapter(def); + const messages = normalizeMessages(input.messages, def.instructions); + const toolIndex = buildStandaloneToolIndex(def); + const tools = Array.from(toolIndex.values()).map((e) => e.def); + + const signal = input.signal; + + const executeTool = async (name: string, args: unknown): Promise => { + const entry = toolIndex.get(name); + if (!entry) throw new Error(`Unknown tool: ${name}`); + if (entry.kind === "function") { + return entry.tool.execute(args as Record); + } + if (entry.kind === "subagent") { + const subInput: RunAgentInput = { + messages: + typeof args === "object" && + args !== null && + typeof (args as { input?: unknown }).input === "string" + ? (args as { input: string }).input + : JSON.stringify(args), + signal, + }; + const res = await runAgent(entry.agentDef, subInput); + return res.text; + } + throw new Error( + `runAgent: tool "${name}" is a ${entry.kind} tool. ` + + "Plugin toolkits and MCP tools are only usable via createApp({ plugins: [..., agents(...)] }).", + ); + }; + + const events: AgentEvent[] = []; + let text = ""; + + const stream = adapter.run( + { + messages, + tools, + threadId: randomUUID(), + signal, + }, + { executeTool, signal }, + ); + + for await (const event of stream) { + if (signal?.aborted) break; + events.push(event); + if (event.type === "message_delta") { + text += event.content; + } else if (event.type === "message") { + text = event.content; + } + } + + return { text, events }; +} + +async function resolveAdapter(def: AgentDefinition): Promise { + const { model } = def; + if (!model) { + const { DatabricksAdapter } = await import("../agents/databricks"); + return DatabricksAdapter.fromModelServing(); + } + if (typeof model === "string") { + const { DatabricksAdapter } = await import("../agents/databricks"); + return DatabricksAdapter.fromModelServing(model); + } + return await model; +} + +function normalizeMessages( + input: string | Message[], + instructions: string, +): Message[] { + const systemMessage: Message = { + id: "system", + role: "system", + content: instructions, + createdAt: new Date(), + }; + if (typeof input === "string") { + return [ + systemMessage, + { + id: randomUUID(), + role: "user", + content: input, + createdAt: new Date(), + }, + ]; + } + return [systemMessage, ...input]; +} + +type StandaloneEntry = + | { + kind: "function"; + def: AgentToolDefinition; + tool: FunctionTool; + } + | { + kind: "subagent"; + def: AgentToolDefinition; + agentDef: AgentDefinition; + } + | { + kind: "toolkit"; + def: AgentToolDefinition; + entry: ToolkitEntry; + } + | { + kind: "hosted"; + def: AgentToolDefinition; + }; + +function buildStandaloneToolIndex( + def: AgentDefinition, +): Map { + const index = new Map(); + + for (const [key, tool] of Object.entries(def.tools ?? {})) { + index.set(key, classifyTool(key, tool)); + } + + for (const [childKey, child] of Object.entries(def.agents ?? {})) { + const toolName = `agent-${childKey}`; + index.set(toolName, { + kind: "subagent", + agentDef: { ...child, name: child.name ?? childKey }, + def: { + name: toolName, + description: + child.instructions.slice(0, 120) || + `Delegate to the ${childKey} sub-agent`, + parameters: { + type: "object", + properties: { + input: { + type: "string", + description: "Message to send to the sub-agent.", + }, + }, + required: ["input"], + }, + }, + }); + } + + return index; +} + +function classifyTool(key: string, tool: AgentTool): StandaloneEntry { + if (isToolkitEntry(tool)) { + return { kind: "toolkit", def: { ...tool.def, name: key }, entry: tool }; + } + if (isFunctionTool(tool)) { + return { + kind: "function", + tool, + def: { ...functionToolToDefinition(tool), name: key }, + }; + } + if (isHostedTool(tool)) { + return { + kind: "hosted", + def: { + name: key, + description: `Hosted tool: ${tool.type}`, + parameters: { type: "object", properties: {} }, + }, + }; + } + throw new Error(`runAgent: unrecognized tool shape at key "${key}"`); +} diff --git a/packages/appkit/src/index.ts b/packages/appkit/src/index.ts index 6c6c6f5b..23f72216 100644 --- a/packages/appkit/src/index.ts +++ b/packages/appkit/src/index.ts @@ -43,6 +43,12 @@ export { } from "./connectors/lakebase"; export { getExecutionContext } from "./context"; export { createApp } from "./core"; +export { createAgent } from "./core/create-agent-def"; +export { + type RunAgentInput, + type RunAgentResult, + runAgent, +} from "./core/run-agent"; // Errors export { AppKitError, @@ -63,6 +69,19 @@ export { toPlugin, } from "./plugin"; export { analytics, files, genie, lakebase, server, serving } from "./plugins"; +export { + type AgentDefinition, + type AgentsPluginConfig, + type AgentTool, + agents, + type BaseSystemPromptOption, + isToolkitEntry, + loadAgentFromFile, + loadAgentsFromDir, + type PromptContext, + type ToolkitEntry, + type ToolkitOptions, +} from "./plugins/agents"; export { type FunctionTool, type HostedTool, @@ -72,12 +91,6 @@ export { type ToolConfig, tool, } from "./plugins/agents/tools"; -export { - type AgentTool, - isToolkitEntry, - type ToolkitEntry, - type ToolkitOptions, -} from "./plugins/agents/types"; // Files plugin types (for custom policy authoring) export type { FileAction, diff --git a/packages/appkit/src/plugins/agents/agents.ts b/packages/appkit/src/plugins/agents/agents.ts new file mode 100644 index 00000000..07a637fd --- /dev/null +++ b/packages/appkit/src/plugins/agents/agents.ts @@ -0,0 +1,1269 @@ +import { randomUUID } from "node:crypto"; +import path from "node:path"; +import type express from "express"; +import pc from "picocolors"; +import type { + AgentAdapter, + AgentEvent, + AgentRunContext, + AgentToolDefinition, + IAppRouter, + Message, + PluginPhase, + ResponseStreamEvent, + Thread, + ToolProvider, +} from "shared"; +import { createLogger } from "../../logging/logger"; +import { Plugin, toPlugin } from "../../plugin"; +import type { PluginManifest } from "../../registry"; +import { agentStreamDefaults } from "./defaults"; +import { EventChannel } from "./event-channel"; +import { AgentEventTranslator } from "./event-translator"; +import { loadAgentsFromDir } from "./load-agents"; +import manifest from "./manifest.json"; +import { + approvalRequestSchema, + chatRequestSchema, + invocationsRequestSchema, +} from "./schemas"; +import { buildBaseSystemPrompt, composeSystemPrompt } from "./system-prompt"; +import { InMemoryThreadStore } from "./thread-store"; +import { ToolApprovalGate } from "./tool-approval-gate"; +import { + AppKitMcpClient, + functionToolToDefinition, + isFunctionTool, + isHostedTool, + resolveHostedTools, +} from "./tools"; +import { buildMcpHostPolicy } from "./tools/mcp-host-policy"; +import type { + AgentDefinition, + AgentsPluginConfig, + BaseSystemPromptOption, + PromptContext, + RegisteredAgent, + ResolvedToolEntry, +} from "./types"; +import { isToolkitEntry } from "./types"; + +const logger = createLogger("agents"); + +const DEFAULT_AGENTS_DIR = "./config/agents"; + +/** + * Context flag recorded on the in-memory AgentDefinition to indicate whether + * it came from markdown (file) or from user code. Drives the asymmetric + * `autoInheritTools` default. + */ +interface AgentSource { + origin: "file" | "code"; +} + +export class AgentsPlugin extends Plugin implements ToolProvider { + static manifest = manifest as PluginManifest; + static phase: PluginPhase = "deferred"; + + protected declare config: AgentsPluginConfig; + + private agents = new Map(); + private defaultAgentName: string | null = null; + private activeStreams = new Map< + string, + { controller: AbortController; userId: string } + >(); + private mcpClient: AppKitMcpClient | null = null; + private threadStore; + private approvalGate = new ToolApprovalGate(); + + constructor(config: AgentsPluginConfig) { + super(config); + this.config = config; + if (config.threadStore) { + this.threadStore = config.threadStore; + } else { + this.threadStore = new InMemoryThreadStore(); + if (process.env.NODE_ENV === "production") { + logger.warn( + "InMemoryThreadStore is in use in a production build (NODE_ENV=production). " + + "Thread history is unbounded and lost on restart. " + + "Pass agents({ threadStore: }) for real deployments.", + ); + } else { + logger.info( + "Using default InMemoryThreadStore (dev-only — threads are lost on restart and grow without bound).", + ); + } + } + } + + /** Effective approval policy with defaults applied. */ + private get resolvedApprovalPolicy(): { + requireForDestructive: boolean; + timeoutMs: number; + } { + const cfg = this.config.approval ?? {}; + return { + requireForDestructive: cfg.requireForDestructive ?? true, + timeoutMs: cfg.timeoutMs ?? 60_000, + }; + } + + /** Effective DoS limits with defaults applied. */ + private get resolvedLimits(): { + maxConcurrentStreamsPerUser: number; + maxToolCalls: number; + maxSubAgentDepth: number; + } { + const cfg = this.config.limits ?? {}; + return { + maxConcurrentStreamsPerUser: cfg.maxConcurrentStreamsPerUser ?? 5, + maxToolCalls: cfg.maxToolCalls ?? 50, + maxSubAgentDepth: cfg.maxSubAgentDepth ?? 3, + }; + } + + /** Count active streams owned by a given user. */ + private countUserStreams(userId: string): number { + let n = 0; + for (const entry of this.activeStreams.values()) { + if (entry.userId === userId) n++; + } + return n; + } + + async setup() { + await this.loadAgents(); + this.mountInvocationsRoute(); + this.printRegistry(); + } + + /** + * Reload agents from the configured directory, preserving code-defined + * agents. Swaps the registry atomically at the end. + */ + async reload(): Promise { + this.agents.clear(); + this.defaultAgentName = null; + if (this.mcpClient) { + await this.mcpClient.close(); + this.mcpClient = null; + } + await this.loadAgents(); + } + + private async loadAgents() { + const { defs: fileDefs, defaultAgent: fileDefault } = + await this.loadFileDefinitions(); + + const codeDefs = this.config.agents ?? {}; + + for (const name of Object.keys(fileDefs)) { + if (codeDefs[name]) { + logger.warn( + "Agent '%s' defined in both code and a markdown file. Code definition takes precedence.", + name, + ); + } + } + + const merged: Record = + {}; + for (const [name, def] of Object.entries(fileDefs)) { + merged[name] = { def, src: { origin: "file" } }; + } + for (const [name, def] of Object.entries(codeDefs)) { + merged[name] = { def, src: { origin: "code" } }; + } + + if (Object.keys(merged).length === 0) { + logger.info( + "No agents registered (no files in %s, no code-defined agents)", + this.resolvedAgentsDir() ?? "", + ); + return; + } + + for (const [name, { def, src }] of Object.entries(merged)) { + try { + const registered = await this.buildRegisteredAgent(name, def, src); + this.agents.set(name, registered); + if (!this.defaultAgentName) this.defaultAgentName = name; + } catch (err) { + throw new Error( + `Failed to register agent '${name}' (${src.origin}): ${ + err instanceof Error ? err.message : String(err) + }`, + { cause: err instanceof Error ? err : undefined }, + ); + } + } + + if (this.config.defaultAgent) { + if (!this.agents.has(this.config.defaultAgent)) { + throw new Error( + `defaultAgent '${this.config.defaultAgent}' is not registered. Available: ${Array.from(this.agents.keys()).join(", ")}`, + ); + } + this.defaultAgentName = this.config.defaultAgent; + } else if (fileDefault && this.agents.has(fileDefault)) { + this.defaultAgentName = fileDefault; + } + } + + private resolvedAgentsDir(): string | null { + if (this.config.dir === false) return null; + const dir = this.config.dir ?? DEFAULT_AGENTS_DIR; + return path.isAbsolute(dir) ? dir : path.resolve(process.cwd(), dir); + } + + private async loadFileDefinitions(): Promise<{ + defs: Record; + defaultAgent: string | null; + }> { + const dir = this.resolvedAgentsDir(); + if (!dir) return { defs: {}, defaultAgent: null }; + + const pluginToolProviders = this.pluginProviderIndex(); + const ambient = this.config.tools ?? {}; + + const result = await loadAgentsFromDir(dir, { + defaultModel: this.config.defaultModel, + availableTools: ambient, + plugins: pluginToolProviders, + codeAgents: this.config.agents, + }); + + return result; + } + + /** + * Builds the map of plugin-name → toolkit that the markdown loader consults + * when resolving `toolkits:` frontmatter entries. + */ + private pluginProviderIndex(): Map< + string, + { toolkit: (opts?: unknown) => Record } + > { + const out = new Map(); + if (!this.context) return out; + for (const { name, provider } of this.context.getToolProviders()) { + const withToolkit = provider as ToolProvider & { + toolkit?: (opts?: unknown) => Record; + }; + if (typeof withToolkit.toolkit === "function") { + out.set(name, { + toolkit: withToolkit.toolkit.bind(withToolkit), + }); + } + } + return out; + } + + private async buildRegisteredAgent( + name: string, + def: AgentDefinition, + src: AgentSource, + ): Promise { + const adapter = await this.resolveAdapter(def, name); + const toolIndex = await this.buildToolIndex(name, def, src); + + return { + name, + instructions: def.instructions, + adapter, + toolIndex, + baseSystemPrompt: def.baseSystemPrompt, + maxSteps: def.maxSteps, + maxTokens: def.maxTokens, + ephemeral: def.ephemeral, + }; + } + + private async resolveAdapter( + def: AgentDefinition, + name: string, + ): Promise { + const source = def.model ?? this.config.defaultModel; + // Per-agent adapter knobs from `AgentDefinition` / markdown frontmatter. + // Only applied when AppKit builds the adapter itself (string or omitted + // model). Users who pass a pre-built `AgentAdapter` own these settings. + const adapterOptions: { maxSteps?: number; maxTokens?: number } = {}; + if (def.maxSteps !== undefined) adapterOptions.maxSteps = def.maxSteps; + if (def.maxTokens !== undefined) adapterOptions.maxTokens = def.maxTokens; + + if (!source) { + const { DatabricksAdapter } = await import("../../agents/databricks"); + try { + return await DatabricksAdapter.fromModelServing( + undefined, + adapterOptions, + ); + } catch (err) { + throw new Error( + `Agent '${name}' has no model configured and no DATABRICKS_AGENT_ENDPOINT default available`, + { cause: err instanceof Error ? err : undefined }, + ); + } + } + if (typeof source === "string") { + const { DatabricksAdapter } = await import("../../agents/databricks"); + return DatabricksAdapter.fromModelServing(source, adapterOptions); + } + return await source; + } + + /** + * Resolves an agent's tool record into a per-agent dispatch index. Connects + * hosted tools via MCP client. Applies `autoInheritTools` defaults when the + * definition has no declared tools/agents. + */ + private async buildToolIndex( + agentName: string, + def: AgentDefinition, + src: AgentSource, + ): Promise> { + const index = new Map(); + const hasExplicitTools = def.tools && Object.keys(def.tools).length > 0; + const hasExplicitSubAgents = + def.agents && Object.keys(def.agents).length > 0; + + const inheritDefaults = normalizeAutoInherit(this.config.autoInheritTools); + const shouldInherit = + !hasExplicitTools && + !hasExplicitSubAgents && + (src.origin === "file" ? inheritDefaults.file : inheritDefaults.code); + + if (shouldInherit) { + await this.applyAutoInherit(agentName, index); + } + + // 1. Sub-agents → agent- + for (const [childKey, childDef] of Object.entries(def.agents ?? {})) { + const toolName = `agent-${childKey}`; + index.set(toolName, { + source: "subagent", + agentName: childDef.name ?? childKey, + def: { + name: toolName, + description: + childDef.instructions.slice(0, 120) || + `Delegate to the ${childKey} sub-agent`, + parameters: { + type: "object", + properties: { + input: { + type: "string", + description: "Message to send to the sub-agent.", + }, + }, + required: ["input"], + }, + }, + }); + } + + // 2. Explicit tools (toolkit entries, function tools, hosted tools) + const hostedToCollect: import("./tools/hosted-tools").HostedTool[] = []; + for (const [key, tool] of Object.entries(def.tools ?? {})) { + if (isToolkitEntry(tool)) { + index.set(key, { + source: "toolkit", + pluginName: tool.pluginName, + localName: tool.localName, + def: { ...tool.def, name: key }, + }); + continue; + } + if (isFunctionTool(tool)) { + index.set(key, { + source: "function", + functionTool: tool, + def: { ...functionToolToDefinition(tool), name: key }, + }); + continue; + } + if (isHostedTool(tool)) { + hostedToCollect.push(tool); + continue; + } + throw new Error( + `Agent '${agentName}' tool '${key}' has an unrecognized shape`, + ); + } + + if (hostedToCollect.length > 0) { + await this.connectHostedTools(hostedToCollect, index); + } + + return index; + } + + private async applyAutoInherit( + agentName: string, + index: Map, + ): Promise { + if (!this.context) return; + const inherited: string[] = []; + const skippedByPlugin = new Map(); + const recordSkip = (pluginName: string, localName: string) => { + const list = skippedByPlugin.get(pluginName) ?? []; + list.push(localName); + skippedByPlugin.set(pluginName, list); + }; + + for (const { + name: pluginName, + provider, + } of this.context.getToolProviders()) { + if (pluginName === this.name) continue; + const withToolkit = provider as ToolProvider & { + toolkit?: (opts?: unknown) => Record; + }; + if (typeof withToolkit.toolkit === "function") { + const entries = withToolkit.toolkit() as Record; + for (const [key, maybeEntry] of Object.entries(entries)) { + if (!isToolkitEntry(maybeEntry)) continue; + if (maybeEntry.autoInheritable !== true) { + recordSkip(maybeEntry.pluginName, maybeEntry.localName); + continue; + } + index.set(key, { + source: "toolkit", + pluginName: maybeEntry.pluginName, + localName: maybeEntry.localName, + def: { ...maybeEntry.def, name: key }, + }); + inherited.push(key); + } + continue; + } + // Fallback: providers without a toolkit() still expose getAgentTools(). + // These cannot be selectively opted in per tool, so we conservatively + // skip them during auto-inherit and require explicit `tools:` wiring. + for (const tool of provider.getAgentTools()) { + recordSkip(pluginName, tool.name); + } + } + + if (inherited.length > 0) { + logger.info( + "[agent %s] auto-inherited %d tool(s): %s", + agentName, + inherited.length, + inherited.join(", "), + ); + } + if (skippedByPlugin.size > 0) { + const summary = Array.from(skippedByPlugin.entries()) + .map(([p, tools]) => `${p}(${tools.length})`) + .join(", "); + logger.info( + "[agent %s] auto-inherit skipped %d tool(s) not marked autoInheritable: %s. Wire them explicitly via `tools:` if needed.", + agentName, + Array.from(skippedByPlugin.values()).reduce( + (n, list) => n + list.length, + 0, + ), + summary, + ); + } + } + + private async connectHostedTools( + hostedTools: import("./tools/hosted-tools").HostedTool[], + index: Map, + ): Promise { + let host: string | undefined; + let authenticate: () => Promise>; + + try { + const { getWorkspaceClient } = await import("../../context"); + const wsClient = getWorkspaceClient(); + await wsClient.config.ensureResolved(); + host = wsClient.config.host; + authenticate = async () => { + const headers = new Headers(); + await wsClient.config.authenticate(headers); + return Object.fromEntries(headers.entries()); + }; + } catch { + host = process.env.DATABRICKS_HOST; + authenticate = async (): Promise> => { + const token = process.env.DATABRICKS_TOKEN; + return token ? { Authorization: `Bearer ${token}` } : {}; + }; + } + + if (!host) { + logger.warn( + "No Databricks host available — skipping %d hosted tool(s)", + hostedTools.length, + ); + return; + } + + if (!this.mcpClient) { + const policy = buildMcpHostPolicy(this.config.mcp, host); + this.mcpClient = new AppKitMcpClient(host, authenticate, policy); + } + + const endpoints = resolveHostedTools(hostedTools); + await this.mcpClient.connectAll(endpoints); + + for (const def of this.mcpClient.getAllToolDefinitions()) { + index.set(def.name, { + source: "mcp", + mcpToolName: def.name, + def, + }); + } + } + + // ----------------- ToolProvider (no tools of our own) -------------------- + + getAgentTools(): AgentToolDefinition[] { + return []; + } + + async executeAgentTool(): Promise { + throw new Error("AgentsPlugin does not expose executeAgentTool directly"); + } + + // ----------------- Route mounting and handlers --------------------------- + + private mountInvocationsRoute() { + if (!this.context) return; + this.context.addRoute( + "post", + "/invocations", + (req: express.Request, res: express.Response) => { + this._handleInvocations(req, res); + }, + ); + } + + injectRoutes(router: IAppRouter) { + this.route(router, { + name: "chat", + method: "post", + path: "/chat", + handler: async (req, res) => this._handleChat(req, res), + }); + this.route(router, { + name: "cancel", + method: "post", + path: "/cancel", + handler: async (req, res) => this._handleCancel(req, res), + }); + this.route(router, { + name: "approve", + method: "post", + path: "/approve", + handler: async (req, res) => this._handleApprove(req, res), + }); + this.route(router, { + name: "threads", + method: "get", + path: "/threads", + handler: async (req, res) => this._handleListThreads(req, res), + }); + this.route(router, { + name: "thread", + method: "get", + path: "/threads/:threadId", + handler: async (req, res) => this._handleGetThread(req, res), + }); + this.route(router, { + name: "deleteThread", + method: "delete", + path: "/threads/:threadId", + handler: async (req, res) => this._handleDeleteThread(req, res), + }); + this.route(router, { + name: "info", + method: "get", + path: "/info", + handler: async (_req, res) => { + res.json({ + agents: Array.from(this.agents.keys()), + defaultAgent: this.defaultAgentName, + }); + }, + }); + } + + clientConfig(): Record { + return { + agents: Array.from(this.agents.keys()), + defaultAgent: this.defaultAgentName, + }; + } + + private async _handleChat(req: express.Request, res: express.Response) { + const parsed = chatRequestSchema.safeParse(req.body); + if (!parsed.success) { + res.status(400).json({ + error: "Invalid request", + details: parsed.error.flatten().fieldErrors, + }); + return; + } + const { message, threadId, agent: agentName } = parsed.data; + + const registered = this.resolveAgent(agentName); + if (!registered) { + res.status(400).json({ + error: agentName + ? `Agent "${agentName}" not found` + : "No agent registered", + }); + return; + } + + const userId = this.resolveUserId(req); + + // Reject early (before allocating a thread) when the user is already at + // their concurrent-stream limit. Prevents a misbehaving client from + // churning thread rows while being denied elsewhere. + const limits = this.resolvedLimits; + if (this.countUserStreams(userId) >= limits.maxConcurrentStreamsPerUser) { + res.setHeader("Retry-After", "5"); + res.status(429).json({ + error: `Too many concurrent streams for this user (limit ${limits.maxConcurrentStreamsPerUser}). Wait for an existing stream to complete before starting another.`, + }); + return; + } + + let thread = threadId ? await this.threadStore.get(threadId, userId) : null; + if (threadId && !thread) { + res.status(404).json({ error: `Thread ${threadId} not found` }); + return; + } + if (!thread) { + thread = await this.threadStore.create(userId); + } + + const userMessage: Message = { + id: randomUUID(), + role: "user", + content: message, + createdAt: new Date(), + }; + await this.threadStore.addMessage(thread.id, userId, userMessage); + return this._streamAgent(req, res, registered, thread, userId); + } + + private async _handleInvocations( + req: express.Request, + res: express.Response, + ) { + const parsed = invocationsRequestSchema.safeParse(req.body); + if (!parsed.success) { + res.status(400).json({ + error: "Invalid request", + details: parsed.error.flatten().fieldErrors, + }); + return; + } + const { input } = parsed.data; + const registered = this.resolveAgent(); + if (!registered) { + res.status(400).json({ error: "No agent registered" }); + return; + } + const userId = this.resolveUserId(req); + const thread = await this.threadStore.create(userId); + + if (typeof input === "string") { + await this.threadStore.addMessage(thread.id, userId, { + id: randomUUID(), + role: "user", + content: input, + createdAt: new Date(), + }); + } else { + for (const item of input) { + const role = (item.role ?? "user") as Message["role"]; + const content = + typeof item.content === "string" + ? item.content + : JSON.stringify(item.content ?? ""); + if (!content) continue; + await this.threadStore.addMessage(thread.id, userId, { + id: randomUUID(), + role, + content, + createdAt: new Date(), + }); + } + } + + return this._streamAgent(req, res, registered, thread, userId); + } + + private async _streamAgent( + req: express.Request, + res: express.Response, + registered: RegisteredAgent, + thread: Thread, + userId: string, + ): Promise { + const abortController = new AbortController(); + const signal = abortController.signal; + const requestId = randomUUID(); + this.activeStreams.set(requestId, { controller: abortController, userId }); + + const tools = Array.from(registered.toolIndex.values()).map((e) => e.def); + const approvalPolicy = this.resolvedApprovalPolicy; + const limits = this.resolvedLimits; + const outboundEvents = new EventChannel(); + const translator = new AgentEventTranslator(); + // Per-run tool-call budget (shared across the top-level adapter and any + // sub-agents it delegates to). Counted pre-dispatch so a prompt-injected + // agent cannot drain the budget silently via denied calls. + let toolCallsUsed = 0; + + const executeTool = async ( + name: string, + args: unknown, + ): Promise => { + if (toolCallsUsed >= limits.maxToolCalls) { + abortController.abort( + new Error( + `Tool-call budget exhausted (limit ${limits.maxToolCalls}).`, + ), + ); + throw new Error( + `Tool-call budget exhausted (limit ${limits.maxToolCalls}). Raise agents({ limits: { maxToolCalls } }) or review the agent's tool-selection logic.`, + ); + } + toolCallsUsed++; + + const entry = registered.toolIndex.get(name); + if (!entry) throw new Error(`Unknown tool: ${name}`); + + if ( + approvalPolicy.requireForDestructive && + entry.def.annotations?.destructive === true + ) { + const approvalId = randomUUID(); + for (const ev of translator.translate({ + type: "approval_pending", + approvalId, + streamId: requestId, + toolName: name, + args, + annotations: entry.def.annotations, + })) { + outboundEvents.push(ev); + } + const decision = await this.approvalGate.wait({ + approvalId, + streamId: requestId, + userId, + timeoutMs: approvalPolicy.timeoutMs, + }); + if (decision === "deny") { + return `Tool execution denied by user approval gate (tool: ${name}).`; + } + } + + let result: unknown; + if (entry.source === "toolkit") { + if (!this.context) { + throw new Error( + "Plugin tool execution requires PluginContext; this should never happen through createApp", + ); + } + result = await this.context.executeTool( + req, + entry.pluginName, + entry.localName, + args, + signal, + ); + } else if (entry.source === "function") { + result = await entry.functionTool.execute( + args as Record, + ); + } else if (entry.source === "mcp") { + if (!this.mcpClient) throw new Error("MCP client not connected"); + const oboToken = req.headers["x-forwarded-access-token"]; + const mcpAuth = + typeof oboToken === "string" + ? { Authorization: `Bearer ${oboToken}` } + : undefined; + result = await this.mcpClient.callTool( + entry.mcpToolName, + args, + mcpAuth, + ); + } else if (entry.source === "subagent") { + const childAgent = this.agents.get(entry.agentName); + if (!childAgent) + throw new Error(`Sub-agent not found: ${entry.agentName}`); + result = await this.runSubAgent(req, childAgent, args, signal, 1); + } + + // A `void` / `undefined` return is a legitimate tool outcome (e.g., a + // "send notification" side-effecting tool). Return an empty string so + // the LLM sees a successful-but-empty result rather than a bogus + // "execution failed" error. + if (result === undefined) { + return ""; + } + const MAX = 50_000; + const serialized = + typeof result === "string" ? result : JSON.stringify(result); + if (serialized.length > MAX) { + return `${serialized.slice(0, MAX)}\n\n[Result truncated: ${serialized.length} chars exceeds ${MAX} limit]`; + } + return result; + }; + + // Drive the adapter and the approval-event side-channel concurrently. + // Outbound events from both sources flow through `outboundEvents`; the + // generator below drains the channel in order. executeTool pushes + // approval-pending events into the same channel before awaiting the gate. + const driver = (async () => { + try { + for (const evt of translator.translate({ + type: "metadata", + data: { threadId: thread.id }, + })) { + outboundEvents.push(evt); + } + + const pluginNames = this.context + ? this.context + .getPluginNames() + .filter((n) => n !== this.name && n !== "server") + : []; + const fullPrompt = composePromptForAgent( + registered, + this.config.baseSystemPrompt, + { + agentName: registered.name, + pluginNames, + toolNames: tools.map((t) => t.name), + }, + ); + + const messagesWithSystem: Message[] = [ + { + id: "system", + role: "system", + content: fullPrompt, + createdAt: new Date(), + }, + ...thread.messages, + ]; + + const stream = registered.adapter.run( + { + messages: messagesWithSystem, + tools, + threadId: thread.id, + signal, + }, + { executeTool, signal }, + ); + + // Accumulate assistant output from BOTH streaming and non-streaming + // adapters. Delta-based adapters (Databricks, Vercel AI) emit + // `message_delta` chunks that we concatenate; adapters that yield a + // single final assistant message (e.g. LangChain's `on_chain_end` + // path) emit a `message` event whose content replaces whatever + // deltas already arrived. Without the `message` branch, multi-turn + // LangChain conversations silently dropped the assistant turn from + // thread history. + let fullContent = ""; + for await (const event of stream) { + if (signal.aborted) break; + if (event.type === "message_delta") { + fullContent += event.content; + } else if (event.type === "message") { + fullContent = event.content; + } + for (const translated of translator.translate(event)) { + outboundEvents.push(translated); + } + } + + if (fullContent) { + await this.threadStore.addMessage(thread.id, userId, { + id: randomUUID(), + role: "assistant", + content: fullContent, + createdAt: new Date(), + }); + } + + for (const evt of translator.finalize()) outboundEvents.push(evt); + } catch (error) { + if (signal.aborted) { + outboundEvents.close(); + return; + } + logger.error("Agent chat error: %O", error); + outboundEvents.close(error); + return; + } finally { + // Any pending approval gates for this stream are auto-denied so the + // adapter can unwind if it was still waiting. + this.approvalGate.abortStream(requestId); + this.activeStreams.delete(requestId); + // Stateless agents (e.g. autocomplete) don't persist history; drop + // the thread so `InMemoryThreadStore` doesn't accumulate one record + // per request. Swallow delete errors — the stream has already + // finished and the client has the response. + if (registered.ephemeral) { + try { + await this.threadStore.delete(thread.id, userId); + } catch (err) { + logger.warn( + "Failed to delete ephemeral thread %s: %O", + thread.id, + err, + ); + } + } + } + outboundEvents.close(); + })(); + + await this.executeStream( + res, + async function* () { + try { + for await (const ev of outboundEvents) { + yield ev; + } + } finally { + await driver.catch(() => undefined); + } + }, + { + ...agentStreamDefaults, + stream: { ...agentStreamDefaults.stream, streamId: requestId }, + }, + ); + } + + /** + * Runs a sub-agent in response to an `agent-` tool call. Returns the + * concatenated text output to hand back to the parent adapter as the tool + * result. + * + * `depth` starts at 1 for a top-level sub-agent invocation (i.e. the + * outer `_streamAgent` calls `runSubAgent(..., 1)`) and increments on + * each nested `runSubAgent` call. Depths exceeding + * `limits.maxSubAgentDepth` are rejected before any adapter work. + */ + private async runSubAgent( + req: express.Request, + child: RegisteredAgent, + args: unknown, + signal: AbortSignal, + depth: number, + ): Promise { + const limits = this.resolvedLimits; + if (depth > limits.maxSubAgentDepth) { + throw new Error( + `Sub-agent depth exceeded (limit ${limits.maxSubAgentDepth}). ` + + `Raise agents({ limits: { maxSubAgentDepth } }) or break the delegation cycle.`, + ); + } + + const input = + typeof args === "object" && + args !== null && + typeof (args as { input?: unknown }).input === "string" + ? (args as { input: string }).input + : JSON.stringify(args); + const childTools = Array.from(child.toolIndex.values()).map((e) => e.def); + + const childExecute = async ( + name: string, + childArgs: unknown, + ): Promise => { + const entry = child.toolIndex.get(name); + if (!entry) throw new Error(`Unknown tool in sub-agent: ${name}`); + if (entry.source === "toolkit" && this.context) { + return this.context.executeTool( + req, + entry.pluginName, + entry.localName, + childArgs, + signal, + ); + } + if (entry.source === "function") { + return entry.functionTool.execute(childArgs as Record); + } + if (entry.source === "subagent") { + const grandchild = this.agents.get(entry.agentName); + if (!grandchild) + throw new Error(`Sub-agent not found: ${entry.agentName}`); + return this.runSubAgent(req, grandchild, childArgs, signal, depth + 1); + } + if (entry.source === "mcp" && this.mcpClient) { + const oboToken = req.headers["x-forwarded-access-token"]; + const mcpAuth = + typeof oboToken === "string" + ? { Authorization: `Bearer ${oboToken}` } + : undefined; + return this.mcpClient.callTool(entry.mcpToolName, childArgs, mcpAuth); + } + throw new Error(`Unsupported sub-agent tool source: ${entry.source}`); + }; + + const runContext: AgentRunContext = { executeTool: childExecute, signal }; + + const pluginNames = this.context + ? this.context + .getPluginNames() + .filter((n) => n !== this.name && n !== "server") + : []; + const systemPrompt = composePromptForAgent( + child, + this.config.baseSystemPrompt, + { + agentName: child.name, + pluginNames, + toolNames: childTools.map((t) => t.name), + }, + ); + + const messages: Message[] = [ + { + id: "system", + role: "system", + content: systemPrompt, + createdAt: new Date(), + }, + { + id: randomUUID(), + role: "user", + content: input, + createdAt: new Date(), + }, + ]; + + let output = ""; + const events: AgentEvent[] = []; + for await (const event of child.adapter.run( + { messages, tools: childTools, threadId: randomUUID(), signal }, + runContext, + )) { + events.push(event); + if (event.type === "message_delta") output += event.content; + else if (event.type === "message") output = event.content; + } + return output; + } + + private async _handleCancel(req: express.Request, res: express.Response) { + const { streamId } = req.body as { streamId?: string }; + if (!streamId) { + res.status(400).json({ error: "streamId is required" }); + return; + } + const entry = this.activeStreams.get(streamId); + if (!entry) { + // Stream is unknown or already completed — idempotent no-op. + res.json({ cancelled: true }); + return; + } + const userId = this.resolveUserId(req); + if (entry.userId !== userId) { + res.status(403).json({ error: "Forbidden" }); + return; + } + entry.controller.abort("Cancelled by user"); + this.activeStreams.delete(streamId); + this.approvalGate.abortStream(streamId); + res.json({ cancelled: true }); + } + + private async _handleApprove(req: express.Request, res: express.Response) { + const parsed = approvalRequestSchema.safeParse(req.body); + if (!parsed.success) { + res.status(400).json({ + error: "Invalid request", + details: parsed.error.flatten().fieldErrors, + }); + return; + } + const { streamId, approvalId, decision } = parsed.data; + + const streamEntry = this.activeStreams.get(streamId); + if (!streamEntry) { + // Stream has already completed or never existed. Return 404 so the UI + // knows the approval token is no longer valid (the waiter, if any, has + // already been timed out or aborted). + res.status(404).json({ error: "Stream not found or already completed" }); + return; + } + + const userId = this.resolveUserId(req); + if (streamEntry.userId !== userId) { + res.status(403).json({ error: "Forbidden" }); + return; + } + + const result = this.approvalGate.submit({ approvalId, userId, decision }); + if (!result.ok) { + if (result.reason === "forbidden") { + res.status(403).json({ error: "Forbidden" }); + return; + } + res.status(404).json({ error: "Approval not found or already settled" }); + return; + } + + res.json({ decision }); + } + + private async _handleListThreads( + req: express.Request, + res: express.Response, + ) { + const userId = this.resolveUserId(req); + const threads = await this.threadStore.list(userId); + res.json({ threads }); + } + + private async _handleGetThread(req: express.Request, res: express.Response) { + const userId = this.resolveUserId(req); + const thread = await this.threadStore.get(req.params.threadId, userId); + if (!thread) { + res.status(404).json({ error: "Thread not found" }); + return; + } + res.json(thread); + } + + private async _handleDeleteThread( + req: express.Request, + res: express.Response, + ) { + const userId = this.resolveUserId(req); + const deleted = await this.threadStore.delete(req.params.threadId, userId); + if (!deleted) { + res.status(404).json({ error: "Thread not found" }); + return; + } + res.json({ deleted: true }); + } + + private resolveAgent(name?: string): RegisteredAgent | null { + if (name) return this.agents.get(name) ?? null; + if (this.defaultAgentName) { + return this.agents.get(this.defaultAgentName) ?? null; + } + const first = this.agents.values().next(); + return first.done ? null : first.value; + } + + private printRegistry(): void { + if (this.agents.size === 0) return; + console.log(""); + console.log(` ${pc.bold("Agents")} ${pc.dim(`(${this.agents.size})`)}`); + console.log(` ${pc.dim("─".repeat(60))}`); + for (const [name, reg] of this.agents) { + const tools = reg.toolIndex.size; + const marker = name === this.defaultAgentName ? pc.green("●") : " "; + console.log( + ` ${marker} ${pc.bold(name.padEnd(24))} ${pc.dim(`${tools} tools`)}`, + ); + } + console.log(` ${pc.dim("─".repeat(60))}`); + console.log(""); + } + + async shutdown(): Promise { + this.approvalGate.abortAll(); + if (this.mcpClient) { + await this.mcpClient.close(); + this.mcpClient = null; + } + } + + exports() { + return { + register: (name: string, def: AgentDefinition) => + this.registerCodeAgent(name, def), + list: () => Array.from(this.agents.keys()), + get: (name: string) => this.agents.get(name) ?? null, + reload: () => this.reload(), + getDefault: () => this.defaultAgentName, + getThreads: (userId: string) => this.threadStore.list(userId), + }; + } + + private async registerCodeAgent( + name: string, + def: AgentDefinition, + ): Promise { + const registered = await this.buildRegisteredAgent(name, def, { + origin: "code", + }); + this.agents.set(name, registered); + if (!this.defaultAgentName) this.defaultAgentName = name; + } +} + +function normalizeAutoInherit(value: AgentsPluginConfig["autoInheritTools"]): { + file: boolean; + code: boolean; +} { + // Default is opt-out for both origins. A markdown agent or code-defined + // agent with no declared `tools:` gets an empty tool index unless the + // developer explicitly flips `autoInheritTools` on. Even then, only tools + // whose plugin author marked `autoInheritable: true` are spread — see + // `applyAutoInherit` for the filter. + if (value === undefined) return { file: false, code: false }; + if (typeof value === "boolean") return { file: value, code: value }; + return { file: value.file ?? false, code: value.code ?? false }; +} + +function composePromptForAgent( + registered: RegisteredAgent, + pluginLevel: BaseSystemPromptOption | undefined, + ctx: PromptContext, +): string { + const perAgent = registered.baseSystemPrompt; + const resolved = perAgent !== undefined ? perAgent : pluginLevel; + + let base = ""; + if (resolved === false) { + base = ""; + } else if (typeof resolved === "string") { + base = resolved; + } else if (typeof resolved === "function") { + base = resolved(ctx); + } else { + base = buildBaseSystemPrompt(ctx.pluginNames); + } + + return composeSystemPrompt(base, registered.instructions); +} + +/** + * Plugin factory for the agents plugin. Reads `config/agents/*.md` by default, + * resolves toolkits/tools from registered plugins, exposes `appkit.agents.*` + * runtime API and mounts `/invocations`. + * + * @example + * ```ts + * import { agents, analytics, createApp, server } from "@databricks/appkit"; + * + * await createApp({ + * plugins: [server(), analytics(), agents()], + * }); + * ``` + */ +export const agents = toPlugin(AgentsPlugin); diff --git a/packages/appkit/src/plugins/agents/defaults.ts b/packages/appkit/src/plugins/agents/defaults.ts new file mode 100644 index 00000000..4da11bef --- /dev/null +++ b/packages/appkit/src/plugins/agents/defaults.ts @@ -0,0 +1,12 @@ +import type { StreamExecutionSettings } from "shared"; + +export const agentStreamDefaults: StreamExecutionSettings = { + default: { + cache: { enabled: false }, + retry: { enabled: false }, + timeout: 300_000, + }, + stream: { + bufferSize: 200, + }, +}; diff --git a/packages/appkit/src/plugins/agents/event-channel.ts b/packages/appkit/src/plugins/agents/event-channel.ts new file mode 100644 index 00000000..c5b60463 --- /dev/null +++ b/packages/appkit/src/plugins/agents/event-channel.ts @@ -0,0 +1,70 @@ +/** + * Single-producer/single-consumer async queue used by the agents plugin to + * merge streams of SSE events from two concurrent sources: the adapter's + * `run()` generator, and out-of-band events emitted by `executeTool` (e.g. + * human-approval requests). + * + * The consumer drains the channel as an async iterable; the producer pushes + * events synchronously and closes the channel when the source has completed + * or errored. + */ +interface Waiter { + resolve: (value: IteratorResult) => void; + reject: (error: unknown) => void; +} + +export class EventChannel { + private queue: T[] = []; + private waiters: Array> = []; + private closed = false; + private error: unknown = undefined; + + /** Synchronously enqueue an event. Safe to call from non-async contexts. */ + push(value: T): void { + if (this.closed) return; + const waiter = this.waiters.shift(); + if (waiter) { + waiter.resolve({ value, done: false }); + } else { + this.queue.push(value); + } + } + + /** + * Close the channel. Any pending `next()` calls resolve with `done: true`. + * If `error` is supplied, pending `next()` calls reject with it and future + * calls do the same. + */ + close(error?: unknown): void { + if (this.closed) return; + this.closed = true; + this.error = error; + while (this.waiters.length > 0) { + const waiter = this.waiters.shift(); + if (!waiter) break; + if (error) { + waiter.reject(error); + } else { + waiter.resolve({ value: undefined as never, done: true }); + } + } + } + + [Symbol.asyncIterator](): AsyncIterator { + return { + next: (): Promise> => { + if (this.queue.length > 0) { + const value = this.queue.shift() as T; + return Promise.resolve({ value, done: false }); + } + if (this.closed) { + if (this.error) return Promise.reject(this.error); + return Promise.resolve({ value: undefined as never, done: true }); + } + return new Promise((resolve, reject) => { + this.waiters.push({ resolve, reject }); + }); + }, + }; + } +} diff --git a/packages/appkit/src/plugins/agents/event-translator.ts b/packages/appkit/src/plugins/agents/event-translator.ts new file mode 100644 index 00000000..54d749fb --- /dev/null +++ b/packages/appkit/src/plugins/agents/event-translator.ts @@ -0,0 +1,291 @@ +import { randomUUID } from "node:crypto"; +import type { + AgentEvent, + ResponseFunctionCallOutput, + ResponseFunctionToolCall, + ResponseOutputMessage, + ResponseStreamEvent, +} from "shared"; + +/** + * Translates internal `AgentEvent` stream into Responses API SSE events. + * + * Stateful: one instance per streaming request. Tracks sequence numbers and + * allocates `output_index` strictly monotonically — each emitted output + * item (message, function call, function call output) claims the next + * available index and, once claimed, never reuses an earlier one. This is a + * Responses-API contract that OpenAI's own SDK parsers enforce. + * + * A message is opened lazily on the first `message_delta` or `message` + * event. If a `tool_call` or `tool_result` arrives while a message is open + * (common ReAct flow: partial text → tool call → more text), the open + * message is closed (`response.output_item.done`) BEFORE the tool item is + * added, so subsequent text resumes as a new message item at a strictly + * later index. + */ +export class AgentEventTranslator { + private seqNum = 0; + private nextOutputIndex = 0; + private currentMessage: { + id: string; + text: string; + outputIndex: number; + } | null = null; + private finalized = false; + + translate(event: AgentEvent): ResponseStreamEvent[] { + switch (event.type) { + case "message_delta": + return this.handleMessageDelta(event.content); + case "message": + return this.handleFullMessage(event.content); + case "tool_call": + return this.handleToolCall(event.callId, event.name, event.args); + case "tool_result": + return this.handleToolResult(event.callId, event.result, event.error); + case "thinking": + return [ + { + type: "appkit.thinking", + content: event.content, + sequence_number: this.seqNum++, + }, + ]; + case "metadata": + return [ + { + type: "appkit.metadata", + data: event.data, + sequence_number: this.seqNum++, + }, + ]; + case "approval_pending": + return [ + { + type: "appkit.approval_pending", + approval_id: event.approvalId, + stream_id: event.streamId, + tool_name: event.toolName, + args: event.args, + annotations: event.annotations, + sequence_number: this.seqNum++, + }, + ]; + case "status": + return this.handleStatus(event.status, event.error); + } + } + + finalize(): ResponseStreamEvent[] { + if (this.finalized) return []; + this.finalized = true; + + const events: ResponseStreamEvent[] = []; + const closeEvent = this.closeCurrentMessage(); + if (closeEvent) events.push(closeEvent); + + events.push({ + type: "response.completed", + sequence_number: this.seqNum++, + response: {}, + }); + + return events; + } + + private handleMessageDelta(content: string): ResponseStreamEvent[] { + const events: ResponseStreamEvent[] = []; + + if (!this.currentMessage) { + const id = `msg_${randomUUID()}`; + const outputIndex = this.nextOutputIndex++; + this.currentMessage = { id, text: content, outputIndex }; + const item: ResponseOutputMessage = { + type: "message", + id, + status: "in_progress", + role: "assistant", + content: [], + }; + events.push({ + type: "response.output_item.added", + output_index: outputIndex, + item, + sequence_number: this.seqNum++, + }); + } else { + this.currentMessage.text += content; + } + + events.push({ + type: "response.output_text.delta", + item_id: this.currentMessage.id, + output_index: this.currentMessage.outputIndex, + content_index: 0, + delta: content, + sequence_number: this.seqNum++, + }); + + return events; + } + + private handleFullMessage(content: string): ResponseStreamEvent[] { + const events: ResponseStreamEvent[] = []; + + if (!this.currentMessage) { + // No prior deltas — open and immediately close. + const id = `msg_${randomUUID()}`; + const outputIndex = this.nextOutputIndex++; + this.currentMessage = { id, text: content, outputIndex }; + const addedItem: ResponseOutputMessage = { + type: "message", + id, + status: "in_progress", + role: "assistant", + content: [], + }; + events.push({ + type: "response.output_item.added", + output_index: outputIndex, + item: addedItem, + sequence_number: this.seqNum++, + }); + } else { + // Deltas already opened the item; `message` overrides the accumulated + // text (per adapter contract) and closes it. + this.currentMessage.text = content; + } + + const closeEvent = this.closeCurrentMessage(); + if (closeEvent) events.push(closeEvent); + return events; + } + + private handleToolCall( + callId: string, + name: string, + args: unknown, + ): ResponseStreamEvent[] { + const events: ResponseStreamEvent[] = []; + const closeEvent = this.closeCurrentMessage(); + if (closeEvent) events.push(closeEvent); + + const outputIndex = this.nextOutputIndex++; + const item: ResponseFunctionToolCall = { + type: "function_call", + id: `fc_${randomUUID()}`, + call_id: callId, + name, + arguments: typeof args === "string" ? args : JSON.stringify(args), + }; + + events.push( + { + type: "response.output_item.added", + output_index: outputIndex, + item, + sequence_number: this.seqNum++, + }, + { + type: "response.output_item.done", + output_index: outputIndex, + item, + sequence_number: this.seqNum++, + }, + ); + return events; + } + + private handleToolResult( + callId: string, + result: unknown, + error?: string, + ): ResponseStreamEvent[] { + const events: ResponseStreamEvent[] = []; + const closeEvent = this.closeCurrentMessage(); + if (closeEvent) events.push(closeEvent); + + const outputIndex = this.nextOutputIndex++; + // Coalesce `undefined` → "" so the wire shape is always a string (the + // Responses API contract). Non-string results are JSON-serialised. + let output: string; + if (error !== undefined) { + output = error; + } else if (typeof result === "string") { + output = result; + } else if (result === undefined) { + output = ""; + } else { + output = JSON.stringify(result); + } + const item: ResponseFunctionCallOutput = { + type: "function_call_output", + id: `fc_output_${randomUUID()}`, + call_id: callId, + output, + }; + + events.push( + { + type: "response.output_item.added", + output_index: outputIndex, + item, + sequence_number: this.seqNum++, + }, + { + type: "response.output_item.done", + output_index: outputIndex, + item, + sequence_number: this.seqNum++, + }, + ); + return events; + } + + /** + * Emit an `response.output_item.done` for the currently-open message, if + * any, and clear the state. Returns the event to the caller so it can be + * pushed at the right moment in the sequence. Returns `null` when there + * is no open message. + */ + private closeCurrentMessage(): ResponseStreamEvent | null { + if (!this.currentMessage) return null; + const { id, text, outputIndex } = this.currentMessage; + this.currentMessage = null; + const doneItem: ResponseOutputMessage = { + type: "message", + id, + status: "completed", + role: "assistant", + content: [{ type: "output_text", text }], + }; + return { + type: "response.output_item.done", + output_index: outputIndex, + item: doneItem, + sequence_number: this.seqNum++, + }; + } + + private handleStatus(status: string, error?: string): ResponseStreamEvent[] { + if (status === "error") { + return [ + { + type: "error", + error: error ?? "Unknown error", + sequence_number: this.seqNum++, + }, + { + type: "response.failed", + sequence_number: this.seqNum++, + }, + ]; + } + + if (status === "complete") { + return this.finalize(); + } + + return []; + } +} diff --git a/packages/appkit/src/plugins/agents/index.ts b/packages/appkit/src/plugins/agents/index.ts new file mode 100644 index 00000000..1adc41c1 --- /dev/null +++ b/packages/appkit/src/plugins/agents/index.ts @@ -0,0 +1,22 @@ +export { AgentsPlugin, agents } from "./agents"; +export { buildToolkitEntries } from "./build-toolkit"; +export { + type LoadContext, + type LoadResult, + loadAgentFromFile, + loadAgentsFromDir, + parseFrontmatter, +} from "./load-agents"; +export { + type AgentDefinition, + type AgentsPluginConfig, + type AgentTool, + type AutoInheritToolsConfig, + type BaseSystemPromptOption, + isToolkitEntry, + type PromptContext, + type RegisteredAgent, + type ResolvedToolEntry, + type ToolkitEntry, + type ToolkitOptions, +} from "./types"; diff --git a/packages/appkit/src/plugins/agents/load-agents.ts b/packages/appkit/src/plugins/agents/load-agents.ts new file mode 100644 index 00000000..2d82799e --- /dev/null +++ b/packages/appkit/src/plugins/agents/load-agents.ts @@ -0,0 +1,370 @@ +import fs from "node:fs"; +import path from "node:path"; +import yaml from "js-yaml"; +import type { AgentAdapter } from "shared"; +import { createLogger } from "../../logging/logger"; +import type { + AgentDefinition, + AgentTool, + BaseSystemPromptOption, + ToolkitEntry, + ToolkitOptions, +} from "./types"; +import { isToolkitEntry } from "./types"; + +const logger = createLogger("agents:loader"); + +interface ToolkitProvider { + toolkit: (opts?: ToolkitOptions) => Record; +} + +export interface LoadContext { + /** Default model when frontmatter has no `endpoint` and the def has no `model`. */ + defaultModel?: AgentAdapter | Promise | string; + /** Ambient tool library referenced by frontmatter `tools: [key1, key2]`. */ + availableTools?: Record; + /** Registered plugin toolkits referenced by frontmatter `toolkits: [...]`. */ + plugins?: Map; + /** + * Code-defined agents contributed by `agents({ agents: { ... } })`. The + * directory loader resolves `agents:` frontmatter references against + * these alongside sibling markdown files, so a markdown parent can + * delegate to a code-defined child. Code-defined names win on collision + * with markdown names, matching the plugin's top-level merge precedence. + */ + codeAgents?: Record; +} + +export interface LoadResult { + /** Agent definitions keyed by file-stem name. */ + defs: Record; + /** First file with `default: true` frontmatter, or `null`. */ + defaultAgent: string | null; +} + +interface Frontmatter { + endpoint?: string; + model?: string; + toolkits?: ToolkitSpec[]; + tools?: string[]; + /** + * Sibling file-stems to expose as sub-agents. Each becomes an + * `agent-` tool on this agent at runtime. Resolution happens at + * directory-load time in {@link loadAgentsFromDir}; the single-file + * {@link loadAgentFromFile} path rejects non-empty values since there + * are no siblings to resolve against. + */ + agents?: string[]; + maxSteps?: number; + maxTokens?: number; + default?: boolean; + baseSystemPrompt?: false | string; + ephemeral?: boolean; +} + +type ToolkitSpec = string | { [pluginName: string]: ToolkitOptions | string[] }; + +const ALLOWED_KEYS = new Set([ + "endpoint", + "model", + "toolkits", + "tools", + "agents", + "maxSteps", + "maxTokens", + "default", + "baseSystemPrompt", + "ephemeral", +]); + +/** + * Loads a single markdown agent file and resolves its frontmatter against + * registered plugin toolkits + ambient tool library. + * + * Rejects non-empty `agents:` frontmatter because single-file loads have + * no siblings to resolve sub-agent references against — callers must use + * {@link loadAgentsFromDir} when markdown agents delegate to one another. + */ +export async function loadAgentFromFile( + filePath: string, + ctx: LoadContext, +): Promise { + const raw = fs.readFileSync(filePath, "utf-8"); + const name = path.basename(filePath, ".md"); + const { data } = parseFrontmatter(raw, filePath); + if (Array.isArray(data?.agents) && data.agents.length > 0) { + throw new Error( + `Agent '${name}' (${filePath}) declares 'agents:' in frontmatter, ` + + `which requires loadAgentsFromDir to resolve sibling references. ` + + `Use loadAgentsFromDir, or wire sub-agents in code via createAgent({ agents: { ... } }).`, + ); + } + return buildDefinition(name, raw, filePath, ctx); +} + +/** + * Scans a directory for `*.md` files and produces an `AgentDefinition` record + * keyed by file-stem. Throws on frontmatter errors or unresolved references. + * Returns an empty map if the directory does not exist. + * + * Runs in two passes so sub-agent references in frontmatter (`agents: [...]`) + * can be resolved regardless of file-system iteration order: + * + * 1. Build every agent's definition from its own file. + * 2. Walk `agents:` references and wire `def.agents = { sibling: siblingDef }` + * by looking them up in the complete map. Dangling names and + * self-references fail loudly; mutual delegation is allowed and bounded + * at runtime by `limits.maxSubAgentDepth`. + */ +export async function loadAgentsFromDir( + dir: string, + ctx: LoadContext, +): Promise { + if (!fs.existsSync(dir)) { + return { defs: {}, defaultAgent: null }; + } + // Sort so `default: true` resolution is deterministic across platforms — + // `readdirSync` order is filesystem-dependent (macOS alphabetical, ext4 + // inode order, etc.). + const files = fs + .readdirSync(dir) + .filter((f) => f.endsWith(".md")) + .sort(); + const defs: Record = {}; + const subAgentRefs: Record = {}; + let defaultAgent: string | null = null; + + // Pass 1: build every agent's definition; collect unresolved sibling refs. + for (const file of files) { + const fullPath = path.join(dir, file); + const raw = fs.readFileSync(fullPath, "utf-8"); + const name = path.basename(file, ".md"); + defs[name] = buildDefinition(name, raw, fullPath, ctx); + const { data } = parseFrontmatter(raw, fullPath); + if (data?.agents !== undefined) { + subAgentRefs[name] = normalizeAgentsFrontmatter( + data.agents, + name, + fullPath, + ); + } + if (data?.default === true && !defaultAgent) { + defaultAgent = name; + } + } + + // Pass 2: resolve sibling references against the complete defs map. + // Code-defined agents (ctx.codeAgents) take precedence over markdown ones + // with the same name, matching the plugin's top-level merge behaviour. + for (const [name, refs] of Object.entries(subAgentRefs)) { + if (refs.length === 0) continue; + const children: Record = {}; + const missing: string[] = []; + for (const ref of refs) { + if (ref === name) { + throw new Error( + `Agent '${name}' (${path.join(dir, `${name}.md`)}) cannot reference itself in 'agents:'.`, + ); + } + const sibling = ctx.codeAgents?.[ref] ?? defs[ref]; + if (!sibling) { + missing.push(ref); + continue; + } + children[ref] = sibling; + } + if (missing.length > 0) { + const available = + [...Object.keys(ctx.codeAgents ?? {}), ...Object.keys(defs)] + .sort() + .join(", ") || ""; + throw new Error( + `Agent '${name}' references sub-agent(s) '${missing.join(", ")}' in 'agents:', ` + + `but no markdown or code agent(s) with those names exist. ` + + `Available: ${available}.`, + ); + } + defs[name].agents = children; + } + + return { defs, defaultAgent }; +} + +/** + * Validates that `agents:` frontmatter is an array of non-empty strings and + * returns it with duplicates removed. Throws with a clear per-file message + * on malformed input rather than silently ignoring. + */ +function normalizeAgentsFrontmatter( + value: unknown, + agentName: string, + filePath: string, +): string[] { + if (!Array.isArray(value)) { + throw new Error( + `Agent '${agentName}' (${filePath}) has invalid 'agents:' frontmatter: ` + + `expected an array of sibling file-stems, got ${typeof value}.`, + ); + } + const out: string[] = []; + const seen = new Set(); + for (const item of value) { + if (typeof item !== "string" || item.trim() === "") { + throw new Error( + `Agent '${agentName}' (${filePath}) has invalid 'agents:' entry: ` + + `expected non-empty string, got ${JSON.stringify(item)}.`, + ); + } + if (seen.has(item)) continue; + seen.add(item); + out.push(item); + } + return out; +} + +/** Exposed for tests. Parses `--- yaml ---\nbody` and validates frontmatter keys. */ +export function parseFrontmatter( + raw: string, + sourcePath?: string, +): { data: Frontmatter | null; content: string } { + const match = raw.match(/^---\r?\n([\s\S]*?)\r?\n---\r?\n?([\s\S]*)$/); + if (!match) { + return { data: null, content: raw.trim() }; + } + let parsed: unknown; + try { + parsed = yaml.load(match[1]); + } catch (err) { + const src = sourcePath ? ` (${sourcePath})` : ""; + throw new Error( + `Invalid YAML frontmatter${src}: ${err instanceof Error ? err.message : String(err)}`, + ); + } + if (parsed === null || parsed === undefined) { + return { data: {}, content: match[2].trim() }; + } + if (typeof parsed !== "object" || Array.isArray(parsed)) { + const src = sourcePath ? ` (${sourcePath})` : ""; + throw new Error(`Frontmatter must be a YAML object${src}`); + } + const data = parsed as Record; + for (const key of Object.keys(data)) { + if (!ALLOWED_KEYS.has(key)) { + logger.warn( + "Ignoring unknown frontmatter key '%s' in %s", + key, + sourcePath ?? "", + ); + } + } + return { data: data as Frontmatter, content: match[2].trim() }; +} + +function buildDefinition( + name: string, + raw: string, + filePath: string, + ctx: LoadContext, +): AgentDefinition { + const { data, content } = parseFrontmatter(raw, filePath); + const fm: Frontmatter = data ?? {}; + + const tools = resolveFrontmatterTools(name, fm, filePath, ctx); + const model = fm.model ?? fm.endpoint ?? ctx.defaultModel; + + let baseSystemPrompt: BaseSystemPromptOption | undefined; + if (fm.baseSystemPrompt === false) baseSystemPrompt = false; + else if (typeof fm.baseSystemPrompt === "string") + baseSystemPrompt = fm.baseSystemPrompt; + + return { + name, + instructions: content, + model, + tools: Object.keys(tools).length > 0 ? tools : undefined, + maxSteps: typeof fm.maxSteps === "number" ? fm.maxSteps : undefined, + maxTokens: typeof fm.maxTokens === "number" ? fm.maxTokens : undefined, + baseSystemPrompt, + ephemeral: typeof fm.ephemeral === "boolean" ? fm.ephemeral : undefined, + }; +} + +function resolveFrontmatterTools( + agentName: string, + fm: Frontmatter, + filePath: string, + ctx: LoadContext, +): Record { + const out: Record = {}; + const pluginIdx = ctx.plugins ?? new Map(); + + for (const spec of fm.toolkits ?? []) { + const [pluginName, opts] = parseToolkitSpec(spec, filePath, agentName); + const provider = pluginIdx.get(pluginName); + if (!provider) { + throw new Error( + `Agent '${agentName}' (${filePath}) references toolkit '${pluginName}', but plugin '${pluginName}' is not registered. Available: ${ + pluginIdx.size > 0 + ? Array.from(pluginIdx.keys()).join(", ") + : "" + }`, + ); + } + const entries = provider.toolkit(opts) as Record; + for (const [key, entry] of Object.entries(entries)) { + if (!isToolkitEntry(entry)) { + throw new Error( + `Plugin '${pluginName}'.toolkit() returned a value at key '${key}' that is not a ToolkitEntry`, + ); + } + out[key] = entry as ToolkitEntry; + } + } + + for (const key of fm.tools ?? []) { + const tool = ctx.availableTools?.[key]; + if (!tool) { + const available = ctx.availableTools + ? Object.keys(ctx.availableTools).join(", ") + : ""; + throw new Error( + `Agent '${agentName}' (${filePath}) references tool '${key}', which is not in the agents() plugin's tools field. Available: ${available}`, + ); + } + out[key] = tool; + } + + return out; +} + +function parseToolkitSpec( + spec: ToolkitSpec, + filePath: string, + agentName: string, +): [string, ToolkitOptions | undefined] { + if (typeof spec === "string") { + return [spec, undefined]; + } + if (typeof spec !== "object" || spec === null) { + throw new Error( + `Agent '${agentName}' (${filePath}) has invalid toolkit entry: ${JSON.stringify(spec)}`, + ); + } + const keys = Object.keys(spec); + if (keys.length !== 1) { + throw new Error( + `Agent '${agentName}' (${filePath}) toolkit entry must have exactly one key, got: ${keys.join(", ")}`, + ); + } + const pluginName = keys[0]; + const value = spec[pluginName]; + if (Array.isArray(value)) { + return [pluginName, { only: value }]; + } + if (typeof value === "object" && value !== null) { + return [pluginName, value as ToolkitOptions]; + } + throw new Error( + `Agent '${agentName}' (${filePath}) toolkit '${pluginName}' options must be an array of tool names or an options object`, + ); +} diff --git a/packages/appkit/src/plugins/agents/manifest.json b/packages/appkit/src/plugins/agents/manifest.json new file mode 100644 index 00000000..cb7a43f8 --- /dev/null +++ b/packages/appkit/src/plugins/agents/manifest.json @@ -0,0 +1,10 @@ +{ + "$schema": "https://databricks.github.io/appkit/schemas/plugin-manifest.schema.json", + "name": "agents", + "displayName": "Agents Plugin", + "description": "AI agents driven by markdown configs or code, with auto-tool-discovery from registered plugins", + "resources": { + "required": [], + "optional": [] + } +} diff --git a/packages/appkit/src/plugins/agents/schemas.ts b/packages/appkit/src/plugins/agents/schemas.ts new file mode 100644 index 00000000..cea6c6d6 --- /dev/null +++ b/packages/appkit/src/plugins/agents/schemas.ts @@ -0,0 +1,69 @@ +import { z } from "zod"; + +/** + * Static body cap for the `message` field on `POST /chat`. 64 000 characters + * is well above any legitimate chat turn (~16k tokens at 4 chars/token) and + * bounds the per-request cost of appending to `InMemoryThreadStore` without + * requiring per-deployment configuration. + */ +const MAX_MESSAGE_CHARS = 64_000; + +/** Cap applied to `/invocations` when `input` is a raw string. */ +const MAX_INVOCATIONS_INPUT_CHARS = 64_000; + +/** + * Cap on the number of items accepted in an `/invocations` `input` array + * (one element per seeded message). Protects against a single request + * seeding hundreds of messages into the thread store. + */ +const MAX_INVOCATIONS_INPUT_ITEMS = 100; + +/** Per-message `content` size cap (string form). */ +const MAX_INVOCATIONS_ITEM_CHARS = 64_000; + +/** Per-message `content` size cap (array form). */ +const MAX_INVOCATIONS_ITEM_ARRAY_ITEMS = 100; + +export const chatRequestSchema = z.object({ + message: z + .string() + .min(1, "message must not be empty") + .max( + MAX_MESSAGE_CHARS, + `message exceeds the ${MAX_MESSAGE_CHARS}-character limit`, + ), + threadId: z.string().optional(), + agent: z.string().optional(), +}); + +const messageItemSchema = z.object({ + role: z.enum(["user", "assistant", "system"]).optional(), + content: z + .union([ + z.string().max(MAX_INVOCATIONS_ITEM_CHARS), + z.array(z.any()).max(MAX_INVOCATIONS_ITEM_ARRAY_ITEMS), + ]) + .optional(), + type: z.string().optional(), +}); + +export const invocationsRequestSchema = z.object({ + input: z.union([ + z.string().min(1).max(MAX_INVOCATIONS_INPUT_CHARS), + z + .array(messageItemSchema) + .min(1) + .max( + MAX_INVOCATIONS_INPUT_ITEMS, + `input array exceeds the ${MAX_INVOCATIONS_INPUT_ITEMS}-item limit`, + ), + ]), + stream: z.boolean().optional().default(true), + model: z.string().optional(), +}); + +export const approvalRequestSchema = z.object({ + streamId: z.string().min(1, "streamId is required"), + approvalId: z.string().min(1, "approvalId is required"), + decision: z.enum(["approve", "deny"]), +}); diff --git a/packages/appkit/src/plugins/agents/system-prompt.ts b/packages/appkit/src/plugins/agents/system-prompt.ts new file mode 100644 index 00000000..634f49c5 --- /dev/null +++ b/packages/appkit/src/plugins/agents/system-prompt.ts @@ -0,0 +1,40 @@ +/** + * Builds the AppKit base system prompt from active plugin names. + * + * The base prompt provides guidelines and app context. It does NOT + * include individual tool descriptions — those are sent via the + * structured `tools` API parameter to the LLM. + */ +export function buildBaseSystemPrompt(pluginNames: string[]): string { + const lines: string[] = [ + "You are an AI assistant running on Databricks AppKit.", + ]; + + if (pluginNames.length > 0) { + lines.push(""); + lines.push(`Active plugins: ${pluginNames.join(", ")}`); + } + + lines.push(""); + lines.push("Guidelines:"); + lines.push("- Use Databricks SQL syntax when writing queries"); + lines.push( + "- When results are large, summarize key findings rather than dumping raw data", + ); + lines.push("- If a tool call fails, explain the error clearly to the user"); + lines.push("- When browsing files, verify the path exists before reading"); + + return lines.join("\n"); +} + +/** + * Compose the full system prompt from the base prompt and an optional + * per-agent user prompt. + */ +export function composeSystemPrompt( + basePrompt: string, + agentPrompt?: string, +): string { + if (!agentPrompt) return basePrompt; + return `${basePrompt}\n\n${agentPrompt}`; +} diff --git a/packages/appkit/src/plugins/agents/tests/agents-plugin.test.ts b/packages/appkit/src/plugins/agents/tests/agents-plugin.test.ts new file mode 100644 index 00000000..d9abb925 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/agents-plugin.test.ts @@ -0,0 +1,373 @@ +import fs from "node:fs"; +import os from "node:os"; +import path from "node:path"; +import type { + AgentAdapter, + AgentInput, + AgentRunContext, + AgentToolDefinition, + ToolProvider, +} from "shared"; +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { z } from "zod"; +import { CacheManager } from "../../../cache"; +// Import the class directly so we can construct it without a createApp +import { AgentsPlugin } from "../agents"; +import { buildToolkitEntries } from "../build-toolkit"; +import { defineTool, type ToolRegistry } from "../tools/define-tool"; +import type { AgentsPluginConfig, ToolkitEntry } from "../types"; +import { isToolkitEntry } from "../types"; + +interface FakeContext { + providers: Array<{ name: string; provider: ToolProvider }>; + getToolProviders(): Array<{ name: string; provider: ToolProvider }>; + getPluginNames(): string[]; + addRoute(): void; + executeTool: ( + req: unknown, + pluginName: string, + localName: string, + args: unknown, + ) => Promise; +} + +function fakeContext( + providers: Array<{ name: string; provider: ToolProvider }>, +): FakeContext { + return { + providers, + getToolProviders: () => providers, + getPluginNames: () => providers.map((p) => p.name), + addRoute: vi.fn(), + executeTool: vi.fn(async (_req, p, n, args) => ({ + plugin: p, + tool: n, + args, + })), + }; +} + +function stubAdapter(): AgentAdapter { + return { + async *run(_input: AgentInput, _ctx: AgentRunContext) { + yield { type: "message_delta", content: "" }; + }, + }; +} + +function makeToolProvider( + pluginName: string, + registry: ToolRegistry, +): ToolProvider & { + toolkit: (opts?: unknown) => Record; +} { + return { + getAgentTools(): AgentToolDefinition[] { + return Object.entries(registry).map(([name, entry]) => ({ + name, + description: entry.description, + parameters: { type: "object", properties: {} }, + })); + }, + async executeAgentTool(name, args) { + return { callFrom: pluginName, name, args }; + }, + toolkit: (opts) => buildToolkitEntries(pluginName, registry, opts as never), + }; +} + +let tmpDir: string; + +beforeEach(async () => { + tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), "agents-plugin-")); + const storage = { + get: vi.fn(), + set: vi.fn(), + delete: vi.fn(), + keys: vi.fn(), + healthCheck: vi.fn(async () => true), + close: vi.fn(async () => {}), + }; + // biome-ignore lint/suspicious/noExplicitAny: test-only CacheManager wiring + await CacheManager.getInstance({ storage: storage as any }); +}); + +afterEach(() => { + fs.rmSync(tmpDir, { recursive: true, force: true }); +}); + +function instantiate(config: AgentsPluginConfig, ctx?: FakeContext) { + const plugin = new AgentsPlugin({ ...config, name: "agent" }); + plugin.attachContext({ context: ctx as unknown as object }); + return plugin; +} + +describe("AgentsPlugin", () => { + test("registers code-defined agents and exposes them via exports", async () => { + const plugin = instantiate({ + dir: false, + agents: { + support: { + instructions: "You help customers.", + model: stubAdapter(), + }, + }, + }); + await plugin.setup(); + + const api = plugin.exports() as { + list: () => string[]; + getDefault: () => string | null; + }; + expect(api.list()).toEqual(["support"]); + expect(api.getDefault()).toBe("support"); + }); + + test("loads markdown agents from a directory", async () => { + fs.writeFileSync( + path.join(tmpDir, "assistant.md"), + "---\ndefault: true\n---\nYou are helpful.", + "utf-8", + ); + const plugin = instantiate({ + dir: tmpDir, + defaultModel: stubAdapter(), + }); + await plugin.setup(); + + const api = plugin.exports() as { + list: () => string[]; + getDefault: () => string | null; + }; + expect(api.list()).toEqual(["assistant"]); + expect(api.getDefault()).toBe("assistant"); + }); + + test("code definitions override markdown on key collision", async () => { + fs.writeFileSync( + path.join(tmpDir, "support.md"), + "---\n---\nFrom markdown.", + "utf-8", + ); + const plugin = instantiate({ + dir: tmpDir, + defaultModel: stubAdapter(), + agents: { + support: { + instructions: "From code", + model: stubAdapter(), + }, + }, + }); + await plugin.setup(); + + const api = plugin.exports() as { + get: (name: string) => { instructions: string } | null; + }; + expect(api.get("support")?.instructions).toBe("From code"); + }); + + test("auto-inherit default is safe (both file and code get nothing without an explicit opt-in)", async () => { + const registry: ToolRegistry = { + query: defineTool({ + description: "q", + schema: z.object({ sql: z.string() }), + autoInheritable: true, // even with autoInheritable, no spread without opt-in + handler: () => "ok", + }), + }; + const provider = makeToolProvider("analytics", registry); + const ctx = fakeContext([{ name: "analytics", provider }]); + + fs.writeFileSync( + path.join(tmpDir, "assistant.md"), + "---\n---\nYou are helpful.", + "utf-8", + ); + + const plugin = instantiate( + { + dir: tmpDir, + defaultModel: stubAdapter(), + agents: { + manual: { + instructions: "Manual agent", + model: stubAdapter(), + }, + }, + }, + ctx, + ); + await plugin.setup(); + + const api = plugin.exports() as { + get: (name: string) => { toolIndex: Map } | null; + }; + const fileAgent = api.get("assistant"); + const codeAgent = api.get("manual"); + + expect(fileAgent?.toolIndex.size).toBe(0); + expect(codeAgent?.toolIndex.size).toBe(0); + }); + + test("opting in with autoInheritTools: { file: true } spreads only autoInheritable tools", async () => { + const registry: ToolRegistry = { + query: defineTool({ + description: "read-only query", + schema: z.object({ sql: z.string() }), + autoInheritable: true, + handler: () => "ok", + }), + destructive: defineTool({ + description: "mutation", + schema: z.object({}), + // autoInheritable left unset → skipped even when opted in + handler: () => "ok", + }), + }; + const provider = makeToolProvider("analytics", registry); + const ctx = fakeContext([{ name: "analytics", provider }]); + + fs.writeFileSync( + path.join(tmpDir, "assistant.md"), + "---\n---\nYou are helpful.", + "utf-8", + ); + + const plugin = instantiate( + { + dir: tmpDir, + defaultModel: stubAdapter(), + autoInheritTools: { file: true }, + }, + ctx, + ); + await plugin.setup(); + + const api = plugin.exports() as { + get: (name: string) => { toolIndex: Map } | null; + }; + const fileAgent = api.get("assistant"); + const keys = Array.from(fileAgent?.toolIndex.keys() ?? []); + expect(keys).toEqual(["analytics.query"]); + }); + + test("autoInheritTools: true enables both origins but still filters by autoInheritable", async () => { + const registry: ToolRegistry = { + safe: defineTool({ + description: "safe", + schema: z.object({}), + autoInheritable: true, + handler: () => "ok", + }), + unsafe: defineTool({ + description: "unsafe", + schema: z.object({}), + handler: () => "ok", + }), + }; + const provider = makeToolProvider("p", registry); + const ctx = fakeContext([{ name: "p", provider }]); + + const plugin = instantiate( + { + dir: false, + defaultModel: stubAdapter(), + autoInheritTools: true, + agents: { + code1: { + instructions: "code agent", + model: stubAdapter(), + }, + }, + }, + ctx, + ); + await plugin.setup(); + + const api = plugin.exports() as { + get: (name: string) => { toolIndex: Map } | null; + }; + const codeAgent = api.get("code1"); + const keys = Array.from(codeAgent?.toolIndex.keys() ?? []); + expect(keys).toEqual(["p.safe"]); + }); + + test("file-loaded agent respects explicit toolkits (skips auto-inherit)", async () => { + const registry: ToolRegistry = { + query: defineTool({ + description: "q", + schema: z.object({ sql: z.string() }), + handler: () => "ok", + }), + }; + const registry2: ToolRegistry = { + list: defineTool({ + description: "l", + schema: z.object({}), + handler: () => [], + }), + }; + const ctx = fakeContext([ + { name: "analytics", provider: makeToolProvider("analytics", registry) }, + { name: "files", provider: makeToolProvider("files", registry2) }, + ]); + + fs.writeFileSync( + path.join(tmpDir, "analyst.md"), + "---\ntoolkits: [analytics]\n---\nAnalyst.", + "utf-8", + ); + + const plugin = instantiate( + { dir: tmpDir, defaultModel: stubAdapter() }, + ctx, + ); + await plugin.setup(); + + const api = plugin.exports() as { + get: (name: string) => { toolIndex: Map } | null; + }; + const agent = api.get("analyst"); + const toolNames = Array.from(agent?.toolIndex.keys() ?? []); + expect(toolNames.some((n) => n.startsWith("analytics."))).toBe(true); + expect(toolNames.some((n) => n.startsWith("files."))).toBe(false); + }); + + test("registers sub-agents as agent- tools", async () => { + const plugin = instantiate({ + dir: false, + agents: { + supervisor: { + instructions: "Supervise", + model: stubAdapter(), + agents: { + worker: { + instructions: "Work", + model: stubAdapter(), + }, + }, + }, + }, + }); + await plugin.setup(); + + const api = plugin.exports() as { + get: (name: string) => { toolIndex: Map } | null; + }; + const sup = api.get("supervisor"); + expect(sup?.toolIndex.has("agent-worker")).toBe(true); + }); + + test("isToolkitEntry type guard recognizes toolkit entries", () => { + const entry: ToolkitEntry = { + __toolkitRef: true, + pluginName: "x", + localName: "y", + def: { name: "x.y", description: "", parameters: { type: "object" } }, + }; + expect(isToolkitEntry(entry)).toBe(true); + expect(isToolkitEntry({ foo: 1 })).toBe(false); + expect(isToolkitEntry(null)).toBe(false); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/approval-route.test.ts b/packages/appkit/src/plugins/agents/tests/approval-route.test.ts new file mode 100644 index 00000000..6e090bd2 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/approval-route.test.ts @@ -0,0 +1,292 @@ +import type express from "express"; +import { beforeEach, describe, expect, test, vi } from "vitest"; +import { CacheManager } from "../../../cache"; +import { AgentsPlugin } from "../agents"; + +/** + * Focused tests for the `POST /approve` route and the associated + * ownership / error paths on `_handleApprove`. Covers: + * + * - Schema validation of the request body. + * - Ownership check: the user submitting the decision must be the same + * user who initiated the underlying chat stream. + * - 404 for unknown stream (already completed or never existed). + * - 404 for unknown approvalId even when the stream is active. + * - Happy-path resolution of a pending gate with `approve` and `deny`. + * - Cancel of an active stream denies every pending gate on that stream. + */ + +function mockReq(body: unknown, userId?: string): express.Request { + const headers: Record = {}; + if (userId) { + headers["x-forwarded-user"] = userId; + headers["x-forwarded-access-token"] = "fake-token"; + } + return { + body, + headers, + header: (name: string) => headers[name.toLowerCase()], + } as unknown as express.Request; +} + +function mockRes() { + const json = vi.fn(); + const end = vi.fn(); + let statusCode = 200; + const status = vi.fn((code: number) => { + statusCode = code; + return { json, end }; + }); + return { + res: { status, json, end } as unknown as express.Response, + get statusCode() { + return statusCode; + }, + json, + }; +} + +beforeEach(() => { + CacheManager.getInstanceSync = vi.fn(() => ({ + get: vi.fn(), + set: vi.fn(), + delete: vi.fn(), + getOrExecute: vi.fn(async (_k: unknown[], fn: () => Promise) => + fn(), + ), + generateKey: vi.fn(() => "test-key"), + })) as any; + process.env.NODE_ENV = "development"; +}); + +describe("POST /approve route handler", () => { + test("rejects invalid body shape with 400", async () => { + const plugin = new AgentsPlugin({ dir: false }); + const { res, json } = mockRes(); + await (plugin as any)._handleApprove(mockReq({}, "alice"), res); + expect(res.status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith( + expect.objectContaining({ error: "Invalid request" }), + ); + }); + + test("returns 404 when the streamId is unknown", async () => { + const plugin = new AgentsPlugin({ dir: false }); + const { res, json } = mockRes(); + await ( + plugin as unknown as { + _handleApprove: ( + r: express.Request, + w: express.Response, + ) => Promise; + } + )._handleApprove( + mockReq( + { streamId: "ghost", approvalId: "a1", decision: "approve" }, + "alice", + ), + res, + ); + expect(res.status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith( + expect.objectContaining({ error: expect.stringMatching(/not found/i) }), + ); + }); + + test("returns 403 when submitter is different from stream owner", async () => { + const plugin = new AgentsPlugin({ dir: false }); + (plugin as any).activeStreams.set("stream-x", { + controller: new AbortController(), + userId: "alice", + }); + const gate = (plugin as any).approvalGate; + const waiter = gate.wait({ + approvalId: "a1", + streamId: "stream-x", + userId: "alice", + timeoutMs: 60_000, + }); + + const { res, json } = mockRes(); + await ( + plugin as unknown as { + _handleApprove: ( + r: express.Request, + w: express.Response, + ) => Promise; + } + )._handleApprove( + mockReq( + { streamId: "stream-x", approvalId: "a1", decision: "approve" }, + "bob", + ), + res, + ); + expect(res.status).toHaveBeenCalledWith(403); + expect(json).toHaveBeenCalledWith( + expect.objectContaining({ error: "Forbidden" }), + ); + + // Settle the waiter to clean up. + gate.submit({ approvalId: "a1", userId: "alice", decision: "deny" }); + await expect(waiter).resolves.toBe("deny"); + }); + + test("returns 404 when approvalId is unknown on an active stream", async () => { + const plugin = new AgentsPlugin({ dir: false }); + (plugin as any).activeStreams.set("stream-y", { + controller: new AbortController(), + userId: "alice", + }); + const { res, json } = mockRes(); + await ( + plugin as unknown as { + _handleApprove: ( + r: express.Request, + w: express.Response, + ) => Promise; + } + )._handleApprove( + mockReq( + { streamId: "stream-y", approvalId: "unknown-a", decision: "approve" }, + "alice", + ), + res, + ); + expect(res.status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith( + expect.objectContaining({ + error: expect.stringMatching(/not found|already settled/i), + }), + ); + }); + + test("happy path: approve resolves pending gate with 'approve'", async () => { + const plugin = new AgentsPlugin({ dir: false }); + (plugin as any).activeStreams.set("stream-z", { + controller: new AbortController(), + userId: "alice", + }); + const gate = (plugin as any).approvalGate; + const waiter = gate.wait({ + approvalId: "a42", + streamId: "stream-z", + userId: "alice", + timeoutMs: 60_000, + }); + + const { res, json } = mockRes(); + await ( + plugin as unknown as { + _handleApprove: ( + r: express.Request, + w: express.Response, + ) => Promise; + } + )._handleApprove( + mockReq( + { streamId: "stream-z", approvalId: "a42", decision: "approve" }, + "alice", + ), + res, + ); + expect(res.status).not.toHaveBeenCalled(); + expect(json).toHaveBeenCalledWith({ decision: "approve" }); + await expect(waiter).resolves.toBe("approve"); + }); + + test("happy path: deny resolves pending gate with 'deny'", async () => { + const plugin = new AgentsPlugin({ dir: false }); + (plugin as any).activeStreams.set("stream-z", { + controller: new AbortController(), + userId: "alice", + }); + const gate = (plugin as any).approvalGate; + const waiter = gate.wait({ + approvalId: "a43", + streamId: "stream-z", + userId: "alice", + timeoutMs: 60_000, + }); + + const { res, json } = mockRes(); + await ( + plugin as unknown as { + _handleApprove: ( + r: express.Request, + w: express.Response, + ) => Promise; + } + )._handleApprove( + mockReq( + { streamId: "stream-z", approvalId: "a43", decision: "deny" }, + "alice", + ), + res, + ); + expect(json).toHaveBeenCalledWith({ decision: "deny" }); + await expect(waiter).resolves.toBe("deny"); + }); +}); + +describe("POST /cancel ownership + gate cleanup", () => { + test("cancelling a stream denies every pending approval on that stream", async () => { + const plugin = new AgentsPlugin({ dir: false }); + const controller = new AbortController(); + (plugin as any).activeStreams.set("stream-c", { + controller, + userId: "alice", + }); + const gate = (plugin as any).approvalGate; + const a = gate.wait({ + approvalId: "ca1", + streamId: "stream-c", + userId: "alice", + timeoutMs: 60_000, + }); + const b = gate.wait({ + approvalId: "ca2", + streamId: "stream-c", + userId: "alice", + timeoutMs: 60_000, + }); + + const { res, json } = mockRes(); + await ( + plugin as unknown as { + _handleCancel: ( + r: express.Request, + w: express.Response, + ) => Promise; + } + )._handleCancel(mockReq({ streamId: "stream-c" }, "alice"), res); + + expect(controller.signal.aborted).toBe(true); + expect(json).toHaveBeenCalledWith({ cancelled: true }); + await expect(a).resolves.toBe("deny"); + await expect(b).resolves.toBe("deny"); + }); + + test("cancel from a different user is refused with 403", async () => { + const plugin = new AgentsPlugin({ dir: false }); + const controller = new AbortController(); + (plugin as any).activeStreams.set("stream-d", { + controller, + userId: "alice", + }); + const { res, json } = mockRes(); + await ( + plugin as unknown as { + _handleCancel: ( + r: express.Request, + w: express.Response, + ) => Promise; + } + )._handleCancel(mockReq({ streamId: "stream-d" }, "bob"), res); + expect(res.status).toHaveBeenCalledWith(403); + expect(controller.signal.aborted).toBe(false); + expect(json).toHaveBeenCalledWith( + expect.objectContaining({ error: "Forbidden" }), + ); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/create-agent.test.ts b/packages/appkit/src/plugins/agents/tests/create-agent.test.ts new file mode 100644 index 00000000..3822897f --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/create-agent.test.ts @@ -0,0 +1,75 @@ +import { describe, expect, test } from "vitest"; +import { z } from "zod"; +import { createAgent } from "../../../core/create-agent-def"; +import { tool } from "../tools/tool"; +import type { AgentDefinition } from "../types"; + +describe("createAgent", () => { + test("returns the definition unchanged for a simple agent", () => { + const def: AgentDefinition = { + name: "support", + instructions: "You help customers.", + model: "endpoint-x", + }; + const result = createAgent(def); + expect(result).toBe(def); + }); + + test("accepts tools as a keyed record", () => { + const get_weather = tool({ + name: "get_weather", + description: "Get the weather", + schema: z.object({ city: z.string() }), + execute: async ({ city }) => `Sunny in ${city}`, + }); + + const def = createAgent({ + instructions: "...", + tools: { get_weather }, + }); + + expect(def.tools?.get_weather).toBe(get_weather); + }); + + test("accepts sub-agents in a keyed record", () => { + const researcher = createAgent({ instructions: "Research." }); + const supervisor = createAgent({ + instructions: "Supervise.", + agents: { researcher }, + }); + expect(supervisor.agents?.researcher).toBe(researcher); + }); + + test("throws on a direct self-cycle", () => { + const a: AgentDefinition = { instructions: "a" }; + // biome-ignore lint/suspicious/noExplicitAny: intentional cycle setup for test + (a as any).agents = { self: a }; + expect(() => createAgent(a)).toThrow(/cycle/i); + }); + + test("throws on an indirect cycle", () => { + const a: AgentDefinition = { instructions: "a" }; + const b: AgentDefinition = { instructions: "b" }; + a.agents = { b }; + b.agents = { a }; + expect(() => createAgent(a)).toThrow(/cycle/i); + }); + + test("accepts a DAG of sub-agents without throwing", () => { + const leaf: AgentDefinition = { instructions: "leaf" }; + const branchA: AgentDefinition = { + instructions: "a", + agents: { leaf }, + }; + const branchB: AgentDefinition = { + instructions: "b", + agents: { leaf }, + }; + const root = createAgent({ + instructions: "root", + agents: { branchA, branchB }, + }); + expect(root.agents?.branchA.agents?.leaf).toBe(leaf); + expect(root.agents?.branchB.agents?.leaf).toBe(leaf); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/dos-limits.test.ts b/packages/appkit/src/plugins/agents/tests/dos-limits.test.ts new file mode 100644 index 00000000..935fa240 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/dos-limits.test.ts @@ -0,0 +1,299 @@ +import type express from "express"; +import { beforeEach, describe, expect, test, vi } from "vitest"; +import { CacheManager } from "../../../cache"; +import { AgentsPlugin } from "../agents"; +import { chatRequestSchema, invocationsRequestSchema } from "../schemas"; + +/** + * Exercises the four DoS caps landed for MVP: + * + * - `chatRequestSchema.message.max(64_000)` — body cap on `POST /chat`. + * - Per-user `maxConcurrentStreamsPerUser` — 429 with Retry-After. + * - Per-run `maxToolCalls` — aborts stream and throws in `executeTool`. + * - Per-delegation `maxSubAgentDepth` — rejects in `runSubAgent`. + * + * Route-level tests exercise the schemas + `_handleChat` directly via the + * mocked req/res pattern already used by approval-route.test.ts. + */ + +function mockReq(body: unknown, userId?: string): express.Request { + const headers: Record = {}; + if (userId) { + headers["x-forwarded-user"] = userId; + headers["x-forwarded-access-token"] = "fake-token"; + } + return { + body, + headers, + header: (name: string) => headers[name.toLowerCase()], + } as unknown as express.Request; +} + +function mockRes() { + const json = vi.fn(); + const setHeader = vi.fn(); + let statusCode = 200; + const status = vi.fn((code: number) => { + statusCode = code; + return { json }; + }); + return { + res: { status, json, setHeader } as unknown as express.Response, + get statusCode() { + return statusCode; + }, + json, + setHeader, + }; +} + +beforeEach(() => { + CacheManager.getInstanceSync = vi.fn(() => ({ + get: vi.fn(), + set: vi.fn(), + delete: vi.fn(), + getOrExecute: vi.fn(async (_k: unknown[], fn: () => Promise) => + fn(), + ), + generateKey: vi.fn(() => "test-key"), + // biome-ignore lint/suspicious/noExplicitAny: test mock + })) as any; + process.env.NODE_ENV = "development"; +}); + +describe("chatRequestSchema — body cap", () => { + test("accepts messages up to 64_000 characters", () => { + const result = chatRequestSchema.safeParse({ + message: "a".repeat(64_000), + }); + expect(result.success).toBe(true); + }); + + test("rejects messages over 64_000 characters", () => { + const result = chatRequestSchema.safeParse({ + message: "a".repeat(64_001), + }); + expect(result.success).toBe(false); + if (!result.success) { + expect(JSON.stringify(result.error.flatten())).toMatch(/64000/); + } + }); + + test("rejects empty message (existing contract)", () => { + expect(chatRequestSchema.safeParse({ message: "" }).success).toBe(false); + }); +}); + +describe("invocationsRequestSchema — input caps", () => { + test("accepts string input up to 64_000 characters", () => { + const result = invocationsRequestSchema.safeParse({ + input: "a".repeat(64_000), + }); + expect(result.success).toBe(true); + }); + + test("rejects string input over 64_000 characters", () => { + const result = invocationsRequestSchema.safeParse({ + input: "a".repeat(64_001), + }); + expect(result.success).toBe(false); + }); + + test("accepts array input up to 100 items", () => { + const items = Array.from({ length: 100 }, (_, i) => ({ + role: "user" as const, + content: `m${i}`, + })); + expect(invocationsRequestSchema.safeParse({ input: items }).success).toBe( + true, + ); + }); + + test("rejects array input over 100 items", () => { + const items = Array.from({ length: 101 }, (_, i) => ({ + role: "user" as const, + content: `m${i}`, + })); + const result = invocationsRequestSchema.safeParse({ input: items }); + expect(result.success).toBe(false); + }); + + test("rejects per-item content over 64_000 characters", () => { + const result = invocationsRequestSchema.safeParse({ + input: [{ role: "user", content: "a".repeat(64_001) }], + }); + expect(result.success).toBe(false); + }); +}); + +describe("POST /chat — per-user concurrent-stream limit", () => { + function seedPlugin( + overrides: ConstructorParameters[0] = { dir: false }, + ): AgentsPlugin { + const plugin = new AgentsPlugin(overrides); + // Seed the agents map directly so _handleChat can resolve "hello" + // without running setup() (which would require a live model). + // biome-ignore lint/suspicious/noExplicitAny: seeding private state + (plugin as any).agents.set("hello", { + name: "hello", + instructions: "hi", + adapter: { async *run() {} }, + toolIndex: new Map(), + }); + // biome-ignore lint/suspicious/noExplicitAny: seeding private state + (plugin as any).defaultAgentName = "hello"; + return plugin; + } + + test("rejects with 429 + Retry-After when user is at-limit (default 5)", async () => { + const plugin = seedPlugin(); + for (let i = 0; i < 5; i++) { + // biome-ignore lint/suspicious/noExplicitAny: seeding + (plugin as any).activeStreams.set(`s${i}`, { + controller: new AbortController(), + userId: "alice", + }); + } + + const { res, setHeader, json } = mockRes(); + await ( + plugin as unknown as { + _handleChat: (r: express.Request, w: express.Response) => Promise; + } + )._handleChat(mockReq({ message: "hi" }, "alice"), res); + + expect(res.status).toHaveBeenCalledWith(429); + expect(setHeader).toHaveBeenCalledWith("Retry-After", "5"); + expect(json).toHaveBeenCalledWith( + expect.objectContaining({ + error: expect.stringMatching(/Too many concurrent streams/), + }), + ); + }); + + test("does not reject when another user is at-limit (per-user, not global)", async () => { + const plugin = seedPlugin(); + for (let i = 0; i < 5; i++) { + // biome-ignore lint/suspicious/noExplicitAny: seeding + (plugin as any).activeStreams.set(`s${i}`, { + controller: new AbortController(), + userId: "alice", + }); + } + + // Carol's request must not see a 429 even though alice is at-limit. + // Don't bother running the full stream — we assert only that 429 is + // not the response status. + const { res } = mockRes(); + // biome-ignore lint/suspicious/noExplicitAny: stub _streamAgent to avoid needing a real adapter + (plugin as any)._streamAgent = vi.fn(async () => undefined); + + await ( + plugin as unknown as { + _handleChat: (r: express.Request, w: express.Response) => Promise; + } + )._handleChat(mockReq({ message: "hi" }, "carol"), res); + + expect(res.status).not.toHaveBeenCalledWith(429); + }); + + test("honours agents({ limits: { maxConcurrentStreamsPerUser } })", async () => { + const plugin = seedPlugin({ + dir: false, + limits: { maxConcurrentStreamsPerUser: 2 }, + }); + for (let i = 0; i < 2; i++) { + // biome-ignore lint/suspicious/noExplicitAny: seeding + (plugin as any).activeStreams.set(`s${i}`, { + controller: new AbortController(), + userId: "alice", + }); + } + + const { res } = mockRes(); + await ( + plugin as unknown as { + _handleChat: (r: express.Request, w: express.Response) => Promise; + } + )._handleChat(mockReq({ message: "hi" }, "alice"), res); + + expect(res.status).toHaveBeenCalledWith(429); + }); +}); + +describe("resolvedLimits — default values", () => { + test("exposes the documented MVP defaults when unconfigured", () => { + const plugin = new AgentsPlugin({ dir: false }); + // biome-ignore lint/suspicious/noExplicitAny: read private getter + const limits = (plugin as any).resolvedLimits; + expect(limits).toEqual({ + maxConcurrentStreamsPerUser: 5, + maxToolCalls: 50, + maxSubAgentDepth: 3, + }); + }); + + test("lets callers override any subset", () => { + const plugin = new AgentsPlugin({ + dir: false, + limits: { maxToolCalls: 100 }, + }); + // biome-ignore lint/suspicious/noExplicitAny: read private + const limits = (plugin as any).resolvedLimits; + expect(limits.maxToolCalls).toBe(100); + expect(limits.maxConcurrentStreamsPerUser).toBe(5); + expect(limits.maxSubAgentDepth).toBe(3); + }); +}); + +describe("runSubAgent — depth guard", () => { + test("rejects when depth exceeds the configured maximum", async () => { + const plugin = new AgentsPlugin({ + dir: false, + limits: { maxSubAgentDepth: 2 }, + }); + // biome-ignore lint/suspicious/noExplicitAny: call private method directly + await expect( + (plugin as any).runSubAgent( + mockReq({}, "alice"), + { name: "child", toolIndex: new Map() }, + {}, + new AbortController().signal, + 3, // exceeds limit 2 + ), + ).rejects.toThrow(/Sub-agent depth exceeded \(limit 2\)/); + }); + + test("accepts at the boundary (depth === limit)", async () => { + // Use a stub adapter so we don't need a real model. + const plugin = new AgentsPlugin({ + dir: false, + limits: { maxSubAgentDepth: 3 }, + agents: {}, + }); + + const stubAdapter = { + // biome-ignore lint/suspicious/noExplicitAny: adapter shape not under test + async *run(): any { + yield { type: "message", content: "hello from depth-3" }; + }, + }; + const child = { + name: "child", + instructions: "test", + // biome-ignore lint/suspicious/noExplicitAny: stub shape + adapter: stubAdapter as any, + toolIndex: new Map(), + }; + + // biome-ignore lint/suspicious/noExplicitAny: call private + const result = await (plugin as any).runSubAgent( + mockReq({}, "alice"), + child, + { input: "test" }, + new AbortController().signal, + 3, // at the limit, not over + ); + expect(result).toBe("hello from depth-3"); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/event-channel.test.ts b/packages/appkit/src/plugins/agents/tests/event-channel.test.ts new file mode 100644 index 00000000..d80d788d --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/event-channel.test.ts @@ -0,0 +1,78 @@ +import { describe, expect, test } from "vitest"; +import { EventChannel } from "../event-channel"; + +async function collect(ch: EventChannel): Promise { + const out: T[] = []; + for await (const v of ch) out.push(v); + return out; +} + +describe("EventChannel", () => { + test("yields pushed values in order", async () => { + const ch = new EventChannel(); + const p = collect(ch); + ch.push(1); + ch.push(2); + ch.push(3); + ch.close(); + await expect(p).resolves.toEqual([1, 2, 3]); + }); + + test("pushes before iteration start are buffered", async () => { + const ch = new EventChannel(); + ch.push("a"); + ch.push("b"); + ch.close(); + await expect(collect(ch)).resolves.toEqual(["a", "b"]); + }); + + test("waiting iterator is unblocked by subsequent push", async () => { + const ch = new EventChannel(); + const promise = collect(ch); + await new Promise((r) => setTimeout(r, 5)); + ch.push(42); + ch.close(); + await expect(promise).resolves.toEqual([42]); + }); + + test("close with no pending values terminates iteration", async () => { + const ch = new EventChannel(); + const p = collect(ch); + ch.close(); + await expect(p).resolves.toEqual([]); + }); + + test("push after close is a no-op (channel is closed)", async () => { + const ch = new EventChannel(); + ch.close(); + ch.push(1); + await expect(collect(ch)).resolves.toEqual([]); + }); + + test("close with error rejects the waiting iterator", async () => { + const ch = new EventChannel(); + const promise = collect(ch); + await new Promise((r) => setTimeout(r, 5)); + ch.close(new Error("boom")); + await expect(promise).rejects.toThrow(/boom/); + }); + + test("interleaved pushes and reads stream through", async () => { + const ch = new EventChannel(); + const received: number[] = []; + const reader = (async () => { + for await (const v of ch) { + received.push(v); + if (received.length === 3) break; + } + })(); + ch.push(1); + await new Promise((r) => setTimeout(r, 0)); + ch.push(2); + await new Promise((r) => setTimeout(r, 0)); + ch.push(3); + await reader; + expect(received).toEqual([1, 2, 3]); + ch.close(); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/event-translator.test.ts b/packages/appkit/src/plugins/agents/tests/event-translator.test.ts new file mode 100644 index 00000000..050af001 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/event-translator.test.ts @@ -0,0 +1,332 @@ +import type { ResponseStreamEvent } from "shared"; +import { describe, expect, test } from "vitest"; +import { AgentEventTranslator } from "../event-translator"; + +describe("AgentEventTranslator", () => { + test("translates message_delta to output_item.added + output_text.delta on first delta", () => { + const translator = new AgentEventTranslator(); + const events = translator.translate({ + type: "message_delta", + content: "Hello", + }); + + expect(events).toHaveLength(2); + expect(events[0].type).toBe("response.output_item.added"); + expect(events[1].type).toBe("response.output_text.delta"); + + if (events[1].type === "response.output_text.delta") { + expect(events[1].delta).toBe("Hello"); + } + }); + + test("subsequent message_delta only produces output_text.delta", () => { + const translator = new AgentEventTranslator(); + translator.translate({ type: "message_delta", content: "Hello" }); + const events = translator.translate({ + type: "message_delta", + content: " world", + }); + + expect(events).toHaveLength(1); + expect(events[0].type).toBe("response.output_text.delta"); + }); + + test("sequence_number is monotonically increasing", () => { + const translator = new AgentEventTranslator(); + const e1 = translator.translate({ type: "message_delta", content: "a" }); + const e2 = translator.translate({ type: "message_delta", content: "b" }); + const e3 = translator.finalize(); + + const allSeqs = [...e1, ...e2, ...e3].map((e) => + "sequence_number" in e ? e.sequence_number : -1, + ); + + for (let i = 1; i < allSeqs.length; i++) { + expect(allSeqs[i]).toBeGreaterThan(allSeqs[i - 1]); + } + }); + + test("translates tool_call to paired output_item.added + output_item.done", () => { + const translator = new AgentEventTranslator(); + const events = translator.translate({ + type: "tool_call", + callId: "call_1", + name: "analytics.query", + args: { sql: "SELECT 1" }, + }); + + expect(events).toHaveLength(2); + expect(events[0].type).toBe("response.output_item.added"); + expect(events[1].type).toBe("response.output_item.done"); + + if (events[0].type === "response.output_item.added") { + expect(events[0].item.type).toBe("function_call"); + if (events[0].item.type === "function_call") { + expect(events[0].item.name).toBe("analytics.query"); + expect(events[0].item.call_id).toBe("call_1"); + } + } + }); + + test("translates tool_result to paired output_item events", () => { + const translator = new AgentEventTranslator(); + const events = translator.translate({ + type: "tool_result", + callId: "call_1", + result: { rows: 42 }, + }); + + expect(events).toHaveLength(2); + expect(events[0].type).toBe("response.output_item.added"); + + if (events[0].type === "response.output_item.added") { + expect(events[0].item.type).toBe("function_call_output"); + } + }); + + test("translates tool_result error", () => { + const translator = new AgentEventTranslator(); + const events = translator.translate({ + type: "tool_result", + callId: "call_1", + result: null, + error: "Query failed", + }); + + if ( + events[0].type === "response.output_item.added" && + events[0].item.type === "function_call_output" + ) { + expect(events[0].item.output).toBe("Query failed"); + } + }); + + test("translates thinking to appkit.thinking extension event", () => { + const translator = new AgentEventTranslator(); + const events = translator.translate({ + type: "thinking", + content: "Let me think about this...", + }); + + expect(events).toHaveLength(1); + expect(events[0].type).toBe("appkit.thinking"); + if (events[0].type === "appkit.thinking") { + expect(events[0].content).toBe("Let me think about this..."); + } + }); + + test("translates metadata to appkit.metadata extension event", () => { + const translator = new AgentEventTranslator(); + const events = translator.translate({ + type: "metadata", + data: { threadId: "t-123" }, + }); + + expect(events).toHaveLength(1); + expect(events[0].type).toBe("appkit.metadata"); + if (events[0].type === "appkit.metadata") { + expect(events[0].data.threadId).toBe("t-123"); + } + }); + + test("status:complete triggers finalize with response.completed", () => { + const translator = new AgentEventTranslator(); + translator.translate({ type: "message_delta", content: "Hi" }); + const events = translator.translate({ type: "status", status: "complete" }); + + const types = events.map((e) => e.type); + expect(types).toContain("response.output_item.done"); + expect(types).toContain("response.completed"); + }); + + test("status:error emits error + response.failed", () => { + const translator = new AgentEventTranslator(); + const events = translator.translate({ + type: "status", + status: "error", + error: "Something broke", + }); + + expect(events).toHaveLength(2); + expect(events[0].type).toBe("error"); + expect(events[1].type).toBe("response.failed"); + + if (events[0].type === "error") { + expect(events[0].error).toBe("Something broke"); + } + }); + + test("finalize produces response.completed", () => { + const translator = new AgentEventTranslator(); + const events = translator.finalize(); + + expect(events.some((e) => e.type === "response.completed")).toBe(true); + }); + + test("finalize with accumulated message text produces output_item.done", () => { + const translator = new AgentEventTranslator(); + translator.translate({ type: "message_delta", content: "Hello " }); + translator.translate({ type: "message_delta", content: "world" }); + const events = translator.finalize(); + + const doneEvent = events.find( + (e) => e.type === "response.output_item.done", + ); + expect(doneEvent).toBeDefined(); + if ( + doneEvent?.type === "response.output_item.done" && + doneEvent.item.type === "message" + ) { + expect(doneEvent.item.content[0].text).toBe("Hello world"); + } + }); + + test("output_index increments for tool calls", () => { + const translator = new AgentEventTranslator(); + const e1 = translator.translate({ + type: "tool_call", + callId: "c1", + name: "tool1", + args: {}, + }); + const e2 = translator.translate({ + type: "tool_result", + callId: "c1", + result: "ok", + }); + + if ( + e1[0].type === "response.output_item.added" && + e2[0].type === "response.output_item.added" + ) { + expect(e2[0].output_index).toBeGreaterThan(e1[0].output_index); + } + }); +}); + +describe("AgentEventTranslator — monotonic output_index", () => { + /** + * Helper: every emitted `response.output_item.added`/`output_item.done` + * event's `output_index` must be >= every prior add/done `output_index`. + * This is the strict contract Responses-API clients (OpenAI's own SDK + * parser) enforce. + */ + function assertMonotonic(events: ResponseStreamEvent[]) { + let last = -1; + for (const ev of events) { + if ( + ev.type === "response.output_item.added" || + ev.type === "response.output_item.done" + ) { + expect(ev.output_index).toBeGreaterThanOrEqual(last); + last = ev.output_index; + } + } + } + + test("tool_call followed by message_delta emits monotonic indices (regression)", () => { + // Before the fix this produced: tool_call at index 1, then + // message_delta.added at 0 — monotonicity violated. + const t = new AgentEventTranslator(); + const all: ResponseStreamEvent[] = []; + all.push( + ...t.translate({ + type: "tool_call", + callId: "c1", + name: "lookup", + args: { q: "x" }, + }), + ); + all.push( + ...t.translate({ type: "tool_result", callId: "c1", result: "ok" }), + ); + all.push(...t.translate({ type: "message_delta", content: "Result: " })); + all.push(...t.translate({ type: "message_delta", content: "ok." })); + all.push(...t.finalize()); + + assertMonotonic(all); + + const added = all.filter((e) => e.type === "response.output_item.added"); + // Three items: tool_call, tool_result, message. Indices 0/1/2. + expect(added.map((e) => e.output_index)).toEqual([0, 1, 2]); + }); + + test("message interrupted by tool_call is closed before the tool_call opens", () => { + const t = new AgentEventTranslator(); + const all: ResponseStreamEvent[] = []; + all.push(...t.translate({ type: "message_delta", content: "thinking..." })); + all.push( + ...t.translate({ + type: "tool_call", + callId: "c1", + name: "lookup", + args: {}, + }), + ); + all.push( + ...t.translate({ type: "tool_result", callId: "c1", result: "ok" }), + ); + all.push(...t.translate({ type: "message_delta", content: "final" })); + all.push(...t.finalize()); + + assertMonotonic(all); + + // Structure: msg0.added, msg0.delta, msg0.done (closed before tool), + // tool_call.added/done, tool_result.added/done, msg1.added, msg1.delta, + // msg1.done (from finalize), response.completed. + const addedDone = all.filter( + (e) => + e.type === "response.output_item.added" || + e.type === "response.output_item.done", + ); + expect(addedDone.map((e) => `${e.type}@${e.output_index}`)).toEqual([ + "response.output_item.added@0", + "response.output_item.done@0", + "response.output_item.added@1", + "response.output_item.done@1", + "response.output_item.added@2", + "response.output_item.done@2", + "response.output_item.added@3", + "response.output_item.done@3", + ]); + }); + + test("full `message` event after deltas does not double-emit output_item.added", () => { + const t = new AgentEventTranslator(); + const all: ResponseStreamEvent[] = []; + all.push(...t.translate({ type: "message_delta", content: "partial" })); + all.push( + ...t.translate({ type: "message", content: "full final content" }), + ); + all.push(...t.finalize()); + + const added = all.filter((e) => e.type === "response.output_item.added"); + const done = all.filter((e) => e.type === "response.output_item.done"); + // Exactly one added (from the first delta) and one done (from the full + // message). finalize() must not emit a second done for the same item. + expect(added).toHaveLength(1); + expect(done).toHaveLength(1); + if (done[0].type === "response.output_item.done") { + const item = done[0].item; + if (item.type === "message") { + expect(item.content[0].text).toBe("full final content"); + } + } + }); + + test("tool_result coerces undefined result to empty-string output", () => { + const t = new AgentEventTranslator(); + const events = t.translate({ + type: "tool_result", + callId: "c1", + result: undefined, + }); + const done = events.find((e) => e.type === "response.output_item.done"); + if (done?.type === "response.output_item.done") { + const item = done.item; + if (item.type === "function_call_output") { + expect(item.output).toBe(""); + } + } + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/load-agents.test.ts b/packages/appkit/src/plugins/agents/tests/load-agents.test.ts new file mode 100644 index 00000000..3410f566 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/load-agents.test.ts @@ -0,0 +1,302 @@ +import fs from "node:fs"; +import os from "node:os"; +import path from "node:path"; +import { afterEach, beforeEach, describe, expect, test } from "vitest"; +import { z } from "zod"; +import { buildToolkitEntries } from "../build-toolkit"; +import { + loadAgentFromFile, + loadAgentsFromDir, + parseFrontmatter, +} from "../load-agents"; +import { defineTool, type ToolRegistry } from "../tools/define-tool"; +import { tool } from "../tools/tool"; +import type { AgentDefinition } from "../types"; + +let workDir: string; + +beforeEach(() => { + workDir = fs.mkdtempSync(path.join(os.tmpdir(), "agents-test-")); +}); + +afterEach(() => { + fs.rmSync(workDir, { recursive: true, force: true }); +}); + +function write(name: string, content: string) { + fs.writeFileSync(path.join(workDir, name), content, "utf-8"); + return path.join(workDir, name); +} + +describe("parseFrontmatter", () => { + test("parses a simple object", () => { + const { data, content } = parseFrontmatter( + "---\nendpoint: foo\ndefault: true\n---\nHello body", + ); + expect(data).toEqual({ endpoint: "foo", default: true }); + expect(content).toBe("Hello body"); + }); + + test("parses nested arrays", () => { + const { data } = parseFrontmatter( + "---\ntoolkits:\n - analytics\n - files: [uploads.list]\n---\nbody", + ); + expect(data?.toolkits).toEqual(["analytics", { files: ["uploads.list"] }]); + }); + + test("returns null data when no frontmatter", () => { + const { data, content } = parseFrontmatter("No frontmatter here"); + expect(data).toBeNull(); + expect(content).toBe("No frontmatter here"); + }); + + test("throws on invalid YAML", () => { + expect(() => parseFrontmatter("---\nkey: : : bad\n---\n")).toThrow(/YAML/); + }); +}); + +describe("loadAgentFromFile", () => { + test("returns AgentDefinition with body as instructions", async () => { + const p = write( + "assistant.md", + "---\nendpoint: e-1\n---\nYou are helpful.", + ); + const def = await loadAgentFromFile(p, {}); + expect(def.name).toBe("assistant"); + expect(def.instructions).toBe("You are helpful."); + expect(def.model).toBe("e-1"); + }); +}); + +describe("loadAgentsFromDir", () => { + test("returns empty map when dir doesn't exist", async () => { + const res = await loadAgentsFromDir("/nonexistent-for-tests", {}); + expect(res.defs).toEqual({}); + expect(res.defaultAgent).toBeNull(); + }); + + test("loads all .md files keyed by file-stem", async () => { + write("support.md", "---\nendpoint: e-1\n---\nSupport prompt."); + write("sales.md", "---\nendpoint: e-2\n---\nSales prompt."); + const res = await loadAgentsFromDir(workDir, {}); + expect(Object.keys(res.defs).sort()).toEqual(["sales", "support"]); + }); + + test("picks up default: true from frontmatter", async () => { + write("one.md", "---\nendpoint: a\n---\nOne."); + write("two.md", "---\nendpoint: b\ndefault: true\n---\nTwo."); + const res = await loadAgentsFromDir(workDir, {}); + expect(res.defaultAgent).toBe("two"); + }); + + test("throws when frontmatter references an unregistered plugin", async () => { + write( + "broken.md", + "---\nendpoint: e\ntoolkits: [missing]\n---\nBroken agent.", + ); + await expect(loadAgentsFromDir(workDir, {})).rejects.toThrow( + /references toolkit 'missing'/, + ); + }); + + test("throws when frontmatter references an unknown ambient tool", async () => { + write("broken.md", "---\nendpoint: e\ntools: [unknown_tool]\n---\nBroken."); + await expect(loadAgentsFromDir(workDir, {})).rejects.toThrow( + /references tool 'unknown_tool'/, + ); + }); + + test("resolves toolkits + ambient tools when provided", async () => { + const registry: ToolRegistry = { + query: defineTool({ + description: "q", + schema: z.object({ sql: z.string() }), + handler: () => "ok", + }), + }; + const plugins = new Map< + string, + { toolkit: (opts?: unknown) => Record } + >([ + [ + "analytics", + { + toolkit: (opts) => + buildToolkitEntries("analytics", registry, opts as never), + }, + ], + ]); + + const weather = tool({ + name: "get_weather", + description: "Weather", + schema: z.object({ city: z.string() }), + execute: async () => "sunny", + }); + + write( + "analyst.md", + "---\nendpoint: e\ntoolkits:\n - analytics\ntools:\n - get_weather\n---\nBody.", + ); + const res = await loadAgentsFromDir(workDir, { + plugins, + availableTools: { get_weather: weather }, + }); + expect(res.defs.analyst.tools).toBeDefined(); + expect(Object.keys(res.defs.analyst.tools ?? {}).sort()).toEqual([ + "analytics.query", + "get_weather", + ]); + }); + + describe("agents: sibling sub-agent references", () => { + test("resolves sibling references into def.agents regardless of file order", async () => { + // Names chosen so alphabetical iteration puts `dispatcher` *before* + // its siblings — pass-1 populates defs in any order, pass-2 resolves. + write( + "dispatcher.md", + "---\nendpoint: e\nagents:\n - analyst\n - writer\n---\nRoute work.", + ); + write("analyst.md", "---\nendpoint: e\n---\nAnalyst."); + write("writer.md", "---\nendpoint: e\n---\nWriter."); + + const res = await loadAgentsFromDir(workDir, {}); + expect(Object.keys(res.defs.dispatcher.agents ?? {}).sort()).toEqual([ + "analyst", + "writer", + ]); + expect(res.defs.dispatcher.agents?.analyst).toBe(res.defs.analyst); + expect(res.defs.dispatcher.agents?.writer).toBe(res.defs.writer); + // Leaves with no `agents:` retain undefined — only declared keys wire. + expect(res.defs.analyst.agents).toBeUndefined(); + expect(res.defs.writer.agents).toBeUndefined(); + }); + + test("mutual delegation is allowed (runtime depth cap handles cycles)", async () => { + write("a.md", "---\nendpoint: e\nagents:\n - b\n---\nA."); + write("b.md", "---\nendpoint: e\nagents:\n - a\n---\nB."); + + const res = await loadAgentsFromDir(workDir, {}); + expect(res.defs.a.agents?.b).toBe(res.defs.b); + expect(res.defs.b.agents?.a).toBe(res.defs.a); + }); + + test("throws with available list when a sibling is missing", async () => { + write("dispatcher.md", "---\nendpoint: e\nagents:\n - ghost\n---\nD."); + write("analyst.md", "---\nendpoint: e\n---\nAnalyst."); + await expect(loadAgentsFromDir(workDir, {})).rejects.toThrow( + /references sub-agent\(s\) 'ghost'.*Available: analyst, dispatcher/s, + ); + }); + + test("reports every missing sibling in one error, not just the first", async () => { + write( + "dispatcher.md", + "---\nendpoint: e\nagents:\n - ghost1\n - ghost2\n---\nD.", + ); + await expect(loadAgentsFromDir(workDir, {})).rejects.toThrow( + /ghost1, ghost2/, + ); + }); + + test("throws on self-reference", async () => { + write("solo.md", "---\nendpoint: e\nagents:\n - solo\n---\nSolo."); + await expect(loadAgentsFromDir(workDir, {})).rejects.toThrow( + /'solo'.*cannot reference itself/s, + ); + }); + + test("throws on non-array 'agents:' value", async () => { + write("bad.md", "---\nendpoint: e\nagents: analyst\n---\nBad."); + write("analyst.md", "---\nendpoint: e\n---\nAnalyst."); + await expect(loadAgentsFromDir(workDir, {})).rejects.toThrow( + /invalid 'agents:' frontmatter/, + ); + }); + + test("throws on non-string entries in 'agents:'", async () => { + write("bad.md", "---\nendpoint: e\nagents:\n - 42\n---\nBad."); + await expect(loadAgentsFromDir(workDir, {})).rejects.toThrow( + /invalid 'agents:' entry/, + ); + }); + + test("deduplicates repeated entries silently", async () => { + write( + "dispatcher.md", + "---\nendpoint: e\nagents:\n - analyst\n - analyst\n---\nD.", + ); + write("analyst.md", "---\nendpoint: e\n---\nAnalyst."); + const res = await loadAgentsFromDir(workDir, {}); + expect(Object.keys(res.defs.dispatcher.agents ?? {})).toEqual([ + "analyst", + ]); + }); + + test("empty array yields no sub-agents (no-op)", async () => { + write("dispatcher.md", "---\nendpoint: e\nagents: []\n---\nD."); + const res = await loadAgentsFromDir(workDir, {}); + expect(res.defs.dispatcher.agents).toBeUndefined(); + }); + + test("resolves 'agents:' references against codeAgents when provided", async () => { + write("dispatcher.md", "---\nendpoint: e\nagents:\n - support\n---\nD."); + const support: AgentDefinition = { + name: "support", + instructions: "Code-defined support.", + }; + const res = await loadAgentsFromDir(workDir, { + codeAgents: { support }, + }); + expect(res.defs.dispatcher.agents?.support).toBe(support); + }); + + test("codeAgents takes precedence over markdown sibling with the same name", async () => { + write("dispatcher.md", "---\nendpoint: e\nagents:\n - support\n---\nD."); + write("support.md", "---\nendpoint: e\n---\nMarkdown support."); + const codeSupport: AgentDefinition = { + name: "support", + instructions: "Code support.", + }; + const res = await loadAgentsFromDir(workDir, { + codeAgents: { support: codeSupport }, + }); + // Reference binds to code version, matching the plugin's top-level + // `code wins` merge behaviour. + expect(res.defs.dispatcher.agents?.support).toBe(codeSupport); + expect(res.defs.dispatcher.agents?.support.instructions).toBe( + "Code support.", + ); + }); + + test("missing-sibling error lists both markdown and code agent names", async () => { + write("dispatcher.md", "---\nendpoint: e\nagents:\n - ghost\n---\nD."); + write("analyst.md", "---\nendpoint: e\n---\nAnalyst."); + const codeAgent: AgentDefinition = { + name: "writer", + instructions: "Writer.", + }; + await expect( + loadAgentsFromDir(workDir, { codeAgents: { writer: codeAgent } }), + ).rejects.toThrow(/Available: analyst, dispatcher, writer/); + }); + }); +}); + +describe("loadAgentFromFile — sub-agent refs rejected", () => { + test("throws when 'agents:' is non-empty in a single-file load", async () => { + const p = write( + "lonely.md", + "---\nendpoint: e\nagents:\n - ghost\n---\nLonely.", + ); + await expect(loadAgentFromFile(p, {})).rejects.toThrow( + /requires loadAgentsFromDir/, + ); + }); + + test("ignores empty 'agents:' array (treated as absent)", async () => { + const p = write("lonely.md", "---\nendpoint: e\nagents: []\n---\nLonely."); + const def = await loadAgentFromFile(p, {}); + expect(def.agents).toBeUndefined(); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/run-agent.test.ts b/packages/appkit/src/plugins/agents/tests/run-agent.test.ts new file mode 100644 index 00000000..1a974811 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/run-agent.test.ts @@ -0,0 +1,120 @@ +import type { + AgentAdapter, + AgentEvent, + AgentInput, + AgentRunContext, +} from "shared"; +import { describe, expect, test, vi } from "vitest"; +import { z } from "zod"; +import { createAgent } from "../../../core/create-agent-def"; +import { runAgent } from "../../../core/run-agent"; +import { tool } from "../tools/tool"; +import type { ToolkitEntry } from "../types"; + +function scriptedAdapter(events: AgentEvent[]): AgentAdapter { + return { + async *run(_input: AgentInput, _context: AgentRunContext) { + for (const event of events) { + yield event; + } + }, + }; +} + +describe("runAgent", () => { + test("drives the adapter and returns aggregated text", async () => { + const events: AgentEvent[] = [ + { type: "message_delta", content: "Hello " }, + { type: "message_delta", content: "world" }, + { type: "status", status: "complete" }, + ]; + const def = createAgent({ + instructions: "Say hello", + model: scriptedAdapter(events), + }); + + const result = await runAgent(def, { messages: "hi" }); + expect(result.text).toBe("Hello world"); + expect(result.events).toHaveLength(3); + }); + + test("prefers terminal 'message' event over deltas when present", async () => { + const events: AgentEvent[] = [ + { type: "message_delta", content: "partial" }, + { type: "message", content: "final answer" }, + ]; + const def = createAgent({ + instructions: "x", + model: scriptedAdapter(events), + }); + const result = await runAgent(def, { messages: "hi" }); + expect(result.text).toBe("final answer"); + }); + + test("invokes inline tools via executeTool callback", async () => { + const weatherFn = vi.fn(async () => "Sunny in NYC"); + const weather = tool({ + name: "get_weather", + description: "Weather", + schema: z.object({ city: z.string() }), + execute: weatherFn, + }); + + let capturedCtx: AgentRunContext | null = null; + const adapter: AgentAdapter = { + async *run(_input, context) { + capturedCtx = context; + yield { type: "message_delta", content: "" }; + }, + }; + + const def = createAgent({ + instructions: "x", + model: adapter, + tools: { get_weather: weather }, + }); + + await runAgent(def, { messages: "hi" }); + expect(capturedCtx).not.toBeNull(); + // biome-ignore lint/style/noNonNullAssertion: asserted above + const result = await capturedCtx!.executeTool("get_weather", { + city: "NYC", + }); + expect(result).toBe("Sunny in NYC"); + expect(weatherFn).toHaveBeenCalledWith({ city: "NYC" }); + }); + + test("throws a clear error when a ToolkitEntry is invoked", async () => { + const toolkitEntry: ToolkitEntry = { + __toolkitRef: true, + pluginName: "analytics", + localName: "query", + def: { + name: "analytics.query", + description: "SQL", + parameters: { type: "object", properties: {} }, + }, + }; + + let capturedCtx: AgentRunContext | null = null; + const adapter: AgentAdapter = { + async *run(_input, context) { + capturedCtx = context; + yield { type: "message_delta", content: "" }; + }, + }; + + const def = createAgent({ + instructions: "x", + model: adapter, + tools: { "analytics.query": toolkitEntry }, + }); + + await runAgent(def, { messages: "hi" }); + expect(capturedCtx).not.toBeNull(); + await expect( + // biome-ignore lint/style/noNonNullAssertion: asserted above + capturedCtx!.executeTool("analytics.query", {}), + ).rejects.toThrow(/only usable via createApp/); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/system-prompt.test.ts b/packages/appkit/src/plugins/agents/tests/system-prompt.test.ts new file mode 100644 index 00000000..83bf8e19 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/system-prompt.test.ts @@ -0,0 +1,45 @@ +import { describe, expect, test } from "vitest"; +import { buildBaseSystemPrompt, composeSystemPrompt } from "../system-prompt"; + +describe("buildBaseSystemPrompt", () => { + test("includes plugin names", () => { + const prompt = buildBaseSystemPrompt(["analytics", "files", "genie"]); + expect(prompt).toContain("Active plugins: analytics, files, genie"); + }); + + test("includes guidelines", () => { + const prompt = buildBaseSystemPrompt([]); + expect(prompt).toContain("Guidelines:"); + expect(prompt).toContain("Databricks SQL"); + expect(prompt).toContain("summarize key findings"); + }); + + test("works with no plugins", () => { + const prompt = buildBaseSystemPrompt([]); + expect(prompt).toContain("AI assistant running on Databricks AppKit"); + expect(prompt).not.toContain("Active plugins:"); + }); + + test("does NOT include individual tool names", () => { + const prompt = buildBaseSystemPrompt(["analytics"]); + expect(prompt).not.toContain("analytics.query"); + expect(prompt).not.toContain("Available tools:"); + }); +}); + +describe("composeSystemPrompt", () => { + test("concatenates base + agent prompt with double newline", () => { + const composed = composeSystemPrompt("Base prompt.", "Agent prompt."); + expect(composed).toBe("Base prompt.\n\nAgent prompt."); + }); + + test("returns base prompt alone when no agent prompt", () => { + const composed = composeSystemPrompt("Base prompt."); + expect(composed).toBe("Base prompt."); + }); + + test("returns base prompt when agent prompt is empty string", () => { + const composed = composeSystemPrompt("Base prompt.", ""); + expect(composed).toBe("Base prompt."); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/thread-store.test.ts b/packages/appkit/src/plugins/agents/tests/thread-store.test.ts new file mode 100644 index 00000000..ed4f70ba --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/thread-store.test.ts @@ -0,0 +1,138 @@ +import { describe, expect, test } from "vitest"; +import { InMemoryThreadStore } from "../thread-store"; + +describe("InMemoryThreadStore", () => { + test("create() returns a new thread with the given userId", async () => { + const store = new InMemoryThreadStore(); + const thread = await store.create("user-1"); + + expect(thread.id).toBeDefined(); + expect(thread.userId).toBe("user-1"); + expect(thread.messages).toEqual([]); + expect(thread.createdAt).toBeInstanceOf(Date); + expect(thread.updatedAt).toBeInstanceOf(Date); + }); + + test("get() returns the thread for the correct user", async () => { + const store = new InMemoryThreadStore(); + const thread = await store.create("user-1"); + + const retrieved = await store.get(thread.id, "user-1"); + expect(retrieved).toEqual(thread); + }); + + test("get() returns null for wrong user", async () => { + const store = new InMemoryThreadStore(); + const thread = await store.create("user-1"); + + const retrieved = await store.get(thread.id, "user-2"); + expect(retrieved).toBeNull(); + }); + + test("get() returns null for non-existent thread", async () => { + const store = new InMemoryThreadStore(); + const retrieved = await store.get("non-existent", "user-1"); + expect(retrieved).toBeNull(); + }); + + test("list() returns threads sorted by updatedAt desc", async () => { + const store = new InMemoryThreadStore(); + const t1 = await store.create("user-1"); + const t2 = await store.create("user-1"); + + // Make t1 more recently updated + await store.addMessage(t1.id, "user-1", { + id: "msg-1", + role: "user", + content: "hello", + createdAt: new Date(), + }); + + const threads = await store.list("user-1"); + expect(threads).toHaveLength(2); + expect(threads[0].id).toBe(t1.id); + expect(threads[1].id).toBe(t2.id); + }); + + test("list() returns empty for unknown user", async () => { + const store = new InMemoryThreadStore(); + await store.create("user-1"); + + const threads = await store.list("user-2"); + expect(threads).toEqual([]); + }); + + test("addMessage() appends to thread and updates timestamp", async () => { + const store = new InMemoryThreadStore(); + const thread = await store.create("user-1"); + const originalUpdatedAt = thread.updatedAt; + + // Small delay to ensure timestamp differs + await new Promise((r) => setTimeout(r, 5)); + + await store.addMessage(thread.id, "user-1", { + id: "msg-1", + role: "user", + content: "hello", + createdAt: new Date(), + }); + + const updated = await store.get(thread.id, "user-1"); + expect(updated?.messages).toHaveLength(1); + expect(updated?.messages[0].content).toBe("hello"); + expect(updated?.updatedAt.getTime()).toBeGreaterThanOrEqual( + originalUpdatedAt.getTime(), + ); + }); + + test("addMessage() throws for non-existent thread", async () => { + const store = new InMemoryThreadStore(); + + await expect( + store.addMessage("non-existent", "user-1", { + id: "msg-1", + role: "user", + content: "hello", + createdAt: new Date(), + }), + ).rejects.toThrow("Thread non-existent not found"); + }); + + test("delete() removes a thread and returns true", async () => { + const store = new InMemoryThreadStore(); + const thread = await store.create("user-1"); + + const deleted = await store.delete(thread.id, "user-1"); + expect(deleted).toBe(true); + + const retrieved = await store.get(thread.id, "user-1"); + expect(retrieved).toBeNull(); + }); + + test("delete() returns false for non-existent thread", async () => { + const store = new InMemoryThreadStore(); + const deleted = await store.delete("non-existent", "user-1"); + expect(deleted).toBe(false); + }); + + test("delete() returns false for wrong user", async () => { + const store = new InMemoryThreadStore(); + const thread = await store.create("user-1"); + + const deleted = await store.delete(thread.id, "user-2"); + expect(deleted).toBe(false); + }); + + test("threads are isolated per user", async () => { + const store = new InMemoryThreadStore(); + await store.create("user-1"); + await store.create("user-1"); + await store.create("user-2"); + + const user1Threads = await store.list("user-1"); + const user2Threads = await store.list("user-2"); + + expect(user1Threads).toHaveLength(2); + expect(user2Threads).toHaveLength(1); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/tool-approval-gate.test.ts b/packages/appkit/src/plugins/agents/tests/tool-approval-gate.test.ts new file mode 100644 index 00000000..1e17ddf6 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/tool-approval-gate.test.ts @@ -0,0 +1,156 @@ +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { ToolApprovalGate } from "../tool-approval-gate"; + +describe("ToolApprovalGate", () => { + let gate: ToolApprovalGate; + + beforeEach(() => { + vi.useFakeTimers(); + gate = new ToolApprovalGate(); + }); + + afterEach(() => { + vi.useRealTimers(); + }); + + test("resolves with 'approve' when a matching submit arrives", async () => { + const waiter = gate.wait({ + approvalId: "a1", + streamId: "s1", + userId: "alice", + timeoutMs: 60_000, + }); + expect(gate.size).toBe(1); + + const result = gate.submit({ + approvalId: "a1", + userId: "alice", + decision: "approve", + }); + + expect(result).toEqual({ ok: true }); + await expect(waiter).resolves.toBe("approve"); + expect(gate.size).toBe(0); + }); + + test("resolves with 'deny' on explicit deny", async () => { + const waiter = gate.wait({ + approvalId: "a2", + streamId: "s1", + userId: "alice", + timeoutMs: 60_000, + }); + gate.submit({ + approvalId: "a2", + userId: "alice", + decision: "deny", + }); + await expect(waiter).resolves.toBe("deny"); + }); + + test("auto-denies after timeoutMs with no submit", async () => { + const waiter = gate.wait({ + approvalId: "a3", + streamId: "s1", + userId: "alice", + timeoutMs: 1000, + }); + vi.advanceTimersByTime(1000); + await expect(waiter).resolves.toBe("deny"); + expect(gate.size).toBe(0); + }); + + test("refuses a submit from a different user (ownership check)", async () => { + const waiter = gate.wait({ + approvalId: "a4", + streamId: "s1", + userId: "alice", + timeoutMs: 60_000, + }); + const result = gate.submit({ + approvalId: "a4", + userId: "bob", + decision: "approve", + }); + expect(result).toEqual({ ok: false, reason: "forbidden" }); + expect(gate.size).toBe(1); + // Waiter is still pending; cleanup to let fake timers drain. + gate.submit({ + approvalId: "a4", + userId: "alice", + decision: "deny", + }); + await expect(waiter).resolves.toBe("deny"); + }); + + test("returns 'unknown' reason when approvalId is not registered", () => { + expect( + gate.submit({ approvalId: "nope", userId: "x", decision: "approve" }), + ).toEqual({ ok: false, reason: "unknown" }); + }); + + test("abortStream denies every pending gate for that stream", async () => { + const a = gate.wait({ + approvalId: "a5", + streamId: "s1", + userId: "alice", + timeoutMs: 60_000, + }); + const b = gate.wait({ + approvalId: "a6", + streamId: "s1", + userId: "alice", + timeoutMs: 60_000, + }); + const c = gate.wait({ + approvalId: "a7", + streamId: "s2", + userId: "alice", + timeoutMs: 60_000, + }); + gate.abortStream("s1"); + await expect(a).resolves.toBe("deny"); + await expect(b).resolves.toBe("deny"); + expect(gate.size).toBe(1); + // s2's waiter is still pending; settle it to clean up timers. + gate.submit({ approvalId: "a7", userId: "alice", decision: "deny" }); + await expect(c).resolves.toBe("deny"); + }); + + test("abortAll denies every pending gate", async () => { + const a = gate.wait({ + approvalId: "a8", + streamId: "s1", + userId: "alice", + timeoutMs: 60_000, + }); + const b = gate.wait({ + approvalId: "a9", + streamId: "s2", + userId: "bob", + timeoutMs: 60_000, + }); + gate.abortAll(); + await expect(a).resolves.toBe("deny"); + await expect(b).resolves.toBe("deny"); + expect(gate.size).toBe(0); + }); + + test("a timed-out approval cannot be resolved by a late submit", async () => { + const waiter = gate.wait({ + approvalId: "a10", + streamId: "s1", + userId: "alice", + timeoutMs: 500, + }); + vi.advanceTimersByTime(500); + await expect(waiter).resolves.toBe("deny"); + + const late = gate.submit({ + approvalId: "a10", + userId: "alice", + decision: "approve", + }); + expect(late).toEqual({ ok: false, reason: "unknown" }); + }); +}); diff --git a/packages/appkit/src/plugins/agents/thread-store.ts b/packages/appkit/src/plugins/agents/thread-store.ts new file mode 100644 index 00000000..7c4622cd --- /dev/null +++ b/packages/appkit/src/plugins/agents/thread-store.ts @@ -0,0 +1,66 @@ +import { randomUUID } from "node:crypto"; +import type { Message, Thread, ThreadStore } from "shared"; + +/** + * In-memory thread store backed by a nested Map. + * + * Outer key: userId, inner key: threadId. Thread history is retained for the + * lifetime of the process with no eviction, caps, or TTL — a chatty user will + * grow the in-memory footprint monotonically, and the server loses every + * thread on restart. **This implementation is intended for local development + * and single-process demos only.** + * + * For any real deployment, pass a persistent `ThreadStore` to `agents({ ... })` + * (e.g. a Lakebase- or Postgres-backed implementation). A bounded + * `InMemoryThreadStore` with eviction policies is tracked as a follow-up. + */ +export class InMemoryThreadStore implements ThreadStore { + private store = new Map>(); + + async create(userId: string): Promise { + const now = new Date(); + const thread: Thread = { + id: randomUUID(), + userId, + messages: [], + createdAt: now, + updatedAt: now, + }; + this.userMap(userId).set(thread.id, thread); + return thread; + } + + async get(threadId: string, userId: string): Promise { + return this.userMap(userId).get(threadId) ?? null; + } + + async list(userId: string): Promise { + return Array.from(this.userMap(userId).values()).sort( + (a, b) => b.updatedAt.getTime() - a.updatedAt.getTime(), + ); + } + + async addMessage( + threadId: string, + userId: string, + message: Message, + ): Promise { + const thread = this.userMap(userId).get(threadId); + if (!thread) throw new Error(`Thread ${threadId} not found`); + thread.messages.push(message); + thread.updatedAt = new Date(); + } + + async delete(threadId: string, userId: string): Promise { + return this.userMap(userId).delete(threadId); + } + + private userMap(userId: string): Map { + let map = this.store.get(userId); + if (!map) { + map = new Map(); + this.store.set(userId, map); + } + return map; + } +} diff --git a/packages/appkit/src/plugins/agents/tool-approval-gate.ts b/packages/appkit/src/plugins/agents/tool-approval-gate.ts new file mode 100644 index 00000000..669f30a9 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tool-approval-gate.ts @@ -0,0 +1,122 @@ +/** + * Server-side state for the human-in-the-loop approval gate on + * `destructive: true` agent tool calls. + * + * Lifecycle: + * + * 1. `wait(...)` is called from inside `executeTool` when a destructive tool + * is about to execute. A `Pending` record is registered and a timer is + * scheduled for auto-deny. The returned promise is what blocks the + * adapter until the decision arrives. + * 2. The client receives an `appkit.approval_pending` SSE event carrying the + * `approvalId` + `streamId` and posts a decision to `POST /chat/approve`. + * The route calls {@link ToolApprovalGate.submit} which resolves the + * pending promise and clears the timer. + * 3. If no submit arrives within `timeoutMs`, the timer fires and the + * promise resolves with `"deny"`. + * + * Security invariants: + * + * - `submit` verifies that the decider's user id matches the user who + * initiated the stream (set by `wait`). Mismatches are rejected without + * resolving the pending promise — this prevents a second user from + * approving (or denying) another user's destructive action. + * - `abort(streamId)` cancels every pending gate for a stream and denies + * each one. Used when the enclosing stream is cancelled or the plugin is + * shutting down. + */ +type ApprovalDecision = "approve" | "deny"; + +interface Pending { + resolve: (decision: ApprovalDecision) => void; + userId: string; + streamId: string; + timeout: ReturnType; +} + +type ApprovalSubmitResult = + | { ok: true } + | { ok: false; reason: "unknown" | "forbidden" }; + +export class ToolApprovalGate { + private pending = new Map(); + + /** + * Register a pending approval and return a promise that resolves with the + * user's decision or with `"deny"` when the timeout elapses. The returned + * promise never rejects. + */ + wait(args: { + approvalId: string; + streamId: string; + userId: string; + timeoutMs: number; + }): Promise { + const { approvalId, streamId, userId, timeoutMs } = args; + return new Promise((resolve) => { + const timeout = setTimeout(() => { + if (this.pending.delete(approvalId)) { + resolve("deny"); + } + }, timeoutMs); + this.pending.set(approvalId, { + resolve, + userId, + streamId, + timeout, + }); + }); + } + + /** + * Settle an approval with a user decision. Returns: + * - `{ ok: true }` if the pending record existed, the userId matched, and + * the promise was resolved. + * - `{ ok: false, reason: "unknown" }` if no pending record matches the id. + * - `{ ok: false, reason: "forbidden" }` if the userId does not match the + * user who initiated the stream. + */ + submit(args: { + approvalId: string; + userId: string; + decision: ApprovalDecision; + }): ApprovalSubmitResult { + const { approvalId, userId, decision } = args; + const p = this.pending.get(approvalId); + if (!p) return { ok: false, reason: "unknown" }; + if (p.userId !== userId) return { ok: false, reason: "forbidden" }; + clearTimeout(p.timeout); + this.pending.delete(approvalId); + p.resolve(decision); + return { ok: true }; + } + + /** + * Cancel all pending gates for a specific stream (e.g., when the user + * cancels the stream). Each gate resolves with `"deny"` so the adapter + * unwinds cleanly. + */ + abortStream(streamId: string): void { + for (const [id, p] of this.pending) { + if (p.streamId === streamId) { + clearTimeout(p.timeout); + this.pending.delete(id); + p.resolve("deny"); + } + } + } + + /** Cancel every pending gate. Used at plugin shutdown. */ + abortAll(): void { + for (const [id, p] of this.pending) { + clearTimeout(p.timeout); + this.pending.delete(id); + p.resolve("deny"); + } + } + + /** Number of pending approvals (test/diagnostic helper). */ + get size(): number { + return this.pending.size; + } +} diff --git a/packages/appkit/src/plugins/agents/types.ts b/packages/appkit/src/plugins/agents/types.ts index 086a0426..e18cc8f4 100644 --- a/packages/appkit/src/plugins/agents/types.ts +++ b/packages/appkit/src/plugins/agents/types.ts @@ -1,6 +1,13 @@ -import type { AgentToolDefinition, ToolAnnotations } from "shared"; +import type { + AgentAdapter, + AgentToolDefinition, + BasePluginConfig, + ThreadStore, + ToolAnnotations, +} from "shared"; import type { FunctionTool } from "./tools/function-tool"; import type { HostedTool } from "./tools/hosted-tools"; +import type { McpHostPolicyConfig } from "./tools/mcp-host-policy"; /** * A tool reference produced by a plugin's `.toolkit()` call. The agents plugin @@ -42,8 +49,172 @@ export interface ToolkitOptions { } /** - * Type guard for `ToolkitEntry` — used to differentiate toolkit references - * from inline tools in a mixed `tools` record. + * Context passed to `baseSystemPrompt` callbacks. + */ +export interface PromptContext { + agentName: string; + pluginNames: string[]; + toolNames: string[]; +} + +export type BaseSystemPromptOption = + | false + | string + | ((ctx: PromptContext) => string); + +export interface AgentDefinition { + /** Filled in from the enclosing key when used in `agents: { foo: def }`. */ + name?: string; + /** System prompt body. For markdown-loaded agents this is the file body. */ + instructions: string; + /** + * Model adapter (or endpoint-name string sugar for + * `DatabricksAdapter.fromServingEndpoint({ endpointName })`). Optional — + * falls back to the plugin's `defaultModel`. + */ + model?: AgentAdapter | Promise | string; + /** Per-agent tool record. Key is the LLM-visible tool-call name. */ + tools?: Record; + /** Sub-agents, exposed as `agent-` tools on this agent. */ + agents?: Record; + /** Override the plugin's baseSystemPrompt for this agent only. */ + baseSystemPrompt?: BaseSystemPromptOption; + maxSteps?: number; + maxTokens?: number; + /** + * When true, the thread used for a chat request against this agent is + * deleted from `ThreadStore` after the stream completes (success or + * failure). Use for stateless one-shot agents — e.g. autocomplete, where + * each request is independent and retaining history would both poison + * future calls and accumulate unbounded state in the default + * `InMemoryThreadStore`. Defaults to `false`. + */ + ephemeral?: boolean; +} + +/** + * Auto-inherit configuration. When enabled for a given agent origin, agents + * with no explicit `tools:` declaration receive every registered ToolProvider + * plugin tool whose author marked `autoInheritable: true`. Tools without that + * flag — destructive, state-mutating, or privilege-sensitive — never spread + * automatically and must be wired via `tools:`, `toolkits:`, or `fromPlugin`. + * + * Defaults are `false` for both origins (safe-by-default): developers must + * consciously opt an origin in to any auto-inherit behaviour. + */ +export interface AutoInheritToolsConfig { + /** Default for agents loaded from markdown files. Default: `false`. */ + file?: boolean; + /** Default for code-defined agents (via `agents: { foo: createAgent(...) }`). Default: `false`. */ + code?: boolean; +} + +export interface AgentsPluginConfig extends BasePluginConfig { + /** Directory to scan for markdown agent files. Default `./config/agents`. Set to `false` to disable. */ + dir?: string | false; + /** Code-defined agents, merged with file-loaded ones (code wins on key collision). */ + agents?: Record; + /** Agent used when clients don't specify one. Defaults to the first-registered agent or the file with `default: true` frontmatter. */ + defaultAgent?: string; + /** Default model for agents that don't specify their own (in code or frontmatter). */ + defaultModel?: AgentAdapter | Promise | string; + /** Ambient tool library. Keys may be referenced by markdown frontmatter via `tools: [key1, key2]`. */ + tools?: Record; + /** Whether to auto-inherit every ToolProvider plugin's toolkit. Accepts a boolean shorthand. */ + autoInheritTools?: boolean | AutoInheritToolsConfig; + /** Persistent thread store. Default: in-memory. */ + threadStore?: ThreadStore; + /** Customize or disable the AppKit base system prompt. */ + baseSystemPrompt?: BaseSystemPromptOption; + /** + * MCP server host policy. By default only same-origin Databricks workspace + * URLs may be used as MCP endpoints; custom hosts must be explicitly + * allowlisted here. Workspace credentials (SP / OBO) are never forwarded + * to non-workspace hosts. + */ + mcp?: McpHostPolicyConfig; + /** + * Human-in-the-loop approval gate for destructive tool calls. When enabled + * (the default), the agents plugin emits an `appkit.approval_pending` SSE + * event before executing any tool annotated `destructive: true` and waits + * for a `POST /chat/approve` decision from the same user who initiated the + * stream. A missing decision after `timeoutMs` auto-denies the call. + */ + approval?: { + /** Require human approval for tools annotated `destructive: true`. Default: `true`. */ + requireForDestructive?: boolean; + /** Milliseconds to wait before auto-denying. Default: 60_000. */ + timeoutMs?: number; + }; + /** + * Runtime resource limits applied during agent execution. Defaults are + * tuned to protect a single-instance deployment from a misbehaving user or + * a runaway prompt injection; tighten or relax as appropriate for the + * deployment's scale and trust model. Request-body caps (chat message + * size, invocations input size / length) are enforced statically by the + * Zod schemas and are not configurable here. + */ + limits?: { + /** + * Max concurrent chat streams a single user may have open. Subsequent + * `POST /chat` requests from that user while at-limit are rejected with + * HTTP 429. Default: `5`. + */ + maxConcurrentStreamsPerUser?: number; + /** + * Max tool invocations per agent run (across the full tool-call graph, + * including sub-agent invocations). A run that exceeds the budget is + * aborted with a terminal error event. Default: `50`. + */ + maxToolCalls?: number; + /** + * Max sub-agent recursion depth. Protects against a prompt-injected + * agent that delegates to a sub-agent which in turn delegates back to + * itself (directly or transitively). Default: `3`. + */ + maxSubAgentDepth?: number; + }; +} + +/** Internal tool-index entry after a tool record has been resolved to a dispatchable form. */ +export type ResolvedToolEntry = + | { + source: "toolkit"; + pluginName: string; + localName: string; + def: AgentToolDefinition; + } + | { + source: "function"; + functionTool: FunctionTool; + def: AgentToolDefinition; + } + | { + source: "mcp"; + mcpToolName: string; + def: AgentToolDefinition; + } + | { + source: "subagent"; + agentName: string; + def: AgentToolDefinition; + }; + +export interface RegisteredAgent { + name: string; + instructions: string; + adapter: AgentAdapter; + toolIndex: Map; + baseSystemPrompt?: BaseSystemPromptOption; + maxSteps?: number; + maxTokens?: number; + /** Mirrors `AgentDefinition.ephemeral` — skip thread persistence. */ + ephemeral?: boolean; +} + +/** + * Type guard for `ToolkitEntry` — used by the agents plugin to differentiate + * toolkit references from inline tools in a mixed `tools` record. */ export function isToolkitEntry(value: unknown): value is ToolkitEntry { return ( diff --git a/packages/shared/src/agent.ts b/packages/shared/src/agent.ts index c4f76b29..8e34d5fb 100644 --- a/packages/shared/src/agent.ts +++ b/packages/shared/src/agent.ts @@ -86,7 +86,21 @@ export type AgentEvent = status: "running" | "waiting" | "complete" | "error"; error?: string; } - | { type: "metadata"; data: Record }; + | { type: "metadata"; data: Record } + | { + /** + * Emitted by the agents plugin (not adapters) when a tool call annotated + * `destructive: true` is awaiting human approval. Clients should render + * an approval prompt and POST to `/chat/approve` with the matching + * `approvalId` and a `decision` of `approve` or `deny`. + */ + type: "approval_pending"; + approvalId: string; + streamId: string; + toolName: string; + args: unknown; + annotations?: ToolAnnotations; + }; // --------------------------------------------------------------------------- // Responses API types (OpenAI-compatible wire format for HTTP boundary) @@ -178,6 +192,23 @@ export interface AppKitMetadataEvent { sequence_number: number; } +/** + * Emitted when a destructive tool call is awaiting human approval. The client + * should render an approval UI and POST the decision to `/chat/approve` with + * `{ streamId, approvalId, decision: "approve" | "deny" }`. If no decision + * arrives before the server-side timeout, the call is auto-denied and the + * agent receives a denial string as the tool output. + */ +export interface AppKitApprovalPendingEvent { + type: "appkit.approval_pending"; + approval_id: string; + stream_id: string; + tool_name: string; + args: unknown; + annotations?: ToolAnnotations; + sequence_number: number; +} + export type ResponseStreamEvent = | ResponseOutputItemAddedEvent | ResponseOutputItemDoneEvent @@ -186,7 +217,8 @@ export type ResponseStreamEvent = | ResponseErrorEvent | ResponseFailedEvent | AppKitThinkingEvent - | AppKitMetadataEvent; + | AppKitMetadataEvent + | AppKitApprovalPendingEvent; // --------------------------------------------------------------------------- // Adapter contract diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 16079b1d..307f44cf 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -305,6 +305,9 @@ importers: express: specifier: 4.22.0 version: 4.22.0 + js-yaml: + specifier: ^4.1.1 + version: 4.1.1 obug: specifier: 2.1.1 version: 2.1.1 @@ -339,6 +342,9 @@ importers: '@types/express': specifier: 4.17.25 version: 4.17.25 + '@types/js-yaml': + specifier: ^4.0.9 + version: 4.0.9 '@types/json-schema': specifier: 7.0.15 version: 7.0.15 @@ -4989,6 +4995,9 @@ packages: '@types/istanbul-reports@3.0.4': resolution: {integrity: sha512-pk2B1NWalF9toCRu6gjBzR69syFjP4Od8WRAX+0mmf9lAjCRicLOWc+ZrxZHx/0XRjotgkF9t6iaMJ+aXcOdZQ==} + '@types/js-yaml@4.0.9': + resolution: {integrity: sha512-k4MGaQl5TGo/iipqb2UDG2UwjXziSWkh0uysQelTlJpX1qGlpUZYm8PnO4DxG1qBomtJUdYJ6qR6xdIah10JLg==} + '@types/jsesc@2.5.1': resolution: {integrity: sha512-9VN+6yxLOPLOav+7PwjZbxiID2bVaeq0ED4qSQmdQTdjnXJSaCVKTR58t15oqH1H5t8Ng2ZX1SabJVoN9Q34bw==} @@ -17421,6 +17430,8 @@ snapshots: dependencies: '@types/istanbul-lib-report': 3.0.3 + '@types/js-yaml@4.0.9': {} + '@types/jsesc@2.5.1': {} '@types/json-schema@7.0.15': {} From ddde87b265804dbe37409dfc01d75a84962552bc Mon Sep 17 00:00:00 2001 From: MarioCadenas Date: Thu, 23 Apr 2026 17:59:39 +0200 Subject: [PATCH 15/23] refactor(appkit): generalize default base system prompt Tool-agnostic guidelines instead of SQL/files-specific defaults; accept full PromptContext in buildBaseSystemPrompt for parity with custom callbacks. Signed-off-by: MarioCadenas --- packages/appkit/src/plugins/agents/agents.ts | 2 +- .../src/plugins/agents/system-prompt.ts | 32 +++++++++++++------ .../agents/tests/system-prompt.test.ts | 30 ++++++++++++----- 3 files changed, 45 insertions(+), 19 deletions(-) diff --git a/packages/appkit/src/plugins/agents/agents.ts b/packages/appkit/src/plugins/agents/agents.ts index 07a637fd..98b49f20 100644 --- a/packages/appkit/src/plugins/agents/agents.ts +++ b/packages/appkit/src/plugins/agents/agents.ts @@ -1246,7 +1246,7 @@ function composePromptForAgent( } else if (typeof resolved === "function") { base = resolved(ctx); } else { - base = buildBaseSystemPrompt(ctx.pluginNames); + base = buildBaseSystemPrompt(ctx); } return composeSystemPrompt(base, registered.instructions); diff --git a/packages/appkit/src/plugins/agents/system-prompt.ts b/packages/appkit/src/plugins/agents/system-prompt.ts index 634f49c5..01f3fe9b 100644 --- a/packages/appkit/src/plugins/agents/system-prompt.ts +++ b/packages/appkit/src/plugins/agents/system-prompt.ts @@ -1,28 +1,40 @@ +import type { PromptContext } from "./types"; + /** - * Builds the AppKit base system prompt from active plugin names. + * Default base system prompt: product identity, active AppKit plugins, and + * tool-agnostic behavior hints. * - * The base prompt provides guidelines and app context. It does NOT - * include individual tool descriptions — those are sent via the - * structured `tools` API parameter to the LLM. + * Individual tool definitions and JSON Schemas are still sent through the + * model's `tools` / function-calling channel — this string is not a second + * copy of that list. `ctx.toolNames` is available for custom + * `baseSystemPrompt` callbacks; the default text stays short and does not + * enumerate tools to avoid drift and token bloat. */ -export function buildBaseSystemPrompt(pluginNames: string[]): string { +export function buildBaseSystemPrompt(ctx: PromptContext): string { + const { pluginNames } = ctx; const lines: string[] = [ "You are an AI assistant running on Databricks AppKit.", ]; if (pluginNames.length > 0) { lines.push(""); - lines.push(`Active plugins: ${pluginNames.join(", ")}`); + lines.push(`Active AppKit plugins: ${pluginNames.join(", ")}`); } lines.push(""); lines.push("Guidelines:"); - lines.push("- Use Databricks SQL syntax when writing queries"); lines.push( - "- When results are large, summarize key findings rather than dumping raw data", + "- Be concise: for large or noisy tool output, summarize what matters and how to go deeper instead of pasting everything.", + ); + lines.push( + "- Use each tool as defined: pass required arguments and use the syntax, dialect, or path rules the target system expects (see each tool’s description and schema).", + ); + lines.push( + "- If a tool call fails, explain the error in plain language and suggest a fix or next step.", + ); + lines.push( + "- Respect tool metadata and app policy: read-only vs destructive tools, user/identity context, and any approval or safety flows the app provides.", ); - lines.push("- If a tool call fails, explain the error clearly to the user"); - lines.push("- When browsing files, verify the path exists before reading"); return lines.join("\n"); } diff --git a/packages/appkit/src/plugins/agents/tests/system-prompt.test.ts b/packages/appkit/src/plugins/agents/tests/system-prompt.test.ts index 83bf8e19..25724259 100644 --- a/packages/appkit/src/plugins/agents/tests/system-prompt.test.ts +++ b/packages/appkit/src/plugins/agents/tests/system-prompt.test.ts @@ -1,27 +1,41 @@ import { describe, expect, test } from "vitest"; import { buildBaseSystemPrompt, composeSystemPrompt } from "../system-prompt"; +const emptyCtx = { + agentName: "a", + pluginNames: [] as string[], + toolNames: [] as string[], +}; + describe("buildBaseSystemPrompt", () => { test("includes plugin names", () => { - const prompt = buildBaseSystemPrompt(["analytics", "files", "genie"]); - expect(prompt).toContain("Active plugins: analytics, files, genie"); + const prompt = buildBaseSystemPrompt({ + agentName: "assistant", + pluginNames: ["analytics", "files", "genie"], + toolNames: [], + }); + expect(prompt).toContain("Active AppKit plugins: analytics, files, genie"); }); test("includes guidelines", () => { - const prompt = buildBaseSystemPrompt([]); + const prompt = buildBaseSystemPrompt(emptyCtx); expect(prompt).toContain("Guidelines:"); - expect(prompt).toContain("Databricks SQL"); - expect(prompt).toContain("summarize key findings"); + expect(prompt).toContain("syntax, dialect, or path rules"); + expect(prompt).toContain("summarize what matters"); }); test("works with no plugins", () => { - const prompt = buildBaseSystemPrompt([]); + const prompt = buildBaseSystemPrompt(emptyCtx); expect(prompt).toContain("AI assistant running on Databricks AppKit"); - expect(prompt).not.toContain("Active plugins:"); + expect(prompt).not.toContain("Active AppKit plugins:"); }); test("does NOT include individual tool names", () => { - const prompt = buildBaseSystemPrompt(["analytics"]); + const prompt = buildBaseSystemPrompt({ + agentName: "a", + pluginNames: ["analytics"], + toolNames: ["analytics.query"], + }); expect(prompt).not.toContain("analytics.query"); expect(prompt).not.toContain("Available tools:"); }); From de8b72c6deb1929c082f983dfaa2a9b0001f5028 Mon Sep 17 00:00:00 2001 From: MarioCadenas Date: Thu, 23 Apr 2026 18:01:42 +0200 Subject: [PATCH 16/23] feat(appkit): optional serving_endpoint on agents plugin manifest Register DATABRICKS_SERVING_ENDPOINT_NAME as optional CAN_QUERY so apps using Databricks-hosted agent models get resource wiring; optional when agents use only external adapters. Sync template/appkit.plugins.json. Signed-off-by: MarioCadenas --- .../appkit/src/plugins/agents/manifest.json | 16 ++++++++++++- template/appkit.plugins.json | 24 +++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/packages/appkit/src/plugins/agents/manifest.json b/packages/appkit/src/plugins/agents/manifest.json index cb7a43f8..f2766b80 100644 --- a/packages/appkit/src/plugins/agents/manifest.json +++ b/packages/appkit/src/plugins/agents/manifest.json @@ -5,6 +5,20 @@ "description": "AI agents driven by markdown configs or code, with auto-tool-discovery from registered plugins", "resources": { "required": [], - "optional": [] + "optional": [ + { + "type": "serving_endpoint", + "alias": "Model Serving (agents)", + "resourceKey": "agents-serving-endpoint", + "description": "Databricks Model Serving endpoint for agent runs that use workspace-hosted models (DatabricksAdapter, optional default serving endpoint env, or markdown configs that resolve to serving). Omit when agents rely only on external/custom model adapters.", + "permission": "CAN_QUERY", + "fields": { + "name": { + "env": "DATABRICKS_SERVING_ENDPOINT_NAME", + "description": "Serving endpoint name used for agent LLM inference when configured for Databricks Model Serving" + } + } + } + ] } } diff --git a/template/appkit.plugins.json b/template/appkit.plugins.json index d1420d2e..acbcb9d0 100644 --- a/template/appkit.plugins.json +++ b/template/appkit.plugins.json @@ -2,6 +2,30 @@ "$schema": "https://databricks.github.io/appkit/schemas/template-plugins.schema.json", "version": "1.0", "plugins": { + "agents": { + "name": "agents", + "displayName": "Agents Plugin", + "description": "AI agents driven by markdown configs or code, with auto-tool-discovery from registered plugins", + "package": "@databricks/appkit", + "resources": { + "required": [], + "optional": [ + { + "type": "serving_endpoint", + "alias": "Model Serving (agents)", + "resourceKey": "agents-serving-endpoint", + "description": "Databricks Model Serving endpoint for agent runs that use workspace-hosted models (DatabricksAdapter, optional default serving endpoint env, or markdown configs that resolve to serving). Omit when agents rely only on external/custom model adapters.", + "permission": "CAN_QUERY", + "fields": { + "name": { + "env": "DATABRICKS_SERVING_ENDPOINT_NAME", + "description": "Serving endpoint name used for agent LLM inference when configured for Databricks Model Serving" + } + } + } + ] + } + }, "analytics": { "name": "analytics", "displayName": "Analytics Plugin", From b2a7e95aff523232810f527ea822a3b7e3aa2cf4 Mon Sep 17 00:00:00 2001 From: MarioCadenas Date: Thu, 23 Apr 2026 18:03:21 +0200 Subject: [PATCH 17/23] fix(appkit): agents manifest uses DATABRICKS_AGENT_ENDPOINT MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Align optional serving resource with `DatabricksAdapter.fromModelServing()`, which reads `DATABRICKS_AGENT_ENDPOINT` — not `DATABRICKS_SERVING_ENDPOINT_NAME` (serving plugin). Sync template. Signed-off-by: MarioCadenas --- packages/appkit/src/plugins/agents/manifest.json | 6 +++--- template/appkit.plugins.json | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/packages/appkit/src/plugins/agents/manifest.json b/packages/appkit/src/plugins/agents/manifest.json index f2766b80..f3986c83 100644 --- a/packages/appkit/src/plugins/agents/manifest.json +++ b/packages/appkit/src/plugins/agents/manifest.json @@ -10,12 +10,12 @@ "type": "serving_endpoint", "alias": "Model Serving (agents)", "resourceKey": "agents-serving-endpoint", - "description": "Databricks Model Serving endpoint for agent runs that use workspace-hosted models (DatabricksAdapter, optional default serving endpoint env, or markdown configs that resolve to serving). Omit when agents rely only on external/custom model adapters.", + "description": "Databricks Model Serving endpoint for agents using workspace-hosted models (`DatabricksAdapter.fromModelServing`). Wire the same endpoint name AppKit reads from `DATABRICKS_AGENT_ENDPOINT` when no per-agent model is configured. Omit when agents use only external adapters.", "permission": "CAN_QUERY", "fields": { "name": { - "env": "DATABRICKS_SERVING_ENDPOINT_NAME", - "description": "Serving endpoint name used for agent LLM inference when configured for Databricks Model Serving" + "env": "DATABRICKS_AGENT_ENDPOINT", + "description": "Endpoint name passed to Model Serving when agents default to `DatabricksAdapter.fromModelServing()`" } } } diff --git a/template/appkit.plugins.json b/template/appkit.plugins.json index acbcb9d0..c8589f9a 100644 --- a/template/appkit.plugins.json +++ b/template/appkit.plugins.json @@ -14,12 +14,12 @@ "type": "serving_endpoint", "alias": "Model Serving (agents)", "resourceKey": "agents-serving-endpoint", - "description": "Databricks Model Serving endpoint for agent runs that use workspace-hosted models (DatabricksAdapter, optional default serving endpoint env, or markdown configs that resolve to serving). Omit when agents rely only on external/custom model adapters.", + "description": "Databricks Model Serving endpoint for agents using workspace-hosted models (`DatabricksAdapter.fromModelServing`). Wire the same endpoint name AppKit reads from `DATABRICKS_AGENT_ENDPOINT` when no per-agent model is configured. Omit when agents use only external adapters.", "permission": "CAN_QUERY", "fields": { "name": { - "env": "DATABRICKS_SERVING_ENDPOINT_NAME", - "description": "Serving endpoint name used for agent LLM inference when configured for Databricks Model Serving" + "env": "DATABRICKS_AGENT_ENDPOINT", + "description": "Endpoint name passed to Model Serving when agents default to `DatabricksAdapter.fromModelServing()`" } } } From c4b71349c92892df8e93690f51109798ee52207b Mon Sep 17 00:00:00 2001 From: MarioCadenas Date: Thu, 23 Apr 2026 18:20:05 +0200 Subject: [PATCH 18/23] feat(agents): folder-based markdown discovery (/agent.md) BREAKING CHANGE: top-level config/agents/*.md is no longer loaded. Use /agent.md. The skills directory name is reserved and skipped. Orphan top-level .md files error at load; subdirs without agent.md error. Export agentIdFromMarkdownPath for path-based id resolution. --- packages/appkit/src/index.ts | 1 + packages/appkit/src/plugins/agents/agents.ts | 2 +- packages/appkit/src/plugins/agents/index.ts | 1 + .../appkit/src/plugins/agents/load-agents.ts | 109 ++++++++----- .../agents/tests/agents-plugin.test.ts | 36 ++--- .../plugins/agents/tests/load-agents.test.ts | 146 ++++++++++++------ packages/appkit/src/plugins/agents/types.ts | 2 +- 7 files changed, 194 insertions(+), 103 deletions(-) diff --git a/packages/appkit/src/index.ts b/packages/appkit/src/index.ts index 23f72216..ba4110b3 100644 --- a/packages/appkit/src/index.ts +++ b/packages/appkit/src/index.ts @@ -73,6 +73,7 @@ export { type AgentDefinition, type AgentsPluginConfig, type AgentTool, + agentIdFromMarkdownPath, agents, type BaseSystemPromptOption, isToolkitEntry, diff --git a/packages/appkit/src/plugins/agents/agents.ts b/packages/appkit/src/plugins/agents/agents.ts index 98b49f20..f1938ef2 100644 --- a/packages/appkit/src/plugins/agents/agents.ts +++ b/packages/appkit/src/plugins/agents/agents.ts @@ -1253,7 +1253,7 @@ function composePromptForAgent( } /** - * Plugin factory for the agents plugin. Reads `config/agents/*.md` by default, + * Plugin factory for the agents plugin. Reads `config/agents//agent.md` by default, * resolves toolkits/tools from registered plugins, exposes `appkit.agents.*` * runtime API and mounts `/invocations`. * diff --git a/packages/appkit/src/plugins/agents/index.ts b/packages/appkit/src/plugins/agents/index.ts index 1adc41c1..377a8776 100644 --- a/packages/appkit/src/plugins/agents/index.ts +++ b/packages/appkit/src/plugins/agents/index.ts @@ -1,6 +1,7 @@ export { AgentsPlugin, agents } from "./agents"; export { buildToolkitEntries } from "./build-toolkit"; export { + agentIdFromMarkdownPath, type LoadContext, type LoadResult, loadAgentFromFile, diff --git a/packages/appkit/src/plugins/agents/load-agents.ts b/packages/appkit/src/plugins/agents/load-agents.ts index 2d82799e..5b2999ca 100644 --- a/packages/appkit/src/plugins/agents/load-agents.ts +++ b/packages/appkit/src/plugins/agents/load-agents.ts @@ -36,9 +36,9 @@ export interface LoadContext { } export interface LoadResult { - /** Agent definitions keyed by file-stem name. */ + /** Agent definitions keyed by agent id (directory name under `dir`). */ defs: Record; - /** First file with `default: true` frontmatter, or `null`. */ + /** First agent with `default: true` frontmatter (sorted id order), or `null`. */ defaultAgent: string | null; } @@ -48,11 +48,10 @@ interface Frontmatter { toolkits?: ToolkitSpec[]; tools?: string[]; /** - * Sibling file-stems to expose as sub-agents. Each becomes an - * `agent-` tool on this agent at runtime. Resolution happens at - * directory-load time in {@link loadAgentsFromDir}; the single-file - * {@link loadAgentFromFile} path rejects non-empty values since there - * are no siblings to resolve against. + * Other agent ids to expose as sub-agents. Each becomes an `agent-` + * tool at runtime. Resolution happens at directory-load time in + * {@link loadAgentsFromDir}; the single-file {@link loadAgentFromFile} path + * rejects non-empty values since there are no siblings to resolve against. */ agents?: string[]; maxSteps?: number; @@ -64,6 +63,21 @@ interface Frontmatter { type ToolkitSpec = string | { [pluginName: string]: ToolkitOptions | string[] }; +/** + * Derives the logical agent id from a markdown path. When the file is named + * `agent.md`, the id is the parent directory name (folder-based layout); + * otherwise the id is the file stem (e.g. legacy single-file paths). + */ +export function agentIdFromMarkdownPath(filePath: string): string { + const normalized = path.normalize(filePath); + const base = path.basename(normalized); + const parent = path.basename(path.dirname(normalized)); + if (base === "agent.md" && parent && parent !== "." && parent !== "..") { + return parent; + } + return path.basename(normalized, ".md"); +} + const ALLOWED_KEYS = new Set([ "endpoint", "model", @@ -90,7 +104,7 @@ export async function loadAgentFromFile( ctx: LoadContext, ): Promise { const raw = fs.readFileSync(filePath, "utf-8"); - const name = path.basename(filePath, ".md"); + const name = agentIdFromMarkdownPath(filePath); const { data } = parseFrontmatter(raw, filePath); if (Array.isArray(data?.agents) && data.agents.length > 0) { throw new Error( @@ -103,15 +117,19 @@ export async function loadAgentFromFile( } /** - * Scans a directory for `*.md` files and produces an `AgentDefinition` record - * keyed by file-stem. Throws on frontmatter errors or unresolved references. - * Returns an empty map if the directory does not exist. + * Scans a directory for one subdirectory per agent, each containing + * `agent.md` (frontmatter + body). Produces an `AgentDefinition` record keyed + * by agent id (folder name). Throws on frontmatter errors or unresolved + * references. Returns an empty map if the directory does not exist. + * + * Legacy top-level `*.md` files are rejected with an error — migrate each to + * `/agent.md` under a sibling folder named for the agent id. * * Runs in two passes so sub-agent references in frontmatter (`agents: [...]`) - * can be resolved regardless of file-system iteration order: + * can be resolved regardless of directory iteration order: * - * 1. Build every agent's definition from its own file. - * 2. Walk `agents:` references and wire `def.agents = { sibling: siblingDef }` + * 1. Build every agent's definition from its own `agent.md`. + * 2. Walk `agents:` references and wire `def.agents = { child: childDef }` * by looking them up in the complete map. Dangling names and * self-references fail loudly; mutual delegation is allowed and bounded * at runtime by `limits.maxSubAgentDepth`. @@ -123,37 +141,56 @@ export async function loadAgentsFromDir( if (!fs.existsSync(dir)) { return { defs: {}, defaultAgent: null }; } - // Sort so `default: true` resolution is deterministic across platforms — - // `readdirSync` order is filesystem-dependent (macOS alphabetical, ext4 - // inode order, etc.). - const files = fs - .readdirSync(dir) - .filter((f) => f.endsWith(".md")) + + const entries = fs.readdirSync(dir, { withFileTypes: true }); + const orphanMd = entries + .filter((e) => e.isFile() && e.name.endsWith(".md")) + .map((e) => e.name) .sort(); + + if (orphanMd.length > 0) { + const hint = orphanMd + .map((f) => `${path.basename(f, ".md")}/agent.md`) + .join(", "); + throw new Error( + `Agents directory contains unsupported top-level markdown file(s): ${orphanMd.join(", ")}. ` + + `Use one folder per agent with a fixed entry file, e.g. ${hint}.`, + ); + } + + /** Reserved folder name until per-agent skills land; not an agent package. */ + const RESERVED_DIRS = new Set(["skills"]); + + const agentIds = entries + .filter((e) => e.isDirectory()) + .map((e) => e.name) + .filter((name) => !RESERVED_DIRS.has(name)) + .sort(); + const defs: Record = {}; const subAgentRefs: Record = {}; let defaultAgent: string | null = null; - // Pass 1: build every agent's definition; collect unresolved sibling refs. - for (const file of files) { - const fullPath = path.join(dir, file); - const raw = fs.readFileSync(fullPath, "utf-8"); - const name = path.basename(file, ".md"); - defs[name] = buildDefinition(name, raw, fullPath, ctx); - const { data } = parseFrontmatter(raw, fullPath); - if (data?.agents !== undefined) { - subAgentRefs[name] = normalizeAgentsFrontmatter( - data.agents, - name, - fullPath, + // Pass 1: build every agent's definition; collect sub-agent refs. + for (const id of agentIds) { + const agentPath = path.join(dir, id, "agent.md"); + if (!fs.existsSync(agentPath)) { + throw new Error( + `Agents subdirectory '${path.join(dir, id)}' must contain agent.md.`, ); } + const raw = fs.readFileSync(agentPath, "utf-8"); + defs[id] = buildDefinition(id, raw, agentPath, ctx); + const { data } = parseFrontmatter(raw, agentPath); + if (data?.agents !== undefined) { + subAgentRefs[id] = normalizeAgentsFrontmatter(data.agents, id, agentPath); + } if (data?.default === true && !defaultAgent) { - defaultAgent = name; + defaultAgent = id; } } - // Pass 2: resolve sibling references against the complete defs map. + // Pass 2: resolve sub-agent references against the complete defs map. // Code-defined agents (ctx.codeAgents) take precedence over markdown ones // with the same name, matching the plugin's top-level merge behaviour. for (const [name, refs] of Object.entries(subAgentRefs)) { @@ -163,7 +200,7 @@ export async function loadAgentsFromDir( for (const ref of refs) { if (ref === name) { throw new Error( - `Agent '${name}' (${path.join(dir, `${name}.md`)}) cannot reference itself in 'agents:'.`, + `Agent '${name}' (${path.join(dir, name, "agent.md")}) cannot reference itself in 'agents:'.`, ); } const sibling = ctx.codeAgents?.[ref] ?? defs[ref]; @@ -203,7 +240,7 @@ function normalizeAgentsFrontmatter( if (!Array.isArray(value)) { throw new Error( `Agent '${agentName}' (${filePath}) has invalid 'agents:' frontmatter: ` + - `expected an array of sibling file-stems, got ${typeof value}.`, + `expected an array of sibling agent ids, got ${typeof value}.`, ); } const out: string[] = []; diff --git a/packages/appkit/src/plugins/agents/tests/agents-plugin.test.ts b/packages/appkit/src/plugins/agents/tests/agents-plugin.test.ts index d9abb925..747ada48 100644 --- a/packages/appkit/src/plugins/agents/tests/agents-plugin.test.ts +++ b/packages/appkit/src/plugins/agents/tests/agents-plugin.test.ts @@ -102,6 +102,12 @@ function instantiate(config: AgentsPluginConfig, ctx?: FakeContext) { return plugin; } +function writeMarkdownAgent(dir: string, id: string, content: string) { + const folder = path.join(dir, id); + fs.mkdirSync(folder, { recursive: true }); + fs.writeFileSync(path.join(folder, "agent.md"), content, "utf-8"); +} + describe("AgentsPlugin", () => { test("registers code-defined agents and exposes them via exports", async () => { const plugin = instantiate({ @@ -124,10 +130,10 @@ describe("AgentsPlugin", () => { }); test("loads markdown agents from a directory", async () => { - fs.writeFileSync( - path.join(tmpDir, "assistant.md"), + writeMarkdownAgent( + tmpDir, + "assistant", "---\ndefault: true\n---\nYou are helpful.", - "utf-8", ); const plugin = instantiate({ dir: tmpDir, @@ -144,11 +150,7 @@ describe("AgentsPlugin", () => { }); test("code definitions override markdown on key collision", async () => { - fs.writeFileSync( - path.join(tmpDir, "support.md"), - "---\n---\nFrom markdown.", - "utf-8", - ); + writeMarkdownAgent(tmpDir, "support", "---\n---\nFrom markdown."); const plugin = instantiate({ dir: tmpDir, defaultModel: stubAdapter(), @@ -179,11 +181,7 @@ describe("AgentsPlugin", () => { const provider = makeToolProvider("analytics", registry); const ctx = fakeContext([{ name: "analytics", provider }]); - fs.writeFileSync( - path.join(tmpDir, "assistant.md"), - "---\n---\nYou are helpful.", - "utf-8", - ); + writeMarkdownAgent(tmpDir, "assistant", "---\n---\nYou are helpful."); const plugin = instantiate( { @@ -228,11 +226,7 @@ describe("AgentsPlugin", () => { const provider = makeToolProvider("analytics", registry); const ctx = fakeContext([{ name: "analytics", provider }]); - fs.writeFileSync( - path.join(tmpDir, "assistant.md"), - "---\n---\nYou are helpful.", - "utf-8", - ); + writeMarkdownAgent(tmpDir, "assistant", "---\n---\nYou are helpful."); const plugin = instantiate( { @@ -313,10 +307,10 @@ describe("AgentsPlugin", () => { { name: "files", provider: makeToolProvider("files", registry2) }, ]); - fs.writeFileSync( - path.join(tmpDir, "analyst.md"), + writeMarkdownAgent( + tmpDir, + "analyst", "---\ntoolkits: [analytics]\n---\nAnalyst.", - "utf-8", ); const plugin = instantiate( diff --git a/packages/appkit/src/plugins/agents/tests/load-agents.test.ts b/packages/appkit/src/plugins/agents/tests/load-agents.test.ts index 3410f566..1a5b9523 100644 --- a/packages/appkit/src/plugins/agents/tests/load-agents.test.ts +++ b/packages/appkit/src/plugins/agents/tests/load-agents.test.ts @@ -5,6 +5,7 @@ import { afterEach, beforeEach, describe, expect, test } from "vitest"; import { z } from "zod"; import { buildToolkitEntries } from "../build-toolkit"; import { + agentIdFromMarkdownPath, loadAgentFromFile, loadAgentsFromDir, parseFrontmatter, @@ -23,11 +24,33 @@ afterEach(() => { fs.rmSync(workDir, { recursive: true, force: true }); }); -function write(name: string, content: string) { +/** Flat file under workDir (for legacy loadAgentFromFile tests). */ +function writeRoot(name: string, content: string) { fs.writeFileSync(path.join(workDir, name), content, "utf-8"); return path.join(workDir, name); } +/** Folder layout: `/agent.md`. */ +function writeAgent(id: string, content: string) { + const dir = path.join(workDir, id); + fs.mkdirSync(dir, { recursive: true }); + const p = path.join(dir, "agent.md"); + fs.writeFileSync(p, content, "utf-8"); + return p; +} + +describe("agentIdFromMarkdownPath", () => { + test("uses parent folder name when file is agent.md", () => { + expect(agentIdFromMarkdownPath("/foo/bar/assistant/agent.md")).toBe( + "assistant", + ); + }); + + test("uses file stem for other .md names", () => { + expect(agentIdFromMarkdownPath("/tmp/assistant.md")).toBe("assistant"); + }); +}); + describe("parseFrontmatter", () => { test("parses a simple object", () => { const { data, content } = parseFrontmatter( @@ -57,7 +80,7 @@ describe("parseFrontmatter", () => { describe("loadAgentFromFile", () => { test("returns AgentDefinition with body as instructions", async () => { - const p = write( + const p = writeRoot( "assistant.md", "---\nendpoint: e-1\n---\nYou are helpful.", ); @@ -66,6 +89,13 @@ describe("loadAgentFromFile", () => { expect(def.instructions).toBe("You are helpful."); expect(def.model).toBe("e-1"); }); + + test("derives agent id from folder when path ends with agent.md", async () => { + const p = writeAgent("router", "---\nendpoint: e-1\n---\nRoute traffic."); + const def = await loadAgentFromFile(p, {}); + expect(def.name).toBe("router"); + expect(def.instructions).toBe("Route traffic."); + }); }); describe("loadAgentsFromDir", () => { @@ -75,23 +105,44 @@ describe("loadAgentsFromDir", () => { expect(res.defaultAgent).toBeNull(); }); - test("loads all .md files keyed by file-stem", async () => { - write("support.md", "---\nendpoint: e-1\n---\nSupport prompt."); - write("sales.md", "---\nendpoint: e-2\n---\nSales prompt."); + test("loads each subdirectory with agent.md keyed by folder name", async () => { + writeAgent("support", "---\nendpoint: e-1\n---\nSupport prompt."); + writeAgent("sales", "---\nendpoint: e-2\n---\nSales prompt."); const res = await loadAgentsFromDir(workDir, {}); expect(Object.keys(res.defs).sort()).toEqual(["sales", "support"]); }); - test("picks up default: true from frontmatter", async () => { - write("one.md", "---\nendpoint: a\n---\nOne."); - write("two.md", "---\nendpoint: b\ndefault: true\n---\nTwo."); + test("throws when legacy top-level .md exists", async () => { + writeRoot("assistant.md", "---\nendpoint: e\n---\nLegacy flat file."); + await expect(loadAgentsFromDir(workDir, {})).rejects.toThrow( + /unsupported top-level markdown file\(s\): assistant\.md.*assistant\/agent\.md/s, + ); + }); + + test("throws when a subdirectory lacks agent.md", async () => { + fs.mkdirSync(path.join(workDir, "broken"), { recursive: true }); + await expect(loadAgentsFromDir(workDir, {})).rejects.toThrow( + /must contain agent\.md/, + ); + }); + + test("ignores reserved skills directory without agent.md", async () => { + fs.mkdirSync(path.join(workDir, "skills"), { recursive: true }); + writeAgent("solo", "---\nendpoint: e\n---\nOnly real agent."); + const res = await loadAgentsFromDir(workDir, {}); + expect(Object.keys(res.defs)).toEqual(["solo"]); + }); + + test("picks up default: true from frontmatter (deterministic sorted ids)", async () => { + writeAgent("one", "---\nendpoint: a\n---\nOne."); + writeAgent("two", "---\nendpoint: b\ndefault: true\n---\nTwo."); const res = await loadAgentsFromDir(workDir, {}); expect(res.defaultAgent).toBe("two"); }); test("throws when frontmatter references an unregistered plugin", async () => { - write( - "broken.md", + writeAgent( + "broken", "---\nendpoint: e\ntoolkits: [missing]\n---\nBroken agent.", ); await expect(loadAgentsFromDir(workDir, {})).rejects.toThrow( @@ -100,7 +151,10 @@ describe("loadAgentsFromDir", () => { }); test("throws when frontmatter references an unknown ambient tool", async () => { - write("broken.md", "---\nendpoint: e\ntools: [unknown_tool]\n---\nBroken."); + writeAgent( + "broken", + "---\nendpoint: e\ntools: [unknown_tool]\n---\nBroken.", + ); await expect(loadAgentsFromDir(workDir, {})).rejects.toThrow( /references tool 'unknown_tool'/, ); @@ -134,8 +188,8 @@ describe("loadAgentsFromDir", () => { execute: async () => "sunny", }); - write( - "analyst.md", + writeAgent( + "analyst", "---\nendpoint: e\ntoolkits:\n - analytics\ntools:\n - get_weather\n---\nBody.", ); const res = await loadAgentsFromDir(workDir, { @@ -150,15 +204,13 @@ describe("loadAgentsFromDir", () => { }); describe("agents: sibling sub-agent references", () => { - test("resolves sibling references into def.agents regardless of file order", async () => { - // Names chosen so alphabetical iteration puts `dispatcher` *before* - // its siblings — pass-1 populates defs in any order, pass-2 resolves. - write( - "dispatcher.md", + test("resolves sibling references into def.agents regardless of folder order", async () => { + writeAgent( + "dispatcher", "---\nendpoint: e\nagents:\n - analyst\n - writer\n---\nRoute work.", ); - write("analyst.md", "---\nendpoint: e\n---\nAnalyst."); - write("writer.md", "---\nendpoint: e\n---\nWriter."); + writeAgent("analyst", "---\nendpoint: e\n---\nAnalyst."); + writeAgent("writer", "---\nendpoint: e\n---\nWriter."); const res = await loadAgentsFromDir(workDir, {}); expect(Object.keys(res.defs.dispatcher.agents ?? {}).sort()).toEqual([ @@ -167,14 +219,13 @@ describe("loadAgentsFromDir", () => { ]); expect(res.defs.dispatcher.agents?.analyst).toBe(res.defs.analyst); expect(res.defs.dispatcher.agents?.writer).toBe(res.defs.writer); - // Leaves with no `agents:` retain undefined — only declared keys wire. expect(res.defs.analyst.agents).toBeUndefined(); expect(res.defs.writer.agents).toBeUndefined(); }); test("mutual delegation is allowed (runtime depth cap handles cycles)", async () => { - write("a.md", "---\nendpoint: e\nagents:\n - b\n---\nA."); - write("b.md", "---\nendpoint: e\nagents:\n - a\n---\nB."); + writeAgent("a", "---\nendpoint: e\nagents:\n - b\n---\nA."); + writeAgent("b", "---\nendpoint: e\nagents:\n - a\n---\nB."); const res = await loadAgentsFromDir(workDir, {}); expect(res.defs.a.agents?.b).toBe(res.defs.b); @@ -182,16 +233,16 @@ describe("loadAgentsFromDir", () => { }); test("throws with available list when a sibling is missing", async () => { - write("dispatcher.md", "---\nendpoint: e\nagents:\n - ghost\n---\nD."); - write("analyst.md", "---\nendpoint: e\n---\nAnalyst."); + writeAgent("dispatcher", "---\nendpoint: e\nagents:\n - ghost\n---\nD."); + writeAgent("analyst", "---\nendpoint: e\n---\nAnalyst."); await expect(loadAgentsFromDir(workDir, {})).rejects.toThrow( /references sub-agent\(s\) 'ghost'.*Available: analyst, dispatcher/s, ); }); test("reports every missing sibling in one error, not just the first", async () => { - write( - "dispatcher.md", + writeAgent( + "dispatcher", "---\nendpoint: e\nagents:\n - ghost1\n - ghost2\n---\nD.", ); await expect(loadAgentsFromDir(workDir, {})).rejects.toThrow( @@ -200,33 +251,33 @@ describe("loadAgentsFromDir", () => { }); test("throws on self-reference", async () => { - write("solo.md", "---\nendpoint: e\nagents:\n - solo\n---\nSolo."); + writeAgent("solo", "---\nendpoint: e\nagents:\n - solo\n---\nSolo."); await expect(loadAgentsFromDir(workDir, {})).rejects.toThrow( /'solo'.*cannot reference itself/s, ); }); test("throws on non-array 'agents:' value", async () => { - write("bad.md", "---\nendpoint: e\nagents: analyst\n---\nBad."); - write("analyst.md", "---\nendpoint: e\n---\nAnalyst."); + writeAgent("bad", "---\nendpoint: e\nagents: analyst\n---\nBad."); + writeAgent("analyst", "---\nendpoint: e\n---\nAnalyst."); await expect(loadAgentsFromDir(workDir, {})).rejects.toThrow( /invalid 'agents:' frontmatter/, ); }); test("throws on non-string entries in 'agents:'", async () => { - write("bad.md", "---\nendpoint: e\nagents:\n - 42\n---\nBad."); + writeAgent("bad", "---\nendpoint: e\nagents:\n - 42\n---\nBad."); await expect(loadAgentsFromDir(workDir, {})).rejects.toThrow( /invalid 'agents:' entry/, ); }); test("deduplicates repeated entries silently", async () => { - write( - "dispatcher.md", + writeAgent( + "dispatcher", "---\nendpoint: e\nagents:\n - analyst\n - analyst\n---\nD.", ); - write("analyst.md", "---\nendpoint: e\n---\nAnalyst."); + writeAgent("analyst", "---\nendpoint: e\n---\nAnalyst."); const res = await loadAgentsFromDir(workDir, {}); expect(Object.keys(res.defs.dispatcher.agents ?? {})).toEqual([ "analyst", @@ -234,13 +285,16 @@ describe("loadAgentsFromDir", () => { }); test("empty array yields no sub-agents (no-op)", async () => { - write("dispatcher.md", "---\nendpoint: e\nagents: []\n---\nD."); + writeAgent("dispatcher", "---\nendpoint: e\nagents: []\n---\nD."); const res = await loadAgentsFromDir(workDir, {}); expect(res.defs.dispatcher.agents).toBeUndefined(); }); test("resolves 'agents:' references against codeAgents when provided", async () => { - write("dispatcher.md", "---\nendpoint: e\nagents:\n - support\n---\nD."); + writeAgent( + "dispatcher", + "---\nendpoint: e\nagents:\n - support\n---\nD.", + ); const support: AgentDefinition = { name: "support", instructions: "Code-defined support.", @@ -252,8 +306,11 @@ describe("loadAgentsFromDir", () => { }); test("codeAgents takes precedence over markdown sibling with the same name", async () => { - write("dispatcher.md", "---\nendpoint: e\nagents:\n - support\n---\nD."); - write("support.md", "---\nendpoint: e\n---\nMarkdown support."); + writeAgent( + "dispatcher", + "---\nendpoint: e\nagents:\n - support\n---\nD.", + ); + writeAgent("support", "---\nendpoint: e\n---\nMarkdown support."); const codeSupport: AgentDefinition = { name: "support", instructions: "Code support.", @@ -261,8 +318,6 @@ describe("loadAgentsFromDir", () => { const res = await loadAgentsFromDir(workDir, { codeAgents: { support: codeSupport }, }); - // Reference binds to code version, matching the plugin's top-level - // `code wins` merge behaviour. expect(res.defs.dispatcher.agents?.support).toBe(codeSupport); expect(res.defs.dispatcher.agents?.support.instructions).toBe( "Code support.", @@ -270,8 +325,8 @@ describe("loadAgentsFromDir", () => { }); test("missing-sibling error lists both markdown and code agent names", async () => { - write("dispatcher.md", "---\nendpoint: e\nagents:\n - ghost\n---\nD."); - write("analyst.md", "---\nendpoint: e\n---\nAnalyst."); + writeAgent("dispatcher", "---\nendpoint: e\nagents:\n - ghost\n---\nD."); + writeAgent("analyst", "---\nendpoint: e\n---\nAnalyst."); const codeAgent: AgentDefinition = { name: "writer", instructions: "Writer.", @@ -285,7 +340,7 @@ describe("loadAgentsFromDir", () => { describe("loadAgentFromFile — sub-agent refs rejected", () => { test("throws when 'agents:' is non-empty in a single-file load", async () => { - const p = write( + const p = writeRoot( "lonely.md", "---\nendpoint: e\nagents:\n - ghost\n---\nLonely.", ); @@ -295,7 +350,10 @@ describe("loadAgentFromFile — sub-agent refs rejected", () => { }); test("ignores empty 'agents:' array (treated as absent)", async () => { - const p = write("lonely.md", "---\nendpoint: e\nagents: []\n---\nLonely."); + const p = writeRoot( + "lonely.md", + "---\nendpoint: e\nagents: []\n---\nLonely.", + ); const def = await loadAgentFromFile(p, {}); expect(def.agents).toBeUndefined(); }); diff --git a/packages/appkit/src/plugins/agents/types.ts b/packages/appkit/src/plugins/agents/types.ts index e18cc8f4..8d52d278 100644 --- a/packages/appkit/src/plugins/agents/types.ts +++ b/packages/appkit/src/plugins/agents/types.ts @@ -110,7 +110,7 @@ export interface AutoInheritToolsConfig { } export interface AgentsPluginConfig extends BasePluginConfig { - /** Directory to scan for markdown agent files. Default `./config/agents`. Set to `false` to disable. */ + /** Directory of agent packages (`/agent.md` each). Default `./config/agents`. Set to `false` to disable. */ dir?: string | false; /** Code-defined agents, merged with file-loaded ones (code wins on key collision). */ agents?: Record; From 2ca8074e45240c95644fe9d95830998f62af0df3 Mon Sep 17 00:00:00 2001 From: MarioCadenas Date: Thu, 23 Apr 2026 21:37:29 +0200 Subject: [PATCH 19/23] refactor(appkit): promote MCP client + host policy to connectors/mcp MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The MCP transport client and host policy aren't agents-specific; they are HTTP + JSON-RPC transport with URL/DNS allowlisting. Move them under packages/appkit/src/connectors/mcp/ so they sit alongside the other transport-layer modules (serving, genie, sql-warehouse, lakebase, …) and stop being reachable only through the agents plugin. - Move mcp-client.ts -> connectors/mcp/client.ts - Move mcp-host-policy.ts -> connectors/mcp/host-policy.ts - Move McpEndpointConfig type -> connectors/mcp/types.ts - Add connectors/mcp/index.ts barrel; re-export from connectors/index.ts - Move mcp-client / mcp-host-policy tests to connectors/mcp/tests/ - Agents plugin keeps hosted-tools.ts (HostedTool sugar + resolve) and imports connector types from ../../connectors/mcp. - tools/ barrel no longer re-exports AppKitMcpClient (never was public). No behaviour change. All existing tests pass against the new paths. --- packages/appkit/src/connectors/index.ts | 1 + .../tools/mcp-client.ts => connectors/mcp/client.ts} | 8 ++++---- .../mcp/host-policy.ts} | 0 packages/appkit/src/connectors/mcp/index.ts | 6 ++++++ .../mcp/tests/client.test.ts} | 4 ++-- .../mcp/tests/host-policy.test.ts} | 2 +- packages/appkit/src/connectors/mcp/types.ts | 12 ++++++++++++ packages/appkit/src/plugins/agents/agents.ts | 3 +-- .../appkit/src/plugins/agents/tools/hosted-tools.ts | 8 ++------ packages/appkit/src/plugins/agents/tools/index.ts | 1 - packages/appkit/src/plugins/agents/types.ts | 2 +- 11 files changed, 30 insertions(+), 17 deletions(-) rename packages/appkit/src/{plugins/agents/tools/mcp-client.ts => connectors/mcp/client.ts} (98%) rename packages/appkit/src/{plugins/agents/tools/mcp-host-policy.ts => connectors/mcp/host-policy.ts} (100%) create mode 100644 packages/appkit/src/connectors/mcp/index.ts rename packages/appkit/src/{plugins/agents/tests/mcp-client.test.ts => connectors/mcp/tests/client.test.ts} (98%) rename packages/appkit/src/{plugins/agents/tests/mcp-host-policy.test.ts => connectors/mcp/tests/host-policy.test.ts} (99%) create mode 100644 packages/appkit/src/connectors/mcp/types.ts diff --git a/packages/appkit/src/connectors/index.ts b/packages/appkit/src/connectors/index.ts index 54a24fa4..daae9439 100644 --- a/packages/appkit/src/connectors/index.ts +++ b/packages/appkit/src/connectors/index.ts @@ -2,5 +2,6 @@ export * from "./files"; export * from "./genie"; export * from "./lakebase"; export * from "./lakebase-v1"; +export * from "./mcp"; export * from "./sql-warehouse"; export * from "./vector-search"; diff --git a/packages/appkit/src/plugins/agents/tools/mcp-client.ts b/packages/appkit/src/connectors/mcp/client.ts similarity index 98% rename from packages/appkit/src/plugins/agents/tools/mcp-client.ts rename to packages/appkit/src/connectors/mcp/client.ts index 49db7882..4c8d058b 100644 --- a/packages/appkit/src/plugins/agents/tools/mcp-client.ts +++ b/packages/appkit/src/connectors/mcp/client.ts @@ -23,16 +23,16 @@ * transport. */ import type { AgentToolDefinition } from "shared"; -import { createLogger } from "../../../logging/logger"; -import type { McpEndpointConfig } from "./hosted-tools"; +import { createLogger } from "../../logging/logger"; import { assertResolvedHostSafe, checkMcpUrl, type DnsLookup, type McpHostPolicy, -} from "./mcp-host-policy"; +} from "./host-policy"; +import type { McpEndpointConfig } from "./types"; -const logger = createLogger("agent:mcp"); +const logger = createLogger("connector:mcp"); interface JsonRpcRequest { jsonrpc: "2.0"; diff --git a/packages/appkit/src/plugins/agents/tools/mcp-host-policy.ts b/packages/appkit/src/connectors/mcp/host-policy.ts similarity index 100% rename from packages/appkit/src/plugins/agents/tools/mcp-host-policy.ts rename to packages/appkit/src/connectors/mcp/host-policy.ts diff --git a/packages/appkit/src/connectors/mcp/index.ts b/packages/appkit/src/connectors/mcp/index.ts new file mode 100644 index 00000000..f9f32a41 --- /dev/null +++ b/packages/appkit/src/connectors/mcp/index.ts @@ -0,0 +1,6 @@ +export { AppKitMcpClient } from "./client"; +export { + buildMcpHostPolicy, + type McpHostPolicyConfig, +} from "./host-policy"; +export type { McpEndpointConfig } from "./types"; diff --git a/packages/appkit/src/plugins/agents/tests/mcp-client.test.ts b/packages/appkit/src/connectors/mcp/tests/client.test.ts similarity index 98% rename from packages/appkit/src/plugins/agents/tests/mcp-client.test.ts rename to packages/appkit/src/connectors/mcp/tests/client.test.ts index 483fb5f4..0cdffa29 100644 --- a/packages/appkit/src/plugins/agents/tests/mcp-client.test.ts +++ b/packages/appkit/src/connectors/mcp/tests/client.test.ts @@ -1,6 +1,6 @@ import { beforeEach, describe, expect, test, vi } from "vitest"; -import { AppKitMcpClient } from "../tools/mcp-client"; -import type { DnsLookup, McpHostPolicy } from "../tools/mcp-host-policy"; +import { AppKitMcpClient } from "../client"; +import type { DnsLookup, McpHostPolicy } from "../host-policy"; const WORKSPACE = "https://test-workspace.cloud.databricks.com"; diff --git a/packages/appkit/src/plugins/agents/tests/mcp-host-policy.test.ts b/packages/appkit/src/connectors/mcp/tests/host-policy.test.ts similarity index 99% rename from packages/appkit/src/plugins/agents/tests/mcp-host-policy.test.ts rename to packages/appkit/src/connectors/mcp/tests/host-policy.test.ts index 06d98627..451536ed 100644 --- a/packages/appkit/src/plugins/agents/tests/mcp-host-policy.test.ts +++ b/packages/appkit/src/connectors/mcp/tests/host-policy.test.ts @@ -8,7 +8,7 @@ import { isLoopbackHost, type McpHostPolicy, type McpHostPolicyConfig, -} from "../tools/mcp-host-policy"; +} from "../host-policy"; function stubLookup( addresses: Array<{ address: string; family?: number }>, diff --git a/packages/appkit/src/connectors/mcp/types.ts b/packages/appkit/src/connectors/mcp/types.ts new file mode 100644 index 00000000..d74f0a46 --- /dev/null +++ b/packages/appkit/src/connectors/mcp/types.ts @@ -0,0 +1,12 @@ +/** + * Input shape consumed by {@link AppKitMcpClient.connect}. Produced by the + * agents plugin from user-facing `HostedTool` declarations (see + * `plugins/agents/tools/hosted-tools.ts`) and accepted directly by the + * connector to keep its surface free of agent-layer concepts. + */ +export interface McpEndpointConfig { + /** Stable logical name used as the `mcp..*` tool prefix and in logs. */ + name: string; + /** Absolute URL (`https://…`) or workspace-relative path (`/api/2.0/mcp/…`). */ + url: string; +} diff --git a/packages/appkit/src/plugins/agents/agents.ts b/packages/appkit/src/plugins/agents/agents.ts index f1938ef2..b1594f44 100644 --- a/packages/appkit/src/plugins/agents/agents.ts +++ b/packages/appkit/src/plugins/agents/agents.ts @@ -14,6 +14,7 @@ import type { Thread, ToolProvider, } from "shared"; +import { AppKitMcpClient, buildMcpHostPolicy } from "../../connectors/mcp"; import { createLogger } from "../../logging/logger"; import { Plugin, toPlugin } from "../../plugin"; import type { PluginManifest } from "../../registry"; @@ -31,13 +32,11 @@ import { buildBaseSystemPrompt, composeSystemPrompt } from "./system-prompt"; import { InMemoryThreadStore } from "./thread-store"; import { ToolApprovalGate } from "./tool-approval-gate"; import { - AppKitMcpClient, functionToolToDefinition, isFunctionTool, isHostedTool, resolveHostedTools, } from "./tools"; -import { buildMcpHostPolicy } from "./tools/mcp-host-policy"; import type { AgentDefinition, AgentsPluginConfig, diff --git a/packages/appkit/src/plugins/agents/tools/hosted-tools.ts b/packages/appkit/src/plugins/agents/tools/hosted-tools.ts index bce70c4f..c1f06767 100644 --- a/packages/appkit/src/plugins/agents/tools/hosted-tools.ts +++ b/packages/appkit/src/plugins/agents/tools/hosted-tools.ts @@ -1,3 +1,5 @@ +import type { McpEndpointConfig } from "../../../connectors/mcp"; + export interface GenieTool { type: "genie-space"; genie_space: { id: string }; @@ -37,12 +39,6 @@ export function isHostedTool(value: unknown): value is HostedTool { return typeof obj.type === "string" && HOSTED_TOOL_TYPES.has(obj.type); } -export interface McpEndpointConfig { - name: string; - /** Absolute URL or path relative to workspace host */ - url: string; -} - /** * Resolves HostedTool configs into MCP endpoint configurations * that the MCP client can connect to. diff --git a/packages/appkit/src/plugins/agents/tools/index.ts b/packages/appkit/src/plugins/agents/tools/index.ts index 7b779d1c..004c96b5 100644 --- a/packages/appkit/src/plugins/agents/tools/index.ts +++ b/packages/appkit/src/plugins/agents/tools/index.ts @@ -16,5 +16,4 @@ export { mcpServer, resolveHostedTools, } from "./hosted-tools"; -export { AppKitMcpClient } from "./mcp-client"; export { type ToolConfig, tool } from "./tool"; diff --git a/packages/appkit/src/plugins/agents/types.ts b/packages/appkit/src/plugins/agents/types.ts index 8d52d278..14366e9a 100644 --- a/packages/appkit/src/plugins/agents/types.ts +++ b/packages/appkit/src/plugins/agents/types.ts @@ -5,9 +5,9 @@ import type { ThreadStore, ToolAnnotations, } from "shared"; +import type { McpHostPolicyConfig } from "../../connectors/mcp"; import type { FunctionTool } from "./tools/function-tool"; import type { HostedTool } from "./tools/hosted-tools"; -import type { McpHostPolicyConfig } from "./tools/mcp-host-policy"; /** * A tool reference produced by a plugin's `.toolkit()` call. The agents plugin From 0f4893df7af131e9b0bab5cfe21fdab7822160bd Mon Sep 17 00:00:00 2001 From: MarioCadenas Date: Thu, 23 Apr 2026 21:50:54 +0200 Subject: [PATCH 20/23] refactor(appkit): static context import + SDK credential chain in agents MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two small cleanups to AgentsPlugin.connectHostedTools(): - Replace the dynamic `await import("../../context")` with a top-level `import { getWorkspaceClient } from "../../context"`, matching every other plugin (genie, serving, analytics, files, vector-search). - Drop the ad-hoc env-var fallback (DATABRICKS_HOST + DATABRICKS_TOKEN, PAT only). When ServiceContext is not initialized (test rigs, manual embeds) construct a bare `new WorkspaceClient({})` and let the SDK walk its own credential chain — env, ~/.databrickscfg profiles, DAB auth, OAuth, metadata service — before calling config.authenticate(). No behaviour change on the normal createApp path. The fallback branch now supports every SDK auth type instead of PAT only, and tells the user which setting to fix when no host can be resolved. --- packages/appkit/src/plugins/agents/agents.ts | 50 ++++++++++++-------- 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/packages/appkit/src/plugins/agents/agents.ts b/packages/appkit/src/plugins/agents/agents.ts index b1594f44..90a920f4 100644 --- a/packages/appkit/src/plugins/agents/agents.ts +++ b/packages/appkit/src/plugins/agents/agents.ts @@ -15,6 +15,7 @@ import type { ToolProvider, } from "shared"; import { AppKitMcpClient, buildMcpHostPolicy } from "../../connectors/mcp"; +import { getWorkspaceClient } from "../../context"; import { createLogger } from "../../logging/logger"; import { Plugin, toPlugin } from "../../plugin"; import type { PluginManifest } from "../../registry"; @@ -474,35 +475,25 @@ export class AgentsPlugin extends Plugin implements ToolProvider { hostedTools: import("./tools/hosted-tools").HostedTool[], index: Map, ): Promise { - let host: string | undefined; - let authenticate: () => Promise>; - - try { - const { getWorkspaceClient } = await import("../../context"); - const wsClient = getWorkspaceClient(); - await wsClient.config.ensureResolved(); - host = wsClient.config.host; - authenticate = async () => { - const headers = new Headers(); - await wsClient.config.authenticate(headers); - return Object.fromEntries(headers.entries()); - }; - } catch { - host = process.env.DATABRICKS_HOST; - authenticate = async (): Promise> => { - const token = process.env.DATABRICKS_TOKEN; - return token ? { Authorization: `Bearer ${token}` } : {}; - }; - } + const wsClient = await this.resolveWorkspaceClient(); + await wsClient.config.ensureResolved(); + const host = wsClient.config.host; if (!host) { logger.warn( - "No Databricks host available — skipping %d hosted tool(s)", + "No Databricks host available — skipping %d hosted tool(s). " + + "Set DATABRICKS_HOST or configure a profile in ~/.databrickscfg.", hostedTools.length, ); return; } + const authenticate = async (): Promise> => { + const headers = new Headers(); + await wsClient.config.authenticate(headers); + return Object.fromEntries(headers.entries()); + }; + if (!this.mcpClient) { const policy = buildMcpHostPolicy(this.config.mcp, host); this.mcpClient = new AppKitMcpClient(host, authenticate, policy); @@ -520,6 +511,23 @@ export class AgentsPlugin extends Plugin implements ToolProvider { } } + /** + * Return the ambient workspace client from {@link getWorkspaceClient} when + * `ServiceContext` is initialized (the normal `createApp` path). Fall back + * to a fresh `WorkspaceClient()` that walks the SDK's credential chain — + * `DATABRICKS_HOST` / `DATABRICKS_TOKEN`, `~/.databrickscfg` profiles, + * DAB auth, OAuth, metadata service — for test rigs and manual embeds + * that never ran through `createApp`. + */ + private async resolveWorkspaceClient() { + try { + return getWorkspaceClient(); + } catch { + const { WorkspaceClient } = await import("@databricks/sdk-experimental"); + return new WorkspaceClient({}); + } + } + // ----------------- ToolProvider (no tools of our own) -------------------- getAgentTools(): AgentToolDefinition[] { From 54684f6dca734d5fd7fc3ebd7bb110b83a430b95 Mon Sep 17 00:00:00 2001 From: MarioCadenas Date: Thu, 23 Apr 2026 22:06:39 +0200 Subject: [PATCH 21/23] refactor(appkit): extract normalizeToolResult, consumeAdapterStream, dispatchToolCall Three small helpers pulled out of the AgentsPlugin streaming path to cut duplication and shrink the two large methods. - normalize-result.ts: void->"", JSON-stringify, 50K truncation with a human-readable marker. Unit-testable (previously covered only via the HTTP path). - consume-adapter-stream.ts: the 'message_delta' + 'message' accumulation loop shared between _streamAgent and runSubAgent. Accepts an optional signal and per-event side-effect callback (for SSE translation). - tool-dispatch.ts: one place that fans out toolkit/function/mcp/subagent entries. 'never'-typed default forces exhaustiveness: adding a fifth source is now a compile error at every call site. _streamAgent: executeTool closure shrinks from ~60 lines of dispatch + normalize to a single dispatchToolCall + normalizeToolResult call. Stream consumption collapses to consumeAdapterStream. runSubAgent: childExecute shrinks from ~30 lines of if/else dispatch to one dispatchToolCall call. Adapter loop collapses to consumeAdapterStream. Behaviour change (minor): childExecute previously silently fell through to 'Unsupported sub-agent tool source' when mcpClient or PluginContext was missing; now it throws the same specific error as the main stream. Matches the main-path behaviour. Tests: 15 new unit tests for normalizeToolResult + consumeAdapterStream. dispatchToolCall is exercised transitively through the full agent suite (288 existing tests still pass, 303 total on this branch). --- packages/appkit/src/plugins/agents/agents.ts | 151 +++++------------- .../plugins/agents/consume-adapter-stream.ts | 52 ++++++ .../src/plugins/agents/normalize-result.ts | 33 ++++ .../tests/consume-adapter-stream.test.ts | 86 ++++++++++ .../agents/tests/normalize-result.test.ts | 63 ++++++++ .../src/plugins/agents/tool-dispatch.ts | 97 +++++++++++ 6 files changed, 372 insertions(+), 110 deletions(-) create mode 100644 packages/appkit/src/plugins/agents/consume-adapter-stream.ts create mode 100644 packages/appkit/src/plugins/agents/normalize-result.ts create mode 100644 packages/appkit/src/plugins/agents/tests/consume-adapter-stream.test.ts create mode 100644 packages/appkit/src/plugins/agents/tests/normalize-result.test.ts create mode 100644 packages/appkit/src/plugins/agents/tool-dispatch.ts diff --git a/packages/appkit/src/plugins/agents/agents.ts b/packages/appkit/src/plugins/agents/agents.ts index 90a920f4..de594848 100644 --- a/packages/appkit/src/plugins/agents/agents.ts +++ b/packages/appkit/src/plugins/agents/agents.ts @@ -4,7 +4,6 @@ import type express from "express"; import pc from "picocolors"; import type { AgentAdapter, - AgentEvent, AgentRunContext, AgentToolDefinition, IAppRouter, @@ -19,11 +18,13 @@ import { getWorkspaceClient } from "../../context"; import { createLogger } from "../../logging/logger"; import { Plugin, toPlugin } from "../../plugin"; import type { PluginManifest } from "../../registry"; +import { consumeAdapterStream } from "./consume-adapter-stream"; import { agentStreamDefaults } from "./defaults"; import { EventChannel } from "./event-channel"; import { AgentEventTranslator } from "./event-translator"; import { loadAgentsFromDir } from "./load-agents"; import manifest from "./manifest.json"; +import { normalizeToolResult } from "./normalize-result"; import { approvalRequestSchema, chatRequestSchema, @@ -32,6 +33,7 @@ import { import { buildBaseSystemPrompt, composeSystemPrompt } from "./system-prompt"; import { InMemoryThreadStore } from "./thread-store"; import { ToolApprovalGate } from "./tool-approval-gate"; +import { dispatchToolCall } from "./tool-dispatch"; import { functionToolToDefinition, isFunctionTool, @@ -777,57 +779,18 @@ export class AgentsPlugin extends Plugin implements ToolProvider { } } - let result: unknown; - if (entry.source === "toolkit") { - if (!this.context) { - throw new Error( - "Plugin tool execution requires PluginContext; this should never happen through createApp", - ); - } - result = await this.context.executeTool( - req, - entry.pluginName, - entry.localName, - args, - signal, - ); - } else if (entry.source === "function") { - result = await entry.functionTool.execute( - args as Record, - ); - } else if (entry.source === "mcp") { - if (!this.mcpClient) throw new Error("MCP client not connected"); - const oboToken = req.headers["x-forwarded-access-token"]; - const mcpAuth = - typeof oboToken === "string" - ? { Authorization: `Bearer ${oboToken}` } - : undefined; - result = await this.mcpClient.callTool( - entry.mcpToolName, - args, - mcpAuth, - ); - } else if (entry.source === "subagent") { - const childAgent = this.agents.get(entry.agentName); - if (!childAgent) - throw new Error(`Sub-agent not found: ${entry.agentName}`); - result = await this.runSubAgent(req, childAgent, args, signal, 1); - } - - // A `void` / `undefined` return is a legitimate tool outcome (e.g., a - // "send notification" side-effecting tool). Return an empty string so - // the LLM sees a successful-but-empty result rather than a bogus - // "execution failed" error. - if (result === undefined) { - return ""; - } - const MAX = 50_000; - const serialized = - typeof result === "string" ? result : JSON.stringify(result); - if (serialized.length > MAX) { - return `${serialized.slice(0, MAX)}\n\n[Result truncated: ${serialized.length} chars exceeds ${MAX} limit]`; - } - return result; + const raw = await dispatchToolCall(entry, args, { + req, + signal, + pluginContext: this.context, + mcpClient: this.mcpClient, + runSubAgent: (agentName, subArgs) => { + const childAgent = this.agents.get(agentName); + if (!childAgent) throw new Error(`Sub-agent not found: ${agentName}`); + return this.runSubAgent(req, childAgent, subArgs, signal, 1); + }, + }); + return normalizeToolResult(raw); }; // Drive the adapter and the approval-event side-channel concurrently. @@ -878,26 +841,14 @@ export class AgentsPlugin extends Plugin implements ToolProvider { { executeTool, signal }, ); - // Accumulate assistant output from BOTH streaming and non-streaming - // adapters. Delta-based adapters (Databricks, Vercel AI) emit - // `message_delta` chunks that we concatenate; adapters that yield a - // single final assistant message (e.g. LangChain's `on_chain_end` - // path) emit a `message` event whose content replaces whatever - // deltas already arrived. Without the `message` branch, multi-turn - // LangChain conversations silently dropped the assistant turn from - // thread history. - let fullContent = ""; - for await (const event of stream) { - if (signal.aborted) break; - if (event.type === "message_delta") { - fullContent += event.content; - } else if (event.type === "message") { - fullContent = event.content; - } - for (const translated of translator.translate(event)) { - outboundEvents.push(translated); - } - } + const fullContent = await consumeAdapterStream(stream, { + signal, + onEvent: (event) => { + for (const translated of translator.translate(event)) { + outboundEvents.push(translated); + } + }, + }); if (fullContent) { await this.threadStore.addMessage(thread.id, userId, { @@ -998,33 +949,17 @@ export class AgentsPlugin extends Plugin implements ToolProvider { ): Promise => { const entry = child.toolIndex.get(name); if (!entry) throw new Error(`Unknown tool in sub-agent: ${name}`); - if (entry.source === "toolkit" && this.context) { - return this.context.executeTool( - req, - entry.pluginName, - entry.localName, - childArgs, - signal, - ); - } - if (entry.source === "function") { - return entry.functionTool.execute(childArgs as Record); - } - if (entry.source === "subagent") { - const grandchild = this.agents.get(entry.agentName); - if (!grandchild) - throw new Error(`Sub-agent not found: ${entry.agentName}`); - return this.runSubAgent(req, grandchild, childArgs, signal, depth + 1); - } - if (entry.source === "mcp" && this.mcpClient) { - const oboToken = req.headers["x-forwarded-access-token"]; - const mcpAuth = - typeof oboToken === "string" - ? { Authorization: `Bearer ${oboToken}` } - : undefined; - return this.mcpClient.callTool(entry.mcpToolName, childArgs, mcpAuth); - } - throw new Error(`Unsupported sub-agent tool source: ${entry.source}`); + return dispatchToolCall(entry, childArgs, { + req, + signal, + pluginContext: this.context, + mcpClient: this.mcpClient, + runSubAgent: (agentName, args) => { + const grandchild = this.agents.get(agentName); + if (!grandchild) throw new Error(`Sub-agent not found: ${agentName}`); + return this.runSubAgent(req, grandchild, args, signal, depth + 1); + }, + }); }; const runContext: AgentRunContext = { executeTool: childExecute, signal }; @@ -1059,17 +994,13 @@ export class AgentsPlugin extends Plugin implements ToolProvider { }, ]; - let output = ""; - const events: AgentEvent[] = []; - for await (const event of child.adapter.run( - { messages, tools: childTools, threadId: randomUUID(), signal }, - runContext, - )) { - events.push(event); - if (event.type === "message_delta") output += event.content; - else if (event.type === "message") output = event.content; - } - return output; + return consumeAdapterStream( + child.adapter.run( + { messages, tools: childTools, threadId: randomUUID(), signal }, + runContext, + ), + { signal }, + ); } private async _handleCancel(req: express.Request, res: express.Response) { diff --git a/packages/appkit/src/plugins/agents/consume-adapter-stream.ts b/packages/appkit/src/plugins/agents/consume-adapter-stream.ts new file mode 100644 index 00000000..c4f3d07e --- /dev/null +++ b/packages/appkit/src/plugins/agents/consume-adapter-stream.ts @@ -0,0 +1,52 @@ +import type { AgentEvent } from "shared"; + +interface ConsumeAdapterStreamOptions { + /** + * Optional abort signal. When aborted, the loop stops consuming (the caller + * is expected to have forwarded the same signal to `adapter.run` to stop + * upstream work). `undefined` is valid — standalone `runAgent` runs without + * a signal. + */ + signal?: AbortSignal; + /** + * Side-effect callback invoked once per adapter event, after the content + * accumulator has been updated. Use to fan events out to SSE translators, + * collect a raw event list for tests, or emit telemetry. + */ + onEvent?: (event: AgentEvent) => void; +} + +/** + * Consume an adapter's event stream and aggregate the assistant's final text. + * + * Accumulation rule (shared across all agent-execution paths in AppKit): + * + * - `message_delta` events append their `content` to the running text. + * - A `message` event *replaces* the running text with its `content`. + * + * The two branches coexist because different adapters emit different shapes: + * streaming adapters (Databricks, Vercel AI) emit deltas chunk-by-chunk, + * while `LangChain`'s `on_chain_end` path emits a single final `message`. + * Without the replace branch, LangChain conversations silently dropped the + * assistant turn from thread history. + * + * Kept pure (no I/O, no mutable external state beyond the caller's `onEvent` + * side effect) so each execution path — HTTP streaming, sub-agents, and the + * standalone `runAgent` — can share one loop. + */ +export async function consumeAdapterStream( + stream: AsyncIterable, + opts: ConsumeAdapterStreamOptions = {}, +): Promise { + let text = ""; + for await (const event of stream) { + if (opts.signal?.aborted) break; + if (event.type === "message_delta") { + text += event.content; + } else if (event.type === "message") { + text = event.content; + } + opts.onEvent?.(event); + } + return text; +} diff --git a/packages/appkit/src/plugins/agents/normalize-result.ts b/packages/appkit/src/plugins/agents/normalize-result.ts new file mode 100644 index 00000000..6fe2362c --- /dev/null +++ b/packages/appkit/src/plugins/agents/normalize-result.ts @@ -0,0 +1,33 @@ +/** + * Maximum serialized length of a tool result before we truncate with a + * human-readable marker. 50k chars is roughly ~12k tokens — enough for + * reasonable SQL result sets and JSON blobs, well short of the per-call + * context limits on current frontier models. + */ +export const MAX_TOOL_RESULT_CHARS = 50_000; + +/** + * Normalise a raw tool-execution result for the LLM: + * + * - `undefined` → empty string. A `void` return is a legitimate outcome for + * side-effecting tools ("send notification"); surfacing `undefined` to the + * adapter would otherwise read as "execution failed". + * - strings are returned as-is. + * - everything else is JSON-stringified. + * - results longer than {@link MAX_TOOL_RESULT_CHARS} are truncated and + * annotated so the model sees the cut rather than silent data loss. + * + * Pure function; safe to unit-test in isolation. + */ +export function normalizeToolResult( + result: unknown, + maxChars: number = MAX_TOOL_RESULT_CHARS, +): unknown { + if (result === undefined) return ""; + const serialized = + typeof result === "string" ? result : JSON.stringify(result); + if (serialized.length > maxChars) { + return `${serialized.slice(0, maxChars)}\n\n[Result truncated: ${serialized.length} chars exceeds ${maxChars} limit]`; + } + return result; +} diff --git a/packages/appkit/src/plugins/agents/tests/consume-adapter-stream.test.ts b/packages/appkit/src/plugins/agents/tests/consume-adapter-stream.test.ts new file mode 100644 index 00000000..98863a62 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/consume-adapter-stream.test.ts @@ -0,0 +1,86 @@ +import type { AgentEvent } from "shared"; +import { describe, expect, test } from "vitest"; +import { consumeAdapterStream } from "../consume-adapter-stream"; + +async function* streamOf( + events: AgentEvent[], +): AsyncGenerator { + for (const event of events) { + yield event; + } +} + +describe("consumeAdapterStream", () => { + test("concatenates message_delta events into the final text", async () => { + const text = await consumeAdapterStream( + streamOf([ + { type: "message_delta", content: "Hello " }, + { type: "message_delta", content: "world" }, + ]), + ); + expect(text).toBe("Hello world"); + }); + + test("a `message` event replaces whatever deltas arrived so far", async () => { + const text = await consumeAdapterStream( + streamOf([ + { type: "message_delta", content: "partial" }, + { type: "message", content: "final answer" }, + ]), + ); + expect(text).toBe("final answer"); + }); + + test("invokes onEvent once per event, in order, with the raw event", async () => { + const seen: AgentEvent[] = []; + await consumeAdapterStream( + streamOf([ + { type: "message_delta", content: "a" }, + { type: "thinking", content: "…" }, + { type: "message_delta", content: "b" }, + ]), + { onEvent: (ev) => seen.push(ev) }, + ); + expect(seen.map((e) => e.type)).toEqual([ + "message_delta", + "thinking", + "message_delta", + ]); + }); + + test("stops iterating once the signal aborts", async () => { + const controller = new AbortController(); + const emitted: string[] = []; + await consumeAdapterStream( + (async function* () { + yield { type: "message_delta", content: "first" } as AgentEvent; + controller.abort(); + yield { type: "message_delta", content: "second" } as AgentEvent; + })(), + { + signal: controller.signal, + onEvent: (ev) => { + if (ev.type === "message_delta") emitted.push(ev.content); + }, + }, + ); + expect(emitted).toEqual(["first"]); + }); + + test("returns an empty string for a stream with no content events", async () => { + const text = await consumeAdapterStream( + streamOf([{ type: "thinking", content: "…" }]), + ); + expect(text).toBe(""); + }); + + test("works without a signal (standalone runAgent path)", async () => { + const text = await consumeAdapterStream( + streamOf([ + { type: "message_delta", content: "x" }, + { type: "message_delta", content: "y" }, + ]), + ); + expect(text).toBe("xy"); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/normalize-result.test.ts b/packages/appkit/src/plugins/agents/tests/normalize-result.test.ts new file mode 100644 index 00000000..a0545d09 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/normalize-result.test.ts @@ -0,0 +1,63 @@ +import { describe, expect, test } from "vitest"; +import { + MAX_TOOL_RESULT_CHARS, + normalizeToolResult, +} from "../normalize-result"; + +describe("normalizeToolResult", () => { + test("maps undefined to empty string so void tools don't surface as errors", () => { + expect(normalizeToolResult(undefined)).toBe(""); + }); + + test("returns strings unchanged", () => { + expect(normalizeToolResult("hello")).toBe("hello"); + }); + + test("leaves non-string results intact (caller serialises)", () => { + const result = normalizeToolResult({ rows: 2, ok: true }); + expect(result).toEqual({ rows: 2, ok: true }); + }); + + test("returns an empty string input as an empty string (not undefined)", () => { + expect(normalizeToolResult("")).toBe(""); + }); + + test("preserves null without converting to empty string", () => { + expect(normalizeToolResult(null)).toBeNull(); + }); + + test("truncates long strings and appends a marker with the original length", () => { + const big = "x".repeat(MAX_TOOL_RESULT_CHARS + 1000); + const result = normalizeToolResult(big); + expect(typeof result).toBe("string"); + const s = result as string; + // Content portion is bounded to MAX_TOOL_RESULT_CHARS (plus the marker). + expect(s.slice(0, MAX_TOOL_RESULT_CHARS)).toBe( + "x".repeat(MAX_TOOL_RESULT_CHARS), + ); + expect(s).toMatch( + new RegExp( + `\\[Result truncated: ${big.length} chars exceeds ${MAX_TOOL_RESULT_CHARS} limit\\]`, + ), + ); + }); + + test("truncates long serialised objects the same way", () => { + const big = { blob: "x".repeat(MAX_TOOL_RESULT_CHARS + 10) }; + const result = normalizeToolResult(big); + expect(typeof result).toBe("string"); + expect(result as string).toMatch(/\[Result truncated:/); + }); + + test("honours a custom maxChars parameter", () => { + const result = normalizeToolResult("hello world", 5); + expect(result).toBe( + "hello\n\n[Result truncated: 11 chars exceeds 5 limit]", + ); + }); + + test("does not truncate at the boundary (exact length is fine)", () => { + const s = "x".repeat(MAX_TOOL_RESULT_CHARS); + expect(normalizeToolResult(s)).toBe(s); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tool-dispatch.ts b/packages/appkit/src/plugins/agents/tool-dispatch.ts new file mode 100644 index 00000000..a3e220bb --- /dev/null +++ b/packages/appkit/src/plugins/agents/tool-dispatch.ts @@ -0,0 +1,97 @@ +import type express from "express"; +import type { AppKitMcpClient } from "../../connectors/mcp"; +import type { PluginContext } from "../../core/plugin-context"; +import type { ResolvedToolEntry } from "./types"; + +interface ToolDispatchContext { + /** + * The originating HTTP request. Used by `toolkit` entries to scope execution + * to the caller's user context (`asUser(req)`) and by `mcp` entries to pick + * up the OBO bearer token from `x-forwarded-access-token`. + */ + req: express.Request; + /** Cancellation signal, forwarded to the tool implementation. */ + signal: AbortSignal; + /** + * PluginContext mediator — required to dispatch `toolkit` entries. Absent in + * unit tests that construct `AgentsPlugin` directly; callers may pass + * `null` / `undefined`, in which case toolkit calls throw a clear error. + */ + pluginContext?: PluginContext | null; + /** Live MCP client. Required for `mcp` entries. */ + mcpClient?: AppKitMcpClient | null; + /** + * Delegates a sub-agent invocation. The closure owns the recursion depth so + * the dispatcher itself remains depth-agnostic — the top-level caller + * passes `depth = 1`, and a sub-agent's inner dispatcher passes `depth + 1`. + */ + runSubAgent: (agentName: string, args: unknown) => Promise; +} + +/** + * Fan-out a resolved tool entry to the correct executor. One place to add a + * fifth `source` variant; `never`-typed default forces every caller to + * update in lockstep. + * + * This only handles dispatch — result normalisation (`normalizeToolResult`), + * budget counting, and approval gating remain at the call site, where each + * stream has different policies. + */ +export async function dispatchToolCall( + entry: ResolvedToolEntry, + args: unknown, + ctx: ToolDispatchContext, +): Promise { + switch (entry.source) { + case "toolkit": { + if (!ctx.pluginContext) { + throw new Error( + "Plugin tool execution requires PluginContext; " + + "this should never happen through createApp.", + ); + } + return ctx.pluginContext.executeTool( + ctx.req, + entry.pluginName, + entry.localName, + args, + ctx.signal, + ); + } + case "function": + return entry.functionTool.execute(args as Record); + case "mcp": { + if (!ctx.mcpClient) throw new Error("MCP client not connected"); + return ctx.mcpClient.callTool( + entry.mcpToolName, + args, + extractOboMcpAuth(ctx.req), + ); + } + case "subagent": + return ctx.runSubAgent(entry.agentName, args); + default: { + // Exhaustiveness guard: adding a new `source` to ResolvedToolEntry + // without teaching this switch breaks the build here. + const _exhaustive: never = entry; + throw new Error( + `Unsupported tool source: ${(_exhaustive as ResolvedToolEntry).source}`, + ); + } + } +} + +/** + * Extracts the caller's OBO bearer token from the standard Databricks Apps + * forwarded-auth header. MCP destinations that `forwardWorkspaceAuth` admits + * as same-origin will receive this header; non-workspace destinations drop + * it inside {@link AppKitMcpClient.callTool}. + */ +function extractOboMcpAuth( + req: express.Request, +): Record | undefined { + const oboToken = req.headers["x-forwarded-access-token"]; + return typeof oboToken === "string" + ? { Authorization: `Bearer ${oboToken}` } + : undefined; +} From 70804c3e1306ea1a24165cf3e27e639d18f1516e Mon Sep 17 00:00:00 2001 From: MarioCadenas Date: Fri, 24 Apr 2026 16:19:34 +0200 Subject: [PATCH 22/23] =?UTF-8?q?fix(agents):=20propagate=20tool=20annotat?= =?UTF-8?q?ions=20through=20tool()=20=E2=86=92=20FunctionTool=20=E2=86=92?= =?UTF-8?q?=20def?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The `annotations` field (notably `destructive: true`) was silently dropped as tools flowed from `tool({...})` into the resolved `AgentToolDefinition`, so user-defined destructive tools never triggered the approval gate. - `ToolConfig` now accepts `annotations?: ToolAnnotations`. - `tool()` forwards it to the returned `FunctionTool`. - `FunctionTool` exposes `annotations` and `functionToolToDefinition` preserves it on the definition it builds. - `AgentsPlugin` reads the flag via `isDestructiveToolEntry()` (falls back to `functionTool.annotations` so a future divergence between def and function cannot re-introduce the bug) and emits the merged annotations via `combinedToolAnnotations()` on the `approval_pending` SSE payload. Covered by `tests/tool-approval-gate.test.ts` and `tests/function-tool.test.ts`. --- packages/appkit/src/plugins/agents/agents.ts | 33 +++++++++++++++++-- .../src/plugins/agents/tools/function-tool.ts | 11 ++++++- .../appkit/src/plugins/agents/tools/tool.ts | 10 ++++++ 3 files changed, 51 insertions(+), 3 deletions(-) diff --git a/packages/appkit/src/plugins/agents/agents.ts b/packages/appkit/src/plugins/agents/agents.ts index de594848..529a68fc 100644 --- a/packages/appkit/src/plugins/agents/agents.ts +++ b/packages/appkit/src/plugins/agents/agents.ts @@ -11,6 +11,7 @@ import type { PluginPhase, ResponseStreamEvent, Thread, + ToolAnnotations, ToolProvider, } from "shared"; import { AppKitMcpClient, buildMcpHostPolicy } from "../../connectors/mcp"; @@ -755,7 +756,7 @@ export class AgentsPlugin extends Plugin implements ToolProvider { if ( approvalPolicy.requireForDestructive && - entry.def.annotations?.destructive === true + isDestructiveToolEntry(entry) ) { const approvalId = randomUUID(); for (const ev of translator.translate({ @@ -764,7 +765,7 @@ export class AgentsPlugin extends Plugin implements ToolProvider { streamId: requestId, toolName: name, args, - annotations: entry.def.annotations, + annotations: combinedToolAnnotations(entry), })) { outboundEvents.push(ev); } @@ -1154,6 +1155,34 @@ export class AgentsPlugin extends Plugin implements ToolProvider { } } +/** + * True when the tool should go through the destructive approval gate. + * `def.annotations` is the normal path; for `function` tools we also read + * `functionTool.annotations` so a mismatch between the spread def and the + * original {@link FunctionTool} cannot drop the flag. + */ +function isDestructiveToolEntry(entry: ResolvedToolEntry): boolean { + if (entry.def.annotations?.destructive === true) return true; + if (entry.source === "function" && entry.functionTool.annotations) { + return entry.functionTool.annotations.destructive === true; + } + return false; +} + +/** Merged annotations for the approval SSE payload (client UI + debugging). */ +function combinedToolAnnotations( + entry: ResolvedToolEntry, +): ToolAnnotations | undefined { + if (entry.source === "function") { + const merged: ToolAnnotations = { + ...entry.functionTool.annotations, + ...entry.def.annotations, + }; + return Object.keys(merged).length > 0 ? merged : undefined; + } + return entry.def.annotations; +} + function normalizeAutoInherit(value: AgentsPluginConfig["autoInheritTools"]): { file: boolean; code: boolean; diff --git a/packages/appkit/src/plugins/agents/tools/function-tool.ts b/packages/appkit/src/plugins/agents/tools/function-tool.ts index 8ce634e0..7371d857 100644 --- a/packages/appkit/src/plugins/agents/tools/function-tool.ts +++ b/packages/appkit/src/plugins/agents/tools/function-tool.ts @@ -1,4 +1,4 @@ -import type { AgentToolDefinition } from "shared"; +import type { AgentToolDefinition, ToolAnnotations } from "shared"; export interface FunctionTool { type: "function"; @@ -6,6 +6,14 @@ export interface FunctionTool { description?: string | null; parameters?: Record | null; strict?: boolean | null; + /** + * Behavioural flags that drive the agents plugin's approval gate and + * auto-inherit filtering. `destructive: true` forces HITL approval + * before execute() runs; `readOnly: true` marks safe-by-default tools. + * Must be preserved through {@link functionToolToDefinition} so the + * plugin sees them when building agent tool indexes. + */ + annotations?: ToolAnnotations; execute: (args: Record) => Promise | string; } @@ -29,5 +37,6 @@ export function functionToolToDefinition( type: "object", properties: {}, }, + ...(tool.annotations ? { annotations: tool.annotations } : {}), }; } diff --git a/packages/appkit/src/plugins/agents/tools/tool.ts b/packages/appkit/src/plugins/agents/tools/tool.ts index b5d4db65..370a1d4b 100644 --- a/packages/appkit/src/plugins/agents/tools/tool.ts +++ b/packages/appkit/src/plugins/agents/tools/tool.ts @@ -1,3 +1,4 @@ +import type { ToolAnnotations } from "shared"; import type { z } from "zod"; import type { FunctionTool } from "./function-tool"; import { toToolJSONSchema } from "./json-schema"; @@ -6,6 +7,14 @@ export interface ToolConfig { name: string; description?: string; schema: S; + /** + * Behavioural flags forwarded to the resolved tool definition. Required + * for the agents plugin to gate destructive tools through the approval + * card, surface `readOnly` tools to auto-inherit, etc. Dropped silently + * before the fix that added this field — any tool wanting HITL must + * set `annotations: { destructive: true }` here. + */ + annotations?: ToolAnnotations; execute: (args: z.infer) => Promise | string; } @@ -29,6 +38,7 @@ export function tool(config: ToolConfig): FunctionTool { name: config.name, description: config.description ?? config.name, parameters, + ...(config.annotations ? { annotations: config.annotations } : {}), execute: async (args: Record) => { const parsed = config.schema.safeParse(args); if (!parsed.success) { From 6c7291b077311321379a12ceadc7266e0c34dafc Mon Sep 17 00:00:00 2001 From: MarioCadenas Date: Fri, 24 Apr 2026 17:45:27 +0200 Subject: [PATCH 23/23] =?UTF-8?q?feat(agents):=20semantic=20ToolEffect=20?= =?UTF-8?q?=E2=80=94=20write/update/destructive=20tiers?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ToolAnnotations.destructive is binary and has started to mislead: "save_view" captures a screenshot and creates a new file, which is nothing like deleting a dashboard, yet both trip the same red "destructive" approval card. This adds a semantic `effect` enum with four tiers — `read`, `write`, `update`, `destructive` — so tool authors can tell the UI what blast radius they actually have. The approval gate fires for any mutating effect (`write`/`update`/ `destructive`) and continues to honour the legacy `destructive: true` flag so existing tools keep their current red treatment without migration. Callers consuming `annotations` over the wire (MCP clients, approval UIs) can now differentiate; the playground will ship a tiered approval card as a follow-up. --- packages/appkit/src/plugins/agents/agents.ts | 24 ++++++++++---- .../src/plugins/agents/tools/function-tool.ts | 12 ++++--- .../appkit/src/plugins/agents/tools/tool.ts | 11 ++++--- packages/shared/src/agent.ts | 32 +++++++++++++++++++ 4 files changed, 62 insertions(+), 17 deletions(-) diff --git a/packages/appkit/src/plugins/agents/agents.ts b/packages/appkit/src/plugins/agents/agents.ts index 529a68fc..ceed66e6 100644 --- a/packages/appkit/src/plugins/agents/agents.ts +++ b/packages/appkit/src/plugins/agents/agents.ts @@ -1156,16 +1156,26 @@ export class AgentsPlugin extends Plugin implements ToolProvider { } /** - * True when the tool should go through the destructive approval gate. - * `def.annotations` is the normal path; for `function` tools we also read - * `functionTool.annotations` so a mismatch between the spread def and the - * original {@link FunctionTool} cannot drop the flag. + * True when the tool should go through the approval gate. Historically + * scoped to `destructive: true` — hence the name — but now also fires for + * the semantic `effect` enum on {@link ToolAnnotations}. Any effect that + * mutates the world (`write` | `update` | `destructive`) gates; `read` and + * unannotated tools do not. `def.annotations` is the normal path; for + * `function` tools we also read `functionTool.annotations` so a mismatch + * between the spread def and the original {@link FunctionTool} cannot drop + * the hint. */ function isDestructiveToolEntry(entry: ResolvedToolEntry): boolean { - if (entry.def.annotations?.destructive === true) return true; - if (entry.source === "function" && entry.functionTool.annotations) { - return entry.functionTool.annotations.destructive === true; + const defAnn = entry.def.annotations; + const fnAnn = + entry.source === "function" ? entry.functionTool.annotations : undefined; + + const effect = defAnn?.effect ?? fnAnn?.effect; + if (effect === "write" || effect === "update" || effect === "destructive") { + return true; } + if (defAnn?.destructive === true) return true; + if (fnAnn?.destructive === true) return true; return false; } diff --git a/packages/appkit/src/plugins/agents/tools/function-tool.ts b/packages/appkit/src/plugins/agents/tools/function-tool.ts index 7371d857..19820f8f 100644 --- a/packages/appkit/src/plugins/agents/tools/function-tool.ts +++ b/packages/appkit/src/plugins/agents/tools/function-tool.ts @@ -7,11 +7,13 @@ export interface FunctionTool { parameters?: Record | null; strict?: boolean | null; /** - * Behavioural flags that drive the agents plugin's approval gate and - * auto-inherit filtering. `destructive: true` forces HITL approval - * before execute() runs; `readOnly: true` marks safe-by-default tools. - * Must be preserved through {@link functionToolToDefinition} so the - * plugin sees them when building agent tool indexes. + * Behavioural hints that drive the agents plugin's approval gate and the + * client's approval-card styling. Prefer setting `effect` (one of + * `"read" | "write" | "update" | "destructive"`) — any mutating value + * forces HITL approval before `execute()` runs. Legacy `destructive: true` + * is still honoured. Must be preserved through {@link + * functionToolToDefinition} so the plugin sees them when building agent + * tool indexes. */ annotations?: ToolAnnotations; execute: (args: Record) => Promise | string; diff --git a/packages/appkit/src/plugins/agents/tools/tool.ts b/packages/appkit/src/plugins/agents/tools/tool.ts index 370a1d4b..53305c23 100644 --- a/packages/appkit/src/plugins/agents/tools/tool.ts +++ b/packages/appkit/src/plugins/agents/tools/tool.ts @@ -8,11 +8,12 @@ export interface ToolConfig { description?: string; schema: S; /** - * Behavioural flags forwarded to the resolved tool definition. Required - * for the agents plugin to gate destructive tools through the approval - * card, surface `readOnly` tools to auto-inherit, etc. Dropped silently - * before the fix that added this field — any tool wanting HITL must - * set `annotations: { destructive: true }` here. + * Behavioural hints forwarded to the resolved tool definition. Prefer + * `effect` (`"read" | "write" | "update" | "destructive"`) — any mutating + * value forces the agents-plugin approval gate before `execute()` runs + * and the client's approval card will colour itself accordingly. Legacy + * `destructive: true` still gates. Dropped silently before the fix that + * added this field. */ annotations?: ToolAnnotations; execute: (args: z.infer) => Promise | string; diff --git a/packages/shared/src/agent.ts b/packages/shared/src/agent.ts index 8e34d5fb..5e22126b 100644 --- a/packages/shared/src/agent.ts +++ b/packages/shared/src/agent.ts @@ -4,8 +4,40 @@ import type { JSONSchema7 } from "json-schema"; // Tool definitions // --------------------------------------------------------------------------- +/** + * Semantic hint for what the tool does to the world. Drives both the + * agents-plugin approval gate and the client's approval-card styling. + * + * - `read` — observes only; never needs approval. + * - `write` — creates or appends new state (e.g. saving a new view). Approval + * required by default. Rendered as a low-severity "writes" card. + * - `update` — mutates existing state in place (e.g. renaming, toggling). + * Approval required. Rendered as a medium-severity "updates" card. + * - `destructive` — deletes or irreversibly mutates (e.g. dropping a view). + * Approval required. Rendered as a high-severity "destructive" card. + * + * Prefer this over the legacy `readOnly`/`destructive` booleans: it lets the + * UI distinguish "captured a screenshot" from "deleted a dashboard", both of + * which today are lumped under a single red "destructive" label. + */ +export type ToolEffect = "read" | "write" | "update" | "destructive"; + export interface ToolAnnotations { + /** + * Preferred semantic label. When set, drives both the approval gate (fires + * for `write`/`update`/`destructive`) and the approval-card styling. + */ + effect?: ToolEffect; + /** + * @deprecated Prefer {@link effect}. Retained for backward compatibility + * with tools authored against the original flags and for MCP interop. + */ readOnly?: boolean; + /** + * @deprecated Prefer {@link effect} with value `"destructive"`. Retained + * so existing annotations continue to force the approval gate, and so + * MCP-style consumers that only read `destructive` still see the hint. + */ destructive?: boolean; idempotent?: boolean; requiresUserContext?: boolean;