From 875003120e1be298c4cc46bcdde795fdbe58db9d Mon Sep 17 00:00:00 2001 From: lutz-grex Date: Wed, 15 Apr 2026 11:50:54 +0200 Subject: [PATCH 1/3] feat(router): support runtime disabling of tools Add methods to disable/enable tools at runtime. Disabled tools are hidden from listing, lookup, and execution, including in composed routers. Closes #477 --- crates/rmcp/src/handler/server/router.rs | 1 + crates/rmcp/src/handler/server/router/tool.rs | 68 ++++++++- crates/rmcp/tests/test_tool_routers.rs | 130 ++++++++++++++++++ 3 files changed, 194 insertions(+), 5 deletions(-) diff --git a/crates/rmcp/src/handler/server/router.rs b/crates/rmcp/src/handler/server/router.rs index 08beb61d2..aeca04cbe 100644 --- a/crates/rmcp/src/handler/server/router.rs +++ b/crates/rmcp/src/handler/server/router.rs @@ -84,6 +84,7 @@ where match request { ClientRequest::CallToolRequest(request) => { if self.tool_router.has_route(request.params.name.as_ref()) + || self.tool_router.is_disabled(request.params.name.as_ref()) || !self.tool_router.transparent_when_not_found { let tool_call_context = crate::handler::server::tool::ToolCallContext::new( diff --git a/crates/rmcp/src/handler/server/router/tool.rs b/crates/rmcp/src/handler/server/router/tool.rs index 79a228ffe..e6d861a45 100644 --- a/crates/rmcp/src/handler/server/router/tool.rs +++ b/crates/rmcp/src/handler/server/router/tool.rs @@ -305,6 +305,8 @@ pub struct ToolRouter { pub map: std::collections::HashMap, ToolRoute>, pub transparent_when_not_found: bool, + + disabled: std::collections::HashSet>, } impl Default for ToolRouter { @@ -312,6 +314,7 @@ impl Default for ToolRouter { Self { map: std::collections::HashMap::new(), transparent_when_not_found: false, + disabled: std::collections::HashSet::new(), } } } @@ -320,6 +323,7 @@ impl Clone for ToolRouter { Self { map: self.map.clone(), transparent_when_not_found: self.transparent_when_not_found, + disabled: self.disabled.clone(), } } } @@ -329,7 +333,11 @@ impl IntoIterator for ToolRouter { type IntoIter = std::collections::hash_map::IntoValues, ToolRoute>; fn into_iter(self) -> Self::IntoIter { - self.map.into_values() + let mut map = self.map; + for name in &self.disabled { + map.remove(name); + } + map.into_values() } } @@ -341,6 +349,7 @@ where Self { map: std::collections::HashMap::new(), transparent_when_not_found: false, + disabled: std::collections::HashSet::new(), } } pub fn with_route(mut self, route: R) -> Self @@ -394,6 +403,7 @@ where } pub fn merge(&mut self, other: ToolRouter) { + self.disabled.extend(other.disabled); for item in other.map.into_values() { self.add_route(item); } @@ -401,17 +411,56 @@ where pub fn remove_route(&mut self, name: &str) { self.map.remove(name); + self.disabled.remove(name); } + pub fn has_route(&self, name: &str) -> bool { - self.map.contains_key(name) + self.map.contains_key(name) && !self.disabled.contains(name) + } + + /// Disable a tool by name so it is hidden from `list_all`, `get`, and + /// rejected by `call`. The tool remains in the router and can be + /// re-enabled later with [`enable_route`](Self::enable_route). + /// + /// The name is recorded even if no matching route exists yet, so routes + /// added later (via [`add_route`](Self::add_route) or + /// [`merge`](Self::merge)) will inherit the disabled state. + pub fn disable_route(&mut self, name: &str) { + self.disabled.insert(Cow::Owned(name.to_owned())); + } + + /// Re-enable a previously disabled tool. + pub fn enable_route(&mut self, name: &str) { + self.disabled.remove(name); + } + + /// Returns `true` if the tool exists in the router but is currently + /// disabled. + pub fn is_disabled(&self, name: &str) -> bool { + self.map.contains_key(name) && self.disabled.contains(name) + } + + /// Builder-style variant of [`disable_route`](Self::disable_route). + /// + /// The name is recorded even if no matching route has been added yet, + /// so it can be called before [`with_route`](Self::with_route) in a + /// builder chain. + pub fn with_disabled(mut self, name: impl Into>) -> Self { + self.disabled.insert(name.into()); + self } + pub async fn call( &self, context: ToolCallContext<'_, S>, ) -> Result { + let name = context.name(); + if self.disabled.contains(name) { + return Err(crate::ErrorData::invalid_params("tool not found", None)); + } let item = self .map - .get(context.name()) + .get(name) .ok_or_else(|| crate::ErrorData::invalid_params("tool not found", None))?; let result = (item.call)(context).await?; @@ -420,15 +469,24 @@ where } pub fn list_all(&self) -> Vec { - let mut tools: Vec<_> = self.map.values().map(|item| item.attr.clone()).collect(); + let mut tools: Vec<_> = self + .map + .values() + .filter(|item| !self.disabled.contains(&item.attr.name)) + .map(|item| item.attr.clone()) + .collect(); tools.sort_by(|a, b| a.name.cmp(&b.name)); tools } /// Get a tool definition by name. /// - /// Returns the tool if found, or `None` if no tool with the given name exists. + /// Returns the tool if found and enabled, or `None` if the tool does not + /// exist or is disabled. pub fn get(&self, name: &str) -> Option<&crate::model::Tool> { + if self.disabled.contains(name) { + return None; + } self.map.get(name).map(|r| &r.attr) } } diff --git a/crates/rmcp/tests/test_tool_routers.rs b/crates/rmcp/tests/test_tool_routers.rs index c10665064..d60aa5d32 100644 --- a/crates/rmcp/tests/test_tool_routers.rs +++ b/crates/rmcp/tests/test_tool_routers.rs @@ -84,3 +84,133 @@ fn test_tool_router_list_all_is_sorted() { "list_all() should return tools sorted alphabetically by name" ); } + +fn build_router() -> ToolRouter> { + ToolRouter::>::new() + .with_route((async_function_tool_attr(), async_function)) + .with_route((async_function2_tool_attr(), async_function2)) + + TestHandler::<()>::test_router_1() + + TestHandler::<()>::test_router_2() +} + +#[test] +fn test_disable_route() { + let mut router = build_router(); + assert_eq!(router.list_all().len(), 4); + assert!(router.has_route("async_function")); + assert!(router.get("async_function").is_some()); + + router.disable_route("async_function"); + + assert_eq!(router.list_all().len(), 3); + assert!(!router.has_route("async_function")); + assert!(router.get("async_function").is_none()); + assert!(router.is_disabled("async_function")); + + // other tools unaffected + assert!(router.has_route("async_function2")); + assert!(router.get("async_function2").is_some()); + assert!(!router.is_disabled("async_function2")); +} + +#[test] +fn test_enable_route() { + let mut router = build_router(); + router.disable_route("async_function"); + assert!(!router.has_route("async_function")); + + router.enable_route("async_function"); + assert!(router.has_route("async_function")); + assert!(router.get("async_function").is_some()); + assert!(!router.is_disabled("async_function")); + assert_eq!(router.list_all().len(), 4); +} + +#[test] +fn test_with_disabled_builder() { + let router = build_router() + .with_disabled("async_function") + .with_disabled("sync_method"); + + assert_eq!(router.list_all().len(), 2); + assert!(!router.has_route("async_function")); + assert!(!router.has_route("sync_method")); + assert!(router.has_route("async_function2")); + assert!(router.has_route("async_method")); +} + +#[test] +fn test_disabled_tools_survive_merge() { + let mut router_a = ToolRouter::>::new() + .with_route((async_function_tool_attr(), async_function)); + router_a.disable_route("async_function"); + + let router_b = ToolRouter::>::new() + .with_route((async_function2_tool_attr(), async_function2)); + + router_a.merge(router_b); + + assert_eq!(router_a.list_all().len(), 1); + assert!(router_a.is_disabled("async_function")); + assert!(router_a.has_route("async_function2")); +} + +#[test] +fn test_disable_nonexistent_tool() { + let mut router = build_router(); + // should not panic + router.disable_route("does_not_exist"); + assert_eq!(router.list_all().len(), 4); + // is_disabled returns false for tools not in the map + assert!(!router.is_disabled("does_not_exist")); +} + +#[test] +fn test_remove_route_clears_disabled_state() { + let mut router = build_router(); + router.disable_route("async_function"); + assert!(router.is_disabled("async_function")); + + router.remove_route("async_function"); + assert!(!router.is_disabled("async_function")); + assert!(!router.has_route("async_function")); +} + +#[test] +fn test_into_iter_skips_disabled() { + let router = build_router().with_disabled("async_function"); + let names: Vec<_> = router + .into_iter() + .map(|r| r.attr.name.to_string()) + .collect(); + assert_eq!(names.len(), 3); + assert!(!names.contains(&"async_function".to_string())); +} + +#[test] +fn test_pre_disable_before_add_route() { + // Disabling a name before adding a route with that name should + // result in the route being disabled once added. + let router = ToolRouter::>::new() + .with_disabled("async_function") + .with_route((async_function_tool_attr(), async_function)); + + assert_eq!(router.list_all().len(), 0); + assert!(router.is_disabled("async_function")); + assert!(!router.has_route("async_function")); +} + +#[test] +fn test_disabled_tool_invisible_across_all_queries() { + let router = build_router().with_disabled("async_function"); + + // Not listed + let names: Vec<_> = router.list_all().iter().map(|t| t.name.clone()).collect(); + assert!(!names.contains(&"async_function".into())); + // Not retrievable + assert!(router.get("async_function").is_none()); + // Not routable + assert!(!router.has_route("async_function")); + // But still known as disabled + assert!(router.is_disabled("async_function")); +} From 16f8b8c7808569ba566dd3f875d644eaa82ec2f1 Mon Sep 17 00:00:00 2001 From: lutz-grex Date: Fri, 17 Apr 2026 15:24:12 +0200 Subject: [PATCH 2/3] fix(router): simplify disable tool api --- crates/rmcp/src/handler/server/router.rs | 6 +- crates/rmcp/src/handler/server/router/tool.rs | 80 +++++++++++++++++-- crates/rmcp/tests/test_tool_routers.rs | 74 ++++++++++++++--- 3 files changed, 141 insertions(+), 19 deletions(-) diff --git a/crates/rmcp/src/handler/server/router.rs b/crates/rmcp/src/handler/server/router.rs index aeca04cbe..19ed8a90c 100644 --- a/crates/rmcp/src/handler/server/router.rs +++ b/crates/rmcp/src/handler/server/router.rs @@ -83,8 +83,10 @@ where ) -> Result<::Resp, crate::ErrorData> { match request { ClientRequest::CallToolRequest(request) => { - if self.tool_router.has_route(request.params.name.as_ref()) - || self.tool_router.is_disabled(request.params.name.as_ref()) + if self + .tool_router + .map + .contains_key(request.params.name.as_ref()) || !self.tool_router.transparent_when_not_found { let tool_call_context = crate::handler::server::tool::ToolCallContext::new( diff --git a/crates/rmcp/src/handler/server/router/tool.rs b/crates/rmcp/src/handler/server/router/tool.rs index e6d861a45..b98982f49 100644 --- a/crates/rmcp/src/handler/server/router/tool.rs +++ b/crates/rmcp/src/handler/server/router/tool.rs @@ -409,11 +409,19 @@ where } } + /// Remove a tool route from the router. + /// + /// The disabled state is **preserved**: if the name was in the disabled + /// set, it stays there so that a future [`add_route`](Self::add_route) + /// or [`merge`](Self::merge) with the same name will inherit the + /// disabled state. To also clear the disabled marker, call + /// [`enable_route`](Self::enable_route) afterwards. pub fn remove_route(&mut self, name: &str) { self.map.remove(name); - self.disabled.remove(name); } + /// Returns `true` if the tool is registered **and** not currently + /// disabled. pub fn has_route(&self, name: &str) -> bool { self.map.contains_key(name) && !self.disabled.contains(name) } @@ -422,20 +430,30 @@ where /// rejected by `call`. The tool remains in the router and can be /// re-enabled later with [`enable_route`](Self::enable_route). /// + /// Returns `true` if the name was newly added to the disabled set. /// The name is recorded even if no matching route exists yet, so routes /// added later (via [`add_route`](Self::add_route) or /// [`merge`](Self::merge)) will inherit the disabled state. - pub fn disable_route(&mut self, name: &str) { - self.disabled.insert(Cow::Owned(name.to_owned())); + /// + /// Callers should send `Peer::notify_tool_list_changed` when the + /// visible tool list changes. Accepts `&'static str` or `String`; + /// for a non-static `&str`, call `.to_owned()` first. + pub fn disable_route(&mut self, name: impl Into>) -> bool { + self.disabled.insert(name.into()) } - /// Re-enable a previously disabled tool. - pub fn enable_route(&mut self, name: &str) { - self.disabled.remove(name); + /// Re-enable a previously disabled tool. Returns `true` if the name + /// was present in the disabled set and was removed. + /// + /// Callers should send `Peer::notify_tool_list_changed` when the + /// visible tool list changes. + pub fn enable_route(&mut self, name: &str) -> bool { + self.disabled.remove(name) } - /// Returns `true` if the tool exists in the router but is currently - /// disabled. + /// Returns `true` if the tool exists in the router **and** is currently + /// disabled. Returns `false` if the tool does not exist or if the name + /// was pre-disabled without a matching route. pub fn is_disabled(&self, name: &str) -> bool { self.map.contains_key(name) && self.disabled.contains(name) } @@ -511,3 +529,49 @@ where self.merge(other); } } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use super::*; + use crate::{ + RoleServer, + model::{CallToolRequestParams, ErrorCode, NumberOrString}, + service::{AtomicU32RequestIdProvider, Peer, RequestContext}, + }; + + struct DummyService; + impl crate::handler::server::ServerHandler for DummyService {} + + #[tokio::test] + async fn test_call_disabled_tool_returns_error() { + let service = DummyService; + let mut router = ToolRouter::new().with_route(ToolRoute::new_dyn( + crate::model::Tool::new("test_tool", "a test tool", Arc::new(Default::default())), + |_ctx| Box::pin(async { Ok(CallToolResult::default()) }), + )); + router.disable_route("test_tool"); + + let id_provider: Arc = + Arc::new(AtomicU32RequestIdProvider::default()); + let (peer, _rx) = Peer::::new(id_provider, None); + let ctx = crate::handler::server::tool::ToolCallContext::new( + &service, + CallToolRequestParams { + meta: None, + name: Cow::Borrowed("test_tool"), + arguments: None, + task: None, + }, + RequestContext::new(NumberOrString::Number(1), peer), + ); + + let err = router + .call(ctx) + .await + .expect_err("disabled tool should reject"); + assert_eq!(err.code, ErrorCode::INVALID_PARAMS); + assert_eq!(err.message, "tool not found"); + } +} diff --git a/crates/rmcp/tests/test_tool_routers.rs b/crates/rmcp/tests/test_tool_routers.rs index d60aa5d32..0bb50e43d 100644 --- a/crates/rmcp/tests/test_tool_routers.rs +++ b/crates/rmcp/tests/test_tool_routers.rs @@ -100,7 +100,7 @@ fn test_disable_route() { assert!(router.has_route("async_function")); assert!(router.get("async_function").is_some()); - router.disable_route("async_function"); + assert!(router.disable_route("async_function")); assert_eq!(router.list_all().len(), 3); assert!(!router.has_route("async_function")); @@ -116,10 +116,10 @@ fn test_disable_route() { #[test] fn test_enable_route() { let mut router = build_router(); - router.disable_route("async_function"); + assert!(router.disable_route("async_function")); assert!(!router.has_route("async_function")); - router.enable_route("async_function"); + assert!(router.enable_route("async_function")); assert!(router.has_route("async_function")); assert!(router.get("async_function").is_some()); assert!(!router.is_disabled("async_function")); @@ -143,7 +143,7 @@ fn test_with_disabled_builder() { fn test_disabled_tools_survive_merge() { let mut router_a = ToolRouter::>::new() .with_route((async_function_tool_attr(), async_function)); - router_a.disable_route("async_function"); + assert!(router_a.disable_route("async_function")); let router_b = ToolRouter::>::new() .with_route((async_function2_tool_attr(), async_function2)); @@ -158,22 +158,42 @@ fn test_disabled_tools_survive_merge() { #[test] fn test_disable_nonexistent_tool() { let mut router = build_router(); - // should not panic - router.disable_route("does_not_exist"); + // should not panic; returns true because the name is newly added to disabled set + assert!(router.disable_route("does_not_exist")); assert_eq!(router.list_all().len(), 4); // is_disabled returns false for tools not in the map assert!(!router.is_disabled("does_not_exist")); } #[test] -fn test_remove_route_clears_disabled_state() { +fn test_remove_route_preserves_disabled_state() { let mut router = build_router(); - router.disable_route("async_function"); + assert!(router.disable_route("async_function")); assert!(router.is_disabled("async_function")); router.remove_route("async_function"); + assert!(!router.has_route("async_function")); + // Disabled marker is preserved — is_disabled returns false (no route in map) + // but re-adding will inherit the disabled state (tested separately) assert!(!router.is_disabled("async_function")); +} + +#[test] +fn test_remove_route_then_readd_stays_disabled() { + let mut router = build_router(); + assert!(router.disable_route("async_function")); + + router.remove_route("async_function"); + assert!(!router.has_route("async_function")); + + // Re-add the route — it should inherit the disabled state + let other = ToolRouter::>::new() + .with_route((async_function_tool_attr(), async_function)); + router.merge(other); + assert!(!router.has_route("async_function")); + assert!(router.is_disabled("async_function")); + assert!(router.get("async_function").is_none()); } #[test] @@ -211,6 +231,42 @@ fn test_disabled_tool_invisible_across_all_queries() { assert!(router.get("async_function").is_none()); // Not routable assert!(!router.has_route("async_function")); - // But still known as disabled + // But known as disabled assert!(router.is_disabled("async_function")); } + +#[test] +fn test_disable_route_then_add_route_blocks_tool() { + // Full pre-disable lifecycle via runtime mutation (not builder) + let mut router = ToolRouter::>::new(); + router.disable_route("async_function"); + + // Add route after disabling — tool should be blocked + let other = ToolRouter::>::new() + .with_route((async_function_tool_attr(), async_function)); + router.merge(other); + + assert!(router.is_disabled("async_function")); + assert!(!router.has_route("async_function")); + assert!(router.get("async_function").is_none()); + assert_eq!(router.list_all().len(), 0); +} + +#[test] +fn test_disable_enable_return_false_cases() { + let mut router = build_router(); + + // Repeated disable returns false + assert!(router.disable_route("async_function")); + assert!(!router.disable_route("async_function")); + + // Enable returns true, then false on repeat + assert!(router.enable_route("async_function")); + assert!(!router.enable_route("async_function")); + + // Enable on name never disabled returns false + assert!(!router.enable_route("async_function2")); + + // Enable on unknown name returns false + assert!(!router.enable_route("unknown")); +} From 30ac65c30a0a2487913609abd2cf9e9a48153ae1 Mon Sep 17 00:00:00 2001 From: lutz-grex Date: Tue, 21 Apr 2026 11:31:09 +0200 Subject: [PATCH 3/3] feat(router): auto-send tools/list_changed on disable/enable --- crates/rmcp/src/handler/server/router.rs | 92 +++++++++- crates/rmcp/src/handler/server/router/tool.rs | 109 ++++++++--- .../tests/test_tool_disable_notification.rs | 172 ++++++++++++++++++ crates/rmcp/tests/test_tool_routers.rs | 102 ++++++++++- 4 files changed, 450 insertions(+), 25 deletions(-) create mode 100644 crates/rmcp/tests/test_tool_disable_notification.rs diff --git a/crates/rmcp/src/handler/server/router.rs b/crates/rmcp/src/handler/server/router.rs index 19ed8a90c..45ff9a586 100644 --- a/crates/rmcp/src/handler/server/router.rs +++ b/crates/rmcp/src/handler/server/router.rs @@ -6,7 +6,7 @@ use tool::{IntoToolRoute, ToolRoute}; use super::ServerHandler; use crate::{ RoleServer, Service, - model::{ClientRequest, ListPromptsResult, ListToolsResult, ServerResult}, + model::{ClientNotification, ClientRequest, ListPromptsResult, ListToolsResult, ServerResult}, service::NotificationContext, }; @@ -18,6 +18,7 @@ pub struct Router { pub tool_router: tool::ToolRouter, pub prompt_router: prompt::PromptRouter, pub service: Arc, + peer_slot: Arc>>, } impl Router @@ -25,10 +26,14 @@ where S: ServerHandler, { pub fn new(service: S) -> Self { + let (notifier, peer_slot) = tool::ToolRouter::::deferred_peer_notifier(); + let mut tool_router = tool::ToolRouter::new(); + tool_router.set_notifier(notifier); Self { - tool_router: tool::ToolRouter::new(), + tool_router, prompt_router: prompt::PromptRouter::new(), service: Arc::new(service), + peer_slot, } } @@ -72,6 +77,12 @@ where notification: ::PeerNot, context: NotificationContext, ) -> Result<(), crate::ErrorData> { + if matches!( + ¬ification, + ClientNotification::InitializedNotification(_) + ) { + let _ = self.peer_slot.set(context.peer.clone()); + } self.service .handle_notification(notification, context) .await @@ -137,6 +148,81 @@ where } fn get_info(&self) -> ::Info { - ServerHandler::get_info(&self.service) + let mut info = ServerHandler::get_info(&self.service); + info.capabilities + .tools + .get_or_insert_with(Default::default) + .list_changed = Some(true); + info + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use super::*; + use crate::{ + model::{CallToolResult, ClientNotification, ServerNotification, Tool}, + service::{AtomicU32RequestIdProvider, Peer, PeerSinkMessage, RequestIdProvider}, + }; + + struct DummyHandler; + impl ServerHandler for DummyHandler {} + + async fn recv_notification( + rx: &mut tokio::sync::mpsc::Receiver>, + ) -> ServerNotification { + let msg = tokio::time::timeout(std::time::Duration::from_secs(1), rx.recv()) + .await + .expect("timed out") + .expect("channel closed"); + match msg { + PeerSinkMessage::Notification { + notification, + responder, + } => { + let _ = responder.send(Ok(())); + notification + } + other => panic!("expected notification, got {other:?}"), + } + } + + #[tokio::test] + async fn test_router_deferred_notifier_e2e() { + let mut router = Router::new(DummyHandler).with_tool(tool::ToolRoute::new_dyn( + Tool::new("my_tool", "test", Arc::new(Default::default())), + |_ctx| Box::pin(async { Ok(CallToolResult::default()) }), + )); + + let id_provider: Arc = + Arc::new(AtomicU32RequestIdProvider::default()); + let (peer, mut rx) = Peer::::new(id_provider, None); + + let context = crate::service::NotificationContext { + peer: peer.clone(), + meta: Default::default(), + extensions: Default::default(), + }; + router + .handle_notification( + ClientNotification::InitializedNotification(Default::default()), + context, + ) + .await + .unwrap(); + + router.tool_router.disable_route("my_tool"); + assert!(matches!( + recv_notification(&mut rx).await, + ServerNotification::ToolListChangedNotification(_) + )); + + router.tool_router.enable_route("my_tool"); + assert!(matches!( + recv_notification(&mut rx).await, + ServerNotification::ToolListChangedNotification(_) + )); } } diff --git a/crates/rmcp/src/handler/server/router/tool.rs b/crates/rmcp/src/handler/server/router/tool.rs index b98982f49..b33f67b78 100644 --- a/crates/rmcp/src/handler/server/router/tool.rs +++ b/crates/rmcp/src/handler/server/router/tool.rs @@ -298,7 +298,6 @@ where self } } -#[derive(Debug)] #[non_exhaustive] pub struct ToolRouter { #[allow(clippy::type_complexity)] @@ -307,6 +306,22 @@ pub struct ToolRouter { pub transparent_when_not_found: bool, disabled: std::collections::HashSet>, + + notifier: Option>, +} + +impl std::fmt::Debug for ToolRouter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ToolRouter") + .field("map", &self.map) + .field( + "transparent_when_not_found", + &self.transparent_when_not_found, + ) + .field("disabled", &self.disabled) + .field("notifier", &self.notifier.as_ref().map(|_| "...")) + .finish() + } } impl Default for ToolRouter { @@ -315,15 +330,18 @@ impl Default for ToolRouter { map: std::collections::HashMap::new(), transparent_when_not_found: false, disabled: std::collections::HashSet::new(), + notifier: None, } } } + impl Clone for ToolRouter { fn clone(&self) -> Self { Self { map: self.map.clone(), transparent_when_not_found: self.transparent_when_not_found, disabled: self.disabled.clone(), + notifier: self.notifier.clone(), } } } @@ -346,11 +364,7 @@ where S: MaybeSend + 'static, { pub fn new() -> Self { - Self { - map: std::collections::HashMap::new(), - transparent_when_not_found: false, - disabled: std::collections::HashSet::new(), - } + Self::default() } pub fn with_route(mut self, route: R) -> Self where @@ -426,29 +440,30 @@ where self.map.contains_key(name) && !self.disabled.contains(name) } - /// Disable a tool by name so it is hidden from `list_all`, `get`, and - /// rejected by `call`. The tool remains in the router and can be - /// re-enabled later with [`enable_route`](Self::enable_route). + /// Disable a tool by name. Hidden from `list_all`, `get`, rejected by + /// `call`. Re-enable with [`enable_route`](Self::enable_route). /// /// Returns `true` if the name was newly added to the disabled set. /// The name is recorded even if no matching route exists yet, so routes - /// added later (via [`add_route`](Self::add_route) or - /// [`merge`](Self::merge)) will inherit the disabled state. - /// - /// Callers should send `Peer::notify_tool_list_changed` when the - /// visible tool list changes. Accepts `&'static str` or `String`; - /// for a non-static `&str`, call `.to_owned()` first. + /// added later will inherit the disabled state. pub fn disable_route(&mut self, name: impl Into>) -> bool { - self.disabled.insert(name.into()) + let name = name.into(); + let was_visible = self.map.contains_key(&name) && !self.disabled.contains(&name); + let newly_disabled = self.disabled.insert(name.clone()); + if was_visible && newly_disabled { + self.notify_if_visible(&name); + } + newly_disabled } /// Re-enable a previously disabled tool. Returns `true` if the name - /// was present in the disabled set and was removed. - /// - /// Callers should send `Peer::notify_tool_list_changed` when the - /// visible tool list changes. + /// was in the disabled set. pub fn enable_route(&mut self, name: &str) -> bool { - self.disabled.remove(name) + let removed = self.disabled.remove(name); + if removed { + self.notify_if_visible(name); + } + removed } /// Returns `true` if the tool exists in the router **and** is currently @@ -468,6 +483,58 @@ where self } + /// Install a callback invoked when the visible tool list changes. + pub fn set_notifier(&mut self, f: impl Fn() + Send + Sync + 'static) { + self.notifier = Some(Arc::new(f)); + } + + pub fn clear_notifier(&mut self) { + self.notifier = None; + } + + /// Install a notifier that sends `notifications/tools/list_changed` + /// via the given peer. + pub fn bind_peer_notifier(&mut self, peer: &crate::service::Peer) { + let peer = peer.clone(); + self.set_notifier(move || { + let peer = peer.clone(); + tokio::spawn(async move { + if let Err(e) = peer.notify_tool_list_changed().await { + tracing::warn!("failed to send tools/list_changed notification: {e}"); + } + }); + }); + } + + /// Deferred notifier: no-op until the peer slot is filled. + pub(crate) fn deferred_peer_notifier() -> ( + impl Fn() + Send + Sync + 'static, + Arc>>, + ) { + let peer_slot = + Arc::new(std::sync::OnceLock::>::new()); + let slot_clone = peer_slot.clone(); + let notifier = move || { + if let Some(peer) = slot_clone.get() { + let peer = peer.clone(); + tokio::spawn(async move { + if let Err(e) = peer.notify_tool_list_changed().await { + tracing::warn!("failed to send tools/list_changed notification: {e}"); + } + }); + } + }; + (notifier, peer_slot) + } + + fn notify_if_visible(&self, name: &str) { + if self.map.contains_key(name) { + if let Some(notifier) = &self.notifier { + (notifier)(); + } + } + } + pub async fn call( &self, context: ToolCallContext<'_, S>, diff --git a/crates/rmcp/tests/test_tool_disable_notification.rs b/crates/rmcp/tests/test_tool_disable_notification.rs new file mode 100644 index 000000000..84037b59a --- /dev/null +++ b/crates/rmcp/tests/test_tool_disable_notification.rs @@ -0,0 +1,172 @@ +//! Integration tests for tool list change notifications. +#![cfg(all(feature = "client", not(feature = "local")))] + +use std::sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, +}; + +use rmcp::{ + ClientHandler, RoleClient, RoleServer, ServerHandler, ServiceExt, + handler::server::{router::tool::ToolRoute, tool::ToolCallContext}, + model::{CallToolResult, ServerCapabilities, ServerInfo, Tool}, + service::{MaybeSendFuture, NotificationContext}, +}; +use tokio::sync::{Notify, RwLock}; + +#[derive(Clone)] +struct TestToolServer { + router: Arc>>, + trigger_disable: Arc, + trigger_enable: Arc, +} + +impl TestToolServer { + fn new() -> Self { + let mut tool_router = rmcp::handler::server::router::tool::ToolRouter::::new(); + tool_router.add_route(ToolRoute::new_dyn( + Tool::new("tool_a", "Tool A", Arc::new(Default::default())), + |_ctx| Box::pin(async { Ok(CallToolResult::default()) }), + )); + tool_router.add_route(ToolRoute::new_dyn( + Tool::new("tool_b", "Tool B", Arc::new(Default::default())), + |_ctx| Box::pin(async { Ok(CallToolResult::default()) }), + )); + Self { + router: Arc::new(RwLock::new(tool_router)), + trigger_disable: Arc::new(Notify::new()), + trigger_enable: Arc::new(Notify::new()), + } + } +} + +impl ServerHandler for TestToolServer { + fn get_info(&self) -> ServerInfo { + ServerInfo::new(ServerCapabilities::builder().enable_tools().build()) + } + + fn call_tool( + &self, + request: rmcp::model::CallToolRequestParams, + context: rmcp::service::RequestContext, + ) -> impl std::future::Future> + MaybeSendFuture + '_ + { + async move { + let router = self.router.read().await; + let tcc = ToolCallContext::new(self, request, context); + router.call(tcc).await + } + } + + fn list_tools( + &self, + _request: Option, + _context: rmcp::service::RequestContext, + ) -> impl std::future::Future> + + MaybeSendFuture + + '_ { + async move { + let router = self.router.read().await; + Ok(rmcp::model::ListToolsResult { + tools: router.list_all(), + ..Default::default() + }) + } + } + + fn on_initialized( + &self, + context: NotificationContext, + ) -> impl std::future::Future + MaybeSendFuture + '_ { + let router = self.router.clone(); + let trigger_disable = self.trigger_disable.clone(); + let trigger_enable = self.trigger_enable.clone(); + let peer = context.peer.clone(); + + async move { + router.write().await.bind_peer_notifier(&peer); + + let router = router.clone(); + tokio::spawn(async move { + trigger_disable.notified().await; + { + let mut r = router.write().await; + r.disable_route("tool_a"); + } + + trigger_enable.notified().await; + { + let mut r = router.write().await; + r.enable_route("tool_a"); + } + }); + } + } +} + +#[derive(Clone)] +struct TestToolClient { + notification_count: Arc, + notify: Arc, +} + +impl TestToolClient { + fn new() -> Self { + Self { + notification_count: Arc::new(AtomicUsize::new(0)), + notify: Arc::new(Notify::new()), + } + } +} + +impl ClientHandler for TestToolClient { + fn on_tool_list_changed( + &self, + _context: NotificationContext, + ) -> impl std::future::Future + MaybeSendFuture + '_ { + self.notification_count.fetch_add(1, Ordering::SeqCst); + self.notify.notify_one(); + std::future::ready(()) + } +} + +#[tokio::test] +async fn test_disable_enable_sends_tool_list_changed() { + let server = TestToolServer::new(); + let trigger_disable = server.trigger_disable.clone(); + let trigger_enable = server.trigger_enable.clone(); + + let client = TestToolClient::new(); + let notification_count = client.notification_count.clone(); + let client_notify = client.notify.clone(); + + let (server_transport, client_transport) = tokio::io::duplex(4096); + + let server_handle = tokio::spawn(async move { server.serve(server_transport).await }); + let client_service = client.serve(client_transport).await.unwrap(); + + let tools = client_service.peer().list_tools(None).await.unwrap(); + assert_eq!(tools.tools.len(), 2); + + trigger_disable.notify_one(); + tokio::time::timeout(std::time::Duration::from_secs(5), client_notify.notified()) + .await + .expect("timed out waiting for tool_list_changed"); + assert_eq!(notification_count.load(Ordering::SeqCst), 1); + + let tools = client_service.peer().list_tools(None).await.unwrap(); + assert_eq!(tools.tools.len(), 1); + assert_eq!(tools.tools[0].name, "tool_b"); + + trigger_enable.notify_one(); + tokio::time::timeout(std::time::Duration::from_secs(5), client_notify.notified()) + .await + .expect("timed out waiting for tool_list_changed"); + assert_eq!(notification_count.load(Ordering::SeqCst), 2); + + let tools = client_service.peer().list_tools(None).await.unwrap(); + assert_eq!(tools.tools.len(), 2); + + client_service.cancel().await.unwrap(); + server_handle.abort(); +} diff --git a/crates/rmcp/tests/test_tool_routers.rs b/crates/rmcp/tests/test_tool_routers.rs index 0bb50e43d..f2e28b0f3 100644 --- a/crates/rmcp/tests/test_tool_routers.rs +++ b/crates/rmcp/tests/test_tool_routers.rs @@ -1,5 +1,8 @@ #![cfg(not(feature = "local"))] -use std::collections::HashMap; +use std::{ + collections::HashMap, + sync::atomic::{AtomicUsize, Ordering}, +}; use futures::future::BoxFuture; use rmcp::{ @@ -270,3 +273,100 @@ fn test_disable_enable_return_false_cases() { // Enable on unknown name returns false assert!(!router.enable_route("unknown")); } + +// ── Notifier tests ────────────────────────────────────────────────────── + +fn counter_notifier() -> ( + impl Fn() + Send + Sync + 'static, + std::sync::Arc, +) { + let counter = std::sync::Arc::new(AtomicUsize::new(0)); + let c = counter.clone(); + let notifier = move || { + c.fetch_add(1, Ordering::SeqCst); + }; + (notifier, counter) +} + +#[test] +fn test_notifier_fires_on_disable_and_enable() { + let (notifier, counter) = counter_notifier(); + let mut router = build_router(); + router.set_notifier(notifier); + + assert!(router.disable_route("async_function")); + assert_eq!(counter.load(Ordering::SeqCst), 1); + + assert!(!router.disable_route("async_function")); + assert_eq!(counter.load(Ordering::SeqCst), 1); + + assert!(router.enable_route("async_function")); + assert_eq!(counter.load(Ordering::SeqCst), 2); + + assert!(!router.enable_route("async_function")); + assert_eq!(counter.load(Ordering::SeqCst), 2); +} + +#[test] +fn test_notifier_skips_nonexistent_tools() { + let (notifier, counter) = counter_notifier(); + let mut router = build_router(); + router.set_notifier(notifier); + + assert!(router.disable_route("does_not_exist")); + assert_eq!(counter.load(Ordering::SeqCst), 0); + + assert!(router.enable_route("does_not_exist")); + assert_eq!(counter.load(Ordering::SeqCst), 0); + + assert!(router.disable_route("future_tool")); + assert_eq!(counter.load(Ordering::SeqCst), 0); + assert!(router.enable_route("future_tool")); + assert_eq!(counter.load(Ordering::SeqCst), 0); +} + +#[test] +fn test_no_notifier_no_panic() { + let mut router = build_router(); + assert!(router.disable_route("async_function")); + assert!(router.enable_route("async_function")); + assert!(router.disable_route("async_function")); + assert!(!router.disable_route("async_function")); +} + +#[test] +fn test_clone_shares_notifier() { + let (notifier, counter) = counter_notifier(); + let mut router = build_router(); + router.set_notifier(notifier); + let mut cloned = router.clone(); + + assert!(cloned.disable_route("async_function")); + assert_eq!(counter.load(Ordering::SeqCst), 1); + + assert!(router.disable_route("async_function")); + assert_eq!(counter.load(Ordering::SeqCst), 2); + + cloned.clear_notifier(); + assert!(cloned.enable_route("async_function")); + assert_eq!(counter.load(Ordering::SeqCst), 2); + + assert!(router.enable_route("async_function")); + assert_eq!(counter.load(Ordering::SeqCst), 3); +} + +#[test] +fn test_pre_init_disable_silent_but_correct() { + let mut router = build_router(); + + assert!(router.disable_route("async_function")); + assert_eq!(router.list_all().len(), 3); + assert!(!router.has_route("async_function")); + + let (notifier, counter) = counter_notifier(); + router.set_notifier(notifier); + assert_eq!(counter.load(Ordering::SeqCst), 0); + + assert!(router.enable_route("async_function")); + assert_eq!(counter.load(Ordering::SeqCst), 1); +}