diff --git a/redisvl/mcp/config.py b/redisvl/mcp/config.py index 65f53060..014c4e46 100644 --- a/redisvl/mcp/config.py +++ b/redisvl/mcp/config.py @@ -282,9 +282,16 @@ class MCPSchemaOverrides(BaseModel): class MCPIndexBindingConfig(BaseModel): - """The sole configured v1 index binding.""" + """A single configured logical index binding. + + A server can configure one or many of these under ``indexes.``. Each + binding inspects and serves one existing Redis index independently, and + owns its own schema inspection, runtime mapping, and search validation. + """ redis_name: str = Field(..., min_length=1) + description: str | None = Field(default=None, min_length=1) + read_only: bool = False vectorizer: MCPVectorizerConfig | None = None search: MCPIndexSearchConfig runtime: MCPRuntimeConfig @@ -355,83 +362,9 @@ def _validate_capability_requirements(self) -> "MCPIndexBindingConfig": return self - -class MCPConfig(BaseModel): - """Validated MCP server configuration loaded from YAML.""" - - server: MCPServerConfig - indexes: dict[str, MCPIndexBindingConfig] - - @model_validator(mode="after") - def _validate_bindings(self) -> "MCPConfig": - """Validate that there is exactly one configured logical binding.""" - if len(self.indexes) != 1: - raise ValueError( - "indexes must contain exactly one configured index binding" - ) - - binding_id = next(iter(self.indexes)) - if not binding_id.strip(): - raise ValueError("indexes binding id must be non-blank") - return self - - @property - def binding_id(self) -> str: - """Return the single logical binding identifier configured for v1.""" - return next(iter(self.indexes)) - - @property - def binding(self) -> MCPIndexBindingConfig: - """Return the sole configured binding.""" - return self.indexes[self.binding_id] - - @property - def runtime(self) -> MCPRuntimeConfig: - """Expose the sole binding's runtime config for phase 1.""" - return self.binding.runtime - - @property - def vectorizer(self) -> MCPVectorizerConfig | None: - """Expose the sole binding's vectorizer config for phase 1.""" - return self.binding.vectorizer - - @property - def search(self) -> MCPIndexSearchConfig: - """Expose the sole binding's configured search behavior.""" - return self.binding.search - - @property - def uses_text_search(self) -> bool: - """Return whether configured search uses a text field.""" - return self.binding.uses_text_search - - @property - def uses_query_embedding(self) -> bool: - """Return whether configured search embeds user queries.""" - return self.binding.uses_query_embedding - - @property - def supports_vector_backed_upsert(self) -> bool: - """Return whether configured upserts manage a vector field.""" - return self.binding.supports_vector_backed_upsert - - @property - def supports_server_side_embedding(self) -> bool: - """Return whether configured upserts can generate embeddings.""" - return self.binding.supports_server_side_embedding - - @property - def requires_startup_vectorizer(self) -> bool: - """Return whether startup must initialize a vectorizer.""" - return self.binding.requires_startup_vectorizer - - @property - def redis_name(self) -> str: - """Return the existing Redis index name that must be inspected at startup.""" - return self.binding.redis_name - + @staticmethod def inspected_schema_from_index_info( - self, index_info: dict[str, Any] + index_info: dict[str, Any], ) -> dict[str, Any]: """Build a schema dict from FT.INFO while preserving discovered field identity. @@ -478,7 +411,7 @@ def merge_schema_overrides( if isinstance(field, dict) and "name" in field } - for override in self.binding.schema_overrides.fields: + for override in self.schema_overrides.fields: discovered = discovered_fields.get(override.name) if discovered is None: raise ValueError( @@ -575,6 +508,29 @@ def validate_search( ) +class MCPConfig(BaseModel): + """Validated MCP server configuration loaded from YAML. + + ``indexes`` is the canonical multi-binding map: a server may configure one + or many logical bindings. Single-index configs remain valid and unchanged; + each binding owns its own inspection, runtime mapping, and search behavior. + """ + + server: MCPServerConfig + indexes: dict[str, MCPIndexBindingConfig] + + @model_validator(mode="after") + def _validate_bindings(self) -> "MCPConfig": + """Require at least one binding and reject blank logical ids.""" + if not self.indexes: + raise ValueError("indexes must contain at least one configured binding") + + for binding_id in self.indexes: + if not binding_id.strip(): + raise ValueError("indexes binding id must be non-blank") + return self + + def _substitute_env(value: Any) -> Any: """Recursively resolve `${VAR}` and `${VAR:-default}` placeholders.""" if isinstance(value, dict): diff --git a/redisvl/mcp/runtime.py b/redisvl/mcp/runtime.py new file mode 100644 index 00000000..2564d899 --- /dev/null +++ b/redisvl/mcp/runtime.py @@ -0,0 +1,28 @@ +from dataclasses import dataclass +from typing import Any + +from redisvl.index import AsyncSearchIndex +from redisvl.mcp.config import MCPIndexBindingConfig +from redisvl.schema import IndexSchema + + +@dataclass(frozen=True) +class BindingRuntime: + """Immutable per-binding runtime state assembled once at server startup. + + Each configured logical index becomes one ``BindingRuntime`` bundling the + binding config with the resources a tool call needs: the connected index, + its effective (inspected + overridden) schema, an optional vectorizer, the + resolved native-hybrid-search capability, and the effective write policy. + + Tools resolve a binding once via ``server.resolve_binding(index)`` and then + read these attributes directly instead of calling back into the server. + """ + + binding_id: str + binding: MCPIndexBindingConfig + index: AsyncSearchIndex + schema: IndexSchema + vectorizer: Any | None + supports_native_hybrid_search: bool + effective_read_only: bool diff --git a/redisvl/mcp/server.py b/redisvl/mcp/server.py index a67618df..2b014354 100644 --- a/redisvl/mcp/server.py +++ b/redisvl/mcp/server.py @@ -10,7 +10,9 @@ from redisvl.exceptions import RedisSearchError from redisvl.index import AsyncSearchIndex from redisvl.mcp.auth import build_auth_provider, resolve_auth_config -from redisvl.mcp.config import MCPConfig, load_mcp_config +from redisvl.mcp.config import MCPConfig, MCPIndexBindingConfig, load_mcp_config +from redisvl.mcp.errors import MCPErrorCode, RedisVLMCPError +from redisvl.mcp.runtime import BindingRuntime from redisvl.mcp.settings import MCPSettings from redisvl.mcp.tools.search import register_search_tool from redisvl.mcp.tools.upsert import register_upsert_tool @@ -47,17 +49,15 @@ class _LifecycleState(Enum): class RedisVLMCPServer(FastMCP): - """MCP server exposing RedisVL capabilities for one existing Redis index.""" + """MCP server exposing RedisVL capabilities for one or many existing indexes.""" _LifecycleState = _LifecycleState def __init__(self, settings: MCPSettings): - """Create a server shell with lazy config, index, and vectorizer state.""" + """Create a server shell with lazy config and per-binding runtime state.""" self.mcp_settings = settings self.config: MCPConfig | None = None - self._index: AsyncSearchIndex | None = None - self._vectorizer: Any | None = None - self._supports_native_hybrid_search: bool | None = None + self._bindings: dict[str, BindingRuntime] = {} self._semaphore: asyncio.Semaphore | None = None self._tools_registered = False @@ -90,12 +90,10 @@ async def startup(self) -> None: """Load config, inspect the configured index, and initialize dependencies.""" async with self._transition_lock: await self._begin_startup() - client = None try: - client = await self._initialize_runtime_resources() + await self._initialize_runtime_resources() await self._mark_running() except Exception: - await self._teardown_runtime(client) await self._mark_stopped() raise @@ -112,22 +110,64 @@ async def shutdown(self) -> None: finally: await self._mark_stopped() - async def get_index(self) -> AsyncSearchIndex: - """Return the initialized async index or fail if startup has not run.""" - if self._index is None: - raise RuntimeError("MCP server has not been started") - return self._index + def resolve_binding(self, index_id: str | None) -> BindingRuntime: + """Resolve the runtime for a logical index id, honoring single-index defaults. - async def get_vectorizer(self) -> Any: - """Return the initialized vectorizer or fail if startup has not run.""" - if self.config is None: + - ``None`` with exactly one configured binding returns that binding, + preserving backward-compatible single-index behavior. + - ``None`` with multiple bindings is an ``invalid_request``; the caller + must name an index. + - An unknown id is an ``invalid_request``. + + Write-availability is not enforced here; that is the upsert tool's job. + """ + if not self._bindings: raise RuntimeError("MCP server has not been started") - if self._vectorizer is None: - raise RuntimeError("MCP server vectorizer is not configured") - return self._vectorizer - async def run_guarded(self, operation_name: str, awaitable: Awaitable[Any]) -> Any: - """Run a coroutine under the configured concurrency and timeout limits.""" + if index_id is None: + if len(self._bindings) == 1: + return next(iter(self._bindings.values())) + available = ", ".join(sorted(self._bindings)) + raise RedisVLMCPError( + "index is required when multiple indexes are configured; " + f"available: {available}", + code=MCPErrorCode.INVALID_REQUEST, + retryable=False, + ) + + runtime = self._bindings.get(index_id) + if runtime is None: + available = ", ".join(sorted(self._bindings)) + raise RedisVLMCPError( + f"Unknown index '{index_id}'; available: {available}", + code=MCPErrorCode.INVALID_REQUEST, + retryable=False, + ) + return runtime + + async def get_index(self, index_id: str | None = None) -> AsyncSearchIndex: + """Return an initialized async index, defaulting to the sole binding.""" + return self.resolve_binding(index_id).index + + async def get_vectorizer(self, index_id: str | None = None) -> Any: + """Return an initialized vectorizer, defaulting to the sole binding.""" + runtime = self.resolve_binding(index_id) + if runtime.vectorizer is None: + raise RuntimeError("MCP server vectorizer is not configured") + return runtime.vectorizer + + async def run_guarded( + self, + operation_name: str, + awaitable: Awaitable[Any], + *, + timeout_seconds: float, + ) -> Any: + """Run a coroutine under the global concurrency cap and a request timeout. + + The timeout is sourced per-binding by the caller; the concurrency + semaphore is a single process-wide ceiling shared across all bindings. + """ del operation_name semaphore = self._semaphore if semaphore is None: @@ -140,8 +180,7 @@ async def run_guarded(self, operation_name: str, awaitable: Awaitable[Any]) -> A self._close_awaitable(awaitable) raise RuntimeError("MCP server is not running") - config = self.config - if config is None: + if self.config is None: self._close_awaitable(awaitable) raise RuntimeError("MCP server is not running") @@ -149,33 +188,32 @@ async def run_guarded(self, operation_name: str, awaitable: Awaitable[Any]) -> A self._active_requests_drained.clear() try: - return await asyncio.wait_for( - awaitable, - timeout=config.runtime.request_timeout_seconds, - ) + return await asyncio.wait_for(awaitable, timeout=timeout_seconds) finally: async with self._request_state_lock: self._active_requests -= 1 if self._active_requests == 0: self._active_requests_drained.set() - def _build_vectorizer(self) -> Any: - """Instantiate the configured vectorizer class from validated config.""" - if self.config is None: - raise RuntimeError("MCP server config not loaded") - if self.config.vectorizer is None: + @staticmethod + def _build_vectorizer(binding: MCPIndexBindingConfig) -> Any: + """Instantiate a binding's configured vectorizer class from its config.""" + if binding.vectorizer is None: raise RuntimeError("MCP server vectorizer is not configured") - vectorizer_class = resolve_vectorizer_class(self.config.vectorizer.class_name) - return vectorizer_class(**self.config.vectorizer.to_init_kwargs()) + vectorizer_class = resolve_vectorizer_class(binding.vectorizer.class_name) + return vectorizer_class(**binding.vectorizer.to_init_kwargs()) - def _validate_vectorizer_dims(self, schema: IndexSchema) -> None: + @staticmethod + def _validate_vectorizer_dims( + binding: MCPIndexBindingConfig, vectorizer: Any, schema: IndexSchema + ) -> None: """Fail startup when vectorizer dimensions disagree with schema dimensions.""" - if self.config is None or self._vectorizer is None: + if vectorizer is None: return - configured_dims = self.config.get_vector_field_dims(schema) - actual_dims = getattr(self._vectorizer, "dims", None) + configured_dims = binding.get_vector_field_dims(schema) + actual_dims = getattr(vectorizer, "dims", None) if ( configured_dims is not None and actual_dims is not None @@ -185,33 +223,32 @@ def _validate_vectorizer_dims(self, schema: IndexSchema) -> None: f"Vectorizer dims {actual_dims} do not match configured vector field dims {configured_dims}" ) - async def supports_native_hybrid_search(self) -> bool: - """Return whether the current runtime supports Redis native hybrid search.""" - if self._supports_native_hybrid_search is not None: - return self._supports_native_hybrid_search - if self._index is None: - raise RuntimeError("MCP server has not been started") + @staticmethod + async def _probe_native_hybrid_search(index: AsyncSearchIndex) -> bool: + """Probe whether a connected index supports Redis native hybrid search.""" if not is_version_gte(redis_py_version, "7.1.0"): - self._supports_native_hybrid_search = False return False - client = await self._index._get_client() + client = await index._get_client() info = await client.info("server") if not is_version_gte(info.get("redis_version", "0.0.0"), "8.4.0"): - self._supports_native_hybrid_search = False return False - self._supports_native_hybrid_search = hasattr( - client.ft(self._index.schema.index.name), "hybrid_search" - ) - return self._supports_native_hybrid_search + return hasattr(client.ft(index.schema.index.name), "hybrid_search") - def _register_tools(self, schema: IndexSchema) -> None: - """Register MCP tools once the server is ready.""" + def _register_tools(self) -> None: + """Register MCP tools once every binding is ready.""" if self._tools_registered or not hasattr(self, "tool"): return - register_search_tool(self, schema) + # The search description advertises schema-specific filter hints, which + # are only unambiguous for a single binding. With multiple bindings the + # caller selects an index per call, so fall back to the base description. + search_schema: IndexSchema | None = None + if len(self._bindings) == 1: + search_schema = next(iter(self._bindings.values())).schema + + register_search_tool(self, search_schema) if not self.mcp_settings.read_only: register_upsert_tool(self) self._tools_registered = True @@ -241,15 +278,15 @@ async def _server_lifespan(self, _server: Any): finally: await self.shutdown() - async def _teardown_runtime(self, client: Any | None = None) -> None: - """Release runtime resources and clear terminal state.""" - vectorizer = self._vectorizer - index = self._index - self._vectorizer = None - self._index = None - self.config = None - self._semaphore = None + @staticmethod + async def _close_resources( + *, index: Any | None, vectorizer: Any | None, client: Any | None = None + ) -> None: + """Close one binding's vectorizer and Redis connection. + A fully built binding owns its client through ``index``; a binding that + failed mid-startup may have a bare ``client`` and no index yet. + """ try: if vectorizer is not None: aclose = getattr(vectorizer, "aclose", None) @@ -259,12 +296,23 @@ async def _teardown_runtime(self, client: Any | None = None) -> None: elif callable(close): close() finally: - self._supports_native_hybrid_search = None if index is not None: await index.disconnect() elif client is not None: await client.aclose() + async def _teardown_runtime(self) -> None: + """Release every binding's runtime resources and clear terminal state.""" + bindings = list(self._bindings.values()) + self._bindings = {} + self.config = None + self._semaphore = None + + for runtime in bindings: + await self._close_resources( + index=runtime.index, vectorizer=runtime.vectorizer + ) + @staticmethod def _close_awaitable(awaitable: Awaitable[Any]) -> None: """Close coroutine objects we reject before awaiting to avoid warnings.""" @@ -296,6 +344,7 @@ async def _begin_shutdown(self) -> bool: ): self.config = None self._semaphore = None + self._bindings = {} self._lifecycle_state = _LifecycleState.STOPPED return True @@ -329,29 +378,63 @@ def _verify_auth_not_stale(self) -> None: "ensure the config file exists before constructing the server." ) - async def _initialize_runtime_resources(self) -> Any: - """Load config and initialize the Redis-backed runtime dependencies.""" + async def _initialize_runtime_resources(self) -> None: + """Load config and initialize every configured binding independently.""" self.config = load_mcp_config(self._config_path) self._verify_auth_not_stale() - self._semaphore = asyncio.Semaphore(self.config.runtime.max_concurrency) - self._supports_native_hybrid_search = None - timeout = self.config.runtime.startup_timeout_seconds + # The semaphore is a single process-wide concurrency ceiling shared by + # all bindings; take the max configured limit across bindings. + self._semaphore = asyncio.Semaphore( + max( + binding.runtime.max_concurrency + for binding in self.config.indexes.values() + ) + ) + self._bindings = {} + + try: + for binding_id, binding in self.config.indexes.items(): + self._bindings[binding_id] = await self._initialize_binding( + binding_id, binding + ) + self._register_tools() + except Exception: + # Tear down any bindings already built before re-raising so startup + # fails closed without leaking connections. + await self._teardown_runtime() + raise + async def _initialize_binding( + self, binding_id: str, binding: MCPIndexBindingConfig + ) -> BindingRuntime: + """Inspect, validate, and initialize a single configured binding.""" + timeout = binding.runtime.startup_timeout_seconds client = await self._connect_redis_client(timeout) + index: AsyncSearchIndex | None = None + vectorizer: Any | None = None try: - effective_schema = await self._load_effective_schema(client, timeout) - self._initialize_index(effective_schema, client) - self.config.validate_search( - schema=effective_schema, - supports_native_hybrid_search=await self.supports_native_hybrid_search(), + schema = await self._load_effective_schema(binding, client, timeout) + index = self._make_index(schema, client) + supports_native_hybrid = await self._probe_native_hybrid_search(index) + binding.validate_search( + schema=schema, + supports_native_hybrid_search=supports_native_hybrid, + ) + if binding.requires_startup_vectorizer: + vectorizer = await self._initialize_vectorizer(binding, schema, timeout) + return BindingRuntime( + binding_id=binding_id, + binding=binding, + index=index, + schema=schema, + vectorizer=vectorizer, + supports_native_hybrid_search=supports_native_hybrid, + effective_read_only=self.mcp_settings.read_only or binding.read_only, ) - if self.config.requires_startup_vectorizer: - await self._initialize_vectorizer(effective_schema, timeout) - self._register_tools(effective_schema) - return client except Exception: - if self._index is None: - await client.aclose() + await self._close_resources( + index=index, vectorizer=vectorizer, client=client + ) raise async def _connect_redis_client(self, timeout: int) -> Any: @@ -368,37 +451,41 @@ async def _connect_redis_client(self, timeout: int) -> Any: await asyncio.wait_for(client.info("server"), timeout=timeout) return client - async def _load_effective_schema(self, client: Any, timeout: int) -> IndexSchema: - """Inspect the configured Redis index and build the effective schema.""" - if self.config is None: - raise RuntimeError("MCP server config not loaded") - + async def _load_effective_schema( + self, binding: MCPIndexBindingConfig, client: Any, timeout: int + ) -> IndexSchema: + """Inspect a binding's Redis index and build its effective schema.""" try: index_info = await asyncio.wait_for( - AsyncSearchIndex._info(self.config.redis_name, client), + AsyncSearchIndex._info(binding.redis_name, client), timeout=timeout, ) except RedisSearchError as exc: if self._is_missing_index_error(exc): raise ValueError( - f"Configured Redis index '{self.config.redis_name}' does not exist" + f"Configured Redis index '{binding.redis_name}' does not exist" ) from exc raise - inspected_schema = self.config.inspected_schema_from_index_info(index_info) - return self.config.to_index_schema(inspected_schema) + inspected_schema = binding.inspected_schema_from_index_info(index_info) + return binding.to_index_schema(inspected_schema) - def _initialize_index(self, schema: IndexSchema, client: Any) -> None: - """Bind the inspected schema and Redis client into an async index.""" - self._index = AsyncSearchIndex(schema=schema, redis_client=client) + @staticmethod + def _make_index(schema: IndexSchema, client: Any) -> AsyncSearchIndex: + """Bind an inspected schema and Redis client into an async index.""" + index = AsyncSearchIndex(schema=schema, redis_client=client) # The server acquired this client explicitly during startup, so hand # ownership to the index for a single shutdown path. - self._index._owns_redis_client = True - - async def _initialize_vectorizer(self, schema: IndexSchema, timeout: int) -> None: - """Build the configured vectorizer and validate it against the schema.""" - self._vectorizer = await asyncio.wait_for( - asyncio.to_thread(self._build_vectorizer), + index._owns_redis_client = True + return index + + async def _initialize_vectorizer( + self, binding: MCPIndexBindingConfig, schema: IndexSchema, timeout: int + ) -> Any: + """Build a binding's vectorizer and validate it against the schema.""" + vectorizer = await asyncio.wait_for( + asyncio.to_thread(self._build_vectorizer, binding), timeout=timeout, ) - self._validate_vectorizer_dims(schema) + self._validate_vectorizer_dims(binding, vectorizer, schema) + return vectorizer diff --git a/redisvl/mcp/tools/search.py b/redisvl/mcp/tools/search.py index 89cfa00f..8ed94e03 100644 --- a/redisvl/mcp/tools/search.py +++ b/redisvl/mcp/tools/search.py @@ -51,10 +51,17 @@ def _build_return_fields_hint(schema: IndexSchema) -> str: def _build_search_tool_description( - schema: IndexSchema, base_description: str | None = None + schema: IndexSchema | None, base_description: str | None = None ) -> str: - """Build the `search-records` description from static text plus schema hints.""" + """Build the `search-records` description from static text plus schema hints. + + With multiple bindings configured the schema is ambiguous (the caller picks + an index per call via `list-indexes`), so `schema` is None and only the + base description is returned. + """ description = (base_description or DEFAULT_SEARCH_DESCRIPTION).strip() + if schema is None: + return description # `exists` is currently accepted for any schema field in the MCP object filter. exists_fields = [field.name for field in schema.fields.values()] @@ -79,18 +86,16 @@ def _validate_request( limit: int | None, offset: int, return_fields: list[str] | None, - server: Any, - index: Any, + runtime: Any, + schema: Any, ) -> tuple[int, list[str]]: """Validate a `search-records` request and resolve default projection. The MCP caller can only supply query text, pagination, filters, and return - fields. Search mode and tuning are sourced from config, so this validation - step focuses only on the public request contract. + fields. Search mode and tuning are sourced from the selected binding's + config, so this validation step focuses only on the public request contract. """ - runtime = server.config.runtime - if not isinstance(query, str) or not query.strip(): raise RedisVLMCPError( "query must be a non-empty string", @@ -125,17 +130,17 @@ def _validate_request( retryable=False, ) - schema_fields = set(index.schema.field_names) + schema_fields = set(schema.field_names) vector_field_names = { field_name - for field_name, field in index.schema.fields.items() + for field_name, field in schema.fields.items() if field.type == "vector" } if return_fields is None: fields = [ field_name - for field_name in index.schema.field_names + for field_name in schema.field_names if field_name not in vector_field_names ] else: @@ -224,12 +229,19 @@ async def _embed_query(vectorizer: Any, query: str) -> Any: return await asyncio.to_thread(embed, query) -def _get_configured_search(server: Any) -> tuple[str, dict[str, Any]]: - """Return the configured search mode and normalized query params.""" - search_config = server.config.search +def _get_configured_search(rt: Any) -> tuple[str, dict[str, Any]]: + """Return the binding's configured search mode and normalized query params.""" + search_config = rt.binding.search return search_config.type, search_config.to_query_params() +def _require_vectorizer(rt: Any) -> Any: + """Return the binding's vectorizer or fail when it is not configured.""" + if rt.vectorizer is None: + raise RuntimeError("MCP server vectorizer is not configured") + return rt.vectorizer + + def _build_native_hybrid_kwargs( *, query: str, @@ -311,29 +323,27 @@ def _build_fallback_hybrid_kwargs( async def _build_query( *, - server: Any, - index: Any, + rt: Any, query: str, limit: int, offset: int, filter_value: str | dict[str, Any] | None, return_fields: list[str], ) -> tuple[Any, str, str, str]: - """Build the RedisVL query object from configured search mode and params. + """Build the RedisVL query object from the binding's search mode and params. Returns the query instance, the raw score field to read from RedisVL results, the public MCP `score_type`, and the configured `search_type`. """ - runtime = server.config.runtime - search_type, search_params = _get_configured_search(server) + runtime = rt.binding.runtime + search_type, search_params = _get_configured_search(rt) num_results = limit + offset - filter_expression = parse_filter(filter_value, index.schema) + filter_expression = parse_filter(filter_value, rt.schema) if search_type == "vector": if runtime.vector_field_name is None: raise RuntimeError("Vector search requires a configured vector field") - vectorizer = await server.get_vectorizer() - embedding = await _embed_query(vectorizer, query) + embedding = await _embed_query(_require_vectorizer(rt), query) vector_kwargs = { "vector": embedding, "vector_field_name": runtime.vector_field_name, @@ -373,9 +383,8 @@ async def _build_query( search_type, ) - vectorizer = await server.get_vectorizer() - embedding = await _embed_query(vectorizer, query) - if await server.supports_native_hybrid_search(): + embedding = await _embed_query(_require_vectorizer(rt), query) + if rt.supports_native_hybrid_search: native_query = HybridQuery( **_build_native_hybrid_kwargs( query=query, @@ -423,20 +432,19 @@ async def search_records( filter: str | dict[str, Any] | None = None, return_fields: list[str] | None = None, ) -> dict[str, Any]: - """Execute `search-records` against the configured Redis index binding.""" + """Execute `search-records` against the selected Redis index binding.""" try: - index = await server.get_index() + rt = server.resolve_binding(None) effective_limit, effective_return_fields = _validate_request( query=query, limit=limit, offset=offset, return_fields=return_fields, - server=server, - index=index, + runtime=rt.binding.runtime, + schema=rt.schema, ) built_query, score_field, score_type, search_type = await _build_query( - server=server, - index=index, + rt=rt, query=query.strip(), limit=effective_limit, offset=offset, @@ -445,7 +453,8 @@ async def search_records( ) raw_results = await server.run_guarded( "search-records", - index.query(built_query), + rt.index.query(built_query), + timeout_seconds=rt.binding.runtime.request_timeout_seconds, ) sliced_results = raw_results[offset : offset + effective_limit] return { @@ -467,7 +476,7 @@ async def search_records( raise map_exception(exc) from exc -def register_search_tool(server: Any, schema: IndexSchema) -> None: +def register_search_tool(server: Any, schema: IndexSchema | None) -> None: """Register the MCP `search-records` tool with its config-owned contract.""" description = _build_search_tool_description( schema=schema, diff --git a/redisvl/mcp/tools/upsert.py b/redisvl/mcp/tools/upsert.py index dd5721e8..eb72e08d 100644 --- a/redisvl/mcp/tools/upsert.py +++ b/redisvl/mcp/tools/upsert.py @@ -14,14 +14,12 @@ def _validate_request( *, - server: Any, + runtime: Any, records: list[dict[str, Any]], id_field: str | None, skip_embedding_if_present: bool | None, ) -> bool: """Validate the public upsert request contract and resolve defaults.""" - runtime = server.config.runtime - if not isinstance(records, list) or not records: raise RedisVLMCPError( "records must be a non-empty list", @@ -183,9 +181,9 @@ async def _embed_many(vectorizer: Any, contents: list[str]) -> list[list[float]] return embeddings -def _vector_dtype(server: Any, index: Any) -> str: - """Resolve the configured vector field datatype as a lowercase string.""" - field = server.config.get_vector_field(index.schema) +def _vector_dtype(rt: Any) -> str: + """Resolve the binding's vector field datatype as a lowercase string.""" + field = rt.binding.get_vector_field(rt.schema) datatype = getattr(field.attrs.datatype, "value", field.attrs.datatype) return str(datatype).lower() @@ -225,12 +223,12 @@ def _validate_record( def _prepare_record_for_storage( record: dict[str, Any], *, - server: Any, - index: Any, + rt: Any, ) -> dict[str, Any]: """Validate records before serializing HASH vectors for storage.""" prepared = dict(record) - vector_field_name = server.config.runtime.vector_field_name + index = rt.index + vector_field_name = rt.binding.runtime.vector_field_name _validate_record(prepared, index=index, vector_field_name=vector_field_name) if vector_field_name is None: @@ -242,7 +240,7 @@ def _prepare_record_for_storage( if isinstance(vector_value, list): prepared[vector_field_name] = array_to_buffer( vector_value, - _vector_dtype(server, index), + _vector_dtype(rt), ) return prepared @@ -254,11 +252,13 @@ async def upsert_records( id_field: str | None = None, skip_embedding_if_present: bool | None = None, ) -> dict[str, Any]: - """Execute `upsert-records` against the configured Redis index.""" + """Execute `upsert-records` against the selected Redis index binding.""" try: - index = await server.get_index() + rt = server.resolve_binding(None) + index = rt.index + runtime = rt.binding.runtime effective_skip_embedding = _validate_request( - server=server, + runtime=runtime, records=records, id_field=id_field, skip_embedding_if_present=skip_embedding_if_present, @@ -266,14 +266,13 @@ async def upsert_records( # Copy caller-provided records before enriching them with embeddings or # storage-specific serialization so the MCP tool does not mutate inputs. prepared_records = [deepcopy(record) for record in records] - runtime = server.config.runtime for record in prepared_records: _validate_record( record, index=index, vector_field_name=runtime.vector_field_name, ) - if server.config.supports_server_side_embedding: + if rt.binding.supports_server_side_embedding: if ( runtime.default_embed_text_field is None or runtime.vector_field_name is None @@ -289,7 +288,9 @@ async def upsert_records( ) if embed_contents: - vectorizer = await server.get_vectorizer() + if rt.vectorizer is None: + raise RuntimeError("MCP server vectorizer is not configured") + vectorizer = rt.vectorizer # TODO: Avoid re-embedding records that already include vectors. # The current flow can regenerate embeddings for caller-supplied # vectors, which is wasteful and can add external service cost. @@ -319,14 +320,14 @@ async def upsert_records( ) loadable_records = [ - _prepare_record_for_storage(record, server=server, index=index) - for record in prepared_records + _prepare_record_for_storage(record, rt=rt) for record in prepared_records ] try: keys = await server.run_guarded( "upsert-records", index.load(loadable_records, id_field=id_field), + timeout_seconds=runtime.request_timeout_seconds, ) except Exception as exc: mapped = map_exception(exc) diff --git a/tests/integration/test_aggregation.py b/tests/integration/test_aggregation.py index fe26839b..06711464 100644 --- a/tests/integration/test_aggregation.py +++ b/tests/integration/test_aggregation.py @@ -8,13 +8,13 @@ @pytest.fixture -def index(multi_vector_data, redis_url, worker_id): +def index(multi_vector_data, redis_url, redis_test_name): index = SearchIndex.from_dict( { "index": { - "name": f"user_index_{worker_id}", - "prefix": f"v1_{worker_id}", + "name": redis_test_name("user_index"), + "prefix": redis_test_name("v1"), "storage_type": "hash", }, "fields": [ @@ -60,7 +60,7 @@ def index(multi_vector_data, redis_url, worker_id): ) # create the index (no data yet) - index.create(overwrite=True) + index.create(overwrite=True, drop=True) # prepare and load the data def hash_preprocess(item: dict) -> dict: diff --git a/tests/integration/test_hybrid.py b/tests/integration/test_hybrid.py index 95f9c481..3c505337 100644 --- a/tests/integration/test_hybrid.py +++ b/tests/integration/test_hybrid.py @@ -19,12 +19,12 @@ @pytest.fixture -def index_schema(worker_id): +def index_schema(redis_test_name): return IndexSchema.from_dict( { "index": { - "name": f"user_index_{worker_id}", - "prefix": f"v1_{worker_id}", + "name": redis_test_name("user_index"), + "prefix": redis_test_name("v1"), "storage_type": "hash", }, "fields": [ @@ -74,7 +74,7 @@ def index(index_schema, multi_vector_data, redis_url): index = SearchIndex(schema=index_schema, redis_url=redis_url) # create the index (no data yet) - index.create(overwrite=True) + index.create(overwrite=True, drop=True) # prepare and load the data def hash_preprocess(item: dict) -> dict: @@ -97,7 +97,7 @@ def hash_preprocess(item: dict) -> dict: @pytest.fixture async def async_index(index_schema, multi_vector_data, async_client): index = AsyncSearchIndex(schema=index_schema, redis_client=async_client) - await index.create(overwrite=True) + await index.create(overwrite=True, drop=True) def hash_preprocess(item: dict) -> dict: return { diff --git a/tests/integration/test_mcp/test_server_startup.py b/tests/integration/test_mcp/test_server_startup.py index 6b9c315a..170aa32c 100644 --- a/tests/integration/test_mcp/test_server_startup.py +++ b/tests/integration/test_mcp/test_server_startup.py @@ -6,6 +6,7 @@ import yaml from redisvl.index import AsyncSearchIndex +from redisvl.mcp.errors import MCPErrorCode, RedisVLMCPError from redisvl.mcp.server import RedisVLMCPServer from redisvl.mcp.settings import MCPSettings from redisvl.redis.connection import is_version_gte @@ -154,10 +155,10 @@ async def test_server_startup_succeeds_for_fulltext_without_vectorizer( original_build_vectorizer = RedisVLMCPServer._build_vectorizer build_vectorizer_called = False - def tracked_build_vectorizer(self): + def tracked_build_vectorizer(binding): nonlocal build_vectorizer_called build_vectorizer_called = True - return original_build_vectorizer(self) + return original_build_vectorizer(binding) monkeypatch.setattr( "redisvl.mcp.server.resolve_vectorizer_class", @@ -166,7 +167,7 @@ def tracked_build_vectorizer(self): monkeypatch.setattr( RedisVLMCPServer, "_build_vectorizer", - tracked_build_vectorizer, + staticmethod(tracked_build_vectorizer), ) server = RedisVLMCPServer( MCPSettings( @@ -472,7 +473,9 @@ async def guarded_operation(): return "done" operation_task = asyncio.create_task( - server.run_guarded("drain-during-shutdown", guarded_operation()) + server.run_guarded( + "drain-during-shutdown", guarded_operation(), timeout_seconds=5 + ) ) await entered.wait() @@ -518,7 +521,9 @@ async def guarded_operation(): return "done" active_task = asyncio.create_task( - server.run_guarded("active-during-shutdown", guarded_operation()) + server.run_guarded( + "active-during-shutdown", guarded_operation(), timeout_seconds=5 + ) ) await entered.wait() @@ -528,7 +533,7 @@ async def guarded_operation(): future = asyncio.get_running_loop().create_future() future.set_result("later") with pytest.raises(RuntimeError, match="not running"): - await server.run_guarded("reject-after-stop", future) + await server.run_guarded("reject-after-stop", future, timeout_seconds=5) release.set() assert await active_task == "done" @@ -567,11 +572,13 @@ async def second_operation(): second_started.set() return "second" - first_task = asyncio.create_task(server.run_guarded("first-op", first_operation())) + first_task = asyncio.create_task( + server.run_guarded("first-op", first_operation(), timeout_seconds=5) + ) await first_entered.wait() second_task = asyncio.create_task( - server.run_guarded("second-op", second_operation()) + server.run_guarded("second-op", second_operation(), timeout_seconds=5) ) await asyncio.sleep(0) @@ -586,3 +593,133 @@ async def second_operation(): assert second_started.is_set() is False await shutdown_task + + +@pytest.fixture +def multi_index_config_path(tmp_path: Path, redis_url: str): + def factory(bindings: dict[str, dict[str, Any]]) -> str: + config = {"server": {"redis_url": redis_url}, "indexes": bindings} + config_path = tmp_path / "multi-index.yaml" + config_path.write_text(yaml.safe_dump(config), encoding="utf-8") + return str(config_path) + + return factory + + +def _binding_config(redis_name: str, *, read_only: bool = False) -> dict[str, Any]: + return { + "redis_name": redis_name, + "read_only": read_only, + "vectorizer": {"class": "FakeVectorizer", "model": "fake-model", "dims": 3}, + "search": {"type": "vector"}, + "runtime": { + "text_field_name": "content", + "vector_field_name": "embedding", + "default_embed_text_field": "content", + }, + } + + +@pytest.mark.asyncio +async def test_server_starts_with_multiple_bindings( + monkeypatch, existing_index, multi_index_config_path +): + knowledge = await existing_index(index_name="mcp-multi-knowledge") + tickets = await existing_index(index_name="mcp-multi-tickets") + monkeypatch.setattr( + "redisvl.mcp.server.resolve_vectorizer_class", + lambda class_name: FakeVectorizer, + ) + server = RedisVLMCPServer( + MCPSettings( + config=multi_index_config_path( + { + "knowledge": _binding_config(knowledge.name), + "tickets": _binding_config(tickets.name, read_only=True), + } + ) + ) + ) + + await server.startup() + + try: + assert sorted(server._bindings) == ["knowledge", "tickets"] + + knowledge_rt = server.resolve_binding("knowledge") + tickets_rt = server.resolve_binding("tickets") + + # Each binding is inspected and initialized independently. + assert knowledge_rt.index.schema.index.name == knowledge.name + assert tickets_rt.index.schema.index.name == tickets.name + assert knowledge_rt.index is not tickets_rt.index + + # Per-index write availability is respected. + assert knowledge_rt.effective_read_only is False + assert tickets_rt.effective_read_only is True + + # An omitted index is ambiguous when multiple bindings are configured. + with pytest.raises(RedisVLMCPError) as excinfo: + server.resolve_binding(None) + assert excinfo.value.code == MCPErrorCode.INVALID_REQUEST + finally: + await server.shutdown() + + +@pytest.mark.asyncio +async def test_server_global_read_only_overrides_all_bindings( + monkeypatch, existing_index, multi_index_config_path +): + knowledge = await existing_index(index_name="mcp-multi-ro-knowledge") + tickets = await existing_index(index_name="mcp-multi-ro-tickets") + monkeypatch.setattr( + "redisvl.mcp.server.resolve_vectorizer_class", + lambda class_name: FakeVectorizer, + ) + server = RedisVLMCPServer( + MCPSettings( + config=multi_index_config_path( + { + "knowledge": _binding_config(knowledge.name), + "tickets": _binding_config(tickets.name, read_only=True), + } + ), + read_only=True, + ) + ) + + await server.startup() + + try: + # Global read-only forces effective write availability false everywhere. + assert server.resolve_binding("knowledge").effective_read_only is True + assert server.resolve_binding("tickets").effective_read_only is True + finally: + await server.shutdown() + + +@pytest.mark.asyncio +async def test_server_startup_fails_when_one_binding_is_invalid( + monkeypatch, existing_index, multi_index_config_path +): + knowledge = await existing_index(index_name="mcp-multi-invalid") + monkeypatch.setattr( + "redisvl.mcp.server.resolve_vectorizer_class", + lambda class_name: FakeVectorizer, + ) + server = RedisVLMCPServer( + MCPSettings( + config=multi_index_config_path( + { + "knowledge": _binding_config(knowledge.name), + "missing": _binding_config("nonexistent-index-name"), + } + ) + ) + ) + + with pytest.raises(ValueError, match="does not exist"): + await server.startup() + + assert server._lifecycle_state.name == "STOPPED" + assert server._bindings == {} diff --git a/tests/unit/test_mcp/test_config.py b/tests/unit/test_mcp/test_config.py index 7631e368..512af828 100644 --- a/tests/unit/test_mcp/test_config.py +++ b/tests/unit/test_mcp/test_config.py @@ -90,11 +90,12 @@ def test_load_mcp_config_env_substitution(tmp_path: Path, monkeypatch): config = load_mcp_config(str(config_path)) assert config.server.redis_url == "redis://localhost:6379" - assert config.binding_id == "knowledge" - assert config.redis_name == "docs-index" - assert config.vectorizer.class_name == "FakeVectorizer" - assert config.vectorizer.model == "test-model" - assert config.vectorizer.extra_kwargs == {"api_config": {"api_key": "secret"}} + assert list(config.indexes) == ["knowledge"] + binding = config.indexes["knowledge"] + assert binding.redis_name == "docs-index" + assert binding.vectorizer.class_name == "FakeVectorizer" + assert binding.vectorizer.model == "test-model" + assert binding.vectorizer.extra_kwargs == {"api_config": {"api_key": "secret"}} def test_load_mcp_config_required_env_missing(tmp_path: Path, monkeypatch): @@ -132,24 +133,46 @@ def test_mcp_config_requires_server_redis_url(): MCPConfig.model_validate(config) -@pytest.mark.parametrize( - "indexes", - [ - {}, - { - "knowledge": deepcopy(_valid_config()["indexes"]["knowledge"]), - "other": deepcopy(_valid_config()["indexes"]["knowledge"]), - }, - ], -) -def test_mcp_config_validates_index_count(indexes): +def test_mcp_config_requires_at_least_one_binding(): config = _valid_config() - config["indexes"] = indexes + config["indexes"] = {} - with pytest.raises(ValueError, match="exactly one configured index binding"): + with pytest.raises(ValueError, match="at least one configured binding"): MCPConfig.model_validate(config) +def test_mcp_config_allows_multiple_bindings(): + config = _valid_config() + config["indexes"] = { + "knowledge": deepcopy(_valid_config()["indexes"]["knowledge"]), + "tickets": deepcopy(_valid_config()["indexes"]["knowledge"]), + } + + loaded = MCPConfig.model_validate(config) + + assert list(loaded.indexes) == ["knowledge", "tickets"] + assert loaded.indexes["tickets"].redis_name == "docs-index" + + +def test_mcp_config_binding_defaults_for_description_and_read_only(): + config = MCPConfig.model_validate(_valid_config()) + + binding = config.indexes["knowledge"] + assert binding.description is None + assert binding.read_only is False + + +def test_mcp_config_binding_accepts_description_and_read_only(): + config = _valid_config() + config["indexes"]["knowledge"]["description"] = "Product docs and runbooks" + config["indexes"]["knowledge"]["read_only"] = True + + binding = MCPConfig.model_validate(config).indexes["knowledge"] + + assert binding.description == "Product docs and runbooks" + assert binding.read_only is True + + def test_mcp_config_rejects_blank_binding_id(): config = _valid_config() config["indexes"] = {"": deepcopy(config["indexes"]["knowledge"])} @@ -166,25 +189,24 @@ def test_mcp_config_rejects_blank_redis_name(): MCPConfig.model_validate(config) -def test_mcp_config_binding_helpers(): +def test_mcp_config_binding_exposes_index_settings(): config = MCPConfig.model_validate(_valid_config()) - assert config.binding_id == "knowledge" - assert config.binding.redis_name == "docs-index" - assert config.binding.search.type == "vector" - assert config.runtime.default_embed_text_field == "content" - assert config.vectorizer.class_name == "FakeVectorizer" - assert config.redis_name == "docs-index" + binding = config.indexes["knowledge"] + assert binding.redis_name == "docs-index" + assert binding.search.type == "vector" + assert binding.runtime.default_embed_text_field == "content" + assert binding.vectorizer.class_name == "FakeVectorizer" def test_vector_search_config_can_omit_text_field_name(): config = _valid_config() del config["indexes"]["knowledge"]["runtime"]["text_field_name"] - loaded = MCPConfig.model_validate(config) + binding = MCPConfig.model_validate(config).indexes["knowledge"] - assert loaded.search.type == "vector" - assert loaded.runtime.text_field_name is None + assert binding.search.type == "vector" + assert binding.runtime.text_field_name is None def test_fulltext_config_can_omit_vector_settings_and_vectorizer(): @@ -194,12 +216,12 @@ def test_fulltext_config_can_omit_vector_settings_and_vectorizer(): del config["indexes"]["knowledge"]["runtime"]["vector_field_name"] del config["indexes"]["knowledge"]["runtime"]["default_embed_text_field"] - loaded = MCPConfig.model_validate(config) + binding = MCPConfig.model_validate(config).indexes["knowledge"] - assert loaded.search.type == "fulltext" - assert loaded.vectorizer is None - assert loaded.runtime.vector_field_name is None - assert loaded.runtime.default_embed_text_field is None + assert binding.search.type == "fulltext" + assert binding.vectorizer is None + assert binding.runtime.vector_field_name is None + assert binding.runtime.default_embed_text_field is None def test_mcp_config_merges_schema_overrides_into_inspection_result(): @@ -221,7 +243,7 @@ def test_mcp_config_merges_schema_overrides_into_inspection_result(): inspected["fields"][1]["attrs"] = {"algorithm": "flat"} config = MCPConfig.model_validate(config_dict) - schema = config.to_index_schema(inspected) + schema = config.indexes["knowledge"].to_index_schema(inspected) assert isinstance(schema, IndexSchema) assert schema.index.name == "docs-index" @@ -237,7 +259,7 @@ def test_mcp_config_rejects_override_for_unknown_field(): config = MCPConfig.model_validate(config_dict) with pytest.raises(ValueError, match="schema_overrides.fields.*missing"): - config.to_index_schema(_inspected_schema()) + config.indexes["knowledge"].to_index_schema(_inspected_schema()) def test_mcp_config_rejects_override_type_conflict(): @@ -248,7 +270,7 @@ def test_mcp_config_rejects_override_type_conflict(): config = MCPConfig.model_validate(config_dict) with pytest.raises(ValueError, match="cannot change discovered field type"): - config.to_index_schema(_inspected_schema()) + config.indexes["knowledge"].to_index_schema(_inspected_schema()) def test_mcp_config_rejects_override_path_conflict(): @@ -280,7 +302,7 @@ def test_mcp_config_rejects_override_path_conflict(): config = MCPConfig.model_validate(config_dict) with pytest.raises(ValueError, match="cannot change discovered field path"): - config.to_index_schema(inspected) + config.indexes["knowledge"].to_index_schema(inspected) def test_mcp_config_validates_runtime_mapping_against_effective_schema(): @@ -289,7 +311,7 @@ def test_mcp_config_validates_runtime_mapping_against_effective_schema(): config = MCPConfig.model_validate(config_dict) with pytest.raises(ValueError, match="runtime.vector_field_name"): - config.to_index_schema(_inspected_schema()) + config.indexes["knowledge"].to_index_schema(_inspected_schema()) def test_fulltext_config_does_not_require_vector_mapping_in_schema(): @@ -300,12 +322,12 @@ def test_fulltext_config_does_not_require_vector_mapping_in_schema(): del config_dict["indexes"]["knowledge"]["runtime"]["default_embed_text_field"] config = MCPConfig.model_validate(config_dict) - schema = config.to_index_schema(_inspected_schema()) + schema = config.indexes["knowledge"].to_index_schema(_inspected_schema()) assert isinstance(schema, IndexSchema) -def test_load_mcp_config_requires_exactly_one_binding(tmp_path: Path): +def test_load_mcp_config_requires_at_least_one_binding(tmp_path: Path): config_path = tmp_path / "mcp.yaml" config_path.write_text( yaml.safe_dump( @@ -317,7 +339,7 @@ def test_load_mcp_config_requires_exactly_one_binding(tmp_path: Path): encoding="utf-8", ) - with pytest.raises(ValueError, match="exactly one configured index binding"): + with pytest.raises(ValueError, match="at least one configured binding"): load_mcp_config(str(config_path)) @@ -328,8 +350,8 @@ def test_mcp_config_accepts_search_types(search_type): loaded = MCPConfig.model_validate(config) - assert loaded.binding.search.type == search_type - assert loaded.binding.search.params == {} + assert loaded.indexes["knowledge"].search.type == search_type + assert loaded.indexes["knowledge"].search.params == {} def test_mcp_config_requires_search_type(): @@ -441,8 +463,8 @@ def test_mcp_config_normalizes_hybrid_linear_text_weight(): loaded = MCPConfig.model_validate(config) - assert loaded.binding.search.type == "hybrid" - assert loaded.binding.search.params["linear_text_weight"] == 0.3 + assert loaded.indexes["knowledge"].search.type == "hybrid" + assert loaded.indexes["knowledge"].search.params["linear_text_weight"] == 0.3 def test_mcp_config_allows_linear_text_weight_without_explicit_combination_method(): @@ -456,8 +478,8 @@ def test_mcp_config_allows_linear_text_weight_without_explicit_combination_metho loaded = MCPConfig.model_validate(config) - assert loaded.binding.search.type == "hybrid" - assert loaded.binding.search.params["linear_text_weight"] == 0.3 + assert loaded.indexes["knowledge"].search.type == "hybrid" + assert loaded.indexes["knowledge"].search.params["linear_text_weight"] == 0.3 @pytest.mark.parametrize( @@ -476,10 +498,10 @@ def test_mcp_config_rejects_native_only_hybrid_runtime_params(params): } loaded = MCPConfig.model_validate(config) - schema = loaded.to_index_schema(_inspected_schema()) + schema = loaded.indexes["knowledge"].to_index_schema(_inspected_schema()) with pytest.raises(ValueError, match="native hybrid search support"): - loaded.validate_search( + loaded.indexes["knowledge"].validate_search( schema=schema, supports_native_hybrid_search=False, ) @@ -497,9 +519,9 @@ def test_mcp_config_allows_linear_hybrid_fallback_params(): } loaded = MCPConfig.model_validate(config) - schema = loaded.to_index_schema(_inspected_schema()) + schema = loaded.indexes["knowledge"].to_index_schema(_inspected_schema()) - loaded.validate_search( + loaded.indexes["knowledge"].validate_search( schema=schema, supports_native_hybrid_search=False, ) diff --git a/tests/unit/test_mcp/test_search_tool_unit.py b/tests/unit/test_mcp/test_search_tool_unit.py index d5a92244..830a0200 100644 --- a/tests/unit/test_mcp/test_search_tool_unit.py +++ b/tests/unit/test_mcp/test_search_tool_unit.py @@ -5,6 +5,7 @@ from redisvl.mcp.config import MCPConfig from redisvl.mcp.errors import MCPErrorCode, RedisVLMCPError +from redisvl.mcp.runtime import BindingRuntime from redisvl.mcp.tools.search import ( _build_fallback_hybrid_kwargs, _build_search_tool_description, @@ -111,6 +112,17 @@ def __init__( self.registered_tools = [] self.native_hybrid_supported = False + def resolve_binding(self, index_id=None): + return BindingRuntime( + binding_id="knowledge", + binding=self.config.indexes["knowledge"], + index=self.index, + schema=self.index.schema, + vectorizer=self.vectorizer, + supports_native_hybrid_search=self.native_hybrid_supported, + effective_read_only=False, + ) + async def get_index(self): return self.index @@ -119,7 +131,7 @@ async def get_vectorizer(self): raise RuntimeError("MCP server vectorizer is not configured") return self.vectorizer - async def run_guarded(self, operation_name, awaitable): + async def run_guarded(self, operation_name, awaitable, *, timeout_seconds=None): return await awaitable async def supports_native_hybrid_search(self): @@ -652,7 +664,7 @@ async def test_validate_search_rejects_reserved_score_metadata_field_names( ) with pytest.raises(ValueError, match="MCP-reserved score metadata names"): - config.validate_search( + config.indexes["knowledge"].validate_search( schema=schema, supports_native_hybrid_search=supports_native, ) @@ -701,7 +713,7 @@ async def test_search_records_rejects_native_only_hybrid_runtime_params(monkeypa ) with pytest.raises(ValueError, match="native hybrid search support"): - server.config.validate_search( + server.config.indexes["knowledge"].validate_search( schema=_schema(), supports_native_hybrid_search=False, ) diff --git a/tests/unit/test_mcp/test_server.py b/tests/unit/test_mcp/test_server.py index e88b8a38..22d5239e 100644 --- a/tests/unit/test_mcp/test_server.py +++ b/tests/unit/test_mcp/test_server.py @@ -37,12 +37,18 @@ def _startup_schema() -> IndexSchema: ) -def _startup_config(): +def _binding_namespace( + *, requires_startup_vectorizer: bool = True, max_concurrency: int = 1 +): return SimpleNamespace( - runtime=SimpleNamespace(max_concurrency=1, startup_timeout_seconds=1), - server=SimpleNamespace(redis_url="redis://localhost:6379"), redis_name="idx", - requires_startup_vectorizer=True, + read_only=False, + requires_startup_vectorizer=requires_startup_vectorizer, + runtime=SimpleNamespace( + max_concurrency=max_concurrency, + startup_timeout_seconds=1, + request_timeout_seconds=1, + ), vectorizer=SimpleNamespace( class_name="FakeVectorizer", to_init_kwargs=lambda: {}, @@ -51,6 +57,22 @@ def _startup_config(): ) +def _startup_config(indexes=None): + return SimpleNamespace( + server=SimpleNamespace(redis_url="redis://localhost:6379"), + indexes=indexes or {"knowledge": _binding_namespace()}, + ) + + +def _patch_probe(monkeypatch, value: bool = False): + async def fake_probe(index): + return value + + monkeypatch.setattr( + RedisVLMCPServer, "_probe_native_hybrid_search", staticmethod(fake_probe) + ) + + @pytest.mark.asyncio async def test_server_registers_fastmcp_lifespan(monkeypatch): captured = {} @@ -97,15 +119,12 @@ async def test_run_guarded_rejects_before_startup(): future.set_result(None) with pytest.raises(RuntimeError, match="not running"): - await server.run_guarded("test", future) + await server.run_guarded("test", future, timeout_seconds=1) @pytest.mark.asyncio async def test_run_guarded_rejects_after_shutdown(): server = RedisVLMCPServer(_dummy_settings()) - server.config = SimpleNamespace( - runtime=SimpleNamespace(request_timeout_seconds=1, max_concurrency=1) - ) server._semaphore = asyncio.Semaphore(1) await server.shutdown() @@ -113,7 +132,66 @@ async def test_run_guarded_rejects_after_shutdown(): future = asyncio.get_running_loop().create_future() future.set_result(None) with pytest.raises(RuntimeError, match="not running"): - await server.run_guarded("test", future) + await server.run_guarded("test", future, timeout_seconds=1) + + +@pytest.mark.asyncio +async def test_run_guarded_uses_per_binding_timeout(monkeypatch): + server = RedisVLMCPServer(_dummy_settings()) + server._semaphore = asyncio.Semaphore(1) + server.config = _startup_config() + server._lifecycle_state = server._LifecycleState.RUNNING + + captured = {} + + async def fake_wait_for(awaitable, timeout): + captured["timeout"] = timeout + return await awaitable + + monkeypatch.setattr("redisvl.mcp.server.asyncio.wait_for", fake_wait_for) + + future = asyncio.get_running_loop().create_future() + future.set_result("ok") + + result = await server.run_guarded("test", future, timeout_seconds=42) + + assert result == "ok" + assert captured["timeout"] == 42 + + +@pytest.mark.asyncio +async def test_startup_sizes_semaphore_from_max_binding_concurrency(monkeypatch): + monkeypatch.setattr( + "redisvl.mcp.server.FastMCP.__init__", lambda self, *a, **k: None + ) + config = _startup_config( + indexes={ + "knowledge": _binding_namespace(max_concurrency=4), + "tickets": _binding_namespace(max_concurrency=9), + } + ) + monkeypatch.setattr("redisvl.mcp.server.load_mcp_config", lambda path: config) + + captured = {} + + def fake_semaphore(value): + captured["value"] = value + return SimpleNamespace(value=value) + + monkeypatch.setattr("redisvl.mcp.server.asyncio.Semaphore", fake_semaphore) + + async def fake_initialize_binding(self, binding_id, binding): + return SimpleNamespace(binding_id=binding_id) + + monkeypatch.setattr( + RedisVLMCPServer, "_initialize_binding", fake_initialize_binding + ) + monkeypatch.setattr(RedisVLMCPServer, "_register_tools", lambda self: None) + + server = RedisVLMCPServer(_dummy_settings()) + await server._initialize_runtime_resources() + + assert captured["value"] == 9 @pytest.mark.asyncio @@ -123,11 +201,7 @@ async def test_startup_failure_leaves_server_stopped(monkeypatch): ) monkeypatch.setattr( "redisvl.mcp.server.load_mcp_config", - lambda path: SimpleNamespace( - runtime=SimpleNamespace(max_concurrency=1, startup_timeout_seconds=1), - server=SimpleNamespace(redis_url="redis://localhost:6379"), - redis_name="idx", - ), + lambda path: _startup_config(), ) async def fail_connection(**kwargs): @@ -146,6 +220,7 @@ async def fail_connection(**kwargs): assert server._lifecycle_state.name == "STOPPED" assert server.config is None assert server._semaphore is None + assert server._bindings == {} @pytest.mark.asyncio @@ -170,7 +245,7 @@ async def aclose(self): async def fake_connect(self, timeout): return client - async def fail_load_schema(self, client, timeout): + async def fail_load_schema(self, binding, client, timeout): raise RuntimeError("schema load failed") monkeypatch.setattr(RedisVLMCPServer, "_connect_redis_client", fake_connect) @@ -185,7 +260,7 @@ async def fail_load_schema(self, client, timeout): assert server._lifecycle_state.name == "STOPPED" assert server.config is None assert server._semaphore is None - assert server._index is None + assert server._bindings == {} @pytest.mark.asyncio @@ -213,13 +288,10 @@ async def aclose(self): async def fake_connect(self, timeout): return client - async def fake_load_schema(self, client, timeout): + async def fake_load_schema(self, binding, client, timeout): return _startup_schema() - async def fake_supports_native_hybrid_search(self): - return False - - async def fail_vectorizer(self, schema, timeout): + async def fail_vectorizer(self, binding, schema, timeout): raise RuntimeError("vectorizer init failed") async def fake_disconnect(self): @@ -227,11 +299,7 @@ async def fake_disconnect(self): monkeypatch.setattr(RedisVLMCPServer, "_connect_redis_client", fake_connect) monkeypatch.setattr(RedisVLMCPServer, "_load_effective_schema", fake_load_schema) - monkeypatch.setattr( - RedisVLMCPServer, - "supports_native_hybrid_search", - fake_supports_native_hybrid_search, - ) + _patch_probe(monkeypatch, value=False) monkeypatch.setattr(RedisVLMCPServer, "_initialize_vectorizer", fail_vectorizer) monkeypatch.setattr( "redisvl.mcp.server.AsyncSearchIndex.disconnect", @@ -249,7 +317,7 @@ async def fake_disconnect(self): assert server._lifecycle_state.name == "STOPPED" assert server.config is None assert server._semaphore is None - assert server._index is None + assert server._bindings == {} @pytest.mark.asyncio @@ -269,7 +337,7 @@ async def aclose(self): async def fake_connect(self, timeout): return FakeClient() - async def fake_load_schema(self, client, timeout): + async def fake_load_schema(self, binding, client, timeout): return IndexSchema.from_dict( { "index": { @@ -295,11 +363,8 @@ async def fake_load_schema(self, client, timeout): } ) - async def fake_supports_native_hybrid_search(self): - return False - - async def fake_initialize_vectorizer(self, schema, timeout): - self._vectorizer = SimpleNamespace(dims=3) + async def fake_initialize_vectorizer(self, binding, schema, timeout): + return SimpleNamespace(dims=3) registered_schemas = [] @@ -311,11 +376,7 @@ async def fake_disconnect(self): monkeypatch.setattr(RedisVLMCPServer, "_connect_redis_client", fake_connect) monkeypatch.setattr(RedisVLMCPServer, "_load_effective_schema", fake_load_schema) - monkeypatch.setattr( - RedisVLMCPServer, - "supports_native_hybrid_search", - fake_supports_native_hybrid_search, - ) + _patch_probe(monkeypatch, value=False) monkeypatch.setattr( RedisVLMCPServer, "_initialize_vectorizer", fake_initialize_vectorizer ) diff --git a/tests/unit/test_mcp/test_server_unit.py b/tests/unit/test_mcp/test_server_unit.py index 13ba23d5..9e8af87a 100644 --- a/tests/unit/test_mcp/test_server_unit.py +++ b/tests/unit/test_mcp/test_server_unit.py @@ -2,6 +2,8 @@ import pytest +from redisvl.mcp.errors import MCPErrorCode, RedisVLMCPError +from redisvl.mcp.runtime import BindingRuntime from redisvl.mcp.server import RedisVLMCPServer @@ -29,14 +31,82 @@ async def _get_client(self): @pytest.mark.asyncio -async def test_supports_native_hybrid_search_caches_runtime_probe(monkeypatch): +async def test_probe_native_hybrid_search_detects_support(monkeypatch): client = FakeClient() - server = RedisVLMCPServer.__new__(RedisVLMCPServer) - server._index = FakeIndex(client) - server._supports_native_hybrid_search = None + index = FakeIndex(client) monkeypatch.setattr("redisvl.mcp.server.redis_py_version", "7.1.0") - assert await server.supports_native_hybrid_search() is True - assert await server.supports_native_hybrid_search() is True + assert await RedisVLMCPServer._probe_native_hybrid_search(index) is True assert client.info_calls == 1 + + +@pytest.mark.asyncio +async def test_probe_native_hybrid_search_false_for_old_redis_py(monkeypatch): + client = FakeClient() + index = FakeIndex(client) + + monkeypatch.setattr("redisvl.mcp.server.redis_py_version", "7.0.0") + + assert await RedisVLMCPServer._probe_native_hybrid_search(index) is False + # Old redis-py short-circuits before querying the server. + assert client.info_calls == 0 + + +def _binding_runtime(binding_id: str) -> BindingRuntime: + return BindingRuntime( + binding_id=binding_id, + binding=SimpleNamespace(), + index=SimpleNamespace(), + schema=SimpleNamespace(), + vectorizer=None, + supports_native_hybrid_search=False, + effective_read_only=False, + ) + + +def _server_with_bindings(*binding_ids: str) -> RedisVLMCPServer: + server = RedisVLMCPServer.__new__(RedisVLMCPServer) + server._bindings = {bid: _binding_runtime(bid) for bid in binding_ids} + return server + + +def test_resolve_binding_before_startup_raises(): + server = RedisVLMCPServer.__new__(RedisVLMCPServer) + server._bindings = {} + + with pytest.raises(RuntimeError, match="not been started"): + server.resolve_binding(None) + + +def test_resolve_binding_defaults_to_sole_binding(): + server = _server_with_bindings("knowledge") + + assert server.resolve_binding(None).binding_id == "knowledge" + + +def test_resolve_binding_requires_index_when_multiple_configured(): + server = _server_with_bindings("knowledge", "tickets") + + with pytest.raises(RedisVLMCPError) as excinfo: + server.resolve_binding(None) + + assert excinfo.value.code == MCPErrorCode.INVALID_REQUEST + assert "knowledge" in str(excinfo.value) + assert "tickets" in str(excinfo.value) + + +def test_resolve_binding_routes_to_named_index(): + server = _server_with_bindings("knowledge", "tickets") + + assert server.resolve_binding("tickets").binding_id == "tickets" + + +def test_resolve_binding_rejects_unknown_index(): + server = _server_with_bindings("knowledge", "tickets") + + with pytest.raises(RedisVLMCPError) as excinfo: + server.resolve_binding("missing") + + assert excinfo.value.code == MCPErrorCode.INVALID_REQUEST + assert "missing" in str(excinfo.value) diff --git a/tests/unit/test_mcp/test_upsert_tool_unit.py b/tests/unit/test_mcp/test_upsert_tool_unit.py index 15d3d5ca..d47046ce 100644 --- a/tests/unit/test_mcp/test_upsert_tool_unit.py +++ b/tests/unit/test_mcp/test_upsert_tool_unit.py @@ -6,6 +6,7 @@ from redisvl.mcp.config import MCPConfig from redisvl.mcp.errors import MCPErrorCode, RedisVLMCPError +from redisvl.mcp.runtime import BindingRuntime from redisvl.mcp.tools.upsert import register_upsert_tool, upsert_records from redisvl.redis.utils import array_to_buffer from redisvl.schema import IndexSchema @@ -164,6 +165,17 @@ def __init__( self.vectorizer = vectorizer or FakeVectorizer() if include_vectorizer else None self.registered_tools = [] + def resolve_binding(self, index_id=None): + return BindingRuntime( + binding_id="knowledge", + binding=self.config.indexes["knowledge"], + index=self.index, + schema=self.index.schema, + vectorizer=self.vectorizer, + supports_native_hybrid_search=False, + effective_read_only=False, + ) + async def get_index(self): return self.index @@ -172,7 +184,9 @@ async def get_vectorizer(self): raise RuntimeError("MCP server vectorizer is not configured") return self.vectorizer - async def run_guarded(self, operation_name: str, awaitable: Any): + async def run_guarded( + self, operation_name: str, awaitable: Any, *, timeout_seconds=None + ): return await awaitable def tool(self, name=None, description=None, **kwargs):