diff --git a/src/allow_list.rs b/src/allow_list.rs new file mode 100644 index 000000000..bd2f92d35 --- /dev/null +++ b/src/allow_list.rs @@ -0,0 +1,106 @@ +//! Pluggable source for the per-platform "allow-list" of user IDs. +//! +//! The default in-tree implementation is [`StaticAllowList`], which wraps the +//! static set produced from a platform's `allowed_users` config field. +//! Downstream forks may provide alternative implementations (e.g. a +//! file-watching impl that mirrors an IdP group) without modifying the +//! gate-check call sites in the adapters. + +use std::collections::HashSet; +use std::sync::Arc; + +/// Provides the current set of user IDs allowed to interact with the bot. +/// +/// Implementations must be cheap to call repeatedly: the dispatch path calls +/// [`AllowListSource::allowed_users`] once per inbound message. Returning an +/// `Arc>` lets implementations that hot-swap the underlying +/// set (e.g. via `arc_swap`) hand out a consistent snapshot to each caller +/// without taking a lock on the read path. +pub trait AllowListSource: Send + Sync { + /// Returns a snapshot of the currently-allowed user IDs. + fn allowed_users(&self) -> Arc>; +} + +/// In-tree default implementation: wraps a fixed set loaded once at startup +/// from configuration. Snapshots share a single `Arc`-backed allocation, so +/// the read path is allocation-free. +pub struct StaticAllowList { + users: Arc>, +} + +impl StaticAllowList { + pub fn new(users: HashSet) -> Self { + Self { + users: Arc::new(users), + } + } +} + +impl AllowListSource for StaticAllowList { + fn allowed_users(&self) -> Arc> { + Arc::clone(&self.users) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn static_round_trips_input() { + let input: HashSet = ["U1", "U2"].iter().map(|s| s.to_string()).collect(); + let source = StaticAllowList::new(input.clone()); + assert_eq!(*source.allowed_users(), input); + } + + #[test] + fn static_snapshots_share_allocation() { + let input: HashSet = ["U1"].iter().map(|s| s.to_string()).collect(); + let source = StaticAllowList::new(input); + let a = source.allowed_users(); + let b = source.allowed_users(); + assert!(Arc::ptr_eq(&a, &b)); + } + + /// Mock impl proving the seam supports downstream impls that swap the + /// underlying set at runtime. Mutex-guarded for test simplicity; a real + /// hot-reload impl would use `arc_swap::ArcSwap`. + struct SwappableSource { + inner: std::sync::Mutex>>, + } + + impl SwappableSource { + fn new(initial: HashSet) -> Self { + Self { + inner: std::sync::Mutex::new(Arc::new(initial)), + } + } + + fn swap(&self, next: HashSet) { + *self.inner.lock().unwrap() = Arc::new(next); + } + } + + impl AllowListSource for SwappableSource { + fn allowed_users(&self) -> Arc> { + Arc::clone(&self.inner.lock().unwrap()) + } + } + + #[test] + fn custom_source_can_hot_swap_through_trait_object() { + let initial: HashSet = ["U1"].iter().map(|s| s.to_string()).collect(); + let typed = Arc::new(SwappableSource::new(initial)); + let dyn_source: Arc = typed.clone(); + + let before = dyn_source.allowed_users(); + assert!(before.contains("U1")); + + let next: HashSet = ["U2"].iter().map(|s| s.to_string()).collect(); + typed.swap(next); + + let after = dyn_source.allowed_users(); + assert!(after.contains("U2")); + assert!(!after.contains("U1")); + } +} diff --git a/src/main.rs b/src/main.rs index 0252193b8..ed31f4396 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,6 @@ mod acp; mod adapter; +mod allow_list; mod bot_turns; mod config; mod cron; @@ -382,6 +383,10 @@ async fn main() -> anyhow::Result<()> { )); dispatchers.lock().unwrap().push(slack_dispatcher.clone()); let slack_ctl_registry = ctl_registry.clone(); + let slack_allow_list: Arc = + Arc::new(allow_list::StaticAllowList::new( + slack_cfg.allowed_users.into_iter().collect(), + )); Some(tokio::spawn(async move { if let Err(e) = slack::run_slack_adapter( adapter, @@ -389,7 +394,7 @@ async fn main() -> anyhow::Result<()> { allow_all_channels, allow_all_users, slack_cfg.allowed_channels.into_iter().collect(), - slack_cfg.allowed_users.into_iter().collect(), + slack_allow_list, slack_cfg.allow_bot_messages, slack_cfg.trusted_bot_ids.into_iter().collect(), slack_cfg.allow_user_messages, diff --git a/src/slack.rs b/src/slack.rs index 42986d25f..bcd22c797 100644 --- a/src/slack.rs +++ b/src/slack.rs @@ -1,5 +1,6 @@ use crate::acp::ContentBlock; use crate::adapter::{ChannelRef, ChatAdapter, MessageRef, SenderContext}; +use crate::allow_list::AllowListSource; use crate::bot_turns::{BotTurnTracker, TurnAction, TurnSeverity}; use crate::config::{AllowBots, AllowUsers, SttConfig}; use crate::media; @@ -712,7 +713,7 @@ pub async fn run_slack_adapter( allow_all_channels: bool, allow_all_users: bool, allowed_channels: HashSet, - allowed_users: HashSet, + allowed_users: Arc, allow_bot_messages: AllowBots, trusted_bot_ids: HashSet, allow_user_messages: AllowUsers, @@ -827,7 +828,7 @@ pub async fn run_slack_adapter( let adapter = adapter.clone(); let bot_token = bot_token.clone(); let allowed_channels = allowed_channels.clone(); - let allowed_users = allowed_users.clone(); + let allowed_users = allowed_users.allowed_users(); let stt_config = stt_config.clone(); let dispatcher = dispatcher.clone(); let ctl_registry = ctl_registry.clone(); @@ -1080,7 +1081,7 @@ pub async fn run_slack_adapter( let adapter = adapter.clone(); let bot_token = bot_token.clone(); let allowed_channels = allowed_channels.clone(); - let allowed_users = allowed_users.clone(); + let allowed_users = allowed_users.allowed_users(); let stt_config = stt_config.clone(); let dispatcher = dispatcher.clone(); let ctl_registry = ctl_registry.clone();