From 3fab6a860a1ccecebda7db4f0380a55c3c0e6c0e Mon Sep 17 00:00:00 2001 From: Mahmoud Mabrouk Date: Fri, 19 Jun 2026 14:01:20 +0200 Subject: [PATCH] feat(sdk): add typed agent tool resolution --- sdks/python/agenta/sdk/agents/__init__.py | 80 ++- .../agenta/sdk/agents/adapters/harnesses.py | 46 +- sdks/python/agenta/sdk/agents/dtos.py | 192 ++++++- sdks/python/agenta/sdk/agents/errors.py | 3 + sdks/python/agenta/sdk/agents/mcp/__init__.py | 22 + sdks/python/agenta/sdk/agents/mcp/errors.py | 33 ++ .../agenta/sdk/agents/mcp/interfaces.py | 10 + sdks/python/agenta/sdk/agents/mcp/models.py | 57 ++ sdks/python/agenta/sdk/agents/mcp/parsing.py | 39 ++ sdks/python/agenta/sdk/agents/mcp/resolver.py | 68 +++ sdks/python/agenta/sdk/agents/mcp/wire.py | 17 + .../agenta/sdk/agents/tools/__init__.py | 75 +++ sdks/python/agenta/sdk/agents/tools/compat.py | 132 +++++ sdks/python/agenta/sdk/agents/tools/errors.py | 82 +++ .../agenta/sdk/agents/tools/interfaces.py | 20 + sdks/python/agenta/sdk/agents/tools/models.py | 221 ++++++++ .../python/agenta/sdk/agents/tools/parsing.py | 39 ++ .../agenta/sdk/agents/tools/resolver.py | 177 +++++++ sdks/python/agenta/sdk/agents/tools/wire.py | 15 + sdks/python/agenta/sdk/agents/ui_messages.py | 491 ++++++++++++++++++ sdks/python/agenta/sdk/agents/utils/wire.py | 3 + .../pytest/integration/agents/__init__.py | 1 + .../agents/test_transport_roundtrip.py | 113 ++++ .../agents/golden/run_request.claude.json | 3 +- .../unit/agents/golden/run_request.pi.json | 3 +- .../tests/pytest/unit/agents/mcp/__init__.py | 1 + .../pytest/unit/agents/mcp/test_resolver.py | 76 +++ .../unit/agents/test_dtos_agent_config.py | 16 +- .../unit/agents/test_dtos_harness_configs.py | 33 +- .../unit/agents/test_harness_adapters.py | 24 +- .../pytest/unit/agents/test_ui_messages.py | 429 +++++++++++++++ .../pytest/unit/agents/test_wire_contract.py | 77 +++ .../pytest/unit/agents/tools/__init__.py | 1 + .../pytest/unit/agents/tools/test_models.py | 63 +++ .../pytest/unit/agents/tools/test_parsing.py | 60 +++ .../pytest/unit/agents/tools/test_resolver.py | 131 +++++ 36 files changed, 2775 insertions(+), 78 deletions(-) create mode 100644 sdks/python/agenta/sdk/agents/mcp/__init__.py create mode 100644 sdks/python/agenta/sdk/agents/mcp/errors.py create mode 100644 sdks/python/agenta/sdk/agents/mcp/interfaces.py create mode 100644 sdks/python/agenta/sdk/agents/mcp/models.py create mode 100644 sdks/python/agenta/sdk/agents/mcp/parsing.py create mode 100644 sdks/python/agenta/sdk/agents/mcp/resolver.py create mode 100644 sdks/python/agenta/sdk/agents/mcp/wire.py create mode 100644 sdks/python/agenta/sdk/agents/tools/__init__.py create mode 100644 sdks/python/agenta/sdk/agents/tools/compat.py create mode 100644 sdks/python/agenta/sdk/agents/tools/errors.py create mode 100644 sdks/python/agenta/sdk/agents/tools/interfaces.py create mode 100644 sdks/python/agenta/sdk/agents/tools/models.py create mode 100644 sdks/python/agenta/sdk/agents/tools/parsing.py create mode 100644 sdks/python/agenta/sdk/agents/tools/resolver.py create mode 100644 sdks/python/agenta/sdk/agents/tools/wire.py create mode 100644 sdks/python/agenta/sdk/agents/ui_messages.py create mode 100644 sdks/python/oss/tests/pytest/integration/agents/__init__.py create mode 100644 sdks/python/oss/tests/pytest/integration/agents/test_transport_roundtrip.py create mode 100644 sdks/python/oss/tests/pytest/unit/agents/mcp/__init__.py create mode 100644 sdks/python/oss/tests/pytest/unit/agents/mcp/test_resolver.py create mode 100644 sdks/python/oss/tests/pytest/unit/agents/test_ui_messages.py create mode 100644 sdks/python/oss/tests/pytest/unit/agents/tools/__init__.py create mode 100644 sdks/python/oss/tests/pytest/unit/agents/tools/test_models.py create mode 100644 sdks/python/oss/tests/pytest/unit/agents/tools/test_parsing.py create mode 100644 sdks/python/oss/tests/pytest/unit/agents/tools/test_resolver.py diff --git a/sdks/python/agenta/sdk/agents/__init__.py b/sdks/python/agenta/sdk/agents/__init__.py index 38c5daca39..9df80c3fd9 100644 --- a/sdks/python/agenta/sdk/agents/__init__.py +++ b/sdks/python/agenta/sdk/agents/__init__.py @@ -48,9 +48,47 @@ TraceContext, to_messages, ) -from .errors import UnsupportedHarnessError +from .errors import ToolResolutionError, UnsupportedHarnessError from .interfaces import Backend, Environment, Harness, Sandbox, Session +from .mcp import ( + MCPConfigurationError, + MCPError, + MCPResolver, + MCPServerConfig, + MissingMCPSecretError, + ResolvedMCPServer, +) from .streaming import AgentRun +from .tools import ( + BuiltinToolConfig, + CallbackToolSpec, + ClientToolConfig, + ClientToolSpec, + CodeToolConfig, + CodeToolSpec, + DuplicateToolNameError, + EnvironmentToolSecretProvider, + GatewayToolResolver, + GatewayToolConfig, + GatewayToolResolution, + GatewayToolResolutionError, + MissingSecretPolicy, + MissingToolSecretError, + ResolvedToolSet, + ToolConfig, + ToolConfigError, + ToolConfigurationError, + ToolError, + ToolResolver, + ToolSecretProvider, + ToolSpec, + UnsupportedToolProviderError, + coerce_tool_config, + coerce_tool_configs, + parse_tool_config, + parse_tool_configs, +) +from .ui_messages import from_ui_messages, to_ui_message, ui_message_stream __all__ = [ # DTOs @@ -69,9 +107,48 @@ "AgentEvent", "AgentResult", "AgentRun", + # UI message codec (the /messages egress adapter) + "from_ui_messages", + "to_ui_message", + "ui_message_stream", "TraceContext", "ToolCallback", "PermissionPolicy", + # Canonical tools API + "ToolConfig", + "BuiltinToolConfig", + "GatewayToolConfig", + "CodeToolConfig", + "ClientToolConfig", + "ToolSpec", + "CallbackToolSpec", + "CodeToolSpec", + "ClientToolSpec", + "ResolvedToolSet", + "GatewayToolResolution", + "ToolResolver", + "ToolSecretProvider", + "GatewayToolResolver", + "EnvironmentToolSecretProvider", + "MissingSecretPolicy", + "parse_tool_config", + "parse_tool_configs", + "coerce_tool_config", + "coerce_tool_configs", + "ToolError", + "ToolConfigError", + "ToolConfigurationError", + "GatewayToolResolutionError", + "UnsupportedToolProviderError", + "MissingToolSecretError", + "DuplicateToolNameError", + # MCP is a sibling subsystem + "MCPServerConfig", + "ResolvedMCPServer", + "MCPResolver", + "MCPError", + "MCPConfigurationError", + "MissingMCPSecretError", # Interfaces (ports) "Backend", "Sandbox", @@ -80,6 +157,7 @@ "Harness", # Errors "UnsupportedHarnessError", + "ToolResolutionError", # Adapters "RivetBackend", "InProcessPiBackend", diff --git a/sdks/python/agenta/sdk/agents/adapters/harnesses.py b/sdks/python/agenta/sdk/agents/adapters/harnesses.py index 31e52d73da..0ac53076a8 100644 --- a/sdks/python/agenta/sdk/agents/adapters/harnesses.py +++ b/sdks/python/agenta/sdk/agents/adapters/harnesses.py @@ -28,6 +28,7 @@ SessionConfig, ) from ..interfaces import Environment, Harness +from ..tools.models import ToolSpec, coerce_tool_spec from .agenta_builtins import ( AGENTA_FORCED_SKILLS, compose_append_system, @@ -37,8 +38,6 @@ log = get_module_logger(__name__) -_EMPTY_OBJECT_SCHEMA: Dict[str, Any] = {"type": "object", "properties": {}} - def _opt_str(value: Any) -> Any: """Keep a harness option only if it is a non-empty string; otherwise drop it to ``None`` @@ -48,29 +47,9 @@ def _opt_str(value: Any) -> Any: return None -def _normalize_tool_specs(specs: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """Coerce resolved tool specs into the shape every harness expects. - - Drops malformed entries (no name) and fills the defaults the harness runtimes need: a - description (falls back to the name) and a JSON-Schema ``inputSchema`` (an empty object - when none was resolved). ``callRef`` is preserved so the call routes back to Agenta. - """ - normalized: List[Dict[str, Any]] = [] - for spec in specs or []: - if not isinstance(spec, dict): - continue - name = spec.get("name") - if not name: - continue - normalized.append( - { - "name": name, - "description": spec.get("description") or name, - "inputSchema": spec.get("inputSchema") or dict(_EMPTY_OBJECT_SCHEMA), - "callRef": spec.get("callRef"), - } - ) - return normalized +def _normalize_tool_specs(specs: List[Dict[str, Any]]) -> List[ToolSpec]: + """Compatibility helper for callers still supplying runner dictionaries.""" + return [coerce_tool_spec(spec) for spec in specs or []] class PiHarness(Harness): @@ -85,9 +64,10 @@ def _to_harness_config(self, config: SessionConfig) -> PiAgentConfig: return PiAgentConfig( agents_md=config.agent.instructions, model=config.agent.model, - builtin_tools=list(config.builtin_tools), - custom_tools=_normalize_tool_specs(config.custom_tools), + builtin_names=list(config.builtin_names), + tool_specs=list(config.tool_specs), tool_callback=config.tool_callback, + mcp_servers=list(config.mcp_servers), system=_opt_str(pi_options.get("system")), append_system=_opt_str(pi_options.get("append_system")), ) @@ -100,16 +80,17 @@ def _to_harness_config(self, config: SessionConfig) -> ClaudeAgentConfig: # Claude has no Pi built-in tools; drop them rather than ship a name Claude cannot # honor. Tools go over MCP, and Claude gates tool use, so the permission policy is # carried through. - if config.builtin_tools: + if config.builtin_names: log.warning( "ClaudeHarness ignores %d built-in tool(s); built-ins are a Pi concept", - len(config.builtin_tools), + len(config.builtin_names), ) return ClaudeAgentConfig( agents_md=config.agent.instructions, model=config.agent.model, - custom_tools=_normalize_tool_specs(config.custom_tools), + tool_specs=list(config.tool_specs), tool_callback=config.tool_callback, + mcp_servers=list(config.mcp_servers), permission_policy=config.permission_policy, ) @@ -130,9 +111,10 @@ def _to_harness_config(self, config: SessionConfig) -> AgentaAgentConfig: return AgentaAgentConfig( agents_md=compose_instructions(config.agent.instructions), model=config.agent.model, - builtin_tools=force_tools(list(config.builtin_tools)), - custom_tools=_normalize_tool_specs(config.custom_tools), + builtin_names=force_tools(list(config.builtin_names)), + tool_specs=list(config.tool_specs), tool_callback=config.tool_callback, + mcp_servers=list(config.mcp_servers), system=_opt_str(pi_options.get("system")), append_system=compose_append_system( _opt_str(pi_options.get("append_system")) diff --git a/sdks/python/agenta/sdk/agents/dtos.py b/sdks/python/agenta/sdk/agents/dtos.py index 578ece4901..0a050b4cb1 100644 --- a/sdks/python/agenta/sdk/agents/dtos.py +++ b/sdks/python/agenta/sdk/agents/dtos.py @@ -13,7 +13,16 @@ from enum import Enum from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union -from pydantic import BaseModel, Field +from pydantic import AliasChoices, BaseModel, ConfigDict, Field, field_validator + +from .mcp import ( + MCPServerConfig, + ResolvedMCPServer, + mcp_servers_to_wire, + parse_mcp_server_configs, +) +from .tools import ToolCallback, ToolConfig, ToolSpec, coerce_tool_configs +from .tools.models import coerce_tool_spec # --------------------------------------------------------------------------- @@ -98,13 +107,26 @@ class ContentBlock(BaseModel): ``text`` is the only kind callers send today; ``image`` and ``resource`` are plumbed so an image-capable harness can take them. A bare string normalizes to a single ``text`` block on the wire. + + ``tool_call`` / ``tool_result`` carriers (``tool_call_id``/``tool_name``/``input``/ + ``output``/``is_error``) hold a resolved tool turn for structured-message continuation: + the ``/messages`` egress folds inbound UIMessage tool/approval parts into these so a + cross-turn HITL reply replays as a real tool call plus its result, and the model resumes + from the result instead of re-asking. Mirrors ``ContentBlock`` in + ``services/agent/src/protocol.ts``. """ - type: str # "text" | "image" | "resource" + type: str # "text" | "image" | "resource" | "tool_call" | "tool_result" text: Optional[str] = None data: Optional[str] = None # base64 payload, used when type != "text" mime_type: Optional[str] = None uri: Optional[str] = None + # Tool-turn carriers (used by tool_call / tool_result blocks). + tool_call_id: Optional[str] = None + tool_name: Optional[str] = None + input: Optional[Any] = None + output: Optional[Any] = None + is_error: Optional[bool] = None def to_wire(self) -> Dict[str, Any]: block: Dict[str, Any] = {"type": self.type} @@ -116,6 +138,16 @@ def to_wire(self) -> Dict[str, Any]: block["mimeType"] = self.mime_type if self.uri is not None: block["uri"] = self.uri + if self.tool_call_id is not None: + block["toolCallId"] = self.tool_call_id + if self.tool_name is not None: + block["toolName"] = self.tool_name + if self.input is not None: + block["input"] = self.input + if self.output is not None: + block["output"] = self.output + if self.is_error is not None: + block["isError"] = self.is_error return block @classmethod @@ -132,6 +164,13 @@ def from_raw(cls, raw: Any) -> "ContentBlock": data=raw.get("data"), mime_type=raw.get("mimeType") or raw.get("mime_type"), uri=raw.get("uri"), + tool_call_id=raw.get("toolCallId") or raw.get("tool_call_id"), + tool_name=raw.get("toolName") or raw.get("tool_name"), + input=raw.get("input"), + output=raw.get("output"), + is_error=raw.get("isError") + if raw.get("isError") is not None + else raw.get("is_error"), ) return cls(type="text", text=str(raw)) @@ -232,18 +271,6 @@ def to_wire(self) -> Dict[str, Any]: } -class ToolCallback(BaseModel): - """How a harness routes a tool call back through Agenta's ``/tools/call``. The provider - key and connection auth stay server-side. Empty for a standalone run with no - Agenta-resolved tools.""" - - endpoint: str # full ``/tools/call`` URL - authorization: Optional[str] = None - - def to_wire(self) -> Dict[str, Any]: - return {"endpoint": self.endpoint, "authorization": self.authorization} - - # --------------------------------------------------------------------------- # Run result # --------------------------------------------------------------------------- @@ -282,11 +309,24 @@ class AgentConfig(BaseModel): ignores the rest; a key for a harness that is not running is simply never looked at. """ + model_config = ConfigDict(populate_by_name=True) + instructions: Optional[str] = None model: Optional[str] = None - tools: List[Any] = Field(default_factory=list) + tools: List[ToolConfig] = Field(default_factory=list) + mcp_servers: List[MCPServerConfig] = Field(default_factory=list) harness_options: Dict[str, Dict[str, Any]] = Field(default_factory=dict) + @field_validator("tools", mode="before") + @classmethod + def _coerce_tools(cls, value: Any) -> List[ToolConfig]: + return coerce_tool_configs(_as_list(value)).tool_configs + + @field_validator("mcp_servers", mode="before") + @classmethod + def _coerce_mcp_servers(cls, value: Any) -> List[MCPServerConfig]: + return parse_mcp_server_configs(_as_list(value)) + @classmethod def from_params( cls, @@ -308,6 +348,7 @@ def from_params( instructions=instructions, model=model, tools=_as_list(tools), + mcp_servers=_parse_mcp_servers_raw(params, base), harness_options=_parse_harness_options(params, base), ) @@ -354,11 +395,24 @@ class HarnessAgentConfig(BaseModel): fields for the ``/run`` payload. """ + model_config = ConfigDict(populate_by_name=True) + harness: ClassVar[HarnessType] agents_md: Optional[str] = None model: Optional[str] = None tool_callback: Optional[ToolCallback] = None + mcp_servers: List[ResolvedMCPServer] = Field(default_factory=list) + + @field_validator("mcp_servers", mode="before") + @classmethod + def _coerce_resolved_mcp_servers(cls, value: Any) -> List[ResolvedMCPServer]: + return [ + item + if isinstance(item, ResolvedMCPServer) + else ResolvedMCPServer.model_validate(item) + for item in value or [] + ] def wire_tools(self) -> Dict[str, Any]: """The tool + permission fields this harness contributes to the ``/run`` payload.""" @@ -369,6 +423,13 @@ def wire_prompt(self) -> Dict[str, Any]: by default; a harness that exposes prompt overrides (Pi) emits them here.""" return {} + def wire_mcp(self) -> Dict[str, Any]: + """The ``mcpServers`` field for the ``/run`` payload. Omitted when none are declared so + a tool-free run's payload is unchanged (the golden wire contract).""" + if not self.mcp_servers: + return {} + return {"mcpServers": mcp_servers_to_wire(self.mcp_servers)} + class PiAgentConfig(HarnessAgentConfig): """Pi's config. Built-in tools by name plus resolved specs delivered natively (Pi has no @@ -385,15 +446,34 @@ class PiAgentConfig(HarnessAgentConfig): harness: ClassVar[HarnessType] = HarnessType.PI - builtin_tools: List[str] = Field(default_factory=list) - custom_tools: List[Dict[str, Any]] = Field(default_factory=list) + builtin_names: List[str] = Field( + default_factory=list, + validation_alias=AliasChoices("builtin_names", "builtin_tools"), + ) + tool_specs: List[ToolSpec] = Field( + default_factory=list, + validation_alias=AliasChoices("tool_specs", "custom_tools"), + ) system: Optional[str] = None append_system: Optional[str] = None + @field_validator("tool_specs", mode="before") + @classmethod + def _coerce_tool_specs(cls, value: Any) -> List[ToolSpec]: + return [coerce_tool_spec(item) for item in value or []] + + @property + def builtin_tools(self) -> List[str]: + return list(self.builtin_names) + + @property + def custom_tools(self) -> List[Dict[str, Any]]: + return [tool_spec.to_wire() for tool_spec in self.tool_specs] + def wire_tools(self) -> Dict[str, Any]: return { - "tools": list(self.builtin_tools), - "customTools": list(self.custom_tools), + "tools": list(self.builtin_names), + "customTools": [tool_spec.to_wire() for tool_spec in self.tool_specs], "toolCallback": self.tool_callback.to_wire() if self.tool_callback else None, @@ -415,13 +495,25 @@ class ClaudeAgentConfig(HarnessAgentConfig): harness: ClassVar[HarnessType] = HarnessType.CLAUDE - custom_tools: List[Dict[str, Any]] = Field(default_factory=list) + tool_specs: List[ToolSpec] = Field( + default_factory=list, + validation_alias=AliasChoices("tool_specs", "custom_tools"), + ) permission_policy: PermissionPolicy = "auto" + @field_validator("tool_specs", mode="before") + @classmethod + def _coerce_tool_specs(cls, value: Any) -> List[ToolSpec]: + return [coerce_tool_spec(item) for item in value or []] + + @property + def custom_tools(self) -> List[Dict[str, Any]]: + return [tool_spec.to_wire() for tool_spec in self.tool_specs] + def wire_tools(self) -> Dict[str, Any]: return { "tools": [], # Claude has no Pi built-in tools - "customTools": list(self.custom_tools), + "customTools": [tool_spec.to_wire() for tool_spec in self.tool_specs], "toolCallback": self.tool_callback.to_wire() if self.tool_callback else None, @@ -460,14 +552,46 @@ class SessionConfig(BaseModel): empty for a bare standalone run). Sandbox is intentionally absent: it is a backend/environment concern.""" + model_config = ConfigDict(populate_by_name=True) + agent: AgentConfig secrets: Dict[str, str] = Field(default_factory=dict) permission_policy: PermissionPolicy = "auto" trace: Optional[TraceContext] = None session_id: Optional[str] = None - builtin_tools: List[str] = Field(default_factory=list) - custom_tools: List[Dict[str, Any]] = Field(default_factory=list) + builtin_names: List[str] = Field( + default_factory=list, + validation_alias=AliasChoices("builtin_names", "builtin_tools"), + ) + tool_specs: List[ToolSpec] = Field( + default_factory=list, + validation_alias=AliasChoices("tool_specs", "custom_tools"), + ) tool_callback: Optional[ToolCallback] = None + mcp_servers: List[ResolvedMCPServer] = Field(default_factory=list) + + @field_validator("tool_specs", mode="before") + @classmethod + def _coerce_tool_specs(cls, value: Any) -> List[ToolSpec]: + return [coerce_tool_spec(item) for item in value or []] + + @field_validator("mcp_servers", mode="before") + @classmethod + def _coerce_resolved_mcp_servers(cls, value: Any) -> List[ResolvedMCPServer]: + return [ + item + if isinstance(item, ResolvedMCPServer) + else ResolvedMCPServer.model_validate(item) + for item in value or [] + ] + + @property + def builtin_tools(self) -> List[str]: + return list(self.builtin_names) + + @property + def custom_tools(self) -> List[Dict[str, Any]]: + return [tool_spec.to_wire() for tool_spec in self.tool_specs] # --------------------------------------------------------------------------- @@ -483,6 +607,22 @@ def _as_list(raw: Any) -> List[Any]: return [] +def _parse_mcp_servers_raw( + params: Dict[str, Any], + defaults: AgentConfig, +) -> List[Any]: + """Pull the raw ``mcp_servers`` list from a request/config dict, falling back to defaults. + + Reads ``mcp_servers`` from the ``agent`` element when present, else the flat request. + Canonical validation happens on :class:`AgentConfig` construction.""" + agent = params.get("agent") + source = agent if isinstance(agent, dict) else params + raw = source.get("mcp_servers") + if raw is None: + return list(defaults.mcp_servers) + return _as_list(raw) + + def _parse_harness_options( params: Dict[str, Any], defaults: AgentConfig, @@ -530,8 +670,12 @@ def _parse_agent_fields( """Pull (instructions, model, tools) from a request/config dict, with fallbacks.""" agent = params.get("agent") if isinstance(agent, dict): + # ``agents_md`` is the field the playground/catalog schema exposes; ``instructions`` is + # the legacy key kept as a fallback so already-stored agent configs still resolve. return ( - agent.get("instructions") or defaults.instructions, + agent.get("agents_md") + or agent.get("instructions") + or defaults.instructions, agent.get("model") or defaults.model, agent.get("tools"), ) diff --git a/sdks/python/agenta/sdk/agents/errors.py b/sdks/python/agenta/sdk/agents/errors.py index dfe412253d..b9f136a472 100644 --- a/sdks/python/agenta/sdk/agents/errors.py +++ b/sdks/python/agenta/sdk/agents/errors.py @@ -5,6 +5,9 @@ from typing import TYPE_CHECKING from .dtos import HarnessType +from .tools.errors import ToolResolutionError + +__all__ = ["UnsupportedHarnessError", "ToolResolutionError"] if TYPE_CHECKING: from .interfaces import Backend diff --git a/sdks/python/agenta/sdk/agents/mcp/__init__.py b/sdks/python/agenta/sdk/agents/mcp/__init__.py new file mode 100644 index 0000000000..4881f30d52 --- /dev/null +++ b/sdks/python/agenta/sdk/agents/mcp/__init__.py @@ -0,0 +1,22 @@ +"""Public MCP configuration and resolution API.""" + +from .errors import MCPConfigurationError, MCPError, MissingMCPSecretError +from .interfaces import MCPSecretProvider +from .models import MCPServerConfig, ResolvedMCPServer +from .parsing import parse_mcp_server_config, parse_mcp_server_configs +from .resolver import MCPResolver +from .wire import mcp_server_to_wire, mcp_servers_to_wire + +__all__ = [ + "MCPServerConfig", + "ResolvedMCPServer", + "MCPSecretProvider", + "MCPResolver", + "parse_mcp_server_config", + "parse_mcp_server_configs", + "mcp_server_to_wire", + "mcp_servers_to_wire", + "MCPError", + "MCPConfigurationError", + "MissingMCPSecretError", +] diff --git a/sdks/python/agenta/sdk/agents/mcp/errors.py b/sdks/python/agenta/sdk/agents/mcp/errors.py new file mode 100644 index 0000000000..2d2ab05193 --- /dev/null +++ b/sdks/python/agenta/sdk/agents/mcp/errors.py @@ -0,0 +1,33 @@ +"""Errors raised while parsing and resolving MCP server configuration.""" + +from __future__ import annotations + +from typing import Any, Optional, Sequence + + +class MCPError(RuntimeError): + """Base error for the agent MCP subsystem.""" + + +class MCPConfigurationError(MCPError): + def __init__( + self, + message: str, + *, + index: Optional[int] = None, + value: Any = None, + ) -> None: + super().__init__(message) + self.index = index + self.value = value + + +class MissingMCPSecretError(MCPError): + def __init__(self, *, server_name: str, secret_names: Sequence[str]) -> None: + names = tuple(secret_names) + super().__init__( + f"MCP server '{server_name}' is missing required secret(s): " + f"{', '.join(names)}" + ) + self.server_name = server_name + self.secret_names = names diff --git a/sdks/python/agenta/sdk/agents/mcp/interfaces.py b/sdks/python/agenta/sdk/agents/mcp/interfaces.py new file mode 100644 index 0000000000..23c5c91522 --- /dev/null +++ b/sdks/python/agenta/sdk/agents/mcp/interfaces.py @@ -0,0 +1,10 @@ +"""Injected dependencies used by MCP resolution.""" + +from __future__ import annotations + +from typing import Mapping, Protocol, Sequence + + +class MCPSecretProvider(Protocol): + async def get_many(self, names: Sequence[str]) -> Mapping[str, str]: + """Return available values for the requested MCP secret names.""" diff --git a/sdks/python/agenta/sdk/agents/mcp/models.py b/sdks/python/agenta/sdk/agents/mcp/models.py new file mode 100644 index 0000000000..e4df7f87e5 --- /dev/null +++ b/sdks/python/agenta/sdk/agents/mcp/models.py @@ -0,0 +1,57 @@ +"""Canonical MCP server declarations and resolved runner configuration.""" + +from __future__ import annotations + +from typing import Any, Dict, List, Literal, Optional + +from pydantic import BaseModel, ConfigDict, Field, model_validator + + +class MCPServerConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + name: str = Field(min_length=1) + transport: Literal["stdio", "http"] = "stdio" + command: Optional[str] = None + args: List[str] = Field(default_factory=list) + env: Dict[str, str] = Field(default_factory=dict, repr=False) + url: Optional[str] = None + secrets: Dict[str, str] = Field(default_factory=dict) + tools: List[str] = Field(default_factory=list) + + @model_validator(mode="after") + def _validate_transport(self) -> "MCPServerConfig": + if self.transport == "stdio" and not self.command: + raise ValueError("stdio MCP server requires command") + if self.transport == "http" and not self.url: + raise ValueError("http MCP server requires url") + return self + + +class ResolvedMCPServer(BaseModel): + model_config = ConfigDict(extra="forbid", frozen=True) + + name: str + transport: Literal["stdio", "http"] = "stdio" + command: Optional[str] = None + args: List[str] = Field(default_factory=list) + env: Dict[str, str] = Field(default_factory=dict, repr=False) + url: Optional[str] = None + tools: List[str] = Field(default_factory=list) + + def to_wire(self) -> Dict[str, Any]: + wire: Dict[str, Any] = { + "name": self.name, + "transport": self.transport, + } + if self.command: + wire["command"] = self.command + if self.args: + wire["args"] = list(self.args) + if self.env: + wire["env"] = dict(self.env) + if self.url: + wire["url"] = self.url + if self.tools: + wire["tools"] = list(self.tools) + return wire diff --git a/sdks/python/agenta/sdk/agents/mcp/parsing.py b/sdks/python/agenta/sdk/agents/mcp/parsing.py new file mode 100644 index 0000000000..dfb5f169a6 --- /dev/null +++ b/sdks/python/agenta/sdk/agents/mcp/parsing.py @@ -0,0 +1,39 @@ +"""Strict parsing of MCP server configuration.""" + +from __future__ import annotations + +from typing import Any, Mapping, Sequence + +from pydantic import ValidationError + +from .errors import MCPConfigurationError +from .models import MCPServerConfig + + +def parse_mcp_server_config( + value: MCPServerConfig | Mapping[str, Any], +) -> MCPServerConfig: + try: + return MCPServerConfig.model_validate(value) + except ValidationError as exc: + raise MCPConfigurationError( + "Invalid MCP server configuration: " + f"{exc.errors(include_url=False, include_input=False)}", + value=value, + ) from exc + + +def parse_mcp_server_configs( + values: Sequence[MCPServerConfig | Mapping[str, Any]], +) -> list[MCPServerConfig]: + parsed: list[MCPServerConfig] = [] + for index, value in enumerate(values): + try: + parsed.append(parse_mcp_server_config(value)) + except MCPConfigurationError as exc: + raise MCPConfigurationError( + str(exc), + index=index, + value=value, + ) from exc + return parsed diff --git a/sdks/python/agenta/sdk/agents/mcp/resolver.py b/sdks/python/agenta/sdk/agents/mcp/resolver.py new file mode 100644 index 0000000000..6ce78162dd --- /dev/null +++ b/sdks/python/agenta/sdk/agents/mcp/resolver.py @@ -0,0 +1,68 @@ +"""Resolution of MCP server declarations into runner configuration.""" + +from __future__ import annotations + +from typing import Mapping, Sequence + +from agenta.sdk.agents.tools.models import MissingSecretPolicy + +from .errors import MissingMCPSecretError +from .interfaces import MCPSecretProvider +from .models import MCPServerConfig, ResolvedMCPServer + + +class MCPResolver: + def __init__( + self, + *, + secret_provider: MCPSecretProvider, + missing_secret_policy: MissingSecretPolicy = MissingSecretPolicy.ERROR, + ) -> None: + self._secret_provider = secret_provider + self._missing_secret_policy = missing_secret_policy + + async def resolve( + self, + server_configs: Sequence[MCPServerConfig], + ) -> list[ResolvedMCPServer]: + secret_names = sorted( + { + secret_name + for server_config in server_configs + for secret_name in server_config.secrets.values() + } + ) + secret_values: Mapping[str, str] = ( + await self._secret_provider.get_many(secret_names) if secret_names else {} + ) + + resolved: list[ResolvedMCPServer] = [] + for server_config in server_configs: + missing = [ + secret_name + for secret_name in server_config.secrets.values() + if secret_name not in secret_values + ] + if missing and self._missing_secret_policy == MissingSecretPolicy.ERROR: + raise MissingMCPSecretError( + server_name=server_config.name, + secret_names=missing, + ) + + env = dict(server_config.env) + for env_var, secret_name in server_config.secrets.items(): + if secret_name in secret_values: + env[env_var] = secret_values[secret_name] + + resolved.append( + ResolvedMCPServer( + name=server_config.name, + transport=server_config.transport, + command=server_config.command, + args=list(server_config.args), + env=env, + url=server_config.url, + tools=list(server_config.tools), + ) + ) + return resolved diff --git a/sdks/python/agenta/sdk/agents/mcp/wire.py b/sdks/python/agenta/sdk/agents/mcp/wire.py new file mode 100644 index 0000000000..f9c1a7cb68 --- /dev/null +++ b/sdks/python/agenta/sdk/agents/mcp/wire.py @@ -0,0 +1,17 @@ +"""Serialization of resolved MCP servers to the runner contract.""" + +from __future__ import annotations + +from typing import Any, Dict, Sequence + +from .models import ResolvedMCPServer + + +def mcp_server_to_wire(server: ResolvedMCPServer) -> Dict[str, Any]: + return server.to_wire() + + +def mcp_servers_to_wire( + servers: Sequence[ResolvedMCPServer], +) -> list[Dict[str, Any]]: + return [mcp_server_to_wire(server) for server in servers] diff --git a/sdks/python/agenta/sdk/agents/tools/__init__.py b/sdks/python/agenta/sdk/agents/tools/__init__.py new file mode 100644 index 0000000000..2b40dc082e --- /dev/null +++ b/sdks/python/agenta/sdk/agents/tools/__init__.py @@ -0,0 +1,75 @@ +"""Public agent-tool configuration and resolution API.""" + +from .compat import ( + ToolConfigDiagnostic, + ToolConfigParseResult, + coerce_tool_config, + coerce_tool_configs, +) +from .errors import ( + DuplicateToolNameError, + GatewayToolResolutionError, + MissingToolSecretError, + ToolConfigError, + ToolConfigurationError, + ToolError, + ToolResolutionError, + UnsupportedToolProviderError, +) +from .interfaces import GatewayToolResolver, ToolSecretProvider +from .models import ( + BuiltinToolConfig, + CallbackToolSpec, + ClientToolConfig, + ClientToolSpec, + CodeToolConfig, + CodeToolSpec, + GatewayToolConfig, + GatewayToolResolution, + MissingSecretPolicy, + ResolvedToolSet, + ToolCallback, + ToolConfig, + ToolConfigBase, + ToolSpec, +) +from .parsing import parse_tool_config, parse_tool_configs +from .resolver import EnvironmentToolSecretProvider, ToolResolver +from .wire import tool_spec_to_wire, tool_specs_to_wire + +__all__ = [ + "ToolConfigBase", + "ToolConfig", + "BuiltinToolConfig", + "GatewayToolConfig", + "CodeToolConfig", + "ClientToolConfig", + "ToolSpec", + "CallbackToolSpec", + "CodeToolSpec", + "ClientToolSpec", + "ToolCallback", + "ResolvedToolSet", + "GatewayToolResolution", + "MissingSecretPolicy", + "ToolResolver", + "ToolSecretProvider", + "GatewayToolResolver", + "EnvironmentToolSecretProvider", + "parse_tool_config", + "parse_tool_configs", + "coerce_tool_config", + "coerce_tool_configs", + "ToolConfigDiagnostic", + "ToolConfigParseResult", + "tool_spec_to_wire", + "tool_specs_to_wire", + "ToolError", + "ToolConfigError", + "ToolConfigurationError", + "ToolResolutionError", + "GatewayToolResolutionError", + "UnsupportedToolProviderError", + "MissingToolSecretError", + "DuplicateToolNameError", +] diff --git a/sdks/python/agenta/sdk/agents/tools/compat.py b/sdks/python/agenta/sdk/agents/tools/compat.py new file mode 100644 index 0000000000..e356abfdde --- /dev/null +++ b/sdks/python/agenta/sdk/agents/tools/compat.py @@ -0,0 +1,132 @@ +"""Compatibility conversion for legacy playground and persisted tool shapes.""" + +from __future__ import annotations + +from typing import Any, Literal, Optional, Sequence + +from pydantic import BaseModel, ConfigDict, Field + +from .errors import ToolConfigurationError +from .models import ( + BuiltinToolConfig, + ClientToolConfig, + CodeToolConfig, + GatewayToolConfig, + ToolConfig, +) +from .parsing import parse_tool_config + + +class ToolConfigDiagnostic(BaseModel): + model_config = ConfigDict(frozen=True) + + index: int + message: str + + +class ToolConfigParseResult(BaseModel): + model_config = ConfigDict(frozen=True) + + tool_configs: list[ToolConfig] = Field(default_factory=list) + diagnostics: list[ToolConfigDiagnostic] = Field(default_factory=list) + + +def _parse_gateway_slug(slug: Any) -> Optional[dict[str, Any]]: + if not isinstance(slug, str): + return None + parts = slug.replace("__", ".").split(".") + if len(parts) != 5 or parts[0] != "tools": + return None + return { + "type": "gateway", + "provider": parts[1], + "integration": parts[2], + "action": parts[3], + "connection": parts[4], + } + + +def _copy_tool_metadata( + source: dict[str, Any], target: dict[str, Any] +) -> dict[str, Any]: + result = dict(target) + if "needs_approval" in source: + result["needs_approval"] = bool(source["needs_approval"]) + if isinstance(source.get("render"), dict): + result["render"] = dict(source["render"]) + return result + + +def coerce_tool_config(value: Any) -> ToolConfig: + """Convert one supported legacy shape into canonical tool configuration.""" + if isinstance( + value, + ( + BuiltinToolConfig, + GatewayToolConfig, + CodeToolConfig, + ClientToolConfig, + ), + ): + return value + if isinstance(value, str): + return BuiltinToolConfig(name=value) + if not isinstance(value, dict): + raise ToolConfigurationError( + "Tool configuration must be a string or mapping", + value=value, + ) + + data = dict(value) + if data.get("type") == "composio": + data["type"] = "gateway" + data.setdefault("provider", "composio") + + if data.get("type") in {"builtin", "gateway", "code", "client"}: + return parse_tool_config(data) + + function = data.get("function") if isinstance(data.get("function"), dict) else {} + gateway = _parse_gateway_slug(function.get("name") or data.get("name")) + if gateway: + return parse_tool_config(_copy_tool_metadata(data, gateway)) + + if isinstance(data.get("name"), str) and "type" not in data: + return BuiltinToolConfig(name=data["name"]) + + raise ToolConfigurationError("Unsupported tool configuration shape", value=value) + + +def coerce_tool_configs( + values: Optional[Sequence[Any]], + *, + on_error: Literal["raise", "collect"] = "raise", +) -> ToolConfigParseResult: + """Convert legacy values, either raising or returning structured diagnostics.""" + tool_configs: list[ToolConfig] = [] + diagnostics: list[ToolConfigDiagnostic] = [] + for index, value in enumerate(values or []): + if value is None: + error = ToolConfigurationError( + "Tool configuration cannot be null", + index=index, + value=value, + ) + else: + try: + tool_configs.append(coerce_tool_config(value)) + continue + except ToolConfigurationError as exc: + error = ToolConfigurationError( + str(exc), + index=index, + value=value, + ) + + if on_error == "raise": + raise error + diagnostics.append(ToolConfigDiagnostic(index=index, message=str(error))) + + return ToolConfigParseResult( + tool_configs=tool_configs, + diagnostics=diagnostics, + ) diff --git a/sdks/python/agenta/sdk/agents/tools/errors.py b/sdks/python/agenta/sdk/agents/tools/errors.py new file mode 100644 index 0000000000..24d62614c4 --- /dev/null +++ b/sdks/python/agenta/sdk/agents/tools/errors.py @@ -0,0 +1,82 @@ +"""Errors raised while parsing and resolving agent tools.""" + +from __future__ import annotations + +from typing import Any, Optional, Sequence + + +class ToolError(RuntimeError): + """Base error for the agent tools domain.""" + + +class ToolConfigurationError(ToolError): + """Raised when tool configuration cannot be converted to a canonical model.""" + + def __init__( + self, + message: str, + *, + index: Optional[int] = None, + value: Any = None, + ) -> None: + super().__init__(message) + self.index = index + self.value = value + + +ToolConfigError = ToolConfigurationError + + +class ToolResolutionError(ToolError): + """Raised when tool configuration cannot become runnable specifications.""" + + def __init__( + self, + message: str, + *, + status: Optional[int] = None, + ref_count: Optional[int] = None, + spec_count: Optional[int] = None, + provider: Optional[str] = None, + reference: Optional[str] = None, + ) -> None: + super().__init__(message) + self.status = status + self.ref_count = ref_count + self.spec_count = spec_count + self.provider = provider + self.reference = reference + + +class GatewayToolResolutionError(ToolResolutionError): + """Raised when a gateway adapter cannot resolve a configured tool.""" + + +class UnsupportedToolProviderError(ToolResolutionError): + """Raised when no resolver is available for a configured gateway provider.""" + + def __init__(self, provider: str) -> None: + super().__init__( + f"Unsupported tool provider: {provider}", + provider=provider, + ) + + +class MissingToolSecretError(ToolResolutionError): + """Raised when a tool declares required secrets that a provider cannot supply.""" + + def __init__(self, *, tool_name: str, secret_names: Sequence[str]) -> None: + names = tuple(secret_names) + super().__init__( + f"Tool '{tool_name}' is missing required secret(s): {', '.join(names)}" + ) + self.tool_name = tool_name + self.secret_names = names + + +class DuplicateToolNameError(ToolResolutionError): + """Raised when two configured tools resolve to the same model-visible name.""" + + def __init__(self, name: str) -> None: + super().__init__(f"Duplicate tool name: {name}") + self.name = name diff --git a/sdks/python/agenta/sdk/agents/tools/interfaces.py b/sdks/python/agenta/sdk/agents/tools/interfaces.py new file mode 100644 index 0000000000..3ccc4c767c --- /dev/null +++ b/sdks/python/agenta/sdk/agents/tools/interfaces.py @@ -0,0 +1,20 @@ +"""Injected dependencies used by the tool resolver.""" + +from __future__ import annotations + +from typing import Mapping, Protocol, Sequence + +from .models import GatewayToolConfig, GatewayToolResolution + + +class ToolSecretProvider(Protocol): + async def get_many(self, names: Sequence[str]) -> Mapping[str, str]: + """Return available values for the requested secret names.""" + + +class GatewayToolResolver(Protocol): + async def resolve( + self, + tools: Sequence[GatewayToolConfig], + ) -> GatewayToolResolution: + """Resolve gateway declarations into callback specifications.""" diff --git a/sdks/python/agenta/sdk/agents/tools/models.py b/sdks/python/agenta/sdk/agents/tools/models.py new file mode 100644 index 0000000000..6e467f51dd --- /dev/null +++ b/sdks/python/agenta/sdk/agents/tools/models.py @@ -0,0 +1,221 @@ +"""Canonical tool configuration and resolved runtime specifications.""" + +from __future__ import annotations + +from enum import Enum +from typing import Annotated, Any, Dict, List, Literal, Optional, Union + +from pydantic import ( + AliasChoices, + BaseModel, + ConfigDict, + Field, + TypeAdapter, + field_validator, +) + + +def _empty_object_schema() -> Dict[str, Any]: + return {"type": "object", "properties": {}} + + +class ToolConfigBase(BaseModel): + """Fields shared by every persisted tool declaration.""" + + model_config = ConfigDict(extra="forbid") + + needs_approval: bool = False + render: Optional[Dict[str, Any]] = None + + +class BuiltinToolConfig(ToolConfigBase): + type: Literal["builtin"] = "builtin" + name: str = Field(min_length=1) + + +class GatewayToolConfig(ToolConfigBase): + type: Literal["gateway"] = "gateway" + provider: str = Field(default="composio", min_length=1) + integration: str = Field(min_length=1) + action: str = Field(min_length=1) + connection: str = Field(min_length=1) + name: Optional[str] = Field(default=None, min_length=1) + + @property + def reference(self) -> str: + return ( + f"tools.{self.provider}.{self.integration}.{self.action}.{self.connection}" + ) + + +class CodeToolConfig(ToolConfigBase): + type: Literal["code"] = "code" + name: str = Field(min_length=1) + description: Optional[str] = None + runtime: Literal["python", "node"] = "python" + script: str = Field(min_length=1) + input_schema: Dict[str, Any] = Field(default_factory=_empty_object_schema) + secrets: List[str] = Field(default_factory=list) + + +class ClientToolConfig(ToolConfigBase): + type: Literal["client"] = "client" + name: str = Field(min_length=1) + description: Optional[str] = None + input_schema: Dict[str, Any] = Field(default_factory=_empty_object_schema) + + +ToolConfig = Annotated[ + Union[ + BuiltinToolConfig, + GatewayToolConfig, + CodeToolConfig, + ClientToolConfig, + ], + Field(discriminator="type"), +] +TOOL_CONFIG_ADAPTER: TypeAdapter[ToolConfig] = TypeAdapter(ToolConfig) + + +class ToolCallback(BaseModel): + """Where callback tool calls are sent.""" + + model_config = ConfigDict(frozen=True) + + endpoint: str + authorization: Optional[str] = Field(default=None, repr=False) + + def to_wire(self) -> Dict[str, Any]: + return { + "endpoint": self.endpoint, + "authorization": self.authorization, + } + + +class ToolSpecBase(BaseModel): + """Fields shared by every resolved, runner-ready tool specification.""" + + model_config = ConfigDict( + extra="forbid", + frozen=True, + populate_by_name=True, + ) + + name: str + description: str + input_schema: Dict[str, Any] = Field( + default_factory=_empty_object_schema, + validation_alias=AliasChoices("input_schema", "inputSchema"), + serialization_alias="inputSchema", + ) + needs_approval: bool = Field( + default=False, + validation_alias=AliasChoices("needs_approval", "needsApproval"), + serialization_alias="needsApproval", + ) + render: Optional[Dict[str, Any]] = None + + def to_wire(self) -> Dict[str, Any]: + wire = self.model_dump( + mode="json", + by_alias=True, + exclude_none=True, + ) + if not self.needs_approval: + wire.pop("needsApproval", None) + if not wire.get("env"): + wire.pop("env", None) + return wire + + +class CallbackToolSpec(ToolSpecBase): + kind: Literal["callback"] = "callback" + call_ref: str = Field( + validation_alias=AliasChoices("call_ref", "callRef"), + serialization_alias="callRef", + ) + + +class CodeToolSpec(ToolSpecBase): + kind: Literal["code"] = "code" + runtime: Literal["python", "node"] = "python" + code: str + env: Dict[str, str] = Field(default_factory=dict, repr=False) + + +class ClientToolSpec(ToolSpecBase): + kind: Literal["client"] = "client" + + +ToolSpec = Annotated[ + Union[CallbackToolSpec, CodeToolSpec, ClientToolSpec], + Field(discriminator="kind"), +] +TOOL_SPEC_ADAPTER: TypeAdapter[ToolSpec] = TypeAdapter(ToolSpec) + + +def coerce_tool_spec(value: Any) -> ToolSpec: + if isinstance(value, (CallbackToolSpec, CodeToolSpec, ClientToolSpec)): + return value + if not isinstance(value, dict): + raise TypeError("tool spec must be a mapping") + data = dict(value) + if not data.get("kind"): + if data.get("callRef") or data.get("call_ref"): + data["kind"] = "callback" + elif data.get("code") is not None: + data["kind"] = "code" + else: + data["kind"] = "client" + name = data.get("name") + data.setdefault("description", name) + data.setdefault("inputSchema", _empty_object_schema()) + return TOOL_SPEC_ADAPTER.validate_python(data) + + +class MissingSecretPolicy(str, Enum): + ERROR = "error" + OMIT = "omit" + + +class ResolvedToolSet(BaseModel): + """Resolved tools ready to attach to a session.""" + + model_config = ConfigDict( + frozen=True, + populate_by_name=True, + ) + + builtin_names: List[str] = Field( + default_factory=list, + validation_alias=AliasChoices("builtin_names", "builtin_tools"), + ) + tool_specs: List[ToolSpec] = Field( + default_factory=list, + validation_alias=AliasChoices("tool_specs", "custom_tools"), + ) + tool_callback: Optional[ToolCallback] = None + + @field_validator("tool_specs", mode="before") + @classmethod + def _coerce_specs(cls, value: Any) -> List[ToolSpec]: + return [coerce_tool_spec(item) for item in value or []] + + @property + def builtin_tools(self) -> List[str]: + """Compatibility alias for the previous field name.""" + return list(self.builtin_names) + + @property + def custom_tools(self) -> List[Dict[str, Any]]: + """Compatibility wire dictionaries for callers not yet using typed specs.""" + return [spec.to_wire() for spec in self.tool_specs] + + +class GatewayToolResolution(BaseModel): + """Result returned by an injected gateway adapter.""" + + model_config = ConfigDict(frozen=True) + + tool_specs: List[CallbackToolSpec] = Field(default_factory=list) + tool_callback: ToolCallback diff --git a/sdks/python/agenta/sdk/agents/tools/parsing.py b/sdks/python/agenta/sdk/agents/tools/parsing.py new file mode 100644 index 0000000000..b5779caa19 --- /dev/null +++ b/sdks/python/agenta/sdk/agents/tools/parsing.py @@ -0,0 +1,39 @@ +"""Strict parsing of canonical tool configuration.""" + +from __future__ import annotations + +from typing import Any, Mapping, Sequence + +from pydantic import ValidationError + +from .errors import ToolConfigurationError +from .models import TOOL_CONFIG_ADAPTER, ToolConfig + + +def parse_tool_config(value: ToolConfig | Mapping[str, Any]) -> ToolConfig: + """Parse one canonical tool mapping, rejecting legacy and unexpected fields.""" + try: + return TOOL_CONFIG_ADAPTER.validate_python(value) + except ValidationError as exc: + raise ToolConfigurationError( + "Invalid tool configuration: " + f"{exc.errors(include_url=False, include_input=False)}", + value=value, + ) from exc + + +def parse_tool_configs( + values: Sequence[ToolConfig | Mapping[str, Any]], +) -> list[ToolConfig]: + """Parse canonical tool mappings and report the failing item index.""" + parsed: list[ToolConfig] = [] + for index, value in enumerate(values): + try: + parsed.append(parse_tool_config(value)) + except ToolConfigurationError as exc: + raise ToolConfigurationError( + str(exc), + index=index, + value=value, + ) from exc + return parsed diff --git a/sdks/python/agenta/sdk/agents/tools/resolver.py b/sdks/python/agenta/sdk/agents/tools/resolver.py new file mode 100644 index 0000000000..54f4c8b03f --- /dev/null +++ b/sdks/python/agenta/sdk/agents/tools/resolver.py @@ -0,0 +1,177 @@ +"""Resolution of canonical tool configuration into runnable specifications.""" + +from __future__ import annotations + +import os +from typing import Mapping, Optional, Sequence + +from .errors import ( + DuplicateToolNameError, + MissingToolSecretError, + UnsupportedToolProviderError, +) +from .interfaces import GatewayToolResolver, ToolSecretProvider +from .models import ( + BuiltinToolConfig, + ClientToolConfig, + ClientToolSpec, + CodeToolConfig, + CodeToolSpec, + GatewayToolConfig, + MissingSecretPolicy, + ResolvedToolSet, + ToolConfig, + ToolSpec, +) + + +class EnvironmentToolSecretProvider: + """Read declared tool secrets from the current process environment.""" + + async def get_many(self, names: Sequence[str]) -> Mapping[str, str]: + return { + name: value for name in names if (value := os.environ.get(name)) is not None + } + + +def _apply_tool_metadata(tool_spec: ToolSpec, tool_config: ToolConfig) -> ToolSpec: + """Return a new spec carrying the config's approval and rendering metadata.""" + return tool_spec.model_copy( + update={ + "needs_approval": tool_config.needs_approval, + "render": tool_config.render, + } + ) + + +def _build_code_tool_spec( + *, + tool_config: CodeToolConfig, + env: Mapping[str, str], +) -> CodeToolSpec: + return _apply_tool_metadata( + CodeToolSpec( + name=tool_config.name, + description=tool_config.description or tool_config.name, + input_schema=tool_config.input_schema, + runtime=tool_config.runtime, + code=tool_config.script, + env=dict(env), + ), + tool_config, + ) + + +def _build_client_tool_spec(*, tool_config: ClientToolConfig) -> ClientToolSpec: + return _apply_tool_metadata( + ClientToolSpec( + name=tool_config.name, + description=tool_config.description or tool_config.name, + input_schema=tool_config.input_schema, + ), + tool_config, + ) + + +def _validate_unique_names( + *, + builtin_names: Sequence[str], + tool_specs: Sequence[ToolSpec], +) -> None: + seen: set[str] = set() + for name in [*builtin_names, *(tool_spec.name for tool_spec in tool_specs)]: + if name in seen: + raise DuplicateToolNameError(name) + seen.add(name) + + +class ToolResolver: + """Resolve canonical tool configuration through injected secret and gateway adapters.""" + + def __init__( + self, + *, + secret_provider: Optional[ToolSecretProvider] = None, + gateway_resolver: Optional[GatewayToolResolver] = None, + missing_secret_policy: MissingSecretPolicy = MissingSecretPolicy.ERROR, + ) -> None: + self._secret_provider = secret_provider or EnvironmentToolSecretProvider() + self._gateway_resolver = gateway_resolver + self._missing_secret_policy = missing_secret_policy + + async def resolve(self, tool_configs: Sequence[ToolConfig]) -> ResolvedToolSet: + builtin_names = [ + tool_config.name + for tool_config in tool_configs + if isinstance(tool_config, BuiltinToolConfig) + ] + code_configs = [ + tool_config + for tool_config in tool_configs + if isinstance(tool_config, CodeToolConfig) + ] + client_configs = [ + tool_config + for tool_config in tool_configs + if isinstance(tool_config, ClientToolConfig) + ] + gateway_configs = [ + tool_config + for tool_config in tool_configs + if isinstance(tool_config, GatewayToolConfig) + ] + + secret_names = sorted( + { + secret_name + for tool_config in code_configs + for secret_name in tool_config.secrets + } + ) + secret_values = ( + dict(await self._secret_provider.get_many(secret_names)) + if secret_names + else {} + ) + + tool_specs: list[ToolSpec] = [] + for tool_config in code_configs: + missing = [ + secret_name + for secret_name in tool_config.secrets + if secret_name not in secret_values + ] + if missing and self._missing_secret_policy == MissingSecretPolicy.ERROR: + raise MissingToolSecretError( + tool_name=tool_config.name, + secret_names=missing, + ) + env = { + secret_name: secret_values[secret_name] + for secret_name in tool_config.secrets + if secret_name in secret_values + } + tool_specs.append(_build_code_tool_spec(tool_config=tool_config, env=env)) + + tool_specs.extend( + _build_client_tool_spec(tool_config=tool_config) + for tool_config in client_configs + ) + + tool_callback = None + if gateway_configs: + if self._gateway_resolver is None: + raise UnsupportedToolProviderError(gateway_configs[0].provider) + gateway_resolution = await self._gateway_resolver.resolve(gateway_configs) + tool_specs = [*gateway_resolution.tool_specs, *tool_specs] + tool_callback = gateway_resolution.tool_callback + + _validate_unique_names( + builtin_names=builtin_names, + tool_specs=tool_specs, + ) + return ResolvedToolSet( + builtin_names=builtin_names, + tool_specs=tool_specs, + tool_callback=tool_callback, + ) diff --git a/sdks/python/agenta/sdk/agents/tools/wire.py b/sdks/python/agenta/sdk/agents/tools/wire.py new file mode 100644 index 0000000000..1f716b503d --- /dev/null +++ b/sdks/python/agenta/sdk/agents/tools/wire.py @@ -0,0 +1,15 @@ +"""Serialization of resolved tool specifications to the runner contract.""" + +from __future__ import annotations + +from typing import Any, Dict, Sequence + +from .models import ToolSpec + + +def tool_spec_to_wire(tool_spec: ToolSpec) -> Dict[str, Any]: + return tool_spec.to_wire() + + +def tool_specs_to_wire(tool_specs: Sequence[ToolSpec]) -> list[Dict[str, Any]]: + return [tool_spec_to_wire(tool_spec) for tool_spec in tool_specs] diff --git a/sdks/python/agenta/sdk/agents/ui_messages.py b/sdks/python/agenta/sdk/agents/ui_messages.py new file mode 100644 index 0000000000..41517245aa --- /dev/null +++ b/sdks/python/agenta/sdk/agents/ui_messages.py @@ -0,0 +1,491 @@ +"""UI message codec: translate between the Vercel AI SDK ``UIMessage`` wire shape and the +neutral agent runtime types (``Message`` / ``AgentEvent`` / ``AgentResult``). + +This is the ``/messages`` egress adapter, the parts-aware sibling of +:func:`agenta.sdk.agents.to_messages` (which only understands the ``/invoke`` ``{role, +content}`` shape). The neutral types in ``dtos.py`` stay the port; this module is one more +adapter behind that seam, so the ``/run`` runner wire (``utils/wire.py``, ``{role, content}``) +is unchanged — the Vercel shape lives only at the HTTP edge. + +Three directions: + +- :func:`from_ui_messages` — inbound ``UIMessage[]`` -> ``List[Message]``. Text and file parts + fold into content blocks; tool and approval parts are PRESERVED as structured ``tool_call`` / + ``tool_result`` :class:`~agenta.sdk.agents.dtos.ContentBlock`s (never flattened to text), so a + cross-turn human-in-the-loop reply replays as a real tool turn and the model resumes from the + result. The runner's message transcript renders these blocks into the cold replay. +- :func:`to_ui_message` — outbound ``AgentResult`` / ``Message`` -> one ``UIMessage`` dict, for + the ``load-session`` history. +- :func:`ui_message_stream` — the streaming edge: a live + :class:`~agenta.sdk.agents.streaming.AgentRun` -> Vercel UI Message Stream parts + (``start`` ... ``finish``). The SSE framing and the terminal ``data: [DONE]`` are added by the + routing layer (``_vercel_sse_stream``); this generator yields the part dicts only. +""" + +from __future__ import annotations + +from typing import Any, AsyncIterator, Dict, List, Optional + +from .dtos import AgentResult, ContentBlock, Message +from .streaming import AgentRun + +# Inbound UIMessage part type names handled specially (the rest of ``tool-*`` is a tool call). +_TOOL_APPROVAL_REQUEST = "tool-approval-request" +_TOOL_APPROVAL_RESPONSE = "tool-approval-response" +_TOOL_OUTPUT_AVAILABLE = "tool-output-available" + + +# --------------------------------------------------------------------------- +# Inbound: UIMessage[] -> List[Message] +# --------------------------------------------------------------------------- + + +def from_ui_messages(raw: Optional[List[Any]]) -> List[Message]: + """Coerce inbound Vercel ``UIMessage`` objects into neutral :class:`Message` objects. + + The parts-aware sibling of :func:`agenta.sdk.agents.to_messages`. Tool and approval parts + are preserved as structured tool-call / tool-result content blocks (never dropped), so a + cross-turn human-in-the-loop reply resumes the pending interaction on the next turn. + """ + messages: List[Message] = [] + for item in raw or []: + message = _ui_message_to_message(item) + if message is not None: + messages.append(message) + return messages + + +def _ui_message_to_message(raw: Any) -> Optional[Message]: + if isinstance(raw, Message): + return raw + if not isinstance(raw, dict) or "role" not in raw: + return None + role = str(raw["role"]) + + parts = raw.get("parts") + if parts is None: + # Not a parts-based UIMessage — fall back to the {role, content} shape so a mixed + # history still parses. + return Message.from_raw(raw) + + blocks: List[ContentBlock] = [] + for part in parts or []: + blocks.extend(_part_to_blocks(part)) + + if not blocks: + return Message(role=role, content="") + # Collapse an all-text message to a plain string (the shape the runner replays); keep the + # block list when any structured (file / tool) content is present. + if all(block.type == "text" for block in blocks): + return Message(role=role, content="".join(block.text or "" for block in blocks)) + return Message(role=role, content=blocks) + + +def _part_to_blocks(part: Any) -> List[ContentBlock]: + if not isinstance(part, dict): + return [] + ptype = str(part.get("type", "")) + + if ptype == "text": + text = part.get("text") + return [ContentBlock(type="text", text=text)] if text is not None else [] + + if ptype == "file": + media = part.get("mediaType") or part.get("mimeType") + kind = ( + "image" + if isinstance(media, str) and media.startswith("image/") + else "resource" + ) + return [ + ContentBlock( + type=kind, + uri=part.get("url") or part.get("uri"), + data=part.get("data"), + mime_type=media, + ) + ] + + if ptype == _TOOL_APPROVAL_REQUEST: + # The server's own request, echoed back in history; regenerated on replay, not input. + return [] + + if ptype == _TOOL_APPROVAL_RESPONSE: + return _approval_response_blocks(part) + + if ( + ptype == _TOOL_OUTPUT_AVAILABLE + or ptype == "dynamic-tool" + or ptype.startswith("tool-") + ): + return _tool_part_blocks(part, ptype) + + # reasoning / step-start / data-* parts are the assistant's own output or transient UI; + # they are not model input on replay, so they are dropped. + return [] + + +def _tool_part_blocks(part: Dict[str, Any], ptype: str) -> List[ContentBlock]: + """A Vercel tool part -> a ``tool_call`` block plus, when resolved, a ``tool_result``. + + Field names match what the runner's transcript renders: ``toolCallId`` / ``toolName`` / + ``input`` / ``output`` / ``isError`` (via :meth:`ContentBlock.to_wire`). + """ + tool_call_id = part.get("toolCallId") or part.get("tool_call_id") + tool_name = part.get("toolName") or part.get("tool_name") + if ( + tool_name is None + and ptype.startswith("tool-") + and ptype != _TOOL_OUTPUT_AVAILABLE + ): + tool_name = ptype[len("tool-") :] + + blocks: List[ContentBlock] = [] + + # The call itself (a bare tool-output-available part carries only a result). + if ptype != _TOOL_OUTPUT_AVAILABLE or "input" in part: + blocks.append( + ContentBlock( + type="tool_call", + tool_call_id=tool_call_id, + tool_name=tool_name, + input=part.get("input"), + ) + ) + + state = part.get("state") + error_text = part.get("errorText") + if error_text is not None or state == "output-error": + blocks.append( + ContentBlock( + type="tool_result", + tool_call_id=tool_call_id, + tool_name=tool_name, + output=error_text if error_text is not None else part.get("output"), + is_error=True, + ) + ) + elif "output" in part or state == "output-available": + blocks.append( + ContentBlock( + type="tool_result", + tool_call_id=tool_call_id, + tool_name=tool_name, + output=part.get("output"), + is_error=False, + ) + ) + return blocks + + +def _approval_response_blocks(part: Dict[str, Any]) -> List[ContentBlock]: + """A cross-turn ``tool-approval-response`` reply -> a ``tool_result`` keyed by toolCallId, + so the runtime matches the pending interaction and resumes (the resolve step is joint).""" + tool_call_id = ( + part.get("toolCallId") or part.get("tool_call_id") or part.get("approvalId") + ) + output = part.get("output") + if output is None: + approved = part.get("approved") + output = {"approved": approved} if approved is not None else part.get("reason") + return [ContentBlock(type="tool_result", tool_call_id=tool_call_id, output=output)] + + +# --------------------------------------------------------------------------- +# Outbound (batch): AgentResult / Message -> one UIMessage dict +# --------------------------------------------------------------------------- + + +def to_ui_message(source: Any, *, message_id: str = "msg-1") -> Dict[str, Any]: + """Render an :class:`AgentResult` or :class:`Message` as one Vercel ``UIMessage`` dict + (the shape ``load-session`` returns and ``useChat`` takes as its initial messages).""" + if isinstance(source, AgentResult): + return { + "id": message_id, + "role": "assistant", + "parts": [{"type": "text", "text": source.output or ""}], + } + if isinstance(source, Message): + return { + "id": message_id, + "role": source.role, + "parts": _content_to_parts(source.content), + } + raise TypeError( + f"to_ui_message expects an AgentResult or Message, got {type(source).__name__!r}" + ) + + +def _content_to_parts(content: Any) -> List[Dict[str, Any]]: + if isinstance(content, str): + return [{"type": "text", "text": content}] if content else [] + parts: List[Dict[str, Any]] = [] + for block in content or []: + parts.extend(_block_to_parts(block)) + return parts + + +def _block_to_parts(block: ContentBlock) -> List[Dict[str, Any]]: + if block.type == "text": + return [{"type": "text", "text": block.text or ""}] + if block.type in ("image", "resource"): + part: Dict[str, Any] = {"type": "file"} + if block.uri is not None: + part["url"] = block.uri + if block.mime_type is not None: + part["mediaType"] = block.mime_type + if block.data is not None: + part["data"] = block.data + return [part] + if block.type == "tool_call": + return [ + { + "type": f"tool-{block.tool_name or 'tool'}", + "toolCallId": block.tool_call_id, + "state": "input-available", + "input": block.input, + } + ] + if block.type == "tool_result": + return [ + { + "type": f"tool-{block.tool_name or 'tool'}", + "toolCallId": block.tool_call_id, + "state": "output-error" if block.is_error else "output-available", + "output": block.output, + } + ] + return [] + + +# --------------------------------------------------------------------------- +# Streaming edge: a live AgentRun -> Vercel UI Message Stream parts +# --------------------------------------------------------------------------- + + +async def ui_message_stream( + run: AgentRun, + *, + session_id: Optional[str] = None, + message_id: str = "msg-1", + trace_id: Optional[str] = None, +) -> AsyncIterator[Dict[str, Any]]: + """Encode a live :class:`AgentRun` as Vercel UI Message Stream part dicts. + + Consumes the run's live ``AgentEvent`` stream and yields parts as they arrive: ``start`` + (carrying ``messageMetadata.sessionId``) first, then the body, then ``finish`` (carrying + ``messageMetadata.traceId`` so the client can open the run's OTel trace, RFC §6.1). The SSE + framing and the terminal ``data: [DONE]`` are added by the routing layer + (``_vercel_sse_stream``). On a terminal run failure the run raises while iterating; that is + surfaced as an ``error`` part (RFC §8.2) and the stream ends without a ``finish``. + """ + start: Dict[str, Any] = {"type": "start", "messageId": message_id} + if session_id is not None: + start["messageMetadata"] = {"sessionId": session_id} + yield start + yield {"type": "start-step"} + + text_seq = 0 + reasoning_seq = 0 + usage: Optional[Dict[str, Any]] = None + stop_reason: Optional[str] = None + + try: + async for event in run: + etype = event.type + data = event.data + + if etype == "message": + text_seq += 1 + tid = f"text-{text_seq}" + yield {"type": "text-start", "id": tid} + yield {"type": "text-delta", "id": tid, "delta": data.get("text", "")} + yield {"type": "text-end", "id": tid} + elif etype == "message_start": + yield {"type": "text-start", "id": data.get("id")} + elif etype == "message_delta": + yield { + "type": "text-delta", + "id": data.get("id"), + "delta": data.get("delta", ""), + } + elif etype == "message_end": + yield {"type": "text-end", "id": data.get("id")} + elif etype == "thought": + reasoning_seq += 1 + rid = f"reasoning-{reasoning_seq}" + yield {"type": "reasoning-start", "id": rid} + yield { + "type": "reasoning-delta", + "id": rid, + "delta": data.get("text", ""), + } + yield {"type": "reasoning-end", "id": rid} + elif etype == "reasoning_start": + yield {"type": "reasoning-start", "id": data.get("id")} + elif etype == "reasoning_delta": + yield { + "type": "reasoning-delta", + "id": data.get("id"), + "delta": data.get("delta", ""), + } + elif etype == "reasoning_end": + yield {"type": "reasoning-end", "id": data.get("id")} + elif etype == "tool_call": + tool_call_id = data.get("id") + tool_name = data.get("name") + yield { + "type": "tool-input-start", + "toolCallId": tool_call_id, + "toolName": tool_name, + } + available: Dict[str, Any] = { + "type": "tool-input-available", + "toolCallId": tool_call_id, + "toolName": tool_name, + "input": data.get("input"), + } + if data.get("render") is not None: + available["render"] = data["render"] + yield available + elif etype == "tool_result": + tool_call_id = data.get("id") + if data.get("denied"): + # A human denied the tool, so it never ran (RFC: emit tool-output-denied + # instead of tool-output-available; the FE renders the output-denied state). + yield { + "type": "tool-output-denied", + "toolCallId": tool_call_id, + } + elif data.get("isError"): + yield { + "type": "tool-output-error", + "toolCallId": tool_call_id, + "errorText": _as_text(data.get("output")), + } + else: + # Prefer the structured object (generative UI); fall back to the text form. + structured = data.get("data") + out = structured if structured is not None else data.get("output") + available = { + "type": "tool-output-available", + "toolCallId": tool_call_id, + "output": out, + } + if data.get("render") is not None: + available["render"] = data["render"] + yield available + elif etype == "interaction_request": + yield _interaction_part(data) + elif etype == "data": + part: Dict[str, Any] = { + "type": f"data-{data.get('name', 'data')}", + "data": data.get("data"), + } + if data.get("transient"): + part["transient"] = True + yield part + elif etype == "file": + yield { + "type": "file", + "url": data.get("url"), + "mediaType": data.get("mediaType"), + } + elif etype == "usage": + usage = _usage_metadata(data) + elif etype == "error": + yield {"type": "error", "errorText": data.get("message", "")} + elif etype == "done": + stop_reason = data.get("stopReason") + # unknown event types are ignored + except Exception as exc: # AgentRun raises on a terminal ok=false result + yield {"type": "error", "errorText": str(exc)} + return + + # Pull usage and the trace id from the terminal result when not already known, the same + # fallback both lean on (RFC §6.1: the finish trace id matches the JSON response's). + if usage is None or trace_id is None: + result = _safe_result(run) + if result is not None: + if usage is None: + usage = _usage_metadata(result.usage or {}) + if stop_reason is None: + stop_reason = result.stop_reason + if trace_id is None: + trace_id = result.trace_id + + yield {"type": "finish-step"} + finish: Dict[str, Any] = {"type": "finish"} + if stop_reason is not None: + finish["finishReason"] = stop_reason + # usage and traceId coexist under messageMetadata; the client reads message.metadata.traceId. + metadata: Dict[str, Any] = {} + if usage: + metadata["usage"] = usage + if trace_id is not None: + metadata["traceId"] = trace_id + if metadata: + finish["messageMetadata"] = metadata + yield finish + + +def _interaction_part(data: Dict[str, Any]) -> Dict[str, Any]: + """Project an ``interaction_request`` event to a Vercel part. Permission -> an approval + request; input -> a forward-looking input part; any other kind (e.g. ``client_tool``) -> + a generic interaction part so it is surfaced, not dropped (the resolve step is joint).""" + kind = data.get("kind") + payload = data.get("payload") or {} + if kind == "permission": + return { + "type": _TOOL_APPROVAL_REQUEST, + "approvalId": data.get("id"), + # REQUIRED alongside approvalId (RFC / AI SDK chunk): the gated tool's call id, so + # the FE binds the approval to its existing tool part and the inbound + # tool-approval-response (keyed by toolCallId) correlates back for the cross-turn + # resume. Prefer the top-level toolCallId the runner emits; fall back to the nested + # ACP toolCall detail (id / toolCallId). + "toolCallId": _approval_tool_call_id(payload), + "availableReplies": payload.get("availableReplies"), + "toolCall": payload.get("toolCall"), + } + if kind == "input": + return {"type": "data-input-request", "id": data.get("id"), "data": payload} + return { + "type": "data-interaction", + "id": data.get("id"), + "data": {"kind": kind, "payload": payload}, + } + + +def _approval_tool_call_id(payload: Dict[str, Any]) -> Optional[Any]: + """The gated tool's call id for a ``tool-approval-request``. The runner stamps a top-level + ``toolCallId`` on the permission payload; if it is absent, dig it out of the nested ACP + ``toolCall`` detail (``id`` / ``toolCallId``).""" + tool_call_id = payload.get("toolCallId") + if tool_call_id is not None: + return tool_call_id + tool_call = payload.get("toolCall") + if isinstance(tool_call, dict): + return tool_call.get("id") or tool_call.get("toolCallId") + return None + + +def _usage_metadata(data: Dict[str, Any]) -> Dict[str, Any]: + return { + key: data[key] + for key in ("input", "output", "total", "cost") + if data.get(key) is not None + } + + +def _as_text(value: Any) -> str: + if value is None: + return "" + return value if isinstance(value, str) else str(value) + + +def _safe_result(run: AgentRun) -> Optional[AgentResult]: + try: + return run.result() + except Exception: # result not available (stream not fully consumed / failed) + return None diff --git a/sdks/python/agenta/sdk/agents/utils/wire.py b/sdks/python/agenta/sdk/agents/utils/wire.py index e21ae6268d..b7558a4530 100644 --- a/sdks/python/agenta/sdk/agents/utils/wire.py +++ b/sdks/python/agenta/sdk/agents/utils/wire.py @@ -38,6 +38,8 @@ def request_to_wire( own (Pi: built-ins + native specs, no gating; Claude: MCP specs + permission policy). ``config.wire_prompt()`` adds any system-prompt overrides the harness exposes (Pi's ``systemPrompt`` / ``appendSystemPrompt``); it is empty for harnesses that have none. + ``config.wire_mcp()`` adds user-declared MCP servers, omitted when there are none so a + tool-free run's payload is unchanged. """ return { "backend": engine, @@ -51,6 +53,7 @@ def request_to_wire( "trace": trace.to_wire() if trace else None, **config.wire_tools(), **config.wire_prompt(), + **config.wire_mcp(), } diff --git a/sdks/python/oss/tests/pytest/integration/agents/__init__.py b/sdks/python/oss/tests/pytest/integration/agents/__init__.py new file mode 100644 index 0000000000..de6d92eeaf --- /dev/null +++ b/sdks/python/oss/tests/pytest/integration/agents/__init__.py @@ -0,0 +1 @@ +# Integration tests for the agent runtime: the real wire + transport against a fake runner. diff --git a/sdks/python/oss/tests/pytest/integration/agents/test_transport_roundtrip.py b/sdks/python/oss/tests/pytest/integration/agents/test_transport_roundtrip.py new file mode 100644 index 0000000000..a73c30eecc --- /dev/null +++ b/sdks/python/oss/tests/pytest/integration/agents/test_transport_roundtrip.py @@ -0,0 +1,113 @@ +"""End-to-end through the real wire and transport, against a fake runner. + +This is the Python-only stand-in for a live ``/invoke``: a tiny script plays the runner, +echoing the latest turn. The whole runtime path is real -- harness translation, the cold +environment lifecycle, ``request_to_wire``, the subprocess transport, and ``result_from_wire`` +-- only the runner program (which would be the TS + Pi + LLM stack) is faked. So it catches +serialization or transport drift that per-side unit tests cannot, with no TS and no LLM. +""" + +from __future__ import annotations + +import sys + +import pytest + +from agenta.sdk.agents import ( + AgentConfig, + Environment, + InProcessPiBackend, + Message, + PiHarness, + SessionConfig, +) + +pytestmark = pytest.mark.integration + + +# A runner that reads the /run request on stdin and echoes the latest user turn as a full +# AgentRunResult on stdout (the camelCase wire shape result_from_wire parses). +_ECHO_RUNNER = """ +import sys, json + +req = json.load(sys.stdin) +text = "" +for message in reversed(req.get("messages") or []): + if message.get("role") == "user": + content = message.get("content") + if isinstance(content, str): + text = content + else: + text = "".join( + block.get("text", "") + for block in content + if isinstance(block, dict) and block.get("type") == "text" + ) + if text: + break + +out = { + "ok": True, + "output": "echo: " + text, + "messages": [{"role": "assistant", "content": "echo: " + text}], + "events": [ + {"type": "message", "text": "echo: " + text}, + {"type": "done", "stopReason": "end_turn"}, + ], + "usage": {"input": 1, "output": 1, "total": 2, "cost": 0.0}, + "stopReason": "end_turn", + "capabilities": {"textMessages": True, "mcpTools": False}, + "sessionId": "sess-fake", + "model": req.get("model"), +} +sys.stdout.write(json.dumps(out)) +""" + +_FAIL_RUNNER = """ +import sys, json +json.load(sys.stdin) +sys.stdout.write(json.dumps({"ok": False, "error": "model exploded"})) +""" + +_SILENT_RUNNER = """ +import sys, json +json.load(sys.stdin) +""" + + +def _backend(tmp_path, body: str) -> InProcessPiBackend: + runner = tmp_path / "fake_runner.py" + runner.write_text(body, encoding="utf-8") + return InProcessPiBackend(command=[sys.executable, str(runner)], cwd=str(tmp_path)) + + +async def test_prompt_round_trips_through_the_real_transport(tmp_path): + harness = PiHarness(Environment(_backend(tmp_path, _ECHO_RUNNER))) + config = SessionConfig(agent=AgentConfig(instructions="hi", model="gpt-5.5")) + + result = await harness.prompt(config, [Message(role="user", content="ping")]) + + # The runner saw the wired turn and model, and the result parsed back cleanly. + assert result.output == "echo: ping" + assert result.model == "gpt-5.5" + assert [e.type for e in result.events] == ["message", "done"] + assert result.capabilities is not None and result.capabilities.mcp_tools is False + # The session id is parsed and carried forward for a follow-up turn. + assert result.session_id == "sess-fake" + assert config.session_id == "sess-fake" + + +async def test_runner_failure_surfaces_as_runtime_error(tmp_path): + harness = PiHarness(Environment(_backend(tmp_path, _FAIL_RUNNER))) + config = SessionConfig(agent=AgentConfig(instructions="hi")) + + with pytest.raises(RuntimeError, match="model exploded"): + await harness.prompt(config, [Message(role="user", content="hi")]) + + +async def test_runner_empty_output_raises(tmp_path): + harness = PiHarness(Environment(_backend(tmp_path, _SILENT_RUNNER))) + config = SessionConfig(agent=AgentConfig(instructions="hi")) + + with pytest.raises(RuntimeError, match="no output"): + await harness.prompt(config, [Message(role="user", content="hi")]) diff --git a/sdks/python/oss/tests/pytest/unit/agents/golden/run_request.claude.json b/sdks/python/oss/tests/pytest/unit/agents/golden/run_request.claude.json index 9c6315110e..318722efe5 100644 --- a/sdks/python/oss/tests/pytest/unit/agents/golden/run_request.claude.json +++ b/sdks/python/oss/tests/pytest/unit/agents/golden/run_request.claude.json @@ -16,7 +16,8 @@ "name": "get_user", "description": "Get a user", "inputSchema": {"type": "object", "properties": {}}, - "callRef": "tools__composio__github__GET_THE_AUTHENTICATED_USER__github-tvn" + "callRef": "tools__composio__github__GET_THE_AUTHENTICATED_USER__github-tvn", + "kind": "callback" } ], "toolCallback": { diff --git a/sdks/python/oss/tests/pytest/unit/agents/golden/run_request.pi.json b/sdks/python/oss/tests/pytest/unit/agents/golden/run_request.pi.json index ae1dbae468..ebfb966479 100644 --- a/sdks/python/oss/tests/pytest/unit/agents/golden/run_request.pi.json +++ b/sdks/python/oss/tests/pytest/unit/agents/golden/run_request.pi.json @@ -22,7 +22,8 @@ "name": "get_user", "description": "Get a user", "inputSchema": {"type": "object", "properties": {}}, - "callRef": "tools__composio__github__GET_THE_AUTHENTICATED_USER__github-tvn" + "callRef": "tools__composio__github__GET_THE_AUTHENTICATED_USER__github-tvn", + "kind": "callback" } ], "toolCallback": { diff --git a/sdks/python/oss/tests/pytest/unit/agents/mcp/__init__.py b/sdks/python/oss/tests/pytest/unit/agents/mcp/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/sdks/python/oss/tests/pytest/unit/agents/mcp/__init__.py @@ -0,0 +1 @@ + diff --git a/sdks/python/oss/tests/pytest/unit/agents/mcp/test_resolver.py b/sdks/python/oss/tests/pytest/unit/agents/mcp/test_resolver.py new file mode 100644 index 0000000000..a8a97ab6f0 --- /dev/null +++ b/sdks/python/oss/tests/pytest/unit/agents/mcp/test_resolver.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from typing import Mapping, Sequence + +import pytest +from pydantic import ValidationError + +from agenta.sdk.agents.mcp import ( + MCPResolver, + MCPServerConfig, + MissingMCPSecretError, +) +from agenta.sdk.agents.tools import MissingSecretPolicy + + +class DictSecretProvider: + def __init__(self, values: Mapping[str, str]): + self.values = values + + async def get_many(self, names: Sequence[str]) -> Mapping[str, str]: + return {name: self.values[name] for name in names if name in self.values} + + +def test_transport_specific_fields_are_required(): + with pytest.raises(ValidationError, match="requires command"): + MCPServerConfig(name="stdio") + with pytest.raises(ValidationError, match="requires url"): + MCPServerConfig(name="remote", transport="http") + + +async def test_resolves_mcp_environment_in_sibling_subsystem(): + servers = await MCPResolver( + secret_provider=DictSecretProvider({"github_pat": "ghp"}) + ).resolve( + [ + MCPServerConfig( + name="github", + command="npx", + env={"LOG": "info"}, + secrets={"GITHUB_TOKEN": "github_pat"}, + ) + ] + ) + assert servers[0].to_wire()["env"] == { + "LOG": "info", + "GITHUB_TOKEN": "ghp", + } + + +async def test_missing_mcp_secret_is_explicit(): + with pytest.raises(MissingMCPSecretError): + await MCPResolver(secret_provider=DictSecretProvider({})).resolve( + [ + MCPServerConfig( + name="github", + command="npx", + secrets={"GITHUB_TOKEN": "missing"}, + ) + ] + ) + + +async def test_mcp_compatibility_policy_can_omit_missing_secret(): + servers = await MCPResolver( + secret_provider=DictSecretProvider({}), + missing_secret_policy=MissingSecretPolicy.OMIT, + ).resolve( + [ + MCPServerConfig( + name="github", + command="npx", + secrets={"GITHUB_TOKEN": "missing"}, + ) + ] + ) + assert "env" not in servers[0].to_wire() diff --git a/sdks/python/oss/tests/pytest/unit/agents/test_dtos_agent_config.py b/sdks/python/oss/tests/pytest/unit/agents/test_dtos_agent_config.py index 0b7c4744ee..f4bacd92d4 100644 --- a/sdks/python/oss/tests/pytest/unit/agents/test_dtos_agent_config.py +++ b/sdks/python/oss/tests/pytest/unit/agents/test_dtos_agent_config.py @@ -7,7 +7,11 @@ from __future__ import annotations -from agenta.sdk.agents import AgentConfig, RunSelection +from agenta.sdk.agents import ( + AgentConfig, + BuiltinToolConfig, + RunSelection, +) _DEFAULTS = AgentConfig(instructions="default-md", model="default-model", tools=["d"]) @@ -29,7 +33,7 @@ def test_from_params_agent_element_shape(): ) assert config.instructions == "I" assert config.model == "M" - assert config.tools == [{"type": "builtin", "name": "read"}] + assert config.tools == [BuiltinToolConfig(name="read")] assert config.harness_options == {"pi": {"system": "S"}} @@ -48,7 +52,7 @@ def test_from_params_prompt_template_shape(): ) assert config.instructions == "You are helpful." # system message -> instructions assert config.model == "M" - assert config.tools == ["t"] + assert config.tools == [BuiltinToolConfig(name="t")] def test_from_params_prompt_template_joins_multiple_system_messages(): @@ -76,19 +80,19 @@ def test_from_params_flat_shape(): ) assert config.instructions == "A" assert config.model == "M" - assert config.tools == [{"name": "x"}] + assert config.tools == [BuiltinToolConfig(name="x")] def test_from_params_falls_back_to_defaults(): config = AgentConfig.from_params({}, defaults=_DEFAULTS) assert config.instructions == "default-md" assert config.model == "default-model" - assert config.tools == ["d"] + assert config.tools == [BuiltinToolConfig(name="d")] def test_from_params_coerces_single_tool_dict_to_list(): config = AgentConfig.from_params({"agent": {"tools": {"name": "solo"}}}) - assert config.tools == [{"name": "solo"}] + assert config.tools == [BuiltinToolConfig(name="solo")] def test_harness_options_drops_malformed_and_lowercases_keys(): diff --git a/sdks/python/oss/tests/pytest/unit/agents/test_dtos_harness_configs.py b/sdks/python/oss/tests/pytest/unit/agents/test_dtos_harness_configs.py index 5d96bccad4..1d53c8f469 100644 --- a/sdks/python/oss/tests/pytest/unit/agents/test_dtos_harness_configs.py +++ b/sdks/python/oss/tests/pytest/unit/agents/test_dtos_harness_configs.py @@ -11,6 +11,7 @@ from agenta.sdk.agents import ( ClaudeAgentConfig, + ClientToolSpec, HarnessAgentConfig, PiAgentConfig, ToolCallback, @@ -22,12 +23,24 @@ def test_pi_wire_tools_is_native_and_never_gates(): config = PiAgentConfig( builtin_tools=["read"], - custom_tools=[{"name": "t"}], + tool_specs=[ + ClientToolSpec( + name="t", + description="t", + ) + ], tool_callback=_CALLBACK, ) assert config.wire_tools() == { "tools": ["read"], - "customTools": [{"name": "t"}], + "customTools": [ + { + "name": "t", + "description": "t", + "inputSchema": {"type": "object", "properties": {}}, + "kind": "client", + } + ], "toolCallback": { "endpoint": "https://api.example/tools/call", "authorization": "A", @@ -52,13 +65,25 @@ def test_pi_wire_prompt_emits_only_set_overrides(): def test_claude_wire_tools_has_no_builtins_and_carries_policy(): config = ClaudeAgentConfig( - custom_tools=[{"name": "t"}], + tool_specs=[ + ClientToolSpec( + name="t", + description="t", + ) + ], tool_callback=_CALLBACK, permission_policy="deny", ) wire = config.wire_tools() assert wire["tools"] == [] # Claude has no Pi built-ins - assert wire["customTools"] == [{"name": "t"}] + assert wire["customTools"] == [ + { + "name": "t", + "description": "t", + "inputSchema": {"type": "object", "properties": {}}, + "kind": "client", + } + ] assert wire["permissionPolicy"] == "deny" diff --git a/sdks/python/oss/tests/pytest/unit/agents/test_harness_adapters.py b/sdks/python/oss/tests/pytest/unit/agents/test_harness_adapters.py index d7a03aeed5..7e68d3af93 100644 --- a/sdks/python/oss/tests/pytest/unit/agents/test_harness_adapters.py +++ b/sdks/python/oss/tests/pytest/unit/agents/test_harness_adapters.py @@ -16,6 +16,7 @@ AgentConfig, ClaudeAgentConfig, ClaudeHarness, + ClientToolSpec, HarnessType, PiAgentConfig, PiHarness, @@ -204,11 +205,9 @@ def test_claude_no_warning_without_builtins(make_env, monkeypatch): # --------------------------------------------------------------- _normalize_tool_specs -def test_normalize_tool_specs_fills_defaults_and_drops_malformed(): +def test_compat_normalize_tool_specs_returns_typed_specs(): specs = [ {"name": "keep", "callRef": "r1"}, # missing description + inputSchema - {"description": "no name"}, # dropped: no name - "not a dict", # dropped: not a dict { "name": "full", "description": "d", @@ -219,14 +218,21 @@ def test_normalize_tool_specs_fills_defaults_and_drops_malformed(): out = _normalize_tool_specs(specs) - assert [s["name"] for s in out] == ["keep", "full"] + assert [spec.name for spec in out] == ["keep", "full"] # description falls back to the name; inputSchema falls back to an empty object schema. - assert out[0]["description"] == "keep" - assert out[0]["inputSchema"] == {"type": "object", "properties": {}} - assert out[0]["callRef"] == "r1" + assert out[0].description == "keep" + assert out[0].input_schema == {"type": "object", "properties": {}} + assert out[0].call_ref == "r1" # provided values are preserved. - assert out[1]["description"] == "d" - assert out[1]["inputSchema"]["properties"] == {"x": {}} + assert out[1].description == "d" + assert out[1].input_schema["properties"] == {"x": {}} + + +def test_harness_accepts_typed_tool_specs_without_normalizing_dicts(make_env): + harness = PiHarness(make_env(supported=[HarnessType.PI])) + spec = ClientToolSpec(name="pick", description="Pick") + result = harness._to_harness_config(_session_config(tool_specs=[spec])) + assert result.tool_specs == [spec] def test_normalize_tool_specs_empty(): diff --git a/sdks/python/oss/tests/pytest/unit/agents/test_ui_messages.py b/sdks/python/oss/tests/pytest/unit/agents/test_ui_messages.py new file mode 100644 index 0000000000..aaf9afec10 --- /dev/null +++ b/sdks/python/oss/tests/pytest/unit/agents/test_ui_messages.py @@ -0,0 +1,429 @@ +"""Tests for the UI message codec (``agenta.sdk.agents.ui_messages``), the ``/messages`` +egress adapter between the Vercel ``UIMessage`` shape and the neutral runtime types. + +Three directions: + +- ``from_ui_messages`` — inbound parts -> ``Message``; tool/approval parts are preserved as + structured ``tool_call`` / ``tool_result`` content blocks (the HITL reply channel). +- ``to_ui_message`` — outbound ``AgentResult`` / ``Message`` -> one ``UIMessage`` dict. +- ``ui_message_stream`` — a live ``AgentRun`` -> Vercel UI Message Stream parts. + +The stream tests fabricate an ``AgentRun`` from a fixed record list (the same trick +``test_streaming.py`` uses), so they are pure and need no backend. +""" + +from __future__ import annotations + +from typing import Any, Dict, List + +from agenta.sdk.agents import AgentRun, AgentResult, Message +from agenta.sdk.agents.ui_messages import ( + from_ui_messages, + to_ui_message, + ui_message_stream, +) + + +async def _from_list(records: List[Dict[str, Any]]): + for record in records: + yield record + + +def _run(events: List[Dict[str, Any]], result: Dict[str, Any]) -> AgentRun: + """An ``AgentRun`` over fabricated live events plus a terminal result record.""" + records = [{"kind": "event", "event": e} for e in events] + records.append({"kind": "result", "result": {"ok": True, **result}}) + return AgentRun(_from_list(records)) + + +async def _collect(run: AgentRun, **kwargs) -> List[Dict[str, Any]]: + return [part async for part in ui_message_stream(run, **kwargs)] + + +# --------------------------------------------------------------------------- +# from_ui_messages +# --------------------------------------------------------------------------- + + +class TestFromUIMessages: + def test_all_text_message_collapses_to_string(self): + msgs = from_ui_messages( + [{"id": "m1", "role": "user", "parts": [{"type": "text", "text": "hi"}]}] + ) + assert len(msgs) == 1 + assert msgs[0].role == "user" + assert msgs[0].content == "hi" + + def test_file_part_becomes_image_or_resource_block(self): + msgs = from_ui_messages( + [ + { + "id": "m1", + "role": "user", + "parts": [ + {"type": "text", "text": "look:"}, + {"type": "file", "url": "data:...", "mediaType": "image/png"}, + ], + } + ] + ) + blocks = msgs[0].content + assert [b.type for b in blocks] == ["text", "image"] + assert blocks[1].uri == "data:..." + assert blocks[1].mime_type == "image/png" + + def test_tool_part_is_preserved_as_structured_blocks(self): + # A resolved tool part -> a tool_call block plus a tool_result block, keyed by + # toolCallId, with the field names the runner transcript renders. + msgs = from_ui_messages( + [ + { + "id": "m2", + "role": "assistant", + "parts": [ + { + "type": "tool-getWeather", + "toolCallId": "call_1", + "state": "output-available", + "input": {"city": "Paris"}, + "output": {"weather": "sunny"}, + } + ], + } + ] + ) + wire = [b.to_wire() for b in msgs[0].content] + assert wire == [ + { + "type": "tool_call", + "toolCallId": "call_1", + "toolName": "getWeather", + "input": {"city": "Paris"}, + }, + { + "type": "tool_result", + "toolCallId": "call_1", + "toolName": "getWeather", + "output": {"weather": "sunny"}, + "isError": False, + }, + ] + + def test_tool_error_part_sets_is_error(self): + msgs = from_ui_messages( + [ + { + "id": "m2", + "role": "assistant", + "parts": [ + { + "type": "tool-getWeather", + "toolCallId": "call_1", + "state": "output-error", + "input": {"city": "Paris"}, + "errorText": "boom", + } + ], + } + ] + ) + result_block = msgs[0].content[1] + assert result_block.type == "tool_result" + assert result_block.is_error is True + assert result_block.output == "boom" + + def test_approval_response_becomes_tool_result_keyed_by_call_id(self): + # The cross-turn HITL reply: a tool_result keyed by toolCallId so the runtime resumes. + msgs = from_ui_messages( + [ + { + "id": "m3", + "role": "user", + "parts": [ + { + "type": "tool-approval-response", + "toolCallId": "call_1", + "approved": True, + } + ], + } + ] + ) + block = msgs[0].content[0] + assert block.type == "tool_result" + assert block.tool_call_id == "call_1" + assert block.output == {"approved": True} + + def test_approval_request_part_is_dropped_on_replay(self): + # The server's own request, echoed back; regenerated on replay, not model input. + msgs = from_ui_messages( + [ + { + "id": "m4", + "role": "assistant", + "parts": [ + {"type": "tool-approval-request", "approvalId": "p1"}, + {"type": "text", "text": "thinking"}, + ], + } + ] + ) + assert msgs[0].content == "thinking" + + def test_plain_role_content_message_still_parses(self): + # A non-parts {role, content} message in a mixed history falls back cleanly. + msgs = from_ui_messages([{"role": "user", "content": "hello"}]) + assert msgs[0].content == "hello" + + +# --------------------------------------------------------------------------- +# to_ui_message +# --------------------------------------------------------------------------- + + +class TestToUIMessage: + def test_agent_result_becomes_assistant_text_message(self): + ui = to_ui_message(AgentResult(output="Paris."), message_id="m9") + assert ui == { + "id": "m9", + "role": "assistant", + "parts": [{"type": "text", "text": "Paris."}], + } + + def test_message_with_tool_blocks_round_trips_to_parts(self): + from agenta.sdk.agents import ContentBlock + + msg = Message( + role="assistant", + content=[ + ContentBlock( + type="tool_call", + tool_call_id="c1", + tool_name="getWeather", + input={"city": "Paris"}, + ), + ], + ) + ui = to_ui_message(msg) + assert ui["role"] == "assistant" + assert ui["parts"][0]["type"] == "tool-getWeather" + assert ui["parts"][0]["toolCallId"] == "c1" + + +# --------------------------------------------------------------------------- +# ui_message_stream +# --------------------------------------------------------------------------- + + +class TestUIMessageStream: + async def test_full_turn_part_order(self): + run = _run( + events=[ + { + "type": "tool_call", + "id": "call_1", + "name": "getWeather", + "input": {"city": "Paris"}, + }, + { + "type": "tool_result", + "id": "call_1", + "output": "sunny", + "data": {"w": "sunny"}, + }, + {"type": "message_start", "id": "t1"}, + {"type": "message_delta", "id": "t1", "delta": "It is sunny."}, + {"type": "message_end", "id": "t1"}, + {"type": "usage", "input": 820, "output": 36, "cost": 0.004}, + {"type": "done", "stopReason": "end_turn"}, + ], + result={"output": "It is sunny.", "sessionId": "sess_123"}, + ) + parts = await _collect(run, session_id="sess_123") + + assert [p["type"] for p in parts] == [ + "start", + "start-step", + "tool-input-start", + "tool-input-available", + "tool-output-available", + "text-start", + "text-delta", + "text-end", + "finish-step", + "finish", + ] + # start carries the session id; tool output prefers the structured `data`. + assert parts[0]["messageMetadata"] == {"sessionId": "sess_123"} + assert parts[4]["output"] == {"w": "sunny"} + # finish carries the usage and the stop reason. + assert parts[-1]["finishReason"] == "end_turn" + assert parts[-1]["messageMetadata"]["usage"] == { + "input": 820, + "output": 36, + "cost": 0.004, + } + + async def test_usage_falls_back_to_terminal_result(self): + run = _run( + events=[ + {"type": "message", "text": "hi"}, + {"type": "done", "stopReason": "end_turn"}, + ], + result={"output": "hi", "usage": {"input": 10, "output": 2}}, + ) + parts = await _collect(run, session_id="s1") + assert parts[-1]["messageMetadata"]["usage"] == {"input": 10, "output": 2} + + async def test_coalesced_message_emits_text_block(self): + run = _run( + events=[{"type": "message", "text": "Paris."}, {"type": "done"}], + result={"output": "Paris."}, + ) + parts = await _collect(run, session_id="s1") + types = [p["type"] for p in parts] + assert "text-start" in types and "text-delta" in types and "text-end" in types + delta = next(p for p in parts if p["type"] == "text-delta") + assert delta["delta"] == "Paris." + + async def test_permission_interaction_becomes_approval_request(self): + run = _run( + events=[ + { + "type": "interaction_request", + "id": "perm_1", + "kind": "permission", + "payload": { + "toolCallId": "call_1", + "availableReplies": ["once", "always", "reject"], + "toolCall": {"toolCallId": "call_1", "name": "deleteFile"}, + }, + }, + {"type": "done"}, + ], + result={"output": ""}, + ) + parts = await _collect(run, session_id="s1") + approval = next(p for p in parts if p["type"] == "tool-approval-request") + assert approval["approvalId"] == "perm_1" + # REQUIRED top-level toolCallId binds the approval to its tool part (RFC / AI SDK). + assert approval["toolCallId"] == "call_1" + assert approval["availableReplies"] == ["once", "always", "reject"] + assert approval["toolCall"] == {"toolCallId": "call_1", "name": "deleteFile"} + + async def test_permission_tool_call_id_falls_back_to_nested_tool_call(self): + # No top-level toolCallId on the payload: dig it out of the nested ACP toolCall detail. + run = _run( + events=[ + { + "type": "interaction_request", + "id": "perm_2", + "kind": "permission", + "payload": { + "availableReplies": ["once", "reject"], + "toolCall": {"id": "call_9", "name": "deleteFile"}, + }, + }, + {"type": "done"}, + ], + result={"output": ""}, + ) + parts = await _collect(run, session_id="s1") + approval = next(p for p in parts if p["type"] == "tool-approval-request") + assert approval["toolCallId"] == "call_9" + + async def test_tool_denial_becomes_output_denied(self): + # A human denied the tool: it never ran, so emit tool-output-denied (not -available). + run = _run( + events=[ + {"type": "tool_call", "id": "c1", "name": "deleteFile", "input": {}}, + {"type": "tool_result", "id": "c1", "denied": True}, + {"type": "done"}, + ], + result={"output": ""}, + ) + parts = await _collect(run, session_id="s1") + denied = next(p for p in parts if p["type"] == "tool-output-denied") + assert denied["toolCallId"] == "c1" + # A denied result is neither output-available nor output-error. + types = [p["type"] for p in parts] + assert "tool-output-available" not in types + assert "tool-output-error" not in types + + async def test_finish_carries_trace_id_from_param(self): + run = _run( + events=[ + {"type": "message", "text": "hi"}, + {"type": "done", "stopReason": "end_turn"}, + ], + result={"output": "hi", "usage": {"input": 10, "output": 2}}, + ) + parts = await _collect(run, session_id="s1", trace_id="abc123") + # traceId and usage coexist under the finish messageMetadata. + assert parts[-1]["messageMetadata"]["traceId"] == "abc123" + assert parts[-1]["messageMetadata"]["usage"] == {"input": 10, "output": 2} + + async def test_finish_trace_id_falls_back_to_terminal_result(self): + run = _run( + events=[ + {"type": "message", "text": "hi"}, + {"type": "done", "stopReason": "end_turn"}, + ], + result={"output": "hi", "traceId": "trace_from_result"}, + ) + parts = await _collect(run, session_id="s1") + assert parts[-1]["messageMetadata"]["traceId"] == "trace_from_result" + + async def test_render_hint_passes_through_tool_parts(self): + render = {"kind": "component", "component": "WeatherCard"} + run = _run( + events=[ + { + "type": "tool_call", + "id": "c1", + "name": "w", + "input": {}, + "render": render, + }, + { + "type": "tool_result", + "id": "c1", + "data": {"w": "sunny"}, + "render": render, + }, + {"type": "done"}, + ], + result={"output": ""}, + ) + parts = await _collect(run, session_id="s1") + available = next(p for p in parts if p["type"] == "tool-input-available") + output = next(p for p in parts if p["type"] == "tool-output-available") + assert available["render"] == render + assert output["render"] == render + + async def test_tool_error_becomes_output_error(self): + run = _run( + events=[ + {"type": "tool_call", "id": "c1", "name": "w", "input": {}}, + {"type": "tool_result", "id": "c1", "output": "boom", "isError": True}, + {"type": "done"}, + ], + result={"output": ""}, + ) + parts = await _collect(run, session_id="s1") + err = next(p for p in parts if p["type"] == "tool-output-error") + assert err["toolCallId"] == "c1" + assert err["errorText"] == "boom" + + async def test_terminal_failure_emits_error_part_and_no_finish(self): + records = [ + {"kind": "event", "event": {"type": "message", "text": "partial"}}, + {"kind": "result", "result": {"ok": False, "error": "kaboom"}}, + ] + run = AgentRun(_from_list(records)) + parts = [part async for part in ui_message_stream(run, session_id="s1")] + types = [p["type"] for p in parts] + assert types[0] == "start" + assert "finish" not in types + error = next(p for p in parts if p["type"] == "error") + assert "kaboom" in error["errorText"] diff --git a/sdks/python/oss/tests/pytest/unit/agents/test_wire_contract.py b/sdks/python/oss/tests/pytest/unit/agents/test_wire_contract.py index 34687695ed..4aa24a86b1 100644 --- a/sdks/python/oss/tests/pytest/unit/agents/test_wire_contract.py +++ b/sdks/python/oss/tests/pytest/unit/agents/test_wire_contract.py @@ -40,6 +40,7 @@ "trace", "tools", "customTools", + "mcpServers", "toolCallback", "permissionPolicy", "systemPrompt", @@ -52,6 +53,7 @@ "description": "Get a user", "inputSchema": {"type": "object", "properties": {}}, "callRef": "tools__composio__github__GET_THE_AUTHENTICATED_USER__github-tvn", + "kind": "callback", } _CALLBACK = ToolCallback( endpoint="https://api.example/tools/call", authorization="Access tok-123" @@ -222,3 +224,78 @@ def test_result_from_wire_minimal_ok(): assert result.events == [] assert result.capabilities is None assert result.session_id is None + + +def test_request_to_wire_carries_code_client_and_mcp_specs(): + # The three-axes surface reaches the wire intact: a code spec keeps its executor fields + # (kind/runtime/code/env) and the orthogonal axes (needsApproval/render); a client spec + # has no callRef; user MCP servers ride `mcpServers`. + config = PiAgentConfig( + custom_tools=[ + { + "name": "calc", + "description": "calc", + "inputSchema": {"type": "object", "properties": {}}, + "kind": "code", + "runtime": "python", + "code": "def main(): return 1", + "env": {"STRIPE_API_KEY": "sk"}, + "needsApproval": True, + "render": {"kind": "component", "component": "Calc"}, + }, + { + "name": "pick", + "description": "pick", + "inputSchema": {"type": "object", "properties": {}}, + "kind": "client", + }, + ], + mcp_servers=[ + { + "name": "github", + "transport": "stdio", + "command": "npx", + "env": {"GITHUB_TOKEN": "ghp"}, + "tools": ["create_issue"], + } + ], + ) + payload = request_to_wire( + engine="pi", + harness=HarnessType.PI, + sandbox="local", + config=config, + messages=[Message(role="user", content="hi")], + ) + assert set(payload) <= KNOWN_REQUEST_KEYS + code = next(t for t in payload["customTools"] if t["name"] == "calc") + assert code["kind"] == "code" + assert code["runtime"] == "python" + assert code["code"] == "def main(): return 1" + assert code["env"] == {"STRIPE_API_KEY": "sk"} + assert code["needsApproval"] is True + assert code["render"] == {"kind": "component", "component": "Calc"} + client = next(t for t in payload["customTools"] if t["name"] == "pick") + assert client["kind"] == "client" + assert "callRef" not in client + assert payload["mcpServers"] == [ + { + "name": "github", + "transport": "stdio", + "command": "npx", + "env": {"GITHUB_TOKEN": "ghp"}, + "tools": ["create_issue"], + } + ] + + +def test_request_to_wire_omits_mcp_servers_when_none(): + # No declared servers -> no `mcpServers` key (keeps a tool-free payload byte-identical). + payload = request_to_wire( + engine="pi", + harness=HarnessType.PI, + sandbox="local", + config=PiAgentConfig(), + messages=[Message(role="user", content="hi")], + ) + assert "mcpServers" not in payload diff --git a/sdks/python/oss/tests/pytest/unit/agents/tools/__init__.py b/sdks/python/oss/tests/pytest/unit/agents/tools/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/sdks/python/oss/tests/pytest/unit/agents/tools/__init__.py @@ -0,0 +1 @@ + diff --git a/sdks/python/oss/tests/pytest/unit/agents/tools/test_models.py b/sdks/python/oss/tests/pytest/unit/agents/tools/test_models.py new file mode 100644 index 0000000000..f823b4f32c --- /dev/null +++ b/sdks/python/oss/tests/pytest/unit/agents/tools/test_models.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import pytest +from pydantic import ValidationError + +from agenta.sdk.agents.tools import ( + CallbackToolSpec, + CodeToolConfig, + CodeToolSpec, +) + + +def test_canonical_config_forbids_unexpected_fields(): + with pytest.raises(ValidationError): + CodeToolConfig( + name="calc", + script="def main(): return 1", + unexpected=True, + ) + + +def test_code_spec_serializes_only_runner_fields(): + spec = CodeToolSpec( + name="calc", + description="Calculate", + input_schema={"type": "object", "properties": {}}, + runtime="python", + code="def main(): return 1", + env={"TOKEN": "secret"}, + needs_approval=True, + render={"kind": "component", "component": "Calculator"}, + ) + assert spec.to_wire() == { + "name": "calc", + "description": "Calculate", + "inputSchema": {"type": "object", "properties": {}}, + "kind": "code", + "runtime": "python", + "code": "def main(): return 1", + "env": {"TOKEN": "secret"}, + "needsApproval": True, + "render": {"kind": "component", "component": "Calculator"}, + } + + +def test_callback_spec_has_stable_typed_contract(): + spec = CallbackToolSpec( + name="get_user", + description="Get user", + call_ref="tools.composio.github.GET_USER.c1", + ) + assert spec.to_wire()["kind"] == "callback" + assert spec.to_wire()["callRef"] == "tools.composio.github.GET_USER.c1" + + +def test_secret_values_are_hidden_from_repr(): + spec = CodeToolSpec( + name="private", + description="private", + code="...", + env={"TOKEN": "do-not-print"}, + ) + assert "do-not-print" not in repr(spec) diff --git a/sdks/python/oss/tests/pytest/unit/agents/tools/test_parsing.py b/sdks/python/oss/tests/pytest/unit/agents/tools/test_parsing.py new file mode 100644 index 0000000000..ff6f212f9f --- /dev/null +++ b/sdks/python/oss/tests/pytest/unit/agents/tools/test_parsing.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import pytest + +from agenta.sdk.agents.tools import ( + BuiltinToolConfig, + GatewayToolConfig, + ToolConfigurationError, + coerce_tool_config, + coerce_tool_configs, + parse_tool_config, +) + + +def test_strict_parser_accepts_only_canonical_mapping(): + tool = parse_tool_config({"type": "builtin", "name": "read"}) + assert isinstance(tool, BuiltinToolConfig) + with pytest.raises(ToolConfigurationError): + parse_tool_config({"name": "read"}) + + +def test_compat_parser_accepts_legacy_shapes(): + assert coerce_tool_config("bash") == BuiltinToolConfig(name="bash") + gateway = coerce_tool_config( + { + "type": "composio", + "integration": "github", + "action": "GET_USER", + "connection": "c1", + } + ) + assert isinstance(gateway, GatewayToolConfig) + assert gateway.provider == "composio" + + +def test_compat_parser_accepts_playground_gateway_slug_and_metadata(): + gateway = coerce_tool_config( + { + "function": {"name": "tools__composio__github__GET_USER__c1"}, + "needs_approval": True, + "render": {"kind": "component", "component": "User"}, + } + ) + assert gateway.needs_approval is True + assert gateway.render == {"kind": "component", "component": "User"} + + +def test_collect_mode_reports_invalid_entries(): + result = coerce_tool_configs( + ["read", {"invalid": True}, None], + on_error="collect", + ) + assert result.tool_configs == [BuiltinToolConfig(name="read")] + assert [diagnostic.index for diagnostic in result.diagnostics] == [1, 2] + + +def test_default_compat_mode_raises_with_index(): + with pytest.raises(ToolConfigurationError) as caught: + coerce_tool_configs(["read", {"invalid": True}]) + assert caught.value.index == 1 diff --git a/sdks/python/oss/tests/pytest/unit/agents/tools/test_resolver.py b/sdks/python/oss/tests/pytest/unit/agents/tools/test_resolver.py new file mode 100644 index 0000000000..7c7ef58b46 --- /dev/null +++ b/sdks/python/oss/tests/pytest/unit/agents/tools/test_resolver.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +from typing import Mapping, Sequence + +import pytest + +from agenta.sdk.agents.tools import ( + BuiltinToolConfig, + CallbackToolSpec, + ClientToolConfig, + CodeToolConfig, + DuplicateToolNameError, + GatewayToolConfig, + GatewayToolResolution, + MissingSecretPolicy, + MissingToolSecretError, + ToolCallback, + ToolResolver, + UnsupportedToolProviderError, +) + + +class DictSecretProvider: + def __init__(self, values: Mapping[str, str]): + self.values = values + self.requests: list[list[str]] = [] + + async def get_many(self, names: Sequence[str]) -> Mapping[str, str]: + self.requests.append(list(names)) + return {name: self.values[name] for name in names if name in self.values} + + +class FakeGatewayResolver: + async def resolve( + self, + tools: Sequence[GatewayToolConfig], + ) -> GatewayToolResolution: + return GatewayToolResolution( + tool_specs=[ + CallbackToolSpec( + name=tool.name or f"{tool.integration}__{tool.action}", + description=tool.name or tool.action, + call_ref=tool.reference, + needs_approval=tool.needs_approval, + render=tool.render, + ) + for tool in tools + ], + tool_callback=ToolCallback(endpoint="https://example/tools/call"), + ) + + +async def test_resolves_builtin_code_client_and_scopes_secrets(): + secrets = DictSecretProvider({"A": "a", "B": "b"}) + resolved = await ToolResolver(secret_provider=secrets).resolve( + [ + BuiltinToolConfig(name="read"), + CodeToolConfig(name="one", script="...", secrets=["A"]), + CodeToolConfig(name="two", script="...", secrets=["B"]), + ClientToolConfig(name="pick"), + ] + ) + assert resolved.builtin_names == ["read"] + assert secrets.requests == [["A", "B"]] + by_name = {spec.name: spec for spec in resolved.tool_specs} + assert by_name["one"].env == {"A": "a"} + assert by_name["two"].env == {"B": "b"} + assert by_name["pick"].kind == "client" + + +async def test_missing_declared_secret_fails_by_default(): + resolver = ToolResolver(secret_provider=DictSecretProvider({})) + with pytest.raises(MissingToolSecretError) as caught: + await resolver.resolve( + [CodeToolConfig(name="charge", script="...", secrets=["TOKEN"])] + ) + assert caught.value.secret_names == ("TOKEN",) + + +async def test_missing_secret_can_be_explicitly_omitted_for_compatibility(): + resolved = await ToolResolver( + secret_provider=DictSecretProvider({}), + missing_secret_policy=MissingSecretPolicy.OMIT, + ).resolve([CodeToolConfig(name="charge", script="...", secrets=["TOKEN"])]) + assert resolved.tool_specs[0].env == {} + + +async def test_gateway_requires_injected_adapter(): + with pytest.raises(UnsupportedToolProviderError): + await ToolResolver().resolve( + [ + GatewayToolConfig( + integration="github", + action="GET_USER", + connection="c1", + ) + ] + ) + + +async def test_gateway_metadata_survives_resolution(): + resolved = await ToolResolver(gateway_resolver=FakeGatewayResolver()).resolve( + [ + GatewayToolConfig( + integration="github", + action="GET_USER", + connection="c1", + needs_approval=True, + render={"kind": "component", "component": "User"}, + ) + ] + ) + spec = resolved.tool_specs[0] + assert spec.needs_approval is True + assert spec.render == {"kind": "component", "component": "User"} + + +@pytest.mark.parametrize( + "configs", + [ + [BuiltinToolConfig(name="read"), BuiltinToolConfig(name="read")], + [ + BuiltinToolConfig(name="same"), + ClientToolConfig(name="same"), + ], + [ClientToolConfig(name="same"), ClientToolConfig(name="same")], + ], +) +async def test_duplicate_model_visible_names_are_rejected(configs): + with pytest.raises(DuplicateToolNameError): + await ToolResolver().resolve(configs)