From 195036ad19b989cacc7a94037e4303a7d921d5df Mon Sep 17 00:00:00 2001 From: chaodu-agent Date: Thu, 18 Jun 2026 20:13:58 +0000 Subject: [PATCH 01/20] feat: restructure as Cargo workspace with unified binary feature flags MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements ADR #1116 — separate binaries with opt-in unified build. Workspace layout: - crates/openab-core: core library (Discord, Slack, ACP, Dispatcher) - crates/openab-gateway: gateway adapters (Telegram, LINE, Feishu, etc.) - Root Cargo.toml: workspace + final binary with feature flag routing Feature flags: - default = discord + slack (unchanged behavior) - unified = all gateway adapters compiled in - Granular: --no-default-features --features telegram (single adapter) Dockerfile: - BUILD_MODE=unified → all-in-one binary - FEATURES=telegram,line → custom adapter selection (with --no-default-features) --- Cargo.toml | 64 +- Dockerfile | 16 +- crates/openab-core/Cargo.toml | 54 + crates/openab-core/src/acp/agentcore.rs | 722 +++ crates/openab-core/src/acp/connection.rs | 937 ++++ crates/openab-core/src/acp/mod.rs | 9 + crates/openab-core/src/acp/pool.rs | 622 +++ crates/openab-core/src/acp/protocol.rs | 406 ++ crates/openab-core/src/adapter.rs | 1659 +++++++ crates/openab-core/src/bot_turns.rs | 368 ++ crates/openab-core/src/config.rs | 1500 +++++++ crates/openab-core/src/cron.rs | 1768 ++++++++ crates/openab-core/src/directives.rs | 314 ++ crates/openab-core/src/discord.rs | 3203 ++++++++++++++ crates/openab-core/src/dispatch.rs | 1727 ++++++++ crates/openab-core/src/error_display.rs | 323 ++ crates/openab-core/src/format.rs | 327 ++ crates/openab-core/src/gateway.rs | 1054 +++++ crates/openab-core/src/hooks.rs | 425 ++ crates/openab-core/src/lib.rs | 25 + crates/openab-core/src/markdown.rs | 349 ++ crates/openab-core/src/media.rs | 846 ++++ crates/openab-core/src/multibot_cache.rs | 85 + crates/openab-core/src/reactions.rs | 276 ++ crates/openab-core/src/remind.rs | 399 ++ crates/openab-core/src/secrets.rs | 479 ++ crates/openab-core/src/setup/config.rs | 157 + crates/openab-core/src/setup/mod.rs | 12 + crates/openab-core/src/setup/validate.rs | 78 + crates/openab-core/src/setup/wizard.rs | 667 +++ crates/openab-core/src/slack.rs | 2329 ++++++++++ crates/openab-core/src/stt.rs | 354 ++ crates/openab-core/src/timestamp.rs | 114 + crates/openab-gateway/Cargo.toml | 43 + crates/openab-gateway/src/adapters/feishu.rs | 3928 +++++++++++++++++ .../openab-gateway/src/adapters/googlechat.rs | 2470 +++++++++++ crates/openab-gateway/src/adapters/line.rs | 780 ++++ crates/openab-gateway/src/adapters/mod.rs | 12 + crates/openab-gateway/src/adapters/teams.rs | 877 ++++ .../openab-gateway/src/adapters/telegram.rs | 782 ++++ crates/openab-gateway/src/adapters/wecom.rs | 1654 +++++++ crates/openab-gateway/src/lib.rs | 4 + crates/openab-gateway/src/media.rs | 123 + crates/openab-gateway/src/schema.rs | 126 + crates/openab-gateway/src/store.rs | 132 + 45 files changed, 32560 insertions(+), 39 deletions(-) create mode 100644 crates/openab-core/Cargo.toml create mode 100644 crates/openab-core/src/acp/agentcore.rs create mode 100644 crates/openab-core/src/acp/connection.rs create mode 100644 crates/openab-core/src/acp/mod.rs create mode 100644 crates/openab-core/src/acp/pool.rs create mode 100644 crates/openab-core/src/acp/protocol.rs create mode 100644 crates/openab-core/src/adapter.rs create mode 100644 crates/openab-core/src/bot_turns.rs create mode 100644 crates/openab-core/src/config.rs create mode 100644 crates/openab-core/src/cron.rs create mode 100644 crates/openab-core/src/directives.rs create mode 100644 crates/openab-core/src/discord.rs create mode 100644 crates/openab-core/src/dispatch.rs create mode 100644 crates/openab-core/src/error_display.rs create mode 100644 crates/openab-core/src/format.rs create mode 100644 crates/openab-core/src/gateway.rs create mode 100644 crates/openab-core/src/hooks.rs create mode 100644 crates/openab-core/src/lib.rs create mode 100644 crates/openab-core/src/markdown.rs create mode 100644 crates/openab-core/src/media.rs create mode 100644 crates/openab-core/src/multibot_cache.rs create mode 100644 crates/openab-core/src/reactions.rs create mode 100644 crates/openab-core/src/remind.rs create mode 100644 crates/openab-core/src/secrets.rs create mode 100644 crates/openab-core/src/setup/config.rs create mode 100644 crates/openab-core/src/setup/mod.rs create mode 100644 crates/openab-core/src/setup/validate.rs create mode 100644 crates/openab-core/src/setup/wizard.rs create mode 100644 crates/openab-core/src/slack.rs create mode 100644 crates/openab-core/src/stt.rs create mode 100644 crates/openab-core/src/timestamp.rs create mode 100644 crates/openab-gateway/Cargo.toml create mode 100644 crates/openab-gateway/src/adapters/feishu.rs create mode 100644 crates/openab-gateway/src/adapters/googlechat.rs create mode 100644 crates/openab-gateway/src/adapters/line.rs create mode 100644 crates/openab-gateway/src/adapters/mod.rs create mode 100644 crates/openab-gateway/src/adapters/teams.rs create mode 100644 crates/openab-gateway/src/adapters/telegram.rs create mode 100644 crates/openab-gateway/src/adapters/wecom.rs create mode 100644 crates/openab-gateway/src/lib.rs create mode 100644 crates/openab-gateway/src/media.rs create mode 100644 crates/openab-gateway/src/schema.rs create mode 100644 crates/openab-gateway/src/store.rs diff --git a/Cargo.toml b/Cargo.toml index 162df3490..88b2456d6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,3 +1,6 @@ +[workspace] +members = ["crates/openab-core", "crates/openab-gateway"] + [package] name = "openab" version = "0.8.5" @@ -5,48 +8,33 @@ edition = "2021" license = "MIT" [dependencies] +openab-core = { path = "crates/openab-core", default-features = false } +openab-gateway = { path = "crates/openab-gateway", default-features = false, optional = true } tokio = { version = "1", features = ["full"] } -serde = { version = "1", features = ["derive"] } -serde_json = "1" -toml = "0.8" -toml_edit = "0.22" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] } -serenity = { version = "0.12", default-features = false, features = ["client", "gateway", "model", "rustls_backend", "cache"] } -uuid = { version = "1", features = ["v4"] } -regex = "1" -anyhow = "1" -async-trait = "0.1" -futures-util = "0.3" -rand = "0.8" clap = { version = "4", features = ["derive"] } -rpassword = "7" -reqwest = { version = "0.12", default-features = false, features = ["rustls-tls", "multipart", "json", "blocking"] } -base64 = "0.22" -image = { version = "0.25", default-features = false, features = ["jpeg", "png", "gif", "webp"] } -unicode-width = "0.2" -pulldown-cmark = { version = "0.13", default-features = false } -tokio-tungstenite = { version = "0.21", features = ["rustls-tls-webpki-roots"] } -rustls = { version = "0.22", optional = true } -tokio-rustls = { version = "0.25", optional = true } -webpki-roots = { version = "0.26", optional = true } -cron = "0.16.0" -chrono = { version = "0.4.44", features = ["serde"] } -chrono-tz = "0.10.4" -sha2 = "0.10" -tempfile = "3.27.0" -aws-sdk-secretsmanager = { version = "1", optional = true } -aws-config = { version = "1", optional = true } -aws-sigv4 = { version = "1", optional = true } -aws-credential-types = { version = "1", optional = true } -urlencoding = { version = "2", optional = true } -hex = { version = "0.4", optional = true } -http = { version = "1", optional = true } +anyhow = "1" [features] -default = ["secrets-aws", "agentcore"] -secrets-aws = ["dep:aws-sdk-secretsmanager", "dep:aws-config"] -agentcore = ["dep:aws-config", "dep:aws-sigv4", "dep:aws-credential-types", "dep:urlencoding", "dep:hex", "dep:http", "dep:rustls", "dep:tokio-rustls", "dep:webpki-roots"] +# Default: core only (Discord + Slack). Gateway ships as separate binary. +default = ["discord", "slack", "secrets-aws", "agentcore"] + +# Opt-in: compile all gateway adapters into a single unified binary +unified = ["telegram", "line", "feishu", "googlechat", "wecom", "teams"] + +# Core adapters +discord = ["openab-core/discord"] +slack = ["openab-core/slack"] + +# Core optional features +secrets-aws = ["openab-core/secrets-aws"] +agentcore = ["openab-core/agentcore"] -[target.'cfg(unix)'.dependencies] -libc = "0.2" +# Gateway adapters (each pulls in the gateway crate) +telegram = ["dep:openab-gateway", "openab-gateway/telegram"] +line = ["dep:openab-gateway", "openab-gateway/line"] +feishu = ["dep:openab-gateway", "openab-gateway/feishu"] +googlechat = ["dep:openab-gateway", "openab-gateway/googlechat"] +wecom = ["dep:openab-gateway", "openab-gateway/wecom"] +teams = ["dep:openab-gateway", "openab-gateway/teams"] diff --git a/Dockerfile b/Dockerfile index cdfde7414..f3222aa72 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,10 +1,24 @@ # --- Build stage --- +ARG BUILD_MODE=default +ARG FEATURES="" + FROM rust:1-bookworm AS builder +ARG BUILD_MODE +ARG FEATURES + WORKDIR /build COPY Cargo.toml Cargo.lock ./ +COPY crates/ crates/ RUN mkdir src && echo 'fn main() {}' > src/main.rs && cargo build --release && rm -rf src COPY src/ src/ -RUN touch src/main.rs && cargo build --release +RUN touch src/main.rs && \ + if [ "$BUILD_MODE" = "unified" ]; then \ + cargo build --release --features unified; \ + elif [ -n "$FEATURES" ]; then \ + cargo build --release --no-default-features --features "$FEATURES"; \ + else \ + cargo build --release; \ + fi # --- Runtime stage --- FROM debian:bookworm-slim diff --git a/crates/openab-core/Cargo.toml b/crates/openab-core/Cargo.toml new file mode 100644 index 000000000..5136e69a3 --- /dev/null +++ b/crates/openab-core/Cargo.toml @@ -0,0 +1,54 @@ +[package] +name = "openab-core" +version = "0.8.5" +edition = "2021" +license = "MIT" + +[dependencies] +tokio = { version = "1", features = ["full"] } +serde = { version = "1", features = ["derive"] } +serde_json = "1" +toml = "0.8" +toml_edit = "0.22" +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] } +serenity = { version = "0.12", default-features = false, features = ["client", "gateway", "model", "rustls_backend", "cache"] } +uuid = { version = "1", features = ["v4"] } +regex = "1" +anyhow = "1" +async-trait = "0.1" +futures-util = "0.3" +rand = "0.8" +clap = { version = "4", features = ["derive"] } +rpassword = "7" +reqwest = { version = "0.12", default-features = false, features = ["rustls-tls", "multipart", "json", "blocking"] } +base64 = "0.22" +image = { version = "0.25", default-features = false, features = ["jpeg", "png", "gif", "webp"] } +unicode-width = "0.2" +pulldown-cmark = { version = "0.13", default-features = false } +tokio-tungstenite = { version = "0.21", features = ["rustls-tls-webpki-roots"] } +rustls = { version = "0.22", optional = true } +tokio-rustls = { version = "0.25", optional = true } +webpki-roots = { version = "0.26", optional = true } +cron = "0.16.0" +chrono = { version = "0.4.44", features = ["serde"] } +chrono-tz = "0.10.4" +sha2 = "0.10" +tempfile = "3.27.0" +aws-sdk-secretsmanager = { version = "1", optional = true } +aws-config = { version = "1", optional = true } +aws-sigv4 = { version = "1", optional = true } +aws-credential-types = { version = "1", optional = true } +urlencoding = { version = "2", optional = true } +hex = { version = "0.4", optional = true } +http = { version = "1", optional = true } + +[target.'cfg(unix)'.dependencies] +libc = "0.2" + +[features] +default = ["discord", "slack", "secrets-aws", "agentcore"] +discord = [] +slack = [] +secrets-aws = ["dep:aws-sdk-secretsmanager", "dep:aws-config"] +agentcore = ["dep:aws-config", "dep:aws-sigv4", "dep:aws-credential-types", "dep:urlencoding", "dep:hex", "dep:http", "dep:rustls", "dep:tokio-rustls", "dep:webpki-roots"] diff --git a/crates/openab-core/src/acp/agentcore.rs b/crates/openab-core/src/acp/agentcore.rs new file mode 100644 index 000000000..86696d357 --- /dev/null +++ b/crates/openab-core/src/acp/agentcore.rs @@ -0,0 +1,722 @@ +//! AgentCore ACP bridge — stdin/stdout subprocess that bridges ACP JSON-RPC +//! to AgentCore's InvokeAgentRuntimeCommandShell WebSocket API. +//! +//! Invoked as: `openab --agentcore-bridge --runtime-arn ARN --region REGION` +//! +//! Opens a persistent PTY shell in the microVM, launches `kiro-cli acp +//! --trust-all-tools`, and forwards JSON-RPC bidirectionally. + +use anyhow::{anyhow, Result}; +use aws_credential_types::provider::ProvideCredentials; +use aws_sigv4::http_request::{sign, SignableBody, SignableRequest, SigningSettings}; +use aws_sigv4::sign::v4; +use futures_util::{SinkExt, StreamExt}; +use serde_json::{json, Value}; +use sha2::{Digest, Sha256}; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::SystemTime; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; +use tokio::sync::Mutex; +use tokio_tungstenite::tungstenite::http; +use tokio_tungstenite::tungstenite::protocol::Message; +use tracing::info; + +const AGENT_CMD_PREFIX: &str = "stty -echo 2>/dev/null; mkdir -p /tmp/kiro-cli && cp -n /mnt/agent/.local/share/kiro-cli/data.sqlite3 /tmp/kiro-cli/ 2>/dev/null; export XDG_DATA_HOME=/tmp; exec "; + +/// WebSocket binary frame channel bytes (1-byte prefix protocol). +const CHANNEL_STDIN: u8 = 0x00; +const CHANNEL_STDOUT: u8 = 0x01; +const CHANNEL_STDERR: u8 = 0x02; + +/// Extract a complete JSON object from a line that may have PTY prefix noise. +/// Uses brace-counting to find matching `{}` pairs, robust against partial JSON +/// or embedded `{` in prompt text. +fn extract_json_object(line: &str) -> Option { + let bytes = line.as_bytes(); + let start = bytes.iter().position(|&b| b == b'{')?; + + let mut depth: i32 = 0; + let mut in_string = false; + let mut escape = false; + + for (i, &c) in bytes.iter().enumerate().skip(start) { + if escape { + escape = false; + continue; + } + if c == b'\\' && in_string { + escape = true; + continue; + } + if c == b'"' { + in_string = !in_string; + continue; + } + if in_string { + continue; + } + if c == b'{' { + depth += 1; + } else if c == b'}' { + depth -= 1; + if depth == 0 { + let candidate = &line[start..=i]; + // Validate it's actually valid JSON + if serde_json::from_str::(candidate).is_ok() { + return Some(candidate.to_string()); + } + // Not valid — try next `{` + return extract_json_object(&line[start + 1..]); + } + } + } + None +} + +/// Entry point for the agentcore bridge subprocess. +pub async fn run_bridge(runtime_arn: &str, region: &str, agent_command: &str) -> Result<()> { + let stdin = BufReader::new(tokio::io::stdin()); + let stdout = tokio::io::stdout(); + + let mut bridge = Bridge::new(runtime_arn, region, agent_command, stdin, stdout); + bridge.run().await +} + +struct Bridge { + runtime_arn: String, + region: String, + agent_command: String, + stdin: R, + stdout: W, + sessions: HashMap, + next_id: u64, +} + +struct ShellHandle { + /// Sender for writing to the WebSocket (stdin of shell) + ws_write: Arc>, + /// Buffered output from kiro-cli (stdout of shell via WebSocket) + line_rx: tokio::sync::mpsc::UnboundedReceiver, + /// Pump task handle + _pump: tokio::task::JoinHandle<()>, + /// Runtime session ID (for future reconnect support). + #[allow(dead_code)] + runtime_session_id: String, + /// kiro-cli's internal ACP session ID + kiro_session_id: String, +} + +type WsSink = futures_util::stream::SplitSink< + tokio_tungstenite::WebSocketStream>, + Message, +>; + +impl Bridge +where + R: AsyncBufReadExt + Unpin, + W: AsyncWriteExt + Unpin, +{ + fn new(runtime_arn: &str, region: &str, agent_command: &str, stdin: R, stdout: W) -> Self { + Self { + runtime_arn: runtime_arn.to_string(), + region: region.to_string(), + agent_command: agent_command.to_string(), + stdin, + stdout, + sessions: HashMap::new(), + next_id: 1000, + } + } + + fn alloc_id(&mut self) -> u64 { + self.next_id += 1; + self.next_id + } + + async fn write_msg(&mut self, msg: &Value) -> Result<()> { + let data = serde_json::to_string(msg)?; + self.stdout.write_all(data.as_bytes()).await?; + self.stdout.write_all(b"\n").await?; + self.stdout.flush().await?; + Ok(()) + } + + async fn write_response(&mut self, id: &Value, result: Value) -> Result<()> { + self.write_msg(&json!({"jsonrpc": "2.0", "id": id, "result": result})) + .await + } + + async fn write_error(&mut self, id: &Value, code: i32, message: &str) -> Result<()> { + self.write_msg( + &json!({"jsonrpc": "2.0", "id": id, "error": {"code": code, "message": message}}), + ) + .await + } + + async fn run(&mut self) -> Result<()> { + let mut line = String::new(); + loop { + line.clear(); + let n = self.stdin.read_line(&mut line).await?; + if n == 0 { + break; // EOF + } + let trimmed = line.trim(); + if trimmed.is_empty() { + continue; + } + let msg: Value = match serde_json::from_str(trimmed) { + Ok(v) => v, + Err(_) => continue, + }; + + let method = msg.get("method").and_then(|m| m.as_str()).unwrap_or(""); + let id = msg.get("id").cloned().unwrap_or(Value::Null); + let params = msg.get("params").cloned().unwrap_or(json!({})); + + // Skip messages without a method (e.g. stray responses) — same fix as Python F1 + if method.is_empty() { + continue; + } + + match method { + "initialize" => { + self.write_response( + &id, + json!({ + "protocolVersion": 1, + "agentInfo": {"name": "agentcore-shell-bridge", "version": "0.2.0"}, + "agentCapabilities": {"loadSession": true} + }), + ) + .await?; + } + "session/new" => { + let ts = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis(); + let acp_sid = format!("agentcore-{ts}"); + let runtime_sid = format!("oab-session-{ts:020}-{ts:013x}"); + + // Eagerly open shell + initialize the agent + match self.open_shell(&runtime_sid).await { + Ok(handle) => { + self.sessions.insert(acp_sid.clone(), handle); + self.write_response(&id, json!({"sessionId": acp_sid})) + .await?; + } + Err(e) => { + self.write_error(&id, -32000, &format!("shell init failed: {e}")) + .await?; + } + } + } + "session/load" => { + let acp_sid = params + .get("sessionId") + .and_then(|s| s.as_str()) + .unwrap_or("") + .to_string(); + self.write_response(&id, json!({"sessionId": acp_sid})) + .await?; + } + "session/prompt" => { + self.handle_prompt(&id, ¶ms).await?; + } + "session/cancel" | "cancel" => { + self.handle_cancel(¶ms).await; + } + "session/destroy" | "session/stop" => { + let acp_sid = params + .get("sessionId") + .and_then(|s| s.as_str()) + .unwrap_or("") + .to_string(); + self.sessions.remove(&acp_sid); + if id != Value::Null { + self.write_response(&id, json!({})).await?; + } + } + "session/request_permission" => { + if id != Value::Null { + self.write_response(&id, json!({"approved": true})).await?; + } + } + _ => { + if id != Value::Null { + self.write_error(&id, -32601, &format!("unknown method: {method}")) + .await?; + } + } + } + } + Ok(()) + } + + async fn handle_prompt(&mut self, id: &Value, params: &Value) -> Result<()> { + let acp_sid = params + .get("sessionId") + .and_then(|s| s.as_str()) + .unwrap_or("") + .to_string(); + + // Reconnect if session was lost (shell closed unexpectedly) + if !self.sessions.contains_key(&acp_sid) { + info!("session lost, reconnecting shell..."); + let ts = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis(); + let runtime_sid = format!("oab-reconnect-{ts:020}-{ts:013x}"); + match self.open_shell(&runtime_sid).await { + Ok(handle) => { self.sessions.insert(acp_sid.clone(), handle); } + Err(e) => { + self.write_error(id, -32000, &format!("reconnect failed: {e}")).await?; + return Ok(()); + } + } + } + + // Allocate ID before borrowing sessions + let kiro_id = self.alloc_id(); + let kiro_sid = self.sessions.get(&acp_sid) + .map(|s| s.kiro_session_id.clone()) + .unwrap_or_default(); + let mut fwd_params = params.clone(); + if let Some(obj) = fwd_params.as_object_mut() { + obj.insert("sessionId".to_string(), json!(kiro_sid)); + } + let kiro_msg = json!({ + "jsonrpc": "2.0", + "id": kiro_id, + "method": "session/prompt", + "params": fwd_params, + }); + let data = format!("{}\n", serde_json::to_string(&kiro_msg)?); + + // Send prompt to kiro-cli + { + let shell = self.sessions.get_mut(&acp_sid).unwrap(); + let mut w = shell.ws_write.lock().await; + let mut frame = Vec::with_capacity(1 + data.len()); + frame.push(CHANNEL_STDIN); + frame.extend_from_slice(data.as_bytes()); + let _ = w.send(Message::Binary(frame)).await; + } + + // Read responses/notifications from kiro-cli until we get the response for our id. + // We take line_rx out of the session to avoid holding &mut self across await points. + let mut line_rx = match self.sessions.get_mut(&acp_sid) { + Some(s) => std::mem::replace(&mut s.line_rx, tokio::sync::mpsc::unbounded_channel().1), + None => { + self.write_error(id, -32000, "session lost").await?; + return Ok(()); + } + }; + + let result = loop { + match tokio::time::timeout(std::time::Duration::from_secs(300), line_rx.recv()).await { + Ok(Some(line)) => { + let msg: Value = match serde_json::from_str(&line) { + Ok(v) => v, + Err(_) => continue, + }; + if msg.get("id").and_then(|i| i.as_u64()) == Some(kiro_id) { + if let Some(err) = msg.get("error") { + self.write_msg(&json!({"jsonrpc": "2.0", "id": id, "error": err})) + .await?; + } else { + let r = msg + .get("result") + .cloned() + .unwrap_or(json!({"type": "success"})); + self.write_response(id, r).await?; + } + break Some(line_rx); + } + if msg.get("method").is_some() { + self.write_msg(&msg).await?; + } + } + Ok(None) => { + self.write_error(id, -32000, "shell connection closed") + .await?; + self.sessions.remove(&acp_sid); + break None; + } + Err(_) => { + self.write_error(id, -32000, "timeout waiting for agent response") + .await?; + break Some(line_rx); + } + } + }; + + // Put line_rx back + if let Some(rx) = result { + if let Some(s) = self.sessions.get_mut(&acp_sid) { + s.line_rx = rx; + } + } + Ok(()) + } + + async fn handle_cancel(&mut self, params: &Value) { + let acp_sid = params + .get("sessionId") + .and_then(|s| s.as_str()) + .unwrap_or(""); + if let Some(shell) = self.sessions.get(acp_sid) { + let cancel_msg = json!({ + "jsonrpc": "2.0", + "method": "session/cancel", + "params": params, + }); + let data = format!( + "{}\n", + serde_json::to_string(&cancel_msg).unwrap_or_default() + ); + let mut frame = Vec::with_capacity(1 + data.len()); + frame.push(CHANNEL_STDIN); + frame.extend_from_slice(data.as_bytes()); + let mut w = shell.ws_write.lock().await; + let _ = w.send(Message::Binary(frame)).await; + } + } + + #[allow(dead_code)] + fn derive_runtime_session_id(&self, params: &Value) -> String { + // Try to extract from sender_context in prompt blocks + if let Some(blocks) = params.get("prompt").and_then(|p| p.as_array()) { + for block in blocks { + if let Some(text) = block.get("text").and_then(|t| t.as_str()) { + if let Some(start) = text.find("") { + if let Some(end) = text.find("") { + let ctx_str = &text[start + 16..end]; + if let Ok(ctx) = serde_json::from_str::(ctx_str.trim()) { + let platform = ctx + .get("channel") + .and_then(|c| c.as_str()) + .unwrap_or("unknown"); + let thread_id = ctx + .get("thread_id") + .or_else(|| ctx.get("channel_id")) + .and_then(|t| t.as_str()) + .unwrap_or(""); + let mut sid = format!("oab-{platform}-thread-{thread_id}"); + while sid.len() < 33 { + sid.push('0'); + } + return sid; + } + } + } + } + } + } + // Fallback + let mut sid = format!("oab-fallback-{}", uuid::Uuid::new_v4()); + while sid.len() < 33 { + sid.push('0'); + } + sid + } + + async fn open_shell(&self, session_id: &str) -> Result { + let (request, host) = build_signed_request(&self.runtime_arn, session_id, &self.region).await?; + + // Manual TLS connection — gives us full control, avoids connect_async host override + let tcp = tokio::net::TcpStream::connect(format!("{host}:443")) + .await + .map_err(|e| anyhow!("TCP connect to {host}:443 failed: {e}"))?; + + let connector = tokio_tungstenite::Connector::Rustls(std::sync::Arc::new( + rustls::ClientConfig::builder() + .with_root_certificates(rustls::RootCertStore { + roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(), + }) + .with_no_client_auth(), + )); + + let tls_stream = match connector { + tokio_tungstenite::Connector::Rustls(cfg) => { + let domain = rustls::pki_types::ServerName::try_from(host.as_str()) + .map_err(|e| anyhow!("bad DNS: {e}"))? + .to_owned(); + tokio_rustls::TlsConnector::from(cfg) + .connect(domain, tcp) + .await + .map_err(|e| anyhow!("TLS failed: {e}"))? + } + _ => unreachable!(), + }; + + // client_async performs the WebSocket upgrade using our exact request + let (ws_stream, _) = tokio_tungstenite::client_async(request, tls_stream) + .await + .map_err(|e| anyhow!("WebSocket upgrade failed: {e}"))?; + + info!(session_id, "AgentCore shell connected"); + + let (ws_write, mut ws_read) = ws_stream.split(); + let ws_write = Arc::new(Mutex::new(ws_write)); + + // Send agent launch command + let shell_cmd = format!("{}{}\n", AGENT_CMD_PREFIX, self.agent_command); + { + let mut frame = Vec::with_capacity(1 + shell_cmd.len()); + frame.push(CHANNEL_STDIN); + frame.extend_from_slice(shell_cmd.as_bytes()); + let mut w = ws_write.lock().await; + w.send(Message::Binary(frame)) + .await + .map_err(|e| anyhow!("failed to send launch cmd: {e}"))?; + } + + // Channel for forwarding parsed JSON-RPC lines + let (line_tx, mut line_rx) = tokio::sync::mpsc::unbounded_channel::(); + + // Spawn reader pump + let pump = tokio::spawn(async move { + let mut buf = String::new(); + while let Some(Ok(msg)) = ws_read.next().await { + match msg { + Message::Binary(data) => { + if data.len() < 2 { + continue; + } + if data[0] == CHANNEL_STDOUT { + // stdout + if let Ok(s) = std::str::from_utf8(&data[1..]) { + buf.push_str(s); + while let Some(nl) = buf.find('\n') { + let line = buf[..nl].to_string(); + buf = buf[nl + 1..].to_string(); + let trimmed = line.trim().to_string(); + if trimmed.is_empty() { + continue; + } + // Extract JSON object using brace-counting (handles PTY prefix noise) + if let Some(json_str) = extract_json_object(&trimmed) { + if line_tx.send(json_str).is_err() { + return; // receiver dropped — exit pump + } + } + } + } + } else if data[0] == CHANNEL_STDERR { + // stderr — log + if let Ok(s) = std::str::from_utf8(&data[1..]) { + let t = s.trim(); + if !t.is_empty() { + eprintln!("[agentcore] {t}"); + } + } + } + } + Message::Close(_) => break, + _ => {} + } + } + }); + + // Send ACP initialize to the agent (it will respond once booted) + + let init_msg = serde_json::json!({ + "jsonrpc": "2.0", + "id": 0, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "openab-agentcore-bridge", "version": env!("CARGO_PKG_VERSION")} + } + }); + let init_data = format!("{}\n", serde_json::to_string(&init_msg)?); + + // Send initialize and wait for response (retry if agent hasn't booted yet) + let mut initialized = false; + for attempt in 0..5 { + { + let mut w = ws_write.lock().await; + let mut frame = Vec::with_capacity(1 + init_data.len()); + frame.push(CHANNEL_STDIN); + frame.extend_from_slice(init_data.as_bytes()); + if let Err(e) = w.send(Message::Binary(frame)).await { + if attempt < 4 { + tokio::time::sleep(std::time::Duration::from_secs(2)).await; + continue; + } + return Err(anyhow!("failed to send initialize: {e}")); + } + } + // Wait up to 10s for response — skip notifications (lines without "id":0) + let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(10); + loop { + let remaining = deadline.saturating_duration_since(tokio::time::Instant::now()); + if remaining.is_zero() { + info!(attempt, "no initialize response, retrying..."); + break; + } + match tokio::time::timeout(remaining, line_rx.recv()).await { + Ok(Some(line)) => { + // Check if this is the initialize response (has "id":0 or "id": 0) + if let Ok(v) = serde_json::from_str::(&line) { + if v.get("id").and_then(|i| i.as_u64()) == Some(0) && v.get("result").is_some() { + info!(attempt, "agent initialized"); + initialized = true; + break; + } + } + // Skip notifications and other non-response lines + continue; + } + Ok(None) => return Err(anyhow!("agent closed before initialize response")), + Err(_) => { + info!(attempt, "no initialize response, retrying..."); + break; + } + } + } + if initialized { break; } + } + if !initialized { + return Err(anyhow!("agent failed to respond to initialize after 5 attempts")); + } + + // Send session/new to kiro-cli to create a session + let sess_msg = serde_json::json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "session/new", + "params": {"cwd": "/home/agent", "mcpServers": []} + }); + let sess_data = format!("{}\n", serde_json::to_string(&sess_msg)?); + { + let mut w = ws_write.lock().await; + let mut frame = Vec::with_capacity(1 + sess_data.len()); + frame.push(CHANNEL_STDIN); + frame.extend_from_slice(sess_data.as_bytes()); + w.send(Message::Binary(frame)).await?; + } + + // Wait for session/new response — skip notifications (up to 120s) + let kiro_session_id = { + let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(120); + let mut sid = String::from("default"); + loop { + let remaining = deadline.saturating_duration_since(tokio::time::Instant::now()); + if remaining.is_zero() { + info!("session/new timed out, using default session"); + break; + } + match tokio::time::timeout(remaining, line_rx.recv()).await { + Ok(Some(line)) => { + if let Ok(v) = serde_json::from_str::(&line) { + if v.get("id").and_then(|i| i.as_u64()) == Some(1) { + sid = v.pointer("/result/sessionId") + .and_then(|s| s.as_str()) + .unwrap_or("default") + .to_string(); + info!(kiro_session_id = %sid, "agent session created"); + break; + } + } + // Skip notifications + continue; + } + Ok(None) => return Err(anyhow!("agent closed before session/new response")), + Err(_) => { + info!("session/new timed out, using default session"); + break; + } + } + } + sid + }; + + Ok(ShellHandle { + ws_write, + line_rx, + _pump: pump, + runtime_session_id: session_id.to_string(), + kiro_session_id, + }) + } +} + +/// Build a WebSocket upgrade request with SigV4 Authorization header. +async fn build_signed_request( + arn: &str, + session_id: &str, + region: &str, +) -> Result<(http::Request<()>, String)> { + let config = aws_config::defaults(aws_config::BehaviorVersion::latest()) + .region(aws_config::Region::new(region.to_string())) + .load() + .await; + + let creds = config + .credentials_provider() + .ok_or_else(|| anyhow!("No AWS credentials found"))? + .provide_credentials() + .await + .map_err(|e| anyhow!("Failed to get credentials: {e}"))?; + + let identity = creds.into(); + + let encoded_arn = urlencoding::encode(arn); + let host = format!("bedrock-agentcore.{region}.amazonaws.com"); + let path = format!("/runtimes/{encoded_arn}/ws/shells"); + + // Deterministic shell_id from session_id + let hash = Sha256::digest(session_id.as_bytes()); + let shell_id = format!("oab-{}", hex::encode(&hash[..8])); + + let query = format!("qualifier=DEFAULT&shellId={shell_id}"); + let uri = format!("https://{host}{path}?{query}"); + + // Header-based SigV4 auth + let mut settings = SigningSettings::default(); + settings.expires_in = None; + settings.uri_path_normalization_mode = + aws_sigv4::http_request::UriPathNormalizationMode::Enabled; + + let signing_params = v4::SigningParams::builder() + .identity(&identity) + .region(region) + .name("bedrock-agentcore") + .time(SystemTime::now()) + .settings(settings) + .build()?; + + let headers = [ + ("host", host.as_str()), + ("x-amzn-bedrock-agentcore-runtime-session-id", session_id), + ]; + let signable = SignableRequest::new("GET", &uri, headers.into_iter(), SignableBody::empty())?; + let (instructions, _sig) = sign(signable, &signing_params.into())?.into_parts(); + + let wss_uri = format!("wss://{host}{path}?{query}"); + + // Build request with auth headers + WebSocket headers + let mut builder = http::Request::builder() + .method("GET") + .uri(&wss_uri) + .header("host", &host) + .header("x-amzn-bedrock-agentcore-runtime-session-id", session_id) + .header("connection", "Upgrade") + .header("upgrade", "websocket") + .header("sec-websocket-version", "13") + .header("sec-websocket-key", tokio_tungstenite::tungstenite::handshake::client::generate_key()); + + // Add SigV4 auth headers (x-amz-date, authorization) + for (name, value) in instructions.headers() { + builder = builder.header(name, value); + } + + let request = builder.body(())?; + Ok((request, host)) +} diff --git a/crates/openab-core/src/acp/connection.rs b/crates/openab-core/src/acp/connection.rs new file mode 100644 index 000000000..8df3451f4 --- /dev/null +++ b/crates/openab-core/src/acp/connection.rs @@ -0,0 +1,937 @@ +use crate::acp::protocol::{ + parse_config_options, ConfigOption, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, +}; +use anyhow::{anyhow, Result}; +use serde_json::{json, Value}; +use std::collections::HashMap; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; +use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}; +use tokio::process::{Child, ChildStdin}; +use tokio::sync::{mpsc, oneshot, Mutex}; +use tokio::task::JoinHandle; +use tracing::{debug, error, info, trace}; + +/// Pick the most permissive selectable permission option from ACP options. +fn pick_best_option(options: &[Value]) -> Option { + let mut fallback: Option<&Value> = None; + + for kind in ["allow_always", "allow_once"] { + if let Some(option) = options + .iter() + .find(|option| option.get("kind").and_then(|k| k.as_str()) == Some(kind)) + { + return option + .get("optionId") + .and_then(|id| id.as_str()) + .map(str::to_owned); + } + } + + for option in options { + let kind = option.get("kind").and_then(|k| k.as_str()); + if kind == Some("reject_once") || kind == Some("reject_always") { + continue; + } + fallback = Some(option); + break; + } + + fallback + .and_then(|option| option.get("optionId")) + .and_then(|id| id.as_str()) + .map(str::to_owned) +} + +/// Build a spec-compliant permission response with backward-compatible fallback. +fn build_permission_response(params: Option<&Value>) -> Value { + match params + .and_then(|p| p.get("options")) + .and_then(|options| options.as_array()) + { + None => json!({ + "outcome": { + "outcome": "selected", + "optionId": "allow_always" + } + }), + Some(options) => { + if let Some(option_id) = pick_best_option(options) { + json!({ + "outcome": { + "outcome": "selected", + "optionId": option_id + } + }) + } else { + json!({ + "outcome": { + "outcome": "cancelled" + } + }) + } + } + } +} + +fn expand_env(val: &str) -> String { + if val.starts_with("${") && val.ends_with('}') { + let key = &val[2..val.len() - 1]; + std::env::var(key).unwrap_or_default() + } else { + val.to_string() + } +} +use tokio::time::Instant; + +/// A content block for the ACP prompt — either text or image. +#[derive(Debug, Clone)] +pub enum ContentBlock { + Text { text: String }, + Image { media_type: String, data: String }, +} + +impl ContentBlock { + pub fn to_json(&self) -> Value { + match self { + ContentBlock::Text { text } => json!({ + "type": "text", + "text": text + }), + ContentBlock::Image { media_type, data } => json!({ + "type": "image", + "data": data, + "mimeType": media_type + }), + } + } +} + +pub struct AcpConnection { + _proc: Child, + /// PID of the direct child, used as the process group ID for cleanup. + child_pgid: Option, + stdin: Arc>, + next_id: AtomicU64, + pending: Arc>>>, + notify_tx: Arc>>>, + pub acp_session_id: Option, + pub supports_load_session: bool, + pub config_options: Vec, + pub last_active: Instant, + pub session_reset: bool, + _reader_handle: JoinHandle<()>, + _stderr_handle: Option>, +} + +/// Build the final set of env vars for the agent subprocess. +/// `explicit` ([agent].env) takes precedence over `inherit` ([agent].inherit_env). +/// Returns (merged env map, list of keys that were inherited from the process). +fn build_agent_env( + explicit: &std::collections::HashMap, + inherit_keys: &[String], +) -> (std::collections::HashMap, Vec) { + let mut result: std::collections::HashMap = std::collections::HashMap::new(); + let mut inherited: Vec = Vec::new(); + + for (k, v) in explicit { + result.insert(k.clone(), expand_env(v)); + } + + for key in inherit_keys { + if !result.contains_key(key) { + if let Ok(v) = std::env::var(key) { + result.insert(key.clone(), v); + inherited.push(key.clone()); + } + } + } + + (result, inherited) +} + +/// Reader loop body: reads JSON-RPC messages from `reader`, auto-replies +/// `session/request_permission` via `writer`, resolves pending responses, +/// and forwards notifications + stale id-bearing messages to the active +/// subscriber. Extracted as a free generic function so unit tests can drive +/// it with `tokio::io::duplex()` halves instead of a real child process. +pub(crate) async fn run_reader_loop( + reader: R, + writer: Arc>, + pending: Arc>>>, + notify_tx: Arc>>>, +) where + R: AsyncRead + Unpin + Send + 'static, + W: AsyncWrite + Unpin + Send + 'static, +{ + let mut reader = BufReader::new(reader); + let mut line = String::new(); + loop { + line.clear(); + match reader.read_line(&mut line).await { + Ok(0) => break, // EOF + Ok(_) => {} + Err(e) => { + error!("reader error: {e}"); + break; + } + } + let msg: JsonRpcMessage = match serde_json::from_str(line.trim()) { + Ok(m) => m, + Err(_) => continue, + }; + debug!(line = line.trim(), "acp_recv"); + + // Auto-reply session/request_permission + if msg.method.as_deref() == Some("session/request_permission") { + if let Some(id) = msg.id { + let title = msg + .params + .as_ref() + .and_then(|p| p.get("toolCall")) + .and_then(|t| t.get("title")) + .and_then(|t| t.as_str()) + .unwrap_or("?"); + + let outcome = build_permission_response(msg.params.as_ref()); + info!(title, %outcome, "auto-respond permission"); + let reply = JsonRpcResponse::new(id, outcome); + if let Ok(data) = serde_json::to_string(&reply) { + let mut w = writer.lock().await; + let _ = w.write_all(format!("{data}\n").as_bytes()).await; + let _ = w.flush().await; + } + } + continue; + } + + // Response (has id) → resolve pending AND forward to subscriber + if let Some(id) = msg.id { + let mut map = pending.lock().await; + if let Some(tx) = map.remove(&id) { + // Forward to subscriber so they see the completion + let sub = notify_tx.lock().await; + if let Some(ntx) = sub.as_ref() { + // Clone the essential fields for the subscriber + let _ = ntx.send(JsonRpcMessage { + id: Some(id), + method: None, + result: msg.result.clone(), + error: msg.error.clone(), + params: None, + }); + } + let _ = tx.send(msg); + continue; + } + // Stale id (#732): pending was already abandoned. Falls through + // to subscriber forwarding; the adapter recv loop filters by + // request_id so it can't leak into the next prompt. + trace!(request_id = id, "stale id-bearing message after abandon"); + } + + // Notification → forward to subscriber + let sub = notify_tx.lock().await; + if let Some(tx) = sub.as_ref() { + let _ = tx.send(msg); + } + } + + // Connection closed — resolve all pending with error + let mut map = pending.lock().await; + for (_, tx) in map.drain() { + let _ = tx.send(JsonRpcMessage { + id: None, + method: None, + result: None, + error: Some(crate::acp::protocol::JsonRpcError { + code: -1, + message: "connection closed".into(), + data: None, + }), + params: None, + }); + } + // Close the notify channel so rx.recv() returns None + let mut sub = notify_tx.lock().await; + *sub = None; +} + +impl AcpConnection { + pub async fn spawn( + command: &str, + args: &[String], + working_dir: &str, + env: &std::collections::HashMap, + inherit_env: &[String], + ) -> Result { + info!(cmd = command, ?args, cwd = working_dir, "spawning agent"); + + let mut cmd = tokio::process::Command::new(command); + cmd.args(args) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .current_dir(working_dir); + // Create a new process group so we can kill the entire tree. + // SAFETY: setpgid is async-signal-safe (POSIX.1-2008) and called + // before exec. Return value checked — failure means the child won't + // have its own process group, so kill(-pgid) would be unsafe. + #[cfg(unix)] + unsafe { + cmd.pre_exec(|| { + if libc::setpgid(0, 0) != 0 { + return Err(std::io::Error::last_os_error()); + } + Ok(()) + }); + } + #[cfg(windows)] + { + cmd.creation_flags(0x00000200); // CREATE_NEW_PROCESS_GROUP + } + // Clear inherited env to prevent credential leakage (e.g. DISCORD_BOT_TOKEN). + // Only [agent].env values + essential baseline vars are passed through. + cmd.env_clear(); + // Preserve the real HOME so agents can find OAuth/auth files (~/.codex, + // ~/.claude, ~/.config/gh, etc.). working_dir is already set via + // current_dir() above and is not necessarily the user's home directory. + cmd.env( + "HOME", + std::env::var("HOME").unwrap_or_else(|_| working_dir.into()), + ); + cmd.env( + "PATH", + std::env::var("PATH").unwrap_or_else(|_| "/usr/local/bin:/usr/bin:/bin".into()), + ); + #[cfg(unix)] + { + cmd.env( + "USER", + std::env::var("USER").unwrap_or_else(|_| "agent".into()), + ); + } + #[cfg(windows)] + { + // Windows requires SystemRoot for DLL loading and basic OS functionality. + // USERPROFILE is the Windows equivalent of HOME. + cmd.env( + "USERPROFILE", + std::env::var("USERPROFILE").unwrap_or_else(|_| working_dir.into()), + ); + cmd.env( + "USERNAME", + std::env::var("USERNAME").unwrap_or_else(|_| "agent".into()), + ); + if let Ok(v) = std::env::var("SystemRoot") { + cmd.env("SystemRoot", v); + } + if let Ok(v) = std::env::var("SystemDrive") { + cmd.env("SystemDrive", v); + } + } + for (k, v) in env { + cmd.env(k, expand_env(v)); + } + // Inherit selected env vars from the OAB process (e.g. vars injected + // via Kubernetes envFrom). Keys already in [agent].env are skipped — + // explicit values take precedence. + let (agent_env, inherited_keys) = build_agent_env(env, inherit_env); + for (k, v) in &agent_env { + cmd.env(k, v); + } + if !agent_env.is_empty() { + let explicit_keys: Vec<&String> = env.keys().collect(); + tracing::warn!( + ?explicit_keys, + ?inherited_keys, + "[agent].env/inherit_env is set -- these values are accessible to the agent and could be exfiltrated via prompt injection" + ); + } + let mut proc = cmd + .spawn() + .map_err(|e| anyhow!("failed to spawn {command}: {e}"))?; + let child_pgid = proc.id().and_then(|pid| i32::try_from(pid).ok()); + + let stdout = proc.stdout.take().ok_or_else(|| anyhow!("no stdout"))?; + let stdin = proc.stdin.take().ok_or_else(|| anyhow!("no stdin"))?; + let stdin = Arc::new(Mutex::new(stdin)); + + // Capture agent stderr and log it (ACP spec: agents MAY write to stderr + // for logging; clients MAY capture or ignore it). + let stderr_handle = if let Some(stderr) = proc.stderr.take() { + let cmd_name = command.to_string(); + Some(tokio::spawn(async move { + let mut reader = BufReader::new(stderr); + let mut line = String::new(); + loop { + line.clear(); + match reader.read_line(&mut line).await { + Ok(0) => break, + Ok(_) => { + let trimmed = line.trim(); + if !trimmed.is_empty() { + let sanitized: String = trimmed.chars() + .filter(|c| !c.is_control() || *c == '\t') + .collect(); + if !sanitized.is_empty() { + tracing::warn!(agent = %cmd_name, "{sanitized}"); + } + } + } + Err(_) => break, + } + } + })) + } else { + None + }; + + let pending: Arc>>> = + Arc::new(Mutex::new(HashMap::new())); + let notify_tx: Arc>>> = + Arc::new(Mutex::new(None)); + + let reader_handle = tokio::spawn(run_reader_loop( + stdout, + stdin.clone(), + pending.clone(), + notify_tx.clone(), + )); + + Ok(Self { + _proc: proc, + child_pgid, + stdin, + next_id: AtomicU64::new(1), + pending, + notify_tx, + acp_session_id: None, + supports_load_session: false, + config_options: Vec::new(), + last_active: Instant::now(), + session_reset: false, + _reader_handle: reader_handle, + _stderr_handle: stderr_handle, + }) + } + + fn next_id(&self) -> u64 { + self.next_id.fetch_add(1, Ordering::Relaxed) + } + + pub(crate) async fn send_raw(&self, data: &str) -> Result<()> { + debug!(data = data.trim(), "acp_send"); + let mut w = self.stdin.lock().await; + w.write_all(data.as_bytes()).await?; + w.write_all(b"\n").await?; + w.flush().await?; + Ok(()) + } + + async fn send_request(&self, method: &str, params: Option) -> Result { + let id = self.next_id(); + let req = JsonRpcRequest::new(id, method, params); + let data = serde_json::to_string(&req)?; + + let (tx, rx) = oneshot::channel(); + self.pending.lock().await.insert(id, tx); + + self.send_raw(&data).await?; + + let timeout_secs = if method == "session/new" { 120 } else { 30 }; + let resp = tokio::time::timeout(std::time::Duration::from_secs(timeout_secs), rx) + .await + .map_err(|_| anyhow!("timeout waiting for {method} response"))? + .map_err(|_| anyhow!("channel closed waiting for {method}"))?; + + if let Some(err) = &resp.error { + return Err(anyhow!("{err}")); + } + Ok(resp) + } + + pub async fn initialize(&mut self) -> Result<()> { + let resp = self + .send_request( + "initialize", + Some(json!({ + "protocolVersion": 1, + "clientCapabilities": {}, + "clientInfo": {"name": "openab", "version": "0.1.0"}, + })), + ) + .await?; + + let result = resp.result.as_ref(); + let agent_name = result + .and_then(|r| r.get("agentInfo")) + .and_then(|a| a.get("name")) + .and_then(|n| n.as_str()) + .unwrap_or("unknown"); + self.supports_load_session = result + .and_then(|r| r.get("agentCapabilities")) + .and_then(|c| c.get("loadSession")) + .and_then(|v| v.as_bool()) + .unwrap_or(false); + info!( + agent = agent_name, + load_session = self.supports_load_session, + "initialized" + ); + Ok(()) + } + + pub async fn session_new(&mut self, cwd: &str) -> Result { + let resp = self + .send_request("session/new", Some(json!({"cwd": cwd, "mcpServers": []}))) + .await?; + + let session_id = resp + .result + .as_ref() + .and_then(|r| r.get("sessionId")) + .and_then(|s| s.as_str()) + .ok_or_else(|| anyhow!("no sessionId in session/new response"))? + .to_string(); + + info!(session_id = %session_id, "session created"); + self.acp_session_id = Some(session_id.clone()); + if let Some(result) = resp.result.as_ref() { + self.config_options = parse_config_options(result); + if !self.config_options.is_empty() { + info!(count = self.config_options.len(), "parsed configOptions"); + } + } + Ok(session_id) + } + + /// Set a config option (e.g. model, mode) via ACP session/set_config_option. + /// Returns the updated list of all config options. + pub async fn set_config_option( + &mut self, + config_id: &str, + value: &str, + ) -> Result> { + let session_id = self + .acp_session_id + .as_ref() + .ok_or_else(|| anyhow!("no session"))? + .clone(); + + let resp = self + .send_request( + "session/set_config_option", + Some(json!({ + "sessionId": session_id, + "configId": config_id, + "value": value, + })), + ) + .await; + + match resp { + Ok(r) => { + if let Some(result) = r.result.as_ref() { + self.config_options = parse_config_options(result); + } + info!(config_id, value, "config option set"); + } + Err(_) => { + // Fall back: send as a slash command (e.g. "/model claude-sonnet-4") + let cmd = format!("/{config_id} {value}"); + info!( + cmd, + "set_config_option not supported, falling back to prompt" + ); + let _resp = self + .send_request( + "session/prompt", + Some(json!({ + "sessionId": session_id, + "prompt": [{"type": "text", "text": cmd}], + })), + ) + .await?; + for opt in &mut self.config_options { + if opt.id == config_id { + opt.current_value = value.to_string(); + } + } + } + } + + Ok(self.config_options.clone()) + } + + /// Send a prompt with content blocks (text and/or images) and return a receiver + /// for streaming notifications. The final message on the channel will have id set + /// (the prompt response). + pub async fn session_prompt( + &mut self, + content_blocks: Vec, + ) -> Result<(mpsc::UnboundedReceiver, u64)> { + self.last_active = Instant::now(); + + let session_id = self + .acp_session_id + .as_ref() + .ok_or_else(|| anyhow!("no session"))?; + + let (tx, rx) = mpsc::unbounded_channel(); + *self.notify_tx.lock().await = Some(tx); + + let id = self.next_id(); + + // Convert content blocks to JSON + let prompt_json: Vec = content_blocks.iter().map(|b| b.to_json()).collect(); + + let req = JsonRpcRequest::new( + id, + "session/prompt", + Some(json!({ + "sessionId": session_id, + "prompt": prompt_json, + })), + ); + let data = serde_json::to_string(&req)?; + + let (resp_tx, _resp_rx) = oneshot::channel(); + self.pending.lock().await.insert(id, resp_tx); + + self.send_raw(&data).await?; + Ok((rx, id)) + } + + /// Call after prompt streaming is done to clean up subscriber. + pub async fn prompt_done(&mut self) { + *self.notify_tx.lock().await = None; + self.last_active = Instant::now(); + } + + /// Drop the pending entry for `request_id` and best-effort send + /// `session/cancel` as a JSON-RPC notification (no id; per ACP spec the + /// agent does not reply). Errors are swallowed: the agent process may + /// already be dead, in which case the stdin write fails harmlessly. + /// See #732. + pub async fn abandon_request(&self, request_id: u64) { + self.pending.lock().await.remove(&request_id); + let Some(session_id) = self.acp_session_id.as_deref() else { + return; + }; + let req = json!({ + "jsonrpc": "2.0", + "method": "session/cancel", + "params": {"sessionId": session_id}, + }); + if let Ok(data) = serde_json::to_string(&req) { + let _ = self.send_raw(&data).await; + } + } + + /// Return a clone of the stdin handle for lock-free cancel. + pub fn cancel_handle(&self) -> Arc> { + Arc::clone(&self.stdin) + } + + pub fn alive(&self) -> bool { + !self._reader_handle.is_finished() + } + + /// Resume a previous session by ID. Returns Ok(()) if the agent accepted + /// the load, or an error if it failed (caller should fall back to session/new). + pub async fn session_load(&mut self, session_id: &str, cwd: &str) -> Result<()> { + let resp = self + .send_request( + "session/load", + Some(json!({"sessionId": session_id, "cwd": cwd, "mcpServers": []})), + ) + .await?; + // Accept any non-error response as success + if resp.error.is_some() { + return Err(anyhow!("session/load rejected")); + } + info!(session_id, "session loaded"); + self.acp_session_id = Some(session_id.to_string()); + if let Some(result) = resp.result.as_ref() { + self.config_options = parse_config_options(result); + } + Ok(()) + } + + /// Kill the entire process group: SIGTERM → SIGKILL. + /// Uses std::thread (not tokio::spawn) so SIGKILL fires even during + /// runtime shutdown or panic unwinding. + fn kill_process_group(&mut self) { + let pgid = match self.child_pgid { + Some(pid) if pid > 0 => pid, + _ => return, + }; + #[cfg(unix)] + { + // Stage 1: SIGTERM the process group + unsafe { + libc::kill(-pgid, libc::SIGTERM); + } + // Stage 2: SIGKILL after brief grace (std::thread survives runtime shutdown) + std::thread::spawn(move || { + std::thread::sleep(std::time::Duration::from_millis(1500)); + unsafe { + libc::kill(-pgid, libc::SIGKILL); + } + }); + } + #[cfg(not(unix))] + { + let _ = pgid; // suppress unused warning on Windows + } + } +} + +impl Drop for AcpConnection { + fn drop(&mut self) { + if let Some(handle) = self._stderr_handle.take() { + handle.abort(); + } + self.kill_process_group(); + } +} + +#[cfg(test)] +mod tests { + use super::{build_agent_env, build_permission_response, pick_best_option}; + use serde_json::json; + + #[test] + fn picks_allow_always_over_other_options() { + let options = vec![ + json!({"kind": "allow_once", "optionId": "once"}), + json!({"kind": "allow_always", "optionId": "always"}), + json!({"kind": "reject_once", "optionId": "reject"}), + ]; + + assert_eq!(pick_best_option(&options), Some("always".to_string())); + } + + #[test] + fn falls_back_to_first_unknown_non_reject_kind() { + let options = vec![ + json!({"kind": "reject_once", "optionId": "reject"}), + json!({"kind": "workspace_write", "optionId": "workspace-write"}), + ]; + + assert_eq!( + pick_best_option(&options), + Some("workspace-write".to_string()) + ); + } + + #[test] + fn selects_bypass_permissions_for_exit_plan_mode() { + let options = vec![ + json!({"optionId": "bypassPermissions", "kind": "allow_always"}), + json!({"optionId": "acceptEdits", "kind": "allow_always"}), + json!({"optionId": "default", "kind": "allow_once"}), + json!({"optionId": "plan", "kind": "reject_once"}), + ]; + + assert_eq!( + pick_best_option(&options), + Some("bypassPermissions".to_string()) + ); + } + + #[test] + fn returns_none_when_only_reject_options_exist() { + let options = vec![ + json!({"kind": "reject_once", "optionId": "reject-once"}), + json!({"kind": "reject_always", "optionId": "reject-always"}), + ]; + + assert_eq!(pick_best_option(&options), None); + } + + #[test] + fn builds_cancelled_outcome_when_no_selectable_option_exists() { + let response = build_permission_response(Some(&json!({ + "options": [ + {"kind": "reject_once", "optionId": "reject-once"} + ] + }))); + + assert_eq!(response, json!({"outcome": {"outcome": "cancelled"}})); + } + + #[test] + fn builds_cancelled_when_options_array_is_empty() { + let response = build_permission_response(Some(&json!({ + "options": [] + }))); + + assert_eq!(response, json!({"outcome": {"outcome": "cancelled"}})); + } + + #[test] + fn falls_back_to_allow_always_when_options_are_missing() { + let response = build_permission_response(Some(&json!({ + "toolCall": {"title": "legacy"} + }))); + + assert_eq!( + response, + json!({"outcome": {"outcome": "selected", "optionId": "allow_always"}}) + ); + } + + #[test] + fn falls_back_to_allow_always_when_params_is_none() { + let response = build_permission_response(None); + + assert_eq!( + response, + json!({"outcome": {"outcome": "selected", "optionId": "allow_always"}}) + ); + } + + #[test] + fn explicit_env_takes_precedence_over_inherit_env() { + let key = "OAB_TEST_PRECEDENCE"; + std::env::set_var(key, "from_process"); + let mut explicit = std::collections::HashMap::new(); + explicit.insert(key.to_string(), "from_config".to_string()); + let inherit = vec![key.to_string()]; + + let (result, inherited) = build_agent_env(&explicit, &inherit); + + assert_eq!(result.get(key).unwrap(), "from_config"); + assert!(!inherited.contains(&key.to_string())); + std::env::remove_var(key); + } + + #[test] + fn inherit_env_copies_from_process() { + let key = "OAB_TEST_INHERIT"; + std::env::set_var(key, "process_value"); + let explicit = std::collections::HashMap::new(); + let inherit = vec![key.to_string()]; + + let (result, inherited) = build_agent_env(&explicit, &inherit); + + assert_eq!(result.get(key).unwrap(), "process_value"); + assert!(inherited.contains(&key.to_string())); + std::env::remove_var(key); + } + + #[test] + fn inherit_env_skips_missing_vars() { + let explicit = std::collections::HashMap::new(); + let inherit = vec!["OAB_TEST_NONEXISTENT_VAR_12345".to_string()]; + + let (result, inherited) = build_agent_env(&explicit, &inherit); + + assert!(!result.contains_key("OAB_TEST_NONEXISTENT_VAR_12345")); + assert!(inherited.is_empty()); + } +} + +#[cfg(test)] +mod reader_loop_tests { + use super::*; + use std::collections::HashMap; + use std::sync::Arc; + use tokio::io::{duplex, AsyncWriteExt}; + use tokio::sync::{mpsc, oneshot, Mutex}; + + /// #732 stale-id path: when a response arrives for an id the broker has + /// already abandoned, the reader must (a) not crash, (b) leave `pending` + /// untouched, and (c) still forward the message to whoever is currently + /// subscribed — the adapter recv loop is responsible for filtering by + /// request_id so the stray response never leaks into the next prompt. + #[tokio::test] + async fn stale_id_response_is_forwarded_without_pending_entry() { + let (mut agent_stdout_writer, agent_stdout_reader) = duplex(8 * 1024); + let (agent_stdin_writer, _agent_stdin_reader) = duplex(8 * 1024); + + let pending: Arc>>> = + Arc::new(Mutex::new(HashMap::new())); + let notify_tx: Arc>>> = + Arc::new(Mutex::new(None)); + + let (sub_tx, mut sub_rx) = mpsc::unbounded_channel(); + *notify_tx.lock().await = Some(sub_tx); + + let writer = Arc::new(Mutex::new(agent_stdin_writer)); + let handle = tokio::spawn(run_reader_loop( + agent_stdout_reader, + writer, + pending.clone(), + notify_tx.clone(), + )); + + let stale = b"{\"jsonrpc\":\"2.0\",\"id\":42,\"result\":{\"stopReason\":\"ok\"}}\n"; + agent_stdout_writer.write_all(stale).await.unwrap(); + agent_stdout_writer.flush().await.unwrap(); + + let forwarded = tokio::time::timeout( + std::time::Duration::from_secs(2), + sub_rx.recv(), + ) + .await + .expect("subscriber should receive stale message before timeout") + .expect("subscriber channel should not be closed"); + assert_eq!(forwarded.id, Some(42)); + assert!(pending.lock().await.is_empty()); + + drop(agent_stdout_writer); + handle.await.unwrap(); + } + + /// Matched-id path: when a response's id is in `pending`, the loop must + /// resolve the oneshot AND forward a copy to the subscriber so the + /// adapter's recv loop sees the completion. Guards against regressions + /// that would suppress the forward branch while keeping resolve. + #[tokio::test] + async fn matched_id_response_resolves_pending_and_forwards() { + let (mut agent_stdout_writer, agent_stdout_reader) = duplex(8 * 1024); + let (agent_stdin_writer, _agent_stdin_reader) = duplex(8 * 1024); + + let pending: Arc>>> = + Arc::new(Mutex::new(HashMap::new())); + let notify_tx: Arc>>> = + Arc::new(Mutex::new(None)); + + let (resp_tx, resp_rx) = oneshot::channel(); + pending.lock().await.insert(7, resp_tx); + + let (sub_tx, mut sub_rx) = mpsc::unbounded_channel(); + *notify_tx.lock().await = Some(sub_tx); + + let writer = Arc::new(Mutex::new(agent_stdin_writer)); + let handle = tokio::spawn(run_reader_loop( + agent_stdout_reader, + writer, + pending.clone(), + notify_tx.clone(), + )); + + let payload = b"{\"jsonrpc\":\"2.0\",\"id\":7,\"result\":{\"stopReason\":\"end_turn\"}}\n"; + agent_stdout_writer.write_all(payload).await.unwrap(); + agent_stdout_writer.flush().await.unwrap(); + + let resolved = tokio::time::timeout(std::time::Duration::from_secs(2), resp_rx) + .await + .expect("oneshot should resolve") + .expect("oneshot should not be cancelled"); + assert_eq!(resolved.id, Some(7)); + + let forwarded = tokio::time::timeout(std::time::Duration::from_secs(2), sub_rx.recv()) + .await + .expect("subscriber should receive forwarded copy") + .expect("subscriber channel should not be closed"); + assert_eq!(forwarded.id, Some(7)); + assert!(pending.lock().await.is_empty()); + + drop(agent_stdout_writer); + handle.await.unwrap(); + } +} diff --git a/crates/openab-core/src/acp/mod.rs b/crates/openab-core/src/acp/mod.rs new file mode 100644 index 000000000..b6a60eaaf --- /dev/null +++ b/crates/openab-core/src/acp/mod.rs @@ -0,0 +1,9 @@ +#[cfg(feature = "agentcore")] +pub mod agentcore; +pub mod connection; +pub mod pool; +pub mod protocol; + +pub use connection::ContentBlock; +pub use pool::SessionPool; +pub use protocol::{classify_notification, AcpEvent}; diff --git a/crates/openab-core/src/acp/pool.rs b/crates/openab-core/src/acp/pool.rs new file mode 100644 index 000000000..d97397169 --- /dev/null +++ b/crates/openab-core/src/acp/pool.rs @@ -0,0 +1,622 @@ +use crate::acp::connection::AcpConnection; +use crate::acp::protocol::ConfigOption; +use crate::config::AgentConfig; +use anyhow::{anyhow, Result}; +use std::collections::HashMap; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use tokio::sync::{Mutex, RwLock}; +use tokio::time::Instant; +use tracing::{info, warn}; + +/// Combined state protected by a single lock to prevent deadlocks. +/// Lock ordering: never await a per-connection mutex while holding `state`. +struct PoolState { + /// Active connections: thread_key → AcpConnection handle. + active: HashMap>>, + /// Lock-free cancel handles: thread_key → (stdin, session_id). + /// Stored separately so cancel can work without locking the connection. + cancel_handles: HashMap>, String)>, + /// Suspended sessions: thread_key → ACP sessionId. + /// Used at runtime to decide which thread can be resumed via `session/load` + /// because it no longer has a live in-memory connection. + suspended: HashMap, + /// Persisted resumable sessions: thread_key → ACP sessionId. + /// Includes both suspended sessions and active sessions so a process restart + /// can recover any live thread via `session/load`. + persisted: HashMap, + /// Serializes create/resume work per thread so rapid same-thread requests + /// cannot race each other into duplicate `session/load` attempts. + creating: HashMap>>, + /// Per-session working directory overrides (from control directives). + /// thread_key → canonical workspace path. + session_workdirs: HashMap, +} + +pub struct SessionPool { + state: RwLock, + config: AgentConfig, + max_sessions: usize, + mapping_path: PathBuf, + meta_path: PathBuf, +} + +type EvictionCandidate = (String, Arc>, Instant, Option); + +fn remove_if_same_handle( + map: &mut HashMap>>, + key: &str, + expected: &Arc>, +) -> Option>> { + let should_remove = map + .get(key) + .is_some_and(|current| Arc::ptr_eq(current, expected)); + if should_remove { + map.remove(key) + } else { + None + } +} + +fn get_or_insert_gate(map: &mut HashMap>>, key: &str) -> Arc> { + map.entry(key.to_string()) + .or_insert_with(|| Arc::new(Mutex::new(()))) + .clone() +} + +impl SessionPool { + pub fn new(config: AgentConfig, max_sessions: usize) -> Self { + let openab_dir = std::env::var("HOME") + .map(PathBuf::from) + .unwrap_or_else(|_| PathBuf::from("/tmp")) + .join(".openab"); + let _ = std::fs::create_dir_all(&openab_dir); + let mapping_path = openab_dir.join("thread_map.json"); + let meta_path = openab_dir.join("session_meta.json"); + let suspended = Self::load_mapping(&mapping_path); + let session_workdirs = Self::load_mapping(&meta_path); + Self { + state: RwLock::new(PoolState { + active: HashMap::new(), + cancel_handles: HashMap::new(), + persisted: suspended.clone(), + suspended, + creating: HashMap::new(), + session_workdirs, + }), + config, + max_sessions, + mapping_path, + meta_path, + } + } + + fn load_mapping(path: &Path) -> HashMap { + match std::fs::read_to_string(path) { + Ok(data) => serde_json::from_str(&data).unwrap_or_else(|e| { + warn!(path = %path.display(), error = %e, "corrupt mapping file, starting fresh"); + HashMap::new() + }), + Err(_) => HashMap::new(), + } + } + + fn save_mapping(&self, persisted: &HashMap) { + let data = match serde_json::to_string_pretty(persisted) { + Ok(d) => d, + Err(e) => { + warn!(error = %e, "failed to serialize thread mapping"); + return; + } + }; + let tmp = self.mapping_path.with_extension("json.tmp"); + if let Err(e) = + std::fs::write(&tmp, &data).and_then(|_| std::fs::rename(&tmp, &self.mapping_path)) + { + warn!(path = %self.mapping_path.display(), error = %e, "failed to persist thread mapping"); + } + } + + fn save_meta(&self, workdirs: &HashMap) { + let data = match serde_json::to_string_pretty(workdirs) { + Ok(d) => d, + Err(e) => { + warn!(error = %e, "failed to serialize session metadata"); + return; + } + }; + let tmp = self.meta_path.with_extension("json.tmp"); + if let Err(e) = + std::fs::write(&tmp, &data).and_then(|_| std::fs::rename(&tmp, &self.meta_path)) + { + warn!(path = %self.meta_path.display(), error = %e, "failed to persist session metadata"); + } + } + + /// Check if session state exists for this thread (active, suspended, or persisted). + #[allow(dead_code)] + pub async fn has_active_session(&self, thread_id: &str) -> bool { + let state = self.state.read().await; + // Any of these means the thread already has session state. + if state.suspended.contains_key(thread_id) || state.persisted.contains_key(thread_id) { + return true; + } + if let Some(conn) = state.active.get(thread_id) { + match conn.try_lock() { + Ok(c) => return c.alive(), + Err(_) => return true, // lock held = connection busy streaming = alive + } + } + false + } + + pub async fn get_or_create( + &self, + thread_id: &str, + working_dir_override: Option<&str>, + ) -> Result { + let create_gate = { + let mut state = self.state.write().await; + get_or_insert_gate(&mut state.creating, thread_id) + }; + let _create_guard = create_gate.lock().await; + + let (existing, saved_session_id) = { + let state = self.state.read().await; + ( + state.active.get(thread_id).cloned(), + state.suspended.get(thread_id).cloned(), + ) + }; + + let had_existing = existing.is_some(); + let mut saved_session_id = saved_session_id; + if let Some(conn) = existing.clone() { + let conn = conn.lock().await; + if conn.alive() { + return Ok(false); + } + if saved_session_id.is_none() { + saved_session_id = conn.acp_session_id.clone(); + } + } + + // Snapshot active handles so we can inspect them outside the state lock. + let snapshot: Vec<(String, Arc>)> = { + let state = self.state.read().await; + state + .active + .iter() + .map(|(k, v)| (k.clone(), Arc::clone(v))) + .collect() + }; + + let mut eviction_candidate: Option = None; + let mut skipped_locked_candidates = 0usize; + for (key, conn) in snapshot { + if key == thread_id { + continue; + } + let conn_handle = Arc::clone(&conn); + let Ok(conn) = conn.try_lock() else { + skipped_locked_candidates += 1; + continue; + }; + let candidate = ( + key, + conn_handle, + conn.last_active, + conn.acp_session_id.clone(), + ); + match &eviction_candidate { + Some((_, _, oldest_last_active, _)) if candidate.2 >= *oldest_last_active => {} + _ => eviction_candidate = Some(candidate), + } + } + + // Resolve effective working directory: stored per-session > explicit override > global config. + // Stored value has highest priority to enforce immutability (ADR §4.5). + let stored_workdir = { + let state = self.state.read().await; + state.session_workdirs.get(thread_id).cloned() + }; + + let effective_workdir = if let Some(stored) = stored_workdir { + stored + } else if let Some(wd) = working_dir_override { + wd.to_string() + } else { + self.config.working_dir.clone() + }; + + // Build the replacement connection outside the state lock so one stuck + // initialization does not block all unrelated sessions. + let mut new_conn = AcpConnection::spawn( + &self.config.command, + &self.config.args, + &effective_workdir, + &self.config.env, + &self.config.inherit_env, + ) + .await?; + + new_conn.initialize().await?; + + let mut resumed = false; + if let Some(ref sid) = saved_session_id { + if new_conn.supports_load_session { + match new_conn.session_load(sid, &effective_workdir).await { + Ok(()) => { + info!(thread_id, session_id = %sid, "session resumed via session/load"); + resumed = true; + } + Err(e) => { + warn!(thread_id, session_id = %sid, error = %e, "session/load failed, creating new session"); + } + } + } + } + + if !resumed { + new_conn.session_new(&effective_workdir).await?; + // Surface the reset banner both for restored sessions and for stale + // live entries that died before we could recover a resumable + // session id. In both cases the caller is continuing after an + // unexpected session loss. + if had_existing || saved_session_id.is_some() { + new_conn.session_reset = true; + } + } + + let cancel_handle = new_conn.cancel_handle(); + let cancel_session_id = new_conn.acp_session_id.clone().unwrap_or_default(); + let new_conn = Arc::new(Mutex::new(new_conn)); + + let mut state = self.state.write().await; + + // Another task may have created a healthy connection while we were + // initializing this one. + if let Some(existing) = state.active.get(thread_id).cloned() { + let Ok(existing) = existing.try_lock() else { + return Ok(false); + }; + if existing.alive() { + return Ok(false); + } + warn!(thread_id, "stale connection, rebuilding"); + drop(existing); + state.active.remove(thread_id); + state.cancel_handles.remove(thread_id); + } + + if state.active.len() >= self.max_sessions { + if let Some((key, expected_conn, _, sid)) = eviction_candidate { + if remove_if_same_handle(&mut state.active, &key, &expected_conn).is_some() { + state.cancel_handles.remove(&key); + info!(evicted = %key, "pool full, suspending oldest idle session"); + if let Some(sid) = sid { + state.persisted.insert(key.clone(), sid.clone()); + state.suspended.insert(key, sid); + } else { + state.persisted.remove(&key); + } + } else { + warn!(evicted = %key, "pool full but eviction candidate changed before removal"); + } + } else if skipped_locked_candidates > 0 { + warn!( + max_sessions = self.max_sessions, + skipped_locked_candidates, + "pool full but all other sessions were busy during eviction scan" + ); + } + } + + if state.active.len() >= self.max_sessions { + return Err(anyhow!("pool exhausted ({} sessions)", self.max_sessions)); + } + + if cancel_session_id.is_empty() { + state.persisted.remove(thread_id); + } else { + state + .persisted + .insert(thread_id.to_string(), cancel_session_id.clone()); + } + state.suspended.remove(thread_id); + state.active.insert(thread_id.to_string(), new_conn); + if !cancel_session_id.is_empty() { + state + .cancel_handles + .insert(thread_id.to_string(), (cancel_handle, cancel_session_id)); + } + self.save_mapping(&state.persisted); + + // Persist workspace override only after session spawn succeeded (口渡 F2). + if working_dir_override.is_some() { + state + .session_workdirs + .entry(thread_id.to_string()) + .or_insert_with(|| effective_workdir.clone()); + self.save_meta(&state.session_workdirs); + } + + // Return true only for genuinely new sessions — not resumed or reconnected ones. + // A session with prior state (saved_session_id or had_existing) is a resume, + // even if we had to spawn a new ACP process. ADR §2.2: directives are first-message-only. + let is_fresh = !had_existing && saved_session_id.is_none(); + Ok(is_fresh) + } + + /// Get mutable access to a connection. Caller must have called get_or_create first. + /// + /// Only the per-connection `Mutex` is held during `f`; the pool-level + /// `RwLock` is acquired briefly (read-only) to look up the `Arc` and then + /// released, so other connections can be used concurrently. + pub async fn with_connection(&self, thread_id: &str, f: F) -> Result + where + F: for<'a> FnOnce( + &'a mut AcpConnection, + ) -> std::pin::Pin< + Box> + Send + 'a>, + >, + { + let conn = { + let state = self.state.read().await; + state + .active + .get(thread_id) + .cloned() + .ok_or_else(|| anyhow!("no connection for thread {thread_id}"))? + }; + + let mut conn = conn.lock().await; + f(&mut conn).await + } + + /// Get cached configOptions for a session (e.g. available models). + pub async fn get_config_options(&self, thread_id: &str) -> Vec { + let state = self.state.read().await; + let conn = match state.active.get(thread_id) { + Some(c) => c.clone(), + None => return Vec::new(), + }; + drop(state); + let conn = conn.lock().await; + conn.config_options.clone() + } + + /// Set a config option (e.g. model) via ACP and return updated options. + pub async fn set_config_option( + &self, + thread_id: &str, + config_id: &str, + value: &str, + ) -> Result> { + let conn = { + let state = self.state.read().await; + state + .active + .get(thread_id) + .cloned() + .ok_or_else(|| anyhow!("no connection for thread {thread_id}"))? + }; + let mut conn = conn.lock().await; + conn.set_config_option(config_id, value).await + } + + /// Cancel the current in-flight operation for a session. + /// Uses pre-stored cancel handles to avoid locking the connection (which is held during streaming). + pub async fn cancel_session(&self, thread_id: &str) -> Result<()> { + let (stdin, session_id) = { + let state = self.state.read().await; + state + .cancel_handles + .get(thread_id) + .cloned() + .ok_or_else(|| anyhow!("no session for thread {thread_id}"))? + }; + let data = serde_json::to_string(&serde_json::json!({ + "jsonrpc": "2.0", + "method": "session/cancel", + "params": {"sessionId": session_id} + }))?; + tracing::info!(session_id, "sending session/cancel"); + use tokio::io::AsyncWriteExt; + let mut w = stdin.lock().await; + w.write_all(data.as_bytes()).await?; + w.write_all(b"\n").await?; + w.flush().await?; + Ok(()) + } + + /// Reset a session: cancel any in-flight operation, remove the active connection, + /// and clear all suspended state. The ACP process will be killed once the last + /// Arc reference is dropped (after streaming finishes). The next message will + /// trigger a fresh `get_or_create` with a new ACP session. + pub async fn reset_session(&self, thread_id: &str) -> Result<()> { + // Send session/cancel via the lock-free stdin handle first. + // This stops in-flight streaming even while with_connection() holds the + // connection mutex, so the old process finishes promptly. + if let Some((stdin, session_id)) = { + let state = self.state.read().await; + state.cancel_handles.get(thread_id).cloned() + } { + let data = serde_json::to_string(&serde_json::json!({ + "jsonrpc": "2.0", + "method": "session/cancel", + "params": {"sessionId": session_id} + }))?; + tracing::info!(session_id, "reset: sending session/cancel"); + use tokio::io::AsyncWriteExt; + let mut w = stdin.lock().await; + let _ = w.write_all(data.as_bytes()).await; + let _ = w.write_all(b"\n").await; + let _ = w.flush().await; + } + + let mut state = self.state.write().await; + let had_active = state.active.remove(thread_id).is_some(); + state.cancel_handles.remove(thread_id); + state.suspended.remove(thread_id); + state.persisted.remove(thread_id); + state.creating.remove(thread_id); + state.session_workdirs.remove(thread_id); + self.save_mapping(&state.persisted); + self.save_meta(&state.session_workdirs); + if had_active { + info!(thread_id, "session reset"); + Ok(()) + } else { + Err(anyhow!("no session for thread {thread_id}")) + } + } + + pub async fn cleanup_idle(&self, ttl_secs: u64) { + let cutoff = Instant::now() - std::time::Duration::from_secs(ttl_secs); + + let snapshot: Vec<(String, Arc>)> = { + let state = self.state.read().await; + state + .active + .iter() + .map(|(k, v)| (k.clone(), Arc::clone(v))) + .collect() + }; + + let mut stale = Vec::new(); + for (key, conn) in snapshot { + // Skip active sessions for this cleanup round instead of waiting on + // their per-connection mutex. A busy session is not idle. + let conn_handle = Arc::clone(&conn); + let Ok(conn) = conn.try_lock() else { + continue; + }; + if conn.last_active < cutoff || !conn.alive() { + stale.push((key, conn_handle, conn.acp_session_id.clone())); + } + } + + if stale.is_empty() { + return; + } + + let mut state = self.state.write().await; + for (key, expected_conn, sid) in stale { + if remove_if_same_handle(&mut state.active, &key, &expected_conn).is_some() { + info!(thread_id = %key, "cleaning up idle session"); + state.cancel_handles.remove(&key); + if let Some(sid) = sid { + state.persisted.insert(key.clone(), sid.clone()); + state.suspended.insert(key, sid); + } else { + state.persisted.remove(&key); + state.session_workdirs.remove(&key); + } + } + } + self.save_mapping(&state.persisted); + self.save_meta(&state.session_workdirs); + } + + pub async fn shutdown(&self) { + // Snapshot active handles, then drop state lock before awaiting + // per-connection mutexes (lock ordering: never hold state while + // awaiting a connection lock). + let snapshot: Vec<(String, Arc>)> = { + let state = self.state.read().await; + state + .active + .iter() + .map(|(k, v)| (k.clone(), Arc::clone(v))) + .collect() + }; + + let mut session_ids: Vec<(String, String)> = Vec::new(); + for (key, conn) in snapshot { + let conn = conn.lock().await; + if let Some(sid) = conn.acp_session_id.clone() { + session_ids.push((key, sid)); + } + } + + let mut state = self.state.write().await; + for (key, sid) in session_ids { + state.persisted.insert(key.clone(), sid.clone()); + state.suspended.insert(key, sid); + } + self.save_mapping(&state.persisted); + let count = state.active.len(); + state.active.clear(); + state.cancel_handles.clear(); + info!(count, "pool shutdown complete"); + } +} + +#[cfg(test)] +mod tests { + use super::{get_or_insert_gate, remove_if_same_handle}; + use std::collections::HashMap; + use std::sync::Arc; + use tokio::sync::Mutex; + + #[test] + fn remove_if_same_handle_removes_matching_entry() { + let expected = Arc::new(Mutex::new(1_u8)); + let mut map = HashMap::from([("thread".to_string(), Arc::clone(&expected))]); + + let removed = remove_if_same_handle(&mut map, "thread", &expected); + + assert!(removed.is_some()); + assert!(map.is_empty()); + } + + #[test] + fn remove_if_same_handle_keeps_replaced_entry() { + let stale = Arc::new(Mutex::new(1_u8)); + let fresh = Arc::new(Mutex::new(2_u8)); + let mut map = HashMap::from([("thread".to_string(), Arc::clone(&fresh))]); + + let removed = remove_if_same_handle(&mut map, "thread", &stale); + + assert!(removed.is_none()); + let current = map.get("thread").expect("entry should remain"); + assert!(Arc::ptr_eq(current, &fresh)); + } + + #[test] + fn get_or_insert_gate_reuses_gate_for_same_thread() { + let mut map = HashMap::new(); + + let first = get_or_insert_gate(&mut map, "thread"); + let second = get_or_insert_gate(&mut map, "thread"); + + assert!(Arc::ptr_eq(&first, &second)); + assert_eq!(map.len(), 1); + } + + #[test] + fn persisted_mapping_can_include_active_and_suspended_sessions() { + let persisted = HashMap::from([ + ("active-thread".to_string(), "session-active".to_string()), + ( + "suspended-thread".to_string(), + "session-suspended".to_string(), + ), + ]); + + let serialized = + serde_json::to_string_pretty(&persisted).expect("serialize persisted mapping"); + let roundtrip: HashMap = + serde_json::from_str(&serialized).expect("deserialize persisted mapping"); + + assert_eq!( + roundtrip.get("active-thread"), + Some(&"session-active".to_string()) + ); + assert_eq!( + roundtrip.get("suspended-thread"), + Some(&"session-suspended".to_string()) + ); + } +} diff --git a/crates/openab-core/src/acp/protocol.rs b/crates/openab-core/src/acp/protocol.rs new file mode 100644 index 000000000..099d98b71 --- /dev/null +++ b/crates/openab-core/src/acp/protocol.rs @@ -0,0 +1,406 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +// --- Outgoing --- + +#[derive(Debug, Serialize)] +pub struct JsonRpcRequest { + pub jsonrpc: &'static str, + pub id: u64, + pub method: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub params: Option, +} + +impl JsonRpcRequest { + pub fn new(id: u64, method: impl Into, params: Option) -> Self { + Self { + jsonrpc: "2.0", + id, + method: method.into(), + params, + } + } +} + +#[derive(Debug, Serialize)] +pub struct JsonRpcResponse { + pub jsonrpc: &'static str, + pub id: u64, + pub result: Value, +} + +impl JsonRpcResponse { + pub fn new(id: u64, result: Value) -> Self { + Self { + jsonrpc: "2.0", + id, + result, + } + } +} + +// --- Incoming --- + +#[derive(Debug, Deserialize)] +pub struct JsonRpcMessage { + pub id: Option, + pub method: Option, + pub result: Option, + pub error: Option, + pub params: Option, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct JsonRpcError { + pub code: i64, + pub message: String, + /// Optional structured data from the agent (JSON-RPC `error.data`). + /// Agents like codex-acp include `{"message": "...", "codex_error_info": "..."}`. + pub data: Option, +} + +impl JsonRpcError { + /// Extract a human-readable detail from `error.data.message` if present. + /// + /// The `"message"` key is a convention used by codex-acp and aligns with + /// common JSON-RPC practice, but is NOT mandated by the ACP spec. + /// Other agents may use `"detail"`, `"reason"`, etc. — extend here if needed. + pub fn data_message(&self) -> Option<&str> { + self.data + .as_ref() + .and_then(|d| d.get("message")) + .and_then(|m| m.as_str()) + } +} + +impl std::fmt::Display for JsonRpcError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "JSON-RPC error {}: {}", self.code, self.message)?; + if let Some(detail) = self.data_message() { + write!(f, " — {detail}")?; + } + Ok(()) + } +} + +// --- ACP configOptions (session-level configuration) --- + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConfigOptionValue { + pub value: String, + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ConfigOption { + pub id: String, + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub category: Option, + #[serde(rename = "type")] + pub option_type: String, + pub current_value: String, + pub options: Vec, +} + +/// Extract configOptions from a JSON-RPC result value. +/// Supports standard `configOptions` and kiro-cli's `models`/`modes` fallback. +pub fn parse_config_options(result: &Value) -> Vec { + if let Some(opts) = result + .get("configOptions") + .and_then(|v| serde_json::from_value::>(v.clone()).ok()) + { + if !opts.is_empty() { + return opts; + } + } + + // Kiro-cli fallback: parse models/modes format + let mut options = Vec::new(); + + if let Some(models) = result.get("models") { + let current = models + .get("currentModelId") + .and_then(|v| v.as_str()) + .unwrap_or(""); + if let Some(available) = models.get("availableModels").and_then(|v| v.as_array()) { + let values: Vec = available + .iter() + .filter_map(|m| { + let id = m + .get("modelId") + .or_else(|| m.get("id")) + .and_then(|v| v.as_str())?; + let name = m.get("name").and_then(|v| v.as_str()).unwrap_or(id); + Some(ConfigOptionValue { + value: id.to_string(), + name: name.to_string(), + description: m + .get("description") + .and_then(|v| v.as_str()) + .map(String::from), + }) + }) + .collect(); + if !values.is_empty() { + options.push(ConfigOption { + id: "model".to_string(), + name: "Model".to_string(), + description: Some("AI model selection".to_string()), + category: Some("model".to_string()), + option_type: "enum".to_string(), + current_value: current.to_string(), + options: values, + }); + } + } + } + + if let Some(modes) = result.get("modes") { + let current = modes + .get("currentModeId") + .and_then(|v| v.as_str()) + .unwrap_or(""); + if let Some(available) = modes.get("availableModes").and_then(|v| v.as_array()) { + let values: Vec = available + .iter() + .filter_map(|m| { + let id = m.get("id").and_then(|v| v.as_str())?; + let name = m.get("name").and_then(|v| v.as_str()).unwrap_or(id); + Some(ConfigOptionValue { + value: id.to_string(), + name: name.to_string(), + description: m + .get("description") + .and_then(|v| v.as_str()) + .map(String::from), + }) + }) + .collect(); + if !values.is_empty() { + options.push(ConfigOption { + id: "agent".to_string(), + name: "Agent".to_string(), + description: Some("Agent mode selection".to_string()), + category: Some("agent".to_string()), + option_type: "enum".to_string(), + current_value: current.to_string(), + options: values, + }); + } + } + } + + options +} + +// --- ACP notification classification --- + +#[derive(Debug)] +pub enum AcpEvent { + Text(String), + Thinking, + ToolStart { + id: String, + title: String, + }, + ToolDone { + id: String, + title: String, + status: String, + }, + ConfigUpdate { + options: Vec, + }, + Status, +} + +pub fn classify_notification(msg: &JsonRpcMessage) -> Option { + let params = msg.params.as_ref()?; + let update = params.get("update")?; + let session_update = update.get("sessionUpdate")?.as_str()?; + + // toolCallId is the stable identity across tool_call → tool_call_update + // events for the same tool invocation. claude-agent-acp emits the first + // event before the input fields are streamed in (so the title falls back + // to "Terminal" / "Edit" / etc.) and refines them in a later + // tool_call_update; without the id we can't tell those events belong to + // the same call and end up rendering placeholder + refined as two + // separate lines. + let tool_id = update + .get("toolCallId") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + + match session_update { + "agent_message_chunk" => { + let text = update.get("content")?.get("text")?.as_str()?; + Some(AcpEvent::Text(text.to_string())) + } + "agent_thought_chunk" => Some(AcpEvent::Thinking), + "tool_call" => { + let title = update + .get("title") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + Some(AcpEvent::ToolStart { id: tool_id, title }) + } + "tool_call_update" => { + let title = update + .get("title") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let status = update + .get("status") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + if status == "completed" || status == "failed" { + Some(AcpEvent::ToolDone { + id: tool_id, + title, + status, + }) + } else { + Some(AcpEvent::ToolStart { id: tool_id, title }) + } + } + "plan" => Some(AcpEvent::Status), + "config_option_update" => { + let options = parse_config_options(update); + Some(AcpEvent::ConfigUpdate { options }) + } + _ => None, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn parse_standard_config_options() { + let result = json!({ + "configOptions": [{ + "id": "model", + "name": "Model", + "type": "enum", + "currentValue": "claude-sonnet-4", + "options": [ + {"value": "claude-sonnet-4", "name": "Sonnet 4"}, + {"value": "claude-opus-4", "name": "Opus 4"} + ] + }] + }); + let opts = parse_config_options(&result); + assert_eq!(opts.len(), 1); + assert_eq!(opts[0].id, "model"); + assert_eq!(opts[0].current_value, "claude-sonnet-4"); + assert_eq!(opts[0].options.len(), 2); + } + + #[test] + fn parse_kiro_models_fallback() { + let result = json!({ + "models": { + "currentModelId": "m1", + "availableModels": [ + {"modelId": "m1", "name": "Model One"}, + {"modelId": "m2", "name": "Model Two"} + ] + } + }); + let opts = parse_config_options(&result); + assert_eq!(opts.len(), 1); + assert_eq!(opts[0].id, "model"); + assert_eq!(opts[0].category.as_deref(), Some("model")); + assert_eq!(opts[0].current_value, "m1"); + assert_eq!(opts[0].options.len(), 2); + } + + #[test] + fn parse_kiro_modes_fallback() { + let result = json!({ + "modes": { + "currentModeId": "default", + "availableModes": [ + {"id": "default", "name": "Default"}, + {"id": "planner", "name": "Planner"} + ] + } + }); + let opts = parse_config_options(&result); + assert_eq!(opts.len(), 1); + assert_eq!(opts[0].id, "agent"); + assert_eq!(opts[0].category.as_deref(), Some("agent")); + assert_eq!(opts[0].current_value, "default"); + } + + #[test] + fn parse_kiro_models_and_modes() { + let result = json!({ + "models": { + "currentModelId": "m1", + "availableModels": [{"modelId": "m1", "name": "M1"}] + }, + "modes": { + "currentModeId": "default", + "availableModes": [{"id": "default", "name": "Default"}] + } + }); + let opts = parse_config_options(&result); + assert_eq!(opts.len(), 2); + assert_eq!(opts[0].id, "model"); + assert_eq!(opts[1].id, "agent"); + } + + #[test] + fn parse_standard_takes_precedence_over_kiro() { + let result = json!({ + "configOptions": [{ + "id": "model", + "name": "Model", + "type": "enum", + "currentValue": "standard", + "options": [{"value": "standard", "name": "Standard"}] + }], + "models": { + "currentModelId": "kiro", + "availableModels": [{"modelId": "kiro", "name": "Kiro"}] + } + }); + let opts = parse_config_options(&result); + assert_eq!(opts.len(), 1); + assert_eq!(opts[0].current_value, "standard"); + } + + #[test] + fn parse_empty_result() { + let opts = parse_config_options(&json!({})); + assert!(opts.is_empty()); + } + + #[test] + fn parse_empty_config_options_falls_through_to_kiro() { + let result = json!({ + "configOptions": [], + "models": { + "currentModelId": "m1", + "availableModels": [{"modelId": "m1", "name": "M1"}] + } + }); + let opts = parse_config_options(&result); + assert_eq!(opts.len(), 1); + assert_eq!(opts[0].id, "model"); + } +} diff --git a/crates/openab-core/src/adapter.rs b/crates/openab-core/src/adapter.rs new file mode 100644 index 000000000..8b77242b5 --- /dev/null +++ b/crates/openab-core/src/adapter.rs @@ -0,0 +1,1659 @@ +use anyhow::Result; +use async_trait::async_trait; +use serde::Serialize; +use std::sync::Arc; +use tracing::{error, warn}; + +use crate::acp::{classify_notification, AcpEvent, ContentBlock, SessionPool}; +use crate::config::{ReactionsConfig, ToolDisplay}; +use crate::error_display::{format_coded_error, format_user_error}; +use crate::format; +use crate::markdown::{self, TableMode}; +use crate::reactions::StatusReactionController; + +// --- Output directive parsing --- + +/// Parsed directives from agent output header block. +/// Consecutive `[[key:value]]` lines at the start of output are directives. +#[derive(Default, Debug)] +pub struct OutputDirectives { + /// Message ID to reply to (Discord: message_reference) + pub reply_to: Option, +} + +/// Parse `[[key:value]]` directives from the beginning of agent output. +/// Returns parsed directives and the remaining content (directives stripped). +pub fn parse_output_directives(content: &str) -> (OutputDirectives, String) { + let mut directives = OutputDirectives::default(); + let mut content_start = 0; + let mut trailing_content: Option<&str> = None; + + for line in content.lines() { + let trimmed = line.trim(); + // Try to match [[key:value]] at the start of the line (lenient: allows trailing content) + if let Some(after_open) = trimmed.strip_prefix("[[") { + if let Some(close_pos) = after_open.find("]]") { + let inner = &after_open[..close_pos]; + if let Some((key, value)) = inner.split_once(':') { + match key.trim() { + "reply_to" => { + let v = value.trim(); + // Validate: non-empty, reasonable length, no whitespace/control chars + if !v.is_empty() && v.len() <= 64 && v.chars().all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_') { + directives.reply_to = Some(v.to_string()); + } + } + _ => { + tracing::debug!(key = key.trim(), "unknown output directive ignored"); + } + } + // Check for trailing content after ]] + let remainder = after_open[close_pos + 2..].trim(); + if !remainder.is_empty() { + trailing_content = Some(remainder); + // Advance past this line + content_start += line.len(); + if content.as_bytes().get(content_start) == Some(&b'\r') { + content_start += 1; + } + if content.as_bytes().get(content_start) == Some(&b'\n') { + content_start += 1; + } + break; // Trailing content ends directive header + } + // Advance past this line + its line ending (handles both \n and \r\n) + content_start += line.len(); + if content.as_bytes().get(content_start) == Some(&b'\r') { + content_start += 1; + } + if content.as_bytes().get(content_start) == Some(&b'\n') { + content_start += 1; + } + } else { + // [[X]] without colon — not a directive, stop parsing + break; + } + } else { + // No closing ]] found — not a directive, stop parsing + break; + } + } else { + break; + } + } + + let remaining = if let Some(trailing) = trailing_content { + if content_start < content.len() { + format!("{}\n{}", trailing, &content[content_start..]) + } else { + trailing.to_string() + } + } else if content_start < content.len() { + content[content_start..].to_string() + } else { + String::new() + }; + (directives, remaining) +} + +// --- Platform-agnostic types --- + +/// Identifies a channel or thread across platforms. +/// +/// Used for **routing**: `channel_id` is the ID the adapter sends messages to. +/// For Discord threads, this is the thread's own channel ID (Discord API +/// requires it for `say`/`edit`). Use `parent_id` to find the parent channel. +/// +/// Compare with `SenderContext`, which is **metadata for the agent**: there +/// `channel_id` is the parent channel and `thread_id` is the thread, +/// matching Slack's model for cross-platform consistency. +#[derive(Clone, Debug)] +pub struct ChannelRef { + pub platform: String, + pub channel_id: String, + /// Thread within a channel (e.g. Slack thread_ts, Telegram topic_id). + /// For Discord, threads are separate channels so this is None. + pub thread_id: Option, + /// Parent channel if this is a thread-as-channel (Discord). + pub parent_id: Option, + /// Originating gateway event ID, propagated back in `GatewayReply.reply_to` + /// so the gateway can correlate replies with inbound events (e.g. LINE reply tokens). + /// Excluded from Hash/Eq — two ChannelRefs pointing to the same channel are + /// equal regardless of which event they originated from. + pub origin_event_id: Option, +} + +impl PartialEq for ChannelRef { + fn eq(&self, other: &Self) -> bool { + self.platform == other.platform + && self.channel_id == other.channel_id + && self.thread_id == other.thread_id + && self.parent_id == other.parent_id + } +} + +impl Eq for ChannelRef {} + +impl std::hash::Hash for ChannelRef { + fn hash(&self, state: &mut H) { + self.platform.hash(state); + self.channel_id.hash(state); + self.thread_id.hash(state); + self.parent_id.hash(state); + } +} + +/// Identifies a message across platforms. +#[derive(Clone, Debug)] +pub struct MessageRef { + pub channel: ChannelRef, + pub message_id: String, +} + +/// Bundles per-message parameters for `AdapterRouter::handle_message`. +/// +/// Introduced to reduce parameter count and make the signature extensible +/// (e.g. streaming policy, rate limit hints) without breaking call sites. +pub struct MessageContext { + pub thread_channel: ChannelRef, + pub sender_json: String, + pub prompt: String, + pub extra_blocks: Vec, + pub trigger_msg: MessageRef, + pub other_bot_present: bool, +} + +/// Sender identity injected into prompts for downstream agent context. +/// +/// This is **metadata for the agent** — `channel_id` always refers to the +/// logical parent channel, and `thread_id` identifies the thread (if any). +/// This convention is consistent across platforms (Slack, Discord, Telegram). +/// +/// Compare with `ChannelRef`, which is used for **routing**: there +/// `channel_id` is the ID the adapter sends messages to (for Discord +/// threads, that's the thread's own channel ID, not the parent). +#[derive(Clone, Debug, Serialize)] +pub struct SenderContext { + pub schema: String, + pub sender_id: String, + pub sender_name: String, + pub display_name: String, + pub channel: String, + pub channel_id: String, + /// Thread identifier, if the message is inside a thread. + /// Slack: thread_ts. Discord: thread channel ID (channel_id holds the parent). + #[serde(skip_serializing_if = "Option::is_none")] + pub thread_id: Option, + pub is_bot: bool, + /// Platform message creation time (ISO 8601 UTC), if available. + /// Discord/Slack: platform timestamp. Gateway: broker receive time (best-effort). + /// Additive optional field — schema version stays openab.sender.v1 (no consumer + /// breakage). If future additions require breaking changes, bump to v1.1+. + #[serde(skip_serializing_if = "Option::is_none")] + pub timestamp: Option, + /// Platform message ID. Agents can use this to reply to a specific message + /// via the `[[reply_to:]]` output directive. + #[serde(skip_serializing_if = "Option::is_none")] + pub message_id: Option, + /// The platform user ID of the receiving bot/agent. + /// Enables agents to identify themselves when multiple agents share the same backend. + #[serde(skip_serializing_if = "Option::is_none")] + pub receiver_id: Option, +} + +// --- ChatAdapter trait --- + +#[async_trait] +pub trait ChatAdapter: Send + Sync + 'static { + /// Platform name for logging and session key namespacing. + fn platform(&self) -> &'static str; + + /// Maximum message length (chars) for this platform; the router splits longer + /// replies into multiple messages at this bound. Platform-specific (e.g. 2000 + /// for Discord; Slack uses its Block Kit `markdown` block cap). + fn message_limit(&self) -> usize; + + /// Send a new message, returns a reference to the sent message. + async fn send_message(&self, channel: &ChannelRef, content: &str) -> Result; + + /// Create a thread from a trigger message, returns the thread channel ref. + async fn create_thread( + &self, + channel: &ChannelRef, + trigger_msg: &MessageRef, + title: &str, + ) -> Result; + + /// Add a reaction/emoji to a message. + async fn add_reaction(&self, msg: &MessageRef, emoji: &str) -> Result<()>; + + /// Remove a reaction/emoji from a message. + async fn remove_reaction(&self, msg: &MessageRef, emoji: &str) -> Result<()>; + + /// Edit an existing message in-place (for streaming updates). + /// Default: unsupported (send-once only). + async fn edit_message(&self, _msg: &MessageRef, _content: &str) -> Result<()> { + Err(anyhow::anyhow!("edit_message not supported")) + } + + /// Send a message as a reply to a specific message (Discord: message_reference). + /// Default: falls back to plain send_message (ignores reply_to). + async fn send_message_with_reply( + &self, + channel: &ChannelRef, + content: &str, + reply_to_message_id: &str, + ) -> Result { + let _ = reply_to_message_id; // unused in default impl + self.send_message(channel, content).await + } + + /// Rename the thread/channel title. Default: no-op (not all platforms support it). + async fn rename_thread(&self, _channel: &ChannelRef, _title: &str) -> Result<()> { + Ok(()) + } + + /// Delete a message. Used to remove streaming placeholders when reply_to is set. + /// Default: edits to zero-width space (fallback for platforms without delete support). + async fn delete_message(&self, msg: &MessageRef) -> Result<()> { + self.edit_message(msg, "\u{200b}").await + } + + /// Whether this adapter streams via a native streaming API (Slack + /// chat.startStream) rather than the post+edit loop. Default: false. + /// `other_bot_present` lets adapters fall back to send-once in multi-bot + /// threads (mirrors `use_streaming`'s #534 rule). + fn uses_native_streaming(&self, _other_bot_present: bool) -> bool { + false + } + + /// Begin a native stream. The returned MessageRef is the handle for + /// subsequent `stream_append` / `stream_finish`. + /// Default: delegate to send_message (only called when uses_native_streaming). + /// `recipient` is the per-turn `(user_id, team_id)` for platforms (Slack) that + /// need it for the native stream open; ignored by the default impl. + async fn stream_begin( + &self, + channel: &ChannelRef, + _recipient: Option<(String, String)>, + ) -> Result { + self.send_message(channel, "…").await + } + + /// Append an INCREMENTAL delta to a native stream. + /// Default: best-effort edit (only called when uses_native_streaming). + async fn stream_append(&self, msg: &MessageRef, delta: &str) -> Result<()> { + self.edit_message(msg, delta).await + } + + /// Finish a native stream and write the COMPLETE final content. + /// Default: delegate to edit_message. + async fn stream_finish(&self, msg: &MessageRef, final_content: &str) -> Result<()> { + self.edit_message(msg, final_content).await + } + + /// Whether this adapter uses a status API (e.g. assistant.threads.setStatus) + /// instead of emoji reactions for thinking/tool indicators. Independent of + /// `uses_native_streaming` — status can work without content streaming. + /// Default: false. + fn uses_assistant_status(&self) -> bool { + false + } + + /// Set an ephemeral status line (e.g. "Thinking…", "Using …"). + /// Empty string clears it. Default: no-op (platforms without a status API). + async fn set_status(&self, _channel: &ChannelRef, _status: &str) -> Result<()> { + Ok(()) + } + + /// Whether this platform renders Markdown tables natively. When `true`, the + /// router skips the `convert_tables` pre-pass (which rewrites tables into + /// code blocks / bullet lists for platforms that cannot render them) and + /// lets the platform render the raw Markdown table itself. + /// Default: `false` (keep converting). Overridden by Slack (Block Kit + /// `markdown` blocks / `markdown_text` stream chunks render tables natively). + fn renders_native_tables(&self) -> bool { + false + } + + /// Whether this adapter should use streaming edit (true) or send-once (false). + /// `other_bot_present` indicates if another bot has posted in the current thread. + /// Streaming should be disabled in multi-bot threads to avoid edit interference. + /// NOTE: Slight race window exists — the multibot cache is checked before + /// handle_message, so a bot arriving between the check and the response will + /// not be detected until the next message. This is acceptable: the first + /// response may stream, but subsequent ones will correctly use send-once. + fn use_streaming(&self, other_bot_present: bool) -> bool; + + /// Whether to send the "…" placeholder message before streaming starts. + /// Default: true. Platforms using drafts (e.g. Telegram Rich Messages) can + /// return false to suppress the placeholder. + fn show_streaming_placeholder(&self) -> bool { + true + } +} + +// --- AdapterRouter --- + +/// Shared logic for routing messages to ACP agents, managing sessions, +/// streaming edits, and controlling reactions. Platform-independent. +pub struct AdapterRouter { + pool: Arc, + reactions_config: ReactionsConfig, + table_mode: TableMode, + prompt_hard_timeout: std::time::Duration, + /// Polling cadence for the recv-loop liveness check (#732). + liveness_check_interval: std::time::Duration, + /// Workspace aliases from `[workspace.aliases]` config. + workspace_aliases: std::collections::HashMap, + /// Bot home directory (security boundary for workspace directives). + bot_home: std::path::PathBuf, +} + +impl AdapterRouter { + pub fn new( + pool: Arc, + reactions_config: ReactionsConfig, + table_mode: TableMode, + prompt_hard_timeout_secs: u64, + liveness_check_secs: u64, + workspace_aliases: std::collections::HashMap, + bot_home: std::path::PathBuf, + ) -> Self { + if liveness_check_secs >= prompt_hard_timeout_secs { + warn!( + liveness_check_secs, + prompt_hard_timeout_secs, + "pool.liveness_check_secs >= pool.prompt_hard_timeout_secs; \ + the hard ceiling will only fire after the next liveness tick \ + and may be effectively bypassed. Lower liveness_check_secs." + ); + } + Self { + pool, + reactions_config, + table_mode, + prompt_hard_timeout: std::time::Duration::from_secs(prompt_hard_timeout_secs), + liveness_check_interval: std::time::Duration::from_secs(liveness_check_secs), + workspace_aliases, + bot_home, + } + } + + /// Access the underlying session pool (e.g. for config option queries). + pub fn pool(&self) -> &Arc { + &self.pool + } + + /// Access the reactions config (used by dispatch.rs). + pub fn reactions_config(&self) -> &ReactionsConfig { + &self.reactions_config + } + + /// Workspace aliases for control directive resolution. + pub fn workspace_aliases_map(&self) -> std::collections::HashMap { + self.workspace_aliases.clone() + } + + /// Bot home path for workspace security boundary. + pub fn bot_home_path(&self) -> std::path::PathBuf { + self.bot_home.clone() + } + + /// Pack one arrival event into ContentBlocks. Per-arrival layout: + /// Text { "\n{json}\n" } <- delimiter + /// [Text blocks from extra_blocks (e.g. STT transcripts)] + /// Text { "{prompt}" } <- omitted if empty + /// [non-Text blocks from extra_blocks (e.g. Image)] + /// + /// The sender_context block stands alone so it can serve as a structural + /// delimiter between arrivals in batched dispatch — agents can scan for + /// `` openers to find arrival boundaries. Within an arrival, + /// transcript text precedes the typed prompt to match pre-batching adapter + /// behavior (voice content first), and images trail the prompt as before. + /// This is the single packing code path for both per-message and batched + /// dispatch (ADR §3.5). For a batch of N messages, call this N times and + /// concatenate. + pub fn pack_arrival_event( + sender_json: &str, + prompt: &str, + extra_blocks: Vec, + ) -> Vec { + let header = format!("\n{}\n", sender_json); + let (texts, others): (Vec<_>, Vec<_>) = extra_blocks + .into_iter() + .partition(|b| matches!(b, ContentBlock::Text { .. })); + let mut blocks = Vec::with_capacity(2 + texts.len() + others.len()); + blocks.push(ContentBlock::Text { text: header }); + blocks.extend(texts); + if !prompt.is_empty() { + blocks.push(ContentBlock::Text { + text: prompt.to_string(), + }); + } + blocks.extend(others); + blocks + } + + /// Handle an incoming user message. The adapter is responsible for + /// filtering, resolving the thread, and building the SenderContext. + /// This method handles sender context injection, session management, and streaming. + pub async fn handle_message( + &self, + adapter: &Arc, + ctx: MessageContext, + ) -> Result<()> { + tracing::debug!(platform = adapter.platform(), "processing message"); + + let content_blocks = + Self::pack_arrival_event(&ctx.sender_json, &ctx.prompt, ctx.extra_blocks); + + let thread_key = format!( + "{}:{}", + adapter.platform(), + ctx.thread_channel + .thread_id + .as_deref() + .unwrap_or(&ctx.thread_channel.channel_id) + ); + + if let Err(e) = self.pool.get_or_create(&thread_key, None).await { + let msg = format_user_error(&e.to_string()); + let _ = adapter + .send_message(&ctx.thread_channel, &format!("⚠️ {msg}")) + .await; + error!("pool error: {e}"); + return Err(e); + } + + // In assistant-status mode (e.g. Slack assistant_mode), status is conveyed + // via assistant.threads.setStatus, so the emoji-reaction lifecycle is skipped + // entirely — mirrors dispatch_batch so per-message and batched modes agree. + let assistant_status = adapter.uses_assistant_status(); + + let reactions = Arc::new(StatusReactionController::new( + self.reactions_config.enabled, + adapter.clone(), + ctx.trigger_msg.clone(), + self.reactions_config.emojis.clone(), + self.reactions_config.timing.clone(), + )); + if !assistant_status { + reactions.set_queued().await; + } + + let result = self + .stream_prompt( + adapter, + &thread_key, + content_blocks, + &ctx.thread_channel, + reactions.clone(), + ctx.other_bot_present, + ) + .await; + + if !assistant_status { + match &result { + Ok(()) => reactions.set_done().await, + Err(_) => reactions.set_error().await, + } + + let hold_ms = if result.is_ok() { + self.reactions_config.timing.done_hold_ms + } else { + self.reactions_config.timing.error_hold_ms + }; + if self.reactions_config.remove_after_reply { + let reactions = reactions; + tokio::spawn(async move { + tokio::time::sleep(std::time::Duration::from_millis(hold_ms)).await; + reactions.clear().await; + }); + } + } + + if let Err(ref e) = result { + let _ = adapter + .send_message(&ctx.thread_channel, &format!("⚠️ {e}")) + .await; + } + + result + } + + async fn stream_prompt( + &self, + adapter: &Arc, + thread_key: &str, + content_blocks: Vec, + thread_channel: &ChannelRef, + reactions: Arc, + other_bot_present: bool, + ) -> Result<()> { + self.stream_prompt_blocks( + adapter, + thread_key, + content_blocks, + thread_channel, + reactions, + other_bot_present, + // handle_message path (e.g. cron) is never Slack assistant-mode native + // streaming, so no per-turn recipient — degrades to post+edit if it were. + None, + ) + .await + } + + /// Drive one ACP turn with the given pre-packed ContentBlocks. + /// Called by both `handle_message` (per-message mode) and `dispatch::dispatch_batch` + /// (batched mode). + #[allow(clippy::too_many_arguments)] + pub async fn stream_prompt_blocks( + &self, + adapter: &Arc, + thread_key: &str, + content_blocks: Vec, + thread_channel: &ChannelRef, + reactions: Arc, + other_bot_present: bool, + recipient: Option<(String, String)>, + ) -> Result<()> { + let adapter = adapter.clone(); + let thread_channel = thread_channel.clone(); + let message_limit = adapter.message_limit(); + let streaming = adapter.use_streaming(other_bot_present); + let native = adapter.uses_native_streaming(other_bot_present); + let assistant_status = adapter.uses_assistant_status(); + // Platforms that render Markdown tables natively (e.g. Slack Block Kit + // `markdown` blocks / `markdown_text` stream chunks) skip the + // table→code/bullets pre-pass so the raw table renders natively. + let table_mode = if adapter.renders_native_tables() { + TableMode::Off + } else { + self.table_mode + }; + let tool_display = self.reactions_config.tool_display; + let prompt_hard_timeout = self.prompt_hard_timeout; + let liveness_check_interval = self.liveness_check_interval; + + self.pool + .with_connection(thread_key, |conn| { + let content_blocks = content_blocks.clone(); + Box::pin(async move { + let reset = conn.session_reset; + conn.session_reset = false; + + let (mut rx, request_id) = conn.session_prompt(content_blocks).await?; + if assistant_status { + let _ = adapter.set_status(&thread_channel, "Thinking…").await; + } else { + reactions.set_thinking().await; + } + + let mut text_buf = String::new(); + let mut tool_lines: Vec = Vec::new(); + + if reset { + text_buf.push_str("⚠️ _Session expired, starting fresh..._\n\n"); + } + + // Native streaming: defer stream_begin until first Text event + // so the thinking phase only shows set_status (no placeholder msg). + let mut native_msg: Option = None; + // Once stream_begin fails, stop retrying for this turn to avoid + // hammering the API on transient failures. + let mut stream_begin_failed = false; + // Native delta coalescing state (used only when `native`). + let mut native_pending = String::new(); + let mut native_last_flush = tokio::time::Instant::now(); + const NATIVE_FLUSH_MS: u128 = 400; + + // Streaming edit: send placeholder, spawn edit loop + let (buf_tx, placeholder_msg, edit_handle) = if streaming && !native { + let initial = if reset { + "⚠️ _Session expired, starting fresh..._\n\n…".to_string() + } else { + "…".to_string() + }; + let msg = if adapter.show_streaming_placeholder() { + adapter.send_message(&thread_channel, &initial).await? + } else { + // Dummy ref for edit loop — gateway uses drafts, doesn't need real msg_id + MessageRef { + message_id: "draft".to_string(), + channel: thread_channel.clone(), + } + }; + let (tx, rx) = tokio::sync::watch::channel(initial); + let edit_adapter = adapter.clone(); + let edit_msg = msg.clone(); + let limit = message_limit; + let mut buf_rx = rx; + let edit_handle = tokio::spawn(async move { + let mut last = String::new(); + // Track consecutive edit failures so we can abort cosmetic + // streaming when the platform stops accepting edits (e.g. + // Feishu's 20-edits-per-message hard cap, errcode 230072). + // Once aborted, the final delivery path still runs and the + // user sees the complete content at turn end. + let mut consecutive_failures: u32 = 0; + const MAX_CONSECUTIVE_FAILURES: u32 = 3; + loop { + tokio::time::sleep(std::time::Duration::from_millis(1500)).await; + if buf_rx.has_changed().unwrap_or(false) { + let content = buf_rx.borrow_and_update().clone(); + if content != last { + let display = if content.chars().count() > limit - 100 { + format!( + "…{}", + format::truncate_chars_tail(&content, limit - 100) + ) + } else { + content.clone() + }; + match edit_adapter + .edit_message(&edit_msg, &display) + .await + { + Ok(_) => { + consecutive_failures = 0; + last = content; + } + Err(e) => { + consecutive_failures += 1; + tracing::debug!( + message_id = %edit_msg.message_id, + platform = %edit_msg.channel.platform, + error = ?e, + consecutive_failures, + "mid-stream cosmetic edit failed" + ); + if consecutive_failures + >= MAX_CONSECUTIVE_FAILURES + { + tracing::warn!( + message_id = %edit_msg.message_id, + platform = %edit_msg.channel.platform, + consecutive_failures, + "mid-stream cosmetic edit aborted; \ + final content will be delivered at turn end" + ); + break; + } + } + } + } + } + if buf_rx.has_changed().is_err() { + break; + } + } + }); + (Some(tx), Some(msg), Some(edit_handle)) + } else { + (None, None, None) + }; + + // (#732) Liveness-aware recv loop. Filters stale id-bearing + // messages and abandons cleanly on dead agent / hard ceiling + // so late responses cannot leak into the next prompt. + let mut response_error: Option = None; + let prompt_start = tokio::time::Instant::now(); + loop { + let notification = tokio::select! { + msg = rx.recv() => match msg { + Some(n) => n, + // Reader saw EOF and already drained pending; nothing to abandon. + None => break, + }, + _ = tokio::time::sleep(liveness_check_interval) => { + if !conn.alive() { + response_error = Some("Agent process died".into()); + conn.abandon_request(request_id).await; + break; + } + if prompt_start.elapsed() > prompt_hard_timeout { + response_error = Some(format!( + "Agent exceeded hard timeout ({}s)", + prompt_hard_timeout.as_secs(), + )); + conn.abandon_request(request_id).await; + break; + } + continue; + } + }; + if let Some(notification_id) = notification.id { + if notification_id != request_id { + // Stale response from a previously-abandoned prompt. + // No automated test seam: this path only triggers when a + // real subprocess emits a late response after the broker + // already called abandon_request — covered by manual + // repro against a live agent (see #732 PR description). + continue; + } + if let Some(ref err) = notification.error { + response_error = Some(format_coded_error(err.code, &err.message, err.data_message())); + } + break; + } + + if let Some(event) = classify_notification(¬ification) { + match event { + AcpEvent::Text(t) => { + text_buf.push_str(&t); + if native { + // Lazy stream_begin: open the stream on first text. + if native_msg.is_none() && !stream_begin_failed { + match adapter.stream_begin(&thread_channel, recipient.clone()).await { + Ok(m) => { native_msg = Some(m); } + Err(e) => { + tracing::error!(error = ?e, "stream_begin failed on first text; will not retry this turn"); + stream_begin_failed = true; + } + } + } + if let Some(msg) = &native_msg { + native_pending.push_str(&t); + if native_last_flush.elapsed().as_millis() + >= NATIVE_FLUSH_MS + && !native_pending.is_empty() + { + let _ = adapter + .stream_append(msg, &native_pending) + .await; + native_pending.clear(); + native_last_flush = tokio::time::Instant::now(); + } + } + } else if let Some(tx) = &buf_tx { + let _ = tx.send(compose_display( + &tool_lines, + &text_buf, + true, + tool_display, + )); + } + } + AcpEvent::Thinking => { + if assistant_status { + let _ = adapter + .set_status(&thread_channel, "Thinking…") + .await; + } else { + reactions.set_thinking().await; + } + } + AcpEvent::ToolStart { id, title } if !title.is_empty() => { + // Live indicator: assistant status line vs emoji reaction. + if assistant_status { + let _ = adapter + .set_status( + &thread_channel, + &format!("Using {title}…"), + ) + .await; + } else { + reactions.set_tool(&title).await; + } + // Record the tool in BOTH modes so the finalized message keeps + // a tool summary (compose_display, gated by tool_display). In + // assistant_mode the status line is transient and cleared before + // the reply, so without this the message would retain no record + // of which tools ran. + let title = sanitize_title(&title); + if let Some(slot) = + tool_lines.iter_mut().find(|e| e.id == id) + { + slot.title = title; + slot.state = ToolState::Running; + } else { + tool_lines.push(ToolEntry { + id, + title, + state: ToolState::Running, + }); + } + // Post+edit live update (no-op under native streaming: buf_tx is None). + if let Some(tx) = &buf_tx { + let _ = tx.send(compose_display( + &tool_lines, + &text_buf, + true, + tool_display, + )); + } + } + AcpEvent::ToolDone { id, title, status } => { + // Live indicator: assistant status line vs emoji reaction. + if assistant_status { + let _ = adapter + .set_status(&thread_channel, "Thinking…") + .await; + } else { + reactions.set_thinking().await; + } + // Update the tool's state in BOTH modes (see ToolStart) so the + // finalized message's tool summary reflects completion/failure. + let new_state = if status == "completed" { + ToolState::Completed + } else { + ToolState::Failed + }; + if let Some(slot) = + tool_lines.iter_mut().find(|e| e.id == id) + { + if !title.is_empty() { + slot.title = sanitize_title(&title); + } + slot.state = new_state; + } else if !title.is_empty() { + tool_lines.push(ToolEntry { + id, + title: sanitize_title(&title), + state: new_state, + }); + } + if let Some(tx) = &buf_tx { + let _ = tx.send(compose_display( + &tool_lines, + &text_buf, + true, + tool_display, + )); + } + } + AcpEvent::ConfigUpdate { options } => { + conn.config_options = options; + } + _ => {} + } + } + } + + conn.prompt_done().await; + // Stop the cosmetic edit loop before the finalize write path + // issues its authoritative edit. Dropping buf_tx closes the watch + // channel so the loop breaks on its next check, but it may be + // mid-edit (a single edit can now block up to the gateway response + // timeout). Without an explicit abort+join, a cosmetic edit issued + // just before close could land *after* the finalize edit and + // overwrite it with stale, mid-stream content (#1122 review NEW-1). + // + // abort() cancels any cosmetic edit that has not yet been put on + // the wire and interrupts the inter-flush sleep immediately; the + // await confirms the task is gone before we proceed. This narrows + // the race to near zero — it does NOT fully eliminate it: a PUT + // already flushed microseconds before abort cannot be recalled, + // and if finalize's PUT travels a different pooled connection the + // server-side arrival order is not strictly guaranteed. That + // residual window is display-only (stale tail briefly shown) and + // far narrower than before this join existed. + drop(buf_tx); + if let Some(handle) = edit_handle { + handle.abort(); + let _ = handle.await; + } + + // Parse output directives from raw text_buf BEFORE compose_display. + // Directives are agent meta-layer, not content — must be stripped + // before tool lines are composed into the display output. + let (directives, stripped_text) = parse_output_directives(&text_buf); + let text_buf = stripped_text; + + // Build final content + let final_content = + compose_display(&tool_lines, &text_buf, false, tool_display); + let final_content = if final_content.is_empty() { + if let Some(err) = response_error { + format!("⚠️ {err}") + } else { + "_(no response)_".to_string() + } + } else if let Some(err) = response_error { + format!("⚠️ {err}\n\n{final_content}") + } else { + final_content + }; + + let final_content = markdown::convert_tables(&final_content, table_mode); + let chunks = format::split_message(&final_content, message_limit); + // Track delivery health across all final write paths. Any failure + // here means the user's view is incomplete; we propagate Err at the + // end of the closure so dispatch surfaces set_error (❌) instead of + // silently calling set_done (🆗) over a half-delivered turn. + let mut delivery_failed = false; + // Clear the assistant status line before delivering the final message. + if assistant_status { + let _ = adapter.set_status(&thread_channel, "").await; + } + if native { + if let Some(msg) = &native_msg { + if !native_pending.is_empty() { + if let Err(e) = + adapter.stream_append(msg, &native_pending).await + { + tracing::warn!(error = ?e, platform = %thread_channel.platform, message_id = %msg.message_id, "native finalize stream_append failed"); + delivery_failed = true; + } + } + // Finalize the streamed message with the first chunk (full-replace), + // then post any overflow chunks as new in-thread messages — mirrors + // the post+edit path so long replies aren't truncated at message_limit. + // NOTE: the reply_to directive is intentionally NOT honored in native + // streaming mode — the streamed message is the in-thread reply. + match chunks.first() { + Some(first) => { + if let Err(e) = adapter.stream_finish(msg, first).await { + tracing::warn!(error = ?e, platform = %thread_channel.platform, message_id = %msg.message_id, "native stream_finish failed"); + delivery_failed = true; + } + for chunk in chunks.iter().skip(1) { + if let Err(e) = + adapter.send_message(&thread_channel, chunk).await + { + tracing::warn!(error = ?e, platform = %thread_channel.platform, message_id = %msg.message_id, "native overflow chunk send failed"); + delivery_failed = true; + } + } + } + None => { + if let Err(e) = + adapter.stream_finish(msg, &final_content).await + { + tracing::warn!(error = ?e, platform = %thread_channel.platform, message_id = %msg.message_id, "native stream_finish (no chunks) failed"); + delivery_failed = true; + } + } + } + } else { + // native_msg is None — either no Text event ever arrived + // (tool-only or empty turn) so lazy stream_begin never + // fired, or stream_begin failed on the first Text event + // and we stopped retrying for this turn. In both cases no + // native stream was opened, so deliver the final content + // (which may be the "_(no response)_" sentinel, or the + // accumulated text_buf) as plain in-thread messages so + // the turn is never silently dropped. + for chunk in &chunks { + if let Err(e) = + adapter.send_message(&thread_channel, chunk).await + { + tracing::warn!(error = ?e, platform = %thread_channel.platform, "native fallback chunk send failed"); + delivery_failed = true; + } + } + } + } else if let Some(msg) = placeholder_msg { + if let Some(ref reply_id) = directives.reply_to { + // reply_to directive: send reply first, then delete placeholder. + // Only delete if send succeeds — preserves placeholder on failure. + let mut send_ok = false; + let mut first = true; + for chunk in &chunks { + if first { + match adapter.send_message_with_reply( + &thread_channel, + chunk, + reply_id, + ).await { + Ok(_) => { send_ok = true; } + Err(e) => { + tracing::warn!(error = ?e, platform = %thread_channel.platform, message_id = %msg.message_id, "reply_to send failed; preserving placeholder"); + delivery_failed = true; + } + } + } else if let Err(e) = + adapter.send_message(&thread_channel, chunk).await + { + tracing::warn!(error = ?e, platform = %thread_channel.platform, message_id = %msg.message_id, "reply_to overflow chunk send failed"); + delivery_failed = true; + } + first = false; + } + if send_ok { + if let Err(e) = adapter.delete_message(&msg).await { + tracing::warn!(error = ?e, platform = %thread_channel.platform, message_id = %msg.message_id, "delete placeholder failed; placeholder will remain visible"); + } + } + } else if adapter.platform() == "discord" + && contains_bot_mention(&final_content) + { + // Discord-specific: bot mention detected. Delete placeholder + // and send as new message so Discord emits MESSAGE_CREATE — + // otherwise the mentioned bot won't receive the gateway + // event since MESSAGE_UPDATE skips notifications (#1110). + let mut send_ok = false; + if let Some(first) = chunks.first() { + match adapter.send_message(&thread_channel, first).await { + Ok(_) => { + send_ok = true; + } + Err(e) => { + tracing::warn!(error = ?e, platform = %thread_channel.platform, message_id = %msg.message_id, "discord bot-mention first chunk send failed"); + delivery_failed = true; + } + } + } + for chunk in chunks.iter().skip(1) { + if let Err(e) = adapter.send_message(&thread_channel, chunk).await { + tracing::warn!(error = ?e, platform = %thread_channel.platform, message_id = %msg.message_id, "streaming overflow chunk send failed"); + delivery_failed = true; + } + } + if send_ok { + let _ = adapter.delete_message(&msg).await; + } + } else { + // Normal streaming: edit first chunk into placeholder, send rest. + // If placeholder is a dummy "draft" ref (no real message), send as + // new message instead — the gateway will persist via sendRichMessage. + if msg.message_id == "draft" { + for chunk in &chunks { + if let Err(e) = + adapter.send_message(&thread_channel, chunk).await + { + tracing::warn!(error = ?e, platform = %thread_channel.platform, message_id = %msg.message_id, "draft placeholder fallback chunk send failed"); + delivery_failed = true; + } + } + } else if let Some(first) = chunks.first() { + // If the placeholder edit fails (e.g. Feishu's + // 20-edits-per-message cap was hit during + // cosmetic streaming and the gateway reports + // edit_cap_reached), fall back to deleting the + // half-edited placeholder and sending the first + // chunk as a fresh message so the user sees the + // complete reply without overlap. If delete + // fails the placeholder simply remains — same + // UX as pre-recovery, not a hard failure. + if let Err(e) = adapter.edit_message(&msg, first).await { + tracing::warn!(error = ?e, platform = %thread_channel.platform, message_id = %msg.message_id, "final streaming edit failed; deleting placeholder and sending fresh"); + if let Err(de) = adapter.delete_message(&msg).await { + tracing::warn!(error = ?de, platform = %thread_channel.platform, message_id = %msg.message_id, "delete placeholder failed; user will see overlap"); + } + if let Err(e2) = + adapter.send_message(&thread_channel, first).await + { + tracing::error!(error = ?e2, platform = %thread_channel.platform, message_id = %msg.message_id, "fallback send_message also failed"); + delivery_failed = true; + } + } + for chunk in chunks.iter().skip(1) { + if let Err(e) = + adapter.send_message(&thread_channel, chunk).await + { + tracing::warn!(error = ?e, platform = %thread_channel.platform, message_id = %msg.message_id, "streaming overflow chunk send failed"); + delivery_failed = true; + } + } + } + } + } else { + // Send-once: all chunks as new messages + // First chunk uses reply_to directive if present + let mut first = true; + for chunk in &chunks { + if first { + if let Some(ref reply_id) = directives.reply_to { + if let Err(e) = adapter.send_message_with_reply( + &thread_channel, + chunk, + reply_id, + ).await { + tracing::warn!(error = ?e, platform = %thread_channel.platform, "send-once reply_to first chunk failed"); + delivery_failed = true; + } + } else if let Err(e) = + adapter.send_message(&thread_channel, chunk).await + { + tracing::warn!(error = ?e, platform = %thread_channel.platform, "send-once first chunk failed"); + delivery_failed = true; + } + } else if let Err(e) = + adapter.send_message(&thread_channel, chunk).await + { + tracing::warn!(error = ?e, platform = %thread_channel.platform, "send-once subsequent chunk failed"); + delivery_failed = true; + } + first = false; + } + } + + if delivery_failed { + Err(anyhow::anyhow!( + "streaming finalization had delivery failures; user view is incomplete" + )) + } else { + Ok(()) + } + }) + }) + .await + } +} + +/// Returns true if `content` contains a Discord user/bot mention (`<@123>`, `<@!123>`) +/// or a role mention (`<@&123>`). +/// Used to detect cross-bot mentions so the streaming path can switch from +/// edit (MESSAGE_UPDATE, no mention notification) to delete+send (MESSAGE_CREATE). +fn contains_bot_mention(content: &str) -> bool { + let mut i = 0; + let bytes = content.as_bytes(); + while i + 2 < bytes.len() { + if bytes[i] == b'<' && bytes[i + 1] == b'@' { + // Skip optional '!' (nickname mention) or '&' (role mention) + let start = if i + 2 < bytes.len() + && (bytes[i + 2] == b'!' || bytes[i + 2] == b'&') + { + i + 3 + } else { + i + 2 + }; + if start < bytes.len() && bytes[start].is_ascii_digit() { + if let Some(end) = content[start..].find('>') { + if content[start..start + end].chars().all(|c| c.is_ascii_digit()) { + return true; + } + } + } + i = start; + } else { + i += 1; + } + } + false +} + +/// Flatten a tool-call title into a single line safe for inline-code spans. +fn sanitize_title(title: &str) -> String { + title + .replace('\r', "") + .replace('\n', " ; ") + .replace('`', "'") +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ToolState { + Running, + Completed, + Failed, +} + +#[derive(Debug, Clone)] +struct ToolEntry { + id: String, + title: String, + state: ToolState, +} + +impl ToolEntry { + fn render(&self) -> String { + let icon = match self.state { + ToolState::Running => "🔧", + ToolState::Completed => "✅", + ToolState::Failed => "❌", + }; + let suffix = if self.state == ToolState::Running { + "..." + } else { + "" + }; + format!("{icon} `{}`{}", self.title, suffix) + } +} + +/// Maximum number of finished tool entries to show individually +/// during streaming before collapsing into a summary line. +const TOOL_COLLAPSE_THRESHOLD: usize = 3; + +fn compose_display( + tool_lines: &[ToolEntry], + text: &str, + streaming: bool, + tool_display: ToolDisplay, +) -> String { + let mut out = String::new(); + if !tool_lines.is_empty() && tool_display != ToolDisplay::None { + let done = tool_lines + .iter() + .filter(|e| e.state == ToolState::Completed) + .count(); + let failed = tool_lines + .iter() + .filter(|e| e.state == ToolState::Failed) + .count(); + let running = tool_lines + .iter() + .filter(|e| e.state == ToolState::Running) + .count(); + let finished = done + failed; + + match tool_display { + ToolDisplay::Compact => { + // Always show count summary, never per-tool details + let mut parts = Vec::new(); + if done > 0 { + parts.push(format!("✅ {done}")); + } + if failed > 0 { + parts.push(format!("❌ {failed}")); + } + if running > 0 { + parts.push(format!("🔧 {running}")); + } + if !parts.is_empty() { + out.push_str(&format!("{} tool(s)\n", parts.join(" · "))); + } + } + ToolDisplay::Full => { + if streaming { + let running_entries: Vec<_> = tool_lines + .iter() + .filter(|e| e.state == ToolState::Running) + .collect(); + + if finished <= TOOL_COLLAPSE_THRESHOLD { + for entry in tool_lines.iter().filter(|e| e.state != ToolState::Running) { + out.push_str(&entry.render()); + out.push('\n'); + } + } else { + let mut parts = Vec::new(); + if done > 0 { + parts.push(format!("✅ {done}")); + } + if failed > 0 { + parts.push(format!("❌ {failed}")); + } + out.push_str(&format!("{} tool(s) completed\n", parts.join(" · "))); + } + + if running_entries.len() <= TOOL_COLLAPSE_THRESHOLD { + for entry in &running_entries { + out.push_str(&entry.render()); + out.push('\n'); + } + } else { + let hidden = running_entries.len() - TOOL_COLLAPSE_THRESHOLD; + out.push_str(&format!("🔧 {hidden} more running\n")); + for entry in running_entries.iter().skip(hidden) { + out.push_str(&entry.render()); + out.push('\n'); + } + } + } else { + for entry in tool_lines { + out.push_str(&entry.render()); + out.push('\n'); + } + } + } + ToolDisplay::None => {} // guarded above, but safe no-op + } + if !out.is_empty() { + out.push('\n'); + } + } + out.push_str(text.trim_end()); + out +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Compile-time regression guard: use_streaming() is a required trait method + /// (no default). Any adapter that forgets to implement it will fail to compile. + /// This test documents the contract — see PR #503 / issue #502 for context. + #[test] + fn use_streaming_is_required_method() { + // If use_streaming() had a default impl, this test module would still + // compile even if an adapter forgot to override it. The real guard is + // the trait definition itself — this test exists as documentation and + // to catch if someone re-adds a default. + struct TestAdapter; + + #[async_trait] + impl ChatAdapter for TestAdapter { + fn platform(&self) -> &'static str { + "test" + } + fn message_limit(&self) -> usize { + 2000 + } + async fn send_message(&self, _: &ChannelRef, _: &str) -> Result { + unimplemented!() + } + async fn create_thread( + &self, + _: &ChannelRef, + _: &MessageRef, + _: &str, + ) -> Result { + unimplemented!() + } + async fn add_reaction(&self, _: &MessageRef, _: &str) -> Result<()> { + Ok(()) + } + async fn remove_reaction(&self, _: &MessageRef, _: &str) -> Result<()> { + Ok(()) + } + // use_streaming() MUST be declared — removing this line should fail compilation + fn use_streaming(&self, _other_bot_present: bool) -> bool { + false + } + } + + let adapter = TestAdapter; + // Verify the method is callable and returns the declared value + assert!(!adapter.use_streaming(false)); + // renders_native_tables defaults to false: platforms that don't override + // it keep the table→code/bullets conversion (e.g. Discord, Gateway). + assert!(!adapter.renders_native_tables()); + } + + #[test] + fn origin_event_id_excluded_from_eq() { + let a = ChannelRef { + platform: "line".into(), + channel_id: "U123".into(), + thread_id: None, + parent_id: None, + origin_event_id: Some("evt_aaa".into()), + }; + let b = ChannelRef { + platform: "line".into(), + channel_id: "U123".into(), + thread_id: None, + parent_id: None, + origin_event_id: Some("evt_bbb".into()), + }; + assert_eq!(a, b, "same channel with different event IDs must be equal"); + } + + #[test] + fn origin_event_id_excluded_from_hash() { + use std::collections::HashMap; + let a = ChannelRef { + platform: "line".into(), + channel_id: "U123".into(), + thread_id: None, + parent_id: None, + origin_event_id: Some("evt_aaa".into()), + }; + let b = ChannelRef { + platform: "line".into(), + channel_id: "U123".into(), + thread_id: None, + parent_id: None, + origin_event_id: Some("evt_bbb".into()), + }; + let mut map = HashMap::new(); + map.insert(a, "first"); + // b should hit the same bucket and overwrite + map.insert(b, "second"); + assert_eq!(map.len(), 1); + assert_eq!(map.values().next(), Some(&"second")); + } + + #[test] + fn origin_event_id_survives_clone() { + let ch = ChannelRef { + platform: "line".into(), + channel_id: "U123".into(), + thread_id: None, + parent_id: None, + origin_event_id: Some("evt_abc".into()), + }; + // Simulates create_thread propagation: clone preserves origin_event_id + let thread_ch = ChannelRef { + thread_id: Some("topic_1".into()), + origin_event_id: ch.origin_event_id.clone(), + ..ch.clone() + }; + assert_eq!(thread_ch.origin_event_id.as_deref(), Some("evt_abc")); + } + + fn tool(id: &str, title: &str, state: ToolState) -> ToolEntry { + ToolEntry { + id: id.into(), + title: title.into(), + state, + } + } + + #[test] + fn compose_display_full_shows_complete_title() { + let tools = vec![tool( + "1", + "curl -s https://example.com", + ToolState::Completed, + )]; + let out = compose_display(&tools, "done", false, ToolDisplay::Full); + assert!(out.contains("`curl -s https://example.com`")); + } + + #[test] + fn compose_display_compact_shows_count_summary() { + let tools = vec![ + tool("1", "curl -s https://example.com", ToolState::Completed), + tool("2", "grep -r pattern src/", ToolState::Completed), + tool("3", "cat /etc/hosts", ToolState::Failed), + ]; + let out = compose_display(&tools, "done", false, ToolDisplay::Compact); + assert!(out.contains("✅ 2"), "expected completed count: {out}"); + assert!(out.contains("❌ 1"), "expected failed count: {out}"); + assert!(out.contains("tool(s)"), "expected tool(s) label: {out}"); + // Must NOT contain individual tool names + assert!(!out.contains("curl"), "should not show tool names: {out}"); + assert!(!out.contains("grep"), "should not show tool names: {out}"); + } + + #[test] + fn compose_display_compact_shows_running_count() { + let tools = vec![ + tool("1", "curl", ToolState::Completed), + tool("2", "npm install", ToolState::Running), + ]; + let out = compose_display(&tools, "", true, ToolDisplay::Compact); + assert!(out.contains("✅ 1"), "expected completed count: {out}"); + assert!(out.contains("🔧 1"), "expected running count: {out}"); + } + + #[test] + fn compose_display_none_hides_tools() { + let tools = vec![tool( + "1", + "curl -s https://example.com", + ToolState::Completed, + )]; + let out = compose_display(&tools, "response text", false, ToolDisplay::None); + assert_eq!(out, "response text"); + } + + #[test] + fn contains_bot_mention_user() { + assert!(contains_bot_mention("hello <@1234567890> world")); + } + + #[test] + fn contains_bot_mention_nickname() { + assert!(contains_bot_mention("hey <@!9876543210>")); + } + + #[test] + fn contains_bot_mention_role() { + assert!(contains_bot_mention("calling <@&1496247626675257384>")); + } + + #[test] + fn contains_bot_mention_no_match() { + assert!(!contains_bot_mention("hello world")); + assert!(!contains_bot_mention("email user@example.com")); + assert!(!contains_bot_mention("<@not_a_number>")); + assert!(!contains_bot_mention("<#123456>")); // channel mention + } + + #[test] + fn contains_bot_mention_embedded() { + assert!(contains_bot_mention("請問 <@1501788608439386172> 1+1=?")); + } +} + +#[cfg(test)] +mod directive_tests { + use super::parse_output_directives; + + #[test] + fn parse_reply_to_directive() { + let input = "[[reply_to:1502606076451885136]]\nHello world"; + let (directives, content) = parse_output_directives(input); + assert_eq!(directives.reply_to, Some("1502606076451885136".to_string())); + assert_eq!(content, "Hello world"); + } + + #[test] + fn parse_no_directives() { + let input = "Just plain content\nwith multiple lines"; + let (directives, content) = parse_output_directives(input); + assert_eq!(directives.reply_to, None); + assert_eq!(content, input); + } + + #[test] + fn parse_multiple_directives() { + let input = "[[reply_to:123456]]\n[[unknown_key:value]]\nContent here"; + let (directives, content) = parse_output_directives(input); + assert_eq!(directives.reply_to, Some("123456".to_string())); + assert_eq!(content, "Content here"); + } + + #[test] + fn parse_invalid_reply_to_rejects_whitespace() { + let input = "[[reply_to:has spaces]]\nContent"; + let (directives, content) = parse_output_directives(input); + assert_eq!(directives.reply_to, None); + assert_eq!(content, "Content"); + } + + #[test] + fn parse_slack_ts_format_accepted() { + let input = "[[reply_to:1234567890.123456]]\nContent"; + let (directives, content) = parse_output_directives(input); + assert_eq!(directives.reply_to, Some("1234567890.123456".to_string())); + assert_eq!(content, "Content"); + } + + #[test] + fn parse_empty_reply_to() { + let input = "[[reply_to:]]\nContent"; + let (directives, content) = parse_output_directives(input); + assert_eq!(directives.reply_to, None); + assert_eq!(content, "Content"); + } + + #[test] + fn parse_crlf_line_endings() { + let input = "[[reply_to:999]]\r\nContent with CRLF"; + let (directives, content) = parse_output_directives(input); + assert_eq!(directives.reply_to, Some("999".to_string())); + assert_eq!(content, "Content with CRLF"); + } + + #[test] + fn parse_directive_only_no_content() { + let input = "[[reply_to:123]]"; + let (directives, content) = parse_output_directives(input); + assert_eq!(directives.reply_to, Some("123".to_string())); + assert_eq!(content, ""); + } + + #[test] + fn parse_non_directive_line_stops_parsing() { + let input = "Normal first line\n[[reply_to:123]]\nMore content"; + let (directives, content) = parse_output_directives(input); + assert_eq!(directives.reply_to, None); + assert_eq!(content, input); + } + + #[test] + fn parse_duplicate_reply_to_last_wins() { + let input = "[[reply_to:111]]\n[[reply_to:222]]\nContent"; + let (directives, content) = parse_output_directives(input); + // Last value wins + assert_eq!(directives.reply_to, Some("222".to_string())); + assert_eq!(content, "Content"); + } + + #[test] + fn parse_crlf_multiple_directives() { + let input = "[[reply_to:456]]\r\n[[unknown:x]]\r\nContent after CRLF"; + let (directives, content) = parse_output_directives(input); + assert_eq!(directives.reply_to, Some("456".to_string())); + assert_eq!(content, "Content after CRLF"); + } + + #[test] + fn parse_bracket_without_colon_preserved() { + // [[Note]] has no colon — not a directive, preserved as content + let input = "[[Summary]]\nThis is body text"; + let (directives, content) = parse_output_directives(input); + assert_eq!(directives.reply_to, None); + assert_eq!(content, input); + } + + #[test] + fn parse_reply_to_with_inline_content() { + // Agent puts content on same line as directive — should still parse + let input = "[[reply_to:1502724086474870926]] @BOT I'm on standby"; + let (directives, content) = parse_output_directives(input); + assert_eq!(directives.reply_to, Some("1502724086474870926".to_string())); + assert_eq!(content, "@BOT I'm on standby"); + } + + #[test] + fn parse_reply_to_inline_with_more_lines() { + let input = "[[reply_to:123]] First line\nSecond line\nThird line"; + let (directives, content) = parse_output_directives(input); + assert_eq!(directives.reply_to, Some("123".to_string())); + assert_eq!(content, "First line\nSecond line\nThird line"); + } + + #[test] + fn parse_reply_to_no_space_before_content() { + // No space between ]] and content + let input = "[[reply_to:1502724086474870926]]收到"; + let (directives, content) = parse_output_directives(input); + assert_eq!(directives.reply_to, Some("1502724086474870926".to_string())); + assert_eq!(content, "收到"); + } + + #[test] + fn parse_reply_to_inline_with_mention() { + // Real-world case: directive followed by Discord mention + let input = "[[reply_to:1502724086474870926]] <@1490365068863606784> 我 standby"; + let (directives, content) = parse_output_directives(input); + assert_eq!(directives.reply_to, Some("1502724086474870926".to_string())); + assert_eq!(content, "<@1490365068863606784> 我 standby"); + } + + #[test] + fn parse_reply_to_inline_only_spaces() { + // Trailing spaces only — no real content, should be empty + let input = "[[reply_to:123]] "; + let (directives, content) = parse_output_directives(input); + assert_eq!(directives.reply_to, Some("123".to_string())); + assert_eq!(content, ""); + } + + #[test] + fn parse_reply_to_with_brackets_in_content() { + // Content after ]] contains brackets — should not confuse parser + let input = "[[reply_to:456]] 看看 [[這個]] 怎麼樣"; + let (directives, content) = parse_output_directives(input); + assert_eq!(directives.reply_to, Some("456".to_string())); + assert_eq!(content, "看看 [[這個]] 怎麼樣"); + } +} diff --git a/crates/openab-core/src/bot_turns.rs b/crates/openab-core/src/bot_turns.rs new file mode 100644 index 000000000..130fa717b --- /dev/null +++ b/crates/openab-core/src/bot_turns.rs @@ -0,0 +1,368 @@ +//! Per-thread bot turn tracking for runaway-loop prevention. +//! +//! Shared between Discord and Slack adapters so both platforms apply the same +//! soft/hard limit semantics. Both counters reset on a human message in the +//! thread. Runs before self-check so a bot's own messages count too — this +//! means `soft_limit=20` caps the *total* bot messages in a thread, not per-bot. + +use std::collections::HashMap; + +/// Absolute per-thread cap on consecutive bot turns without human intervention. +/// A human message resets both soft and hard counters to 0, allowing bots to +/// resume. This is *not* a lifetime total — it guards against runaway loops +/// between human resets. +pub const HARD_BOT_TURN_LIMIT: u32 = 1000; + +/// Stable prefix used in all bot turn limit warning messages. +/// Referenced by the dedup check in the Discord adapter — changing this +/// string requires updating the dedup check too. +pub const BOT_TURN_LIMIT_WARNING_PREFIX: &str = "⚠️ Bot turn limit reached"; + +#[derive(Debug, PartialEq, Eq)] +pub enum TurnResult { + /// Counter below limits — continue normally. + Ok, + /// Counter == soft_limit — warn once, then stop. + SoftLimit(u32), + /// Counter > soft_limit — silently stop (already warned). + Throttled, + /// Counter == HARD_BOT_TURN_LIMIT — warn once, then stop. + HardLimit, + /// Counter > HARD_BOT_TURN_LIMIT — silently stop (already warned). + Stopped, +} + +pub struct BotTurnTracker { + soft_limit: u32, + counts: HashMap, +} + +impl BotTurnTracker { + pub fn new(soft_limit: u32) -> Self { + Self { + soft_limit, + counts: HashMap::new(), + } + } + + pub fn on_bot_message(&mut self, thread_id: &str) -> TurnResult { + let (soft, hard) = self.counts.entry(thread_id.to_string()).or_insert((0, 0)); + *soft += 1; + *hard += 1; + if *hard > HARD_BOT_TURN_LIMIT { + TurnResult::Stopped + } else if *hard == HARD_BOT_TURN_LIMIT { + TurnResult::HardLimit + } else if *soft > self.soft_limit { + TurnResult::Throttled + } else if *soft == self.soft_limit { + TurnResult::SoftLimit(*soft) + } else { + TurnResult::Ok + } + } + + pub fn on_human_message(&mut self, thread_id: &str) { + if let Some((soft, hard)) = self.counts.get_mut(thread_id) { + *soft = 0; + *hard = 0; + } + } + + /// High-level decision for a bot message: increments the counter and + /// returns what the adapter should do. Collapses the warn-once semantics + /// and user-facing message formatting so Discord/Slack (and future adapters) + /// don't duplicate the match. + pub fn classify_bot_message(&mut self, thread_id: &str) -> TurnAction { + match self.on_bot_message(thread_id) { + TurnResult::Ok => TurnAction::Continue, + TurnResult::SoftLimit(n) => TurnAction::WarnAndStop { + severity: TurnSeverity::Soft, + turns: n, + user_message: format!( + "{} ({n}/{soft}). \ + A human must reply in this thread to continue bot-to-bot conversation.", + BOT_TURN_LIMIT_WARNING_PREFIX, + soft = self.soft_limit, + ), + }, + TurnResult::HardLimit => TurnAction::WarnAndStop { + severity: TurnSeverity::Hard, + turns: HARD_BOT_TURN_LIMIT, + user_message: format!( + "🛑 Hard bot turn limit reached ({HARD_BOT_TURN_LIMIT}). \ + A human must reply to continue." + ), + }, + TurnResult::Throttled | TurnResult::Stopped => TurnAction::SilentStop, + } + } +} + +/// Log severity hint for `TurnAction::WarnAndStop`. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum TurnSeverity { + /// Soft limit — typically logged at `info!`. + Soft, + /// Hard absolute cap — typically logged at `warn!`. + Hard, +} + +/// High-level action for a bot message after calling +/// [`BotTurnTracker::classify_bot_message`]. +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum TurnAction { + /// Safe to continue processing this bot message. + Continue, + /// Stop processing; if the message did not come from our own bot, the + /// caller should post `user_message` to the thread so humans see why + /// the bot went quiet. `turns` is the counter value at the warning + /// point — useful as a structured log field. + WarnAndStop { + severity: TurnSeverity, + turns: u32, + user_message: String, + }, + /// Stop processing silently — the warning was already sent on a previous + /// turn; further warnings would spam the thread. + SilentStop, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn bot_turns_increment() { + let mut t = BotTurnTracker::new(5); + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + } + + #[test] + fn soft_limit_triggers() { + let mut t = BotTurnTracker::new(3); + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + assert_eq!(t.on_bot_message("t1"), TurnResult::SoftLimit(3)); + } + + #[test] + fn human_resets_both_counters() { + let mut t = BotTurnTracker::new(3); + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + t.on_human_message("t1"); + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + assert_eq!(t.on_bot_message("t1"), TurnResult::SoftLimit(3)); + } + + #[test] + fn hard_limit_triggers() { + let mut t = BotTurnTracker::new(HARD_BOT_TURN_LIMIT + 1); + for _ in 0..HARD_BOT_TURN_LIMIT - 1 { + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + } + assert_eq!(t.on_bot_message("t1"), TurnResult::HardLimit); + } + + #[test] + fn hard_limit_does_not_fire_at_legacy_100() { + let mut t = BotTurnTracker::new(HARD_BOT_TURN_LIMIT + 1); + for i in 1..=100 { + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok, "turn {i}"); + } + } + + #[test] + fn hard_limit_resets_on_human() { + let mut t = BotTurnTracker::new(HARD_BOT_TURN_LIMIT + 1); + for _ in 0..HARD_BOT_TURN_LIMIT - 1 { + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + } + t.on_human_message("t1"); + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + } + + #[test] + fn hard_before_soft_when_equal() { + let mut t = BotTurnTracker::new(HARD_BOT_TURN_LIMIT); + for _ in 0..HARD_BOT_TURN_LIMIT - 1 { + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + } + assert_eq!(t.on_bot_message("t1"), TurnResult::HardLimit); + } + + #[test] + fn threads_are_independent() { + let mut t = BotTurnTracker::new(3); + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + assert_eq!(t.on_bot_message("t1"), TurnResult::SoftLimit(3)); + assert_eq!(t.on_bot_message("t2"), TurnResult::Ok); + } + + #[test] + fn human_on_unknown_thread_is_noop() { + let mut t = BotTurnTracker::new(5); + t.on_human_message("unknown"); + } + + #[test] + fn two_bot_pingpong_hits_soft_limit() { + let mut t = BotTurnTracker::new(20); + for i in 1..20 { + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok, "turn {i}"); + } + assert_eq!(t.on_bot_message("t1"), TurnResult::SoftLimit(20)); + } + + #[test] + fn two_bot_pingpong_human_resets() { + let mut t = BotTurnTracker::new(20); + for _ in 0..15 { + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + } + t.on_human_message("t1"); + for _ in 0..15 { + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + } + for _ in 0..4 { + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + } + assert_eq!(t.on_bot_message("t1"), TurnResult::SoftLimit(20)); + } + + #[test] + fn soft_limit_warn_once_semantics() { + let mut t = BotTurnTracker::new(20); + for _ in 0..19 { + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + } + assert_eq!(t.on_bot_message("t1"), TurnResult::SoftLimit(20)); + assert_eq!(t.on_bot_message("t1"), TurnResult::Throttled); + assert_eq!(t.on_bot_message("t1"), TurnResult::Throttled); + } + + #[test] + fn hard_limit_warn_once_semantics() { + let mut t = BotTurnTracker::new(HARD_BOT_TURN_LIMIT + 1); + for _ in 0..HARD_BOT_TURN_LIMIT - 1 { + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + } + assert_eq!(t.on_bot_message("t1"), TurnResult::HardLimit); + assert_eq!(t.on_bot_message("t1"), TurnResult::Stopped); + } + + // System messages (thread created, pin, etc.) must not reset the counter. + // Filtering happens at the call site; this verifies the counter stays put + // when on_human_message is never called. Regression for openabdev/openab#497. + #[test] + fn system_message_does_not_reset_counter() { + let mut t = BotTurnTracker::new(3); + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + assert_eq!(t.on_bot_message("t1"), TurnResult::SoftLimit(3)); + } + + #[test] + fn classify_returns_continue_under_limits() { + let mut t = BotTurnTracker::new(5); + assert_eq!(t.classify_bot_message("t1"), TurnAction::Continue); + } + + #[test] + fn classify_returns_warn_and_stop_on_soft_limit() { + let mut t = BotTurnTracker::new(3); + let _ = t.classify_bot_message("t1"); + let _ = t.classify_bot_message("t1"); + assert_eq!( + t.classify_bot_message("t1"), + TurnAction::WarnAndStop { + severity: TurnSeverity::Soft, + turns: 3, + user_message: format!( + "{} (3/3). \ + A human must reply in this thread to continue bot-to-bot conversation.", + BOT_TURN_LIMIT_WARNING_PREFIX, + ), + }, + ); + } + + #[test] + fn classify_returns_silent_stop_past_soft_limit() { + let mut t = BotTurnTracker::new(2); + let _ = t.classify_bot_message("t1"); + let _ = t.classify_bot_message("t1"); + assert_eq!(t.classify_bot_message("t1"), TurnAction::SilentStop); + assert_eq!(t.classify_bot_message("t1"), TurnAction::SilentStop); + } + + #[test] + fn classify_returns_warn_and_stop_on_hard_limit() { + let mut t = BotTurnTracker::new(HARD_BOT_TURN_LIMIT + 1); + for _ in 0..HARD_BOT_TURN_LIMIT - 1 { + let _ = t.classify_bot_message("t1"); + } + assert_eq!( + t.classify_bot_message("t1"), + TurnAction::WarnAndStop { + severity: TurnSeverity::Hard, + turns: HARD_BOT_TURN_LIMIT, + user_message: format!( + "🛑 Hard bot turn limit reached ({HARD_BOT_TURN_LIMIT}). \ + A human must reply to continue." + ), + }, + ); + assert_eq!(t.classify_bot_message("t1"), TurnAction::SilentStop); + } + + #[test] + fn classify_is_per_thread_independent() { + let mut t = BotTurnTracker::new(2); + assert_eq!(t.classify_bot_message("t1"), TurnAction::Continue); + assert!(matches!( + t.classify_bot_message("t1"), + TurnAction::WarnAndStop { + severity: TurnSeverity::Soft, + .. + }, + )); + assert_eq!(t.classify_bot_message("t2"), TurnAction::Continue); + assert!(matches!( + t.classify_bot_message("t2"), + TurnAction::WarnAndStop { + severity: TurnSeverity::Soft, + .. + }, + )); + } + + // End-to-end: human message must fully reset classify behavior on the + // same thread, including unlocking new `Continue` responses. + #[test] + fn classify_resumes_after_human_message() { + let mut t = BotTurnTracker::new(2); + let _ = t.classify_bot_message("t1"); // Continue + assert!(matches!( + t.classify_bot_message("t1"), + TurnAction::WarnAndStop { .. }, + )); + // Without a human message, the next classify is silent. + assert_eq!(t.classify_bot_message("t1"), TurnAction::SilentStop); + // Human resets — classify starts at Continue again. + t.on_human_message("t1"); + assert_eq!(t.classify_bot_message("t1"), TurnAction::Continue); + assert!(matches!( + t.classify_bot_message("t1"), + TurnAction::WarnAndStop { + severity: TurnSeverity::Soft, + turns: 2, + .. + }, + )); + } +} diff --git a/crates/openab-core/src/config.rs b/crates/openab-core/src/config.rs new file mode 100644 index 000000000..991071164 --- /dev/null +++ b/crates/openab-core/src/config.rs @@ -0,0 +1,1500 @@ +use crate::markdown::TableMode; +use regex::Regex; +use serde::Deserialize; +use std::collections::HashMap; +use std::path::Path; + +/// Controls how incoming messages are dispatched to ACP turns. +/// +/// - `Message` (default): each message becomes its own ACP turn (v0.8.2-beta.1 behaviour). +/// - `Thread`: one buffer per thread; all senders in a thread share a single batch and +/// produce one ACP turn per turn boundary. +/// - `Lane`: one buffer per (thread, sender); each sender batches independently and gets +/// its own ACP turn — no silent-drop risk when multiple senders address the same thread. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub enum MessageProcessingMode { + #[default] + Message, + Thread, + Lane, +} + +impl<'de> Deserialize<'de> for MessageProcessingMode { + fn deserialize>(deserializer: D) -> Result { + let s = String::deserialize(deserializer)?; + match s.to_lowercase().replace('-', "_").as_str() { + "per_message" => Ok(Self::Message), + "per_thread" => Ok(Self::Thread), + "per_lane" => Ok(Self::Lane), + other => Err(serde::de::Error::unknown_variant( + other, + &["per-message", "per-thread", "per-lane"], + )), + } + } +} + +/// Controls whether the bot processes messages from other Discord bots. +/// +/// Inspired by Hermes Agent's `DISCORD_ALLOW_BOTS` 3-value design: +/// - `Off` (default): ignore all bot messages (safe default, no behavior change) +/// - `Mentions`: only process bot messages that @mention this bot (natural loop breaker) +/// - `All`: process all bot messages (hard-capped at 1000 consecutive bot turns) +/// +/// The bot's own messages are always ignored regardless of this setting. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub enum AllowBots { + #[default] + Off, + Mentions, + All, +} + +impl<'de> Deserialize<'de> for AllowBots { + fn deserialize>(deserializer: D) -> Result { + let s = String::deserialize(deserializer)?; + match s.to_lowercase().as_str() { + "off" | "none" | "false" => Ok(Self::Off), + "mentions" => Ok(Self::Mentions), + "all" | "true" => Ok(Self::All), + other => Err(serde::de::Error::unknown_variant( + other, + &["off", "mentions", "all"], + )), + } + } +} + +#[derive(Debug, Clone, Deserialize)] +pub struct AgentCoreConfig { + /// AgentCore Runtime ARN (required) + pub runtime_arn: String, + /// ACP agent command to run in the PTY shell (default: kiro-cli acp --trust-all-tools) + #[serde(default = "default_agentcore_shell_command")] + pub shell_command: String, + /// Cancel strategy: "noop" or "stop" (default: stop) + #[serde(default = "default_agentcore_cancel_strategy")] + #[allow(dead_code)] + pub cancel_strategy: AgentCoreCancelStrategy, +} + +fn default_agentcore_shell_command() -> String { + "kiro-cli acp --trust-all-tools".to_string() +} + +impl AgentCoreConfig { + /// Extract region from ARN: arn:aws:bedrock-agentcore:REGION:ACCOUNT:runtime/ID + pub fn region(&self) -> String { + let parts: Vec<&str> = self.runtime_arn.split(':').collect(); + if parts.len() >= 4 && !parts[3].is_empty() { + return parts[3].to_string(); + } + "us-east-1".into() // fallback (should never hit with valid ARN) + } +} + +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub enum AgentCoreCancelStrategy { + #[default] + Stop, + Noop, +} + +impl<'de> Deserialize<'de> for AgentCoreCancelStrategy { + fn deserialize>(deserializer: D) -> Result { + let s = String::deserialize(deserializer)?; + match s.to_lowercase().as_str() { + "stop" => Ok(Self::Stop), + "noop" => Ok(Self::Noop), + other => Err(serde::de::Error::unknown_variant(other, &["stop", "noop"])), + } + } +} + +impl std::fmt::Display for AgentCoreCancelStrategy { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Stop => write!(f, "stop"), + Self::Noop => write!(f, "noop"), + } + } +} + +fn default_agentcore_cancel_strategy() -> AgentCoreCancelStrategy { + AgentCoreCancelStrategy::Stop +} + +#[derive(Debug, Deserialize)] +pub struct Config { + pub discord: Option, + pub slack: Option, + pub gateway: Option, + pub agentcore: Option, + #[serde(default)] + pub agent: AgentConfig, + #[serde(default)] + pub pool: PoolConfig, + #[serde(default)] + pub reactions: ReactionsConfig, + #[serde(default)] + pub stt: SttConfig, + #[serde(default)] + pub markdown: MarkdownConfig, + #[serde(default)] + pub cron: CronConfig, + #[serde(default)] + pub hooks: HooksConfig, + #[serde(default)] + pub workspace: WorkspaceConfig, + #[serde(default)] + pub secrets: SecretsConfig, +} + +#[derive(Debug, Clone, Default, Deserialize)] +pub struct WorkspaceConfig { + /// Workspace aliases: `name = "~/path/to/project"` + /// Used with `[[ws:@alias]]` control directives. + #[serde(default)] + pub aliases: std::collections::HashMap, +} + +#[derive(Debug, Clone, Default, Deserialize)] +pub struct SecretsConfig { + /// AWS Secrets Manager configuration. + #[serde(default)] + pub aws: AwsSecretsConfig, + /// Exec provider configuration. + #[serde(default)] + pub exec: ExecSecretsConfig, + /// Secret references: key = "aws-sm://..." or "exec://..." + #[serde(default)] + pub refs: HashMap, +} + +#[derive(Debug, Clone, Default, Deserialize)] +pub struct AwsSecretsConfig { + /// Override AWS region (otherwise uses default credential chain). + pub region: Option, + /// Override endpoint URL (for LocalStack or VPC endpoints). + pub endpoint_url: Option, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ExecSecretsConfig { + /// Per-invocation timeout in seconds (default: 10). + #[serde(default = "default_exec_timeout")] + pub timeout_seconds: u64, +} + +impl Default for ExecSecretsConfig { + fn default() -> Self { + Self { timeout_seconds: 10 } + } +} + +fn default_exec_timeout() -> u64 { + 10 +} + +#[derive(Debug, Clone, Default, Deserialize)] +pub struct HooksConfig { + pub pre_boot: Option, + pub pre_shutdown: Option, +} + +/// Failure policy for a hook. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub enum OnFailure { + #[default] + Abort, + Warn, +} + +impl<'de> Deserialize<'de> for OnFailure { + fn deserialize>(deserializer: D) -> Result { + let s = String::deserialize(deserializer)?; + match s.to_lowercase().as_str() { + "abort" => Ok(Self::Abort), + "warn" => Ok(Self::Warn), + other => Err(serde::de::Error::unknown_variant(other, &["abort", "warn"])), + } + } +} + +/// Configuration for a single hook. Exactly one of `script`, `inline`, or `url` must be set. +#[derive(Debug, Clone, Deserialize)] +pub struct HookConfig { + /// Absolute path to an executable script. + pub script: Option, + /// Inline script content (written to temp file and executed). + pub inline: Option, + /// Remote script URL (fetched and executed). + pub url: Option, + /// SHA-256 checksum of the remote script (required with `url`). + pub sha256: Option, + /// Max wall-clock seconds. Default: 60. + #[serde(default = "default_hook_timeout")] + pub timeout_seconds: u64, + /// Failure policy. Default: abort. + #[serde(default)] + pub on_failure: OnFailure, +} + +fn default_hook_timeout() -> u64 { + 60 +} + +#[derive(Debug, Clone, Default, Deserialize)] +pub struct CronConfig { + /// Enable usercron hot-reload (default: false). Must be explicitly set to true. + #[serde(default)] + pub usercron_enabled: bool, + /// Path to an external cronjob.toml for hot-reloadable user-managed schedules. + pub usercron_path: Option, + /// Baseline cronjob definitions: `[[cron.jobs]]` + #[serde(default)] + pub jobs: Vec, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct SttConfig { + #[serde(default)] + pub enabled: bool, + #[serde(default)] + pub api_key: String, + #[serde(default = "default_stt_model")] + pub model: String, + #[serde(default = "default_stt_base_url")] + pub base_url: String, + /// Echo the transcribed text back to the thread (no mentions) before + /// dispatching the prompt to the agent. Lets users verify STT accuracy. + #[serde(default = "default_echo_transcript")] + pub echo_transcript: bool, +} + +impl Default for SttConfig { + fn default() -> Self { + Self { + enabled: false, + api_key: String::new(), + model: default_stt_model(), + base_url: default_stt_base_url(), + echo_transcript: default_echo_transcript(), + } + } +} + +fn default_stt_model() -> String { + "whisper-large-v3-turbo".into() +} +fn default_stt_base_url() -> String { + "https://api.groq.com/openai/v1".into() +} +fn default_echo_transcript() -> bool { + false +} + +#[derive(Debug, Deserialize)] +pub struct DiscordConfig { + pub bot_token: String, + /// Explicit flag: true = allow all channels, false = check allowed_channels list. + /// When not set, auto-detected: non-empty list → false, empty list → true. + pub allow_all_channels: Option, + /// Explicit flag: true = allow all users, false = check allowed_users list. + /// When not set, auto-detected: non-empty list → false, empty list → true. + pub allow_all_users: Option, + #[serde(default)] + pub allowed_channels: Vec, + #[serde(default)] + pub allowed_users: Vec, + #[serde(default)] + pub allow_bot_messages: AllowBots, + /// When non-empty, only bot messages from these IDs pass the bot gate. + /// Combines with `allow_bot_messages`: the mode check runs first, then + /// the allowlist filters further. Empty = allow any bot (mode permitting). + /// Only relevant when `allow_bot_messages` is `"mentions"` or `"all"`; + /// ignored when `"off"` since all bot messages are rejected before this check. + /// + /// **Admission override**: a trusted bot that explicitly @mentions this bot + /// bypasses the `allow_bot_messages` mode entirely (treated as human @mention). + /// This allows trusted bots to pull this bot into threads regardless of mode. + #[serde(default)] + pub trusted_bot_ids: Vec, + #[serde(default)] + pub allow_user_messages: AllowUsers, + /// Max consecutive bot turns (without human intervention) before throttling. + /// Human message resets the counter. Default: 100. + #[serde(default = "default_max_bot_turns")] + pub max_bot_turns: u32, + /// Role IDs that trigger the bot (same as direct @mention). + /// When a message mentions a role in this list, it is treated as a bot trigger. + /// Empty (default) = role mentions do not trigger the bot. + #[serde(default)] + pub allowed_role_ids: Vec, + /// Allow the bot to respond to Discord direct messages (DMs). + /// Default: false (opt-in). `allowed_users` still applies in DMs. + #[serde(default)] + pub allow_dm: bool, + /// Message dispatch mode. Default: per-message (v0.8.2-beta.1 behaviour). + #[serde(default)] + pub message_processing_mode: MessageProcessingMode, + /// Batched mode only: per-thread channel capacity. Default: 10. + #[serde(default = "default_max_buffered_messages")] + pub max_buffered_messages: usize, + /// Batched mode only: soft token cap for greedy drain. Default: 24000. + #[serde(default = "default_max_batch_tokens")] + pub max_batch_tokens: usize, +} + +fn default_max_bot_turns() -> u32 { + 100 +} +fn default_max_buffered_messages() -> usize { + 10 +} +fn default_max_batch_tokens() -> usize { + 24_000 +} + +/// Controls whether the bot responds to user messages in threads without @mention. +/// +/// - `Involved`: respond to thread messages only if the bot has participated +/// in the thread (posted at least one message, or the thread parent @mentions the bot). +/// Channel/MPDM messages always require @mention. DMs always process (implicit mention). +/// - `Mentions`: always require @mention, even in threads the bot is participating in. +/// - `MultibotMentions` (default): same as `Involved` in single-bot threads; falls back to +/// `Mentions` when other bots have also posted in the thread. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub enum AllowUsers { + Involved, + Mentions, + #[default] + MultibotMentions, +} + +impl<'de> Deserialize<'de> for AllowUsers { + fn deserialize>(deserializer: D) -> Result { + let s = String::deserialize(deserializer)?; + match s.to_lowercase().replace('-', "_").as_str() { + "involved" => Ok(Self::Involved), + "mentions" => Ok(Self::Mentions), + "multibot_mentions" => Ok(Self::MultibotMentions), + other => Err(serde::de::Error::unknown_variant( + other, + &["involved", "mentions", "multibot-mentions"], + )), + } + } +} + +#[derive(Debug, Deserialize)] +pub struct SlackConfig { + pub bot_token: String, + pub app_token: String, + /// Explicit flag: true = allow all channels, false = check allowed_channels list. + /// When not set, auto-detected: non-empty list → false, empty list → true. + pub allow_all_channels: Option, + /// Explicit flag: true = allow all users, false = check allowed_users list. + /// When not set, auto-detected: non-empty list → false, empty list → true. + pub allow_all_users: Option, + #[serde(default)] + pub allowed_channels: Vec, + #[serde(default)] + pub allowed_users: Vec, + #[serde(default)] + pub allow_bot_messages: AllowBots, + /// Bot User IDs (U...) allowed to interact when allow_bot_messages is + /// "mentions" or "all". Find via Slack UI: click bot profile → Copy member ID. + /// Empty = allow any bot (mode permitting). + #[serde(default)] + pub trusted_bot_ids: Vec, + #[serde(default)] + pub allow_user_messages: AllowUsers, + /// Max consecutive bot turns (without human intervention) before throttling. + /// Human message resets the counter. Default: 100. + #[serde(default = "default_max_bot_turns")] + pub max_bot_turns: u32, + /// Message dispatch mode. Default: per-message. + #[serde(default)] + pub message_processing_mode: MessageProcessingMode, + /// Batched mode only: per-thread channel capacity. Default: 10. + #[serde(default = "default_max_buffered_messages")] + pub max_buffered_messages: usize, + /// Batched mode only: soft token cap for greedy drain. Default: 24000. + #[serde(default = "default_max_batch_tokens")] + pub max_batch_tokens: usize, + /// Slack "AI app / Assistant" mode: stream replies via chat.startStream + + /// assistant.threads.setStatus instead of post+edit + emoji reactions. + /// Requires the Slack app to be an AI app (assistant feature enabled) with + /// the `assistant:write` scope. Default: true — set to false for Slack apps + /// that are not AI apps (no `assistant:write`) to keep emoji-reaction status. + #[serde(default = "default_true")] + pub assistant_mode: bool, +} + +#[derive(Debug, Deserialize)] +pub struct GatewayConfig { + /// WebSocket URL of the custom gateway (e.g. ws://gateway:8080/ws) + pub url: String, + /// Platform name for session key namespacing (e.g. "telegram", "line") + #[serde(default = "default_gateway_platform")] + pub platform: String, + /// Shared token for WebSocket authentication (optional but recommended) + pub token: Option, + /// Bot username for @mention gating in groups (e.g. "my_bot") + pub bot_username: Option, + /// Explicit flag: true = allow all channels, false = check allowed_channels list. + /// When not set, auto-detected: non-empty list → false, empty list → true. + pub allow_all_channels: Option, + /// Explicit flag: true = allow all users, false = check allowed_users list. + /// When not set, auto-detected: non-empty list → false, empty list → true. + pub allow_all_users: Option, + #[serde(default)] + pub allowed_channels: Vec, + #[serde(default)] + pub allowed_users: Vec, + /// Enable streaming (typewriter) mode — requires gateway platform to support message editing. + #[serde(default)] + pub streaming: bool, + /// Show "…" placeholder at streaming start. Default: true. Set false for platforms using drafts. + #[serde(default = "default_true")] + pub streaming_placeholder: bool, + /// Message dispatch mode. Default: per-message. + #[serde(default)] + pub message_processing_mode: MessageProcessingMode, + /// Batched mode only: per-thread channel capacity. Default: 10. + #[serde(default = "default_max_buffered_messages")] + pub max_buffered_messages: usize, + /// Batched mode only: soft token cap for greedy drain. Default: 24000. + #[serde(default = "default_max_batch_tokens")] + pub max_batch_tokens: usize, +} + +fn default_gateway_platform() -> String { + "telegram".into() +} + +/// Raw intermediate struct for serde — uses `Option` to detect explicit fields. +#[derive(Debug, Deserialize)] +#[serde(default)] +struct AgentConfigRaw { + command: Option, + args: Option>, + working_dir: String, + env: HashMap, + inherit_env: Vec, +} + +impl Default for AgentConfigRaw { + fn default() -> Self { + Self { + command: None, + args: None, + working_dir: default_working_dir(), + env: HashMap::new(), + inherit_env: Vec::new(), + } + } +} + +#[derive(Debug)] +pub struct AgentConfig { + pub command: String, + pub args: Vec, + pub working_dir: String, + pub env: HashMap, + pub inherit_env: Vec, + /// Whether the command was explicitly set in config (vs defaulted from env/fallback). + pub command_explicit: bool, +} + +impl Default for AgentConfig { + fn default() -> Self { + Self { + command: default_agent_command(), + args: default_agent_args(), + working_dir: default_working_dir(), + env: HashMap::new(), + inherit_env: Vec::new(), + command_explicit: false, + } + } +} + +impl<'de> serde::Deserialize<'de> for AgentConfig { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let raw = AgentConfigRaw::deserialize(deserializer)?; + let cmd_explicit = raw.command.is_some(); + let command = raw.command.unwrap_or_else(default_agent_command); + // If command was explicitly set but args was not, default args to [] + // to avoid leaking env-var args into a custom command. + let args = match (cmd_explicit, raw.args) { + (_, Some(args)) => args, // args explicitly set → use them + (true, None) => Vec::new(), // command set, args omitted → empty + (false, None) => default_agent_args(), // neither set → env var + }; + Ok(AgentConfig { + command, + args, + working_dir: raw.working_dir, + env: raw.env, + inherit_env: raw.inherit_env, + command_explicit: cmd_explicit, + }) + } +} + +#[derive(Debug, Deserialize)] +pub struct PoolConfig { + #[serde(default = "default_max_sessions")] + pub max_sessions: usize, + #[serde(default = "default_ttl_hours")] + pub session_ttl_hours: u64, + /// Hard ceiling for a single prompt (#732). Once exceeded, the broker + /// abandons the in-flight request, sends `session/cancel` to the agent, + /// and clears the pending entry so late responses cannot leak into the + /// next prompt's subscriber. + /// + /// Precision: checked every `liveness_check_secs`, so actual cutoff is + /// ±`liveness_check_secs` from this value. + #[serde(default = "default_prompt_hard_timeout_secs")] + pub prompt_hard_timeout_secs: u64, + /// Polling cadence (seconds) for the recv-loop liveness check (#732). + /// Lower = faster reaction to a dead agent / hard ceiling at the cost of + /// more wakeups while the agent is streaming normally. + #[serde(default = "default_liveness_check_secs")] + pub liveness_check_secs: u64, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct CronJobConfig { + /// Stable ID for usercron jobs that need scheduler writeback. + pub id: Option, + /// Whether this cronjob is active (default: true) + #[serde(default = "default_true")] + pub enabled: bool, + /// Cron expression (5-field POSIX format) + pub schedule: String, + /// Target channel ID + pub channel: String, + /// Message to send to the agent + pub message: String, + /// Target platform (default: "discord") + #[serde(default = "default_cron_platform")] + pub platform: String, + /// Sender name for attribution (default: "openab-cron") + #[serde(default = "default_cron_sender")] + pub sender_name: String, + /// Optional thread ID (post to existing thread) + pub thread_id: Option, + /// Timezone (default: "UTC") + #[serde(default = "default_cron_timezone")] + pub timezone: String, + /// Usercron-only: command to run before firing. Exit 0 plus a matching + /// `disable_on_success_match` means the goal is complete and the scheduler + /// disables the job in the usercron file. + pub disable_on_success: Option, + /// Usercron-only: required output marker for `disable_on_success`. + pub disable_on_success_match: Option, + /// Usercron-only: timeout for `disable_on_success`. + #[serde(default = "default_disable_on_success_timeout_secs")] + pub disable_on_success_timeout_secs: u64, + /// Usercron-only: working directory for `disable_on_success`. + pub disable_on_success_working_dir: Option, +} + +fn default_cron_platform() -> String { + "discord".into() +} +fn default_cron_sender() -> String { + "openab-cron".into() +} +fn default_cron_timezone() -> String { + "UTC".into() +} +fn default_disable_on_success_timeout_secs() -> u64 { + 60 +} + +/// Controls how tool calls are rendered in chat messages. +/// +/// - `full`: show complete tool title including arguments (default, original behavior) +/// - `compact`: show only a count summary, e.g. `✅ 3 · 🔧 1 tool(s)` +/// - `none`: hide tool lines entirely, only show final response +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub enum ToolDisplay { + #[default] + Full, + Compact, + None, +} + +impl<'de> Deserialize<'de> for ToolDisplay { + fn deserialize>(deserializer: D) -> Result { + let s = String::deserialize(deserializer)?; + match s.to_lowercase().as_str() { + "full" => Ok(Self::Full), + "compact" => Ok(Self::Compact), + "none" | "off" | "hidden" => Ok(Self::None), + other => Err(serde::de::Error::unknown_variant( + other, + &["full", "compact", "none"], + )), + } + } +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ReactionsConfig { + #[serde(default = "default_true")] + pub enabled: bool, + #[serde(default)] + pub remove_after_reply: bool, + #[serde(default)] + pub tool_display: ToolDisplay, + #[serde(default)] + pub emojis: ReactionEmojis, + #[serde(default)] + pub timing: ReactionTiming, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ReactionEmojis { + #[serde(default = "emoji_queued")] + pub queued: String, + #[serde(default = "emoji_thinking")] + pub thinking: String, + #[serde(default = "emoji_tool")] + pub tool: String, + #[serde(default = "emoji_coding")] + pub coding: String, + #[serde(default = "emoji_web")] + pub web: String, + #[serde(default = "emoji_done")] + pub done: String, + #[serde(default = "emoji_error")] + pub error: String, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ReactionTiming { + #[serde(default = "default_debounce_ms")] + pub debounce_ms: u64, + #[serde(default = "default_stall_soft_ms")] + pub stall_soft_ms: u64, + #[serde(default = "default_stall_hard_ms")] + pub stall_hard_ms: u64, + #[serde(default = "default_done_hold_ms")] + pub done_hold_ms: u64, + #[serde(default = "default_error_hold_ms")] + pub error_hold_ms: u64, +} + +// --- defaults --- + +fn default_working_dir() -> String { + std::env::var("HOME").unwrap_or_else(|_| "/tmp".into()) +} +fn default_agent_command() -> String { + if let Ok(val) = std::env::var("OPENAB_AGENT_COMMAND") { + if let Some(cmd) = val.split_whitespace().next() { + return cmd.to_string(); + } + } + "openab-agent".into() +} +fn default_agent_args() -> Vec { + if let Ok(val) = std::env::var("OPENAB_AGENT_COMMAND") { + let parts: Vec<&str> = val.split_whitespace().collect(); + if parts.len() > 1 { + return parts[1..].iter().map(|s| s.to_string()).collect(); + } + } + Vec::new() +} +fn default_max_sessions() -> usize { + 10 +} +fn default_ttl_hours() -> u64 { + 4 +} +pub(crate) fn default_prompt_hard_timeout_secs() -> u64 { + 30 * 60 +} +pub(crate) fn default_liveness_check_secs() -> u64 { + 30 +} +fn default_true() -> bool { + true +} + +fn emoji_queued() -> String { + "👀".into() +} +fn emoji_thinking() -> String { + "🤔".into() +} +fn emoji_tool() -> String { + "🔥".into() +} +fn emoji_coding() -> String { + "👨‍💻".into() +} +fn emoji_web() -> String { + "⚡".into() +} +fn emoji_done() -> String { + "🆗".into() +} +fn emoji_error() -> String { + "😱".into() +} + +fn default_debounce_ms() -> u64 { + 700 +} +fn default_stall_soft_ms() -> u64 { + 10_000 +} +fn default_stall_hard_ms() -> u64 { + 30_000 +} +fn default_done_hold_ms() -> u64 { + 1_500 +} +fn default_error_hold_ms() -> u64 { + 2_500 +} + +impl Default for PoolConfig { + fn default() -> Self { + Self { + max_sessions: default_max_sessions(), + session_ttl_hours: default_ttl_hours(), + prompt_hard_timeout_secs: default_prompt_hard_timeout_secs(), + liveness_check_secs: default_liveness_check_secs(), + } + } +} + +impl Default for ReactionsConfig { + fn default() -> Self { + Self { + enabled: true, + remove_after_reply: false, + tool_display: ToolDisplay::default(), + emojis: ReactionEmojis::default(), + timing: ReactionTiming::default(), + } + } +} + +impl Default for ReactionEmojis { + fn default() -> Self { + Self { + queued: emoji_queued(), + thinking: emoji_thinking(), + tool: emoji_tool(), + coding: emoji_coding(), + web: emoji_web(), + done: emoji_done(), + error: emoji_error(), + } + } +} + +impl Default for ReactionTiming { + fn default() -> Self { + Self { + debounce_ms: default_debounce_ms(), + stall_soft_ms: default_stall_soft_ms(), + stall_hard_ms: default_stall_hard_ms(), + done_hold_ms: default_done_hold_ms(), + error_hold_ms: default_error_hold_ms(), + } + } +} + +// --- markdown --- + +#[derive(Debug, Clone, Default, Deserialize)] +pub struct MarkdownConfig { + #[serde(default)] + pub tables: TableMode, +} + +// --- loading --- + +/// Resolve an allow_all flag: if explicitly set, use it; otherwise infer from the list. +/// Non-empty list → false (respect the list), empty list → true (allow all). +pub fn resolve_allow_all(flag: Option, list: &[String]) -> bool { + flag.unwrap_or(list.is_empty()) +} + +fn expand_env_vars(raw: &str) -> String { + let re = Regex::new(r"\$\{(\w+)\}").unwrap(); + re.replace_all(raw, |caps: ®ex::Captures| { + std::env::var(&caps[1]).unwrap_or_default() + }) + .into_owned() +} + +/// Load raw config text from a file path (env vars expanded but secrets NOT resolved). +pub fn load_config_raw(path: &Path) -> anyhow::Result { + let raw = std::fs::read_to_string(path) + .map_err(|e| anyhow::anyhow!("failed to read {}: {e}", path.display()))?; + Ok(expand_env_vars(&raw)) +} + +/// Load raw config text from a URL (env vars expanded but secrets NOT resolved). +pub async fn load_config_raw_from_url(url: &str) -> anyhow::Result { + let client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(10)) + .build()?; + let resp = client + .get(url) + .send() + .await + .map_err(|e| anyhow::anyhow!("failed to fetch remote config from {url}: {e}"))?; + let status = resp.status(); + if !status.is_success() { + anyhow::bail!("remote config request to {url} returned HTTP {status}"); + } + let bytes = resp + .bytes() + .await + .map_err(|e| anyhow::anyhow!("failed to read response body from {url}: {e}"))?; + const MAX_CONFIG_BYTES: usize = 1024 * 1024; + if bytes.len() > MAX_CONFIG_BYTES { + anyhow::bail!( + "remote config from {url} exceeds 1 MiB limit ({} bytes)", + bytes.len() + ); + } + let raw = String::from_utf8(bytes.to_vec()) + .map_err(|e| anyhow::anyhow!("remote config from {url} is not valid UTF-8: {e}"))?; + Ok(expand_env_vars(&raw)) +} + +/// Parse config from already-expanded text. +pub fn parse_config_str(expanded: &str, source: &str) -> anyhow::Result { + parse_config_inner(expanded, source) +} + +#[cfg(test)] +fn parse_config(raw: &str, source: &str) -> anyhow::Result { + let expanded = expand_env_vars(raw); + parse_config_inner(&expanded, source) +} + +#[cfg(test)] +fn load_config(path: &Path) -> anyhow::Result { + let raw = std::fs::read_to_string(path) + .map_err(|e| anyhow::anyhow!("failed to read {}: {e}", path.display()))?; + parse_config(&raw, path.display().to_string().as_str()) +} + +#[cfg(test)] +async fn load_config_from_url(url: &str) -> anyhow::Result { + let client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(10)) + .build()?; + let resp = client + .get(url) + .send() + .await + .map_err(|e| anyhow::anyhow!("failed to fetch remote config from {url}: {e}"))?; + let status = resp.status(); + if !status.is_success() { + anyhow::bail!("remote config request to {url} returned HTTP {status}"); + } + let bytes = resp + .bytes() + .await + .map_err(|e| anyhow::anyhow!("failed to read response body from {url}: {e}"))?; + let raw = String::from_utf8(bytes.to_vec()) + .map_err(|e| anyhow::anyhow!("remote config from {url} is not valid UTF-8: {e}"))?; + parse_config(&raw, url) +} + +fn parse_config_inner(expanded: &str, source: &str) -> anyhow::Result { + let mut config: Config = toml::from_str(expanded) + .map_err(|e| anyhow::anyhow!("failed to parse config from {source}: {e}"))?; + + // If [agentcore] is set and [agent] command was not explicitly provided, + // synthesize agent config to spawn the bundled agentcore-acp adapter. + if let Some(ref ac) = config.agentcore { + // Validate ARN format: arn:aws:bedrock-agentcore:REGION:ACCOUNT:runtime/ID + let parts: Vec<&str> = ac.runtime_arn.split(':').collect(); + anyhow::ensure!( + parts.len() >= 6 + && parts[0] == "arn" + && parts[2] == "bedrock-agentcore" + && !parts[3].is_empty() + && parts[5].starts_with("runtime/"), + "agentcore.runtime_arn is not a valid AgentCore Runtime ARN \ + (expected arn:aws:bedrock-agentcore:REGION:ACCOUNT:runtime/ID, got \"{}\")", + ac.runtime_arn + ); + + if !config.agent.command_explicit { + // Use native Rust bridge (agentcore feature) or fall back to Python adapter + #[cfg(feature = "agentcore")] + let (cmd, args) = { + let self_exe = std::env::current_exe() + .map(|p| p.to_string_lossy().to_string()) + .unwrap_or_else(|_| "openab".to_string()); + ( + self_exe, + vec![ + "agentcore-bridge".into(), + "--runtime-arn".into(), + ac.runtime_arn.clone(), + "--region".into(), + ac.region(), + "--command".into(), + ac.shell_command.clone(), + ], + ) + }; + #[cfg(not(feature = "agentcore"))] + let (cmd, args) = ( + "uv".to_string(), + vec![ + "run".into(), + "--script".into(), + "/opt/agentcore/acp/agentcore_acp.py".into(), + "--runtime-arn".into(), + ac.runtime_arn.clone(), + "--region".into(), + ac.region(), + "--cancel-strategy".into(), + ac.cancel_strategy.to_string(), + ], + ); + config.agent = AgentConfig { + command: cmd, + args, + working_dir: config.agent.working_dir.clone(), + env: config.agent.env.clone(), + inherit_env: config.agent.inherit_env.clone(), + command_explicit: true, // synthesized counts as explicit + }; + } + } + + // Validate max_buffered_messages > 0 (tokio::sync::mpsc::channel panics on 0) + // and max_batch_tokens > 0 (otherwise the consumer's token-cap check forces every + // batch to size 1 — functionally per-message via a confusing path). + if let Some(ref d) = config.discord { + anyhow::ensure!( + d.max_buffered_messages > 0, + "discord.max_buffered_messages must be > 0" + ); + anyhow::ensure!( + d.max_batch_tokens > 0, + "discord.max_batch_tokens must be > 0" + ); + } + if let Some(ref s) = config.slack { + anyhow::ensure!( + s.max_buffered_messages > 0, + "slack.max_buffered_messages must be > 0" + ); + anyhow::ensure!(s.max_batch_tokens > 0, "slack.max_batch_tokens must be > 0"); + } + if let Some(ref g) = config.gateway { + anyhow::ensure!( + g.max_buffered_messages > 0, + "gateway.max_buffered_messages must be > 0" + ); + anyhow::ensure!( + g.max_batch_tokens > 0, + "gateway.max_batch_tokens must be > 0" + ); + } + anyhow::ensure!( + config.pool.liveness_check_secs > 0, + "pool.liveness_check_secs must be > 0 (zero would spin the recv loop)" + ); + + Ok(config) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Write; + + const MINIMAL_TOML: &str = r#" +[discord] +bot_token = "test-token" + +[agent] +command = "echo" +"#; + + #[test] + fn parse_minimal_config() { + let cfg = parse_config(MINIMAL_TOML, "test").unwrap(); + assert_eq!(cfg.discord.unwrap().bot_token, "test-token"); + assert_eq!(cfg.agent.command, "echo"); + assert_eq!(cfg.pool.max_sessions, 10); + assert!(cfg.reactions.enabled); + } + + #[test] + fn expand_env_vars_replaces_known_var() { + std::env::set_var("AB_TEST_VAR", "hello"); + let result = expand_env_vars("token=${AB_TEST_VAR}"); + assert_eq!(result, "token=hello"); + std::env::remove_var("AB_TEST_VAR"); + } + + #[test] + fn expand_env_vars_unknown_becomes_empty() { + let result = expand_env_vars("token=${AB_NONEXISTENT_12345}"); + assert_eq!(result, "token="); + } + + #[test] + fn expand_env_vars_in_config() { + std::env::set_var("AB_TEST_TOKEN", "secret-bot-token"); + let toml = r#" +[discord] +bot_token = "${AB_TEST_TOKEN}" + +[agent] +command = "echo" +"#; + let cfg = parse_config(toml, "test").unwrap(); + assert_eq!(cfg.discord.unwrap().bot_token, "secret-bot-token"); + std::env::remove_var("AB_TEST_TOKEN"); + } + + #[test] + fn parse_invalid_toml_returns_error() { + let result = parse_config("not valid toml {{{}}", "test"); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("failed to parse config from test")); + } + + #[test] + fn load_config_missing_file_returns_error() { + let result = load_config(Path::new("/tmp/agent-broker-nonexistent.toml")); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("failed to read")); + } + + #[test] + fn load_config_from_file() { + let mut tmp = tempfile::NamedTempFile::new().unwrap(); + write!(tmp, "{}", MINIMAL_TOML).unwrap(); + let cfg = load_config(tmp.path()).unwrap(); + assert_eq!(cfg.discord.unwrap().bot_token, "test-token"); + } + + #[tokio::test] + async fn load_config_from_url_invalid_host() { + let result = load_config_from_url("https://invalid.test.example/config.toml").await; + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("failed to fetch remote config")); + } + + #[test] + fn parse_gateway_config_defaults() { + let toml = r#" +[gateway] +url = "ws://gw:8080/ws" + +[agent] +command = "echo" +"#; + let cfg = parse_config(toml, "test").unwrap(); + let gw = cfg.gateway.unwrap(); + assert_eq!(gw.url, "ws://gw:8080/ws"); + assert_eq!(gw.platform, "telegram"); + assert!(gw.allowed_users.is_empty()); + assert!(gw.allowed_channels.is_empty()); + assert!(gw.allow_all_users.is_none()); + assert!(gw.allow_all_channels.is_none()); + // resolve_allow_all: empty lists → allow all + assert!(resolve_allow_all(gw.allow_all_users, &gw.allowed_users)); + assert!(resolve_allow_all( + gw.allow_all_channels, + &gw.allowed_channels + )); + } + + #[test] + fn parse_gateway_config_with_allowlists() { + let toml = r#" +[gateway] +url = "ws://gw:8080/ws" +platform = "line" +allowed_users = ["U1", "U2"] +allowed_channels = ["C1"] + +[agent] +command = "echo" +"#; + let cfg = parse_config(toml, "test").unwrap(); + let gw = cfg.gateway.unwrap(); + assert_eq!(gw.platform, "line"); + assert_eq!(gw.allowed_users, vec!["U1", "U2"]); + assert_eq!(gw.allowed_channels, vec!["C1"]); + // resolve_allow_all: non-empty lists → restricted + assert!(!resolve_allow_all(gw.allow_all_users, &gw.allowed_users)); + assert!(!resolve_allow_all( + gw.allow_all_channels, + &gw.allowed_channels + )); + } + + #[test] + fn tool_display_default_is_full() { + assert_eq!(ToolDisplay::default(), ToolDisplay::Full); + } + + #[test] + fn message_processing_mode_parses_per_message() { + let toml = r#" +[discord] +bot_token = "t" +message_processing_mode = "per-message" + +[agent] +command = "echo" +"#; + let cfg = parse_config(toml, "test").unwrap(); + assert_eq!( + cfg.discord.unwrap().message_processing_mode, + MessageProcessingMode::Message + ); + } + + #[test] + fn message_processing_mode_parses_per_thread() { + let toml = r#" +[discord] +bot_token = "t" +message_processing_mode = "per-thread" + +[agent] +command = "echo" +"#; + let cfg = parse_config(toml, "test").unwrap(); + assert_eq!( + cfg.discord.unwrap().message_processing_mode, + MessageProcessingMode::Thread + ); + } + + #[test] + fn message_processing_mode_parses_per_lane() { + let toml = r#" +[discord] +bot_token = "t" +message_processing_mode = "per-lane" + +[agent] +command = "echo" +"#; + let cfg = parse_config(toml, "test").unwrap(); + assert_eq!( + cfg.discord.unwrap().message_processing_mode, + MessageProcessingMode::Lane + ); + } + + // The legacy alias "batched" was removed: only per-message / per-thread / per-lane + // are accepted. Configs still using "batched" must migrate to an explicit value. + #[test] + fn message_processing_mode_batched_is_rejected() { + let toml = r#" +[discord] +bot_token = "t" +message_processing_mode = "batched" + +[agent] +command = "echo" +"#; + assert!(parse_config(toml, "test").is_err()); + } + + #[test] + fn message_processing_mode_default_is_per_message() { + let cfg = parse_config(MINIMAL_TOML, "test").unwrap(); + assert_eq!( + cfg.discord.unwrap().message_processing_mode, + MessageProcessingMode::Message + ); + } + + #[test] + fn message_processing_mode_unknown_value_errors() { + let toml = r#" +[discord] +bot_token = "t" +message_processing_mode = "bogus" + +[agent] +command = "echo" +"#; + assert!(parse_config(toml, "test").is_err()); + } + + #[test] + fn parse_gateway_config_explicit_allow_all_overrides_list() { + let toml = r#" +[gateway] +url = "ws://gw:8080/ws" +allow_all_users = true +allowed_users = ["U1"] + +[agent] +command = "echo" +"#; + let cfg = parse_config(toml, "test").unwrap(); + let gw = cfg.gateway.unwrap(); + // explicit flag overrides non-empty list + assert!(resolve_allow_all(gw.allow_all_users, &gw.allowed_users)); + } + + #[test] + fn stt_echo_transcript_defaults_to_false() { + let cfg = SttConfig::default(); + assert!( + !cfg.echo_transcript, + "echo_transcript should default to false" + ); + } + + #[test] + fn stt_echo_transcript_respects_explicit_false() { + let toml = r#" +[agent] +command = "echo" + +[stt] +enabled = true +api_key = "test" +echo_transcript = false +"#; + let cfg = parse_config(toml, "test").unwrap(); + assert!(cfg.stt.enabled); + assert!(!cfg.stt.echo_transcript); + } + + #[test] + fn parse_secrets_config() { + let toml = r#" +[discord] +bot_token = "${secrets.discord_token}" + +[agent] +command = "echo" + +[secrets.refs] +discord_token = "aws-sm://openab/prod#discord_bot_token" +github_pat = "exec:///home/agent/.local/bin/get-secret.sh vault/openab github_pat" + +[secrets.aws] +region = "ap-northeast-1" +endpoint_url = "http://localhost:4566" + +[secrets.exec] +timeout_seconds = 15 +"#; + let cfg = parse_config(toml, "test").unwrap(); + assert_eq!(cfg.secrets.refs.len(), 2); + assert_eq!( + cfg.secrets.refs.get("discord_token").unwrap(), + "aws-sm://openab/prod#discord_bot_token" + ); + assert_eq!( + cfg.secrets.refs.get("github_pat").unwrap(), + "exec:///home/agent/.local/bin/get-secret.sh vault/openab github_pat" + ); + assert_eq!(cfg.secrets.aws.region.as_deref(), Some("ap-northeast-1")); + assert_eq!( + cfg.secrets.aws.endpoint_url.as_deref(), + Some("http://localhost:4566") + ); + assert_eq!(cfg.secrets.exec.timeout_seconds, 15); + } + + #[test] + fn parse_secrets_config_defaults() { + let toml = r#" +[discord] +bot_token = "test" + +[agent] +command = "echo" +"#; + let cfg = parse_config(toml, "test").unwrap(); + assert!(cfg.secrets.refs.is_empty()); + assert!(cfg.secrets.aws.region.is_none()); + assert!(cfg.secrets.aws.endpoint_url.is_none()); + assert_eq!(cfg.secrets.exec.timeout_seconds, 10); + } + + #[test] + fn slack_assistant_mode_defaults_true_and_parses_false() { + let cfg: SlackConfig = toml::from_str("bot_token = \"x\"\napp_token = \"y\"\n").unwrap(); + assert!(cfg.assistant_mode, "assistant_mode must default to true"); + + let cfg2: SlackConfig = + toml::from_str("bot_token = \"x\"\napp_token = \"y\"\nassistant_mode = false\n") + .unwrap(); + assert!(!cfg2.assistant_mode); + } + + #[test] + fn agentcore_config_synthesizes_agent_command() { + let toml = r#" +[discord] +bot_token = "t" + +[agentcore] +runtime_arn = "arn:aws:bedrock-agentcore:us-east-1:123456789012:runtime/my-agent" +"#; + let cfg = parse_config(toml, "test").unwrap(); + #[cfg(feature = "agentcore")] + { + // With agentcore feature, spawns self with agentcore-bridge subcommand + assert!(cfg.agent.args.contains(&"agentcore-bridge".to_string())); + } + #[cfg(not(feature = "agentcore"))] + { + assert_eq!(cfg.agent.command, "uv"); + } + assert!(cfg.agent.args.contains(&"--runtime-arn".to_string())); + assert!(cfg + .agent + .args + .contains(&"arn:aws:bedrock-agentcore:us-east-1:123456789012:runtime/my-agent".to_string())); + } + + #[test] + fn agentcore_config_does_not_override_explicit_agent() { + let toml = r#" +[discord] +bot_token = "t" + +[agent] +command = "my-custom-agent" + +[agentcore] +runtime_arn = "arn:aws:bedrock-agentcore:us-east-1:123456789012:runtime/my-agent" +"#; + let cfg = parse_config(toml, "test").unwrap(); + assert_eq!(cfg.agent.command, "my-custom-agent"); + } + + #[test] + fn agentcore_config_defaults() { + let toml = r#" +[discord] +bot_token = "t" + +[agentcore] +runtime_arn = "arn:aws:bedrock-agentcore:us-east-1:123456789012:runtime/test" +"#; + let cfg = parse_config(toml, "test").unwrap(); + let ac = cfg.agentcore.unwrap(); + assert_eq!(ac.region(), "us-east-1"); + assert_eq!(ac.cancel_strategy, AgentCoreCancelStrategy::Stop); + } + + #[test] + fn agentcore_rejects_invalid_arn() { + let toml = r#" +[discord] +bot_token = "t" + +[agentcore] +runtime_arn = "not-a-valid-arn" +"#; + let err = parse_config(toml, "test").unwrap_err(); + assert!(err.to_string().contains("not a valid AgentCore Runtime ARN")); + } + + #[test] + fn agentcore_rejects_arn_wrong_service() { + let toml = r#" +[discord] +bot_token = "t" + +[agentcore] +runtime_arn = "arn:aws:s3:us-east-1:123456789012:bucket/my-bucket" +"#; + let err = parse_config(toml, "test").unwrap_err(); + assert!(err.to_string().contains("not a valid AgentCore Runtime ARN")); + } + + #[test] + fn agentcore_rejects_arn_missing_runtime_prefix() { + let toml = r#" +[discord] +bot_token = "t" + +[agentcore] +runtime_arn = "arn:aws:bedrock-agentcore:us-east-1:123456789012:agent/my-agent" +"#; + let err = parse_config(toml, "test").unwrap_err(); + assert!(err.to_string().contains("not a valid AgentCore Runtime ARN")); + } + + #[test] + fn agentcore_rejects_invalid_cancel_strategy() { + let toml = r#" +[discord] +bot_token = "t" + +[agentcore] +runtime_arn = "arn:aws:bedrock-agentcore:us-east-1:123456789012:runtime/test" +cancel_strategy = "stopp" +"#; + let err = parse_config(toml, "test").unwrap_err(); + assert!(err.to_string().contains("unknown variant")); + } + + #[test] + fn agentcore_extracts_region_from_arn() { + let toml = r#" +[discord] +bot_token = "t" + +[agentcore] +runtime_arn = "arn:aws:bedrock-agentcore:ap-northeast-1:123456789012:runtime/tokyo-agent" +"#; + let cfg = parse_config(toml, "test").unwrap(); + assert!(cfg.agent.args.contains(&"ap-northeast-1".to_string())); + } + + #[test] + fn agentcore_cancel_strategy_noop() { + let toml = r#" +[discord] +bot_token = "t" + +[agentcore] +runtime_arn = "arn:aws:bedrock-agentcore:us-east-1:123456789012:runtime/test" +cancel_strategy = "noop" +"#; + let cfg = parse_config(toml, "test").unwrap(); + let ac = cfg.agentcore.unwrap(); + assert_eq!(ac.cancel_strategy, AgentCoreCancelStrategy::Noop); + } +} diff --git a/crates/openab-core/src/cron.rs b/crates/openab-core/src/cron.rs new file mode 100644 index 000000000..db5828b22 --- /dev/null +++ b/crates/openab-core/src/cron.rs @@ -0,0 +1,1768 @@ +use crate::adapter::{AdapterRouter, ChannelRef, ChatAdapter, SenderContext}; +use crate::config::CronJobConfig; +use crate::format; +use chrono::{Timelike, Utc}; +use chrono_tz::Tz; +use cron::Schedule; +use std::collections::{HashMap, HashSet}; +use std::path::{Path, PathBuf}; +use std::str::FromStr; +use std::sync::Arc; +use std::time::SystemTime; +use tokio::process::Command; +use tokio::sync::Mutex; +use toml_edit::{value, DocumentMut}; +use tracing::{debug, error, info, warn}; + +/// Parse a 5-field POSIX cron expression into a `Schedule`. +/// +/// The `cron` crate expects a 6-field expression (with seconds), so we prepend "0". +/// +/// POSIX numeric day-of-week values (0..=7, where 0 or 7 = Sunday) are translated +/// to the `cron` crate's 1-based form (1..=7, where 1 = Sunday) before being handed +/// to the underlying parser. Without this, numeric day-of-week values are off by one +/// — e.g. `1-5` (Mon-Fri in POSIX) would be evaluated as Sun-Thu. See the +/// [`translate_posix_dow_field`] doc comment for details. +/// +/// Name-based day-of-week tokens (`Mon`, `Sun`, `Mon-Fri`, ...) are passed through +/// unchanged — the `cron` crate's internal name-to-ordinal map is consistent. +pub fn parse_cron_expr(expr: &str) -> Result { + let translated = translate_posix_cron_expr(expr)?; + let six_field = format!("0 {}", translated); + Schedule::from_str(&six_field).map_err(|e| e.to_string()) +} + +/// Translate a 5-field POSIX cron expression so the day-of-week field uses the +/// numeric convention of the `cron` crate. +/// +/// Only the 5th field (day-of-week) is rewritten; the other four fields pass +/// through unchanged. +fn translate_posix_cron_expr(expr: &str) -> Result { + let fields: Vec<&str> = expr.split_whitespace().collect(); + if fields.len() != 5 { + return Err(format!( + "expected 5 whitespace-separated cron fields, got {}: {:?}", + fields.len(), + expr + )); + } + let translated_dow = translate_posix_dow_field(fields[4])?; + Ok(format!( + "{} {} {} {} {}", + fields[0], fields[1], fields[2], fields[3], translated_dow + )) +} + +/// Translate a POSIX day-of-week field to the `cron` crate's numeric form. +/// +/// # Background +/// +/// POSIX cron (and Linux crontab, Kubernetes CronJob, GitHub Actions) uses +/// `0..=7` where `0` or `7` = Sunday, `1` = Monday, ..., `6` = Saturday. +/// +/// The `cron` crate uses `1..=7` where `1` = Sunday, `2` = Monday, ..., `7` = Saturday +/// (it matches via chrono's `Weekday::number_from_sunday()`). Without translation, +/// every numeric day-of-week value fires one day early: +/// +/// | POSIX intent | Without translation (cron crate reads as) | +/// |---------------|-------------------------------------------| +/// | `0`, `7` (Sun) | out-of-range / Sat | +/// | `1` (Mon) | Sun | +/// | `5` (Fri) | Thu | +/// | `1-5` (Mon-Fri) | Sun-Thu | +/// +/// # Algorithm +/// +/// 1. If the field contains any ASCII letter (e.g. `Mon-Fri`), pass it through — +/// the cron crate's name-to-ordinal map is internally consistent. +/// 2. Otherwise, expand each comma-separated component into the set of POSIX +/// day values it represents. Ranges (`a-b`) and step values (`a/s`, `a-b/s`, +/// `*/s`) are expanded here. `7` is normalized to `0` (both = Sunday) to +/// avoid duplication. +/// 3. If the resulting set covers all 7 days, emit `*` for brevity. +/// 4. Otherwise, shift each value by `+1` (POSIX `{0..=6}` → cron crate +/// `{1..=7}`) and emit as a comma-separated list, compacting contiguous +/// runs into ranges for readability. +/// +/// # Mixed numeric and name notation +/// +/// Mixing numeric and name tokens in the same field (e.g. `1,Mon`) is not +/// supported and will return an error. Use either all numeric (POSIX) or all +/// name-based notation. +fn translate_posix_dow_field(field: &str) -> Result { + use std::collections::BTreeSet; + + // Name-based notation is internally consistent in the cron crate — pass through. + // But reject mixed numeric+name notation (e.g. "1,Mon") which would leave the + // numeric part untranslated and silently wrong. + let has_alpha = field.chars().any(|c| c.is_ascii_alphabetic()); + let has_digit = field.chars().any(|c| c.is_ascii_digit()); + if has_alpha && has_digit { + return Err(format!( + "mixed numeric and name notation is not supported in day-of-week field: {:?}", + field + )); + } + if has_alpha { + return Ok(field.to_string()); + } + + if field.is_empty() { + return Err("empty day-of-week field".to_string()); + } + + let mut days: BTreeSet = BTreeSet::new(); + + for part in field.split(',') { + if part.is_empty() { + return Err(format!("empty component in day-of-week field: {:?}", field)); + } + + // Split off optional step: `a/s`, `a-b/s`, `*/s`. + let (range_part, step) = match part.split_once('/') { + Some((r, s)) => { + let step_n: u32 = s + .parse() + .map_err(|_| format!("invalid step value in {:?}", part))?; + if step_n == 0 { + return Err(format!("step value cannot be zero in {:?}", part)); + } + (r, step_n) + } + None => (part, 1u32), + }; + + // Expand range_part to the list of POSIX day values it represents. + // Values may include 7 (Sunday alias for 0); normalization happens below. + let raw_values: Vec = if range_part == "*" { + (0..=6).collect() + } else if let Some((a, b)) = range_part.split_once('-') { + let a_n: u32 = a + .parse() + .map_err(|_| format!("invalid range start in {:?}", part))?; + let b_n: u32 = b + .parse() + .map_err(|_| format!("invalid range end in {:?}", part))?; + if a_n > 7 || b_n > 7 { + return Err(format!( + "day-of-week value out of range (0-7) in {:?}", + part + )); + } + if a_n > b_n { + return Err(format!("invalid range {:?}: start > end", part)); + } + (a_n..=b_n).collect() + } else { + let n: u32 = range_part + .parse() + .map_err(|_| format!("invalid number in {:?}", part))?; + if n > 7 { + return Err(format!("day-of-week value out of range (0-7): {}", n)); + } + if step > 1 { + // n/step means "from n through end-of-domain, stepping by step" + // Normalize 7 (Sunday alias) to 0 before expansion. + let start = if n == 7 { 0 } else { n }; + (start..=6).collect() + } else { + vec![n] + } + }; + + // Apply step filter, normalize 7 → 0, collect into the set. + for (i, &v) in raw_values.iter().enumerate() { + if (i as u32).is_multiple_of(step) { + let normalized = if v == 7 { 0 } else { v }; + days.insert(normalized); + } + } + } + + if days.is_empty() { + return Err(format!("empty day-of-week field: {:?}", field)); + } + + // All 7 days → emit `*` for brevity. + if days.len() == 7 { + return Ok("*".to_string()); + } + + // Shift POSIX {0..=6} → cron crate {1..=7} and emit, compacting contiguous runs. + let shifted: Vec = days.iter().map(|d| d + 1).collect(); + Ok(compact_ordinal_set(&shifted)) +} + +/// Compact a sorted list of ordinals into cron-style comma-list with ranges, +/// e.g. `[2,3,4,5,6]` → `"2-6"`, `[1,3,5]` → `"1,3,5"`, `[1,2,4,5]` → `"1-2,4-5"`. +fn compact_ordinal_set(sorted: &[u32]) -> String { + if sorted.is_empty() { + return String::new(); + } + let mut out: Vec = Vec::new(); + let mut start = sorted[0]; + let mut end = sorted[0]; + for &v in &sorted[1..] { + if v == end + 1 { + end = v; + } else { + out.push(render_run(start, end)); + start = v; + end = v; + } + } + out.push(render_run(start, end)); + out.join(",") +} + +fn render_run(start: u32, end: u32) -> String { + if start == end { + format!("{}", start) + } else { + format!("{}-{}", start, end) + } +} + +/// Check whether a cron schedule should fire right now. +/// Truncates the current time to the minute boundary and checks if the +/// schedule has an event at exactly that minute. +pub fn should_fire(schedule: &Schedule, tz: Tz) -> bool { + let now = Utc::now().with_timezone(&tz); + let minute_start = now.with_second(0).unwrap().with_nanosecond(0).unwrap(); + let query_from = minute_start - chrono::Duration::seconds(1); + schedule + .after(&query_from) + .next() + .map(|next| next == minute_start) + .unwrap_or(false) +} + +/// Known platforms that have adapter support. +const VALID_PLATFORMS: &[&str] = &["discord", "slack"]; + +/// Validate all cronjob configs (fail-fast on bad cron expressions or timezones). +pub fn validate_cronjobs( + cronjobs: &[CronJobConfig], + configured_platforms: &[&str], +) -> anyhow::Result<()> { + for (i, job) in cronjobs.iter().enumerate() { + if !job.enabled { + continue; + } + parse_cron_expr(&job.schedule).map_err(|e| { + anyhow::anyhow!( + "cronjobs[{i}]: invalid cron expression {:?}: {e}", + job.schedule + ) + })?; + job.timezone.parse::().map_err(|e| { + anyhow::anyhow!("cronjobs[{i}]: invalid timezone {:?}: {e}", job.timezone) + })?; + if !VALID_PLATFORMS.contains(&job.platform.as_str()) { + anyhow::bail!( + "cronjobs[{i}]: unknown platform {:?} (expected one of: {VALID_PLATFORMS:?})", + job.platform + ); + } + if !configured_platforms.contains(&job.platform.as_str()) { + anyhow::bail!( + "cronjobs[{i}]: platform {:?} is not configured — add [{}] to config.toml", + job.platform, + job.platform + ); + } + if job.disable_on_success.is_some() { + anyhow::bail!( + "cronjobs[{i}]: disable_on_success is only supported in usercron [[jobs]], not baseline [[cron.jobs]]" + ); + } + } + Ok(()) +} + +// --------------------------------------------------------------------------- +// Usercron hot-reload +// --------------------------------------------------------------------------- + +/// Wrapper for deserializing cronjob.toml which contains `[[jobs]]`. +#[derive(serde::Deserialize)] +struct UsercronFile { + #[serde(default)] + jobs: Vec, +} + +/// Load and validate cronjobs from an external TOML file. +/// Returns an empty vec if the file doesn't exist. +/// Logs and skips individual invalid entries rather than failing entirely. +pub fn load_usercron_file(path: &Path, configured_platforms: &[&str]) -> Vec { + let content = match std::fs::read_to_string(path) { + Ok(c) => c, + Err(e) if e.kind() == std::io::ErrorKind::NotFound => return vec![], + Err(e) => { + warn!(path = %path.display(), error = %e, "failed to read usercron file"); + return vec![]; + } + }; + let parsed: UsercronFile = match toml::from_str(&content) { + Ok(f) => f, + Err(e) => { + warn!(path = %path.display(), error = %e, "failed to parse usercron file, skipping all entries"); + return vec![]; + } + }; + // Validate each entry individually — keep valid ones, skip bad ones + parsed.jobs.into_iter().enumerate().filter(|(i, job)| { + if let Err(e) = parse_cron_expr(&job.schedule) { + warn!(index = i, schedule = %job.schedule, error = %e, "usercron: invalid cron expression, skipping"); + return false; + } + if job.timezone.parse::().is_err() { + warn!(index = i, timezone = %job.timezone, "usercron: invalid timezone, skipping"); + return false; + } + if !VALID_PLATFORMS.contains(&job.platform.as_str()) { + warn!(index = i, platform = %job.platform, "usercron: unknown platform, skipping"); + return false; + } + if !configured_platforms.contains(&job.platform.as_str()) { + warn!(index = i, platform = %job.platform, "usercron: platform not configured, skipping"); + return false; + } + if job.disable_on_success.as_deref().is_some_and(|s| !s.trim().is_empty()) { + if job.id.as_deref().is_none_or(|s| s.trim().is_empty()) { + warn!(index = i, "usercron: disable_on_success requires id, skipping"); + return false; + } + if job + .disable_on_success_match + .as_deref() + .is_none_or(|s| s.trim().is_empty()) + { + warn!(index = i, "usercron: disable_on_success requires disable_on_success_match, skipping"); + return false; + } + } + true + }).map(|(_, job)| job).collect() +} + +/// Get file mtime, returns None if file doesn't exist or metadata fails. +fn file_mtime(path: &Path) -> Option { + std::fs::metadata(path).ok().and_then(|m| m.modified().ok()) +} + +/// A parsed, ready-to-evaluate cron job. +struct ParsedJob { + schedule: Schedule, + tz: Tz, + config: CronJobConfig, + usercron_path: Option, +} + +/// Parse a list of CronJobConfig into ParsedJob, filtering out disabled/invalid entries. +fn parse_job_list( + configs: &[CronJobConfig], + source: &str, + usercron_path: Option<&Path>, +) -> Vec { + configs.iter().filter(|job| { + if !job.enabled { + info!(schedule = %job.schedule, channel = %job.channel, source, "cronjob disabled, skipping"); + } + job.enabled + }).filter_map(|job| { + let schedule = match parse_cron_expr(&job.schedule) { + Ok(s) => s, + Err(e) => { + error!(schedule = %job.schedule, error = %e, source, "invalid cron expression, skipping"); + return None; + } + }; + let tz: Tz = match job.timezone.parse() { + Ok(t) => t, + Err(e) => { + error!(timezone = %job.timezone, error = %e, source, "invalid timezone, skipping"); + return None; + } + }; + info!( + schedule = %job.schedule, timezone = %job.timezone, + channel = %job.channel, platform = %job.platform, + message = %job.message, source, + "cronjob registered" + ); + Some(ParsedJob { + schedule, + tz, + config: job.clone(), + usercron_path: usercron_path.map(Path::to_path_buf), + }) + }).collect() +} + +/// Run the internal cron scheduler. Evaluates cron expressions once per minute. +/// `usercron_path` enables hot-reload of an external cronjob.toml file. +pub async fn run_scheduler( + cronjobs: Vec, + usercron_path: Option, + configured_platforms: Vec, + router: Arc, + adapters: HashMap>, + mut shutdown_rx: tokio::sync::watch::Receiver, +) { + let platform_refs: Vec<&str> = configured_platforms.iter().map(|s| s.as_str()).collect(); + + // Parse baseline jobs from config.toml + let baseline_jobs = parse_job_list(&cronjobs, "config.toml", None); + + // Load initial usercron jobs + let mut usercron_jobs = if let Some(ref path) = usercron_path { + let configs = load_usercron_file(path, &platform_refs); + if !configs.is_empty() { + info!(count = configs.len(), path = %path.display(), "loaded usercron jobs"); + } + parse_job_list(&configs, "cronjob.toml", Some(path.as_path())) + } else { + vec![] + }; + let mut last_usercron_mtime: Option = usercron_path.as_deref().and_then(file_mtime); + + if baseline_jobs.is_empty() && usercron_jobs.is_empty() { + if usercron_path.is_some() { + info!( + "no cronjobs yet, but usercron_path is set — scheduler will watch for cronjob.toml" + ); + } else { + debug!("no cronjobs configured, scheduler not started"); + return; + } + } + + let total = baseline_jobs.len() + usercron_jobs.len(); + info!( + baseline = baseline_jobs.len(), + usercron = usercron_jobs.len(), + total, + "cron scheduler started" + ); + + let in_flight: Arc>> = Arc::new(Mutex::new(HashSet::new())); + // Serialize usercron read-modify-write updates so concurrent jobs do not + // overwrite each other's enabled/thread_id changes. + let usercron_write_lock: Arc> = Arc::new(Mutex::new(())); + + // Align to next minute boundary + let now = Utc::now(); + let secs_into_minute = now.timestamp() % 60; + let align_delay = if secs_into_minute == 0 { + 0 + } else { + 60 - secs_into_minute as u64 + }; + if align_delay > 0 { + debug!(align_secs = align_delay, "aligning to next minute boundary"); + tokio::time::sleep(std::time::Duration::from_secs(align_delay)).await; + } + let mut ticker = tokio::time::interval(std::time::Duration::from_secs(60)); + ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + + let mut tasks: tokio::task::JoinSet<()> = tokio::task::JoinSet::new(); + + loop { + tokio::select! { + _ = ticker.tick() => { + // Hot-reload usercron file if mtime changed + if let Some(ref path) = usercron_path { + let current_mtime = file_mtime(path); + if current_mtime != last_usercron_mtime { + let configs = load_usercron_file(path, &platform_refs); + info!(count = configs.len(), path = %path.display(), "usercron file changed, reloading"); + // Keep in-flight indices across reload. A scheduler writeback + // (thread_id or enabled=false) changes mtime deterministically; + // clearing usercron indices here would allow the same job to + // overlap on the next tick while its previous run is still active. + usercron_jobs = + parse_job_list(&configs, "cronjob.toml", Some(path.as_path())); + last_usercron_mtime = current_mtime; + } + } + + // Evaluate all jobs: baseline first, then usercron + let all_jobs = baseline_jobs.iter().chain(usercron_jobs.iter()); + for (idx, job) in all_jobs.enumerate() { + if !should_fire(&job.schedule, job.tz) { + continue; + } + { + let running = in_flight.lock().await; + if running.contains(&idx) { + warn!(schedule = %job.config.schedule, channel = %job.config.channel, "skipping cronjob, previous execution still running"); + continue; + } + } + info!( + schedule = %job.config.schedule, + channel = %job.config.channel, + platform = %job.config.platform, + message = %job.config.message, + sender = %job.config.sender_name, + "🔔 cronjob fired" + ); + in_flight.lock().await.insert(idx); + + let config = job.config.clone(); + let usercron_path = job.usercron_path.clone(); + let router = router.clone(); + let adapters = adapters.clone(); + let in_flight = in_flight.clone(); + let usercron_write_lock = usercron_write_lock.clone(); + tasks.spawn(async move { + fire_cronjob( + idx, + &config, + usercron_path, + &router, + &adapters, + in_flight, + usercron_write_lock, + ) + .await; + }); + } + while tasks.try_join_next().is_some() {} + } + _ = shutdown_rx.changed() => { + if *shutdown_rx.borrow() { + info!("cron scheduler shutting down, waiting for in-flight tasks"); + let drain = async { while tasks.join_next().await.is_some() {} }; + let _ = tokio::time::timeout(std::time::Duration::from_secs(30), drain).await; + return; + } + } + } + } +} + +/// RAII guard that removes a job index from the in-flight set on drop. +struct InFlightGuard { + idx: usize, + set: Arc>>, +} + +impl Drop for InFlightGuard { + fn drop(&mut self) { + let idx = self.idx; + let set = self.set.clone(); + tokio::spawn(async move { + set.lock().await.remove(&idx); + }); + } +} + +async fn fire_cronjob( + idx: usize, + job: &CronJobConfig, + usercron_path: Option, + router: &Arc, + adapters: &HashMap>, + in_flight: Arc>>, + usercron_write_lock: Arc>, +) { + let _guard = InFlightGuard { + idx, + set: in_flight, + }; + + let adapter = match adapters.get(&job.platform) { + Some(a) => a.clone(), + None => { + error!(platform = %job.platform, "no adapter for platform, skipping cronjob"); + return; + } + }; + + if let Some(command) = non_empty_opt(job.disable_on_success.as_deref()) { + let marker = match non_empty_opt(job.disable_on_success_match.as_deref()) { + Some(marker) => marker, + None => { + warn!( + id = job.id.as_deref().unwrap_or(""), + "disable_on_success configured without disable_on_success_match, treating as not achieved" + ); + "" + } + }; + if !marker.is_empty() { + match check_disable_on_success(job, command, marker).await { + DisableOnSuccessResult::Achieved => { + let channel = ChannelRef { + platform: job.platform.clone(), + channel_id: job.channel.clone(), + thread_id: job.thread_id.clone(), + parent_id: None, + origin_event_id: None, + }; + if let Err(e) = adapter + .send_message( + &channel, + &format!( + "✅ Goal achieved: `{}` matched `{}`. Disabling cronjob.", + command, marker + ), + ) + .await + { + error!(channel = %job.channel, error = %e, "failed to send goal achieved message"); + } + + if let (Some(path), Some(id)) = + (usercron_path.as_deref(), non_empty_opt(job.id.as_deref())) + { + let _write_guard = usercron_write_lock.lock().await; + if let Err(e) = update_usercron_job(path, id, Some(false), None) { + error!(path = %path.display(), id, error = %e, "failed to disable completed usercron job"); + } + } else { + warn!("completed disable_on_success job has no usercron path or id, cannot write enabled=false"); + } + return; + } + DisableOnSuccessResult::NotAchieved(reason) => { + info!( + id = job.id.as_deref().unwrap_or(""), + reason, + "disable_on_success not achieved, firing cronjob normally" + ); + } + } + } + } + + let thread_channel = ChannelRef { + platform: job.platform.clone(), + channel_id: job.channel.clone(), + thread_id: job.thread_id.clone(), + parent_id: None, + origin_event_id: None, + }; + + let trigger_msg = match adapter + .send_message( + &thread_channel, + &format!("🕐 [{}]: {}", job.sender_name, job.message), + ) + .await + { + Ok(msg) => msg, + Err(e) => { + error!(channel = %job.channel, error = %e, "failed to send cron message"); + return; + } + }; + + let reply_channel = if job.thread_id.is_some() { + thread_channel.clone() + } else { + let thread_name = format::shorten_thread_name(&job.message); + match adapter + .create_thread(&thread_channel, &trigger_msg, &thread_name) + .await + { + Ok(ch) => { + if let (Some(path), Some(id), Some(thread_id)) = ( + usercron_path.as_deref(), + non_empty_opt(job.id.as_deref()), + ch.thread_id.as_deref().or(Some(ch.channel_id.as_str())), + ) { + let _write_guard = usercron_write_lock.lock().await; + if let Err(e) = update_usercron_job(path, id, None, Some(thread_id)) { + warn!(path = %path.display(), id, error = %e, "failed to persist usercron thread_id"); + } + } + ch + } + Err(e) => { + error!(channel = %job.channel, error = %e, "failed to create cron thread"); + let _ = adapter + .send_message( + &thread_channel, + &format!("⚠️ cronjob: failed to create thread: {e}"), + ) + .await; + return; + } + } + }; + + let sender = SenderContext { + schema: "openab.sender.v1".into(), + sender_id: "openab-cron".into(), + sender_name: job.sender_name.clone(), + display_name: job.sender_name.clone(), + channel: job.platform.clone(), + channel_id: reply_channel + .parent_id + .as_deref() + .unwrap_or(&reply_channel.channel_id) + .to_string(), + thread_id: reply_channel + .thread_id + .clone() + .or(Some(reply_channel.channel_id.clone())), + is_bot: true, + timestamp: Some(Utc::now().to_rfc3339()), + message_id: None, // cron jobs don't originate from a message + receiver_id: None, // cron jobs are self-triggered, no external receiver + }; + let sender_json = match serde_json::to_string(&sender) { + Ok(j) => j, + Err(e) => { + warn!(error = %e, "failed to serialize cron sender context, skipping"); + return; + } + }; + + if let Err(e) = router + .handle_message( + &adapter, + crate::adapter::MessageContext { + thread_channel: reply_channel.clone(), + sender_json, + prompt: job.message.clone(), + extra_blocks: vec![], + trigger_msg, + other_bot_present: false, + }, + ) + .await + { + error!("cron handle_message error: {e}"); + let _ = adapter + .send_message(&reply_channel, &format!("⚠️ cronjob error: {e}")) + .await; + } +} + +enum DisableOnSuccessResult { + Achieved, + NotAchieved(&'static str), +} + +fn non_empty_opt(value: Option<&str>) -> Option<&str> { + value.and_then(|s| { + let trimmed = s.trim(); + if trimmed.is_empty() { + None + } else { + Some(trimmed) + } + }) +} + +async fn check_disable_on_success( + job: &CronJobConfig, + command: &str, + marker: &str, +) -> DisableOnSuccessResult { + let timeout_secs = job.disable_on_success_timeout_secs.max(1); + let mut cmd = shell_command(command); + if let Some(dir) = non_empty_opt(job.disable_on_success_working_dir.as_deref()) { + cmd.current_dir(dir); + } + cmd.stdout(std::process::Stdio::piped()); + cmd.stderr(std::process::Stdio::piped()); + + let mut child = match cmd.spawn() { + Ok(child) => child, + Err(e) => { + warn!( + id = job.id.as_deref().unwrap_or(""), + command, + error = %e, + "disable_on_success command failed to start" + ); + return DisableOnSuccessResult::NotAchieved("command failed to start"); + } + }; + + // Take stdout/stderr handles and drain them concurrently to prevent pipe buffer deadlock. + let stdout_handle = child.stdout.take(); + let stderr_handle = child.stderr.take(); + + let stdout_task = tokio::spawn(async move { + let mut buf = Vec::new(); + if let Some(mut out) = stdout_handle { + let _ = tokio::io::AsyncReadExt::read_to_end(&mut out, &mut buf).await; + } + buf + }); + let stderr_task = tokio::spawn(async move { + let mut buf = Vec::new(); + if let Some(mut err) = stderr_handle { + let _ = tokio::io::AsyncReadExt::read_to_end(&mut err, &mut buf).await; + } + buf + }); + + let deadline = tokio::time::sleep(std::time::Duration::from_secs(timeout_secs)); + tokio::pin!(deadline); + + tokio::select! { + status = child.wait() => { + let status = match status { + Ok(s) => s, + Err(e) => { + warn!( + id = job.id.as_deref().unwrap_or(""), + command, + error = %e, + "disable_on_success command wait failed" + ); + stdout_task.abort(); + stderr_task.abort(); + return DisableOnSuccessResult::NotAchieved("command wait failed"); + } + }; + if !status.success() { + stdout_task.abort(); + stderr_task.abort(); + return DisableOnSuccessResult::NotAchieved("command exited non-zero"); + } + let stdout_buf = stdout_task.await.unwrap_or_default(); + let stderr_buf = stderr_task.await.unwrap_or_default(); + let stdout = String::from_utf8_lossy(&stdout_buf); + let stderr = String::from_utf8_lossy(&stderr_buf); + if stdout.contains(marker) || stderr.contains(marker) { + DisableOnSuccessResult::Achieved + } else { + DisableOnSuccessResult::NotAchieved("success marker not found") + } + } + _ = &mut deadline => { + // Timeout — kill the child to avoid orphan processes. + let _ = child.kill().await; + stdout_task.abort(); + stderr_task.abort(); + warn!( + id = job.id.as_deref().unwrap_or(""), + command, + timeout_secs, + "disable_on_success command timed out" + ); + DisableOnSuccessResult::NotAchieved("command timed out") + } + } +} + +fn shell_command(command: &str) -> Command { + #[cfg(windows)] + { + let mut child = Command::new("cmd"); + child.arg("/C").arg(command); + child + } + #[cfg(not(windows))] + { + let mut child = Command::new("sh"); + child.arg("-c").arg(command); + child + } +} + +fn update_usercron_job( + path: &Path, + id: &str, + enabled: Option, + thread_id: Option<&str>, +) -> anyhow::Result<()> { + let content = std::fs::read_to_string(path)?; + let mut doc = content.parse::()?; + let jobs = doc + .get_mut("jobs") + .and_then(|item| item.as_array_of_tables_mut()) + .ok_or_else(|| anyhow::anyhow!("usercron file has no [[jobs]] array"))?; + + let mut found = false; + for table in jobs.iter_mut() { + if table.get("id").and_then(|item| item.as_str()) != Some(id) { + continue; + } + if let Some(enabled) = enabled { + table["enabled"] = value(enabled); + } + if let Some(thread_id) = thread_id { + table["thread_id"] = value(thread_id); + } + found = true; + break; + } + + if !found { + anyhow::bail!("usercron job id {:?} not found", id); + } + + // Atomic write: write to temp file then rename to avoid corruption on crash. + let tmp = path.with_extension("toml.tmp"); + std::fs::write(&tmp, doc.to_string())?; + std::fs::rename(&tmp, path)?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use chrono::{Datelike, Timelike}; + + // --- POSIX day-of-week translator --- + + #[test] + fn translate_dow_star_passes_through() { + assert_eq!(translate_posix_dow_field("*").unwrap(), "*"); + } + + #[test] + fn translate_dow_single_sunday_zero() { + assert_eq!(translate_posix_dow_field("0").unwrap(), "1"); + } + + #[test] + fn translate_dow_single_sunday_seven() { + assert_eq!(translate_posix_dow_field("7").unwrap(), "1"); + } + + #[test] + fn translate_dow_single_monday() { + assert_eq!(translate_posix_dow_field("1").unwrap(), "2"); + } + + #[test] + fn translate_dow_single_saturday() { + assert_eq!(translate_posix_dow_field("6").unwrap(), "7"); + } + + #[test] + fn translate_dow_weekday_range() { + // POSIX 1-5 (Mon-Fri) -> cron crate 2-6 + assert_eq!(translate_posix_dow_field("1-5").unwrap(), "2-6"); + } + + #[test] + fn translate_dow_all_days_zero_to_six() { + assert_eq!(translate_posix_dow_field("0-6").unwrap(), "*"); + } + + #[test] + fn translate_dow_all_days_zero_to_seven() { + // POSIX `0-7` is a quirky but valid "all days" expression. + assert_eq!(translate_posix_dow_field("0-7").unwrap(), "*"); + } + + #[test] + fn translate_dow_all_days_one_to_seven() { + // POSIX `1-7` covers Mon..Sun = all 7 days. + assert_eq!(translate_posix_dow_field("1-7").unwrap(), "*"); + } + + #[test] + fn translate_dow_range_three_to_five() { + // POSIX 3-5 (Wed-Fri) -> cron crate 4-6 + assert_eq!(translate_posix_dow_field("3-5").unwrap(), "4-6"); + } + + #[test] + fn translate_dow_list_dedupes_zero_and_seven() { + // Both 0 and 7 = Sunday; output is a single value. + assert_eq!(translate_posix_dow_field("0,7").unwrap(), "1"); + } + + #[test] + fn translate_dow_list_non_contiguous() { + // POSIX 1,3,5 (Mon,Wed,Fri) -> cron crate 2,4,6 + assert_eq!(translate_posix_dow_field("1,3,5").unwrap(), "2,4,6"); + } + + #[test] + fn translate_dow_list_compacts_contiguous_runs() { + // POSIX 1,2,4,5 -> cron crate 2,3,5,6 -> "2-3,5-6" + assert_eq!(translate_posix_dow_field("1,2,4,5").unwrap(), "2-3,5-6"); + } + + #[test] + fn translate_dow_step_from_star() { + // POSIX */2 = 0,2,4,6 = Sun,Tue,Thu,Sat -> cron crate 1,3,5,7 + assert_eq!(translate_posix_dow_field("*/2").unwrap(), "1,3,5,7"); + } + + #[test] + fn translate_dow_step_from_range() { + // POSIX 1-5/2 = 1,3,5 = Mon,Wed,Fri -> cron crate 2,4,6 + assert_eq!(translate_posix_dow_field("1-5/2").unwrap(), "2,4,6"); + } + + #[test] + fn translate_dow_names_pass_through() { + assert_eq!(translate_posix_dow_field("Mon-Fri").unwrap(), "Mon-Fri"); + assert_eq!( + translate_posix_dow_field("Mon,Wed,Fri").unwrap(), + "Mon,Wed,Fri" + ); + assert_eq!(translate_posix_dow_field("Sun").unwrap(), "Sun"); + } + + #[test] + fn translate_dow_step_from_singleton() { + // POSIX 1/2 = from Mon through Sat, step 2 = {1,3,5} = Mon,Wed,Fri -> cron crate 2,4,6 + assert_eq!(translate_posix_dow_field("1/2").unwrap(), "2,4,6"); + } + + #[test] + fn translate_dow_step_from_singleton_sunday() { + // POSIX 0/3 = from Sun through Sat, step 3 = {0,3,6} = Sun,Wed,Sat -> cron crate 1,4,7 + assert_eq!(translate_posix_dow_field("0/3").unwrap(), "1,4,7"); + } + + #[test] + fn translate_dow_step_from_singleton_seven() { + // POSIX 7/2 = Sunday alias, same as 0/2 = {0,2,4,6} = Sun,Tue,Thu,Sat -> cron crate 1,3,5,7 + assert_eq!(translate_posix_dow_field("7/2").unwrap(), "1,3,5,7"); + } + + #[test] + fn translate_dow_rejects_mixed_notation() { + assert!(translate_posix_dow_field("1,Mon").is_err()); + assert!(translate_posix_dow_field("Mon,1").is_err()); + assert!(translate_posix_dow_field("1-Fri").is_err()); + } + + #[test] + fn translate_dow_rejects_out_of_range() { + assert!(translate_posix_dow_field("8").is_err()); + assert!(translate_posix_dow_field("0-8").is_err()); + } + + #[test] + fn translate_dow_rejects_reversed_range() { + assert!(translate_posix_dow_field("5-3").is_err()); + } + + #[test] + fn translate_dow_rejects_empty() { + assert!(translate_posix_dow_field("").is_err()); + assert!(translate_posix_dow_field(",1").is_err()); + assert!(translate_posix_dow_field("1,").is_err()); + } + + #[test] + fn translate_dow_rejects_zero_step() { + assert!(translate_posix_dow_field("*/0").is_err()); + } + + // --- parse_cron_expr rejects wrong number of fields --- + + #[test] + fn parse_rejects_too_few_fields() { + assert!(parse_cron_expr("* * * *").is_err()); + } + + // --- POSIX-semantic Schedule behavior (regression for #784) --- + + #[test] + fn weekday_schedule_does_not_fire_on_sunday() { + use chrono::TimeZone; + // Regression for the reported bug: "0 7 * * 1-5" with timezone Asia/Taipei + // was firing on Sunday 2026-05-10 because the cron crate's `1-5` means + // Sun-Thu without translation. + let schedule = parse_cron_expr("0 7 * * 1-5").unwrap(); + let tz: Tz = "Asia/Taipei".parse().unwrap(); + let sunday = tz.with_ymd_and_hms(2026, 5, 10, 7, 0, 0).unwrap(); + let before = sunday - chrono::Duration::seconds(1); + let next = schedule.after(&before).next(); + assert_ne!( + next, + Some(sunday), + "POSIX 1-5 must not fire on Sunday (got next = {:?})", + next + ); + } + + #[test] + fn weekday_schedule_fires_on_monday() { + use chrono::TimeZone; + let schedule = parse_cron_expr("0 7 * * 1-5").unwrap(); + let tz: Tz = "Asia/Taipei".parse().unwrap(); + let monday = tz.with_ymd_and_hms(2026, 5, 11, 7, 0, 0).unwrap(); + let before = monday - chrono::Duration::seconds(1); + let next = schedule.after(&before).next(); + assert_eq!(next, Some(monday), "POSIX 1-5 must fire on Monday"); + } + + #[test] + fn weekday_schedule_fires_on_friday_not_saturday() { + use chrono::TimeZone; + let schedule = parse_cron_expr("0 7 * * 1-5").unwrap(); + let tz: Tz = "Asia/Taipei".parse().unwrap(); + // 2026-05-15 is Friday + let friday = tz.with_ymd_and_hms(2026, 5, 15, 7, 0, 0).unwrap(); + let before = friday - chrono::Duration::seconds(1); + let next = schedule.after(&before).next(); + assert_eq!(next, Some(friday), "POSIX 1-5 must fire on Friday"); + + // 2026-05-16 is Saturday - should not fire + let saturday = tz.with_ymd_and_hms(2026, 5, 16, 7, 0, 0).unwrap(); + let before = saturday - chrono::Duration::seconds(1); + let next = schedule.after(&before).next(); + assert_ne!(next, Some(saturday), "POSIX 1-5 must not fire on Saturday"); + } + + #[test] + fn sunday_schedule_fires_on_sunday_via_zero() { + use chrono::TimeZone; + let schedule = parse_cron_expr("0 7 * * 0").unwrap(); + let tz: Tz = "Asia/Taipei".parse().unwrap(); + let sunday = tz.with_ymd_and_hms(2026, 5, 10, 7, 0, 0).unwrap(); + let before = sunday - chrono::Duration::seconds(1); + let next = schedule.after(&before).next(); + assert_eq!(next, Some(sunday), "POSIX `0` must fire on Sunday"); + } + + #[test] + fn sunday_schedule_fires_on_sunday_via_seven() { + use chrono::TimeZone; + let schedule = parse_cron_expr("0 7 * * 7").unwrap(); + let tz: Tz = "Asia/Taipei".parse().unwrap(); + let sunday = tz.with_ymd_and_hms(2026, 5, 10, 7, 0, 0).unwrap(); + let before = sunday - chrono::Duration::seconds(1); + let next = schedule.after(&before).next(); + assert_eq!(next, Some(sunday), "POSIX `7` must also fire on Sunday"); + } + + #[test] + fn saturday_schedule_fires_on_saturday_via_six() { + use chrono::TimeZone; + let schedule = parse_cron_expr("0 7 * * 6").unwrap(); + let tz: Tz = "Asia/Taipei".parse().unwrap(); + // 2026-05-16 is Saturday + let saturday = tz.with_ymd_and_hms(2026, 5, 16, 7, 0, 0).unwrap(); + let before = saturday - chrono::Duration::seconds(1); + let next = schedule.after(&before).next(); + assert_eq!(next, Some(saturday), "POSIX `6` must fire on Saturday"); + } + + #[test] + fn name_based_weekday_still_works() { + use chrono::TimeZone; + // Name-based notation should be unaffected by the translation. + let schedule = parse_cron_expr("0 7 * * Mon-Fri").unwrap(); + let tz: Tz = "Asia/Taipei".parse().unwrap(); + let monday = tz.with_ymd_and_hms(2026, 5, 11, 7, 0, 0).unwrap(); + let before = monday - chrono::Duration::seconds(1); + let next = schedule.after(&before).next(); + assert_eq!(next, Some(monday)); + + let sunday = tz.with_ymd_and_hms(2026, 5, 10, 7, 0, 0).unwrap(); + let before = sunday - chrono::Duration::seconds(1); + let next = schedule.after(&before).next(); + assert_ne!(next, Some(sunday)); + } + + #[test] + fn parse_valid_cron_expression() { + let schedule = parse_cron_expr("0 9 * * 1-5").unwrap(); + let next = schedule.upcoming(chrono_tz::UTC).next(); + assert!(next.is_some()); + } + + #[test] + fn parse_every_minute_cron() { + let schedule = parse_cron_expr("* * * * *").unwrap(); + let next = schedule.upcoming(chrono_tz::UTC).next(); + assert!(next.is_some()); + } + + #[test] + fn parse_invalid_cron_expression() { + assert!(parse_cron_expr("not a cron").is_err()); + } + + #[test] + fn parse_invalid_cron_too_many_fields() { + assert!(parse_cron_expr("0 0 9 * * 1-5").is_err()); + } + + #[test] + fn valid_timezone_parses() { + assert!("Asia/Taipei".parse::().is_ok()); + } + + #[test] + fn invalid_timezone_fails() { + assert!("Mars/Olympus".parse::().is_err()); + } + + #[test] + fn utc_timezone_parses() { + assert!("UTC".parse::().is_ok()); + } + + #[test] + fn should_fire_every_minute_returns_true() { + let schedule = parse_cron_expr("* * * * *").unwrap(); + assert!(should_fire(&schedule, chrono_tz::UTC)); + } + + #[test] + fn should_fire_returns_false_for_distant_schedule() { + let schedule = parse_cron_expr("0 0 1 1 *").unwrap(); + let now = chrono::Utc::now(); + if now.month() != 1 || now.day() != 1 || now.hour() != 0 { + assert!(!should_fire(&schedule, chrono_tz::UTC)); + } + } + + #[test] + fn should_fire_respects_timezone() { + let schedule = parse_cron_expr("* * * * *").unwrap(); + let tz: Tz = "Asia/Taipei".parse().unwrap(); + assert!(should_fire(&schedule, tz)); + } + + #[test] + fn cronjob_config_defaults() { + let toml_str = r#" +[[jobs]] +schedule = "0 9 * * 1-5" +channel = "123" +message = "hello" +"#; + let cfg: UsercronFile = toml::from_str(toml_str).unwrap(); + let job = &cfg.jobs[0]; + assert_eq!(job.enabled, true); + assert_eq!(job.platform, "discord"); + assert_eq!(job.sender_name, "openab-cron"); + assert_eq!(job.timezone, "UTC"); + assert!(job.thread_id.is_none()); + assert!(job.id.is_none()); + assert!(job.disable_on_success.is_none()); + assert!(job.disable_on_success_match.is_none()); + assert_eq!(job.disable_on_success_timeout_secs, 60); + assert!(job.disable_on_success_working_dir.is_none()); + } + + #[test] + fn cronjob_config_disabled() { + let toml_str = r#" +[[jobs]] +enabled = false +schedule = "0 9 * * 1-5" +channel = "123" +message = "hello" +"#; + let cfg: UsercronFile = toml::from_str(toml_str).unwrap(); + assert_eq!(cfg.jobs[0].enabled, false); + } + + #[test] + fn cronjob_config_custom_values() { + let toml_str = r#" +[[jobs]] +schedule = "0 18 * * 1-5" +channel = "456" +message = "report" +platform = "slack" +sender_name = "DailyOps" +timezone = "Asia/Taipei" +thread_id = "789" +id = "daily-report" +disable_on_success = "npm test" +disable_on_success_match = "SUCCESS" +disable_on_success_timeout_secs = 30 +disable_on_success_working_dir = "/tmp/project" +"#; + let cfg: UsercronFile = toml::from_str(toml_str).unwrap(); + let job = &cfg.jobs[0]; + assert_eq!(job.platform, "slack"); + assert_eq!(job.sender_name, "DailyOps"); + assert_eq!(job.timezone, "Asia/Taipei"); + assert_eq!(job.thread_id.as_deref(), Some("789")); + assert_eq!(job.id.as_deref(), Some("daily-report")); + assert_eq!(job.disable_on_success.as_deref(), Some("npm test")); + assert_eq!(job.disable_on_success_match.as_deref(), Some("SUCCESS")); + assert_eq!(job.disable_on_success_timeout_secs, 30); + assert_eq!( + job.disable_on_success_working_dir.as_deref(), + Some("/tmp/project") + ); + } + + #[test] + fn load_usercron_nonexistent_returns_empty() { + let jobs = load_usercron_file(Path::new("/tmp/nonexistent-usercron.toml"), &["discord"]); + assert!(jobs.is_empty()); + } + + #[test] + fn load_usercron_valid_file() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("cronjob.toml"); + std::fs::write( + &path, + r#" +[[jobs]] +schedule = "* * * * *" +channel = "123" +message = "ping" +"#, + ) + .unwrap(); + let jobs = load_usercron_file(&path, &["discord"]); + assert_eq!(jobs.len(), 1); + assert_eq!(jobs[0].message, "ping"); + } + + #[test] + fn load_usercron_invalid_toml_returns_empty() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("cronjob.toml"); + std::fs::write(&path, "not valid toml {{{").unwrap(); + let jobs = load_usercron_file(&path, &["discord"]); + assert!(jobs.is_empty()); + } + + #[test] + fn load_usercron_skips_invalid_entries() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("cronjob.toml"); + std::fs::write( + &path, + r#" +[[jobs]] +schedule = "* * * * *" +channel = "123" +message = "good" + +[[jobs]] +schedule = "bad cron" +channel = "456" +message = "bad" +"#, + ) + .unwrap(); + let jobs = load_usercron_file(&path, &["discord"]); + assert_eq!(jobs.len(), 1); + assert_eq!(jobs[0].message, "good"); + } + + #[test] + fn load_usercron_skips_unconfigured_platform() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("cronjob.toml"); + std::fs::write( + &path, + r#" +[[jobs]] +schedule = "* * * * *" +channel = "123" +message = "discord job" + +[[jobs]] +schedule = "* * * * *" +channel = "456" +message = "slack job" +platform = "slack" +"#, + ) + .unwrap(); + // Only discord configured + let jobs = load_usercron_file(&path, &["discord"]); + assert_eq!(jobs.len(), 1); + assert_eq!(jobs[0].message, "discord job"); + } + + #[test] + fn load_usercron_skips_disable_on_success_without_id() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("cronjob.toml"); + std::fs::write( + &path, + r#" +[[jobs]] +schedule = "* * * * *" +channel = "123" +message = "missing id" +disable_on_success = "echo SUCCESS" +disable_on_success_match = "SUCCESS" +"#, + ) + .unwrap(); + let jobs = load_usercron_file(&path, &["discord"]); + assert!(jobs.is_empty()); + } + + #[test] + fn load_usercron_skips_disable_on_success_without_match() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("cronjob.toml"); + std::fs::write( + &path, + r#" +[[jobs]] +id = "goal" +schedule = "* * * * *" +channel = "123" +message = "missing marker" +disable_on_success = "echo SUCCESS" +"#, + ) + .unwrap(); + let jobs = load_usercron_file(&path, &["discord"]); + assert!(jobs.is_empty()); + } + + #[test] + fn validate_cronjobs_rejects_baseline_disable_on_success() { + let jobs = vec![CronJobConfig { + id: Some("baseline-goal".into()), + enabled: true, + schedule: "* * * * *".into(), + channel: "123".into(), + message: "hi".into(), + platform: "discord".into(), + sender_name: "test".into(), + thread_id: None, + timezone: "UTC".into(), + disable_on_success: Some("echo SUCCESS".into()), + disable_on_success_match: Some("SUCCESS".into()), + disable_on_success_timeout_secs: 60, + disable_on_success_working_dir: None, + }]; + let err = validate_cronjobs(&jobs, &["discord"]).unwrap_err(); + assert!(err.to_string().contains("only supported in usercron")); + } + + #[test] + fn update_usercron_job_sets_enabled_and_thread_id_by_id() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("cronjob.toml"); + std::fs::write( + &path, + r#" +[[jobs]] +id = "goal-a" +enabled = true +schedule = "* * * * *" +channel = "123" +message = "a" + +[[jobs]] +id = "goal-b" +enabled = true +schedule = "* * * * *" +channel = "456" +message = "b" +"#, + ) + .unwrap(); + + update_usercron_job(&path, "goal-b", Some(false), Some("thread-456")).unwrap(); + + let updated = std::fs::read_to_string(&path).unwrap(); + let doc = updated.parse::().unwrap(); + let jobs = doc["jobs"].as_array_of_tables().unwrap(); + let job_a = jobs.iter().next().unwrap(); + let job_b = jobs.iter().nth(1).unwrap(); + assert_eq!(job_a["id"].as_str(), Some("goal-a")); + assert_eq!(job_a["enabled"].as_bool(), Some(true)); + assert!(job_a.get("thread_id").is_none()); + assert_eq!(job_b["id"].as_str(), Some("goal-b")); + assert_eq!(job_b["enabled"].as_bool(), Some(false)); + assert_eq!(job_b["thread_id"].as_str(), Some("thread-456")); + } + + #[test] + fn update_usercron_job_errors_for_missing_id() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("cronjob.toml"); + std::fs::write( + &path, + r#" +[[jobs]] +id = "goal-a" +schedule = "* * * * *" +channel = "123" +message = "a" +"#, + ) + .unwrap(); + let err = update_usercron_job(&path, "missing", Some(false), None).unwrap_err(); + assert!(err.to_string().contains("not found")); + } + + #[tokio::test] + async fn disable_on_success_requires_exit_zero_and_marker() { + let mut job = test_cron_job(); + job.disable_on_success_timeout_secs = 5; + + assert!(matches!( + check_disable_on_success(&job, "printf SUCCESS", "SUCCESS").await, + DisableOnSuccessResult::Achieved + )); + assert!(matches!( + check_disable_on_success(&job, "printf DONE", "SUCCESS").await, + DisableOnSuccessResult::NotAchieved("success marker not found") + )); + assert!(matches!( + check_disable_on_success(&job, "printf SUCCESS; exit 1", "SUCCESS").await, + DisableOnSuccessResult::NotAchieved("command exited non-zero") + )); + } + + #[tokio::test] + async fn disable_on_success_kills_child_on_timeout() { + let mut job = test_cron_job(); + job.disable_on_success_timeout_secs = 1; + + let result = check_disable_on_success(&job, "sleep 999", "SUCCESS").await; + assert!(matches!( + result, + DisableOnSuccessResult::NotAchieved("command timed out") + )); + } + + fn test_cron_job() -> CronJobConfig { + CronJobConfig { + id: Some("goal".into()), + enabled: true, + schedule: "* * * * *".into(), + channel: "123".into(), + message: "hi".into(), + platform: "discord".into(), + sender_name: "test".into(), + thread_id: None, + timezone: "UTC".into(), + disable_on_success: Some("echo SUCCESS".into()), + disable_on_success_match: Some("SUCCESS".into()), + disable_on_success_timeout_secs: 60, + disable_on_success_working_dir: None, + } + } + + // --- validate_cronjobs tests --- + + #[test] + fn validate_cronjobs_valid_passes() { + let jobs = vec![CronJobConfig { + id: None, + enabled: true, + schedule: "0 9 * * 1-5".into(), + channel: "123".into(), + message: "hi".into(), + platform: "discord".into(), + sender_name: "test".into(), + thread_id: None, + timezone: "UTC".into(), + disable_on_success: None, + disable_on_success_match: None, + disable_on_success_timeout_secs: 60, + disable_on_success_working_dir: None, + }]; + assert!(validate_cronjobs(&jobs, &["discord"]).is_ok()); + } + + #[test] + fn validate_cronjobs_invalid_cron_fails() { + let jobs = vec![CronJobConfig { + id: None, + enabled: true, + schedule: "bad".into(), + channel: "123".into(), + message: "hi".into(), + platform: "discord".into(), + sender_name: "test".into(), + thread_id: None, + timezone: "UTC".into(), + disable_on_success: None, + disable_on_success_match: None, + disable_on_success_timeout_secs: 60, + disable_on_success_working_dir: None, + }]; + let err = validate_cronjobs(&jobs, &["discord"]).unwrap_err(); + assert!(err.to_string().contains("invalid cron expression")); + } + + #[test] + fn validate_cronjobs_invalid_timezone_fails() { + let jobs = vec![CronJobConfig { + id: None, + enabled: true, + schedule: "* * * * *".into(), + channel: "123".into(), + message: "hi".into(), + platform: "discord".into(), + sender_name: "test".into(), + thread_id: None, + timezone: "Mars/Olympus".into(), + disable_on_success: None, + disable_on_success_match: None, + disable_on_success_timeout_secs: 60, + disable_on_success_working_dir: None, + }]; + let err = validate_cronjobs(&jobs, &["discord"]).unwrap_err(); + assert!(err.to_string().contains("invalid timezone")); + } + + #[test] + fn validate_cronjobs_unknown_platform_fails() { + let jobs = vec![CronJobConfig { + id: None, + enabled: true, + schedule: "* * * * *".into(), + channel: "123".into(), + message: "hi".into(), + platform: "telegram".into(), + sender_name: "test".into(), + thread_id: None, + timezone: "UTC".into(), + disable_on_success: None, + disable_on_success_match: None, + disable_on_success_timeout_secs: 60, + disable_on_success_working_dir: None, + }]; + let err = validate_cronjobs(&jobs, &["discord"]).unwrap_err(); + assert!(err.to_string().contains("unknown platform")); + } + + #[test] + fn validate_cronjobs_unconfigured_platform_fails() { + let jobs = vec![CronJobConfig { + id: None, + enabled: true, + schedule: "* * * * *".into(), + channel: "123".into(), + message: "hi".into(), + platform: "slack".into(), + sender_name: "test".into(), + thread_id: None, + timezone: "UTC".into(), + disable_on_success: None, + disable_on_success_match: None, + disable_on_success_timeout_secs: 60, + disable_on_success_working_dir: None, + }]; + let err = validate_cronjobs(&jobs, &["discord"]).unwrap_err(); + assert!(err.to_string().contains("not configured")); + } + + #[test] + fn validate_cronjobs_disabled_with_invalid_cron_passes() { + let jobs = vec![CronJobConfig { + id: None, + enabled: false, + schedule: "bad".into(), + channel: "123".into(), + message: "hi".into(), + platform: "discord".into(), + sender_name: "test".into(), + thread_id: None, + timezone: "UTC".into(), + disable_on_success: None, + disable_on_success_match: None, + disable_on_success_timeout_secs: 60, + disable_on_success_working_dir: None, + }]; + assert!(validate_cronjobs(&jobs, &["discord"]).is_ok()); + } + + #[test] + fn validate_cronjobs_enabled_with_invalid_cron_still_fails() { + let jobs = vec![CronJobConfig { + id: None, + enabled: true, + schedule: "bad".into(), + channel: "123".into(), + message: "hi".into(), + platform: "discord".into(), + sender_name: "test".into(), + thread_id: None, + timezone: "UTC".into(), + disable_on_success: None, + disable_on_success_match: None, + disable_on_success_timeout_secs: 60, + disable_on_success_working_dir: None, + }]; + assert!(validate_cronjobs(&jobs, &["discord"]).is_err()); + } + + // --- file_mtime tests --- + + #[test] + fn file_mtime_detects_change() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("test.toml"); + assert!(file_mtime(&path).is_none()); // doesn't exist yet + std::fs::write(&path, "v1").unwrap(); + let m1 = file_mtime(&path); + assert!(m1.is_some()); + // Sleep briefly to ensure mtime differs + std::thread::sleep(std::time::Duration::from_millis(50)); + std::fs::write(&path, "v2").unwrap(); + let m2 = file_mtime(&path); + assert!(m2.is_some()); + assert!(m2 != m1); + } + + // --- CronConfig TOML deserialization --- + + #[test] + fn cron_config_toml_parses() { + use crate::config::Config; + let toml_str = r#" +[agent] +command = "echo" + +[cron] +usercron_enabled = true +usercron_path = "cronjob.toml" + +[[cron.jobs]] +schedule = "0 9 * * 1-5" +channel = "123" +message = "hello" + +[[cron.jobs]] +schedule = "*/30 * * * *" +channel = "456" +message = "ping" +platform = "slack" +"#; + let cfg: Config = toml::from_str(toml_str).unwrap(); + assert!(cfg.cron.usercron_enabled); + assert_eq!(cfg.cron.usercron_path.as_deref(), Some("cronjob.toml")); + assert_eq!(cfg.cron.jobs.len(), 2); + assert_eq!(cfg.cron.jobs[0].message, "hello"); + assert_eq!(cfg.cron.jobs[1].platform, "slack"); + } + + #[test] + fn cron_config_defaults_when_omitted() { + use crate::config::Config; + let toml_str = r#" +[agent] +command = "echo" +"#; + let cfg: Config = toml::from_str(toml_str).unwrap(); + assert!(!cfg.cron.usercron_enabled); + assert!(cfg.cron.usercron_path.is_none()); + assert!(cfg.cron.jobs.is_empty()); + } + + // --- load_usercron empty file --- + + #[test] + fn load_usercron_empty_file_returns_empty() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("cronjob.toml"); + std::fs::write(&path, "").unwrap(); + let jobs = load_usercron_file(&path, &["discord"]); + assert!(jobs.is_empty()); + } +} diff --git a/crates/openab-core/src/directives.rs b/crates/openab-core/src/directives.rs new file mode 100644 index 000000000..c1d5c27f2 --- /dev/null +++ b/crates/openab-core/src/directives.rs @@ -0,0 +1,314 @@ +//! Control Directives parser (ADR: control-directives.md). +//! +//! Extracts leading `[[key:value]]` directives from the first message in a +//! session, strips them from the prompt, and returns structured metadata. + +use regex::Regex; +use std::collections::HashMap; +use std::path::{Path, PathBuf}; +use std::sync::LazyLock; +use tracing::warn; + +static DIRECTIVE_RE: LazyLock = + LazyLock::new(|| Regex::new(r"^\s*\[\[([a-z_]+):([^\]]*)\]\]").unwrap()); + +/// Parsed control directives from a session's first message. +#[derive(Debug, Clone, Default)] +pub struct SessionMetadata { + /// Resolved canonical workspace path (None = use default working_dir). + #[allow(dead_code)] + pub workspace: Option, + /// Thread title override (None = use generated title). + pub title: Option, + /// Raw directives map for forward-compatible unknown keys. + pub raw: HashMap, +} + +/// Result of parsing directives from a prompt. +pub struct ParseResult { + /// The prompt with leading directives stripped. + pub prompt: String, + /// Parsed session metadata. + pub metadata: SessionMetadata, +} + +/// Parse leading `[[key:value]]` directives from a prompt string. +/// +/// Directives must appear at the start of the message (after optional +/// whitespace). The first line/token that is not a directive stops parsing; +/// any `[[key:value]]` text after that point is preserved verbatim. +pub fn parse_directives(input: &str) -> ParseResult { + let mut raw: HashMap = HashMap::new(); + let mut remaining = input; + + loop { + remaining = remaining.trim_start_matches([' ', '\t']); + if remaining.starts_with('\n') || remaining.starts_with("\r\n") { + // A blank line after directives = end of header + let next = remaining.trim_start_matches(['\r', '\n']); + let next_trimmed = next.trim_start_matches([' ', '\t']); + if !next_trimmed.starts_with("[[") { + remaining = next; + break; + } + remaining = remaining.trim_start_matches(['\r', '\n']); + } + if let Some(caps) = DIRECTIVE_RE.captures(remaining) { + let full_match = caps.get(0).unwrap(); + let key = caps[1].to_string(); + let value = caps[2].to_string(); + // Last value wins for duplicate keys + raw.insert(key, value); + remaining = &remaining[full_match.end()..]; + } else { + break; + } + } + + let prompt = remaining.trim().to_string(); + let metadata = SessionMetadata { + workspace: None, // resolved later by resolve_workspace + title: raw.get("title").cloned(), + raw, + }; + + ParseResult { prompt, metadata } +} + +/// Resolve the `[[ws:...]]` directive value into a canonical path. +/// +/// Supports: +/// - Raw paths: `~/projects/foo` or `/home/bot/projects/foo` +/// - Aliases: `@alias_name` → looked up in `aliases` map +/// +/// Returns `Err` with a user-visible message on failure. +pub fn resolve_workspace( + raw_value: &str, + aliases: &HashMap, + bot_home: &Path, +) -> Result { + let path_str = if let Some(alias) = raw_value.strip_prefix('@') { + match aliases.get(alias) { + Some(resolved) => resolved.as_str(), + None => { + let available: Vec<&str> = aliases.keys().map(|s| s.as_str()).collect(); + return Err(format!( + "Unknown workspace alias `@{alias}`. Available: {}", + if available.is_empty() { + "(none configured)".to_string() + } else { + available.join(", ") + } + )); + } + } + } else { + raw_value + }; + + // Rule 1: reject relative paths + if !path_str.starts_with('~') && !path_str.starts_with('/') { + return Err(format!( + "Workspace path must be absolute (start with `~` or `/`): `{path_str}`" + )); + } + + // Rule 2: expand ~ + let expanded = if let Some(rest) = path_str.strip_prefix('~') { + let rest = rest.strip_prefix('/').unwrap_or(rest); + bot_home.join(rest) + } else { + PathBuf::from(path_str) + }; + + // Rule 3: canonicalize both paths + let canonical_home = bot_home.canonicalize().map_err(|e| { + warn!(path = %bot_home.display(), error = %e, "cannot canonicalize bot home"); + "Internal error: cannot resolve bot home directory".to_string() + })?; + + let canonical_target = expanded.canonicalize().map_err(|e| { + warn!(path = %expanded.display(), error = %e, "cannot canonicalize workspace path"); + format!( + "Workspace path does not exist: `{path_str}` (expanded to `{}`)", + expanded.display() + ) + })?; + + // Rule 4+5: verify within bot home subtree + if !canonical_target.starts_with(&canonical_home) { + return Err(format!( + "Workspace path is outside allowed directory: `{path_str}`" + )); + } + + // Rule 6: must be a directory (not a file) + if !canonical_target.is_dir() { + return Err(format!( + "Workspace path is not a directory: `{}`", + canonical_target.display() + )); + } + + Ok(canonical_target) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + use tempfile::TempDir; + + #[test] + fn parse_basic_directives() { + let input = "[[ws:~/projects/foo]] [[title:Bug fix]]\ninvestigate the build failure"; + let result = parse_directives(input); + assert_eq!(result.prompt, "investigate the build failure"); + assert_eq!(result.metadata.raw.get("ws").unwrap(), "~/projects/foo"); + assert_eq!(result.metadata.title.as_deref(), Some("Bug fix")); + } + + #[test] + fn parse_directives_multiline_header() { + let input = "[[ws:@openab]]\n[[title:Review PR]]\nplease review this change"; + let result = parse_directives(input); + assert_eq!(result.prompt, "please review this change"); + assert_eq!(result.metadata.raw.get("ws").unwrap(), "@openab"); + assert_eq!(result.metadata.title.as_deref(), Some("Review PR")); + } + + #[test] + fn parse_preserves_body_directives() { + let input = "[[title:Test]]\nHere is some code with [[key:value]] in it"; + let result = parse_directives(input); + assert_eq!(result.prompt, "Here is some code with [[key:value]] in it"); + assert_eq!(result.metadata.title.as_deref(), Some("Test")); + assert!(!result.metadata.raw.contains_key("key")); + } + + #[test] + fn parse_no_directives() { + let input = "just a regular message"; + let result = parse_directives(input); + assert_eq!(result.prompt, "just a regular message"); + assert!(result.metadata.raw.is_empty()); + } + + #[test] + fn parse_duplicate_keys_last_wins() { + let input = "[[title:First]] [[title:Second]]\ndo stuff"; + let result = parse_directives(input); + assert_eq!(result.metadata.title.as_deref(), Some("Second")); + } + + #[test] + fn parse_empty_value() { + let input = "[[title:]]\ndo stuff"; + let result = parse_directives(input); + assert_eq!(result.metadata.title.as_deref(), Some("")); + } + + #[test] + fn parse_unknown_keys_ignored() { + let input = "[[foo:bar]] [[ws:~/x]]\ndo stuff"; + let result = parse_directives(input); + assert_eq!(result.metadata.raw.get("foo").unwrap(), "bar"); + assert_eq!(result.prompt, "do stuff"); + } + + #[test] + fn resolve_alias_success() { + let tmp = TempDir::new().unwrap(); + let projects = tmp.path().join("projects").join("openab"); + fs::create_dir_all(&projects).unwrap(); + + let mut aliases = HashMap::new(); + aliases.insert( + "openab".to_string(), + format!("{}/projects/openab", tmp.path().display()), + ); + + let result = resolve_workspace("@openab", &aliases, tmp.path()).unwrap(); + assert_eq!(result, projects.canonicalize().unwrap()); + } + + #[test] + fn resolve_alias_not_found() { + let tmp = TempDir::new().unwrap(); + let aliases = HashMap::new(); + let result = resolve_workspace("@nope", &aliases, tmp.path()); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("Unknown workspace alias")); + } + + #[test] + fn resolve_relative_path_rejected() { + let tmp = TempDir::new().unwrap(); + let aliases = HashMap::new(); + let result = resolve_workspace("relative/path", &aliases, tmp.path()); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("must be absolute")); + } + + #[test] + fn resolve_outside_home_rejected() { + let tmp = TempDir::new().unwrap(); + let aliases = HashMap::new(); + let result = resolve_workspace("/tmp", &aliases, tmp.path()); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("outside allowed directory")); + } + + #[test] + fn resolve_tilde_expansion() { + let tmp = TempDir::new().unwrap(); + let projects = tmp.path().join("myapp"); + fs::create_dir_all(&projects).unwrap(); + + let aliases = HashMap::new(); + let result = resolve_workspace("~/myapp", &aliases, tmp.path()).unwrap(); + assert_eq!(result, projects.canonicalize().unwrap()); + } + + #[test] + fn resolve_nonexistent_path() { + let tmp = TempDir::new().unwrap(); + let aliases = HashMap::new(); + let result = resolve_workspace("~/does_not_exist", &aliases, tmp.path()); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("does not exist")); + } + + #[test] + fn parse_directives_leading_spaces_on_newline() { + let input = "[[ws:@openab]]\n [[title:Fix CI]]\nhelp me debug"; + let result = parse_directives(input); + assert_eq!(result.prompt, "help me debug"); + assert_eq!(result.metadata.raw.get("ws").unwrap(), "@openab"); + assert_eq!(result.metadata.title.as_deref(), Some("Fix CI")); + } + + #[test] + fn resolve_file_path_rejected() { + let tmp = TempDir::new().unwrap(); + let file_path = tmp.path().join("Cargo.toml"); + fs::write(&file_path, "").unwrap(); + + let aliases = HashMap::new(); + let result = resolve_workspace(&format!("{}", file_path.display()), &aliases, tmp.path()); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("not a directory")); + } + + #[test] + fn resolve_error_shows_expanded_path() { + let tmp = TempDir::new().unwrap(); + let aliases = HashMap::new(); + let result = resolve_workspace("~/no_such_dir", &aliases, tmp.path()); + assert!(result.is_err()); + let err = result.unwrap_err(); + // Error should contain both the original and expanded path + assert!(err.contains("~/no_such_dir")); + assert!(err.contains(&tmp.path().display().to_string())); + } +} diff --git a/crates/openab-core/src/discord.rs b/crates/openab-core/src/discord.rs new file mode 100644 index 000000000..12281afad --- /dev/null +++ b/crates/openab-core/src/discord.rs @@ -0,0 +1,3203 @@ +use crate::acp::protocol::ConfigOption; +use crate::acp::ContentBlock; +use crate::adapter::{AdapterRouter, ChannelRef, ChatAdapter, MessageRef, SenderContext}; +use crate::bot_turns::{BotTurnTracker, TurnAction, TurnSeverity, BOT_TURN_LIMIT_WARNING_PREFIX}; +use crate::config::{AllowBots, AllowUsers, SttConfig}; +use crate::format; +use crate::media; +use crate::remind::{self, ReminderStore}; +use async_trait::async_trait; +use serenity::builder::{ + CreateActionRow, CreateAttachment, CreateButton, CreateCommand, CreateCommandOption, + CreateInteractionResponse, CreateInteractionResponseFollowup, CreateInteractionResponseMessage, + CreateSelectMenu, CreateSelectMenuKind, CreateSelectMenuOption, CreateThread, EditChannel, + EditMessage, GetMessages, +}; +use serenity::http::Http; +use serenity::model::application::ButtonStyle; +use serenity::model::application::{Command, CommandOptionType, ComponentInteractionDataKind, Interaction}; +use serenity::model::channel::{AutoArchiveDuration, Message, MessageType, ReactionType}; +use serenity::model::gateway::Ready; +use serenity::model::id::{ChannelId, MessageId, UserId}; +use serenity::prelude::*; +use std::collections::{HashMap, HashSet}; +use std::sync::LazyLock; +use std::sync::{Arc, OnceLock}; +use tracing::{debug, error, info, warn}; + +/// Hard cap on consecutive bot messages in a channel or thread. +/// Prevents runaway loops between multiple bots in "all" mode. +const MAX_CONSECUTIVE_BOT_TURNS: u32 = 1000; + +/// Maximum entries in the participation cache before eviction. +const PARTICIPATION_CACHE_MAX: usize = 1000; + +/// Discord StringSelectMenu hard limit on options. +const SELECT_MENU_PAGE_SIZE: usize = 25; + +/// Avoid unbounded Discord history exports from very large threads. +const THREAD_EXPORT_MESSAGE_LIMIT: usize = 5000; + +// --- DiscordAdapter: implements ChatAdapter for Discord via serenity --- + +pub struct DiscordAdapter { + http: Arc, +} + +impl DiscordAdapter { + pub fn new(http: Arc) -> Self { + Self { http } + } + + /// Resolve the effective Discord channel ID from a ChannelRef. + /// Discord threads are channels, so prefer thread_id when set. + fn resolve_channel(channel: &ChannelRef) -> &str { + channel.thread_id.as_deref().unwrap_or(&channel.channel_id) + } +} + +#[async_trait] +impl ChatAdapter for DiscordAdapter { + fn platform(&self) -> &'static str { + "discord" + } + + fn message_limit(&self) -> usize { + 2000 + } + + async fn send_message( + &self, + channel: &ChannelRef, + content: &str, + ) -> anyhow::Result { + let ch_id: u64 = Self::resolve_channel(channel).parse()?; + let msg = ChannelId::new(ch_id).say(&self.http, content).await?; + Ok(MessageRef { + channel: channel.clone(), + message_id: msg.id.to_string(), + }) + } + + async fn send_message_with_reply( + &self, + channel: &ChannelRef, + content: &str, + reply_to_message_id: &str, + ) -> anyhow::Result { + let ch_id: u64 = Self::resolve_channel(channel).parse()?; + let msg_id: u64 = reply_to_message_id.parse().unwrap_or(0); + if msg_id == 0 { + // Invalid message ID, fall back to plain send + return self.send_message(channel, content).await; + } + let builder = serenity::builder::CreateMessage::new() + .content(content) + .reference_message((ChannelId::new(ch_id), MessageId::new(msg_id))); + match ChannelId::new(ch_id) + .send_message(&self.http, builder) + .await + { + Ok(msg) => Ok(MessageRef { + channel: channel.clone(), + message_id: msg.id.to_string(), + }), + Err(e) => { + // Fallback to plain send if reply fails (e.g. unknown message, cross-channel) + tracing::warn!(error = ?e, reply_to = reply_to_message_id, "reply_to failed, falling back to plain send"); + self.send_message(channel, content).await + } + } + } + + async fn delete_message(&self, msg: &MessageRef) -> anyhow::Result<()> { + let ch_id: u64 = Self::resolve_channel(&msg.channel).parse()?; + let msg_id: u64 = msg.message_id.parse()?; + self.http + .delete_message(ChannelId::new(ch_id), MessageId::new(msg_id), None) + .await?; + Ok(()) + } + + async fn edit_message(&self, msg: &MessageRef, content: &str) -> anyhow::Result<()> { + let ch_id: u64 = Self::resolve_channel(&msg.channel).parse()?; + let msg_id: u64 = msg.message_id.parse()?; + ChannelId::new(ch_id) + .edit_message( + &self.http, + MessageId::new(msg_id), + EditMessage::new().content(content), + ) + .await?; + Ok(()) + } + + fn use_streaming(&self, other_bot_present: bool) -> bool { + !other_bot_present + } + + async fn create_thread( + &self, + channel: &ChannelRef, + trigger_msg: &MessageRef, + title: &str, + ) -> anyhow::Result { + let ch_id: u64 = channel.channel_id.parse()?; + let msg_id: u64 = trigger_msg.message_id.parse()?; + let thread = ChannelId::new(ch_id) + .create_thread_from_message( + &self.http, + MessageId::new(msg_id), + CreateThread::new(title).auto_archive_duration(AutoArchiveDuration::OneDay), + ) + .await?; + Ok(ChannelRef { + platform: "discord".into(), + channel_id: thread.id.to_string(), + thread_id: None, + parent_id: Some(channel.channel_id.clone()), + origin_event_id: None, + }) + } + + async fn add_reaction(&self, msg: &MessageRef, emoji: &str) -> anyhow::Result<()> { + let ch_id: u64 = Self::resolve_channel(&msg.channel).parse()?; + let msg_id: u64 = msg.message_id.parse()?; + self.http + .create_reaction( + ChannelId::new(ch_id), + MessageId::new(msg_id), + &ReactionType::Unicode(emoji.to_string()), + ) + .await?; + Ok(()) + } + + async fn remove_reaction(&self, msg: &MessageRef, emoji: &str) -> anyhow::Result<()> { + let ch_id: u64 = Self::resolve_channel(&msg.channel).parse()?; + let msg_id: u64 = msg.message_id.parse()?; + self.http + .delete_reaction_me( + ChannelId::new(ch_id), + MessageId::new(msg_id), + &ReactionType::Unicode(emoji.to_string()), + ) + .await?; + Ok(()) + } + + async fn rename_thread(&self, channel: &ChannelRef, title: &str) -> anyhow::Result<()> { + let ch_id: u64 = Self::resolve_channel(channel).parse()?; + // Truncate at char boundary to avoid panic on multi-byte chars (中文/Emoji). + let truncated: &str = if title.chars().count() > 100 { + let end = title.char_indices().nth(100).map(|(i, _)| i).unwrap_or(title.len()); + &title[..end] + } else { + title + }; + ChannelId::new(ch_id) + .edit(&self.http, EditChannel::new().name(truncated)) + .await?; + Ok(()) + } +} + +// --- Handler: serenity EventHandler that delegates to AdapterRouter --- + +pub struct Handler { + pub router: Arc, + pub allow_all_channels: bool, + pub allow_all_users: bool, + pub allowed_channels: HashSet, + pub allowed_users: HashSet, + pub stt_config: SttConfig, + pub adapter: OnceLock>, + pub allow_bot_messages: AllowBots, + pub trusted_bot_ids: HashSet, + pub allow_user_messages: AllowUsers, + /// Role IDs that trigger the bot (same as direct @mention). + pub allowed_role_ids: HashSet, + /// Positive-only cache: thread channel_id → cached_at for threads where bot has participated. + pub participated_threads: tokio::sync::Mutex>, + /// Positive-only cache: thread channel_id → cached_at for threads where other bots have posted. + /// Like participation, a thread becoming multi-bot is irreversible (bot messages don't disappear). + pub multibot_threads: tokio::sync::Mutex>, + /// Persistent disk cache for multibot thread detection (survives restarts). + pub multibot_cache: crate::multibot_cache::MultibotCache, + /// TTL for participation cache entries (from pool.session_ttl_hours). + pub session_ttl: std::time::Duration, + /// Configurable soft limit on bot turns per thread (reset by human message). + pub max_bot_turns: u32, + /// Per-thread bot turn tracker. Both counters reset on human msg. + pub bot_turns: tokio::sync::Mutex, + /// Allow the bot to respond to Discord DMs. + pub allow_dm: bool, + /// Per-thread dispatcher (Message mode uses cap=1 for FIFO; Thread/Lane use configured cap). + pub dispatcher: Arc, + /// Reminder store for /remind slash command. + pub reminder_store: ReminderStore, + /// Track scheduled reminder IDs to prevent duplicate scheduling on reconnect. + pub scheduled_ids: tokio::sync::Mutex>, +} + +impl Handler { + /// Check if the bot has participated in a Discord thread, and whether + /// other bots have also posted in it. + /// Returns `(involved, other_bot_present)`. + /// Fail-closed: returns `(false, false)` on API error. + /// Caches positive results only (both participation and multi-bot status are irreversible). + async fn bot_participated_in_thread( + &self, + http: &Http, + channel_id: ChannelId, + bot_id: UserId, + ) -> (bool, bool) { + let key = channel_id.to_string(); + + // Check positive caches + let cached_involved = { + let cache = self.participated_threads.lock().await; + cache + .get(&key) + .is_some_and(|ts| ts.elapsed() < self.session_ttl) + }; + let cached_multibot = { + let cache = self.multibot_threads.lock().await; + cache + .get(&key) + .is_some_and(|ts| ts.elapsed() < self.session_ttl) + } || self.multibot_cache.is_multibot(&key); + + // Both cached → skip fetch entirely + // With early detection from msg.author, multibot_threads is populated + // eagerly — no need to fetch just to check for other bots. + if cached_involved { + return (true, cached_multibot); + } + + // Fetch recent messages + let messages = match channel_id + .messages(http, serenity::builder::GetMessages::new().limit(200)) + .await + { + Ok(msgs) => msgs, + Err(e) => { + tracing::warn!( + channel_id = %channel_id, + error = %e, + "failed to fetch thread messages for participation check, rejecting (fail-closed)" + ); + return (false, false); + } + }; + + let involved = cached_involved || messages.iter().any(|m| m.author.id == bot_id); + // other_bot_present relies solely on early detection + disk cache; + // no longer scanned from fetched messages (200-msg window was unreliable). + let other_bot_present = cached_multibot; + + if involved && !cached_involved { + let mut cache = self.participated_threads.lock().await; + cache.insert(key.clone(), tokio::time::Instant::now()); + + // Evict if over capacity + if cache.len() > PARTICIPATION_CACHE_MAX { + cache.retain(|_, ts| ts.elapsed() < self.session_ttl); + if cache.len() > PARTICIPATION_CACHE_MAX { + let mut entries: Vec<_> = cache.iter().map(|(k, v)| (k.clone(), *v)).collect(); + entries.sort_by_key(|(_, ts)| *ts); + let evict_count = entries.len() / 2; + for (k, _) in entries.into_iter().take(evict_count) { + cache.remove(&k); + } + } + } + } + + (involved, other_bot_present) + } +} + +#[serenity::async_trait] +impl EventHandler for Handler { + async fn message(&self, ctx: Context, msg: Message) { + let bot_id = ctx.cache.current_user().id; + + // Early multibot detection: cache that another bot is present. + // Runs before self-check and bot gating so we always detect other bots. (#481) + if msg.author.bot && msg.author.id != bot_id { + let key = msg.channel_id.to_string(); + { + let mut cache = self.multibot_threads.lock().await; + cache.entry(key.clone()).or_insert_with(tokio::time::Instant::now); + } + // Persist to disk — multibot is irreversible + self.multibot_cache.mark_multibot(&key).await; + } + + // Bot turn counting: runs before self-check so ALL bot messages + // (including own) count toward the per-thread limit. This means + // soft_limit=20 = 20 total bot messages in the thread (~10 per bot + // in a two-bot ping-pong). (#483) + { + let thread_key = msg.channel_id.to_string(); + let mut tracker = self.bot_turns.lock().await; + if msg.author.bot { + match tracker.classify_bot_message(&thread_key) { + TurnAction::Continue => {} + TurnAction::SilentStop => return, + TurnAction::WarnAndStop { + severity, + turns, + user_message, + } => { + match severity { + TurnSeverity::Hard => tracing::warn!( + channel_id = %msg.channel_id, + turns, + "hard bot turn limit reached", + ), + TurnSeverity::Soft => tracing::info!( + channel_id = %msg.channel_id, + turns, + max = self.max_bot_turns, + "soft bot turn limit reached", + ), + } + // Only post the warning if this bot is allowed in the channel/thread. + // Bot turn counting intentionally runs before channel gating so ALL + // bot messages are counted, but the *warning message* must respect + // channel permissions — otherwise bots that never participated in a + // thread will spam it with warnings. + // + // Must match the full thread allowlist semantics: a thread is allowed + // if its own channel_id OR its parent_id is in allowed_channels. + let ch = msg.channel_id.get(); + let in_allowed_channel = self.allowed_channels.contains(&ch); + let mut allowed_here = self.allow_all_channels || in_allowed_channel; + if !allowed_here { + // Reuse detect_thread() for thread allowlist semantics. + // Only called on the WarnAndStop path (once per soft/hard + // limit hit), not on every bot message. + if let Ok(serenity::model::channel::Channel::Guild(gc)) = + msg.channel_id.to_channel(&ctx.http).await + { + let (in_thread, _) = detect_thread( + gc.thread_metadata.is_some(), + gc.parent_id.map(|id| id.get()), + gc.owner_id.map(|id| id.get()), + bot_id.get(), + &self.allowed_channels, + self.allow_all_channels, + in_allowed_channel, + ); + if in_thread { + allowed_here = true; + } + } + } + if msg.author.id != bot_id && allowed_here { + // Only warn if this bot actually participated in the + // thread — prevents uninvolved bots from spamming + // warnings in shared channels. (#727) + // Second value is `is_multibot`; not needed here. + let (participated, _) = self + .bot_participated_in_thread(&ctx.http, msg.channel_id, bot_id) + .await; + if participated { + // Dedup: skip if another bot already posted the same + // warning in this thread. Prevents N duplicate warnings + // when N bot processes each hit the soft limit. (#530) + let recent = msg + .channel_id + .messages( + &ctx.http, + serenity::builder::GetMessages::new().limit(10), + ) + .await + .unwrap_or_default(); + let pairs: Vec<(bool, &str)> = recent + .iter() + .map(|m| (m.author.bot, m.content.as_str())) + .collect(); + let already_warned = turn_limit_warning_present(&pairs); + if !already_warned { + let _ = msg.channel_id.say(&ctx.http, &user_message).await; + } + } + } + return; + } + } + } else if matches!(msg.kind, MessageType::Regular | MessageType::InlineReply) + && !msg.content.is_empty() + { + tracker.on_human_message(&thread_key); + } + } + + // Ignore own messages (after counting toward bot turns above) + if msg.author.id == bot_id { + return; + } + + let adapter = self + .adapter + .get_or_init(|| Arc::new(DiscordAdapter::new(ctx.http.clone()))) + .clone(); + + let channel_id = msg.channel_id.get(); + let in_allowed_channel = + self.allow_all_channels || self.allowed_channels.contains(&channel_id); + + let is_mentioned = msg.mentions_user_id(bot_id) + || msg.content.contains(&format!("<@{}>", bot_id)) + || (!self.allowed_role_ids.is_empty() + && msg + .mention_roles + .iter() + .any(|r| self.allowed_role_ids.contains(&r.get()))); + + // Bot message gating (from upstream #321) + if msg.author.bot { + // Trusted bot admission override: when a bot listed in `trusted_bot_ids` + // explicitly @mentions this bot, bypass the entire `allow_bot_messages` + // mode check. This treats the trusted bot's @mention identically to a + // human @mention — the bot becomes involved in the thread and the message + // is dispatched regardless of the `allow_bot_messages` setting. + // + // Rationale: `trusted_bot_ids` expresses admin-level trust. A trusted bot + // that @mentions this bot is performing a deliberate handoff/coordination + // action, equivalent to a human pulling the bot into a conversation. + // + // Safety: requires both (1) explicit @mention AND (2) sender in + // trusted_bot_ids. Messages from trusted bots without @mention still + // follow normal gating. Empty trusted_bot_ids (default) disables this + // entirely — no behavioral change for existing deployments. + let trusted_mention = is_mentioned + && !self.trusted_bot_ids.is_empty() + && self.trusted_bot_ids.contains(&msg.author.id.get()); + + if !trusted_mention { + match self.allow_bot_messages { + AllowBots::Off => return, + AllowBots::Mentions => { + if !is_mentioned { + return; + } + } + AllowBots::All => { + let cap = MAX_CONSECUTIVE_BOT_TURNS as usize; + let limit = std::cmp::min(MAX_CONSECUTIVE_BOT_TURNS, 100) as u8; + let history = ctx + .cache + .channel_messages(msg.channel_id) + .map(|msgs| { + let mut recent: Vec<_> = msgs + .iter() + .filter(|(mid, _)| **mid < msg.id) + .map(|(_, m)| m.clone()) + .collect(); + recent.sort_unstable_by_key(|m| std::cmp::Reverse(m.id)); + recent.truncate(cap); + recent + }) + .filter(|msgs| !msgs.is_empty()); + + let recent = if let Some(cached) = history { + cached + } else { + match msg + .channel_id + .messages( + &ctx.http, + serenity::builder::GetMessages::new() + .before(msg.id) + .limit(limit), + ) + .await + { + Ok(msgs) => msgs, + Err(e) => { + tracing::warn!(channel_id = %msg.channel_id, error = %e, "failed to fetch history for bot turn cap, rejecting (fail-closed)"); + return; + } + } + }; + + let consecutive_bot = recent + .iter() + .take_while(|m| m.author.bot && m.author.id != bot_id) + .count(); + if consecutive_bot >= cap { + tracing::warn!(channel_id = %msg.channel_id, cap, "bot turn cap reached, ignoring"); + return; + } + } + } + + if !self.trusted_bot_ids.is_empty() + && !self.trusted_bot_ids.contains(&msg.author.id.get()) + { + tracing::debug!(bot_id = %msg.author.id, "bot not in trusted_bot_ids, ignoring"); + return; + } + } + } + + // Thread detection: single to_channel() call for both allowed and + // non-allowed channels. Uses thread_metadata (not parent_id) to + // identify threads — see detect_thread() doc comments for rationale. + let (in_thread, bot_owns_thread, thread_parent_id, is_dm) = match msg + .channel_id + .to_channel(&ctx.http) + .await + { + Ok(serenity::model::channel::Channel::Guild(gc)) => { + let parent = gc.parent_id.map(|id| id.get().to_string()); + let result = detect_thread( + gc.thread_metadata.is_some(), + gc.parent_id.map(|id| id.get()), + gc.owner_id.map(|id| id.get()), + bot_id.get(), + &self.allowed_channels, + self.allow_all_channels, + in_allowed_channel, + ); + tracing::debug!( + channel_id = %msg.channel_id, + parent_id = ?gc.parent_id, + owner_id = ?gc.owner_id, + has_thread_metadata = gc.thread_metadata.is_some(), + in_thread = result.0, + bot_owns = ?result.1, + "thread check" + ); + ( + result.0, + result.1.unwrap_or(false), + if result.0 { parent } else { None }, + false, + ) + } + Ok(serenity::model::channel::Channel::Private(_)) => { + tracing::debug!(channel_id = %msg.channel_id, "DM channel"); + (false, false, None, true) + } + Ok(other) => { + tracing::debug!(channel_id = %msg.channel_id, kind = ?other, "not a guild thread"); + (false, false, None, false) + } + Err(e) => { + tracing::debug!(channel_id = %msg.channel_id, error = %e, "to_channel failed"); + (false, false, None, false) + } + }; + + // DM gating: allow_dm must be true, otherwise reject + if is_dm && !self.allow_dm { + tracing::debug!(channel_id = %msg.channel_id, "DM rejected (allow_dm=false)"); + return; + } + + if !is_dm && !in_allowed_channel && !in_thread { + return; + } + + // User message gating (mirrors Slack's AllowUsers logic). + // Mentions: always require @mention, even in bot's own threads. + // Involved (default): skip @mention if the bot owns the thread + // (Option A) OR has previously posted in it (Option B). + // MultibotMentions: same as Involved, but if other bots are also + // in the thread, require @mention to avoid all bots responding. + // DMs are treated as implicit @mention (mirrors Slack behavior). + if !is_mentioned && !is_dm { + match self.allow_user_messages { + AllowUsers::Mentions => return, + AllowUsers::Involved => { + if !in_thread { + return; + } + let (involved, _) = if bot_owns_thread { + (true, false) // other_bot_present not needed for Involved mode + } else { + self.bot_participated_in_thread(&ctx.http, msg.channel_id, bot_id) + .await + }; + if !involved { + tracing::debug!(channel_id = %msg.channel_id, "bot not involved in thread, ignoring"); + return; + } + } + AllowUsers::MultibotMentions => { + if !in_thread { + return; + } + let (involved, other_bot) = if bot_owns_thread { + // Still need to check for other bots + let (_, other) = self + .bot_participated_in_thread(&ctx.http, msg.channel_id, bot_id) + .await; + (true, other) + } else { + self.bot_participated_in_thread(&ctx.http, msg.channel_id, bot_id) + .await + }; + if !involved { + tracing::debug!(channel_id = %msg.channel_id, "bot not involved in thread, ignoring"); + return; + } + if other_bot { + tracing::debug!(channel_id = %msg.channel_id, "multi-bot thread, requiring @mention"); + return; + } + } + } + } + + if is_denied_user( + msg.author.bot, + self.allow_all_users, + &self.allowed_users, + msg.author.id.get(), + ) { + tracing::info!(user_id = %msg.author.id, "denied user, ignoring"); + let msg_ref = discord_msg_ref(&msg); + let _ = adapter.add_reaction(&msg_ref, "🚫").await; + return; + } + + let prompt = resolve_mentions(&msg.content, bot_id, &self.allowed_role_ids); + + // No text and no attachments → skip + if prompt.is_empty() && msg.attachments.is_empty() { + return; + } + + let display_name = msg + .member + .as_ref() + .and_then(|m| m.nick.as_ref()) + .or(msg.author.global_name.as_ref()) + .unwrap_or(&msg.author.name); + let sender = build_sender_context( + &msg.author.id.to_string(), + &msg.author.name, + display_name, + &msg.channel_id.to_string(), + thread_parent_id.as_deref(), + msg.author.bot, + &msg.timestamp.to_rfc3339().unwrap_or_default(), + &msg.id.to_string(), + &bot_id.to_string(), + ); + + // Build extra content blocks from attachments (audio -> STT, text -> inline, + // image -> encode, video -> URL for agent-side inspection). + let mut extra_blocks = Vec::new(); + let mut echo_entries: Vec = Vec::new(); + let mut failed_image_files: Vec = Vec::new(); + let mut text_file_bytes: u64 = 0; + let mut text_file_count: u32 = 0; + const TEXT_TOTAL_CAP: u64 = 1024 * 1024; // 1 MB total for all text file attachments + const TEXT_FILE_COUNT_CAP: u32 = 5; + + for attachment in &msg.attachments { + let mime = attachment.content_type.as_deref().unwrap_or(""); + if media::is_audio_mime(mime) { + if self.stt_config.enabled { + let mime_clean = mime.split(';').next().unwrap_or(mime).trim(); + match media::download_and_transcribe( + &attachment.url, + &attachment.filename, + mime_clean, + u64::from(attachment.size), + &self.stt_config, + None, + ) + .await + { + Some(transcript) => { + debug!(filename = %attachment.filename, chars = transcript.len(), "voice transcript injected"); + extra_blocks.insert( + 0, + ContentBlock::Text { + text: format!("[Voice message transcript]: {transcript}"), + }, + ); + echo_entries.push(crate::stt::EchoEntry::Success(transcript)); + } + None => { + warn!(filename = %attachment.filename, "STT failed for voice attachment"); + echo_entries.push(crate::stt::EchoEntry::Failed); + } + } + } else { + tracing::warn!(filename = %attachment.filename, "skipping audio attachment (STT disabled)"); + let msg_ref = discord_msg_ref(&msg); + let _ = adapter.add_reaction(&msg_ref, "🎤").await; + } + } else if media::is_text_file(&attachment.filename, attachment.content_type.as_deref()) + { + if text_file_count >= TEXT_FILE_COUNT_CAP { + tracing::warn!(filename = %attachment.filename, count = text_file_count, "text file count cap reached, skipping"); + continue; + } + // Pre-check with Discord-reported size (fast path, avoids unnecessary download). + // Running total uses actual downloaded bytes for accurate accounting. + if text_file_bytes + u64::from(attachment.size) > TEXT_TOTAL_CAP { + tracing::warn!(filename = %attachment.filename, total = text_file_bytes, "text attachments total exceeds 1MB cap, skipping remaining"); + continue; + } + if let Some((block, actual_bytes)) = media::download_and_read_text_file( + &attachment.url, + &attachment.filename, + u64::from(attachment.size), + None, + ) + .await + { + text_file_bytes += actual_bytes; + text_file_count += 1; + debug!(filename = %attachment.filename, "adding text file attachment"); + extra_blocks.push(block); + } + } else { + match media::download_and_encode_image( + &attachment.url, + attachment.content_type.as_deref(), + &attachment.filename, + u64::from(attachment.size), + None, + ) + .await + { + Ok(block) => { + debug!(url = %attachment.url, filename = %attachment.filename, "adding image attachment"); + extra_blocks.push(block); + } + Err(media::MediaFetchError::NotAnImage) => { + if media::is_video_file( + &attachment.filename, + attachment.content_type.as_deref(), + ) { + debug!(url = %attachment.url, filename = %attachment.filename, "adding video attachment link"); + extra_blocks.push(video_attachment_block( + &attachment.filename, + attachment.content_type.as_deref(), + u64::from(attachment.size), + &attachment.url, + )); + } + } + Err(e) => { + tracing::warn!( + url = %attachment.url, + filename = %attachment.filename, + error = %e, + "image attachment failed" + ); + failed_image_files.push(attachment.filename.clone()); + } + } + } + } + + tracing::debug!( + num_extra_blocks = extra_blocks.len(), + num_attachments = msg.attachments.len(), + in_thread, + "processing" + ); + + let thread_channel = if in_thread || is_dm { + // DMs use the DM channel directly (no threads in DMs). + ChannelRef { + platform: "discord".into(), + channel_id: msg.channel_id.get().to_string(), + thread_id: None, + parent_id: thread_parent_id.clone(), + origin_event_id: None, + } + } else { + match get_or_create_thread(&ctx, &adapter, &msg, &prompt).await { + Ok(ch) => ch, + Err(e) => { + error!("failed to create thread: {e}"); + return; + } + } + }; + + // Notify user if any images couldn't be processed. + if !failed_image_files.is_empty() { + let file_list = failed_image_files + .iter() + .map(|n| format!("`{}`", n.replace('`', "'"))) + .collect::>() + .join(", "); + let warn_msg = format!( + ":warning: I couldn't process the image(s) you shared ({}). \ + The files may be inaccessible or in an unsupported format (PNG/JPEG/GIF/WebP only).", + file_list + ); + if let Err(e) = adapter.send_message(&thread_channel, &warn_msg).await { + tracing::warn!(error = %e, "failed to send image warning to user"); + } + } + + let trigger_msg = discord_msg_ref(&msg); + + // Per-thread streaming: check if another bot is present in this thread + let other_bot_present_flag = { + let cache = self.multibot_threads.lock().await; + cache.contains_key(&msg.channel_id.to_string()) + } || self.multibot_cache.is_multibot(&msg.channel_id.to_string()); + + // Backfill thread_id: when OAB just created a new thread, the sender + // was built before the thread existed. Patch it so the agent sees + // thread_id on the very first turn. + let mut sender = sender; + if sender.thread_id.is_none() && thread_channel.parent_id.is_some() { + sender.thread_id = Some(thread_channel.channel_id.clone()); + } + + let dispatcher = self.dispatcher.clone(); + let stt_cfg = self.stt_config.clone(); + + tokio::spawn(async move { + // Best-effort echo before the agent reply so the user can verify STT. + crate::stt::post_echo( + &adapter, + &thread_channel, + &trigger_msg, + &echo_entries, + &stt_cfg, + ) + .await; + + let sender_id = sender.sender_id.clone(); + let sender_name = sender.sender_name.clone(); + let sender_json = serde_json::to_string(&sender).unwrap(); + let thread_key = dispatcher.key("discord", &thread_channel.channel_id, &sender_id); + let estimated_tokens = crate::dispatch::estimate_tokens(&prompt, &extra_blocks); + let buf_msg = crate::dispatch::BufferedMessage { + sender_json, + sender_name, + prompt, + extra_blocks, + trigger_msg, + arrived_at: std::time::Instant::now(), + estimated_tokens, + other_bot_present: other_bot_present_flag, + recipient: None, // Slack-only (assistant mode); N/A for Discord + }; + if let Err(e) = dispatcher + .submit(thread_key, thread_channel, adapter, buf_msg) + .await + { + error!("dispatcher submit error: {e}"); + } + }); + } + + async fn ready(&self, ctx: Context, ready: Ready) { + info!(user = %ready.user.name, "discord bot connected"); + + // Build the shared command list once. + let commands = vec![ + CreateCommand::new("models").description("Select the AI model for this session"), + CreateCommand::new("agents").description("Select the agent mode for this session"), + CreateCommand::new("cancel").description("Cancel the current operation"), + CreateCommand::new("cancel-all") + .description("Cancel current operation and drop all buffered messages"), + CreateCommand::new("reset").description("Reset the conversation session"), + CreateCommand::new("remind") + .description("Set a one-shot reminder to mention users/roles after a delay") + .add_option(CreateCommandOption::new( + CommandOptionType::String, + "targets", + "Users/roles to mention (e.g. @user1 @role1)", + ).required(true)) + .add_option(CreateCommandOption::new( + CommandOptionType::String, + "message", + "Reminder message", + ).required(true)) + .add_option(CreateCommandOption::new( + CommandOptionType::String, + "delay", + "Delay before firing (e.g. 30m, 2h, 1d)", + ).required(true)), + CreateCommand::new("export-thread") + .description("Download this thread as a text file") + .add_option(CreateCommandOption::new( + CommandOptionType::Integer, + "limit", + "Export only the most recent N messages (1–5000)", + )) + .add_option(CreateCommandOption::new( + CommandOptionType::String, + "since", + "Export messages after this message ID", + )) + .add_option(CreateCommandOption::new( + CommandOptionType::Integer, + "days", + "Export messages from the last N days (1–365)", + )) + .add_option(CreateCommandOption::new( + CommandOptionType::Boolean, + "all", + "Export all messages (up to 5000). Default is last 100.", + )), + ]; + + // Register global commands only. Registering the same commands per-guild + // makes Discord show duplicate slash commands in guild command pickers. + if let Err(e) = Command::set_global_commands(&ctx.http, commands.clone()).await { + tracing::warn!(error = %e, "failed to register global slash commands"); + } else { + info!("registered global slash commands"); + } + + // One-time migration cleanup: older versions registered the same + // slash commands per-guild, and Discord persists those server-side. + // Keep guild command sets empty so only global commands are shown. + for guild in &ready.guilds { + let guild_id = guild.id; + if let Err(e) = guild_id.set_commands(&ctx.http, Vec::new()).await { + tracing::warn!( + %guild_id, + error = %e, + "failed to clear stale guild slash commands" + ); + } + } + + // Re-schedule any pending reminders that survived a restart. + let pending = self.reminder_store.pending().await; + if !pending.is_empty() { + let mut scheduled = self.scheduled_ids.lock().await; + let mut count = 0; + for r in pending { + if scheduled.insert(r.id.clone()) { + remind::schedule_reminder(ctx.http.clone(), self.reminder_store.clone(), r); + count += 1; + } + } + if count > 0 { + info!(count, "re-scheduled pending reminders"); + } + } + } + + async fn interaction_create(&self, ctx: Context, interaction: Interaction) { + match interaction { + Interaction::Command(cmd) if cmd.data.name == "models" => { + self.handle_config_command(&ctx, &cmd, "model", "model") + .await; + } + Interaction::Command(cmd) if cmd.data.name == "agents" => { + self.handle_config_command(&ctx, &cmd, "agent", "agent") + .await; + } + Interaction::Command(cmd) if cmd.data.name == "cancel" => { + self.handle_cancel_command(&ctx, &cmd).await; + } + Interaction::Command(cmd) if cmd.data.name == "cancel-all" => { + self.handle_cancel_all_command(&ctx, &cmd).await; + } + Interaction::Command(cmd) if cmd.data.name == "reset" => { + self.handle_reset_command(&ctx, &cmd).await; + } + Interaction::Command(cmd) if cmd.data.name == "remind" => { + self.handle_remind_command(&ctx, &cmd).await; + } + Interaction::Command(cmd) if cmd.data.name == "export-thread" => { + self.handle_export_thread_command(&ctx, &cmd).await; + } + Interaction::Component(comp) if comp.data.custom_id.starts_with("acp_config_") => { + self.handle_config_select(&ctx, &comp).await; + } + Interaction::Component(comp) if comp.data.custom_id.starts_with("acp_pg:") => { + self.handle_pagination(&ctx, &comp).await; + } + _ => {} + } + } +} + +// --- Slash command & interaction handlers --- + +impl Handler { + /// Build a Discord select menu from ACP configOptions with the given category. + /// Paginates options in pages of 25 (Discord limit). The current selection is + /// always placed first so it appears on page 0. + fn build_config_select( + options: &[ConfigOption], + category: &str, + page: usize, + ) -> Option { + let opt = options + .iter() + .find(|o| o.category.as_deref() == Some(category))?; + + // Put current selection first so it always lands on page 0, + // then fill remaining slots in original order. + let sorted: Vec<_> = opt + .options + .iter() + .filter(|o| o.value == opt.current_value) + .chain(opt.options.iter().filter(|o| o.value != opt.current_value)) + .collect(); + + let menu_options: Vec = sorted + .iter() + .skip(page * SELECT_MENU_PAGE_SIZE) + .take(SELECT_MENU_PAGE_SIZE) + .map(|o| { + let mut item = CreateSelectMenuOption::new(&o.name, &o.value); + if let Some(desc) = &o.description { + item = item.description(desc); + } + if o.value == opt.current_value { + item = item.default_selection(true); + } + item + }) + .collect(); + + if menu_options.is_empty() { + return None; + } + + let current_name = opt + .options + .iter() + .find(|o| o.value == opt.current_value) + .map(|o| o.name.as_str()) + .unwrap_or(&opt.current_value); + let total_pages = sorted.len().div_ceil(SELECT_MENU_PAGE_SIZE); + let placeholder = if total_pages > 1 { + format!( + "Current: {} (page {}/{})", + current_name, + page + 1, + total_pages + ) + } else { + format!("Current: {}", current_name) + }; + + Some( + CreateSelectMenu::new( + format!("acp_config_{}", opt.id), + CreateSelectMenuKind::String { + options: menu_options, + }, + ) + .placeholder(placeholder), + ) + } + + /// Build ◀/▶ pagination buttons. Returns None when only one page exists. + fn build_pagination_buttons( + category: &str, + page: usize, + total_pages: usize, + ) -> Option { + if total_pages <= 1 { + return None; + } + let prev = CreateButton::new(format!("acp_pg:{}:{}", category, page.saturating_sub(1))) + .label("◀") + .style(ButtonStyle::Secondary) + .disabled(page == 0); + let next = CreateButton::new(format!("acp_pg:{}:{}", category, page + 1)) + .label("▶") + .style(ButtonStyle::Secondary) + .disabled(page + 1 >= total_pages); + let indicator = CreateButton::new("acp_pg_noop") + .label(format!("{}/{}", page + 1, total_pages)) + .style(ButtonStyle::Secondary) + .disabled(true); + Some(CreateActionRow::Buttons(vec![prev, indicator, next])) + } + + /// Build the full component rows (select menu + optional pagination) for a config category. + /// When `page` is `None`, auto-selects the page containing the current value. + fn build_config_components( + options: &[ConfigOption], + category: &str, + page: Option, + ) -> Option> { + let opt = options + .iter() + .find(|o| o.category.as_deref() == Some(category))?; + let total_pages = opt.options.len().div_ceil(SELECT_MENU_PAGE_SIZE); + let page = match page { + Some(p) => p.min(total_pages.saturating_sub(1)), + None => opt + .options + .iter() + .position(|o| o.value == opt.current_value) + .map(|i| i / SELECT_MENU_PAGE_SIZE) + .unwrap_or(0), + }; + + let select = Self::build_config_select(options, category, page)?; + let mut rows = vec![CreateActionRow::SelectMenu(select)]; + if let Some(buttons) = Self::build_pagination_buttons(category, page, total_pages) { + rows.push(buttons); + } + Some(rows) + } + + async fn handle_config_command( + &self, + ctx: &Context, + cmd: &serenity::model::application::CommandInteraction, + category: &str, + label: &str, + ) { + let thread_key = format!("discord:{}", cmd.channel_id.get()); + let config_options = self.router.pool().get_config_options(&thread_key).await; + + let response = match Self::build_config_components(&config_options, category, None) { + Some(rows) => CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content(format!("🔧 Select a {label}:")) + .components(rows) + .ephemeral(true), + ), + None => CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content(format!("⚠️ No {label} options available. Start a conversation first by @mentioning the bot.")) + .ephemeral(true), + ), + }; + + if let Err(e) = cmd.create_response(&ctx.http, response).await { + tracing::error!(error = %e, category, "failed to respond to slash command"); + } + } + + async fn handle_cancel_command( + &self, + ctx: &Context, + cmd: &serenity::model::application::CommandInteraction, + ) { + let thread_key = format!("discord:{}", cmd.channel_id.get()); + let result = self.router.pool().cancel_session(&thread_key).await; + + let msg = match result { + Ok(()) => "🛑 Cancel signal sent.".to_string(), + Err(e) => format!("⚠️ {e}"), + }; + + let response = CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content(msg) + .ephemeral(true), + ); + if let Err(e) = cmd.create_response(&ctx.http, response).await { + tracing::error!(error = %e, "failed to respond to /cancel command"); + } + } + + async fn handle_cancel_all_command( + &self, + ctx: &Context, + cmd: &serenity::model::application::CommandInteraction, + ) { + // /cancel-all is the nuclear escape hatch: stop the in-flight turn AND clear + // every lane's buffer in this thread, so a human can intervene from a clean slate. + let session_key = format!("discord:{}", cmd.channel_id.get()); + let dropped = self + .dispatcher + .cancel_buffered_thread("discord", &cmd.channel_id.get().to_string()); + + let cancel_result = self.router.pool().cancel_session(&session_key).await; + + // Buffer count is approximate (sweep races with new arrivals) so we surface + // a binary "cleared / nothing" signal rather than a misleading exact number. + let msg = match (cancel_result, dropped) { + (Ok(()), 0) => "🛑 Cancel signal sent.".to_string(), + (Ok(()), _) => "🛑 Cancel signal sent. Buffered messages cleared.".to_string(), + (Err(_), 0) => { + "⚠️ Nothing to cancel — no active session and no buffered messages.".to_string() + } + (Err(_), _) => "🛑 Buffered messages cleared. No active session to cancel.".to_string(), + }; + + let response = CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content(msg) + .ephemeral(true), + ); + if let Err(e) = cmd.create_response(&ctx.http, response).await { + tracing::error!(error = %e, "failed to respond to /cancel-all command"); + } + } + + async fn handle_reset_command( + &self, + ctx: &Context, + cmd: &serenity::model::application::CommandInteraction, + ) { + // /reset clears every lane's buffer in this thread and tears down the shared + // ACP session — the next message in the thread starts a fresh conversation. + let session_key = format!("discord:{}", cmd.channel_id.get()); + let dropped = self + .dispatcher + .cancel_buffered_thread("discord", &cmd.channel_id.get().to_string()); + + let result = self.router.pool().reset_session(&session_key).await; + + let msg = match result { + Ok(()) if dropped > 0 => { + format!("🔄 Session reset. Dropped {dropped} buffered message(s). Start a new conversation!") + } + Ok(()) => "🔄 Session reset. Start a new conversation!".to_string(), + Err(_) if dropped > 0 => { + format!("🔄 Dropped {dropped} buffered message(s). No active session to reset.") + } + Err(_) => { + "⚠️ No active session to reset. Start a conversation first by @mentioning the bot." + .to_string() + } + }; + + let response = CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content(msg) + .ephemeral(true), + ); + if let Err(e) = cmd.create_response(&ctx.http, response).await { + tracing::error!(error = %e, "failed to respond to /reset command"); + } + } + + async fn handle_remind_command( + &self, + ctx: &Context, + cmd: &serenity::model::application::CommandInteraction, + ) { + // Only humans can use /remind + if cmd.user.bot { + let response = CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content("⚠️ Only humans can set reminders.") + .ephemeral(true), + ); + let _ = cmd.create_response(&ctx.http, response).await; + return; + } + + // Extract options + let opts = &cmd.data.options; + let targets_raw = opts.iter() + .find(|o| o.name == "targets") + .and_then(|o| o.value.as_str()) + .unwrap_or(""); + let message = opts.iter() + .find(|o| o.name == "message") + .and_then(|o| o.value.as_str()) + .unwrap_or(""); + let delay_raw = opts.iter() + .find(|o| o.name == "delay") + .and_then(|o| o.value.as_str()) + .unwrap_or(""); + + if targets_raw.is_empty() || message.is_empty() || delay_raw.is_empty() { + let response = CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content("⚠️ All fields (targets, message, delay) are required.") + .ephemeral(true), + ); + let _ = cmd.create_response(&ctx.http, response).await; + return; + } + + // Parse delay + let delay_secs = match remind::parse_delay(delay_raw) { + Ok(s) => s, + Err(e) => { + let response = CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content(format!("⚠️ Invalid delay: {e}")) + .ephemeral(true), + ); + let _ = cmd.create_response(&ctx.http, response).await; + return; + } + }; + + if let Err(e) = remind::validate_message(message) { + let response = CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content(format!("⚠️ {e}")) + .ephemeral(true), + ); + let _ = cmd.create_response(&ctx.http, response).await; + return; + } + + // Strip @everyone / @here to prevent unintended mass pings. + let message = remind::sanitize_message(message); + + // Extract mention strings from targets (keep raw — Discord renders them) + let targets: Vec = targets_raw + .split_whitespace() + .filter(|t| t.starts_with("<@") && t.ends_with('>')) + .map(|t| t.to_string()) + .collect(); + + if targets.is_empty() { + let response = CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content("⚠️ No valid mentions found in targets. Use @user or @role.") + .ephemeral(true), + ); + let _ = cmd.create_response(&ctx.http, response).await; + return; + } + + if targets.len() > remind::MAX_TARGETS { + let response = CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content(format!("⚠️ Too many targets (max {}). Use a @role instead.", remind::MAX_TARGETS)) + .ephemeral(true), + ); + let _ = cmd.create_response(&ctx.http, response).await; + return; + } + + // F4: Per-user rate limit (max 5 active reminders) + let user_id = cmd.user.id.get(); + let pending = self.reminder_store.pending().await; + let user_count = pending.iter().filter(|r| r.sender_id == user_id).count(); + if user_count >= 5 { + let response = CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content("⚠️ You already have 5 active reminders. Wait for some to fire before adding more.") + .ephemeral(true), + ); + let _ = cmd.create_response(&ctx.http, response).await; + return; + } + + let fire_at = chrono::Utc::now() + chrono::Duration::seconds(delay_secs as i64); + let reminder = remind::Reminder { + id: uuid::Uuid::new_v4().to_string(), + channel_id: cmd.channel_id.get(), + sender_id: cmd.user.id.get(), + targets: targets.clone(), + message: message.clone(), + fire_at, + created_at: chrono::Utc::now(), + }; + + // Persist and schedule + self.reminder_store.add(reminder.clone()).await; + self.scheduled_ids.lock().await.insert(reminder.id.clone()); + remind::schedule_reminder(ctx.http.clone(), self.reminder_store.clone(), reminder); + + let delay_str = remind::format_delay(delay_secs); + let response = CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content(format!( + "⏰ Reminder set! Will fire in **{delay_str}** and mention {}", + targets.join(" ") + )) + .ephemeral(true), + ); + if let Err(e) = cmd.create_response(&ctx.http, response).await { + tracing::error!(error = %e, "failed to respond to /remind command"); + } + } + + async fn handle_export_thread_command( + &self, + ctx: &Context, + cmd: &serenity::model::application::CommandInteraction, + ) { + if is_denied_user( + false, + self.allow_all_users, + &self.allowed_users, + cmd.user.id.get(), + ) { + let response = CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content("🚫 You are not allowed to use this bot.") + .ephemeral(true), + ); + if let Err(e) = cmd.create_response(&ctx.http, response).await { + tracing::error!(error = %e, "failed to deny /export-thread command"); + } + return; + } + + let channel_id = cmd.channel_id; + let (export_allowed, export_name) = match channel_id.to_channel(&ctx.http).await { + Ok(serenity::model::channel::Channel::Guild(gc)) => { + let in_allowed_channel = + self.allow_all_channels || self.allowed_channels.contains(&channel_id.get()); + let (in_thread, _) = detect_thread( + gc.thread_metadata.is_some(), + gc.parent_id.map(|id| id.get()), + gc.owner_id.map(|id| id.get()), + ctx.cache.current_user().id.get(), + &self.allowed_channels, + self.allow_all_channels, + in_allowed_channel, + ); + (in_thread, gc.name.clone()) + } + Ok(serenity::model::channel::Channel::Private(_)) => { + (self.allow_dm, "dm".to_string()) + } + Ok(_) => (false, "channel".to_string()), + Err(e) => { + tracing::warn!(channel_id = %channel_id, error = %e, "failed to inspect channel for export"); + (false, "channel".to_string()) + } + }; + + if !export_allowed { + let response = CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content("⚠️ Run this command inside an allowed Discord thread or DM.") + .ephemeral(true), + ); + if let Err(e) = cmd.create_response(&ctx.http, response).await { + tracing::error!(error = %e, "failed to respond to /export-thread rejection"); + } + return; + } + + // --- Parse and validate filter params (mutual exclusion) --- + let opts = &cmd.data.options; + let limit_opt = opts.iter().find(|o| o.name == "limit").and_then(|o| o.value.as_i64()); + let since_opt = opts.iter().find(|o| o.name == "since").and_then(|o| o.value.as_str()); + let days_opt = opts.iter().find(|o| o.name == "days").and_then(|o| o.value.as_i64()); + let all_opt = opts.iter().find(|o| o.name == "all").and_then(|o| o.value.as_bool()).unwrap_or(false); + + let filter_count = limit_opt.is_some() as u8 + since_opt.is_some() as u8 + days_opt.is_some() as u8 + all_opt as u8; + if filter_count > 1 { + let response = CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content("⚠️ Please specify only one filter: `limit`, `since`, `days`, or `all`.") + .ephemeral(true), + ); + let _ = cmd.create_response(&ctx.http, response).await; + return; + } + + let filter = if all_opt { + ExportFilter::All + } else if let Some(n) = limit_opt { + if !(1..=5000).contains(&n) { + let response = CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content("⚠️ `limit` must be between 1 and 5000.") + .ephemeral(true), + ); + let _ = cmd.create_response(&ctx.http, response).await; + return; + } + ExportFilter::Limit(n as usize) + } else if let Some(id_str) = since_opt { + match id_str.parse::() { + Ok(id) if id > 0 => ExportFilter::After(MessageId::new(id)), + _ => { + let response = CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content("⚠️ `since` must be a valid message ID (right-click a message → Copy Message ID).") + .ephemeral(true), + ); + let _ = cmd.create_response(&ctx.http, response).await; + return; + } + } + } else if let Some(d) = days_opt { + if !(1..=365).contains(&d) { + let response = CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content("⚠️ `days` must be between 1 and 365.") + .ephemeral(true), + ); + let _ = cmd.create_response(&ctx.http, response).await; + return; + } + let since_ts = chrono::Utc::now() - chrono::Duration::days(d); + let ts_ms = since_ts.timestamp_millis() as u64; + ExportFilter::After(timestamp_ms_to_snowflake(ts_ms)) + } else { + // Default: export last 100 messages (use limit:N or all:true for more) + ExportFilter::Limit(100) + }; + + let response = CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content("Preparing thread export...") + .ephemeral(true), + ); + if let Err(e) = cmd.create_response(&ctx.http, response).await { + tracing::error!(error = %e, "failed to acknowledge /export-thread command"); + return; + } + + match export_channel_messages( + &ctx.http, + channel_id, + &export_name, + cmd.attachment_size_limit, + filter, + ) + .await + { + Ok(result) => { + let mut content = format!("Exported {} messages.", result.written); + if result.hit_cap { + content.push_str(&format!( + " Only the most recent {} messages were fetched — older messages were not included.", + result.fetched + )); + } + if result.byte_truncated { + content.push_str(&format!( + " Transcript truncated to fit Discord's attachment size limit ({} of {} fetched messages included).", + result.written, result.fetched + )); + } + let attachment = + CreateAttachment::bytes(result.transcript.into_bytes(), result.filename); + let followup = CreateInteractionResponseFollowup::new() + .content(content) + .add_file(attachment) + .ephemeral(true); + if let Err(e) = cmd.create_followup(&ctx.http, followup).await { + tracing::error!(error = %e, "failed to send /export-thread attachment"); + } + } + Err(e) => { + tracing::warn!(channel_id = %channel_id, error = %e, "failed to export thread"); + let followup = CreateInteractionResponseFollowup::new() + .content(format!("⚠️ Failed to export thread: {e}")) + .ephemeral(true); + if let Err(e) = cmd.create_followup(&ctx.http, followup).await { + tracing::error!(error = %e, "failed to send /export-thread error"); + } + } + } + } + + async fn handle_config_select( + &self, + ctx: &Context, + comp: &serenity::model::application::ComponentInteraction, + ) { + let config_id = comp + .data + .custom_id + .strip_prefix("acp_config_") + .unwrap_or("") + .to_string(); + + if config_id.is_empty() { + return; + } + + let selected_value = match &comp.data.kind { + ComponentInteractionDataKind::StringSelect { values } => match values.first() { + Some(v) => v.clone(), + None => return, + }, + _ => return, + }; + + let thread_key = format!("discord:{}", comp.channel_id.get()); + + let result = self + .router + .pool() + .set_config_option(&thread_key, &config_id, &selected_value) + .await; + + let response_msg = match result { + Ok(updated_options) => { + let display_name = updated_options + .iter() + .find(|o| o.id == config_id) + .and_then(|o| o.options.iter().find(|v| v.value == selected_value)) + .map(|v| v.name.as_str()) + .unwrap_or(&selected_value); + format!("✅ Switched to **{}**", display_name) + } + Err(e) => { + tracing::error!(error = %e, "failed to set config option"); + format!("❌ Failed to switch: {}", e) + } + }; + + let response = CreateInteractionResponse::UpdateMessage( + CreateInteractionResponseMessage::new() + .content(response_msg) + .components(vec![]), + ); + + if let Err(e) = comp.create_response(&ctx.http, response).await { + tracing::error!(error = %e, "failed to respond to config select"); + } + } + + async fn handle_pagination( + &self, + ctx: &Context, + comp: &serenity::model::application::ComponentInteraction, + ) { + // Parse custom_id format: acp_pg:{category}:{page} + let parts: Vec<&str> = comp.data.custom_id.splitn(3, ':').collect(); + let (category, page) = match parts.as_slice() { + [_, cat, pg] => match pg.parse::() { + Ok(p) => (*cat, p), + Err(_) => return, + }, + _ => return, + }; + + // Only allow known config categories. + if !matches!(category, "model" | "agent") { + return; + } + + let thread_key = format!("discord:{}", comp.channel_id.get()); + let config_options = self.router.pool().get_config_options(&thread_key).await; + + let response = match Self::build_config_components(&config_options, category, Some(page)) { + Some(rows) => CreateInteractionResponse::UpdateMessage( + CreateInteractionResponseMessage::new() + .content(format!("🔧 Select a {category}:")) + .components(rows), + ), + None => CreateInteractionResponse::UpdateMessage( + CreateInteractionResponseMessage::new() + .content(format!("⚠️ No {category} options available.")) + .components(vec![]), + ), + }; + + if let Err(e) = comp.create_response(&ctx.http, response).await { + tracing::error!(error = %e, category, "failed to respond to pagination"); + } + } +} + +// --- Discord-specific helpers --- + +fn discord_msg_ref(msg: &Message) -> MessageRef { + MessageRef { + channel: ChannelRef { + platform: "discord".into(), + channel_id: msg.channel_id.get().to_string(), + thread_id: None, + parent_id: None, + origin_event_id: None, + }, + message_id: msg.id.to_string(), + } +} + +struct ExportResult { + filename: String, + transcript: String, + /// Messages successfully pulled from Discord. + fetched: usize, + /// Messages that fit in the transcript (≤ `fetched`; differs when the + /// attachment-size limit truncates). + written: usize, + /// We stopped fetching because we hit the message cap and the thread still + /// has more messages we did not include. + hit_cap: bool, + /// Transcript was cut to keep the attachment under Discord's size limit. + byte_truncated: bool, +} + +/// Filter mode for export_channel_messages. +enum ExportFilter { + /// Fetch all messages (newest-first via `before`), capped at THREAD_EXPORT_MESSAGE_LIMIT. + All, + /// Fetch the most recent N messages (newest-first via `before`). + Limit(usize), + /// Fetch messages after a synthetic snowflake (newest-first via `before`, with boundary filtering). + After(MessageId), +} + +/// Discord epoch: 2015-01-01T00:00:00Z in milliseconds. +const DISCORD_EPOCH_MS: u64 = 1_420_070_400_000; + +/// Convert a UTC timestamp (in milliseconds since Unix epoch) to a synthetic +/// Discord snowflake suitable for use as an `after` cursor. +fn timestamp_ms_to_snowflake(timestamp_ms: u64) -> MessageId { + let discord_ms = timestamp_ms.saturating_sub(DISCORD_EPOCH_MS); + // Snowflake IDs use NonZeroU64 in serenity; ensure at least 1. + MessageId::new((discord_ms << 22).max(1)) +} + +async fn export_channel_messages( + http: &Http, + channel_id: ChannelId, + channel_name: &str, + attachment_size_limit: u32, + filter: ExportFilter, +) -> anyhow::Result { + let cap = match &filter { + ExportFilter::Limit(n) => *n, + _ => THREAD_EXPORT_MESSAGE_LIMIT, + }; + + let mut messages = Vec::new(); + let mut hit_cap = false; + + match &filter { + ExportFilter::All | ExportFilter::Limit(_) => { + // Fetch newest-first using `before` pagination, then reverse. + let mut before = None; + loop { + if messages.len() >= cap { + hit_cap = true; + break; + } + let remaining = cap - messages.len(); + let limit = remaining.min(100) as u8; + let mut request = GetMessages::new().limit(limit); + if let Some(before_id) = before { + request = request.before(before_id); + } + let batch = channel_id.messages(http, request).await?; + if batch.is_empty() { + break; + } + before = batch.last().map(|m| m.id); + let batch_len = batch.len(); + messages.extend(batch); + if batch_len < limit as usize { + break; + } + } + // Probe to confirm we actually left messages behind. + if hit_cap { + let probe = GetMessages::new().limit(1); + let probe = if let Some(before_id) = before { + probe.before(before_id) + } else { + probe + }; + if matches!(channel_id.messages(http, probe).await, Ok(b) if b.is_empty()) { + hit_cap = false; + } + } + messages.reverse(); + } + ExportFilter::After(after_id) => { + // Fetch newest-first using `before` pagination, stop when we hit + // messages at or before the filter boundary. This ensures that when + // the cap is reached, we keep the *newest* messages in the window. + let mut before = None; + loop { + if messages.len() >= cap { + hit_cap = true; + break; + } + let remaining = cap - messages.len(); + let limit = remaining.min(100) as u8; + let mut request = GetMessages::new().limit(limit); + if let Some(before_id) = before { + request = request.before(before_id); + } + let batch = channel_id.messages(http, request).await?; + if batch.is_empty() { + break; + } + before = batch.last().map(|m| m.id); + let batch_len = batch.len(); + // Filter out messages at or before the boundary. + let filtered: Vec<_> = batch.into_iter().filter(|m| m.id > *after_id).collect(); + let hit_boundary = filtered.len() < batch_len; + messages.extend(filtered); + if hit_boundary { + // We've reached the time boundary; no need to fetch older. + break; + } + if batch_len < limit as usize { + break; + } + } + // Probe only if we stopped due to cap (not boundary). + if hit_cap { + let probe = GetMessages::new().limit(1); + let probe = if let Some(before_id) = before { + probe.before(before_id) + } else { + probe + }; + if let Ok(batch) = channel_id.messages(http, probe).await { + // If the next message is beyond our filter boundary, + // we didn't actually leave relevant messages behind. + let has_more_in_window = batch.iter().any(|m| m.id > *after_id); + if !has_more_in_window { + hit_cap = false; + } + } + } + messages.reverse(); + } + } + + let filename = export_filename(channel_id, channel_name); + if attachment_size_limit < 2048 { + tracing::warn!(attachment_size_limit, "attachment_size_limit is very small; export will likely be truncated"); + } + let max_bytes = usize::try_from(attachment_size_limit) + .unwrap_or(8 * 1024 * 1024) + .saturating_sub(1024) + .max(1024); + let (transcript, written, byte_truncated) = + format_thread_export(channel_id, channel_name, &messages, max_bytes); + let fetched = messages.len(); + + Ok(ExportResult { + filename, + transcript, + fetched, + written, + hit_cap, + byte_truncated, + }) +} + +fn format_thread_export( + channel_id: ChannelId, + channel_name: &str, + messages: &[Message], + max_bytes: usize, +) -> (String, usize, bool) { + let header = format!( + "Discord thread export\nChannel: {channel_name} ({channel_id})\nMessages: {}\n\n", + messages.len() + ); + let entries: Vec = messages.iter().map(format_export_message).collect(); + assemble_export(&header, &entries, max_bytes) +} + +/// Build the transcript body from a pre-rendered header and a list of +/// already-formatted message entries, honouring `max_bytes`. +/// +/// Returns `(transcript, written, truncated)` where `written` is the number of +/// entries actually included. Split out from `format_thread_export` so the +/// truncation boundary logic can be unit-tested without constructing real +/// `serenity::model::channel::Message` values. +fn assemble_export(header: &str, entries: &[String], max_bytes: usize) -> (String, usize, bool) { + let mut out = String::from(header); + let mut written = 0; + let mut truncated = false; + + for entry in entries { + if out.len() + entry.len() > max_bytes { + truncated = true; + break; + } + out.push_str(entry); + written += 1; + } + + if truncated { + let note = "\n[Export truncated to fit Discord attachment size limit]\n"; + let room = max_bytes.saturating_sub(out.len()); + if room >= note.len() { + out.push_str(note); + } + } + + (out, written, truncated) +} + +fn format_export_message(msg: &Message) -> String { + let bot_marker = if msg.author.bot { " [bot]" } else { "" }; + let mut out = format!( + "[{}] {}{} ({})\n", + msg.timestamp, + msg.author.name, + bot_marker, + msg.author.id + ); + + if msg.content.is_empty() { + out.push_str("(no text)\n"); + } else { + out.push_str(&msg.content); + out.push('\n'); + } + + for attachment in &msg.attachments { + let mime = attachment.content_type.as_deref().unwrap_or("unknown"); + out.push_str(&format!( + "[attachment] {} ({} bytes, {}): {}\n", + attachment.filename, attachment.size, mime, attachment.url + )); + } + + out.push('\n'); + out +} + +fn export_filename(channel_id: ChannelId, channel_name: &str) -> String { + let safe_name = sanitize_filename_component(channel_name); + format!("discord-thread-{safe_name}-{channel_id}.txt") +} + +/// Reduce a free-form Discord channel/thread name to a safe ASCII filename +/// fragment. +/// +/// Non-ASCII characters are dropped silently — a purely-Chinese thread name +/// like "扈三娘的房間" yields a date-based fallback (e.g. `"20260512"`). +/// The caller appends the channel ID, which already guarantees uniqueness, +/// and an ASCII fragment plays nicer with downstream tools (mail attachments, +/// S3 keys, browser save-as dialogs). The 64-byte cap leaves room for the +/// `discord-thread-` prefix and the channel-ID suffix within typical +/// filesystem limits. +fn sanitize_filename_component(input: &str) -> String { + let mut safe = String::with_capacity(input.len()); + for ch in input.chars() { + if ch.is_ascii_alphanumeric() || matches!(ch, '-' | '_') { + safe.push(ch); + } else if ch.is_whitespace() || matches!(ch, '.' | '/') { + safe.push('-'); + } + } + let safe = safe.trim_matches('-'); + if safe.is_empty() { + // Use current date as a human-friendly fallback when the thread name + // is entirely non-ASCII. + chrono::Utc::now().format("%Y%m%d").to_string() + } else { + safe.chars().take(64).collect() + } +} + +async fn get_or_create_thread( + ctx: &Context, + adapter: &Arc, + msg: &Message, + prompt: &str, +) -> anyhow::Result { + let channel = msg.channel_id.to_channel(&ctx.http).await?; + if let serenity::model::channel::Channel::Guild(ref gc) = channel { + // Already in a thread — reuse it. Uses thread_metadata (see detect_thread()). + if gc.thread_metadata.is_some() { + return Ok(ChannelRef { + platform: "discord".into(), + channel_id: msg.channel_id.get().to_string(), + thread_id: None, + parent_id: None, + origin_event_id: None, + }); + } + } + + let thread_name = format::shorten_thread_name(prompt); + let parent = ChannelRef { + platform: "discord".into(), + channel_id: msg.channel_id.get().to_string(), + thread_id: None, + parent_id: None, + origin_event_id: None, + }; + let trigger_ref = discord_msg_ref(msg); + match adapter + .create_thread(&parent, &trigger_ref, &thread_name) + .await + { + Ok(ch) => Ok(ch), + Err(e) if is_thread_already_exists_error(&e) => { + // Another bot won the race from the same trigger message. Discord + // only allows one thread per message, so refetch the message and + // join the thread our sibling just created. + let refreshed = msg + .channel_id + .message(&ctx.http, msg.id) + .await + .map_err(|fe| { + anyhow::anyhow!("thread_already_exists (race), but refetch failed: {fe}") + })?; + let existing = refreshed.thread.ok_or_else(|| { + anyhow::anyhow!( + "thread_already_exists (race), but message has no thread after refetch" + ) + })?; + tracing::info!( + channel_id = %msg.channel_id, + thread_id = %existing.id, + "joining thread created by sibling bot from same trigger message" + ); + Ok(ChannelRef { + platform: "discord".into(), + channel_id: existing.id.to_string(), + thread_id: None, + parent_id: Some(msg.channel_id.get().to_string()), + origin_event_id: None, + }) + } + Err(e) => Err(e), + } +} + +/// Detect Discord's "A thread has already been created for this message" error +/// (JSON error code 160004). Triggered when two bots responding to the same +/// @-mention race to create a thread from the same trigger message. +/// +/// Uses string matching because serenity surfaces Discord API errors as +/// formatted strings — there is no structured error code we can match on. +/// Unit tests pin the expected patterns so serenity formatting changes are caught. +fn is_thread_already_exists_error(err: &anyhow::Error) -> bool { + let msg = err.to_string(); + msg.contains("160004") || msg.contains("already been created") +} + +static ROLE_MENTION_RE: LazyLock = + LazyLock::new(|| regex::Regex::new(r"<@&\d+>").unwrap()); + +fn resolve_mentions(content: &str, bot_id: UserId, allowed_role_ids: &HashSet) -> String { + // 1. Strip the bot's own trigger mention + let out = content + .replace(&format!("<@{}>", bot_id), "") + .replace(&format!("<@!{}>", bot_id), ""); + // 2. Strip allowed role mentions (they triggered the bot, not useful in prompt) + let out = if allowed_role_ids.is_empty() { + out + } else { + allowed_role_ids + .iter() + .fold(out, |s, id| s.replace(&format!("<@&{}>", id), "")) + }; + // 3. Other user mentions: keep <@UID> as-is so the LLM can mention back + // 4. Fallback: replace remaining role mentions only (user mentions are preserved) + let out = ROLE_MENTION_RE.replace_all(&out, "@(role)").to_string(); + out.trim().to_string() +} + +fn video_attachment_block( + filename: &str, + content_type: Option<&str>, + size: u64, + url: &str, +) -> ContentBlock { + ContentBlock::Text { + text: format!( + "[Video attachment]\nfilename: {}\ncontent_type: {}\nsize_bytes: {}\nurl: {}", + filename, + content_type.unwrap_or("unknown"), + size, + url + ), + } +} + +/// Build a `SenderContext` for Discord messages. +/// +/// Pure function extracted from `EventHandler::message` for testability. +/// When `thread_parent_id` is `Some`, the message is inside a thread: +/// - `channel_id` → parent channel (where the thread lives) +/// - `thread_id` → thread's own channel ID +/// +/// This mirrors Slack's model where `channel_id` is always the parent +/// channel and `thread_id` (thread_ts) identifies the thread. +/// +/// Note: `ChannelRef.channel_id` uses the *opposite* convention — it holds +/// the thread's channel ID for routing (Discord API sends to thread by its +/// channel ID). See `ChannelRef` doc comments for details. +#[allow(clippy::too_many_arguments)] +fn build_sender_context( + sender_id: &str, + sender_name: &str, + display_name: &str, + msg_channel_id: &str, + thread_parent_id: Option<&str>, + is_bot: bool, + timestamp: &str, + message_id: &str, + receiver_id: &str, +) -> SenderContext { + SenderContext { + schema: "openab.sender.v1".into(), + sender_id: sender_id.to_string(), + sender_name: sender_name.to_string(), + display_name: display_name.to_string(), + channel: "discord".into(), + channel_id: thread_parent_id.unwrap_or(msg_channel_id).to_string(), + thread_id: thread_parent_id.map(|_| msg_channel_id.to_string()), + is_bot, + timestamp: Some(timestamp.to_string()), + message_id: Some(message_id.to_string()), + receiver_id: Some(receiver_id.to_string()), + } +} + +/// Pure thread detection: determines whether a channel is a Discord thread +/// in an allowed parent, and whether the bot owns it. +/// +/// Returns `(in_allowed_thread, bot_owns)`: +/// - `in_allowed_thread`: true only if the channel IS a thread AND its parent +/// is permitted (via allowlist, `allow_all_channels`, or `in_allowed_channel`). +/// - `bot_owns`: `None` if the channel is not a thread (ownership is meaningless); +/// `Some(true/false)` if it IS a thread, indicating whether the bot owns it. +/// +/// Uses `thread_metadata.is_some()` — the canonical way to identify threads. +/// `parent_id` is NOT reliable for thread detection: category children also +/// have `parent_id` set. `parent_id` is only used here for the allowlist check. +/// +/// Discord API refs: +/// - Channel Object (parent_id / thread_metadata fields): +/// https://docs.discord.com/developers/resources/channel#channel-object +/// - Thread Metadata ("thread-specific fields not needed by other channels"): +/// https://docs.discord.com/developers/resources/channel#thread-metadata-object +fn detect_thread( + has_thread_metadata: bool, + parent_id: Option, + owner_id: Option, + bot_id: u64, + allowed_channels: &HashSet, + allow_all_channels: bool, + in_allowed_channel: bool, +) -> (bool, Option) { + if !has_thread_metadata { + return (false, None); + } + let in_allowed_thread = in_allowed_channel + || allow_all_channels + || parent_id.is_some_and(|pid| allowed_channels.contains(&pid)); + let bot_owns = owner_id.is_some_and(|oid| oid == bot_id); + (in_allowed_thread, Some(bot_owns)) +} + +/// Returns `true` if the author should be denied by the user allowlist. +/// Bot authors skip this check — they are gated by `allow_bot_messages` + `trusted_bot_ids`. +fn is_denied_user( + is_bot: bool, + allow_all_users: bool, + allowed_users: &HashSet, + user_id: u64, +) -> bool { + !is_bot && !allow_all_users && !allowed_users.contains(&user_id) +} + +/// Returns `true` if a bot message should bypass the `allow_bot_messages` mode check. +/// A trusted bot that @mentions this bot is treated the same as a human @mention — +/// it can pull the bot into a thread regardless of the `allow_bot_messages` setting. +#[cfg(test)] +fn is_trusted_bot_mention( + is_mentioned: bool, + trusted_bot_ids: &HashSet, + author_id: u64, +) -> bool { + is_mentioned && !trusted_bot_ids.is_empty() && trusted_bot_ids.contains(&author_id) +} + +/// Pure decision function: should a DM be processed? +/// Returns `true` if the DM should be processed (bot responds). +/// Mirrors the DM gating logic in EventHandler::message: +/// - `allow_dm` must be true +/// - `allowed_users` still applies (checked separately via `is_denied_user`) +/// - DMs bypass `allowed_channels` and `@mention` requirements +#[cfg(test)] +fn should_process_dm(allow_dm: bool) -> bool { + allow_dm +} + +/// Pure decision function: should thread creation be skipped? +/// Returns `true` when the message should reuse the current channel +/// directly (existing thread or DM), `false` when a new thread should +/// be created. Pins the invariant that DMs never call +/// `get_or_create_thread()` — Discord DM channels cannot create threads. +#[cfg(test)] +fn should_skip_thread_creation(in_thread: bool, is_dm: bool) -> bool { + in_thread || is_dm +} + +/// Pure decision function: should this message be processed or ignored? +/// Returns `true` if the message should be processed (bot responds). +/// Extracted from the EventHandler::message gating logic for testability. +#[cfg(test)] +fn should_process_user_message( + mode: AllowUsers, + is_mentioned: bool, + in_thread: bool, + involved: bool, + other_bot_present: bool, +) -> bool { + if is_mentioned { + return true; + } + match mode { + AllowUsers::Mentions => false, + AllowUsers::Involved => in_thread && involved, + AllowUsers::MultibotMentions => { + if !in_thread || !involved { + return false; + } + !other_bot_present + } + } +} + +/// Returns true if any bot message in `messages` contains a turn limit warning. +/// Used to dedup `WarnAndStop` across multiple bot processes sharing a thread. (#530) +/// Note: this is best-effort — a narrow race window exists where two bots fetch +/// simultaneously and both see no warning, resulting in a duplicate. For most +/// deployments this is acceptable; strict once-only semantics would require +/// shared state (e.g. gateway-owned emission or distributed lock). +/// +/// Accepts `(is_bot, content)` pairs so the logic can be unit-tested without +/// constructing `serenity::model::channel::Message` values (see existing test +/// boundary comment at `format_thread_export`). +fn turn_limit_warning_present(messages: &[(bool, &str)]) -> bool { + messages + .iter() + .any(|(is_bot, content)| *is_bot && content.contains(BOT_TURN_LIMIT_WARNING_PREFIX)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::bot_turns::{TurnResult, HARD_BOT_TURN_LIMIT, BOT_TURN_LIMIT_WARNING_PREFIX}; + + // --- resolve_mentions tests --- + + /// Bot's own <@UID> mention is stripped from the prompt. + #[test] + fn resolve_mentions_strips_bot_mention() { + let bot_id = UserId::new(111); + let result = resolve_mentions("hello <@111> world", bot_id, &HashSet::new()); + assert_eq!(result, "hello world"); + } + + /// Bot's own legacy <@!UID> mention is also stripped. + #[test] + fn resolve_mentions_strips_bot_mention_legacy() { + let bot_id = UserId::new(111); + let result = resolve_mentions("hello <@!111> world", bot_id, &HashSet::new()); + assert_eq!(result, "hello world"); + } + + /// Other users' <@UID> mentions are preserved so the LLM can mention them back. + #[test] + fn resolve_mentions_preserves_other_user_mentions() { + let bot_id = UserId::new(111); + let result = resolve_mentions("<@111> say hi to <@222>", bot_id, &HashSet::new()); + assert_eq!(result, "say hi to <@222>"); + } + + /// Role mentions <@&UID> are replaced with @(role) placeholder. + #[test] + fn resolve_mentions_replaces_role_mentions() { + let bot_id = UserId::new(111); + let result = resolve_mentions("hello <@&999>", bot_id, &HashSet::new()); + assert_eq!(result, "hello @(role)"); + } + + /// Message containing only the bot mention results in empty string. + #[test] + fn resolve_mentions_empty_after_strip() { + let bot_id = UserId::new(111); + let result = resolve_mentions("<@111>", bot_id, &HashSet::new()); + assert_eq!(result, ""); + } + + /// Allowed role mentions are stripped from prompt (not replaced with @(role)). + #[test] + fn resolve_mentions_strips_allowed_role() { + let bot_id = UserId::new(111); + let roles: HashSet = [999].into_iter().collect(); + let result = resolve_mentions("hello <@&999> world", bot_id, &roles); + assert_eq!(result, "hello world"); + } + + /// Non-allowed role mentions are still replaced with @(role). + #[test] + fn resolve_mentions_keeps_other_roles_as_placeholder() { + let bot_id = UserId::new(111); + let roles: HashSet = [999].into_iter().collect(); + let result = resolve_mentions("<@&999> check <@&888>", bot_id, &roles); + assert_eq!(result, "check @(role)"); + } + + #[test] + fn video_attachment_block_includes_actionable_metadata() { + let block = video_attachment_block( + "demo.mp4", + Some("video/mp4"), + 12345, + "https://cdn.discordapp.com/attachments/demo.mp4", + ); + + let ContentBlock::Text { text } = block else { + panic!("video attachments must be forwarded as text metadata"); + }; + + assert!(text.contains("[Video attachment]")); + assert!(text.contains("filename: demo.mp4")); + assert!(text.contains("content_type: video/mp4")); + assert!(text.contains("size_bytes: 12345")); + assert!(text.contains("url: https://cdn.discordapp.com/attachments/demo.mp4")); + } + + // --- thread-race error detection --- + + /// Detects the Discord error code for "thread already exists" (160004). + #[test] + fn is_thread_already_exists_matches_code() { + let err = anyhow::Error::msg( + r#"HTTP error: {"code": 160004, "message": "A thread has already been created for this message."}"#, + ); + assert!(is_thread_already_exists_error(&err)); + } + + /// Detects the human-readable form of the error in case serenity renders + /// it without the numeric code. + #[test] + fn is_thread_already_exists_matches_message() { + let err = anyhow::anyhow!("A thread has already been created for this message."); + assert!(is_thread_already_exists_error(&err)); + } + + /// Unrelated errors do not match — we don't want the fallback path + /// swallowing real failures like permission denied. + #[test] + fn is_thread_already_exists_ignores_other_errors() { + let err = anyhow::anyhow!("Missing Permissions"); + assert!(!is_thread_already_exists_error(&err)); + let err = anyhow::anyhow!("rate limit exceeded"); + assert!(!is_thread_already_exists_error(&err)); + } + + // --- thread export helpers --- + + #[test] + fn sanitize_filename_component_keeps_safe_ascii() { + assert_eq!( + sanitize_filename_component("release notes_v2"), + "release-notes_v2" + ); + } + + #[test] + fn sanitize_filename_component_falls_back_for_empty_result() { + let result = sanitize_filename_component("///..."); + // Fallback is a YYYYMMDD date string + assert_eq!(result.len(), 8); + assert!(result.chars().all(|c| c.is_ascii_digit())); + } + + // --- assemble_export --- + // Split out from format_thread_export so we can test the truncation + // boundary without constructing serenity::model::channel::Message values. + + #[test] + fn assemble_export_empty_entries_returns_header_only() { + let (out, written, truncated) = assemble_export("HDR\n", &[], 1024); + assert_eq!(out, "HDR\n"); + assert_eq!(written, 0); + assert!(!truncated); + } + + #[test] + fn assemble_export_single_oversized_entry_writes_zero_and_marks_truncated() { + let entries = vec!["x".repeat(200)]; + let (out, written, truncated) = assemble_export("h\n", &entries, 50); + assert_eq!(written, 0); + assert!(truncated); + // Footer needs ~56 bytes; max_bytes 50 leaves ≤48 of room, so it is + // intentionally omitted (it can't be appended without exceeding the + // limit). The header is still present. + assert!(out.starts_with("h\n")); + assert!(!out.contains("xx")); + } + + #[test] + fn assemble_export_entry_at_exact_boundary_is_included() { + // header(2) + entry(3) == max_bytes(5); the strict-greater check + // keeps the entry in. + let (out, written, truncated) = assemble_export("h\n", &["abc".to_string()], 5); + assert_eq!(written, 1); + assert!(!truncated); + assert_eq!(out, "h\nabc"); + } + + #[test] + fn assemble_export_entry_one_byte_over_boundary_is_excluded() { + // header(2) + entry(4) == 6 > max_bytes(5); entry is dropped. + let (out, written, truncated) = assemble_export("h\n", &["abcd".to_string()], 5); + assert_eq!(written, 0); + assert!(truncated); + assert!(out.starts_with("h\n")); + assert!(!out.contains("abcd")); + } + + #[test] + fn assemble_export_appends_footer_when_room_remains() { + // First two short entries fit; the long third entry would overflow, + // and the remaining headroom is enough for the truncation footer. + let entries = vec!["a\n".to_string(), "b\n".to_string(), "c".repeat(500)]; + let (out, written, truncated) = assemble_export("h\n", &entries, 200); + assert_eq!(written, 2); + assert!(truncated); + assert!(out.contains("[Export truncated")); + } + + // --- snowflake conversion --- + + #[test] + fn timestamp_ms_to_snowflake_known_value() { + // 2026-05-10 00:00:00 UTC = 1778572800000 ms since Unix epoch + // Discord ms = 1778572800000 - 1420070400000 = 358502400000 + // Snowflake = 358502400000 << 22 = 1503238553600000000 (approx) + let ts_ms: u64 = 1_778_572_800_000; + let snowflake = timestamp_ms_to_snowflake(ts_ms); + // Verify round-trip: extract timestamp back from snowflake + let extracted_ms = (snowflake.get() >> 22) + DISCORD_EPOCH_MS; + assert_eq!(extracted_ms, ts_ms); + } + + #[test] + fn timestamp_ms_to_snowflake_at_discord_epoch_is_one() { + // At exactly the Discord epoch, discord_ms=0, shifted=0, clamped to 1 + let snowflake = timestamp_ms_to_snowflake(DISCORD_EPOCH_MS); + assert_eq!(snowflake.get(), 1); + } + + #[test] + fn timestamp_ms_to_snowflake_before_epoch_saturates() { + // Timestamp before Discord epoch should saturate to 1 + let snowflake = timestamp_ms_to_snowflake(1_000_000_000_000); + assert_eq!(snowflake.get(), 1); + } + + // --- ExportFilter cap logic --- + + #[test] + fn export_filter_default_cap_is_100() { + // Default (no params) uses Limit(100) + let filter = ExportFilter::Limit(100); + let cap = match &filter { + ExportFilter::Limit(n) => *n, + _ => THREAD_EXPORT_MESSAGE_LIMIT, + }; + assert_eq!(cap, 100); + } + + #[test] + fn export_filter_all_cap_is_5000() { + let filter = ExportFilter::All; + let cap = match &filter { + ExportFilter::Limit(n) => *n, + _ => THREAD_EXPORT_MESSAGE_LIMIT, + }; + assert_eq!(cap, THREAD_EXPORT_MESSAGE_LIMIT); + assert_eq!(cap, 5000); + } + + #[test] + fn export_filter_limit_uses_custom_cap() { + let filter = ExportFilter::Limit(250); + let cap = match &filter { + ExportFilter::Limit(n) => *n, + _ => THREAD_EXPORT_MESSAGE_LIMIT, + }; + assert_eq!(cap, 250); + } + + #[test] + fn export_filter_after_uses_global_cap() { + let filter = ExportFilter::After(MessageId::new(123456789)); + let cap = match &filter { + ExportFilter::Limit(n) => *n, + _ => THREAD_EXPORT_MESSAGE_LIMIT, + }; + assert_eq!(cap, THREAD_EXPORT_MESSAGE_LIMIT); + } + + // --- should_process_user_message tests (GIVEN/WHEN/THEN) --- + // Tests the multibot-mentions gating logic extracted from EventHandler::message. + // The bug in #481 was that other bots' messages were filtered by bot gating + // before multibot detection could run, so the bot never learned the thread + // was multi-bot and responded without @mention. + + /// GIVEN: multibot-mentions mode, single-bot thread, bot is involved + /// WHEN: human sends message without @mention + /// THEN: bot responds (natural conversation) + #[test] + fn multibot_mentions_single_bot_thread_no_mention() { + assert!(should_process_user_message( + AllowUsers::MultibotMentions, + false, // is_mentioned + true, // in_thread + true, // involved + false, // other_bot_present + )); + } + + /// GIVEN: multibot-mentions mode, multi-bot thread (other bot has posted) + /// WHEN: human sends message without @mention + /// THEN: bot does NOT respond (requires @mention in multi-bot thread) + /// This is the exact scenario from bug #481. + #[test] + fn multibot_mentions_multi_bot_thread_no_mention() { + assert!(!should_process_user_message( + AllowUsers::MultibotMentions, + false, // is_mentioned + true, // in_thread + true, // involved + true, // other_bot_present ← another bot posted + )); + } + + /// GIVEN: multibot-mentions mode, multi-bot thread + /// WHEN: human sends message WITH @mention + /// THEN: bot responds (explicit @mention always works) + #[test] + fn multibot_mentions_multi_bot_thread_with_mention() { + assert!(should_process_user_message( + AllowUsers::MultibotMentions, + true, // is_mentioned + true, // in_thread + true, // involved + true, // other_bot_present + )); + } + + /// GIVEN: multibot-mentions mode, not in a thread (main channel) + /// WHEN: human sends message without @mention + /// THEN: bot does NOT respond (main channel always requires @mention) + #[test] + fn multibot_mentions_main_channel_no_mention() { + assert!(!should_process_user_message( + AllowUsers::MultibotMentions, + false, // is_mentioned + false, // in_thread (main channel) + false, // involved + false, // other_bot_present + )); + } + + /// GIVEN: multibot-mentions mode, in thread but bot is NOT involved + /// WHEN: human sends message without @mention + /// THEN: bot does NOT respond (not participating in this thread) + #[test] + fn multibot_mentions_not_involved() { + assert!(!should_process_user_message( + AllowUsers::MultibotMentions, + false, // is_mentioned + true, // in_thread + false, // involved ← bot hasn't posted here + false, // other_bot_present + )); + } + + /// GIVEN: involved mode, multi-bot thread + /// WHEN: human sends message without @mention + /// THEN: bot responds (involved mode ignores multi-bot status) + #[test] + fn involved_mode_ignores_multibot() { + assert!(should_process_user_message( + AllowUsers::Involved, + false, // is_mentioned + true, // in_thread + true, // involved + true, // other_bot_present ← ignored in involved mode + )); + } + + /// GIVEN: mentions mode + /// WHEN: human sends message without @mention (even in own thread) + /// THEN: bot does NOT respond (always requires @mention) + #[test] + fn mentions_mode_always_requires_mention() { + assert!(!should_process_user_message( + AllowUsers::Mentions, + false, // is_mentioned + true, // in_thread + true, // involved + false, // other_bot_present + )); + } + + /// After soft limit fires once (n==20), subsequent bot messages still return + /// SoftLimit but with n>20. The caller warns only when n==max (exact hit), + /// preventing warning messages from ping-ponging between bots. + #[test] + fn soft_limit_warn_once_semantics() { + let mut t = BotTurnTracker::new(20); + for _ in 0..19 { + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + } + // n==20: exact hit — caller should send warning + assert_eq!(t.on_bot_message("t1"), TurnResult::SoftLimit(20)); + // n==21: past limit — caller should silently return (no warning) + assert_eq!(t.on_bot_message("t1"), TurnResult::Throttled); + // n==22: still past — still silent + assert_eq!(t.on_bot_message("t1"), TurnResult::Throttled); + } + + /// Hard limit also carries count for warn-once semantics. + #[test] + fn hard_limit_warn_once_semantics() { + let mut t = BotTurnTracker::new(HARD_BOT_TURN_LIMIT + 1); // soft > hard so hard fires first + for _ in 0..HARD_BOT_TURN_LIMIT - 1 { + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + } + // Exact hit — warn + assert_eq!(t.on_bot_message("t1"), TurnResult::HardLimit); + // Past — silent + assert_eq!(t.on_bot_message("t1"), TurnResult::Stopped); + } + + /// Regression test for #497: system messages (thread created, pin, etc.) + /// should NOT reset the bot turn counter. The filtering happens at the + /// call site (MessageType check); this verifies the counter stays put + /// when on_human_message is never called. + #[test] + fn system_message_does_not_reset_counter() { + let mut t = BotTurnTracker::new(3); + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + // No on_human_message (system message filtered out at call site) + assert_eq!(t.on_bot_message("t1"), TurnResult::SoftLimit(3)); + } + + // --- build_sender_context tests (regression for #581 → #584) --- + // PR #583 fixed SenderContext to use parent channel_id when in a thread. + // These tests verify the pure function extracted from EventHandler::message. + + /// In-thread message: channel_id = parent, thread_id = thread channel ID. + #[test] + fn build_sender_context_in_thread() { + let ctx = build_sender_context( + "user1", + "alice", + "Alice", + "thread_ch", + Some("parent_ch"), + false, + "2026-05-01T00:00:00Z", + "msg123", + "bot99", + ); + assert_eq!(ctx.channel_id, "parent_ch"); + assert_eq!(ctx.thread_id, Some("thread_ch".to_string())); + assert_eq!(ctx.channel, "discord"); + assert_eq!(ctx.sender_id, "user1"); + assert!(!ctx.is_bot); + assert_eq!(ctx.receiver_id, Some("bot99".to_string())); + } + + /// Non-thread message: channel_id = message channel, thread_id = None. + #[test] + fn build_sender_context_not_in_thread() { + let ctx = build_sender_context( + "user1", + "alice", + "Alice", + "main_ch", + None, + false, + "2026-05-01T00:00:00Z", + "msg456", + "bot99", + ); + assert_eq!(ctx.channel_id, "main_ch"); + assert_eq!(ctx.thread_id, None); + } + + /// Bot sender: is_bot flag propagated correctly. + #[test] + fn build_sender_context_bot_sender() { + let ctx = build_sender_context( + "bot1", + "mybot", + "MyBot", + "ch", + Some("parent"), + true, + "2026-05-01T00:00:00Z", + "msg789", + "bot99", + ); + assert!(ctx.is_bot); + assert_eq!(ctx.channel_id, "parent"); + assert_eq!(ctx.thread_id, Some("ch".to_string())); + } + + // --- detect_thread tests (regression for #506 → #518 → #519) --- + // PR #506 used parent_id.is_some() to detect threads, but category text + // channels also have parent_id (pointing to the category). This caused + // the bot to skip thread creation for normal channels inside categories. + // + // detect_thread() uses thread_metadata.is_some() — the canonical check + // per Discord API docs. Table-driven to cover all channel scenarios. + + const BOT: u64 = 1000; + const OTHER: u64 = 2000; + const PARENT_CH: u64 = 100; + const CATEGORY: u64 = 200; + + /// Helper: build an allowed_channels set from a slice. + fn allowed(ids: &[u64]) -> HashSet { + ids.iter().copied().collect() + } + + /// Table-driven: each row is a realistic Discord channel scenario. + #[test] + fn detect_thread_table() { + struct Case { + name: &'static str, + has_thread_metadata: bool, + parent_id: Option, + owner_id: Option, + bot_id: u64, + allowed_channels: HashSet, + allow_all: bool, + in_allowed: bool, + expect: (bool, Option), // (in_thread, bot_owns) + } + + let cases = vec![ + // --- Non-thread channels: thread_metadata = None --- + Case { + name: "text channel under category (regression #506)", + has_thread_metadata: false, + parent_id: Some(CATEGORY), // points to category, NOT a thread + owner_id: None, + bot_id: BOT, + allowed_channels: allowed(&[]), + allow_all: false, + in_allowed: true, + expect: (false, None), + }, + Case { + name: "top-level text channel (no category)", + has_thread_metadata: false, + parent_id: None, + owner_id: None, + bot_id: BOT, + allowed_channels: allowed(&[]), + allow_all: false, + in_allowed: true, + expect: (false, None), + }, + Case { + name: "voice channel under category", + has_thread_metadata: false, + parent_id: Some(CATEGORY), + owner_id: None, + bot_id: BOT, + allowed_channels: allowed(&[]), + allow_all: false, + in_allowed: false, + expect: (false, None), + }, + // --- Thread channels: thread_metadata = Some --- + Case { + name: "public thread, parent in allowlist, bot owns", + has_thread_metadata: true, + parent_id: Some(PARENT_CH), + owner_id: Some(BOT), + bot_id: BOT, + allowed_channels: allowed(&[PARENT_CH]), + allow_all: false, + in_allowed: false, + expect: (true, Some(true)), + }, + Case { + name: "public thread, parent in allowlist, other user owns", + has_thread_metadata: true, + parent_id: Some(PARENT_CH), + owner_id: Some(OTHER), + bot_id: BOT, + allowed_channels: allowed(&[PARENT_CH]), + allow_all: false, + in_allowed: false, + expect: (true, Some(false)), + }, + Case { + name: "thread, parent NOT in allowlist, not allow_all", + has_thread_metadata: true, + parent_id: Some(PARENT_CH), + owner_id: Some(BOT), + bot_id: BOT, + allowed_channels: allowed(&[]), + allow_all: false, + in_allowed: false, + expect: (false, Some(true)), + }, + Case { + name: "thread, allow_all_channels = true", + has_thread_metadata: true, + parent_id: Some(PARENT_CH), + owner_id: Some(OTHER), + bot_id: BOT, + allowed_channels: allowed(&[]), + allow_all: true, + in_allowed: false, + expect: (true, Some(false)), + }, + Case { + name: "thread, in_allowed_channel = true (parent is the allowed channel)", + has_thread_metadata: true, + parent_id: Some(PARENT_CH), + owner_id: None, + bot_id: BOT, + allowed_channels: allowed(&[]), + allow_all: false, + in_allowed: true, + expect: (true, Some(false)), + }, + // --- Defensive: partial data --- + Case { + name: "thread with parent_id = None (defensive, partial API data)", + has_thread_metadata: true, + parent_id: None, + owner_id: Some(BOT), + bot_id: BOT, + allowed_channels: allowed(&[PARENT_CH]), + allow_all: false, + in_allowed: false, + expect: (false, Some(true)), // can't verify parent → not allowed, but bot still owns + }, + ]; + + for c in &cases { + let result = detect_thread( + c.has_thread_metadata, + c.parent_id, + c.owner_id, + c.bot_id, + &c.allowed_channels, + c.allow_all, + c.in_allowed, + ); + assert_eq!(result, c.expect, "FAILED: {}", c.name); + } + } + + // --- WarnAndStop regression test (#633) --- + // The WarnAndStop path now delegates to detect_thread(). This test pins + // the exact scenario from #633: a category child channel whose category + // ID is in another bot's allowed_channels must NOT be treated as allowed. + #[test] + fn detect_thread_rejects_category_child_in_warn_and_stop() { + let category_id: u64 = 200; + let allowed = HashSet::from([category_id]); + // Category child: has parent_id (the category) but NO thread_metadata. + let (in_thread, _) = + detect_thread(false, Some(category_id), None, 1000, &allowed, false, false); + assert!( + !in_thread, + "category child must not match allowed_channels via parent_id" + ); + } + + // --- Per-thread streaming tests (#534) --- + // Streaming ON by default, OFF when another bot is detected in the thread. + + /// Single bot thread: streaming enabled. + #[test] + fn discord_streams_when_no_other_bot() { + let adapter = super::DiscordAdapter::new(Arc::new(super::Http::new(""))); + assert!(adapter.use_streaming(false)); + } + + /// Multi-bot thread: send-once to avoid edit interference. + #[test] + fn discord_no_stream_when_other_bot_present() { + let adapter = super::DiscordAdapter::new(Arc::new(super::Http::new(""))); + assert!(!adapter.use_streaming(true)); + } + + // --- resolve_channel tests --- + + #[test] + fn resolve_channel_uses_channel_id_when_no_thread() { + let ch = ChannelRef { + platform: "discord".into(), + channel_id: "111".into(), + thread_id: None, + parent_id: None, + origin_event_id: None, + }; + assert_eq!(DiscordAdapter::resolve_channel(&ch), "111"); + } + + #[test] + fn resolve_channel_prefers_thread_id_when_set() { + let ch = ChannelRef { + platform: "discord".into(), + channel_id: "111".into(), + thread_id: Some("222".into()), + parent_id: None, + origin_event_id: None, + }; + assert_eq!(DiscordAdapter::resolve_channel(&ch), "222"); + } + + // --- is_denied_user tests (regression for #604) --- + + /// Human not in allowlist → denied. + #[test] + fn denied_user_human_not_in_allowlist() { + let allowed = HashSet::from([100]); + assert!(is_denied_user(false, false, &allowed, 999)); + } + + /// Human in allowlist → allowed. + #[test] + fn denied_user_human_in_allowlist() { + let allowed = HashSet::from([100]); + assert!(!is_denied_user(false, false, &allowed, 100)); + } + + /// Bot not in allowlist → allowed (bots skip user gate). This is the #604 fix. + #[test] + fn denied_user_bot_skips_allowlist() { + let allowed = HashSet::from([100]); + assert!(!is_denied_user(true, false, &allowed, 999)); + } + + // --- Trusted bot mention bypass tests --- + // A trusted bot @mentioning this bot bypasses allow_bot_messages mode, + // treating the mention the same as a human @mention. + + /// GIVEN: trusted bot @mentions this bot + /// THEN: bypass is granted (treated as human mention) + #[test] + fn trusted_bot_mention_bypasses_gate() { + let trusted = HashSet::from([42]); + assert!(is_trusted_bot_mention(true, &trusted, 42)); + } + + /// GIVEN: untrusted bot @mentions this bot + /// THEN: no bypass (normal bot gating applies) + #[test] + fn untrusted_bot_mention_no_bypass() { + let trusted = HashSet::from([42]); + assert!(!is_trusted_bot_mention(true, &trusted, 99)); + } + + /// GIVEN: trusted bot sends message WITHOUT @mention + /// THEN: no bypass (must explicitly @mention) + #[test] + fn trusted_bot_no_mention_no_bypass() { + let trusted = HashSet::from([42]); + assert!(!is_trusted_bot_mention(false, &trusted, 42)); + } + + /// GIVEN: empty trusted_bot_ids (feature not configured) + /// THEN: no bypass regardless of mention + #[test] + fn empty_trusted_ids_no_bypass() { + let trusted: HashSet = HashSet::new(); + assert!(!is_trusted_bot_mention(true, &trusted, 42)); + } + + // --- Trusted bot admission integration tests --- + // These test the full bot gating decision path: allow_bot_messages mode + + // trusted_bot_ids + trusted mention bypass, mirroring the actual logic in + // EventHandler::message. + + /// Simulates the bot admission decision from EventHandler::message. + /// Returns `true` if the bot message would be processed (not dropped). + fn should_admit_bot_message( + allow_bot_messages: AllowBots, + is_mentioned: bool, + trusted_bot_ids: &HashSet, + author_id: u64, + ) -> bool { + let trusted_mention = is_mentioned + && !trusted_bot_ids.is_empty() + && trusted_bot_ids.contains(&author_id); + + if !trusted_mention { + match allow_bot_messages { + AllowBots::Off => return false, + AllowBots::Mentions => { + if !is_mentioned { + return false; + } + } + AllowBots::All => {} // would check consecutive cap, skip for unit test + } + + if !trusted_bot_ids.is_empty() && !trusted_bot_ids.contains(&author_id) { + return false; + } + } + true + } + + /// GIVEN: allow_bot_messages=Off, trusted bot @mentions this bot + /// THEN: admitted (trusted mention overrides Off mode) + #[test] + fn bot_admission_trusted_mention_overrides_off() { + let trusted = HashSet::from([42]); + assert!(should_admit_bot_message(AllowBots::Off, true, &trusted, 42)); + } + + /// GIVEN: allow_bot_messages=Off, untrusted bot @mentions this bot + /// THEN: rejected (Off mode blocks) + #[test] + fn bot_admission_untrusted_mention_blocked_by_off() { + let trusted = HashSet::from([42]); + assert!(!should_admit_bot_message(AllowBots::Off, true, &trusted, 99)); + } + + /// GIVEN: allow_bot_messages=Off, trusted bot without @mention + /// THEN: rejected (no mention = no bypass) + #[test] + fn bot_admission_trusted_no_mention_blocked_by_off() { + let trusted = HashSet::from([42]); + assert!(!should_admit_bot_message(AllowBots::Off, false, &trusted, 42)); + } + + /// GIVEN: allow_bot_messages=Off, empty trusted_bot_ids, bot @mentions + /// THEN: rejected (feature not configured) + #[test] + fn bot_admission_empty_trusted_ids_off_mode() { + let trusted: HashSet = HashSet::new(); + assert!(!should_admit_bot_message(AllowBots::Off, true, &trusted, 42)); + } + + /// GIVEN: allow_bot_messages=Mentions, trusted bot @mentions + /// THEN: admitted (would pass anyway, but bypass also works) + #[test] + fn bot_admission_mentions_mode_trusted_mention() { + let trusted = HashSet::from([42]); + assert!(should_admit_bot_message(AllowBots::Mentions, true, &trusted, 42)); + } + + /// GIVEN: allow_bot_messages=All, untrusted bot (not in trusted_bot_ids) + /// THEN: rejected by trusted_bot_ids filter + #[test] + fn bot_admission_all_mode_untrusted_bot_rejected() { + let trusted = HashSet::from([42]); + assert!(!should_admit_bot_message(AllowBots::All, false, &trusted, 99)); + } + + // --- DM gating tests (#656) --- + // DMs are gated by `allow_dm` config. When allowed, DMs bypass + // `allowed_channels` and treat the message as implicit @mention. + + /// GIVEN: allow_dm = false + /// WHEN: user sends a DM + /// THEN: DM is rejected + #[test] + fn dm_rejected_when_allow_dm_false() { + assert!(!should_process_dm(false)); + } + + /// GIVEN: allow_dm = true + /// WHEN: user sends a DM + /// THEN: DM is accepted + #[test] + fn dm_accepted_when_allow_dm_true() { + assert!(should_process_dm(true)); + } + + /// GIVEN: allow_dm = true, user NOT in allowed_users + /// WHEN: user sends a DM + /// THEN: user is denied (allowed_users still enforced in DMs) + #[test] + fn dm_denied_user_still_enforced() { + let allowed = HashSet::from([100]); + // DM passes allow_dm gate, but user gate still applies + assert!(should_process_dm(true)); + assert!(is_denied_user(false, false, &allowed, 999)); + } + + /// GIVEN: allow_dm = true, user in allowed_users + /// WHEN: user sends a DM + /// THEN: user is allowed + #[test] + fn dm_allowed_user_passes() { + let allowed = HashSet::from([100]); + assert!(should_process_dm(true)); + assert!(!is_denied_user(false, false, &allowed, 100)); + } + + /// DMs are treated as implicit @mention — should_process_user_message + /// is never called for DMs (the `!is_dm` guard skips it). + /// This test verifies the Involved mode would reject a non-thread, + /// non-mentioned message — confirming DMs MUST bypass this check. + #[test] + fn dm_must_bypass_user_message_gating() { + // Without the `!is_dm` bypass, a DM would be rejected by Involved mode + // because is_mentioned=false and in_thread=false. + assert!(!should_process_user_message( + AllowUsers::Involved, + false, // is_mentioned (DMs don't have @mention) + false, // in_thread (DMs are not threads) + false, // involved + false, // other_bot_present + )); + } + + // --- Thread creation skip tests (regression for #656 DM bug) --- + // Pins the invariant: DMs must never call get_or_create_thread(). + // Discord DM channels do not support thread creation. + + /// GIVEN: is_dm = true, not in a thread + /// THEN: skip thread creation (use DM channel directly) + #[test] + fn dm_skips_thread_creation() { + assert!(should_skip_thread_creation(false, true)); + } + + /// GIVEN: already in a thread, not a DM + /// THEN: skip thread creation (reuse existing thread) + #[test] + fn existing_thread_skips_thread_creation() { + assert!(should_skip_thread_creation(true, false)); + } + + /// GIVEN: not in a thread, not a DM (normal channel message) + /// THEN: do NOT skip — create a new thread + #[test] + fn normal_channel_creates_thread() { + assert!(!should_skip_thread_creation(false, false)); + } + + // --- WarnAndStop dedup tests (#530) --- + + #[test] + fn dedup_detects_existing_bot_warning() { + let msg = format!("{} (20/20). A human must reply.", BOT_TURN_LIMIT_WARNING_PREFIX); + assert!(turn_limit_warning_present(&[(true, &msg)])); + } + + #[test] + fn dedup_ignores_human_warning_text() { + let msg = format!("{} (20/20). A human must reply.", BOT_TURN_LIMIT_WARNING_PREFIX); + assert!(!turn_limit_warning_present(&[(false, &msg)])); + } + + #[test] + fn dedup_returns_false_when_no_warning() { + assert!(!turn_limit_warning_present(&[(true, "hello"), (false, "world")])); + } + + #[test] + fn dedup_returns_false_for_empty_messages() { + assert!(!turn_limit_warning_present(&[])); + } +} diff --git a/crates/openab-core/src/dispatch.rs b/crates/openab-core/src/dispatch.rs new file mode 100644 index 000000000..97d5f25e3 --- /dev/null +++ b/crates/openab-core/src/dispatch.rs @@ -0,0 +1,1727 @@ +//! Turn-boundary message batching dispatcher. +//! +//! See ADR: docs/adr/turn-boundary-batching.md for full design rationale. +//! +//! # Invariants +//! - I1: First message after idle has zero added latency. +//! - I2: At most one in-flight ACP turn per thread. +//! - I3: Broker structural fidelity — no merging, splitting, reordering, or +//! semantic transformation of arrival events. + +use std::collections::HashMap; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::{Arc, Mutex}; +use std::time::{Duration, Instant}; + +use anyhow::Result; +use async_trait::async_trait; +use tracing::{debug, error, info, info_span, warn}; + +use crate::acp::ContentBlock; +use crate::adapter::{AdapterRouter, ChannelRef, ChatAdapter, MessageRef}; +use crate::config::ReactionsConfig; +use crate::error_display::format_user_error; +use crate::reactions::StatusReactionController; + +// --------------------------------------------------------------------------- +// Public types +// --------------------------------------------------------------------------- + +/// One arrival event buffered for a future ACP turn. +pub struct BufferedMessage { + /// Serialised SenderContext JSON (already built by the platform adapter). + pub sender_json: String, + /// Author display name — denormalised from `sender_json` so observability + /// fields (per-event tracing in `dispatch_batch`) don't pay a JSON parse. + /// Per ADR §2.3 each arrival event carries its sender name. + pub sender_name: String, + /// User-visible prompt text (verbatim, never transformed). + pub prompt: String, + /// Attachment blocks (images, STT transcripts) in arrival order. + pub extra_blocks: Vec, + /// Anchor for reactions (👀 / ❌). + pub trigger_msg: MessageRef, + /// Broker receive time — used for `buffer_wait_ms` observability. + pub arrived_at: Instant, + /// Rough token estimate for `max_batch_tokens` cap. + pub estimated_tokens: usize, + /// Snapshot at submit time. Captured per-message so a batch reflects the + /// freshest known state; `dispatch_batch` reads `batch.last()`. + pub other_bot_present: bool, + /// Slack streaming recipient `(user_id, team_id)` for `chat.startStream`, + /// captured at message-arrival time (after allow-list) and bound to this turn + /// — no shared thread cache, so no cross-turn race. Populated for real-user + /// Slack turns regardless of `assistant_mode`; only *consumed* when assistant + /// mode's native streaming is active. `None` for non-Slack platforms and + /// bot-authored turns. + pub recipient: Option<(String, String)>, +} + +/// How `thread_key` is built for the dispatcher's per-thread map. +/// +/// - `Thread`: one mpsc per thread → all senders in a thread share one batch → one +/// ACP turn per batch (cheaper, but risks silent drop when the agent's single reply +/// forgets to address some senders). +/// - `Lane`: one mpsc per (thread, sender) → each sender batches independently and +/// gets a dedicated ACP turn. Sessions are still shared per-thread; turns serialise +/// through the shared session. +/// +/// Derived from `config::MessageProcessingMode` in `main.rs`. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum BatchGrouping { + Thread, + Lane, +} + +/// Error returned by `Dispatcher::submit`. +#[derive(Debug)] +pub enum DispatchError { + /// The per-thread consumer task has exited unexpectedly. + ConsumerDead, +} + +impl std::fmt::Display for DispatchError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::ConsumerDead => write!(f, "dispatch consumer exited unexpectedly"), + } + } +} + +impl std::error::Error for DispatchError {} + +// --------------------------------------------------------------------------- +// Internal types +// --------------------------------------------------------------------------- + +struct ThreadHandle { + tx: tokio::sync::mpsc::Sender, + consumer: tokio::task::JoinHandle<()>, + /// Race-safe eviction counter (§2.5). Plain u64 — all reads/writes under per_thread lock. + generation: u64, + channel_id: String, + adapter_kind: String, +} + +impl ThreadHandle { + /// Approximate number of messages still buffered in the mpsc — used for + /// shutdown / cancel logging. Not exact: tokio's mpsc has no sync `.len()`. + fn pending_count(&self) -> usize { + self.tx.max_capacity() - self.tx.capacity() + } +} + +// --------------------------------------------------------------------------- +// DispatchTarget — trait seam between Dispatcher and AdapterRouter +// --------------------------------------------------------------------------- + +/// Surface that `consumer_loop` / `dispatch_batch` need from the underlying +/// router. Extracted as a trait so the dispatcher can be unit-tested without +/// spinning up a real `SessionPool` (which forks ACP CLI subprocesses). +/// `AdapterRouter` is the production implementor; tests use a mock that +/// records calls. +#[async_trait] +pub trait DispatchTarget: Send + Sync + 'static { + fn reactions_config(&self) -> &ReactionsConfig; + + /// Workspace aliases from config (for `[[ws:@alias]]` resolution). + fn workspace_aliases(&self) -> std::collections::HashMap; + + /// Bot home directory (security boundary for workspace resolution). + fn bot_home(&self) -> std::path::PathBuf; + + /// Ensure the ACP session for `session_key` exists (idempotent). + /// Returns `true` if a new session was created, `false` if it already existed. + async fn ensure_session(&self, session_key: &str, working_dir: Option<&str>) -> Result; + + /// Destroy the session for `session_key` (used to rollback on directive failure). + async fn reset_session(&self, session_key: &str); + + /// Drive one ACP turn with the pre-packed `content_blocks`. + #[allow(clippy::too_many_arguments)] + async fn stream_prompt_blocks( + &self, + adapter: &Arc, + session_key: &str, + content_blocks: Vec, + thread_channel: &ChannelRef, + reactions: Arc, + other_bot_present: bool, + recipient: Option<(String, String)>, + ) -> Result<()>; +} + +#[async_trait] +impl DispatchTarget for AdapterRouter { + fn reactions_config(&self) -> &ReactionsConfig { + AdapterRouter::reactions_config(self) + } + + fn workspace_aliases(&self) -> std::collections::HashMap { + self.workspace_aliases_map() + } + + fn bot_home(&self) -> std::path::PathBuf { + self.bot_home_path() + } + + async fn ensure_session(&self, session_key: &str, working_dir: Option<&str>) -> Result { + self.pool().get_or_create(session_key, working_dir).await + } + + async fn reset_session(&self, session_key: &str) { + let _ = self.pool().reset_session(session_key).await; + } + + async fn stream_prompt_blocks( + &self, + adapter: &Arc, + session_key: &str, + content_blocks: Vec, + thread_channel: &ChannelRef, + reactions: Arc, + other_bot_present: bool, + recipient: Option<(String, String)>, + ) -> Result<()> { + AdapterRouter::stream_prompt_blocks( + self, + adapter, + session_key, + content_blocks, + thread_channel, + reactions, + other_bot_present, + recipient, + ) + .await + } +} + +// --------------------------------------------------------------------------- +// Dispatcher +// --------------------------------------------------------------------------- + +/// Default idle timeout for per-thread consumer tasks in batched modes (Thread / Lane). +/// When no message arrives within this window the consumer exits, allowing `per_thread` +/// map cleanup on the next `submit` (via `SendError` → `try_evict_locked`). Prevents +/// unbounded task/memory growth from one-shot thread keys (e.g. Slack non-thread messages). +/// +/// Batched modes need a longer window so a lane that's between trigger arrivals isn't +/// torn down and respawned on every message. +pub const DEFAULT_CONSUMER_IDLE_TIMEOUT: Duration = Duration::from_secs(300); + +/// Idle timeout for per-message mode (cap=1, no batching). Per-message dispatchers +/// don't benefit from holding consumers across message gaps — there is no batch +/// window to preserve — so a much shorter timeout reduces idle resource footprint +/// from one-shot thread keys (Little's Law: steady-state idle count = arrival rate +/// × idle window). +pub const PER_MESSAGE_CONSUMER_IDLE_TIMEOUT: Duration = Duration::from_secs(10); + +/// Resolve `(cap, grouping, idle_timeout)` for a given processing mode. +/// +/// Per-message mode forces cap=1 + Thread grouping + the short per-message idle +/// (one-shot threads shouldn't pin a consumer for 5 min); Thread / Lane modes +/// use the configured `max_buffered` and the default idle window. +pub fn dispatch_params( + mode: &crate::config::MessageProcessingMode, + max_buffered: usize, +) -> (usize, BatchGrouping, Duration) { + use crate::config::MessageProcessingMode; + match mode { + MessageProcessingMode::Message => { + (1, BatchGrouping::Thread, PER_MESSAGE_CONSUMER_IDLE_TIMEOUT) + } + MessageProcessingMode::Thread => ( + max_buffered, + BatchGrouping::Thread, + DEFAULT_CONSUMER_IDLE_TIMEOUT, + ), + MessageProcessingMode::Lane => ( + max_buffered, + BatchGrouping::Lane, + DEFAULT_CONSUMER_IDLE_TIMEOUT, + ), + } +} + +/// Per-thread message dispatcher for batched mode. +/// +/// Constructed once in `main.rs` and shared via `Arc`. Platform adapters call +/// `submit()` from their per-message `tokio::spawn`'d tasks. +pub struct Dispatcher { + /// std::sync::Mutex — critical section has no .await; tokio::Mutex buys nothing here. + per_thread: Mutex>, + /// Monotonic counter for `ThreadHandle.generation` (§2.5). Pre-fetched on + /// every `submit` and consumed only when a fresh handle is inserted; wasted + /// values are fine because generations need only be monotonic, not contiguous. + next_generation: AtomicU64, + target: Arc, + max_buffered_messages: usize, + max_batch_tokens: usize, + grouping: BatchGrouping, + idle_timeout: Duration, +} + +impl Dispatcher { + /// Construct a dispatcher with an explicit consumer idle timeout. Per-mode + /// callers in `main.rs` pass `PER_MESSAGE_CONSUMER_IDLE_TIMEOUT` for cap=1 + /// dispatchers and `DEFAULT_CONSUMER_IDLE_TIMEOUT` for batched modes. + pub fn with_idle_timeout( + target: Arc, + max_buffered_messages: usize, + max_batch_tokens: usize, + grouping: BatchGrouping, + idle_timeout: Duration, + ) -> Self { + Self { + per_thread: Mutex::new(HashMap::new()), + next_generation: AtomicU64::new(0), + target, + max_buffered_messages, + max_batch_tokens, + grouping, + idle_timeout, + } + } + + /// Build the dispatcher key for a (platform, thread, sender) tuple. + /// + /// In `Thread` mode the sender is ignored; in `Lane` mode the sender is appended + /// so each (thread, sender) pair gets its own mpsc and consumer. + /// + /// Note: this is the *dispatcher* key, not the *session pool* key. Session pool keys + /// are always `:` regardless of grouping (the ACP session is + /// shared per-thread by design). + pub fn key(&self, platform: &str, thread_id: &str, sender_id: &str) -> String { + match self.grouping { + BatchGrouping::Thread => format!("{platform}:{thread_id}"), + BatchGrouping::Lane => format!("{platform}:{thread_id}:{sender_id}"), + } + } + + /// Build the shared session pool key for a routed channel. + /// + /// Unlike dispatcher keys, session keys never include sender identity. + /// They track the logical conversation thread across all grouping modes. + fn session_key(thread_channel: &ChannelRef) -> String { + let logical_thread_id = thread_channel + .thread_id + .as_deref() + .unwrap_or(&thread_channel.channel_id); + format!("{}:{}", thread_channel.platform, logical_thread_id) + } + + /// Submit one arrival event for the given thread. + /// + /// - If the thread has no active consumer, one is spawned lazily. + /// - If the channel is full, this future parks until space is available + /// (backpressure — no data loss, no error). + /// - If the consumer has died (`SendError`), surfaces ❌ + ⚠️ and returns + /// `Err(DispatchError::ConsumerDead)` (§2.5). + /// + /// `adapter` is passed per-call (not stored on `Dispatcher`) because the + /// Discord adapter is constructed inside serenity's `ready` callback via + /// `OnceLock` — after the Dispatcher is built in `main.rs`. + pub async fn submit( + &self, + thread_key: String, + thread_channel: ChannelRef, + adapter: Arc, + msg: BufferedMessage, + ) -> Result<(), DispatchError> { + let cap = self.max_buffered_messages; + let target = Arc::clone(&self.target); + let max_tokens = self.max_batch_tokens; + let idle_timeout = self.idle_timeout; + + // Pre-fetch a generation in case we end up inserting a fresh handle. + // Wasted if the entry already exists; generations need only be monotonic. + let next_g = self.next_generation.fetch_add(1, Ordering::Relaxed); + + let (tx, my_generation) = { + // SAFETY: no .await while this guard is held — guard drops at end of block. + let mut map = self.per_thread.lock().unwrap(); + + // Proactive stale-entry cleanup: if the consumer has exited (idle + // timeout or unexpected), remove the entry so `or_insert_with` + // creates a fresh one. Prevents map leak from one-shot thread keys + // and avoids the first-message-after-idle being treated as an error. + if let Some(handle) = map.get(&thread_key) { + if handle.consumer.is_finished() { + map.remove(&thread_key); + } + } + + let entry = map.entry(thread_key.clone()).or_insert_with(|| { + let (tx, rx) = tokio::sync::mpsc::channel(cap); + let consumer = tokio::spawn(consumer_loop( + thread_key.clone(), + thread_channel.clone(), + rx, + Arc::clone(&target), + Arc::clone(&adapter), + cap, + max_tokens, + idle_timeout, + )); + ThreadHandle { + tx, + consumer, + generation: next_g, + channel_id: thread_channel.channel_id.clone(), + adapter_kind: adapter.platform().to_string(), + } + }); + (entry.tx.clone(), entry.generation) + }; + + if let Err(e) = tx.send(msg).await { + // Consumer has exited between our check and the send — race-safe + // eviction under lock (§2.5), then transparent retry once. + // + // Safe to re-acquire `per_thread` here: the first lock guard above + // was dropped before `tx.send().await`, so this acquisition cannot + // deadlock against the await point. The same property holds for the + // retry acquisition below. + { + // SAFETY: no .await while this guard is held. + let mut map = self.per_thread.lock().unwrap(); + Self::try_evict_locked(&mut map, &thread_key, my_generation); + } + let failed_msg = e.0; + + // Retry: spawn a fresh consumer and re-send. If this also fails, + // surface the error to the user. + let retry_g = self.next_generation.fetch_add(1, Ordering::Relaxed); + let (retry_tx, retry_gen) = { + // SAFETY: no .await while this guard is held — guard drops at end of block. + let mut map = self.per_thread.lock().unwrap(); + let entry = map.entry(thread_key.clone()).or_insert_with(|| { + let (tx, rx) = tokio::sync::mpsc::channel(cap); + let consumer = tokio::spawn(consumer_loop( + thread_key.clone(), + thread_channel.clone(), + rx, + Arc::clone(&target), + Arc::clone(&adapter), + cap, + max_tokens, + idle_timeout, + )); + ThreadHandle { + tx, + consumer, + generation: retry_g, + channel_id: thread_channel.channel_id.clone(), + adapter_kind: adapter.platform().to_string(), + } + }); + (entry.tx.clone(), entry.generation) + }; + + if let Err(e2) = retry_tx.send(failed_msg).await { + // Retry also failed — truly unexpected. Surface error. + { + // SAFETY: no .await while this guard is held. + let mut map = self.per_thread.lock().unwrap(); + Self::try_evict_locked(&mut map, &thread_key, retry_gen); + } + let failed_msg = e2.0; + let _ = adapter + .add_reaction( + &failed_msg.trigger_msg, + &self.target.reactions_config().emojis.error, + ) + .await; + let _ = adapter + .send_message( + &thread_channel, + &format!( + "⚠️ {}", + format_user_error("dispatch consumer exited unexpectedly") + ), + ) + .await; + return Err(DispatchError::ConsumerDead); + } + } + Ok(()) + } + + /// Drop all per-thread handles whose key belongs to `(platform, thread_id)`, + /// regardless of grouping, and abort each consumer (§2.5 / §4.4). Returns + /// the total number of buffered messages discarded across all lanes. + /// + /// Matches both Thread keys (`:`) and Lane keys + /// (`::`). Used by `/reset` and + /// `/cancel-all` to clear the entire thread, not just one lane. + /// + /// Disjoint from SendError recovery: removal happens *before* abort, so any + /// fresh `submit` after this returns lands on a lazily-constructed new handle + /// instead of observing `SendError`. + pub fn cancel_buffered_thread(&self, platform: &str, thread_id: &str) -> usize { + let prefix = format!("{platform}:{thread_id}"); + let lane_prefix = format!("{prefix}:"); + // SAFETY: no .await while this guard is held — function is sync. + let mut map = self.per_thread.lock().unwrap(); + let keys: Vec = map + .keys() + .filter(|k| k.as_str() == prefix || k.starts_with(&lane_prefix)) + .cloned() + .collect(); + let mut dropped = 0; + for k in keys { + if let Some(handle) = map.remove(&k) { + dropped += handle.pending_count(); + handle.consumer.abort(); + } + } + dropped + } + + /// §2.5 race-safe eviction. Caller must hold the `per_thread` mutex. + /// Removes the entry only if its generation matches `my_generation` — + /// protects against evicting a fresh handle that another `submit` lazily + /// inserted between this caller's failed `tx.send` and this call. + /// Returns true if the entry was removed. + fn try_evict_locked( + map: &mut HashMap, + thread_key: &str, + my_generation: u64, + ) -> bool { + if let Some(handle) = map.get(thread_key) { + if handle.generation == my_generation { + map.remove(thread_key); + return true; + } + } + false + } + + /// Remove map entries whose consumer task has finished (idle timeout or + /// unexpected exit). Called periodically from the cleanup task in main.rs + /// to prevent unbounded map growth from one-shot thread keys that never + /// receive a second `submit()`. Returns the number of entries swept. + pub fn sweep_stale(&self) -> usize { + // SAFETY: no .await while this guard is held — function is sync. + let mut map = self.per_thread.lock().unwrap(); + let before = map.len(); + map.retain(|_, handle| !handle.consumer.is_finished()); + before - map.len() + } + + /// Log buffered-message counts and drop all handles (called on SIGTERM). + pub fn shutdown(&self) { + // SAFETY: no .await while this guard is held — function is sync. + let mut map = self.per_thread.lock().unwrap(); + for (thread_id, handle) in map.iter() { + let pending = handle.pending_count(); + if pending > 0 { + warn!( + thread_id = %thread_id, + channel = %handle.channel_id, + adapter = %handle.adapter_kind, + buffered_lost = pending, + "shutdown dropped pending messages without dispatch", + ); + } + handle.consumer.abort(); + } + map.clear(); + } +} + +// --------------------------------------------------------------------------- +// consumer_loop +// --------------------------------------------------------------------------- + +#[allow(clippy::too_many_arguments)] +async fn consumer_loop( + thread_key: String, + thread_channel: ChannelRef, + mut rx: tokio::sync::mpsc::Receiver, + target: Arc, + adapter: Arc, + max_batch: usize, + max_tokens: usize, + idle_timeout: Duration, +) { + // `pending` holds a message that exceeded the token cap for the current batch; + // it becomes the first message of the next batch, preserving FIFO. + let mut pending: Option = None; + + loop { + // I1: block until at least one message arrives (zero latency for first message). + // Idle timeout: if no message arrives within `idle_timeout` the consumer + // exits, freeing the task and mpsc. The next `submit` for this thread_key + // will observe `SendError`, evict the stale entry, and lazily spawn a + // fresh consumer (§2.5 generation check prevents mis-eviction). + let first = match pending.take() { + Some(msg) => msg, + None => match tokio::time::timeout(idle_timeout, rx.recv()).await { + Ok(Some(msg)) => msg, + Ok(None) => { + // All senders dropped → shutdown() or cancel_buffered_thread(). + break; + } + Err(_elapsed) => { + debug!( + thread_key = %thread_key, + channel = %thread_channel.channel_id, + "consumer idle timeout, exiting" + ); + break; + } + }, + }; + + // Greedy drain up to max_batch messages or max_tokens. + let mut batch = vec![first]; + let mut cumulative_tokens = batch[0].estimated_tokens; + + while batch.len() < max_batch { + match rx.try_recv() { + Ok(more) => { + if cumulative_tokens + more.estimated_tokens > max_tokens { + // Token cap — save for next turn (FIFO preserved). + pending = Some(more); + break; + } + cumulative_tokens += more.estimated_tokens; + batch.push(more); + } + Err(_) => break, + } + } + + // §2.6: read the freshest snapshot in the batch (batch is non-empty). + let bot_present = batch.last().unwrap().other_bot_present; + + dispatch_batch( + &thread_key, + &thread_channel, + &target, + &adapter, + batch, + bot_present, + ) + .await; + } +} + +// --------------------------------------------------------------------------- +// dispatch_batch +// --------------------------------------------------------------------------- + +async fn dispatch_batch( + thread_key: &str, + thread_channel: &ChannelRef, + target: &Arc, + adapter: &Arc, + batch: Vec, + other_bot_present: bool, +) { + let dispatch_start = Instant::now(); + let batch_size = batch.len(); + let session_key = Dispatcher::session_key(thread_channel); + + // Apply 👀 reaction to every message in the batch before dispatch (§6.7). + // Skip when assistant status API is active — uses + // assistant.threads.setStatus instead of emoji reactions. + let assistant_status = adapter.uses_assistant_status(); + if !assistant_status { + let queued_emoji = &target.reactions_config().emojis.queued; + for msg in batch.iter() { + let _ = adapter.add_reaction(&msg.trigger_msg, queued_emoji).await; + } + } + + // Collect per-event observability data (before consuming the batch). + let tokens_per_event: Vec = batch.iter().map(|m| m.estimated_tokens).collect(); + let wait_ms: Vec = batch + .iter() + .map(|m| m.arrived_at.elapsed().as_millis()) + .collect(); + let senders: Vec = batch.iter().map(|m| m.sender_name.clone()).collect(); + + // Native-streaming recipient is bound to the turn (captured per-message). A + // batch attributes to the most recent sender; None for non-Slack/bot turns. + let recipient: Option<(String, String)> = batch.last().and_then(|m| m.recipient.clone()); + + // Anchor reactions on the last message in the batch (before consuming). + let trigger_msg = batch.last().unwrap().trigger_msg.clone(); + let dispatch_channel = ChannelRef { + // Reply correlation is event-scoped, but the dispatcher consumer is + // thread-scoped. Rebuild the per-dispatch channel from the stable + // thread route plus the freshest event ID so gateway replies (e.g. + // LINE reply-token lookup) target the current inbound event. + origin_event_id: trigger_msg.channel.origin_event_id.clone(), + ..thread_channel.clone() + }; + + // Pack all arrival events into one Vec (§3.3). + // Uses into_iter() to avoid deep-copying extra_blocks (may contain base64 image data). + let mut content_blocks: Vec = Vec::new(); + + // Parse control directives from the first message in the batch (ADR: control-directives). + // Directives are only processed on the session's first message (§2.2). + // + // Strategy: + // 1. Parse directives (cheap text extraction — no mutation, no I/O) + // 2. Attempt workspace resolution if [[ws:...]] present (may fail gracefully) + // 3. Call ensure_session with resolved workspace — returns created_now + // 4. Only strip prompt and apply title/workspace if created_now == true + // 5. If created_now == false, the [[...]] text is preserved verbatim + let mut batch = batch; + let parse_result = batch + .first() + .map(|first_msg| crate::directives::parse_directives(&first_msg.prompt)); + + // Tentatively resolve [[ws:...]] — if resolution fails and the session turns out to + // be new, we abort. If the session already existed, resolution failure is irrelevant. + let ws_resolved: Option> = parse_result.as_ref().and_then(|pr| { + pr.metadata.raw.get("ws").map(|ws_value| { + let aliases = target.workspace_aliases(); + let bot_home = target.bot_home(); + crate::directives::resolve_workspace(ws_value, &aliases, &bot_home) + .map(|p| p.display().to_string()) + }) + }); + + // Extract workspace path for ensure_session (None if no directive or resolution failed). + let workspace_override: Option = + ws_resolved.as_ref().and_then(|r| r.as_ref().ok().cloned()); + + // Ensure session exists. The create_gate mutex inside get_or_create serializes + // concurrent callers — only the winner gets created_now == true. + let created_now = match target + .ensure_session(&session_key, workspace_override.as_deref()) + .await + { + Ok(created) => created, + Err(e) => { + let user_msg = format_user_error(&e.to_string()); + let _ = adapter + .send_message(&dispatch_channel, &format!("⚠️ {user_msg}")) + .await; + error!("pool error in dispatch_batch: {e}"); + return; + } + }; + + // Only apply directives if this is genuinely the first message (fresh session). + if created_now { + if let Some(pr) = parse_result { + if !pr.metadata.raw.is_empty() { + // Apply [[title:...]] independently — works regardless of ws outcome. + let title_to_apply = pr.metadata.title.clone(); + + // If workspace resolution failed on a NEW session, rollback and abort. + // Reset FIRST to minimize TOCTOU window (擺渡 F1), then rename. + if let Some(Err(e)) = ws_resolved { + target.reset_session(&session_key).await; + // Apply title after reset so the thread is identifiable. + if let Some(ref title) = title_to_apply { + if !title.is_empty() { + let _ = adapter.rename_thread(&dispatch_channel, title).await; + } + } + let _ = adapter + .send_message(&dispatch_channel, &format!("⚠️ {e}")) + .await; + error!(session_key, error = %e, "workspace directive rejected"); + return; + } + + // Strip directives from the prompt + if let Some(first_msg) = batch.first_mut() { + first_msg.prompt = pr.prompt; + } + + // Apply title on success path. + if let Some(ref title) = title_to_apply { + if !title.is_empty() { + if let Err(e) = adapter.rename_thread(&dispatch_channel, title).await { + warn!(session_key, error = %e, "failed to apply title directive"); + } + } + } + } + } + } + + for msg in batch { + let mut event_blocks = + AdapterRouter::pack_arrival_event(&msg.sender_json, &msg.prompt, msg.extra_blocks); + content_blocks.append(&mut event_blocks); + } + let packed_block_count = content_blocks.len(); + + let reactions_config = target.reactions_config().clone(); + let reactions = Arc::new(StatusReactionController::new( + reactions_config.enabled, + adapter.clone(), + trigger_msg, + reactions_config.emojis.clone(), + reactions_config.timing.clone(), + )); + // 👀 already applied above; skip set_queued() to avoid double-reaction. + + let result = target + .stream_prompt_blocks( + adapter, + &session_key, + content_blocks, + &dispatch_channel, + reactions.clone(), + other_bot_present, + recipient, + ) + .await; + + // In assistant status mode, all status is conveyed via + // assistant.threads.setStatus — skip emoji reactions entirely. + if !assistant_status { + match &result { + Ok(()) => reactions.set_done().await, + Err(_) => reactions.set_error().await, + } + + let hold_ms = if result.is_ok() { + reactions_config.timing.done_hold_ms + } else { + reactions_config.timing.error_hold_ms + }; + if reactions_config.remove_after_reply { + let reactions = reactions; + tokio::spawn(async move { + tokio::time::sleep(std::time::Duration::from_millis(hold_ms)).await; + reactions.clear().await; + }); + } + } + + if let Err(ref e) = result { + let _ = adapter + .send_message(&dispatch_channel, &format!("⚠️ {e}")) + .await; + } + + let agent_dispatch_ms = dispatch_start.elapsed().as_millis(); + let span = info_span!( + "dispatch", + channel = %thread_channel.channel_id, + adapter = adapter.platform(), + ); + let _enter = span.enter(); + info!( + thread_key = %thread_key, + events_per_dispatch = batch_size, + packed_block_count = packed_block_count, + agent_dispatch_ms = agent_dispatch_ms, + tokens_per_event = ?tokens_per_event, + wait_ms = ?wait_ms, + senders = ?senders, + "batch dispatched", + ); +} + +// --------------------------------------------------------------------------- +// Token estimation +// --------------------------------------------------------------------------- + +/// Rough char-to-token ratio for English-ish text. Coarse on purpose — the goal +/// is a guard rail for `max_batch_tokens`, not an exact pre-flight. +const CHARS_PER_TOKEN_ESTIMATE: usize = 4; +/// Conservative per-image token budget. Larger than typical Claude image cost +/// so the cap trips before we hand the model an oversized batch. +const TOKENS_PER_IMAGE_ESTIMATE: usize = 512; + +/// Rough token estimate for a buffered message (used for `max_batch_tokens` cap). +/// Intentionally coarse — the goal is a guard rail, not an exact pre-flight. +pub fn estimate_tokens(prompt: &str, extra_blocks: &[ContentBlock]) -> usize { + let text_tokens = prompt.len() / CHARS_PER_TOKEN_ESTIMATE + 1; + let block_tokens: usize = extra_blocks + .iter() + .map(|b| match b { + ContentBlock::Text { text } => text.len() / CHARS_PER_TOKEN_ESTIMATE + 1, + ContentBlock::Image { .. } => TOKENS_PER_IMAGE_ESTIMATE, + }) + .sum(); + text_tokens + block_tokens +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn estimate_tokens_empty() { + assert!(estimate_tokens("", &[]) >= 1); + } + + #[test] + fn estimate_tokens_text() { + // 400 chars ≈ 100 tokens + let s = "a".repeat(400); + assert_eq!(estimate_tokens(&s, &[]), 101); + } + + #[test] + fn estimate_tokens_image_block() { + let blocks = vec![ContentBlock::Image { + media_type: "image/png".into(), + data: "base64data".into(), + }]; + assert_eq!(estimate_tokens("", &blocks), 1 + 512); + } + + #[test] + fn pack_arrival_event_single() { + let blocks = + AdapterRouter::pack_arrival_event(r#"{"schema":"openab.sender.v1"}"#, "hello", vec![]); + // sender_context delimiter + prompt = 2 blocks + assert_eq!(blocks.len(), 2); + if let ContentBlock::Text { text } = &blocks[0] { + assert!(text.contains("")); + assert!(text.contains("")); + // Header is delimiter only — prompt lives in its own block. + assert!(!text.contains("hello")); + } else { + panic!("expected Text delimiter block"); + } + if let ContentBlock::Text { text } = &blocks[1] { + assert_eq!(text, "hello"); + } else { + panic!("expected Text prompt block"); + } + } + + #[test] + fn pack_arrival_event_with_extra_blocks() { + let extra = vec![ + ContentBlock::Text { + text: "[Voice transcript]: hi".into(), + }, + ContentBlock::Image { + media_type: "image/png".into(), + data: "abc".into(), + }, + ]; + let blocks = AdapterRouter::pack_arrival_event("{}", "prompt", extra); + // delimiter + transcript + prompt + image = 4 blocks + assert_eq!(blocks.len(), 4); + assert!( + matches!(&blocks[0], ContentBlock::Text { text } if text.contains("")) + ); + assert!( + matches!(&blocks[1], ContentBlock::Text { text } if text.contains("Voice transcript")) + ); + assert!(matches!(&blocks[2], ContentBlock::Text { text } if text == "prompt")); + assert!(matches!(&blocks[3], ContentBlock::Image { .. })); + } + + #[test] + fn pack_arrival_event_batch_n2() { + // Two arrival events concatenated → 2 (header + prompt) pairs = 4 blocks. + let mut all: Vec = Vec::new(); + all.extend(AdapterRouter::pack_arrival_event( + r#"{"ts":"T1"}"#, + "msg1", + vec![], + )); + all.extend(AdapterRouter::pack_arrival_event( + r#"{"ts":"T2"}"#, + "msg2", + vec![], + )); + assert_eq!(all.len(), 4); + if let ContentBlock::Text { text } = &all[0] { + assert!(text.contains(r#""ts":"T1""#)); + assert!(!text.contains("msg1")); + } + if let ContentBlock::Text { text } = &all[1] { + assert_eq!(text, "msg1"); + } + if let ContentBlock::Text { text } = &all[2] { + assert!(text.contains(r#""ts":"T2""#)); + assert!(!text.contains("msg2")); + } + if let ContentBlock::Text { text } = &all[3] { + assert_eq!(text, "msg2"); + } + } + + // ADR §3.6 Scenario B — text in one message, image in the next, same author. + // Broker preserves structural truth: image stays in M2 alone, both messages + // carry the same sender_id so the agent can semantically link them. + #[test] + fn pack_arrival_event_scenario_b_image_in_separate_message() { + let mut all: Vec = Vec::new(); + // M1 (alice): "see this image" + all.extend(AdapterRouter::pack_arrival_event( + r#"{"sender_id":"A","ts":"T1"}"#, + "see this image", + vec![], + )); + // M2 (alice): image, no text + all.extend(AdapterRouter::pack_arrival_event( + r#"{"sender_id":"A","ts":"T2"}"#, + "", + vec![ContentBlock::Image { + media_type: "image/png".into(), + data: "imgB".into(), + }], + )); + // header(M1) + prompt(M1) + header(M2) + image(M2) = 4 blocks + // (M2 has empty prompt, so its prompt block is omitted) + assert_eq!(all.len(), 4); + if let ContentBlock::Text { text } = &all[0] { + assert!(text.contains(r#""sender_id":"A""#)); + assert!(text.contains(r#""ts":"T1""#)); + } else { + panic!("expected Text delimiter for M1"); + } + if let ContentBlock::Text { text } = &all[1] { + assert_eq!(text, "see this image"); + } else { + panic!("expected Text prompt for M1"); + } + if let ContentBlock::Text { text } = &all[2] { + assert!(text.contains(r#""ts":"T2""#)); + } else { + panic!("expected Text delimiter for M2"); + } + // M2's image follows immediately after its delimiter (no prompt block). + assert!(matches!(&all[3], ContentBlock::Image { .. })); + } + + // ADR §3.6 Scenario C — fragmented multi-author batch. + // Repeated sender_id is preserved across non-adjacent messages; bob's interjection + // is kept as-is (no silent drop, no temporal reorder). + #[test] + fn pack_arrival_event_scenario_c_multi_author_interleaved() { + let mut all: Vec = Vec::new(); + all.extend(AdapterRouter::pack_arrival_event( + r#"{"sender_id":"A","ts":"T1"}"#, + "see this image", + vec![], + )); + all.extend(AdapterRouter::pack_arrival_event( + r#"{"sender_id":"B","ts":"T2"}"#, + "what?", + vec![], + )); + all.extend(AdapterRouter::pack_arrival_event( + r#"{"sender_id":"A","ts":"T3"}"#, + "", + vec![ContentBlock::Image { + media_type: "image/png".into(), + data: "imgC".into(), + }], + )); + // M1: header + prompt = 2 blocks + // M2: header + prompt = 2 blocks + // M3: header + image = 2 blocks (empty prompt → no prompt block) + // total = 6 + assert_eq!(all.len(), 6); + let h1 = match &all[0] { + ContentBlock::Text { text } => text, + _ => panic!("expected Text delimiter for M1"), + }; + let p1 = match &all[1] { + ContentBlock::Text { text } => text, + _ => panic!("expected Text prompt for M1"), + }; + let h2 = match &all[2] { + ContentBlock::Text { text } => text, + _ => panic!("expected Text delimiter for M2"), + }; + let p2 = match &all[3] { + ContentBlock::Text { text } => text, + _ => panic!("expected Text prompt for M2"), + }; + let h3 = match &all[4] { + ContentBlock::Text { text } => text, + _ => panic!("expected Text delimiter for M3"), + }; + assert!(h1.contains(r#""sender_id":"A""#) && h1.contains(r#""ts":"T1""#)); + assert_eq!(p1, "see this image"); + assert!(h2.contains(r#""sender_id":"B""#) && h2.contains(r#""ts":"T2""#)); + assert_eq!(p2, "what?"); + assert!(h3.contains(r#""sender_id":"A""#) && h3.contains(r#""ts":"T3""#)); + // M3's image attached to M3 only. + assert!(matches!(&all[5], ContentBlock::Image { .. })); + } + + // ADR §3.6 Scenario D — voice-only message in a batch. + // Within each arrival, transcript Text blocks precede the prompt block so the + // agent sees voice content before any typed text. The sender_context delimiter + // still opens each arrival. + #[test] + fn pack_arrival_event_scenario_d_voice_only() { + let mut all: Vec = Vec::new(); + all.extend(AdapterRouter::pack_arrival_event( + r#"{"sender_id":"A","ts":"T1"}"#, + "look at this", + vec![ContentBlock::Image { + media_type: "image/png".into(), + data: "scr".into(), + }], + )); + all.extend(AdapterRouter::pack_arrival_event( + r#"{"sender_id":"A","ts":"T2"}"#, + "", + vec![ContentBlock::Text { + text: "[Voice message transcript]: hey can we sync about the deploy".into(), + }], + )); + all.extend(AdapterRouter::pack_arrival_event( + r#"{"sender_id":"B","ts":"T3"}"#, + "what?", + vec![], + )); + // M1: header + prompt + image = 3 + // M2: header + transcript = 2 (empty prompt → no prompt block) + // M3: header + prompt = 2 + // total = 7 + assert_eq!(all.len(), 7); + if let ContentBlock::Text { text } = &all[0] { + assert!(text.contains(r#""ts":"T1""#)); + assert!(!text.contains("look at this")); + } + if let ContentBlock::Text { text } = &all[1] { + assert_eq!(text, "look at this"); + } + assert!(matches!(&all[2], ContentBlock::Image { .. })); + if let ContentBlock::Text { text } = &all[3] { + assert!(text.contains(r#""ts":"T2""#)); + } + // Transcript precedes prompt (and prompt is omitted here because empty). + if let ContentBlock::Text { text } = &all[4] { + assert!(text.contains("Voice message transcript")); + assert!(text.contains("sync about the deploy")); + } else { + panic!("expected transcript Text block after M2 delimiter"); + } + if let ContentBlock::Text { text } = &all[5] { + assert!(text.contains(r#""sender_id":"B""#)); + } + if let ContentBlock::Text { text } = &all[6] { + assert_eq!(text, "what?"); + } + } + + // Token-cap math: a single message that already exceeds max_batch_tokens still + // dispatches alone (the consumer_loop logic admits the first message before + // checking the cap). Verifies estimate_tokens scales with input length. + #[test] + fn estimate_tokens_oversized_single_message() { + // ~24k token text (96000 chars / 4 chars-per-token). + let big = "x".repeat(96_000); + let est = estimate_tokens(&big, &[]); + assert!(est > 24_000, "expected >24k tokens, got {est}"); + } + + // Cumulative token math: two messages whose sum exceeds max_batch_tokens. + // The consumer_loop reads first, then peeks at the next; if cumulative tokens + // > cap, the second is held over to the next batch (FIFO preserved). + #[test] + fn estimate_tokens_cumulative_exceeds_cap() { + let max_tokens = 24_000_usize; + let m1 = estimate_tokens(&"a".repeat(80_000), &[]); + let m2 = estimate_tokens(&"b".repeat(50_000), &[]); + assert!(m1 < max_tokens); + assert!(m1 + m2 > max_tokens, "{m1} + {m2} should exceed cap"); + } + + // ADR §2.5 race-safe eviction. The full SendError path requires a real + // AdapterRouter (concrete struct, not a trait — no easy mock seam), so we + // unit-test the eviction predicate in isolation. End-to-end consumer-death + // recovery is exercised by the manual staging smoke documented in the ADR. + fn dummy_handle(generation: u64) -> ThreadHandle { + let (tx, _rx) = tokio::sync::mpsc::channel::(1); + let consumer = tokio::spawn(async {}); + ThreadHandle { + tx, + consumer, + generation, + channel_id: "C".into(), + adapter_kind: "discord".into(), + } + } + + #[tokio::test] + async fn try_evict_locked_removes_when_generation_matches() { + let mut map: HashMap = HashMap::new(); + map.insert("t".into(), dummy_handle(7)); + assert!(Dispatcher::try_evict_locked(&mut map, "t", 7)); + assert!(map.is_empty()); + } + + // The bug §2.5 prevents: a stale producer (my_gen=7) observing SendError + // must not remove a freshly inserted handle (gen=8) created by another + // submit between the failed send and the eviction attempt. + #[tokio::test] + async fn try_evict_locked_keeps_when_generation_differs() { + let mut map: HashMap = HashMap::new(); + map.insert("t".into(), dummy_handle(8)); + assert!(!Dispatcher::try_evict_locked(&mut map, "t", 7)); + assert_eq!(map.len(), 1); + assert_eq!(map.get("t").unwrap().generation, 8); + } + + #[tokio::test] + async fn try_evict_locked_returns_false_when_absent() { + let mut map: HashMap = HashMap::new(); + assert!(!Dispatcher::try_evict_locked(&mut map, "missing", 0)); + } + + // BatchGrouping → thread_key shape. + fn make_dispatcher(grouping: BatchGrouping) -> Dispatcher { + // The router is wrapped in Arc but never used by `key()` itself; we use + // a dummy AdapterRouter built via the same path main.rs would use. + // For a pure-keying test we'd ideally not need it, but the constructor demands one. + // Construct a minimal router via the public test helpers in adapter.rs if available; + // otherwise we fall back to building one with a dummy SessionPool. + use crate::acp::SessionPool; + let agent_cfg = crate::config::AgentConfig { + command: "/bin/true".into(), + args: vec![], + working_dir: "/tmp".into(), + env: std::collections::HashMap::new(), + inherit_env: vec![], + command_explicit: true, + }; + let pool = Arc::new(SessionPool::new(agent_cfg, 1)); + let router = Arc::new(AdapterRouter::new( + pool, + crate::config::ReactionsConfig::default(), + crate::markdown::TableMode::Off, + crate::config::default_prompt_hard_timeout_secs(), + crate::config::default_liveness_check_secs(), + std::collections::HashMap::new(), + std::path::PathBuf::from("/tmp"), + )); + Dispatcher::with_idle_timeout(router, 10, 24_000, grouping, DEFAULT_CONSUMER_IDLE_TIMEOUT) + } + + #[tokio::test] + async fn key_per_thread_ignores_sender() { + let d = make_dispatcher(BatchGrouping::Thread); + assert_eq!(d.key("discord", "T1", "userA"), "discord:T1"); + assert_eq!(d.key("discord", "T1", "userB"), "discord:T1"); + } + + #[tokio::test] + async fn key_per_lane_includes_sender() { + let d = make_dispatcher(BatchGrouping::Lane); + assert_eq!(d.key("discord", "T1", "userA"), "discord:T1:userA"); + assert_eq!(d.key("discord", "T1", "userB"), "discord:T1:userB"); + // Different threads remain distinct. + assert_eq!(d.key("slack", "T2", "userA"), "slack:T2:userA"); + } + + fn insert_dummy_handle(d: &Dispatcher, key: &str) { + let (tx, _rx) = tokio::sync::mpsc::channel::(10); + let consumer = tokio::spawn(async {}); + let handle = ThreadHandle { + tx, + consumer, + generation: 0, + channel_id: "c".into(), + adapter_kind: "discord".into(), + }; + d.per_thread.lock().unwrap().insert(key.to_string(), handle); + } + + #[tokio::test] + async fn cancel_buffered_thread_drops_per_thread_key() { + let d = make_dispatcher(BatchGrouping::Thread); + insert_dummy_handle(&d, "discord:T1"); + insert_dummy_handle(&d, "discord:T2"); // different thread, must survive + assert_eq!(d.cancel_buffered_thread("discord", "T1"), 0); // no buffered msgs + let map = d.per_thread.lock().unwrap(); + assert!(!map.contains_key("discord:T1")); + assert!(map.contains_key("discord:T2")); + } + + #[tokio::test] + async fn cancel_buffered_thread_drops_all_lanes() { + let d = make_dispatcher(BatchGrouping::Lane); + insert_dummy_handle(&d, "discord:T1:userA"); + insert_dummy_handle(&d, "discord:T1:userB"); + insert_dummy_handle(&d, "discord:T2:userA"); // different thread + insert_dummy_handle(&d, "slack:T1:userA"); // different platform + d.cancel_buffered_thread("discord", "T1"); + let map = d.per_thread.lock().unwrap(); + assert!(!map.contains_key("discord:T1:userA")); + assert!(!map.contains_key("discord:T1:userB")); + assert!(map.contains_key("discord:T2:userA")); + assert!(map.contains_key("slack:T1:userA")); + } + + #[tokio::test] + async fn cancel_buffered_thread_does_not_match_thread_id_prefix() { + // T1 must not match T10 / T11 (substring trap). + let d = make_dispatcher(BatchGrouping::Lane); + insert_dummy_handle(&d, "discord:T1:userA"); + insert_dummy_handle(&d, "discord:T10:userA"); + d.cancel_buffered_thread("discord", "T1"); + let map = d.per_thread.lock().unwrap(); + assert!(!map.contains_key("discord:T1:userA")); + assert!(map.contains_key("discord:T10:userA")); + } + + // Long-running consumer that parks until aborted — used by sweep_stale / + // shutdown tests to exercise the "still alive" path. + fn alive_consumer_handle() -> ThreadHandle { + let (tx, _rx) = tokio::sync::mpsc::channel::(10); + let consumer = tokio::spawn(async { + std::future::pending::<()>().await; + }); + ThreadHandle { + tx, + consumer, + generation: 0, + channel_id: "c".into(), + adapter_kind: "discord".into(), + } + } + + #[tokio::test] + async fn sweep_stale_removes_finished_consumers() { + let d = make_dispatcher(BatchGrouping::Thread); + insert_dummy_handle(&d, "discord:T1"); + insert_dummy_handle(&d, "discord:T2"); + // Yield so the empty-body spawned tasks actually run to completion + // before is_finished() is checked. + tokio::time::sleep(Duration::from_millis(10)).await; + let swept = d.sweep_stale(); + assert_eq!(swept, 2); + assert!(d.per_thread.lock().unwrap().is_empty()); + } + + #[tokio::test] + async fn sweep_stale_keeps_running_consumers() { + let d = make_dispatcher(BatchGrouping::Thread); + let abort = { + let h = alive_consumer_handle(); + let a = h.consumer.abort_handle(); + d.per_thread.lock().unwrap().insert("alive".into(), h); + a + }; + let swept = d.sweep_stale(); + assert_eq!(swept, 0); + assert!(d.per_thread.lock().unwrap().contains_key("alive")); + // Cleanup so the parked task doesn't linger across tests. + abort.abort(); + } + + #[tokio::test] + async fn shutdown_clears_all_handles() { + let d = make_dispatcher(BatchGrouping::Thread); + insert_dummy_handle(&d, "k1"); + insert_dummy_handle(&d, "k2"); + insert_dummy_handle(&d, "k3"); + d.shutdown(); + assert!(d.per_thread.lock().unwrap().is_empty()); + } + + #[tokio::test] + async fn shutdown_aborts_running_consumers() { + let d = make_dispatcher(BatchGrouping::Thread); + let abort = { + let h = alive_consumer_handle(); + let a = h.consumer.abort_handle(); + d.per_thread.lock().unwrap().insert("k".into(), h); + a + }; + d.shutdown(); + // Give the runtime a tick to process abort + map drop. + tokio::time::sleep(Duration::from_millis(10)).await; + assert!(abort.is_finished()); + } + + // ----------------------------------------------------------------------- + // consumer_loop / dispatch_batch integration tests (NIT 2) + // + // These drive `consumer_loop` directly with a pre-populated mpsc, using + // `MockDispatchTarget` to record the calls that would otherwise hit a + // real `AdapterRouter` (and through it, ACP CLI subprocesses). This + // gives deterministic coverage of the orchestration paths the existing + // unit tests don't reach: greedy drain, token-cap overflow, idle timeout. + // ----------------------------------------------------------------------- + + /// One recorded `stream_prompt_blocks` invocation. + #[derive(Clone)] + struct RecordedDispatch { + block_count: usize, + other_bot_present: bool, + dispatch_channel: ChannelRef, + } + + /// Mock `DispatchTarget` — records calls; never touches a real session pool. + struct MockDispatchTarget { + reactions: ReactionsConfig, + calls: Mutex>, + /// If set, `ensure_session` returns this error once. + ensure_err: Mutex>, + /// If set, `stream_prompt_blocks` returns this error once. + stream_err: Mutex>, + } + + impl MockDispatchTarget { + fn new() -> Self { + Self { + reactions: ReactionsConfig::default(), + calls: Mutex::new(Vec::new()), + ensure_err: Mutex::new(None), + stream_err: Mutex::new(None), + } + } + + fn calls(&self) -> Vec { + self.calls.lock().unwrap().clone() + } + } + + #[async_trait] + impl DispatchTarget for MockDispatchTarget { + fn reactions_config(&self) -> &ReactionsConfig { + &self.reactions + } + + fn workspace_aliases(&self) -> std::collections::HashMap { + std::collections::HashMap::new() + } + + fn bot_home(&self) -> std::path::PathBuf { + std::path::PathBuf::from("/tmp") + } + + async fn ensure_session( + &self, + _session_key: &str, + _working_dir: Option<&str>, + ) -> Result { + if let Some(msg) = self.ensure_err.lock().unwrap().take() { + return Err(anyhow::anyhow!(msg)); + } + Ok(true) + } + + async fn reset_session(&self, _session_key: &str) {} + + async fn stream_prompt_blocks( + &self, + _adapter: &Arc, + _session_key: &str, + content_blocks: Vec, + thread_channel: &ChannelRef, + _reactions: Arc, + other_bot_present: bool, + _recipient: Option<(String, String)>, + ) -> Result<()> { + self.calls.lock().unwrap().push(RecordedDispatch { + block_count: content_blocks.len(), + other_bot_present, + dispatch_channel: thread_channel.clone(), + }); + if let Some(msg) = self.stream_err.lock().unwrap().take() { + return Err(anyhow::anyhow!(msg)); + } + Ok(()) + } + } + + /// Mock `ChatAdapter` — every method is a no-op success. The dispatch loop + /// invokes `add_reaction` (queued 👀), `platform`, and on the error path + /// `send_message`; nothing else needs real behavior here. + struct MockChatAdapter; + + #[async_trait] + impl ChatAdapter for MockChatAdapter { + fn platform(&self) -> &'static str { + "mock" + } + fn message_limit(&self) -> usize { + 2000 + } + + async fn send_message(&self, channel: &ChannelRef, _content: &str) -> Result { + Ok(MessageRef { + channel: channel.clone(), + message_id: "mock-msg".into(), + }) + } + + async fn create_thread( + &self, + channel: &ChannelRef, + _trigger_msg: &MessageRef, + _title: &str, + ) -> Result { + Ok(channel.clone()) + } + + async fn add_reaction(&self, _msg: &MessageRef, _emoji: &str) -> Result<()> { + Ok(()) + } + async fn remove_reaction(&self, _msg: &MessageRef, _emoji: &str) -> Result<()> { + Ok(()) + } + fn use_streaming(&self, _other_bot_present: bool) -> bool { + false + } + } + + fn make_channel(thread: &str) -> ChannelRef { + ChannelRef { + platform: "mock".into(), + channel_id: thread.into(), + thread_id: Some(thread.into()), + parent_id: None, + origin_event_id: None, + } + } + + fn make_msg(prompt: &str, tokens: usize) -> BufferedMessage { + BufferedMessage { + sender_json: r#"{"schema":"openab.sender.v1","sender_id":"u","sender_name":"u"}"# + .into(), + sender_name: "u".into(), + prompt: prompt.into(), + extra_blocks: vec![], + trigger_msg: MessageRef { + channel: make_channel("T"), + message_id: format!("m-{prompt}"), + }, + arrived_at: Instant::now(), + estimated_tokens: tokens, + other_bot_present: false, + recipient: None, + } + } + + /// Pre-load `msgs` into a fresh mpsc, drop the sender, and run + /// `consumer_loop` to completion. Returns the recorded dispatches. + async fn run_consumer_with_messages( + msgs: Vec, + max_batch: usize, + max_tokens: usize, + ) -> Vec { + let mock = Arc::new(MockDispatchTarget::new()); + let target: Arc = mock.clone(); + let adapter: Arc = Arc::new(MockChatAdapter); + let (tx, rx) = tokio::sync::mpsc::channel::(msgs.len().max(1)); + for m in msgs { + tx.send(m).await.unwrap(); + } + drop(tx); + + consumer_loop( + "mock:T".into(), + make_channel("T"), + rx, + target, + adapter, + max_batch, + max_tokens, + Duration::from_secs(60), + ) + .await; + + mock.calls() + } + + #[tokio::test] + async fn consumer_dispatches_single_message_as_one_batch() { + let calls = run_consumer_with_messages(vec![make_msg("hi", 10)], 10, 24_000).await; + assert_eq!(calls.len(), 1); + // pack_arrival_event with no extra_blocks → delimiter + prompt = 2 blocks. + assert_eq!(calls[0].block_count, 2); + assert!(!calls[0].other_bot_present); + } + + #[tokio::test] + async fn consumer_greedy_drain_combines_queued_messages_into_one_batch() { + // 3 messages already in the queue when the consumer wakes → greedy + // drain pulls all 3, packs them into one batch, dispatches once. + let calls = run_consumer_with_messages( + vec![make_msg("a", 50), make_msg("b", 50), make_msg("c", 50)], + 10, + 24_000, + ) + .await; + assert_eq!(calls.len(), 1, "expected a single batched dispatch"); + // 3 arrivals × (delimiter + prompt) = 6 blocks. + assert_eq!(calls[0].block_count, 6); + } + + #[tokio::test] + async fn consumer_token_cap_splits_batch_preserving_fifo() { + // max_tokens=100, two 80-token messages → cumulative 160 > 100, so + // msg2 becomes `pending` and is dispatched in the next batch. + let calls = + run_consumer_with_messages(vec![make_msg("a", 80), make_msg("b", 80)], 10, 100).await; + assert_eq!(calls.len(), 2, "token cap should split into two batches"); + // Each batch holds one arrival → delimiter + prompt = 2 blocks. + assert_eq!(calls[0].block_count, 2); + assert_eq!(calls[1].block_count, 2); + } + + #[tokio::test] + async fn consumer_dispatch_uses_last_event_origin_event_id_for_merged_batch() { + let mut first = make_msg("a", 80); + first.trigger_msg.channel.origin_event_id = Some("evt-first".into()); + let mut second = make_msg("b", 80); + second.trigger_msg.channel.origin_event_id = Some("evt-second".into()); + + let calls = run_consumer_with_messages(vec![first, second], 10, 200).await; + assert_eq!(calls.len(), 1); + assert_eq!( + calls[0].dispatch_channel.origin_event_id.as_deref(), + Some("evt-second") + ); + } + + #[tokio::test] + async fn consumer_dispatch_preserves_thread_route_while_refreshing_origin_event_id() { + let mock = Arc::new(MockDispatchTarget::new()); + let target: Arc = mock.clone(); + let adapter: Arc = Arc::new(MockChatAdapter); + let (tx, rx) = tokio::sync::mpsc::channel::(1); + + let mut msg = make_msg("hi", 10); + msg.trigger_msg.channel = ChannelRef { + platform: "mock".into(), + channel_id: "parent-channel".into(), + thread_id: None, + parent_id: None, + origin_event_id: Some("evt-fresh".into()), + }; + tx.send(msg).await.unwrap(); + drop(tx); + + consumer_loop( + "mock:topic-42".into(), + ChannelRef { + platform: "mock".into(), + channel_id: "topic-42".into(), + thread_id: Some("topic-42".into()), + parent_id: Some("parent-channel".into()), + origin_event_id: Some("evt-stale".into()), + }, + rx, + target, + adapter, + 10, + 24_000, + Duration::from_secs(60), + ) + .await; + + let calls = mock.calls(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].dispatch_channel.channel_id, "topic-42"); + assert_eq!( + calls[0].dispatch_channel.thread_id.as_deref(), + Some("topic-42") + ); + assert_eq!( + calls[0].dispatch_channel.parent_id.as_deref(), + Some("parent-channel") + ); + assert_eq!( + calls[0].dispatch_channel.origin_event_id.as_deref(), + Some("evt-fresh") + ); + } + + #[tokio::test] + async fn consumer_exits_after_idle_timeout_with_no_messages() { + // No messages ever arrive; consumer should exit once `idle_timeout` + // elapses. Keep `tx` alive so the exit path is the timeout, not the + // "all senders dropped" branch. + let mock = Arc::new(MockDispatchTarget::new()); + let target: Arc = mock.clone(); + let adapter: Arc = Arc::new(MockChatAdapter); + let (tx, rx) = tokio::sync::mpsc::channel::(1); + let consumer = tokio::spawn(consumer_loop( + "mock:T".into(), + make_channel("T"), + rx, + target, + adapter, + 10, + 24_000, + Duration::from_millis(50), + )); + // Wait enough for the timeout branch + a tick for the task to finish. + tokio::time::sleep(Duration::from_millis(150)).await; + assert!( + consumer.is_finished(), + "consumer should exit after idle timeout" + ); + // No dispatches should have been recorded. + assert!(mock.calls().is_empty()); + drop(tx); + } + + #[tokio::test] + async fn submit_evicts_dead_handle_and_retries_with_fresh_consumer() { + // §2.5: if `tx.send()` returns `SendError` (consumer's rx dropped + // mid-flight), `submit` evicts the stale entry under lock and spawns + // a fresh consumer. Manufacture this state by inserting a handle + // whose consumer is still parked but whose rx has been dropped. + let mock = Arc::new(MockDispatchTarget::new()); + let target: Arc = mock.clone(); + let d = Dispatcher::with_idle_timeout( + target, + 10, + 24_000, + BatchGrouping::Thread, + DEFAULT_CONSUMER_IDLE_TIMEOUT, + ); + let adapter: Arc = Arc::new(MockChatAdapter); + + let key = "mock:T".to_string(); + let parked = { + let (tx, rx) = tokio::sync::mpsc::channel::(10); + drop(rx); // closes the channel → next tx.send() yields SendError + let consumer = tokio::spawn(std::future::pending::<()>()); + let abort = consumer.abort_handle(); + let handle = ThreadHandle { + tx, + consumer, + generation: 999, + channel_id: "T".into(), + adapter_kind: "mock".into(), + }; + d.per_thread.lock().unwrap().insert(key.clone(), handle); + abort + }; + + d.submit(key, make_channel("T"), adapter, make_msg("hello", 10)) + .await + .expect("retry should spawn a fresh consumer"); + // Give the freshly spawned consumer time to drain + dispatch. + tokio::time::sleep(Duration::from_millis(50)).await; + + let calls = mock.calls(); + assert_eq!( + calls.len(), + 1, + "fresh consumer should have dispatched the retry" + ); + // pack_arrival_event with no extra_blocks → delimiter + prompt = 2 blocks. + assert_eq!(calls[0].block_count, 2); + + parked.abort(); + } +} diff --git a/crates/openab-core/src/error_display.rs b/crates/openab-core/src/error_display.rs new file mode 100644 index 000000000..c8826dcbf --- /dev/null +++ b/crates/openab-core/src/error_display.rs @@ -0,0 +1,323 @@ +/// Format any error for user display in Discord. +/// +/// Handles two error categories: +/// - **Coded errors** (code != 0): JSON-RPC or HTTP status codes from upstream agent. +/// - **Startup/connection errors** (code == 0): Errors from pool.rs or connection.rs +/// where only the message string is available. +/// +/// Provider-agnostic: no provider-specific strings, message text passed through verbatim. +pub fn format_user_error(message: &str) -> String { + let msg_lower = message.to_lowercase(); + + // Startup / connection errors (code == 0 from anyhow) + if msg_lower.contains("timeout waiting for") { + // Use msg_lower for extraction to stay case-insistent with the match above. + // msg_lower and message are the same length, so byte offsets are valid. + if let Some(start) = msg_lower.find("timeout waiting for ") { + let rest = &message[start + "timeout waiting for ".len()..]; + let method = rest.split_whitespace().next().unwrap_or("request"); + return format!( + "**Request Timeout**\nTimeout waiting for {}, please try again.", + method + ); + } + return "**Request Timeout**\nTimeout waiting for a response, please try again." + .to_string(); + } + if msg_lower.contains("connection closed") || msg_lower.contains("channel closed") { + return "**Connection Lost**\nThe connection to the agent was lost, please try again." + .to_string(); + } + if msg_lower.contains("failed to spawn") || msg_lower.contains("no such file") { + return "**Agent Not Found**\nCould not start the agent — please check your configuration." + .to_string(); + } + if msg_lower.contains("pool exhausted") { + return "**Service Busy**\nAll agent sessions are in use, please try again shortly." + .to_string(); + } + if msg_lower.contains("invalid api key") || msg_lower.contains("unauthorized") { + return "**Unauthorized**\nPlease check your API key configuration.".to_string(); + } + + // Unknown error — pass through as-is + if message.is_empty() { + "**Error**\nAn unknown error occurred.".to_string() + } else { + format!("**Error**\n{}", message) + } +} + +/// Format coded error from ACP agent for display in Discord. +/// Used for response errors that have a JSON-RPC or HTTP status code. +/// `data_message` is the optional detail extracted from `error.data.message`. +/// Public for reuse by other adapters (e.g. Slack). +pub fn format_coded_error(code: i64, message: &str, data_message: Option<&str>) -> String { + let prefix = match code { + 400 => "**Bad Request**", + 401 => "**Unauthorized**", + 403 => "**Forbidden**", + 404 => "**Not Found**", + 408 => "**Request Timeout**", + 429 => "**Rate Limited**", + 500 => "**Internal Server Error**", + 502 => "**Bad Gateway**", + 503 => "**Service Unavailable**", + 504 => "**Gateway Timeout**", + -32600 => "**Invalid Request**", + -32601 => "**Method Not Found**", + -32602 => "**Invalid Params**", + -32603 => "**Internal Error**", + -32099..=-32000 => "**Server Error**", + _ => "**Error**", + }; + let mut out = if message.is_empty() { + format!("{} (code: {})", prefix, code) + } else { + format!("{} (code: {})\n{}", prefix, code, message) + }; + let detail = data_message.filter(|s| !s.trim().is_empty()); + if let Some(detail) = detail { + if !message.contains(detail) { + out.push_str("\n> "); + out.push_str(detail); + } + } else if code == -32603 { + out.push_str( + "\n\n_The agent did not return any error details. \ + Please check the agent's own logs for more information._", + ); + } + out +} + +#[cfg(test)] +mod tests { + use super::*; + + // ─── format_user_error tests ───────────────────────────────────────────── + + #[test] + fn format_user_error_timeout() { + let result = format_user_error("timeout waiting for session/new response"); + assert!(result.contains("Request Timeout")); + assert!(result.contains("session/new")); + } + + #[test] + fn format_user_error_connection_closed() { + let result = format_user_error("connection closed"); + assert!(result.contains("Connection Lost")); + } + + #[test] + fn format_user_error_channel_closed() { + let result = format_user_error("channel closed"); + assert!(result.contains("Connection Lost")); + } + + #[test] + fn format_user_error_failed_to_spawn() { + let result = format_user_error("failed to spawn /some/path: No such file"); + assert!(result.contains("Agent Not Found")); + assert!(result.contains("the agent")); // generic, no provider name + } + + #[test] + fn format_user_error_no_such_file() { + let result = format_user_error("binary /usr/bin/nonexistent: no such file"); + assert!(result.contains("Agent Not Found")); + } + + #[test] + fn format_user_error_pool_exhausted() { + let result = format_user_error("pool exhausted (5 sessions)"); + assert!(result.contains("Service Busy")); + } + + #[test] + fn format_user_error_invalid_api_key() { + let result = format_user_error("invalid api key"); + assert!(result.contains("Unauthorized")); + } + + #[test] + fn format_user_error_unauthorized() { + let result = format_user_error("unauthorized: token rejected"); + assert!(result.contains("Unauthorized")); + } + + #[test] + fn format_user_error_unknown() { + let result = format_user_error("something went wrong"); + assert!(result.contains("Error")); + assert!(result.contains("something went wrong")); + } + + #[test] + fn format_user_error_empty() { + let result = format_user_error(""); + assert!(result.contains("Error")); + assert!(result.contains("unknown")); + } + + #[test] + fn format_user_error_case_insensitive() { + assert!(format_user_error("TIMEOUT WAITING FOR foo").contains("Timeout")); + assert!(format_user_error("CONNECTION CLOSED").contains("Connection")); + assert!(format_user_error("POOL EXHAUSTED").contains("Busy")); + } + + #[test] + fn format_user_error_mixed_case_timeout() { + // Case-insensitive matching should still extract method correctly + let result = format_user_error("Timeout Waiting For custom/method"); + assert!(result.contains("Request Timeout")); + assert!(result.contains("custom/method")); + } + + // ─── format_coded_error tests ─────────────────────────────────────────── + + #[test] + fn format_coded_error_401() { + let result = format_coded_error(401, "invalid token", None); + assert!(result.contains("Unauthorized")); + assert!(result.contains("401")); + assert!(result.contains("invalid token")); + } + + #[test] + fn format_coded_error_429() { + let result = format_coded_error(429, "", None); + assert!(result.contains("Rate Limited")); + assert!(result.contains("429")); + assert!(!result.contains("\n")); // no message, no newline + } + + #[test] + fn format_coded_error_503() { + let result = format_coded_error(503, "service unavailable", None); + assert!(result.contains("Service Unavailable")); + assert!(result.contains("503")); + assert!(result.contains("service unavailable")); + } + + #[test] + fn format_coded_error_json_rpc() { + let result = format_coded_error(-32602, "missing required parameter", None); + assert!(result.contains("Invalid Params")); + assert!(result.contains("-32602")); + } + + #[test] + fn format_coded_error_server_error_range() { + let result = format_coded_error(-32050, "internal failure", None); + assert!(result.contains("Server Error")); + assert!(result.contains("-32050")); + } + + #[test] + fn format_coded_error_connection_error() { + let result = format_coded_error(-32000, "connection refused", None); + assert!(result.contains("Server Error")); // -32000 falls in -32099..=-32000 range + assert!(result.contains("-32000")); + } + + #[test] + fn format_coded_error_unknown_code() { + let result = format_coded_error(999, "something happened", None); + assert!(result.contains("Error")); + assert!(result.contains("999")); + assert!(result.contains("something happened")); + } + + #[test] + fn format_coded_error_with_data_message() { + let result = format_coded_error(-32603, "Internal error", Some("model not supported")); + assert!(result.contains("Internal Error")); + assert!(result.contains("model not supported")); + } + + #[test] + fn format_coded_error_data_message_not_duplicated() { + // If data_message is already in message, don't repeat it + let result = format_coded_error(-32603, "model not supported", Some("model not supported")); + assert_eq!(result.matches("model not supported").count(), 1); + } + + #[test] + fn format_coded_error_32603_no_detail_shows_fallback() { + let result = format_coded_error(-32603, "Internal error", None); + assert!(result.contains("Internal Error")); + assert!(result.contains("did not return any error details")); + assert!(result.contains("agent's own logs")); + } + + #[test] + fn format_coded_error_32603_with_detail_no_fallback() { + let result = format_coded_error(-32603, "Internal error", Some("model not found")); + assert!(result.contains("model not found")); + assert!(!result.contains("did not return any error details")); + } + + #[test] + fn format_coded_error_32603_empty_detail_shows_fallback() { + let result = format_coded_error(-32603, "Internal error", Some("")); + assert!(result.contains("did not return any error details")); + } + + #[test] + fn format_coded_error_other_code_no_detail_no_fallback() { + // Fallback only applies to -32603 + let result = format_coded_error(-32602, "bad params", None); + assert!(!result.contains("did not return any error details")); + } + + #[test] + fn format_coded_error_32603_empty_message_still_shows_fallback() { + // Even when message is empty, fallback should appear + let result = format_coded_error(-32603, "", None); + assert!(result.contains("Internal Error")); + assert!(result.contains("did not return any error details")); + } + + #[test] + fn format_coded_error_32603_whitespace_detail_shows_fallback() { + // Whitespace-only detail should be treated as empty + let result = format_coded_error(-32603, "Internal error", Some(" ")); + assert!(result.contains("Internal Error")); + assert!(result.contains("did not return any error details")); + } + + #[test] + fn format_coded_error_500_no_detail_no_fallback() { + // HTTP 500 without detail should NOT get the ACP-specific hint + let result = format_coded_error(500, "server error", None); + assert!(result.contains("Internal Server Error")); + assert!(!result.contains("did not return any error details")); + } + + #[test] + fn format_coded_error_32603_fallback_does_not_duplicate_with_detail() { + // When detail is present, no fallback appears — mutually exclusive + let result = format_coded_error(-32603, "Internal error", Some("rate limit exceeded")); + assert!(result.contains("rate limit exceeded")); + assert!(!result.contains("did not return any error details")); + assert!(!result.contains("agent's own logs")); + } + + #[test] + fn format_coded_error_server_error_range_no_fallback() { + // Other JSON-RPC server error codes should NOT get the hint + let result = format_coded_error(-32099, "custom error", None); + assert!(!result.contains("did not return any error details")); + } + + #[test] + fn format_coded_error_32603_fallback_message_is_italic() { + // Verify Discord markdown italic formatting + let result = format_coded_error(-32603, "Internal error", None); + assert!(result.contains("_The agent did not return")); + assert!(result.ends_with("_")); + } +} diff --git a/crates/openab-core/src/format.rs b/crates/openab-core/src/format.rs new file mode 100644 index 000000000..d39410f15 --- /dev/null +++ b/crates/openab-core/src/format.rs @@ -0,0 +1,327 @@ +/// Split text into chunks at line boundaries, each <= limit Unicode characters (UTF-8 safe). +/// Discord's message limit counts Unicode characters, not bytes. +/// +/// Fenced code blocks (``` ... ```) are handled specially: if a split falls inside a +/// code block, the current chunk is closed with ``` and the next chunk is reopened with +/// the original opener (preserving language tag), so each chunk renders correctly. +/// +/// Invariant: every returned chunk satisfies `chunk.chars().count() <= limit`. +pub fn split_message(text: &str, limit: usize) -> Vec { + if text.chars().count() <= limit { + return vec![text.to_string()]; + } + + let mut chunks = Vec::new(); + let mut current = String::new(); + let mut current_len: usize = 0; + // When inside a fenced code block, holds the full opener line (e.g. "```rust"). + let mut fence_opener: Option = None; + + // Cost of appending "\n```" to close a fence before emitting a chunk. + const CLOSE_COST: usize = 4; // '\n' + '`' + '`' + '`' + + for line in text.split('\n') { + let line_chars = line.chars().count(); + let is_fence_line = line.starts_with("```"); + + // Determine overhead that must be reserved when inside a fence. + let close_reserve = if fence_opener.is_some() && !is_fence_line { + CLOSE_COST + } else { + 0 + }; + + // Check whether appending this line (+ newline separator + close reserve) overflows. + if !current.is_empty() && current_len + 1 + line_chars + close_reserve > limit { + // Emit current chunk, closing fence if needed. + if let Some(ref opener) = fence_opener { + if !is_fence_line { + current.push_str("\n```"); + } + chunks.push(std::mem::take(&mut current)); + // Reopen fence in next chunk with full opener (preserves language tag). + current.push_str(opener); + current_len = opener.chars().count(); + + if is_fence_line { + // The closing fence marker itself triggers the split. + fence_opener = None; + current.push('\n'); + current_len += 1; + current.push_str(line); + current_len += line_chars; + continue; + } else if current_len + 1 + line_chars + CLOSE_COST <= limit { + // Line fits in the reopened chunk (with room for \n + line + close marker). + current.push('\n'); + current_len += 1; + current.push_str(line); + current_len += line_chars; + continue; + } + // Otherwise: line doesn't fit even in a fresh reopened chunk. + // Fall through to the normal line-processing logic below, + // which will hit the hard-split path if line_chars > limit, + // or the normal append path otherwise. + } else { + chunks.push(std::mem::take(&mut current)); + current_len = 0; + } + } + + // Newline separator between lines within a chunk. + if !current.is_empty() { + current.push('\n'); + current_len += 1; + } + + // Track fence state. + if is_fence_line { + if fence_opener.is_some() { + fence_opener = None; + } else { + fence_opener = Some(line.to_string()); + } + } + + // Hard-split: single line exceeds available space. + // This triggers when the line itself is longer than limit, OR when the + // line doesn't fit in the current chunk even after accounting for fence + // close overhead (e.g. after a reopen where opener already consumed space). + let effective_avail = if fence_opener.is_some() { + limit.saturating_sub(current_len + CLOSE_COST) + } else { + limit.saturating_sub(current_len) + }; + if line_chars > effective_avail { + let overhead = if let Some(ref opener) = fence_opener { + // opener + '\n' at start, '\n```' at end + opener.chars().count() + 1 + CLOSE_COST + } else { + 0 + }; + // If limit can't even fit overhead, fall back to unfenced hard-split. + let capacity = limit.saturating_sub(overhead); + if let Some(opener) = fence_opener.as_ref().filter(|_| capacity > 0) { + // Fenced hard-split: each mid chunk = opener\n + chars + \n``` + let opener_len = opener.chars().count(); + let mut chars = line.chars().peekable(); + + // Fill remaining space in current chunk first. + let avail_first = if current_len > 0 { + limit.saturating_sub(current_len + CLOSE_COST) + } else { + capacity + }; + for _ in 0..avail_first { + if let Some(ch) = chars.next() { + current.push(ch); + current_len += 1; + } else { + break; + } + } + + while chars.peek().is_some() { + // Close current fenced chunk. + current.push_str("\n```"); + chunks.push(std::mem::take(&mut current)); + // Reopen. + current.push_str(opener); + current.push('\n'); + current_len = opener_len + 1; + for _ in 0..capacity { + if let Some(ch) = chars.next() { + current.push(ch); + current_len += 1; + } else { + break; + } + } + } + } else { + // Plain hard-split (no fence or limit too small for fence wrapping). + for ch in line.chars() { + if current_len >= limit { + chunks.push(std::mem::take(&mut current)); + current_len = 0; + } + current.push(ch); + current_len += 1; + } + } + } else { + current.push_str(line); + current_len += line_chars; + } + } + + if !current.is_empty() { + // Close any trailing open fence. + if fence_opener.is_some() { + current.push_str("\n```"); + } + chunks.push(current); + } + chunks +} + +/// Shorten a prompt into a thread title: collapse GitHub URLs and cap at 40 chars. +pub fn shorten_thread_name(prompt: &str) -> String { + use std::sync::LazyLock; + static GH_RE: LazyLock = LazyLock::new(|| { + regex::Regex::new(r"https?://github\.com/([^/]+/[^/]+)/(issues|pull)/(\d+)").unwrap() + }); + // Strip @(role) and @(user) placeholders left by resolve_mentions() + let cleaned = prompt.replace("@(role)", "").replace("@(user)", ""); + let shortened = GH_RE.replace_all(cleaned.trim(), "$1#$3"); + let name: String = shortened.chars().take(40).collect(); + if name.len() < shortened.len() { + format!("{name}...") + } else { + name + } +} + +/// Truncate a string to at most `limit` Unicode characters, keeping the tail +/// (most recent output) for better streaming UX. +pub fn truncate_chars_tail(s: &str, limit: usize) -> String { + let total = s.chars().count(); + if total <= limit { + return s.to_string(); + } + s.chars().skip(total - limit).collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Helper: assert every chunk respects the limit. + fn assert_length_invariant(chunks: &[String], limit: usize) { + for (i, chunk) in chunks.iter().enumerate() { + let len = chunk.chars().count(); + assert!( + len <= limit, + "chunk {i} has {len} chars, exceeds limit {limit}:\n{chunk}" + ); + } + } + + #[test] + fn no_split_under_limit() { + let text = "hello\nworld"; + let chunks = split_message(text, 100); + assert_eq!(chunks.len(), 1); + assert_eq!(chunks[0], text); + } + + #[test] + fn plain_text_split_respects_limit() { + let text = "aaaa\nbbbb\ncccc\ndddd"; + let chunks = split_message(text, 10); + assert_length_invariant(&chunks, 10); + assert!(chunks.len() > 1); + } + + #[test] + fn fenced_split_preserves_language_tag() { + // ```rust\n + 1990 chars of content + \n``` — should split + let content_line = "x".repeat(1990); + let text = format!("```rust\n{content_line}\nanother line here\n```"); + let chunks = split_message(&text, 2000); + assert_length_invariant(&chunks, 2000); + // First chunk should start with ```rust + assert!(chunks[0].starts_with("```rust")); + // If split happened, second chunk should reopen with ```rust + if chunks.len() > 1 { + assert!( + chunks[1].starts_with("```rust"), + "second chunk should reopen with language tag: {}", + &chunks[1][..chunks[1].len().min(20)] + ); + } + } + + #[test] + fn fenced_split_close_overhead_budgeted() { + // Construct a fenced block where content + close marker would overflow + // without proper budgeting. + // limit=50, opener="```" (3), close="\n```" (4) + // Available for content per chunk: 50 - 3 - 1 - 4 = 42 (with opener+newline+close) + let line1 = "a".repeat(40); + let line2 = "b".repeat(40); + let text = format!("```\n{line1}\n{line2}\n```"); + let chunks = split_message(&text, 50); + assert_length_invariant(&chunks, 50); + } + + #[test] + fn reopen_path_no_overflow() { + // Regression: limit=2000, fenced block with a 1996-char line. + // Old code would produce 2004-char chunk due to reopen + extra \n. + let content = "x".repeat(1990); + let text = format!("```rust\n{content}\nshort\n```"); + let chunks = split_message(&text, 2000); + assert_length_invariant(&chunks, 2000); + } + + #[test] + fn hard_split_fenced_respects_limit() { + // A single very long line inside a fence. + let long_line = "x".repeat(100); + let text = format!("```\n{long_line}\n```"); + let chunks = split_message(&text, 20); + assert_length_invariant(&chunks, 20); + // All content should be present + let total_x: usize = chunks + .iter() + .map(|c| c.chars().filter(|&ch| ch == 'x').count()) + .sum(); + assert_eq!(total_x, 100); + } + + #[test] + fn hard_split_plain_respects_limit() { + let long_line = "y".repeat(50); + let text = format!("before\n{long_line}\nafter"); + let chunks = split_message(&text, 10); + assert_length_invariant(&chunks, 10); + } + + #[test] + fn closing_fence_triggers_split() { + // The closing ``` itself pushes over the limit. + let content = "a".repeat(44); + // "```\n" + 44 chars + "\n```" = 3 + 1 + 44 + 1 + 3 = 52 + let text = format!("```\n{content}\n```"); + let chunks = split_message(&text, 50); + assert_length_invariant(&chunks, 50); + } + + #[test] + fn multi_fence_blocks() { + let text = "text\n```python\ncode1\ncode2\n```\nmore text\n```js\ncode3\n```"; + let chunks = split_message(text, 25); + assert_length_invariant(&chunks, 25); + } + + #[test] + fn fence_balance_across_chunks() { + // Every chunk should have balanced fences (even number of ``` lines). + let content = (0..20) + .map(|i| format!("line {i}")) + .collect::>() + .join("\n"); + let text = format!("```\n{content}\n```"); + let chunks = split_message(&text, 30); + assert_length_invariant(&chunks, 30); + for (i, chunk) in chunks.iter().enumerate() { + let fence_count = chunk.lines().filter(|l| l.starts_with("```")).count(); + assert!( + fence_count % 2 == 0, + "chunk {i} has unbalanced fences ({fence_count}):\n{chunk}" + ); + } + } +} diff --git a/crates/openab-core/src/gateway.rs b/crates/openab-core/src/gateway.rs new file mode 100644 index 000000000..d73819ee4 --- /dev/null +++ b/crates/openab-core/src/gateway.rs @@ -0,0 +1,1054 @@ +use crate::acp::ContentBlock; +use crate::adapter::{AdapterRouter, ChannelRef, ChatAdapter, MessageRef, SenderContext}; +use anyhow::Result; +use async_trait::async_trait; +use futures_util::{SinkExt, StreamExt}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::Mutex; +use tokio_tungstenite::tungstenite::Message; +use tracing::{error, info, warn}; + +/// Timeout for waiting on gateway reply acknowledgement. +const GATEWAY_REPLY_TIMEOUT_SECS: u64 = 5; + +/// Platforms whose gateway adapter emits a `GatewayResponse` for `edit_message` +/// so core can observe edit success or failure (used to gate the per-edit +/// response-wait below). +/// +/// Today only Feishu does, because it is the only adapter with a known +/// per-message edit cap (errcode 230072) that requires core-side recovery, and +/// the only one wired to ack edits. +/// +/// NOTE: this gates the `edit_message` response-wait only. `delete_message` is +/// unconditionally fire-and-forget (the recovery path sends fresh content +/// regardless of the delete outcome), so it does not consult this list. +/// +/// TECH DEBT: this is platform-identity standing in for a *capability*. The +/// right model is a capability handshake at gateway-connect time ("does this +/// adapter acknowledge edits?") rather than a hardcoded platform name. We +/// accept the hardcode now because there is no handshake protocol yet; when one +/// lands, replace this allowlist with a negotiated capability flag. Any new +/// adapter that wires request/response for edits MUST be added here, or its +/// edit failures stay invisible to core (silent failure mode). +const EDIT_RESPONSE_PLATFORMS: &[&str] = &["feishu"]; + +/// Whether `platform` acknowledges `edit_message` with a `GatewayResponse`. +/// See `EDIT_RESPONSE_PLATFORMS`. +fn platform_acks_writes(platform: &str) -> bool { + EDIT_RESPONSE_PLATFORMS.contains(&platform) +} + +// --- Gateway event/reply schemas (mirrors gateway service) --- + +#[derive(Clone, Debug, Deserialize)] +struct GatewayEvent { + #[allow(dead_code)] + schema: String, + event_id: String, + #[allow(dead_code)] + timestamp: String, + platform: String, + channel: GwChannel, + sender: GwSender, + content: GwContent, + #[serde(default)] + #[allow(dead_code)] + mentions: Vec, + message_id: String, +} + +#[derive(Clone, Debug, Deserialize)] +struct GwChannel { + id: String, + #[serde(rename = "type")] + channel_type: String, + thread_id: Option, +} + +#[derive(Clone, Debug, Deserialize)] +struct GwSender { + id: String, + name: String, + display_name: String, + is_bot: bool, +} + +#[derive(Clone, Debug, Deserialize)] +struct GwContent { + #[allow(dead_code)] + #[serde(rename = "type")] + content_type: String, + text: String, + #[serde(default)] + attachments: Vec, +} + +#[derive(Clone, Debug, Deserialize)] +struct GwAttachment { + #[serde(rename = "type")] + attachment_type: String, + filename: String, + mime_type: String, + #[serde(default)] + data: String, + #[allow(dead_code)] + size: u64, + /// Colocate mode: local file path (preferred over base64 `data` when present) + #[serde(default)] + path: Option, +} + +#[derive(Serialize)] +struct GatewayReply { + schema: String, + reply_to: String, + platform: String, + channel: ReplyChannel, + content: ReplyContent, + #[serde(skip_serializing_if = "Option::is_none")] + command: Option, + #[serde(skip_serializing_if = "Option::is_none")] + request_id: Option, + /// When set, the gateway should send this message as a reply/quote to the specified message ID. + /// Unlike `reply_to` (routing/dedup identifier for the triggering event), this field controls + /// the visual reply/quote UI on the platform. Falls back to plain send on failure. + #[serde(skip_serializing_if = "Option::is_none")] + quote_message_id: Option, +} + +#[derive(Serialize)] +struct ReplyChannel { + id: String, + #[serde(skip_serializing_if = "Option::is_none")] + thread_id: Option, +} + +#[derive(Serialize)] +struct ReplyContent { + #[serde(rename = "type")] + content_type: String, + text: String, +} + +#[derive(Clone, Debug, Deserialize)] +struct GatewayResponse { + #[allow(dead_code)] + schema: String, + request_id: String, + success: bool, + thread_id: Option, + message_id: Option, + error: Option, +} + +// --- GatewayAdapter: ChatAdapter over WebSocket --- + +type PendingRequests = Arc>>>; +type SharedWsTx = Arc< + Mutex< + futures_util::stream::SplitSink< + tokio_tungstenite::WebSocketStream< + tokio_tungstenite::MaybeTlsStream, + >, + Message, + >, + >, +>; + +pub struct GatewayAdapter { + ws_tx: SharedWsTx, + pending: PendingRequests, + platform_name: &'static str, + streaming: bool, + streaming_placeholder: bool, +} + +impl GatewayAdapter { + fn new( + ws_tx: SharedWsTx, + pending: PendingRequests, + platform_name: &'static str, + streaming: bool, + streaming_placeholder: bool, + ) -> Self { + Self { + ws_tx, + pending, + platform_name, + streaming, + streaming_placeholder, + } + } + + /// Internal helper for send_message / send_message_with_reply. + async fn send_gateway_reply( + &self, + channel: &ChannelRef, + content: &str, + quote_message_id: Option<&str>, + ) -> Result { + let req_id = if self.streaming { + Some(format!("req_{}", uuid::Uuid::new_v4())) + } else { + None + }; + let pending_rx = if let Some(ref id) = req_id { + let (tx, rx) = tokio::sync::oneshot::channel(); + self.pending.lock().await.insert(id.clone(), tx); + Some(rx) + } else { + None + }; + let reply = GatewayReply { + schema: "openab.gateway.reply.v1".into(), + reply_to: channel.origin_event_id.clone().unwrap_or_default(), + platform: channel.platform.clone(), + channel: ReplyChannel { + id: channel.channel_id.clone(), + thread_id: channel.thread_id.clone(), + }, + content: ReplyContent { + content_type: "text".into(), + text: content.into(), + }, + command: None, + request_id: req_id.clone(), + quote_message_id: quote_message_id.map(|s| s.to_string()), + }; + let json = serde_json::to_string(&reply)?; + if let Err(e) = self.ws_tx.lock().await.send(Message::Text(json)).await { + if let Some(ref id) = req_id { + self.pending.lock().await.remove(id); + } + return Err(e.into()); + } + let msg_id = if let (Some(rx), Some(ref id)) = (pending_rx, &req_id) { + match tokio::time::timeout(std::time::Duration::from_secs(GATEWAY_REPLY_TIMEOUT_SECS), rx).await { + Ok(Ok(resp)) if resp.success => resp.message_id.unwrap_or_else(|| "gw_sent".into()), + Ok(Ok(resp)) => { + // Gateway explicitly reported failure (success=false). Surface + // as Err so dispatch sets ❌ instead of 🆗 over an incomplete + // delivery. Examples: Feishu edit cap reached after append-new + // fallback also failed; chunked send delivered N/M chunks. + let err_msg = resp.error.clone() + .unwrap_or_else(|| "gateway reported failure".to_string()); + tracing::warn!(request_id = %id, error = %err_msg, "gateway replied with failure"); + return Err(anyhow::anyhow!("gateway reported failure: {err_msg}")); + } + Ok(Err(_)) => { + // Channel closed (gateway shutting down or pending dropped). + // Maintain legacy behavior — adapters that don't implement + // GatewayResponse for all reply types (LINE, Teams) rely on + // this for non-failure outcomes. + tracing::warn!(request_id = %id, "gateway response channel closed"); + "gw_sent".into() + } + Err(_) => { + // Timeout. Many adapters (LINE, Teams) intentionally do not + // emit GatewayResponse for replies, so timeout is the expected + // path for them. Maintain legacy behavior to avoid breaking + // platforms that have not yet wired request/response feedback. + tracing::warn!(request_id = %id, "gateway reply timed out"); + self.pending.lock().await.remove(id); + "gw_sent".into() + } + } + } else { + "gw_sent".into() + }; + Ok(MessageRef { + channel: channel.clone(), + message_id: msg_id, + }) + } +} + +/// Send a fire-and-forget reply via the shared WebSocket (no request-response). +/// Used for slash command responses where we don't need message_id back. +async fn send_fire_and_forget( + ws_tx: &SharedWsTx, + channel: &ChannelRef, + content: &str, +) -> Result<()> { + let reply = GatewayReply { + schema: "openab.gateway.reply.v1".into(), + reply_to: channel.origin_event_id.clone().unwrap_or_default(), + platform: channel.platform.clone(), + channel: ReplyChannel { + id: channel.channel_id.clone(), + thread_id: channel.thread_id.clone(), + }, + content: ReplyContent { + content_type: "text".into(), + text: content.into(), + }, + command: None, + request_id: None, + quote_message_id: None, + }; + let json = serde_json::to_string(&reply)?; + ws_tx.lock().await.send(Message::Text(json)).await?; + Ok(()) +} + +/// Handle `/models` or `/agents` text commands for gateway platforms. +/// Returns the response message, or None if the command was not recognized. +/// +/// Supported syntax: +/// /model list — numbered list of available models +/// /model set — switch by exact name or number +/// /models — alias of /model list +/// /agent list — numbered list of available agents +/// /agent set — switch by exact name or number +/// /agents — alias of /agent list +async fn handle_config_command( + trimmed: &str, + router: &AdapterRouter, + thread_key: &str, +) -> Option { + // Parse command: /model or /models (alias) + let (category, label, action, arg) = if trimmed == "/models" { + ("model", "model", "list", "") + } else if trimmed == "/agents" { + ("agent", "agent", "list", "") + } else if trimmed.starts_with("/model ") { + let rest = trimmed.strip_prefix("/model ").unwrap().trim(); + let (action, arg) = rest.split_once(' ').unwrap_or((rest, "")); + ("model", "model", action, arg.trim()) + } else if trimmed.starts_with("/agent ") { + let rest = trimmed.strip_prefix("/agent ").unwrap().trim(); + let (action, arg) = rest.split_once(' ').unwrap_or((rest, "")); + ("agent", "agent", action, arg.trim()) + } else if trimmed == "/model" { + ("model", "model", "list", "") + } else if trimmed == "/agent" { + ("agent", "agent", "list", "") + } else { + return None; + }; + + // Support both "agent" and "mode" categories (kiro-cli vs cursor-agent) + let categories: &[&str] = if category == "agent" { + &["agent", "mode"] + } else { + &[category] + }; + + let options = router.pool().get_config_options(thread_key).await; + let filtered: Vec<_> = options + .iter() + .filter(|o| { + o.category + .as_deref() + .is_some_and(|c| categories.contains(&c)) + }) + .collect(); + + if filtered.is_empty() { + return Some(format!( + "⚠️ No {label} options available. Start a conversation first." + )); + } + + // Collect all values with index for numbered list / set-by-number + let mut all_values: Vec<(String, String, String, bool)> = Vec::new(); // (config_id, value, name, is_current) + for opt in &filtered { + for v in &opt.options { + all_values.push(( + opt.id.clone(), + v.value.clone(), + v.name.clone(), + v.value == opt.current_value, + )); + } + } + + match action { + "list" => { + let mut lines = vec![format!("🔧 Available {label}s:")]; + for (i, (_, _, name, is_current)) in all_values.iter().enumerate() { + let marker = if *is_current { " ✅" } else { "" }; + lines.push(format!(" {}. {}{}", i + 1, name, marker)); + } + lines.push(format!("\nUsage: /{label} set ")); + Some(lines.join("\n")) + } + "set" => { + if arg.is_empty() { + return Some(format!("Usage: /{label} set ")); + } + // Try number first + if let Ok(num) = arg.parse::() { + if num >= 1 && num <= all_values.len() { + let (ref config_id, ref value, ref name, _) = all_values[num - 1]; + return match router + .pool() + .set_config_option(thread_key, config_id, value) + .await + { + Ok(_) => Some(format!("✅ Switched to **{name}**")), + Err(e) => Some(format!("❌ Failed to switch: {e}")), + }; + } else { + return Some(format!("⚠️ Invalid number. Use 1–{}.", all_values.len())); + } + } + // Exact match on value or name + let arg_lower = arg.to_lowercase(); + for (config_id, value, name, _) in &all_values { + if value.to_lowercase() == arg_lower || name.to_lowercase() == arg_lower { + return match router + .pool() + .set_config_option(thread_key, config_id, value) + .await + { + Ok(_) => Some(format!("✅ Switched to **{name}**")), + Err(e) => Some(format!("❌ Failed to switch: {e}")), + }; + } + } + Some(format!( + "⚠️ No {label} matching \"{arg}\". Use /{label} list to see options." + )) + } + _ => Some(format!( + "Unknown action \"{action}\". Usage: /{label} list | /{label} set " + )), + } +} + +#[async_trait] +impl ChatAdapter for GatewayAdapter { + fn platform(&self) -> &'static str { + self.platform_name + } + + fn message_limit(&self) -> usize { + 4096 // Telegram limit + } + + async fn send_message(&self, channel: &ChannelRef, content: &str) -> Result { + self.send_gateway_reply(channel, content, None).await + } + + async fn send_message_with_reply( + &self, + channel: &ChannelRef, + content: &str, + reply_to_message_id: &str, + ) -> Result { + self.send_gateway_reply(channel, content, Some(reply_to_message_id)).await + } + + async fn create_thread( + &self, + channel: &ChannelRef, + _trigger_msg: &MessageRef, + title: &str, + ) -> Result { + // Send create_topic command to gateway + let req_id = format!("req_{}", uuid::Uuid::new_v4()); + let (tx, rx) = tokio::sync::oneshot::channel(); + self.pending.lock().await.insert(req_id.clone(), tx); + + let reply = GatewayReply { + schema: "openab.gateway.reply.v1".into(), + reply_to: String::new(), + platform: channel.platform.clone(), + channel: ReplyChannel { + id: channel.channel_id.clone(), + thread_id: None, + }, + content: ReplyContent { + content_type: "text".into(), + text: title.into(), + }, + command: Some("create_topic".into()), + request_id: Some(req_id.clone()), + quote_message_id: None, + }; + let json = serde_json::to_string(&reply)?; + self.ws_tx.lock().await.send(Message::Text(json)).await?; + + // Wait for response (5s timeout) + match tokio::time::timeout(std::time::Duration::from_secs(5), rx).await { + Ok(Ok(resp)) if resp.success => Ok(ChannelRef { + platform: channel.platform.clone(), + channel_id: channel.channel_id.clone(), + thread_id: resp.thread_id, + parent_id: None, + origin_event_id: channel.origin_event_id.clone(), + }), + Ok(Ok(resp)) => { + warn!(err = ?resp.error, "create_topic failed, falling back to same channel"); + Ok(channel.clone()) + } + _ => { + warn!("create_topic timeout, falling back to same channel"); + self.pending.lock().await.remove(&req_id); + Ok(channel.clone()) + } + } + } + + async fn add_reaction(&self, msg: &MessageRef, emoji: &str) -> Result<()> { + let reply = GatewayReply { + schema: "openab.gateway.reply.v1".into(), + reply_to: msg.message_id.clone(), + platform: msg.channel.platform.clone(), + channel: ReplyChannel { + id: msg.channel.channel_id.clone(), + thread_id: msg.channel.thread_id.clone(), + }, + content: ReplyContent { + content_type: "text".into(), + text: emoji.into(), + }, + command: Some("add_reaction".into()), + quote_message_id: None, + request_id: None, + }; + let json = serde_json::to_string(&reply)?; + self.ws_tx.lock().await.send(Message::Text(json)).await?; + Ok(()) + } + + async fn remove_reaction(&self, msg: &MessageRef, emoji: &str) -> Result<()> { + let reply = GatewayReply { + schema: "openab.gateway.reply.v1".into(), + reply_to: msg.message_id.clone(), + platform: msg.channel.platform.clone(), + channel: ReplyChannel { + id: msg.channel.channel_id.clone(), + thread_id: msg.channel.thread_id.clone(), + }, + content: ReplyContent { + content_type: "text".into(), + text: emoji.into(), + }, + command: Some("remove_reaction".into()), + quote_message_id: None, + request_id: None, + }; + let json = serde_json::to_string(&reply)?; + self.ws_tx.lock().await.send(Message::Text(json)).await?; + Ok(()) + } + + async fn edit_message(&self, msg: &MessageRef, content: &str) -> Result<()> { + // Use a short request/response cycle so we can react to platform-level + // edit failures (e.g. Feishu's 20-edits-per-message cap, errcode 230072). + // Without this, edit_message was fire-and-forget and core never saw cap + // signals — cosmetic streaming would keep flushing forever and the final + // edit fallback to send_message could not trigger. + // + // Scope intentionally limited to platforms that ack writes (see + // EDIT_RESPONSE_PLATFORMS). Other adapters (LINE, Teams, Slack, Discord, + // …) keep the original fire-and-forget path so cosmetic streaming on + // those platforms does not pay a response-wait penalty per flush. + const EDIT_RESPONSE_TIMEOUT_MS: u64 = 800; + let needs_response = self.streaming && platform_acks_writes(&msg.channel.platform); + + let req_id = if needs_response { + Some(format!("req_{}", uuid::Uuid::new_v4())) + } else { + None + }; + let pending_rx = if let Some(ref id) = req_id { + let (tx, rx) = tokio::sync::oneshot::channel(); + self.pending.lock().await.insert(id.clone(), tx); + Some(rx) + } else { + None + }; + let reply = GatewayReply { + schema: "openab.gateway.reply.v1".into(), + reply_to: msg.message_id.clone(), + platform: msg.channel.platform.clone(), + channel: ReplyChannel { + id: msg.channel.channel_id.clone(), + thread_id: msg.channel.thread_id.clone(), + }, + content: ReplyContent { + content_type: "text".into(), + text: content.into(), + }, + command: Some("edit_message".into()), + quote_message_id: None, + request_id: req_id.clone(), + }; + let json = serde_json::to_string(&reply)?; + if let Err(e) = self.ws_tx.lock().await.send(Message::Text(json)).await { + if let Some(ref id) = req_id { + self.pending.lock().await.remove(id); + } + return Err(e.into()); + } + if let (Some(rx), Some(ref id)) = (pending_rx, &req_id) { + match tokio::time::timeout( + std::time::Duration::from_millis(EDIT_RESPONSE_TIMEOUT_MS), + rx, + ).await { + Ok(Ok(resp)) if resp.success => Ok(()), + Ok(Ok(resp)) => { + let err_msg = resp.error.clone() + .unwrap_or_else(|| "gateway reported edit failure".to_string()); + tracing::warn!(request_id = %id, error = %err_msg, "edit_message gateway replied failure"); + Err(anyhow::anyhow!("edit failure: {err_msg}")) + } + Ok(Err(_)) => { + tracing::debug!(request_id = %id, "edit_message gateway response channel closed"); + Ok(()) + } + Err(_) => { + // Timeout — feishu didn't respond within the window + // (probably a slow API). Treat as success to avoid + // false-positive ❌; the cap-reached path already short- + // circuits much faster (gateway returns immediately). + self.pending.lock().await.remove(id); + Ok(()) + } + } + } else { + // Non-feishu (or non-streaming): fire-and-forget, no added latency. + Ok(()) + } + } + + /// Override default delete_message (which falls back to edit-to-zero-width) + /// so platforms with native delete APIs (e.g. Feishu DELETE /im/v1/messages/{id}) + /// can perform real deletions. Critical for the streaming-edit-cap recovery + /// path: when Feishu's 20-edits-per-message cap is hit and we send full + /// content as a fresh message, we need to remove the half-edited placeholder + /// to avoid duplicated content. The default zero-width-edit fallback would + /// itself fail on a cap-reached message, leaving the placeholder visible. + /// + /// Fire-and-forget: gateway adapters that don't implement delete will simply + /// ignore the command. Failure is non-fatal — if delete fails, the user sees + /// the placeholder remain (same behavior as before this override). We do not + /// wait on a response here: the recovery path sends fresh content regardless + /// of whether the delete landed, so a response would only buy an extra log + /// line at the cost of a per-finalize wait. + async fn delete_message(&self, msg: &MessageRef) -> Result<()> { + let reply = GatewayReply { + schema: "openab.gateway.reply.v1".into(), + reply_to: msg.message_id.clone(), + platform: msg.channel.platform.clone(), + channel: ReplyChannel { + id: msg.channel.channel_id.clone(), + thread_id: msg.channel.thread_id.clone(), + }, + content: ReplyContent { + content_type: "text".into(), + text: String::new(), + }, + command: Some("delete_message".into()), + quote_message_id: None, + request_id: None, + }; + let json = serde_json::to_string(&reply)?; + self.ws_tx.lock().await.send(Message::Text(json)).await?; + Ok(()) + } + + fn use_streaming(&self, _other_bot_present: bool) -> bool { + self.streaming + } + + fn show_streaming_placeholder(&self) -> bool { + self.streaming_placeholder + } +} + +// --- Run the gateway adapter (connects to gateway WS, routes events to AdapterRouter) --- + +/// Resolved gateway configuration passed to the adapter at startup. +pub struct GatewayParams { + pub url: String, + pub platform: String, + pub token: Option, + pub bot_username: Option, + pub allow_all_channels: bool, + pub allowed_channels: Vec, + pub allow_all_users: bool, + pub allowed_users: Vec, + pub streaming: bool, + pub streaming_placeholder: bool, + pub stt: crate::config::SttConfig, +} + +pub async fn run_gateway_adapter( + params: GatewayParams, + mut shutdown_rx: tokio::sync::watch::Receiver, + dispatcher: Arc, + router: Arc, +) -> Result<()> { + let platform: &'static str = Box::leak(params.platform.into_boxed_str()); + + // Append auth token as query param if configured + let gateway_url = params.url; + let bot_username = params.bot_username; + let allow_all_channels = params.allow_all_channels; + let allowed_channels = params.allowed_channels; + let allow_all_users = params.allow_all_users; + let allowed_users = params.allowed_users; + let streaming = params.streaming; + let streaming_placeholder = params.streaming_placeholder; + let stt_config = params.stt; + + let connect_url = match ¶ms.token { + Some(token) => { + let sep = if gateway_url.contains('?') { "&" } else { "?" }; + format!("{gateway_url}{sep}token={token}") + } + None => { + warn!("gateway.token not set — WebSocket connection is NOT authenticated"); + gateway_url.clone() + } + }; + let mut backoff_secs = 1u64; + const MAX_BACKOFF: u64 = 30; + + loop { + // Check shutdown before connecting + if *shutdown_rx.borrow() { + info!("gateway adapter shutting down"); + return Ok(()); + } + + info!(url = %gateway_url, "connecting to custom gateway"); + + let ws_stream = match tokio_tungstenite::connect_async(&connect_url).await { + Ok((stream, _)) => { + backoff_secs = 1; // reset on success + info!("connected to gateway"); + stream + } + Err(e) => { + error!(err = %e, backoff = backoff_secs, "gateway connection failed, retrying"); + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_secs(backoff_secs)) => {} + _ = shutdown_rx.changed() => { return Ok(()); } + } + backoff_secs = (backoff_secs * 2).min(MAX_BACKOFF); + continue; + } + }; + + let (ws_tx, mut ws_rx) = ws_stream.split(); + let ws_tx: SharedWsTx = Arc::new(Mutex::new(ws_tx)); + let pending: PendingRequests = Arc::new(Mutex::new(HashMap::new())); + let adapter: Arc = Arc::new(GatewayAdapter::new( + ws_tx.clone(), + pending.clone(), + platform, + streaming, + streaming_placeholder, + )); + let slash_ws_tx = ws_tx.clone(); // for fire-and-forget slash command responses + let mut tasks: tokio::task::JoinSet<()> = tokio::task::JoinSet::new(); + + loop { + tokio::select! { + msg = ws_rx.next() => { + match msg { + Some(Ok(Message::Text(text))) => { + let text_str: &str = &text; + + // Check if it's a response to a pending command + if let Ok(resp) = serde_json::from_str::(text_str) { + if resp.schema == "openab.gateway.response.v1" { + if let Some(tx) = pending.lock().await.remove(&resp.request_id) { + let _ = tx.send(resp); + } + continue; + } + } + + match serde_json::from_str::(text_str) { + Ok(event) => { + // TODO: gateway adapters (feishu) do their own bot filtering + // via AllowBots + trusted_bot_ids, but Telegram does not. + // When Feishu lifts the bot-to-bot delivery restriction, + // this guard needs to become adapter-aware (e.g. a field on + // GatewayEvent indicating the adapter already filtered bots). + if event.sender.is_bot { + continue; + } + + // Channel allowlist gate + if !allow_all_channels && !allowed_channels.contains(&event.channel.id) { + info!(channel = %event.channel.id, "gateway: channel not in allowed_channels, skipping"); + continue; + } + + // User allowlist gate + if !allow_all_users && !allowed_users.contains(&event.sender.id) { + info!(sender = %event.sender.id, "gateway: user not in allowed_users, skipping"); + continue; + } + + // @mention gating: in groups, only respond if bot is mentioned + // DMs (private) and thread replies always pass through + let is_group = event.channel.channel_type == "group" + || event.channel.channel_type == "supergroup"; + let in_thread = event.channel.thread_id.is_some(); + if is_group && !in_thread { + if let Some(ref bot_name) = bot_username { + let mentioned = event.mentions.iter().any(|m| m == bot_name); + if !mentioned { + continue; // skip non-mentioned group messages + } + } + } + + info!( + platform = %event.platform, + sender = %event.sender.name, + channel = %event.channel.id, + "gateway event received" + ); + + let channel = ChannelRef { + platform: event.platform.clone(), + channel_id: event.channel.id.clone(), + thread_id: event.channel.thread_id.clone(), + parent_id: None, + origin_event_id: Some(event.event_id.clone()), + }; + + let sender_ctx = SenderContext { + schema: "openab.sender.v1".into(), + sender_id: event.sender.id.clone(), + sender_name: event.sender.name.clone(), + display_name: event.sender.display_name.clone(), + channel: event.channel.channel_type.clone(), + channel_id: event.channel.id.clone(), + thread_id: event.channel.thread_id.clone(), + is_bot: event.sender.is_bot, + // Gateway: use event timestamp if available, else broker receive time + timestamp: Some(if event.timestamp.is_empty() { + crate::timestamp::now_iso8601() + } else { + event.timestamp.clone() + }), + message_id: if event.message_id.is_empty() { None } else { Some(event.message_id.clone()) }, + receiver_id: None, // gateway does not yet resolve receiver identity + }; + let sender_json = serde_json::to_string(&sender_ctx) + .unwrap_or_default(); + + let trigger_msg = MessageRef { + channel: channel.clone(), + message_id: event.message_id.clone(), + }; + + let adapter = adapter.clone(); + let prompt = event.content.text.clone(); + let sender_name = event.sender.name.clone(); + let sender_id = event.sender.id.clone(); + let dispatcher = dispatcher.clone(); + + // Convert gateway attachments to ContentBlocks + let mut extra_blocks = Vec::new(); + for att in &event.content.attachments { + // Read bytes: prefer file path (colocate), fallback to base64 + let bytes_result = if let Some(ref path) = att.path { + tokio::fs::read(path).await.map_err(|e| e.to_string()) + } else if !att.data.is_empty() { + use base64::Engine; + base64::engine::general_purpose::STANDARD + .decode(&att.data) + .map_err(|e| e.to_string()) + } else { + Err("no path or data".into()) + }; + + match att.attachment_type.as_str() { + "image" => { + match bytes_result { + Ok(bytes) => { + use base64::Engine; + let b64 = base64::engine::general_purpose::STANDARD.encode(&bytes); + extra_blocks.push(ContentBlock::Image { + media_type: att.mime_type.clone(), + data: b64, + }); + } + Err(e) => { + tracing::warn!(filename = %att.filename, error = %e, "gateway image read failed"); + } + } + } + "text_file" => { + if let Ok(bytes) = bytes_result { + let text = String::from_utf8_lossy(&bytes); + extra_blocks.push(ContentBlock::Text { + text: format!("```{}\n{}\n```", att.filename, text), + }); + } + } + "audio" if stt_config.enabled => { + match bytes_result { + Ok(bytes) => { + match crate::stt::transcribe( + &crate::media::HTTP_CLIENT, + &stt_config, + bytes, + att.filename.clone(), + &att.mime_type, + ).await { + Some(transcript) => { + extra_blocks.push(ContentBlock::Text { + text: format!("[Voice message transcript]: {transcript}"), + }); + } + None => { + tracing::warn!(filename = %att.filename, "gateway audio STT failed"); + extra_blocks.push(ContentBlock::Text { + text: format!( + "[Voice message — transcription failed for {}]", + att.filename + ), + }); + } + } + } + Err(e) => { + tracing::warn!(filename = %att.filename, error = %e, "gateway audio read failed"); + extra_blocks.push(ContentBlock::Text { + text: format!( + "[Voice message — read failed for {}]", + att.filename + ), + }); + } + } + } + "audio" => { + tracing::debug!(filename = %att.filename, "audio attachment skipped — STT not enabled"); + } + _ => {} + } + } + + // Slash command interception for gateway platforms + // (Feishu/LINE/Telegram don't have native slash commands) + // Use fire-and-forget send — slash command responses don't + // need message_id for streaming edits. + let trimmed = prompt.trim(); + if trimmed == "/reset" { + let thread_id_str = event.channel.thread_id.as_deref().unwrap_or(&event.channel.id); + let thread_key = format!("{}:{}", event.platform, thread_id_str); + let dropped = dispatcher.cancel_buffered_thread(event.platform.as_str(), thread_id_str); + let msg = match (router.pool().reset_session(&thread_key).await, dropped) { + (Ok(()), 0) => "🔄 Session reset. Start a new conversation!".to_string(), + (Ok(()), n) => format!("🔄 Session reset. Dropped {n} buffered message(s). Start a new conversation!"), + (Err(_), 0) => "⚠️ No active session to reset.".to_string(), + (Err(_), n) => format!("🔄 Dropped {n} buffered message(s). No active session to reset."), + }; + let _ = send_fire_and_forget(&slash_ws_tx, &channel, &msg).await; + continue; + } + if trimmed == "/cancel" { + let thread_key = format!("{}:{}", event.platform, event.channel.thread_id.as_deref().unwrap_or(&event.channel.id)); + let msg = match router.pool().cancel_session(&thread_key).await { + Ok(()) => "🛑 Cancel signal sent.".to_string(), + Err(e) => format!("⚠️ {e}"), + }; + let _ = send_fire_and_forget(&slash_ws_tx, &channel, &msg).await; + continue; + } + { + let thread_key = format!("{}:{}", event.platform, event.channel.thread_id.as_deref().unwrap_or(&event.channel.id)); + if let Some(msg) = handle_config_command(trimmed, &router, &thread_key).await { + let _ = send_fire_and_forget(&slash_ws_tx, &channel, &msg).await; + continue; + } + } + + tasks.spawn(async move { + // If supergroup with no thread_id, create a forum topic + let thread_channel = if event.channel.channel_type == "supergroup" + && channel.thread_id.is_none() + { + let title = crate::format::shorten_thread_name(&prompt); + match adapter.create_thread(&channel, &trigger_msg, &title).await { + Ok(tc) => tc, + Err(e) => { + warn!("create_thread failed, using channel: {e}"); + channel.clone() + } + } + } else { + channel.clone() + }; + + let thread_id = thread_channel + .thread_id + .as_deref() + .unwrap_or(&thread_channel.channel_id); + let thread_key = dispatcher.key( + &thread_channel.platform, + thread_id, + &sender_id, + ); + let estimated_tokens = + crate::dispatch::estimate_tokens(&prompt, &extra_blocks); + let buf_msg = crate::dispatch::BufferedMessage { + sender_json, + sender_name, + prompt, + extra_blocks, + trigger_msg, + arrived_at: std::time::Instant::now(), + estimated_tokens, + // TODO: implement gateway multibot detection + other_bot_present: false, + recipient: None, // Slack-only (assistant mode); N/A for gateway + }; + if let Err(e) = dispatcher + .submit(thread_key, thread_channel, adapter, buf_msg) + .await + { + error!("gateway dispatcher submit error: {e}"); + } + }); + } + Err(e) => warn!("invalid gateway event: {e}"), + } + } + Some(Ok(Message::Close(_))) | None => { + warn!("gateway WebSocket closed, will reconnect"); + break; + } + Some(Err(e)) => { + error!("gateway WebSocket error: {e}, will reconnect"); + break; + } + _ => {} + } + } + _ = shutdown_rx.changed() => { + if *shutdown_rx.borrow() { + info!("gateway adapter shutting down, waiting for {} in-flight tasks", tasks.len()); + while tasks.join_next().await.is_some() {} + return Ok(()); + } + } + } + } // inner loop — break here means reconnect + + // Drain in-flight tasks before reconnecting + while tasks.join_next().await.is_some() {} + + warn!(backoff = backoff_secs, "reconnecting to gateway"); + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_secs(backoff_secs)) => {} + _ = shutdown_rx.changed() => { return Ok(()); } + } + backoff_secs = (backoff_secs * 2).min(MAX_BACKOFF); + } // outer reconnect loop +} diff --git a/crates/openab-core/src/hooks.rs b/crates/openab-core/src/hooks.rs new file mode 100644 index 000000000..164ce39d0 --- /dev/null +++ b/crates/openab-core/src/hooks.rs @@ -0,0 +1,425 @@ +use crate::config::{HookConfig, OnFailure}; +use sha2::{Digest, Sha256}; +use std::io::Write; +use std::path::PathBuf; +use tokio::process::Command; +use tracing::{error, info, warn}; + +/// Maximum size for a remote hook script (1 MiB). +const MAX_SCRIPT_SIZE: usize = 1024 * 1024; + +/// Run a hook. Returns Ok(()) if the hook succeeds or is not configured. +/// Returns Err only if on_failure=abort and the hook fails. +pub async fn run_hook(name: &str, hook: &HookConfig) -> anyhow::Result<()> { + info!(hook = name, "running hook"); + + let resolved = match resolve_script(name, hook).await { + Ok(r) => r, + Err(e) => return handle_failure(name, hook.on_failure, e), + }; + + let result = execute(&resolved.path, hook.timeout_seconds).await; + + // Clean up temp files + if resolved.temp { + let _ = std::fs::remove_file(&resolved.path); + } + + match result { + Ok(()) => { + info!(hook = name, "hook completed successfully"); + Ok(()) + } + Err(e) => handle_failure(name, hook.on_failure, e), + } +} + +/// Validate hook config at parse time. +pub fn validate_hook(name: &str, hook: &HookConfig) -> anyhow::Result<()> { + let sources = [ + hook.script.is_some(), + hook.inline.is_some(), + hook.url.is_some(), + ]; + let count = sources.iter().filter(|&&b| b).count(); + if count == 0 { + anyhow::bail!("hooks.{name}: exactly one of script, inline, or url must be set"); + } + if count > 1 { + anyhow::bail!( + "hooks.{name}: only one of script, inline, or url may be set (found {count})" + ); + } + if hook.url.is_some() && hook.sha256.is_none() { + anyhow::bail!("hooks.{name}: sha256 is required when using url"); + } + if let Some(ref path) = hook.script { + if !PathBuf::from(path).is_absolute() { + anyhow::bail!("hooks.{name}: script path must be absolute, got: {path}"); + } + } + Ok(()) +} + +struct ResolvedScript { + path: PathBuf, + temp: bool, +} + +async fn resolve_script(name: &str, hook: &HookConfig) -> anyhow::Result { + if let Some(ref path) = hook.script { + let p = PathBuf::from(path); + if !p.exists() { + anyhow::bail!("hooks.{name}: script not found: {path}"); + } + return Ok(ResolvedScript { + path: p, + temp: false, + }); + } + + if let Some(ref content) = hook.inline { + let path = write_temp_script(name, content)?; + return Ok(ResolvedScript { path, temp: true }); + } + + if let Some(ref url) = hook.url { + let expected_hash = hook.sha256.as_deref().unwrap(); + let content = fetch_and_verify(url, expected_hash).await?; + let path = write_temp_script(name, &content)?; + return Ok(ResolvedScript { path, temp: true }); + } + + anyhow::bail!("hooks.{name}: no script source configured"); +} + +fn write_temp_script(name: &str, content: &str) -> anyhow::Result { + #[cfg(unix)] + let suffix = ".sh"; + #[cfg(windows)] + let suffix = ".cmd"; + + let prefix = format!("openab-hook-{name}-"); + let mut builder = tempfile::Builder::new(); + builder.prefix(prefix.as_str()).suffix(suffix); + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + builder.permissions(std::fs::Permissions::from_mode(0o700)); + } + + let mut f = builder.tempfile()?; + f.write_all(content.as_bytes())?; + let path = f.into_temp_path().keep().map_err(|e| { + anyhow::anyhow!("failed to persist temp script: {}", e.error) + })?; + Ok(path) +} + +async fn fetch_and_verify(url: &str, expected_hex: &str) -> anyhow::Result { + info!(url = url, "fetching hook script from URL"); + let client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(30)) + .build()?; + let resp = client.get(url).send().await?; + if !resp.status().is_success() { + anyhow::bail!("hook url returned HTTP {}", resp.status()); + } + let content_length = resp.content_length().unwrap_or(0) as usize; + if content_length > MAX_SCRIPT_SIZE { + anyhow::bail!( + "hook script too large: {content_length} bytes (max {MAX_SCRIPT_SIZE})" + ); + } + let body = resp.bytes().await?; + if body.len() > MAX_SCRIPT_SIZE { + anyhow::bail!( + "hook script too large: {} bytes (max {MAX_SCRIPT_SIZE})", + body.len() + ); + } + + let mut hasher = Sha256::new(); + hasher.update(&body); + let actual_hex = format!("{:x}", hasher.finalize()); + + if actual_hex != expected_hex.to_lowercase() { + anyhow::bail!("hook sha256 mismatch: expected {expected_hex}, got {actual_hex}"); + } + + Ok(String::from_utf8(body.to_vec())?) +} + +async fn execute(path: &PathBuf, timeout_secs: u64) -> anyhow::Result<()> { + let mut cmd = Command::new(path); + cmd.env_clear(); + + // Baseline env (same as agent subprocess) + if let Ok(v) = std::env::var("HOME") { + cmd.env("HOME", &v); + } + if let Ok(v) = std::env::var("PATH") { + cmd.env("PATH", &v); + } + #[cfg(unix)] + if let Ok(v) = std::env::var("USER") { + cmd.env("USER", &v); + } + #[cfg(windows)] + { + if let Ok(v) = std::env::var("USERPROFILE") { + cmd.env("USERPROFILE", &v); + } + if let Ok(v) = std::env::var("USERNAME") { + cmd.env("USERNAME", &v); + } + if let Ok(v) = std::env::var("SystemRoot") { + cmd.env("SystemRoot", &v); + } + if let Ok(v) = std::env::var("SystemDrive") { + cmd.env("SystemDrive", &v); + } + } + + // Pass through cloud credential env vars for IAM-based auth (IRSA, Workload Identity, ECS task role) + for (key, val) in std::env::vars() { + let pass = key.starts_with("AWS_") + || key.starts_with("AMAZON_") + || key.starts_with("ECS_CONTAINER_METADATA_URI") + || key.starts_with("GOOGLE_") + || key.starts_with("GCLOUD_") + || key.starts_with("CLOUDSDK_") + || key.starts_with("AZURE_") + || key == "BOOTSTRAP_URI" + || key == "BOOTSTRAP_BASE_URI" + || key == "BOOTSTRAP_PERSONAL_URI" + || key == "STATE_BUCKET" + || key == "TASK_FAMILY" + || key == "OPENAB_AGENT_NAME" + || key == "OPENAB_BACKEND_AGENT"; + if pass { + cmd.env(&key, &val); + } + } + + let mut child = cmd.spawn()?; + + if timeout_secs == 0 { + let status = child.wait().await?; + if !status.success() { + anyhow::bail!("hook exited with {status}"); + } + return Ok(()); + } + + let timeout = std::time::Duration::from_secs(timeout_secs); + match tokio::time::timeout(timeout, child.wait()).await { + Ok(Ok(status)) => { + if !status.success() { + anyhow::bail!("hook exited with {status}"); + } + Ok(()) + } + Ok(Err(e)) => anyhow::bail!("hook process error: {e}"), + Err(_) => { + let _ = child.kill().await; + anyhow::bail!("hook timed out after {timeout_secs}s"); + } + } +} + +fn handle_failure(name: &str, policy: OnFailure, err: anyhow::Error) -> anyhow::Result<()> { + match policy { + OnFailure::Abort => { + error!(hook = name, error = %err, "hook failed (on_failure=abort)"); + Err(err) + } + OnFailure::Warn => { + warn!(hook = name, error = %err, "hook failed (on_failure=warn), continuing"); + Ok(()) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::{HookConfig, OnFailure}; + + fn hook_with_script(path: &str) -> HookConfig { + HookConfig { + script: Some(path.into()), + inline: None, + url: None, + sha256: None, + timeout_seconds: 60, + on_failure: OnFailure::Abort, + } + } + + fn hook_with_inline(content: &str) -> HookConfig { + HookConfig { + script: None, + inline: Some(content.into()), + url: None, + sha256: None, + timeout_seconds: 60, + on_failure: OnFailure::Abort, + } + } + + #[test] + fn validate_rejects_no_source() { + let hook = HookConfig { + script: None, + inline: None, + url: None, + sha256: None, + timeout_seconds: 60, + on_failure: OnFailure::Abort, + }; + assert!(validate_hook("test", &hook).is_err()); + } + + #[test] + fn validate_rejects_multiple_sources() { + let hook = HookConfig { + script: Some("/bin/true".into()), + inline: Some("echo hi".into()), + url: None, + sha256: None, + timeout_seconds: 60, + on_failure: OnFailure::Abort, + }; + assert!(validate_hook("test", &hook).is_err()); + } + + #[test] + fn validate_rejects_url_without_sha256() { + let hook = HookConfig { + script: None, + inline: None, + url: Some("https://example.com/script.sh".into()), + sha256: None, + timeout_seconds: 60, + on_failure: OnFailure::Abort, + }; + assert!(validate_hook("test", &hook).is_err()); + } + + #[test] + fn validate_rejects_relative_script_path() { + let hook = hook_with_script("relative/path.sh"); + assert!(validate_hook("test", &hook).is_err()); + } + + #[test] + fn validate_accepts_absolute_script_path() { + let hook = hook_with_script("/usr/local/bin/bootstrap.sh"); + assert!(validate_hook("test", &hook).is_ok()); + } + + #[test] + fn validate_accepts_inline() { + let hook = hook_with_inline("#!/bin/sh\necho hello"); + assert!(validate_hook("test", &hook).is_ok()); + } + + #[test] + fn validate_accepts_url_with_sha256() { + let hook = HookConfig { + script: None, + inline: None, + url: Some("https://example.com/script.sh".into()), + sha256: Some("abc123".into()), + timeout_seconds: 60, + on_failure: OnFailure::Abort, + }; + assert!(validate_hook("test", &hook).is_ok()); + } + + #[tokio::test] + async fn run_inline_script_success() { + let hook = hook_with_inline("#!/bin/sh\nexit 0"); + let result = run_hook("test", &hook).await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn run_inline_script_failure_abort() { + let hook = hook_with_inline("#!/bin/sh\nexit 1"); + let result = run_hook("test", &hook).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn run_inline_script_failure_warn() { + let hook = HookConfig { + script: None, + inline: Some("#!/bin/sh\nexit 1".into()), + url: None, + sha256: None, + timeout_seconds: 60, + on_failure: OnFailure::Warn, + }; + let result = run_hook("test", &hook).await; + assert!(result.is_ok()); // warn mode continues + } + + #[tokio::test] + async fn run_inline_script_timeout() { + let hook = HookConfig { + script: None, + inline: Some("#!/bin/sh\nsleep 10".into()), + url: None, + sha256: None, + timeout_seconds: 1, + on_failure: OnFailure::Abort, + }; + let result = run_hook("test", &hook).await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("timed out")); + } + + #[tokio::test] + async fn run_script_file_success() { + let dir = std::env::temp_dir(); + let path = dir.join("openab-test-hook-success.sh"); + std::fs::write(&path, "#!/bin/sh\nexit 0").unwrap(); + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o700)).unwrap(); + } + let hook = hook_with_script(path.to_str().unwrap()); + let result = run_hook("test", &hook).await; + let _ = std::fs::remove_file(&path); + assert!(result.is_ok()); + } + + #[tokio::test] + async fn run_script_file_not_found() { + let hook = hook_with_script("/tmp/openab-nonexistent-hook-12345.sh"); + let result = run_hook("test", &hook).await; + assert!(result.is_err()); + } + + #[test] + fn config_parses_hooks() { + let toml_str = "[agent]\ncommand = \"echo\"\n\n[hooks.pre_boot]\ninline = \"echo hello\"\ntimeout_seconds = 30\non_failure = \"warn\"\n"; + let cfg: crate::config::Config = toml::from_str(toml_str).unwrap(); + let hook = cfg.hooks.pre_boot.unwrap(); + assert_eq!(hook.inline.unwrap(), "echo hello"); + assert_eq!(hook.timeout_seconds, 30); + assert_eq!(hook.on_failure, OnFailure::Warn); + } + + #[test] + fn config_parses_no_hooks() { + let toml_str = "[agent]\ncommand = \"echo\"\n"; + let cfg: crate::config::Config = toml::from_str(toml_str).unwrap(); + assert!(cfg.hooks.pre_boot.is_none()); + assert!(cfg.hooks.pre_shutdown.is_none()); + } +} diff --git a/crates/openab-core/src/lib.rs b/crates/openab-core/src/lib.rs new file mode 100644 index 000000000..f61540657 --- /dev/null +++ b/crates/openab-core/src/lib.rs @@ -0,0 +1,25 @@ +pub mod acp; +pub mod adapter; +pub mod bot_turns; +pub mod config; +pub mod cron; +pub mod directives; +pub mod dispatch; +pub mod error_display; +pub mod format; +pub mod gateway; +pub mod hooks; +pub mod markdown; +pub mod media; +pub mod multibot_cache; +pub mod reactions; +pub mod remind; +pub mod secrets; +pub mod setup; +pub mod stt; +pub mod timestamp; + +#[cfg(feature = "discord")] +pub mod discord; +#[cfg(feature = "slack")] +pub mod slack; diff --git a/crates/openab-core/src/markdown.rs b/crates/openab-core/src/markdown.rs new file mode 100644 index 000000000..32398cc25 --- /dev/null +++ b/crates/openab-core/src/markdown.rs @@ -0,0 +1,349 @@ +use pulldown_cmark::{Event, Options, Parser, Tag, TagEnd}; +use serde::Deserialize; +use std::fmt; +use unicode_width::UnicodeWidthStr; + +/// How to render markdown tables for a given channel. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum TableMode { + /// Wrap the table in a fenced code block (default). + #[default] + Code, + /// Convert each row into bullet points. + Bullets, + /// Pass through unchanged. + Off, +} + +impl fmt::Display for TableMode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Code => write!(f, "code"), + Self::Bullets => write!(f, "bullets"), + Self::Off => write!(f, "off"), + } + } +} + +// ── IR types ──────────────────────────────────────────────────────── + +/// A parsed table: header row + data rows, each cell is plain text. +struct Table { + headers: Vec, + rows: Vec>, +} + +/// Segment of the document — either verbatim text or a parsed table. +enum Segment { + Text(String), + Table(Table), +} + +// ── Public API ────────────────────────────────────────────────────── + +/// Parse markdown, detect tables via pulldown-cmark, and render them +/// according to `mode`. Non-table content passes through unchanged. +pub fn convert_tables(markdown: &str, mode: TableMode) -> String { + if mode == TableMode::Off || markdown.is_empty() { + return markdown.to_string(); + } + + let segments = parse_segments(markdown); + + let mut out = String::with_capacity(markdown.len()); + for seg in segments { + match seg { + Segment::Text(t) => out.push_str(&t), + Segment::Table(table) => match mode { + TableMode::Code => render_table_code(&table, &mut out), + TableMode::Bullets => render_table_bullets(&table, &mut out), + TableMode::Off => unreachable!(), + }, + } + } + out +} + +// ── Parser ────────────────────────────────────────────────────────── + +/// Walk the markdown source with pulldown-cmark and split it into +/// text segments and parsed Table segments. +fn parse_segments(markdown: &str) -> Vec { + let mut opts = Options::empty(); + opts.insert(Options::ENABLE_TABLES); + + let mut segments: Vec = Vec::new(); + let mut in_table = false; + let mut in_head = false; + let mut headers: Vec = Vec::new(); + let mut rows: Vec> = Vec::new(); + let mut current_row: Vec = Vec::new(); + let mut cell_buf = String::new(); + let mut last_table_end: usize = 0; + + // We need byte offsets to grab non-table text verbatim. + let parser_with_offsets = Parser::new_ext(markdown, opts).into_offset_iter(); + + for (event, range) in parser_with_offsets { + match event { + Event::Start(Tag::Table(_)) => { + // Flush text before this table + let before = &markdown[last_table_end..range.start]; + if !before.is_empty() { + push_text(&mut segments, before); + } + in_table = true; + headers.clear(); + rows.clear(); + } + Event::End(TagEnd::Table) => { + let table = Table { + headers: std::mem::take(&mut headers), + rows: std::mem::take(&mut rows), + }; + segments.push(Segment::Table(table)); + in_table = false; + last_table_end = range.end; + } + Event::Start(Tag::TableHead) => { + in_head = true; + current_row.clear(); + } + Event::End(TagEnd::TableHead) => { + headers = std::mem::take(&mut current_row); + in_head = false; + } + Event::Start(Tag::TableRow) => { + current_row.clear(); + } + Event::End(TagEnd::TableRow) if !in_head => { + rows.push(std::mem::take(&mut current_row)); + } + Event::Start(Tag::TableCell) => { + cell_buf.clear(); + } + Event::End(TagEnd::TableCell) => { + current_row.push(cell_buf.trim().to_string()); + cell_buf.clear(); + } + Event::Text(t) if in_table => { + cell_buf.push_str(&t); + } + Event::Code(t) if in_table => { + cell_buf.push('`'); + cell_buf.push_str(&t); + cell_buf.push('`'); + } + // Inline markup inside cells: collect text, ignore tags + Event::SoftBreak if in_table => { + cell_buf.push(' '); + } + Event::HardBreak if in_table => { + cell_buf.push(' '); + } + // Start/End of inline tags (bold, italic, link, etc.) — skip the + // tag markers but keep processing their child text events above. + Event::Start(Tag::Emphasis) + | Event::Start(Tag::Strong) + | Event::Start(Tag::Strikethrough) + | Event::Start(Tag::Link { .. }) + | Event::End(TagEnd::Emphasis) + | Event::End(TagEnd::Strong) + | Event::End(TagEnd::Strikethrough) + | Event::End(TagEnd::Link) + if in_table => {} + _ => {} + } + } + + // Remaining text after last table + if last_table_end < markdown.len() { + let tail = &markdown[last_table_end..]; + if !tail.is_empty() { + push_text(&mut segments, tail); + } + } + + segments +} + +fn push_text(segments: &mut Vec, text: &str) { + if let Some(Segment::Text(ref mut prev)) = segments.last_mut() { + prev.push_str(text); + } else { + segments.push(Segment::Text(text.to_string())); + } +} + +// ── Renderers ─────────────────────────────────────────────────────── + +/// Render table as a fenced code block with aligned columns. +fn render_table_code(table: &Table, out: &mut String) { + let col_count = table + .headers + .len() + .max(table.rows.iter().map(|r| r.len()).max().unwrap_or(0)); + if col_count == 0 { + return; + } + + // Strip backticks from cells — inside a code fence they render as literals. + let strip = |s: &str| s.replace('`', ""); + let headers: Vec = table.headers.iter().map(|h| strip(h)).collect(); + let rows: Vec> = table + .rows + .iter() + .map(|r| r.iter().map(|c| strip(c)).collect()) + .collect(); + + // Compute column widths (using display width for CJK/emoji) + let mut widths = vec![0usize; col_count]; + for (i, h) in headers.iter().enumerate() { + widths[i] = widths[i].max(UnicodeWidthStr::width(h.as_str())); + } + for row in &rows { + for (i, cell) in row.iter().enumerate() { + if i < col_count { + widths[i] = widths[i].max(UnicodeWidthStr::width(cell.as_str())); + } + } + } + // Minimum width 3 for the divider + for w in &mut widths { + *w = (*w).max(3); + } + + out.push_str("```\n"); + + // Header row + write_row(out, &headers, &widths, col_count); + // Divider + out.push('|'); + for w in &widths { + out.push(' '); + for _ in 0..*w { + out.push('-'); + } + out.push_str(" |"); + } + out.push('\n'); + // Data rows + for row in &rows { + write_row(out, row, &widths, col_count); + } + + out.push_str("```\n"); +} + +fn write_row(out: &mut String, cells: &[String], widths: &[usize], col_count: usize) { + out.push('|'); + for (i, w) in widths.iter().enumerate().take(col_count) { + out.push(' '); + let cell = cells.get(i).map(|s| s.as_str()).unwrap_or(""); + out.push_str(cell); + let display_width = UnicodeWidthStr::width(cell); + let pad = w.saturating_sub(display_width); + for _ in 0..pad { + out.push(' '); + } + out.push_str(" |"); + } + out.push('\n'); +} + +/// Render table as bullet points: `• header: value` per cell. +fn render_table_bullets(table: &Table, out: &mut String) { + for (row_idx, row) in table.rows.iter().enumerate() { + for (i, cell) in row.iter().enumerate() { + if cell.is_empty() { + continue; + } + out.push_str("• "); + if let Some(h) = table.headers.get(i) { + if !h.is_empty() { + out.push_str(h); + out.push_str(": "); + } + } + out.push_str(cell); + out.push('\n'); + } + // Blank line between rows, but not after the last one + if row_idx + 1 < table.rows.len() { + out.push('\n'); + } + } +} + +// ── Tests ─────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + const TABLE_MD: &str = "\ +Some text before. + +| Name | Age | +|-------|-----| +| Alice | 30 | +| Bob | 25 | + +Some text after. +"; + + #[test] + fn off_mode_passes_through() { + let result = convert_tables(TABLE_MD, TableMode::Off); + assert_eq!(result, TABLE_MD); + } + + #[test] + fn code_mode_wraps_in_codeblock() { + let result = convert_tables(TABLE_MD, TableMode::Code); + assert!(result.contains("```\n")); + assert!(result.contains("| Alice")); + assert!(result.contains("Some text before.")); + assert!(result.contains("Some text after.")); + } + + #[test] + fn bullets_mode_converts_to_bullets() { + let result = convert_tables(TABLE_MD, TableMode::Bullets); + assert!(result.contains("• Name: Alice")); + assert!(result.contains("• Age: 30")); + assert!(!result.contains("```")); + } + + #[test] + fn no_table_passes_through() { + let plain = "Hello world\nNo tables here."; + let result = convert_tables(plain, TableMode::Code); + assert_eq!(result, plain); + } + + #[test] + fn code_mode_strips_backticks_from_code_cells() { + let md = "| col |\n|-----|\n| `value` |\n"; + let result = convert_tables(md, TableMode::Code); + // The table is inside a ``` block — backtick wrapping must be stripped. + assert!(result.contains("value"), "cell content should be present"); + // Only the fence markers themselves should contain backticks. + let inner = result.trim_start_matches("```\n").trim_end_matches("```\n"); + assert!( + !inner.contains('`'), + "no backticks should appear inside the code fence: {result:?}" + ); + } + + #[test] + fn bullets_mode_keeps_backticks_in_code_cells() { + let md = "| col |\n|-----|\n| `value` |\n"; + let result = convert_tables(md, TableMode::Bullets); + assert!( + result.contains("`value`"), + "backticks should be kept in bullets mode" + ); + } +} diff --git a/crates/openab-core/src/media.rs b/crates/openab-core/src/media.rs new file mode 100644 index 000000000..33ea59010 --- /dev/null +++ b/crates/openab-core/src/media.rs @@ -0,0 +1,846 @@ +use crate::acp::ContentBlock; +use crate::config::SttConfig; +use base64::engine::general_purpose::STANDARD as BASE64; +use base64::Engine; +use image::codecs::gif::GifDecoder; +use image::{AnimationDecoder, ImageReader}; +use std::io::Cursor; +use std::sync::LazyLock; +use tracing::{debug, error, warn}; + +/// Reusable HTTP client for downloading attachments (shared across adapters). +pub static HTTP_CLIENT: LazyLock = LazyLock::new(|| { + reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(30)) + .build() + .expect("static HTTP client must build") +}); + +/// Maximum dimension (width or height) for resized images. +const IMAGE_MAX_DIMENSION_PX: u32 = 1200; + +/// JPEG quality for compressed output. +const IMAGE_JPEG_QUALITY: u8 = 75; + +/// Error variants for `download_and_encode_image`. +#[derive(Debug)] +pub enum MediaFetchError { + /// URL empty or MIME/filename doesn't indicate an image; skip silently. + NotAnImage, + /// HTTP response Content-Type is not a supported image format. + UnsupportedResponseType { actual: Option }, + /// Response body magic bytes don't match a supported image format. + InvalidImageBody { magic_prefix_hex: String }, + /// File exceeds the configured size limit. + SizeExceeded { actual: u64, limit: u64 }, + /// Network-level error (send or body-read). + Network(reqwest::Error), + /// Server returned a non-success HTTP status. + HttpStatus(reqwest::StatusCode), + /// Body was a valid image but post-processing (resize/compress) failed. + /// Unlike `InvalidImageBody`, the bytes decoded successfully — this is an + /// unexpected processing error, not a content validation failure. Both the + /// Slack and Discord adapters surface this as a user-facing warning alongside + /// other image-validation failures. + ProcessingFailed(image::ImageError), +} + +impl std::fmt::Display for MediaFetchError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::NotAnImage => write!(f, "not an image attachment"), + Self::UnsupportedResponseType { actual } => write!( + f, + "server returned unexpected content type (actual: {})", + actual.as_deref().unwrap_or("none"), + ), + Self::InvalidImageBody { magic_prefix_hex } => write!( + f, + "response body is not a valid image (first 8 bytes: {magic_prefix_hex})" + ), + Self::SizeExceeded { actual, limit } => { + write!(f, "file size {actual} exceeds limit {limit}") + } + Self::Network(e) => write!(f, "network error: {e}"), + Self::HttpStatus(s) => write!(f, "HTTP {s}"), + Self::ProcessingFailed(e) => write!(f, "image processing failed: {e}"), + } + } +} + +/// Strip MIME parameters and trim whitespace. `"image/png; charset=binary"` → `"image/png"`. +pub(crate) fn strip_mime_params(mime: &str) -> &str { + mime.split(';').next().unwrap_or(mime).trim() +} + +/// Format the first 8 bytes of a buffer as lowercase hex (no separator). +fn hex_prefix(body: &[u8]) -> String { + body.iter() + .take(8) + .map(|b| format!("{b:02x}")) + .collect::>() + .concat() +} + +/// Validate the HTTP response Content-Type and body magic bytes. +/// +/// If Content-Type is present and explicitly text-typed (e.g. `text/html` from +/// Slack's auth redirect when `files:read` scope is missing), rejects immediately. +/// Generic types such as `application/octet-stream` and absent headers pass through +/// to the magic-byte check, which is the authoritative gate for image validity. +/// +/// Content-Type is filtered with a block-list (`text/*`) rather than an allow-list +/// (`image/*`) because CDNs commonly serve any file type as `application/octet-stream`; +/// rejecting that header would silently break real downloads. The magic-byte check +/// examines the actual bytes regardless of what the server claims. +fn validate_image_response( + content_type: Option<&str>, + body: &[u8], +) -> Result<(), MediaFetchError> { + // Reject explicitly-text responses early (e.g. Slack HTML login page at HTTP 200). + // application/octet-stream and other generic types pass through to magic-byte check. + if let Some(ct) = content_type { + let base = strip_mime_params(ct).to_lowercase(); + if base.starts_with("text/") { + return Err(MediaFetchError::UnsupportedResponseType { actual: Some(base) }); + } + } + + let reader = match ImageReader::new(Cursor::new(body)).with_guessed_format() { + Ok(r) => r, + Err(e) => { + error!(error = %e, "image format detection I/O error"); + return Err(MediaFetchError::InvalidImageBody { + magic_prefix_hex: hex_prefix(body), + }); + } + }; + + match reader.format() { + Some(image::ImageFormat::Png | image::ImageFormat::Jpeg | image::ImageFormat::WebP) => { + Ok(()) + } + Some(image::ImageFormat::Gif) => { + validate_gif_body(body).map_err(|e| { + warn!(error = %e, "GIF validation failed"); + MediaFetchError::InvalidImageBody { + magic_prefix_hex: hex_prefix(body), + } + })?; + Ok(()) + } + _ => Err(MediaFetchError::InvalidImageBody { + magic_prefix_hex: hex_prefix(body), + }), + } +} + +/// Validate a GIF body by attempting to decode exactly one frame. +/// +/// Decoding only the first frame is intentional: the GIF header and colour tables +/// must be valid before the first frame can be decoded, so this catches truncated +/// or corrupt payloads without the CPU/memory cost of decoding a large animated GIF +/// in full. +/// +/// Creates its own `Cursor` over `raw`; the caller can independently re-read the +/// same slice for resizing. +fn validate_gif_body(raw: &[u8]) -> image::ImageResult<()> { + let decoder = GifDecoder::new(Cursor::new(raw))?; + let mut frames = decoder.into_frames(); + frames.next().ok_or_else(|| { + image::ImageError::Decoding(image::error::DecodingError::new( + image::error::ImageFormatHint::Exact(image::ImageFormat::Gif), + "GIF has no frames", + )) + })??; + Ok(()) +} + +/// Download an image from a URL, resize/compress it, and return as a ContentBlock. +/// +/// Returns `Err(MediaFetchError::NotAnImage)` when the URL or MIME hint don't +/// indicate an image — callers should skip silently. Returns +/// `Err(MediaFetchError::SizeExceeded)` when the declared `size` exceeds the limit +/// before any request is made, or when the downloaded body exceeds the limit. Returns +/// other `Err` variants (`Network`, `HttpStatus`, `UnsupportedResponseType`, +/// `InvalidImageBody`) after a request attempt — callers should surface these to the user. Returns +/// `Err(MediaFetchError::ProcessingFailed)` when the body is a valid image but +/// resize/compression fails — callers should warn the user and skip. +/// +/// Pass `auth_token` for platforms that require authentication (e.g. Slack private files). +pub async fn download_and_encode_image( + url: &str, + mime_hint: Option<&str>, + filename: &str, + size: u64, + auth_token: Option<&str>, +) -> Result { + const MAX_SIZE: u64 = 10 * 1024 * 1024; // 10 MB + + if url.is_empty() { + return Err(MediaFetchError::NotAnImage); + } + + let mime = mime_hint.or_else(|| { + filename + .rsplit('.') + .next() + .and_then(|ext| match ext.to_lowercase().as_str() { + "png" => Some("image/png"), + "jpg" | "jpeg" => Some("image/jpeg"), + "gif" => Some("image/gif"), + "webp" => Some("image/webp"), + _ => None, + }) + }); + + let Some(mime) = mime else { + debug!(filename, "skipping non-image attachment"); + return Err(MediaFetchError::NotAnImage); + }; + let mime = mime.split(';').next().unwrap_or(mime).trim(); + if !mime.starts_with("image/") { + debug!(filename, mime, "skipping non-image attachment"); + return Err(MediaFetchError::NotAnImage); + } + + if size > MAX_SIZE { + error!(filename, size, "image exceeds 10MB limit"); + return Err(MediaFetchError::SizeExceeded { + actual: size, + limit: MAX_SIZE, + }); + } + + let mut req = HTTP_CLIENT.get(url); + if let Some(token) = auth_token { + req = req.header("Authorization", format!("Bearer {token}")); + } + + let response = match req.send().await { + Ok(resp) => resp, + Err(e) => { + error!(url, error = %e, "download failed"); + return Err(MediaFetchError::Network(e)); + } + }; + if !response.status().is_success() { + error!(url, status = %response.status(), "HTTP error downloading image"); + return Err(MediaFetchError::HttpStatus(response.status())); + } + + // Capture Content-Type BEFORE .bytes() consumes the response. + let content_type = response + .headers() + .get(reqwest::header::CONTENT_TYPE) + .and_then(|v| v.to_str().ok()) + .map(str::to_string); + + let bytes = match response.bytes().await { + Ok(b) => b, + Err(e) => { + error!(url, error = %e, "read failed"); + return Err(MediaFetchError::Network(e)); + } + }; + + if bytes.len() as u64 > MAX_SIZE { + error!( + filename, + size = bytes.len(), + "downloaded image exceeds limit" + ); + return Err(MediaFetchError::SizeExceeded { + actual: bytes.len() as u64, + limit: MAX_SIZE, + }); + } + + // Guard against HTTP 200 responses that are error pages (e.g. Slack auth redirect + // when files:read scope is missing), and against corrupted or mislabeled bodies. + if let Err(e) = validate_image_response(content_type.as_deref(), &bytes) { + error!( + filename, + mime_hint = mime, + content_type = content_type.as_deref().unwrap_or("none"), + magic = hex_prefix(&bytes), + error = %e, + "image validation failed — body is not a supported image" + ); + return Err(e); + } + + let (output_bytes, output_mime) = match resize_and_compress(&bytes) { + Ok(result) => result, + Err(e) => { + error!( + filename, + error = %e, + size = bytes.len(), + "resize failed after successful validation" + ); + return Err(MediaFetchError::ProcessingFailed(e)); + } + }; + + debug!( + filename, + original_size = bytes.len(), + compressed_size = output_bytes.len(), + "image processed" + ); + + let encoded = BASE64.encode(&output_bytes); + Ok(ContentBlock::Image { + media_type: output_mime, + data: encoded, + }) +} + +/// Download an audio file and transcribe it via the configured STT provider. +/// Pass `auth_token` for platforms that require authentication. +pub async fn download_and_transcribe( + url: &str, + filename: &str, + mime_type: &str, + size: u64, + stt_config: &SttConfig, + auth_token: Option<&str>, +) -> Option { + const MAX_SIZE: u64 = 25 * 1024 * 1024; // 25 MB (Whisper API limit) + + if size > MAX_SIZE { + error!(filename, size, "audio exceeds 25MB limit"); + return None; + } + + let mut req = HTTP_CLIENT.get(url); + if let Some(token) = auth_token { + req = req.header("Authorization", format!("Bearer {token}")); + } + + let resp = match req.send().await { + Ok(r) => r, + Err(e) => { + error!(url, error = %e, "audio download request failed"); + return None; + } + }; + if !resp.status().is_success() { + error!(url, status = %resp.status(), "audio download failed"); + return None; + } + let bytes = match resp.bytes().await { + Ok(b) => b.to_vec(), + Err(e) => { + error!(url, error = %e, "audio body read failed"); + return None; + } + }; + + if bytes.len() as u64 > MAX_SIZE { + error!(filename, size = bytes.len(), "downloaded audio exceeds 25MB limit"); + return None; + } + + crate::stt::transcribe( + &HTTP_CLIENT, + stt_config, + bytes, + filename.to_string(), + mime_type, + ) + .await +} + +/// Resize image so longest side <= IMAGE_MAX_DIMENSION_PX, then encode as JPEG. +/// GIFs are passed through unchanged to preserve animation. +pub fn resize_and_compress(raw: &[u8]) -> Result<(Vec, String), image::ImageError> { + let reader = ImageReader::new(Cursor::new(raw)).with_guessed_format()?; + + let format = reader.format(); + + if format == Some(image::ImageFormat::Gif) { + return Ok((raw.to_vec(), "image/gif".to_string())); + } + + let img = reader.decode()?; + let (w, h) = (img.width(), img.height()); + + let img = if w > IMAGE_MAX_DIMENSION_PX || h > IMAGE_MAX_DIMENSION_PX { + let max_side = std::cmp::max(w, h); + let ratio = f64::from(IMAGE_MAX_DIMENSION_PX) / f64::from(max_side); + let new_w = (f64::from(w) * ratio) as u32; + let new_h = (f64::from(h) * ratio) as u32; + img.resize(new_w, new_h, image::imageops::FilterType::Lanczos3) + } else { + img + }; + + let mut buf = Cursor::new(Vec::new()); + let encoder = image::codecs::jpeg::JpegEncoder::new_with_quality(&mut buf, IMAGE_JPEG_QUALITY); + img.write_with_encoder(encoder)?; + + Ok((buf.into_inner(), "image/jpeg".to_string())) +} + +/// Check if a MIME type is audio. +pub fn is_audio_mime(mime: &str) -> bool { + mime.starts_with("audio/") +} + +/// Check if an attachment is a video file. +pub fn is_video_file(filename: &str, content_type: Option<&str>) -> bool { + let mime = content_type.unwrap_or(""); + let mime_base = mime.split(';').next().unwrap_or(mime).trim(); + if mime_base.starts_with("video/") { + return true; + } + + filename + .rsplit('.') + .next() + .map(|ext| { + matches!( + ext.to_lowercase().as_str(), + "mp4" | "mov" | "m4v" | "webm" | "mkv" | "avi" + ) + }) + .unwrap_or(false) +} + +/// Extensions recognised as text-based files that can be inlined into the prompt. +const TEXT_EXTENSIONS: &[&str] = &[ + "txt", "csv", "log", "md", "json", "jsonl", "yaml", "yml", "toml", "xml", "rs", "py", "js", + "ts", "jsx", "tsx", "go", "java", "c", "cpp", "h", "hpp", "rb", "sh", "bash", "zsh", "fish", + "ps1", "bat", "sql", "html", "css", "scss", "less", "ini", "cfg", "conf", "env", +]; + +/// Exact filenames (no extension) recognised as text files. +const TEXT_FILENAMES: &[&str] = &[ + "dockerfile", + "makefile", + "justfile", + "rakefile", + "gemfile", + "procfile", + "vagrantfile", + ".gitignore", + ".dockerignore", + ".editorconfig", +]; + +/// MIME types recognised as text-based (beyond `text/*`). +const TEXT_MIME_TYPES: &[&str] = &[ + "application/json", + "application/xml", + "application/javascript", + "application/x-yaml", + "application/x-sh", + "application/toml", + "application/x-toml", +]; + +/// Check if a file is text-based and can be inlined into the prompt. +pub fn is_text_file(filename: &str, content_type: Option<&str>) -> bool { + let mime = content_type.unwrap_or(""); + let mime_base = mime.split(';').next().unwrap_or(mime).trim(); + if mime_base.starts_with("text/") || TEXT_MIME_TYPES.contains(&mime_base) { + return true; + } + // Check extension + if filename.contains('.') { + if let Some(ext) = filename.rsplit('.').next() { + if TEXT_EXTENSIONS.contains(&ext.to_lowercase().as_str()) { + return true; + } + } + } + // Check exact filename (Dockerfile, Makefile, etc.) + TEXT_FILENAMES.contains(&filename.to_lowercase().as_str()) +} + +/// Download a text-based file and return it as a ContentBlock::Text. +/// Files larger than 512 KB are skipped to avoid bloating the prompt. +/// +/// Pass `auth_token` for platforms that require authentication (e.g. Slack private files). +/// +/// Note: the caller already guards total size via a total cap; the per-file +/// MAX_SIZE check here is intentional defense-in-depth so this function remains +/// self-contained and safe when called from other contexts. +pub async fn download_and_read_text_file( + url: &str, + filename: &str, + size: u64, + auth_token: Option<&str>, +) -> Option<(ContentBlock, u64)> { + const MAX_SIZE: u64 = 512 * 1024; // 512 KB + + if size > MAX_SIZE { + tracing::warn!(filename, size, "text file exceeds 512KB limit, skipping"); + return None; + } + + let mut req = HTTP_CLIENT.get(url); + if let Some(token) = auth_token { + req = req.header("Authorization", format!("Bearer {token}")); + } + + let resp = match req.send().await { + Ok(r) => r, + Err(e) => { + tracing::warn!(url, error = %e, "text file download failed"); + return None; + } + }; + if !resp.status().is_success() { + tracing::warn!(url, status = %resp.status(), "text file download failed"); + return None; + } + let bytes = match resp.bytes().await { + Ok(b) => b, + Err(e) => { + tracing::warn!(url, error = %e, "text file body read failed"); + return None; + } + }; + let actual_size = bytes.len() as u64; + + // Defense-in-depth: verify actual download size + if actual_size > MAX_SIZE { + tracing::warn!( + filename, + size = actual_size, + "downloaded text file exceeds 512KB limit, skipping" + ); + return None; + } + + // from_utf8_lossy returns Cow::Borrowed for valid UTF-8 (zero-copy) + let text = String::from_utf8_lossy(&bytes).into_owned(); + + // Dynamic fence: keep adding backticks until the fence doesn't appear in content + let mut fence = "```".to_string(); + while text.contains(fence.as_str()) { + fence.push('`'); + } + + debug!(filename, bytes = text.len(), "text file inlined"); + Some(( + ContentBlock::Text { + text: format!("[File: {filename}]\n{fence}\n{text}\n{fence}"), + }, + actual_size, + )) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_png(width: u32, height: u32) -> Vec { + let img = image::RgbImage::new(width, height); + let mut buf = Cursor::new(Vec::new()); + img.write_to(&mut buf, image::ImageFormat::Png).unwrap(); + buf.into_inner() + } + + fn make_jpeg(width: u32, height: u32) -> Vec { + let img = image::RgbImage::new(width, height); + let mut buf = Cursor::new(Vec::new()); + img.write_to(&mut buf, image::ImageFormat::Jpeg).unwrap(); + buf.into_inner() + } + + fn make_gif() -> Vec { + vec![ + 0x47, 0x49, 0x46, 0x38, 0x39, 0x61, 0x01, 0x00, 0x01, 0x00, 0x80, 0x00, 0x00, 0x00, + 0x00, 0x00, 0xff, 0xff, 0xff, 0x2C, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, + 0x00, 0x02, 0x02, 0x44, 0x01, 0x00, 0x3B, + ] + } + + #[test] + fn large_image_resized_to_max_dimension() { + let png = make_png(3000, 2000); + let (compressed, mime) = resize_and_compress(&png).unwrap(); + + assert_eq!(mime, "image/jpeg"); + let result = image::load_from_memory(&compressed).unwrap(); + assert!(result.width() <= IMAGE_MAX_DIMENSION_PX); + assert!(result.height() <= IMAGE_MAX_DIMENSION_PX); + } + + #[test] + fn small_image_keeps_original_dimensions() { + let png = make_png(800, 600); + let (compressed, mime) = resize_and_compress(&png).unwrap(); + + assert_eq!(mime, "image/jpeg"); + let result = image::load_from_memory(&compressed).unwrap(); + assert_eq!(result.width(), 800); + assert_eq!(result.height(), 600); + } + + #[test] + fn landscape_image_respects_aspect_ratio() { + let png = make_png(4000, 2000); + let (compressed, _) = resize_and_compress(&png).unwrap(); + + let result = image::load_from_memory(&compressed).unwrap(); + assert_eq!(result.width(), 1200); + assert_eq!(result.height(), 600); + } + + #[test] + fn portrait_image_respects_aspect_ratio() { + let png = make_png(2000, 4000); + let (compressed, _) = resize_and_compress(&png).unwrap(); + + let result = image::load_from_memory(&compressed).unwrap(); + assert_eq!(result.width(), 600); + assert_eq!(result.height(), 1200); + } + + #[test] + fn compressed_output_is_smaller_than_original() { + let png = make_png(3000, 2000); + let (compressed, _) = resize_and_compress(&png).unwrap(); + + assert!( + compressed.len() < png.len(), + "compressed {} should be < original {}", + compressed.len(), + png.len() + ); + } + + #[test] + fn gif_passes_through_unchanged() { + let gif = make_gif(); + let (output, mime) = resize_and_compress(&gif).unwrap(); + + assert_eq!(mime, "image/gif"); + assert_eq!(output, gif); + } + + #[test] + fn invalid_data_returns_error() { + let garbage = vec![0x00, 0x01, 0x02, 0x03]; + assert!(resize_and_compress(&garbage).is_err()); + } + + #[test] + fn video_file_detects_mime_and_common_extensions() { + assert!(is_video_file("clip.bin", Some("video/mp4"))); + assert!(is_video_file("clip.mp4", None)); + assert!(is_video_file("clip.MOV", None)); + assert!(!is_video_file("notes.txt", Some("text/plain"))); + } + + // --- validate_image_response tests --- + + #[test] + fn validate_accepts_png_with_matching_content_type() { + let png = make_png(1, 1); + assert!(validate_image_response(Some("image/png"), &png).is_ok()); + } + + #[test] + fn validate_accepts_jpeg_with_matching_content_type() { + let jpeg = make_jpeg(1, 1); + assert!(validate_image_response(Some("image/jpeg"), &jpeg).is_ok()); + } + + #[test] + fn validate_accepts_gif_with_matching_content_type() { + let gif = make_gif(); + assert!(validate_image_response(Some("image/gif"), &gif).is_ok()); + } + + #[test] + fn validate_rejects_corrupt_gif_body() { + let corrupt_gif = b"GIF89a\x01\x00\x01\x00\x00\x00\x00"; + let result = validate_image_response(Some("image/gif"), corrupt_gif); + assert!(matches!( + result, + Err(MediaFetchError::InvalidImageBody { .. }) + )); + } + + #[test] + fn validate_accepts_missing_content_type_with_valid_png() { + // When Content-Type header is absent, fall back to magic-byte detection. + let png = make_png(1, 1); + assert!(validate_image_response(None, &png).is_ok()); + } + + #[test] + fn validate_content_type_strips_params() { + // "image/png; charset=binary" is a real header value — must be accepted. + let png = make_png(1, 1); + assert!(validate_image_response(Some("image/png; charset=binary"), &png).is_ok()); + } + + /// Exact reproduction of issue #776: Slack serves the workspace login HTML + /// page at HTTP 200 when the bot token lacks the `files:read` scope. + /// The Slack file metadata says `mimetype: image/png`; the response body + /// magic bytes are `Slack login"; + let result = validate_image_response(Some("image/png"), html_body); + match result { + Err(MediaFetchError::InvalidImageBody { magic_prefix_hex }) => { + assert_eq!(magic_prefix_hex, "3c21444f43545950"); + } + other => panic!("expected InvalidImageBody, got {other:?}"), + } + } + + #[test] + fn validate_rejects_text_html_content_type() { + // Even if the body were a valid image, a text/html Content-Type must be rejected. + let png = make_png(1, 1); + let result = validate_image_response(Some("text/html; charset=utf-8"), &png); + assert!(matches!( + result, + Err(MediaFetchError::UnsupportedResponseType { .. }) + )); + } + + #[test] + fn validate_rejects_mixed_case_text_content_type() { + // Mixed-case Content-Type must be normalised before rejection. + let png = make_png(1, 1); + let result = validate_image_response(Some("Text/HTML; Charset=utf-8"), &png); + assert!(matches!( + result, + Err(MediaFetchError::UnsupportedResponseType { .. }) + )); + } + + /// Regression test for the application/octet-stream fix: CDNs and generic + /// file download endpoints commonly serve any file with this Content-Type. + /// The old allow-list incorrectly rejected it before magic-byte check. + #[test] + fn validate_accepts_octet_stream_with_valid_png() { + let png = make_png(1, 1); + assert!( + validate_image_response(Some("application/octet-stream"), &png).is_ok(), + "application/octet-stream must pass through to magic-byte check" + ); + } + + /// application/json body is rejected by magic bytes, not by Content-Type. + #[test] + fn validate_rejects_json_body_by_magic_bytes() { + let json_body = b"{\"error\":\"invalid_auth\",\"ok\":false}"; + let result = validate_image_response(Some("application/json"), json_body); + assert!(matches!( + result, + Err(MediaFetchError::InvalidImageBody { .. }) + )); + } + + /// Missing Content-Type with invalid body: CDN stripping the header should + /// still be caught by magic-byte detection. + #[test] + fn validate_rejects_html_body_with_missing_content_type() { + let html_body = b"error page"; + let result = validate_image_response(None, html_body); + assert!(matches!( + result, + Err(MediaFetchError::InvalidImageBody { .. }) + )); + } + + #[test] + fn validate_rejects_empty_body() { + let result = validate_image_response(Some("image/png"), &[]); + assert!(matches!( + result, + Err(MediaFetchError::InvalidImageBody { .. }) + )); + } + + #[test] + fn validate_rejects_truncated_png_header() { + // PNG magic is 8 bytes; 4 bytes is not enough to identify the format. + let truncated = [0x89u8, 0x50, 0x4e, 0x47]; + let result = validate_image_response(Some("image/png"), &truncated); + assert!(matches!( + result, + Err(MediaFetchError::InvalidImageBody { .. }) + )); + } + + #[test] + fn truncated_png_body_must_not_produce_content_block() { + // Valid PNG magic bytes (8 bytes) + partial IHDR -- body is too short to decode. + // Previously: the <=1MB fallback in download_and_encode_image forwarded raw bytes + // after resize_and_compress failed, reproducing the #776 poisoning class. + // After removing the fallback, resize_and_compress failure must propagate as Err. + let truncated: &[u8] = &[ + 0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a, // PNG magic + 0x00, 0x00, 0x00, 0x0d, 0x49, 0x48, 0x44, 0x52, // partial IHDR + ]; + assert!( + validate_image_response(Some("image/png"), truncated).is_ok(), + "magic-byte check still passes for truncated body" + ); + assert!( + resize_and_compress(truncated).is_err(), + "truncated PNG must fail at decode -- no raw-byte fallback allowed" + ); + } + + #[test] + fn media_fetch_error_display_renders() { + let _ = MediaFetchError::NotAnImage.to_string(); + let _ = MediaFetchError::UnsupportedResponseType { + actual: Some("text/html".into()), + } + .to_string(); + let s = MediaFetchError::UnsupportedResponseType { actual: None }.to_string(); + assert!(s.contains("none"), "None branch should render as 'none'"); + let _ = MediaFetchError::InvalidImageBody { + magic_prefix_hex: "3c21444f43545950".into(), + } + .to_string(); + let _ = MediaFetchError::SizeExceeded { + actual: 11_000_000, + limit: 10_000_000, + } + .to_string(); + let _ = MediaFetchError::HttpStatus(reqwest::StatusCode::UNAUTHORIZED).to_string(); + let _ = MediaFetchError::ProcessingFailed(image::ImageError::Unsupported( + image::error::UnsupportedError::from_format_and_kind( + image::error::ImageFormatHint::Unknown, + image::error::UnsupportedErrorKind::Color(image::ExtendedColorType::Rgba16), + ), + )) + .to_string(); + } + + #[test] + fn validate_accepts_webp_by_magic_bytes() { + let img = image::RgbImage::new(1, 1); + let mut buf = std::io::Cursor::new(Vec::new()); + img.write_to(&mut buf, image::ImageFormat::WebP).unwrap(); + let webp_body = buf.into_inner(); + assert!(validate_image_response(Some("image/webp"), &webp_body).is_ok()); + } + + #[test] + fn hex_prefix_formats_first_8_bytes() { + let bytes = b""; + assert_eq!(hex_prefix(bytes), "3c21444f43545950"); + } + + #[test] + fn hex_prefix_handles_short_buffer() { + let bytes = [0xffu8, 0xd8]; + assert_eq!(hex_prefix(&bytes), "ffd8"); + } +} diff --git a/crates/openab-core/src/multibot_cache.rs b/crates/openab-core/src/multibot_cache.rs new file mode 100644 index 000000000..a9bd8c82e --- /dev/null +++ b/crates/openab-core/src/multibot_cache.rs @@ -0,0 +1,85 @@ +//! Persistent disk cache for multibot thread detection. +//! +//! Once a thread is identified as multi-bot (irreversible), it is stored in +//! `~/.openab/cache/threads.json` so the detection survives restarts and +//! in-memory TTL expiry. + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::path::PathBuf; +use std::sync::{Arc, Mutex}; +use tracing::{info, warn}; + +#[derive(Serialize, Deserialize, Clone)] +struct Entry { + detected_at: DateTime, +} + +/// Shared multibot thread cache with file persistence. +#[derive(Clone)] +pub struct MultibotCache { + threads: Arc>>, + path: PathBuf, +} + +impl MultibotCache { + /// Load or create the cache from `~/.openab/cache/threads.json`. + pub fn load(path: PathBuf) -> Self { + let threads = match std::fs::read_to_string(&path) { + Ok(data) => serde_json::from_str(&data).unwrap_or_else(|e| { + warn!(error = %e, "failed to parse threads.json, starting empty"); + HashMap::new() + }), + Err(_) => HashMap::new(), + }; + info!(count = threads.len(), path = %path.display(), "loaded multibot cache"); + Self { + threads: Arc::new(Mutex::new(threads)), + path, + } + } + + /// Check if a thread is known to be multi-bot. + pub fn is_multibot(&self, thread_id: &str) -> bool { + self.threads.lock().unwrap().contains_key(thread_id) + } + + /// Mark a thread as multi-bot and persist to disk (non-blocking). + pub async fn mark_multibot(&self, thread_id: &str) { + let snapshot = { + let mut threads = self.threads.lock().unwrap(); + if threads.contains_key(thread_id) { + return; + } + threads.insert( + thread_id.to_string(), + Entry { + detected_at: Utc::now(), + }, + ); + threads.clone() + }; + let path = self.path.clone(); + tokio::task::spawn_blocking(move || persist(&path, &snapshot)).await.ok(); + } +} + +fn persist(path: &PathBuf, threads: &HashMap) { + if let Some(parent) = path.parent() { + if let Err(e) = std::fs::create_dir_all(parent) { + warn!(error = %e, "failed to create cache directory"); + return; + } + } + match serde_json::to_string_pretty(threads) { + Ok(data) => { + if let Err(e) = std::fs::write(path, data) { + warn!(error = %e, "failed to persist threads.json"); + } + } + Err(e) => { + warn!(error = %e, "failed to serialize multibot cache"); + } + } +} diff --git a/crates/openab-core/src/reactions.rs b/crates/openab-core/src/reactions.rs new file mode 100644 index 000000000..6e68f90b6 --- /dev/null +++ b/crates/openab-core/src/reactions.rs @@ -0,0 +1,276 @@ +use crate::adapter::{ChatAdapter, MessageRef}; +use crate::config::{ReactionEmojis, ReactionTiming}; +use std::sync::Arc; +use tokio::sync::Mutex; +use tokio::time::Duration; + +const CODING_TOKENS: &[&str] = &["exec", "process", "read", "write", "edit", "bash", "shell"]; +const WEB_TOKENS: &[&str] = &[ + "web_search", + "web_fetch", + "web-search", + "web-fetch", + "browser", +]; + +fn classify_tool<'a>(name: &str, emojis: &'a ReactionEmojis) -> &'a str { + let n = name.to_lowercase(); + if WEB_TOKENS.iter().any(|t| n.contains(t)) { + &emojis.web + } else if CODING_TOKENS.iter().any(|t| n.contains(t)) { + &emojis.coding + } else { + &emojis.tool + } +} + +struct Inner { + adapter: Arc, + message: MessageRef, + emojis: ReactionEmojis, + timing: ReactionTiming, + current: String, + finished: bool, + debounce_handle: Option>, + stall_soft_handle: Option>, + stall_hard_handle: Option>, +} + +pub struct StatusReactionController { + inner: Arc>, + enabled: bool, +} + +impl StatusReactionController { + pub fn new( + enabled: bool, + adapter: Arc, + message: MessageRef, + emojis: ReactionEmojis, + timing: ReactionTiming, + ) -> Self { + Self { + inner: Arc::new(Mutex::new(Inner { + adapter, + message, + emojis, + timing, + current: String::new(), + finished: false, + debounce_handle: None, + stall_soft_handle: None, + stall_hard_handle: None, + })), + enabled, + } + } + + pub async fn set_queued(&self) { + if !self.enabled { + return; + } + let emoji = { self.inner.lock().await.emojis.queued.clone() }; + self.apply_immediate(&emoji).await; + } + + pub async fn set_thinking(&self) { + if !self.enabled { + return; + } + let emoji = { self.inner.lock().await.emojis.thinking.clone() }; + self.schedule_debounced(&emoji).await; + } + + pub async fn set_tool(&self, tool_name: &str) { + if !self.enabled { + return; + } + let emoji = { + let inner = self.inner.lock().await; + classify_tool(tool_name, &inner.emojis).to_string() + }; + self.schedule_debounced(&emoji).await; + } + + pub async fn set_done(&self) { + if !self.enabled { + return; + } + let emoji = { self.inner.lock().await.emojis.done.clone() }; + self.finish(&emoji).await; + // Add a random mood face + let faces = ["😊", "😎", "🫡", "🤓", "😏", "✌️", "💪", "🦾"]; + let face = faces[rand::random::() % faces.len()]; + let inner = self.inner.lock().await; + let _ = inner.adapter.add_reaction(&inner.message, face).await; + } + + pub async fn set_error(&self) { + if !self.enabled { + return; + } + let emoji = { self.inner.lock().await.emojis.error.clone() }; + self.finish(&emoji).await; + } + + pub async fn clear(&self) { + if !self.enabled { + return; + } + let mut inner = self.inner.lock().await; + cancel_timers(&mut inner); + let current = inner.current.clone(); + if !current.is_empty() { + let _ = inner + .adapter + .remove_reaction(&inner.message, ¤t) + .await; + inner.current.clear(); + } + } + + async fn apply_immediate(&self, emoji: &str) { + let mut inner = self.inner.lock().await; + if inner.finished || emoji == inner.current { + return; + } + cancel_debounce(&mut inner); + let old = inner.current.clone(); + inner.current = emoji.to_string(); + let adapter = inner.adapter.clone(); + let msg = inner.message.clone(); + let new = emoji.to_string(); + drop(inner); + + let _ = adapter.add_reaction(&msg, &new).await; + if !old.is_empty() && old != new { + let _ = adapter.remove_reaction(&msg, &old).await; + } + self.reset_stall_timers().await; + } + + async fn schedule_debounced(&self, emoji: &str) { + let mut inner = self.inner.lock().await; + if inner.finished || emoji == inner.current { + self.reset_stall_timers_inner(&mut inner); + return; + } + cancel_debounce(&mut inner); + + let emoji = emoji.to_string(); + let ctrl = self.inner.clone(); + let debounce_ms = inner.timing.debounce_ms; + inner.debounce_handle = Some(tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(debounce_ms)).await; + let mut inner = ctrl.lock().await; + if inner.finished { + return; + } + let old = inner.current.clone(); + inner.current = emoji.clone(); + let adapter = inner.adapter.clone(); + let msg = inner.message.clone(); + drop(inner); + + let _ = adapter.add_reaction(&msg, &emoji).await; + if !old.is_empty() && old != emoji { + let _ = adapter.remove_reaction(&msg, &old).await; + } + })); + self.reset_stall_timers_inner(&mut inner); + } + + async fn finish(&self, emoji: &str) { + let mut inner = self.inner.lock().await; + if inner.finished { + return; + } + inner.finished = true; + cancel_timers(&mut inner); + + let old = inner.current.clone(); + inner.current = emoji.to_string(); + let adapter = inner.adapter.clone(); + let msg = inner.message.clone(); + let new = emoji.to_string(); + drop(inner); + + let _ = adapter.add_reaction(&msg, &new).await; + if !old.is_empty() && old != new { + let _ = adapter.remove_reaction(&msg, &old).await; + } + } + + async fn reset_stall_timers(&self) { + let mut inner = self.inner.lock().await; + self.reset_stall_timers_inner(&mut inner); + } + + fn reset_stall_timers_inner(&self, inner: &mut Inner) { + if let Some(h) = inner.stall_soft_handle.take() { + h.abort(); + } + if let Some(h) = inner.stall_hard_handle.take() { + h.abort(); + } + + let soft_ms = inner.timing.stall_soft_ms; + let hard_ms = inner.timing.stall_hard_ms; + let ctrl = self.inner.clone(); + + inner.stall_soft_handle = Some(tokio::spawn({ + let ctrl = ctrl.clone(); + async move { + tokio::time::sleep(Duration::from_millis(soft_ms)).await; + let mut inner = ctrl.lock().await; + if inner.finished { + return; + } + let old = inner.current.clone(); + inner.current = "🥱".to_string(); + let adapter = inner.adapter.clone(); + let msg = inner.message.clone(); + drop(inner); + let _ = adapter.add_reaction(&msg, "🥱").await; + if !old.is_empty() && old != "🥱" { + let _ = adapter.remove_reaction(&msg, &old).await; + } + } + })); + + inner.stall_hard_handle = Some(tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(hard_ms)).await; + let mut inner = ctrl.lock().await; + if inner.finished { + return; + } + let old = inner.current.clone(); + inner.current = "😨".to_string(); + let adapter = inner.adapter.clone(); + let msg = inner.message.clone(); + drop(inner); + let _ = adapter.add_reaction(&msg, "😨").await; + if !old.is_empty() && old != "😨" { + let _ = adapter.remove_reaction(&msg, &old).await; + } + })); + } +} + +fn cancel_debounce(inner: &mut Inner) { + if let Some(h) = inner.debounce_handle.take() { + h.abort(); + } +} + +fn cancel_timers(inner: &mut Inner) { + if let Some(h) = inner.debounce_handle.take() { + h.abort(); + } + if let Some(h) = inner.stall_soft_handle.take() { + h.abort(); + } + if let Some(h) = inner.stall_hard_handle.take() { + h.abort(); + } +} diff --git a/crates/openab-core/src/remind.rs b/crates/openab-core/src/remind.rs new file mode 100644 index 000000000..9472c53d8 --- /dev/null +++ b/crates/openab-core/src/remind.rs @@ -0,0 +1,399 @@ +//! One-shot `/remind` slash command — schedules a delayed mention in a Discord channel. +//! +//! Persistence: reminders are stored in `reminders.json` and reloaded on startup. + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use serenity::http::Http; +use serenity::model::id::ChannelId; +use std::path::PathBuf; +use std::sync::Arc; +use tokio::sync::Mutex; +use tracing::{error, info, warn}; + +/// A single pending reminder. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Reminder { + pub id: String, + pub channel_id: u64, + pub sender_id: u64, + /// Raw mention strings (e.g. "<@123>", "<@&456>") + pub targets: Vec, + pub message: String, + pub fire_at: DateTime, + pub created_at: DateTime, +} + +/// Shared reminder store with file persistence. +#[derive(Clone)] +pub struct ReminderStore { + reminders: Arc>>, + path: PathBuf, +} + +impl ReminderStore { + /// Load or create the reminder store from the given path. + pub fn load(path: PathBuf) -> Self { + let reminders = match std::fs::read_to_string(&path) { + Ok(data) => serde_json::from_str(&data).unwrap_or_else(|e| { + warn!(error = %e, "failed to parse reminders.json, starting empty"); + Vec::new() + }), + Err(_) => Vec::new(), + }; + info!(count = reminders.len(), path = %path.display(), "loaded reminders"); + Self { + reminders: Arc::new(Mutex::new(reminders)), + path, + } + } + + /// Add a reminder and persist to disk. + pub async fn add(&self, reminder: Reminder) { + let snapshot = { + let mut reminders = self.reminders.lock().await; + reminders.push(reminder); + reminders.clone() + }; + self.persist(&snapshot); + } + + /// Remove a reminder by ID and persist. + pub async fn remove(&self, id: &str) { + let snapshot = { + let mut reminders = self.reminders.lock().await; + reminders.retain(|r| r.id != id); + reminders.clone() + }; + self.persist(&snapshot); + } + + /// Get all pending reminders (for startup re-scheduling). + pub async fn pending(&self) -> Vec { + self.reminders.lock().await.clone() + } + + fn persist(&self, reminders: &[Reminder]) { + match serde_json::to_string_pretty(reminders) { + Ok(data) => { + if let Some(parent) = self.path.parent() { + if let Err(e) = std::fs::create_dir_all(parent) { + error!(error = %e, "failed to create reminders directory"); + return; + } + } + if let Err(e) = std::fs::write(&self.path, data) { + error!(error = %e, "failed to persist reminders.json"); + } + } + Err(e) => { + error!(error = %e, "failed to serialize reminders, skipping persist"); + } + } + } +} + +/// Maximum allowed message length for reminders. +pub const MAX_MESSAGE_LEN: usize = 1800; + +/// Maximum number of mention targets per reminder. +pub const MAX_TARGETS: usize = 10; + +/// Sanitize reminder message: neutralize @everyone/@here. +pub fn sanitize_message(msg: &str) -> String { + msg.replace("@everyone", "@\u{200b}everyone") + .replace("@here", "@\u{200b}here") +} + +/// Validate reminder message length. +pub fn validate_message(msg: &str) -> Result<(), String> { + if msg.len() > MAX_MESSAGE_LEN { + Err(format!("message too long (max {MAX_MESSAGE_LEN} characters)")) + } else { + Ok(()) + } +} + +/// Parse a human delay string like "30m", "2h", "7d" into seconds. +/// Supports combinations: "1h30m", "2d12h". +/// Range: 1m (60s) to 30d (2_592_000s). +pub fn parse_delay(input: &str) -> Result { + let s = input.trim().to_lowercase(); + if s.is_empty() { + return Err("empty delay".into()); + } + + let mut total_secs: u64 = 0; + let mut num_buf = String::new(); + + for ch in s.chars() { + if ch.is_ascii_digit() { + num_buf.push(ch); + } else { + let n: u64 = num_buf.parse().map_err(|_| format!("invalid number in delay: {input}"))?; + num_buf.clear(); + let multiplier = match ch { + 'm' => 60, + 'h' => 3600, + 'd' => 86400, + _ => return Err(format!("unknown unit '{ch}' in delay (use m/h/d)")), + }; + total_secs += n * multiplier; + } + } + + // Handle bare number (default to minutes) + if !num_buf.is_empty() { + let n: u64 = num_buf.parse().map_err(|_| format!("invalid number in delay: {input}"))?; + total_secs += n * 60; // default unit = minutes + } + + if total_secs < 60 { + return Err("minimum delay is 1m".into()); + } + if total_secs > 2_592_000 { + return Err("maximum delay is 30d".into()); + } + + Ok(total_secs) +} + +/// Format seconds into a human-readable string like "2h 30m". +pub fn format_delay(secs: u64) -> String { + let d = secs / 86400; + let h = (secs % 86400) / 3600; + let m = (secs % 3600) / 60; + let mut parts = Vec::new(); + if d > 0 { parts.push(format!("{d}d")); } + if h > 0 { parts.push(format!("{h}h")); } + if m > 0 { parts.push(format!("{m}m")); } + if parts.is_empty() { "< 1m".into() } else { parts.join(" ") } +} + +/// Spawn a tokio task that fires the reminder after the delay. +pub fn schedule_reminder( + http: Arc, + store: ReminderStore, + reminder: Reminder, +) { + let now = Utc::now(); + let delay = if reminder.fire_at > now { + (reminder.fire_at - now).to_std().unwrap_or_default() + } else { + std::time::Duration::ZERO + }; + + let id = reminder.id.clone(); + tokio::spawn(async move { + tokio::time::sleep(delay).await; + + let targets_str = reminder.targets.join(" "); + let content = format!( + "⏰ **Reminder** from <@{}>:\n\"{}\"\ncc {}", + reminder.sender_id, reminder.message, targets_str + ); + + let channel = ChannelId::new(reminder.channel_id); + match channel.say(&http, &content).await { + Ok(_) => { + info!(id = %id, channel = reminder.channel_id, "reminder fired"); + store.remove(&id).await; + } + Err(e) => { + error!(error = %e, id = %id, "failed to send reminder — keeping for retry on next restart"); + } + } + }); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_delay_minutes() { + assert_eq!(parse_delay("5m").unwrap(), 300); + assert_eq!(parse_delay("1m").unwrap(), 60); + } + + #[test] + fn parse_delay_hours() { + assert_eq!(parse_delay("2h").unwrap(), 7200); + } + + #[test] + fn parse_delay_days() { + assert_eq!(parse_delay("1d").unwrap(), 86400); + assert_eq!(parse_delay("30d").unwrap(), 2_592_000); + } + + #[test] + fn parse_delay_combined() { + assert_eq!(parse_delay("1h30m").unwrap(), 5400); + assert_eq!(parse_delay("1d12h").unwrap(), 129_600); + } + + #[test] + fn parse_delay_bare_number_defaults_to_minutes() { + assert_eq!(parse_delay("10").unwrap(), 600); + } + + #[test] + fn parse_delay_too_short() { + assert!(parse_delay("0m").is_err()); + assert!(parse_delay("0h").is_err()); + } + + #[test] + fn parse_delay_too_long() { + assert!(parse_delay("31d").is_err()); + } + + #[test] + fn format_delay_basic() { + assert_eq!(format_delay(3600), "1h"); + assert_eq!(format_delay(5400), "1h 30m"); + assert_eq!(format_delay(90000), "1d 1h"); + } + + #[test] + fn parse_delay_empty() { + assert!(parse_delay("").is_err()); + assert!(parse_delay(" ").is_err()); + } + + #[test] + fn parse_delay_invalid_unit() { + assert!(parse_delay("2x").is_err()); + assert!(parse_delay("abc").is_err()); + assert!(parse_delay("5s").is_err()); + } + + #[test] + fn parse_delay_case_insensitive() { + assert_eq!(parse_delay("2H").unwrap(), 7200); + assert_eq!(parse_delay("1D30M").unwrap(), 88200); + } + + #[test] + fn parse_delay_whitespace_trimmed() { + assert_eq!(parse_delay(" 5m ").unwrap(), 300); + } + + #[test] + fn parse_delay_bare_number_boundary() { + assert_eq!(parse_delay("1").unwrap(), 60); // 1 min + assert_eq!(parse_delay("30").unwrap(), 1800); // 30 min + } + + #[test] + fn parse_delay_exact_boundaries() { + // Exactly 1m (minimum) + assert_eq!(parse_delay("1m").unwrap(), 60); + // Exactly 30d (maximum) + assert_eq!(parse_delay("30d").unwrap(), 2_592_000); + // Just over 30d + assert!(parse_delay("30d1m").is_err()); + } + + #[test] + fn format_delay_zero() { + assert_eq!(format_delay(0), "< 1m"); + } + + #[test] + fn format_delay_pure_units() { + assert_eq!(format_delay(86400), "1d"); + assert_eq!(format_delay(120), "2m"); + assert_eq!(format_delay(7200), "2h"); + } + + #[tokio::test] + async fn reminder_store_add_remove() { + let dir = std::env::temp_dir().join(format!("remind_test_{}", std::process::id())); + std::fs::create_dir_all(&dir).unwrap(); + let path = dir.join("reminders.json"); + + let store = ReminderStore::load(path.clone()); + assert_eq!(store.pending().await.len(), 0); + + let r = Reminder { + id: "test-1".into(), + channel_id: 123, + sender_id: 456, + targets: vec!["<@789>".into()], + message: "hello".into(), + fire_at: Utc::now() + chrono::Duration::hours(1), + created_at: Utc::now(), + }; + + store.add(r).await; + assert_eq!(store.pending().await.len(), 1); + + store.remove("test-1").await; + assert_eq!(store.pending().await.len(), 0); + + // Verify persistence + let store2 = ReminderStore::load(path.clone()); + assert_eq!(store2.pending().await.len(), 0); + + std::fs::remove_dir_all(&dir).ok(); + } + + #[tokio::test] + async fn reminder_store_persists_across_reload() { + let dir = std::env::temp_dir().join(format!("remind_test2_{}", std::process::id())); + std::fs::create_dir_all(&dir).unwrap(); + let path = dir.join("reminders.json"); + + let store = ReminderStore::load(path.clone()); + let r = Reminder { + id: "persist-1".into(), + channel_id: 100, + sender_id: 200, + targets: vec!["<@300>".into()], + message: "persist test".into(), + fire_at: Utc::now() + chrono::Duration::hours(2), + created_at: Utc::now(), + }; + store.add(r).await; + + // Reload from disk + let store2 = ReminderStore::load(path.clone()); + let pending = store2.pending().await; + assert_eq!(pending.len(), 1); + assert_eq!(pending[0].id, "persist-1"); + assert_eq!(pending[0].message, "persist test"); + + std::fs::remove_dir_all(&dir).ok(); + } + + #[test] + fn sanitize_message_strips_everyone_here() { + assert_eq!(sanitize_message("hello @everyone"), "hello @\u{200b}everyone"); + assert_eq!(sanitize_message("hey @here check"), "hey @\u{200b}here check"); + assert_eq!(sanitize_message("@everyone @here"), "@\u{200b}everyone @\u{200b}here"); + } + + #[test] + fn sanitize_message_no_change() { + assert_eq!(sanitize_message("normal message"), "normal message"); + assert_eq!(sanitize_message("<@123> hello"), "<@123> hello"); + } + + #[test] + fn validate_message_ok() { + assert!(validate_message("short message").is_ok()); + assert!(validate_message(&"a".repeat(1800)).is_ok()); + } + + #[test] + fn validate_message_too_long() { + assert!(validate_message(&"a".repeat(1801)).is_err()); + } + + #[test] + fn max_targets_constant() { + assert_eq!(MAX_TARGETS, 10); + } +} diff --git a/crates/openab-core/src/secrets.rs b/crates/openab-core/src/secrets.rs new file mode 100644 index 000000000..e6a7967fd --- /dev/null +++ b/crates/openab-core/src/secrets.rs @@ -0,0 +1,479 @@ +use std::collections::HashMap; +use tracing::{error, info}; + +use crate::config::SecretsConfig; + +/// Resolved secrets: mapping from key name to plaintext value. +pub type ResolvedSecrets = HashMap; + +/// Resolve all secret references in the [secrets] config table. +/// Returns a map of key → resolved value. +pub async fn resolve(cfg: &SecretsConfig) -> anyhow::Result { + let mut resolved = HashMap::new(); + + // Build AWS client once if any refs use aws-sm:// + #[cfg(feature = "secrets-aws")] + let aws_client = if cfg.refs.values().any(|v| v.starts_with("aws-sm://")) { + Some(build_aws_client(cfg).await) + } else { + None + }; + + for (key, uri) in &cfg.refs { + let value = if uri.starts_with("aws-sm://") { + #[cfg(feature = "secrets-aws")] + { + let client = aws_client.as_ref().ok_or_else(|| { + anyhow::anyhow!("secret '{key}': AWS client not initialized") + })?; + resolve_aws_sm(key, uri, client).await? + } + #[cfg(not(feature = "secrets-aws"))] + { + anyhow::bail!( + "secret '{key}' uses aws-sm:// but the 'secrets-aws' feature is not enabled" + ); + } + } else if uri.starts_with("exec://") { + resolve_exec(key, uri, cfg).await? + } else { + anyhow::bail!( + "secret '{key}': unrecognized URI scheme in '{uri}' (expected aws-sm:// or exec://)" + ); + }; + resolved.insert(key.clone(), value); + } + + if !resolved.is_empty() { + info!(count = resolved.len(), "secrets resolved"); + } + Ok(resolved) +} + +// -- AWS Secrets Manager provider -- + +#[cfg(feature = "secrets-aws")] +async fn build_aws_client(cfg: &SecretsConfig) -> aws_sdk_secretsmanager::Client { + let mut config_loader = aws_config::defaults(aws_config::BehaviorVersion::latest()); + if let Some(ref region) = cfg.aws.region { + config_loader = config_loader.region(aws_config::Region::new(region.clone())); + } + if let Some(ref endpoint) = cfg.aws.endpoint_url { + config_loader = config_loader.endpoint_url(endpoint); + } + let sdk_config = config_loader.load().await; + aws_sdk_secretsmanager::Client::new(&sdk_config) +} + +#[cfg(feature = "secrets-aws")] +async fn resolve_aws_sm( + key: &str, + uri: &str, + client: &aws_sdk_secretsmanager::Client, +) -> anyhow::Result { + let (secret_id, json_key) = parse_aws_sm_uri(uri) + .ok_or_else(|| anyhow::anyhow!("secret '{key}': invalid aws-sm:// URI '{uri}' — expected aws-sm://#"))?; + + let resp = client + .get_secret_value() + .secret_id(&secret_id) + .send() + .await + .map_err(|e| { + error!(secret = key, secret_id = %secret_id, "AWS Secrets Manager error"); + anyhow::anyhow!("secret '{key}': failed to fetch '{secret_id}' from AWS Secrets Manager: {e}") + })?; + + let secret_string = resp + .secret_string() + .ok_or_else(|| anyhow::anyhow!("secret '{key}': '{secret_id}' has no string value (binary secrets not supported)"))?; + + // Parse as JSON and extract the key + let json: serde_json::Value = serde_json::from_str(secret_string) + .map_err(|e| anyhow::anyhow!("secret '{key}': '{secret_id}' is not valid JSON: {e}"))?; + + let value = json + .get(&json_key) + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("secret '{key}': JSON key '{json_key}' not found in '{secret_id}'"))?; + + Ok(value.to_owned()) +} + +/// Parse `aws-sm://secret-id#json-key` into (secret_id, json_key). +#[cfg(feature = "secrets-aws")] +fn parse_aws_sm_uri(uri: &str) -> Option<(String, String)> { + let rest = uri.strip_prefix("aws-sm://")?; + let (secret_id, json_key) = rest.rsplit_once('#')?; + if secret_id.is_empty() || json_key.is_empty() { + return None; + } + Some((secret_id.to_owned(), json_key.to_owned())) +} + +// -- Exec provider -- +// Note: script path is delimited by the first space. Paths containing spaces are not supported. + +async fn resolve_exec(key: &str, uri: &str, cfg: &SecretsConfig) -> anyhow::Result { + let rest = uri.strip_prefix("exec://").unwrap(); + let mut parts_iter = rest.splitn(3, ' '); + let script = parts_iter.next().ok_or_else(|| { + anyhow::anyhow!("secret '{key}': exec:// URI missing script path") + })?; + if script.is_empty() { + anyhow::bail!("secret '{key}': exec:// URI has empty script path"); + } + + let mut cmd = tokio::process::Command::new(script); + cmd.kill_on_drop(true); + + // Sanitized environment (same as pre_boot hooks — no unrelated tokens leak) + cmd.env_clear(); + if let Ok(v) = std::env::var("HOME") { + cmd.env("HOME", &v); + } + if let Ok(v) = std::env::var("PATH") { + cmd.env("PATH", &v); + } + #[cfg(unix)] + if let Ok(v) = std::env::var("USER") { + cmd.env("USER", &v); + } + // Pass through cloud credential env vars for IAM-based auth + for (key, val) in std::env::vars() { + let pass = key.starts_with("AWS_") + || key.starts_with("AMAZON_") + || key.starts_with("ECS_CONTAINER_METADATA_URI") + || key.starts_with("GOOGLE_") + || key.starts_with("GCLOUD_") + || key.starts_with("CLOUDSDK_") + || key.starts_with("AZURE_"); + if pass { + cmd.env(&key, &val); + } + } + + // Pass remaining parts as arguments (key, attribute) + for arg in parts_iter { + if !arg.is_empty() { + cmd.arg(arg); + } + } + + let timeout = std::time::Duration::from_secs(cfg.exec.timeout_seconds); + let output = tokio::time::timeout(timeout, cmd.output()) + .await + .map_err(|_| { + anyhow::anyhow!( + "secret '{key}': exec script '{script}' timed out after {}s", + cfg.exec.timeout_seconds + ) + })? + .map_err(|e| { + if e.kind() == std::io::ErrorKind::NotFound { + anyhow::anyhow!( + "secret '{key}': exec script '{script}' not found — did [hooks.pre_boot] run successfully?" + ) + } else { + anyhow::anyhow!("secret '{key}': failed to execute '{script}': {e}") + } + })?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + error!(secret = key, script, %stderr, "exec provider failed"); + anyhow::bail!( + "secret '{key}': exec script '{script}' exited with {}", + output.status + ); + } + + let value = String::from_utf8(output.stdout) + .map_err(|e| anyhow::anyhow!("secret '{key}': exec output is not valid UTF-8: {e}"))?; + Ok(value.trim_end_matches('\n').to_owned()) +} + +/// Substitute `${secrets.}` references in the raw config text with resolved values. +/// Uses single-pass replacement to avoid double-substitution if a secret value +/// itself contains `${secrets.*}` patterns. +/// Values are escaped for use within TOML double-quoted strings. +pub fn substitute(raw: &str, secrets: &ResolvedSecrets) -> String { + let re = regex::Regex::new(r"\$\{secrets\.([^}]+)\}").unwrap(); + re.replace_all(raw, |caps: ®ex::Captures| { + let key = &caps[1]; + secrets + .get(key) + .map(|v| escape_toml_value(v)) + .unwrap_or_else(|| caps[0].to_owned()) + }) + .into_owned() +} + +/// Escape a string value so it is safe inside a TOML double-quoted string. +fn escape_toml_value(s: &str) -> String { + let mut out = String::with_capacity(s.len()); + for ch in s.chars() { + match ch { + '\\' => out.push_str("\\\\"), + '"' => out.push_str("\\\""), + '\n' => out.push_str("\\n"), + '\r' => out.push_str("\\r"), + '\t' => out.push_str("\\t"), + _ => out.push(ch), + } + } + out +} + +#[cfg(test)] +mod tests { + use super::*; + + #[cfg(feature = "secrets-aws")] + #[test] + fn parse_aws_sm_uri_valid() { + let (id, key) = parse_aws_sm_uri("aws-sm://openab/prod#discord_bot_token").unwrap(); + assert_eq!(id, "openab/prod"); + assert_eq!(key, "discord_bot_token"); + } + + #[cfg(feature = "secrets-aws")] + #[test] + fn parse_aws_sm_uri_with_arn() { + let uri = "aws-sm://arn:aws:secretsmanager:us-east-1:123456789:secret:my-secret-abc123#api_key"; + let (id, key) = parse_aws_sm_uri(uri).unwrap(); + assert_eq!(id, "arn:aws:secretsmanager:us-east-1:123456789:secret:my-secret-abc123"); + assert_eq!(key, "api_key"); + } + + #[cfg(feature = "secrets-aws")] + #[test] + fn parse_aws_sm_uri_missing_key() { + assert!(parse_aws_sm_uri("aws-sm://openab/prod").is_none()); + assert!(parse_aws_sm_uri("aws-sm://openab/prod#").is_none()); + assert!(parse_aws_sm_uri("aws-sm://#key").is_none()); + } + + #[test] + fn substitute_replaces_secrets() { + let mut secrets = HashMap::new(); + secrets.insert("token".to_owned(), "my-secret-value".to_owned()); + let input = r#"bot_token = "${secrets.token}""#; + let output = substitute(input, &secrets); + assert_eq!(output, r#"bot_token = "my-secret-value""#); + } + + #[test] + fn substitute_escapes_special_chars() { + let mut secrets = HashMap::new(); + secrets.insert("key".to_owned(), "has\"quotes\\and\nnewlines".to_owned()); + let input = r#"value = "${secrets.key}""#; + let output = substitute(input, &secrets); + assert_eq!(output, r#"value = "has\"quotes\\and\nnewlines""#); + } + + #[test] + fn substitute_no_match_unchanged() { + let secrets = HashMap::new(); + let input = r#"bot_token = "${DISCORD_BOT_TOKEN}""#; + let output = substitute(input, &secrets); + assert_eq!(output, input); + } + + #[test] + fn substitute_unknown_key_left_intact() { + let mut secrets = HashMap::new(); + secrets.insert("known".to_owned(), "val".to_owned()); + let input = r#"a = "${secrets.known}" b = "${secrets.unknown}""#; + let output = substitute(input, &secrets); + assert_eq!(output, r#"a = "val" b = "${secrets.unknown}""#); + } + + #[test] + fn substitute_no_double_replacement() { + let mut secrets = HashMap::new(); + // Secret value itself contains a ${secrets.*} pattern + secrets.insert("a".to_owned(), "${secrets.b}".to_owned()); + secrets.insert("b".to_owned(), "should-not-appear".to_owned()); + let input = r#"val = "${secrets.a}""#; + let output = substitute(input, &secrets); + // The literal ${secrets.b} should be escaped, not re-substituted + assert!(!output.contains("should-not-appear")); + } + + #[test] + fn substitute_multiple_refs_same_line() { + let mut secrets = HashMap::new(); + secrets.insert("user".to_owned(), "admin".to_owned()); + secrets.insert("pass".to_owned(), "s3cret".to_owned()); + let input = r#"dsn = "postgres://${secrets.user}:${secrets.pass}@localhost""#; + let output = substitute(input, &secrets); + assert_eq!(output, r#"dsn = "postgres://admin:s3cret@localhost""#); + } + + #[test] + fn escape_toml_value_basic() { + assert_eq!(escape_toml_value("hello"), "hello"); + assert_eq!(escape_toml_value(r#"a"b"#), r#"a\"b"#); + assert_eq!(escape_toml_value("a\\b"), "a\\\\b"); + assert_eq!(escape_toml_value("line1\nline2"), "line1\\nline2"); + assert_eq!(escape_toml_value("tab\there"), "tab\\there"); + assert_eq!(escape_toml_value("cr\rhere"), "cr\\rhere"); + } + + #[test] + fn escape_toml_value_combined() { + let input = "key=\"val\"\nnext"; + let output = escape_toml_value(input); + assert_eq!(output, "key=\\\"val\\\"\\nnext"); + } + + #[tokio::test] + async fn resolve_exec_success() { + use crate::config::{AwsSecretsConfig, ExecSecretsConfig, SecretsConfig}; + let cfg = SecretsConfig { + aws: AwsSecretsConfig::default(), + exec: ExecSecretsConfig { timeout_seconds: 5 }, + refs: HashMap::new(), + }; + // Use echo as a script + let result = resolve_exec("test", "exec:///bin/echo hello world", &cfg).await; + assert_eq!(result.unwrap(), "hello world"); + } + + #[tokio::test] + async fn resolve_exec_not_found() { + use crate::config::{AwsSecretsConfig, ExecSecretsConfig, SecretsConfig}; + let cfg = SecretsConfig { + aws: AwsSecretsConfig::default(), + exec: ExecSecretsConfig { timeout_seconds: 5 }, + refs: HashMap::new(), + }; + let result = + resolve_exec("test", "exec:///nonexistent/script arg1 arg2", &cfg).await; + let err = result.unwrap_err().to_string(); + assert!(err.contains("not found")); + assert!(err.contains("pre_boot")); + } + + #[tokio::test] + async fn resolve_exec_nonzero_exit() { + use crate::config::{AwsSecretsConfig, ExecSecretsConfig, SecretsConfig}; + let cfg = SecretsConfig { + aws: AwsSecretsConfig::default(), + exec: ExecSecretsConfig { timeout_seconds: 5 }, + refs: HashMap::new(), + }; + let result = resolve_exec("test", "exec:///bin/false", &cfg).await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("exited with")); + } + + #[tokio::test] + async fn resolve_exec_strips_trailing_newline() { + use crate::config::{AwsSecretsConfig, ExecSecretsConfig, SecretsConfig}; + let cfg = SecretsConfig { + aws: AwsSecretsConfig::default(), + exec: ExecSecretsConfig { timeout_seconds: 5 }, + refs: HashMap::new(), + }; + // printf adds no newline, echo does — test both + let result = resolve_exec("test", "exec:///bin/echo secret_value", &cfg).await; + assert_eq!(result.unwrap(), "secret_value"); + } + + #[tokio::test] + async fn resolve_empty_refs_returns_empty() { + use crate::config::{AwsSecretsConfig, ExecSecretsConfig, SecretsConfig}; + let cfg = SecretsConfig { + aws: AwsSecretsConfig::default(), + exec: ExecSecretsConfig::default(), + refs: HashMap::new(), + }; + let result = resolve(&cfg).await.unwrap(); + assert!(result.is_empty()); + } + + #[tokio::test] + async fn resolve_unknown_scheme_fails() { + use crate::config::{AwsSecretsConfig, ExecSecretsConfig, SecretsConfig}; + let mut refs = HashMap::new(); + refs.insert("bad".to_owned(), "ftp://something".to_owned()); + let cfg = SecretsConfig { + aws: AwsSecretsConfig::default(), + exec: ExecSecretsConfig::default(), + refs, + }; + let result = resolve(&cfg).await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("unrecognized URI scheme")); + } + + #[cfg(feature = "secrets-aws")] + #[test] + fn parse_aws_sm_uri_hash_in_secret_name() { + // rsplit_once('#') should split on the LAST # + let uri = "aws-sm://my#secret#api_key"; + let (id, key) = parse_aws_sm_uri(uri).unwrap(); + assert_eq!(id, "my#secret"); + assert_eq!(key, "api_key"); + } + + #[tokio::test] + async fn resolve_exec_sanitized_env() { + use crate::config::{AwsSecretsConfig, ExecSecretsConfig, SecretsConfig}; + // Set a dummy env var that should NOT be visible to the exec script + std::env::set_var("OPENAB_TEST_LEAKED_SECRET", "should_not_leak"); + let cfg = SecretsConfig { + aws: AwsSecretsConfig::default(), + exec: ExecSecretsConfig { timeout_seconds: 5 }, + refs: HashMap::new(), + }; + // /usr/bin/env prints all env vars; grep for our dummy var + let result = resolve_exec("test", "exec:///usr/bin/env", &cfg).await.unwrap(); + assert!( + !result.contains("OPENAB_TEST_LEAKED_SECRET"), + "exec script should not see unrelated env vars" + ); + // HOME and PATH should still be present + assert!(result.contains("HOME="), "HOME should be in sanitized env"); + assert!(result.contains("PATH="), "PATH should be in sanitized env"); + std::env::remove_var("OPENAB_TEST_LEAKED_SECRET"); + } + + #[tokio::test] + async fn resolve_exec_timeout() { + use crate::config::{AwsSecretsConfig, ExecSecretsConfig, SecretsConfig}; + let cfg = SecretsConfig { + aws: AwsSecretsConfig::default(), + exec: ExecSecretsConfig { timeout_seconds: 1 }, + refs: HashMap::new(), + }; + // sleep 10 will be killed after 1s timeout + let result = resolve_exec("test", "exec:///bin/sleep 10", &cfg).await; + let err = result.unwrap_err().to_string(); + assert!(err.contains("timed out"), "expected timeout error, got: {err}"); + } + + #[test] + fn substitute_and_reparse_integration() { + let mut secrets = HashMap::new(); + secrets.insert("token".to_owned(), "xoxb-secret-value".to_owned()); + secrets.insert("key".to_owned(), "sk-with\"special\\chars".to_owned()); + + let raw = r#" +[discord] +bot_token = "${secrets.token}" + +[agent] +command = "echo" +args = ["--key", "${secrets.key}"] +"#; + let substituted = substitute(raw, &secrets); + // Verify the substituted text is valid TOML that parses correctly + let cfg: crate::config::Config = toml::from_str(&substituted) + .expect("substituted config should be valid TOML"); + assert_eq!(cfg.discord.unwrap().bot_token, "xoxb-secret-value"); + assert_eq!(cfg.agent.args[1], "sk-with\"special\\chars"); + } +} diff --git a/crates/openab-core/src/setup/config.rs b/crates/openab-core/src/setup/config.rs new file mode 100644 index 000000000..ee1483650 --- /dev/null +++ b/crates/openab-core/src/setup/config.rs @@ -0,0 +1,157 @@ +//! Config generation and TOML serialization for the setup wizard. + +/// Mask bot token in config output for preview +pub fn mask_bot_token(config: &str) -> String { + config + .lines() + .map(|line| { + if line.trim_start().starts_with("bot_token") { + "bot_token = \"***\"".to_string() + } else { + line.to_string() + } + }) + .collect::>() + .join("\n") +} + +#[derive(serde::Serialize)] +pub(crate) struct ConfigToml { + discord: DiscordConfigToml, + agent: AgentConfigToml, + pool: PoolConfigToml, + reactions: ReactionsConfigToml, +} + +#[derive(serde::Serialize)] +struct DiscordConfigToml { + bot_token: String, + allowed_channels: Vec, +} + +#[derive(serde::Serialize)] +struct AgentConfigToml { + command: String, + args: Vec, + working_dir: String, +} + +#[derive(serde::Serialize)] +struct PoolConfigToml { + max_sessions: usize, + session_ttl_hours: u64, +} + +#[derive(serde::Serialize)] +struct ReactionsConfigToml { + enabled: bool, + remove_after_reply: bool, + emojis: EmojisToml, + timing: TimingToml, +} + +#[derive(serde::Serialize)] +struct EmojisToml { + queued: String, + thinking: String, + tool: String, + coding: String, + web: String, + done: String, + error: String, +} + +#[derive(serde::Serialize)] +struct TimingToml { + debounce_ms: u64, + stall_soft_ms: u64, + stall_hard_ms: u64, + done_hold_ms: u64, + error_hold_ms: u64, +} + +pub fn generate_config( + bot_token: &str, + agent_command: &str, + channel_ids: Vec, + working_dir: &str, + max_sessions: usize, + session_ttl_hours: u64, +) -> String { + let config = ConfigToml { + discord: DiscordConfigToml { + bot_token: bot_token.to_string(), + allowed_channels: channel_ids, + }, + agent: { + let (command, args): (&str, Vec) = match agent_command { + "kiro" => ("kiro-cli", vec!["acp".into(), "--trust-all-tools".into()]), + "claude" => ("claude-agent-acp", vec![]), + "codex" => ("codex-acp", vec![]), + "gemini" => ("gemini", vec!["--acp".into()]), + other => (other, vec![]), + }; + AgentConfigToml { + command: command.to_string(), + args, + working_dir: working_dir.to_string(), + } + }, + pool: PoolConfigToml { + max_sessions, + session_ttl_hours, + }, + reactions: ReactionsConfigToml { + enabled: true, + remove_after_reply: false, + emojis: EmojisToml { + queued: "👀".into(), + thinking: "🤔".into(), + tool: "🔥".into(), + coding: "👨💻".into(), + web: "⚡".into(), + done: "🆗".into(), + error: "😱".into(), + }, + timing: TimingToml { + debounce_ms: 700, + stall_soft_ms: 10_000, + stall_hard_ms: 30_000, + done_hold_ms: 1_500, + error_hold_ms: 2_500, + }, + }, + }; + toml::to_string_pretty(&config).expect("TOML serialization failed") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn generate_config_contains_sections() { + let config = generate_config( + "my_token", + "claude", + vec!["123".to_string()], + "/home/agent", + 10, + 24, + ); + assert!(config.contains("[discord]")); + assert!(config.contains("[agent]")); + assert!(config.contains("[pool]")); + assert!(config.contains("[reactions]")); + assert!(config.contains("[reactions.emojis]")); + assert!(config.contains("[reactions.timing]")); + } + + #[test] + fn generate_config_kiro_working_dir() { + let config = generate_config("tok", "kiro", vec!["ch".to_string()], "/home/agent", 10, 24); + assert!(config.contains(r#"working_dir = "/home/agent""#)); + assert!(config.contains("acp")); + assert!(config.contains("--trust-all-tools")); + } +} diff --git a/crates/openab-core/src/setup/mod.rs b/crates/openab-core/src/setup/mod.rs new file mode 100644 index 000000000..96034f0ab --- /dev/null +++ b/crates/openab-core/src/setup/mod.rs @@ -0,0 +1,12 @@ +//! OpenAB interactive setup wizard. +//! +//! Modules: +//! - `validate` — input validation (bot token, channel ID, agent command) +//! - `config` — TOML config generation and serialization +//! - `wizard` — interactive TUI, Discord API client, and wizard entry point + +mod config; +mod validate; +mod wizard; + +pub use wizard::run_setup; diff --git a/crates/openab-core/src/setup/validate.rs b/crates/openab-core/src/setup/validate.rs new file mode 100644 index 000000000..b09401559 --- /dev/null +++ b/crates/openab-core/src/setup/validate.rs @@ -0,0 +1,78 @@ +//! Input validation functions for the setup wizard. + +/// Validate bot token format using allowlist (a-zA-Z0-9-./_) +pub fn validate_bot_token(token: &str) -> anyhow::Result<()> { + if token.is_empty() { + anyhow::bail!("Token cannot be empty"); + } + if !token.chars().all(|c| { + c.is_ascii_alphanumeric() + || c == '-' + || c == '.' + || c == '_' + || c == '/' + || c == '*' + || c == '=' + }) { + anyhow::bail!( + "Token must only contain ASCII letters, numbers, dash, period, underscore, slash, or equals" + ); + } + Ok(()) +} + +/// Validate agent command +#[cfg(test)] +pub fn validate_agent_command(cmd: &str) -> anyhow::Result<()> { + let valid = ["kiro", "claude", "codex", "gemini"]; + if !valid.contains(&cmd) { + anyhow::bail!("Agent must be one of: {}", valid.join(", ")); + } + Ok(()) +} + +/// Validate channel ID is numeric +pub fn validate_channel_id(id: &str) -> anyhow::Result<()> { + if id.is_empty() { + anyhow::bail!("Channel ID cannot be empty"); + } + if !id.chars().all(|c| c.is_ascii_digit()) { + anyhow::bail!("Channel ID must be numeric only"); + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn validate_bot_token_ok() { + assert!(validate_bot_token("simple_token").is_ok()); + assert!(validate_bot_token("token.with-dashes_123").is_ok()); + assert!(validate_bot_token("***/efgh").is_ok()); + } + + #[test] + fn validate_bot_token_reject_invalid() { + assert!(validate_bot_token("").is_err()); + assert!(validate_bot_token("token\nnewline").is_err()); + assert!(validate_bot_token("token\ttab").is_err()); + assert!(validate_bot_token("token with space").is_err()); + } + + #[test] + fn validate_agent_command_known_and_unknown() { + for agent in &["kiro", "claude", "codex", "gemini"] { + assert!(validate_agent_command(agent).is_ok()); + } + assert!(validate_agent_command("invalid").is_err()); + } + + #[test] + fn validate_channel_id_accepts_numeric_rejects_invalid() { + assert!(validate_channel_id("1492329565824094370").is_ok()); + assert!(validate_channel_id("").is_err()); + assert!(validate_channel_id("abc123").is_err()); + } +} diff --git a/crates/openab-core/src/setup/wizard.rs b/crates/openab-core/src/setup/wizard.rs new file mode 100644 index 000000000..f5a789609 --- /dev/null +++ b/crates/openab-core/src/setup/wizard.rs @@ -0,0 +1,667 @@ +//! Interactive setup wizard TUI and Discord API client. + +use std::io::{self, IsTerminal, Write}; +use std::path::{Path, PathBuf}; + +use crate::setup::config::{generate_config, mask_bot_token}; +use crate::setup::validate::{validate_bot_token, validate_channel_id}; + +// --------------------------------------------------------------------------- +// Color codes (ANSI) +// --------------------------------------------------------------------------- + +const C: Colors = Colors { + reset: "\x1b[0m", + bold: "\x1b[1m", + cyan: "\x1b[36m", + green: "\x1b[32m", + red: "\x1b[31m", + yellow: "\x1b[33m", + magenta: "\x1b[35m", +}; + +struct Colors { + reset: &'static str, + bold: &'static str, + cyan: &'static str, + green: &'static str, + red: &'static str, + yellow: &'static str, + magenta: &'static str, +} + +const BORDER: char = '═'; + +macro_rules! cprintln { + ($color:expr, $fmt:expr) => {{ + println!("{}{}{}", $color, $fmt, C.reset); + }}; + ($color:expr, $fmt:expr, $($arg:tt)*) => {{ + println!("{}{}{}", $color, format!($fmt, $($arg)*), C.reset); + }}; +} + +// --------------------------------------------------------------------------- +// Input helpers +// --------------------------------------------------------------------------- + +fn is_interactive() -> bool { + std::io::stdin().is_terminal() && std::io::stdout().is_terminal() +} + +fn prompt(prompt_text: &str) -> String { + print!("{}{}: {}", C.yellow, prompt_text, C.reset); + io::stdout().flush().ok(); + let mut input = String::new(); + io::stdin().read_line(&mut input).ok(); + input.trim().to_string() +} + +fn prompt_default(prompt_text: &str, default: &str) -> String { + print!("{}{} [{}]: {}", C.yellow, prompt_text, default, C.reset); + io::stdout().flush().ok(); + let mut input = String::new(); + io::stdin().read_line(&mut input).ok(); + let input = input.trim(); + if input.is_empty() { + default.to_string() + } else { + input.to_string() + } +} + +fn prompt_password(prompt_text: &str) -> String { + print!("{}{}: ", C.yellow, prompt_text); + io::stdout().flush().ok(); + rpassword::read_password().unwrap_or_default() +} + +fn prompt_yes_no(prompt_text: &str, default: bool) -> bool { + let default_str = if default { "Y/n" } else { "y/N" }; + loop { + print!("{}{} [{}]: ", C.yellow, prompt_text, default_str,); + io::stdout().flush().ok(); + let mut input = String::new(); + io::stdin().read_line(&mut input).ok(); + let input = input.trim().to_lowercase(); + if input.is_empty() { + return default; + } + match input.as_str() { + "y" | "yes" => return true, + "n" | "no" => return false, + _ => cprintln!(C.red, "Please enter 'y' or 'n'"), + } + } +} + +fn prompt_choice(prompt_text: &str, choices: &[&str]) -> usize { + println!(); + cprintln!(C.cyan, "{}", prompt_text); + for (i, choice) in choices.iter().enumerate() { + println!(" {}. {}", i + 1, choice); + } + print!("{}Select [1-{}]: {}", C.yellow, choices.len(), C.reset); + io::stdout().flush().ok(); + loop { + let mut input = String::new(); + io::stdin().read_line(&mut input).ok(); + match input.trim().parse::() { + Ok(n) if n >= 1 && n <= choices.len() => return n - 1, + _ => { + print!("{}Select [1-{}]: {}", C.yellow, choices.len(), C.reset); + io::stdout().flush().ok(); + } + } + } +} + +fn prompt_checklist(prompt_text: &str, items: &[&str]) -> Vec { + println!(); + cprintln!(C.cyan, "{}", prompt_text); + for (i, item) in items.iter().enumerate() { + println!(" [{}] {}", i + 1, item); + } + println!(); + print!( + "{}Enter numbers separated by commas (e.g. 1,3,5) or press Enter for all: {}", + C.yellow, C.reset + ); + io::stdout().flush().ok(); + let mut input = String::new(); + io::stdin().read_line(&mut input).ok(); + let input = input.trim(); + if input.is_empty() { + return (0..items.len()).collect(); + } + input + .split(',') + .filter_map(|s| s.trim().parse::().ok()) + .filter(|n| *n >= 1 && *n <= items.len()) + .map(|n| n - 1) + .collect() +} + +// --------------------------------------------------------------------------- +// Box drawing helpers +// --------------------------------------------------------------------------- + +fn print_box(lines: &[&str]) { + let width = lines + .iter() + .map(|l| unicode_width::UnicodeWidthStr::width(&**l)) + .max() + .unwrap_or(60); + let width = width.clamp(60, 76); + println!(); + cprintln!( + C.cyan, + "{}", + "╔".to_string() + &BORDER.to_string().repeat(width + 2) + "╗" + ); + for line in lines { + let padded = format!(" {: Self { + Self { + token: token.to_string(), + http: reqwest::blocking::Client::builder() + .timeout(std::time::Duration::from_secs(10)) + .build() + .expect("static HTTP client must build"), + } + } + + /// Verify token by fetching bot info + fn verify_token(&self) -> anyhow::Result<(String, String)> { + let resp = self + .http + .get("https://discord.com/api/v10/users/@me") + .header("Authorization", format!("Bot {}", self.token)) + .header("User-Agent", "OpenAB setup wizard") + .send()?; + if !resp.status().is_success() { + anyhow::bail!("Token verification failed: HTTP {}", resp.status()); + } + #[derive(serde::Deserialize)] + struct MeResponse { + id: String, + username: String, + } + let me: MeResponse = resp.json()?; + Ok((me.id, me.username)) + } + + /// Fetch guilds the bot is in + fn fetch_guilds(&self) -> anyhow::Result> { + let resp = self + .http + .get("https://discord.com/api/v10/users/@me/guilds") + .header("Authorization", format!("Bot {}", self.token)) + .header("User-Agent", "OpenAB setup wizard") + .send()?; + if !resp.status().is_success() { + anyhow::bail!("Failed to fetch guilds: HTTP {}", resp.status()); + } + #[derive(serde::Deserialize)] + struct Guild { + id: String, + name: String, + } + let guilds: Vec = resp.json()?; + Ok(guilds.into_iter().map(|g| (g.id, g.name)).collect()) + } + + /// Fetch channels in a guild + fn fetch_channels(&self, guild_id: &str) -> anyhow::Result> { + let url = format!("https://discord.com/api/v10/guilds/{}/channels", guild_id); + let resp = self + .http + .get(&url) + .header("Authorization", format!("Bot {}", self.token)) + .header("User-Agent", "OpenAB setup wizard") + .send()?; + if !resp.status().is_success() { + anyhow::bail!("Failed to fetch channels: HTTP {}", resp.status()); + } + #[derive(serde::Deserialize)] + struct Channel { + id: String, + #[serde(rename = "type")] + kind: u8, + name: String, + } + let channels: Vec = resp.json()?; + // type 0 = text channel + Ok(channels + .into_iter() + .filter(|c| c.kind == 0) + .map(|c| (c.id, c.name, guild_id.to_string())) + .collect()) + } +} + +// --------------------------------------------------------------------------- +// Section 1: Discord Bot Setup Guide +// --------------------------------------------------------------------------- + +fn section_discord_guide() { + print_box(&[ + "Discord Bot Setup Guide", + "", + "1. Go to: https://discord.com/developers/applications", + "2. Click 'New Application' -> name it (e.g. OpenAB)", + "3. Bot -> Reset Token -> COPY the token", + "", + "4. Enable Privileged Gateway Intents:", + " - Message Content Intent", + " - Guild Members Intent", + "", + "5. OAuth2 -> URL Generator:", + " - SCOPES: bot", + " - BOT PERMISSIONS:", + " Send Messages | Embed Links | Attach Files", + " Read Message History | Add Reactions", + " Use Slash Commands", + "", + "6. Visit the generated URL -> add bot to your server", + ]); +} + +// --------------------------------------------------------------------------- +// Section 2: Channel Selection +// --------------------------------------------------------------------------- + +fn section_channels(client: &DiscordClient) -> anyhow::Result> { + println!(); + cprintln!(C.bold, "--- Step 2: Allowed Channels ---"); + println!(); + + print!(" Fetching servers... "); + io::stdout().flush().ok(); + let guilds = client.fetch_guilds()?; + cprintln!(C.green, "OK Found {} server(s)", guilds.len()); + println!(); + + if guilds.is_empty() { + cprintln!(C.yellow, " No servers found. Enter channel IDs manually."); + let input = prompt(" Channel ID(s), comma-separated"); + let ids: Vec = input + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect(); + for id in &ids { + validate_channel_id(id)?; + } + return Ok(ids); + } + + let guild_names: Vec<&str> = guilds.iter().map(|(_, n)| n.as_str()).collect(); + let guild_idx = prompt_choice(" Select server:", &guild_names); + let (guild_id, guild_name) = &guilds[guild_idx]; + + print!(" Fetching channels in '{}'... ", guild_name); + io::stdout().flush().ok(); + let channels = client.fetch_channels(guild_id)?; + cprintln!(C.green, "OK Found {} channel(s)", channels.len()); + println!(); + + if channels.is_empty() { + cprintln!( + C.yellow, + " No text channels found. Enter channel IDs manually." + ); + let input = prompt(" Channel ID(s), comma-separated"); + let ids: Vec = input + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect(); + for id in &ids { + validate_channel_id(id)?; + } + return Ok(ids); + } + + let channel_names: Vec = channels.iter().map(|(_, n, _)| format!("#{}", n)).collect(); + let channel_names_refs: Vec<&str> = channel_names.iter().map(|s| s.as_str()).collect(); + + let selected = prompt_checklist(" Select channels (by number):", &channel_names_refs); + let selected_ids: Vec = selected.iter().map(|&i| channels[i].0.clone()).collect(); + + println!(); + cprintln!(C.green, " Selected {} channel(s)", selected_ids.len()); + for id in &selected_ids { + if let Some((_, name, _)) = channels.iter().find(|(cid, _, _)| cid == id) { + println!(" * #{}", name); + } else { + println!(" * {}", id); + } + } + println!(); + + Ok(selected_ids) +} + +// --------------------------------------------------------------------------- +// Section 3: Agent Configuration +// --------------------------------------------------------------------------- + +fn section_agent() -> (String, String, bool) { + println!(); + cprintln!(C.bold, "--- Step 3: Agent Configuration ---"); + println!(); + + print_box(&[ + "Agent Installation Guide", + "", + "claude: npm install -g @anthropic-ai/claude-code", + "kiro: npm install -g @koryhutchison/kiro-cli", + "codex: npm install -g openai-codex (requires OpenAI API key)", + "gemini: npm install -g @google/gemini-cli", + "", + "Make sure the agent is in your PATH before continuing.", + ]); + println!(); + + let choices = ["claude", "kiro", "codex", "gemini"]; + let idx = prompt_choice(" Select agent:", &choices); + let agent = choices[idx]; + + let deploy_choices = ["Local (current directory)", "Docker / k8s"]; + let deploy_idx = prompt_choice(" Deployment target:", &deploy_choices); + let is_local = deploy_idx == 0; + let default_dir = match (is_local, agent) { + (true, _) => ".", + (false, "kiro") => "/home/agent", + (false, _) => "/home/node", + }; + + let working_dir = prompt_default(" Working directory", default_dir); + + cprintln!(C.green, " Agent: {} | Working dir: {}", agent, working_dir); + println!(); + + (agent.to_string(), working_dir, is_local) +} + +// --------------------------------------------------------------------------- +// Section 4: Pool Settings +// --------------------------------------------------------------------------- + +fn section_pool() -> (usize, u64) { + println!(); + cprintln!(C.bold, "--- Step 4: Session Pool ---"); + println!(); + + let max_sessions: usize = prompt_default(" Max sessions", "10").parse().unwrap_or(10); + let ttl_hours: u64 = prompt_default(" Session TTL (hours)", "24") + .parse() + .unwrap_or(24); + + cprintln!( + C.green, + " Max sessions: {} | TTL: {}h", + max_sessions, + ttl_hours + ); + println!(); + + (max_sessions, ttl_hours) +} + +// --------------------------------------------------------------------------- +// Preview & Save +// --------------------------------------------------------------------------- + +fn section_preview_and_save(config_content: &str, output_path: &PathBuf) -> anyhow::Result<()> { + println!(); + cprintln!(C.bold, "--- Preview ---"); + println!(); + println!("{}", mask_bot_token(config_content)); + println!(); + + if output_path.exists() && !prompt_yes_no(" File exists. Overwrite?", false) { + println!(" Saving cancelled."); + return Ok(()); + } + + std::fs::write(output_path, config_content)?; + cprintln!(C.green, "OK config.toml saved to {}", output_path.display()); + println!(); + + Ok(()) +} + +// --------------------------------------------------------------------------- +// Non-interactive guidance +// --------------------------------------------------------------------------- + +fn print_noninteractive_guide() { + print_box(&[ + "Non-Interactive Mode", + "", + "The interactive wizard requires a terminal.", + "Create config.toml manually, then run:", + "", + " openab run config.toml", + "", + "Config format reference:", + " [discord]", + " bot_token = \"YOUR_BOT_TOKEN\"", + " allowed_channels = [\"CHANNEL_ID\"]", + "", + " [agent]", + " command = \"kiro-cli\"", + " args = [\"acp\", \"--trust-all-tools\"]", + " working_dir = \"/home/agent\"", + "", + " [pool]", + " max_sessions = 10", + " session_ttl_hours = 24", + "", + " [reactions]", + " enabled = true", + " remove_after_reply = false", + " ...", + ]); +} + +// --------------------------------------------------------------------------- +// Next steps printer +// --------------------------------------------------------------------------- + +fn print_next_steps(agent: &str, output_path: &Path, is_local: bool) { + println!(); + cprintln!(C.bold, "--- Next Steps ---"); + println!(); + + if is_local { + match agent { + "kiro" => { + cprintln!( + C.cyan, + " 1. Install kiro-cli (see https://kiro.dev for installer)" + ); + cprintln!(C.cyan, " 2. Authenticate:"); + println!(" kiro-cli login --use-device-flow"); + } + "claude" => { + cprintln!(C.cyan, " 1. Install Claude Code + ACP adapter:"); + println!(" npm install -g @anthropic-ai/claude-code @agentclientprotocol/claude-agent-acp"); + cprintln!(C.cyan, " 2. Authenticate:"); + println!(" claude auth login"); + } + "codex" => { + cprintln!(C.cyan, " 1. Install Codex CLI + ACP adapter:"); + println!(" npm install -g @openai/codex @zed-industries/codex-acp"); + cprintln!(C.cyan, " 2. Authenticate:"); + println!(" codex login --device-auth"); + } + "gemini" => { + cprintln!(C.cyan, " 1. Install Gemini CLI:"); + println!(" npm install -g @google/gemini-cli"); + cprintln!( + C.cyan, + " 2. Authenticate via Google OAuth, or set GEMINI_API_KEY in config.toml" + ); + } + _ => {} + } + + println!(); + cprintln!(C.green, " 3. Run the bot:"); + println!(" cargo run -- run {}", output_path.display()); + } else { + cprintln!( + C.cyan, + " Docker image already bundles the agent CLI and ACP adapter." + ); + println!(); + cprintln!(C.cyan, " 1. Deploy with Helm (or your preferred method):"); + println!(" helm install openab openab/openab \\"); + println!( + " --set agents.{}.discord.botToken=\"$BOT_TOKEN\"", + agent + ); + println!(); + cprintln!( + C.cyan, + " 2. Authenticate inside the pod (first time only):" + ); + match agent { + "kiro" => println!( + " kubectl exec -it deployment/openab-kiro -- kiro-cli login --use-device-flow" + ), + "claude" => { + println!(" kubectl exec -it deployment/openab-claude -- claude auth login") + } + "codex" => println!( + " kubectl exec -it deployment/openab-codex -- codex login --device-auth" + ), + "gemini" => { + println!(" Set GEMINI_API_KEY via secret, or exec into the pod for OAuth") + } + _ => {} + } + println!(); + cprintln!(C.green, " See README for full Helm options."); + } + println!(); +} + +// --------------------------------------------------------------------------- +// Main wizard entry point +// --------------------------------------------------------------------------- + +pub fn run_setup(output_path: Option) -> anyhow::Result<()> { + if !is_interactive() { + print_noninteractive_guide(); + return Ok(()); + } + + println!(); + cprintln!( + C.magenta, + "============================================================" + ); + cprintln!( + C.magenta, + " OpenAB Interactive Setup Wizard " + ); + cprintln!( + C.magenta, + "============================================================" + ); + + // Step 1: Discord Guide + Token + section_discord_guide(); + println!(); + let bot_token = prompt_password(" Bot Token (or press Enter to skip)"); + if bot_token.is_empty() { + cprintln!(C.yellow, " Skipped. Set bot_token manually in config.toml"); + println!(); + cprintln!( + C.green, + " Setup complete! Edit config.toml to add your bot token." + ); + return Ok(()); + } + validate_bot_token(&bot_token)?; + + let client = DiscordClient::new(&bot_token); + print!(" Verifying token with Discord API... "); + io::stdout().flush().ok(); + let (_bot_id, bot_username) = client.verify_token()?; + cprintln!(C.green, "OK Logged in as {}", bot_username); + + // Step 2: Channels + let channel_ids = match section_channels(&client) { + Ok(ids) if !ids.is_empty() => ids, + Ok(_) => { + cprintln!(C.yellow, " No channels selected."); + vec![] + } + Err(e) => { + cprintln!(C.yellow, " Channel fetch failed: {}. Enter manually.", e); + let input = prompt(" Channel ID(s), comma-separated"); + let ids: Vec = input + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect(); + for id in &ids { + validate_channel_id(id).map_err(|e| anyhow::anyhow!("{}", e))?; + } + ids + } + }; + + // Step 3: Agent + let (agent, working_dir, is_local) = section_agent(); + + // Step 4: Pool + let (max_sessions, ttl_hours) = section_pool(); + + // Generate + let config_content = generate_config( + &bot_token, + &agent, + channel_ids, + &working_dir, + max_sessions, + ttl_hours, + ); + + // Output + let output_path = output_path.unwrap_or_else(|| PathBuf::from("config.toml")); + section_preview_and_save(&config_content, &output_path)?; + + print_next_steps(&agent, &output_path, is_local); + + Ok(()) +} diff --git a/crates/openab-core/src/slack.rs b/crates/openab-core/src/slack.rs new file mode 100644 index 000000000..94fdcf0a1 --- /dev/null +++ b/crates/openab-core/src/slack.rs @@ -0,0 +1,2329 @@ +use crate::acp::ContentBlock; +use crate::adapter::{ChannelRef, ChatAdapter, MessageRef, SenderContext}; +use crate::bot_turns::{BotTurnTracker, TurnAction, TurnSeverity}; +use crate::config::{AllowBots, AllowUsers, SttConfig}; +use crate::media; +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use futures_util::{SinkExt, StreamExt}; +use std::collections::{HashMap, HashSet}; +use std::sync::{Arc, LazyLock}; +use tokio::sync::watch; +use tokio_tungstenite::tungstenite; +use tracing::{debug, error, info, warn}; + +const SLACK_API: &str = "https://slack.com/api"; + +/// Map Unicode emoji to Slack short names for reactions API. +/// Only covers the default `[reactions.emojis]` set. Custom emoji configured +/// outside this map will fall back to `grey_question`. +fn unicode_to_slack_emoji(unicode: &str) -> &str { + match unicode { + "👀" => "eyes", + "🤔" => "thinking_face", + "🔥" => "fire", + "👨\u{200d}💻" => "technologist", + "⚡" => "zap", + "🆗" => "ok", + "😱" => "scream", + "🚫" => "no_entry_sign", + "😊" => "blush", + "😎" => "sunglasses", + "🫡" => "saluting_face", + "🤓" => "nerd_face", + "😏" => "smirk", + "✌\u{fe0f}" => "v", + "💪" => "muscle", + "🦾" => "mechanical_arm", + "🥱" => "yawning_face", + "😨" => "fearful", + "✅" => "white_check_mark", + "❌" => "x", + "🔧" => "wrench", + "🎤" => "microphone", + _ => "grey_question", + } +} + +// --- SlackAdapter: implements ChatAdapter for Slack --- + +/// TTL for cached user display names (5 minutes). +const USER_CACHE_TTL: std::time::Duration = std::time::Duration::from_secs(300); + +/// Maximum entries in the participation cache before eviction. +const PARTICIPATION_CACHE_MAX: usize = 1000; + +/// Maximum entries in the streams map before eviction (safety net for +/// aborted turns that begin a stream but never reach stream_finish). +const STREAM_CACHE_MAX: usize = 1024; + +#[derive(Default)] +struct StreamEntry { + active: bool, + degraded_buf: String, +} + +pub struct SlackAdapter { + client: reqwest::Client, + bot_token: String, + bot_user_id: tokio::sync::OnceCell, + user_cache: tokio::sync::Mutex>, + /// Cache: Bot ID (B...) → Bot User ID (U...) for trusted_bot_ids matching. + bot_id_cache: tokio::sync::Mutex>, + /// Positive-only cache: thread_ts → cached_at for threads where bot has participated. + participated_threads: tokio::sync::Mutex>, + /// Positive-only cache: thread_ts → cached_at for threads where other bots have posted. + /// Like participation, a thread becoming multi-bot is irreversible (bot messages don't disappear). + multibot_threads: tokio::sync::Mutex>, + /// Persistent disk cache for multibot thread detection (survives restarts). + multibot_cache: crate::multibot_cache::MultibotCache, + /// TTL for participation cache entries (matches session_ttl_hours from config). + session_ttl: std::time::Duration, + /// Assistant mode: stream via chat.startStream + assistant.threads.setStatus. + assistant_mode: bool, + /// streaming message ts → state. active=false = degraded (post+edit fallback). + /// Lifecycle: stream_begin inserts, stream_finish removes; insert_stream + /// bounds the map (STREAM_CACHE_MAX) as a safety net against aborted turns. + streams: tokio::sync::Mutex>, +} + +impl SlackAdapter { + pub fn new( + bot_token: String, + session_ttl: std::time::Duration, + _allow_bot_messages: AllowBots, + assistant_mode: bool, + multibot_cache: crate::multibot_cache::MultibotCache, + ) -> Self { + Self { + // Bound every Slack Web API call; an unbounded inline gating call in the + // read loop could otherwise stall the Socket Mode idle-timeout watchdog. + client: reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(30)) + .build() + .unwrap_or_else(|_| reqwest::Client::new()), + bot_token, + bot_user_id: tokio::sync::OnceCell::new(), + user_cache: tokio::sync::Mutex::new(HashMap::new()), + bot_id_cache: tokio::sync::Mutex::new(HashMap::new()), + participated_threads: tokio::sync::Mutex::new(HashMap::new()), + multibot_threads: tokio::sync::Mutex::new(HashMap::new()), + multibot_cache, + session_ttl, + assistant_mode, + streams: tokio::sync::Mutex::new(HashMap::new()), + } + } + + /// Returns the bot token for use in API calls outside the adapter. + pub fn bot_token(&self) -> &str { + &self.bot_token + } + + /// Eagerly record that another bot has posted in a thread. Called from the + /// event loop when a bot message arrives, so multibot detection doesn't + /// depend on fetching thread history. Idempotent. + async fn note_other_bot_in_thread(&self, thread_ts: &str) { + { + let mut cache = self.multibot_threads.lock().await; + cache + .entry(thread_ts.to_string()) + .or_insert_with(tokio::time::Instant::now); + enforce_cache_bounds(&mut cache, self.session_ttl); + } + // Persist to disk — multibot is irreversible + self.multibot_cache.mark_multibot(thread_ts).await; + } + + + /// Insert a stream entry, bounding the map so aborted turns (begin without a + /// matching finish) can't leak unboundedly. Normal lifecycle: stream_begin + /// inserts, stream_finish removes. + async fn insert_stream(&self, ts: String, entry: StreamEntry) { + let mut map = self.streams.lock().await; + if map.len() >= STREAM_CACHE_MAX { + // Only evict inactive (degraded/stale) streams to avoid cutting off + // active streams mid-turn. If no inactive entries exist, fall through + // and allow the map to grow slightly beyond the soft cap. + let evict: Vec = map + .iter() + .filter(|(_, e)| !e.active) + .map(|(k, _)| k.clone()) + .collect(); + for k in evict { + map.remove(&k); + } + } + map.insert(ts, entry); + } + + /// Accumulate a delta into a degraded stream's buffer and return the new + /// cumulative text. Returns None if no (degraded) stream entry exists for + /// `ts` — never resurrects a removed/absent stream. No network I/O. + async fn accumulate_degraded(&self, ts: &str, delta: &str) -> Option { + let mut map = self.streams.lock().await; + let entry = map.get_mut(ts)?; + entry.degraded_buf.push_str(delta); + Some(entry.degraded_buf.clone()) + } + + /// Get the bot's own Slack user ID (cached after first call). + async fn get_bot_user_id(&self) -> Option<&str> { + self.bot_user_id + .get_or_try_init(|| async { + let resp = self + .api_post("auth.test", serde_json::json!({})) + .await + .map_err(|e| anyhow!("auth.test failed: {e}"))?; + resp["user_id"] + .as_str() + .map(|s| s.to_string()) + .ok_or_else(|| anyhow!("no user_id in auth.test response")) + }) + .await + .inspect_err(|e| warn!(error = %e, "bot user ID unavailable; mention detection may suppress bot messages under Mentions mode")) + .ok() + .map(|s| s.as_str()) + } + + async fn api_post(&self, method: &str, body: serde_json::Value) -> Result { + let resp = self + .client + .post(format!("{SLACK_API}/{method}")) + .header("Authorization", format!("Bearer {}", self.bot_token)) + .header("Content-Type", "application/json; charset=utf-8") + .json(&body) + .send() + .await?; + + let json: serde_json::Value = resp.json().await?; + if json["ok"].as_bool() != Some(true) { + let err = json["error"].as_str().unwrap_or("unknown error"); + return Err(anyhow!("Slack API {method}: {err}")); + } + Ok(json) + } + + /// Call a Slack API method using GET with query parameters. + /// Required for read methods like conversations.replies that don't accept JSON body. + async fn api_get(&self, method: &str, params: &[(&str, &str)]) -> Result { + let resp = self + .client + .get(format!("{SLACK_API}/{method}")) + .header("Authorization", format!("Bearer {}", self.bot_token)) + .query(params) + .send() + .await?; + + let json: serde_json::Value = resp.json().await?; + if json["ok"].as_bool() != Some(true) { + let err = json["error"].as_str().unwrap_or("unknown error"); + return Err(anyhow!("Slack API {method}: {err}")); + } + Ok(json) + } + + /// Resolve a Slack user ID to display name via users.info API. + /// Results are cached for 5 minutes to avoid hitting Slack rate limits. + async fn resolve_user_name(&self, user_id: &str) -> Option { + // Check cache first + { + let cache = self.user_cache.lock().await; + if let Some((name, ts)) = cache.get(user_id) { + if ts.elapsed() < USER_CACHE_TTL { + return Some(name.clone()); + } + } + } + + let resp = self + .api_post("users.info", serde_json::json!({ "user": user_id })) + .await + .ok()?; + let user = resp.get("user")?; + let profile = user.get("profile")?; + let display = profile + .get("display_name") + .and_then(|v| v.as_str()) + .filter(|s| !s.is_empty()); + let real = profile + .get("real_name") + .and_then(|v| v.as_str()) + .filter(|s| !s.is_empty()); + let name = user.get("name").and_then(|v| v.as_str()); + let resolved = display.or(real).or(name)?.to_string(); + + // Cache the result + self.user_cache.lock().await.insert( + user_id.to_string(), + (resolved.clone(), tokio::time::Instant::now()), + ); + + Some(resolved) + } + + /// Resolve a Bot ID (B...) to Bot User ID (U...) via bots.info API. + /// Cached permanently (bot IDs don't change). + async fn resolve_bot_user_id(&self, bot_id: &str) -> Option { + if bot_id.is_empty() { + return None; + } + + { + let cache = self.bot_id_cache.lock().await; + if let Some(user_id) = cache.get(bot_id) { + return Some(user_id.clone()); + } + } + + let resp = self + .api_post("bots.info", serde_json::json!({ "bot": bot_id })) + .await + .inspect_err(|e| { + warn!( + bot_id, + error = %e, + "failed to resolve Slack bot ID via bots.info" + ) + }) + .ok()?; + let user_id = resp.get("bot")?.get("user_id")?.as_str()?.to_string(); + + self.bot_id_cache + .lock() + .await + .insert(bot_id.to_string(), user_id.clone()); + + Some(user_id) + } + + async fn trusted_bot_ids_contains( + &self, + trusted_bot_ids: &HashSet, + event_bot_id: &str, + ) -> bool { + if trusted_bot_ids.is_empty() { + return true; + } + if bot_id_matches_trusted(trusted_bot_ids, event_bot_id, None) { + return true; + } + let resolved = self.resolve_bot_user_id(event_bot_id).await; + bot_id_matches_trusted(trusted_bot_ids, event_bot_id, resolved.as_deref()) + } + + /// Check whether the bot has participated in a Slack thread and whether + /// other bots have also posted in it. + /// Returns `(involved, other_bot_present)`. + /// Involved = parent message @mentions the bot OR any message in thread is from the bot. + /// Fail-closed: returns `(false, false)` on API error (consistent with Discord's approach). + /// Caches positive results only — both states are irreversible. + async fn bot_participated_in_thread(&self, channel: &str, thread_ts: &str) -> (bool, bool) { + let cached_involved = { + let cache = self.participated_threads.lock().await; + cache + .get(thread_ts) + .is_some_and(|ts| ts.elapsed() < self.session_ttl) + }; + let cached_multibot = { + let cache = self.multibot_threads.lock().await; + cache + .get(thread_ts) + .is_some_and(|ts| ts.elapsed() < self.session_ttl) + } || self.multibot_cache.is_multibot(thread_ts); + + // Eager multibot detection from message events populates the cache + // before this runs. When already involved and cached, skip the fetch. + if cached_involved { + return (true, cached_multibot); + } + + let bot_id = match self.get_bot_user_id().await { + Some(id) => id, + None => { + warn!("cannot resolve bot user ID, rejecting (fail-closed)"); + return (false, false); + } + }; + + let resp = self + .api_get( + "conversations.replies", + &[ + ("channel", channel), + ("ts", thread_ts), + ("limit", "200"), + ("inclusive", "true"), + ], + ) + .await; + + let json = match resp { + Ok(json) => json, + Err(e) => { + warn!(channel, thread_ts, error = %e, "failed to fetch thread replies, rejecting (fail-closed)"); + return (false, false); + } + }; + let Some(messages) = json["messages"].as_array() else { + return (false, false); + }; + + let parent_mentions_bot = messages + .first() + .and_then(|m| m["text"].as_str()) + .is_some_and(|text| text_mentions_uid(text, bot_id)); + + let bot_posted = messages.iter().any(|m| m["user"].as_str() == Some(bot_id)); + + let involved = parent_mentions_bot || bot_posted; + // other_bot_present relies solely on early detection + disk cache; + // no longer scanned from fetched messages (200-msg window was unreliable). + let other_bot_present = cached_multibot; + + if involved { + self.cache_participation(thread_ts).await; + } + + (involved, other_bot_present) + } + + /// Insert a positive participation entry, enforcing cache bounds. + async fn cache_participation(&self, thread_ts: &str) { + let mut cache = self.participated_threads.lock().await; + cache.insert(thread_ts.to_string(), tokio::time::Instant::now()); + enforce_cache_bounds(&mut cache, self.session_ttl); + } +} + +/// Shared eviction policy for positive-only caches. +/// First drops expired entries; if still over, drops the oldest half. +fn enforce_cache_bounds( + cache: &mut HashMap, + ttl: std::time::Duration, +) { + if cache.len() <= PARTICIPATION_CACHE_MAX { + return; + } + cache.retain(|_, ts| ts.elapsed() < ttl); + if cache.len() > PARTICIPATION_CACHE_MAX { + let mut entries: Vec<_> = cache.iter().map(|(k, v)| (k.clone(), *v)).collect(); + entries.sort_by_key(|(_, ts)| *ts); + let evict_count = entries.len() / 2; + for (key, _) in entries.into_iter().take(evict_count) { + cache.remove(&key); + } + } +} + +#[async_trait] +impl ChatAdapter for SlackAdapter { + fn platform(&self) -> &'static str { + "slack" + } + + fn message_limit(&self) -> usize { + // Match the Block Kit `markdown` block cap (12k) minus headroom. Messages + // are sent as markdown blocks, so the old 4000 mrkdwn-era limit would + // split long replies (and Markdown tables) across messages needlessly — + // a mid-table split renders as raw pipes. 11_900 keeps typical tables in + // one block and cuts message-spam on long replies. + MARKDOWN_BLOCK_LIMIT + } + + async fn send_message(&self, channel: &ChannelRef, content: &str) -> Result { + let thread_ts = channel.thread_id.as_deref(); + let body = build_post_message_body(&channel.channel_id, thread_ts, content); + let resp = match self.api_post("chat.postMessage", body).await { + Ok(r) => r, + // Graceful degradation: if the `blocks` payload is rejected (workspace + // lacks the markdown block, or content exceeds the cumulative block + // cap), retry text-only so the message still lands (mrkdwn fallback) + // instead of failing outright. + Err(e) if is_block_payload_rejected(&e) => { + warn!(error = %e, "markdown block rejected; retrying chat.postMessage text-only"); + let fallback = build_post_message_text_only(&channel.channel_id, thread_ts, content); + self.api_post("chat.postMessage", fallback).await? + } + Err(e) => return Err(e), + }; + let ts = resp["ts"] + .as_str() + .ok_or_else(|| anyhow!("no ts in chat.postMessage response"))?; + Ok(MessageRef { + channel: ChannelRef { + platform: "slack".into(), + channel_id: channel.channel_id.clone(), + thread_id: channel.thread_id.clone(), + parent_id: None, + origin_event_id: None, + }, + message_id: ts.to_string(), + }) + } + + async fn create_thread( + &self, + channel: &ChannelRef, + trigger_msg: &MessageRef, + _title: &str, + ) -> Result { + // Slack threads are implicit — posting with thread_ts creates/continues a thread. + Ok(ChannelRef { + platform: "slack".into(), + channel_id: channel.channel_id.clone(), + thread_id: Some(trigger_msg.message_id.clone()), + parent_id: None, + origin_event_id: None, + }) + } + + async fn add_reaction(&self, msg: &MessageRef, emoji: &str) -> Result<()> { + let name = unicode_to_slack_emoji(emoji); + match self + .api_post( + "reactions.add", + serde_json::json!({ + "channel": msg.channel.channel_id, + "timestamp": msg.message_id, + "name": name, + }), + ) + .await + { + Ok(_) => Ok(()), + Err(e) if e.to_string().contains("already_reacted") => Ok(()), + Err(e) => Err(e), + } + } + + async fn remove_reaction(&self, msg: &MessageRef, emoji: &str) -> Result<()> { + let name = unicode_to_slack_emoji(emoji); + match self + .api_post( + "reactions.remove", + serde_json::json!({ + "channel": msg.channel.channel_id, + "timestamp": msg.message_id, + "name": name, + }), + ) + .await + { + Ok(_) => Ok(()), + Err(e) if e.to_string().contains("no_reaction") => Ok(()), + Err(e) => Err(e), + } + } + + async fn edit_message(&self, msg: &MessageRef, content: &str) -> Result<()> { + let body = build_update_body(&msg.channel.channel_id, &msg.message_id, content); + match self.api_post("chat.update", body).await { + Ok(_) => Ok(()), + // See send_message: degrade to text-only if the blocks payload is rejected. + Err(e) if is_block_payload_rejected(&e) => { + warn!(error = %e, "markdown block rejected; retrying chat.update text-only"); + let fallback = + build_update_text_only(&msg.channel.channel_id, &msg.message_id, content); + self.api_post("chat.update", fallback).await?; + Ok(()) + } + Err(e) => Err(e), + } + } + + fn use_streaming(&self, other_bot_present: bool) -> bool { + !other_bot_present + } + + fn renders_native_tables(&self) -> bool { + true + } + + fn uses_assistant_status(&self) -> bool { + self.assistant_mode + } + + fn uses_native_streaming(&self, other_bot_present: bool) -> bool { + let native = self.assistant_mode && !other_bot_present; + debug!( + assistant_mode = self.assistant_mode, + other_bot_present, + native, + "slack assistant_mode decision (per turn)" + ); + native + } + + async fn stream_begin( + &self, + channel: &ChannelRef, + recipient: Option<(String, String)>, + ) -> Result { + let thread_ts = channel.thread_id.clone().unwrap_or_default(); + // recipient is bound to this turn (captured at message arrival, carried on + // BufferedMessage) — no shared thread cache, so no cross-turn race. + let make_ref = |ts: String| MessageRef { + channel: ChannelRef { + platform: "slack".into(), + channel_id: channel.channel_id.clone(), + thread_id: channel.thread_id.clone(), + parent_id: None, + origin_event_id: None, + }, + message_id: ts, + }; + + if let Some((user_id, team_id)) = recipient { + let body = build_start_stream_body(&channel.channel_id, &thread_ts, &user_id, &team_id); + match self.api_post("chat.startStream", body).await { + Ok(resp) => { + if let Some(ts) = resp["ts"].as_str() { + self.insert_stream( + ts.to_string(), + StreamEntry { active: true, degraded_buf: String::new() }, + ) + .await; + return Ok(make_ref(ts.to_string())); + } + error!("chat.startStream ok but no ts; falling back to post+edit"); + } + Err(e) => { + error!(error = %e, "chat.startStream failed; falling back to post+edit for this turn"); + } + } + } else { + // Expected for bot-authored turns (no recipient bound) and non-user + // triggers, so warn! rather than error! to avoid on-call noise. + warn!(thread_ts, "no recipient for turn; falling back to post+edit"); + } + + // Degraded fallback: plain placeholder via send_message; mark inactive. + let msg = self.send_message(channel, "…").await?; + self.insert_stream( + msg.message_id.clone(), + StreamEntry { active: false, degraded_buf: String::new() }, + ) + .await; + Ok(msg) + } + + async fn stream_append(&self, msg: &MessageRef, delta: &str) -> Result<()> { + let ts = &msg.message_id; + let active = { + let map = self.streams.lock().await; + map.get(ts).map(|e| e.active).unwrap_or(false) + }; + if active { + let body = build_append_stream_body(&msg.channel.channel_id, ts, delta); + if let Err(e) = self.api_post("chat.appendStream", body).await { + warn!(error = %e, "chat.appendStream failed (cosmetic; final replace will correct)"); + } + } else if let Some(cumulative) = self.accumulate_degraded(ts, delta).await { + let _ = self.edit_message(msg, &cumulative).await; // cosmetic mid-stream + } + Ok(()) + } + + async fn stream_finish(&self, msg: &MessageRef, final_content: &str) -> Result<()> { + let ts = &msg.message_id; + let active = { + let map = self.streams.lock().await; + map.get(ts).map(|e| e.active).unwrap_or(false) + }; + if active { + // Close the native stream WITHOUT re-sending content. The reply was + // already streamed live via chat.appendStream; stopStream's + // `markdown_text` *appends* (it does not replace), so passing the full + // content here duplicates the whole reply (#1055). Close only, then + // replace with the finalized content via chat.update below. + let close = serde_json::json!({ "channel": msg.channel.channel_id, "ts": ts }); + if let Err(e) = self.api_post("chat.stopStream", close).await { + warn!(error = %e, "chat.stopStream(close) failed; continuing to final replace"); + } + } + // Replace with the finalized content (Block Kit markdown). For the active + // path this overwrites the streamed preview with a single clean copy + // (rich rendering + native tables); for the degraded path it is the final + // post+edit update. chat.update replaces, so there is no duplication. + if let Err(e) = self.edit_message(msg, final_content).await { + if active { + // The native stream already delivered the reply (chat.appendStream), + // and stopStream left it in place. Do NOT postMessage a fallback + // here — that would post a duplicate copy. Keep the streamed + // content as the final message. + warn!(error = %e, "final chat.update failed; keeping streamed content (no duplicate post)"); + } else { + // Degraded path: no streamed content exists (post+edit placeholder), + // so post the final as a new message to avoid losing the reply. + warn!(error = %e, "final chat.update failed; trying postMessage"); + if let Err(e2) = self.send_message(&msg.channel, final_content).await { + error!(error = %e2, "final postMessage also failed; reply may be incomplete"); + } + } + } + self.streams.lock().await.remove(ts); + Ok(()) + } + + async fn set_status(&self, channel: &ChannelRef, status: &str) -> Result<()> { + let thread_ts = channel.thread_id.clone().unwrap_or_default(); + let body = build_set_status_body(&channel.channel_id, &thread_ts, status); + if let Err(e) = self.api_post("assistant.threads.setStatus", body).await { + warn!(error = %e, status, "assistant.threads.setStatus failed (cosmetic)"); + } + Ok(()) + } +} + +// --- Socket Mode event loop --- + +/// Hard cap on consecutive bot messages in a thread. Prevents runaway loops. +const MAX_CONSECUTIVE_BOT_TURNS: usize = 1000; + +/// Socket Mode keepalive. Slack's inbound WebSocket can go half-open (e.g. a NAT +/// idle-timeout silently drops inbound frames with no Close/FIN), which leaves +/// `read.next()` blocked forever, so the reconnect loop never fires and the bot +/// goes deaf while still showing as connected. We proactively ping and force a +/// reconnect when no inbound frame (including Slack's own pings) has arrived +/// within the idle window. Reconnect backoff mirrors the gateway adapter. +const PING_INTERVAL_SECS: u64 = 30; +const IDLE_TIMEOUT_SECS: u64 = 75; +const MAX_BACKOFF_SECS: u64 = 30; + +/// Next reconnect delay: double, capped. Reset to 1 on a successful connect. +fn next_backoff(cur: u64) -> u64 { + (cur * 2).min(MAX_BACKOFF_SECS) +} + +/// The socket is considered dead (half-open) when no inbound frame has arrived +/// within `timeout`; Slack sends periodic pings, so silence past the window +/// means the inbound path is gone. +fn socket_idle(since_last_inbound: std::time::Duration, timeout: std::time::Duration) -> bool { + since_last_inbound >= timeout +} + +/// Run the Slack adapter using Socket Mode (persistent WebSocket, no public URL needed). +/// Reconnects automatically on disconnect. +#[allow(clippy::too_many_arguments)] +pub async fn run_slack_adapter( + adapter: Arc, + app_token: String, + allow_all_channels: bool, + allow_all_users: bool, + allowed_channels: HashSet, + allowed_users: HashSet, + allow_bot_messages: AllowBots, + trusted_bot_ids: HashSet, + allow_user_messages: AllowUsers, + max_bot_turns: u32, + stt_config: SttConfig, + mut shutdown_rx: watch::Receiver, + dispatcher: Arc, +) -> Result<()> { + let bot_token = adapter.bot_token().to_string(); + let bot_turns = Arc::new(tokio::sync::Mutex::new(BotTurnTracker::new(max_bot_turns))); + // Warm the bot-user-id cache once so the per-message path never does the + // cold-cache `auth.test` inline in the read loop. + let _ = adapter.get_bot_user_id().await; + let mut backoff_secs = 1u64; + + loop { + // Check for shutdown before (re)connecting + if *shutdown_rx.borrow() { + info!("Slack adapter shutting down"); + return Ok(()); + } + + let ws_url = match get_socket_mode_url(&app_token).await { + Ok(url) => url, + Err(e) => { + error!(err = %e, backoff = backoff_secs, "failed to get Socket Mode URL, retrying"); + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_secs(backoff_secs)) => {} + _ = shutdown_rx.changed() => { return Ok(()); } + } + backoff_secs = next_backoff(backoff_secs); + continue; + } + }; + info!(url = %ws_url, "connecting to Slack Socket Mode"); + + match tokio_tungstenite::connect_async(&ws_url).await { + Ok((ws_stream, _)) => { + info!("Slack Socket Mode connected"); + backoff_secs = 1; // reset on success + let (mut write, mut read) = ws_stream.split(); + let mut ping_interval = + tokio::time::interval(std::time::Duration::from_secs(PING_INTERVAL_SECS)); + ping_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + let mut last_inbound = std::time::Instant::now(); + + loop { + tokio::select! { + msg_result = read.next() => { + last_inbound = std::time::Instant::now(); + let Some(msg_result) = msg_result else { break }; + match msg_result { + Ok(tungstenite::Message::Text(text)) => { + let envelope: serde_json::Value = + match serde_json::from_str(&text) { + Ok(v) => v, + Err(_) => continue, + }; + + // Acknowledge the envelope immediately + if let Some(envelope_id) = envelope["envelope_id"].as_str() { + let ack = serde_json::json!({"envelope_id": envelope_id}); + let _ = write + .send(tungstenite::Message::Text(ack.to_string())) + .await; + } + + // Slash commands and interactive block_actions aren't + // handled on Slack: slash commands are blocked by Slack + // in thread composers, and the channel-level delivery + // lacks the thread_ts needed to route to a session. + // Ack only; ignore payload. + match envelope["type"].as_str() { + Some("slash_commands") | Some("interactive") => { + debug!( + envelope_type = envelope["type"].as_str().unwrap_or(""), + "ignoring Slack envelope type (not supported on this adapter)" + ); + continue; + } + _ => {} + } + + // Route events + if envelope["type"].as_str() == Some("events_api") { + let event = &envelope["payload"]["event"]; + let event_type = event["type"].as_str().unwrap_or(""); + match event_type { + "app_mention" => { + // Apply bot gating for app_mention events (same rules as message events) + let is_bot = event["bot_id"].is_string() + || event["subtype"].as_str() == Some("bot_message"); + if is_bot { + match allow_bot_messages { + AllowBots::Off => { continue; } + AllowBots::Mentions | AllowBots::All => { + if !trusted_bot_ids.is_empty() { + let event_bot_id = event["bot_id"].as_str().unwrap_or(""); + let is_trusted = adapter + .trusted_bot_ids_contains(&trusted_bot_ids, event_bot_id) + .await; + if !is_trusted { + debug!(event_bot_id, "bot not in trusted_bot_ids, ignoring app_mention"); + continue; + } + } + } + } + } + let event = event.clone(); + let adapter = adapter.clone(); + let bot_token = bot_token.clone(); + let allowed_channels = allowed_channels.clone(); + let allowed_users = allowed_users.clone(); + let stt_config = stt_config.clone(); + let dispatcher = dispatcher.clone(); + let team_id = envelope["payload"]["team_id"] + .as_str() + .unwrap_or("") + .to_string(); + tokio::spawn(async move { + handle_message( + &event, + &team_id, + &adapter, + &bot_token, + allow_all_channels, + allow_all_users, + &allowed_channels, + &allowed_users, + &stt_config, + &dispatcher, + ) + .await; + }); + } + "message" => { + let channel_id = event["channel"].as_str().unwrap_or(""); + let has_thread = event["thread_ts"].is_string(); + let is_bot = event["bot_id"].is_string() + || event["subtype"].as_str() == Some("bot_message"); + let subtype = event["subtype"].as_str().unwrap_or(""); + let msg_text = event["text"].as_str().unwrap_or(""); + let bot_uid_opt = adapter.get_bot_user_id().await.map(|s| s.to_string()); + let mentions_bot = bot_uid_opt + .as_ref() + .is_some_and(|bot_uid| text_mentions_uid(msg_text, bot_uid)); + let is_dm = channel_id.starts_with('D'); + let event_user_id = event["user"].as_str(); + let is_own_bot_msg = is_bot + && bot_uid_opt.as_deref().is_some() + && event_user_id == bot_uid_opt.as_deref(); + + debug!( + channel_id, + has_thread, + is_bot, + is_dm, + subtype, + mentions_bot, + text = msg_text, + "message event received" + ); + + // Skip non-message subtypes + let skip_subtype = matches!(subtype, + "message_changed" | "message_deleted" | + "channel_join" | "channel_leave" | + "channel_topic" | "channel_purpose" + ); + if skip_subtype { continue; } + + // --- Eager multibot detection --- + // Runs before self-check and bot gating so we always detect + // other bots even when allow_bot_messages=Off filters them out. + // Matches Discord #481 ordering. + if is_bot && !is_own_bot_msg { + if let Some(thread_ts) = event["thread_ts"].as_str() { + adapter.note_other_bot_in_thread(thread_ts).await; + } + } + + // --- Bot turn tracking --- + // Runs before self-check so ALL bot messages (including own) + // count toward the per-thread limit. Matches Discord #483. + // Keyed on thread_ts when in a thread, else channel:ts. + // Non-thread messages get a unique key per message, so the + // counter never accumulates — intentional, because bot-to-bot + // loops only happen inside threads. + let turn_key = if let Some(thread_ts) = event["thread_ts"].as_str() { + thread_ts.to_string() + } else { + format!("{}:{}", channel_id, event["ts"].as_str().unwrap_or("")) + }; + // Classify under the lock (order-sensitive, kept in the read + // loop), but run any warning send AFTER releasing it; holding + // the tracker mutex across `chat.postMessage` would stall turn + // tracking for every thread, not just this one. + let turn_action = { + let mut tracker = bot_turns.lock().await; + if is_bot { + tracker.classify_bot_message(&turn_key) + } else { + if is_plain_user_message(subtype, msg_text) { + tracker.on_human_message(&turn_key); + } + TurnAction::Continue + } + }; + match turn_action { + TurnAction::Continue => {} + TurnAction::SilentStop => continue, + TurnAction::WarnAndStop { severity, turns, user_message } => { + match severity { + TurnSeverity::Hard => warn!(channel_id, turns, "hard bot turn limit reached"), + TurnSeverity::Soft => info!(channel_id, turns, max = max_bot_turns, "soft bot turn limit reached"), + } + let channel_allowed = allow_all_channels + || allowed_channels.contains(channel_id); + if !is_own_bot_msg && channel_allowed { + let warn_channel = ChannelRef { + platform: "slack".into(), + channel_id: channel_id.to_string(), + thread_id: event["thread_ts"].as_str().map(|s| s.to_string()), + parent_id: None, + origin_event_id: None, + }; + let adapter = adapter.clone(); + tokio::spawn(async move { + if let Err(e) = adapter.send_message(&warn_channel, &user_message).await { + warn!(error = %e, "failed to send bot turn limit warning"); + } + }); + } + continue; + } + } + + // Ignore own bot messages (after counting toward turns) + if is_own_bot_msg { continue; } + + // Skip messages that @mention the bot — app_mention handles those + // (except in DMs where app_mention doesn't fire) + if mentions_bot && !is_dm { continue; } + + // --- Bot message gating --- + if is_bot { + let event_bot_id = event["bot_id"].as_str().unwrap_or(""); + match allow_bot_messages { + AllowBots::Off => { continue; } + AllowBots::Mentions => { + if !mentions_bot { continue; } + } + AllowBots::All => { + // Loop protection: count consecutive bot msgs (fail-closed) + if let Some(thread_ts) = event["thread_ts"].as_str() { + let cap = MAX_CONSECUTIVE_BOT_TURNS; + let limit_str = std::cmp::min(cap + 1, 1000).to_string(); + match adapter.api_get( + "conversations.replies", + &[ + ("channel", channel_id), + ("ts", thread_ts), + ("limit", &limit_str), + ("inclusive", "true"), + ], + ).await { + Ok(resp) => { + if let Some(msgs) = resp["messages"].as_array() { + let consecutive = msgs.iter().rev() + .take_while(|m| { + m["bot_id"].is_string() + || m["subtype"].as_str() == Some("bot_message") + }) + .count(); + if consecutive >= cap { + warn!(channel_id, cap, "bot turn cap reached, ignoring"); + continue; + } + } + } + Err(e) => { + warn!(channel_id, thread_ts, error = %e, "failed to fetch thread for bot loop check, rejecting (fail-closed)"); + continue; + } + } + } + } + } + // Check trusted_bot_ids + if !trusted_bot_ids.is_empty() { + let is_trusted = adapter + .trusted_bot_ids_contains(&trusted_bot_ids, event_bot_id) + .await; + if !is_trusted { + debug!(event_bot_id, "bot not in trusted_bot_ids, ignoring"); + continue; + } + } + // Bot messages must be in a thread (no top-level bot processing) + if !has_thread { continue; } + } + + // --- User message gating --- + if !is_bot { + if is_dm { + // DM: implicit mention — always process + } else { + match allow_user_messages { + AllowUsers::Mentions => { + if !mentions_bot { continue; } + } + AllowUsers::Involved => { + if !has_thread { + continue; + } + let thread_ts = event["thread_ts"].as_str().unwrap_or(""); + let (involved, _) = adapter + .bot_participated_in_thread(channel_id, thread_ts) + .await; + if !involved { + debug!(channel_id, thread_ts, "bot not involved in thread, ignoring"); + continue; + } + } + AllowUsers::MultibotMentions => { + if !has_thread { + continue; + } + let thread_ts = event["thread_ts"].as_str().unwrap_or(""); + let (involved, other_bot) = adapter + .bot_participated_in_thread(channel_id, thread_ts) + .await; + if !involved { + debug!(channel_id, thread_ts, "bot not involved in thread, ignoring"); + continue; + } + // In multi-bot threads, require @mention — mirrors + // Discord's `should_process_user_message`. In practice + // mention-bearing message events are already deduped + // earlier (app_mention handles the @-path), so this + // branch rarely sees `mentions_bot == true`, but keep + // the explicit check so the logic is self-consistent + // and survives changes to the earlier dedup. + if other_bot && !mentions_bot { + debug!(channel_id, thread_ts, "multi-bot thread without @mention, ignoring"); + continue; + } + } + } + } + } + + // Dispatch to handle_message (per-thread serialization comes + // from Dispatcher consumer task in batched mode and from + // pool.with_connection in per-message mode). + let team_id = envelope["payload"]["team_id"] + .as_str() + .unwrap_or("") + .to_string(); + let event = event.clone(); + let adapter = adapter.clone(); + let bot_token = bot_token.clone(); + let allowed_channels = allowed_channels.clone(); + let allowed_users = allowed_users.clone(); + let stt_config = stt_config.clone(); + let dispatcher = dispatcher.clone(); + tokio::spawn(async move { + handle_message( + &event, + &team_id, + &adapter, + &bot_token, + allow_all_channels, + allow_all_users, + &allowed_channels, + &allowed_users, + &stt_config, + &dispatcher, + ) + .await; + }); + } + _ => {} + } + } + } + Ok(tungstenite::Message::Ping(data)) => { + let _ = write.send(tungstenite::Message::Pong(data)).await; + } + Ok(tungstenite::Message::Close(_)) => { + warn!("Slack Socket Mode connection closed by server"); + break; + } + Err(e) => { + error!("Socket Mode read error: {e}"); + break; + } + _ => {} + } + } + _ = ping_interval.tick() => { + if socket_idle( + last_inbound.elapsed(), + std::time::Duration::from_secs(IDLE_TIMEOUT_SECS), + ) { + warn!( + idle_secs = last_inbound.elapsed().as_secs(), + "Slack Socket Mode idle past timeout (likely half-open), forcing reconnect" + ); + break; + } + if let Err(e) = write.send(tungstenite::Message::Ping(Vec::new())).await { + warn!(error = %e, "Slack Socket Mode ping failed, reconnecting"); + break; + } + } + _ = shutdown_rx.changed() => { + info!("Slack adapter received shutdown signal"); + let _ = write.send(tungstenite::Message::Close(None)).await; + return Ok(()); + } + } + } + } + Err(e) => { + error!(err = %e, backoff = backoff_secs, "failed to connect to Slack Socket Mode, retrying"); + } + } + + warn!(backoff = backoff_secs, "reconnecting to Slack Socket Mode"); + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_secs(backoff_secs)) => {} + _ = shutdown_rx.changed() => { return Ok(()); } + } + backoff_secs = next_backoff(backoff_secs); + } +} + +/// Call apps.connections.open to get a WebSocket URL for Socket Mode. +async fn get_socket_mode_url(app_token: &str) -> Result { + let client = reqwest::Client::new(); + let resp = client + .post(format!("{SLACK_API}/apps.connections.open")) + .header("Authorization", format!("Bearer {app_token}")) + .header("Content-Type", "application/x-www-form-urlencoded") + .send() + .await?; + let json: serde_json::Value = resp.json().await?; + if json["ok"].as_bool() != Some(true) { + let err = json["error"].as_str().unwrap_or("unknown"); + return Err(anyhow!("apps.connections.open: {err}")); + } + json["url"] + .as_str() + .map(|s| s.to_string()) + .ok_or_else(|| anyhow!("no url in apps.connections.open response")) +} + +#[allow(clippy::too_many_arguments)] +async fn handle_message( + event: &serde_json::Value, + team_id: &str, + adapter: &Arc, + bot_token: &str, + allow_all_channels: bool, + allow_all_users: bool, + allowed_channels: &HashSet, + allowed_users: &HashSet, + stt_config: &SttConfig, + dispatcher: &Arc, +) { + let channel_id = match event["channel"].as_str() { + Some(ch) => ch.to_string(), + None => return, + }; + // Bot messages may lack "user" field — fall back to "bot_id" as sender identifier + let user_id = match event["user"].as_str().or_else(|| event["bot_id"].as_str()) { + Some(u) => u.to_string(), + None => return, + }; + let is_bot_msg = + event["bot_id"].is_string() || event["subtype"].as_str() == Some("bot_message"); + let text = match event["text"].as_str() { + Some(t) => t.to_string(), + None => return, + }; + let ts = match event["ts"].as_str() { + Some(ts) => ts.to_string(), + None => return, + }; + let thread_ts = event["thread_ts"].as_str().map(|s| s.to_string()); + + // Check allowed channels + if !allow_all_channels && !allowed_channels.contains(&channel_id) { + return; + } + + // Check allowed users — skip for bot messages (they go through trusted_bot_ids instead) + if !is_bot_msg && !allow_all_users && !allowed_users.contains(&user_id) { + tracing::info!(user_id, "denied Slack user, ignoring"); + let msg_ref = MessageRef { + channel: ChannelRef { + platform: "slack".into(), + channel_id: channel_id.clone(), + thread_id: thread_ts.clone(), + parent_id: None, + origin_event_id: None, + }, + message_id: ts.clone(), + }; + let _ = adapter.add_reaction(&msg_ref, "🚫").await; + return; + } + + // Capture the native-streaming recipient for THIS turn, now that the sender has + // passed the channel + user allow-list checks above (so denied/unauthorized + // senders are never recorded). It rides on the per-turn BufferedMessage to + // stream_begin — no shared thread cache, no cross-turn race. Real users only: + // bot IDs (B...) are rejected by chat.startStream's recipient_user_id, and an + // empty team_id would silently degrade, so we surface that. + let stream_recipient = if is_bot_msg { + None + } else { + if team_id.is_empty() { + warn!("empty team_id; chat.startStream will degrade to post+edit"); + } + Some((user_id.clone(), team_id.to_string())) + }; + + // Resolve mentions: strip only this bot's own trigger mention so the LLM + // can still @-mention other users in its reply. + let bot_id = adapter.get_bot_user_id().await; + let prompt = resolve_slack_mentions(&text, bot_id); + + // Process file attachments (images, audio) + let files = event["files"].as_array(); + let has_files = files.is_some_and(|f| !f.is_empty()); + + if prompt.is_empty() && !has_files { + return; + } + + // Caps mirror Discord's text-file attachment flow (PR #291) so both + // adapters apply the same limits: 5 files or 1 MB of text per message. + const TEXT_TOTAL_CAP: u64 = 1024 * 1024; + const TEXT_FILE_COUNT_CAP: u32 = 5; + + let mut extra_blocks = Vec::new(); + let mut echo_entries: Vec = Vec::new(); + let mut text_file_bytes: u64 = 0; + let mut text_file_count: u32 = 0; + let mut failed_image_files: Vec = Vec::new(); + + if let Some(files) = files { + for file in files { + let mimetype_raw = file["mimetype"].as_str().unwrap_or(""); + let mimetype = strip_mime_params(mimetype_raw); + let filename = file["name"].as_str().unwrap_or("file"); + let size = file["size"].as_u64().unwrap_or(0); + // Slack private files require Bearer token to download + let url = slack_file_download_url(file); + + if url.is_empty() { + continue; + } + + if media::is_audio_mime(mimetype) { + if stt_config.enabled { + match media::download_and_transcribe( + url, + filename, + mimetype, + size, + stt_config, + Some(bot_token), + ) + .await + { + Some(transcript) => { + debug!( + filename, + chars = transcript.len(), + "voice transcript injected" + ); + extra_blocks.insert( + 0, + ContentBlock::Text { + text: format!("[Voice message transcript]: {transcript}"), + }, + ); + echo_entries.push(crate::stt::EchoEntry::Success(transcript)); + } + None => { + warn!(filename, "STT failed for voice attachment"); + echo_entries.push(crate::stt::EchoEntry::Failed); + } + } + } else { + debug!(filename, "skipping audio attachment (STT disabled)"); + let msg_ref = MessageRef { + channel: ChannelRef { + platform: "slack".into(), + channel_id: channel_id.clone(), + thread_id: thread_ts.clone(), + parent_id: None, + origin_event_id: None, + }, + message_id: ts.clone(), + }; + let _ = adapter.add_reaction(&msg_ref, "🎤").await; + } + } else if media::is_text_file(filename, Some(mimetype)) { + if text_file_count >= TEXT_FILE_COUNT_CAP { + debug!( + filename, + count = text_file_count, + "text file count cap reached, skipping" + ); + continue; + } + // Pre-check with Slack-reported size as a fast path when the + // field is populated. Slack can report `size == 0` for + // externally-backed files, so this is advisory only — the + // authoritative cap check happens after download using + // `actual_bytes`. + if size > 0 && text_file_bytes + size > TEXT_TOTAL_CAP { + debug!( + filename, + total = text_file_bytes, + "text attachments total exceeds 1MB cap, skipping remaining" + ); + continue; + } + if let Some((block, actual_bytes)) = + media::download_and_read_text_file(url, filename, size, Some(bot_token)).await + { + if text_file_bytes + actual_bytes > TEXT_TOTAL_CAP { + debug!( + filename, + running = text_file_bytes, + actual = actual_bytes, + "text attachments total exceeds 1MB cap after download, dropping file", + ); + continue; + } + text_file_bytes += actual_bytes; + text_file_count += 1; + debug!(filename, "adding text file attachment"); + extra_blocks.push(block); + } + } else { + match media::download_and_encode_image( + url, + Some(mimetype), + filename, + size, + Some(bot_token), + ) + .await + { + Ok(block) => { + debug!(filename, "adding image attachment"); + extra_blocks.push(block); + } + Err(media::MediaFetchError::NotAnImage) => {} + Err(media::MediaFetchError::SizeExceeded { actual, limit }) => { + warn!(filename, actual, limit, "image exceeds size limit"); + failed_image_files.push(filename.to_string()); + } + Err( + media::MediaFetchError::UnsupportedResponseType { .. } + | media::MediaFetchError::InvalidImageBody { .. }, + ) => { + warn!( + filename, + "image validation failed; server may have returned non-image content" + ); + failed_image_files.push(filename.to_string()); + } + Err(media::MediaFetchError::ProcessingFailed(ref e)) => { + warn!(filename, error = %e, "image post-processing failed"); + failed_image_files.push(filename.to_string()); + } + Err(media::MediaFetchError::HttpStatus(status)) + if status.is_client_error() => + { + warn!(filename, %status, "image download denied"); + failed_image_files.push(filename.to_string()); + } + Err(e) => { + warn!(filename, error = %e, "image download failed"); + failed_image_files.push(filename.to_string()); + } + } + } + } + } + + // Notify user if any images couldn't be processed. + if !failed_image_files.is_empty() { + let warn_channel = ChannelRef { + platform: "slack".into(), + channel_id: channel_id.clone(), + thread_id: thread_ts.clone().or_else(|| Some(ts.clone())), + parent_id: None, + origin_event_id: None, + }; + let file_list = failed_image_files + .iter() + .map(|n| sanitize_slack_filename(n)) + .collect::>() + .join("`, `"); + let msg = format!( + ":warning: I couldn't process the file(s) you shared (`{file_list}`). \ + This can happen when the bot lacks the `files:read` OAuth scope, \ + the file format isn't supported (PNG/JPEG/GIF/WebP only), \ + or the file is too large." + ); + if let Err(e) = adapter.send_message(&warn_channel, &msg).await { + warn!(error = %e, "failed to send image validation warning to user"); + } + } + + // Resolve Slack display name (best-effort, fallback to user_id) + let display_name = adapter + .resolve_user_name(&user_id) + .await + .unwrap_or_else(|| user_id.clone()); + + let sender = SenderContext { + schema: "openab.sender.v1".into(), + sender_id: user_id.clone(), + sender_name: display_name.clone(), + display_name, + channel: "slack".into(), + channel_id: channel_id.clone(), + thread_id: thread_ts.clone(), + is_bot: is_bot_msg, + timestamp: Some(crate::timestamp::slack_ts_to_iso8601(&ts)), + message_id: Some(ts.clone()), + receiver_id: bot_id.map(|id| id.to_string()), + }; + + let trigger_msg = MessageRef { + channel: ChannelRef { + platform: "slack".into(), + channel_id: channel_id.clone(), + thread_id: thread_ts.clone(), + parent_id: None, + origin_event_id: None, + }, + message_id: ts.clone(), + }; + + // Determine thread: if already in a thread, continue it; otherwise start a new thread + let thread_channel = ChannelRef { + platform: "slack".into(), + channel_id: channel_id.clone(), + thread_id: Some(thread_ts.unwrap_or(ts)), + parent_id: None, + origin_event_id: None, + }; + + // Serialize sender context with Slack-native key names so agents calling + // the Slack API directly see "thread_ts" rather than the generic "thread_id". + let sender_json = { + let mut v = serde_json::to_value(&sender).unwrap(); + if let Some(obj) = v.as_object_mut() { + if let Some(tid) = obj.remove("thread_id") { + obj.insert("thread_ts".to_string(), tid); + } + } + v.to_string() + }; + + let adapter_dyn: Arc = adapter.clone(); + let other_bot_present = { + let cache = adapter.multibot_threads.lock().await; + thread_channel.thread_id.as_deref().is_some_and(|ts| { + cache + .get(ts) + .is_some_and(|inst| inst.elapsed() < adapter.session_ttl) + }) + } || thread_channel + .thread_id + .as_deref() + .is_some_and(|ts| adapter.multibot_cache.is_multibot(ts)); + + // Best-effort echo before the agent reply so the user can verify STT. + crate::stt::post_echo( + &adapter_dyn, + &thread_channel, + &trigger_msg, + &echo_entries, + stt_config, + ) + .await; + + let thread_id = thread_channel + .thread_id + .as_deref() + .unwrap_or(&thread_channel.channel_id); + let thread_key = dispatcher.key("slack", thread_id, &sender.sender_id); + let estimated_tokens = crate::dispatch::estimate_tokens(&prompt, &extra_blocks); + let buf_msg = crate::dispatch::BufferedMessage { + sender_json, + sender_name: sender.sender_name.clone(), + prompt, + extra_blocks, + trigger_msg, + arrived_at: std::time::Instant::now(), + estimated_tokens, + other_bot_present, + recipient: stream_recipient, + }; + if let Err(e) = dispatcher + .submit(thread_key, thread_channel, adapter_dyn, buf_msg) + .await + { + error!("Slack dispatcher submit error: {e}"); + } +} + +/// Strip all occurrences of the bot's own `<@BOT_UID>` or `<@BOT_UID|handle>` mention. +/// Other users' mentions stay intact so the LLM can @-mention them back. +/// If the bot UID isn't known, fall back to returning the text trimmed — +/// safer than stripping all mentions and losing user addressability. +fn resolve_slack_mentions(text: &str, bot_id: Option<&str>) -> String { + let Some(id) = bot_id else { + return text.trim().to_string(); + }; + let prefix = format!("<@{id}"); + let mut out = String::with_capacity(text.len()); + let mut s = text; + while let Some(pos) = s.find(&prefix) { + let after = &s[pos + prefix.len()..]; + match after.as_bytes().first() { + Some(b'>') => { + out.push_str(&s[..pos]); + s = &after[1..]; + } + Some(b'|') => { + if let Some(close) = after.find('>') { + out.push_str(&s[..pos]); + s = &after[close + 1..]; + } else { + out.push_str(&s[..pos + prefix.len()]); + s = after; + } + } + _ => { + out.push_str(&s[..pos + prefix.len()]); + s = after; + } + } + } + out.push_str(s); + out.trim().to_string() +} + +/// Pick the best download URL for a Slack file object. `url_private_download` +/// streams the raw bytes; `url_private` is the fallback for older file shapes. +/// Returns `""` when neither is present (caller should skip the file). +fn slack_file_download_url(file: &serde_json::Value) -> &str { + file["url_private_download"] + .as_str() + .or_else(|| file["url_private"].as_str()) + .unwrap_or("") +} + +/// Strip MIME parameters so type-detection helpers see the bare media type. +/// Delegates to media::strip_mime_params (single source of truth). +/// Needed because Slack occasionally sends `text/plain; charset=utf-8` and +/// `media::is_text_file` expects the bare form. +fn strip_mime_params(mimetype: &str) -> &str { + media::strip_mime_params(mimetype) +} + +/// Sanitize a filename for safe embedding in a Slack mrkdwn message. +/// +/// Ampersands (`&`), backticks (`` ` ``), and angle brackets (`<`, `>`) are escaped. +/// `&` is encoded as `&` first because Slack decodes HTML entities before parsing +/// mrkdwn — a filename like `<@here>` would otherwise round-trip back to +/// `<@here>` and trigger a mention ping. Backticks and angle brackets are Slack +/// mrkdwn delimiters; without escaping, `` or `` `<@U123>` `` would render +/// as mentions or @-here pings. +pub(crate) fn sanitize_slack_filename(s: &str) -> String { + s.replace('&', "&").replace('`', "'").replace('<', "(").replace('>', ")") +} + +/// Returns `true` if `text` contains a Slack user mention for `uid`. +/// +/// Accepts both `<@U...>` (bare) and `<@U...|handle>` (labelled) wire forms. +/// Slack (and bots addressing peers) can emit the labelled form; `<@UID>` is +/// not a substring of `<@UID|handle>`, so a bare `contains("<@UID>")` silently +/// misses it. +fn text_mentions_uid(text: &str, uid: &str) -> bool { + let prefix = format!("<@{uid}"); + text.match_indices(&prefix) + .any(|(i, _)| matches!(text.as_bytes().get(i + prefix.len()), Some(b'>') | Some(b'|'))) +} + +fn bot_id_matches_trusted( + trusted_bot_ids: &HashSet, + event_bot_id: &str, + resolved_user_id: Option<&str>, +) -> bool { + if event_bot_id.is_empty() { + return false; + } + + trusted_bot_ids.contains(event_bot_id) + || resolved_user_id.is_some_and(|uid| trusted_bot_ids.contains(uid)) +} + +/// True only when a Slack non-bot event represents a real user message +/// that should reset the bot-turn counter. +/// +/// Many Slack subtypes (pinned_item, channel_name, channel_archive, +/// group_join / group_leave / group_topic / group_purpose, reminder_add, +/// tombstone, …) carry a `user` field so the event loop sees +/// `is_bot == false`, but they represent administrative/system actions, +/// not conversation. Resetting the counter on them would let runaway +/// bot-to-bot loops re-arm whenever any pin / rename / archive happens. +/// +/// Mirrors Discord's `MessageType::Regular | InlineReply` + non-empty +/// content gate in `src/discord.rs`. Regression parity for +/// openabdev/openab#497. +fn is_plain_user_message(subtype: &str, text: &str) -> bool { + if text.is_empty() { + return false; + } + matches!( + subtype, + "" | "me_message" | "thread_broadcast" | "file_share", + ) +} + +/// Slack caps a single Block Kit `markdown` block at 12,000 characters; we use +/// 11,900 to keep ~100 chars of headroom. Doubles as the Slack `message_limit` +/// so the router splits long replies into separate messages at the same bound +/// (one markdown block per message stays under the API cap). +const MARKDOWN_BLOCK_LIMIT: usize = 11_900; + +/// True if a Slack API error indicates the `blocks` payload was rejected, so the +/// caller should retry text-only: +/// - `invalid_blocks` — workspace can't render the Block Kit `markdown` block +/// (malformed/unsupported payload). +/// - `msg_blocks_too_long` — content exceeds Slack's cumulative ~12k cap across +/// all `markdown` blocks in one message. Reachable by direct `send_message` +/// callers that bypass the router's `message_limit` pre-split (e.g. STT echo). +/// +/// `invalid_arguments` is deliberately excluded — it's a Slack catch-all (bad +/// channel, missing/invalid `ts`, malformed `thread_ts`, …) and would trigger a +/// pointless text-only retry that fails identically. +/// +/// Matches the Slack error *code* exactly (the trailing token of `api_post`'s +/// `"Slack API : "` message), not a substring of the message — +/// so a future code like `invalid_blocks_field` does not falsely match. +fn is_block_payload_rejected(e: &anyhow::Error) -> bool { + let s = e.to_string(); + let code = s.rsplit(": ").next().unwrap_or(s.as_str()).trim(); + code == "invalid_blocks" || code == "msg_blocks_too_long" +} + +/// Build Block Kit `markdown` blocks from raw Markdown. Slack renders these +/// natively — real headings, lists, tables, blockquotes, and language-tagged +/// code fences — unlike the legacy `text` mrkdwn field, which flattens headings +/// to bold and cannot render tables. Long content is split at the block limit, +/// reusing `format::split_message` so code-fence balance is preserved. +/// +/// Follow-up (non-blocking): `split_message` is not table-aware — a single +/// Markdown table exceeding `MARKDOWN_BLOCK_LIMIT` (11,900 chars) splits at line +/// boundaries, so continuation blocks lack the header/separator rows and render +/// as raw pipes. The 4000→11,900 bump makes this rare; a future improvement is +/// to re-emit the table header at the top of each continuation chunk. +fn build_markdown_blocks(content: &str) -> Vec { + let chunks = if content.len() <= MARKDOWN_BLOCK_LIMIT { + vec![content.to_string()] + } else { + crate::format::split_message(content, MARKDOWN_BLOCK_LIMIT) + }; + chunks + .into_iter() + .map(|chunk| serde_json::json!({ "type": "markdown", "text": chunk })) + .collect() +} + +/// Body for `chat.postMessage`: Block Kit `markdown` blocks (rich rendering) +/// plus a `text` fallback used for notifications and accessibility. +fn build_post_message_body( + channel_id: &str, + thread_ts: Option<&str>, + content: &str, +) -> serde_json::Value { + let mut body = serde_json::json!({ + "channel": channel_id, + "blocks": build_markdown_blocks(content), + "text": markdown_to_mrkdwn(content), + }); + if let Some(ts) = thread_ts { + body["thread_ts"] = serde_json::Value::String(ts.to_string()); + } + body +} + +/// Body for `chat.update`: same Block Kit `markdown` blocks + `text` fallback. +fn build_update_body(channel_id: &str, ts: &str, content: &str) -> serde_json::Value { + serde_json::json!({ + "channel": channel_id, + "ts": ts, + "blocks": build_markdown_blocks(content), + "text": markdown_to_mrkdwn(content), + }) +} + +/// Text-only `chat.postMessage` body (no `blocks`) — degradation path when a +/// workspace rejects the Block Kit `markdown` block. +fn build_post_message_text_only( + channel_id: &str, + thread_ts: Option<&str>, + content: &str, +) -> serde_json::Value { + let mut body = serde_json::json!({ + "channel": channel_id, + "text": markdown_to_mrkdwn(content), + }); + if let Some(ts) = thread_ts { + body["thread_ts"] = serde_json::Value::String(ts.to_string()); + } + body +} + +/// Text-only `chat.update` body (no `blocks`) — see `build_post_message_text_only`. +fn build_update_text_only(channel_id: &str, ts: &str, content: &str) -> serde_json::Value { + serde_json::json!({ + "channel": channel_id, + "ts": ts, + "text": markdown_to_mrkdwn(content), + }) +} + +/// Convert Markdown (as output by Claude Code) to Slack mrkdwn format. +/// Used for the `text` fallback field that accompanies Block Kit blocks +/// (shown in notification previews and to assistive tech). +fn markdown_to_mrkdwn(text: &str) -> String { + static BOLD_RE: LazyLock = + LazyLock::new(|| regex::Regex::new(r"\*\*(.+?)\*\*").unwrap()); + static ITALIC_RE: LazyLock = + LazyLock::new(|| regex::Regex::new(r"\*([^*]+?)\*").unwrap()); + static LINK_RE: LazyLock = + LazyLock::new(|| regex::Regex::new(r"\[([^\]]+)\]\(([^)]+)\)").unwrap()); + static HEADING_RE: LazyLock = + LazyLock::new(|| regex::Regex::new(r"(?m)^#{1,6}\s+(.+)$").unwrap()); + static CODE_BLOCK_LANG_RE: LazyLock = + LazyLock::new(|| regex::Regex::new(r"```\w+\n").unwrap()); + + // Order: bold first (** → placeholder), then italic (* → _), then restore bold + let text = BOLD_RE.replace_all(text, "\x01$1\x02"); // **bold** → \x01bold\x02 + let text = ITALIC_RE.replace_all(&text, "_${1}_"); // *italic* → _italic_ + // Restore bold: \x01bold\x02 → *bold* + let text = text.replace(['\x01', '\x02'], "*"); + let text = LINK_RE.replace_all(&text, "<$2|$1>"); // [text](url) → + let text = HEADING_RE.replace_all(&text, "*$1*"); // # heading → *heading* + let text = CODE_BLOCK_LANG_RE.replace_all(&text, "```\n"); // ```rust → ``` + text.into_owned() +} + +fn build_start_stream_body(channel: &str, thread_ts: &str, user_id: &str, team_id: &str) -> serde_json::Value { + serde_json::json!({ + "channel": channel, + "thread_ts": thread_ts, + "recipient_user_id": user_id, + "recipient_team_id": team_id, + }) +} + +fn build_append_stream_body(channel: &str, ts: &str, delta: &str) -> serde_json::Value { + serde_json::json!({ + "channel": channel, + "ts": ts, + "markdown_text": delta, + }) +} + +fn build_set_status_body(channel_id: &str, thread_ts: &str, status: &str) -> serde_json::Value { + serde_json::json!({ + "channel_id": channel_id, + "thread_ts": thread_ts, + "status": status, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + // --- builder tests --- + + #[test] + fn build_start_stream_body_has_recipient() { + let b = build_start_stream_body("C1", "1700.1", "U2", "T3"); + assert_eq!(b["channel"], "C1"); + assert_eq!(b["thread_ts"], "1700.1"); + assert_eq!(b["recipient_user_id"], "U2"); + assert_eq!(b["recipient_team_id"], "T3"); + } + + #[test] + fn build_append_stream_body_is_markdown_text_chunk() { + let b = build_append_stream_body("C1", "1700.9", "hello"); + assert_eq!(b["channel"], "C1"); + assert_eq!(b["ts"], "1700.9"); + assert_eq!(b["markdown_text"], "hello"); + } + + #[test] + fn build_set_status_body_shape() { + let b = build_set_status_body("C1", "1700.1", "Thinking\u{2026}"); + assert_eq!(b["channel_id"], "C1"); + assert_eq!(b["thread_ts"], "1700.1"); + assert_eq!(b["status"], "Thinking\u{2026}"); + } + + #[tokio::test] + async fn degraded_stream_append_accumulates() { + let adapter = SlackAdapter::new("xoxb-test".into(), std::time::Duration::from_secs(60), AllowBots::Off, true, crate::multibot_cache::MultibotCache::load("/dev/null".into())); + adapter.streams.lock().await.insert( + "TS".into(), + StreamEntry { active: false, degraded_buf: String::new() }, + ); + assert_eq!(adapter.accumulate_degraded("TS", "a").await.as_deref(), Some("a")); + assert_eq!(adapter.accumulate_degraded("TS", "b").await.as_deref(), Some("ab")); + // missing stream is not resurrected: + assert_eq!(adapter.accumulate_degraded("MISSING", "x").await, None); + } + use crate::adapter::ChatAdapter; + + /// Bot's own `<@UID>` trigger mention is stripped. + #[test] + fn resolve_mentions_strips_bot_mention() { + let out = resolve_slack_mentions("<@U1BOT> hello", Some("U1BOT")); + assert_eq!(out, "hello"); + } + + /// Other users' mentions are preserved so the LLM can address them back — + /// this is the core fix: the old `strip_slack_mention` wiped all `<@...>`. + #[test] + fn resolve_mentions_preserves_other_user_mentions() { + let out = resolve_slack_mentions("<@U1BOT> say hi to <@U2ALICE>", Some("U1BOT")); + assert_eq!(out, "say hi to <@U2ALICE>"); + } + + /// Multiple occurrences of the bot mention all get stripped. + #[test] + fn resolve_mentions_strips_repeated_bot_mentions() { + let out = resolve_slack_mentions("<@U1BOT> ping <@U1BOT>", Some("U1BOT")); + assert_eq!(out, "ping"); + } + + /// When the bot UID is unknown, fall back to preserving the text + /// (safer than stripping all user mentions). + #[test] + fn resolve_mentions_unknown_bot_preserves_all() { + let out = resolve_slack_mentions("<@U1BOT> hi <@U2ALICE>", None); + assert_eq!(out, "<@U1BOT> hi <@U2ALICE>"); + } + + /// Labelled form of another user's mention (`<@UID|handle>`) is preserved. + #[test] + fn resolve_mentions_preserves_labelled_other_user_mention() { + let out = resolve_slack_mentions("<@U1BOT> say hi to <@U2ALICE|alice>", Some("U1BOT")); + assert_eq!(out, "say hi to <@U2ALICE|alice>"); + } + + /// Labelled form `<@UID|handle>` is stripped the same as bare form. + #[test] + fn resolve_mentions_strips_labelled_bot_mention() { + let out = resolve_slack_mentions("<@U1BOT|my-bot> hello", Some("U1BOT")); + assert_eq!(out, "hello"); + } + + /// Labelled form mid-sentence is stripped and surrounding text preserved. + #[test] + fn resolve_mentions_strips_labelled_mid_sentence() { + let out = resolve_slack_mentions("please ask <@U1BOT|handle> to run", Some("U1BOT")); + assert_eq!(out, "please ask to run"); + } + + /// Mixed bare and labelled forms of the same UID in one string are both stripped. + #[test] + fn resolve_mentions_strips_mixed_bare_and_labelled() { + let out = resolve_slack_mentions("<@U1BOT> and <@U1BOT|handle> run", Some("U1BOT")); + assert_eq!(out, "and run"); + } + + /// Malformed unclosed `<@UID|label` (no closing `>`) is preserved verbatim. + #[test] + fn resolve_mentions_malformed_unclosed_label_preserved() { + let out = resolve_slack_mentions("ask <@U1BOT|nolabel to run", Some("U1BOT")); + assert!(out.contains("<@U1BOT")); + } + + #[test] + fn resolve_mentions_preserves_longer_uid_prefix() { + let out = resolve_slack_mentions("<@U1BOTX> hello", Some("U1BOT")); + assert_eq!(out, "<@U1BOTX> hello"); + } + + // --- text_mentions_uid tests --- + + #[test] + fn mentions_uid_bare_form() { + assert!(text_mentions_uid("<@U123BOT> hello", "U123BOT")); + } + + #[test] + fn mentions_uid_labelled_form() { + assert!(text_mentions_uid("<@U123BOT|my-bot> hello", "U123BOT")); + } + + #[test] + fn mentions_uid_labelled_form_mid_sentence() { + assert!(text_mentions_uid("please ask <@U123BOT|handle> to run", "U123BOT")); + } + + #[test] + fn mentions_uid_no_match() { + assert!(!text_mentions_uid("hello world", "U123BOT")); + } + + #[test] + fn mentions_uid_no_false_positive_on_uid_prefix() { + assert!(!text_mentions_uid("<@U123BOT> hello", "U123")); + } + + #[test] + fn mentions_uid_second_mention_matches() { + assert!(text_mentions_uid("<@U999OTHER> and <@U123BOT>", "U123BOT")); + } + + #[test] + fn mentions_uid_empty_label_form() { + assert!(text_mentions_uid("<@U123BOT|> hello", "U123BOT")); + } + + #[test] + fn mentions_uid_truncated_no_closing_delimiter() { + assert!(!text_mentions_uid("<@U123BOT", "U123BOT")); + } + + // --- is_plain_user_message tests (regression for openabdev/openab#497 parity) --- + + /// Empty message text never counts as a user message (regardless of subtype). + #[test] + fn empty_text_is_not_plain_user_message() { + assert!(!is_plain_user_message("", "")); + assert!(!is_plain_user_message("me_message", "")); + } + + /// No subtype + non-empty text = plain user message (the common case). + #[test] + fn no_subtype_nonempty_text_is_plain_user_message() { + assert!(is_plain_user_message("", "hello")); + } + + /// Whitelisted subtypes with non-empty text are user messages. + #[test] + fn whitelisted_subtypes_are_plain_user_messages() { + assert!(is_plain_user_message("me_message", "waves")); + assert!(is_plain_user_message("thread_broadcast", "see channel")); + assert!(is_plain_user_message("file_share", "caption")); + } + + /// System-ish subtypes (even from real users) are NOT user messages — + /// resetting the counter on them would let bot-to-bot loops re-arm. + #[test] + fn system_subtypes_are_not_plain_user_messages() { + for subtype in [ + "pinned_item", + "unpinned_item", + "channel_name", + "channel_archive", + "channel_unarchive", + "group_join", + "group_leave", + "group_topic", + "group_purpose", + "reminder_add", + "tombstone", + ] { + assert!( + !is_plain_user_message(subtype, "some text"), + "subtype {subtype} must not count as a user message", + ); + } + } + + // --- slack_file_download_url tests --- + + /// Prefers url_private_download when both fields are present — + /// that endpoint always streams raw bytes even for browser-previewed types. + #[test] + fn slack_file_url_prefers_download_variant() { + let file = serde_json::json!({ + "url_private_download": "https://files.slack.com/.../download/log.txt", + "url_private": "https://files.slack.com/.../preview/log.txt", + }); + assert_eq!( + slack_file_download_url(&file), + "https://files.slack.com/.../download/log.txt", + ); + } + + /// Falls back to url_private when url_private_download is absent. + #[test] + fn slack_file_url_falls_back_to_private() { + let file = serde_json::json!({ + "url_private": "https://files.slack.com/.../log.txt", + }); + assert_eq!( + slack_file_download_url(&file), + "https://files.slack.com/.../log.txt", + ); + } + + /// Externally-backed files with no private URL return empty — caller skips. + #[test] + fn slack_file_url_empty_for_external_only() { + let file = serde_json::json!({ + "external_type": "gdrive", + "permalink": "https://docs.google.com/...", + }); + assert_eq!(slack_file_download_url(&file), ""); + } + + // --- sanitize_slack_filename tests --- + + #[test] + fn sanitize_leaves_normal_filename_unchanged() { + assert_eq!(sanitize_slack_filename("photo.png"), "photo.png"); + assert_eq!(sanitize_slack_filename("my file (1).jpg"), "my file (1).jpg"); + } + + #[test] + fn sanitize_replaces_backtick() { + assert_eq!(sanitize_slack_filename("file`name.png"), "file'name.png"); + } + + #[test] + fn sanitize_replaces_angle_brackets() { + // Angle brackets are Slack mrkdwn delimiters; they must not pass through. + assert_eq!(sanitize_slack_filename("<@U123>"), "(@U123)"); + assert_eq!(sanitize_slack_filename(""), "(!here)"); + } + + #[test] + fn sanitize_combined_injection_attempt() { + // A filename constructed to inject a Slack @here ping. + assert_eq!( + sanitize_slack_filename("``"), + "'(!here)'" + ); + } + + #[test] + fn sanitize_escapes_ampersand_before_angle_brackets() { + // Slack mrkdwn decodes HTML entities before markup parsing. + // "<@here>" would round-trip back to "<@here>" and trigger a mention + // ping if & is not escaped. The & must be escaped first so downstream + // Slack entity decoding cannot reconstruct a mrkdwn delimiter. + assert_eq!(sanitize_slack_filename("<@here>"), "&lt;@here&gt;"); + assert_eq!(sanitize_slack_filename("file&name.png"), "file&name.png"); + } + + // --- strip_mime_params tests --- + + /// MIME with charset parameter strips to bare media type. + #[test] + fn strip_mime_params_removes_charset() { + assert_eq!(strip_mime_params("text/plain; charset=utf-8"), "text/plain"); + } + + /// Bare MIME is unchanged. + #[test] + fn strip_mime_params_bare_unchanged() { + assert_eq!(strip_mime_params("image/png"), "image/png"); + } + + /// Empty input is unchanged. + #[test] + fn strip_mime_params_empty() { + assert_eq!(strip_mime_params(""), ""); + } + + /// Surrounding whitespace is trimmed. + #[test] + fn strip_mime_params_trims_whitespace() { + assert_eq!(strip_mime_params(" text/plain "), "text/plain"); + } + + // --- bot_id_matches_trusted tests --- + + #[test] + fn trusted_bot_ids_accepts_raw_slack_bot_id() { + let trusted = HashSet::from(["B123BOT".to_string()]); + assert!(bot_id_matches_trusted(&trusted, "B123BOT", None)); + } + + #[test] + fn trusted_bot_ids_accepts_resolved_bot_user_id() { + let trusted = HashSet::from(["U123BOT".to_string()]); + assert!(bot_id_matches_trusted( + &trusted, + "B123BOT", + Some("U123BOT") + )); + } + + #[test] + fn trusted_bot_ids_rejects_unknown_bot_when_resolution_fails() { + let trusted = HashSet::from(["U123BOT".to_string()]); + assert!(!bot_id_matches_trusted(&trusted, "B999BOT", None)); + } + + #[test] + fn trusted_bot_ids_rejects_empty_event_bot_id() { + let trusted = HashSet::from(["".to_string()]); + assert!(!bot_id_matches_trusted(&trusted, "", None)); + } + + /// Per-thread streaming: ON by default, OFF when another bot is present (#534). + #[test] + fn streaming_per_thread() { + let ttl = std::time::Duration::from_secs(300); + let adapter = SlackAdapter::new("xoxb-test".into(), ttl, AllowBots::Mentions, false, crate::multibot_cache::MultibotCache::load("/dev/null".into())); + + assert!( + adapter.use_streaming(false), + "should stream when no other bot" + ); + assert!( + !adapter.use_streaming(true), + "should NOT stream when other bot present" + ); + } + + #[tokio::test] + async fn assistant_mode_gates_status_and_native_streaming() { + let ttl = std::time::Duration::from_secs(60); + // assistant_mode=true → status API on; native streaming on (no other bot), + // off when another bot is present; post+edit streaming on regardless. + let adapter = SlackAdapter::new("xoxb-test".into(), ttl, AllowBots::Off, true, crate::multibot_cache::MultibotCache::load("/dev/null".into())); + assert!(adapter.uses_assistant_status(), "assistant_mode enables status API"); + assert!(adapter.use_streaming(false), "post+edit streaming on when no other bot"); + assert!(adapter.uses_native_streaming(false), "native streaming on when no other bot"); + assert!(!adapter.uses_native_streaming(true), "other bot present disables native"); + // assistant_mode=false → no status API, no native streaming; post+edit still streams. + let adapter2 = SlackAdapter::new("xoxb-test".into(), ttl, AllowBots::Off, false, crate::multibot_cache::MultibotCache::load("/dev/null".into())); + assert!(!adapter2.uses_assistant_status()); + assert!(adapter2.use_streaming(false), "post+edit streaming independent of assistant_mode"); + assert!(!adapter2.uses_native_streaming(false), "native streaming requires assistant_mode"); + } + + /// chat.postMessage body carries Block Kit `markdown` blocks with the raw + /// Markdown preserved (NOT downgraded), plus a `text` fallback and thread_ts. + #[test] + fn post_message_body_uses_raw_markdown_blocks() { + let b = build_post_message_body("C1", Some("1700.1"), "## Heading\n- item"); + assert_eq!(b["channel"], "C1"); + assert_eq!(b["thread_ts"], "1700.1"); + assert_eq!(b["blocks"][0]["type"], "markdown"); + // Raw markdown preserved — heading is NOT flattened to `*Heading*`. + assert_eq!(b["blocks"][0]["text"], "## Heading\n- item"); + assert!(b["text"].is_string(), "text fallback present for a11y/notifs"); + } + + /// thread_ts is omitted (top-level post) when the channel has no thread. + #[test] + fn post_message_body_omits_thread_ts_when_none() { + let b = build_post_message_body("C1", None, "hi"); + assert!(b.get("thread_ts").is_none()); + } + + /// chat.update body also uses Block Kit `markdown` blocks with raw markdown. + #[test] + fn update_body_uses_raw_markdown_blocks() { + let b = build_update_body("C1", "1700.9", "**bold**"); + assert_eq!(b["channel"], "C1"); + assert_eq!(b["ts"], "1700.9"); + assert_eq!(b["blocks"][0]["type"], "markdown"); + assert_eq!(b["blocks"][0]["text"], "**bold**"); + } + + /// Content over the per-block cap (11,900) splits into multiple markdown + /// blocks, each within the limit. Assert on char count — `split_message` + /// enforces `chars().count() <= limit`, not byte length. + #[test] + fn long_content_splits_into_multiple_markdown_blocks() { + let big = "lorem ipsum dolor\n".repeat(1000); // > MARKDOWN_BLOCK_LIMIT + assert!(big.chars().count() > MARKDOWN_BLOCK_LIMIT); + let blocks = build_markdown_blocks(&big); + assert!(blocks.len() >= 2, "should split into multiple blocks"); + for blk in &blocks { + assert_eq!(blk["type"], "markdown"); + assert!(blk["text"].as_str().unwrap().chars().count() <= MARKDOWN_BLOCK_LIMIT); + } + } + + /// Regression for the long-table split: a Markdown table that overflows the + /// old 4000 limit but fits the new 11,900 message_limit must stay in a single + /// chunk, so it isn't split mid-table into raw pipe text. + #[test] + fn typical_long_table_stays_in_one_chunk() { + let ttl = std::time::Duration::from_secs(300); + let adapter = SlackAdapter::new("xoxb-test".into(), ttl, AllowBots::Mentions, true, crate::multibot_cache::MultibotCache::load("/dev/null".into())); + let limit = adapter.message_limit(); + assert_eq!(limit, MARKDOWN_BLOCK_LIMIT); + let mut table = String::from("| col a | col b | col c |\n|---|---|---|\n"); + for i in 0..150 { + table.push_str(&format!("| row {i} aaaa | bbbb {i} | cccc {i} |\n")); + } + assert!(table.chars().count() > 4000, "table must exceed old limit"); + assert!(table.chars().count() < limit, "but fit the new one"); + assert_eq!( + crate::format::split_message(&table, limit).len(), + 1, + "table within message_limit must not be split mid-table" + ); + } + + /// Text-only fallback bodies carry `text` and no `blocks` — used when a + /// workspace rejects the Block Kit markdown block. + #[test] + fn text_only_fallback_bodies_have_no_blocks() { + let post = build_post_message_text_only("C1", Some("1700.1"), "## H\n- x"); + assert!(post.get("blocks").is_none()); + assert!(post["text"].is_string()); + assert_eq!(post["thread_ts"], "1700.1"); + let upd = build_update_text_only("C1", "1700.9", "**b**"); + assert!(upd.get("blocks").is_none()); + assert!(upd["text"].is_string()); + } + + /// Error classifier matches `invalid_blocks` (malformed/unsupported blocks) + /// and `msg_blocks_too_long` (over the cumulative block cap) → degrade to + /// text. `invalid_arguments` is a Slack catch-all and must NOT trigger a + /// pointless text-only retry; unrelated errors are ignored too. + #[test] + fn detects_block_payload_rejected_errors() { + assert!(is_block_payload_rejected(&anyhow!( + "Slack API chat.postMessage: invalid_blocks" + ))); + assert!( + is_block_payload_rejected(&anyhow!("Slack API chat.postMessage: msg_blocks_too_long")), + "oversize block payload should degrade to text-only" + ); + assert!( + !is_block_payload_rejected(&anyhow!("Slack API chat.update: invalid_arguments")), + "invalid_arguments is a catch-all, not a block-rejection signal" + ); + assert!(!is_block_payload_rejected(&anyhow!( + "Slack API chat.postMessage: channel_not_found" + ))); + // Exact error-code match, not substring: a future code that merely + // contains `invalid_blocks` must NOT trigger a text-only retry. + assert!( + !is_block_payload_rejected(&anyhow!("Slack API chat.postMessage: invalid_blocks_field")), + "must match the error code exactly, not as a substring" + ); + } + + /// Slack opts into native table rendering (Block Kit markdown / markdown_text + /// stream chunks), so the router skips the table→code-block conversion. + #[test] + fn slack_renders_native_tables() { + let ttl = std::time::Duration::from_secs(300); + let adapter = SlackAdapter::new("xoxb-test".into(), ttl, AllowBots::Mentions, true, crate::multibot_cache::MultibotCache::load("/dev/null".into())); + assert!(adapter.renders_native_tables()); + } +} + +#[cfg(test)] +mod socket_keepalive_tests { + use super::{next_backoff, socket_idle, IDLE_TIMEOUT_SECS, MAX_BACKOFF_SECS}; + use std::time::Duration; + + /// Backoff doubles and caps, matching the gateway adapter (1,2,4,8,16,30,30…). + #[test] + fn backoff_doubles_then_caps() { + let mut b = 1u64; + let seq: Vec = (0..8) + .map(|_| { + let cur = b; + b = next_backoff(b); + cur + }) + .collect(); + assert_eq!(seq, vec![1, 2, 4, 8, 16, MAX_BACKOFF_SECS, MAX_BACKOFF_SECS, MAX_BACKOFF_SECS]); + assert_eq!(next_backoff(MAX_BACKOFF_SECS), MAX_BACKOFF_SECS); + } + + /// A half-open socket (no inbound past the window) is detected; an active one + /// (recent inbound, e.g. a Slack ping) is not. This is the deaf-socket guard. + #[test] + fn idle_detects_half_open_at_boundary() { + let timeout = Duration::from_secs(IDLE_TIMEOUT_SECS); + assert!(!socket_idle(Duration::from_secs(0), timeout)); + assert!(!socket_idle(Duration::from_secs(IDLE_TIMEOUT_SECS - 1), timeout)); + assert!(socket_idle(Duration::from_secs(IDLE_TIMEOUT_SECS), timeout)); + assert!(socket_idle(Duration::from_secs(IDLE_TIMEOUT_SECS + 10), timeout)); + } +} diff --git a/crates/openab-core/src/stt.rs b/crates/openab-core/src/stt.rs new file mode 100644 index 000000000..d266e6117 --- /dev/null +++ b/crates/openab-core/src/stt.rs @@ -0,0 +1,354 @@ +use crate::adapter::{ChannelRef, ChatAdapter, MessageRef}; +use crate::config::SttConfig; +use reqwest::multipart; +use std::sync::Arc; +use tracing::{debug, error, warn}; + +/// Outcome of attempting STT on a single audio attachment. +/// Used by adapters to feed `post_echo`. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum EchoEntry { + Success(String), + Failed, +} + +/// Render a list of echo entries as a single multi-line quoted block. +/// Returns `None` for empty input so callers can short-circuit. +/// +/// Each entry produces one `> 🎤 …` line. Internal newlines inside a +/// transcript are flattened to spaces so each entry occupies exactly one +/// visual line — Discord and Slack both stop applying `>` at the next `\n`. +pub fn format_echo_message(entries: &[EchoEntry]) -> Option { + if entries.is_empty() { + return None; + } + let mut lines = Vec::with_capacity(entries.len()); + for e in entries { + match e { + EchoEntry::Success(text) => { + let flat = text.replace(['\n', '\r'], " "); + lines.push(format!("> 🎤 {flat}")); + } + EchoEntry::Failed => { + lines.push("> 🎤 (transcription failed)".to_string()); + } + } + } + Some(lines.join("\n")) +} + +/// Post a transcript echo to the thread and add a ⚠️ reaction for any failed +/// entries. No-op when the config disables echoing or when `entries` is empty. +/// +/// Errors from the adapter (send/reaction) are logged and swallowed — the +/// echo is best-effort and must never block the agent reply. +pub async fn post_echo( + adapter: &Arc, + thread: &ChannelRef, + trigger: &MessageRef, + entries: &[EchoEntry], + cfg: &SttConfig, +) { + if !cfg.echo_transcript { + return; + } + let Some(body) = format_echo_message(entries) else { + return; + }; + if let Err(e) = adapter.send_message(thread, &body).await { + warn!(error = %e, platform = adapter.platform(), "failed to send STT echo message"); + } + for entry in entries { + if matches!(entry, EchoEntry::Failed) { + if let Err(e) = adapter.add_reaction(trigger, "⚠️").await { + warn!(error = %e, platform = adapter.platform(), "failed to add STT failure reaction"); + } + // Add only one reaction even with multiple failures — emoji reactions + // are unique per (user, emoji, message), so additional calls are no-ops. + break; + } + } +} + +/// Transcribe audio bytes via an OpenAI-compatible `/audio/transcriptions` endpoint. +pub async fn transcribe( + client: &reqwest::Client, + cfg: &SttConfig, + audio_bytes: Vec, + filename: String, + mime_type: &str, +) -> Option { + let url = format!( + "{}/audio/transcriptions", + cfg.base_url.trim_end_matches('/') + ); + + let file_part = multipart::Part::bytes(audio_bytes) + .file_name(filename) + .mime_str(mime_type) + .ok()?; + + let form = multipart::Form::new() + .part("file", file_part) + .text("model", cfg.model.clone()) + .text("response_format", "json"); + + let resp = match client + .post(&url) + .bearer_auth(&cfg.api_key) + .multipart(form) + .send() + .await + { + Ok(r) => r, + Err(e) => { + error!(error = %e, "STT request failed"); + return None; + } + }; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + error!(status = %status, body = %body, "STT API error"); + return None; + } + + let json: serde_json::Value = match resp.json().await { + Ok(v) => v, + Err(e) => { + error!(error = %e, "STT response parse failed"); + return None; + } + }; + + let text = json.get("text")?.as_str()?.trim().to_string(); + if text.is_empty() { + return None; + } + + debug!(chars = text.len(), "STT transcription complete"); + Some(text) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn format_single_success_entry() { + let entries = vec![EchoEntry::Success("hello world".into())]; + let out = format_echo_message(&entries).expect("non-empty input → Some"); + assert_eq!(out, "> 🎤 hello world"); + } + + #[test] + fn format_single_failure_entry() { + let entries = vec![EchoEntry::Failed]; + let out = format_echo_message(&entries).expect("non-empty input → Some"); + assert_eq!(out, "> 🎤 (transcription failed)"); + } + + #[test] + fn format_multiple_mixed_entries() { + let entries = vec![ + EchoEntry::Success("first".into()), + EchoEntry::Failed, + EchoEntry::Success("third".into()), + ]; + let out = format_echo_message(&entries).expect("non-empty input → Some"); + assert_eq!(out, "> 🎤 first\n> 🎤 (transcription failed)\n> 🎤 third"); + } + + #[test] + fn format_empty_entries_returns_none() { + let entries: Vec = vec![]; + assert!(format_echo_message(&entries).is_none()); + } + + #[test] + fn format_strips_internal_newlines_in_transcript() { + // Multi-line transcripts must collapse to a single quoted line so the + // ">" prefix still applies to every visual line. + let entries = vec![EchoEntry::Success("line one\nline two".into())]; + let out = format_echo_message(&entries).expect("non-empty input → Some"); + assert_eq!(out, "> 🎤 line one line two"); + } + + use crate::adapter::{ChannelRef, ChatAdapter, MessageRef}; + use anyhow::Result; + use async_trait::async_trait; + use std::sync::{Arc, Mutex}; + + #[derive(Default)] + struct MockAdapter { + sent_messages: Mutex>, + reactions: Mutex>, + } + + #[async_trait] + impl ChatAdapter for MockAdapter { + fn platform(&self) -> &'static str { + "mock" + } + fn message_limit(&self) -> usize { + 4000 + } + async fn send_message(&self, channel: &ChannelRef, content: &str) -> Result { + self.sent_messages + .lock() + .unwrap() + .push((channel.clone(), content.to_string())); + Ok(MessageRef { + channel: channel.clone(), + message_id: "mock-msg".into(), + }) + } + async fn create_thread( + &self, + channel: &ChannelRef, + _trigger: &MessageRef, + _title: &str, + ) -> Result { + Ok(channel.clone()) + } + async fn add_reaction(&self, msg: &MessageRef, emoji: &str) -> Result<()> { + self.reactions + .lock() + .unwrap() + .push((msg.clone(), emoji.to_string())); + Ok(()) + } + async fn remove_reaction(&self, _msg: &MessageRef, _emoji: &str) -> Result<()> { + Ok(()) + } + fn use_streaming(&self, _other_bot_present: bool) -> bool { + false + } + } + + fn test_channel() -> ChannelRef { + ChannelRef { + platform: "mock".into(), + channel_id: "C1".into(), + thread_id: Some("T1".into()), + parent_id: None, + origin_event_id: None, + } + } + + fn test_trigger() -> MessageRef { + MessageRef { + channel: test_channel(), + message_id: "M1".into(), + } + } + + fn cfg(echo: bool) -> SttConfig { + SttConfig { + echo_transcript: echo, + ..SttConfig::default() + } + } + + #[tokio::test] + async fn post_echo_success_sends_one_message_no_reactions() { + let mock = Arc::new(MockAdapter::default()); + let adapter: Arc = mock.clone(); + let entries = vec![EchoEntry::Success("hello".into())]; + post_echo( + &adapter, + &test_channel(), + &test_trigger(), + &entries, + &cfg(true), + ) + .await; + + assert_eq!(mock.sent_messages.lock().unwrap().len(), 1); + assert_eq!(mock.sent_messages.lock().unwrap()[0].1, "> 🎤 hello"); + assert!(mock.reactions.lock().unwrap().is_empty()); + } + + #[tokio::test] + async fn post_echo_failure_adds_warning_reaction() { + let mock = Arc::new(MockAdapter::default()); + let adapter: Arc = mock.clone(); + let entries = vec![EchoEntry::Failed]; + post_echo( + &adapter, + &test_channel(), + &test_trigger(), + &entries, + &cfg(true), + ) + .await; + + assert_eq!(mock.sent_messages.lock().unwrap().len(), 1); + assert_eq!( + mock.sent_messages.lock().unwrap()[0].1, + "> 🎤 (transcription failed)" + ); + let reactions = mock.reactions.lock().unwrap(); + assert_eq!(reactions.len(), 1); + assert_eq!(reactions[0].1, "⚠️"); + } + + #[tokio::test] + async fn post_echo_mixed_one_message_one_reaction() { + let mock = Arc::new(MockAdapter::default()); + let adapter: Arc = mock.clone(); + let entries = vec![EchoEntry::Success("ok".into()), EchoEntry::Failed]; + post_echo( + &adapter, + &test_channel(), + &test_trigger(), + &entries, + &cfg(true), + ) + .await; + + assert_eq!(mock.sent_messages.lock().unwrap().len(), 1); + assert_eq!( + mock.sent_messages.lock().unwrap()[0].1, + "> 🎤 ok\n> 🎤 (transcription failed)" + ); + assert_eq!(mock.reactions.lock().unwrap().len(), 1); + } + + #[tokio::test] + async fn post_echo_disabled_is_noop() { + let mock = Arc::new(MockAdapter::default()); + let adapter: Arc = mock.clone(); + let entries = vec![EchoEntry::Success("hi".into()), EchoEntry::Failed]; + post_echo( + &adapter, + &test_channel(), + &test_trigger(), + &entries, + &cfg(false), + ) + .await; + + assert!(mock.sent_messages.lock().unwrap().is_empty()); + assert!(mock.reactions.lock().unwrap().is_empty()); + } + + #[tokio::test] + async fn post_echo_empty_entries_is_noop() { + let mock = Arc::new(MockAdapter::default()); + let adapter: Arc = mock.clone(); + let entries: Vec = vec![]; + post_echo( + &adapter, + &test_channel(), + &test_trigger(), + &entries, + &cfg(true), + ) + .await; + + assert!(mock.sent_messages.lock().unwrap().is_empty()); + assert!(mock.reactions.lock().unwrap().is_empty()); + } +} diff --git a/crates/openab-core/src/timestamp.rs b/crates/openab-core/src/timestamp.rs new file mode 100644 index 000000000..aa7adce46 --- /dev/null +++ b/crates/openab-core/src/timestamp.rs @@ -0,0 +1,114 @@ +//! ISO 8601 UTC timestamp helpers — no external crate dependency. +//! +//! Centralizes the Gregorian date math used by Slack (`.` ts strings) +//! and Gateway (`SystemTime::now()`) so both adapters share one implementation. + +use std::time::{SystemTime, UNIX_EPOCH}; + +/// Convert days since the Unix epoch (1970-01-01) to a Gregorian (year, month, day). +/// Algorithm from . +fn days_to_ymd(days: u64) -> (u64, u64, u64) { + let z = days + 719468; + let era = z / 146097; + let doe = z % 146097; + let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365; + let y = yoe + era * 400; + let doy = doe - (365 * yoe + yoe / 4 - yoe / 100); + let mp = (5 * doy + 2) / 153; + let d = doy - (153 * mp + 2) / 5 + 1; + let m = if mp < 10 { mp + 3 } else { mp - 9 }; + let y = if m <= 2 { y + 1 } else { y }; + (y, m, d) +} + +/// Format a Unix timestamp (seconds + millis) as ISO 8601 UTC with millisecond precision. +fn unix_to_iso8601(secs: u64, ms: u64) -> String { + let days = secs / 86400; + let time_secs = secs % 86400; + let h = time_secs / 3600; + let m = (time_secs % 3600) / 60; + let s = time_secs % 60; + let (year, month, day) = days_to_ymd(days); + format!("{year:04}-{month:02}-{day:02}T{h:02}:{m:02}:{s:02}.{ms:03}Z") +} + +/// Convert a Slack `ts` string (".") to ISO 8601 UTC. +/// Best-effort; falls back to epoch on parse failure. +/// +/// Parses as `f64` so the fractional part carries decimal semantics directly — +/// ".12" maps to 120 ms, not 12 ms — without any string-padding gymnastics. +pub fn slack_ts_to_iso8601(ts: &str) -> String { + let total = ts.parse::().unwrap_or(0.0); + let secs = total.trunc() as u64; + let ms = (total.fract() * 1000.0).round() as u64; + unix_to_iso8601(secs, ms) +} + +/// Current wall-clock instant as ISO 8601 UTC with millisecond precision. +pub fn now_iso8601() -> String { + let dur = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default(); + unix_to_iso8601(dur.as_secs(), (dur.subsec_millis()) as u64) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn slack_ts_epoch_zero() { + assert_eq!(slack_ts_to_iso8601("0.000000"), "1970-01-01T00:00:00.000Z"); + } + + #[test] + fn slack_ts_keeps_milliseconds() { + // 1714204397 = 2024-04-27T07:53:17 UTC; .123456 → .123 ms + assert_eq!( + slack_ts_to_iso8601("1714204397.123456"), + "2024-04-27T07:53:17.123Z" + ); + } + + #[test] + fn slack_ts_missing_fraction_uses_zero() { + assert_eq!( + slack_ts_to_iso8601("1714204397"), + "2024-04-27T07:53:17.000Z" + ); + } + + #[test] + fn slack_ts_two_digit_fraction_is_120ms_not_12ms() { + // ".12" carries decimal semantics: 0.12 s = 120 ms. + assert_eq!( + slack_ts_to_iso8601("1714204397.12"), + "2024-04-27T07:53:17.120Z" + ); + } + + #[test] + fn slack_ts_one_digit_fraction_is_100ms_not_1ms() { + // ".1" carries decimal semantics: 0.1 s = 100 ms. + assert_eq!( + slack_ts_to_iso8601("1714204397.1"), + "2024-04-27T07:53:17.100Z" + ); + } + + #[test] + fn slack_ts_unparseable_falls_back_to_epoch() { + assert_eq!(slack_ts_to_iso8601("not-a-ts"), "1970-01-01T00:00:00.000Z"); + } + + #[test] + fn now_iso8601_has_expected_shape() { + let s = now_iso8601(); + // YYYY-MM-DDTHH:MM:SS.mmmZ = 24 chars + assert_eq!(s.len(), 24); + assert!(s.ends_with('Z')); + assert_eq!(&s[4..5], "-"); + assert_eq!(&s[10..11], "T"); + assert_eq!(&s[19..20], "."); + } +} diff --git a/crates/openab-gateway/Cargo.toml b/crates/openab-gateway/Cargo.toml new file mode 100644 index 000000000..3236ffcd1 --- /dev/null +++ b/crates/openab-gateway/Cargo.toml @@ -0,0 +1,43 @@ +[package] +name = "openab-gateway" +version = "0.5.4" +edition = "2021" +license = "MIT" + +[dependencies] +tokio = { version = "1", features = ["full"] } +axum = { version = "0.8", features = ["ws"] } +tokio-tungstenite = { version = "0.21", features = ["rustls-tls-webpki-roots"] } +futures-util = "0.3" +serde = { version = "1", features = ["derive"] } +serde_json = "1" +reqwest = { version = "0.12", default-features = false, features = ["rustls-tls", "json"] } +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } +anyhow = "1" +uuid = { version = "1", features = ["v4"] } +chrono = { version = "0.4", features = ["serde"] } +hmac = "0.12" +sha2 = "0.10" +base64 = "0.22" +jsonwebtoken = "9" +aes = "0.8" +cbc = "0.1" +prost = "0.13" +subtle = "2" +sha1 = "0.10" +quick-xml = "0.37" +image = { version = "0.25", default-features = false, features = ["jpeg", "png", "gif", "webp"] } +urlencoding = "2" + +[dev-dependencies] +wiremock = "0.6" + +[features] +default = ["telegram", "line", "feishu", "googlechat", "wecom", "teams"] +telegram = [] +line = [] +feishu = [] +googlechat = [] +wecom = [] +teams = [] diff --git a/crates/openab-gateway/src/adapters/feishu.rs b/crates/openab-gateway/src/adapters/feishu.rs new file mode 100644 index 000000000..84e5ac017 --- /dev/null +++ b/crates/openab-gateway/src/adapters/feishu.rs @@ -0,0 +1,3928 @@ +use crate::schema::*; +use axum::extract::State; +use prost::Message as ProstMessage; +use serde::Deserialize; +use std::collections::{HashMap, VecDeque}; +use std::sync::Arc; +use std::time::Instant; +use tokio::sync::RwLock; +use tracing::{info, warn}; + +/// Timing-safe string comparison to prevent side-channel attacks on tokens. +fn constant_time_eq(a: &str, b: &str) -> bool { + use subtle::ConstantTimeEq; + if a.len() != b.len() { + return false; + } + a.as_bytes().ct_eq(b.as_bytes()).into() +} + +// --------------------------------------------------------------------------- +// Feishu WebSocket protobuf frame (pbbp2.Frame) +// --------------------------------------------------------------------------- + +#[derive(Clone, PartialEq, ProstMessage)] +pub struct WsFrame { + #[prost(uint64, tag = "1")] + pub seq_id: u64, + #[prost(uint64, tag = "2")] + pub log_id: u64, + #[prost(int32, tag = "3")] + pub service: i32, + #[prost(int32, tag = "4")] + pub method: i32, + #[prost(message, repeated, tag = "5")] + pub headers: Vec, + #[prost(string, optional, tag = "6")] + pub payload_encoding: Option, + #[prost(string, optional, tag = "7")] + pub payload_type: Option, + #[prost(bytes = "vec", optional, tag = "8")] + pub payload: Option>, + #[prost(string, optional, tag = "9")] + pub log_id_new: Option, +} + +#[derive(Clone, PartialEq, ProstMessage)] +pub struct WsHeader { + #[prost(string, tag = "1")] + pub key: String, + #[prost(string, tag = "2")] + pub value: String, +} + +// --------------------------------------------------------------------------- +// Configuration +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone, PartialEq)] +pub enum ConnectionMode { + Websocket, + Webhook, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum AllowBots { + Off, + Mentions, + All, +} + +/// Controls when the bot responds without @mention in threads. +/// Mirrors Discord's `allow_user_messages` setting. +#[derive(Debug, Clone, PartialEq, Default)] +pub enum AllowUsers { + /// Bot responds in threads it has participated in without @mention. + Involved, + /// Always require @mention, even in participated threads. + Mentions, + /// Like Involved, but if another bot has also posted in the thread, + /// require @mention to avoid all bots responding. + #[default] + MultibotMentions, +} + +#[derive(Debug, Clone)] +pub struct FeishuConfig { + pub app_id: String, + pub app_secret: String, + pub domain: String, + pub connection_mode: ConnectionMode, + pub webhook_path: String, + pub verification_token: Option, + pub encrypt_key: Option, + pub allowed_groups: Vec, + pub allowed_users: Vec, + pub require_mention: bool, + pub allow_bots: AllowBots, + pub allow_user_messages: AllowUsers, + pub trusted_bot_ids: Vec, + pub max_bot_turns: u32, + pub dedupe_ttl_secs: u64, + pub message_limit: usize, + /// TTL for participated-thread cache entries (seconds). Threads older than + /// this are forgotten and require a fresh @mention to re-engage. + /// Set to 0 (via FEISHU_SESSION_TTL_HOURS=0) to disable participation + /// tracking entirely — all messages will require @mention. + /// Converted from `FEISHU_SESSION_TTL_HOURS` (user-facing, in hours) to seconds internally. + pub session_ttl_secs: u64, + /// Override the API base URL. Used in tests to point at a mock server. + /// Always None in production (not read from env). + pub api_base_override: Option, +} + +impl FeishuConfig { + /// Build config from environment variables. Returns None if FEISHU_APP_ID + /// is not set (adapter disabled). + pub fn from_env() -> Option { + let app_id = std::env::var("FEISHU_APP_ID").ok()?; + let app_secret = std::env::var("FEISHU_APP_SECRET").ok().unwrap_or_default(); + if app_secret.is_empty() { + warn!("FEISHU_APP_ID set but FEISHU_APP_SECRET is empty"); + return None; + } + let domain = std::env::var("FEISHU_DOMAIN").unwrap_or_else(|_| "feishu".into()); + let connection_mode = match std::env::var("FEISHU_CONNECTION_MODE") + .unwrap_or_else(|_| "websocket".into()) + .to_lowercase() + .as_str() + { + "webhook" => ConnectionMode::Webhook, + _ => ConnectionMode::Websocket, + }; + let webhook_path = std::env::var("FEISHU_WEBHOOK_PATH") + .unwrap_or_else(|_| "/webhook/feishu".into()); + let verification_token = std::env::var("FEISHU_VERIFICATION_TOKEN").ok(); + let encrypt_key = std::env::var("FEISHU_ENCRYPT_KEY").ok(); + let allowed_groups = parse_csv("FEISHU_ALLOWED_GROUPS"); + let allowed_users = parse_csv("FEISHU_ALLOWED_USERS"); + let require_mention = std::env::var("FEISHU_REQUIRE_MENTION") + .map(|v| v != "false" && v != "0") + .unwrap_or(true); + let allow_bots = match std::env::var("FEISHU_ALLOW_BOTS") + .unwrap_or_else(|_| "off".into()) + .to_lowercase() + .as_str() + { + "mentions" => AllowBots::Mentions, + "all" => AllowBots::All, + _ => AllowBots::Off, + }; + let trusted_bot_ids = parse_csv("FEISHU_TRUSTED_BOT_IDS"); + let allow_user_messages = match std::env::var("FEISHU_ALLOW_USER_MESSAGES") + .unwrap_or_else(|_| "multibot_mentions".into()) + .to_lowercase() + .replace('-', "_") + .as_str() + { + "involved" => AllowUsers::Involved, + "mentions" => AllowUsers::Mentions, + _ => AllowUsers::MultibotMentions, + }; + let max_bot_turns = std::env::var("FEISHU_MAX_BOT_TURNS") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(20); + let dedupe_ttl_secs = std::env::var("FEISHU_DEDUPE_TTL_SECS") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(300); + let message_limit = std::env::var("FEISHU_MESSAGE_LIMIT") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(4000); + let session_ttl_secs = std::env::var("FEISHU_SESSION_TTL_HOURS") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(24) + * 3600; + + Some(Self { + app_id, + app_secret, + domain, + connection_mode, + webhook_path, + verification_token, + encrypt_key, + allowed_groups, + allowed_users, + require_mention, + allow_bots, + allow_user_messages, + trusted_bot_ids, + max_bot_turns, + dedupe_ttl_secs, + message_limit, + session_ttl_secs, + api_base_override: None, + }) + } + + /// API base URL for the configured domain. + pub fn api_base(&self) -> String { + if let Some(ref base) = self.api_base_override { + return base.clone(); + } + if self.domain == "lark" { + "https://open.larksuite.com".into() + } else { + "https://open.feishu.cn".into() + } + } +} + +fn parse_csv(var: &str) -> Vec { + std::env::var(var) + .unwrap_or_default() + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect() +} + +// --------------------------------------------------------------------------- +// Feishu event types (im.message.receive_v1) +// --------------------------------------------------------------------------- + +mod event_types { + use super::*; + + #[derive(Debug, Deserialize)] + pub struct FeishuEventEnvelope { + pub header: Option, + pub event: Option, + pub challenge: Option, + // Parsed by serde, not consumed in current code paths. + #[allow(dead_code)] + #[serde(rename = "type")] + pub event_type_field: Option, + } + + #[derive(Debug, Deserialize)] + pub struct FeishuEventHeader { + pub event_id: Option, + // Parsed by serde, not consumed in current code paths. + #[allow(dead_code)] + pub event_type: Option, + } + + #[derive(Debug, Deserialize)] + pub struct FeishuEventBody { + pub sender: Option, + pub message: Option, + } + + #[derive(Debug, Deserialize)] + pub struct FeishuSender { + pub sender_id: Option, + pub sender_type: Option, + } + + #[derive(Debug, Deserialize)] + pub struct FeishuSenderId { + pub open_id: Option, + } + + #[derive(Debug, Deserialize)] + pub struct FeishuMessage { + pub message_id: Option, + pub chat_id: Option, + pub chat_type: Option, + pub message_type: Option, + pub content: Option, + pub mentions: Option>, + pub root_id: Option, + pub parent_id: Option, + } + + #[derive(Debug, Deserialize)] + pub struct FeishuMention { + pub key: Option, + pub id: Option, + // Parsed by serde, not consumed in current code paths. + #[allow(dead_code)] + pub name: Option, + } + + #[derive(Debug, Deserialize)] + pub struct FeishuMentionId { + pub open_id: Option, + } + + /// Parse a feishu im.message.receive_v1 event into a GatewayEvent. + /// Returns None if the event should be skipped (unsupported type, bot message, etc). + /// The Vec contains references to media that need async download. + /// + /// `bypass_mention_gating`: whether the bot should skip @mention requirement for this message. + /// This is the final computed result from mode-specific logic (detect_and_mark_multibot), + /// already accounting for the configured `allow_user_messages` mode. + /// Do NOT pass raw participation status here. + pub fn parse_message_event( + envelope: &FeishuEventEnvelope, + bot_open_id: Option<&str>, + config: &FeishuConfig, + bypass_mention_gating: bool, + ) -> Option<(GatewayEvent, Vec)> { + let _header = envelope.header.as_ref()?; + let event = envelope.event.as_ref()?; + let msg = event.message.as_ref()?; + let sender = event.sender.as_ref()?; + + let msg_type = msg.message_type.as_deref().unwrap_or("text"); + if !matches!(msg_type, "text" | "image" | "file" | "post" | "audio") { + return None; + } + // Skip bot messages with explicit sender_type + if matches!(sender.sender_type.as_deref(), Some("bot") | Some("app")) { + return None; + } + + let sender_open_id = sender.sender_id.as_ref()?.open_id.as_deref()?; + // Skip messages from self + if let Some(bot_id) = bot_open_id { + if sender_open_id == bot_id { + return None; + } + } + + // Check if sender is a known bot: + // Bot identification: + // 1. If trusted_bot_ids is configured, check against it + // 2. If trusted_bot_ids is empty, we cannot reliably identify bots + // (Feishu marks other bots as sender_type="user") + let is_bot_sender = if !config.trusted_bot_ids.is_empty() { + config.trusted_bot_ids.iter().any(|id| id == sender_open_id) + } else { + false + }; + + // User allowlist: if configured, only allow listed users. + // Trusted bots bypass user allowlist (same as Discord behavior). + if !is_bot_sender + && !config.allowed_users.is_empty() + && !config.allowed_users.iter().any(|u| u == sender_open_id) + { + return None; + } + + if is_bot_sender { + match config.allow_bots { + AllowBots::Off => return None, + AllowBots::Mentions | AllowBots::All => { + // Allowed — will check mentions below for Mentions mode + } + } + } + + let chat_id = msg.chat_id.as_deref()?; + // Group allowlist: if configured, only allow listed groups + let is_group = msg.chat_type.as_deref() != Some("p2p"); + if is_group + && !config.allowed_groups.is_empty() + && !config.allowed_groups.iter().any(|g| g == chat_id) + { + return None; + } + + let content_json: serde_json::Value = msg.content.as_deref() + .and_then(|s| serde_json::from_str(s).ok())?; + + let message_id = msg.message_id.as_deref()?; + + // Parse content based on message type + let (clean_text, mention_ids, media_refs) = match msg_type { + "image" => { + let image_key = content_json.get("image_key")?.as_str()?; + let mentions = extract_mentions( + "", msg.mentions.as_deref().unwrap_or(&[]), bot_open_id, + ); + let refs = vec![MediaRef::Image { + message_id: message_id.to_string(), + image_key: image_key.to_string(), + }]; + (String::new(), mentions.1, refs) + } + "file" => { + let file_key = content_json.get("file_key")?.as_str()?; + let file_name = content_json.get("file_name") + .and_then(|v| v.as_str()) + .unwrap_or("unknown"); + let mentions = extract_mentions( + "", msg.mentions.as_deref().unwrap_or(&[]), bot_open_id, + ); + let refs = vec![MediaRef::File { + message_id: message_id.to_string(), + file_key: file_key.to_string(), + file_name: file_name.to_string(), + }]; + (String::new(), mentions.1, refs) + } + "audio" => { + let file_key = content_json.get("file_key")?.as_str()?; + let mentions = extract_mentions( + "", msg.mentions.as_deref().unwrap_or(&[]), bot_open_id, + ); + let refs = vec![MediaRef::Audio { + message_id: message_id.to_string(), + file_key: file_key.to_string(), + }]; + (String::new(), mentions.1, refs) + } + "post" => { + // Rich text: content is {"title":"...","content":[[{tag,text,...},{tag,image_key,...}]]} + let mut texts = Vec::new(); + let mut refs = Vec::new(); + if let Some(rows) = content_json.get("content").and_then(|v| v.as_array()) { + for row in rows { + if let Some(elements) = row.as_array() { + for el in elements { + match el.get("tag").and_then(|v| v.as_str()) { + Some("text") => { + if let Some(t) = el.get("text").and_then(|v| v.as_str()) { + texts.push(t.to_string()); + } + } + Some("img") => { + if let Some(key) = el.get("image_key").and_then(|v| v.as_str()) { + refs.push(MediaRef::Image { + message_id: message_id.to_string(), + image_key: key.to_string(), + }); + } + } + Some("a") => { + if let Some(t) = el.get("text").and_then(|v| v.as_str()) { + texts.push(t.to_string()); + } + } + Some("at") => { + // Mentions handled via msg.mentions at envelope level + } + _ => {} + } + } + } + } + } + let raw_text = texts.join(""); + let (clean, ids) = extract_mentions( + &raw_text, + msg.mentions.as_deref().unwrap_or(&[]), + bot_open_id, + ); + (clean, ids, refs) + } + _ => { + // text + let raw_text = content_json.get("text").and_then(|v| v.as_str()).unwrap_or(""); + if raw_text.trim().is_empty() { + return None; + } + let (clean, ids) = extract_mentions( + raw_text, + msg.mentions.as_deref().unwrap_or(&[]), + bot_open_id, + ); + if clean.trim().is_empty() { + return None; + } + (clean, ids, Vec::new()) + } + }; + + let channel_type = match msg.chat_type.as_deref() { + Some("p2p") => "direct", + _ => "group", + }; + + let thread_id = msg.root_id.clone().or_else(|| msg.parent_id.clone()); + + // Gateway-side mention gating: in groups, skip if require_mention + // is true and bot is not mentioned (for human senders). + // Bypass: if bot has previously replied in this thread (participated), + // no @mention needed (like Discord's "involved" mode). + let in_thread = thread_id.is_some(); + if channel_type == "group" + && !is_bot_sender + && config.require_mention + && !(in_thread && bypass_mention_gating) + { + if let Some(bot_id) = bot_open_id { + let bot_mentioned = mention_ids.iter().any(|id| id == bot_id); + if !bot_mentioned { + return None; + } + } + } + + // Bot-to-bot mention gating: in AllowBots::Mentions mode, + // bot messages must @mention this bot (like Discord "mentions" mode). + // Note: in DMs there is no @mention mechanism, so bot DMs are + // silently dropped in Mentions mode. Use AllowBots::All for DM bots. + if is_bot_sender && config.allow_bots == AllowBots::Mentions { + if let Some(bot_id) = bot_open_id { + let bot_mentioned = mention_ids.iter().any(|id| id == bot_id); + if !bot_mentioned { + return None; + } + } + } + + let event = GatewayEvent::new( + "feishu", + ChannelInfo { + id: chat_id.to_string(), + channel_type: channel_type.to_string(), + thread_id, + }, + SenderInfo { + id: sender_open_id.to_string(), + name: sender_open_id.to_string(), + display_name: sender_open_id.to_string(), + is_bot: is_bot_sender, + }, + clean_text.trim(), + message_id, + mention_ids, + ); + Some((event, media_refs)) + } + + fn extract_mentions( + raw_text: &str, + mentions: &[FeishuMention], + bot_open_id: Option<&str>, + ) -> (String, Vec) { + let mut text = raw_text.to_string(); + let mut ids = Vec::new(); + for m in mentions { + let open_id = m.id.as_ref().and_then(|id| id.open_id.as_deref()); + if let Some(oid) = open_id { + ids.push(oid.to_string()); + if let Some(key) = m.key.as_deref() { + if bot_open_id == Some(oid) { + text = text.replacen(key, "", 1); + } + } + } + } + (text, ids) + } +} + +pub use event_types::*; + +// --------------------------------------------------------------------------- +// Deduplication +// --------------------------------------------------------------------------- + +pub struct DedupeCache { + seen: std::sync::Mutex>, + ttl_secs: u64, + max_size: usize, +} + +impl DedupeCache { + pub fn new(ttl_secs: u64) -> Self { + Self { + seen: std::sync::Mutex::new(HashMap::new()), + ttl_secs, + max_size: 10_000, + } + } + + /// Returns true if this id was already seen (duplicate). + pub fn is_duplicate(&self, id: &str) -> bool { + let mut map = self.seen.lock().unwrap_or_else(|e| e.into_inner()); + // Lazy sweep + if map.len() >= self.max_size { + map.retain(|_, ts| ts.elapsed().as_secs() < self.ttl_secs); + } + if let Some(ts) = map.get(id) { + if ts.elapsed().as_secs() < self.ttl_secs { + return true; + } + } + map.insert(id.to_string(), Instant::now()); + false + } +} + +// --------------------------------------------------------------------------- +// Token cache +// --------------------------------------------------------------------------- + +pub struct FeishuTokenCache { + /// (token, created_at, ttl_secs) + token: RwLock>, + api_base: String, + app_id: String, + app_secret: String, +} + +/// Refresh margin: renew 5 minutes before expiry. +const TOKEN_REFRESH_MARGIN_SECS: u64 = 300; + +impl FeishuTokenCache { + pub fn new(config: &FeishuConfig) -> Self { + Self { + token: RwLock::new(None), + api_base: config.api_base(), + app_id: config.app_id.clone(), + app_secret: config.app_secret.clone(), + } + } + + /// Construct with explicit api_base (for tests). + pub fn with_base(config: &FeishuConfig, api_base: &str) -> Self { + Self { + token: RwLock::new(None), + api_base: api_base.to_string(), + app_id: config.app_id.clone(), + app_secret: config.app_secret.clone(), + } + } + + /// Get a valid tenant_access_token, refreshing if expired or missing. + pub async fn get_token(&self, client: &reqwest::Client) -> anyhow::Result { + // Fast path: read lock + { + let guard = self.token.read().await; + if let Some((ref tok, ref ts, ttl)) = *guard { + if ts.elapsed().as_secs() < ttl.saturating_sub(TOKEN_REFRESH_MARGIN_SECS) { + return Ok(tok.clone()); + } + } + } + // Slow path: write lock + refresh + let mut guard = self.token.write().await; + // Double-check after acquiring write lock + if let Some((ref tok, ref ts, ttl)) = *guard { + if ts.elapsed().as_secs() < ttl.saturating_sub(TOKEN_REFRESH_MARGIN_SECS) { + return Ok(tok.clone()); + } + } + let (new_token, expire) = self.refresh(client).await?; + *guard = Some((new_token.clone(), Instant::now(), expire)); + Ok(new_token) + } + + async fn refresh(&self, client: &reqwest::Client) -> anyhow::Result<(String, u64)> { + let url = format!( + "{}/open-apis/auth/v3/tenant_access_token/internal", + self.api_base + ); + let resp = client + .post(&url) + .json(&serde_json::json!({ + "app_id": self.app_id, + "app_secret": self.app_secret, + })) + .send() + .await + .map_err(|e| anyhow::anyhow!("feishu token refresh request failed: {e}"))?; + + let status = resp.status(); + let body: serde_json::Value = resp + .json() + .await + .map_err(|e| anyhow::anyhow!("feishu token refresh parse failed: {e}"))?; + + let code = body.get("code").and_then(|v| v.as_i64()).unwrap_or(-1); + if code != 0 { + let msg = body + .get("msg") + .and_then(|v| v.as_str()) + .unwrap_or("unknown"); + anyhow::bail!("feishu token refresh error: code={code} msg={msg} status={status}"); + } + + let expire = body.get("expire").and_then(|v| v.as_u64()).unwrap_or(7200); + + let token = body.get("tenant_access_token") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + .ok_or_else(|| anyhow::anyhow!("feishu token refresh: missing tenant_access_token"))?; + + Ok((token, expire)) + } +} + +// --------------------------------------------------------------------------- +// Adapter (aggregated state) +// --------------------------------------------------------------------------- + +pub struct FeishuAdapter { + pub config: FeishuConfig, + pub token_cache: Arc, + pub bot_open_id: Arc>>, + pub dedupe: Arc, + pub rate_limiter: Arc, + pub name_cache: Arc>>, + /// Per-channel bot turn counter. Key = chat_id, Value = (count, last_reset). + /// Human message resets count to 0. Prevents runaway bot-to-bot loops. + pub bot_turns: Arc>>, // eviction: human msg resets; follow-up can add TTL like participated_threads + /// Positive-only cache: thread_id (root_id) → last_replied_at. + /// When bot has replied in a thread, subsequent messages in that thread + /// bypass @mention gating (like Discord's "involved" mode). + pub participated_threads: Arc>>, + /// Positive-only cache: thread_id → first_seen for threads where other bots + /// have posted. Used by multibot-mentions mode to require @mention. + pub multibot_threads: Arc>>, + /// Per-message edit count tracker for Feishu's 20-edits-per-message hard cap + /// (errcode 230072 — "The message has reached the number of times it can be edited"). + /// Insertion-order FIFO eviction: when over `EDIT_COUNTS_CACHE_MAX`, the + /// oldest *insertions* are dropped, not the lowest-count entries — so a + /// just-started active stream is far less likely to be evicted than under a + /// count-ascending policy. (A very long-lived stream can still age out once + /// 4096 newer messages have been inserted behind it; that resets its count + /// to 1, which is acceptable — it only loses the local preemptive margin and + /// the on-wire 230072 sentinel still backstops.) + pub edit_counts: Arc>, + pub client: reqwest::Client, +} + +/// Insertion-order edit-count cache for Feishu's per-message edit cap. +/// +/// `counts` holds the current edit count (or `u32::MAX` cap-reached sentinel) +/// for each message_id. `order` records insertion order so eviction is FIFO +/// rather than count-ascending; this matters because count-ascending would +/// preferentially target *active* streams (low count = just started) while +/// leaving stale cap-reached entries in place. FIFO instead ages out the +/// oldest insertions, which strongly favours keeping active streams. +#[derive(Default)] +pub struct EditCountsCache { + pub counts: HashMap, + pub order: VecDeque, +} + +impl FeishuAdapter { + pub fn new(config: FeishuConfig) -> Self { + let token_cache = Arc::new(FeishuTokenCache::new(&config)); + let dedupe = Arc::new(DedupeCache::new(config.dedupe_ttl_secs)); + let rate_limiter = Arc::new(RateLimiter::new(60, 120)); + Self { + config, + token_cache, + dedupe, + rate_limiter, + bot_open_id: Arc::new(RwLock::new(None)), + name_cache: Arc::new(std::sync::Mutex::new(HashMap::new())), + bot_turns: Arc::new(std::sync::Mutex::new(HashMap::new())), + participated_threads: Arc::new(std::sync::Mutex::new(HashMap::new())), + multibot_threads: Arc::new(std::sync::Mutex::new(HashMap::new())), + edit_counts: Arc::new(std::sync::Mutex::new(EditCountsCache::default())), + client: reqwest::Client::new(), + } + } + + /// Resolve bot identity (open_id) via API. Called during startup for both + /// WebSocket and webhook modes so mention gating works in either mode. + pub async fn resolve_bot_identity(&self) { + let token = match self.token_cache.get_token(&self.client).await { + Ok(t) => t, + Err(e) => { + warn!(err = %e, "feishu bot identity lookup failed (token error), mention gating may not work"); + return; + } + }; + match get_bot_info(&self.client, &self.config.api_base(), &token).await { + Ok(bot_id) => { + info!(bot_open_id = %bot_id, "feishu bot identity resolved"); + *self.bot_open_id.write().await = Some(bot_id); + } + Err(e) => { + warn!(err = %e, "feishu bot identity lookup failed, mention gating may not work"); + } + } + } +} + +// --------------------------------------------------------------------------- +// WebSocket long connection +// --------------------------------------------------------------------------- + +use futures_util::{SinkExt, StreamExt}; +use tokio::sync::{broadcast, watch}; + +/// Get WebSocket endpoint URL from feishu API. +/// Note: This API uses AppID+AppSecret directly, not Bearer token. +async fn get_ws_endpoint( + client: &reqwest::Client, + api_base: &str, + app_id: &str, + app_secret: &str, +) -> anyhow::Result { + let url = format!("{}/callback/ws/endpoint", api_base); + let resp = client + .post(&url) + .json(&serde_json::json!({ + "AppID": app_id, + "AppSecret": app_secret, + })) + .send() + .await?; + let body: serde_json::Value = resp.json().await?; + let code = body.get("code").and_then(|v| v.as_i64()).unwrap_or(-1); + if code != 0 { + let msg = body.get("msg").and_then(|v| v.as_str()).unwrap_or("unknown"); + anyhow::bail!("feishu ws endpoint error: code={code} msg={msg}"); + } + body.get("data") + .and_then(|d| d.get("URL")) + .and_then(|u| u.as_str()) + .map(|s| s.to_string()) + .ok_or_else(|| anyhow::anyhow!("feishu ws endpoint: missing URL")) +} + +/// Get bot identity (open_id) via bot info API. +async fn get_bot_info( + client: &reqwest::Client, + api_base: &str, + token: &str, +) -> anyhow::Result { + let url = format!("{}/open-apis/bot/v3/info", api_base); + let resp = client.get(&url).bearer_auth(token).send().await?; + let body: serde_json::Value = resp.json().await?; + body.get("bot") + .and_then(|b| b.get("open_id")) + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + .ok_or_else(|| anyhow::anyhow!("feishu bot info: missing open_id")) +} + +/// Spawn the feishu WebSocket long-connection task. +/// Returns a JoinHandle that runs until shutdown_rx fires. +pub async fn start_websocket( + adapter: &FeishuAdapter, + event_tx: broadcast::Sender, + mut shutdown_rx: watch::Receiver, +) -> anyhow::Result> { + let token_cache = adapter.token_cache.clone(); + let bot_open_id_store = adapter.bot_open_id.clone(); + let dedupe = adapter.dedupe.clone(); + let config = adapter.config.clone(); + let client = adapter.client.clone(); + let name_cache = adapter.name_cache.clone(); + let bot_turns = adapter.bot_turns.clone(); + let participated_threads = adapter.participated_threads.clone(); + let multibot_threads = adapter.multibot_threads.clone(); + + let handle = tokio::spawn(async move { + let mut backoff_secs = 1u64; + loop { + let result = ws_connect_loop( + &token_cache, + &bot_open_id_store, + &dedupe, + &config, + &client, + &event_tx, + &mut shutdown_rx, + &name_cache, + &bot_turns, + &participated_threads, + &multibot_threads, + ) + .await; + + if *shutdown_rx.borrow() { + info!("feishu websocket shutting down"); + break; + } + + match result { + Ok(()) => { + info!("feishu websocket disconnected, reconnecting..."); + backoff_secs = 1; + } + Err(e) => { + tracing::error!(err = %e, backoff = backoff_secs, "feishu websocket error, reconnecting..."); + backoff_secs = (backoff_secs * 2).min(120); + } + } + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_secs(backoff_secs)) => {} + _ = shutdown_rx.changed() => { break; } + } + } + }); + + Ok(handle) +} + +/// Single WebSocket connection lifecycle. +#[allow(clippy::too_many_arguments)] +async fn ws_connect_loop( + token_cache: &Arc, + bot_open_id_store: &Arc>>, + dedupe: &Arc, + config: &FeishuConfig, + client: &reqwest::Client, + event_tx: &broadcast::Sender, + shutdown_rx: &mut watch::Receiver, + name_cache: &Arc>>, + bot_turns: &Arc>>, + participated_threads: &Arc>>, + multibot_threads: &Arc>>, +) -> anyhow::Result<()> { + let api_base = config.api_base(); + + // Refresh bot identity on each reconnect in case it was not resolved earlier + if bot_open_id_store.read().await.is_none() { + if let Ok(token) = token_cache.get_token(client).await { + if let Ok(bot_id) = get_bot_info(client, &api_base, &token).await { + info!(bot_open_id = %bot_id, "feishu bot identity resolved on reconnect"); + *bot_open_id_store.write().await = Some(bot_id); + } + } + } + + let ws_url = get_ws_endpoint(client, &api_base, &config.app_id, &config.app_secret).await?; + info!(url = %ws_url, "feishu websocket connecting"); + + let (ws_stream, _) = tokio_tungstenite::connect_async(&ws_url).await?; + let (mut ws_tx, mut ws_rx) = ws_stream.split(); + info!("feishu websocket connected"); + + loop { + tokio::select! { + msg = ws_rx.next() => { + match msg { + Some(Ok(tokio_tungstenite::tungstenite::Message::Text(text))) => { + handle_ws_message( + &text, bot_open_id_store, dedupe, config, event_tx, + name_cache, token_cache, client, bot_turns, participated_threads, multibot_threads, + ).await; + } + Some(Ok(tokio_tungstenite::tungstenite::Message::Ping(data))) => { + let _ = ws_tx.send(tokio_tungstenite::tungstenite::Message::Pong(data)).await; + } + Some(Ok(tokio_tungstenite::tungstenite::Message::Close(_))) | None => { + return Ok(()); + } + Some(Err(e)) => { + return Err(anyhow::anyhow!("websocket error: {e}")); + } + Some(Ok(tokio_tungstenite::tungstenite::Message::Binary(data))) => { + match WsFrame::decode(data.as_ref()) { + Ok(frame) => { + // method=1 is data frame (events), method=0 is control + if frame.method == 1 { + if let Some(ref payload) = frame.payload { + if let Ok(text) = String::from_utf8(payload.clone()) { + handle_ws_message( + &text, bot_open_id_store, dedupe, config, event_tx, + name_cache, token_cache, client, bot_turns, participated_threads, multibot_threads, + ).await; + } + } + // Send ACK: echo frame back with {"code":200} payload + let mut ack = frame.clone(); + ack.payload = Some(b"{\"code\":200}".to_vec()); + let ack_bytes = ack.encode_to_vec(); + let _ = ws_tx.send( + tokio_tungstenite::tungstenite::Message::Binary(ack_bytes) + ).await; + } + } + Err(e) => { + tracing::debug!(err = %e, len = data.len(), "feishu ws protobuf decode failed"); + } + } + } + _ => {} + } + } + _ = shutdown_rx.changed() => { + let _ = ws_tx.send(tokio_tungstenite::tungstenite::Message::Close(None)).await; + return Ok(()); + } + } + } +} + +/// Process a single WebSocket text message. +#[allow(clippy::too_many_arguments)] +async fn handle_ws_message( + text: &str, + bot_open_id_store: &Arc>>, + dedupe: &Arc, + config: &FeishuConfig, + event_tx: &broadcast::Sender, + name_cache: &Arc>>, + token_cache: &Arc, + client: &reqwest::Client, + bot_turns: &Arc>>, + participated_threads: &Arc>>, + multibot_threads: &Arc>>, +) { + let envelope: FeishuEventEnvelope = match serde_json::from_str(text) { + Ok(e) => e, + Err(_) => return, + }; + + // Handle challenge frame (Feishu may send this in WS mode for verification) + if let Some(ref challenge) = envelope.challenge { + tracing::debug!(challenge = %challenge, "feishu ws challenge received (ignored in WS mode)"); + return; + } + + // Debug: log sender_type for diagnosing bot-to-bot loops + if let Some(ref event) = envelope.event { + if let Some(ref sender) = event.sender { + tracing::debug!( + sender_type = ?sender.sender_type, + sender_id = ?sender.sender_id.as_ref().and_then(|s| s.open_id.as_deref()), + "feishu ws event sender" + ); + } + } + + // Dedupe by event_id + if let Some(ref header) = envelope.header { + if let Some(ref event_id) = header.event_id { + if dedupe.is_duplicate(event_id) { + return; + } + } + } + + let bot_id = bot_open_id_store.read().await; + let bot_id_ref = bot_id.as_deref(); + + // Check if the message is in a thread where bot has previously replied, + // respecting the allow_user_messages mode: + // - Involved (default): bypass @mention if participated + // - MultibotMentions: bypass only if participated AND no other bot in thread + // - Mentions: never bypass + let bypass_mention = detect_and_mark_multibot( + &envelope, bot_id_ref, config, participated_threads, multibot_threads, + ); + + if let Some((mut gateway_event, media_refs)) = parse_message_event(&envelope, bot_id_ref, config, bypass_mention) { + // Also dedupe by message_id + if dedupe.is_duplicate(&gateway_event.message_id) { + return; + } + + // Bot turn tracking: prevent runaway bot-to-bot loops + let channel_id = &gateway_event.channel.id; + { + let mut turns = bot_turns.lock().unwrap_or_else(|e| e.into_inner()); + if gateway_event.sender.is_bot { + let count = turns.entry(channel_id.to_string()).or_insert(0); + *count += 1; + if *count > config.max_bot_turns { + warn!( + channel = %channel_id, + count = *count, + max = config.max_bot_turns, + "feishu: bot turn limit reached, dropping message" + ); + return; + } + // (Feishu doesn't push bot messages to other bots' WebSocket, + // so multibot detection is done via mentions instead — see below.) + } else { + // Human message resets bot turn counter + turns.remove(channel_id.as_str()); + } + } + + // Resolve sender display name (lazy, cached) + let name = resolve_user_name( + &gateway_event.sender.id, name_cache, token_cache, client, &config.api_base(), + ).await; + gateway_event.sender.name = name.clone(); + gateway_event.sender.display_name = name; + + // Download media attachments (images, text files) + if !media_refs.is_empty() { + if let Ok(token) = token_cache.get_token(client).await { + let api_base = config.api_base(); + for media_ref in &media_refs { + let attachment = match media_ref { + MediaRef::Image { message_id, image_key } => { + download_feishu_image(client, &api_base, &token, message_id, image_key).await + } + MediaRef::File { message_id, file_key, file_name } => { + download_feishu_file(client, &api_base, &token, message_id, file_key, file_name).await + } + MediaRef::Audio { message_id, file_key } => { + download_feishu_audio(client, &api_base, &token, message_id, file_key).await + } + }; + if let Some(att) = attachment { + gateway_event.content.attachments.push(att); + } + } + } + } + + // Skip if no text and no attachments (e.g. unsupported file type) + if gateway_event.content.text.trim().is_empty() && gateway_event.content.attachments.is_empty() { + return; + } + + let json = serde_json::to_string(&gateway_event).unwrap(); + info!( + channel = %gateway_event.channel.id, + thread_id = ?gateway_event.channel.thread_id, + sender = %gateway_event.sender.id, + "feishu → gateway" + ); + let _ = event_tx.send(json); + } +} + +/// Resolve user display name from open_id via Contact API, with caching. +async fn resolve_user_name( + open_id: &str, + name_cache: &Arc>>, + token_cache: &Arc, + client: &reqwest::Client, + api_base: &str, +) -> String { + { + let cache = name_cache.lock().unwrap_or_else(|e| e.into_inner()); + if let Some(name) = cache.get(open_id) { + return name.clone(); + } + } + let token = match token_cache.get_token(client).await { + Ok(t) => t, + Err(_) => return open_id.to_string(), + }; + let url = format!( + "{}/open-apis/contact/v3/users/{}?user_id_type=open_id", + api_base, open_id + ); + let resolved = match client.get(&url).bearer_auth(&token).send().await { + Ok(resp) => { + let body: serde_json::Value = resp.json().await.unwrap_or_default(); + body.pointer("/data/user/name") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + } + Err(_) => None, + }; + // Only cache successful resolutions — don't cache fallback open_id + // so retries can succeed after permissions are granted. + if let Some(ref name) = resolved { + let mut cache = name_cache.lock().unwrap_or_else(|e| e.into_inner()); + if cache.len() < 10_000 { + cache.insert(open_id.to_string(), name.clone()); + } + } + resolved.unwrap_or_else(|| open_id.to_string()) +} + +// --------------------------------------------------------------------------- +// Send message +/// Edit (update) an existing feishu message in-place for streaming. +/// Feishu message edit cap: API returns errcode 230072 after 20 edits per message. +/// We stop preemptively at 18 to leave a 2-edit safety margin (handles races where +/// multiple in-flight edits could each push count to the wall) and also catch 230072 +/// defensively in case the local count drifts from server reality. +const FEISHU_EDIT_CAP: u32 = 18; + +/// Maximum entries in the per-adapter edit_counts cache before lazy eviction kicks in. +const EDIT_COUNTS_CACHE_MAX: usize = 4096; + +/// Validates that a Feishu message_id matches the expected `om_` shape +/// before it is interpolated into a REST URL path. Feishu's documented +/// message_id format is the `om_` prefix followed by base62-style characters +/// (`[A-Za-z0-9_]`). Rejecting anything else stops crafted IDs containing `/`, +/// `?`, or `#` from altering URL semantics — defence in depth, since the trust +/// boundary is the core↔gateway WebSocket and not external input. +fn is_valid_feishu_message_id(id: &str) -> bool { + let bytes = id.as_bytes(); + if !id.starts_with("om_") || id.len() < 4 || id.len() > 128 { + return false; + } + bytes + .iter() + .all(|b| b.is_ascii_alphanumeric() || *b == b'_') +} + +/// Detect whether a Feishu API response body indicates the per-message edit +/// cap (errcode 230072). Trusts JSON `code` field when the body parses as +/// JSON; falls back to substring match only on non-JSON bodies (proxy HTML, +/// truncated responses, …) so a JSON body with an unrelated `code` cannot be +/// false-positively flagged just because some inner string contains "230072". +fn is_feishu_cap_reached_body(body: &str) -> bool { + match serde_json::from_str::(body) { + Ok(v) => v + .get("code") + .and_then(|c| c.as_i64()) + .is_some_and(|code| code == 230072), + Err(_) => { + body.contains("230072") + || body.contains("number of times it can be edited") + } + } +} + +/// Outcome of an edit_feishu_message attempt. Distinguishes the cap-reached case +/// from generic failure so the caller can stop attempting edits and let the +/// core finalize path handle recovery. +pub enum EditOutcome { + /// Edit succeeded; the on-screen message now reflects the new content. + Edited, + /// The 20-edits-per-message cap is exhausted (either tracked locally or + /// signaled by errcode 230072). Caller should stop attempting edits; + /// recovery (delete placeholder + send fresh) is handled at the core + /// finalize layer in `src/adapter.rs`, not here — appending new messages + /// per cosmetic flush would spam the user with continuation messages. + CapReached, + /// Generic failure (network, token, other API errors). + Failed(String), +} + +/// Increment the edit count for a message_id. New keys are appended to the +/// FIFO order queue; existing keys keep their position. When the cache is +/// over `EDIT_COUNTS_CACHE_MAX`, the oldest *insertions* are evicted (not the +/// lowest-count entries) so active streams are not bumped out from under +/// themselves. +fn increment_edit_count( + cache: &Arc>, + message_id: &str, +) { + let mut c = cache.lock().unwrap_or_else(|e| e.into_inner()); + let was_new = !c.counts.contains_key(message_id); + let entry = c.counts.entry(message_id.to_string()).or_insert(0); + if *entry != u32::MAX { + *entry = entry.saturating_add(1); + } + if was_new { + c.order.push_back(message_id.to_string()); + evict_if_overcap(&mut c); + } +} + +/// Mark a message_id as cap-reached; subsequent edit attempts skip the API +/// call and signal `EditOutcome::CapReached` directly so the core finalize +/// path can take over. +fn mark_edit_cap( + cache: &Arc>, + message_id: &str, +) { + let mut c = cache.lock().unwrap_or_else(|e| e.into_inner()); + let was_new = !c.counts.contains_key(message_id); + c.counts.insert(message_id.to_string(), u32::MAX); + if was_new { + c.order.push_back(message_id.to_string()); + evict_if_overcap(&mut c); + } +} + +/// FIFO eviction helper: when over `EDIT_COUNTS_CACHE_MAX`, drop the oldest +/// half by insertion order. Tolerant of `order`/`counts` drift — entries that +/// only exist in `order` are silently skipped. +fn evict_if_overcap(c: &mut EditCountsCache) { + if c.counts.len() > EDIT_COUNTS_CACHE_MAX { + let evict = c.counts.len() / 2; + for _ in 0..evict { + if let Some(oldest) = c.order.pop_front() { + c.counts.remove(&oldest); + } else { + break; + } + } + } +} + +/// Return true if this message_id has already reached the edit cap (either +/// tracked locally or marked via 230072 sentinel). +fn is_edit_cap_reached( + cache: &Arc>, + message_id: &str, +) -> bool { + let c = cache.lock().unwrap_or_else(|e| e.into_inner()); + c.counts + .get(message_id) + .is_some_and(|&n| n >= FEISHU_EDIT_CAP) +} + +/// Edit (update) an existing Feishu message in-place for streaming. +/// +/// Returns [`EditOutcome`] so the caller can distinguish success, cap-reached, +/// and generic failure. Performs a preemptive local cap check (`FEISHU_EDIT_CAP`) +/// before hitting the network, and detects the server-side errcode 230072 via +/// body-code-first parsing if the local count drifts from reality. +async fn edit_feishu_message( + adapter: &FeishuAdapter, + message_id: &str, + text: &str, +) -> EditOutcome { + // Pre-check: if we've already tracked >= FEISHU_EDIT_CAP edits (or the sentinel + // u32::MAX from a 230072 response), skip the API call and signal CapReached so + // the caller can stop attempting edits and let the core finalize path recover. + if is_edit_cap_reached(&adapter.edit_counts, message_id) { + return EditOutcome::CapReached; + } + + let token = match adapter.token_cache.get_token(&adapter.client).await { + Ok(t) => t, + Err(e) => { + tracing::error!(err = %e, "feishu: cannot get token for edit"); + return EditOutcome::Failed(format!("token error: {e}")); + } + }; + let api_base = adapter.config.api_base(); + let url = format!("{}/open-apis/im/v1/messages/{}", api_base, message_id); + let post_content = markdown_to_post(text); + let body = serde_json::json!({ + "msg_type": "post", + "content": post_content.to_string(), + }); + match adapter.client.put(&url).bearer_auth(&token) + .header("Content-Type", "application/json; charset=utf-8") + .json(&body).send().await + { + Ok(resp) => { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + // Feishu OpenAPI convention: the business result lives in the body + // `code` field, and an edit-cap rejection (errcode 230072) can arrive + // with HTTP 200. So we decide on the body — consistent with token + // refresh and the WS endpoint elsewhere in this file — rather than + // trusting HTTP status alone, which would miscount a 200 + non-zero + // `code` response as a successful edit and never reach cap detection. + // + // This relies on Feishu returning `code` as a JSON integer (which it + // always does). A non-integer or absent code falls through to the + // HTTP-status fallback below, so a malformed 2xx body is treated as + // success — acceptable, since Feishu never emits such a body. + // + // 1. Cap reached? `is_feishu_cap_reached_body` is the sole authority + // (JSON code == 230072, or substring fallback for non-JSON bodies). + if is_feishu_cap_reached_body(&body) { + mark_edit_cap(&adapter.edit_counts, message_id); + tracing::warn!( + message_id = %message_id, + status = %status, + "feishu edit cap reached (errcode 230072); signaling core for cap-reached recovery" + ); + return EditOutcome::CapReached; + } + // 2. Otherwise classify by body `code` (0 = success), falling back to + // HTTP status only for non-JSON bodies (proxy HTML, truncated). + match serde_json::from_str::(&body) + .ok() + .and_then(|v| v.get("code").and_then(|c| c.as_i64())) + { + Some(0) => { + increment_edit_count(&adapter.edit_counts, message_id); + tracing::trace!(message_id = %message_id, "feishu message edited"); + EditOutcome::Edited + } + Some(code) => { + tracing::error!( + message_id = %message_id, + status = %status, + code, + body = %body, + "feishu edit message failed" + ); + EditOutcome::Failed(format!("code {code}: {body}")) + } + None => { + // Body wasn't JSON-with-code; trust HTTP status as last resort. + if status.is_success() { + increment_edit_count(&adapter.edit_counts, message_id); + tracing::trace!(message_id = %message_id, "feishu message edited (non-JSON 2xx body)"); + EditOutcome::Edited + } else { + tracing::error!( + message_id = %message_id, + status = %status, + body = %body, + "feishu edit message failed" + ); + EditOutcome::Failed(format!("HTTP {status}: {body}")) + } + } + } + } + Err(e) => { + tracing::error!(message_id = %message_id, err = %e, "feishu edit message request failed"); + EditOutcome::Failed(format!("request error: {e}")) + } + } +} + +/// Delete a Feishu message via DELETE /open-apis/im/v1/messages/{id}. +/// Unlike PATCH (edit), DELETE is not subject to the 20-edits-per-message cap, +/// so this works even on messages that have already exhausted their edit quota. +/// Used by the streaming finalize path to remove the half-edited placeholder +/// before sending the full content as fresh messages, avoiding visual overlap. +/// +/// `message_id` shape is validated by the caller (`handle_reply` dispatch seam, +/// via `is_valid_feishu_message_id`) before this is reached, so it is safe to +/// interpolate into the URL path here. +async fn delete_feishu_message( + adapter: &FeishuAdapter, + message_id: &str, +) -> Result<(), String> { + let token = adapter + .token_cache + .get_token(&adapter.client) + .await + .map_err(|e| format!("token error: {e}"))?; + let api_base = adapter.config.api_base(); + let url = format!("{}/open-apis/im/v1/messages/{}", api_base, message_id); + match adapter + .client + .delete(&url) + .bearer_auth(&token) + .send() + .await + { + Ok(resp) if resp.status().is_success() => { + tracing::info!(message_id = %message_id, "feishu message deleted"); + Ok(()) + } + Ok(resp) => { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + tracing::warn!(status = %status, body = %body, message_id = %message_id, "feishu delete message failed"); + Err(format!("HTTP {status}: {body}")) + } + Err(e) => { + tracing::warn!(err = %e, message_id = %message_id, "feishu delete message request failed"); + Err(format!("request error: {e}")) + } + } +} + +// --------------------------------------------------------------------------- +// Markdown → Feishu post conversion +// --------------------------------------------------------------------------- + +/// Convert markdown text to feishu post content JSON. +/// Supported: code blocks → code_block tag, links → a tag, @mentions preserved. +/// Unsupported inline formatting (bold, italic, etc.) is stripped to plain text. +fn markdown_to_post(md: &str) -> serde_json::Value { + let mut lines: Vec> = Vec::new(); + + // We work byte-offset based for code fence detection, line-based otherwise. + let raw_lines: Vec<&str> = md.split('\n').collect(); + let mut li = 0; + while li < raw_lines.len() { + let line = raw_lines[li]; + // Detect fenced code block + let trimmed = line.trim_start(); + if let Some(after_fence) = trimmed.strip_prefix("```") { + let lang = after_fence.trim().to_string(); + let mut code = String::new(); + li += 1; + while li < raw_lines.len() { + if raw_lines[li].trim_start().starts_with("```") { + break; + } + if !code.is_empty() { + code.push('\n'); + } + code.push_str(raw_lines[li]); + li += 1; + } + li += 1; // skip closing ``` + let mut block = serde_json::json!({"tag": "code_block", "text": code}); + if !lang.is_empty() { + block["language"] = serde_json::Value::String(lang); + } + lines.push(vec![block]); + continue; + } + // Normal line: parse inline elements + let elems = parse_inline(line); + lines.push(elems); + li += 1; + } + + serde_json::json!({ + "zh_cn": { + "content": lines + } + }) +} + +/// Parse inline markdown elements in a single line. +/// Extracts links [text](url) → a tag, strips bold/italic/strikethrough markers. +fn parse_inline(line: &str) -> Vec { + let mut elems = Vec::new(); + let mut buf = String::new(); + let chars: Vec = line.chars().collect(); + let len = chars.len(); + let mut i = 0; + + while i < len { + // Link: [text](url) + if chars[i] == '[' { + if let Some((text, url, end)) = try_parse_link(&chars, i) { + if !buf.is_empty() { + elems.push(serde_json::json!({"tag": "text", "text": buf})); + buf.clear(); + } + elems.push(serde_json::json!({"tag": "a", "text": text, "href": url})); + i = end; + continue; + } + } + // Inline code: find matching closing backtick(s), preserve content literally + if chars[i] == '`' { + let mut ticks = 0; + while i + ticks < len && chars[i + ticks] == '`' { + ticks += 1; + } + i += ticks; + // Find matching closing backtick sequence of same length + let mut end = i; + 'outer: while end < len { + if chars[end] == '`' { + let mut close_ticks = 0; + while end + close_ticks < len && chars[end + close_ticks] == '`' { + close_ticks += 1; + } + if close_ticks == ticks { + // Found matching close — content between is literal + buf.extend(chars[i..end].iter().copied()); + i = end + close_ticks; + break 'outer; + } + end += close_ticks; + } else { + end += 1; + } + } + if end >= len { + // No matching close — treat backticks as literal + buf.extend(chars[i..len].iter().copied()); + i = len; + } + continue; + } + // Strip paired markdown markers: **bold**, *italic*, ~~strike~~ + // Unpaired markers are kept as literal text (e.g. ~/.ssh, *.rs, 3 * 4) + if chars[i] == '*' || chars[i] == '~' { + let ch = chars[i]; + let mut run = 0; + while i + run < len && chars[i + run] == ch { + run += 1; + } + // Look ahead for a matching closing run of same length + let after = i + run; + let mut scan = after; + let mut found_close = false; + while scan < len { + if chars[scan] == ch { + let mut close_run = 0; + while scan + close_run < len && chars[scan + close_run] == ch { + close_run += 1; + } + if close_run == run { + // Found matching close — strip both, keep inner text + buf.extend(chars[after..scan].iter().copied()); + i = scan + close_run; + found_close = true; + break; + } + scan += close_run; + } else { + scan += 1; + } + } + if !found_close { + // No matching close — keep markers as literal + for _ in 0..run { + buf.push(ch); + } + i += run; + } + continue; + } + buf.push(chars[i]); + i += 1; + } + if !buf.is_empty() { + elems.push(serde_json::json!({"tag": "text", "text": buf})); + } + if elems.is_empty() { + elems.push(serde_json::json!({"tag": "text", "text": ""})); + } + elems +} + +/// Try to parse a markdown link starting at position `start` (which is '['). +/// Returns (text, url, next_index) on success. +fn try_parse_link(chars: &[char], start: usize) -> Option<(String, String, usize)> { + let len = chars.len(); + // Find closing ] + let mut i = start + 1; + let mut text = String::new(); + while i < len && chars[i] != ']' { + text.push(chars[i]); + i += 1; + } + if i >= len { + return None; + } + i += 1; // skip ] + if i >= len || chars[i] != '(' { + return None; + } + i += 1; // skip ( + let mut url = String::new(); + while i < len && chars[i] != ')' { + url.push(chars[i]); + i += 1; + } + if i >= len { + return None; + } + i += 1; // skip ) + Some((text, url, i)) +} + +// --------------------------------------------------------------------------- +// Media helpers +// --------------------------------------------------------------------------- + +/// Reference to a media resource that needs async download after parse_message_event. +pub enum MediaRef { + Image { message_id: String, image_key: String }, + File { message_id: String, file_key: String, file_name: String }, + Audio { message_id: String, file_key: String }, +} + +const IMAGE_MAX_DIMENSION_PX: u32 = 1200; +const IMAGE_JPEG_QUALITY: u8 = 75; +const IMAGE_MAX_DOWNLOAD: u64 = 10 * 1024 * 1024; // 10 MB +const FILE_MAX_DOWNLOAD: u64 = 512 * 1024; // 512 KB + +/// Resize image so longest side <= 1200px, then encode as JPEG. +/// GIFs are passed through unchanged to preserve animation. +fn resize_and_compress(raw: &[u8]) -> Result<(Vec, String), image::ImageError> { + use image::ImageReader; + use std::io::Cursor; + + let reader = ImageReader::new(Cursor::new(raw)).with_guessed_format()?; + let format = reader.format(); + if format == Some(image::ImageFormat::Gif) { + return Ok((raw.to_vec(), "image/gif".to_string())); + } + let img = reader.decode()?; + let (w, h) = (img.width(), img.height()); + let img = if w > IMAGE_MAX_DIMENSION_PX || h > IMAGE_MAX_DIMENSION_PX { + let max_side = std::cmp::max(w, h); + let ratio = f64::from(IMAGE_MAX_DIMENSION_PX) / f64::from(max_side); + let new_w = (f64::from(w) * ratio) as u32; + let new_h = (f64::from(h) * ratio) as u32; + img.resize(new_w, new_h, image::imageops::FilterType::Lanczos3) + } else { + img + }; + let mut buf = Cursor::new(Vec::new()); + let encoder = image::codecs::jpeg::JpegEncoder::new_with_quality(&mut buf, IMAGE_JPEG_QUALITY); + img.write_with_encoder(encoder)?; + Ok((buf.into_inner(), "image/jpeg".to_string())) +} + +/// Download a Feishu image by message_id + image_key → resize/compress → base64 Attachment. +pub async fn download_feishu_image( + client: &reqwest::Client, + api_base: &str, + token: &str, + message_id: &str, + image_key: &str, +) -> Option { + let url = format!( + "{}/open-apis/im/v1/messages/{}/resources/{}?type=image", + api_base, message_id, image_key + ); + let resp = match client.get(&url).bearer_auth(token).send().await { + Ok(r) => r, + Err(e) => { + tracing::warn!(image_key, error = %e, "feishu image download failed"); + return None; + } + }; + if !resp.status().is_success() { + tracing::warn!(image_key, status = %resp.status(), "feishu image download failed"); + return None; + } + // Early gate: reject oversized downloads before buffering the full body + if let Some(cl) = resp.headers().get(reqwest::header::CONTENT_LENGTH) { + if let Ok(size) = cl.to_str().unwrap_or("0").parse::() { + if size > IMAGE_MAX_DOWNLOAD { + tracing::warn!(image_key, size, "feishu image Content-Length exceeds 10MB limit, skipping download"); + return None; + } + } + } + let bytes = resp.bytes().await.ok()?; + // Fallback check (Content-Length may be absent or misreported) + if bytes.len() as u64 > IMAGE_MAX_DOWNLOAD { + tracing::warn!(image_key, size = bytes.len(), "feishu image exceeds 10MB limit"); + return None; + } + let (compressed, mime) = match resize_and_compress(&bytes) { + Ok(v) => v, + Err(e) => { + tracing::warn!(image_key, error = %e, "feishu image resize failed"); + return None; + } + }; + let path = crate::store::store_media(&compressed).await?; + let ext = if mime == "image/gif" { "gif" } else { "jpg" }; + Some(crate::schema::Attachment { + attachment_type: "image".into(), + filename: format!("{}.{}", image_key, ext), + mime_type: mime, + data: String::new(), + size: compressed.len() as u64, + path: Some(path), + }) +} + +/// Download a Feishu file by message_id + file_key → base64 Attachment (text files only). +pub async fn download_feishu_file( + client: &reqwest::Client, + api_base: &str, + token: &str, + message_id: &str, + file_key: &str, + file_name: &str, +) -> Option { + // Only download text-like files + let ext = file_name.rsplit('.').next().unwrap_or("").to_lowercase(); + const TEXT_EXTS: &[&str] = &[ + "txt", "csv", "log", "md", "json", "jsonl", "yaml", "yml", "toml", "xml", + "rs", "py", "js", "ts", "jsx", "tsx", "go", "java", "c", "cpp", "h", "hpp", + "rb", "sh", "bash", "sql", "html", "css", "ini", "cfg", "conf", "env", + ]; + if !TEXT_EXTS.contains(&ext.as_str()) { + tracing::debug!(file_name, "skipping non-text file attachment"); + return None; + } + let url = format!( + "{}/open-apis/im/v1/messages/{}/resources/{}?type=file", + api_base, message_id, file_key + ); + let resp = match client.get(&url).bearer_auth(token).send().await { + Ok(r) => r, + Err(e) => { + tracing::warn!(file_name, error = %e, "feishu file download failed"); + return None; + } + }; + if !resp.status().is_success() { + tracing::warn!(file_name, status = %resp.status(), "feishu file download failed"); + return None; + } + // Early gate: reject oversized downloads before buffering the full body + if let Some(cl) = resp.headers().get(reqwest::header::CONTENT_LENGTH) { + if let Ok(size) = cl.to_str().unwrap_or("0").parse::() { + if size > FILE_MAX_DOWNLOAD { + tracing::warn!(file_name, size, "feishu file Content-Length exceeds 512KB limit, skipping download"); + return None; + } + } + } + let bytes = resp.bytes().await.ok()?; + // Fallback check (Content-Length may be absent or misreported) + if bytes.len() as u64 > FILE_MAX_DOWNLOAD { + tracing::warn!(file_name, size = bytes.len(), "feishu file exceeds 512KB limit"); + return None; + } + let path = crate::store::store_media(&bytes).await?; + Some(crate::schema::Attachment { + attachment_type: "text_file".into(), + filename: file_name.to_string(), + mime_type: "text/plain".into(), + data: String::new(), + size: bytes.len() as u64, + path: Some(path), + }) +} + +const AUDIO_MAX_DOWNLOAD: u64 = 25 * 1024 * 1024; // 25 MB (Whisper API limit) + +/// Download a Feishu audio message by message_id + file_key → base64 Attachment. +pub async fn download_feishu_audio( + client: &reqwest::Client, + api_base: &str, + token: &str, + message_id: &str, + file_key: &str, +) -> Option { + use urlencoding::encode; + let url = format!( + "{}/open-apis/im/v1/messages/{}/resources/{}?type=file", + api_base, encode(message_id), encode(file_key) + ); + let resp = match client.get(&url).bearer_auth(token).send().await { + Ok(r) => r, + Err(e) => { + tracing::warn!(file_key, error = %e, "feishu audio download failed"); + return None; + } + }; + if !resp.status().is_success() { + tracing::warn!(file_key, status = %resp.status(), "feishu audio download failed"); + return None; + } + let content_type = resp + .headers() + .get(reqwest::header::CONTENT_TYPE) + .and_then(|v| v.to_str().ok()) + .unwrap_or("audio/ogg") + .to_string(); + if let Some(cl) = resp.headers().get(reqwest::header::CONTENT_LENGTH) { + if let Ok(size) = cl.to_str().unwrap_or("0").parse::() { + if size > AUDIO_MAX_DOWNLOAD { + tracing::warn!(file_key, size, "feishu audio exceeds 25MB limit"); + return None; + } + } + } + let bytes = resp.bytes().await.ok()?; + if bytes.len() as u64 > AUDIO_MAX_DOWNLOAD { + tracing::warn!(file_key, size = bytes.len(), "feishu audio exceeds 25MB limit"); + return None; + } + tracing::debug!(file_key, size = bytes.len(), "feishu audio downloaded"); + let path = crate::store::store_media(&bytes).await?; + Some(crate::schema::Attachment { + attachment_type: "audio".into(), + filename: format!("{}.ogg", file_key), + mime_type: content_type, + data: String::new(), + size: bytes.len() as u64, + path: Some(path), + }) +} + +/// Send a post (rich text) message to a feishu chat_id. +/// Returns the sent message_id on success, None on failure. +/// When `reply_to` is Some(root_id), uses the reply API to stay in a thread. +/// When `reply_to` is None, sends a new message to the chat. +pub async fn send_post_message( + client: &reqwest::Client, + api_base: &str, + token: &str, + chat_id: &str, + reply_to: Option<&str>, + text: &str, +) -> Option { + let (url, body) = if let Some(root_id) = reply_to { + ( + format!("{}/open-apis/im/v1/messages/{}/reply", api_base, root_id), + serde_json::json!({ + "msg_type": "post", + "content": markdown_to_post(text).to_string(), + }), + ) + } else { + ( + format!("{}/open-apis/im/v1/messages?receive_id_type=chat_id", api_base), + serde_json::json!({ + "receive_id": chat_id, + "msg_type": "post", + "content": markdown_to_post(text).to_string(), + }), + ) + }; + + match client + .post(&url) + .bearer_auth(token) + .header("Content-Type", "application/json; charset=utf-8") + .json(&body) + .send() + .await + { + Ok(resp) => { + if resp.status().is_success() { + let resp_body: serde_json::Value = match resp.json().await { + Ok(v) => v, + Err(e) => { + tracing::warn!(err = %e, "feishu post: failed to parse response body"); + serde_json::Value::default() + } + }; + let msg_id = resp_body + .pointer("/data/message_id") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + info!(chat_id = %chat_id, reply_to = ?reply_to, message_id = ?msg_id, "feishu post message sent"); + msg_id + } else { + let status = resp.status(); + let text = resp.text().await.unwrap_or_default(); + tracing::error!(status = %status, body = %text, "feishu send post message failed"); + None + } + } + Err(e) => { + tracing::error!(err = %e, "feishu send post message request failed"); + None + } + } +} + +// --------------------------------------------------------------------------- + +/// Send a text message to a feishu chat_id. +/// Returns the sent message_id on success (for self-echo dedupe), None on failure. +/// Kept for webhook fallback and tests; normal reply path uses send_post_message. +#[allow(dead_code)] +pub async fn send_text_message( + client: &reqwest::Client, + api_base: &str, + token: &str, + chat_id: &str, + text: &str, +) -> Option { + let url = format!( + "{}/open-apis/im/v1/messages?receive_id_type=chat_id", + api_base + ); + let content = serde_json::json!({"text": text}).to_string(); + let body = serde_json::json!({ + "receive_id": chat_id, + "msg_type": "text", + "content": content, + }); + + match client + .post(&url) + .bearer_auth(token) + .header("Content-Type", "application/json; charset=utf-8") + .json(&body) + .send() + .await + { + Ok(resp) => { + if resp.status().is_success() { + let msg_id = match resp.json::().await { + Ok(body) => body + .pointer("/data/message_id") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()), + Err(e) => { + warn!(chat_id = %chat_id, err = %e, "feishu 200 response not valid JSON, self-echo dedupe will be skipped"); + None + } + }; + info!(chat_id = %chat_id, message_id = ?msg_id, "feishu message sent"); + msg_id + } else { + let status = resp.status(); + let text = resp.text().await.unwrap_or_default(); + tracing::error!(status = %status, body = %text, "feishu send message failed"); + None + } + } + Err(e) => { + tracing::error!(err = %e, "feishu send message request failed"); + None + } + } +} + +// --------------------------------------------------------------------------- +// Reactions (emoji on original message) +// --------------------------------------------------------------------------- + +/// Map OAB emoji to feishu reaction_type. Feishu uses string keys like "THUMBSUP". +fn emoji_to_feishu_reaction(emoji: &str) -> Option<&'static str> { + match emoji { + "👀" => Some("EYES"), + "🤔" => Some("THINKING"), + "🔥" => Some("FIRE"), + "👨\u{200d}💻" => Some("TECHNOLOGIST"), + "⚡" => Some("LIGHTNING"), + "🆗" => Some("OK"), + "👍" => Some("THUMBSUP"), + "😱" => Some("SCREAM"), + _ => None, + } +} + +async fn add_reaction(adapter: &FeishuAdapter, message_id: &str, emoji: &str) { + let reaction_type = match emoji_to_feishu_reaction(emoji) { + Some(r) => r, + None => { + tracing::debug!(emoji = %emoji, "feishu: no mapping for reaction emoji"); + return; + } + }; + let token = match adapter.token_cache.get_token(&adapter.client).await { + Ok(t) => t, + Err(e) => { tracing::error!(err = %e, "feishu: cannot get token for reaction"); return; } + }; + let url = format!( + "{}/open-apis/im/v1/messages/{}/reactions", + adapter.config.api_base(), message_id + ); + let _ = adapter.client + .post(&url) + .bearer_auth(&token) + .json(&serde_json::json!({"reaction_type": {"emoji_type": reaction_type}})) + .send() + .await + .map_err(|e| tracing::error!(err = %e, "feishu add_reaction failed")); +} + +async fn remove_reaction(adapter: &FeishuAdapter, message_id: &str, emoji: &str) { + let reaction_type = match emoji_to_feishu_reaction(emoji) { + Some(r) => r, + None => return, + }; + let token = match adapter.token_cache.get_token(&adapter.client).await { + Ok(t) => t, + Err(e) => { tracing::error!(err = %e, "feishu: cannot get token for reaction"); return; } + }; + // Feishu remove reaction needs reaction_id. Simpler approach: delete by type. + // GET reactions, find matching, DELETE by id. + let list_url = format!( + "{}/open-apis/im/v1/messages/{}/reactions?reaction_type={}", + adapter.config.api_base(), message_id, reaction_type + ); + let resp = match adapter.client.get(&list_url).bearer_auth(&token).send().await { + Ok(r) => r, + Err(_) => return, + }; + let body: serde_json::Value = match resp.json().await { + Ok(v) => v, + Err(_) => return, + }; + // Find our bot's reaction_id + if let Some(items) = body.pointer("/data/items").and_then(|v| v.as_array()) { + let bot_id = adapter.bot_open_id.read().await; + for item in items { + let is_ours = item.pointer("/operator/operator_id/open_id") + .and_then(|v| v.as_str()) == bot_id.as_deref(); + if is_ours { + if let Some(reaction_id) = item.get("reaction_id").and_then(|v| v.as_str()) { + let del_url = format!( + "{}/open-apis/im/v1/messages/{}/reactions/{}", + adapter.config.api_base(), message_id, reaction_id + ); + let _ = adapter.client.delete(&del_url).bearer_auth(&token).send().await; + return; + } + } + } + } +} + +// --------------------------------------------------------------------------- +// Reply handler +// --------------------------------------------------------------------------- + +/// Check if the bot has participated in the thread referenced by this envelope. +/// Returns `true` if the message is in a thread and that thread has a valid +/// (non-expired) participation entry in the cache. +fn check_thread_participated( + envelope: &FeishuEventEnvelope, + cache: &Arc>>, + session_ttl_secs: u64, +) -> bool { + envelope + .event + .as_ref() + .and_then(|e| e.message.as_ref()) + .and_then(|m| m.root_id.as_deref().or(m.parent_id.as_deref())) + .map(|tid| { + // Intentionally recover from poisoned mutex — cache data loss is acceptable + // and preferable to panicking the gateway. + let c = cache.lock().unwrap_or_else(|e| e.into_inner()); + c.get(tid).is_some_and(|ts| ts.elapsed().as_secs() < session_ttl_secs) + }) + .unwrap_or(false) +} + +/// Max entries before eviction. Shared by both `participated_threads` and +/// `multibot_threads` caches — they have the same cardinality (one entry per +/// active thread) so a single limit is appropriate for both. +const PARTICIPATION_CACHE_MAX: usize = 1000; + +/// Detect if a message @mentions another bot in a participated thread, and if +/// so, mark the thread in the multibot cache. Returns whether @mention gating +/// should be bypassed, respecting the configured `allow_user_messages` mode. +/// +/// This consolidates the duplicated multibot detection logic used by both the +/// WebSocket and webhook paths. +fn detect_and_mark_multibot( + envelope: &FeishuEventEnvelope, + bot_open_id: Option<&str>, + config: &FeishuConfig, + participated_threads: &Arc>>, + multibot_threads: &Arc>>, +) -> bool { + let self_participated = check_thread_participated( + envelope, participated_threads, config.session_ttl_secs, + ); + + let thread_id_for_check = envelope + .event + .as_ref() + .and_then(|e| e.message.as_ref()) + .and_then(|m| m.root_id.as_deref().or(m.parent_id.as_deref())); + + // Early multibot detection: if a message in a participated thread @mentions + // another bot, mark the thread as multibot immediately. + if let Some(tid) = thread_id_for_check { + if self_participated { + let mentions = envelope + .event + .as_ref() + .and_then(|e| e.message.as_ref()) + .and_then(|m| m.mentions.as_ref()); + if let Some(mention_list) = mentions { + let bot_self_id = bot_open_id.unwrap_or(""); + let mention_ids: Vec<_> = mention_list.iter().filter_map(|m| { + m.id.as_ref().and_then(|id| id.open_id.as_deref()) + }).collect(); + + let mentions_other_bot = if !config.trusted_bot_ids.is_empty() { + mention_ids.iter().any(|oid| { + config.trusted_bot_ids.iter().any(|bid| bid == oid) + }) + } else if !config.allowed_users.is_empty() { + mention_ids.iter().any(|oid| { + *oid != bot_self_id && !config.allowed_users.iter().any(|u| u == oid) + }) + } else { + false + }; + + if mentions_other_bot { + info!(thread_id = %tid, "multibot thread detected via @mention"); + let mut cache = multibot_threads.lock().unwrap_or_else(|e| e.into_inner()); + cache.entry(tid.to_string()).or_insert_with(Instant::now); + if cache.len() > PARTICIPATION_CACHE_MAX { + cache.retain(|_, ts| ts.elapsed().as_secs() < config.session_ttl_secs); + } + } + } + } + } + + // Compute bypass_mention_gating based on mode + match config.allow_user_messages { + AllowUsers::Mentions => false, + AllowUsers::Involved => self_participated, + AllowUsers::MultibotMentions => { + if !self_participated { + false + } else { + thread_id_for_check + .map(|tid| { + let cache = multibot_threads.lock().unwrap_or_else(|e| e.into_inner()); + cache + .get(tid) + .is_none_or(|ts| ts.elapsed().as_secs() >= config.session_ttl_secs) + }) + .unwrap_or(true) + } + } + } +} + +/// Record that the bot has participated in a thread. Evicts oldest entries +/// when the cache exceeds PARTICIPATION_CACHE_MAX. +fn record_participation( + cache: &Arc>>, + thread_id: &str, + session_ttl_secs: u64, +) { + if session_ttl_secs == 0 { + return; // Participation tracking disabled + } + // Intentionally recover from poisoned mutex — cache data loss is acceptable + // and preferable to panicking the gateway. + let mut map = cache.lock().unwrap_or_else(|e| e.into_inner()); + map.insert(thread_id.to_string(), Instant::now()); + // Evict if over capacity: first drop expired entries, then oldest half if still over + if map.len() > PARTICIPATION_CACHE_MAX { + map.retain(|_, ts| ts.elapsed().as_secs() < session_ttl_secs); + if map.len() > PARTICIPATION_CACHE_MAX { + let mut entries: Vec<_> = map.iter().map(|(k, v)| (k.clone(), *v)).collect(); + entries.sort_by_key(|(_, ts)| *ts); + let evict_count = entries.len() / 2; + for (k, _) in entries.into_iter().take(evict_count) { + map.remove(&k); + } + } + } +} + +pub async fn handle_reply( + reply: &GatewayReply, + adapter: &FeishuAdapter, + event_tx: &tokio::sync::broadcast::Sender, +) { + // Handle reactions — add/remove emoji on the original message + if let Some(ref cmd) = reply.command { + // Defence-in-depth: every command below interpolates `reply.reply_to` + // into a REST URL path (edit/delete → /im/v1/messages/{id}; reactions → + // /im/v1/messages/{id}/reactions). Validate the id shape once here, at + // the dispatch seam, so a crafted id with URL metacharacters can't alter + // request semantics. Trust boundary is the core↔gateway WebSocket, so + // this is belt-and-suspenders — but it closes the guard over every + // url-path-bearing command instead of just delete. + let interpolates_message_id = matches!( + cmd.as_str(), + "edit_message" | "delete_message" | "add_reaction" | "remove_reaction" + ); + if interpolates_message_id && !is_valid_feishu_message_id(&reply.reply_to) { + // "draft" is a known sentinel from core when streaming_placeholder=false; + // not a security concern, just a no-op — log at debug to avoid noise. + if reply.reply_to == "draft" { + tracing::debug!( + command = %cmd, + message_id = %reply.reply_to, + "feishu: skipping command — draft placeholder has no real message_id" + ); + } else { + tracing::warn!( + command = %cmd, + message_id = %reply.reply_to, + "feishu: refusing command — message_id failed shape validation" + ); + } + if let Some(ref req_id) = reply.request_id { + let resp = crate::schema::GatewayResponse { + schema: "openab.gateway.response.v1".into(), + request_id: req_id.clone(), + success: false, + thread_id: None, + message_id: None, + error: Some("invalid message_id format".to_string()), + }; + if let Ok(json) = serde_json::to_string(&resp) { + let _ = event_tx.send(json); + } + } + return; + } + match cmd.as_str() { + "add_reaction" => { + add_reaction(adapter, &reply.reply_to, &reply.content.text).await; + return; + } + "remove_reaction" => { + remove_reaction(adapter, &reply.reply_to, &reply.content.text).await; + return; + } + "edit_message" => { + let outcome = edit_feishu_message( + adapter, + &reply.reply_to, + &reply.content.text, + ).await; + // Translate outcome → (success, message_id, error). For + // CapReached we deliberately do NOT append-new at the gateway + // layer (see the rationale on the CapReached arm below); we + // signal failure so core's finalize path owns recovery. + let (success, message_id, error) = match outcome { + EditOutcome::Edited => { + (true, Some(reply.reply_to.clone()), None) + } + EditOutcome::CapReached => { + // Do NOT append-new fallback at the gateway layer. Core's + // cosmetic streaming loop flushes every ~1500ms — if every + // post-cap edit spawned a new message, the user would be + // spammed with 20+ duplicate continuation messages over the + // remainder of a long reply. + // + // Instead, signal failure so: + // 1. core's mid-stream cosmetic edit loop hits its + // consecutive-failures break (3 strikes) and stops + // attempting edits, freezing the placeholder mid-content + // 2. the final delivery path in src/adapter.rs sees the + // placeholder edit fail and falls back to send_message + // so the user gets the full reply as a fresh message + // + // Net UX: half-edited placeholder + one complete continuation + // message + ✅ done reaction (vs. today's mid-truncation + 🆗 + // false success, or naive append-fallback's 25-message spam). + ( + false, + None, + Some("edit_cap_reached".to_string()), + ) + } + EditOutcome::Failed(err) => (false, None, Some(err)), + }; + if let Some(ref req_id) = reply.request_id { + let resp = crate::schema::GatewayResponse { + schema: "openab.gateway.response.v1".into(), + request_id: req_id.clone(), + success, + thread_id: None, + message_id, + error, + }; + if let Ok(json) = serde_json::to_string(&resp) { + let _ = event_tx.send(json); + } + } + return; + } + "create_topic" | "set_reaction" => { + tracing::debug!(command = %cmd, "feishu: skipping unsupported command"); + return; + } + "delete_message" => { + let result = delete_feishu_message(adapter, &reply.reply_to).await; + let (success, error) = match result { + Ok(()) => (true, None), + Err(e) => (false, Some(e)), + }; + // Dormant by design: core's delete_message is fire-and-forget + // (request_id = None), so this response branch is currently + // never taken. Kept for symmetry with the other handlers and so + // delete becomes observable for free if a future caller (or + // another gateway client) sets request_id. + if let Some(ref req_id) = reply.request_id { + let resp = crate::schema::GatewayResponse { + schema: "openab.gateway.response.v1".into(), + request_id: req_id.clone(), + success, + thread_id: None, + message_id: None, + error, + }; + if let Ok(json) = serde_json::to_string(&resp) { + let _ = event_tx.send(json); + } + } + return; + } + _ => {} + } + } + + let token = match adapter.token_cache.get_token(&adapter.client).await { + Ok(t) => t, + Err(e) => { + tracing::error!(err = %e, "feishu: cannot get token for reply"); + if let Some(ref req_id) = reply.request_id { + let resp = crate::schema::GatewayResponse { + schema: "openab.gateway.response.v1".into(), + request_id: req_id.clone(), + success: false, + thread_id: None, + message_id: None, + error: Some(format!("token error: {e}")), + }; + if let Ok(json) = serde_json::to_string(&resp) { + let _ = event_tx.send(json); + } + } + return; + } + }; + + let api_base = adapter.config.api_base(); + let text = &reply.content.text; + let limit = adapter.config.message_limit; + // quote_message_id (agent-controlled reply-to) takes priority over thread_id + let reply_target = reply.quote_message_id.as_deref() + .or(reply.channel.thread_id.as_deref()); + let thread_id = reply.channel.thread_id.as_deref(); + + // Split long messages; store sent message_ids in dedupe to prevent + // self-echo (Feishu pushes bot's own messages back via WebSocket) + // Use post (rich text) format for markdown rendering. + // When in a thread (thread_id present), use reply API to stay in the same thread. + if text.len() <= limit { + let result = send_post_message(&adapter.client, &api_base, &token, &reply.channel.id, reply_target, text).await; + // Fallback: if quote_message_id caused failure, retry without it + let result = if result.is_none() && reply.quote_message_id.is_some() { + tracing::warn!(quote_message_id = ?reply.quote_message_id, channel_id = %reply.channel.id, "reply-to failed, falling back to plain send"); + send_post_message(&adapter.client, &api_base, &token, &reply.channel.id, thread_id, text).await + } else { + result + }; + match result { + Some(msg_id) => { + adapter.dedupe.is_duplicate(&msg_id); + // Record thread participation for mention bypass + if let Some(tid) = thread_id { + record_participation(&adapter.participated_threads, tid, adapter.config.session_ttl_secs); + } + // Send response with message_id back to OAB core (for streaming edit) + if let Some(ref req_id) = reply.request_id { + let resp = crate::schema::GatewayResponse { + schema: "openab.gateway.response.v1".into(), + request_id: req_id.clone(), + success: true, + thread_id: None, + message_id: Some(msg_id), + error: None, + }; + if let Ok(json) = serde_json::to_string(&resp) { + let _ = event_tx.send(json); + } + } + } + None => { + // Send failure response so core doesn't wait 5s for timeout + if let Some(ref req_id) = reply.request_id { + let resp = crate::schema::GatewayResponse { + schema: "openab.gateway.response.v1".into(), + request_id: req_id.clone(), + success: false, + thread_id: None, + message_id: None, + error: Some("send_post_message failed".into()), + }; + if let Ok(json) = serde_json::to_string(&resp) { + let _ = event_tx.send(json); + } + } + } + } + } else { + // Track per-chunk success so we can report partial-failure back to core. + // Previously this branch returned no GatewayResponse at all and used + // "any chunk succeeded" as the success criterion — letting core fall + // through to a 5s timeout and silently mark the turn delivered. With + // request/response now wired through, we propagate exact health. + let chunks: Vec<&str> = split_text(text, limit); + let total_chunks = chunks.len(); + let mut succeeded = 0usize; + let mut last_msg_id: Option = None; + for chunk in &chunks { + if let Some(msg_id) = send_post_message(&adapter.client, &api_base, &token, &reply.channel.id, reply_target, chunk).await { + adapter.dedupe.is_duplicate(&msg_id); + succeeded += 1; + last_msg_id = Some(msg_id); + } + } + // Fallback: if quote_message_id caused all chunks to fail, retry without it + if succeeded == 0 && reply.quote_message_id.is_some() { + tracing::warn!(quote_message_id = ?reply.quote_message_id, channel_id = %reply.channel.id, "chunked reply-to failed, falling back to plain send"); + for chunk in &chunks { + if let Some(msg_id) = send_post_message(&adapter.client, &api_base, &token, &reply.channel.id, thread_id, chunk).await { + adapter.dedupe.is_duplicate(&msg_id); + succeeded += 1; + last_msg_id = Some(msg_id); + } + } + } + if succeeded > 0 { + if let Some(tid) = thread_id { + record_participation(&adapter.participated_threads, tid, adapter.config.session_ttl_secs); + } + } + // Report back to core. Success requires every chunk delivered — partial + // success becomes failure so dispatch surfaces ❌ rather than 🆗. + if let Some(ref req_id) = reply.request_id { + let success = succeeded == total_chunks && total_chunks > 0; + let error = if success { + None + } else { + Some(format!( + "chunked send delivered {succeeded}/{total_chunks} chunks" + )) + }; + let resp = crate::schema::GatewayResponse { + schema: "openab.gateway.response.v1".into(), + request_id: req_id.clone(), + success, + thread_id: None, + message_id: last_msg_id, + error, + }; + if let Ok(json) = serde_json::to_string(&resp) { + let _ = event_tx.send(json); + } + } + } +} + +/// Split text into chunks of at most `limit` bytes, breaking at newline or +/// space boundaries when possible. Safe for multi-byte UTF-8 (e.g. Chinese). +fn split_text(text: &str, limit: usize) -> Vec<&str> { + let mut chunks = Vec::new(); + let mut start = 0; + while start < text.len() { + if start + limit >= text.len() { + chunks.push(&text[start..]); + break; + } + // Find a char-safe boundary at or before start + limit + let mut end = start + limit; + while !text.is_char_boundary(end) { + end -= 1; + } + // Try to break at a newline or space within the last 200 bytes. + // search_start must also be on a char boundary to avoid panic. + let mut search_start = if end > start + 200 { end - 200 } else { start }; + while search_start < end && !text.is_char_boundary(search_start) { + search_start += 1; + } + let break_at = text[search_start..end] + .rfind('\n') + .or_else(|| text[search_start..end].rfind(' ')) + .map(|pos| search_start + pos + 1) + .unwrap_or(end); + chunks.push(&text[start..break_at]); + start = break_at; + } + chunks +} + +// --------------------------------------------------------------------------- +// Webhook handler +// --------------------------------------------------------------------------- + +/// Max webhook body size: 1 MB +const WEBHOOK_BODY_LIMIT: usize = 1_048_576; + +/// Simple per-IP rate limiter state. +pub struct RateLimiter { + counts: std::sync::Mutex>, + window_secs: u64, + max_requests: u64, +} + +impl RateLimiter { + pub fn new(window_secs: u64, max_requests: u64) -> Self { + Self { + counts: std::sync::Mutex::new(HashMap::new()), + window_secs, + max_requests, + } + } + + /// Returns true if the request should be rejected (rate exceeded). + pub fn check(&self, key: &str) -> bool { + let mut map = self.counts.lock().unwrap_or_else(|e| e.into_inner()); + // Lazy cleanup + if map.len() > 4096 { + map.retain(|_, (_, ts)| ts.elapsed().as_secs() < self.window_secs); + } + let entry = map.entry(key.to_string()).or_insert((0, Instant::now())); + if entry.1.elapsed().as_secs() >= self.window_secs { + *entry = (1, Instant::now()); + false + } else { + entry.0 += 1; + entry.0 > self.max_requests + } + } +} + +/// Verify webhook signature: SHA256(timestamp + nonce + encrypt_key + body). +fn verify_signature( + timestamp: &str, + nonce: &str, + encrypt_key: &str, + body: &[u8], + expected_sig: &str, +) -> bool { + use sha2::{Digest, Sha256}; + let mut hasher = Sha256::new(); + hasher.update(timestamp.as_bytes()); + hasher.update(nonce.as_bytes()); + hasher.update(encrypt_key.as_bytes()); + hasher.update(body); + let result = format!("{:x}", hasher.finalize()); + constant_time_eq(&result, expected_sig) +} + +/// Decrypt AES-CBC encrypted event body. +/// Feishu uses AES-256-CBC with SHA256(encrypt_key) as key, first 16 bytes of +/// ciphertext as IV. +fn decrypt_event(encrypt_key: &str, encrypted: &str) -> anyhow::Result { + use sha2::{Digest, Sha256}; + let key = Sha256::digest(encrypt_key.as_bytes()); + let cipher_bytes = base64::Engine::decode( + &base64::engine::general_purpose::STANDARD, + encrypted, + ) + .map_err(|e| anyhow::anyhow!("base64 decode failed: {e}"))?; + + if cipher_bytes.len() < 16 { + anyhow::bail!("encrypted data too short"); + } + + let iv = &cipher_bytes[..16]; + let ciphertext = &cipher_bytes[16..]; + + // AES-256-CBC decrypt + use aes::cipher::{BlockDecryptMut, KeyIvInit}; + type Aes256CbcDec = cbc::Decryptor; + + let decryptor = Aes256CbcDec::new_from_slices(&key, iv) + .map_err(|e| anyhow::anyhow!("aes init failed: {e}"))?; + + let mut buf = ciphertext.to_vec(); + let plaintext = decryptor + .decrypt_padded_mut::(&mut buf) + .map_err(|e| anyhow::anyhow!("aes decrypt failed: {e}"))?; + + String::from_utf8(plaintext.to_vec()) + .map_err(|e| anyhow::anyhow!("decrypted data not utf8: {e}")) +} + +pub async fn webhook( + State(state): State>, + headers: axum::http::HeaderMap, + body: axum::body::Bytes, +) -> axum::response::Response { + use axum::response::IntoResponse; + + let feishu = match state.feishu.as_ref() { + Some(f) => f, + None => return axum::http::StatusCode::SERVICE_UNAVAILABLE.into_response(), + }; + + // Body size limit + if body.len() > WEBHOOK_BODY_LIMIT { + warn!(size = body.len(), "feishu webhook body too large"); + return axum::http::StatusCode::PAYLOAD_TOO_LARGE.into_response(); + } + + // Rate limit (by X-Forwarded-For or fallback) + let ip = headers + .get("x-forwarded-for") + .and_then(|v| v.to_str().ok()) + .unwrap_or("unknown"); + if feishu.rate_limiter.check(ip) { + return (axum::http::StatusCode::TOO_MANY_REQUESTS, "rate limit exceeded") + .into_response(); + } + + // Signature verification (if encrypt_key configured) + if let Some(ref encrypt_key) = feishu.config.encrypt_key { + let sig = headers + .get("x-lark-signature") + .and_then(|v| v.to_str().ok()); + let timestamp = headers + .get("x-lark-request-timestamp") + .and_then(|v| v.to_str().ok()); + let nonce = headers + .get("x-lark-request-nonce") + .and_then(|v| v.to_str().ok()); + + match (sig, timestamp, nonce) { + (Some(sig), Some(ts), Some(nonce)) => { + if !verify_signature(ts, nonce, encrypt_key, &body, sig) { + warn!("feishu webhook rejected: invalid signature"); + return axum::http::StatusCode::UNAUTHORIZED.into_response(); + } + } + _ => { + warn!("feishu webhook rejected: missing signature headers"); + return axum::http::StatusCode::UNAUTHORIZED.into_response(); + } + } + } else { + warn!("FEISHU_ENCRYPT_KEY not configured — webhook signature verification is SKIPPED (insecure)"); + } + + // Parse body — may be encrypted + let event_json: serde_json::Value = match serde_json::from_slice(&body) { + Ok(v) => v, + Err(e) => { + warn!(err = %e, "feishu webhook parse error"); + return axum::http::StatusCode::BAD_REQUEST.into_response(); + } + }; + + // Handle encrypted events + let event_json = if let Some(encrypted) = event_json.get("encrypt").and_then(|v| v.as_str()) { + let encrypt_key = match feishu.config.encrypt_key.as_deref() { + Some(k) => k, + None => { + warn!("feishu webhook: encrypted event but no FEISHU_ENCRYPT_KEY configured"); + return axum::http::StatusCode::BAD_REQUEST.into_response(); + } + }; + match decrypt_event(encrypt_key, encrypted) { + Ok(decrypted) => match serde_json::from_str(&decrypted) { + Ok(v) => v, + Err(e) => { + warn!(err = %e, "feishu webhook: decrypted event parse error"); + return axum::http::StatusCode::BAD_REQUEST.into_response(); + } + }, + Err(e) => { + warn!(err = %e, "feishu webhook: decrypt failed"); + return axum::http::StatusCode::BAD_REQUEST.into_response(); + } + } + } else { + event_json + }; + + // URL verification challenge + if event_json.get("challenge").is_some() { + // Verify token if configured + if let Some(ref expected_token) = feishu.config.verification_token { + let token = event_json.get("token").and_then(|v| v.as_str()); + match token { + Some(t) if constant_time_eq(t, expected_token) => {} + _ => { + warn!("feishu webhook: URL verification token mismatch"); + return axum::http::StatusCode::UNAUTHORIZED.into_response(); + } + } + } + let challenge = event_json["challenge"].as_str().unwrap_or(""); + return axum::Json(serde_json::json!({"challenge": challenge})).into_response(); + } + + // Verification token check for regular events + if let Some(ref expected_token) = feishu.config.verification_token { + let token = event_json + .pointer("/header/token") + .or_else(|| event_json.get("token")) + .and_then(|v| v.as_str()); + match token { + Some(t) if constant_time_eq(t, expected_token) => {} + _ => { + warn!("feishu webhook rejected: invalid verification token"); + return axum::http::StatusCode::UNAUTHORIZED.into_response(); + } + } + } + + // Parse as event envelope + let envelope: FeishuEventEnvelope = match serde_json::from_value(event_json) { + Ok(e) => e, + Err(e) => { + warn!(err = %e, "feishu webhook: event envelope parse error"); + return axum::http::StatusCode::OK.into_response(); + } + }; + + // Dedupe + parse + broadcast (same logic as WebSocket handler) + if let Some(ref header) = envelope.header { + if let Some(ref event_id) = header.event_id { + if feishu.dedupe.is_duplicate(event_id) { + return axum::http::StatusCode::OK.into_response(); + } + } + } + + let bot_id = feishu.bot_open_id.read().await; + let bot_id_ref = bot_id.as_deref(); + + // Check participated threads and multibot detection for mention bypass + let bypass_mention = detect_and_mark_multibot( + &envelope, bot_id_ref, &feishu.config, + &feishu.participated_threads, &feishu.multibot_threads, + ); + + if let Some((mut gateway_event, media_refs)) = parse_message_event(&envelope, bot_id_ref, &feishu.config, bypass_mention) { + if !feishu.dedupe.is_duplicate(&gateway_event.message_id) { + let name = resolve_user_name( + &gateway_event.sender.id, &feishu.name_cache, &feishu.token_cache, + &feishu.client, &feishu.config.api_base(), + ).await; + gateway_event.sender.name = name.clone(); + gateway_event.sender.display_name = name; + + // Download media attachments + if !media_refs.is_empty() { + if let Ok(token) = feishu.token_cache.get_token(&feishu.client).await { + let api_base = feishu.config.api_base(); + for media_ref in &media_refs { + let attachment = match media_ref { + MediaRef::Image { message_id, image_key } => { + download_feishu_image(&feishu.client, &api_base, &token, message_id, image_key).await + } + MediaRef::File { message_id, file_key, file_name } => { + download_feishu_file(&feishu.client, &api_base, &token, message_id, file_key, file_name).await + } + MediaRef::Audio { message_id, file_key } => { + download_feishu_audio(&feishu.client, &api_base, &token, message_id, file_key).await + } + }; + if let Some(att) = attachment { + gateway_event.content.attachments.push(att); + } + } + } + } + + // Skip if no text and no attachments (e.g. unsupported file type) + if gateway_event.content.text.trim().is_empty() && gateway_event.content.attachments.is_empty() { + return axum::http::StatusCode::OK.into_response(); + } + + let json = serde_json::to_string(&gateway_event).unwrap(); + info!( + channel = %gateway_event.channel.id, + sender = %gateway_event.sender.id, + "feishu webhook → gateway" + ); + let _ = state.event_tx.send(json); + } + } + + axum::http::StatusCode::OK.into_response() +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use wiremock::matchers::{body_json, header, method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + fn test_config() -> FeishuConfig { + FeishuConfig { + app_id: "cli_test".into(), + app_secret: "secret_test".into(), + domain: "feishu".into(), + connection_mode: ConnectionMode::Websocket, + webhook_path: "/webhook/feishu".into(), + verification_token: None, + encrypt_key: None, + allowed_groups: vec![], + allowed_users: vec![], + require_mention: true, + allow_bots: AllowBots::Off, + allow_user_messages: AllowUsers::MultibotMentions, + trusted_bot_ids: vec![], + max_bot_turns: 20, + dedupe_ttl_secs: 300, + message_limit: 4000, + session_ttl_secs: 86400, + api_base_override: None, + } + } + + // --- Token tests --- + + #[tokio::test] + async fn token_refresh_success() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/open-apis/auth/v3/tenant_access_token/internal")) + .and(body_json(serde_json::json!({ + "app_id": "cli_test", + "app_secret": "secret_test", + }))) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "code": 0, + "msg": "ok", + "tenant_access_token": "t-test-token-123", + "expire": 7200 + }))) + .expect(1) + .mount(&server) + .await; + + let config = test_config(); + let cache = FeishuTokenCache::with_base(&config, &server.uri()); + let client = reqwest::Client::new(); + + let token = cache.get_token(&client).await.unwrap(); + assert_eq!(token, "t-test-token-123"); + + // Second call should use cache, not hit server again (expect(1) above) + let token2 = cache.get_token(&client).await.unwrap(); + assert_eq!(token2, "t-test-token-123"); + } + + #[tokio::test] + async fn token_refresh_api_error() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/open-apis/auth/v3/tenant_access_token/internal")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "code": 10003, + "msg": "invalid app_secret", + }))) + .expect(1) + .mount(&server) + .await; + + let config = test_config(); + let cache = FeishuTokenCache::with_base(&config, &server.uri()); + let client = reqwest::Client::new(); + + let err = cache.get_token(&client).await.unwrap_err(); + let msg = err.to_string(); + assert!(msg.contains("10003"), "error should contain code: {msg}"); + assert!( + !msg.contains("secret_test"), + "error must not leak secret: {msg}" + ); + } + + // --- Send message tests --- + + #[tokio::test] + async fn send_text_message_success() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/open-apis/im/v1/messages")) + .and(header("authorization", "Bearer t-tok")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "code": 0, + "msg": "success", + "data": {"message_id": "om_test123"} + }))) + .expect(1) + .mount(&server) + .await; + + let client = reqwest::Client::new(); + let msg_id = send_text_message(&client, &server.uri(), "t-tok", "oc_chat1", "hello").await; + assert_eq!(msg_id.as_deref(), Some("om_test123")); + } + + #[tokio::test] + async fn send_text_message_api_failure() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/open-apis/im/v1/messages")) + .respond_with(ResponseTemplate::new(400).set_body_string("bad request")) + .expect(1) + .mount(&server) + .await; + + let client = reqwest::Client::new(); + let msg_id = send_text_message(&client, &server.uri(), "t-tok", "oc_chat1", "hello").await; + assert!(msg_id.is_none()); + } + + #[tokio::test] + async fn send_text_message_invalid_json_returns_none() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/open-apis/im/v1/messages")) + .respond_with(ResponseTemplate::new(200).set_body_string("not json")) + .expect(1) + .mount(&server) + .await; + + let client = reqwest::Client::new(); + let msg_id = send_text_message(&client, &server.uri(), "t-tok", "oc_chat1", "hello").await; + assert!(msg_id.is_none()); + } + + #[tokio::test] + async fn send_text_message_missing_message_id_returns_none() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/open-apis/im/v1/messages")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "code": 0, + "msg": "success", + }))) + .expect(1) + .mount(&server) + .await; + + let client = reqwest::Client::new(); + let msg_id = send_text_message(&client, &server.uri(), "t-tok", "oc_chat1", "hello").await; + assert!(msg_id.is_none()); + } + + // --- Split text tests --- + + #[test] + fn split_text_short() { + let chunks = split_text("hello", 100); + assert_eq!(chunks, vec!["hello"]); + } + + #[test] + fn split_text_exact_limit() { + let text = "a".repeat(100); + let chunks = split_text(&text, 100); + assert_eq!(chunks.len(), 1); + } + + #[test] + fn split_text_chinese_utf8_safe() { + // Each Chinese char is 3 bytes. 20 chars = 60 bytes. + // Limit 10 would land mid-char without boundary check. + let text = "你好世界測試飛書中文聊天消息分割安全驗證完成"; + let chunks = split_text(text, 10); + assert!(chunks.len() > 1); + let reassembled: String = chunks.concat(); + assert_eq!(reassembled, text); + } + + #[test] + fn split_text_search_start_char_boundary() { + // Regression: search_start = end - 200 could land mid-char. + // 300 Chinese chars (900 bytes) with limit=500 forces search_start + // into the middle of multi-byte chars. + let text: String = "飛書".repeat(150); // 300 chars, 900 bytes + let chunks = split_text(&text, 500); + assert!(chunks.len() >= 2); + let reassembled: String = chunks.concat(); + assert_eq!(reassembled, text); + } + + #[test] + fn split_text_long_breaks_at_newline() { + let text = format!("{}\n{}", "a".repeat(50), "b".repeat(50)); + let chunks = split_text(&text, 60); + assert_eq!(chunks.len(), 2); + assert!(chunks[0].ends_with('\n')); + } + + // --- Event parsing tests --- + + fn make_envelope( + chat_type: &str, + text: &str, + sender_open_id: &str, + mentions: Option>, + ) -> FeishuEventEnvelope { + FeishuEventEnvelope { + header: Some(FeishuEventHeader { + event_id: Some("evt_test".into()), + event_type: Some("im.message.receive_v1".into()), + }), + event: Some(FeishuEventBody { + sender: Some(FeishuSender { + sender_id: Some(FeishuSenderId { + open_id: Some(sender_open_id.into()), + }), + sender_type: Some("user".into()), + }), + message: Some(FeishuMessage { + message_id: Some("om_msg1".into()), + chat_id: Some("oc_chat1".into()), + chat_type: Some(chat_type.into()), + message_type: Some("text".into()), + content: Some(serde_json::json!({"text": text}).to_string()), + mentions, + root_id: None, + parent_id: None, + }), + }), + challenge: None, + event_type_field: None, + } + } + + #[test] + fn parse_dm_text() { + let env = make_envelope("p2p", "hello", "ou_user1", None); + let cfg = test_config(); + let (evt, _media) = parse_message_event(&env, Some("ou_bot"), &cfg, false).unwrap(); + assert_eq!(evt.platform, "feishu"); + assert_eq!(evt.channel.channel_type, "direct"); + assert_eq!(evt.channel.id, "oc_chat1"); + assert_eq!(evt.sender.id, "ou_user1"); + assert_eq!(evt.content.text, "hello"); + assert!(evt.mentions.is_empty()); + } + + #[test] + fn parse_group_with_bot_mention() { + let mentions = vec![FeishuMention { + key: Some("@_user_1".into()), + id: Some(FeishuMentionId { + open_id: Some("ou_bot".into()), + }), + name: Some("Bot".into()), + }]; + let env = make_envelope("group", "@_user_1 explain VPC", "ou_user1", Some(mentions)); + let cfg = test_config(); + let (evt, _media) = parse_message_event(&env, Some("ou_bot"), &cfg, false).unwrap(); + assert_eq!(evt.channel.channel_type, "group"); + assert_eq!(evt.content.text, "explain VPC"); + assert_eq!(evt.mentions, vec!["ou_bot"]); + } + + #[test] + fn parse_group_without_mention_filtered() { + let env = make_envelope("group", "just chatting", "ou_user1", None); + let cfg = test_config(); // require_mention = true + // Gateway-side mention gating: group message without bot mention is filtered + assert!(parse_message_event(&env, Some("ou_bot"), &cfg, false).is_none()); + } + + #[test] + fn parse_group_without_mention_allowed_when_disabled() { + let env = make_envelope("group", "just chatting", "ou_user1", None); + let mut cfg = test_config(); + cfg.require_mention = false; + let evt = parse_message_event(&env, Some("ou_bot"), &cfg, false); + assert!(evt.is_some()); + } + + #[test] + fn parse_skips_bot_sender() { + let mut env = make_envelope("p2p", "hello", "ou_bot", None); + env.event.as_mut().unwrap().sender.as_mut().unwrap().sender_type = Some("bot".into()); + let cfg = test_config(); + assert!(parse_message_event(&env, Some("ou_bot"), &cfg, false).is_none()); + } + + #[test] + fn parse_skips_empty_text() { + let env = make_envelope("p2p", " ", "ou_user1", None); + let cfg = test_config(); + assert!(parse_message_event(&env, Some("ou_bot"), &cfg, false).is_none()); + } + + #[test] + fn parse_skips_non_text_message() { + let mut env = make_envelope("p2p", "hello", "ou_user1", None); + env.event.as_mut().unwrap().message.as_mut().unwrap().message_type = Some("sticker".into()); + let cfg = test_config(); + assert!(parse_message_event(&env, Some("ou_bot"), &cfg, false).is_none()); + } + + #[test] + fn parse_skips_self_message() { + let env = make_envelope("p2p", "hello", "ou_bot", None); + let cfg = test_config(); + assert!(parse_message_event(&env, Some("ou_bot"), &cfg, false).is_none()); + } + + // --- Dedupe tests --- + + #[test] + fn dedupe_first_is_not_duplicate() { + let cache = DedupeCache::new(300); + assert!(!cache.is_duplicate("msg_1")); + } + + #[test] + fn dedupe_second_is_duplicate() { + let cache = DedupeCache::new(300); + assert!(!cache.is_duplicate("msg_1")); + assert!(cache.is_duplicate("msg_1")); + } + + // --- Webhook security tests --- + + #[test] + fn verify_signature_correct() { + use sha2::{Digest, Sha256}; + let ts = "1234567890"; + let nonce = "abc"; + let key = "mykey"; + let body = b"hello"; + let mut hasher = Sha256::new(); + hasher.update(ts.as_bytes()); + hasher.update(nonce.as_bytes()); + hasher.update(key.as_bytes()); + hasher.update(body); + let expected = format!("{:x}", hasher.finalize()); + assert!(verify_signature(ts, nonce, key, body, &expected)); + } + + #[test] + fn verify_signature_wrong() { + assert!(!verify_signature("ts", "nonce", "key", b"body", "bad_sig")); + } + + #[test] + fn rate_limiter_allows_within_limit() { + let rl = RateLimiter::new(60, 3); + assert!(!rl.check("ip1")); + assert!(!rl.check("ip1")); + assert!(!rl.check("ip1")); + } + + #[test] + fn rate_limiter_rejects_over_limit() { + let rl = RateLimiter::new(60, 2); + assert!(!rl.check("ip1")); + assert!(!rl.check("ip1")); + assert!(rl.check("ip1")); // 3rd request exceeds limit of 2 + } + + // --- Name resolution tests --- + + #[tokio::test] + async fn resolve_user_name_success_and_cache() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/open-apis/auth/v3/tenant_access_token/internal")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "code": 0, "tenant_access_token": "t-tok", "expire": 7200 + }))) + .mount(&server) + .await; + Mock::given(method("GET")) + .and(path("/open-apis/contact/v3/users/ou_user1")) + .and(header("authorization", "Bearer t-tok")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "code": 0, + "data": { "user": { "name": "Alice", "open_id": "ou_user1" } } + }))) + .expect(1) // should only be called once (cached on second call) + .mount(&server) + .await; + + let config = test_config(); + let token_cache = Arc::new(FeishuTokenCache::with_base(&config, &server.uri())); + let name_cache = Arc::new(std::sync::Mutex::new(HashMap::new())); + let client = reqwest::Client::new(); + + let name = resolve_user_name("ou_user1", &name_cache, &token_cache, &client, &server.uri()).await; + assert_eq!(name, "Alice"); + + // Second call should use cache (expect(1) above ensures no second API call) + let name2 = resolve_user_name("ou_user1", &name_cache, &token_cache, &client, &server.uri()).await; + assert_eq!(name2, "Alice"); + } + + #[tokio::test] + async fn resolve_user_name_api_error_falls_back_to_open_id() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/open-apis/auth/v3/tenant_access_token/internal")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "code": 0, "tenant_access_token": "t-tok", "expire": 7200 + }))) + .mount(&server) + .await; + Mock::given(method("GET")) + .and(path("/open-apis/contact/v3/users/ou_unknown")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "code": 40003, "msg": "user not found" + }))) + .mount(&server) + .await; + + let config = test_config(); + let token_cache = Arc::new(FeishuTokenCache::with_base(&config, &server.uri())); + let name_cache = Arc::new(std::sync::Mutex::new(HashMap::new())); + let client = reqwest::Client::new(); + + let name = resolve_user_name("ou_unknown", &name_cache, &token_cache, &client, &server.uri()).await; + assert_eq!(name, "ou_unknown"); + } + + // --- extract_mentions tests --- + + #[test] + fn extract_mentions_replacen_only_first() { + // If mention key appears in normal text too, only the first occurrence is removed + let mentions = vec![FeishuMention { + key: Some("@_user_1".into()), + id: Some(FeishuMentionId { open_id: Some("ou_bot".into()) }), + name: Some("Bot".into()), + }]; + let env = make_envelope("group", "@_user_1 tell me about @_user_1 patterns", "ou_user1", Some(mentions)); + let cfg = test_config(); + let (evt, _media) = parse_message_event(&env, Some("ou_bot"), &cfg, false).unwrap(); + // Only first @_user_1 removed, second preserved + assert!(evt.content.text.contains("@_user_1")); + } + + // --- allowed_users filtering --- + + #[test] + fn parse_allowed_users_blocks_unlisted() { + let env = make_envelope("p2p", "hello", "ou_stranger", None); + let mut cfg = test_config(); + cfg.allowed_users = vec!["ou_vip".into()]; + assert!(parse_message_event(&env, Some("ou_bot"), &cfg, false).is_none()); + } + + #[test] + fn parse_allowed_users_permits_listed() { + let env = make_envelope("p2p", "hello", "ou_vip", None); + let mut cfg = test_config(); + cfg.allowed_users = vec!["ou_vip".into()]; + assert!(parse_message_event(&env, Some("ou_bot"), &cfg, false).is_some()); + } + + // --- allowed_groups filtering --- + + #[test] + fn parse_allowed_groups_blocks_unlisted() { + let mentions = vec![FeishuMention { + key: Some("@_user_1".into()), + id: Some(FeishuMentionId { open_id: Some("ou_bot".into()) }), + name: Some("Bot".into()), + }]; + let env = make_envelope("group", "@_user_1 hello", "ou_user1", Some(mentions)); + let mut cfg = test_config(); + cfg.allowed_groups = vec!["oc_other".into()]; // oc_chat1 not in list + assert!(parse_message_event(&env, Some("ou_bot"), &cfg, false).is_none()); + } + + #[test] + fn parse_allowed_groups_permits_listed() { + let mentions = vec![FeishuMention { + key: Some("@_user_1".into()), + id: Some(FeishuMentionId { open_id: Some("ou_bot".into()) }), + name: Some("Bot".into()), + }]; + let env = make_envelope("group", "@_user_1 hello", "ou_user1", Some(mentions)); + let mut cfg = test_config(); + cfg.allowed_groups = vec!["oc_chat1".into()]; + assert!(parse_message_event(&env, Some("ou_bot"), &cfg, false).is_some()); + } + + // --- Token TTL from API response --- + + #[tokio::test] + async fn token_uses_api_expire_field() { + let server = MockServer::start().await; + // Return a short expire (10s). With 300s margin, token should be + // considered expired immediately on second call. + Mock::given(method("POST")) + .and(path("/open-apis/auth/v3/tenant_access_token/internal")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "code": 0, + "tenant_access_token": "t-short", + "expire": 10 + }))) + .expect(2) // called twice because 10s < 300s margin → always expired + .mount(&server) + .await; + + let config = test_config(); + let cache = FeishuTokenCache::with_base(&config, &server.uri()); + let client = reqwest::Client::new(); + + let t1 = cache.get_token(&client).await.unwrap(); + assert_eq!(t1, "t-short"); + // Second call should refresh (expire=10 < margin=300) + let t2 = cache.get_token(&client).await.unwrap(); + assert_eq!(t2, "t-short"); + // expect(2) verifies it was called twice + } + + // --- constant_time_eq --- + + #[test] + fn constant_time_eq_same() { + assert!(constant_time_eq("abc123", "abc123")); + } + + #[test] + fn constant_time_eq_different() { + assert!(!constant_time_eq("abc123", "abc124")); + } + + #[test] + fn constant_time_eq_different_length() { + assert!(!constant_time_eq("short", "longer_string")); + } + + // --- Thread ID parsing --- + + #[test] + fn parse_thread_id_from_root_id() { + let mut env = make_envelope("p2p", "reply", "ou_user1", None); + env.event.as_mut().unwrap().message.as_mut().unwrap().root_id = Some("om_root".into()); + let cfg = test_config(); + let (evt, _media) = parse_message_event(&env, Some("ou_bot"), &cfg, false).unwrap(); + assert_eq!(evt.channel.thread_id, Some("om_root".into())); + } + + #[test] + fn parse_thread_id_from_parent_id() { + let mut env = make_envelope("p2p", "reply", "ou_user1", None); + env.event.as_mut().unwrap().message.as_mut().unwrap().parent_id = Some("om_parent".into()); + let cfg = test_config(); + let (evt, _media) = parse_message_event(&env, Some("ou_bot"), &cfg, false).unwrap(); + assert_eq!(evt.channel.thread_id, Some("om_parent".into())); + } + + // --- Emoji reaction mapping --- + + #[test] + fn emoji_mapping_known() { + assert_eq!(emoji_to_feishu_reaction("👍"), Some("THUMBSUP")); + assert_eq!(emoji_to_feishu_reaction("🔥"), Some("FIRE")); + assert_eq!(emoji_to_feishu_reaction("👀"), Some("EYES")); + } + + #[test] + fn emoji_mapping_unknown() { + assert_eq!(emoji_to_feishu_reaction("🎉"), None); + } + + // --- Participated thread tests --- + + #[test] + fn participated_thread_bypasses_mention_gating() { + let cfg = test_config(); // require_mention = true + // Build envelope with root_id (in a thread) + let mut env = make_envelope("group", "Hello", "ou_user1", None); + env.event.as_mut().unwrap().message.as_mut().unwrap().root_id = Some("root_123".into()); + // Without participation: no @mention → None + assert!(parse_message_event(&env, Some("ou_bot"), &cfg, false).is_none()); + // With participation: no @mention → Some (bypass) + let result = parse_message_event(&env, Some("ou_bot"), &cfg, true); + assert!(result.is_some()); + let (evt, _) = result.unwrap(); + assert_eq!(evt.channel.thread_id.as_deref(), Some("root_123")); + } + + #[test] + fn participated_no_effect_without_thread() { + let cfg = test_config(); // require_mention = true + // Message in main channel (no thread_id) — participated flag doesn't help + let env = make_envelope("group", "Hello", "ou_user1", None); + assert!(parse_message_event(&env, Some("ou_bot"), &cfg, true).is_none()); + } + + #[test] + fn record_participation_and_eviction() { + let cache = Arc::new(std::sync::Mutex::new(HashMap::new())); + // Record a thread + record_participation(&cache, "thread_1", 86400); + assert_eq!(cache.lock().unwrap().len(), 1); + // Fill beyond PARTICIPATION_CACHE_MAX + for i in 0..PARTICIPATION_CACHE_MAX + 10 { + record_participation(&cache, &format!("thread_{i}"), 86400); + } + // After eviction, should be roughly half + assert!(cache.lock().unwrap().len() <= PARTICIPATION_CACHE_MAX); + } + + // --- Multibot-mentions mode tests --- + + #[test] + fn multibot_mentions_mode_bypasses_when_single_bot() { + let mut cfg = test_config(); + cfg.allow_user_messages = AllowUsers::MultibotMentions; + let mut env = make_envelope("group", "Hello", "ou_user1", None); + env.event.as_mut().unwrap().message.as_mut().unwrap().root_id = Some("root_456".into()); + // participated + no other bot → bypass_mention_gating=true + let result = parse_message_event(&env, Some("ou_bot"), &cfg, true); + assert!(result.is_some()); + } + + #[test] + fn multibot_mentions_mode_requires_mention_when_not_participated() { + let mut cfg = test_config(); + cfg.allow_user_messages = AllowUsers::MultibotMentions; + let mut env = make_envelope("group", "Hello", "ou_user1", None); + env.event.as_mut().unwrap().message.as_mut().unwrap().root_id = Some("root_456".into()); + // not participated → bypass_mention_gating=false + assert!(parse_message_event(&env, Some("ou_bot"), &cfg, false).is_none()); + } + + #[test] + fn mentions_mode_never_bypasses() { + let mut cfg = test_config(); + cfg.allow_user_messages = AllowUsers::Mentions; + let mut env = make_envelope("group", "Hello", "ou_user1", None); + env.event.as_mut().unwrap().message.as_mut().unwrap().root_id = Some("root_789".into()); + // Even with bypass_mention_gating=true, Mentions mode never bypasses + // (caller would pass false because Mentions mode always returns false) + assert!(parse_message_event(&env, Some("ou_bot"), &cfg, false).is_none()); + } + + #[test] + fn quote_message_id_takes_priority_over_thread_id() { + use crate::schema::{GatewayReply, ReplyChannel, Content}; + let reply = GatewayReply { + schema: "openab.gateway.reply.v1".into(), + reply_to: "evt_123".into(), + platform: "feishu".into(), + channel: ReplyChannel { + id: "chat_123".into(), + thread_id: Some("om_root".into()), + }, + content: Content { + content_type: "text".into(), + text: "hello".into(), + attachments: vec![], + }, + command: None, + request_id: None, + quote_message_id: Some("om_specific".into()), + }; + // quote_message_id should take priority + let reply_target = reply.quote_message_id.as_deref() + .or(reply.channel.thread_id.as_deref()); + assert_eq!(reply_target, Some("om_specific")); + } + + #[test] + fn reply_target_falls_back_to_thread_id_when_no_quote() { + use crate::schema::{GatewayReply, ReplyChannel, Content}; + let reply = GatewayReply { + schema: "openab.gateway.reply.v1".into(), + reply_to: "evt_123".into(), + platform: "feishu".into(), + channel: ReplyChannel { + id: "chat_123".into(), + thread_id: Some("om_root".into()), + }, + content: Content { + content_type: "text".into(), + text: "hello".into(), + attachments: vec![], + }, + command: None, + request_id: None, + quote_message_id: None, + }; + let reply_target = reply.quote_message_id.as_deref() + .or(reply.channel.thread_id.as_deref()); + assert_eq!(reply_target, Some("om_root")); + } + + #[test] + fn reply_target_is_none_when_both_absent() { + use crate::schema::{GatewayReply, ReplyChannel, Content}; + let reply = GatewayReply { + schema: "openab.gateway.reply.v1".into(), + reply_to: "evt_123".into(), + platform: "feishu".into(), + channel: ReplyChannel { + id: "chat_123".into(), + thread_id: None, + }, + content: Content { + content_type: "text".into(), + text: "hello".into(), + attachments: vec![], + }, + command: None, + request_id: None, + quote_message_id: None, + }; + let reply_target = reply.quote_message_id.as_deref() + .or(reply.channel.thread_id.as_deref()); + assert_eq!(reply_target, None); + } + + #[tokio::test] + async fn quote_message_id_fallback_on_reply_failure() { + // Tests the actual handle_reply fallback path: when quote_message_id + // is set and the reply API fails, handle_reply retries as plain send. + let server = MockServer::start().await; + + // Token endpoint + Mock::given(method("POST")) + .and(path("/open-apis/auth/v3/tenant_access_token/internal")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "code": 0, + "tenant_access_token": "t-test", + "expire": 7200 + }))) + .mount(&server) + .await; + + // Reply API returns 400 (invalid quote_message_id) + Mock::given(method("POST")) + .and(path("/open-apis/im/v1/messages/om_invalid/reply")) + .respond_with(ResponseTemplate::new(400).set_body_string("invalid message_id")) + .expect(1) + .named("reply_api_fail") + .mount(&server) + .await; + + // Plain send endpoint succeeds (fallback path) + Mock::given(method("POST")) + .and(path("/open-apis/im/v1/messages")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "code": 0, + "data": {"message_id": "om_fallback_ok"} + }))) + .expect(1) + .named("plain_send_fallback") + .mount(&server) + .await; + + let mut config = test_config(); + config.api_base_override = Some(server.uri()); + let adapter = FeishuAdapter::new(config); + + let (event_tx, _rx) = tokio::sync::broadcast::channel(16); + + let reply = crate::schema::GatewayReply { + schema: "openab.gateway.reply.v1".into(), + reply_to: "evt_123".into(), + platform: "feishu".into(), + channel: crate::schema::ReplyChannel { + id: "oc_chat1".into(), + thread_id: None, + }, + content: crate::schema::Content { + content_type: "text".into(), + text: "hello from fallback test".into(), + attachments: vec![], + }, + command: None, + request_id: None, + quote_message_id: Some("om_invalid".into()), + }; + + handle_reply(&reply, &adapter, &event_tx).await; + // wiremock expect(1) on both mocks verifies: + // 1. Reply API was called (and failed) + // 2. Plain send was called (fallback triggered by quote_message_id.is_some() guard) + } + + // --- Edit-cap helpers (F3/F4/F8/F10): no network required --- + + fn fresh_cache() -> Arc> { + Arc::new(std::sync::Mutex::new(EditCountsCache::default())) + } + + #[test] + fn cap_detect_json_code_match() { + // Real-shape Feishu error body: trusted JSON code field == 230072. + let body = r#"{"code":230072,"msg":"The message has reached the number of times it can be edited","data":{}}"#; + assert!(is_feishu_cap_reached_body(body)); + } + + #[test] + fn cap_detect_json_other_code_no_false_positive() { + // JSON parses but code is unrelated; any inner string containing + // "230072" must NOT trigger cap detection. + let body = r#"{"code":99999,"msg":"some other error 230072 in description"}"#; + assert!(!is_feishu_cap_reached_body(body)); + } + + #[test] + fn cap_detect_substring_fallback_for_non_json() { + // Proxy-style HTML / non-JSON body — substring fallback kicks in. + let html = "Error 230072 — number of times it can be edited"; + assert!(is_feishu_cap_reached_body(html)); + + let plain = "upstream error: 230072"; + assert!(is_feishu_cap_reached_body(plain)); + } + + #[test] + fn cap_detect_unrelated_body_returns_false() { + let body = r#"{"code":99991,"msg":"rate limited","data":{}}"#; + assert!(!is_feishu_cap_reached_body(body)); + assert!(!is_feishu_cap_reached_body("plain text without the code")); + assert!(!is_feishu_cap_reached_body("")); + } + + #[test] + fn cap_pre_check_below_threshold_does_not_trip() { + let cache = fresh_cache(); + // Cap is FEISHU_EDIT_CAP (18). 17 increments must stay below. + for _ in 0..17 { + increment_edit_count(&cache, "om_msg1"); + } + assert!(!is_edit_cap_reached(&cache, "om_msg1")); + } + + #[test] + fn cap_pre_check_at_threshold_trips() { + let cache = fresh_cache(); + for _ in 0..(FEISHU_EDIT_CAP as usize) { + increment_edit_count(&cache, "om_msg1"); + } + assert!(is_edit_cap_reached(&cache, "om_msg1")); + } + + #[test] + fn mark_edit_cap_short_circuits_pre_check() { + let cache = fresh_cache(); + mark_edit_cap(&cache, "om_msg1"); + // Sentinel u32::MAX >= FEISHU_EDIT_CAP, so pre-check trips immediately. + assert!(is_edit_cap_reached(&cache, "om_msg1")); + } + + #[test] + fn mark_edit_cap_does_not_double_increment() { + let cache = fresh_cache(); + mark_edit_cap(&cache, "om_msg1"); + increment_edit_count(&cache, "om_msg1"); + // Increment must not push past u32::MAX sentinel. + let map = cache.lock().unwrap(); + assert_eq!(map.counts.get("om_msg1").copied(), Some(u32::MAX)); + } + + #[test] + fn eviction_drops_oldest_inserts_not_lowest_count() { + // Pre-fill cache to over capacity, simulating a long-running adapter. + let cache = fresh_cache(); + // First insert message_id "old_*" with high counts so they would + // *survive* a count-ascending eviction (the buggy strategy). They + // must instead be the *first* evicted under FIFO. + let overcap = EDIT_COUNTS_CACHE_MAX + 100; + for i in 0..overcap { + let id = format!("om_msg_{i:05}"); + increment_edit_count(&cache, &id); + } + // Insert a fresh "active stream" id last — its low count would have + // marked it for eviction under count-ascending. With FIFO it must + // survive. + increment_edit_count(&cache, "om_active_recent"); + + let map = cache.lock().unwrap(); + // FIFO eviction: the newest insert must still be present. + assert!( + map.counts.contains_key("om_active_recent"), + "active recent insert was evicted under FIFO — bug regressed" + ); + // FIFO eviction: at least one of the very first inserts must be gone. + let some_oldest_evicted = (0..50).any(|i| { + let id = format!("om_msg_{i:05}"); + !map.counts.contains_key(&id) + }); + assert!( + some_oldest_evicted, + "no early-insert key was evicted — FIFO not working" + ); + // Cache size bounded. + assert!( + map.counts.len() <= EDIT_COUNTS_CACHE_MAX, + "cache size {} > max {}", + map.counts.len(), + EDIT_COUNTS_CACHE_MAX + ); + } + + #[test] + fn message_id_validation_accepts_valid_shapes() { + assert!(is_valid_feishu_message_id("om_dc13264520392907fcq2e6kpngacls")); + assert!(is_valid_feishu_message_id("om_abc123")); + assert!(is_valid_feishu_message_id("om_A_B_c_1_2_3")); + } + + #[test] + fn message_id_validation_rejects_path_traversal_and_query() { + // The shape guard is the F8 defence: stop crafted IDs containing URL + // metachars from altering /im/v1/messages/{id} semantics. + assert!(!is_valid_feishu_message_id("../etc/passwd")); + assert!(!is_valid_feishu_message_id("om_abc/extra")); + assert!(!is_valid_feishu_message_id("om_abc?q=1")); + assert!(!is_valid_feishu_message_id("om_abc#frag")); + assert!(!is_valid_feishu_message_id("om_abc%2Fextra")); + assert!(!is_valid_feishu_message_id("")); + assert!(!is_valid_feishu_message_id("om_")); + assert!(!is_valid_feishu_message_id("not_om_prefix")); + // Length cap (defense against pathological inputs). + let too_long = format!("om_{}", "a".repeat(200)); + assert!(!is_valid_feishu_message_id(&too_long)); + } + + // --- edit_feishu_message integration (wiremock): proves the cap is detected + // through the HTTP-status gate, including the HTTP-200 + body-code case --- + + async fn mount_token(server: &MockServer) { + Mock::given(method("POST")) + .and(path("/open-apis/auth/v3/tenant_access_token/internal")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "code": 0, + "msg": "ok", + "tenant_access_token": "t-edit-test", + "expire": 7200 + }))) + .mount(server) + .await; + } + + #[tokio::test] + async fn edit_cap_detected_on_http_200_body_code() { + // Feishu returns the edit-cap rejection as HTTP 200 + {"code":230072}. + // Regression guard for the body-code-first fix: a status-only success + // gate would miscount this as Edited and never trip cap detection. + let server = MockServer::start().await; + mount_token(&server).await; + Mock::given(method("PUT")) + .and(path("/open-apis/im/v1/messages/om_capped")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "code": 230072, + "msg": "The message has reached the number of times it can be edited." + }))) + .mount(&server) + .await; + + let mut config = test_config(); + config.api_base_override = Some(server.uri()); + let adapter = FeishuAdapter::new(config); + + let outcome = edit_feishu_message(&adapter, "om_capped", "hello").await; + assert!( + matches!(outcome, EditOutcome::CapReached), + "HTTP 200 + code 230072 must yield CapReached, got non-cap outcome" + ); + // Sentinel marked → subsequent pre-check short-circuits. + assert!(is_edit_cap_reached(&adapter.edit_counts, "om_capped")); + } + + #[tokio::test] + async fn edit_success_on_http_200_code_zero() { + // HTTP 200 + {"code":0} is a real success → Edited + count incremented. + let server = MockServer::start().await; + mount_token(&server).await; + Mock::given(method("PUT")) + .and(path("/open-apis/im/v1/messages/om_ok")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "code": 0, + "msg": "success", + "data": {} + }))) + .mount(&server) + .await; + + let mut config = test_config(); + config.api_base_override = Some(server.uri()); + let adapter = FeishuAdapter::new(config); + + let outcome = edit_feishu_message(&adapter, "om_ok", "hello").await; + assert!( + matches!(outcome, EditOutcome::Edited), + "HTTP 200 + code 0 must yield Edited" + ); + let map = adapter.edit_counts.lock().unwrap(); + assert_eq!(map.counts.get("om_ok").copied(), Some(1)); + } + + #[tokio::test] + async fn edit_failure_on_http_200_other_code() { + // HTTP 200 + non-zero, non-cap code is a genuine failure, not a success. + let server = MockServer::start().await; + mount_token(&server).await; + Mock::given(method("PUT")) + .and(path("/open-apis/im/v1/messages/om_err")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "code": 99991, + "msg": "rate limited" + }))) + .mount(&server) + .await; + + let mut config = test_config(); + config.api_base_override = Some(server.uri()); + let adapter = FeishuAdapter::new(config); + + let outcome = edit_feishu_message(&adapter, "om_err", "hello").await; + assert!( + matches!(outcome, EditOutcome::Failed(_)), + "HTTP 200 + code 99991 must yield Failed, not Edited" + ); + // Failure must NOT increment the edit count. + let map = adapter.edit_counts.lock().unwrap(); + assert_eq!(map.counts.get("om_err").copied(), None); + } + + // --- handle_reply dispatch-seam message_id validation (R3) --- + // These exercise the seam reject path directly (the edit_* tests above call + // edit_feishu_message and bypass the seam). The guard runs before any + // network call, so no mock server is needed. + + #[tokio::test] + async fn handle_reply_seam_rejects_invalid_id_with_response() { + let adapter = FeishuAdapter::new(test_config()); + let (event_tx, mut event_rx) = tokio::sync::broadcast::channel(8); + + let reply = crate::schema::GatewayReply { + schema: "openab.gateway.reply.v1".into(), + reply_to: "draft".into(), // sentinel, not an om_ id → rejected + platform: "feishu".into(), + channel: crate::schema::ReplyChannel { + id: "oc_chat1".into(), + thread_id: None, + }, + content: crate::schema::Content { + content_type: "text".into(), + text: "hello".into(), + attachments: vec![], + }, + command: Some("edit_message".into()), + request_id: Some("req_seam_1".into()), + quote_message_id: None, + }; + + handle_reply(&reply, &adapter, &event_tx).await; + + let raw = event_rx.try_recv().expect("expected a GatewayResponse"); + let resp: serde_json::Value = serde_json::from_str(&raw).unwrap(); + assert_eq!(resp["request_id"], "req_seam_1"); + assert_eq!(resp["success"], false); + assert_eq!(resp["error"], "invalid message_id format"); + } + + #[tokio::test] + async fn handle_reply_seam_rejects_invalid_id_silently_without_request_id() { + let adapter = FeishuAdapter::new(test_config()); + let (event_tx, mut event_rx) = tokio::sync::broadcast::channel(8); + + let reply = crate::schema::GatewayReply { + schema: "openab.gateway.reply.v1".into(), + reply_to: "om_bad/segment".into(), // URL metachar → rejected + platform: "feishu".into(), + channel: crate::schema::ReplyChannel { + id: "oc_chat1".into(), + thread_id: None, + }, + content: crate::schema::Content { + content_type: "text".into(), + text: "hello".into(), + attachments: vec![], + }, + command: Some("delete_message".into()), + request_id: None, + quote_message_id: None, + }; + + handle_reply(&reply, &adapter, &event_tx).await; + + assert!( + event_rx.try_recv().is_err(), + "no response expected when request_id is absent" + ); + } +} diff --git a/crates/openab-gateway/src/adapters/googlechat.rs b/crates/openab-gateway/src/adapters/googlechat.rs new file mode 100644 index 000000000..93c0c8f8e --- /dev/null +++ b/crates/openab-gateway/src/adapters/googlechat.rs @@ -0,0 +1,2470 @@ +use crate::schema::*; +use axum::extract::State; +use axum::http::HeaderMap; +use axum::response::IntoResponse; +use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation}; +use serde::Deserialize; +use std::sync::Arc; +use std::time::Instant; +use tokio::sync::RwLock; +use tracing::{error, info, warn}; + +pub const GOOGLE_CHAT_API_BASE: &str = "https://chat.googleapis.com/v1"; +const GOOGLE_CHAT_MESSAGE_LIMIT: usize = 4096; + +const IMAGE_MAX_DIMENSION_PX: u32 = 1200; +const IMAGE_JPEG_QUALITY: u8 = 75; +const IMAGE_MAX_DOWNLOAD: u64 = 10 * 1024 * 1024; // 10 MB +const FILE_MAX_DOWNLOAD: u64 = 512 * 1024; // 512 KB +const AUDIO_MAX_DOWNLOAD: u64 = 25 * 1024 * 1024; // 25 MB +/// Per-request timeout for Google Chat Media API downloads. Prevents a hung +/// connection from blocking the spawned download task indefinitely. +const MEDIA_REQUEST_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30); +/// Cap on text file attachments per message (matches Discord/Slack). +const TEXT_FILE_COUNT_CAP: usize = 5; +/// Cap on aggregate text file bytes per message (matches Discord/Slack 1 MB). +const TEXT_TOTAL_CAP: u64 = 1024 * 1024; + +// --- Google Chat types --- +// +// Google Chat delivers webhooks in two shapes depending on the App's +// Connection settings in the Cloud Console: +// - HTTP endpoint URL mode: top-level fields (message, user, space, ...) +// - Pub/Sub mode: wrapped under `chat.messagePayload` +// Both are supported via the optional fields below; the handler prefers +// the wrapped form and falls back to top-level when `chat` is absent. + +#[derive(Debug, Deserialize)] +pub struct GoogleChatEnvelope { + pub chat: Option, + // HTTP endpoint URL top-level fields (used when `chat` is None) + pub message: Option, + pub user: Option, + pub space: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ChatPayload { + pub user: Option, + pub message_payload: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct MessagePayload { + pub message: Option, + pub space: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GoogleChatMessage { + pub name: String, + pub text: Option, + pub argument_text: Option, + pub sender: Option, + pub thread: Option, + pub space: Option, + #[serde(default)] + pub attachment: Vec, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GoogleChatAttachment { + #[allow(dead_code)] + pub name: Option, + pub content_name: Option, + pub content_type: Option, + pub source: Option, + pub attachment_data_ref: Option, + #[allow(dead_code)] + pub drive_data_ref: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AttachmentDataRef { + pub resource_name: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +#[allow(dead_code)] +pub struct DriveDataRef { + pub drive_file_id: Option, +} + +/// Reference to media that needs async download after webhook parse. +#[derive(Debug, Clone)] +pub enum GoogleChatMediaRef { + Image { + resource_name: String, + content_name: String, + }, + File { + resource_name: String, + content_name: String, + }, + Audio { + resource_name: String, + content_name: String, + content_type: String, + }, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GoogleChatUser { + pub name: String, + pub display_name: String, + #[serde(rename = "type")] + pub user_type: String, +} + +#[derive(Debug, Deserialize)] +pub struct GoogleChatThread { + pub name: String, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GoogleChatSpace { + pub name: String, + #[serde(rename = "type")] + pub space_type: Option, + // Parsed by serde, not consumed in current code paths. + #[allow(dead_code)] + pub space_type_renamed: Option, +} + +// --- Webhook JWT verification --- + +const GOOGLE_CHAT_ISSUER: &str = "https://accounts.google.com"; +const GOOGLE_CHAT_JWKS_URL: &str = "https://www.googleapis.com/oauth2/v3/certs"; +const GOOGLE_CHAT_SIGNER_EMAIL: &str = "chat@system.gserviceaccount.com"; +const JWKS_CACHE_TTL: std::time::Duration = std::time::Duration::from_secs(3600); + +/// Verify the JWT's `email` claim belongs to Google Chat. +/// HTTP endpoint URL webhooks are signed by `chat@system.gserviceaccount.com`. +/// Without this check, any Google-issued ID token would be accepted. +fn verify_email_claim(claims: &serde_json::Value) -> Result<(), String> { + let email = claims + .get("email") + .and_then(|v| v.as_str()) + .ok_or("missing email claim")?; + if email != GOOGLE_CHAT_SIGNER_EMAIL { + return Err(format!( + "email claim mismatch: expected {GOOGLE_CHAT_SIGNER_EMAIL}, got {email}" + )); + } + Ok(()) +} + +#[derive(Debug, Clone, Deserialize)] +struct JwkKey { + kid: Option, + n: String, + e: String, + kty: String, +} + +#[derive(Debug, Deserialize)] +struct JwksResponse { + keys: Vec, +} + +pub struct GoogleChatJwtVerifier { + audience: String, + client: reqwest::Client, + jwks_cache: RwLock, Instant)>>, +} + +impl GoogleChatJwtVerifier { + pub fn new(audience: String) -> Self { + Self { + audience, + client: reqwest::Client::new(), + jwks_cache: RwLock::new(None), + } + } + + async fn get_jwks(&self) -> Result, String> { + { + let cache = self.jwks_cache.read().await; + if let Some((ref keys, fetched_at)) = *cache { + if fetched_at.elapsed() < JWKS_CACHE_TTL { + return Ok(keys.clone()); + } + } + } + let jwks: JwksResponse = self + .client + .get(GOOGLE_CHAT_JWKS_URL) + .send() + .await + .map_err(|e| format!("JWKS fetch error: {e}"))? + .json() + .await + .map_err(|e| format!("JWKS parse error: {e}"))?; + + let keys = jwks.keys; + *self.jwks_cache.write().await = Some((keys.clone(), Instant::now())); + Ok(keys) + } + + pub async fn verify(&self, auth_header: &str) -> Result<(), String> { + let token = auth_header + .strip_prefix("Bearer ") + .ok_or("missing Bearer prefix")?; + + let header = + jsonwebtoken::decode_header(token).map_err(|e| format!("invalid JWT header: {e}"))?; + let kid = header.kid.ok_or("no kid in JWT header")?; + + let keys = self.get_jwks().await?; + let key = match keys.iter().find(|k| k.kid.as_deref() == Some(&kid)) { + Some(k) => k.clone(), + None => { + // Key rotation: invalidate cache and retry + *self.jwks_cache.write().await = None; + let refreshed = self.get_jwks().await?; + refreshed + .into_iter() + .find(|k| k.kid.as_deref() == Some(&kid)) + .ok_or_else(|| format!("no matching JWK for kid={kid}"))? + } + }; + + if key.kty != "RSA" { + return Err(format!("unsupported key type: {}", key.kty)); + } + + let decoding_key = DecodingKey::from_rsa_components(&key.n, &key.e) + .map_err(|e| format!("RSA key decode error: {e}"))?; + + let mut validation = Validation::new(Algorithm::RS256); + validation.set_audience(&[&self.audience]); + validation.set_issuer(&[GOOGLE_CHAT_ISSUER]); + validation.validate_exp = true; + + let token_data = decode::(token, &decoding_key, &validation) + .map_err(|e| format!("JWT validation failed: {e}"))?; + + verify_email_claim(&token_data.claims)?; + + Ok(()) + } +} + +// --- Adapter (encapsulates all Google Chat state) --- + +pub struct GoogleChatAdapter { + pub token_cache: Option, + pub access_token: Option, + pub jwt_verifier: Option, + pub client: reqwest::Client, + pub api_base: String, +} + +impl GoogleChatAdapter { + pub fn new( + token_cache: Option, + access_token: Option, + jwt_verifier: Option, + ) -> Self { + Self { + token_cache, + access_token, + jwt_verifier, + client: reqwest::Client::new(), + api_base: GOOGLE_CHAT_API_BASE.into(), + } + } + + async fn get_token(&self) -> Option { + if let Some(ref cache) = self.token_cache { + match cache.get_token(&self.client).await { + Ok(t) => return Some(t), + Err(e) => { + error!("googlechat token refresh failed: {e}"); + return None; + } + } + } + self.access_token.clone() + } + + async fn edit_message(&self, message_name: &str, text: &str) { + let Some(token) = self.get_token().await else { + tracing::warn!("googlechat edit_message: no token available"); + return; + }; + + let formatted = markdown_to_gchat(text); + let url = format!( + "{}/{}?updateMask=text", + self.api_base, message_name + ); + let body = serde_json::json!({ "text": formatted }); + + match self.client.patch(&url).bearer_auth(&token).json(&body).send().await { + Ok(r) if r.status().is_success() => { + tracing::trace!(message_name = %message_name, "googlechat message edited"); + } + Ok(r) => { + let status = r.status(); + let body = r.text().await.unwrap_or_default(); + error!(status = %status, body = %body, "googlechat edit_message failed"); + } + Err(e) => { + error!(err = %e, "googlechat edit_message request failed"); + } + } + } + + pub async fn handle_reply( + &self, + reply: &GatewayReply, + event_tx: &tokio::sync::broadcast::Sender, + ) { + // Command routing + match reply.command.as_deref() { + Some("add_reaction") | Some("remove_reaction") | Some("create_topic") => return, + Some("edit_message") => { + self.edit_message(&reply.reply_to, &reply.content.text).await; + return; + } + _ => {} + } + + info!( + space = %reply.channel.id, + thread_id = ?reply.channel.thread_id, + "gateway → googlechat" + ); + + let Some(token) = self.get_token().await else { + info!( + text = %reply.content.text, + "googlechat reply (dry-run, no credentials configured)" + ); + if let Some(ref req_id) = reply.request_id { + let resp = crate::schema::GatewayResponse { + schema: "openab.gateway.response.v1".into(), + request_id: req_id.clone(), + success: false, + thread_id: None, + message_id: None, + error: Some("no credentials configured".into()), + }; + if let Ok(json) = serde_json::to_string(&resp) { + let _ = event_tx.send(json); + } + } + return; + }; + + let text = &reply.content.text; + let chunks = split_text(text, GOOGLE_CHAT_MESSAGE_LIMIT); + + // Empty message: short-circuit, send failure ack and skip API call + if chunks.is_empty() { + if let Some(ref req_id) = reply.request_id { + let resp = crate::schema::GatewayResponse { + schema: "openab.gateway.response.v1".into(), + request_id: req_id.clone(), + success: false, + thread_id: None, + message_id: None, + error: Some("empty message".into()), + }; + if let Ok(json) = serde_json::to_string(&resp) { + let _ = event_tx.send(json); + } + } + return; + } + + if chunks.len() == 1 { + let result = send_message( + &self.client, + &token, + &reply.channel.id, + reply.channel.thread_id.as_deref(), + text, + &self.api_base, + ) + .await; + + if let Some(ref req_id) = reply.request_id { + let (success, message_id, error) = match result { + Ok(name) => (true, Some(name), None), + Err(e) => (false, None, Some(e)), + }; + let resp = crate::schema::GatewayResponse { + schema: "openab.gateway.response.v1".into(), + request_id: req_id.clone(), + success, + thread_id: None, + message_id, + error, + }; + if let Ok(json) = serde_json::to_string(&resp) { + let _ = event_tx.send(json); + } + } + } else { + let mut first_msg_name: Option = None; + let mut first_error: Option = None; + for chunk in chunks { + match send_message( + &self.client, + &token, + &reply.channel.id, + reply.channel.thread_id.as_deref(), + chunk, + &self.api_base, + ) + .await + { + Ok(name) => { + if first_msg_name.is_none() { + first_msg_name = Some(name); + } + } + Err(e) => { + if first_error.is_none() { + first_error = Some(e); + } + } + } + } + if let Some(ref req_id) = reply.request_id { + let resp = crate::schema::GatewayResponse { + schema: "openab.gateway.response.v1".into(), + request_id: req_id.clone(), + success: first_msg_name.is_some() && first_error.is_none(), + thread_id: None, + message_id: first_msg_name, + error: first_error, + }; + if let Ok(json) = serde_json::to_string(&resp) { + let _ = event_tx.send(json); + } + } + } + } +} + +// --- Webhook handler --- + +pub async fn webhook( + State(state): State>, + headers: HeaderMap, + body: axum::body::Bytes, +) -> axum::response::Response { + info!("googlechat webhook received ({} bytes)", body.len()); + + if let Some(ref adapter) = state.google_chat { + if let Some(ref verifier) = adapter.jwt_verifier { + let auth_header = match headers + .get("authorization") + .and_then(|v| v.to_str().ok()) + { + Some(h) => h, + None => { + warn!("googlechat webhook: missing authorization header"); + return (axum::http::StatusCode::UNAUTHORIZED, "unauthorized").into_response(); + } + }; + if let Err(e) = verifier.verify(auth_header).await { + warn!(error = %e, "googlechat webhook JWT verification failed"); + return (axum::http::StatusCode::UNAUTHORIZED, "unauthorized").into_response(); + } + } + } + + let envelope: GoogleChatEnvelope = match serde_json::from_slice(&body) { + Ok(e) => e, + Err(e) => { + let body_str = String::from_utf8_lossy(&body); + error!(body = %body_str, "googlechat webhook parse error: {e}"); + return (axum::http::StatusCode::BAD_REQUEST, "bad request").into_response(); + } + }; + + // Try the Pub/Sub `chat`-wrapped shape first, then fall back to the + // HTTP endpoint URL top-level shape. + let (msg_opt, top_user, top_space) = if let Some(chat) = envelope.chat { + let user = chat.user; + let (msg, space) = match chat.message_payload { + Some(p) => (p.message, p.space), + None => (None, None), + }; + (msg, user, space) + } else { + (envelope.message, envelope.user, envelope.space) + }; + + let Some(ref msg) = msg_opt else { + return empty_json_response(); + }; + + let text = msg + .argument_text + .as_deref() + .or(msg.text.as_deref()) + .unwrap_or(""); + + let media_refs = parse_attachments(&msg.attachment); + + // Drop event only if BOTH text and attachments are empty + if text.trim().is_empty() && media_refs.is_empty() { + return empty_json_response(); + } + + let sender = msg.sender.as_ref().or(top_user.as_ref()); + let space = msg.space.as_ref().or(top_space.as_ref()); + + let is_bot = sender.map(|s| s.user_type == "BOT").unwrap_or(false); + if is_bot { + return empty_json_response(); + } + + let sender_id = sender.map(|s| s.name.clone()).unwrap_or_default(); + let display_name = sender + .map(|s| s.display_name.clone()) + .unwrap_or_else(|| "Unknown".into()); + let sender_name = sender_id + .strip_prefix("users/") + .unwrap_or(&sender_id) + .to_string(); + + let space_name = space.map(|s| s.name.clone()).unwrap_or_default(); + let space_type = space + .and_then(|s| s.space_type.clone()) + .unwrap_or_else(|| "ROOM".into()); + + let thread_id = msg.thread.as_ref().map(|t| t.name.clone()); + + let message_id = msg + .name + .rsplit('/') + .next() + .unwrap_or(&msg.name) + .to_string(); + + // No attachments → emit event synchronously and respond 200 + if media_refs.is_empty() { + send_googlechat_event( + &state, + &space_name, + space_type, + thread_id, + &sender_id, + &sender_name, + &display_name, + text, + &message_id, + Vec::new(), + ); + return empty_json_response(); + } + + // Has attachments — spawn background task so the webhook returns 200 within + // Google Chat's 30 s deadline regardless of how long downloads take. + let text = text.to_string(); + let state = state.clone(); + let spawn_space = space_name.clone(); + tokio::spawn(async move { + use futures_util::FutureExt; + let result = std::panic::AssertUnwindSafe(async { + let mut downloaded: Vec = Vec::new(); + let mut text_file_count: usize = 0; + let mut text_file_bytes: u64 = 0; + if let Some(ref adapter) = state.google_chat { + if let Some(token) = adapter.get_token().await { + for media_ref in &media_refs { + let attachment = match media_ref { + GoogleChatMediaRef::Image { + resource_name, + content_name, + .. + } => { + download_googlechat_image( + &adapter.client, + &token, + &adapter.api_base, + resource_name, + content_name, + ) + .await + } + GoogleChatMediaRef::File { + resource_name, + content_name, + .. + } => { + if text_file_count >= TEXT_FILE_COUNT_CAP { + warn!(content_name = %content_name, cap = TEXT_FILE_COUNT_CAP, "googlechat text file count cap reached, skipping"); + continue; + } + let remaining = TEXT_TOTAL_CAP.saturating_sub(text_file_bytes); + let att = download_googlechat_file( + &adapter.client, + &token, + &adapter.api_base, + resource_name, + content_name, + remaining, + ) + .await; + let Some(att) = att else { continue }; + text_file_count += 1; + text_file_bytes += att.size; + Some(att) + } + GoogleChatMediaRef::Audio { + resource_name, + content_name, + content_type, + } => { + download_googlechat_audio( + &adapter.client, + &token, + &adapter.api_base, + resource_name, + content_name, + content_type, + ) + .await + } + }; + if let Some(att) = attachment { + downloaded.push(att); + } + } + } else { + warn!("googlechat: no token available for attachment download"); + } + } + + // If text is empty AND every attachment failed to download, drop the event. + if text.trim().is_empty() && downloaded.is_empty() { + warn!( + space = %space_name, + "googlechat: empty text + all attachments failed, dropping event" + ); + return; + } + + send_googlechat_event( + &state, + &space_name, + space_type, + thread_id, + &sender_id, + &sender_name, + &display_name, + &text, + &message_id, + downloaded, + ); + }).catch_unwind().await; + if let Err(e) = result { + error!(space = %spawn_space, "googlechat attachment download task panicked: {e:?}"); + } + }); + + empty_json_response() +} + +#[allow(clippy::too_many_arguments)] +fn send_googlechat_event( + state: &Arc, + space_name: &str, + space_type: String, + thread_id: Option, + sender_id: &str, + sender_name: &str, + display_name: &str, + text: &str, + message_id: &str, + attachments: Vec, +) { + let mut gw_event = GatewayEvent::new( + "googlechat", + ChannelInfo { + id: space_name.to_string(), + channel_type: space_type, + thread_id, + }, + SenderInfo { + id: sender_id.to_string(), + name: sender_name.to_string(), + display_name: display_name.to_string(), + is_bot: false, + }, + text, + message_id, + vec![], + ); + gw_event.content.attachments = attachments; + + let attachment_count = gw_event.content.attachments.len(); + let json = match serde_json::to_string(&gw_event) { + Ok(j) => j, + Err(e) => { + error!(error = %e, "googlechat: failed to serialize GatewayEvent"); + return; + } + }; + info!( + space = %space_name, + sender = %sender_name, + attachment_count, + "googlechat → gateway" + ); + let _ = state.event_tx.send(json); +} + +fn empty_json_response() -> axum::response::Response { + use axum::response::IntoResponse; + ( + [(axum::http::header::CONTENT_TYPE, "application/json")], + "{}", + ) + .into_response() +} + +// --- Token cache with JWT auto-refresh --- + +pub struct GoogleChatTokenCache { + token: RwLock>, + sa_email: String, + private_key: String, +} + +const TOKEN_REFRESH_MARGIN_SECS: u64 = 300; + +impl GoogleChatTokenCache { + pub fn new(sa_key_json: &str) -> Result { + let key: serde_json::Value = + serde_json::from_str(sa_key_json).map_err(|e| format!("invalid SA key JSON: {e}"))?; + let email = key + .get("client_email") + .and_then(|v| v.as_str()) + .ok_or("missing client_email in SA key")? + .to_string(); + let pkey = key + .get("private_key") + .and_then(|v| v.as_str()) + .ok_or("missing private_key in SA key")? + .to_string(); + Ok(Self { + token: RwLock::new(None), + sa_email: email, + private_key: pkey, + }) + } + + pub async fn get_token(&self, client: &reqwest::Client) -> Result { + { + let guard = self.token.read().await; + if let Some((ref tok, ref ts, ttl)) = *guard { + if ts.elapsed().as_secs() < ttl.saturating_sub(TOKEN_REFRESH_MARGIN_SECS) { + return Ok(tok.clone()); + } + } + } + let mut guard = self.token.write().await; + if let Some((ref tok, ref ts, ttl)) = *guard { + if ts.elapsed().as_secs() < ttl.saturating_sub(TOKEN_REFRESH_MARGIN_SECS) { + return Ok(tok.clone()); + } + } + let (new_token, expire) = self.refresh(client).await?; + *guard = Some((new_token.clone(), Instant::now(), expire)); + info!("googlechat access token refreshed (expires in {expire}s)"); + Ok(new_token) + } + + async fn refresh(&self, client: &reqwest::Client) -> Result<(String, u64), String> { + let jwt = self.build_jwt().map_err(|e| format!("JWT build error: {e}"))?; + let resp = client + .post("https://oauth2.googleapis.com/token") + .form(&[ + ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"), + ("assertion", &jwt), + ]) + .send() + .await + .map_err(|e| format!("token exchange request failed: {e}"))?; + + let body: serde_json::Value = resp + .json() + .await + .map_err(|e| format!("token exchange parse failed: {e}"))?; + + let token = body + .get("access_token") + .and_then(|v| v.as_str()) + .ok_or_else(|| { + let err = body + .get("error_description") + .and_then(|v| v.as_str()) + .unwrap_or("unknown error"); + format!("token exchange failed: {err}") + })? + .to_string(); + + let expires_in = body + .get("expires_in") + .and_then(|v| v.as_u64()) + .unwrap_or(3600); + + Ok((token, expires_in)) + } + + fn build_jwt(&self) -> Result { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map_err(|e| e.to_string())? + .as_secs(); + + let claims = serde_json::json!({ + "iss": self.sa_email, + "scope": "https://www.googleapis.com/auth/chat.bot", + "aud": "https://oauth2.googleapis.com/token", + "iat": now, + "exp": now + 3600, + }); + + let key = jsonwebtoken::EncodingKey::from_rsa_pem(self.private_key.as_bytes()) + .map_err(|e| format!("RSA key parse error: {e}"))?; + let header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::RS256); + jsonwebtoken::encode(&header, &claims, &key) + .map_err(|e| format!("JWT encode error: {e}")) + } +} + +/// Convert markdown to Google Chat native formatting. +/// +/// Called by both `send_message` and `edit_message`. Assumes the caller passes +/// **raw markdown** — passing already-converted text would double-convert +/// (e.g. `*bold*` from a previous pass would be re-parsed as `*italic*`). +/// OAB core is expected to always emit raw markdown for both initial replies +/// and streaming edits. +fn markdown_to_gchat(text: &str) -> String { + let mut result = String::with_capacity(text.len()); + let lines: Vec<&str> = text.split('\n').collect(); + let mut i = 0; + while i < lines.len() { + let line = lines[i]; + // Detect fenced code block — pass through unchanged + if line.trim_start().starts_with("```") { + result.push_str(line); + result.push('\n'); + i += 1; + while i < lines.len() { + result.push_str(lines[i]); + if lines[i].trim_start().starts_with("```") { + i += 1; + if i < lines.len() { + result.push('\n'); + } + break; + } + result.push('\n'); + i += 1; + } + continue; + } + // Heading → bold + let converted = if let Some(heading) = line + .strip_prefix("### ") + .or_else(|| line.strip_prefix("## ")) + .or_else(|| line.strip_prefix("# ")) + { + format!("*{}*", heading.trim()) + } else { + convert_inline(line) + }; + result.push_str(&converted); + i += 1; + if i < lines.len() { + result.push('\n'); + } + } + result +} + +// TODO(perf): allocates Vec per line. Acceptable at current scale, +// but on hot streaming paths with many edit_message updates this could be +// rewritten with byte-level iteration over &str. +fn convert_inline(line: &str) -> String { + let mut out = String::with_capacity(line.len()); + let chars: Vec = line.chars().collect(); + let mut i = 0; + while i < chars.len() { + // Inline code — pass through + if chars[i] == '`' { + out.push('`'); + i += 1; + while i < chars.len() && chars[i] != '`' { + out.push(chars[i]); + i += 1; + } + if i < chars.len() { + out.push('`'); + i += 1; + } + continue; + } + // Markdown link: [text](url) + if chars[i] == '[' { + if let Some((link_text, url, end)) = parse_md_link(&chars, i) { + let converted_text = convert_inline(&link_text); + out.push_str(&format!("<{}|{}>", url, converted_text)); + i = end; + continue; + } + } + // Bold: **text** → *text* + if chars[i] == '*' && i + 1 < chars.len() && chars[i + 1] == '*' { + if let Some(end) = find_closing(&chars, i + 2, &['*', '*']) { + out.push('*'); + let inner: String = chars[i + 2..end].iter().collect(); + out.push_str(&convert_inline(&inner)); + out.push('*'); + i = end + 2; + continue; + } + } + // Bold: __text__ → *text* + if chars[i] == '_' && i + 1 < chars.len() && chars[i + 1] == '_' { + if let Some(end) = find_closing(&chars, i + 2, &['_', '_']) { + out.push('*'); + let inner: String = chars[i + 2..end].iter().collect(); + out.push_str(&convert_inline(&inner)); + out.push('*'); + i = end + 2; + continue; + } + } + // Strikethrough: ~~text~~ → ~text~ + if chars[i] == '~' && i + 1 < chars.len() && chars[i + 1] == '~' { + if let Some(end) = find_closing(&chars, i + 2, &['~', '~']) { + out.push('~'); + let inner: String = chars[i + 2..end].iter().collect(); + out.push_str(&convert_inline(&inner)); + out.push('~'); + i = end + 2; + continue; + } + } + // Italic: *text* → _text_ (single asterisk, not part of **bold**) + // Must come AFTER the **bold** check above. Requires non-asterisk + // immediately after opening * and before closing *. + if chars[i] == '*' + && i + 1 < chars.len() + && chars[i + 1] != '*' + && !chars[i + 1].is_whitespace() + { + if let Some(end) = find_single(&chars, i + 1, '*') { + if end > i + 1 && !chars[end - 1].is_whitespace() { + out.push('_'); + let inner: String = chars[i + 1..end].iter().collect(); + out.push_str(&convert_inline(&inner)); + out.push('_'); + i = end + 1; + continue; + } + } + } + out.push(chars[i]); + i += 1; + } + out +} + +fn find_single(chars: &[char], start: usize, target: char) -> Option { + let mut i = start; + while i < chars.len() { + if chars[i] == target { + return Some(i); + } + i += 1; + } + None +} + +fn parse_md_link(chars: &[char], start: usize) -> Option<(String, String, usize)> { + let mut i = start + 1; + let mut depth = 1; + let text_start = i; + while i < chars.len() && depth > 0 { + if chars[i] == '[' { + depth += 1; + } else if chars[i] == ']' { + depth -= 1; + } + if depth > 0 { + i += 1; + } + } + if depth != 0 { + return None; + } + let text: String = chars[text_start..i].iter().collect(); + i += 1; // skip ']' + if i >= chars.len() || chars[i] != '(' { + return None; + } + i += 1; // skip '(' + let url_start = i; + let mut paren_depth = 1; + while i < chars.len() && paren_depth > 0 { + if chars[i] == '(' { + paren_depth += 1; + } else if chars[i] == ')' { + paren_depth -= 1; + } + if paren_depth > 0 { + i += 1; + } + } + if paren_depth != 0 { + return None; + } + let url: String = chars[url_start..i].iter().collect(); + Some((text, url, i + 1)) +} + +fn find_closing(chars: &[char], start: usize, pattern: &[char]) -> Option { + if pattern.len() < 2 { + return None; + } + let mut i = start; + while i + 1 < chars.len() { + if chars[i] == pattern[0] && chars[i + 1] == pattern[1] { + return Some(i); + } + i += 1; + } + None +} + +async fn send_message( + client: &reqwest::Client, + token: &str, + space: &str, + thread_id: Option<&str>, + text: &str, + api_base: &str, +) -> Result { + let mut url = format!("{}/{}/messages", api_base, space); + + let formatted = markdown_to_gchat(text); + let mut body = serde_json::json!({ + "text": formatted, + }); + + if let Some(thread_id) = thread_id { + body["thread"] = serde_json::json!({ + "name": thread_id, + }); + url.push_str("?messageReplyOption=REPLY_MESSAGE_FALLBACK_TO_NEW_THREAD"); + } + + let resp = client + .post(&url) + .bearer_auth(token) + .json(&body) + .send() + .await; + + match resp { + Ok(r) if r.status().is_success() => { + let body = r.text().await.unwrap_or_default(); + let parsed: serde_json::Value = serde_json::from_str(&body).unwrap_or_default(); + parsed + .get("name") + .and_then(|v| v.as_str()) + .map(String::from) + .ok_or_else(|| "missing message name in response".into()) + } + Ok(r) => { + let status = r.status(); + let body = r.text().await.unwrap_or_default(); + error!(status = %status, body = %body, "googlechat send error"); + Err(format!("send failed: {} {}", status, body)) + } + Err(e) => { + error!("googlechat send error: {e}"); + Err(format!("request error: {e}")) + } + } +} + +fn split_text(text: &str, limit: usize) -> Vec<&str> { + let mut chunks = Vec::new(); + let mut start = 0; + while start < text.len() { + if start + limit >= text.len() { + chunks.push(&text[start..]); + break; + } + let mut end = start + limit; + while !text.is_char_boundary(end) { + end -= 1; + } + let mut search_start = if end > start + 200 { end - 200 } else { start }; + while search_start < end && !text.is_char_boundary(search_start) { + search_start += 1; + } + let break_at = text[search_start..end] + .rfind('\n') + .or_else(|| text[search_start..end].rfind(' ')) + .map(|pos| search_start + pos + 1) + .unwrap_or(end); + chunks.push(&text[start..break_at]); + start = break_at; + } + chunks +} + +// --- Attachment parsing & download --- + +/// Whitelist of text-like file extensions for `download_googlechat_file`. +const TEXT_EXTS: &[&str] = &[ + "txt", "csv", "log", "md", "json", "jsonl", "yaml", "yml", "toml", "xml", + "rs", "py", "js", "ts", "jsx", "tsx", "go", "java", "c", "cpp", "h", "hpp", + "rb", "sh", "bash", "sql", "html", "css", "ini", "cfg", "conf", +]; + +/// Parse Google Chat attachment array into media references for async download. +/// +/// Skips Drive-sourced attachments (different download API), and unknown +/// content types. Branches on `contentType` prefix to bucket into image / +/// audio / file. +fn parse_attachments(attachments: &[GoogleChatAttachment]) -> Vec { + let mut refs = Vec::new(); + for att in attachments { + // Only handle UPLOADED_CONTENT (Drive needs separate Drive API call) + if att.source.as_deref() != Some("UPLOADED_CONTENT") { + continue; + } + let resource_name = match att + .attachment_data_ref + .as_ref() + .and_then(|d| d.resource_name.clone()) + { + Some(rn) => rn, + None => continue, + }; + let content_type = att.content_type.clone().unwrap_or_default(); + let content_name = att.content_name.clone().unwrap_or_else(|| "file".into()); + + if content_type.starts_with("image/") { + refs.push(GoogleChatMediaRef::Image { + resource_name, + content_name, + }); + } else if content_type.starts_with("audio/") { + refs.push(GoogleChatMediaRef::Audio { + resource_name, + content_name, + content_type, + }); + } else if content_type.starts_with("video/") { + info!(content_name = %content_name, content_type = %content_type, "googlechat: video attachment skipped (not yet supported)"); + } else { + refs.push(GoogleChatMediaRef::File { + resource_name, + content_name, + }); + } + } + refs +} + +/// Resize image so longest side ≤ 1200px, then encode as JPEG. +/// GIFs are passed through unchanged to preserve animation. +fn resize_and_compress(raw: &[u8]) -> Result<(Vec, String), image::ImageError> { + use image::ImageReader; + use std::io::Cursor; + + let reader = ImageReader::new(Cursor::new(raw)).with_guessed_format()?; + let format = reader.format(); + if format == Some(image::ImageFormat::Gif) { + return Ok((raw.to_vec(), "image/gif".to_string())); + } + let img = reader.decode()?; + let (w, h) = (img.width(), img.height()); + let img = if w > IMAGE_MAX_DIMENSION_PX || h > IMAGE_MAX_DIMENSION_PX { + let max_side = std::cmp::max(w, h); + let ratio = f64::from(IMAGE_MAX_DIMENSION_PX) / f64::from(max_side); + let new_w = (f64::from(w) * ratio) as u32; + let new_h = (f64::from(h) * ratio) as u32; + img.resize(new_w, new_h, image::imageops::FilterType::Lanczos3) + } else { + img + }; + let mut buf = Cursor::new(Vec::new()); + let encoder = image::codecs::jpeg::JpegEncoder::new_with_quality(&mut buf, IMAGE_JPEG_QUALITY); + img.write_with_encoder(encoder)?; + Ok((buf.into_inner(), "image/jpeg".to_string())) +} + +/// Build the Media API URL for a given resource_name. +/// Google Chat Media API uses `{+resourceName}` (RFC 6570 reserved expansion), +/// so `/` must stay literal while other special chars are percent-encoded. +fn media_url(api_base: &str, resource_name: &str) -> String { + let encoded: String = resource_name + .bytes() + .map(|b| match b { + b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' | b'/' => { + (b as char).to_string() + } + _ => format!("%{:02X}", b), + }) + .collect(); + format!("{}/media/{}?alt=media", api_base, encoded) +} + +/// Download an image attachment via Google Chat Media API → resize/compress → base64. +pub async fn download_googlechat_image( + client: &reqwest::Client, + token: &str, + api_base: &str, + resource_name: &str, + content_name: &str, +) -> Option { + let url = media_url(api_base, resource_name); + let resp = match client.get(&url).bearer_auth(token).timeout(MEDIA_REQUEST_TIMEOUT).send().await { + Ok(r) => r, + Err(e) => { + warn!(content_name, error = %e, "googlechat image download failed"); + return None; + } + }; + if !resp.status().is_success() { + warn!(content_name, status = %resp.status(), "googlechat image download failed"); + return None; + } + if let Some(cl) = resp.headers().get(reqwest::header::CONTENT_LENGTH) { + if let Ok(size) = cl.to_str().unwrap_or("0").parse::() { + if size > IMAGE_MAX_DOWNLOAD { + warn!(content_name, size, "googlechat image Content-Length exceeds 10MB limit"); + return None; + } + } + } + let bytes = resp.bytes().await.ok()?; + if bytes.len() as u64 > IMAGE_MAX_DOWNLOAD { + warn!(content_name, size = bytes.len(), "googlechat image exceeds 10MB limit"); + return None; + } + let (compressed, mime) = match resize_and_compress(&bytes) { + Ok(v) => v, + Err(e) => { + warn!(content_name, error = %e, "googlechat image resize failed"); + return None; + } + }; + let path = crate::store::store_media(&compressed).await?; + Some(crate::schema::Attachment { + attachment_type: "image".into(), + filename: content_name.to_string(), + mime_type: mime, + data: String::new(), + size: compressed.len() as u64, + path: Some(path), + }) +} + +/// Download a text-like file via Google Chat Media API → base64. +/// Non-text extensions are skipped to avoid sending binary garbage to the model. +pub async fn download_googlechat_file( + client: &reqwest::Client, + token: &str, + api_base: &str, + resource_name: &str, + content_name: &str, + remaining_budget: u64, +) -> Option { + let ext = content_name.rsplit('.').next().unwrap_or("").to_lowercase(); + if !TEXT_EXTS.contains(&ext.as_str()) { + tracing::debug!(content_name, "skipping non-text googlechat file attachment"); + return None; + } + let max_size = FILE_MAX_DOWNLOAD.min(remaining_budget); + let url = media_url(api_base, resource_name); + let resp = match client.get(&url).bearer_auth(token).timeout(MEDIA_REQUEST_TIMEOUT).send().await { + Ok(r) => r, + Err(e) => { + warn!(content_name, error = %e, "googlechat file download failed"); + return None; + } + }; + if !resp.status().is_success() { + warn!(content_name, status = %resp.status(), "googlechat file download failed"); + return None; + } + if let Some(cl) = resp.headers().get(reqwest::header::CONTENT_LENGTH) { + if let Ok(size) = cl.to_str().unwrap_or("0").parse::() { + if size > max_size { + warn!(content_name, size, limit = max_size, "googlechat file Content-Length exceeds limit"); + return None; + } + } + } + let bytes = resp.bytes().await.ok()?; + if bytes.len() as u64 > max_size { + warn!(content_name, size = bytes.len(), limit = max_size, "googlechat file exceeds size limit"); + return None; + } + let path = crate::store::store_media(&bytes).await?; + Some(crate::schema::Attachment { + attachment_type: "text_file".into(), + filename: content_name.to_string(), + mime_type: "text/plain".into(), + data: String::new(), + size: bytes.len() as u64, + path: Some(path), + }) +} + +/// Download an audio attachment as-is (no resize/transcode) → filesystem store. +/// Core's STT pipeline (when available) consumes this as `audio` attachment_type. +pub async fn download_googlechat_audio( + client: &reqwest::Client, + token: &str, + api_base: &str, + resource_name: &str, + content_name: &str, + content_type: &str, +) -> Option { + let url = media_url(api_base, resource_name); + let resp = match client.get(&url).bearer_auth(token).timeout(MEDIA_REQUEST_TIMEOUT).send().await { + Ok(r) => r, + Err(e) => { + warn!(content_name, error = %e, "googlechat audio download failed"); + return None; + } + }; + if !resp.status().is_success() { + warn!(content_name, status = %resp.status(), "googlechat audio download failed"); + return None; + } + if let Some(cl) = resp.headers().get(reqwest::header::CONTENT_LENGTH) { + if let Ok(size) = cl.to_str().unwrap_or("0").parse::() { + if size > AUDIO_MAX_DOWNLOAD { + warn!(content_name, size, "googlechat audio Content-Length exceeds 25MB limit"); + return None; + } + } + } + let bytes = resp.bytes().await.ok()?; + if bytes.len() as u64 > AUDIO_MAX_DOWNLOAD { + warn!(content_name, size = bytes.len(), "googlechat audio exceeds 25MB limit"); + return None; + } + let path = crate::store::store_media(&bytes).await?; + Some(crate::schema::Attachment { + attachment_type: "audio".into(), + filename: content_name.to_string(), + mime_type: content_type.to_string(), + data: String::new(), + size: bytes.len() as u64, + path: Some(path), + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + // --- Webhook parsing tests --- + + fn make_envelope( + text: &str, + argument_text: Option<&str>, + sender_type: &str, + space_type: &str, + thread_name: Option<&str>, + ) -> String { + let arg_field = argument_text + .map(|a| format!(r#""argumentText": "{a}","#)) + .unwrap_or_default(); + let thread_field = thread_name + .map(|t| format!(r#","thread": {{"name": "{t}"}}"#)) + .unwrap_or_default(); + format!( + r#"{{ + "chat": {{ + "user": {{ + "name": "users/111", + "displayName": "Test", + "type": "{sender_type}" + }}, + "messagePayload": {{ + "message": {{ + "name": "spaces/SP/messages/msg1", + "text": "{text}", + {arg_field} + "sender": {{ + "name": "users/111", + "displayName": "Test", + "type": "{sender_type}" + }}, + "space": {{ + "name": "spaces/SP", + "type": "{space_type}" + }} + {thread_field} + }}, + "space": {{ + "name": "spaces/SP", + "type": "{space_type}" + }} + }} + }} + }}"# + ) + } + + #[test] + fn parse_dm_message() { + let json = make_envelope("hello", None, "HUMAN", "DM", None); + let envelope: GoogleChatEnvelope = serde_json::from_str(&json).unwrap(); + let chat = envelope.chat.unwrap(); + let msg = chat.message_payload.unwrap().message.unwrap(); + assert_eq!(msg.text.as_deref(), Some("hello")); + assert_eq!(msg.sender.unwrap().user_type, "HUMAN"); + } + + #[test] + fn parse_space_message_with_thread() { + let json = make_envelope( + "@Bot hi", + Some("hi"), + "HUMAN", + "ROOM", + Some("spaces/SP/threads/t1"), + ); + let envelope: GoogleChatEnvelope = serde_json::from_str(&json).unwrap(); + let chat = envelope.chat.unwrap(); + let payload = chat.message_payload.unwrap(); + let msg = payload.message.as_ref().unwrap(); + assert_eq!(msg.argument_text.as_deref(), Some("hi")); + assert_eq!(msg.thread.as_ref().unwrap().name, "spaces/SP/threads/t1"); + assert_eq!(payload.space.as_ref().unwrap().space_type.as_deref(), Some("ROOM")); + } + + #[test] + fn parse_bot_message_detected() { + let json = make_envelope("bot says hi", None, "BOT", "DM", None); + let envelope: GoogleChatEnvelope = serde_json::from_str(&json).unwrap(); + let chat = envelope.chat.unwrap(); + let user = chat.user.unwrap(); + assert_eq!(user.user_type, "BOT"); + } + + #[test] + fn parse_missing_chat_field() { + let json = r#"{"type": "ADDED_TO_SPACE"}"#; + let envelope: GoogleChatEnvelope = serde_json::from_str(json).unwrap(); + assert!(envelope.chat.is_none()); + } + + #[test] + fn parse_missing_message_payload() { + let json = r#"{"chat": {"user": {"name": "u/1", "displayName": "X", "type": "HUMAN"}}}"#; + let envelope: GoogleChatEnvelope = serde_json::from_str(json).unwrap(); + assert!(envelope.chat.unwrap().message_payload.is_none()); + } + + #[test] + fn parse_invalid_json() { + let result: Result = serde_json::from_str("not json"); + assert!(result.is_err()); + } + + #[test] + fn argument_text_preferred_over_text() { + let json = make_envelope("@Bot explain", Some("explain"), "HUMAN", "ROOM", None); + let envelope: GoogleChatEnvelope = serde_json::from_str(&json).unwrap(); + let msg = envelope + .chat + .unwrap() + .message_payload + .unwrap() + .message + .unwrap(); + let text = msg + .argument_text + .as_deref() + .or(msg.text.as_deref()) + .unwrap(); + assert_eq!(text, "explain"); + } + + #[test] + fn sender_name_strips_users_prefix() { + let sender_id = "users/123456"; + let name = sender_id.strip_prefix("users/").unwrap_or(sender_id); + assert_eq!(name, "123456"); + } + + #[test] + fn message_id_extracts_last_segment() { + let msg_name = "spaces/SP/messages/abc123"; + let id = msg_name.rsplit('/').next().unwrap_or(msg_name); + assert_eq!(id, "abc123"); + } + + // --- split_text tests --- + + #[test] + fn split_text_short() { + let chunks = split_text("hello", 100); + assert_eq!(chunks, vec!["hello"]); + } + + #[test] + fn split_text_exact_limit() { + let text = "a".repeat(100); + let chunks = split_text(&text, 100); + assert_eq!(chunks.len(), 1); + } + + #[test] + fn split_text_over_limit() { + let text = "a".repeat(150); + let chunks = split_text(&text, 100); + assert_eq!(chunks.len(), 2); + let reassembled: String = chunks.concat(); + assert_eq!(reassembled, text); + } + + #[test] + fn split_text_breaks_at_newline() { + let text = format!("{}\n{}", "a".repeat(50), "b".repeat(50)); + let chunks = split_text(&text, 60); + assert_eq!(chunks.len(), 2); + assert!(chunks[0].ends_with('\n')); + } + + #[test] + fn split_text_breaks_at_space() { + let text = format!("{} {}", "a".repeat(50), "b".repeat(50)); + let chunks = split_text(&text, 60); + assert_eq!(chunks.len(), 2); + } + + #[test] + fn split_text_chinese_utf8_safe() { + let text = "你好世界測試谷歌聊天中文消息分割安全驗證完成"; + let chunks = split_text(text, 10); + assert!(chunks.len() > 1); + let reassembled: String = chunks.concat(); + assert_eq!(reassembled, text); + } + + #[test] + fn split_text_search_start_char_boundary() { + let text: String = "谷歌".repeat(150); // 300 chars, 900 bytes + let chunks = split_text(&text, 500); + assert!(chunks.len() >= 2); + let reassembled: String = chunks.concat(); + assert_eq!(reassembled, text); + } + + #[test] + fn split_text_empty() { + let chunks = split_text("", 100); + assert!(chunks.is_empty()); + } + + // --- Token cache tests --- + + #[test] + fn token_cache_rejects_invalid_json() { + let result = GoogleChatTokenCache::new("not json"); + assert!(result.is_err()); + } + + #[test] + fn token_cache_rejects_missing_fields() { + match GoogleChatTokenCache::new(r#"{"type": "service_account"}"#) { + Err(e) => assert!(e.contains("client_email"), "unexpected error: {e}"), + Ok(_) => panic!("expected error for missing client_email"), + } + } + + #[test] + fn token_cache_accepts_valid_sa_key() { + let key = r#"{ + "type": "service_account", + "client_email": "test@test.iam.gserviceaccount.com", + "private_key": "-----BEGIN RSA PRIVATE KEY-----\nMIIBogIBAAJBALvRE+oCMiEhtfO5ufaVc9wGPUMgPGxmVFiMPC/NMxmCSiMGNO9h\nCOyByeF78QHp4gOW/lgVU8MJkv33hVMbOr0CAwEAAQJAD2k/cFR5MIkw1PFcm98K\n9MqYKGpJCmGBjFY0ek0FHoC14d/hpAGaoWMjNaAyjU/IbGv1fj8C5MfFRal0fV/L\nAQIhAP0T6FPJMm3O4bM18kMHnOP2+Y5kxMpVxCCjkVNH7D09AiEAvXEQJYwR+PFs\njDDhEm4VPmk+lKJoQlopj8TN5gQV8DECIBcXbU+LPWx4H+qRElhCB1B5a9mYmpY\nV6LFPnvSfHqNAiEAiNj5+A6E7WJ50il+5NG5yn7gXh8vNxdCYIw5qx6C2bECIBmW\nVGVRhSmNsmDMJFsGIdKJsnEXpizIVHtfpXsS4j9X\n-----END RSA PRIVATE KEY-----\n" + }"#; + let result = GoogleChatTokenCache::new(key); + assert!(result.is_ok()); + } + + // --- Bot filtering logic test --- + + #[test] + fn bot_user_type_detected() { + let json = make_envelope("hello", None, "BOT", "DM", None); + let envelope: GoogleChatEnvelope = serde_json::from_str(&json).unwrap(); + let chat = envelope.chat.unwrap(); + let sender = chat + .message_payload + .as_ref() + .and_then(|p| p.message.as_ref()) + .and_then(|m| m.sender.as_ref()) + .or(chat.user.as_ref()); + let is_bot = sender.map(|s| s.user_type == "BOT").unwrap_or(false); + assert!(is_bot); + } + + // --- JWT verifier tests --- + + #[tokio::test] + async fn jwt_rejects_missing_bearer_prefix() { + let verifier = GoogleChatJwtVerifier::new("123456".into()); + let result = verifier.verify("NotBearer xyz").await; + assert!(result.is_err()); + assert!(result.unwrap_err().contains("Bearer")); + } + + #[tokio::test] + async fn jwt_rejects_invalid_token() { + let verifier = GoogleChatJwtVerifier::new("123456".into()); + let result = verifier.verify("Bearer not.a.valid.jwt").await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn jwt_rejects_empty_bearer() { + let verifier = GoogleChatJwtVerifier::new("123456".into()); + let result = verifier.verify("Bearer ").await; + assert!(result.is_err()); + } + + #[test] + fn email_claim_accepts_chat_system_account() { + let claims = serde_json::json!({"email": "chat@system.gserviceaccount.com"}); + assert!(verify_email_claim(&claims).is_ok()); + } + + #[test] + fn email_claim_rejects_other_google_email() { + let claims = serde_json::json!({"email": "attacker@example.iam.gserviceaccount.com"}); + let err = verify_email_claim(&claims).unwrap_err(); + assert!(err.contains("email claim mismatch")); + } + + #[test] + fn email_claim_rejects_unrelated_gserviceaccount() { + let claims = serde_json::json!({"email": "my-sa@my-project.iam.gserviceaccount.com"}); + assert!(verify_email_claim(&claims).is_err()); + } + + #[test] + fn email_claim_rejects_missing_email() { + let claims = serde_json::json!({"sub": "123", "iss": "accounts.google.com"}); + let err = verify_email_claim(&claims).unwrap_err(); + assert!(err.contains("missing email")); + } + + #[test] + fn email_claim_rejects_non_string_email() { + let claims = serde_json::json!({"email": 12345}); + assert!(verify_email_claim(&claims).is_err()); + } + + #[test] + fn human_user_type_not_filtered() { + let json = make_envelope("hello", None, "HUMAN", "DM", None); + let envelope: GoogleChatEnvelope = serde_json::from_str(&json).unwrap(); + let chat = envelope.chat.unwrap(); + let sender = chat + .message_payload + .as_ref() + .and_then(|p| p.message.as_ref()) + .and_then(|m| m.sender.as_ref()) + .or(chat.user.as_ref()); + let is_bot = sender.map(|s| s.user_type == "BOT").unwrap_or(false); + assert!(!is_bot); + } + + // --- markdown_to_gchat tests --- + + #[test] + fn markdown_bold_double_asterisk() { + assert_eq!(markdown_to_gchat("hello **world**"), "hello *world*"); + } + + #[test] + fn markdown_bold_underscore() { + assert_eq!(markdown_to_gchat("hello __world__"), "hello *world*"); + } + + #[test] + fn markdown_link_conversion() { + assert_eq!( + markdown_to_gchat("see [docs](https://example.com) here"), + "see here" + ); + } + + #[test] + fn markdown_heading_to_bold() { + assert_eq!(markdown_to_gchat("# Title\ntext"), "*Title*\ntext"); + assert_eq!(markdown_to_gchat("## Sub\ntext"), "*Sub*\ntext"); + assert_eq!(markdown_to_gchat("### Deep\ntext"), "*Deep*\ntext"); + } + + #[test] + fn markdown_code_block_preserved() { + let input = "before\n```rust\nlet **x** = 1;\n```\nafter **bold**"; + let output = markdown_to_gchat(input); + assert!(output.contains("let **x** = 1;")); + assert!(output.contains("after *bold*")); + } + + #[test] + fn markdown_inline_code_preserved() { + assert_eq!( + markdown_to_gchat("use `**not bold**` here **bold**"), + "use `**not bold**` here *bold*" + ); + } + + #[test] + fn markdown_strikethrough() { + assert_eq!(markdown_to_gchat("~~deleted~~"), "~deleted~"); + assert_eq!( + markdown_to_gchat("keep ~~this~~ and ~~that~~"), + "keep ~this~ and ~that~" + ); + } + + #[test] + fn markdown_italic_asterisk() { + assert_eq!(markdown_to_gchat("*italic*"), "_italic_"); + assert_eq!( + markdown_to_gchat("plain *one* and *two*"), + "plain _one_ and _two_" + ); + } + + #[test] + fn markdown_italic_does_not_match_bold() { + assert_eq!(markdown_to_gchat("**bold**"), "*bold*"); + assert_eq!( + markdown_to_gchat("**bold** and *italic*"), + "*bold* and _italic_" + ); + } + + #[test] + fn markdown_italic_underscore_passes_through() { + // Google Chat italic is _text_, single underscore should pass through + assert_eq!(markdown_to_gchat("_italic_"), "_italic_"); + } + + #[test] + fn markdown_italic_no_match_when_unbalanced() { + // Lone asterisks (no closing) should pass through + assert_eq!(markdown_to_gchat("a * b"), "a * b"); + // Whitespace adjacent to asterisks should not match (avoid matching multiplication) + assert_eq!(markdown_to_gchat("2 * 3 * 4"), "2 * 3 * 4"); + } + + #[test] + fn markdown_empty_string() { + assert_eq!(markdown_to_gchat(""), ""); + } + + #[test] + fn markdown_no_conversion_needed() { + assert_eq!(markdown_to_gchat("plain text"), "plain text"); + } + + #[test] + fn markdown_multiple_links() { + assert_eq!( + markdown_to_gchat("[a](http://a.com) and [b](http://b.com)"), + " and " + ); + } + + #[test] + fn markdown_nested_bold_in_link_text() { + assert_eq!( + markdown_to_gchat("[**bold link**](http://x.com)"), + "" + ); + } + + #[test] + fn parse_send_message_response_name() { + let resp_json = r#"{"name": "spaces/SP1/messages/msg123", "text": "hello"}"#; + let parsed: serde_json::Value = serde_json::from_str(resp_json).unwrap(); + let name = parsed.get("name").and_then(|v| v.as_str()); + assert_eq!(name, Some("spaces/SP1/messages/msg123")); + } + + #[tokio::test] + async fn handle_reply_sends_gateway_response_success() { + use wiremock::{Mock, MockServer, ResponseTemplate}; + use wiremock::matchers::{method, path_regex}; + + let mock_server = MockServer::start().await; + Mock::given(method("POST")) + .and(path_regex("/spaces/.*/messages")) + .respond_with(ResponseTemplate::new(200).set_body_json( + serde_json::json!({"name": "spaces/TEST/messages/msg_abc"}), + )) + .mount(&mock_server) + .await; + + let (event_tx, mut event_rx) = tokio::sync::broadcast::channel::(16); + let mut adapter = GoogleChatAdapter::new(None, Some("fake-token".into()), None); + adapter.api_base = mock_server.uri(); + + let reply = GatewayReply { + schema: "openab.gateway.reply.v1".into(), + reply_to: "orig_msg".into(), + platform: "googlechat".into(), + channel: ReplyChannel { + id: "spaces/TEST".into(), + thread_id: None, + }, + content: Content { + content_type: "text".into(), + attachments: Vec::new(), + text: "hello".into(), + }, + command: None, + request_id: Some("req_123".into()), + quote_message_id: None, + }; + + adapter.handle_reply(&reply, &event_tx).await; + + let received = event_rx.try_recv(); + assert!(received.is_ok(), "expected GatewayResponse on event_tx"); + let resp: GatewayResponse = serde_json::from_str(&received.unwrap()).unwrap(); + assert_eq!(resp.request_id, "req_123"); + assert!(resp.success); + assert_eq!(resp.message_id, Some("spaces/TEST/messages/msg_abc".into())); + } + + #[tokio::test] + async fn handle_reply_sends_failure_response_on_api_error() { + use wiremock::{Mock, MockServer, ResponseTemplate}; + use wiremock::matchers::{method, path_regex}; + + let mock_server = MockServer::start().await; + Mock::given(method("POST")) + .and(path_regex("/spaces/.*/messages")) + .respond_with(ResponseTemplate::new(500)) + .mount(&mock_server) + .await; + + let (event_tx, mut event_rx) = tokio::sync::broadcast::channel::(16); + let mut adapter = GoogleChatAdapter::new(None, Some("fake-token".into()), None); + adapter.api_base = mock_server.uri(); + + let reply = GatewayReply { + schema: "openab.gateway.reply.v1".into(), + reply_to: "orig_msg".into(), + platform: "googlechat".into(), + channel: ReplyChannel { + id: "spaces/TEST".into(), + thread_id: None, + }, + content: Content { + content_type: "text".into(), + attachments: Vec::new(), + text: "hello".into(), + }, + command: None, + request_id: Some("req_fail".into()), + quote_message_id: None, + }; + + adapter.handle_reply(&reply, &event_tx).await; + + let received = event_rx.try_recv(); + assert!(received.is_ok(), "expected GatewayResponse on event_tx"); + let resp: GatewayResponse = serde_json::from_str(&received.unwrap()).unwrap(); + assert_eq!(resp.request_id, "req_fail"); + assert!(!resp.success); + assert!(resp.message_id.is_none()); + let err = resp.error.expect("error should be set on send failure"); + assert!(err.contains("500"), "error should include status code, got: {}", err); + } + + #[tokio::test] + async fn handle_reply_empty_message_short_circuits() { + use wiremock::{Mock, MockServer, ResponseTemplate}; + use wiremock::matchers::{method, path_regex}; + + let mock_server = MockServer::start().await; + // Mount a mock that would fail the test if called + Mock::given(method("POST")) + .and(path_regex("/spaces/.*/messages")) + .respond_with(ResponseTemplate::new(500)) + .expect(0) + .mount(&mock_server) + .await; + + let (event_tx, mut event_rx) = tokio::sync::broadcast::channel::(16); + let mut adapter = GoogleChatAdapter::new(None, Some("fake-token".into()), None); + adapter.api_base = mock_server.uri(); + + let reply = GatewayReply { + schema: "openab.gateway.reply.v1".into(), + reply_to: "orig_msg".into(), + platform: "googlechat".into(), + channel: ReplyChannel { + id: "spaces/TEST".into(), + thread_id: None, + }, + content: Content { + content_type: "text".into(), + attachments: Vec::new(), + text: "".into(), + }, + command: None, + request_id: Some("req_empty".into()), + quote_message_id: None, + }; + + adapter.handle_reply(&reply, &event_tx).await; + + let received = event_rx.try_recv(); + assert!(received.is_ok(), "expected failure GatewayResponse for empty message"); + let resp: GatewayResponse = serde_json::from_str(&received.unwrap()).unwrap(); + assert_eq!(resp.request_id, "req_empty"); + assert!(!resp.success); + assert_eq!(resp.error, Some("empty message".into())); + } + + #[tokio::test] + async fn handle_reply_multi_chunk_failure_includes_error() { + use wiremock::{Mock, MockServer, ResponseTemplate}; + use wiremock::matchers::{method, path_regex}; + + let mock_server = MockServer::start().await; + Mock::given(method("POST")) + .and(path_regex("/spaces/.*/messages")) + .respond_with(ResponseTemplate::new(500)) + .mount(&mock_server) + .await; + + let (event_tx, mut event_rx) = tokio::sync::broadcast::channel::(16); + let mut adapter = GoogleChatAdapter::new(None, Some("fake-token".into()), None); + adapter.api_base = mock_server.uri(); + + let long_text = "x".repeat(5000); + let reply = GatewayReply { + schema: "openab.gateway.reply.v1".into(), + reply_to: "orig_msg".into(), + platform: "googlechat".into(), + channel: ReplyChannel { + id: "spaces/TEST".into(), + thread_id: None, + }, + content: Content { + content_type: "text".into(), + attachments: Vec::new(), + text: long_text, + }, + command: None, + request_id: Some("req_multi_fail".into()), + quote_message_id: None, + }; + + adapter.handle_reply(&reply, &event_tx).await; + + let received = event_rx.try_recv(); + assert!(received.is_ok(), "expected GatewayResponse"); + let resp: GatewayResponse = serde_json::from_str(&received.unwrap()).unwrap(); + assert_eq!(resp.request_id, "req_multi_fail"); + assert!(!resp.success); + assert!(resp.message_id.is_none()); + let err = resp.error.expect("multi-chunk failure should set error"); + assert!(err.contains("500")); + } + + #[tokio::test] + async fn handle_reply_token_failure_sends_error_response() { + let (event_tx, mut event_rx) = tokio::sync::broadcast::channel::(16); + let adapter = GoogleChatAdapter::new(None, None, None); + + let reply = GatewayReply { + schema: "openab.gateway.reply.v1".into(), + reply_to: "orig_msg".into(), + platform: "googlechat".into(), + channel: ReplyChannel { + id: "spaces/TEST".into(), + thread_id: None, + }, + content: Content { + content_type: "text".into(), + attachments: Vec::new(), + text: "hello".into(), + }, + command: None, + request_id: Some("req_notoken".into()), + quote_message_id: None, + }; + + adapter.handle_reply(&reply, &event_tx).await; + + let received = event_rx.try_recv(); + assert!(received.is_ok(), "expected failure GatewayResponse"); + let resp: GatewayResponse = serde_json::from_str(&received.unwrap()).unwrap(); + assert_eq!(resp.request_id, "req_notoken"); + assert!(!resp.success); + assert_eq!(resp.error, Some("no credentials configured".into())); + } + + #[tokio::test] + async fn handle_reply_edit_message_does_not_send_response() { + use wiremock::{Mock, MockServer, ResponseTemplate}; + use wiremock::matchers::{method, path_regex}; + + let mock_server = MockServer::start().await; + Mock::given(method("PATCH")) + .and(path_regex("/spaces/.*/messages/.*")) + .respond_with(ResponseTemplate::new(200).set_body_json( + serde_json::json!({"name": "spaces/SP/messages/msg1"}), + )) + .mount(&mock_server) + .await; + + let (event_tx, mut event_rx) = tokio::sync::broadcast::channel::(16); + let mut adapter = GoogleChatAdapter::new(None, Some("fake-token".into()), None); + adapter.api_base = mock_server.uri(); + + let reply = GatewayReply { + schema: "openab.gateway.reply.v1".into(), + reply_to: "spaces/SP/messages/msg1".into(), + platform: "googlechat".into(), + channel: ReplyChannel { + id: "spaces/SP".into(), + thread_id: None, + }, + content: Content { + content_type: "text".into(), + attachments: Vec::new(), + text: "updated text".into(), + }, + command: Some("edit_message".into()), + request_id: None, + quote_message_id: None, + }; + + adapter.handle_reply(&reply, &event_tx).await; + + let received = event_rx.try_recv(); + assert!(received.is_err()); + } + + #[tokio::test] + async fn handle_reply_multi_chunk_sends_gateway_response() { + use wiremock::{Mock, MockServer, ResponseTemplate}; + use wiremock::matchers::{method, path_regex}; + + let mock_server = MockServer::start().await; + Mock::given(method("POST")) + .and(path_regex("/spaces/.*/messages")) + .respond_with(ResponseTemplate::new(200).set_body_json( + serde_json::json!({"name": "spaces/TEST/messages/first_chunk"}), + )) + .mount(&mock_server) + .await; + + let (event_tx, mut event_rx) = tokio::sync::broadcast::channel::(16); + let mut adapter = GoogleChatAdapter::new(None, Some("fake-token".into()), None); + adapter.api_base = mock_server.uri(); + + let long_text = "x".repeat(5000); + let reply = GatewayReply { + schema: "openab.gateway.reply.v1".into(), + reply_to: "orig_msg".into(), + platform: "googlechat".into(), + channel: ReplyChannel { + id: "spaces/TEST".into(), + thread_id: None, + }, + content: Content { + content_type: "text".into(), + attachments: Vec::new(), + text: long_text, + }, + command: None, + request_id: Some("req_multi".into()), + quote_message_id: None, + }; + + adapter.handle_reply(&reply, &event_tx).await; + + let received = event_rx.try_recv(); + assert!(received.is_ok(), "expected GatewayResponse for multi-chunk"); + let resp: GatewayResponse = serde_json::from_str(&received.unwrap()).unwrap(); + assert_eq!(resp.request_id, "req_multi"); + assert!(resp.success); + assert_eq!(resp.message_id, Some("spaces/TEST/messages/first_chunk".into())); + } + + #[tokio::test] + async fn handle_reply_multi_chunk_partial_failure_reports_failure() { + // Mixed success/failure: chunk 1 succeeds, subsequent chunks fail. + // Expect success=false (any chunk failure marks overall as failed), + // but message_id is still set so core has a reference. + use wiremock::{Mock, MockServer, ResponseTemplate}; + use wiremock::matchers::{method, path_regex}; + + let mock_server = MockServer::start().await; + // First request: 200 OK with message name + Mock::given(method("POST")) + .and(path_regex("/spaces/.*/messages")) + .respond_with(ResponseTemplate::new(200).set_body_json( + serde_json::json!({"name": "spaces/TEST/messages/first_chunk"}), + )) + .up_to_n_times(1) + .mount(&mock_server) + .await; + // Subsequent requests: 500 + Mock::given(method("POST")) + .and(path_regex("/spaces/.*/messages")) + .respond_with(ResponseTemplate::new(500)) + .mount(&mock_server) + .await; + + let (event_tx, mut event_rx) = tokio::sync::broadcast::channel::(16); + let mut adapter = GoogleChatAdapter::new(None, Some("fake-token".into()), None); + adapter.api_base = mock_server.uri(); + + let long_text = "x".repeat(5000); + let reply = GatewayReply { + schema: "openab.gateway.reply.v1".into(), + reply_to: "orig_msg".into(), + platform: "googlechat".into(), + channel: ReplyChannel { + id: "spaces/TEST".into(), + thread_id: None, + }, + content: Content { + content_type: "text".into(), + attachments: Vec::new(), + text: long_text, + }, + command: None, + request_id: Some("req_partial".into()), + quote_message_id: None, + }; + + adapter.handle_reply(&reply, &event_tx).await; + + let received = event_rx.try_recv(); + assert!(received.is_ok(), "expected GatewayResponse"); + let resp: GatewayResponse = serde_json::from_str(&received.unwrap()).unwrap(); + assert_eq!(resp.request_id, "req_partial"); + assert!(!resp.success, "partial failure must report success=false"); + assert_eq!(resp.message_id, Some("spaces/TEST/messages/first_chunk".into())); + let err = resp.error.expect("partial failure should set error"); + assert!(err.contains("500")); + } + + // --- Attachment parsing tests --- + + fn make_attachment( + source: &str, + content_type: &str, + content_name: &str, + resource_name: Option<&str>, + ) -> GoogleChatAttachment { + GoogleChatAttachment { + name: Some("spaces/SP/messages/MSG/attachments/ATT".into()), + content_name: Some(content_name.into()), + content_type: Some(content_type.into()), + source: Some(source.into()), + attachment_data_ref: resource_name.map(|rn| AttachmentDataRef { + resource_name: Some(rn.into()), + }), + drive_data_ref: None, + } + } + + #[test] + fn parse_attachments_image() { + let atts = vec![make_attachment( + "UPLOADED_CONTENT", + "image/png", + "photo.png", + Some("AATT_resource"), + )]; + let refs = parse_attachments(&atts); + assert_eq!(refs.len(), 1); + match &refs[0] { + GoogleChatMediaRef::Image { + resource_name, + content_name, + } => { + assert_eq!(resource_name, "AATT_resource"); + assert_eq!(content_name, "photo.png"); + } + other => panic!("expected Image, got {:?}", other), + } + } + + #[test] + fn parse_attachments_audio() { + let atts = vec![make_attachment( + "UPLOADED_CONTENT", + "audio/mp4", + "voice.m4a", + Some("AATT"), + )]; + let refs = parse_attachments(&atts); + assert!(matches!(refs[0], GoogleChatMediaRef::Audio { .. })); + } + + #[test] + fn parse_attachments_file() { + let atts = vec![make_attachment( + "UPLOADED_CONTENT", + "text/plain", + "notes.txt", + Some("AATT"), + )]; + let refs = parse_attachments(&atts); + assert!(matches!(refs[0], GoogleChatMediaRef::File { .. })); + } + + #[test] + fn parse_attachments_skips_drive() { + let atts = vec![GoogleChatAttachment { + name: Some("spaces/SP/messages/MSG/attachments/ATT".into()), + content_name: Some("doc".into()), + content_type: Some("application/vnd.google-apps.document".into()), + source: Some("DRIVE_FILE".into()), + attachment_data_ref: None, + drive_data_ref: Some(DriveDataRef { + drive_file_id: Some("drive_id_123".into()), + }), + }]; + assert_eq!(parse_attachments(&atts).len(), 0); + } + + #[test] + fn parse_attachments_skips_missing_resource_name() { + let atts = vec![make_attachment( + "UPLOADED_CONTENT", + "image/png", + "photo.png", + None, + )]; + assert_eq!(parse_attachments(&atts).len(), 0); + } + + #[test] + fn media_url_preserves_slashes_and_encodes_specials() { + let url = media_url("https://chat.googleapis.com/v1", "spaces/SP/messages/MSG/attachments/ATT"); + assert_eq!( + url, + "https://chat.googleapis.com/v1/media/spaces/SP/messages/MSG/attachments/ATT?alt=media" + ); + let url2 = media_url("https://chat.googleapis.com/v1", "AATT/some+resource=name"); + assert_eq!( + url2, + "https://chat.googleapis.com/v1/media/AATT/some%2Bresource%3Dname?alt=media" + ); + } + + #[tokio::test] + async fn download_googlechat_image_resizes_and_returns_attachment() { + use wiremock::{Mock, MockServer, ResponseTemplate}; + use wiremock::matchers::{method, path_regex}; + + // Generate a small valid PNG + let img = image::RgbImage::from_pixel(10, 10, image::Rgb([255, 0, 0])); + let mut buf = std::io::Cursor::new(Vec::new()); + image::DynamicImage::ImageRgb8(img) + .write_to(&mut buf, image::ImageFormat::Png) + .unwrap(); + let png_bytes = buf.into_inner(); + + let mock_server = MockServer::start().await; + Mock::given(method("GET")) + .and(path_regex("/media/.*")) + .respond_with( + ResponseTemplate::new(200) + .set_body_bytes(png_bytes) + .insert_header("content-type", "image/png"), + ) + .mount(&mock_server) + .await; + + let client = reqwest::Client::new(); + let result = download_googlechat_image( + &client, + "fake-token", + &mock_server.uri(), + "AATT_resource", + "photo.png", + ) + .await; + let att = result.expect("expected successful download"); + assert_eq!(att.attachment_type, "image"); + assert_eq!(att.filename, "photo.png"); + assert_eq!(att.mime_type, "image/jpeg"); // resized PNG → JPEG + assert!(att.path.is_some()); // stored to filesystem + assert!(att.size > 0); + } + + #[tokio::test] + async fn download_googlechat_file_rejects_non_text_extension() { + let client = reqwest::Client::new(); + let result = download_googlechat_file( + &client, + "fake-token", + "https://unused", // not called for non-text + "AATT", + "binary.exe", + TEXT_TOTAL_CAP, + ) + .await; + assert!(result.is_none(), "non-text extensions must be skipped"); + } + + #[tokio::test] + async fn download_googlechat_file_text_extension_succeeds() { + use wiremock::{Mock, MockServer, ResponseTemplate}; + use wiremock::matchers::{method, path_regex}; + + let mock_server = MockServer::start().await; + Mock::given(method("GET")) + .and(path_regex("/media/.*")) + .respond_with( + ResponseTemplate::new(200).set_body_bytes(b"hello world".to_vec()), + ) + .mount(&mock_server) + .await; + + let client = reqwest::Client::new(); + let result = download_googlechat_file( + &client, + "fake-token", + &mock_server.uri(), + "AATT", + "notes.txt", + TEXT_TOTAL_CAP, + ) + .await; + let att = result.expect("expected successful download"); + assert_eq!(att.attachment_type, "text_file"); + assert_eq!(att.filename, "notes.txt"); + assert_eq!(att.mime_type, "text/plain"); + } + + #[tokio::test] + async fn download_googlechat_audio_returns_attachment() { + use wiremock::{Mock, MockServer, ResponseTemplate}; + use wiremock::matchers::{method, path_regex}; + + let mock_server = MockServer::start().await; + let audio_bytes = vec![0u8; 1024]; + Mock::given(method("GET")) + .and(path_regex("/media/.*")) + .respond_with(ResponseTemplate::new(200).set_body_bytes(audio_bytes.clone())) + .mount(&mock_server) + .await; + + let client = reqwest::Client::new(); + let result = download_googlechat_audio( + &client, + "fake-token", + &mock_server.uri(), + "AATT", + "voice.m4a", + "audio/mp4", + ) + .await; + let att = result.expect("expected successful download"); + assert_eq!(att.attachment_type, "audio"); + assert_eq!(att.filename, "voice.m4a"); + assert_eq!(att.mime_type, "audio/mp4"); + assert_eq!(att.size, 1024); + } + + #[tokio::test] + async fn download_googlechat_image_rejects_oversized_content_length() { + use wiremock::{Mock, MockServer, ResponseTemplate}; + use wiremock::matchers::{method, path_regex}; + + let mock_server = MockServer::start().await; + Mock::given(method("GET")) + .and(path_regex("/media/.*")) + .respond_with( + ResponseTemplate::new(200) + .insert_header("content-length", "20000000") // 20 MB > 10 MB limit + .set_body_bytes(vec![0u8; 100]), + ) + .mount(&mock_server) + .await; + + let client = reqwest::Client::new(); + let result = download_googlechat_image( + &client, + "fake-token", + &mock_server.uri(), + "AATT", + "huge.png", + ) + .await; + assert!(result.is_none(), "oversized image must be rejected"); + } + + #[test] + fn parses_http_endpoint_url_top_level_envelope() { + let envelope: GoogleChatEnvelope = serde_json::from_value(serde_json::json!({ + "message": { + "name": "spaces/AAAA/messages/BBBB", + "text": "hello", + "attachment": [] + }, + "user": { + "name": "users/123", + "displayName": "Test User", + "type": "HUMAN" + }, + "space": { + "name": "spaces/AAAA", + "type": "DM" + } + })) + .unwrap(); + assert!(envelope.chat.is_none()); + assert!(envelope.message.is_some()); + assert_eq!(envelope.message.unwrap().name, "spaces/AAAA/messages/BBBB"); + assert!(envelope.user.is_some()); + assert_eq!(envelope.user.unwrap().name, "users/123"); + assert!(envelope.space.is_some()); + assert_eq!(envelope.space.unwrap().name, "spaces/AAAA"); + } +} diff --git a/crates/openab-gateway/src/adapters/line.rs b/crates/openab-gateway/src/adapters/line.rs new file mode 100644 index 000000000..1323d2605 --- /dev/null +++ b/crates/openab-gateway/src/adapters/line.rs @@ -0,0 +1,780 @@ +use crate::media::{resize_and_compress, IMAGE_MAX_DOWNLOAD}; +use crate::schema::*; +use crate::store; +use axum::extract::State; +use serde::Deserialize; +use std::sync::Arc; +use tracing::{error, info, warn}; + +// --- LINE types --- + +#[derive(Debug, Deserialize)] +pub struct LineWebhookBody { + events: Vec, +} + +#[derive(Debug, Deserialize)] +struct LineEvent { + #[serde(rename = "type")] + event_type: String, + source: Option, + message: Option, + #[serde(rename = "replyToken")] + reply_token: Option, +} + +#[derive(Debug, Deserialize)] +struct LineSource { + #[serde(rename = "type")] + source_type: String, + #[serde(rename = "userId")] + user_id: Option, + #[serde(rename = "groupId")] + group_id: Option, + #[serde(rename = "roomId")] + room_id: Option, +} + +#[derive(Debug, Deserialize)] +struct LineMessage { + id: String, + #[serde(rename = "type")] + message_type: String, + text: Option, + #[serde(rename = "contentProvider")] + content_provider: Option, + mention: Option, +} + +#[derive(Debug, Deserialize)] +struct LineMention { + mentionees: Vec, +} + +#[derive(Debug, Deserialize)] +struct LineMentionee { + #[serde(rename = "userId")] + user_id: Option, + #[serde(rename = "isSelf", default)] + is_self: bool, +} + +#[derive(Debug, Deserialize)] +struct LineContentProvider { + #[serde(rename = "type")] + provider_type: String, + #[serde(rename = "originalContentUrl")] + original_content_url: Option, +} + +/// Base URL for LINE Messaging API. Overridden in tests via the `api_base` parameter. +pub const LINE_API_BASE: &str = "https://api.line.me"; +/// Base URL for LINE binary content download API. +pub const LINE_DATA_API_BASE: &str = "https://api-data.line.me"; + +// --- Webhook handler --- + +pub async fn webhook( + State(state): State>, + headers: axum::http::HeaderMap, + body: axum::body::Bytes, +) -> axum::http::StatusCode { + // Validate X-Line-Signature + if let Some(ref channel_secret) = state.line_channel_secret { + use base64::Engine; + use hmac::{Hmac, Mac}; + use sha2::Sha256; + + let signature = headers + .get("x-line-signature") + .and_then(|v| v.to_str().ok()); + let Some(signature) = signature else { + warn!("LINE webhook rejected: missing X-Line-Signature"); + return axum::http::StatusCode::UNAUTHORIZED; + }; + + let mut mac = Hmac::::new_from_slice(channel_secret.as_bytes()).expect("HMAC key"); + mac.update(&body); + let expected = + base64::engine::general_purpose::STANDARD.encode(mac.finalize().into_bytes()); + if signature != expected { + warn!("LINE webhook rejected: invalid signature"); + return axum::http::StatusCode::UNAUTHORIZED; + } + } + + let webhook_body: LineWebhookBody = match serde_json::from_slice(&body) { + Ok(b) => b, + Err(e) => { + warn!("LINE webhook parse error: {e}"); + return axum::http::StatusCode::BAD_REQUEST; + } + }; + + let webhook_received_at = std::time::Instant::now(); + let background_state = state.clone(); + let permit = match background_state + .line_webhook_semaphore + .clone() + .acquire_owned() + .await + { + Ok(permit) => permit, + Err(_) => { + warn!("LINE webhook worker semaphore closed unexpectedly"); + return axum::http::StatusCode::SERVICE_UNAVAILABLE; + } + }; + tokio::spawn(async move { + let _permit = permit; + process_line_webhook_events(background_state, webhook_body, webhook_received_at).await; + }); + + axum::http::StatusCode::OK +} + +async fn process_line_webhook_events( + state: Arc, + webhook_body: LineWebhookBody, + webhook_received_at: std::time::Instant, +) { + // Acknowledge the webhook before image download/processing so LINE does not + // redeliver solely because gateway-side attachment work is slow. We keep one + // task per webhook payload so events from the same payload preserve order. + // + // Tradeoff: + // - Pros: lowers webhook latency and reduces redelivery pressure from LINE. + // - Cons: once 200 OK is returned, a later crash/task failure will not be + // retried by LINE. This PR intentionally keeps scope small and does not add + // background-task durability or duplicate suppression on top of early-ack. + // - Cons: an earlier image event from one webhook payload can also be emitted + // after a later text event from another payload if the image path is slower. + // - Guardrail: a shared semaphore bounds how many LINE payloads can enter the + // post-ack path concurrently. When saturated, new webhooks wait for capacity + // before spawning background work so bursts do not create unbounded backlog. + for event in webhook_body.events { + let Some(gateway_event) = build_gateway_event_from_line_event( + &event, + &state.client, + state.line_access_token.as_deref(), + LINE_DATA_API_BASE, + ) + .await + else { + continue; + }; + + // Cache before broadcasting the event. Once event_tx.send() fires, OAB + // may reply immediately; inserting afterward can silently force Push API. + // We still use webhook receipt time so TTL reflects true reply-token age. + if let Some(ref reply_token) = event.reply_token { + let mut cache = state + .reply_token_cache + .lock() + .unwrap_or_else(|e| e.into_inner()); + if cache.len() >= crate::REPLY_TOKEN_CACHE_MAX { + warn!( + size = cache.len(), + "reply token cache full, skipping insert" + ); + } else { + cache.insert( + gateway_event.event_id.clone(), + (reply_token.clone(), webhook_received_at), + ); + info!(event_id = %gateway_event.event_id, "cached LINE replyToken"); + } + } + + let json = serde_json::to_string(&gateway_event).unwrap(); + info!(channel = %gateway_event.channel.id, sender = %gateway_event.sender.id, "line → gateway"); + let _ = state.event_tx.send(json); + } +} + +fn sanitize_line_external_url_for_log(url: &str) -> String { + reqwest::Url::parse(url) + .ok() + .and_then(|parsed| parsed.host_str().map(str::to_owned)) + .unwrap_or_else(|| "invalid-or-missing-host".to_string()) +} + +async fn build_gateway_event_from_line_event( + event: &LineEvent, + client: &reqwest::Client, + line_access_token: Option<&str>, + data_api_base: &str, +) -> Option { + if event.event_type != "message" { + return None; + } + + let msg = event.message.as_ref()?; + if msg.message_type != "text" && msg.message_type != "image" { + return None; + } + + let text = msg.text.as_deref().unwrap_or(""); + let mut attachments = Vec::new(); + + if msg.message_type == "image" { + match msg + .content_provider + .as_ref() + .map(|provider| provider.provider_type.as_str()) + { + Some("external") => { + let original = msg + .content_provider + .as_ref() + .and_then(|provider| provider.original_content_url.as_deref()) + .unwrap_or("unknown"); + warn!( + message_id = %msg.id, + external_content_host = %sanitize_line_external_url_for_log(original), + "LINE external image content is not supported yet" + ); + } + _ => { + if let Some(access_token) = line_access_token { + if let Some(attachment) = + download_line_image(client, access_token, &msg.id, data_api_base).await + { + attachments.push(attachment); + } + } else { + warn!(message_id = %msg.id, "LINE image received but LINE_CHANNEL_ACCESS_TOKEN is not configured"); + } + } + } + } + + // Do not synthesize placeholder text for failed/unsupported image downloads. + // Core treats content.text as the user's prompt, so a fake marker would create + // a misleading turn instead of preserving the actual image content. + let event_text = text; + + if msg.message_type == "image" && event_text.trim().is_empty() && attachments.is_empty() { + info!( + message_id = %msg.id, + "LINE image event produced no attachment; skipping without synthesizing placeholder text" + ); + } + + if event_text.trim().is_empty() && attachments.is_empty() { + return None; + } + + let source = event.source.as_ref(); + let (channel_id, channel_type) = match source { + Some(s) if s.source_type == "group" => match s.group_id.as_deref() { + Some(id) if !id.is_empty() => (id.to_string(), "group".to_string()), + _ => { + warn!("LINE group event missing groupId, skipping"); + return None; + } + }, + Some(s) if s.source_type == "room" => match s.room_id.as_deref() { + Some(id) if !id.is_empty() => (id.to_string(), "room".to_string()), + _ => { + warn!("LINE room event missing roomId, skipping"); + return None; + } + }, + Some(s) => match s.user_id.as_deref() { + Some(id) if !id.is_empty() => (id.to_string(), "user".to_string()), + _ => { + warn!("LINE user event missing userId, skipping"); + return None; + } + }, + None => { + warn!("LINE event missing source, skipping"); + return None; + } + }; + let user_id = source + .and_then(|s| s.user_id.as_deref()) + .unwrap_or("unknown"); + + // Extract mentioned user IDs from the LINE webhook mention object. + // LINE populates this in group/room text messages when users are @-mentioned. + let mentionees = msg + .mention + .as_ref() + .map(|m| m.mentionees.as_slice()) + .unwrap_or_default(); + let mention_ids: Vec = mentionees + .iter() + .filter_map(|m| m.user_id.clone()) + .collect(); + + // @mention gating: in groups/rooms, only forward the event if the bot is mentioned. + // LINE sets isSelf=true on the mentionee that is the bot itself — no env var needed. + // 1:1 DMs always pass through. + let is_group = channel_type == "group" || channel_type == "room"; + if is_group && !mentionees.iter().any(|m| m.is_self) { + info!( + channel = %channel_id, + "line group message dropped (@mention gating: bot not mentioned)" + ); + return None; + } + + let mut gateway_event = GatewayEvent::new( + "line", + ChannelInfo { + id: channel_id, + channel_type, + thread_id: None, + }, + SenderInfo { + id: user_id.into(), + name: user_id.into(), + display_name: user_id.into(), + is_bot: false, + }, + event_text, + &msg.id, + mention_ids, + ); + gateway_event.content.attachments = attachments; + Some(gateway_event) +} + +pub async fn download_line_image( + client: &reqwest::Client, + access_token: &str, + message_id: &str, + api_base: &str, +) -> Option { + let mut resp = match client + .get(format!( + "{}/v2/bot/message/{}/content", + api_base, message_id + )) + .bearer_auth(access_token) + .send() + .await + { + Ok(resp) => resp, + Err(e) => { + warn!(message_id, error = %e, "LINE image download failed"); + return None; + } + }; + + if !resp.status().is_success() { + warn!(message_id, status = %resp.status(), "LINE image download failed"); + return None; + } + + if let Some(cl) = resp.headers().get(reqwest::header::CONTENT_LENGTH) { + if let Ok(size) = cl.to_str().unwrap_or("0").parse::() { + if size > IMAGE_MAX_DOWNLOAD { + warn!(message_id, size, "LINE image Content-Length exceeds limit"); + return None; + } + } + } + + let mut body = Vec::new(); + loop { + let chunk = match resp.chunk().await { + Ok(Some(chunk)) => chunk, + Ok(None) => break, + Err(e) => { + warn!(message_id, error = %e, "LINE image download failed while reading body"); + return None; + } + }; + body.extend_from_slice(&chunk); + if body.len() as u64 > IMAGE_MAX_DOWNLOAD { + warn!(message_id, size = body.len(), "LINE image exceeds limit"); + return None; + } + } + + let (compressed, mime) = + match tokio::task::spawn_blocking(move || resize_and_compress(&body)).await { + Ok(Ok(v)) => v, + Ok(Err(e)) => { + warn!(message_id, error = %e, "LINE image resize/compress failed"); + return None; + } + Err(e) => { + warn!(message_id, error = %e, "LINE image processing task failed"); + return None; + } + }; + let path = store::store_media(&compressed).await?; + let ext = if mime == "image/gif" { "gif" } else { "jpg" }; + Some(Attachment { + attachment_type: "image".into(), + filename: format!("line_{}.{}", message_id, ext), + mime_type: mime, + data: String::new(), + size: compressed.len() as u64, + path: Some(path), + }) +} + +// --- Reply handler (hybrid Reply/Push dispatch) --- + +/// Dispatch a reply to LINE using the hybrid Reply/Push strategy. +/// +/// Returns `true` if Reply API was used (or assumed used), `false` if Push API was used. +pub async fn dispatch_line_reply( + client: &reqwest::Client, + access_token: &str, + reply_cache: &crate::ReplyTokenCache, + reply: &GatewayReply, + api_base: &str, +) -> bool { + if matches!( + reply.command.as_deref(), + Some("add_reaction") | Some("remove_reaction") | Some("create_topic") + ) { + info!(command = ?reply.command.as_deref(), "line: ignoring unsupported command"); + return false; + } + + // Extract token from cache (drop lock before HTTP call) + let cached_token = { + let mut cache = reply_cache.lock().unwrap_or_else(|e| e.into_inner()); + cache + .remove(&reply.reply_to) + .and_then(|(token, cached_at)| { + if cached_at.elapsed().as_secs() < crate::REPLY_TOKEN_TTL_SECS { + Some(token) + } else { + info!("LINE replyToken expired, using Push API"); + None + } + }) + }; + + // Try Reply API first (free, no quota consumed) + let mut used_reply = false; + if let Some(reply_token) = cached_token { + info!(to = %reply.channel.id, "gateway → line (reply API)"); + let resp = client + .post(format!("{}/v2/bot/message/reply", api_base)) + .bearer_auth(access_token) + .json(&serde_json::json!({ + "replyToken": reply_token, + "messages": [{"type": "text", "text": reply.content.text}] + })) + .send() + .await; + match resp { + Ok(r) if r.status().is_success() => { + used_reply = true; + } + Ok(r) => { + let status = r.status(); + let body = r.text().await.unwrap_or_default(); + let body_lower = body.to_lowercase(); + let token_unusable = status.as_u16() == 400 + && ((body_lower.contains("invalid") && body_lower.contains("reply token")) + || body_lower.contains("expired")); + if token_unusable { + warn!(status = %status, body = %body, "LINE reply token unusable, falling back to Push"); + } else { + error!(status = %status, body = %body, "LINE Reply API error, NOT falling back to Push (possible duplicate risk)"); + used_reply = true; + } + } + Err(e) => { + error!(err = %e, "LINE Reply API network error, NOT falling back to Push (possible duplicate risk)"); + used_reply = true; + } + } + } + + // Fallback to Push API + if !used_reply { + info!(to = %reply.channel.id, "gateway → line (push API)"); + let _ = client + .post(format!("{}/v2/bot/message/push", api_base)) + .bearer_auth(access_token) + .json(&serde_json::json!({ + "to": reply.channel.id, + "messages": [{"type": "text", "text": reply.content.text}] + })) + .send() + .await + .map_err(|e| error!("line push error: {e}")); + } + + used_reply +} + +#[cfg(test)] +mod tests { + use super::*; + use axum::extract::State; + use std::collections::HashMap; + use std::sync::Arc; + use tokio::sync::{broadcast, Mutex, Semaphore}; + use wiremock::matchers::{header, method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + #[tokio::test] + async fn download_line_image_resizes_and_returns_attachment() { + let server = MockServer::start().await; + let img = image::RgbImage::from_pixel(16, 16, image::Rgb([0, 128, 255])); + let mut buf = std::io::Cursor::new(Vec::new()); + img.write_to(&mut buf, image::ImageFormat::Png).unwrap(); + let bytes = buf.into_inner(); + + let _mock = Mock::given(method("GET")) + .and(path("/v2/bot/message/msg123/content")) + .and(header("authorization", "Bearer line_token")) + .respond_with( + ResponseTemplate::new(200) + .insert_header("content-type", "image/png") + .set_body_bytes(bytes), + ) + .mount_as_scoped(&server) + .await; + + let attachment = download_line_image( + &reqwest::Client::new(), + "line_token", + "msg123", + &server.uri(), + ) + .await + .expect("attachment should be downloaded"); + + assert_eq!(attachment.attachment_type, "image"); + assert!(attachment.filename.starts_with("line_msg123.")); + assert!(attachment.path.is_some()); + assert!(attachment.size > 0); + + let path = attachment.path.unwrap(); + let stored = tokio::fs::read(&path).await.unwrap(); + assert!(!stored.is_empty()); + let _ = tokio::fs::remove_file(path).await; + } + + #[tokio::test] + async fn build_gateway_event_from_line_image_attaches_downloaded_image() { + let server = MockServer::start().await; + let img = image::RgbImage::from_pixel(8, 8, image::Rgb([255, 0, 0])); + let mut buf = std::io::Cursor::new(Vec::new()); + img.write_to(&mut buf, image::ImageFormat::Png).unwrap(); + let bytes = buf.into_inner(); + + let _mock = Mock::given(method("GET")) + .and(path("/v2/bot/message/msg_image/content")) + .and(header("authorization", "Bearer line_token")) + .respond_with( + ResponseTemplate::new(200) + .insert_header("content-type", "image/png") + .set_body_bytes(bytes), + ) + .mount_as_scoped(&server) + .await; + + let event: LineEvent = serde_json::from_value(serde_json::json!({ + "type": "message", + "replyToken": "reply123", + "source": {"type": "user", "userId": "U123"}, + "message": { + "id": "msg_image", + "type": "image", + "contentProvider": {"type": "line"} + } + })) + .unwrap(); + + let gateway_event = build_gateway_event_from_line_event( + &event, + &reqwest::Client::new(), + Some("line_token"), + &server.uri(), + ) + .await + .expect("image event should produce a gateway event"); + + assert_eq!(gateway_event.platform, "line"); + assert_eq!(gateway_event.content.text, ""); + assert_eq!(gateway_event.content.attachments.len(), 1); + + let path = gateway_event.content.attachments[0] + .path + .clone() + .expect("path should be stored"); + let _ = tokio::fs::remove_file(path).await; + } + + #[tokio::test] + async fn download_line_image_rejects_oversized_content_length() { + let server = MockServer::start().await; + + let _mock = Mock::given(method("GET")) + .and(path("/v2/bot/message/msg_big/content")) + .and(header("authorization", "Bearer line_token")) + .respond_with( + ResponseTemplate::new(200) + .insert_header("content-type", "image/png") + .insert_header("content-length", (IMAGE_MAX_DOWNLOAD + 1).to_string()) + .set_body_bytes(vec![0u8; IMAGE_MAX_DOWNLOAD as usize + 1]), + ) + .mount_as_scoped(&server) + .await; + + let attachment = download_line_image( + &reqwest::Client::new(), + "line_token", + "msg_big", + &server.uri(), + ) + .await; + + assert!(attachment.is_none()); + } + + #[tokio::test] + async fn webhook_acknowledges_before_async_event_forwarding() { + let (event_tx, mut event_rx) = broadcast::channel::(8); + let state = Arc::new(crate::AppState { + telegram_bot_token: None, + telegram_secret_token: None, + telegram_rich_messages: false, + line_channel_secret: None, + line_access_token: None, + teams: None, + teams_service_urls: Mutex::new(HashMap::new()), + feishu: None, + google_chat: None, + wecom: None, + ws_token: None, + event_tx, + reply_token_cache: Arc::new(std::sync::Mutex::new(HashMap::new())), + line_webhook_semaphore: Arc::new(Semaphore::new(crate::LINE_WEBHOOK_CONCURRENCY_MAX)), + client: reqwest::Client::new(), + }); + + let body = axum::body::Bytes::from( + serde_json::json!({ + "events": [{ + "type": "message", + "replyToken": "reply123", + "source": {"type": "user", "userId": "U123"}, + "message": {"id": "msg123", "type": "text", "text": "hello"} + }] + }) + .to_string(), + ); + + let status = webhook(State(state.clone()), axum::http::HeaderMap::new(), body).await; + assert_eq!(status, axum::http::StatusCode::OK); + + let event_json = tokio::time::timeout(std::time::Duration::from_secs(1), event_rx.recv()) + .await + .expect("background task should forward an event") + .expect("broadcast should succeed"); + let event: GatewayEvent = serde_json::from_str(&event_json).expect("valid gateway event"); + + assert_eq!(event.message_id, "msg123"); + assert_eq!(event.content.text, "hello"); + + let cache = state + .reply_token_cache + .lock() + .unwrap_or_else(|e| e.into_inner()); + let (token, cached_at) = cache + .get(&event.event_id) + .expect("reply token should be cached"); + assert_eq!(token, "reply123"); + assert!(cached_at.elapsed() < std::time::Duration::from_secs(1)); + } + + // --- @mention gating tests --- + + fn make_group_text_event(text: &str, bot_mentioned: bool) -> LineEvent { + let mention = if bot_mentioned { + serde_json::json!({"mentionees": [{"userId": "Ubot123", "type": "user", "isSelf": true}]}) + } else { + serde_json::json!({"mentionees": [{"userId": "Uother", "type": "user", "isSelf": false}]}) + }; + serde_json::from_value(serde_json::json!({ + "type": "message", + "source": {"type": "group", "groupId": "C001", "userId": "U_sender"}, + "message": { + "id": "msg001", + "type": "text", + "text": text, + "mention": mention + } + })) + .unwrap() + } + + #[tokio::test] + async fn group_message_passes_when_bot_mentioned() { + let event = make_group_text_event("@Bot hello", true); + let result = build_gateway_event_from_line_event( + &event, + &reqwest::Client::new(), + None, + LINE_DATA_API_BASE, + ) + .await; + assert!(result.is_some()); + let gw = result.unwrap(); + assert_eq!(gw.mentions, vec!["Ubot123"]); + } + + #[tokio::test] + async fn group_message_dropped_when_bot_not_mentioned() { + let event = make_group_text_event("hey everyone", false); + let result = build_gateway_event_from_line_event( + &event, + &reqwest::Client::new(), + None, + LINE_DATA_API_BASE, + ) + .await; + assert!(result.is_none()); + } + + #[tokio::test] + async fn group_message_dropped_when_no_mention_at_all() { + let event: LineEvent = serde_json::from_value(serde_json::json!({ + "type": "message", + "source": {"type": "group", "groupId": "C001", "userId": "U_sender"}, + "message": {"id": "msg001", "type": "text", "text": "plain message no mention"} + })) + .unwrap(); + let result = build_gateway_event_from_line_event( + &event, + &reqwest::Client::new(), + None, + LINE_DATA_API_BASE, + ) + .await; + assert!(result.is_none()); + } + + #[tokio::test] + async fn dm_passes_even_without_mention() { + let event: LineEvent = serde_json::from_value(serde_json::json!({ + "type": "message", + "source": {"type": "user", "userId": "U_human"}, + "message": {"id": "msg002", "type": "text", "text": "hello bot"} + })) + .unwrap(); + let result = build_gateway_event_from_line_event( + &event, + &reqwest::Client::new(), + None, + LINE_DATA_API_BASE, + ) + .await; + assert!(result.is_some()); + } +} diff --git a/crates/openab-gateway/src/adapters/mod.rs b/crates/openab-gateway/src/adapters/mod.rs new file mode 100644 index 000000000..cbacad507 --- /dev/null +++ b/crates/openab-gateway/src/adapters/mod.rs @@ -0,0 +1,12 @@ +#[cfg(feature = "telegram")] +pub mod telegram; +#[cfg(feature = "line")] +pub mod line; +#[cfg(feature = "feishu")] +pub mod feishu; +#[cfg(feature = "googlechat")] +pub mod googlechat; +#[cfg(feature = "wecom")] +pub mod wecom; +#[cfg(feature = "teams")] +pub mod teams; diff --git a/crates/openab-gateway/src/adapters/teams.rs b/crates/openab-gateway/src/adapters/teams.rs new file mode 100644 index 000000000..09ac09df8 --- /dev/null +++ b/crates/openab-gateway/src/adapters/teams.rs @@ -0,0 +1,877 @@ +use crate::schema::*; +use axum::extract::State; +use axum::http::{HeaderMap, StatusCode}; +use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation}; +use serde::Deserialize; +use std::sync::Arc; +use tokio::sync::RwLock; +use tracing::{debug, error, info, warn}; + +// --- Bot Framework activity types --- + +#[allow(dead_code)] // Bot Framework schema fields — needed for future features +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Activity { + #[serde(rename = "type")] + pub activity_type: String, + pub id: Option, + pub timestamp: Option, + pub service_url: Option, + pub channel_id: Option, + pub from: Option, + pub conversation: Option, + pub text: Option, + pub tenant: Option, + pub channel_data: Option, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ChannelAccount { + pub id: Option, + pub name: Option, + pub aad_object_id: Option, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ConversationAccount { + pub id: Option, + pub conversation_type: Option, + pub is_group: Option, + pub tenant_id: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct TenantInfo { + pub id: Option, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ChannelData { + pub tenant: Option, +} + +impl Activity { + /// Resolve tenant id from any of the locations Teams may put it. + pub fn resolved_tenant_id(&self) -> Option<&str> { + self.tenant + .as_ref() + .and_then(|t| t.id.as_deref()) + .or_else(|| { + self.channel_data + .as_ref() + .and_then(|c| c.tenant.as_ref()) + .and_then(|t| t.id.as_deref()) + }) + .or_else(|| { + self.conversation + .as_ref() + .and_then(|c| c.tenant_id.as_deref()) + }) + } +} + +// --- OpenID configuration --- + +#[derive(Debug, Deserialize)] +struct OpenIdConfig { + jwks_uri: String, +} + +#[derive(Debug, Deserialize)] +struct JwksResponse { + keys: Vec, +} + +#[derive(Debug, Clone, Deserialize)] +struct JwkKey { + kid: Option, + n: String, + e: String, + kty: String, + #[serde(default)] + endorsements: Vec, +} + +// --- OAuth token --- + +#[derive(Debug, Deserialize)] +struct TokenResponse { + access_token: String, + expires_in: u64, +} + +struct CachedToken { + token: String, + expires_at: std::time::Instant, +} + +// --- Teams adapter config --- + +pub struct TeamsConfig { + pub app_id: String, + pub app_secret: String, + pub oauth_endpoint: String, + pub openid_metadata: String, + pub allowed_tenants: Vec, +} + +impl TeamsConfig { + pub fn from_env() -> Option { + let app_id = std::env::var("TEAMS_APP_ID").ok()?; + let app_secret = std::env::var("TEAMS_APP_SECRET").ok()?; + Some(Self { + app_id, + app_secret, + oauth_endpoint: std::env::var("TEAMS_OAUTH_ENDPOINT").unwrap_or_else(|_| { + "https://login.microsoftonline.com/botframework.com/oauth2/v2.0/token".into() + }), + openid_metadata: std::env::var("TEAMS_OPENID_METADATA").unwrap_or_else(|_| { + "https://login.botframework.com/v1/.well-known/openidconfiguration".into() + }), + allowed_tenants: std::env::var("TEAMS_ALLOWED_TENANTS") + .unwrap_or_default() + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect(), + }) + } +} + +// --- Teams adapter state --- + +pub struct TeamsAdapter { + config: TeamsConfig, + client: reqwest::Client, + token_cache: RwLock>, + jwks_cache: RwLock, std::time::Instant)>>, +} + +const JWKS_CACHE_TTL: std::time::Duration = std::time::Duration::from_secs(3600); +const TOKEN_REFRESH_MARGIN: std::time::Duration = std::time::Duration::from_secs(300); + +impl TeamsAdapter { + pub fn new(config: TeamsConfig) -> Self { + Self { + config, + client: reqwest::Client::new(), + token_cache: RwLock::new(None), + jwks_cache: RwLock::new(None), + } + } + + /// Get a valid OAuth bearer token, refreshing if needed. + async fn get_token(&self) -> anyhow::Result { + // Check cache + { + let cache = self.token_cache.read().await; + if let Some(ref cached) = *cache { + if cached.expires_at > std::time::Instant::now() + TOKEN_REFRESH_MARGIN { + return Ok(cached.token.clone()); + } + } + } + + // Fetch new token + let resp: TokenResponse = self + .client + .post(&self.config.oauth_endpoint) + .form(&[ + ("grant_type", "client_credentials"), + ("client_id", &self.config.app_id), + ("client_secret", &self.config.app_secret), + ("scope", "https://api.botframework.com/.default"), + ]) + .send() + .await? + .json() + .await?; + + let token = resp.access_token.clone(); + *self.token_cache.write().await = Some(CachedToken { + token: resp.access_token, + expires_at: std::time::Instant::now() + std::time::Duration::from_secs(resp.expires_in), + }); + info!("teams OAuth token refreshed"); + Ok(token) + } + + /// Fetch and cache JWKS signing keys from Microsoft's OpenID metadata. + async fn get_jwks(&self) -> anyhow::Result> { + { + let cache = self.jwks_cache.read().await; + if let Some((ref keys, fetched_at)) = *cache { + if fetched_at.elapsed() < JWKS_CACHE_TTL { + return Ok(keys.clone()); + } + } + } + + let config: OpenIdConfig = self + .client + .get(&self.config.openid_metadata) + .send() + .await? + .json() + .await?; + + let jwks: JwksResponse = self + .client + .get(&config.jwks_uri) + .send() + .await? + .json() + .await?; + + let keys = jwks.keys; + *self.jwks_cache.write().await = Some((keys.clone(), std::time::Instant::now())); + info!(count = keys.len(), "teams JWKS keys refreshed"); + Ok(keys) + } + + /// Force-refresh JWKS keys, bypassing cache TTL. Called on cache miss (kid not found). + async fn refresh_jwks(&self) -> anyhow::Result> { + // Invalidate cache so get_jwks fetches fresh + *self.jwks_cache.write().await = None; + self.get_jwks().await + } + + /// Validate the JWT bearer token from an inbound Bot Framework request. + /// Checks: signature, issuer, audience, expiry, serviceUrl claim, and channel endorsements. + pub async fn validate_jwt(&self, auth_header: &str, activity: &Activity) -> anyhow::Result<()> { + let token = auth_header + .strip_prefix("Bearer ") + .ok_or_else(|| anyhow::anyhow!("missing Bearer prefix"))?; + + // Decode header to get kid + let header = jsonwebtoken::decode_header(token)?; + let kid = header + .kid + .ok_or_else(|| anyhow::anyhow!("no kid in JWT header"))?; + + let keys = self.get_jwks().await?; + let key = match keys.iter().find(|k| k.kid.as_deref() == Some(&kid)) { + Some(k) => k.clone(), + None => { + // Cache miss: Microsoft may have rotated keys. Force refresh and retry. + let refreshed = self.refresh_jwks().await?; + refreshed + .into_iter() + .find(|k| k.kid.as_deref() == Some(&kid)) + .ok_or_else(|| anyhow::anyhow!("no matching JWK for kid={kid} after refresh"))? + } + }; + + if key.kty != "RSA" { + anyhow::bail!("unsupported key type: {}", key.kty); + } + + // B2: Validate channel endorsements — key must endorse the activity's channelId + let channel_id = activity.channel_id.as_deref() + .ok_or_else(|| anyhow::anyhow!("activity missing channelId"))?; + if key.endorsements.is_empty() { + anyhow::bail!("JWK has no endorsements — cannot verify channelId={channel_id}"); + } + if !key.endorsements.iter().any(|e| e == channel_id) { + anyhow::bail!( + "JWK endorsements {:?} do not include channelId={channel_id}", + key.endorsements + ); + } + + let decoding_key = DecodingKey::from_rsa_components(&key.n, &key.e)?; + let mut validation = Validation::new(Algorithm::RS256); + validation.set_audience(&[&self.config.app_id]); + // Bot Framework tokens can use RS256 or RS384 + validation.algorithms = vec![Algorithm::RS256, Algorithm::RS384]; + // Bot Framework issuer per auth spec + validation.set_issuer(&["https://api.botframework.com"]); + validation.validate_aud = true; + validation.validate_exp = true; + validation.validate_nbf = false; + + let token_data = decode::(token, &decoding_key, &validation)?; + + // B1: Validate serviceUrl claim matches activity's serviceUrl + let activity_service_url = activity.service_url.as_deref() + .ok_or_else(|| anyhow::anyhow!("activity missing serviceUrl"))?; + let token_service_url = token_data.claims.get("serviceurl") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("JWT missing serviceurl claim"))?; + if token_service_url != activity_service_url { + anyhow::bail!( + "serviceUrl mismatch: token={token_service_url}, activity={activity_service_url}" + ); + } + + Ok(()) + } + + /// Check tenant allowlist. + fn check_tenant(&self, activity: &Activity) -> bool { + if self.config.allowed_tenants.is_empty() { + return true; + } + activity + .resolved_tenant_id() + .is_some_and(|tid| self.config.allowed_tenants.iter().any(|a| a == tid)) + } + + /// Send a reply via Bot Framework REST API. + pub async fn send_activity( + &self, + service_url: &str, + conversation_id: &str, + text: &str, + reply_to_id: Option<&str>, + ) -> anyhow::Result { + let token = self.get_token().await?; + let url = format!( + "{}v3/conversations/{}/activities", + ensure_trailing_slash(service_url), + conversation_id + ); + + let mut body = serde_json::json!({ + "type": "message", + "from": { "id": &self.config.app_id }, + "text": text, + "textFormat": "markdown", + }); + if let Some(id) = reply_to_id { + body["replyToId"] = serde_json::Value::String(id.to_string()); + } + + let resp = self + .client + .post(&url) + .bearer_auth(&token) + .json(&body) + .send() + .await?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + anyhow::bail!("Bot Framework API error {status}: {body}"); + } + + let result: serde_json::Value = resp.json().await?; + Ok(result["id"].as_str().unwrap_or("").to_string()) + } + + /// Edit an existing activity (for streaming updates). + pub async fn update_activity( + &self, + service_url: &str, + conversation_id: &str, + activity_id: &str, + text: &str, + ) -> anyhow::Result<()> { + let token = self.get_token().await?; + let url = format!( + "{}v3/conversations/{}/activities/{}", + ensure_trailing_slash(service_url), + conversation_id, + activity_id + ); + + let body = serde_json::json!({ + "type": "message", + "from": { "id": &self.config.app_id }, + "text": text, + }); + + let resp = self + .client + .put(&url) + .bearer_auth(&token) + .json(&body) + .send() + .await?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + anyhow::bail!("Bot Framework update error {status}: {body}"); + } + Ok(()) + } +} + +fn ensure_trailing_slash(url: &str) -> String { + if url.ends_with('/') { + url.to_string() + } else { + format!("{url}/") + } +} + +// --- Webhook handler --- + +/// Max webhook body size: 256 KB. Real Teams activities are a few KB; the +/// activity is parsed *before* JWT auth (Bot Framework requires serviceUrl / +/// channelId from the body to validate the token), so this caps the +/// unauthenticated parse attack surface. Mirrors the feishu adapter's limit. +const WEBHOOK_BODY_LIMIT: usize = 256 * 1024; + +pub async fn webhook( + State(state): State>, + headers: HeaderMap, + body: String, +) -> StatusCode { + let teams = match &state.teams { + Some(t) => t, + None => return StatusCode::NOT_FOUND, + }; + + // Defense-in-depth: bound the pre-auth body size (axum's default limit is 2 MB). + if body.len() > WEBHOOK_BODY_LIMIT { + warn!(size = body.len(), "teams webhook body too large"); + return StatusCode::PAYLOAD_TOO_LARGE; + } + + // Extract auth header early (before parsing activity) + let auth_header = match headers.get("authorization").and_then(|v| v.to_str().ok()) { + Some(h) => h.to_string(), + None => { + warn!("teams webhook: missing authorization header"); + return StatusCode::UNAUTHORIZED; + } + }; + + // Parse activity first (needed for JWT serviceUrl + endorsements validation). + // + // SECURITY NOTE (OX untrusted-deserialization finding — false positive): + // `Activity` is a strict, derive-only DTO (String / Option<_> / nested + // structs) with no custom Deserialize, no side-effectful Drop, and no enum + // variant dispatch. serde_json's data model cannot instantiate arbitrary + // types (unlike bincode/serde_yaml/rmp-serde), so object-injection / RCE + // does not apply. The recommended "strict DTO + validate after" pattern is + // already in place: JWT, activity-type, and tenant-allowlist checks below. + // DoS is bounded by serde_json's recursion limit (128) and the body cap above. + let activity: Activity = match serde_json::from_str(&body) { + Ok(a) => a, + Err(e) => { + warn!(error = %e, "teams: invalid activity JSON"); + return StatusCode::BAD_REQUEST; + } + }; + + // JWT validation (with activity context for serviceUrl + channelId checks) + if let Err(e) = teams.validate_jwt(&auth_header, &activity).await { + warn!(error = %e, "teams JWT validation failed"); + return StatusCode::UNAUTHORIZED; + } + + // Only handle message activities + if activity.activity_type != "message" { + debug!(activity_type = %activity.activity_type, "teams: ignoring non-message activity"); + return StatusCode::OK; + } + + // Tenant check + if !teams.check_tenant(&activity) { + let tid = activity.resolved_tenant_id().unwrap_or("unknown"); + warn!(tenant = tid, "teams: tenant not in allowlist"); + return StatusCode::FORBIDDEN; + } + + let text = match activity.text.as_deref() { + Some(t) if !t.trim().is_empty() => t.trim(), + _ => return StatusCode::OK, + }; + + let conversation_id = activity + .conversation + .as_ref() + .and_then(|c| c.id.as_deref()) + .unwrap_or(""); + let conversation_type = activity + .conversation + .as_ref() + .and_then(|c| c.conversation_type.as_deref()) + .unwrap_or("personal"); + let service_url = activity.service_url.as_deref().unwrap_or(""); + let sender_id = activity + .from + .as_ref() + .and_then(|f| f.id.as_deref()) + .unwrap_or(""); + let sender_name = activity + .from + .as_ref() + .and_then(|f| f.name.as_deref()) + .unwrap_or("Unknown"); + let activity_id = activity.id.as_deref().unwrap_or(""); + + // B3: Guard against empty service_url — replies will fail without it + if service_url.is_empty() { + warn!("teams: activity missing service_url, cannot route replies"); + return StatusCode::OK; + } + + let event = GatewayEvent::new( + "teams", + ChannelInfo { + id: conversation_id.to_string(), + channel_type: conversation_type.to_string(), + thread_id: None, // Teams conversations don't have sub-threads in the same way + }, + SenderInfo { + id: sender_id.to_string(), + name: sender_name.to_string(), + display_name: sender_name.to_string(), + is_bot: false, + }, + text, + activity_id, + vec![], // Teams @mentions parsing deferred to future PR + ); + + // Store service_url for reply routing + state.teams_service_urls.lock().await.insert( + conversation_id.to_string(), + (service_url.to_string(), std::time::Instant::now()), + ); + + let json = serde_json::to_string(&event).unwrap(); + let tenant_id = activity.resolved_tenant_id().unwrap_or(""); + info!( + conversation = conversation_id, + sender = sender_name, + tenant = tenant_id, + service_url = service_url, + "teams → gateway" + ); + let _ = state.event_tx.send(json); + + StatusCode::OK +} + +// --- Reply handler --- + +pub async fn handle_reply( + reply: &GatewayReply, + teams: &TeamsAdapter, + service_urls: &tokio::sync::Mutex< + std::collections::HashMap, + >, +) { + // Reactions are not supported on Teams — silently ignore + if reply.command.as_deref() == Some("add_reaction") + || reply.command.as_deref() == Some("remove_reaction") + { + return; + } + + let service_url = { + let mut urls = service_urls.lock().await; + match urls.get_mut(&reply.channel.id) { + Some((url, ts)) => { + // Refresh timestamp on reply to prevent TTL expiry during active conversations + *ts = std::time::Instant::now(); + url.clone() + } + None => { + error!(conversation = %reply.channel.id, "teams: no service_url for conversation"); + return; + } + } + }; + + let reply_to_id = if reply.reply_to.is_empty() { + None + } else { + Some(reply.reply_to.as_str()) + }; + + info!(conversation = %reply.channel.id, "gateway → teams"); + match teams + .send_activity( + &service_url, + &reply.channel.id, + &reply.content.text, + reply_to_id, + ) + .await + { + Ok(id) => debug!(activity_id = %id, "teams activity sent"), + Err(e) => error!(error = %e, "teams send error"), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // --- ensure_trailing_slash --- + + #[test] + fn trailing_slash_adds_when_missing() { + assert_eq!( + ensure_trailing_slash("https://example.com"), + "https://example.com/" + ); + } + + #[test] + fn trailing_slash_keeps_when_present() { + assert_eq!( + ensure_trailing_slash("https://example.com/"), + "https://example.com/" + ); + } + + #[test] + fn trailing_slash_empty_string() { + assert_eq!(ensure_trailing_slash(""), "/"); + } + + // --- check_tenant --- + + fn make_config(tenants: Vec<&str>) -> TeamsConfig { + TeamsConfig { + app_id: "test-app".into(), + app_secret: "test-secret".into(), + oauth_endpoint: "https://example.com/token".into(), + openid_metadata: "https://example.com/openid".into(), + allowed_tenants: tenants.into_iter().map(|s| s.to_string()).collect(), + } + } + + fn make_test_state() -> Arc { + let (event_tx, _rx) = tokio::sync::broadcast::channel(16); + + Arc::new(crate::AppState { + telegram_bot_token: None, + telegram_secret_token: None, + telegram_rich_messages: false, + line_channel_secret: None, + line_access_token: None, + teams: Some(TeamsAdapter::new(make_config(vec![]))), + teams_service_urls: tokio::sync::Mutex::new(std::collections::HashMap::new()), + feishu: None, + google_chat: None, + wecom: None, + ws_token: None, + event_tx, + reply_token_cache: Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())), + line_webhook_semaphore: Arc::new(tokio::sync::Semaphore::new(crate::LINE_WEBHOOK_CONCURRENCY_MAX)), + client: reqwest::Client::new(), + }) + } + + fn make_activity_with_tenant(tenant_id: Option<&str>) -> Activity { + Activity { + activity_type: "message".into(), + id: Some("act1".into()), + timestamp: None, + service_url: Some("https://smba.trafficmanager.net/".into()), + channel_id: Some("msteams".into()), + from: None, + conversation: None, + text: Some("hello".into()), + tenant: tenant_id.map(|id| TenantInfo { + id: Some(id.into()), + }), + channel_data: None, + } + } + + // --- webhook body limit --- + + #[tokio::test] + async fn webhook_rejects_oversized_body_before_auth() { + let status = webhook( + State(make_test_state()), + HeaderMap::new(), + "x".repeat(WEBHOOK_BODY_LIMIT + 1), + ) + .await; + + assert_eq!(status, StatusCode::PAYLOAD_TOO_LARGE); + } + + #[tokio::test] + async fn webhook_allows_body_at_limit_to_reach_auth() { + let status = webhook( + State(make_test_state()), + HeaderMap::new(), + "x".repeat(WEBHOOK_BODY_LIMIT), + ) + .await; + + assert_eq!(status, StatusCode::UNAUTHORIZED); + } + + #[test] + fn tenant_allowed_when_list_empty() { + let adapter = TeamsAdapter::new(make_config(vec![])); + let activity = make_activity_with_tenant(Some("any-tenant")); + assert!(adapter.check_tenant(&activity)); + } + + #[test] + fn tenant_allowed_when_in_list() { + let adapter = TeamsAdapter::new(make_config(vec!["tenant-a", "tenant-b"])); + let activity = make_activity_with_tenant(Some("tenant-b")); + assert!(adapter.check_tenant(&activity)); + } + + #[test] + fn tenant_rejected_when_not_in_list() { + let adapter = TeamsAdapter::new(make_config(vec!["tenant-a"])); + let activity = make_activity_with_tenant(Some("tenant-x")); + assert!(!adapter.check_tenant(&activity)); + } + + #[test] + fn tenant_rejected_when_no_tenant_info() { + let adapter = TeamsAdapter::new(make_config(vec!["tenant-a"])); + let activity = make_activity_with_tenant(None); + assert!(!adapter.check_tenant(&activity)); + } + + #[test] + fn tenant_allowed_when_no_tenant_and_empty_list() { + let adapter = TeamsAdapter::new(make_config(vec![])); + let activity = make_activity_with_tenant(None); + assert!(adapter.check_tenant(&activity)); + } + + // --- resolved_tenant_id --- + + #[test] + fn resolved_tenant_falls_back_to_channel_data() { + // Teams personal/channel webhooks put tenant in channelData, not top-level + let json = r#"{ + "type": "message", + "channelData": {"tenant": {"id": "from-channel-data"}} + }"#; + let activity: Activity = serde_json::from_str(json).unwrap(); + assert_eq!(activity.resolved_tenant_id(), Some("from-channel-data")); + } + + #[test] + fn resolved_tenant_prefers_top_level_over_channel_data() { + let json = r#"{ + "type": "message", + "tenant": {"id": "top-level"}, + "channelData": {"tenant": {"id": "from-channel-data"}} + }"#; + let activity: Activity = serde_json::from_str(json).unwrap(); + assert_eq!(activity.resolved_tenant_id(), Some("top-level")); + } + + #[test] + fn resolved_tenant_falls_back_to_conversation_tenant_id() { + let json = r#"{ + "type": "message", + "conversation": {"id": "c1", "tenantId": "from-conversation"} + }"#; + let activity: Activity = serde_json::from_str(json).unwrap(); + assert_eq!(activity.resolved_tenant_id(), Some("from-conversation")); + } + + #[test] + fn resolved_tenant_returns_none_when_absent() { + let json = r#"{"type": "message"}"#; + let activity: Activity = serde_json::from_str(json).unwrap(); + assert_eq!(activity.resolved_tenant_id(), None); + } + + // --- validate_jwt error paths --- + + #[tokio::test] + async fn jwt_rejects_missing_bearer_prefix() { + let adapter = TeamsAdapter::new(make_config(vec![])); + let activity = make_activity_with_tenant(Some("t1")); + let result = adapter.validate_jwt("NotBearer xyz", &activity).await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Bearer")); + } + + #[tokio::test] + async fn jwt_rejects_empty_bearer() { + let adapter = TeamsAdapter::new(make_config(vec![])); + let activity = make_activity_with_tenant(Some("t1")); + let result = adapter.validate_jwt("Bearer ", &activity).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn jwt_rejects_garbage_token() { + let adapter = TeamsAdapter::new(make_config(vec![])); + let activity = make_activity_with_tenant(Some("t1")); + let result = adapter.validate_jwt("Bearer not.a.valid.jwt", &activity).await; + assert!(result.is_err()); + } + + // --- Activity deserialization --- + + #[test] + fn deserialize_minimal_activity() { + let json = r#"{"type": "message"}"#; + let activity: Activity = serde_json::from_str(json).unwrap(); + assert_eq!(activity.activity_type, "message"); + assert!(activity.text.is_none()); + assert!(activity.from.is_none()); + } + + #[test] + fn deserialize_full_activity() { + let json = r#"{ + "type": "message", + "id": "act123", + "serviceUrl": "https://smba.trafficmanager.net/", + "channelId": "msteams", + "from": {"id": "user1", "name": "Alice", "aadObjectId": "aad-123"}, + "conversation": {"id": "conv1", "conversationType": "personal", "isGroup": false}, + "text": "hello bot", + "tenant": {"id": "tenant-abc"} + }"#; + let activity: Activity = serde_json::from_str(json).unwrap(); + assert_eq!(activity.activity_type, "message"); + assert_eq!(activity.text.as_deref(), Some("hello bot")); + assert_eq!( + activity.from.as_ref().unwrap().name.as_deref(), + Some("Alice") + ); + assert_eq!( + activity.tenant.as_ref().unwrap().id.as_deref(), + Some("tenant-abc") + ); + } + + #[test] + fn deserialize_non_message_activity() { + let json = r#"{"type": "conversationUpdate"}"#; + let activity: Activity = serde_json::from_str(json).unwrap(); + assert_eq!(activity.activity_type, "conversationUpdate"); + } + + #[test] + fn deserialize_invalid_json_fails() { + let result = serde_json::from_str::("not json"); + assert!(result.is_err()); + } + + // --- TeamsConfig::from_env --- + + #[test] + fn config_from_env_returns_none_without_vars() { + // Ensure the env vars are not set (they shouldn't be in test) + std::env::remove_var("TEAMS_APP_ID"); + std::env::remove_var("TEAMS_APP_SECRET"); + assert!(TeamsConfig::from_env().is_none()); + } +} diff --git a/crates/openab-gateway/src/adapters/telegram.rs b/crates/openab-gateway/src/adapters/telegram.rs new file mode 100644 index 000000000..60a98bd06 --- /dev/null +++ b/crates/openab-gateway/src/adapters/telegram.rs @@ -0,0 +1,782 @@ +use crate::media::{resize_and_compress, MediaKind, AUDIO_MAX_DOWNLOAD, FILE_MAX_DOWNLOAD, IMAGE_MAX_DOWNLOAD}; +use crate::schema::*; +use crate::store; +use axum::extract::State; +use axum::Json; +use serde::Deserialize; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::Mutex; +use tracing::{error, info, warn}; + +/// Base URL for Telegram Bot API. Extracted as constant for consistency +/// with LINE's `LINE_API_BASE` and to enable future mock testing. +pub const TELEGRAM_API_BASE: &str = "https://api.telegram.org"; + +// --- Telegram types --- + +#[derive(Debug, Deserialize)] +pub struct TelegramUpdate { + message: Option, +} + +#[derive(Debug, Deserialize)] +struct TelegramMessage { + message_id: i64, + message_thread_id: Option, + chat: TelegramChat, + from: Option, + text: Option, + caption: Option, + #[serde(default)] + entities: Vec, + #[serde(default)] + caption_entities: Vec, + #[serde(default)] + photo: Vec, + document: Option, + voice: Option, + audio: Option, +} + +#[derive(Debug, Deserialize)] +struct TelegramPhoto { + file_id: String, + width: u32, + height: u32, +} + +#[derive(Debug, Deserialize)] +struct TelegramDocument { + file_id: String, + file_name: Option, + mime_type: Option, +} + +#[derive(Debug, Deserialize)] +struct TelegramVoice { + file_id: String, + #[allow(dead_code)] // TODO: use for Content-Type hint + mime_type: Option, +} + +#[derive(Debug, Deserialize)] +struct TelegramAudio { + file_id: String, + #[allow(dead_code)] // TODO: use for filename + file_name: Option, + #[allow(dead_code)] // TODO: use for Content-Type hint + mime_type: Option, +} + +#[derive(Debug, Deserialize)] +struct TelegramEntity { + #[serde(rename = "type")] + entity_type: String, + offset: usize, + length: usize, +} + +#[derive(Debug, Deserialize)] +struct TelegramChat { + id: i64, + #[serde(rename = "type")] + chat_type: String, + #[allow(dead_code)] + is_forum: Option, +} + +#[derive(Debug, Deserialize)] +struct TelegramUser { + id: i64, + first_name: String, + last_name: Option, + username: Option, + is_bot: bool, +} + +// --- Webhook handler --- + +pub async fn webhook( + State(state): State>, + headers: axum::http::HeaderMap, + Json(update): Json, +) -> axum::http::StatusCode { + if let Some(ref expected) = state.telegram_secret_token { + let provided = headers + .get("x-telegram-bot-api-secret-token") + .and_then(|v| v.to_str().ok()); + if provided != Some(expected.as_str()) { + warn!("webhook rejected: invalid or missing secret_token"); + return axum::http::StatusCode::UNAUTHORIZED; + } + } + + let Some(msg) = update.message else { + return axum::http::StatusCode::OK; + }; + let is_photo = !msg.photo.is_empty(); + let is_document = msg.document.is_some(); + let is_voice = msg.voice.is_some(); + let is_audio = msg.audio.is_some(); + let text = msg.text.as_deref().or(msg.caption.as_deref()).unwrap_or(""); + + if text.trim().is_empty() && !is_photo && !is_document && !is_voice && !is_audio { + return axum::http::StatusCode::OK; + } + + let mut attachments = Vec::new(); + if is_photo || is_document || is_voice || is_audio { + if let Some(ref token) = state.telegram_bot_token { + let client = &state.client; + if is_photo { + if let Some(largest) = msg.photo.iter().max_by_key(|p| p.width * p.height) { + if let Some(att) = + download_telegram_media(client, token, &largest.file_id, MediaKind::Image).await + { + attachments.push(att); + } + } + } else if let Some(doc) = msg.document { + let file_name = doc.file_name.unwrap_or_else(|| "unknown.txt".to_string()); + let mime_type = doc.mime_type.unwrap_or_else(|| "text/plain".to_string()); + if let Some(att) = + download_telegram_document(client, token, &doc.file_id, &file_name, &mime_type).await + { + attachments.push(att); + } + } else if let Some(voice) = msg.voice { + if let Some(att) = download_telegram_media(client, token, &voice.file_id, MediaKind::Audio).await { + attachments.push(att); + } + } else if let Some(audio) = msg.audio { + if let Some(att) = download_telegram_media(client, token, &audio.file_id, MediaKind::Audio).await { + attachments.push(att); + } + } + } + } + + let from = msg.from.as_ref(); + let sender_name = from + .and_then(|u| u.username.as_deref()) + .unwrap_or("unknown"); + let display_name = from + .map(|u| { + let mut n = u.first_name.clone(); + if let Some(last) = &u.last_name { + n.push(' '); + n.push_str(last); + } + n + }) + .unwrap_or_else(|| "Unknown".into()); + + let mentions: Vec = msg + .entities + .iter() + .chain(msg.caption_entities.iter()) + .filter(|e| e.entity_type == "mention") + .filter_map(|e| { + text.get(e.offset..e.offset + e.length) + .map(|s| s.trim_start_matches('@').to_string()) + }) + .collect(); + + let mut event = GatewayEvent::new( + "telegram", + ChannelInfo { + id: msg.chat.id.to_string(), + channel_type: msg.chat.chat_type.clone(), + thread_id: msg.message_thread_id.map(|id| id.to_string()), + }, + SenderInfo { + id: from.map(|u| u.id.to_string()).unwrap_or_default(), + name: sender_name.into(), + display_name, + is_bot: from.map(|u| u.is_bot).unwrap_or(false), + }, + text, + &msg.message_id.to_string(), + mentions, + ); + event.content.attachments = attachments; + + // Guard: skip empty events (no text + no attachments) + if event.content.text.trim().is_empty() && event.content.attachments.is_empty() { + return axum::http::StatusCode::OK; + } + + let json = serde_json::to_string(&event).unwrap(); + info!(chat_id = %msg.chat.id, sender = %sender_name, "telegram → gateway"); + let _ = state.event_tx.send(json); + axum::http::StatusCode::OK +} + +/// Split text into chunks of at most `limit` characters, breaking at newlines when possible. +fn chunk_text(text: &str, limit: usize) -> Vec { + if text.chars().count() <= limit { + return vec![text.to_string()]; + } + let mut chunks = Vec::new(); + let mut current = String::new(); + for line in text.lines() { + if !current.is_empty() && current.chars().count() + line.chars().count() + 1 > limit { + chunks.push(std::mem::take(&mut current)); + } + if !current.is_empty() { + current.push('\n'); + } + if line.chars().count() > limit { + // Line itself exceeds limit — hard split + for ch in line.chars() { + current.push(ch); + if current.chars().count() >= limit { + chunks.push(std::mem::take(&mut current)); + } + } + } else { + current.push_str(line); + } + } + if !current.is_empty() { + chunks.push(current); + } + chunks +} + +fn is_markdown_parse_error(description: &str) -> bool { + let desc_lower = description.to_lowercase(); + desc_lower.contains("can't find end") + || desc_lower.contains("can't parse") + || desc_lower.contains("parse entities") +} + +/// Returns true if the content is complex enough to benefit from sendRichMessage. +/// +/// Design decisions: +/// - We classify at the adapter layer (not agent) so agents don't need prompt changes. +/// - Conservative: only route to rich when legacy sendMessage would visibly break. +/// - False positives are acceptable (rich renders simple text fine too), but we avoid +/// unnecessary API switches for plain prose to reduce risk surface. +/// - LaTeX and blockquotes are intentionally omitted for now (Phase 2). +fn is_complex_markdown(text: &str) -> bool { + // 🟡 Code blocks intentionally NOT routed to rich — sendMessage preserves + // syntax highlighting (language header + copy button) which RichBlockPreformatted lacks. + + // sendMessage hard limit is 4096 chars. Rich messages support 32768. + if text.chars().count() > 4096 { + return true; + } + text.lines().any(|line| { + let trimmed = line.trim_start(); + // ATX headings (h1-h6): sendMessage has zero heading support. + if trimmed.starts_with("# ") + || trimmed.starts_with("## ") + || trimmed.starts_with("### ") + || trimmed.starts_with("#### ") + || trimmed.starts_with("##### ") + || trimmed.starts_with("###### ") + { + return true; + } + // GFM table separator row detection. + if trimmed.starts_with('|') && trimmed.ends_with('|') { + let inner = &trimmed[1..trimmed.len() - 1]; + if inner.split('|').all(|cell| { + let c = cell.trim().trim_matches(':'); + !c.is_empty() && c.chars().all(|ch| ch == '-') + }) { + return true; + } + } + false + }) +} + +/// Send a rich message via Bot API 10.1 sendRichMessage. +/// +/// Design: we pass agent markdown directly via InputRichMessage.markdown. +/// Rich Markdown is GFM-compatible, so no conversion layer is needed. +/// The API handles rendering (tables, syntax highlighting, headings, etc.) +async fn send_rich_message( + client: &reqwest::Client, + bot_token: &str, + chat_id: &str, + thread_id: &Option, + text: &str, +) -> Result { + let url = format!("{TELEGRAM_API_BASE}/bot{bot_token}/sendRichMessage"); + let body = serde_json::json!({ + "chat_id": chat_id, + "message_thread_id": thread_id, + "rich_message": { "markdown": text }, + }); + let resp = client.post(&url).json(&body).send().await.map_err(|e| e.to_string())?; + let json: serde_json::Value = resp.json().await.map_err(|e| e.to_string())?; + if json["ok"].as_bool() == Some(true) { + Ok(json) + } else { + Err(json["description"].as_str().unwrap_or("unknown error").to_string()) + } +} + +/// Stream a partial rich message via sendRichMessageDraft. +/// +/// Design: ephemeral 30-second preview. Caller must follow up with +/// sendRichMessage to persist. Same draft_id = animated transition. +/// Wired but unused until gateway streaming infrastructure integrates. +#[allow(dead_code)] +async fn send_rich_message_draft( + client: &reqwest::Client, + bot_token: &str, + chat_id: &str, + thread_id: &Option, + draft_id: i64, + text: &str, +) -> Result<(), String> { + let url = format!("{TELEGRAM_API_BASE}/bot{bot_token}/sendRichMessageDraft"); + let body = serde_json::json!({ + "chat_id": chat_id, + "message_thread_id": thread_id, + "draft_id": draft_id, + "rich_message": if text.contains(", + reaction_state: &Arc>>>, + rich_messages: bool, +) { + // Handle create_topic command + if reply.command.as_deref() == Some("create_topic") { + let req_id = reply.request_id.clone().unwrap_or_default(); + info!(chat_id = %reply.channel.id, "creating forum topic"); + let url = format!("{TELEGRAM_API_BASE}/bot{bot_token}/createForumTopic"); + let resp = client + .post(&url) + .json(&serde_json::json!({"chat_id": reply.channel.id, "name": reply.content.text})) + .send() + .await; + let gw_resp = match resp { + Ok(r) => { + let body: serde_json::Value = r.json().await.unwrap_or_default(); + if body["ok"].as_bool() == Some(true) { + let tid = body["result"]["message_thread_id"] + .as_i64() + .map(|id| id.to_string()); + info!(thread_id = ?tid, "forum topic created"); + GatewayResponse { + schema: "openab.gateway.response.v1".into(), + request_id: req_id, + success: true, + thread_id: tid, + message_id: None, + error: None, + } + } else { + let err = body["description"] + .as_str() + .unwrap_or("unknown error") + .to_string(); + warn!(err = %err, "createForumTopic failed"); + GatewayResponse { + schema: "openab.gateway.response.v1".into(), + request_id: req_id, + success: false, + thread_id: None, + message_id: None, + error: Some(err), + } + } + } + Err(e) => GatewayResponse { + schema: "openab.gateway.response.v1".into(), + request_id: req_id, + success: false, + thread_id: None, + message_id: None, + error: Some(e.to_string()), + }, + }; + let json = serde_json::to_string(&gw_resp).unwrap(); + let _ = event_tx.send(json); + return; + } + + // Handle edit_message + if reply.command.as_deref() == Some("edit_message") { + if reply.reply_to == "draft" { + // Dummy "draft" ref from streaming without placeholder. + if rich_messages { + // Skip short updates — let thinking animation show until meaningful content arrives + if reply.content.text.len() < 30 { + return; + } + let text = if reply.content.text.len() > 32768 { + &reply.content.text[..reply.content.text.floor_char_boundary(32768)] + } else { + &reply.content.text + }; + // Combine channel + thread to avoid draft_id collision in forum topics + let chan: i64 = reply.channel.id.parse::().unwrap_or(1).abs(); + let tid: i64 = reply.channel.thread_id.as_deref().and_then(|t| t.parse::().ok()).unwrap_or(0).abs(); + let draft_id: i64 = (chan.wrapping_add(tid)) % 1_000_000 + 1; + let _ = send_rich_message_draft(client, bot_token, &reply.channel.id, &reply.channel.thread_id, draft_id, text).await; + } + // else: rich_messages=false with dummy ref — silently drop (no real msg to edit) + return; + } + // Real message_id — perform actual editMessageText (legacy streaming path) + let url = format!("{TELEGRAM_API_BASE}/bot{bot_token}/editMessageText"); + let _ = client + .post(&url) + .json(&serde_json::json!({ + "chat_id": reply.channel.id, + "message_id": reply.reply_to, + "text": &reply.content.text, + "parse_mode": "Markdown", + })) + .send() + .await; + return; + } + + // Handle add_reaction / remove_reaction + if reply.command.as_deref() == Some("add_reaction") + || reply.command.as_deref() == Some("remove_reaction") + { + // Send thinking draft on reaction changes — reflects agent state + if rich_messages && reply.command.as_deref() == Some("add_reaction") { + let thinking_text = match reply.content.text.as_str() { + "👀" => Some("Looking..."), + "🤔" => Some("Thinking..."), + "👨\u{200d}💻" => Some("Writing code..."), + "🔥" => Some("Working..."), + "⚡" => Some("Running tools..."), + _ => None, + }; + if let Some(text) = thinking_text { + let chan: i64 = reply.channel.id.parse::().unwrap_or(1).abs(); + let tid: i64 = reply.channel.thread_id.as_deref().and_then(|t| t.parse::().ok()).unwrap_or(0).abs(); + let draft_id: i64 = (chan.wrapping_add(tid)) % 1_000_000 + 1; + let _ = send_rich_message_draft( + client, bot_token, &reply.channel.id, &reply.channel.thread_id, draft_id, text, + ).await; + } + } + + let msg_key = format!("{}:{}", reply.channel.id, reply.reply_to); + let emoji = &reply.content.text; + let tg_emoji = match emoji.as_str() { + "🆗" => "👍", + other => other, + }; + let is_add = reply.command.as_deref() == Some("add_reaction"); + { + let mut reactions = reaction_state.lock().await; + let set = reactions.entry(msg_key.clone()).or_default(); + if is_add { + if !set.contains(&tg_emoji.to_string()) { + set.push(tg_emoji.to_string()); + } + } else { + set.retain(|e| e != tg_emoji); + } + } + let current: Vec = { + let reactions = reaction_state.lock().await; + reactions + .get(&msg_key) + .map(|v| { + v.iter() + .map(|e| serde_json::json!({"type": "emoji", "emoji": e})) + .collect() + }) + .unwrap_or_default() + }; + let url = format!("{TELEGRAM_API_BASE}/bot{bot_token}/setMessageReaction"); + let _ = client + .post(&url) + .json(&serde_json::json!({ + "chat_id": reply.channel.id, + "message_id": reply.reply_to, + "reaction": current, + })) + .send() + .await + .map_err(|e| error!("telegram reaction error: {e}")); + return; + } + + // Normal send_message + info!( + chat_id = %reply.channel.id, + thread_id = ?reply.channel.thread_id, + "gateway → telegram" + ); + + // --- Rich Message routing --- + // Design: try sendRichMessage first for complex content. On ANY failure + // (unsupported client, API version mismatch, network error), fall back to + // legacy sendMessage (chunked). This ensures zero-downtime rollout. + if rich_messages && is_complex_markdown(&reply.content.text) { + // Bot API limit: 32768 UTF-8 characters (not bytes). + let text = &reply.content.text; + let rich_text: String = if text.chars().count() > 32768 { + text.chars().take(32768).collect() + } else { + text.to_string() + }; + match send_rich_message(client, bot_token, &reply.channel.id, &reply.channel.thread_id, &rich_text).await { + Ok(_) => return, + Err(e) => warn!("sendRichMessage failed ({e}), falling back to sendMessage"), + } + } + + // Legacy sendMessage — chunk at 4096 chars to avoid rejection. + let chunks = chunk_text(&reply.content.text, 4096); + for chunk in &chunks { + let url = format!("{TELEGRAM_API_BASE}/bot{bot_token}/sendMessage"); + let resp = client + .post(&url) + .json(&serde_json::json!({ + "chat_id": reply.channel.id, + "text": chunk, + "message_thread_id": reply.channel.thread_id, + "parse_mode": "Markdown", + })) + .send() + .await; + + match resp { + Ok(r) => { + let body: serde_json::Value = r.json().await.unwrap_or_default(); + if body["ok"].as_bool() != Some(true) { + let desc = body["description"].as_str().unwrap_or("unknown error"); + if is_markdown_parse_error(desc) { + warn!("Markdown send failed: {desc}, retrying as plain text"); + match client + .post(&url) + .json(&serde_json::json!({ + "chat_id": reply.channel.id, + "text": chunk, + "message_thread_id": reply.channel.thread_id, + })) + .send() + .await + { + Ok(retry_r) => { + let retry_body: serde_json::Value = + retry_r.json().await.unwrap_or_default(); + if retry_body["ok"].as_bool() != Some(true) { + error!( + "telegram plain-text retry failed: {}", + retry_body["description"] + .as_str() + .unwrap_or("unknown error") + ); + } + } + Err(e) => error!("telegram plain-text send error: {e}"), + } + } else { + error!("telegram send failed: {desc}"); + } + } + } + Err(e) => error!("telegram send error: {e}"), + } + } +} + +/// Download media from Telegram via getFile → store to filesystem (colocate mode). +async fn download_telegram_media( + client: &reqwest::Client, + bot_token: &str, + file_id: &str, + kind: MediaKind, +) -> Option { + let get_file_url = format!("{TELEGRAM_API_BASE}/bot{}/getFile", bot_token); + let resp = client.get(&get_file_url).query(&[("file_id", file_id)]).send().await.ok()?; + let body: serde_json::Value = resp.json().await.ok()?; + let file_path = body["result"]["file_path"].as_str()?; + + let download_url = format!("{TELEGRAM_API_BASE}/file/bot{}/{}", bot_token, file_path); + let resp = client.get(&download_url).send().await.ok()?; + if !resp.status().is_success() { + return None; + } + + let max_size = match kind { + MediaKind::Image => IMAGE_MAX_DOWNLOAD, + MediaKind::Audio => AUDIO_MAX_DOWNLOAD, + }; + + if let Some(cl) = resp.headers().get(reqwest::header::CONTENT_LENGTH) { + if let Ok(size) = cl.to_str().unwrap_or("0").parse::() { + if size > max_size { + warn!(file_id, size, kind = ?kind, "Telegram media Content-Length exceeds limit"); + return None; + } + } + } + + let default_mime = match kind { + MediaKind::Image => "image/jpeg", + MediaKind::Audio => "audio/ogg", + }; + let content_type = resp + .headers() + .get(reqwest::header::CONTENT_TYPE) + .and_then(|h| h.to_str().ok()) + .unwrap_or(default_mime) + .to_string(); + + let bytes = resp.bytes().await.ok()?; + if bytes.len() as u64 > max_size { + warn!(file_id, size = bytes.len(), kind = ?kind, "Telegram media exceeds limit"); + return None; + } + + let (data_bytes, mime) = match kind { + MediaKind::Image => match resize_and_compress(&bytes) { + Ok((c, m)) => (c, m), + Err(e) => { + error!(err = %e, "Telegram image processing failed"); + return None; + } + }, + MediaKind::Audio => (bytes.to_vec(), content_type), + }; + + // Store to filesystem instead of base64 encoding + let path = store::store_media(&data_bytes).await?; + let att_type = match kind { + MediaKind::Image => "image", + MediaKind::Audio => "audio", + }; + info!(file_id, size = data_bytes.len(), kind = ?kind, "Telegram media stored"); + + Some(Attachment { + attachment_type: att_type.into(), + filename: format!("{}.{}", file_id, match kind { + MediaKind::Image => "jpg", + MediaKind::Audio => crate::media::audio_extension(&mime), + }), + mime_type: mime, + data: String::new(), // No base64 — using file path + size: data_bytes.len() as u64, + path: Some(path), + }) +} + +/// Download text document from Telegram → store to filesystem. +async fn download_telegram_document( + client: &reqwest::Client, + bot_token: &str, + file_id: &str, + file_name: &str, + mime_type: &str, +) -> Option { + if !crate::media::is_text_extension(file_name) { + tracing::debug!(file_name, "skipping non-text file attachment"); + return None; + } + + let get_file_url = format!("{TELEGRAM_API_BASE}/bot{}/getFile", bot_token); + let resp = client.get(&get_file_url).query(&[("file_id", file_id)]).send().await.ok()?; + let body: serde_json::Value = resp.json().await.ok()?; + let file_path = body["result"]["file_path"].as_str()?; + + let download_url = format!("{TELEGRAM_API_BASE}/file/bot{}/{}", bot_token, file_path); + let resp = client.get(&download_url).send().await.ok()?; + if !resp.status().is_success() { + return None; + } + + if let Some(cl) = resp.headers().get(reqwest::header::CONTENT_LENGTH) { + if let Ok(size) = cl.to_str().unwrap_or("0").parse::() { + if size > FILE_MAX_DOWNLOAD { + warn!(file_id, size, "Telegram document Content-Length exceeds limit"); + return None; + } + } + } + + let bytes = resp.bytes().await.ok()?; + if bytes.len() as u64 > FILE_MAX_DOWNLOAD { + warn!(file_id, size = bytes.len(), "Telegram document exceeds limit"); + return None; + } + + // Validate UTF-8 — reject binary files + if String::from_utf8(bytes.to_vec()).is_err() { + warn!(file_id, file_name, "Telegram document is not valid UTF-8, skipping"); + return None; + } + + let path = store::store_media(&bytes).await?; + info!(file_id, file_name, size = bytes.len(), "Telegram document stored"); + + Some(Attachment { + attachment_type: "text_file".into(), + filename: file_name.to_string(), + mime_type: mime_type.to_string(), + data: String::new(), + size: bytes.len() as u64, + path: Some(path), + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_is_markdown_parse_error() { + assert!(is_markdown_parse_error("Bad Request: can't find end of italic entity at byte offset 37")); + assert!(is_markdown_parse_error("Bad Request: can't parse entities: Can't find end of bold entity")); + assert!(is_markdown_parse_error("can't parse entities in message text")); + assert!(!is_markdown_parse_error("Unauthorized")); + assert!(!is_markdown_parse_error("Bad Request: chat not found")); + } + + #[test] + fn test_is_complex_markdown() { + // Tables + assert!(is_complex_markdown("| Col1 | Col2 |\n|---|---|\n| a | b |")); + assert!(is_complex_markdown("| Col1 | Col2 |\n| :--- | ---: |\n| a | b |")); + assert!(is_complex_markdown("| A | B |\n| :---: | :---: |\n| x | y |")); + // Code blocks — intentionally NOT complex (preserves syntax highlighting on legacy path) + assert!(!is_complex_markdown("```rust\nfn main() {}\n```")); + assert!(!is_complex_markdown("~~~\ncode\n~~~")); + // Headings + assert!(is_complex_markdown("# Heading\n\nSome text")); + assert!(is_complex_markdown("## Heading 2 at start")); + assert!(is_complex_markdown("### Heading 3 at start")); + assert!(is_complex_markdown("#### Heading 4")); + assert!(is_complex_markdown("text\n##### Heading 5")); + assert!(is_complex_markdown(" ## Indented heading")); + // Size + assert!(is_complex_markdown(&"x".repeat(4097))); + // Negatives + assert!(!is_complex_markdown("Hello world")); + assert!(!is_complex_markdown("*bold* and _italic_")); + assert!(!is_complex_markdown("#hashtag no space")); + assert!(!is_complex_markdown("| just | pipes |")); + } +} diff --git a/crates/openab-gateway/src/adapters/wecom.rs b/crates/openab-gateway/src/adapters/wecom.rs new file mode 100644 index 000000000..e3e97ff17 --- /dev/null +++ b/crates/openab-gateway/src/adapters/wecom.rs @@ -0,0 +1,1654 @@ +use anyhow::Result; +use axum::extract::State; +use std::sync::Arc; +use tokio::sync::RwLock; +use tracing::{info, warn}; + +pub struct WecomConfig { + pub corp_id: String, + pub agent_id: String, + pub secret: String, + pub token: String, + pub encoding_aes_key: String, + pub webhook_path: String, + pub streaming_enabled: bool, + pub debounce_secs: u64, +} + +impl WecomConfig { + pub fn from_env() -> Option { + Self::from_reader(|k| std::env::var(k).ok()) + } + + /// Build config from an arbitrary string reader. Tests use this with a + /// HashMap so they don't mutate process-wide environment variables — + /// `env::set_var` races other tests under cargo's parallel runner. + fn from_reader Option>(read: F) -> Option { + let corp_id = read("WECOM_CORP_ID")?; + let secret = read("WECOM_SECRET")?; + let token = read("WECOM_TOKEN")?; + let encoding_aes_key = read("WECOM_ENCODING_AES_KEY")?; + let agent_id = read("WECOM_AGENT_ID")?; + if agent_id.parse::().is_err() { + warn!("WECOM_AGENT_ID must be a numeric value, got '{}'", agent_id); + return None; + } + let webhook_path = read("WECOM_WEBHOOK_PATH").unwrap_or_else(|| "/webhook/wecom".into()); + // Streaming opts-in: WeCom callback mode has no edit-message API, so + // streaming is implemented via thinking-placeholder + recall + resend, + // which causes a brief client flicker. Default off; set to true only if + // the UX tradeoff is acceptable. + let streaming_enabled = read("WECOM_STREAMING_ENABLED") + .map(|v| v == "true" || v == "1") + .unwrap_or(false); + let debounce_secs = read("WECOM_DEBOUNCE_SECS") + .and_then(|v| v.parse::().ok()) + .unwrap_or(3); + + if encoding_aes_key.len() != 43 { + warn!("WECOM_ENCODING_AES_KEY must be 43 characters, got {}", encoding_aes_key.len()); + return None; + } + + info!( + corp_id = %corp_id, + agent_id = %agent_id, + streaming_enabled, + debounce_secs, + "wecom adapter configured" + ); + Some(Self { + corp_id, + agent_id, + secret, + token, + encoding_aes_key, + webhook_path, + streaming_enabled, + debounce_secs, + }) + } +} + +fn decode_aes_key(encoding_aes_key: &str) -> anyhow::Result> { + use base64::engine::{DecodePaddingMode, GeneralPurpose, GeneralPurposeConfig}; + use base64::Engine; + // WeCom's EncodingAESKey is 43 base64 chars without trailing padding. + // Append "=" to make it a 44-char standard base64 string before decoding. + // Indifferent + allow_trailing_bits accommodate WeCom's non-standard + // encoding: the 43rd char's last 2 bits are not part of the output and + // must be ignored rather than rejected. + let padded = format!("{}=", encoding_aes_key); + let config = GeneralPurposeConfig::new() + .with_decode_padding_mode(DecodePaddingMode::Indifferent) + .with_decode_allow_trailing_bits(true); + let engine = GeneralPurpose::new(&base64::alphabet::STANDARD, config); + let key = engine + .decode(&padded) + .map_err(|e| anyhow::anyhow!("encoding_aes_key base64 decode failed: {e}"))?; + anyhow::ensure!( + key.len() == 32, + "encoding_aes_key must decode to 32 bytes, got {}", + key.len() + ); + Ok(key) +} + +fn compute_signature(token: &str, timestamp: &str, nonce: &str, encrypt: &str) -> String { + use sha1::Digest; + let mut parts = [token, timestamp, nonce, encrypt]; + parts.sort_unstable(); + let joined: String = parts.concat(); + let hash = sha1::Sha1::digest(joined.as_bytes()); + format!("{:x}", hash) +} + +fn verify_signature( + token: &str, + timestamp: &str, + nonce: &str, + encrypt: &str, + expected: &str, +) -> bool { + let computed = compute_signature(token, timestamp, nonce, encrypt); + tracing::debug!( + computed = %computed, + expected = %expected, + token_len = token.len(), + encrypt_len = encrypt.len(), + "signature comparison" + ); + subtle::ConstantTimeEq::ct_eq(computed.as_bytes(), expected.as_bytes()).into() +} + +fn decrypt_message( + encoding_aes_key: &str, + encrypted: &str, + expected_corp_id: &str, +) -> anyhow::Result { + use aes::cipher::{BlockDecryptMut, KeyIvInit}; + use base64::Engine; + + let key = decode_aes_key(encoding_aes_key)?; + let iv = &key[..16]; + + let cipher_bytes = base64::engine::general_purpose::STANDARD + .decode(encrypted) + .map_err(|e| anyhow::anyhow!("base64 decode failed: {e}"))?; + + if cipher_bytes.is_empty() || cipher_bytes.len() % 16 != 0 { + anyhow::bail!("ciphertext length {} not a multiple of 16", cipher_bytes.len()); + } + + type Aes256CbcDec = cbc::Decryptor; + let decryptor = Aes256CbcDec::new_from_slices(&key, iv) + .map_err(|e| anyhow::anyhow!("aes init failed: {e}"))?; + + let mut buf = cipher_bytes.to_vec(); + // WeCom uses PKCS7 with block_size=32, not 16. Decrypt without padding validation + // and strip padding manually. + let plaintext = decryptor + .decrypt_padded_mut::(&mut buf) + .map_err(|e| anyhow::anyhow!("aes decrypt failed: {e}"))?; + + // Strip WeCom PKCS7 padding (block_size=32): last byte indicates pad length (1-32) + let pad_byte = *plaintext.last().ok_or_else(|| anyhow::anyhow!("empty plaintext"))? as usize; + if pad_byte == 0 || pad_byte > 32 || pad_byte > plaintext.len() { + anyhow::bail!("invalid wecom padding value: {pad_byte}"); + } + let pad_start = plaintext.len() - pad_byte; + if !plaintext[pad_start..].iter().all(|&b| b as usize == pad_byte) { + anyhow::bail!("invalid PKCS#7 padding: not all padding bytes match"); + } + let plaintext = &plaintext[..pad_start]; + + // Plaintext structure: random(16) + msg_len(4, big-endian) + msg + corp_id + if plaintext.len() < 20 { + anyhow::bail!("decrypted payload too short"); + } + let msg_len = + u32::from_be_bytes([plaintext[16], plaintext[17], plaintext[18], plaintext[19]]) as usize; + if plaintext.len() < 20 + msg_len { + anyhow::bail!("msg_len exceeds payload size"); + } + let msg = &plaintext[20..20 + msg_len]; + let corp_id = &plaintext[20 + msg_len..]; + + let corp_id_str = + std::str::from_utf8(corp_id).map_err(|e| anyhow::anyhow!("corp_id not utf8: {e}"))?; + if corp_id_str != expected_corp_id { + anyhow::bail!("corp_id mismatch: expected {expected_corp_id}, got {corp_id_str}"); + } + + String::from_utf8(msg.to_vec()).map_err(|e| anyhow::anyhow!("message not utf8: {e}")) +} + +// --- Deduplication --- + +const DEDUPE_TTL_SECS: u64 = 30; +const DEDUPE_MAX_SIZE: usize = 10_000; + +struct DedupeCache { + entries: std::sync::Mutex>, +} + +impl DedupeCache { + fn new() -> Self { + Self { + entries: std::sync::Mutex::new(std::collections::HashMap::new()), + } + } + + fn check_and_insert(&self, msg_id: &str) -> bool { + let mut entries = self.entries.lock().unwrap_or_else(|e| e.into_inner()); + let now = std::time::Instant::now(); + + if entries.len() >= DEDUPE_MAX_SIZE { + entries.retain(|_, t| now.duration_since(*t).as_secs() < DEDUPE_TTL_SECS); + } + + if let Some(t) = entries.get(msg_id) { + if now.duration_since(*t).as_secs() < DEDUPE_TTL_SECS { + return false; + } + } + + entries.insert(msg_id.to_string(), now); + true + } +} + +// --- Token cache --- + +pub const WECOM_API_BASE: &str = "https://qyapi.weixin.qq.com"; +const TOKEN_REFRESH_MARGIN_SECS: u64 = 300; + +pub struct WecomTokenCache { + inner: RwLock>, + base_url: String, +} + +impl WecomTokenCache { + fn new() -> Self { + Self { + inner: RwLock::new(None), + base_url: WECOM_API_BASE.into(), + } + } + + #[cfg(test)] + fn with_base_url(base_url: String) -> Self { + Self { + inner: RwLock::new(None), + base_url, + } + } + + pub async fn get_token( + &self, + client: &reqwest::Client, + corp_id: &str, + secret: &str, + ) -> Result { + // Fast path: read lock + { + let guard = self.inner.read().await; + if let Some((ref token, created_at, expires_in)) = *guard { + let elapsed = created_at.elapsed().as_secs(); + if elapsed + TOKEN_REFRESH_MARGIN_SECS < expires_in { + return Ok(token.clone()); + } + } + } + + // Slow path: write lock + refresh + let mut guard = self.inner.write().await; + // Double-check after acquiring write lock + if let Some((ref token, created_at, expires_in)) = *guard { + let elapsed = created_at.elapsed().as_secs(); + if elapsed + TOKEN_REFRESH_MARGIN_SECS < expires_in { + return Ok(token.clone()); + } + } + + // WeCom's gettoken API requires `corpsecret` as a query parameter — the + // protocol mandates this, we can't move it to a header. Operators must + // configure their reverse proxy / load balancer to redact query strings + // on `/cgi-bin/gettoken` paths before logging access logs. We do not log + // this URL anywhere from the gateway side. + let url = format!( + "{}/cgi-bin/gettoken?corpid={}&corpsecret={}", + self.base_url, corp_id, secret + ); + let resp: serde_json::Value = client.get(&url).send().await?.json().await?; + + let errcode = resp["errcode"].as_i64().unwrap_or(-1); + if errcode != 0 { + anyhow::bail!( + "wecom gettoken failed: errcode={}, errmsg={}", + errcode, + resp["errmsg"] + ); + } + + let token = resp["access_token"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("missing access_token in response"))? + .to_string(); + let expires_in = resp["expires_in"].as_u64().unwrap_or(7200); + + *guard = Some((token.clone(), std::time::Instant::now(), expires_in)); + Ok(token) + } + + pub async fn force_refresh( + &self, + client: &reqwest::Client, + corp_id: &str, + secret: &str, + ) -> Result { + let mut guard = self.inner.write().await; + *guard = None; + drop(guard); + self.get_token(client, corp_id, secret).await + } +} + +// --- Adapter --- + +struct PendingStream { + text_watch: tokio::sync::watch::Sender, +} + +type PendingMap = Arc>>; + +pub struct WecomAdapter { + pub config: WecomConfig, + pub token_cache: Arc, + client: reqwest::Client, + dedupe: DedupeCache, + pending_streams: PendingMap, +} + +impl WecomAdapter { + pub fn new(config: WecomConfig) -> Self { + Self { + token_cache: Arc::new(WecomTokenCache::new()), + client: reqwest::Client::new(), + dedupe: DedupeCache::new(), + pending_streams: Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())), + config, + } + } + + + pub async fn handle_reply( + &self, + reply: &crate::schema::GatewayReply, + event_tx: &tokio::sync::broadcast::Sender, + ) { + if let Some(cmd) = reply.command.as_deref() { + match cmd { + "add_reaction" | "remove_reaction" | "create_topic" => { + info!(command = cmd, "wecom: ignoring unsupported command"); + return; + } + "edit_message" => { + self.handle_edit_message(reply); + return; + } + _ => {} + } + } + + let text = &reply.content.text; + if text.is_empty() { + return; + } + + let to_user = reply + .channel + .id + .rsplit(':') + .next() + .unwrap_or(&reply.channel.id); + + let has_pending = { + let pending = self.pending_streams.lock().unwrap_or_else(|e| e.into_inner()); + pending.contains_key(&reply.channel.id) + }; + let is_streaming_placeholder = reply.request_id.is_some() && !has_pending; + if is_streaming_placeholder { + // Optionally send a thinking placeholder. With streaming disabled + // (default), buffer chunks silently and send the consolidated text + // when the debounce settles — no recall/flicker. + let placeholder_id = if self.config.streaming_enabled { + info!(to_user = to_user, "wecom: sending thinking placeholder"); + match self.send_text(to_user, "⏳...").await { + Ok(id) => Some(id), + Err(e) => { + warn!("wecom send thinking failed: {e}"); + return; + } + } + } else { + None + }; + + let (text_tx, text_rx) = tokio::sync::watch::channel(String::new()); + { + let mut pending = self.pending_streams.lock().unwrap_or_else(|e| e.into_inner()); + pending.insert(reply.channel.id.clone(), PendingStream { + text_watch: text_tx, + }); + } + let client = self.client.clone(); + let token_cache = self.token_cache.clone(); + let corp_id = self.config.corp_id.clone(); + let secret = self.config.secret.clone(); + let agent_id = self.config.agent_id.clone(); + let thinking_id = placeholder_id.clone(); + let flush_to_user = to_user.to_string(); + let channel_id_clone = reply.channel.id.clone(); + let pending_clone = self.pending_streams.clone(); + let debounce_secs = self.config.debounce_secs; + tokio::spawn(async move { + let mut rx = text_rx; + let debounce = std::time::Duration::from_secs(debounce_secs); + let mut last_text = String::new(); + let max_idle = std::time::Duration::from_secs(300); + let started = std::time::Instant::now(); + loop { + match tokio::time::timeout(debounce, rx.changed()).await { + Ok(Ok(())) => { + last_text = rx.borrow().clone(); + } + Ok(Err(_)) => break, + Err(_) => { + if !last_text.is_empty() { + break; + } + if started.elapsed() > max_idle { + warn!("wecom: debounce task timed out after 5 minutes"); + break; + } + } + } + } + // Acquire pending lock first, then capture any late writes + // that landed between the loop break and now. Holding the + // lock blocks handle_reply from sending more chunks for this + // channel, so this read is the last writeable moment. Then + // remove the entry, which drops text_tx and closes the channel. + { + let mut pending = pending_clone.lock().unwrap_or_else(|e| e.into_inner()); + let final_text = rx.borrow().clone(); + if !final_text.is_empty() { + last_text = final_text; + } + pending.remove(&channel_id_clone); + } + if last_text.is_empty() { + return; + } + flush_thinking( + &client, &token_cache, &corp_id, &secret, &agent_id, + thinking_id.as_deref(), &flush_to_user, &last_text, + ).await; + }); + + if let Some(ref req_id) = reply.request_id { + let resp = crate::schema::GatewayResponse { + schema: "openab.gateway.response.v1".into(), + request_id: req_id.clone(), + success: true, + thread_id: None, + message_id: placeholder_id, + error: None, + }; + if let Ok(json) = serde_json::to_string(&resp) { + let _ = event_tx.send(json); + } + } + return; + } + + if has_pending { + // Re-check under lock: the debounce task may have removed the entry + // between our earlier read of `has_pending` and now. If it did, + // fall through to the direct-send path so the chunk isn't lost. + let appended = { + let pending = self.pending_streams.lock().unwrap_or_else(|e| e.into_inner()); + if let Some(stream) = pending.get(&reply.channel.id) { + let current = stream.text_watch.borrow().clone(); + let combined = if current.is_empty() { + text.to_string() + } else { + format!("{}\n{}", current, text) + }; + let _ = stream.text_watch.send(combined); + true + } else { + false + } + }; + if appended { + if let Some(ref req_id) = reply.request_id { + let resp = crate::schema::GatewayResponse { + schema: "openab.gateway.response.v1".into(), + request_id: req_id.clone(), + success: true, + thread_id: None, + message_id: None, + error: None, + }; + if let Ok(json) = serde_json::to_string(&resp) { + let _ = event_tx.send(json); + } + } + return; + } + // Pending entry was already removed (debounce flushed) — fall + // through to direct-send below so this chunk still reaches the user. + } + + info!(to_user = to_user, "wecom: sending reply"); + let chunks = split_text_lines(text, 2048); + let mut msg_id = None; + + for chunk in &chunks { + match self.send_text(to_user, chunk).await { + Ok(id) => { + if msg_id.is_none() { + msg_id = Some(id); + } + } + Err(e) => warn!("wecom send failed: {e}"), + } + } + + if let Some(ref req_id) = reply.request_id { + let resp = crate::schema::GatewayResponse { + schema: "openab.gateway.response.v1".into(), + request_id: req_id.clone(), + success: msg_id.is_some(), + thread_id: None, + message_id: msg_id, + error: None, + }; + if let Ok(json) = serde_json::to_string(&resp) { + let _ = event_tx.send(json); + } + } + } + + fn handle_edit_message(&self, reply: &crate::schema::GatewayReply) { + let text = reply.content.text.trim(); + if text.is_empty() { + return; + } + let pending = self.pending_streams.lock().unwrap_or_else(|e| e.into_inner()); + if let Some(stream) = pending.get(&reply.channel.id) { + let _ = stream.text_watch.send(text.to_string()); + } + } + + + async fn send_text(&self, to_user: &str, text: &str) -> Result { + let agent_id: u64 = self.config.agent_id.parse().expect("agent_id validated at startup"); + let body = serde_json::json!({ + "touser": to_user, + "msgtype": "text", + "agentid": agent_id, + "text": { "content": text } + }); + + let resp = post_with_token_retry( + &self.client, + &self.token_cache, + &self.config.corp_id, + &self.config.secret, + "/cgi-bin/message/send", + &body, + ) + .await?; + Ok(resp["msgid"].as_str().unwrap_or("").to_string()) + } +} + +/// POST a JSON body to a WeCom API endpoint with automatic token refresh +/// on errcode 42001 (access_token expired). Used by both `send_text` and +/// the streaming flush path so a long-running stream can't lose its final +/// reply if the cached token expires mid-flight. +async fn post_with_token_retry( + client: &reqwest::Client, + token_cache: &WecomTokenCache, + corp_id: &str, + secret: &str, + api_path: &str, + body: &serde_json::Value, +) -> Result { + let token = token_cache.get_token(client, corp_id, secret).await?; + let url = format!("{}{}?access_token={}", token_cache.base_url, api_path, token); + let resp: serde_json::Value = client.post(&url).json(body).send().await?.json().await?; + let errcode = resp["errcode"].as_i64().unwrap_or(-1); + + if errcode == 42001 { + warn!(api_path, "wecom: access_token expired, refreshing and retrying"); + let new_token = token_cache.force_refresh(client, corp_id, secret).await?; + let retry_url = format!("{}{}?access_token={}", token_cache.base_url, api_path, new_token); + let retry_resp: serde_json::Value = + client.post(&retry_url).json(body).send().await?.json().await?; + let retry_code = retry_resp["errcode"].as_i64().unwrap_or(-1); + if retry_code != 0 { + anyhow::bail!( + "wecom {} retry failed: errcode={}, errmsg={}", + api_path, + retry_code, + retry_resp["errmsg"] + ); + } + Ok(retry_resp) + } else if errcode != 0 { + anyhow::bail!( + "wecom {} failed: errcode={}, errmsg={}", + api_path, + errcode, + resp["errmsg"] + ); + } else { + Ok(resp) + } +} + +// --- Handlers --- + +fn handle_verify_request( + token: &str, + encoding_aes_key: &str, + corp_id: &str, + msg_signature: &str, + timestamp: &str, + nonce: &str, + echostr: &str, +) -> anyhow::Result { + if !verify_signature(token, timestamp, nonce, echostr, msg_signature) { + anyhow::bail!("signature verification failed"); + } + decrypt_message(encoding_aes_key, echostr, corp_id) +} + +// --- XML parsing --- + +struct CallbackEnvelope { + to_user_name: String, + encrypt: String, +} + +struct WecomMessage { + from_user: String, + msg_type: String, + content: String, + msg_id: String, + pic_url: String, + media_id: String, + file_name: String, +} + +fn parse_envelope_xml(xml: &str) -> Result { + use quick_xml::events::Event; + use quick_xml::Reader; + + let mut reader = Reader::from_str(xml); + let mut to_user_name = String::new(); + let mut encrypt = String::new(); + let mut current_tag = String::new(); + + loop { + match reader.read_event() { + Ok(Event::Start(e)) => { + current_tag = String::from_utf8_lossy(e.name().as_ref()).to_string(); + } + Ok(Event::CData(e)) => { + let text = String::from_utf8_lossy(&e).to_string(); + match current_tag.as_str() { + "ToUserName" => to_user_name = text, + "Encrypt" => encrypt = text, + _ => {} + } + } + Ok(Event::Text(e)) => { + let text = e.unescape().unwrap_or_default().to_string(); + match current_tag.as_str() { + "ToUserName" if to_user_name.is_empty() => to_user_name = text, + "Encrypt" if encrypt.is_empty() => encrypt = text, + _ => {} + } + } + Ok(Event::End(_)) => { + current_tag.clear(); + } + Ok(Event::Eof) => break, + Err(e) => anyhow::bail!("xml parse error: {e}"), + _ => {} + } + } + + if encrypt.is_empty() { + anyhow::bail!("missing Encrypt field in callback XML"); + } + Ok(CallbackEnvelope { + to_user_name, + encrypt, + }) +} + +fn parse_message_xml(xml: &str) -> Result { + use quick_xml::events::Event; + use quick_xml::Reader; + + let mut reader = Reader::from_str(xml); + let mut from_user = String::new(); + let mut msg_type = String::new(); + let mut content = String::new(); + let mut msg_id = String::new(); + let mut pic_url = String::new(); + let mut media_id = String::new(); + let mut file_name = String::new(); + let mut current_tag = String::new(); + + loop { + match reader.read_event() { + Ok(Event::Start(e)) => { + current_tag = String::from_utf8_lossy(e.name().as_ref()).to_string(); + } + Ok(Event::CData(e)) => { + let text = String::from_utf8_lossy(&e).to_string(); + match current_tag.as_str() { + "FromUserName" => from_user = text, + "MsgType" => msg_type = text, + "Content" => content = text, + "MsgId" => msg_id = text, + "PicUrl" => pic_url = text, + "MediaId" => media_id = text, + "FileName" => file_name = text, + _ => {} + } + } + Ok(Event::Text(e)) => { + let text = e.unescape().unwrap_or_default().to_string(); + match current_tag.as_str() { + "FromUserName" if from_user.is_empty() => from_user = text, + "MsgType" if msg_type.is_empty() => msg_type = text, + "Content" if content.is_empty() => content = text, + "MsgId" if msg_id.is_empty() => msg_id = text, + "PicUrl" if pic_url.is_empty() => pic_url = text, + "MediaId" if media_id.is_empty() => media_id = text, + "FileName" if file_name.is_empty() => file_name = text, + _ => {} + } + } + Ok(Event::End(_)) => { + current_tag.clear(); + } + Ok(Event::Eof) => break, + Err(e) => anyhow::bail!("xml parse error: {e}"), + _ => {} + } + } + + Ok(WecomMessage { + from_user, + msg_type, + content, + msg_id, + pic_url, + media_id, + file_name, + }) +} + +#[allow(clippy::too_many_arguments)] +async fn flush_thinking( + client: &reqwest::Client, + token_cache: &WecomTokenCache, + corp_id: &str, + secret: &str, + agent_id: &str, + thinking_msg_id: Option<&str>, + to_user: &str, + text: &str, +) { + info!(?thinking_msg_id, text_len = text.len(), "wecom: flush_thinking starting"); + + // Recall thinking placeholder (only when streaming was enabled) + if let Some(id) = thinking_msg_id { + let body = serde_json::json!({ "msgid": id }); + match post_with_token_retry( + client, + token_cache, + corp_id, + secret, + "/cgi-bin/message/recall", + &body, + ) + .await + { + Ok(resp) => info!(body = %resp, "wecom: recall response"), + Err(e) => warn!(error = %e, "wecom: recall failed"), + } + } + + // Send final text. Each chunk goes through retry-on-token-expiry so a + // long stream that outlives the cached token still delivers its reply. + let aid = agent_id.parse::().unwrap_or(0); + let chunks = split_text_lines(text, 2048); + info!(chunk_count = chunks.len(), "wecom: sending final chunks"); + for (i, chunk) in chunks.iter().enumerate() { + let body = serde_json::json!({ + "touser": to_user, + "msgtype": "text", + "agentid": aid, + "text": { "content": chunk } + }); + match post_with_token_retry( + client, + token_cache, + corp_id, + secret, + "/cgi-bin/message/send", + &body, + ) + .await + { + Ok(val) => { + let msg_id = val["msgid"].as_str().unwrap_or(""); + info!(msg_id = %msg_id, chunk_idx = i, "wecom: sent final reply chunk"); + } + Err(e) => warn!(error = %e, chunk_idx = i, "wecom flush send failed"), + } + } +} + +/// Split `text` into chunks that each fit within `limit` bytes (WeCom's +/// `message/send` truncates server-side at 2048 bytes). Splits prefer +/// newline boundaries; lines that exceed the limit themselves are split at +/// UTF-8 char boundaries via `char_indices()` so multibyte characters are +/// never severed mid-codepoint. The `limit` and all `len()` comparisons in +/// this function are in **bytes**, matching WeCom's server-side check. +fn split_text_lines(text: &str, limit: usize) -> Vec { + if text.len() <= limit { + return vec![text.to_string()]; + } + let mut chunks = Vec::new(); + let mut current = String::new(); + for line in text.split('\n') { + if line.len() > limit { + if !current.is_empty() { + chunks.push(current); + current = String::new(); + } + // Split long line at char boundaries + let mut pos = 0; + for (i, ch) in line.char_indices() { + if i - pos + ch.len_utf8() > limit { + chunks.push(line[pos..i].to_string()); + pos = i; + } + } + if pos < line.len() { + current = line[pos..].to_string(); + } + continue; + } + let candidate_len = if current.is_empty() { + line.len() + } else { + current.len() + 1 + line.len() + }; + if candidate_len > limit && !current.is_empty() { + chunks.push(current); + current = String::new(); + } + if !current.is_empty() { + current.push('\n'); + } + current.push_str(line); + } + if !current.is_empty() { + chunks.push(current); + } + chunks +} + +pub async fn verify( + State(state): State>, + query: axum::extract::Query>, +) -> axum::response::Response { + use axum::response::IntoResponse; + + let wecom = match state.wecom.as_ref() { + Some(w) => w, + None => return axum::http::StatusCode::SERVICE_UNAVAILABLE.into_response(), + }; + + let msg_signature = query.get("msg_signature").map(|s| s.as_str()).unwrap_or(""); + let timestamp = query.get("timestamp").map(|s| s.as_str()).unwrap_or(""); + let nonce = query.get("nonce").map(|s| s.as_str()).unwrap_or(""); + let echostr = query.get("echostr").map(|s| s.as_str()).unwrap_or(""); + + info!( + msg_signature = %msg_signature, + timestamp = %timestamp, + nonce = %nonce, + echostr_len = echostr.len(), + "wecom verify request received" + ); + + match handle_verify_request( + &wecom.config.token, + &wecom.config.encoding_aes_key, + &wecom.config.corp_id, + msg_signature, + timestamp, + nonce, + echostr, + ) { + Ok(plaintext) => plaintext.into_response(), + Err(e) => { + warn!("wecom callback verification failed: {e}"); + axum::http::StatusCode::FORBIDDEN.into_response() + } + } +} + +pub async fn webhook( + State(state): State>, + query: axum::extract::Query>, + body: axum::body::Bytes, +) -> axum::response::Response { + use axum::response::IntoResponse; + + let wecom = match state.wecom.as_ref() { + Some(w) => w, + None => return axum::http::StatusCode::SERVICE_UNAVAILABLE.into_response(), + }; + + let msg_signature = query.get("msg_signature").map(|s| s.as_str()).unwrap_or(""); + let timestamp = query.get("timestamp").map(|s| s.as_str()).unwrap_or(""); + let nonce = query.get("nonce").map(|s| s.as_str()).unwrap_or(""); + + // Reject stale callbacks. WeCom retries within ~5s, our dedup window is + // 30s, so a 5-minute freshness check rejects replays without false- + // positives on legitimate retries. The signature itself doesn't bind a + // freshness expectation, so without this an attacker who captured a + // signed payload could replay it indefinitely. + if let Ok(ts) = timestamp.parse::() { + let now = chrono::Utc::now().timestamp(); + if (now - ts).abs() > 300 { + warn!(timestamp_age_secs = now - ts, "wecom webhook: rejecting stale callback"); + return axum::http::StatusCode::FORBIDDEN.into_response(); + } + } + + let body_str = match std::str::from_utf8(&body) { + Ok(s) => s, + Err(_) => return axum::http::StatusCode::BAD_REQUEST.into_response(), + }; + + let envelope = match parse_envelope_xml(body_str) { + Ok(e) => e, + Err(e) => { + warn!("wecom envelope parse error: {e}"); + return axum::http::StatusCode::BAD_REQUEST.into_response(); + } + }; + + // ToUserName in the outer envelope must match our configured Corp ID. + // The decrypt step also validates the inner Corp ID suffix; checking here + // first surfaces misrouted callbacks before we touch crypto. + if envelope.to_user_name != wecom.config.corp_id { + warn!( + envelope_to = %envelope.to_user_name, + expected = %wecom.config.corp_id, + "wecom webhook: envelope ToUserName mismatch" + ); + return axum::http::StatusCode::FORBIDDEN.into_response(); + } + + if !verify_signature( + &wecom.config.token, + timestamp, + nonce, + &envelope.encrypt, + msg_signature, + ) { + warn!("wecom webhook signature verification failed"); + return axum::http::StatusCode::FORBIDDEN.into_response(); + } + + info!(encrypt_len = envelope.encrypt.len(), "wecom: decrypting callback"); + let decrypted = match decrypt_message( + &wecom.config.encoding_aes_key, + &envelope.encrypt, + &wecom.config.corp_id, + ) { + Ok(d) => { + info!("wecom: decrypt ok"); + d + } + Err(e) => { + warn!(encrypt_len = envelope.encrypt.len(), "wecom decrypt failed: {e}"); + return "success".into_response(); + } + }; + + let msg = match parse_message_xml(&decrypted) { + Ok(m) => m, + Err(e) => { + warn!("wecom message parse error: {e}"); + return "success".into_response(); + } + }; + + info!( + msg_type = %msg.msg_type, + has_pic_url = !msg.pic_url.is_empty(), + msg_id = %msg.msg_id, + "wecom: parsed message" + ); + + if !matches!(msg.msg_type.as_str(), "text" | "image" | "file") { + return "success".into_response(); + } + + if !wecom.dedupe.check_and_insert(&msg.msg_id) { + return "success".into_response(); + } + + let text = match msg.msg_type.as_str() { + "text" => msg.content.clone(), + "image" => "Describe this image.".to_string(), + "file" => format!("User sent a file: {}", msg.file_name), + _ => String::new(), + }; + + let mut attachments = Vec::new(); + if msg.msg_type == "image" && !msg.pic_url.is_empty() { + match download_wecom_image(&wecom.client, &msg.pic_url).await { + Some(att) => attachments.push(att), + None => info!("wecom: image download failed, forwarding without attachment"), + } + } + if msg.msg_type == "file" && !msg.media_id.is_empty() { + match download_wecom_file( + &wecom.client, + &wecom.token_cache, + &wecom.config.corp_id, + &wecom.config.secret, + &msg.media_id, + &msg.file_name, + ) + .await + { + Some(att) => attachments.push(att), + None => info!("wecom: file download failed, forwarding without attachment"), + } + } + + if text.trim().is_empty() && attachments.is_empty() { + return "success".into_response(); + } + + let channel_id = format!("wecom:{}:{}", wecom.config.corp_id, msg.from_user); + let mut event = crate::schema::GatewayEvent::new( + "wecom", + crate::schema::ChannelInfo { + id: channel_id, + channel_type: "direct".into(), + thread_id: None, + }, + crate::schema::SenderInfo { + id: msg.from_user.clone(), + name: msg.from_user.clone(), + display_name: msg.from_user.clone(), + is_bot: false, + }, + &text, + &msg.msg_id, + vec![], + ); + event.content.attachments = attachments; + + let att_sizes: Vec = event.content.attachments.iter().map(|a| a.data.len()).collect(); + info!( + attachments = event.content.attachments.len(), + text_len = event.content.text.len(), + att_data_sizes = ?att_sizes, + att_mime = ?event.content.attachments.iter().map(|a| a.mime_type.as_str()).collect::>(), + "wecom: forwarding event to OAB" + ); + if let Ok(json) = serde_json::to_string(&event) { + info!( + json_len = json.len(), + has_attachments_in_json = json.contains("\"attachments\""), + "wecom: event JSON ready" + ); + let _ = state.event_tx.send(json); + } + + "success".into_response() +} + +const IMAGE_MAX_DOWNLOAD: u64 = 10 * 1024 * 1024; +const IMAGE_MAX_DIMENSION_PX: u32 = 1200; +const IMAGE_JPEG_QUALITY: u8 = 75; + +async fn download_wecom_image( + client: &reqwest::Client, + pic_url: &str, +) -> Option { + // Only fetch over HTTPS. WeCom's CDN serves images over HTTPS; rejecting + // non-HTTPS URLs prevents SSRF if the AES key is ever compromised and + // an attacker forges a callback with PicUrl pointing at an internal host. + if !pic_url.starts_with("https://") { + warn!(pic_url, "wecom: rejecting non-HTTPS pic_url"); + return None; + } + info!(pic_url, "wecom: downloading image"); + let resp = match client.get(pic_url).send().await { + Ok(r) => r, + Err(e) => { + warn!(error = %e, "wecom image download failed"); + return None; + } + }; + if !resp.status().is_success() { + warn!(status = %resp.status(), "wecom image download failed"); + return None; + } + if let Some(cl) = resp.headers().get(reqwest::header::CONTENT_LENGTH) { + if let Ok(size) = cl.to_str().unwrap_or("0").parse::() { + if size > IMAGE_MAX_DOWNLOAD { + warn!(size, "wecom image exceeds 10MB limit, skipping"); + return None; + } + } + } + let bytes = resp.bytes().await.ok()?; + if bytes.len() as u64 > IMAGE_MAX_DOWNLOAD { + warn!(size = bytes.len(), "wecom image exceeds 10MB limit"); + return None; + } + let (compressed, mime) = match resize_and_compress(&bytes) { + Ok(v) => v, + Err(e) => { + warn!(error = %e, "wecom: image resize/compress failed"); + return None; + } + }; + let path = crate::store::store_media(&compressed).await?; + let ext = if mime == "image/gif" { "gif" } else { "jpg" }; + Some(crate::schema::Attachment { + attachment_type: "image".into(), + filename: format!("wecom_{}.{}", chrono::Utc::now().timestamp(), ext), + mime_type: mime, + data: String::new(), + size: compressed.len() as u64, + path: Some(path), + }) +} + +const FILE_MAX_DOWNLOAD: u64 = 20 * 1024 * 1024; + +const TEXT_EXTENSIONS: &[&str] = &[ + "txt", "csv", "log", "md", "json", "jsonl", "yaml", "yml", "toml", "xml", "rs", "py", "js", + "ts", "jsx", "tsx", "go", "java", "c", "cpp", "h", "hpp", "rb", "sh", "bash", "zsh", "fish", + "ps1", "bat", "sql", "html", "css", "scss", "less", "ini", "cfg", "conf", "env", + "swift", "kt", "scala", "r", "pl", "lua", "graphql", "tsv", +]; + +const TEXT_FILENAMES: &[&str] = &[ + "dockerfile", "makefile", "justfile", "rakefile", "gemfile", + "procfile", "vagrantfile", ".gitignore", ".dockerignore", ".editorconfig", +]; + +fn is_text_file(filename: &str) -> bool { + let lower = filename.to_lowercase(); + if lower.contains('.') { + if let Some(ext) = lower.rsplit('.').next() { + if TEXT_EXTENSIONS.contains(&ext) { + return true; + } + } + } + TEXT_FILENAMES.contains(&lower.as_str()) +} + +/// GET /cgi-bin/media/get with token-expiry retry. The media API returns +/// JSON `{"errcode":42001,...}` instead of binary when the token is stale, +/// so we sniff Content-Type and retry once with a force-refreshed token. +async fn fetch_media_with_retry( + client: &reqwest::Client, + token_cache: &WecomTokenCache, + corp_id: &str, + secret: &str, + media_id: &str, +) -> Result { + let token = token_cache.get_token(client, corp_id, secret).await?; + let url = format!( + "{}/cgi-bin/media/get?access_token={}&media_id={}", + token_cache.base_url, token, media_id + ); + let resp = client.get(&url).send().await?; + let content_type = resp + .headers() + .get(reqwest::header::CONTENT_TYPE) + .and_then(|v| v.to_str().ok()) + .unwrap_or("") + .to_string(); + if !content_type.contains("json") { + return Ok(resp); + } + // JSON body means error path. Inspect for 42001 and retry once. + let body = resp.text().await.unwrap_or_default(); + let val: serde_json::Value = serde_json::from_str(&body).unwrap_or_default(); + let errcode = val["errcode"].as_i64().unwrap_or(-1); + if errcode == 42001 { + warn!("wecom media: access_token expired, refreshing and retrying"); + let new_token = token_cache.force_refresh(client, corp_id, secret).await?; + let retry_url = format!( + "{}/cgi-bin/media/get?access_token={}&media_id={}", + token_cache.base_url, new_token, media_id + ); + return Ok(client.get(&retry_url).send().await?); + } + anyhow::bail!("wecom media error: {body}") +} + +async fn download_wecom_file( + client: &reqwest::Client, + token_cache: &WecomTokenCache, + corp_id: &str, + secret: &str, + media_id: &str, + filename: &str, +) -> Option { + info!(filename, media_id, "wecom: downloading file"); + let resp = match fetch_media_with_retry(client, token_cache, corp_id, secret, media_id).await { + Ok(r) => r, + Err(e) => { + warn!(error = %e, "wecom file download failed"); + return None; + } + }; + if !resp.status().is_success() { + warn!(status = %resp.status(), "wecom file download failed"); + return None; + } + if let Some(cl) = resp.headers().get(reqwest::header::CONTENT_LENGTH) { + if let Ok(size) = cl.to_str().unwrap_or("0").parse::() { + if size > FILE_MAX_DOWNLOAD { + warn!(size, "wecom file exceeds 20MB limit, skipping"); + return None; + } + } + } + let bytes = resp.bytes().await.ok()?; + if bytes.len() as u64 > FILE_MAX_DOWNLOAD { + warn!(size = bytes.len(), "wecom file exceeds 20MB limit"); + return None; + } + + if !is_text_file(filename) { + info!(filename, "wecom: skipping non-text file"); + return None; + } + + let text_content = match String::from_utf8(bytes.to_vec()) { + Ok(s) => s, + Err(_) => { + info!(filename, "wecom: file is not valid UTF-8, skipping"); + return None; + } + }; + + let path = crate::store::store_media(text_content.as_bytes()).await?; + let size = text_content.len() as u64; + + Some(crate::schema::Attachment { + attachment_type: "text_file".into(), + filename: filename.to_string(), + mime_type: "text/plain".into(), + data: String::new(), + size, + path: Some(path), + }) +} + +fn resize_and_compress(raw: &[u8]) -> Result<(Vec, String), image::ImageError> { + use image::ImageReader; + use std::io::Cursor; + + let reader = ImageReader::new(Cursor::new(raw)).with_guessed_format()?; + let format = reader.format(); + if format == Some(image::ImageFormat::Gif) { + return Ok((raw.to_vec(), "image/gif".to_string())); + } + let img = reader.decode()?; + let (w, h) = (img.width(), img.height()); + let img = if w > IMAGE_MAX_DIMENSION_PX || h > IMAGE_MAX_DIMENSION_PX { + let max_side = std::cmp::max(w, h); + let ratio = f64::from(IMAGE_MAX_DIMENSION_PX) / f64::from(max_side); + let new_w = (f64::from(w) * ratio) as u32; + let new_h = (f64::from(h) * ratio) as u32; + img.resize(new_w, new_h, image::imageops::FilterType::Lanczos3) + } else { + img + }; + let mut buf = Cursor::new(Vec::new()); + let encoder = image::codecs::jpeg::JpegEncoder::new_with_quality(&mut buf, IMAGE_JPEG_QUALITY); + img.write_with_encoder(encoder)?; + Ok((buf.into_inner(), "image/jpeg".to_string())) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_env(pairs: &[(&str, &str)]) -> impl Fn(&str) -> Option { + let map: std::collections::HashMap = pairs + .iter() + .map(|(k, v)| ((*k).to_string(), (*v).to_string())) + .collect(); + move |k: &str| map.get(k).cloned() + } + + #[test] + fn config_from_env_all_present() { + let env = make_env(&[ + ("WECOM_CORP_ID", "ww_test_corp"), + ("WECOM_SECRET", "test_secret"), + ("WECOM_TOKEN", "test_token"), + ("WECOM_ENCODING_AES_KEY", "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG"), + ("WECOM_AGENT_ID", "1000002"), + ]); + let config = WecomConfig::from_reader(env).unwrap(); + assert_eq!(config.corp_id, "ww_test_corp"); + assert_eq!(config.agent_id, "1000002"); + assert_eq!(config.webhook_path, "/webhook/wecom"); + assert!(!config.streaming_enabled, "streaming defaults off"); + assert_eq!(config.debounce_secs, 3); + } + + #[test] + fn config_from_env_missing_required() { + let env = make_env(&[]); + assert!(WecomConfig::from_reader(env).is_none()); + } + + fn encrypt_for_test(encoding_aes_key: &str, msg: &str, corp_id: &str) -> String { + use aes::cipher::{BlockEncryptMut, KeyIvInit}; + use base64::Engine; + + let key = decode_aes_key(encoding_aes_key).unwrap(); + let iv = &key[..16]; + + let msg_bytes = msg.as_bytes(); + let corp_id_bytes = corp_id.as_bytes(); + let msg_len = (msg_bytes.len() as u32).to_be_bytes(); + + let mut plaintext = Vec::new(); + plaintext.extend_from_slice(&[0u8; 16]); // random bytes (zeros for test) + plaintext.extend_from_slice(&msg_len); + plaintext.extend_from_slice(msg_bytes); + plaintext.extend_from_slice(corp_id_bytes); + + // WeCom uses PKCS7 padding with block_size=32 + let block_size = 32; + let pad_len = block_size - (plaintext.len() % block_size); + for _ in 0..pad_len { + plaintext.push(pad_len as u8); + } + + // Encrypt with NoPadding since we already padded manually + let total_len = plaintext.len(); + let mut buf = vec![0u8; total_len + 16]; // extra space just in case + buf[..total_len].copy_from_slice(&plaintext); + + type Aes256CbcEnc = cbc::Encryptor; + let encryptor = Aes256CbcEnc::new_from_slices(&key, iv).unwrap(); + let encrypted = encryptor + .encrypt_padded_mut::(&mut buf, total_len) + .unwrap(); + + base64::engine::general_purpose::STANDARD.encode(encrypted) + } + + #[test] + fn aes_key_decode() { + let key_str = "QUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUE"; + let key_bytes = decode_aes_key(key_str).unwrap(); + assert_eq!(key_bytes.len(), 32); + } + + #[test] + fn signature_verify() { + let token = "testtoken"; + let timestamp = "1409659813"; + let nonce = "1372623149"; + let encrypt = "msg_encrypt_content"; + + let sig = compute_signature(token, timestamp, nonce, encrypt); + assert!(verify_signature(token, timestamp, nonce, encrypt, &sig)); + assert!(!verify_signature( + token, + timestamp, + nonce, + encrypt, + "wrong_signature_value_here" + )); + } + + #[test] + fn decrypt_wecom_payload() { + let key_str = "QUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUE"; + let corp_id = "ww_test_corp"; + let msg = "hello world"; + + let encrypted = encrypt_for_test(key_str, msg, corp_id); + let decrypted = decrypt_message(key_str, &encrypted, corp_id).unwrap(); + assert_eq!(decrypted, msg); + } + + #[test] + fn verify_callback_echostr() { + let token = "testtoken"; + let encoding_aes_key = "QUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUE"; + let corp_id = "ww_test_corp"; + let echostr_plain = "success_echo_string"; + + let echostr_encrypted = encrypt_for_test(encoding_aes_key, echostr_plain, corp_id); + let sig = compute_signature(token, "1409659813", "nonce123", &echostr_encrypted); + + let result = handle_verify_request( + token, + encoding_aes_key, + corp_id, + &sig, + "1409659813", + "nonce123", + &echostr_encrypted, + ); + assert_eq!(result.unwrap(), echostr_plain); + } + + #[test] + fn parse_text_message_xml() { + let xml = r#"134883186012345678901234561000002"#; + + let msg = parse_message_xml(xml).unwrap(); + assert_eq!(msg.from_user, "user001"); + assert_eq!(msg.msg_type, "text"); + assert_eq!(msg.content, "hello bot"); + assert_eq!(msg.msg_id, "1234567890123456"); + } + + #[test] + fn parse_callback_envelope() { + let xml = r#""#; + + let envelope = parse_envelope_xml(xml).unwrap(); + assert_eq!(envelope.to_user_name, "ww_test_corp"); + assert_eq!(envelope.encrypt, "some_encrypted_base64"); + } + + #[test] + fn dedupe_rejects_duplicates() { + let cache = DedupeCache::new(); + assert!(cache.check_and_insert("msg_001")); + assert!(!cache.check_and_insert("msg_001")); + assert!(cache.check_and_insert("msg_002")); + } + + #[tokio::test] + async fn token_refresh_success() { + use wiremock::matchers::{method, query_param}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(query_param("corpid", "ww_test_corp")) + .and(query_param("corpsecret", "test_secret")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "errcode": 0, + "errmsg": "ok", + "access_token": "test_token_abc", + "expires_in": 7200 + }))) + .expect(1) + .mount(&server) + .await; + + let cache = WecomTokenCache::with_base_url(server.uri()); + let client = reqwest::Client::new(); + let token = cache.get_token(&client, "ww_test_corp", "test_secret").await.unwrap(); + assert_eq!(token, "test_token_abc"); + + // Second call uses cache (mock expects exactly 1 call) + let token2 = cache.get_token(&client, "ww_test_corp", "test_secret").await.unwrap(); + assert_eq!(token2, "test_token_abc"); + } + + #[test] + fn split_text_lines_multi() { + let text = "line1\nline2\nline3"; + let chunks = split_text_lines(text, 11); + assert_eq!(chunks.len(), 2); + assert_eq!(chunks[0], "line1\nline2"); + assert_eq!(chunks[1], "line3"); + } + + #[test] + fn split_text_lines_within_limit() { + let text = "short"; + let chunks = split_text_lines(text, 100); + assert_eq!(chunks, vec!["short"]); + } + + #[test] + fn split_text_lines_long_line() { + let text = "abcdefghij"; + let chunks = split_text_lines(text, 4); + assert_eq!(chunks, vec!["abcd", "efgh", "ij"]); + } + + #[test] + fn split_text_lines_long_line_utf8() { + let text = "你好世界測試"; // 18 bytes, 6 chars + let chunks = split_text_lines(text, 6); + assert_eq!(chunks, vec!["你好", "世界", "測試"]); + } + + #[test] + fn is_text_file_check() { + assert!(is_text_file("readme.md")); + assert!(is_text_file("config.json")); + assert!(is_text_file("data.csv")); + assert!(is_text_file("MAIN.PY")); + assert!(!is_text_file("photo.png")); + assert!(!is_text_file("archive.zip")); + assert!(!is_text_file("doc.pdf")); + } + + #[test] + fn parse_file_message() { + let xml = r#"134883186066661000002"#; + let msg = parse_message_xml(xml).unwrap(); + assert_eq!(msg.msg_type, "file"); + assert_eq!(msg.media_id, "media_abc123"); + assert_eq!(msg.file_name, "report.csv"); + } + + #[test] + fn full_webhook_decrypt_and_parse() { + let token = "testtoken"; + let encoding_aes_key = "QUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUE"; + let corp_id = "ww_test_corp"; + let timestamp = "1409659813"; + let nonce = "nonce123"; + + // Simulate the inner message + let inner_xml = "134883186099991000002"; + + // Encrypt it + let encrypted = encrypt_for_test(encoding_aes_key, inner_xml, corp_id); + + // Compute signature + let sig = compute_signature(token, timestamp, nonce, &encrypted); + + // Verify signature + assert!(verify_signature(token, timestamp, nonce, &encrypted, &sig)); + + // Decrypt + let decrypted = decrypt_message(encoding_aes_key, &encrypted, corp_id).unwrap(); + assert_eq!(decrypted, inner_xml); + + // Parse + let msg = parse_message_xml(&decrypted).unwrap(); + assert_eq!(msg.from_user, "user42"); + assert_eq!(msg.msg_type, "text"); + assert_eq!(msg.content, "ping"); + assert_eq!(msg.msg_id, "9999"); + } + + #[test] + fn parse_image_message() { + let encoding_aes_key = "QUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUE"; + let corp_id = "ww_test_corp"; + + let inner_xml = "134883186088881000002"; + + let encrypted = encrypt_for_test(encoding_aes_key, inner_xml, corp_id); + let decrypted = decrypt_message(encoding_aes_key, &encrypted, corp_id).unwrap(); + let msg = parse_message_xml(&decrypted).unwrap(); + assert_eq!(msg.msg_type, "image"); + assert_eq!(msg.pic_url, "http://example.com/pic.jpg"); + assert_eq!(msg.from_user, "user42"); + } + + #[test] + fn unsupported_msg_type_skipped() { + let xml = "134883186077771000002"; + let msg = parse_message_xml(xml).unwrap(); + assert_eq!(msg.msg_type, "voice"); + assert!(!matches!(msg.msg_type.as_str(), "text" | "image")); + } + + #[test] + fn verify_rejects_wrong_signature() { + let token = "testtoken"; + let encoding_aes_key = "QUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUE"; + let corp_id = "ww_test_corp"; + let echostr_plain = "test_echo"; + + let echostr_encrypted = encrypt_for_test(encoding_aes_key, echostr_plain, corp_id); + + let result = handle_verify_request( + token, + encoding_aes_key, + corp_id, + "completely_wrong_signature", + "1409659813", + "nonce123", + &echostr_encrypted, + ); + assert!(result.is_err()); + } + + #[test] + fn decrypt_with_large_padding_value() { + // Verifies decryption works when WeCom's 32-byte padding exceeds 16 + let encoding_aes_key = "QUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUE"; + let corp_id = "ww_test_corp"; + // Choose a message where (16 + 4 + msg_len + corp_id_len) % 32 < 16, + // producing a pad value > 16 which would fail with PKCS7/block_size=16. + // 16 + 4 + 1 + 12 = 33 → 33 % 32 = 1 → pad = 31 + let msg = "x"; + let encrypted = encrypt_for_test(encoding_aes_key, msg, corp_id); + let decrypted = decrypt_message(encoding_aes_key, &encrypted, corp_id).unwrap(); + assert_eq!(decrypted, msg); + } + + #[test] + fn decrypt_rejects_wrong_corp_id() { + let encoding_aes_key = "QUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUE"; + let corp_id = "ww_test_corp"; + let msg = "hello"; + + let encrypted = encrypt_for_test(encoding_aes_key, msg, corp_id); + let result = decrypt_message(encoding_aes_key, &encrypted, "ww_other_corp"); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("corp_id mismatch")); + } +} diff --git a/crates/openab-gateway/src/lib.rs b/crates/openab-gateway/src/lib.rs new file mode 100644 index 000000000..ad11f9db1 --- /dev/null +++ b/crates/openab-gateway/src/lib.rs @@ -0,0 +1,4 @@ +pub mod adapters; +pub mod media; +pub mod schema; +pub mod store; diff --git a/crates/openab-gateway/src/media.rs b/crates/openab-gateway/src/media.rs new file mode 100644 index 000000000..f6eb88565 --- /dev/null +++ b/crates/openab-gateway/src/media.rs @@ -0,0 +1,123 @@ +use image::ImageReader; +use std::io::Cursor; + +/// Media type for download functions — avoids stringly-typed branching. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MediaKind { + Image, + Audio, +} + +pub const IMAGE_MAX_DIMENSION_PX: u32 = 1200; +pub const IMAGE_JPEG_QUALITY: u8 = 75; +pub const IMAGE_MAX_DOWNLOAD: u64 = 10 * 1024 * 1024; // 10 MB +pub const FILE_MAX_DOWNLOAD: u64 = 20 * 1024 * 1024; // 20 MB (same as store cap) +pub const AUDIO_MAX_DOWNLOAD: u64 = 20 * 1024 * 1024; // 20 MB +pub const GIF_MAX_SIZE: usize = 5 * 1024 * 1024; // 5 MB — prevents base64 bloat exceeding LLM payload limits + +/// Resize image so longest side <= 1200px, then encode as JPEG. +/// GIFs under 5MB are passed through unchanged to preserve animation. +pub fn resize_and_compress(raw: &[u8]) -> Result<(Vec, String), image::ImageError> { + let reader = ImageReader::new(Cursor::new(raw)).with_guessed_format()?; + let format = reader.format(); + if format == Some(image::ImageFormat::Gif) { + if raw.len() > GIF_MAX_SIZE { + return Err(image::ImageError::Limits( + image::error::LimitError::from_kind(image::error::LimitErrorKind::DimensionError), + )); + } + return Ok((raw.to_vec(), "image/gif".to_string())); + } + let img = reader.decode()?; + let (w, h) = (img.width(), img.height()); + let img = if w > IMAGE_MAX_DIMENSION_PX || h > IMAGE_MAX_DIMENSION_PX { + let max_side = std::cmp::max(w, h); + let ratio = f64::from(IMAGE_MAX_DIMENSION_PX) / f64::from(max_side); + let new_w = (f64::from(w) * ratio) as u32; + let new_h = (f64::from(h) * ratio) as u32; + img.resize(new_w, new_h, image::imageops::FilterType::Lanczos3) + } else { + img + }; + let mut buf = Cursor::new(Vec::new()); + let encoder = image::codecs::jpeg::JpegEncoder::new_with_quality(&mut buf, IMAGE_JPEG_QUALITY); + img.write_with_encoder(encoder)?; + Ok((buf.into_inner(), "image/jpeg".to_string())) +} + +/// Derive file extension from Content-Type for audio files. +pub fn audio_extension(content_type: &str) -> &'static str { + if content_type.contains("mpeg") || content_type.contains("mp3") { + "mp3" + } else if content_type.contains("m4a") || content_type.contains("mp4") { + "m4a" + } else { + "ogg" + } +} + +/// Check if a filename has a text-like extension suitable for reading as UTF-8. +pub fn is_text_extension(filename: &str) -> bool { + const TEXT_EXTS: &[&str] = &[ + "txt", "csv", "log", "md", "json", "jsonl", "yaml", "yml", "toml", "xml", "rs", "py", + "js", "ts", "jsx", "tsx", "go", "java", "c", "cpp", "h", "hpp", "rb", "sh", "bash", + "sql", "html", "css", "ini", "cfg", "conf", + ]; + let ext = filename.rsplit('.').next().unwrap_or("").to_lowercase(); + TEXT_EXTS.contains(&ext.as_str()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn gif_under_limit_passes_through() { + let gif = b"GIF89a\x01\x00\x01\x00\x80\x00\x00\xff\xff\xff\x00\x00\x00!\xf9\x04\x00\x00\x00\x00\x00,\x00\x00\x00\x00\x01\x00\x01\x00\x00\x02\x02D\x01\x00;"; + let result = resize_and_compress(gif); + assert!(result.is_ok()); + let (data, mime) = result.unwrap(); + assert_eq!(mime, "image/gif"); + assert_eq!(data, gif); + } + + #[test] + fn gif_over_limit_returns_error() { + let mut data = b"GIF89a\x01\x00\x01\x00\x80\x00\x00\xff\xff\xff\x00\x00\x00".to_vec(); + data.resize(GIF_MAX_SIZE + 1, 0); + let result = resize_and_compress(&data); + assert!(result.is_err()); + } + + #[test] + fn small_jpeg_not_resized() { + let img = image::RgbImage::from_pixel(2, 2, image::Rgb([255, 0, 0])); + let mut buf = std::io::Cursor::new(Vec::new()); + img.write_to(&mut buf, image::ImageFormat::Jpeg).unwrap(); + let result = resize_and_compress(&buf.into_inner()); + assert!(result.is_ok()); + assert_eq!(result.unwrap().1, "image/jpeg"); + } + + #[test] + fn large_image_gets_resized() { + let img = image::RgbImage::from_pixel(2000, 2000, image::Rgb([0, 128, 255])); + let mut buf = std::io::Cursor::new(Vec::new()); + img.write_to(&mut buf, image::ImageFormat::Png).unwrap(); + let result = resize_and_compress(&buf.into_inner()); + assert!(result.is_ok()); + let (data, mime) = result.unwrap(); + assert_eq!(mime, "image/jpeg"); + let decoded = image::load_from_memory(&data).unwrap(); + assert!(decoded.width() <= IMAGE_MAX_DIMENSION_PX); + assert!(decoded.height() <= IMAGE_MAX_DIMENSION_PX); + } + + #[test] + fn text_extension_check() { + assert!(is_text_extension("main.rs")); + assert!(is_text_extension("data.csv")); + assert!(!is_text_extension("archive.zip")); + assert!(!is_text_extension("photo.jpg")); + } +} diff --git a/crates/openab-gateway/src/schema.rs b/crates/openab-gateway/src/schema.rs new file mode 100644 index 000000000..740d0fab8 --- /dev/null +++ b/crates/openab-gateway/src/schema.rs @@ -0,0 +1,126 @@ +use serde::{Deserialize, Serialize}; + +// --- Event schema (ADR openab.gateway.event.v1) --- + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct GatewayEvent { + pub schema: String, + pub event_id: String, + pub timestamp: String, + pub platform: String, + pub event_type: String, + pub channel: ChannelInfo, + pub sender: SenderInfo, + pub content: Content, + pub mentions: Vec, + pub message_id: String, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ChannelInfo { + pub id: String, + #[serde(rename = "type")] + pub channel_type: String, + pub thread_id: Option, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct SenderInfo { + pub id: String, + pub name: String, + pub display_name: String, + pub is_bot: bool, +} + +#[derive(Clone, Debug, Default, Serialize, Deserialize)] +pub struct Content { + #[serde(rename = "type")] + pub content_type: String, + pub text: String, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub attachments: Vec, +} + +#[derive(Clone, Debug, Default, Serialize, Deserialize)] +pub struct Attachment { + #[serde(rename = "type")] + pub attachment_type: String, // "image", "text_file", "audio" + pub filename: String, + pub mime_type: String, + /// Base64-encoded data (deprecated — use `path` for colocate mode). + /// Kept for backward compatibility; Core prefers `path` when present. + #[serde(default, skip_serializing_if = "String::is_empty")] + pub data: String, + pub size: u64, // size in bytes (after compression for images) + /// Local file path for colocate mode (gateway + core share filesystem). + /// When set, Core reads bytes directly from this path instead of decoding `data`. + /// Path format: ~/.openab/media/inbound/ (no extension, MIME in mime_type). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub path: Option, +} + +// --- Reply schema (ADR openab.gateway.reply.v1) --- + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct GatewayReply { + pub schema: String, + pub reply_to: String, + pub platform: String, + pub channel: ReplyChannel, + pub content: Content, + #[serde(default)] + pub command: Option, + #[serde(default)] + pub request_id: Option, + /// When set, send this message as a reply/quote to the specified platform message ID. + /// Unlike `reply_to` (which identifies the triggering event for routing/dedup), + /// this field controls the visual reply/quote UI on the platform. + /// If quoting fails, the gateway MUST fall back to sending without quoting. + #[serde(default)] + pub quote_message_id: Option, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ReplyChannel { + pub id: String, + pub thread_id: Option, +} + +/// Response from gateway back to OAB for commands (e.g. create_topic) +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct GatewayResponse { + pub schema: String, + pub request_id: String, + pub success: bool, + pub thread_id: Option, + pub message_id: Option, + pub error: Option, +} + +impl GatewayEvent { + pub fn new( + platform: &str, + channel: ChannelInfo, + sender: SenderInfo, + text: &str, + message_id: &str, + mentions: Vec, + ) -> Self { + Self { + schema: "openab.gateway.event.v1".into(), + event_id: format!("evt_{}", uuid::Uuid::new_v4()), + timestamp: chrono::Utc::now().to_rfc3339(), + platform: platform.into(), + event_type: "message".into(), + channel, + sender, + content: Content { + content_type: "text".into(), + text: text.into(), + attachments: Vec::new(), + }, + mentions, + message_id: message_id.into(), + } + } +} diff --git a/crates/openab-gateway/src/store.rs b/crates/openab-gateway/src/store.rs new file mode 100644 index 000000000..b08e69903 --- /dev/null +++ b/crates/openab-gateway/src/store.rs @@ -0,0 +1,132 @@ +use std::path::{Path, PathBuf}; +use tokio::fs; +use tracing::{error, info}; +use uuid::Uuid; + +/// Inbound media directory under $HOME. +/// Pattern follows OpenClaw's `~/.openclaw/media/inbound/`. +/// +/// # Security Considerations +/// +/// - **Path traversal prevention**: Filenames are always server-generated UUIDs, +/// never user-supplied. No extension, no special characters — eliminates path +/// traversal attacks (e.g. `../../etc/passwd`). +/// +/// - **No auth token leakage**: Platform media URLs (Telegram getFile, LINE Content API) +/// contain bot tokens or require auth headers. By downloading in the gateway and +/// storing locally, tokens never reach Core or the agent. +/// +/// - **TTL auto-eviction**: Files are evicted after 2 minutes. Prevents disk exhaustion +/// from accumulated media and limits the window for any leaked file to be exploited. +/// +/// - **Colocate trust boundary**: This module assumes gateway and core share the same +/// filesystem (same pod / same $HOME). The file path is passed over the internal WS +/// connection — never exposed externally. If gateway and core are separated in the +/// future, switch to HTTP media proxy with internal-only binding. +/// +/// - **Size limits enforced before write**: Callers must validate file size against +/// IMAGE_MAX_DOWNLOAD / AUDIO_MAX_DOWNLOAD / FILE_MAX_DOWNLOAD before calling +/// `store_media()`. This module does NOT re-validate — it trusts the caller. +/// +/// - **No executable content**: Stored files are raw bytes (images, audio, text). +/// Core reads them as data only — never executed. The `mime_type` in the event +/// payload determines processing path, not the file content or name. +const MEDIA_INBOUND_DIR: &str = ".openab/media/inbound"; + +/// TTL for stored media files (2 minutes) +const TTL_SECS: u64 = 120; + +/// Get the inbound media directory path, creating it if needed. +pub async fn media_dir() -> PathBuf { + let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".into()); + let dir = Path::new(&home).join(MEDIA_INBOUND_DIR); + if !dir.exists() { + let _ = fs::create_dir_all(&dir).await; + } + dir +} + +/// Maximum file size accepted by store (defense-in-depth, callers should pre-check). +const MAX_STORE_SIZE: usize = 20 * 1024 * 1024; // 20 MB (matches AUDIO_MAX_DOWNLOAD) + +/// Store media bytes to disk, return the absolute file path. +/// Filename is UUID only (no extension) — MIME type is carried in the event payload. +/// Rejects files exceeding MAX_STORE_SIZE as a defense-in-depth measure. +pub async fn store_media(bytes: &[u8]) -> Option { + if bytes.len() > MAX_STORE_SIZE { + error!(size = bytes.len(), max = MAX_STORE_SIZE, "store_media rejected: exceeds size limit"); + return None; + } + let dir = media_dir().await; + let filename = Uuid::new_v4().to_string(); + let path = dir.join(&filename); + match fs::write(&path, bytes).await { + Ok(_) => { + info!(path = %path.display(), size = bytes.len(), "media stored"); + Some(path.to_string_lossy().into_owned()) + } + Err(e) => { + error!(error = %e, "failed to store media file"); + None + } + } +} + +/// Background task: evict files older than TTL_SECS. +pub async fn eviction_loop() { + let mut interval = tokio::time::interval(std::time::Duration::from_secs(30)); + loop { + interval.tick().await; + if let Err(e) = evict_expired().await { + error!(error = %e, "media eviction error"); + } + } +} + +async fn evict_expired() -> std::io::Result<()> { + let dir = media_dir().await; + if !dir.exists() { + return Ok(()); + } + let mut entries = fs::read_dir(&dir).await?; + let now = std::time::SystemTime::now(); + while let Some(entry) = entries.next_entry().await? { + if let Ok(meta) = entry.metadata().await { + if let Ok(modified) = meta.modified() { + if let Ok(age) = now.duration_since(modified) { + if age.as_secs() > TTL_SECS { + let path = entry.path(); + let _ = fs::remove_file(&path).await; + tracing::debug!(path = %path.display(), "evicted expired media"); + } + } + } + } + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn store_and_read_back() { + let data = b"hello media"; + let path = store_media(data).await.unwrap(); + let read_back = fs::read(&path).await.unwrap(); + assert_eq!(read_back, data); + // Cleanup + let _ = fs::remove_file(&path).await; + } + + #[tokio::test] + async fn filename_is_uuid_no_extension() { + let path = store_media(b"test").await.unwrap(); + let filename = Path::new(&path).file_name().unwrap().to_str().unwrap(); + // UUID v4 format: 8-4-4-4-12 hex chars + assert_eq!(filename.len(), 36); + assert!(!filename.contains('.')); + let _ = fs::remove_file(&path).await; + } +} From 00adf39d100f13c0201506656db721d786acb35f Mon Sep 17 00:00:00 2001 From: chaodu-agent Date: Thu, 18 Jun 2026 20:17:50 +0000 Subject: [PATCH 02/20] fix: wire workspace into main.rs, remove duplicated source MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses review findings from 口渡法師: - src/main.rs now uses openab_core:: imports instead of local mod declarations - Removed old src/ modules (now live in crates/openab-core/src/) - Removed old gateway/src/ (now in crates/openab-gateway/src/) - Feature-gated discord/slack in main.rs with #[cfg(feature = ...)] - No more code duplication between src/ and crates/ --- Cargo.toml | 1 + gateway/src/adapters/feishu.rs | 3928 ---------------------------- gateway/src/adapters/googlechat.rs | 2470 ----------------- gateway/src/adapters/line.rs | 780 ------ gateway/src/adapters/mod.rs | 6 - gateway/src/adapters/teams.rs | 877 ------- gateway/src/adapters/telegram.rs | 782 ------ gateway/src/adapters/wecom.rs | 1654 ------------ gateway/src/main.rs | 801 ------ gateway/src/media.rs | 123 - gateway/src/schema.rs | 126 - gateway/src/store.rs | 132 - src/acp/agentcore.rs | 722 ----- src/acp/connection.rs | 937 ------- src/acp/mod.rs | 9 - src/acp/pool.rs | 622 ----- src/acp/protocol.rs | 406 --- src/adapter.rs | 1659 ------------ src/bot_turns.rs | 368 --- src/config.rs | 1500 ----------- src/cron.rs | 1768 ------------- src/directives.rs | 314 --- src/discord.rs | 3203 ----------------------- src/dispatch.rs | 1727 ------------ src/error_display.rs | 323 --- src/format.rs | 327 --- src/gateway.rs | 1054 -------- src/hooks.rs | 425 --- src/main.rs | 95 +- src/markdown.rs | 349 --- src/media.rs | 846 ------ src/multibot_cache.rs | 85 - src/reactions.rs | 276 -- src/remind.rs | 399 --- src/secrets.rs | 479 ---- src/setup/config.rs | 157 -- src/setup/mod.rs | 12 - src/setup/validate.rs | 78 - src/setup/wizard.rs | 667 ----- src/slack.rs | 2329 ----------------- src/stt.rs | 354 --- src/timestamp.rs | 114 - 42 files changed, 46 insertions(+), 33238 deletions(-) delete mode 100644 gateway/src/adapters/feishu.rs delete mode 100644 gateway/src/adapters/googlechat.rs delete mode 100644 gateway/src/adapters/line.rs delete mode 100644 gateway/src/adapters/mod.rs delete mode 100644 gateway/src/adapters/teams.rs delete mode 100644 gateway/src/adapters/telegram.rs delete mode 100644 gateway/src/adapters/wecom.rs delete mode 100644 gateway/src/main.rs delete mode 100644 gateway/src/media.rs delete mode 100644 gateway/src/schema.rs delete mode 100644 gateway/src/store.rs delete mode 100644 src/acp/agentcore.rs delete mode 100644 src/acp/connection.rs delete mode 100644 src/acp/mod.rs delete mode 100644 src/acp/pool.rs delete mode 100644 src/acp/protocol.rs delete mode 100644 src/adapter.rs delete mode 100644 src/bot_turns.rs delete mode 100644 src/config.rs delete mode 100644 src/cron.rs delete mode 100644 src/directives.rs delete mode 100644 src/discord.rs delete mode 100644 src/dispatch.rs delete mode 100644 src/error_display.rs delete mode 100644 src/format.rs delete mode 100644 src/gateway.rs delete mode 100644 src/hooks.rs delete mode 100644 src/markdown.rs delete mode 100644 src/media.rs delete mode 100644 src/multibot_cache.rs delete mode 100644 src/reactions.rs delete mode 100644 src/remind.rs delete mode 100644 src/secrets.rs delete mode 100644 src/setup/config.rs delete mode 100644 src/setup/mod.rs delete mode 100644 src/setup/validate.rs delete mode 100644 src/setup/wizard.rs delete mode 100644 src/slack.rs delete mode 100644 src/stt.rs delete mode 100644 src/timestamp.rs diff --git a/Cargo.toml b/Cargo.toml index 88b2456d6..abf9fecc0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] } clap = { version = "4", features = ["derive"] } anyhow = "1" +serenity = { version = "0.12", default-features = false, features = ["client", "gateway", "model", "rustls_backend", "cache"] } [features] # Default: core only (Discord + Slack). Gateway ships as separate binary. diff --git a/gateway/src/adapters/feishu.rs b/gateway/src/adapters/feishu.rs deleted file mode 100644 index 84e5ac017..000000000 --- a/gateway/src/adapters/feishu.rs +++ /dev/null @@ -1,3928 +0,0 @@ -use crate::schema::*; -use axum::extract::State; -use prost::Message as ProstMessage; -use serde::Deserialize; -use std::collections::{HashMap, VecDeque}; -use std::sync::Arc; -use std::time::Instant; -use tokio::sync::RwLock; -use tracing::{info, warn}; - -/// Timing-safe string comparison to prevent side-channel attacks on tokens. -fn constant_time_eq(a: &str, b: &str) -> bool { - use subtle::ConstantTimeEq; - if a.len() != b.len() { - return false; - } - a.as_bytes().ct_eq(b.as_bytes()).into() -} - -// --------------------------------------------------------------------------- -// Feishu WebSocket protobuf frame (pbbp2.Frame) -// --------------------------------------------------------------------------- - -#[derive(Clone, PartialEq, ProstMessage)] -pub struct WsFrame { - #[prost(uint64, tag = "1")] - pub seq_id: u64, - #[prost(uint64, tag = "2")] - pub log_id: u64, - #[prost(int32, tag = "3")] - pub service: i32, - #[prost(int32, tag = "4")] - pub method: i32, - #[prost(message, repeated, tag = "5")] - pub headers: Vec, - #[prost(string, optional, tag = "6")] - pub payload_encoding: Option, - #[prost(string, optional, tag = "7")] - pub payload_type: Option, - #[prost(bytes = "vec", optional, tag = "8")] - pub payload: Option>, - #[prost(string, optional, tag = "9")] - pub log_id_new: Option, -} - -#[derive(Clone, PartialEq, ProstMessage)] -pub struct WsHeader { - #[prost(string, tag = "1")] - pub key: String, - #[prost(string, tag = "2")] - pub value: String, -} - -// --------------------------------------------------------------------------- -// Configuration -// --------------------------------------------------------------------------- - -#[derive(Debug, Clone, PartialEq)] -pub enum ConnectionMode { - Websocket, - Webhook, -} - -#[derive(Debug, Clone, PartialEq)] -pub enum AllowBots { - Off, - Mentions, - All, -} - -/// Controls when the bot responds without @mention in threads. -/// Mirrors Discord's `allow_user_messages` setting. -#[derive(Debug, Clone, PartialEq, Default)] -pub enum AllowUsers { - /// Bot responds in threads it has participated in without @mention. - Involved, - /// Always require @mention, even in participated threads. - Mentions, - /// Like Involved, but if another bot has also posted in the thread, - /// require @mention to avoid all bots responding. - #[default] - MultibotMentions, -} - -#[derive(Debug, Clone)] -pub struct FeishuConfig { - pub app_id: String, - pub app_secret: String, - pub domain: String, - pub connection_mode: ConnectionMode, - pub webhook_path: String, - pub verification_token: Option, - pub encrypt_key: Option, - pub allowed_groups: Vec, - pub allowed_users: Vec, - pub require_mention: bool, - pub allow_bots: AllowBots, - pub allow_user_messages: AllowUsers, - pub trusted_bot_ids: Vec, - pub max_bot_turns: u32, - pub dedupe_ttl_secs: u64, - pub message_limit: usize, - /// TTL for participated-thread cache entries (seconds). Threads older than - /// this are forgotten and require a fresh @mention to re-engage. - /// Set to 0 (via FEISHU_SESSION_TTL_HOURS=0) to disable participation - /// tracking entirely — all messages will require @mention. - /// Converted from `FEISHU_SESSION_TTL_HOURS` (user-facing, in hours) to seconds internally. - pub session_ttl_secs: u64, - /// Override the API base URL. Used in tests to point at a mock server. - /// Always None in production (not read from env). - pub api_base_override: Option, -} - -impl FeishuConfig { - /// Build config from environment variables. Returns None if FEISHU_APP_ID - /// is not set (adapter disabled). - pub fn from_env() -> Option { - let app_id = std::env::var("FEISHU_APP_ID").ok()?; - let app_secret = std::env::var("FEISHU_APP_SECRET").ok().unwrap_or_default(); - if app_secret.is_empty() { - warn!("FEISHU_APP_ID set but FEISHU_APP_SECRET is empty"); - return None; - } - let domain = std::env::var("FEISHU_DOMAIN").unwrap_or_else(|_| "feishu".into()); - let connection_mode = match std::env::var("FEISHU_CONNECTION_MODE") - .unwrap_or_else(|_| "websocket".into()) - .to_lowercase() - .as_str() - { - "webhook" => ConnectionMode::Webhook, - _ => ConnectionMode::Websocket, - }; - let webhook_path = std::env::var("FEISHU_WEBHOOK_PATH") - .unwrap_or_else(|_| "/webhook/feishu".into()); - let verification_token = std::env::var("FEISHU_VERIFICATION_TOKEN").ok(); - let encrypt_key = std::env::var("FEISHU_ENCRYPT_KEY").ok(); - let allowed_groups = parse_csv("FEISHU_ALLOWED_GROUPS"); - let allowed_users = parse_csv("FEISHU_ALLOWED_USERS"); - let require_mention = std::env::var("FEISHU_REQUIRE_MENTION") - .map(|v| v != "false" && v != "0") - .unwrap_or(true); - let allow_bots = match std::env::var("FEISHU_ALLOW_BOTS") - .unwrap_or_else(|_| "off".into()) - .to_lowercase() - .as_str() - { - "mentions" => AllowBots::Mentions, - "all" => AllowBots::All, - _ => AllowBots::Off, - }; - let trusted_bot_ids = parse_csv("FEISHU_TRUSTED_BOT_IDS"); - let allow_user_messages = match std::env::var("FEISHU_ALLOW_USER_MESSAGES") - .unwrap_or_else(|_| "multibot_mentions".into()) - .to_lowercase() - .replace('-', "_") - .as_str() - { - "involved" => AllowUsers::Involved, - "mentions" => AllowUsers::Mentions, - _ => AllowUsers::MultibotMentions, - }; - let max_bot_turns = std::env::var("FEISHU_MAX_BOT_TURNS") - .ok() - .and_then(|v| v.parse().ok()) - .unwrap_or(20); - let dedupe_ttl_secs = std::env::var("FEISHU_DEDUPE_TTL_SECS") - .ok() - .and_then(|v| v.parse().ok()) - .unwrap_or(300); - let message_limit = std::env::var("FEISHU_MESSAGE_LIMIT") - .ok() - .and_then(|v| v.parse().ok()) - .unwrap_or(4000); - let session_ttl_secs = std::env::var("FEISHU_SESSION_TTL_HOURS") - .ok() - .and_then(|v| v.parse::().ok()) - .unwrap_or(24) - * 3600; - - Some(Self { - app_id, - app_secret, - domain, - connection_mode, - webhook_path, - verification_token, - encrypt_key, - allowed_groups, - allowed_users, - require_mention, - allow_bots, - allow_user_messages, - trusted_bot_ids, - max_bot_turns, - dedupe_ttl_secs, - message_limit, - session_ttl_secs, - api_base_override: None, - }) - } - - /// API base URL for the configured domain. - pub fn api_base(&self) -> String { - if let Some(ref base) = self.api_base_override { - return base.clone(); - } - if self.domain == "lark" { - "https://open.larksuite.com".into() - } else { - "https://open.feishu.cn".into() - } - } -} - -fn parse_csv(var: &str) -> Vec { - std::env::var(var) - .unwrap_or_default() - .split(',') - .map(|s| s.trim().to_string()) - .filter(|s| !s.is_empty()) - .collect() -} - -// --------------------------------------------------------------------------- -// Feishu event types (im.message.receive_v1) -// --------------------------------------------------------------------------- - -mod event_types { - use super::*; - - #[derive(Debug, Deserialize)] - pub struct FeishuEventEnvelope { - pub header: Option, - pub event: Option, - pub challenge: Option, - // Parsed by serde, not consumed in current code paths. - #[allow(dead_code)] - #[serde(rename = "type")] - pub event_type_field: Option, - } - - #[derive(Debug, Deserialize)] - pub struct FeishuEventHeader { - pub event_id: Option, - // Parsed by serde, not consumed in current code paths. - #[allow(dead_code)] - pub event_type: Option, - } - - #[derive(Debug, Deserialize)] - pub struct FeishuEventBody { - pub sender: Option, - pub message: Option, - } - - #[derive(Debug, Deserialize)] - pub struct FeishuSender { - pub sender_id: Option, - pub sender_type: Option, - } - - #[derive(Debug, Deserialize)] - pub struct FeishuSenderId { - pub open_id: Option, - } - - #[derive(Debug, Deserialize)] - pub struct FeishuMessage { - pub message_id: Option, - pub chat_id: Option, - pub chat_type: Option, - pub message_type: Option, - pub content: Option, - pub mentions: Option>, - pub root_id: Option, - pub parent_id: Option, - } - - #[derive(Debug, Deserialize)] - pub struct FeishuMention { - pub key: Option, - pub id: Option, - // Parsed by serde, not consumed in current code paths. - #[allow(dead_code)] - pub name: Option, - } - - #[derive(Debug, Deserialize)] - pub struct FeishuMentionId { - pub open_id: Option, - } - - /// Parse a feishu im.message.receive_v1 event into a GatewayEvent. - /// Returns None if the event should be skipped (unsupported type, bot message, etc). - /// The Vec contains references to media that need async download. - /// - /// `bypass_mention_gating`: whether the bot should skip @mention requirement for this message. - /// This is the final computed result from mode-specific logic (detect_and_mark_multibot), - /// already accounting for the configured `allow_user_messages` mode. - /// Do NOT pass raw participation status here. - pub fn parse_message_event( - envelope: &FeishuEventEnvelope, - bot_open_id: Option<&str>, - config: &FeishuConfig, - bypass_mention_gating: bool, - ) -> Option<(GatewayEvent, Vec)> { - let _header = envelope.header.as_ref()?; - let event = envelope.event.as_ref()?; - let msg = event.message.as_ref()?; - let sender = event.sender.as_ref()?; - - let msg_type = msg.message_type.as_deref().unwrap_or("text"); - if !matches!(msg_type, "text" | "image" | "file" | "post" | "audio") { - return None; - } - // Skip bot messages with explicit sender_type - if matches!(sender.sender_type.as_deref(), Some("bot") | Some("app")) { - return None; - } - - let sender_open_id = sender.sender_id.as_ref()?.open_id.as_deref()?; - // Skip messages from self - if let Some(bot_id) = bot_open_id { - if sender_open_id == bot_id { - return None; - } - } - - // Check if sender is a known bot: - // Bot identification: - // 1. If trusted_bot_ids is configured, check against it - // 2. If trusted_bot_ids is empty, we cannot reliably identify bots - // (Feishu marks other bots as sender_type="user") - let is_bot_sender = if !config.trusted_bot_ids.is_empty() { - config.trusted_bot_ids.iter().any(|id| id == sender_open_id) - } else { - false - }; - - // User allowlist: if configured, only allow listed users. - // Trusted bots bypass user allowlist (same as Discord behavior). - if !is_bot_sender - && !config.allowed_users.is_empty() - && !config.allowed_users.iter().any(|u| u == sender_open_id) - { - return None; - } - - if is_bot_sender { - match config.allow_bots { - AllowBots::Off => return None, - AllowBots::Mentions | AllowBots::All => { - // Allowed — will check mentions below for Mentions mode - } - } - } - - let chat_id = msg.chat_id.as_deref()?; - // Group allowlist: if configured, only allow listed groups - let is_group = msg.chat_type.as_deref() != Some("p2p"); - if is_group - && !config.allowed_groups.is_empty() - && !config.allowed_groups.iter().any(|g| g == chat_id) - { - return None; - } - - let content_json: serde_json::Value = msg.content.as_deref() - .and_then(|s| serde_json::from_str(s).ok())?; - - let message_id = msg.message_id.as_deref()?; - - // Parse content based on message type - let (clean_text, mention_ids, media_refs) = match msg_type { - "image" => { - let image_key = content_json.get("image_key")?.as_str()?; - let mentions = extract_mentions( - "", msg.mentions.as_deref().unwrap_or(&[]), bot_open_id, - ); - let refs = vec![MediaRef::Image { - message_id: message_id.to_string(), - image_key: image_key.to_string(), - }]; - (String::new(), mentions.1, refs) - } - "file" => { - let file_key = content_json.get("file_key")?.as_str()?; - let file_name = content_json.get("file_name") - .and_then(|v| v.as_str()) - .unwrap_or("unknown"); - let mentions = extract_mentions( - "", msg.mentions.as_deref().unwrap_or(&[]), bot_open_id, - ); - let refs = vec![MediaRef::File { - message_id: message_id.to_string(), - file_key: file_key.to_string(), - file_name: file_name.to_string(), - }]; - (String::new(), mentions.1, refs) - } - "audio" => { - let file_key = content_json.get("file_key")?.as_str()?; - let mentions = extract_mentions( - "", msg.mentions.as_deref().unwrap_or(&[]), bot_open_id, - ); - let refs = vec![MediaRef::Audio { - message_id: message_id.to_string(), - file_key: file_key.to_string(), - }]; - (String::new(), mentions.1, refs) - } - "post" => { - // Rich text: content is {"title":"...","content":[[{tag,text,...},{tag,image_key,...}]]} - let mut texts = Vec::new(); - let mut refs = Vec::new(); - if let Some(rows) = content_json.get("content").and_then(|v| v.as_array()) { - for row in rows { - if let Some(elements) = row.as_array() { - for el in elements { - match el.get("tag").and_then(|v| v.as_str()) { - Some("text") => { - if let Some(t) = el.get("text").and_then(|v| v.as_str()) { - texts.push(t.to_string()); - } - } - Some("img") => { - if let Some(key) = el.get("image_key").and_then(|v| v.as_str()) { - refs.push(MediaRef::Image { - message_id: message_id.to_string(), - image_key: key.to_string(), - }); - } - } - Some("a") => { - if let Some(t) = el.get("text").and_then(|v| v.as_str()) { - texts.push(t.to_string()); - } - } - Some("at") => { - // Mentions handled via msg.mentions at envelope level - } - _ => {} - } - } - } - } - } - let raw_text = texts.join(""); - let (clean, ids) = extract_mentions( - &raw_text, - msg.mentions.as_deref().unwrap_or(&[]), - bot_open_id, - ); - (clean, ids, refs) - } - _ => { - // text - let raw_text = content_json.get("text").and_then(|v| v.as_str()).unwrap_or(""); - if raw_text.trim().is_empty() { - return None; - } - let (clean, ids) = extract_mentions( - raw_text, - msg.mentions.as_deref().unwrap_or(&[]), - bot_open_id, - ); - if clean.trim().is_empty() { - return None; - } - (clean, ids, Vec::new()) - } - }; - - let channel_type = match msg.chat_type.as_deref() { - Some("p2p") => "direct", - _ => "group", - }; - - let thread_id = msg.root_id.clone().or_else(|| msg.parent_id.clone()); - - // Gateway-side mention gating: in groups, skip if require_mention - // is true and bot is not mentioned (for human senders). - // Bypass: if bot has previously replied in this thread (participated), - // no @mention needed (like Discord's "involved" mode). - let in_thread = thread_id.is_some(); - if channel_type == "group" - && !is_bot_sender - && config.require_mention - && !(in_thread && bypass_mention_gating) - { - if let Some(bot_id) = bot_open_id { - let bot_mentioned = mention_ids.iter().any(|id| id == bot_id); - if !bot_mentioned { - return None; - } - } - } - - // Bot-to-bot mention gating: in AllowBots::Mentions mode, - // bot messages must @mention this bot (like Discord "mentions" mode). - // Note: in DMs there is no @mention mechanism, so bot DMs are - // silently dropped in Mentions mode. Use AllowBots::All for DM bots. - if is_bot_sender && config.allow_bots == AllowBots::Mentions { - if let Some(bot_id) = bot_open_id { - let bot_mentioned = mention_ids.iter().any(|id| id == bot_id); - if !bot_mentioned { - return None; - } - } - } - - let event = GatewayEvent::new( - "feishu", - ChannelInfo { - id: chat_id.to_string(), - channel_type: channel_type.to_string(), - thread_id, - }, - SenderInfo { - id: sender_open_id.to_string(), - name: sender_open_id.to_string(), - display_name: sender_open_id.to_string(), - is_bot: is_bot_sender, - }, - clean_text.trim(), - message_id, - mention_ids, - ); - Some((event, media_refs)) - } - - fn extract_mentions( - raw_text: &str, - mentions: &[FeishuMention], - bot_open_id: Option<&str>, - ) -> (String, Vec) { - let mut text = raw_text.to_string(); - let mut ids = Vec::new(); - for m in mentions { - let open_id = m.id.as_ref().and_then(|id| id.open_id.as_deref()); - if let Some(oid) = open_id { - ids.push(oid.to_string()); - if let Some(key) = m.key.as_deref() { - if bot_open_id == Some(oid) { - text = text.replacen(key, "", 1); - } - } - } - } - (text, ids) - } -} - -pub use event_types::*; - -// --------------------------------------------------------------------------- -// Deduplication -// --------------------------------------------------------------------------- - -pub struct DedupeCache { - seen: std::sync::Mutex>, - ttl_secs: u64, - max_size: usize, -} - -impl DedupeCache { - pub fn new(ttl_secs: u64) -> Self { - Self { - seen: std::sync::Mutex::new(HashMap::new()), - ttl_secs, - max_size: 10_000, - } - } - - /// Returns true if this id was already seen (duplicate). - pub fn is_duplicate(&self, id: &str) -> bool { - let mut map = self.seen.lock().unwrap_or_else(|e| e.into_inner()); - // Lazy sweep - if map.len() >= self.max_size { - map.retain(|_, ts| ts.elapsed().as_secs() < self.ttl_secs); - } - if let Some(ts) = map.get(id) { - if ts.elapsed().as_secs() < self.ttl_secs { - return true; - } - } - map.insert(id.to_string(), Instant::now()); - false - } -} - -// --------------------------------------------------------------------------- -// Token cache -// --------------------------------------------------------------------------- - -pub struct FeishuTokenCache { - /// (token, created_at, ttl_secs) - token: RwLock>, - api_base: String, - app_id: String, - app_secret: String, -} - -/// Refresh margin: renew 5 minutes before expiry. -const TOKEN_REFRESH_MARGIN_SECS: u64 = 300; - -impl FeishuTokenCache { - pub fn new(config: &FeishuConfig) -> Self { - Self { - token: RwLock::new(None), - api_base: config.api_base(), - app_id: config.app_id.clone(), - app_secret: config.app_secret.clone(), - } - } - - /// Construct with explicit api_base (for tests). - pub fn with_base(config: &FeishuConfig, api_base: &str) -> Self { - Self { - token: RwLock::new(None), - api_base: api_base.to_string(), - app_id: config.app_id.clone(), - app_secret: config.app_secret.clone(), - } - } - - /// Get a valid tenant_access_token, refreshing if expired or missing. - pub async fn get_token(&self, client: &reqwest::Client) -> anyhow::Result { - // Fast path: read lock - { - let guard = self.token.read().await; - if let Some((ref tok, ref ts, ttl)) = *guard { - if ts.elapsed().as_secs() < ttl.saturating_sub(TOKEN_REFRESH_MARGIN_SECS) { - return Ok(tok.clone()); - } - } - } - // Slow path: write lock + refresh - let mut guard = self.token.write().await; - // Double-check after acquiring write lock - if let Some((ref tok, ref ts, ttl)) = *guard { - if ts.elapsed().as_secs() < ttl.saturating_sub(TOKEN_REFRESH_MARGIN_SECS) { - return Ok(tok.clone()); - } - } - let (new_token, expire) = self.refresh(client).await?; - *guard = Some((new_token.clone(), Instant::now(), expire)); - Ok(new_token) - } - - async fn refresh(&self, client: &reqwest::Client) -> anyhow::Result<(String, u64)> { - let url = format!( - "{}/open-apis/auth/v3/tenant_access_token/internal", - self.api_base - ); - let resp = client - .post(&url) - .json(&serde_json::json!({ - "app_id": self.app_id, - "app_secret": self.app_secret, - })) - .send() - .await - .map_err(|e| anyhow::anyhow!("feishu token refresh request failed: {e}"))?; - - let status = resp.status(); - let body: serde_json::Value = resp - .json() - .await - .map_err(|e| anyhow::anyhow!("feishu token refresh parse failed: {e}"))?; - - let code = body.get("code").and_then(|v| v.as_i64()).unwrap_or(-1); - if code != 0 { - let msg = body - .get("msg") - .and_then(|v| v.as_str()) - .unwrap_or("unknown"); - anyhow::bail!("feishu token refresh error: code={code} msg={msg} status={status}"); - } - - let expire = body.get("expire").and_then(|v| v.as_u64()).unwrap_or(7200); - - let token = body.get("tenant_access_token") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()) - .ok_or_else(|| anyhow::anyhow!("feishu token refresh: missing tenant_access_token"))?; - - Ok((token, expire)) - } -} - -// --------------------------------------------------------------------------- -// Adapter (aggregated state) -// --------------------------------------------------------------------------- - -pub struct FeishuAdapter { - pub config: FeishuConfig, - pub token_cache: Arc, - pub bot_open_id: Arc>>, - pub dedupe: Arc, - pub rate_limiter: Arc, - pub name_cache: Arc>>, - /// Per-channel bot turn counter. Key = chat_id, Value = (count, last_reset). - /// Human message resets count to 0. Prevents runaway bot-to-bot loops. - pub bot_turns: Arc>>, // eviction: human msg resets; follow-up can add TTL like participated_threads - /// Positive-only cache: thread_id (root_id) → last_replied_at. - /// When bot has replied in a thread, subsequent messages in that thread - /// bypass @mention gating (like Discord's "involved" mode). - pub participated_threads: Arc>>, - /// Positive-only cache: thread_id → first_seen for threads where other bots - /// have posted. Used by multibot-mentions mode to require @mention. - pub multibot_threads: Arc>>, - /// Per-message edit count tracker for Feishu's 20-edits-per-message hard cap - /// (errcode 230072 — "The message has reached the number of times it can be edited"). - /// Insertion-order FIFO eviction: when over `EDIT_COUNTS_CACHE_MAX`, the - /// oldest *insertions* are dropped, not the lowest-count entries — so a - /// just-started active stream is far less likely to be evicted than under a - /// count-ascending policy. (A very long-lived stream can still age out once - /// 4096 newer messages have been inserted behind it; that resets its count - /// to 1, which is acceptable — it only loses the local preemptive margin and - /// the on-wire 230072 sentinel still backstops.) - pub edit_counts: Arc>, - pub client: reqwest::Client, -} - -/// Insertion-order edit-count cache for Feishu's per-message edit cap. -/// -/// `counts` holds the current edit count (or `u32::MAX` cap-reached sentinel) -/// for each message_id. `order` records insertion order so eviction is FIFO -/// rather than count-ascending; this matters because count-ascending would -/// preferentially target *active* streams (low count = just started) while -/// leaving stale cap-reached entries in place. FIFO instead ages out the -/// oldest insertions, which strongly favours keeping active streams. -#[derive(Default)] -pub struct EditCountsCache { - pub counts: HashMap, - pub order: VecDeque, -} - -impl FeishuAdapter { - pub fn new(config: FeishuConfig) -> Self { - let token_cache = Arc::new(FeishuTokenCache::new(&config)); - let dedupe = Arc::new(DedupeCache::new(config.dedupe_ttl_secs)); - let rate_limiter = Arc::new(RateLimiter::new(60, 120)); - Self { - config, - token_cache, - dedupe, - rate_limiter, - bot_open_id: Arc::new(RwLock::new(None)), - name_cache: Arc::new(std::sync::Mutex::new(HashMap::new())), - bot_turns: Arc::new(std::sync::Mutex::new(HashMap::new())), - participated_threads: Arc::new(std::sync::Mutex::new(HashMap::new())), - multibot_threads: Arc::new(std::sync::Mutex::new(HashMap::new())), - edit_counts: Arc::new(std::sync::Mutex::new(EditCountsCache::default())), - client: reqwest::Client::new(), - } - } - - /// Resolve bot identity (open_id) via API. Called during startup for both - /// WebSocket and webhook modes so mention gating works in either mode. - pub async fn resolve_bot_identity(&self) { - let token = match self.token_cache.get_token(&self.client).await { - Ok(t) => t, - Err(e) => { - warn!(err = %e, "feishu bot identity lookup failed (token error), mention gating may not work"); - return; - } - }; - match get_bot_info(&self.client, &self.config.api_base(), &token).await { - Ok(bot_id) => { - info!(bot_open_id = %bot_id, "feishu bot identity resolved"); - *self.bot_open_id.write().await = Some(bot_id); - } - Err(e) => { - warn!(err = %e, "feishu bot identity lookup failed, mention gating may not work"); - } - } - } -} - -// --------------------------------------------------------------------------- -// WebSocket long connection -// --------------------------------------------------------------------------- - -use futures_util::{SinkExt, StreamExt}; -use tokio::sync::{broadcast, watch}; - -/// Get WebSocket endpoint URL from feishu API. -/// Note: This API uses AppID+AppSecret directly, not Bearer token. -async fn get_ws_endpoint( - client: &reqwest::Client, - api_base: &str, - app_id: &str, - app_secret: &str, -) -> anyhow::Result { - let url = format!("{}/callback/ws/endpoint", api_base); - let resp = client - .post(&url) - .json(&serde_json::json!({ - "AppID": app_id, - "AppSecret": app_secret, - })) - .send() - .await?; - let body: serde_json::Value = resp.json().await?; - let code = body.get("code").and_then(|v| v.as_i64()).unwrap_or(-1); - if code != 0 { - let msg = body.get("msg").and_then(|v| v.as_str()).unwrap_or("unknown"); - anyhow::bail!("feishu ws endpoint error: code={code} msg={msg}"); - } - body.get("data") - .and_then(|d| d.get("URL")) - .and_then(|u| u.as_str()) - .map(|s| s.to_string()) - .ok_or_else(|| anyhow::anyhow!("feishu ws endpoint: missing URL")) -} - -/// Get bot identity (open_id) via bot info API. -async fn get_bot_info( - client: &reqwest::Client, - api_base: &str, - token: &str, -) -> anyhow::Result { - let url = format!("{}/open-apis/bot/v3/info", api_base); - let resp = client.get(&url).bearer_auth(token).send().await?; - let body: serde_json::Value = resp.json().await?; - body.get("bot") - .and_then(|b| b.get("open_id")) - .and_then(|v| v.as_str()) - .map(|s| s.to_string()) - .ok_or_else(|| anyhow::anyhow!("feishu bot info: missing open_id")) -} - -/// Spawn the feishu WebSocket long-connection task. -/// Returns a JoinHandle that runs until shutdown_rx fires. -pub async fn start_websocket( - adapter: &FeishuAdapter, - event_tx: broadcast::Sender, - mut shutdown_rx: watch::Receiver, -) -> anyhow::Result> { - let token_cache = adapter.token_cache.clone(); - let bot_open_id_store = adapter.bot_open_id.clone(); - let dedupe = adapter.dedupe.clone(); - let config = adapter.config.clone(); - let client = adapter.client.clone(); - let name_cache = adapter.name_cache.clone(); - let bot_turns = adapter.bot_turns.clone(); - let participated_threads = adapter.participated_threads.clone(); - let multibot_threads = adapter.multibot_threads.clone(); - - let handle = tokio::spawn(async move { - let mut backoff_secs = 1u64; - loop { - let result = ws_connect_loop( - &token_cache, - &bot_open_id_store, - &dedupe, - &config, - &client, - &event_tx, - &mut shutdown_rx, - &name_cache, - &bot_turns, - &participated_threads, - &multibot_threads, - ) - .await; - - if *shutdown_rx.borrow() { - info!("feishu websocket shutting down"); - break; - } - - match result { - Ok(()) => { - info!("feishu websocket disconnected, reconnecting..."); - backoff_secs = 1; - } - Err(e) => { - tracing::error!(err = %e, backoff = backoff_secs, "feishu websocket error, reconnecting..."); - backoff_secs = (backoff_secs * 2).min(120); - } - } - - tokio::select! { - _ = tokio::time::sleep(std::time::Duration::from_secs(backoff_secs)) => {} - _ = shutdown_rx.changed() => { break; } - } - } - }); - - Ok(handle) -} - -/// Single WebSocket connection lifecycle. -#[allow(clippy::too_many_arguments)] -async fn ws_connect_loop( - token_cache: &Arc, - bot_open_id_store: &Arc>>, - dedupe: &Arc, - config: &FeishuConfig, - client: &reqwest::Client, - event_tx: &broadcast::Sender, - shutdown_rx: &mut watch::Receiver, - name_cache: &Arc>>, - bot_turns: &Arc>>, - participated_threads: &Arc>>, - multibot_threads: &Arc>>, -) -> anyhow::Result<()> { - let api_base = config.api_base(); - - // Refresh bot identity on each reconnect in case it was not resolved earlier - if bot_open_id_store.read().await.is_none() { - if let Ok(token) = token_cache.get_token(client).await { - if let Ok(bot_id) = get_bot_info(client, &api_base, &token).await { - info!(bot_open_id = %bot_id, "feishu bot identity resolved on reconnect"); - *bot_open_id_store.write().await = Some(bot_id); - } - } - } - - let ws_url = get_ws_endpoint(client, &api_base, &config.app_id, &config.app_secret).await?; - info!(url = %ws_url, "feishu websocket connecting"); - - let (ws_stream, _) = tokio_tungstenite::connect_async(&ws_url).await?; - let (mut ws_tx, mut ws_rx) = ws_stream.split(); - info!("feishu websocket connected"); - - loop { - tokio::select! { - msg = ws_rx.next() => { - match msg { - Some(Ok(tokio_tungstenite::tungstenite::Message::Text(text))) => { - handle_ws_message( - &text, bot_open_id_store, dedupe, config, event_tx, - name_cache, token_cache, client, bot_turns, participated_threads, multibot_threads, - ).await; - } - Some(Ok(tokio_tungstenite::tungstenite::Message::Ping(data))) => { - let _ = ws_tx.send(tokio_tungstenite::tungstenite::Message::Pong(data)).await; - } - Some(Ok(tokio_tungstenite::tungstenite::Message::Close(_))) | None => { - return Ok(()); - } - Some(Err(e)) => { - return Err(anyhow::anyhow!("websocket error: {e}")); - } - Some(Ok(tokio_tungstenite::tungstenite::Message::Binary(data))) => { - match WsFrame::decode(data.as_ref()) { - Ok(frame) => { - // method=1 is data frame (events), method=0 is control - if frame.method == 1 { - if let Some(ref payload) = frame.payload { - if let Ok(text) = String::from_utf8(payload.clone()) { - handle_ws_message( - &text, bot_open_id_store, dedupe, config, event_tx, - name_cache, token_cache, client, bot_turns, participated_threads, multibot_threads, - ).await; - } - } - // Send ACK: echo frame back with {"code":200} payload - let mut ack = frame.clone(); - ack.payload = Some(b"{\"code\":200}".to_vec()); - let ack_bytes = ack.encode_to_vec(); - let _ = ws_tx.send( - tokio_tungstenite::tungstenite::Message::Binary(ack_bytes) - ).await; - } - } - Err(e) => { - tracing::debug!(err = %e, len = data.len(), "feishu ws protobuf decode failed"); - } - } - } - _ => {} - } - } - _ = shutdown_rx.changed() => { - let _ = ws_tx.send(tokio_tungstenite::tungstenite::Message::Close(None)).await; - return Ok(()); - } - } - } -} - -/// Process a single WebSocket text message. -#[allow(clippy::too_many_arguments)] -async fn handle_ws_message( - text: &str, - bot_open_id_store: &Arc>>, - dedupe: &Arc, - config: &FeishuConfig, - event_tx: &broadcast::Sender, - name_cache: &Arc>>, - token_cache: &Arc, - client: &reqwest::Client, - bot_turns: &Arc>>, - participated_threads: &Arc>>, - multibot_threads: &Arc>>, -) { - let envelope: FeishuEventEnvelope = match serde_json::from_str(text) { - Ok(e) => e, - Err(_) => return, - }; - - // Handle challenge frame (Feishu may send this in WS mode for verification) - if let Some(ref challenge) = envelope.challenge { - tracing::debug!(challenge = %challenge, "feishu ws challenge received (ignored in WS mode)"); - return; - } - - // Debug: log sender_type for diagnosing bot-to-bot loops - if let Some(ref event) = envelope.event { - if let Some(ref sender) = event.sender { - tracing::debug!( - sender_type = ?sender.sender_type, - sender_id = ?sender.sender_id.as_ref().and_then(|s| s.open_id.as_deref()), - "feishu ws event sender" - ); - } - } - - // Dedupe by event_id - if let Some(ref header) = envelope.header { - if let Some(ref event_id) = header.event_id { - if dedupe.is_duplicate(event_id) { - return; - } - } - } - - let bot_id = bot_open_id_store.read().await; - let bot_id_ref = bot_id.as_deref(); - - // Check if the message is in a thread where bot has previously replied, - // respecting the allow_user_messages mode: - // - Involved (default): bypass @mention if participated - // - MultibotMentions: bypass only if participated AND no other bot in thread - // - Mentions: never bypass - let bypass_mention = detect_and_mark_multibot( - &envelope, bot_id_ref, config, participated_threads, multibot_threads, - ); - - if let Some((mut gateway_event, media_refs)) = parse_message_event(&envelope, bot_id_ref, config, bypass_mention) { - // Also dedupe by message_id - if dedupe.is_duplicate(&gateway_event.message_id) { - return; - } - - // Bot turn tracking: prevent runaway bot-to-bot loops - let channel_id = &gateway_event.channel.id; - { - let mut turns = bot_turns.lock().unwrap_or_else(|e| e.into_inner()); - if gateway_event.sender.is_bot { - let count = turns.entry(channel_id.to_string()).or_insert(0); - *count += 1; - if *count > config.max_bot_turns { - warn!( - channel = %channel_id, - count = *count, - max = config.max_bot_turns, - "feishu: bot turn limit reached, dropping message" - ); - return; - } - // (Feishu doesn't push bot messages to other bots' WebSocket, - // so multibot detection is done via mentions instead — see below.) - } else { - // Human message resets bot turn counter - turns.remove(channel_id.as_str()); - } - } - - // Resolve sender display name (lazy, cached) - let name = resolve_user_name( - &gateway_event.sender.id, name_cache, token_cache, client, &config.api_base(), - ).await; - gateway_event.sender.name = name.clone(); - gateway_event.sender.display_name = name; - - // Download media attachments (images, text files) - if !media_refs.is_empty() { - if let Ok(token) = token_cache.get_token(client).await { - let api_base = config.api_base(); - for media_ref in &media_refs { - let attachment = match media_ref { - MediaRef::Image { message_id, image_key } => { - download_feishu_image(client, &api_base, &token, message_id, image_key).await - } - MediaRef::File { message_id, file_key, file_name } => { - download_feishu_file(client, &api_base, &token, message_id, file_key, file_name).await - } - MediaRef::Audio { message_id, file_key } => { - download_feishu_audio(client, &api_base, &token, message_id, file_key).await - } - }; - if let Some(att) = attachment { - gateway_event.content.attachments.push(att); - } - } - } - } - - // Skip if no text and no attachments (e.g. unsupported file type) - if gateway_event.content.text.trim().is_empty() && gateway_event.content.attachments.is_empty() { - return; - } - - let json = serde_json::to_string(&gateway_event).unwrap(); - info!( - channel = %gateway_event.channel.id, - thread_id = ?gateway_event.channel.thread_id, - sender = %gateway_event.sender.id, - "feishu → gateway" - ); - let _ = event_tx.send(json); - } -} - -/// Resolve user display name from open_id via Contact API, with caching. -async fn resolve_user_name( - open_id: &str, - name_cache: &Arc>>, - token_cache: &Arc, - client: &reqwest::Client, - api_base: &str, -) -> String { - { - let cache = name_cache.lock().unwrap_or_else(|e| e.into_inner()); - if let Some(name) = cache.get(open_id) { - return name.clone(); - } - } - let token = match token_cache.get_token(client).await { - Ok(t) => t, - Err(_) => return open_id.to_string(), - }; - let url = format!( - "{}/open-apis/contact/v3/users/{}?user_id_type=open_id", - api_base, open_id - ); - let resolved = match client.get(&url).bearer_auth(&token).send().await { - Ok(resp) => { - let body: serde_json::Value = resp.json().await.unwrap_or_default(); - body.pointer("/data/user/name") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()) - } - Err(_) => None, - }; - // Only cache successful resolutions — don't cache fallback open_id - // so retries can succeed after permissions are granted. - if let Some(ref name) = resolved { - let mut cache = name_cache.lock().unwrap_or_else(|e| e.into_inner()); - if cache.len() < 10_000 { - cache.insert(open_id.to_string(), name.clone()); - } - } - resolved.unwrap_or_else(|| open_id.to_string()) -} - -// --------------------------------------------------------------------------- -// Send message -/// Edit (update) an existing feishu message in-place for streaming. -/// Feishu message edit cap: API returns errcode 230072 after 20 edits per message. -/// We stop preemptively at 18 to leave a 2-edit safety margin (handles races where -/// multiple in-flight edits could each push count to the wall) and also catch 230072 -/// defensively in case the local count drifts from server reality. -const FEISHU_EDIT_CAP: u32 = 18; - -/// Maximum entries in the per-adapter edit_counts cache before lazy eviction kicks in. -const EDIT_COUNTS_CACHE_MAX: usize = 4096; - -/// Validates that a Feishu message_id matches the expected `om_` shape -/// before it is interpolated into a REST URL path. Feishu's documented -/// message_id format is the `om_` prefix followed by base62-style characters -/// (`[A-Za-z0-9_]`). Rejecting anything else stops crafted IDs containing `/`, -/// `?`, or `#` from altering URL semantics — defence in depth, since the trust -/// boundary is the core↔gateway WebSocket and not external input. -fn is_valid_feishu_message_id(id: &str) -> bool { - let bytes = id.as_bytes(); - if !id.starts_with("om_") || id.len() < 4 || id.len() > 128 { - return false; - } - bytes - .iter() - .all(|b| b.is_ascii_alphanumeric() || *b == b'_') -} - -/// Detect whether a Feishu API response body indicates the per-message edit -/// cap (errcode 230072). Trusts JSON `code` field when the body parses as -/// JSON; falls back to substring match only on non-JSON bodies (proxy HTML, -/// truncated responses, …) so a JSON body with an unrelated `code` cannot be -/// false-positively flagged just because some inner string contains "230072". -fn is_feishu_cap_reached_body(body: &str) -> bool { - match serde_json::from_str::(body) { - Ok(v) => v - .get("code") - .and_then(|c| c.as_i64()) - .is_some_and(|code| code == 230072), - Err(_) => { - body.contains("230072") - || body.contains("number of times it can be edited") - } - } -} - -/// Outcome of an edit_feishu_message attempt. Distinguishes the cap-reached case -/// from generic failure so the caller can stop attempting edits and let the -/// core finalize path handle recovery. -pub enum EditOutcome { - /// Edit succeeded; the on-screen message now reflects the new content. - Edited, - /// The 20-edits-per-message cap is exhausted (either tracked locally or - /// signaled by errcode 230072). Caller should stop attempting edits; - /// recovery (delete placeholder + send fresh) is handled at the core - /// finalize layer in `src/adapter.rs`, not here — appending new messages - /// per cosmetic flush would spam the user with continuation messages. - CapReached, - /// Generic failure (network, token, other API errors). - Failed(String), -} - -/// Increment the edit count for a message_id. New keys are appended to the -/// FIFO order queue; existing keys keep their position. When the cache is -/// over `EDIT_COUNTS_CACHE_MAX`, the oldest *insertions* are evicted (not the -/// lowest-count entries) so active streams are not bumped out from under -/// themselves. -fn increment_edit_count( - cache: &Arc>, - message_id: &str, -) { - let mut c = cache.lock().unwrap_or_else(|e| e.into_inner()); - let was_new = !c.counts.contains_key(message_id); - let entry = c.counts.entry(message_id.to_string()).or_insert(0); - if *entry != u32::MAX { - *entry = entry.saturating_add(1); - } - if was_new { - c.order.push_back(message_id.to_string()); - evict_if_overcap(&mut c); - } -} - -/// Mark a message_id as cap-reached; subsequent edit attempts skip the API -/// call and signal `EditOutcome::CapReached` directly so the core finalize -/// path can take over. -fn mark_edit_cap( - cache: &Arc>, - message_id: &str, -) { - let mut c = cache.lock().unwrap_or_else(|e| e.into_inner()); - let was_new = !c.counts.contains_key(message_id); - c.counts.insert(message_id.to_string(), u32::MAX); - if was_new { - c.order.push_back(message_id.to_string()); - evict_if_overcap(&mut c); - } -} - -/// FIFO eviction helper: when over `EDIT_COUNTS_CACHE_MAX`, drop the oldest -/// half by insertion order. Tolerant of `order`/`counts` drift — entries that -/// only exist in `order` are silently skipped. -fn evict_if_overcap(c: &mut EditCountsCache) { - if c.counts.len() > EDIT_COUNTS_CACHE_MAX { - let evict = c.counts.len() / 2; - for _ in 0..evict { - if let Some(oldest) = c.order.pop_front() { - c.counts.remove(&oldest); - } else { - break; - } - } - } -} - -/// Return true if this message_id has already reached the edit cap (either -/// tracked locally or marked via 230072 sentinel). -fn is_edit_cap_reached( - cache: &Arc>, - message_id: &str, -) -> bool { - let c = cache.lock().unwrap_or_else(|e| e.into_inner()); - c.counts - .get(message_id) - .is_some_and(|&n| n >= FEISHU_EDIT_CAP) -} - -/// Edit (update) an existing Feishu message in-place for streaming. -/// -/// Returns [`EditOutcome`] so the caller can distinguish success, cap-reached, -/// and generic failure. Performs a preemptive local cap check (`FEISHU_EDIT_CAP`) -/// before hitting the network, and detects the server-side errcode 230072 via -/// body-code-first parsing if the local count drifts from reality. -async fn edit_feishu_message( - adapter: &FeishuAdapter, - message_id: &str, - text: &str, -) -> EditOutcome { - // Pre-check: if we've already tracked >= FEISHU_EDIT_CAP edits (or the sentinel - // u32::MAX from a 230072 response), skip the API call and signal CapReached so - // the caller can stop attempting edits and let the core finalize path recover. - if is_edit_cap_reached(&adapter.edit_counts, message_id) { - return EditOutcome::CapReached; - } - - let token = match adapter.token_cache.get_token(&adapter.client).await { - Ok(t) => t, - Err(e) => { - tracing::error!(err = %e, "feishu: cannot get token for edit"); - return EditOutcome::Failed(format!("token error: {e}")); - } - }; - let api_base = adapter.config.api_base(); - let url = format!("{}/open-apis/im/v1/messages/{}", api_base, message_id); - let post_content = markdown_to_post(text); - let body = serde_json::json!({ - "msg_type": "post", - "content": post_content.to_string(), - }); - match adapter.client.put(&url).bearer_auth(&token) - .header("Content-Type", "application/json; charset=utf-8") - .json(&body).send().await - { - Ok(resp) => { - let status = resp.status(); - let body = resp.text().await.unwrap_or_default(); - // Feishu OpenAPI convention: the business result lives in the body - // `code` field, and an edit-cap rejection (errcode 230072) can arrive - // with HTTP 200. So we decide on the body — consistent with token - // refresh and the WS endpoint elsewhere in this file — rather than - // trusting HTTP status alone, which would miscount a 200 + non-zero - // `code` response as a successful edit and never reach cap detection. - // - // This relies on Feishu returning `code` as a JSON integer (which it - // always does). A non-integer or absent code falls through to the - // HTTP-status fallback below, so a malformed 2xx body is treated as - // success — acceptable, since Feishu never emits such a body. - // - // 1. Cap reached? `is_feishu_cap_reached_body` is the sole authority - // (JSON code == 230072, or substring fallback for non-JSON bodies). - if is_feishu_cap_reached_body(&body) { - mark_edit_cap(&adapter.edit_counts, message_id); - tracing::warn!( - message_id = %message_id, - status = %status, - "feishu edit cap reached (errcode 230072); signaling core for cap-reached recovery" - ); - return EditOutcome::CapReached; - } - // 2. Otherwise classify by body `code` (0 = success), falling back to - // HTTP status only for non-JSON bodies (proxy HTML, truncated). - match serde_json::from_str::(&body) - .ok() - .and_then(|v| v.get("code").and_then(|c| c.as_i64())) - { - Some(0) => { - increment_edit_count(&adapter.edit_counts, message_id); - tracing::trace!(message_id = %message_id, "feishu message edited"); - EditOutcome::Edited - } - Some(code) => { - tracing::error!( - message_id = %message_id, - status = %status, - code, - body = %body, - "feishu edit message failed" - ); - EditOutcome::Failed(format!("code {code}: {body}")) - } - None => { - // Body wasn't JSON-with-code; trust HTTP status as last resort. - if status.is_success() { - increment_edit_count(&adapter.edit_counts, message_id); - tracing::trace!(message_id = %message_id, "feishu message edited (non-JSON 2xx body)"); - EditOutcome::Edited - } else { - tracing::error!( - message_id = %message_id, - status = %status, - body = %body, - "feishu edit message failed" - ); - EditOutcome::Failed(format!("HTTP {status}: {body}")) - } - } - } - } - Err(e) => { - tracing::error!(message_id = %message_id, err = %e, "feishu edit message request failed"); - EditOutcome::Failed(format!("request error: {e}")) - } - } -} - -/// Delete a Feishu message via DELETE /open-apis/im/v1/messages/{id}. -/// Unlike PATCH (edit), DELETE is not subject to the 20-edits-per-message cap, -/// so this works even on messages that have already exhausted their edit quota. -/// Used by the streaming finalize path to remove the half-edited placeholder -/// before sending the full content as fresh messages, avoiding visual overlap. -/// -/// `message_id` shape is validated by the caller (`handle_reply` dispatch seam, -/// via `is_valid_feishu_message_id`) before this is reached, so it is safe to -/// interpolate into the URL path here. -async fn delete_feishu_message( - adapter: &FeishuAdapter, - message_id: &str, -) -> Result<(), String> { - let token = adapter - .token_cache - .get_token(&adapter.client) - .await - .map_err(|e| format!("token error: {e}"))?; - let api_base = adapter.config.api_base(); - let url = format!("{}/open-apis/im/v1/messages/{}", api_base, message_id); - match adapter - .client - .delete(&url) - .bearer_auth(&token) - .send() - .await - { - Ok(resp) if resp.status().is_success() => { - tracing::info!(message_id = %message_id, "feishu message deleted"); - Ok(()) - } - Ok(resp) => { - let status = resp.status(); - let body = resp.text().await.unwrap_or_default(); - tracing::warn!(status = %status, body = %body, message_id = %message_id, "feishu delete message failed"); - Err(format!("HTTP {status}: {body}")) - } - Err(e) => { - tracing::warn!(err = %e, message_id = %message_id, "feishu delete message request failed"); - Err(format!("request error: {e}")) - } - } -} - -// --------------------------------------------------------------------------- -// Markdown → Feishu post conversion -// --------------------------------------------------------------------------- - -/// Convert markdown text to feishu post content JSON. -/// Supported: code blocks → code_block tag, links → a tag, @mentions preserved. -/// Unsupported inline formatting (bold, italic, etc.) is stripped to plain text. -fn markdown_to_post(md: &str) -> serde_json::Value { - let mut lines: Vec> = Vec::new(); - - // We work byte-offset based for code fence detection, line-based otherwise. - let raw_lines: Vec<&str> = md.split('\n').collect(); - let mut li = 0; - while li < raw_lines.len() { - let line = raw_lines[li]; - // Detect fenced code block - let trimmed = line.trim_start(); - if let Some(after_fence) = trimmed.strip_prefix("```") { - let lang = after_fence.trim().to_string(); - let mut code = String::new(); - li += 1; - while li < raw_lines.len() { - if raw_lines[li].trim_start().starts_with("```") { - break; - } - if !code.is_empty() { - code.push('\n'); - } - code.push_str(raw_lines[li]); - li += 1; - } - li += 1; // skip closing ``` - let mut block = serde_json::json!({"tag": "code_block", "text": code}); - if !lang.is_empty() { - block["language"] = serde_json::Value::String(lang); - } - lines.push(vec![block]); - continue; - } - // Normal line: parse inline elements - let elems = parse_inline(line); - lines.push(elems); - li += 1; - } - - serde_json::json!({ - "zh_cn": { - "content": lines - } - }) -} - -/// Parse inline markdown elements in a single line. -/// Extracts links [text](url) → a tag, strips bold/italic/strikethrough markers. -fn parse_inline(line: &str) -> Vec { - let mut elems = Vec::new(); - let mut buf = String::new(); - let chars: Vec = line.chars().collect(); - let len = chars.len(); - let mut i = 0; - - while i < len { - // Link: [text](url) - if chars[i] == '[' { - if let Some((text, url, end)) = try_parse_link(&chars, i) { - if !buf.is_empty() { - elems.push(serde_json::json!({"tag": "text", "text": buf})); - buf.clear(); - } - elems.push(serde_json::json!({"tag": "a", "text": text, "href": url})); - i = end; - continue; - } - } - // Inline code: find matching closing backtick(s), preserve content literally - if chars[i] == '`' { - let mut ticks = 0; - while i + ticks < len && chars[i + ticks] == '`' { - ticks += 1; - } - i += ticks; - // Find matching closing backtick sequence of same length - let mut end = i; - 'outer: while end < len { - if chars[end] == '`' { - let mut close_ticks = 0; - while end + close_ticks < len && chars[end + close_ticks] == '`' { - close_ticks += 1; - } - if close_ticks == ticks { - // Found matching close — content between is literal - buf.extend(chars[i..end].iter().copied()); - i = end + close_ticks; - break 'outer; - } - end += close_ticks; - } else { - end += 1; - } - } - if end >= len { - // No matching close — treat backticks as literal - buf.extend(chars[i..len].iter().copied()); - i = len; - } - continue; - } - // Strip paired markdown markers: **bold**, *italic*, ~~strike~~ - // Unpaired markers are kept as literal text (e.g. ~/.ssh, *.rs, 3 * 4) - if chars[i] == '*' || chars[i] == '~' { - let ch = chars[i]; - let mut run = 0; - while i + run < len && chars[i + run] == ch { - run += 1; - } - // Look ahead for a matching closing run of same length - let after = i + run; - let mut scan = after; - let mut found_close = false; - while scan < len { - if chars[scan] == ch { - let mut close_run = 0; - while scan + close_run < len && chars[scan + close_run] == ch { - close_run += 1; - } - if close_run == run { - // Found matching close — strip both, keep inner text - buf.extend(chars[after..scan].iter().copied()); - i = scan + close_run; - found_close = true; - break; - } - scan += close_run; - } else { - scan += 1; - } - } - if !found_close { - // No matching close — keep markers as literal - for _ in 0..run { - buf.push(ch); - } - i += run; - } - continue; - } - buf.push(chars[i]); - i += 1; - } - if !buf.is_empty() { - elems.push(serde_json::json!({"tag": "text", "text": buf})); - } - if elems.is_empty() { - elems.push(serde_json::json!({"tag": "text", "text": ""})); - } - elems -} - -/// Try to parse a markdown link starting at position `start` (which is '['). -/// Returns (text, url, next_index) on success. -fn try_parse_link(chars: &[char], start: usize) -> Option<(String, String, usize)> { - let len = chars.len(); - // Find closing ] - let mut i = start + 1; - let mut text = String::new(); - while i < len && chars[i] != ']' { - text.push(chars[i]); - i += 1; - } - if i >= len { - return None; - } - i += 1; // skip ] - if i >= len || chars[i] != '(' { - return None; - } - i += 1; // skip ( - let mut url = String::new(); - while i < len && chars[i] != ')' { - url.push(chars[i]); - i += 1; - } - if i >= len { - return None; - } - i += 1; // skip ) - Some((text, url, i)) -} - -// --------------------------------------------------------------------------- -// Media helpers -// --------------------------------------------------------------------------- - -/// Reference to a media resource that needs async download after parse_message_event. -pub enum MediaRef { - Image { message_id: String, image_key: String }, - File { message_id: String, file_key: String, file_name: String }, - Audio { message_id: String, file_key: String }, -} - -const IMAGE_MAX_DIMENSION_PX: u32 = 1200; -const IMAGE_JPEG_QUALITY: u8 = 75; -const IMAGE_MAX_DOWNLOAD: u64 = 10 * 1024 * 1024; // 10 MB -const FILE_MAX_DOWNLOAD: u64 = 512 * 1024; // 512 KB - -/// Resize image so longest side <= 1200px, then encode as JPEG. -/// GIFs are passed through unchanged to preserve animation. -fn resize_and_compress(raw: &[u8]) -> Result<(Vec, String), image::ImageError> { - use image::ImageReader; - use std::io::Cursor; - - let reader = ImageReader::new(Cursor::new(raw)).with_guessed_format()?; - let format = reader.format(); - if format == Some(image::ImageFormat::Gif) { - return Ok((raw.to_vec(), "image/gif".to_string())); - } - let img = reader.decode()?; - let (w, h) = (img.width(), img.height()); - let img = if w > IMAGE_MAX_DIMENSION_PX || h > IMAGE_MAX_DIMENSION_PX { - let max_side = std::cmp::max(w, h); - let ratio = f64::from(IMAGE_MAX_DIMENSION_PX) / f64::from(max_side); - let new_w = (f64::from(w) * ratio) as u32; - let new_h = (f64::from(h) * ratio) as u32; - img.resize(new_w, new_h, image::imageops::FilterType::Lanczos3) - } else { - img - }; - let mut buf = Cursor::new(Vec::new()); - let encoder = image::codecs::jpeg::JpegEncoder::new_with_quality(&mut buf, IMAGE_JPEG_QUALITY); - img.write_with_encoder(encoder)?; - Ok((buf.into_inner(), "image/jpeg".to_string())) -} - -/// Download a Feishu image by message_id + image_key → resize/compress → base64 Attachment. -pub async fn download_feishu_image( - client: &reqwest::Client, - api_base: &str, - token: &str, - message_id: &str, - image_key: &str, -) -> Option { - let url = format!( - "{}/open-apis/im/v1/messages/{}/resources/{}?type=image", - api_base, message_id, image_key - ); - let resp = match client.get(&url).bearer_auth(token).send().await { - Ok(r) => r, - Err(e) => { - tracing::warn!(image_key, error = %e, "feishu image download failed"); - return None; - } - }; - if !resp.status().is_success() { - tracing::warn!(image_key, status = %resp.status(), "feishu image download failed"); - return None; - } - // Early gate: reject oversized downloads before buffering the full body - if let Some(cl) = resp.headers().get(reqwest::header::CONTENT_LENGTH) { - if let Ok(size) = cl.to_str().unwrap_or("0").parse::() { - if size > IMAGE_MAX_DOWNLOAD { - tracing::warn!(image_key, size, "feishu image Content-Length exceeds 10MB limit, skipping download"); - return None; - } - } - } - let bytes = resp.bytes().await.ok()?; - // Fallback check (Content-Length may be absent or misreported) - if bytes.len() as u64 > IMAGE_MAX_DOWNLOAD { - tracing::warn!(image_key, size = bytes.len(), "feishu image exceeds 10MB limit"); - return None; - } - let (compressed, mime) = match resize_and_compress(&bytes) { - Ok(v) => v, - Err(e) => { - tracing::warn!(image_key, error = %e, "feishu image resize failed"); - return None; - } - }; - let path = crate::store::store_media(&compressed).await?; - let ext = if mime == "image/gif" { "gif" } else { "jpg" }; - Some(crate::schema::Attachment { - attachment_type: "image".into(), - filename: format!("{}.{}", image_key, ext), - mime_type: mime, - data: String::new(), - size: compressed.len() as u64, - path: Some(path), - }) -} - -/// Download a Feishu file by message_id + file_key → base64 Attachment (text files only). -pub async fn download_feishu_file( - client: &reqwest::Client, - api_base: &str, - token: &str, - message_id: &str, - file_key: &str, - file_name: &str, -) -> Option { - // Only download text-like files - let ext = file_name.rsplit('.').next().unwrap_or("").to_lowercase(); - const TEXT_EXTS: &[&str] = &[ - "txt", "csv", "log", "md", "json", "jsonl", "yaml", "yml", "toml", "xml", - "rs", "py", "js", "ts", "jsx", "tsx", "go", "java", "c", "cpp", "h", "hpp", - "rb", "sh", "bash", "sql", "html", "css", "ini", "cfg", "conf", "env", - ]; - if !TEXT_EXTS.contains(&ext.as_str()) { - tracing::debug!(file_name, "skipping non-text file attachment"); - return None; - } - let url = format!( - "{}/open-apis/im/v1/messages/{}/resources/{}?type=file", - api_base, message_id, file_key - ); - let resp = match client.get(&url).bearer_auth(token).send().await { - Ok(r) => r, - Err(e) => { - tracing::warn!(file_name, error = %e, "feishu file download failed"); - return None; - } - }; - if !resp.status().is_success() { - tracing::warn!(file_name, status = %resp.status(), "feishu file download failed"); - return None; - } - // Early gate: reject oversized downloads before buffering the full body - if let Some(cl) = resp.headers().get(reqwest::header::CONTENT_LENGTH) { - if let Ok(size) = cl.to_str().unwrap_or("0").parse::() { - if size > FILE_MAX_DOWNLOAD { - tracing::warn!(file_name, size, "feishu file Content-Length exceeds 512KB limit, skipping download"); - return None; - } - } - } - let bytes = resp.bytes().await.ok()?; - // Fallback check (Content-Length may be absent or misreported) - if bytes.len() as u64 > FILE_MAX_DOWNLOAD { - tracing::warn!(file_name, size = bytes.len(), "feishu file exceeds 512KB limit"); - return None; - } - let path = crate::store::store_media(&bytes).await?; - Some(crate::schema::Attachment { - attachment_type: "text_file".into(), - filename: file_name.to_string(), - mime_type: "text/plain".into(), - data: String::new(), - size: bytes.len() as u64, - path: Some(path), - }) -} - -const AUDIO_MAX_DOWNLOAD: u64 = 25 * 1024 * 1024; // 25 MB (Whisper API limit) - -/// Download a Feishu audio message by message_id + file_key → base64 Attachment. -pub async fn download_feishu_audio( - client: &reqwest::Client, - api_base: &str, - token: &str, - message_id: &str, - file_key: &str, -) -> Option { - use urlencoding::encode; - let url = format!( - "{}/open-apis/im/v1/messages/{}/resources/{}?type=file", - api_base, encode(message_id), encode(file_key) - ); - let resp = match client.get(&url).bearer_auth(token).send().await { - Ok(r) => r, - Err(e) => { - tracing::warn!(file_key, error = %e, "feishu audio download failed"); - return None; - } - }; - if !resp.status().is_success() { - tracing::warn!(file_key, status = %resp.status(), "feishu audio download failed"); - return None; - } - let content_type = resp - .headers() - .get(reqwest::header::CONTENT_TYPE) - .and_then(|v| v.to_str().ok()) - .unwrap_or("audio/ogg") - .to_string(); - if let Some(cl) = resp.headers().get(reqwest::header::CONTENT_LENGTH) { - if let Ok(size) = cl.to_str().unwrap_or("0").parse::() { - if size > AUDIO_MAX_DOWNLOAD { - tracing::warn!(file_key, size, "feishu audio exceeds 25MB limit"); - return None; - } - } - } - let bytes = resp.bytes().await.ok()?; - if bytes.len() as u64 > AUDIO_MAX_DOWNLOAD { - tracing::warn!(file_key, size = bytes.len(), "feishu audio exceeds 25MB limit"); - return None; - } - tracing::debug!(file_key, size = bytes.len(), "feishu audio downloaded"); - let path = crate::store::store_media(&bytes).await?; - Some(crate::schema::Attachment { - attachment_type: "audio".into(), - filename: format!("{}.ogg", file_key), - mime_type: content_type, - data: String::new(), - size: bytes.len() as u64, - path: Some(path), - }) -} - -/// Send a post (rich text) message to a feishu chat_id. -/// Returns the sent message_id on success, None on failure. -/// When `reply_to` is Some(root_id), uses the reply API to stay in a thread. -/// When `reply_to` is None, sends a new message to the chat. -pub async fn send_post_message( - client: &reqwest::Client, - api_base: &str, - token: &str, - chat_id: &str, - reply_to: Option<&str>, - text: &str, -) -> Option { - let (url, body) = if let Some(root_id) = reply_to { - ( - format!("{}/open-apis/im/v1/messages/{}/reply", api_base, root_id), - serde_json::json!({ - "msg_type": "post", - "content": markdown_to_post(text).to_string(), - }), - ) - } else { - ( - format!("{}/open-apis/im/v1/messages?receive_id_type=chat_id", api_base), - serde_json::json!({ - "receive_id": chat_id, - "msg_type": "post", - "content": markdown_to_post(text).to_string(), - }), - ) - }; - - match client - .post(&url) - .bearer_auth(token) - .header("Content-Type", "application/json; charset=utf-8") - .json(&body) - .send() - .await - { - Ok(resp) => { - if resp.status().is_success() { - let resp_body: serde_json::Value = match resp.json().await { - Ok(v) => v, - Err(e) => { - tracing::warn!(err = %e, "feishu post: failed to parse response body"); - serde_json::Value::default() - } - }; - let msg_id = resp_body - .pointer("/data/message_id") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()); - info!(chat_id = %chat_id, reply_to = ?reply_to, message_id = ?msg_id, "feishu post message sent"); - msg_id - } else { - let status = resp.status(); - let text = resp.text().await.unwrap_or_default(); - tracing::error!(status = %status, body = %text, "feishu send post message failed"); - None - } - } - Err(e) => { - tracing::error!(err = %e, "feishu send post message request failed"); - None - } - } -} - -// --------------------------------------------------------------------------- - -/// Send a text message to a feishu chat_id. -/// Returns the sent message_id on success (for self-echo dedupe), None on failure. -/// Kept for webhook fallback and tests; normal reply path uses send_post_message. -#[allow(dead_code)] -pub async fn send_text_message( - client: &reqwest::Client, - api_base: &str, - token: &str, - chat_id: &str, - text: &str, -) -> Option { - let url = format!( - "{}/open-apis/im/v1/messages?receive_id_type=chat_id", - api_base - ); - let content = serde_json::json!({"text": text}).to_string(); - let body = serde_json::json!({ - "receive_id": chat_id, - "msg_type": "text", - "content": content, - }); - - match client - .post(&url) - .bearer_auth(token) - .header("Content-Type", "application/json; charset=utf-8") - .json(&body) - .send() - .await - { - Ok(resp) => { - if resp.status().is_success() { - let msg_id = match resp.json::().await { - Ok(body) => body - .pointer("/data/message_id") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()), - Err(e) => { - warn!(chat_id = %chat_id, err = %e, "feishu 200 response not valid JSON, self-echo dedupe will be skipped"); - None - } - }; - info!(chat_id = %chat_id, message_id = ?msg_id, "feishu message sent"); - msg_id - } else { - let status = resp.status(); - let text = resp.text().await.unwrap_or_default(); - tracing::error!(status = %status, body = %text, "feishu send message failed"); - None - } - } - Err(e) => { - tracing::error!(err = %e, "feishu send message request failed"); - None - } - } -} - -// --------------------------------------------------------------------------- -// Reactions (emoji on original message) -// --------------------------------------------------------------------------- - -/// Map OAB emoji to feishu reaction_type. Feishu uses string keys like "THUMBSUP". -fn emoji_to_feishu_reaction(emoji: &str) -> Option<&'static str> { - match emoji { - "👀" => Some("EYES"), - "🤔" => Some("THINKING"), - "🔥" => Some("FIRE"), - "👨\u{200d}💻" => Some("TECHNOLOGIST"), - "⚡" => Some("LIGHTNING"), - "🆗" => Some("OK"), - "👍" => Some("THUMBSUP"), - "😱" => Some("SCREAM"), - _ => None, - } -} - -async fn add_reaction(adapter: &FeishuAdapter, message_id: &str, emoji: &str) { - let reaction_type = match emoji_to_feishu_reaction(emoji) { - Some(r) => r, - None => { - tracing::debug!(emoji = %emoji, "feishu: no mapping for reaction emoji"); - return; - } - }; - let token = match adapter.token_cache.get_token(&adapter.client).await { - Ok(t) => t, - Err(e) => { tracing::error!(err = %e, "feishu: cannot get token for reaction"); return; } - }; - let url = format!( - "{}/open-apis/im/v1/messages/{}/reactions", - adapter.config.api_base(), message_id - ); - let _ = adapter.client - .post(&url) - .bearer_auth(&token) - .json(&serde_json::json!({"reaction_type": {"emoji_type": reaction_type}})) - .send() - .await - .map_err(|e| tracing::error!(err = %e, "feishu add_reaction failed")); -} - -async fn remove_reaction(adapter: &FeishuAdapter, message_id: &str, emoji: &str) { - let reaction_type = match emoji_to_feishu_reaction(emoji) { - Some(r) => r, - None => return, - }; - let token = match adapter.token_cache.get_token(&adapter.client).await { - Ok(t) => t, - Err(e) => { tracing::error!(err = %e, "feishu: cannot get token for reaction"); return; } - }; - // Feishu remove reaction needs reaction_id. Simpler approach: delete by type. - // GET reactions, find matching, DELETE by id. - let list_url = format!( - "{}/open-apis/im/v1/messages/{}/reactions?reaction_type={}", - adapter.config.api_base(), message_id, reaction_type - ); - let resp = match adapter.client.get(&list_url).bearer_auth(&token).send().await { - Ok(r) => r, - Err(_) => return, - }; - let body: serde_json::Value = match resp.json().await { - Ok(v) => v, - Err(_) => return, - }; - // Find our bot's reaction_id - if let Some(items) = body.pointer("/data/items").and_then(|v| v.as_array()) { - let bot_id = adapter.bot_open_id.read().await; - for item in items { - let is_ours = item.pointer("/operator/operator_id/open_id") - .and_then(|v| v.as_str()) == bot_id.as_deref(); - if is_ours { - if let Some(reaction_id) = item.get("reaction_id").and_then(|v| v.as_str()) { - let del_url = format!( - "{}/open-apis/im/v1/messages/{}/reactions/{}", - adapter.config.api_base(), message_id, reaction_id - ); - let _ = adapter.client.delete(&del_url).bearer_auth(&token).send().await; - return; - } - } - } - } -} - -// --------------------------------------------------------------------------- -// Reply handler -// --------------------------------------------------------------------------- - -/// Check if the bot has participated in the thread referenced by this envelope. -/// Returns `true` if the message is in a thread and that thread has a valid -/// (non-expired) participation entry in the cache. -fn check_thread_participated( - envelope: &FeishuEventEnvelope, - cache: &Arc>>, - session_ttl_secs: u64, -) -> bool { - envelope - .event - .as_ref() - .and_then(|e| e.message.as_ref()) - .and_then(|m| m.root_id.as_deref().or(m.parent_id.as_deref())) - .map(|tid| { - // Intentionally recover from poisoned mutex — cache data loss is acceptable - // and preferable to panicking the gateway. - let c = cache.lock().unwrap_or_else(|e| e.into_inner()); - c.get(tid).is_some_and(|ts| ts.elapsed().as_secs() < session_ttl_secs) - }) - .unwrap_or(false) -} - -/// Max entries before eviction. Shared by both `participated_threads` and -/// `multibot_threads` caches — they have the same cardinality (one entry per -/// active thread) so a single limit is appropriate for both. -const PARTICIPATION_CACHE_MAX: usize = 1000; - -/// Detect if a message @mentions another bot in a participated thread, and if -/// so, mark the thread in the multibot cache. Returns whether @mention gating -/// should be bypassed, respecting the configured `allow_user_messages` mode. -/// -/// This consolidates the duplicated multibot detection logic used by both the -/// WebSocket and webhook paths. -fn detect_and_mark_multibot( - envelope: &FeishuEventEnvelope, - bot_open_id: Option<&str>, - config: &FeishuConfig, - participated_threads: &Arc>>, - multibot_threads: &Arc>>, -) -> bool { - let self_participated = check_thread_participated( - envelope, participated_threads, config.session_ttl_secs, - ); - - let thread_id_for_check = envelope - .event - .as_ref() - .and_then(|e| e.message.as_ref()) - .and_then(|m| m.root_id.as_deref().or(m.parent_id.as_deref())); - - // Early multibot detection: if a message in a participated thread @mentions - // another bot, mark the thread as multibot immediately. - if let Some(tid) = thread_id_for_check { - if self_participated { - let mentions = envelope - .event - .as_ref() - .and_then(|e| e.message.as_ref()) - .and_then(|m| m.mentions.as_ref()); - if let Some(mention_list) = mentions { - let bot_self_id = bot_open_id.unwrap_or(""); - let mention_ids: Vec<_> = mention_list.iter().filter_map(|m| { - m.id.as_ref().and_then(|id| id.open_id.as_deref()) - }).collect(); - - let mentions_other_bot = if !config.trusted_bot_ids.is_empty() { - mention_ids.iter().any(|oid| { - config.trusted_bot_ids.iter().any(|bid| bid == oid) - }) - } else if !config.allowed_users.is_empty() { - mention_ids.iter().any(|oid| { - *oid != bot_self_id && !config.allowed_users.iter().any(|u| u == oid) - }) - } else { - false - }; - - if mentions_other_bot { - info!(thread_id = %tid, "multibot thread detected via @mention"); - let mut cache = multibot_threads.lock().unwrap_or_else(|e| e.into_inner()); - cache.entry(tid.to_string()).or_insert_with(Instant::now); - if cache.len() > PARTICIPATION_CACHE_MAX { - cache.retain(|_, ts| ts.elapsed().as_secs() < config.session_ttl_secs); - } - } - } - } - } - - // Compute bypass_mention_gating based on mode - match config.allow_user_messages { - AllowUsers::Mentions => false, - AllowUsers::Involved => self_participated, - AllowUsers::MultibotMentions => { - if !self_participated { - false - } else { - thread_id_for_check - .map(|tid| { - let cache = multibot_threads.lock().unwrap_or_else(|e| e.into_inner()); - cache - .get(tid) - .is_none_or(|ts| ts.elapsed().as_secs() >= config.session_ttl_secs) - }) - .unwrap_or(true) - } - } - } -} - -/// Record that the bot has participated in a thread. Evicts oldest entries -/// when the cache exceeds PARTICIPATION_CACHE_MAX. -fn record_participation( - cache: &Arc>>, - thread_id: &str, - session_ttl_secs: u64, -) { - if session_ttl_secs == 0 { - return; // Participation tracking disabled - } - // Intentionally recover from poisoned mutex — cache data loss is acceptable - // and preferable to panicking the gateway. - let mut map = cache.lock().unwrap_or_else(|e| e.into_inner()); - map.insert(thread_id.to_string(), Instant::now()); - // Evict if over capacity: first drop expired entries, then oldest half if still over - if map.len() > PARTICIPATION_CACHE_MAX { - map.retain(|_, ts| ts.elapsed().as_secs() < session_ttl_secs); - if map.len() > PARTICIPATION_CACHE_MAX { - let mut entries: Vec<_> = map.iter().map(|(k, v)| (k.clone(), *v)).collect(); - entries.sort_by_key(|(_, ts)| *ts); - let evict_count = entries.len() / 2; - for (k, _) in entries.into_iter().take(evict_count) { - map.remove(&k); - } - } - } -} - -pub async fn handle_reply( - reply: &GatewayReply, - adapter: &FeishuAdapter, - event_tx: &tokio::sync::broadcast::Sender, -) { - // Handle reactions — add/remove emoji on the original message - if let Some(ref cmd) = reply.command { - // Defence-in-depth: every command below interpolates `reply.reply_to` - // into a REST URL path (edit/delete → /im/v1/messages/{id}; reactions → - // /im/v1/messages/{id}/reactions). Validate the id shape once here, at - // the dispatch seam, so a crafted id with URL metacharacters can't alter - // request semantics. Trust boundary is the core↔gateway WebSocket, so - // this is belt-and-suspenders — but it closes the guard over every - // url-path-bearing command instead of just delete. - let interpolates_message_id = matches!( - cmd.as_str(), - "edit_message" | "delete_message" | "add_reaction" | "remove_reaction" - ); - if interpolates_message_id && !is_valid_feishu_message_id(&reply.reply_to) { - // "draft" is a known sentinel from core when streaming_placeholder=false; - // not a security concern, just a no-op — log at debug to avoid noise. - if reply.reply_to == "draft" { - tracing::debug!( - command = %cmd, - message_id = %reply.reply_to, - "feishu: skipping command — draft placeholder has no real message_id" - ); - } else { - tracing::warn!( - command = %cmd, - message_id = %reply.reply_to, - "feishu: refusing command — message_id failed shape validation" - ); - } - if let Some(ref req_id) = reply.request_id { - let resp = crate::schema::GatewayResponse { - schema: "openab.gateway.response.v1".into(), - request_id: req_id.clone(), - success: false, - thread_id: None, - message_id: None, - error: Some("invalid message_id format".to_string()), - }; - if let Ok(json) = serde_json::to_string(&resp) { - let _ = event_tx.send(json); - } - } - return; - } - match cmd.as_str() { - "add_reaction" => { - add_reaction(adapter, &reply.reply_to, &reply.content.text).await; - return; - } - "remove_reaction" => { - remove_reaction(adapter, &reply.reply_to, &reply.content.text).await; - return; - } - "edit_message" => { - let outcome = edit_feishu_message( - adapter, - &reply.reply_to, - &reply.content.text, - ).await; - // Translate outcome → (success, message_id, error). For - // CapReached we deliberately do NOT append-new at the gateway - // layer (see the rationale on the CapReached arm below); we - // signal failure so core's finalize path owns recovery. - let (success, message_id, error) = match outcome { - EditOutcome::Edited => { - (true, Some(reply.reply_to.clone()), None) - } - EditOutcome::CapReached => { - // Do NOT append-new fallback at the gateway layer. Core's - // cosmetic streaming loop flushes every ~1500ms — if every - // post-cap edit spawned a new message, the user would be - // spammed with 20+ duplicate continuation messages over the - // remainder of a long reply. - // - // Instead, signal failure so: - // 1. core's mid-stream cosmetic edit loop hits its - // consecutive-failures break (3 strikes) and stops - // attempting edits, freezing the placeholder mid-content - // 2. the final delivery path in src/adapter.rs sees the - // placeholder edit fail and falls back to send_message - // so the user gets the full reply as a fresh message - // - // Net UX: half-edited placeholder + one complete continuation - // message + ✅ done reaction (vs. today's mid-truncation + 🆗 - // false success, or naive append-fallback's 25-message spam). - ( - false, - None, - Some("edit_cap_reached".to_string()), - ) - } - EditOutcome::Failed(err) => (false, None, Some(err)), - }; - if let Some(ref req_id) = reply.request_id { - let resp = crate::schema::GatewayResponse { - schema: "openab.gateway.response.v1".into(), - request_id: req_id.clone(), - success, - thread_id: None, - message_id, - error, - }; - if let Ok(json) = serde_json::to_string(&resp) { - let _ = event_tx.send(json); - } - } - return; - } - "create_topic" | "set_reaction" => { - tracing::debug!(command = %cmd, "feishu: skipping unsupported command"); - return; - } - "delete_message" => { - let result = delete_feishu_message(adapter, &reply.reply_to).await; - let (success, error) = match result { - Ok(()) => (true, None), - Err(e) => (false, Some(e)), - }; - // Dormant by design: core's delete_message is fire-and-forget - // (request_id = None), so this response branch is currently - // never taken. Kept for symmetry with the other handlers and so - // delete becomes observable for free if a future caller (or - // another gateway client) sets request_id. - if let Some(ref req_id) = reply.request_id { - let resp = crate::schema::GatewayResponse { - schema: "openab.gateway.response.v1".into(), - request_id: req_id.clone(), - success, - thread_id: None, - message_id: None, - error, - }; - if let Ok(json) = serde_json::to_string(&resp) { - let _ = event_tx.send(json); - } - } - return; - } - _ => {} - } - } - - let token = match adapter.token_cache.get_token(&adapter.client).await { - Ok(t) => t, - Err(e) => { - tracing::error!(err = %e, "feishu: cannot get token for reply"); - if let Some(ref req_id) = reply.request_id { - let resp = crate::schema::GatewayResponse { - schema: "openab.gateway.response.v1".into(), - request_id: req_id.clone(), - success: false, - thread_id: None, - message_id: None, - error: Some(format!("token error: {e}")), - }; - if let Ok(json) = serde_json::to_string(&resp) { - let _ = event_tx.send(json); - } - } - return; - } - }; - - let api_base = adapter.config.api_base(); - let text = &reply.content.text; - let limit = adapter.config.message_limit; - // quote_message_id (agent-controlled reply-to) takes priority over thread_id - let reply_target = reply.quote_message_id.as_deref() - .or(reply.channel.thread_id.as_deref()); - let thread_id = reply.channel.thread_id.as_deref(); - - // Split long messages; store sent message_ids in dedupe to prevent - // self-echo (Feishu pushes bot's own messages back via WebSocket) - // Use post (rich text) format for markdown rendering. - // When in a thread (thread_id present), use reply API to stay in the same thread. - if text.len() <= limit { - let result = send_post_message(&adapter.client, &api_base, &token, &reply.channel.id, reply_target, text).await; - // Fallback: if quote_message_id caused failure, retry without it - let result = if result.is_none() && reply.quote_message_id.is_some() { - tracing::warn!(quote_message_id = ?reply.quote_message_id, channel_id = %reply.channel.id, "reply-to failed, falling back to plain send"); - send_post_message(&adapter.client, &api_base, &token, &reply.channel.id, thread_id, text).await - } else { - result - }; - match result { - Some(msg_id) => { - adapter.dedupe.is_duplicate(&msg_id); - // Record thread participation for mention bypass - if let Some(tid) = thread_id { - record_participation(&adapter.participated_threads, tid, adapter.config.session_ttl_secs); - } - // Send response with message_id back to OAB core (for streaming edit) - if let Some(ref req_id) = reply.request_id { - let resp = crate::schema::GatewayResponse { - schema: "openab.gateway.response.v1".into(), - request_id: req_id.clone(), - success: true, - thread_id: None, - message_id: Some(msg_id), - error: None, - }; - if let Ok(json) = serde_json::to_string(&resp) { - let _ = event_tx.send(json); - } - } - } - None => { - // Send failure response so core doesn't wait 5s for timeout - if let Some(ref req_id) = reply.request_id { - let resp = crate::schema::GatewayResponse { - schema: "openab.gateway.response.v1".into(), - request_id: req_id.clone(), - success: false, - thread_id: None, - message_id: None, - error: Some("send_post_message failed".into()), - }; - if let Ok(json) = serde_json::to_string(&resp) { - let _ = event_tx.send(json); - } - } - } - } - } else { - // Track per-chunk success so we can report partial-failure back to core. - // Previously this branch returned no GatewayResponse at all and used - // "any chunk succeeded" as the success criterion — letting core fall - // through to a 5s timeout and silently mark the turn delivered. With - // request/response now wired through, we propagate exact health. - let chunks: Vec<&str> = split_text(text, limit); - let total_chunks = chunks.len(); - let mut succeeded = 0usize; - let mut last_msg_id: Option = None; - for chunk in &chunks { - if let Some(msg_id) = send_post_message(&adapter.client, &api_base, &token, &reply.channel.id, reply_target, chunk).await { - adapter.dedupe.is_duplicate(&msg_id); - succeeded += 1; - last_msg_id = Some(msg_id); - } - } - // Fallback: if quote_message_id caused all chunks to fail, retry without it - if succeeded == 0 && reply.quote_message_id.is_some() { - tracing::warn!(quote_message_id = ?reply.quote_message_id, channel_id = %reply.channel.id, "chunked reply-to failed, falling back to plain send"); - for chunk in &chunks { - if let Some(msg_id) = send_post_message(&adapter.client, &api_base, &token, &reply.channel.id, thread_id, chunk).await { - adapter.dedupe.is_duplicate(&msg_id); - succeeded += 1; - last_msg_id = Some(msg_id); - } - } - } - if succeeded > 0 { - if let Some(tid) = thread_id { - record_participation(&adapter.participated_threads, tid, adapter.config.session_ttl_secs); - } - } - // Report back to core. Success requires every chunk delivered — partial - // success becomes failure so dispatch surfaces ❌ rather than 🆗. - if let Some(ref req_id) = reply.request_id { - let success = succeeded == total_chunks && total_chunks > 0; - let error = if success { - None - } else { - Some(format!( - "chunked send delivered {succeeded}/{total_chunks} chunks" - )) - }; - let resp = crate::schema::GatewayResponse { - schema: "openab.gateway.response.v1".into(), - request_id: req_id.clone(), - success, - thread_id: None, - message_id: last_msg_id, - error, - }; - if let Ok(json) = serde_json::to_string(&resp) { - let _ = event_tx.send(json); - } - } - } -} - -/// Split text into chunks of at most `limit` bytes, breaking at newline or -/// space boundaries when possible. Safe for multi-byte UTF-8 (e.g. Chinese). -fn split_text(text: &str, limit: usize) -> Vec<&str> { - let mut chunks = Vec::new(); - let mut start = 0; - while start < text.len() { - if start + limit >= text.len() { - chunks.push(&text[start..]); - break; - } - // Find a char-safe boundary at or before start + limit - let mut end = start + limit; - while !text.is_char_boundary(end) { - end -= 1; - } - // Try to break at a newline or space within the last 200 bytes. - // search_start must also be on a char boundary to avoid panic. - let mut search_start = if end > start + 200 { end - 200 } else { start }; - while search_start < end && !text.is_char_boundary(search_start) { - search_start += 1; - } - let break_at = text[search_start..end] - .rfind('\n') - .or_else(|| text[search_start..end].rfind(' ')) - .map(|pos| search_start + pos + 1) - .unwrap_or(end); - chunks.push(&text[start..break_at]); - start = break_at; - } - chunks -} - -// --------------------------------------------------------------------------- -// Webhook handler -// --------------------------------------------------------------------------- - -/// Max webhook body size: 1 MB -const WEBHOOK_BODY_LIMIT: usize = 1_048_576; - -/// Simple per-IP rate limiter state. -pub struct RateLimiter { - counts: std::sync::Mutex>, - window_secs: u64, - max_requests: u64, -} - -impl RateLimiter { - pub fn new(window_secs: u64, max_requests: u64) -> Self { - Self { - counts: std::sync::Mutex::new(HashMap::new()), - window_secs, - max_requests, - } - } - - /// Returns true if the request should be rejected (rate exceeded). - pub fn check(&self, key: &str) -> bool { - let mut map = self.counts.lock().unwrap_or_else(|e| e.into_inner()); - // Lazy cleanup - if map.len() > 4096 { - map.retain(|_, (_, ts)| ts.elapsed().as_secs() < self.window_secs); - } - let entry = map.entry(key.to_string()).or_insert((0, Instant::now())); - if entry.1.elapsed().as_secs() >= self.window_secs { - *entry = (1, Instant::now()); - false - } else { - entry.0 += 1; - entry.0 > self.max_requests - } - } -} - -/// Verify webhook signature: SHA256(timestamp + nonce + encrypt_key + body). -fn verify_signature( - timestamp: &str, - nonce: &str, - encrypt_key: &str, - body: &[u8], - expected_sig: &str, -) -> bool { - use sha2::{Digest, Sha256}; - let mut hasher = Sha256::new(); - hasher.update(timestamp.as_bytes()); - hasher.update(nonce.as_bytes()); - hasher.update(encrypt_key.as_bytes()); - hasher.update(body); - let result = format!("{:x}", hasher.finalize()); - constant_time_eq(&result, expected_sig) -} - -/// Decrypt AES-CBC encrypted event body. -/// Feishu uses AES-256-CBC with SHA256(encrypt_key) as key, first 16 bytes of -/// ciphertext as IV. -fn decrypt_event(encrypt_key: &str, encrypted: &str) -> anyhow::Result { - use sha2::{Digest, Sha256}; - let key = Sha256::digest(encrypt_key.as_bytes()); - let cipher_bytes = base64::Engine::decode( - &base64::engine::general_purpose::STANDARD, - encrypted, - ) - .map_err(|e| anyhow::anyhow!("base64 decode failed: {e}"))?; - - if cipher_bytes.len() < 16 { - anyhow::bail!("encrypted data too short"); - } - - let iv = &cipher_bytes[..16]; - let ciphertext = &cipher_bytes[16..]; - - // AES-256-CBC decrypt - use aes::cipher::{BlockDecryptMut, KeyIvInit}; - type Aes256CbcDec = cbc::Decryptor; - - let decryptor = Aes256CbcDec::new_from_slices(&key, iv) - .map_err(|e| anyhow::anyhow!("aes init failed: {e}"))?; - - let mut buf = ciphertext.to_vec(); - let plaintext = decryptor - .decrypt_padded_mut::(&mut buf) - .map_err(|e| anyhow::anyhow!("aes decrypt failed: {e}"))?; - - String::from_utf8(plaintext.to_vec()) - .map_err(|e| anyhow::anyhow!("decrypted data not utf8: {e}")) -} - -pub async fn webhook( - State(state): State>, - headers: axum::http::HeaderMap, - body: axum::body::Bytes, -) -> axum::response::Response { - use axum::response::IntoResponse; - - let feishu = match state.feishu.as_ref() { - Some(f) => f, - None => return axum::http::StatusCode::SERVICE_UNAVAILABLE.into_response(), - }; - - // Body size limit - if body.len() > WEBHOOK_BODY_LIMIT { - warn!(size = body.len(), "feishu webhook body too large"); - return axum::http::StatusCode::PAYLOAD_TOO_LARGE.into_response(); - } - - // Rate limit (by X-Forwarded-For or fallback) - let ip = headers - .get("x-forwarded-for") - .and_then(|v| v.to_str().ok()) - .unwrap_or("unknown"); - if feishu.rate_limiter.check(ip) { - return (axum::http::StatusCode::TOO_MANY_REQUESTS, "rate limit exceeded") - .into_response(); - } - - // Signature verification (if encrypt_key configured) - if let Some(ref encrypt_key) = feishu.config.encrypt_key { - let sig = headers - .get("x-lark-signature") - .and_then(|v| v.to_str().ok()); - let timestamp = headers - .get("x-lark-request-timestamp") - .and_then(|v| v.to_str().ok()); - let nonce = headers - .get("x-lark-request-nonce") - .and_then(|v| v.to_str().ok()); - - match (sig, timestamp, nonce) { - (Some(sig), Some(ts), Some(nonce)) => { - if !verify_signature(ts, nonce, encrypt_key, &body, sig) { - warn!("feishu webhook rejected: invalid signature"); - return axum::http::StatusCode::UNAUTHORIZED.into_response(); - } - } - _ => { - warn!("feishu webhook rejected: missing signature headers"); - return axum::http::StatusCode::UNAUTHORIZED.into_response(); - } - } - } else { - warn!("FEISHU_ENCRYPT_KEY not configured — webhook signature verification is SKIPPED (insecure)"); - } - - // Parse body — may be encrypted - let event_json: serde_json::Value = match serde_json::from_slice(&body) { - Ok(v) => v, - Err(e) => { - warn!(err = %e, "feishu webhook parse error"); - return axum::http::StatusCode::BAD_REQUEST.into_response(); - } - }; - - // Handle encrypted events - let event_json = if let Some(encrypted) = event_json.get("encrypt").and_then(|v| v.as_str()) { - let encrypt_key = match feishu.config.encrypt_key.as_deref() { - Some(k) => k, - None => { - warn!("feishu webhook: encrypted event but no FEISHU_ENCRYPT_KEY configured"); - return axum::http::StatusCode::BAD_REQUEST.into_response(); - } - }; - match decrypt_event(encrypt_key, encrypted) { - Ok(decrypted) => match serde_json::from_str(&decrypted) { - Ok(v) => v, - Err(e) => { - warn!(err = %e, "feishu webhook: decrypted event parse error"); - return axum::http::StatusCode::BAD_REQUEST.into_response(); - } - }, - Err(e) => { - warn!(err = %e, "feishu webhook: decrypt failed"); - return axum::http::StatusCode::BAD_REQUEST.into_response(); - } - } - } else { - event_json - }; - - // URL verification challenge - if event_json.get("challenge").is_some() { - // Verify token if configured - if let Some(ref expected_token) = feishu.config.verification_token { - let token = event_json.get("token").and_then(|v| v.as_str()); - match token { - Some(t) if constant_time_eq(t, expected_token) => {} - _ => { - warn!("feishu webhook: URL verification token mismatch"); - return axum::http::StatusCode::UNAUTHORIZED.into_response(); - } - } - } - let challenge = event_json["challenge"].as_str().unwrap_or(""); - return axum::Json(serde_json::json!({"challenge": challenge})).into_response(); - } - - // Verification token check for regular events - if let Some(ref expected_token) = feishu.config.verification_token { - let token = event_json - .pointer("/header/token") - .or_else(|| event_json.get("token")) - .and_then(|v| v.as_str()); - match token { - Some(t) if constant_time_eq(t, expected_token) => {} - _ => { - warn!("feishu webhook rejected: invalid verification token"); - return axum::http::StatusCode::UNAUTHORIZED.into_response(); - } - } - } - - // Parse as event envelope - let envelope: FeishuEventEnvelope = match serde_json::from_value(event_json) { - Ok(e) => e, - Err(e) => { - warn!(err = %e, "feishu webhook: event envelope parse error"); - return axum::http::StatusCode::OK.into_response(); - } - }; - - // Dedupe + parse + broadcast (same logic as WebSocket handler) - if let Some(ref header) = envelope.header { - if let Some(ref event_id) = header.event_id { - if feishu.dedupe.is_duplicate(event_id) { - return axum::http::StatusCode::OK.into_response(); - } - } - } - - let bot_id = feishu.bot_open_id.read().await; - let bot_id_ref = bot_id.as_deref(); - - // Check participated threads and multibot detection for mention bypass - let bypass_mention = detect_and_mark_multibot( - &envelope, bot_id_ref, &feishu.config, - &feishu.participated_threads, &feishu.multibot_threads, - ); - - if let Some((mut gateway_event, media_refs)) = parse_message_event(&envelope, bot_id_ref, &feishu.config, bypass_mention) { - if !feishu.dedupe.is_duplicate(&gateway_event.message_id) { - let name = resolve_user_name( - &gateway_event.sender.id, &feishu.name_cache, &feishu.token_cache, - &feishu.client, &feishu.config.api_base(), - ).await; - gateway_event.sender.name = name.clone(); - gateway_event.sender.display_name = name; - - // Download media attachments - if !media_refs.is_empty() { - if let Ok(token) = feishu.token_cache.get_token(&feishu.client).await { - let api_base = feishu.config.api_base(); - for media_ref in &media_refs { - let attachment = match media_ref { - MediaRef::Image { message_id, image_key } => { - download_feishu_image(&feishu.client, &api_base, &token, message_id, image_key).await - } - MediaRef::File { message_id, file_key, file_name } => { - download_feishu_file(&feishu.client, &api_base, &token, message_id, file_key, file_name).await - } - MediaRef::Audio { message_id, file_key } => { - download_feishu_audio(&feishu.client, &api_base, &token, message_id, file_key).await - } - }; - if let Some(att) = attachment { - gateway_event.content.attachments.push(att); - } - } - } - } - - // Skip if no text and no attachments (e.g. unsupported file type) - if gateway_event.content.text.trim().is_empty() && gateway_event.content.attachments.is_empty() { - return axum::http::StatusCode::OK.into_response(); - } - - let json = serde_json::to_string(&gateway_event).unwrap(); - info!( - channel = %gateway_event.channel.id, - sender = %gateway_event.sender.id, - "feishu webhook → gateway" - ); - let _ = state.event_tx.send(json); - } - } - - axum::http::StatusCode::OK.into_response() -} - -// --------------------------------------------------------------------------- -// Tests -// --------------------------------------------------------------------------- - -#[cfg(test)] -mod tests { - use super::*; - use wiremock::matchers::{body_json, header, method, path}; - use wiremock::{Mock, MockServer, ResponseTemplate}; - - fn test_config() -> FeishuConfig { - FeishuConfig { - app_id: "cli_test".into(), - app_secret: "secret_test".into(), - domain: "feishu".into(), - connection_mode: ConnectionMode::Websocket, - webhook_path: "/webhook/feishu".into(), - verification_token: None, - encrypt_key: None, - allowed_groups: vec![], - allowed_users: vec![], - require_mention: true, - allow_bots: AllowBots::Off, - allow_user_messages: AllowUsers::MultibotMentions, - trusted_bot_ids: vec![], - max_bot_turns: 20, - dedupe_ttl_secs: 300, - message_limit: 4000, - session_ttl_secs: 86400, - api_base_override: None, - } - } - - // --- Token tests --- - - #[tokio::test] - async fn token_refresh_success() { - let server = MockServer::start().await; - Mock::given(method("POST")) - .and(path("/open-apis/auth/v3/tenant_access_token/internal")) - .and(body_json(serde_json::json!({ - "app_id": "cli_test", - "app_secret": "secret_test", - }))) - .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ - "code": 0, - "msg": "ok", - "tenant_access_token": "t-test-token-123", - "expire": 7200 - }))) - .expect(1) - .mount(&server) - .await; - - let config = test_config(); - let cache = FeishuTokenCache::with_base(&config, &server.uri()); - let client = reqwest::Client::new(); - - let token = cache.get_token(&client).await.unwrap(); - assert_eq!(token, "t-test-token-123"); - - // Second call should use cache, not hit server again (expect(1) above) - let token2 = cache.get_token(&client).await.unwrap(); - assert_eq!(token2, "t-test-token-123"); - } - - #[tokio::test] - async fn token_refresh_api_error() { - let server = MockServer::start().await; - Mock::given(method("POST")) - .and(path("/open-apis/auth/v3/tenant_access_token/internal")) - .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ - "code": 10003, - "msg": "invalid app_secret", - }))) - .expect(1) - .mount(&server) - .await; - - let config = test_config(); - let cache = FeishuTokenCache::with_base(&config, &server.uri()); - let client = reqwest::Client::new(); - - let err = cache.get_token(&client).await.unwrap_err(); - let msg = err.to_string(); - assert!(msg.contains("10003"), "error should contain code: {msg}"); - assert!( - !msg.contains("secret_test"), - "error must not leak secret: {msg}" - ); - } - - // --- Send message tests --- - - #[tokio::test] - async fn send_text_message_success() { - let server = MockServer::start().await; - Mock::given(method("POST")) - .and(path("/open-apis/im/v1/messages")) - .and(header("authorization", "Bearer t-tok")) - .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ - "code": 0, - "msg": "success", - "data": {"message_id": "om_test123"} - }))) - .expect(1) - .mount(&server) - .await; - - let client = reqwest::Client::new(); - let msg_id = send_text_message(&client, &server.uri(), "t-tok", "oc_chat1", "hello").await; - assert_eq!(msg_id.as_deref(), Some("om_test123")); - } - - #[tokio::test] - async fn send_text_message_api_failure() { - let server = MockServer::start().await; - Mock::given(method("POST")) - .and(path("/open-apis/im/v1/messages")) - .respond_with(ResponseTemplate::new(400).set_body_string("bad request")) - .expect(1) - .mount(&server) - .await; - - let client = reqwest::Client::new(); - let msg_id = send_text_message(&client, &server.uri(), "t-tok", "oc_chat1", "hello").await; - assert!(msg_id.is_none()); - } - - #[tokio::test] - async fn send_text_message_invalid_json_returns_none() { - let server = MockServer::start().await; - Mock::given(method("POST")) - .and(path("/open-apis/im/v1/messages")) - .respond_with(ResponseTemplate::new(200).set_body_string("not json")) - .expect(1) - .mount(&server) - .await; - - let client = reqwest::Client::new(); - let msg_id = send_text_message(&client, &server.uri(), "t-tok", "oc_chat1", "hello").await; - assert!(msg_id.is_none()); - } - - #[tokio::test] - async fn send_text_message_missing_message_id_returns_none() { - let server = MockServer::start().await; - Mock::given(method("POST")) - .and(path("/open-apis/im/v1/messages")) - .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ - "code": 0, - "msg": "success", - }))) - .expect(1) - .mount(&server) - .await; - - let client = reqwest::Client::new(); - let msg_id = send_text_message(&client, &server.uri(), "t-tok", "oc_chat1", "hello").await; - assert!(msg_id.is_none()); - } - - // --- Split text tests --- - - #[test] - fn split_text_short() { - let chunks = split_text("hello", 100); - assert_eq!(chunks, vec!["hello"]); - } - - #[test] - fn split_text_exact_limit() { - let text = "a".repeat(100); - let chunks = split_text(&text, 100); - assert_eq!(chunks.len(), 1); - } - - #[test] - fn split_text_chinese_utf8_safe() { - // Each Chinese char is 3 bytes. 20 chars = 60 bytes. - // Limit 10 would land mid-char without boundary check. - let text = "你好世界測試飛書中文聊天消息分割安全驗證完成"; - let chunks = split_text(text, 10); - assert!(chunks.len() > 1); - let reassembled: String = chunks.concat(); - assert_eq!(reassembled, text); - } - - #[test] - fn split_text_search_start_char_boundary() { - // Regression: search_start = end - 200 could land mid-char. - // 300 Chinese chars (900 bytes) with limit=500 forces search_start - // into the middle of multi-byte chars. - let text: String = "飛書".repeat(150); // 300 chars, 900 bytes - let chunks = split_text(&text, 500); - assert!(chunks.len() >= 2); - let reassembled: String = chunks.concat(); - assert_eq!(reassembled, text); - } - - #[test] - fn split_text_long_breaks_at_newline() { - let text = format!("{}\n{}", "a".repeat(50), "b".repeat(50)); - let chunks = split_text(&text, 60); - assert_eq!(chunks.len(), 2); - assert!(chunks[0].ends_with('\n')); - } - - // --- Event parsing tests --- - - fn make_envelope( - chat_type: &str, - text: &str, - sender_open_id: &str, - mentions: Option>, - ) -> FeishuEventEnvelope { - FeishuEventEnvelope { - header: Some(FeishuEventHeader { - event_id: Some("evt_test".into()), - event_type: Some("im.message.receive_v1".into()), - }), - event: Some(FeishuEventBody { - sender: Some(FeishuSender { - sender_id: Some(FeishuSenderId { - open_id: Some(sender_open_id.into()), - }), - sender_type: Some("user".into()), - }), - message: Some(FeishuMessage { - message_id: Some("om_msg1".into()), - chat_id: Some("oc_chat1".into()), - chat_type: Some(chat_type.into()), - message_type: Some("text".into()), - content: Some(serde_json::json!({"text": text}).to_string()), - mentions, - root_id: None, - parent_id: None, - }), - }), - challenge: None, - event_type_field: None, - } - } - - #[test] - fn parse_dm_text() { - let env = make_envelope("p2p", "hello", "ou_user1", None); - let cfg = test_config(); - let (evt, _media) = parse_message_event(&env, Some("ou_bot"), &cfg, false).unwrap(); - assert_eq!(evt.platform, "feishu"); - assert_eq!(evt.channel.channel_type, "direct"); - assert_eq!(evt.channel.id, "oc_chat1"); - assert_eq!(evt.sender.id, "ou_user1"); - assert_eq!(evt.content.text, "hello"); - assert!(evt.mentions.is_empty()); - } - - #[test] - fn parse_group_with_bot_mention() { - let mentions = vec![FeishuMention { - key: Some("@_user_1".into()), - id: Some(FeishuMentionId { - open_id: Some("ou_bot".into()), - }), - name: Some("Bot".into()), - }]; - let env = make_envelope("group", "@_user_1 explain VPC", "ou_user1", Some(mentions)); - let cfg = test_config(); - let (evt, _media) = parse_message_event(&env, Some("ou_bot"), &cfg, false).unwrap(); - assert_eq!(evt.channel.channel_type, "group"); - assert_eq!(evt.content.text, "explain VPC"); - assert_eq!(evt.mentions, vec!["ou_bot"]); - } - - #[test] - fn parse_group_without_mention_filtered() { - let env = make_envelope("group", "just chatting", "ou_user1", None); - let cfg = test_config(); // require_mention = true - // Gateway-side mention gating: group message without bot mention is filtered - assert!(parse_message_event(&env, Some("ou_bot"), &cfg, false).is_none()); - } - - #[test] - fn parse_group_without_mention_allowed_when_disabled() { - let env = make_envelope("group", "just chatting", "ou_user1", None); - let mut cfg = test_config(); - cfg.require_mention = false; - let evt = parse_message_event(&env, Some("ou_bot"), &cfg, false); - assert!(evt.is_some()); - } - - #[test] - fn parse_skips_bot_sender() { - let mut env = make_envelope("p2p", "hello", "ou_bot", None); - env.event.as_mut().unwrap().sender.as_mut().unwrap().sender_type = Some("bot".into()); - let cfg = test_config(); - assert!(parse_message_event(&env, Some("ou_bot"), &cfg, false).is_none()); - } - - #[test] - fn parse_skips_empty_text() { - let env = make_envelope("p2p", " ", "ou_user1", None); - let cfg = test_config(); - assert!(parse_message_event(&env, Some("ou_bot"), &cfg, false).is_none()); - } - - #[test] - fn parse_skips_non_text_message() { - let mut env = make_envelope("p2p", "hello", "ou_user1", None); - env.event.as_mut().unwrap().message.as_mut().unwrap().message_type = Some("sticker".into()); - let cfg = test_config(); - assert!(parse_message_event(&env, Some("ou_bot"), &cfg, false).is_none()); - } - - #[test] - fn parse_skips_self_message() { - let env = make_envelope("p2p", "hello", "ou_bot", None); - let cfg = test_config(); - assert!(parse_message_event(&env, Some("ou_bot"), &cfg, false).is_none()); - } - - // --- Dedupe tests --- - - #[test] - fn dedupe_first_is_not_duplicate() { - let cache = DedupeCache::new(300); - assert!(!cache.is_duplicate("msg_1")); - } - - #[test] - fn dedupe_second_is_duplicate() { - let cache = DedupeCache::new(300); - assert!(!cache.is_duplicate("msg_1")); - assert!(cache.is_duplicate("msg_1")); - } - - // --- Webhook security tests --- - - #[test] - fn verify_signature_correct() { - use sha2::{Digest, Sha256}; - let ts = "1234567890"; - let nonce = "abc"; - let key = "mykey"; - let body = b"hello"; - let mut hasher = Sha256::new(); - hasher.update(ts.as_bytes()); - hasher.update(nonce.as_bytes()); - hasher.update(key.as_bytes()); - hasher.update(body); - let expected = format!("{:x}", hasher.finalize()); - assert!(verify_signature(ts, nonce, key, body, &expected)); - } - - #[test] - fn verify_signature_wrong() { - assert!(!verify_signature("ts", "nonce", "key", b"body", "bad_sig")); - } - - #[test] - fn rate_limiter_allows_within_limit() { - let rl = RateLimiter::new(60, 3); - assert!(!rl.check("ip1")); - assert!(!rl.check("ip1")); - assert!(!rl.check("ip1")); - } - - #[test] - fn rate_limiter_rejects_over_limit() { - let rl = RateLimiter::new(60, 2); - assert!(!rl.check("ip1")); - assert!(!rl.check("ip1")); - assert!(rl.check("ip1")); // 3rd request exceeds limit of 2 - } - - // --- Name resolution tests --- - - #[tokio::test] - async fn resolve_user_name_success_and_cache() { - let server = MockServer::start().await; - Mock::given(method("POST")) - .and(path("/open-apis/auth/v3/tenant_access_token/internal")) - .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ - "code": 0, "tenant_access_token": "t-tok", "expire": 7200 - }))) - .mount(&server) - .await; - Mock::given(method("GET")) - .and(path("/open-apis/contact/v3/users/ou_user1")) - .and(header("authorization", "Bearer t-tok")) - .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ - "code": 0, - "data": { "user": { "name": "Alice", "open_id": "ou_user1" } } - }))) - .expect(1) // should only be called once (cached on second call) - .mount(&server) - .await; - - let config = test_config(); - let token_cache = Arc::new(FeishuTokenCache::with_base(&config, &server.uri())); - let name_cache = Arc::new(std::sync::Mutex::new(HashMap::new())); - let client = reqwest::Client::new(); - - let name = resolve_user_name("ou_user1", &name_cache, &token_cache, &client, &server.uri()).await; - assert_eq!(name, "Alice"); - - // Second call should use cache (expect(1) above ensures no second API call) - let name2 = resolve_user_name("ou_user1", &name_cache, &token_cache, &client, &server.uri()).await; - assert_eq!(name2, "Alice"); - } - - #[tokio::test] - async fn resolve_user_name_api_error_falls_back_to_open_id() { - let server = MockServer::start().await; - Mock::given(method("POST")) - .and(path("/open-apis/auth/v3/tenant_access_token/internal")) - .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ - "code": 0, "tenant_access_token": "t-tok", "expire": 7200 - }))) - .mount(&server) - .await; - Mock::given(method("GET")) - .and(path("/open-apis/contact/v3/users/ou_unknown")) - .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ - "code": 40003, "msg": "user not found" - }))) - .mount(&server) - .await; - - let config = test_config(); - let token_cache = Arc::new(FeishuTokenCache::with_base(&config, &server.uri())); - let name_cache = Arc::new(std::sync::Mutex::new(HashMap::new())); - let client = reqwest::Client::new(); - - let name = resolve_user_name("ou_unknown", &name_cache, &token_cache, &client, &server.uri()).await; - assert_eq!(name, "ou_unknown"); - } - - // --- extract_mentions tests --- - - #[test] - fn extract_mentions_replacen_only_first() { - // If mention key appears in normal text too, only the first occurrence is removed - let mentions = vec![FeishuMention { - key: Some("@_user_1".into()), - id: Some(FeishuMentionId { open_id: Some("ou_bot".into()) }), - name: Some("Bot".into()), - }]; - let env = make_envelope("group", "@_user_1 tell me about @_user_1 patterns", "ou_user1", Some(mentions)); - let cfg = test_config(); - let (evt, _media) = parse_message_event(&env, Some("ou_bot"), &cfg, false).unwrap(); - // Only first @_user_1 removed, second preserved - assert!(evt.content.text.contains("@_user_1")); - } - - // --- allowed_users filtering --- - - #[test] - fn parse_allowed_users_blocks_unlisted() { - let env = make_envelope("p2p", "hello", "ou_stranger", None); - let mut cfg = test_config(); - cfg.allowed_users = vec!["ou_vip".into()]; - assert!(parse_message_event(&env, Some("ou_bot"), &cfg, false).is_none()); - } - - #[test] - fn parse_allowed_users_permits_listed() { - let env = make_envelope("p2p", "hello", "ou_vip", None); - let mut cfg = test_config(); - cfg.allowed_users = vec!["ou_vip".into()]; - assert!(parse_message_event(&env, Some("ou_bot"), &cfg, false).is_some()); - } - - // --- allowed_groups filtering --- - - #[test] - fn parse_allowed_groups_blocks_unlisted() { - let mentions = vec![FeishuMention { - key: Some("@_user_1".into()), - id: Some(FeishuMentionId { open_id: Some("ou_bot".into()) }), - name: Some("Bot".into()), - }]; - let env = make_envelope("group", "@_user_1 hello", "ou_user1", Some(mentions)); - let mut cfg = test_config(); - cfg.allowed_groups = vec!["oc_other".into()]; // oc_chat1 not in list - assert!(parse_message_event(&env, Some("ou_bot"), &cfg, false).is_none()); - } - - #[test] - fn parse_allowed_groups_permits_listed() { - let mentions = vec![FeishuMention { - key: Some("@_user_1".into()), - id: Some(FeishuMentionId { open_id: Some("ou_bot".into()) }), - name: Some("Bot".into()), - }]; - let env = make_envelope("group", "@_user_1 hello", "ou_user1", Some(mentions)); - let mut cfg = test_config(); - cfg.allowed_groups = vec!["oc_chat1".into()]; - assert!(parse_message_event(&env, Some("ou_bot"), &cfg, false).is_some()); - } - - // --- Token TTL from API response --- - - #[tokio::test] - async fn token_uses_api_expire_field() { - let server = MockServer::start().await; - // Return a short expire (10s). With 300s margin, token should be - // considered expired immediately on second call. - Mock::given(method("POST")) - .and(path("/open-apis/auth/v3/tenant_access_token/internal")) - .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ - "code": 0, - "tenant_access_token": "t-short", - "expire": 10 - }))) - .expect(2) // called twice because 10s < 300s margin → always expired - .mount(&server) - .await; - - let config = test_config(); - let cache = FeishuTokenCache::with_base(&config, &server.uri()); - let client = reqwest::Client::new(); - - let t1 = cache.get_token(&client).await.unwrap(); - assert_eq!(t1, "t-short"); - // Second call should refresh (expire=10 < margin=300) - let t2 = cache.get_token(&client).await.unwrap(); - assert_eq!(t2, "t-short"); - // expect(2) verifies it was called twice - } - - // --- constant_time_eq --- - - #[test] - fn constant_time_eq_same() { - assert!(constant_time_eq("abc123", "abc123")); - } - - #[test] - fn constant_time_eq_different() { - assert!(!constant_time_eq("abc123", "abc124")); - } - - #[test] - fn constant_time_eq_different_length() { - assert!(!constant_time_eq("short", "longer_string")); - } - - // --- Thread ID parsing --- - - #[test] - fn parse_thread_id_from_root_id() { - let mut env = make_envelope("p2p", "reply", "ou_user1", None); - env.event.as_mut().unwrap().message.as_mut().unwrap().root_id = Some("om_root".into()); - let cfg = test_config(); - let (evt, _media) = parse_message_event(&env, Some("ou_bot"), &cfg, false).unwrap(); - assert_eq!(evt.channel.thread_id, Some("om_root".into())); - } - - #[test] - fn parse_thread_id_from_parent_id() { - let mut env = make_envelope("p2p", "reply", "ou_user1", None); - env.event.as_mut().unwrap().message.as_mut().unwrap().parent_id = Some("om_parent".into()); - let cfg = test_config(); - let (evt, _media) = parse_message_event(&env, Some("ou_bot"), &cfg, false).unwrap(); - assert_eq!(evt.channel.thread_id, Some("om_parent".into())); - } - - // --- Emoji reaction mapping --- - - #[test] - fn emoji_mapping_known() { - assert_eq!(emoji_to_feishu_reaction("👍"), Some("THUMBSUP")); - assert_eq!(emoji_to_feishu_reaction("🔥"), Some("FIRE")); - assert_eq!(emoji_to_feishu_reaction("👀"), Some("EYES")); - } - - #[test] - fn emoji_mapping_unknown() { - assert_eq!(emoji_to_feishu_reaction("🎉"), None); - } - - // --- Participated thread tests --- - - #[test] - fn participated_thread_bypasses_mention_gating() { - let cfg = test_config(); // require_mention = true - // Build envelope with root_id (in a thread) - let mut env = make_envelope("group", "Hello", "ou_user1", None); - env.event.as_mut().unwrap().message.as_mut().unwrap().root_id = Some("root_123".into()); - // Without participation: no @mention → None - assert!(parse_message_event(&env, Some("ou_bot"), &cfg, false).is_none()); - // With participation: no @mention → Some (bypass) - let result = parse_message_event(&env, Some("ou_bot"), &cfg, true); - assert!(result.is_some()); - let (evt, _) = result.unwrap(); - assert_eq!(evt.channel.thread_id.as_deref(), Some("root_123")); - } - - #[test] - fn participated_no_effect_without_thread() { - let cfg = test_config(); // require_mention = true - // Message in main channel (no thread_id) — participated flag doesn't help - let env = make_envelope("group", "Hello", "ou_user1", None); - assert!(parse_message_event(&env, Some("ou_bot"), &cfg, true).is_none()); - } - - #[test] - fn record_participation_and_eviction() { - let cache = Arc::new(std::sync::Mutex::new(HashMap::new())); - // Record a thread - record_participation(&cache, "thread_1", 86400); - assert_eq!(cache.lock().unwrap().len(), 1); - // Fill beyond PARTICIPATION_CACHE_MAX - for i in 0..PARTICIPATION_CACHE_MAX + 10 { - record_participation(&cache, &format!("thread_{i}"), 86400); - } - // After eviction, should be roughly half - assert!(cache.lock().unwrap().len() <= PARTICIPATION_CACHE_MAX); - } - - // --- Multibot-mentions mode tests --- - - #[test] - fn multibot_mentions_mode_bypasses_when_single_bot() { - let mut cfg = test_config(); - cfg.allow_user_messages = AllowUsers::MultibotMentions; - let mut env = make_envelope("group", "Hello", "ou_user1", None); - env.event.as_mut().unwrap().message.as_mut().unwrap().root_id = Some("root_456".into()); - // participated + no other bot → bypass_mention_gating=true - let result = parse_message_event(&env, Some("ou_bot"), &cfg, true); - assert!(result.is_some()); - } - - #[test] - fn multibot_mentions_mode_requires_mention_when_not_participated() { - let mut cfg = test_config(); - cfg.allow_user_messages = AllowUsers::MultibotMentions; - let mut env = make_envelope("group", "Hello", "ou_user1", None); - env.event.as_mut().unwrap().message.as_mut().unwrap().root_id = Some("root_456".into()); - // not participated → bypass_mention_gating=false - assert!(parse_message_event(&env, Some("ou_bot"), &cfg, false).is_none()); - } - - #[test] - fn mentions_mode_never_bypasses() { - let mut cfg = test_config(); - cfg.allow_user_messages = AllowUsers::Mentions; - let mut env = make_envelope("group", "Hello", "ou_user1", None); - env.event.as_mut().unwrap().message.as_mut().unwrap().root_id = Some("root_789".into()); - // Even with bypass_mention_gating=true, Mentions mode never bypasses - // (caller would pass false because Mentions mode always returns false) - assert!(parse_message_event(&env, Some("ou_bot"), &cfg, false).is_none()); - } - - #[test] - fn quote_message_id_takes_priority_over_thread_id() { - use crate::schema::{GatewayReply, ReplyChannel, Content}; - let reply = GatewayReply { - schema: "openab.gateway.reply.v1".into(), - reply_to: "evt_123".into(), - platform: "feishu".into(), - channel: ReplyChannel { - id: "chat_123".into(), - thread_id: Some("om_root".into()), - }, - content: Content { - content_type: "text".into(), - text: "hello".into(), - attachments: vec![], - }, - command: None, - request_id: None, - quote_message_id: Some("om_specific".into()), - }; - // quote_message_id should take priority - let reply_target = reply.quote_message_id.as_deref() - .or(reply.channel.thread_id.as_deref()); - assert_eq!(reply_target, Some("om_specific")); - } - - #[test] - fn reply_target_falls_back_to_thread_id_when_no_quote() { - use crate::schema::{GatewayReply, ReplyChannel, Content}; - let reply = GatewayReply { - schema: "openab.gateway.reply.v1".into(), - reply_to: "evt_123".into(), - platform: "feishu".into(), - channel: ReplyChannel { - id: "chat_123".into(), - thread_id: Some("om_root".into()), - }, - content: Content { - content_type: "text".into(), - text: "hello".into(), - attachments: vec![], - }, - command: None, - request_id: None, - quote_message_id: None, - }; - let reply_target = reply.quote_message_id.as_deref() - .or(reply.channel.thread_id.as_deref()); - assert_eq!(reply_target, Some("om_root")); - } - - #[test] - fn reply_target_is_none_when_both_absent() { - use crate::schema::{GatewayReply, ReplyChannel, Content}; - let reply = GatewayReply { - schema: "openab.gateway.reply.v1".into(), - reply_to: "evt_123".into(), - platform: "feishu".into(), - channel: ReplyChannel { - id: "chat_123".into(), - thread_id: None, - }, - content: Content { - content_type: "text".into(), - text: "hello".into(), - attachments: vec![], - }, - command: None, - request_id: None, - quote_message_id: None, - }; - let reply_target = reply.quote_message_id.as_deref() - .or(reply.channel.thread_id.as_deref()); - assert_eq!(reply_target, None); - } - - #[tokio::test] - async fn quote_message_id_fallback_on_reply_failure() { - // Tests the actual handle_reply fallback path: when quote_message_id - // is set and the reply API fails, handle_reply retries as plain send. - let server = MockServer::start().await; - - // Token endpoint - Mock::given(method("POST")) - .and(path("/open-apis/auth/v3/tenant_access_token/internal")) - .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ - "code": 0, - "tenant_access_token": "t-test", - "expire": 7200 - }))) - .mount(&server) - .await; - - // Reply API returns 400 (invalid quote_message_id) - Mock::given(method("POST")) - .and(path("/open-apis/im/v1/messages/om_invalid/reply")) - .respond_with(ResponseTemplate::new(400).set_body_string("invalid message_id")) - .expect(1) - .named("reply_api_fail") - .mount(&server) - .await; - - // Plain send endpoint succeeds (fallback path) - Mock::given(method("POST")) - .and(path("/open-apis/im/v1/messages")) - .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ - "code": 0, - "data": {"message_id": "om_fallback_ok"} - }))) - .expect(1) - .named("plain_send_fallback") - .mount(&server) - .await; - - let mut config = test_config(); - config.api_base_override = Some(server.uri()); - let adapter = FeishuAdapter::new(config); - - let (event_tx, _rx) = tokio::sync::broadcast::channel(16); - - let reply = crate::schema::GatewayReply { - schema: "openab.gateway.reply.v1".into(), - reply_to: "evt_123".into(), - platform: "feishu".into(), - channel: crate::schema::ReplyChannel { - id: "oc_chat1".into(), - thread_id: None, - }, - content: crate::schema::Content { - content_type: "text".into(), - text: "hello from fallback test".into(), - attachments: vec![], - }, - command: None, - request_id: None, - quote_message_id: Some("om_invalid".into()), - }; - - handle_reply(&reply, &adapter, &event_tx).await; - // wiremock expect(1) on both mocks verifies: - // 1. Reply API was called (and failed) - // 2. Plain send was called (fallback triggered by quote_message_id.is_some() guard) - } - - // --- Edit-cap helpers (F3/F4/F8/F10): no network required --- - - fn fresh_cache() -> Arc> { - Arc::new(std::sync::Mutex::new(EditCountsCache::default())) - } - - #[test] - fn cap_detect_json_code_match() { - // Real-shape Feishu error body: trusted JSON code field == 230072. - let body = r#"{"code":230072,"msg":"The message has reached the number of times it can be edited","data":{}}"#; - assert!(is_feishu_cap_reached_body(body)); - } - - #[test] - fn cap_detect_json_other_code_no_false_positive() { - // JSON parses but code is unrelated; any inner string containing - // "230072" must NOT trigger cap detection. - let body = r#"{"code":99999,"msg":"some other error 230072 in description"}"#; - assert!(!is_feishu_cap_reached_body(body)); - } - - #[test] - fn cap_detect_substring_fallback_for_non_json() { - // Proxy-style HTML / non-JSON body — substring fallback kicks in. - let html = "Error 230072 — number of times it can be edited"; - assert!(is_feishu_cap_reached_body(html)); - - let plain = "upstream error: 230072"; - assert!(is_feishu_cap_reached_body(plain)); - } - - #[test] - fn cap_detect_unrelated_body_returns_false() { - let body = r#"{"code":99991,"msg":"rate limited","data":{}}"#; - assert!(!is_feishu_cap_reached_body(body)); - assert!(!is_feishu_cap_reached_body("plain text without the code")); - assert!(!is_feishu_cap_reached_body("")); - } - - #[test] - fn cap_pre_check_below_threshold_does_not_trip() { - let cache = fresh_cache(); - // Cap is FEISHU_EDIT_CAP (18). 17 increments must stay below. - for _ in 0..17 { - increment_edit_count(&cache, "om_msg1"); - } - assert!(!is_edit_cap_reached(&cache, "om_msg1")); - } - - #[test] - fn cap_pre_check_at_threshold_trips() { - let cache = fresh_cache(); - for _ in 0..(FEISHU_EDIT_CAP as usize) { - increment_edit_count(&cache, "om_msg1"); - } - assert!(is_edit_cap_reached(&cache, "om_msg1")); - } - - #[test] - fn mark_edit_cap_short_circuits_pre_check() { - let cache = fresh_cache(); - mark_edit_cap(&cache, "om_msg1"); - // Sentinel u32::MAX >= FEISHU_EDIT_CAP, so pre-check trips immediately. - assert!(is_edit_cap_reached(&cache, "om_msg1")); - } - - #[test] - fn mark_edit_cap_does_not_double_increment() { - let cache = fresh_cache(); - mark_edit_cap(&cache, "om_msg1"); - increment_edit_count(&cache, "om_msg1"); - // Increment must not push past u32::MAX sentinel. - let map = cache.lock().unwrap(); - assert_eq!(map.counts.get("om_msg1").copied(), Some(u32::MAX)); - } - - #[test] - fn eviction_drops_oldest_inserts_not_lowest_count() { - // Pre-fill cache to over capacity, simulating a long-running adapter. - let cache = fresh_cache(); - // First insert message_id "old_*" with high counts so they would - // *survive* a count-ascending eviction (the buggy strategy). They - // must instead be the *first* evicted under FIFO. - let overcap = EDIT_COUNTS_CACHE_MAX + 100; - for i in 0..overcap { - let id = format!("om_msg_{i:05}"); - increment_edit_count(&cache, &id); - } - // Insert a fresh "active stream" id last — its low count would have - // marked it for eviction under count-ascending. With FIFO it must - // survive. - increment_edit_count(&cache, "om_active_recent"); - - let map = cache.lock().unwrap(); - // FIFO eviction: the newest insert must still be present. - assert!( - map.counts.contains_key("om_active_recent"), - "active recent insert was evicted under FIFO — bug regressed" - ); - // FIFO eviction: at least one of the very first inserts must be gone. - let some_oldest_evicted = (0..50).any(|i| { - let id = format!("om_msg_{i:05}"); - !map.counts.contains_key(&id) - }); - assert!( - some_oldest_evicted, - "no early-insert key was evicted — FIFO not working" - ); - // Cache size bounded. - assert!( - map.counts.len() <= EDIT_COUNTS_CACHE_MAX, - "cache size {} > max {}", - map.counts.len(), - EDIT_COUNTS_CACHE_MAX - ); - } - - #[test] - fn message_id_validation_accepts_valid_shapes() { - assert!(is_valid_feishu_message_id("om_dc13264520392907fcq2e6kpngacls")); - assert!(is_valid_feishu_message_id("om_abc123")); - assert!(is_valid_feishu_message_id("om_A_B_c_1_2_3")); - } - - #[test] - fn message_id_validation_rejects_path_traversal_and_query() { - // The shape guard is the F8 defence: stop crafted IDs containing URL - // metachars from altering /im/v1/messages/{id} semantics. - assert!(!is_valid_feishu_message_id("../etc/passwd")); - assert!(!is_valid_feishu_message_id("om_abc/extra")); - assert!(!is_valid_feishu_message_id("om_abc?q=1")); - assert!(!is_valid_feishu_message_id("om_abc#frag")); - assert!(!is_valid_feishu_message_id("om_abc%2Fextra")); - assert!(!is_valid_feishu_message_id("")); - assert!(!is_valid_feishu_message_id("om_")); - assert!(!is_valid_feishu_message_id("not_om_prefix")); - // Length cap (defense against pathological inputs). - let too_long = format!("om_{}", "a".repeat(200)); - assert!(!is_valid_feishu_message_id(&too_long)); - } - - // --- edit_feishu_message integration (wiremock): proves the cap is detected - // through the HTTP-status gate, including the HTTP-200 + body-code case --- - - async fn mount_token(server: &MockServer) { - Mock::given(method("POST")) - .and(path("/open-apis/auth/v3/tenant_access_token/internal")) - .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ - "code": 0, - "msg": "ok", - "tenant_access_token": "t-edit-test", - "expire": 7200 - }))) - .mount(server) - .await; - } - - #[tokio::test] - async fn edit_cap_detected_on_http_200_body_code() { - // Feishu returns the edit-cap rejection as HTTP 200 + {"code":230072}. - // Regression guard for the body-code-first fix: a status-only success - // gate would miscount this as Edited and never trip cap detection. - let server = MockServer::start().await; - mount_token(&server).await; - Mock::given(method("PUT")) - .and(path("/open-apis/im/v1/messages/om_capped")) - .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ - "code": 230072, - "msg": "The message has reached the number of times it can be edited." - }))) - .mount(&server) - .await; - - let mut config = test_config(); - config.api_base_override = Some(server.uri()); - let adapter = FeishuAdapter::new(config); - - let outcome = edit_feishu_message(&adapter, "om_capped", "hello").await; - assert!( - matches!(outcome, EditOutcome::CapReached), - "HTTP 200 + code 230072 must yield CapReached, got non-cap outcome" - ); - // Sentinel marked → subsequent pre-check short-circuits. - assert!(is_edit_cap_reached(&adapter.edit_counts, "om_capped")); - } - - #[tokio::test] - async fn edit_success_on_http_200_code_zero() { - // HTTP 200 + {"code":0} is a real success → Edited + count incremented. - let server = MockServer::start().await; - mount_token(&server).await; - Mock::given(method("PUT")) - .and(path("/open-apis/im/v1/messages/om_ok")) - .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ - "code": 0, - "msg": "success", - "data": {} - }))) - .mount(&server) - .await; - - let mut config = test_config(); - config.api_base_override = Some(server.uri()); - let adapter = FeishuAdapter::new(config); - - let outcome = edit_feishu_message(&adapter, "om_ok", "hello").await; - assert!( - matches!(outcome, EditOutcome::Edited), - "HTTP 200 + code 0 must yield Edited" - ); - let map = adapter.edit_counts.lock().unwrap(); - assert_eq!(map.counts.get("om_ok").copied(), Some(1)); - } - - #[tokio::test] - async fn edit_failure_on_http_200_other_code() { - // HTTP 200 + non-zero, non-cap code is a genuine failure, not a success. - let server = MockServer::start().await; - mount_token(&server).await; - Mock::given(method("PUT")) - .and(path("/open-apis/im/v1/messages/om_err")) - .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ - "code": 99991, - "msg": "rate limited" - }))) - .mount(&server) - .await; - - let mut config = test_config(); - config.api_base_override = Some(server.uri()); - let adapter = FeishuAdapter::new(config); - - let outcome = edit_feishu_message(&adapter, "om_err", "hello").await; - assert!( - matches!(outcome, EditOutcome::Failed(_)), - "HTTP 200 + code 99991 must yield Failed, not Edited" - ); - // Failure must NOT increment the edit count. - let map = adapter.edit_counts.lock().unwrap(); - assert_eq!(map.counts.get("om_err").copied(), None); - } - - // --- handle_reply dispatch-seam message_id validation (R3) --- - // These exercise the seam reject path directly (the edit_* tests above call - // edit_feishu_message and bypass the seam). The guard runs before any - // network call, so no mock server is needed. - - #[tokio::test] - async fn handle_reply_seam_rejects_invalid_id_with_response() { - let adapter = FeishuAdapter::new(test_config()); - let (event_tx, mut event_rx) = tokio::sync::broadcast::channel(8); - - let reply = crate::schema::GatewayReply { - schema: "openab.gateway.reply.v1".into(), - reply_to: "draft".into(), // sentinel, not an om_ id → rejected - platform: "feishu".into(), - channel: crate::schema::ReplyChannel { - id: "oc_chat1".into(), - thread_id: None, - }, - content: crate::schema::Content { - content_type: "text".into(), - text: "hello".into(), - attachments: vec![], - }, - command: Some("edit_message".into()), - request_id: Some("req_seam_1".into()), - quote_message_id: None, - }; - - handle_reply(&reply, &adapter, &event_tx).await; - - let raw = event_rx.try_recv().expect("expected a GatewayResponse"); - let resp: serde_json::Value = serde_json::from_str(&raw).unwrap(); - assert_eq!(resp["request_id"], "req_seam_1"); - assert_eq!(resp["success"], false); - assert_eq!(resp["error"], "invalid message_id format"); - } - - #[tokio::test] - async fn handle_reply_seam_rejects_invalid_id_silently_without_request_id() { - let adapter = FeishuAdapter::new(test_config()); - let (event_tx, mut event_rx) = tokio::sync::broadcast::channel(8); - - let reply = crate::schema::GatewayReply { - schema: "openab.gateway.reply.v1".into(), - reply_to: "om_bad/segment".into(), // URL metachar → rejected - platform: "feishu".into(), - channel: crate::schema::ReplyChannel { - id: "oc_chat1".into(), - thread_id: None, - }, - content: crate::schema::Content { - content_type: "text".into(), - text: "hello".into(), - attachments: vec![], - }, - command: Some("delete_message".into()), - request_id: None, - quote_message_id: None, - }; - - handle_reply(&reply, &adapter, &event_tx).await; - - assert!( - event_rx.try_recv().is_err(), - "no response expected when request_id is absent" - ); - } -} diff --git a/gateway/src/adapters/googlechat.rs b/gateway/src/adapters/googlechat.rs deleted file mode 100644 index 93c0c8f8e..000000000 --- a/gateway/src/adapters/googlechat.rs +++ /dev/null @@ -1,2470 +0,0 @@ -use crate::schema::*; -use axum::extract::State; -use axum::http::HeaderMap; -use axum::response::IntoResponse; -use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation}; -use serde::Deserialize; -use std::sync::Arc; -use std::time::Instant; -use tokio::sync::RwLock; -use tracing::{error, info, warn}; - -pub const GOOGLE_CHAT_API_BASE: &str = "https://chat.googleapis.com/v1"; -const GOOGLE_CHAT_MESSAGE_LIMIT: usize = 4096; - -const IMAGE_MAX_DIMENSION_PX: u32 = 1200; -const IMAGE_JPEG_QUALITY: u8 = 75; -const IMAGE_MAX_DOWNLOAD: u64 = 10 * 1024 * 1024; // 10 MB -const FILE_MAX_DOWNLOAD: u64 = 512 * 1024; // 512 KB -const AUDIO_MAX_DOWNLOAD: u64 = 25 * 1024 * 1024; // 25 MB -/// Per-request timeout for Google Chat Media API downloads. Prevents a hung -/// connection from blocking the spawned download task indefinitely. -const MEDIA_REQUEST_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30); -/// Cap on text file attachments per message (matches Discord/Slack). -const TEXT_FILE_COUNT_CAP: usize = 5; -/// Cap on aggregate text file bytes per message (matches Discord/Slack 1 MB). -const TEXT_TOTAL_CAP: u64 = 1024 * 1024; - -// --- Google Chat types --- -// -// Google Chat delivers webhooks in two shapes depending on the App's -// Connection settings in the Cloud Console: -// - HTTP endpoint URL mode: top-level fields (message, user, space, ...) -// - Pub/Sub mode: wrapped under `chat.messagePayload` -// Both are supported via the optional fields below; the handler prefers -// the wrapped form and falls back to top-level when `chat` is absent. - -#[derive(Debug, Deserialize)] -pub struct GoogleChatEnvelope { - pub chat: Option, - // HTTP endpoint URL top-level fields (used when `chat` is None) - pub message: Option, - pub user: Option, - pub space: Option, -} - -#[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ChatPayload { - pub user: Option, - pub message_payload: Option, -} - -#[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct MessagePayload { - pub message: Option, - pub space: Option, -} - -#[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct GoogleChatMessage { - pub name: String, - pub text: Option, - pub argument_text: Option, - pub sender: Option, - pub thread: Option, - pub space: Option, - #[serde(default)] - pub attachment: Vec, -} - -#[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct GoogleChatAttachment { - #[allow(dead_code)] - pub name: Option, - pub content_name: Option, - pub content_type: Option, - pub source: Option, - pub attachment_data_ref: Option, - #[allow(dead_code)] - pub drive_data_ref: Option, -} - -#[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct AttachmentDataRef { - pub resource_name: Option, -} - -#[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -#[allow(dead_code)] -pub struct DriveDataRef { - pub drive_file_id: Option, -} - -/// Reference to media that needs async download after webhook parse. -#[derive(Debug, Clone)] -pub enum GoogleChatMediaRef { - Image { - resource_name: String, - content_name: String, - }, - File { - resource_name: String, - content_name: String, - }, - Audio { - resource_name: String, - content_name: String, - content_type: String, - }, -} - -#[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct GoogleChatUser { - pub name: String, - pub display_name: String, - #[serde(rename = "type")] - pub user_type: String, -} - -#[derive(Debug, Deserialize)] -pub struct GoogleChatThread { - pub name: String, -} - -#[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct GoogleChatSpace { - pub name: String, - #[serde(rename = "type")] - pub space_type: Option, - // Parsed by serde, not consumed in current code paths. - #[allow(dead_code)] - pub space_type_renamed: Option, -} - -// --- Webhook JWT verification --- - -const GOOGLE_CHAT_ISSUER: &str = "https://accounts.google.com"; -const GOOGLE_CHAT_JWKS_URL: &str = "https://www.googleapis.com/oauth2/v3/certs"; -const GOOGLE_CHAT_SIGNER_EMAIL: &str = "chat@system.gserviceaccount.com"; -const JWKS_CACHE_TTL: std::time::Duration = std::time::Duration::from_secs(3600); - -/// Verify the JWT's `email` claim belongs to Google Chat. -/// HTTP endpoint URL webhooks are signed by `chat@system.gserviceaccount.com`. -/// Without this check, any Google-issued ID token would be accepted. -fn verify_email_claim(claims: &serde_json::Value) -> Result<(), String> { - let email = claims - .get("email") - .and_then(|v| v.as_str()) - .ok_or("missing email claim")?; - if email != GOOGLE_CHAT_SIGNER_EMAIL { - return Err(format!( - "email claim mismatch: expected {GOOGLE_CHAT_SIGNER_EMAIL}, got {email}" - )); - } - Ok(()) -} - -#[derive(Debug, Clone, Deserialize)] -struct JwkKey { - kid: Option, - n: String, - e: String, - kty: String, -} - -#[derive(Debug, Deserialize)] -struct JwksResponse { - keys: Vec, -} - -pub struct GoogleChatJwtVerifier { - audience: String, - client: reqwest::Client, - jwks_cache: RwLock, Instant)>>, -} - -impl GoogleChatJwtVerifier { - pub fn new(audience: String) -> Self { - Self { - audience, - client: reqwest::Client::new(), - jwks_cache: RwLock::new(None), - } - } - - async fn get_jwks(&self) -> Result, String> { - { - let cache = self.jwks_cache.read().await; - if let Some((ref keys, fetched_at)) = *cache { - if fetched_at.elapsed() < JWKS_CACHE_TTL { - return Ok(keys.clone()); - } - } - } - let jwks: JwksResponse = self - .client - .get(GOOGLE_CHAT_JWKS_URL) - .send() - .await - .map_err(|e| format!("JWKS fetch error: {e}"))? - .json() - .await - .map_err(|e| format!("JWKS parse error: {e}"))?; - - let keys = jwks.keys; - *self.jwks_cache.write().await = Some((keys.clone(), Instant::now())); - Ok(keys) - } - - pub async fn verify(&self, auth_header: &str) -> Result<(), String> { - let token = auth_header - .strip_prefix("Bearer ") - .ok_or("missing Bearer prefix")?; - - let header = - jsonwebtoken::decode_header(token).map_err(|e| format!("invalid JWT header: {e}"))?; - let kid = header.kid.ok_or("no kid in JWT header")?; - - let keys = self.get_jwks().await?; - let key = match keys.iter().find(|k| k.kid.as_deref() == Some(&kid)) { - Some(k) => k.clone(), - None => { - // Key rotation: invalidate cache and retry - *self.jwks_cache.write().await = None; - let refreshed = self.get_jwks().await?; - refreshed - .into_iter() - .find(|k| k.kid.as_deref() == Some(&kid)) - .ok_or_else(|| format!("no matching JWK for kid={kid}"))? - } - }; - - if key.kty != "RSA" { - return Err(format!("unsupported key type: {}", key.kty)); - } - - let decoding_key = DecodingKey::from_rsa_components(&key.n, &key.e) - .map_err(|e| format!("RSA key decode error: {e}"))?; - - let mut validation = Validation::new(Algorithm::RS256); - validation.set_audience(&[&self.audience]); - validation.set_issuer(&[GOOGLE_CHAT_ISSUER]); - validation.validate_exp = true; - - let token_data = decode::(token, &decoding_key, &validation) - .map_err(|e| format!("JWT validation failed: {e}"))?; - - verify_email_claim(&token_data.claims)?; - - Ok(()) - } -} - -// --- Adapter (encapsulates all Google Chat state) --- - -pub struct GoogleChatAdapter { - pub token_cache: Option, - pub access_token: Option, - pub jwt_verifier: Option, - pub client: reqwest::Client, - pub api_base: String, -} - -impl GoogleChatAdapter { - pub fn new( - token_cache: Option, - access_token: Option, - jwt_verifier: Option, - ) -> Self { - Self { - token_cache, - access_token, - jwt_verifier, - client: reqwest::Client::new(), - api_base: GOOGLE_CHAT_API_BASE.into(), - } - } - - async fn get_token(&self) -> Option { - if let Some(ref cache) = self.token_cache { - match cache.get_token(&self.client).await { - Ok(t) => return Some(t), - Err(e) => { - error!("googlechat token refresh failed: {e}"); - return None; - } - } - } - self.access_token.clone() - } - - async fn edit_message(&self, message_name: &str, text: &str) { - let Some(token) = self.get_token().await else { - tracing::warn!("googlechat edit_message: no token available"); - return; - }; - - let formatted = markdown_to_gchat(text); - let url = format!( - "{}/{}?updateMask=text", - self.api_base, message_name - ); - let body = serde_json::json!({ "text": formatted }); - - match self.client.patch(&url).bearer_auth(&token).json(&body).send().await { - Ok(r) if r.status().is_success() => { - tracing::trace!(message_name = %message_name, "googlechat message edited"); - } - Ok(r) => { - let status = r.status(); - let body = r.text().await.unwrap_or_default(); - error!(status = %status, body = %body, "googlechat edit_message failed"); - } - Err(e) => { - error!(err = %e, "googlechat edit_message request failed"); - } - } - } - - pub async fn handle_reply( - &self, - reply: &GatewayReply, - event_tx: &tokio::sync::broadcast::Sender, - ) { - // Command routing - match reply.command.as_deref() { - Some("add_reaction") | Some("remove_reaction") | Some("create_topic") => return, - Some("edit_message") => { - self.edit_message(&reply.reply_to, &reply.content.text).await; - return; - } - _ => {} - } - - info!( - space = %reply.channel.id, - thread_id = ?reply.channel.thread_id, - "gateway → googlechat" - ); - - let Some(token) = self.get_token().await else { - info!( - text = %reply.content.text, - "googlechat reply (dry-run, no credentials configured)" - ); - if let Some(ref req_id) = reply.request_id { - let resp = crate::schema::GatewayResponse { - schema: "openab.gateway.response.v1".into(), - request_id: req_id.clone(), - success: false, - thread_id: None, - message_id: None, - error: Some("no credentials configured".into()), - }; - if let Ok(json) = serde_json::to_string(&resp) { - let _ = event_tx.send(json); - } - } - return; - }; - - let text = &reply.content.text; - let chunks = split_text(text, GOOGLE_CHAT_MESSAGE_LIMIT); - - // Empty message: short-circuit, send failure ack and skip API call - if chunks.is_empty() { - if let Some(ref req_id) = reply.request_id { - let resp = crate::schema::GatewayResponse { - schema: "openab.gateway.response.v1".into(), - request_id: req_id.clone(), - success: false, - thread_id: None, - message_id: None, - error: Some("empty message".into()), - }; - if let Ok(json) = serde_json::to_string(&resp) { - let _ = event_tx.send(json); - } - } - return; - } - - if chunks.len() == 1 { - let result = send_message( - &self.client, - &token, - &reply.channel.id, - reply.channel.thread_id.as_deref(), - text, - &self.api_base, - ) - .await; - - if let Some(ref req_id) = reply.request_id { - let (success, message_id, error) = match result { - Ok(name) => (true, Some(name), None), - Err(e) => (false, None, Some(e)), - }; - let resp = crate::schema::GatewayResponse { - schema: "openab.gateway.response.v1".into(), - request_id: req_id.clone(), - success, - thread_id: None, - message_id, - error, - }; - if let Ok(json) = serde_json::to_string(&resp) { - let _ = event_tx.send(json); - } - } - } else { - let mut first_msg_name: Option = None; - let mut first_error: Option = None; - for chunk in chunks { - match send_message( - &self.client, - &token, - &reply.channel.id, - reply.channel.thread_id.as_deref(), - chunk, - &self.api_base, - ) - .await - { - Ok(name) => { - if first_msg_name.is_none() { - first_msg_name = Some(name); - } - } - Err(e) => { - if first_error.is_none() { - first_error = Some(e); - } - } - } - } - if let Some(ref req_id) = reply.request_id { - let resp = crate::schema::GatewayResponse { - schema: "openab.gateway.response.v1".into(), - request_id: req_id.clone(), - success: first_msg_name.is_some() && first_error.is_none(), - thread_id: None, - message_id: first_msg_name, - error: first_error, - }; - if let Ok(json) = serde_json::to_string(&resp) { - let _ = event_tx.send(json); - } - } - } - } -} - -// --- Webhook handler --- - -pub async fn webhook( - State(state): State>, - headers: HeaderMap, - body: axum::body::Bytes, -) -> axum::response::Response { - info!("googlechat webhook received ({} bytes)", body.len()); - - if let Some(ref adapter) = state.google_chat { - if let Some(ref verifier) = adapter.jwt_verifier { - let auth_header = match headers - .get("authorization") - .and_then(|v| v.to_str().ok()) - { - Some(h) => h, - None => { - warn!("googlechat webhook: missing authorization header"); - return (axum::http::StatusCode::UNAUTHORIZED, "unauthorized").into_response(); - } - }; - if let Err(e) = verifier.verify(auth_header).await { - warn!(error = %e, "googlechat webhook JWT verification failed"); - return (axum::http::StatusCode::UNAUTHORIZED, "unauthorized").into_response(); - } - } - } - - let envelope: GoogleChatEnvelope = match serde_json::from_slice(&body) { - Ok(e) => e, - Err(e) => { - let body_str = String::from_utf8_lossy(&body); - error!(body = %body_str, "googlechat webhook parse error: {e}"); - return (axum::http::StatusCode::BAD_REQUEST, "bad request").into_response(); - } - }; - - // Try the Pub/Sub `chat`-wrapped shape first, then fall back to the - // HTTP endpoint URL top-level shape. - let (msg_opt, top_user, top_space) = if let Some(chat) = envelope.chat { - let user = chat.user; - let (msg, space) = match chat.message_payload { - Some(p) => (p.message, p.space), - None => (None, None), - }; - (msg, user, space) - } else { - (envelope.message, envelope.user, envelope.space) - }; - - let Some(ref msg) = msg_opt else { - return empty_json_response(); - }; - - let text = msg - .argument_text - .as_deref() - .or(msg.text.as_deref()) - .unwrap_or(""); - - let media_refs = parse_attachments(&msg.attachment); - - // Drop event only if BOTH text and attachments are empty - if text.trim().is_empty() && media_refs.is_empty() { - return empty_json_response(); - } - - let sender = msg.sender.as_ref().or(top_user.as_ref()); - let space = msg.space.as_ref().or(top_space.as_ref()); - - let is_bot = sender.map(|s| s.user_type == "BOT").unwrap_or(false); - if is_bot { - return empty_json_response(); - } - - let sender_id = sender.map(|s| s.name.clone()).unwrap_or_default(); - let display_name = sender - .map(|s| s.display_name.clone()) - .unwrap_or_else(|| "Unknown".into()); - let sender_name = sender_id - .strip_prefix("users/") - .unwrap_or(&sender_id) - .to_string(); - - let space_name = space.map(|s| s.name.clone()).unwrap_or_default(); - let space_type = space - .and_then(|s| s.space_type.clone()) - .unwrap_or_else(|| "ROOM".into()); - - let thread_id = msg.thread.as_ref().map(|t| t.name.clone()); - - let message_id = msg - .name - .rsplit('/') - .next() - .unwrap_or(&msg.name) - .to_string(); - - // No attachments → emit event synchronously and respond 200 - if media_refs.is_empty() { - send_googlechat_event( - &state, - &space_name, - space_type, - thread_id, - &sender_id, - &sender_name, - &display_name, - text, - &message_id, - Vec::new(), - ); - return empty_json_response(); - } - - // Has attachments — spawn background task so the webhook returns 200 within - // Google Chat's 30 s deadline regardless of how long downloads take. - let text = text.to_string(); - let state = state.clone(); - let spawn_space = space_name.clone(); - tokio::spawn(async move { - use futures_util::FutureExt; - let result = std::panic::AssertUnwindSafe(async { - let mut downloaded: Vec = Vec::new(); - let mut text_file_count: usize = 0; - let mut text_file_bytes: u64 = 0; - if let Some(ref adapter) = state.google_chat { - if let Some(token) = adapter.get_token().await { - for media_ref in &media_refs { - let attachment = match media_ref { - GoogleChatMediaRef::Image { - resource_name, - content_name, - .. - } => { - download_googlechat_image( - &adapter.client, - &token, - &adapter.api_base, - resource_name, - content_name, - ) - .await - } - GoogleChatMediaRef::File { - resource_name, - content_name, - .. - } => { - if text_file_count >= TEXT_FILE_COUNT_CAP { - warn!(content_name = %content_name, cap = TEXT_FILE_COUNT_CAP, "googlechat text file count cap reached, skipping"); - continue; - } - let remaining = TEXT_TOTAL_CAP.saturating_sub(text_file_bytes); - let att = download_googlechat_file( - &adapter.client, - &token, - &adapter.api_base, - resource_name, - content_name, - remaining, - ) - .await; - let Some(att) = att else { continue }; - text_file_count += 1; - text_file_bytes += att.size; - Some(att) - } - GoogleChatMediaRef::Audio { - resource_name, - content_name, - content_type, - } => { - download_googlechat_audio( - &adapter.client, - &token, - &adapter.api_base, - resource_name, - content_name, - content_type, - ) - .await - } - }; - if let Some(att) = attachment { - downloaded.push(att); - } - } - } else { - warn!("googlechat: no token available for attachment download"); - } - } - - // If text is empty AND every attachment failed to download, drop the event. - if text.trim().is_empty() && downloaded.is_empty() { - warn!( - space = %space_name, - "googlechat: empty text + all attachments failed, dropping event" - ); - return; - } - - send_googlechat_event( - &state, - &space_name, - space_type, - thread_id, - &sender_id, - &sender_name, - &display_name, - &text, - &message_id, - downloaded, - ); - }).catch_unwind().await; - if let Err(e) = result { - error!(space = %spawn_space, "googlechat attachment download task panicked: {e:?}"); - } - }); - - empty_json_response() -} - -#[allow(clippy::too_many_arguments)] -fn send_googlechat_event( - state: &Arc, - space_name: &str, - space_type: String, - thread_id: Option, - sender_id: &str, - sender_name: &str, - display_name: &str, - text: &str, - message_id: &str, - attachments: Vec, -) { - let mut gw_event = GatewayEvent::new( - "googlechat", - ChannelInfo { - id: space_name.to_string(), - channel_type: space_type, - thread_id, - }, - SenderInfo { - id: sender_id.to_string(), - name: sender_name.to_string(), - display_name: display_name.to_string(), - is_bot: false, - }, - text, - message_id, - vec![], - ); - gw_event.content.attachments = attachments; - - let attachment_count = gw_event.content.attachments.len(); - let json = match serde_json::to_string(&gw_event) { - Ok(j) => j, - Err(e) => { - error!(error = %e, "googlechat: failed to serialize GatewayEvent"); - return; - } - }; - info!( - space = %space_name, - sender = %sender_name, - attachment_count, - "googlechat → gateway" - ); - let _ = state.event_tx.send(json); -} - -fn empty_json_response() -> axum::response::Response { - use axum::response::IntoResponse; - ( - [(axum::http::header::CONTENT_TYPE, "application/json")], - "{}", - ) - .into_response() -} - -// --- Token cache with JWT auto-refresh --- - -pub struct GoogleChatTokenCache { - token: RwLock>, - sa_email: String, - private_key: String, -} - -const TOKEN_REFRESH_MARGIN_SECS: u64 = 300; - -impl GoogleChatTokenCache { - pub fn new(sa_key_json: &str) -> Result { - let key: serde_json::Value = - serde_json::from_str(sa_key_json).map_err(|e| format!("invalid SA key JSON: {e}"))?; - let email = key - .get("client_email") - .and_then(|v| v.as_str()) - .ok_or("missing client_email in SA key")? - .to_string(); - let pkey = key - .get("private_key") - .and_then(|v| v.as_str()) - .ok_or("missing private_key in SA key")? - .to_string(); - Ok(Self { - token: RwLock::new(None), - sa_email: email, - private_key: pkey, - }) - } - - pub async fn get_token(&self, client: &reqwest::Client) -> Result { - { - let guard = self.token.read().await; - if let Some((ref tok, ref ts, ttl)) = *guard { - if ts.elapsed().as_secs() < ttl.saturating_sub(TOKEN_REFRESH_MARGIN_SECS) { - return Ok(tok.clone()); - } - } - } - let mut guard = self.token.write().await; - if let Some((ref tok, ref ts, ttl)) = *guard { - if ts.elapsed().as_secs() < ttl.saturating_sub(TOKEN_REFRESH_MARGIN_SECS) { - return Ok(tok.clone()); - } - } - let (new_token, expire) = self.refresh(client).await?; - *guard = Some((new_token.clone(), Instant::now(), expire)); - info!("googlechat access token refreshed (expires in {expire}s)"); - Ok(new_token) - } - - async fn refresh(&self, client: &reqwest::Client) -> Result<(String, u64), String> { - let jwt = self.build_jwt().map_err(|e| format!("JWT build error: {e}"))?; - let resp = client - .post("https://oauth2.googleapis.com/token") - .form(&[ - ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"), - ("assertion", &jwt), - ]) - .send() - .await - .map_err(|e| format!("token exchange request failed: {e}"))?; - - let body: serde_json::Value = resp - .json() - .await - .map_err(|e| format!("token exchange parse failed: {e}"))?; - - let token = body - .get("access_token") - .and_then(|v| v.as_str()) - .ok_or_else(|| { - let err = body - .get("error_description") - .and_then(|v| v.as_str()) - .unwrap_or("unknown error"); - format!("token exchange failed: {err}") - })? - .to_string(); - - let expires_in = body - .get("expires_in") - .and_then(|v| v.as_u64()) - .unwrap_or(3600); - - Ok((token, expires_in)) - } - - fn build_jwt(&self) -> Result { - let now = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .map_err(|e| e.to_string())? - .as_secs(); - - let claims = serde_json::json!({ - "iss": self.sa_email, - "scope": "https://www.googleapis.com/auth/chat.bot", - "aud": "https://oauth2.googleapis.com/token", - "iat": now, - "exp": now + 3600, - }); - - let key = jsonwebtoken::EncodingKey::from_rsa_pem(self.private_key.as_bytes()) - .map_err(|e| format!("RSA key parse error: {e}"))?; - let header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::RS256); - jsonwebtoken::encode(&header, &claims, &key) - .map_err(|e| format!("JWT encode error: {e}")) - } -} - -/// Convert markdown to Google Chat native formatting. -/// -/// Called by both `send_message` and `edit_message`. Assumes the caller passes -/// **raw markdown** — passing already-converted text would double-convert -/// (e.g. `*bold*` from a previous pass would be re-parsed as `*italic*`). -/// OAB core is expected to always emit raw markdown for both initial replies -/// and streaming edits. -fn markdown_to_gchat(text: &str) -> String { - let mut result = String::with_capacity(text.len()); - let lines: Vec<&str> = text.split('\n').collect(); - let mut i = 0; - while i < lines.len() { - let line = lines[i]; - // Detect fenced code block — pass through unchanged - if line.trim_start().starts_with("```") { - result.push_str(line); - result.push('\n'); - i += 1; - while i < lines.len() { - result.push_str(lines[i]); - if lines[i].trim_start().starts_with("```") { - i += 1; - if i < lines.len() { - result.push('\n'); - } - break; - } - result.push('\n'); - i += 1; - } - continue; - } - // Heading → bold - let converted = if let Some(heading) = line - .strip_prefix("### ") - .or_else(|| line.strip_prefix("## ")) - .or_else(|| line.strip_prefix("# ")) - { - format!("*{}*", heading.trim()) - } else { - convert_inline(line) - }; - result.push_str(&converted); - i += 1; - if i < lines.len() { - result.push('\n'); - } - } - result -} - -// TODO(perf): allocates Vec per line. Acceptable at current scale, -// but on hot streaming paths with many edit_message updates this could be -// rewritten with byte-level iteration over &str. -fn convert_inline(line: &str) -> String { - let mut out = String::with_capacity(line.len()); - let chars: Vec = line.chars().collect(); - let mut i = 0; - while i < chars.len() { - // Inline code — pass through - if chars[i] == '`' { - out.push('`'); - i += 1; - while i < chars.len() && chars[i] != '`' { - out.push(chars[i]); - i += 1; - } - if i < chars.len() { - out.push('`'); - i += 1; - } - continue; - } - // Markdown link: [text](url) - if chars[i] == '[' { - if let Some((link_text, url, end)) = parse_md_link(&chars, i) { - let converted_text = convert_inline(&link_text); - out.push_str(&format!("<{}|{}>", url, converted_text)); - i = end; - continue; - } - } - // Bold: **text** → *text* - if chars[i] == '*' && i + 1 < chars.len() && chars[i + 1] == '*' { - if let Some(end) = find_closing(&chars, i + 2, &['*', '*']) { - out.push('*'); - let inner: String = chars[i + 2..end].iter().collect(); - out.push_str(&convert_inline(&inner)); - out.push('*'); - i = end + 2; - continue; - } - } - // Bold: __text__ → *text* - if chars[i] == '_' && i + 1 < chars.len() && chars[i + 1] == '_' { - if let Some(end) = find_closing(&chars, i + 2, &['_', '_']) { - out.push('*'); - let inner: String = chars[i + 2..end].iter().collect(); - out.push_str(&convert_inline(&inner)); - out.push('*'); - i = end + 2; - continue; - } - } - // Strikethrough: ~~text~~ → ~text~ - if chars[i] == '~' && i + 1 < chars.len() && chars[i + 1] == '~' { - if let Some(end) = find_closing(&chars, i + 2, &['~', '~']) { - out.push('~'); - let inner: String = chars[i + 2..end].iter().collect(); - out.push_str(&convert_inline(&inner)); - out.push('~'); - i = end + 2; - continue; - } - } - // Italic: *text* → _text_ (single asterisk, not part of **bold**) - // Must come AFTER the **bold** check above. Requires non-asterisk - // immediately after opening * and before closing *. - if chars[i] == '*' - && i + 1 < chars.len() - && chars[i + 1] != '*' - && !chars[i + 1].is_whitespace() - { - if let Some(end) = find_single(&chars, i + 1, '*') { - if end > i + 1 && !chars[end - 1].is_whitespace() { - out.push('_'); - let inner: String = chars[i + 1..end].iter().collect(); - out.push_str(&convert_inline(&inner)); - out.push('_'); - i = end + 1; - continue; - } - } - } - out.push(chars[i]); - i += 1; - } - out -} - -fn find_single(chars: &[char], start: usize, target: char) -> Option { - let mut i = start; - while i < chars.len() { - if chars[i] == target { - return Some(i); - } - i += 1; - } - None -} - -fn parse_md_link(chars: &[char], start: usize) -> Option<(String, String, usize)> { - let mut i = start + 1; - let mut depth = 1; - let text_start = i; - while i < chars.len() && depth > 0 { - if chars[i] == '[' { - depth += 1; - } else if chars[i] == ']' { - depth -= 1; - } - if depth > 0 { - i += 1; - } - } - if depth != 0 { - return None; - } - let text: String = chars[text_start..i].iter().collect(); - i += 1; // skip ']' - if i >= chars.len() || chars[i] != '(' { - return None; - } - i += 1; // skip '(' - let url_start = i; - let mut paren_depth = 1; - while i < chars.len() && paren_depth > 0 { - if chars[i] == '(' { - paren_depth += 1; - } else if chars[i] == ')' { - paren_depth -= 1; - } - if paren_depth > 0 { - i += 1; - } - } - if paren_depth != 0 { - return None; - } - let url: String = chars[url_start..i].iter().collect(); - Some((text, url, i + 1)) -} - -fn find_closing(chars: &[char], start: usize, pattern: &[char]) -> Option { - if pattern.len() < 2 { - return None; - } - let mut i = start; - while i + 1 < chars.len() { - if chars[i] == pattern[0] && chars[i + 1] == pattern[1] { - return Some(i); - } - i += 1; - } - None -} - -async fn send_message( - client: &reqwest::Client, - token: &str, - space: &str, - thread_id: Option<&str>, - text: &str, - api_base: &str, -) -> Result { - let mut url = format!("{}/{}/messages", api_base, space); - - let formatted = markdown_to_gchat(text); - let mut body = serde_json::json!({ - "text": formatted, - }); - - if let Some(thread_id) = thread_id { - body["thread"] = serde_json::json!({ - "name": thread_id, - }); - url.push_str("?messageReplyOption=REPLY_MESSAGE_FALLBACK_TO_NEW_THREAD"); - } - - let resp = client - .post(&url) - .bearer_auth(token) - .json(&body) - .send() - .await; - - match resp { - Ok(r) if r.status().is_success() => { - let body = r.text().await.unwrap_or_default(); - let parsed: serde_json::Value = serde_json::from_str(&body).unwrap_or_default(); - parsed - .get("name") - .and_then(|v| v.as_str()) - .map(String::from) - .ok_or_else(|| "missing message name in response".into()) - } - Ok(r) => { - let status = r.status(); - let body = r.text().await.unwrap_or_default(); - error!(status = %status, body = %body, "googlechat send error"); - Err(format!("send failed: {} {}", status, body)) - } - Err(e) => { - error!("googlechat send error: {e}"); - Err(format!("request error: {e}")) - } - } -} - -fn split_text(text: &str, limit: usize) -> Vec<&str> { - let mut chunks = Vec::new(); - let mut start = 0; - while start < text.len() { - if start + limit >= text.len() { - chunks.push(&text[start..]); - break; - } - let mut end = start + limit; - while !text.is_char_boundary(end) { - end -= 1; - } - let mut search_start = if end > start + 200 { end - 200 } else { start }; - while search_start < end && !text.is_char_boundary(search_start) { - search_start += 1; - } - let break_at = text[search_start..end] - .rfind('\n') - .or_else(|| text[search_start..end].rfind(' ')) - .map(|pos| search_start + pos + 1) - .unwrap_or(end); - chunks.push(&text[start..break_at]); - start = break_at; - } - chunks -} - -// --- Attachment parsing & download --- - -/// Whitelist of text-like file extensions for `download_googlechat_file`. -const TEXT_EXTS: &[&str] = &[ - "txt", "csv", "log", "md", "json", "jsonl", "yaml", "yml", "toml", "xml", - "rs", "py", "js", "ts", "jsx", "tsx", "go", "java", "c", "cpp", "h", "hpp", - "rb", "sh", "bash", "sql", "html", "css", "ini", "cfg", "conf", -]; - -/// Parse Google Chat attachment array into media references for async download. -/// -/// Skips Drive-sourced attachments (different download API), and unknown -/// content types. Branches on `contentType` prefix to bucket into image / -/// audio / file. -fn parse_attachments(attachments: &[GoogleChatAttachment]) -> Vec { - let mut refs = Vec::new(); - for att in attachments { - // Only handle UPLOADED_CONTENT (Drive needs separate Drive API call) - if att.source.as_deref() != Some("UPLOADED_CONTENT") { - continue; - } - let resource_name = match att - .attachment_data_ref - .as_ref() - .and_then(|d| d.resource_name.clone()) - { - Some(rn) => rn, - None => continue, - }; - let content_type = att.content_type.clone().unwrap_or_default(); - let content_name = att.content_name.clone().unwrap_or_else(|| "file".into()); - - if content_type.starts_with("image/") { - refs.push(GoogleChatMediaRef::Image { - resource_name, - content_name, - }); - } else if content_type.starts_with("audio/") { - refs.push(GoogleChatMediaRef::Audio { - resource_name, - content_name, - content_type, - }); - } else if content_type.starts_with("video/") { - info!(content_name = %content_name, content_type = %content_type, "googlechat: video attachment skipped (not yet supported)"); - } else { - refs.push(GoogleChatMediaRef::File { - resource_name, - content_name, - }); - } - } - refs -} - -/// Resize image so longest side ≤ 1200px, then encode as JPEG. -/// GIFs are passed through unchanged to preserve animation. -fn resize_and_compress(raw: &[u8]) -> Result<(Vec, String), image::ImageError> { - use image::ImageReader; - use std::io::Cursor; - - let reader = ImageReader::new(Cursor::new(raw)).with_guessed_format()?; - let format = reader.format(); - if format == Some(image::ImageFormat::Gif) { - return Ok((raw.to_vec(), "image/gif".to_string())); - } - let img = reader.decode()?; - let (w, h) = (img.width(), img.height()); - let img = if w > IMAGE_MAX_DIMENSION_PX || h > IMAGE_MAX_DIMENSION_PX { - let max_side = std::cmp::max(w, h); - let ratio = f64::from(IMAGE_MAX_DIMENSION_PX) / f64::from(max_side); - let new_w = (f64::from(w) * ratio) as u32; - let new_h = (f64::from(h) * ratio) as u32; - img.resize(new_w, new_h, image::imageops::FilterType::Lanczos3) - } else { - img - }; - let mut buf = Cursor::new(Vec::new()); - let encoder = image::codecs::jpeg::JpegEncoder::new_with_quality(&mut buf, IMAGE_JPEG_QUALITY); - img.write_with_encoder(encoder)?; - Ok((buf.into_inner(), "image/jpeg".to_string())) -} - -/// Build the Media API URL for a given resource_name. -/// Google Chat Media API uses `{+resourceName}` (RFC 6570 reserved expansion), -/// so `/` must stay literal while other special chars are percent-encoded. -fn media_url(api_base: &str, resource_name: &str) -> String { - let encoded: String = resource_name - .bytes() - .map(|b| match b { - b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' | b'/' => { - (b as char).to_string() - } - _ => format!("%{:02X}", b), - }) - .collect(); - format!("{}/media/{}?alt=media", api_base, encoded) -} - -/// Download an image attachment via Google Chat Media API → resize/compress → base64. -pub async fn download_googlechat_image( - client: &reqwest::Client, - token: &str, - api_base: &str, - resource_name: &str, - content_name: &str, -) -> Option { - let url = media_url(api_base, resource_name); - let resp = match client.get(&url).bearer_auth(token).timeout(MEDIA_REQUEST_TIMEOUT).send().await { - Ok(r) => r, - Err(e) => { - warn!(content_name, error = %e, "googlechat image download failed"); - return None; - } - }; - if !resp.status().is_success() { - warn!(content_name, status = %resp.status(), "googlechat image download failed"); - return None; - } - if let Some(cl) = resp.headers().get(reqwest::header::CONTENT_LENGTH) { - if let Ok(size) = cl.to_str().unwrap_or("0").parse::() { - if size > IMAGE_MAX_DOWNLOAD { - warn!(content_name, size, "googlechat image Content-Length exceeds 10MB limit"); - return None; - } - } - } - let bytes = resp.bytes().await.ok()?; - if bytes.len() as u64 > IMAGE_MAX_DOWNLOAD { - warn!(content_name, size = bytes.len(), "googlechat image exceeds 10MB limit"); - return None; - } - let (compressed, mime) = match resize_and_compress(&bytes) { - Ok(v) => v, - Err(e) => { - warn!(content_name, error = %e, "googlechat image resize failed"); - return None; - } - }; - let path = crate::store::store_media(&compressed).await?; - Some(crate::schema::Attachment { - attachment_type: "image".into(), - filename: content_name.to_string(), - mime_type: mime, - data: String::new(), - size: compressed.len() as u64, - path: Some(path), - }) -} - -/// Download a text-like file via Google Chat Media API → base64. -/// Non-text extensions are skipped to avoid sending binary garbage to the model. -pub async fn download_googlechat_file( - client: &reqwest::Client, - token: &str, - api_base: &str, - resource_name: &str, - content_name: &str, - remaining_budget: u64, -) -> Option { - let ext = content_name.rsplit('.').next().unwrap_or("").to_lowercase(); - if !TEXT_EXTS.contains(&ext.as_str()) { - tracing::debug!(content_name, "skipping non-text googlechat file attachment"); - return None; - } - let max_size = FILE_MAX_DOWNLOAD.min(remaining_budget); - let url = media_url(api_base, resource_name); - let resp = match client.get(&url).bearer_auth(token).timeout(MEDIA_REQUEST_TIMEOUT).send().await { - Ok(r) => r, - Err(e) => { - warn!(content_name, error = %e, "googlechat file download failed"); - return None; - } - }; - if !resp.status().is_success() { - warn!(content_name, status = %resp.status(), "googlechat file download failed"); - return None; - } - if let Some(cl) = resp.headers().get(reqwest::header::CONTENT_LENGTH) { - if let Ok(size) = cl.to_str().unwrap_or("0").parse::() { - if size > max_size { - warn!(content_name, size, limit = max_size, "googlechat file Content-Length exceeds limit"); - return None; - } - } - } - let bytes = resp.bytes().await.ok()?; - if bytes.len() as u64 > max_size { - warn!(content_name, size = bytes.len(), limit = max_size, "googlechat file exceeds size limit"); - return None; - } - let path = crate::store::store_media(&bytes).await?; - Some(crate::schema::Attachment { - attachment_type: "text_file".into(), - filename: content_name.to_string(), - mime_type: "text/plain".into(), - data: String::new(), - size: bytes.len() as u64, - path: Some(path), - }) -} - -/// Download an audio attachment as-is (no resize/transcode) → filesystem store. -/// Core's STT pipeline (when available) consumes this as `audio` attachment_type. -pub async fn download_googlechat_audio( - client: &reqwest::Client, - token: &str, - api_base: &str, - resource_name: &str, - content_name: &str, - content_type: &str, -) -> Option { - let url = media_url(api_base, resource_name); - let resp = match client.get(&url).bearer_auth(token).timeout(MEDIA_REQUEST_TIMEOUT).send().await { - Ok(r) => r, - Err(e) => { - warn!(content_name, error = %e, "googlechat audio download failed"); - return None; - } - }; - if !resp.status().is_success() { - warn!(content_name, status = %resp.status(), "googlechat audio download failed"); - return None; - } - if let Some(cl) = resp.headers().get(reqwest::header::CONTENT_LENGTH) { - if let Ok(size) = cl.to_str().unwrap_or("0").parse::() { - if size > AUDIO_MAX_DOWNLOAD { - warn!(content_name, size, "googlechat audio Content-Length exceeds 25MB limit"); - return None; - } - } - } - let bytes = resp.bytes().await.ok()?; - if bytes.len() as u64 > AUDIO_MAX_DOWNLOAD { - warn!(content_name, size = bytes.len(), "googlechat audio exceeds 25MB limit"); - return None; - } - let path = crate::store::store_media(&bytes).await?; - Some(crate::schema::Attachment { - attachment_type: "audio".into(), - filename: content_name.to_string(), - mime_type: content_type.to_string(), - data: String::new(), - size: bytes.len() as u64, - path: Some(path), - }) -} - -#[cfg(test)] -mod tests { - use super::*; - - // --- Webhook parsing tests --- - - fn make_envelope( - text: &str, - argument_text: Option<&str>, - sender_type: &str, - space_type: &str, - thread_name: Option<&str>, - ) -> String { - let arg_field = argument_text - .map(|a| format!(r#""argumentText": "{a}","#)) - .unwrap_or_default(); - let thread_field = thread_name - .map(|t| format!(r#","thread": {{"name": "{t}"}}"#)) - .unwrap_or_default(); - format!( - r#"{{ - "chat": {{ - "user": {{ - "name": "users/111", - "displayName": "Test", - "type": "{sender_type}" - }}, - "messagePayload": {{ - "message": {{ - "name": "spaces/SP/messages/msg1", - "text": "{text}", - {arg_field} - "sender": {{ - "name": "users/111", - "displayName": "Test", - "type": "{sender_type}" - }}, - "space": {{ - "name": "spaces/SP", - "type": "{space_type}" - }} - {thread_field} - }}, - "space": {{ - "name": "spaces/SP", - "type": "{space_type}" - }} - }} - }} - }}"# - ) - } - - #[test] - fn parse_dm_message() { - let json = make_envelope("hello", None, "HUMAN", "DM", None); - let envelope: GoogleChatEnvelope = serde_json::from_str(&json).unwrap(); - let chat = envelope.chat.unwrap(); - let msg = chat.message_payload.unwrap().message.unwrap(); - assert_eq!(msg.text.as_deref(), Some("hello")); - assert_eq!(msg.sender.unwrap().user_type, "HUMAN"); - } - - #[test] - fn parse_space_message_with_thread() { - let json = make_envelope( - "@Bot hi", - Some("hi"), - "HUMAN", - "ROOM", - Some("spaces/SP/threads/t1"), - ); - let envelope: GoogleChatEnvelope = serde_json::from_str(&json).unwrap(); - let chat = envelope.chat.unwrap(); - let payload = chat.message_payload.unwrap(); - let msg = payload.message.as_ref().unwrap(); - assert_eq!(msg.argument_text.as_deref(), Some("hi")); - assert_eq!(msg.thread.as_ref().unwrap().name, "spaces/SP/threads/t1"); - assert_eq!(payload.space.as_ref().unwrap().space_type.as_deref(), Some("ROOM")); - } - - #[test] - fn parse_bot_message_detected() { - let json = make_envelope("bot says hi", None, "BOT", "DM", None); - let envelope: GoogleChatEnvelope = serde_json::from_str(&json).unwrap(); - let chat = envelope.chat.unwrap(); - let user = chat.user.unwrap(); - assert_eq!(user.user_type, "BOT"); - } - - #[test] - fn parse_missing_chat_field() { - let json = r#"{"type": "ADDED_TO_SPACE"}"#; - let envelope: GoogleChatEnvelope = serde_json::from_str(json).unwrap(); - assert!(envelope.chat.is_none()); - } - - #[test] - fn parse_missing_message_payload() { - let json = r#"{"chat": {"user": {"name": "u/1", "displayName": "X", "type": "HUMAN"}}}"#; - let envelope: GoogleChatEnvelope = serde_json::from_str(json).unwrap(); - assert!(envelope.chat.unwrap().message_payload.is_none()); - } - - #[test] - fn parse_invalid_json() { - let result: Result = serde_json::from_str("not json"); - assert!(result.is_err()); - } - - #[test] - fn argument_text_preferred_over_text() { - let json = make_envelope("@Bot explain", Some("explain"), "HUMAN", "ROOM", None); - let envelope: GoogleChatEnvelope = serde_json::from_str(&json).unwrap(); - let msg = envelope - .chat - .unwrap() - .message_payload - .unwrap() - .message - .unwrap(); - let text = msg - .argument_text - .as_deref() - .or(msg.text.as_deref()) - .unwrap(); - assert_eq!(text, "explain"); - } - - #[test] - fn sender_name_strips_users_prefix() { - let sender_id = "users/123456"; - let name = sender_id.strip_prefix("users/").unwrap_or(sender_id); - assert_eq!(name, "123456"); - } - - #[test] - fn message_id_extracts_last_segment() { - let msg_name = "spaces/SP/messages/abc123"; - let id = msg_name.rsplit('/').next().unwrap_or(msg_name); - assert_eq!(id, "abc123"); - } - - // --- split_text tests --- - - #[test] - fn split_text_short() { - let chunks = split_text("hello", 100); - assert_eq!(chunks, vec!["hello"]); - } - - #[test] - fn split_text_exact_limit() { - let text = "a".repeat(100); - let chunks = split_text(&text, 100); - assert_eq!(chunks.len(), 1); - } - - #[test] - fn split_text_over_limit() { - let text = "a".repeat(150); - let chunks = split_text(&text, 100); - assert_eq!(chunks.len(), 2); - let reassembled: String = chunks.concat(); - assert_eq!(reassembled, text); - } - - #[test] - fn split_text_breaks_at_newline() { - let text = format!("{}\n{}", "a".repeat(50), "b".repeat(50)); - let chunks = split_text(&text, 60); - assert_eq!(chunks.len(), 2); - assert!(chunks[0].ends_with('\n')); - } - - #[test] - fn split_text_breaks_at_space() { - let text = format!("{} {}", "a".repeat(50), "b".repeat(50)); - let chunks = split_text(&text, 60); - assert_eq!(chunks.len(), 2); - } - - #[test] - fn split_text_chinese_utf8_safe() { - let text = "你好世界測試谷歌聊天中文消息分割安全驗證完成"; - let chunks = split_text(text, 10); - assert!(chunks.len() > 1); - let reassembled: String = chunks.concat(); - assert_eq!(reassembled, text); - } - - #[test] - fn split_text_search_start_char_boundary() { - let text: String = "谷歌".repeat(150); // 300 chars, 900 bytes - let chunks = split_text(&text, 500); - assert!(chunks.len() >= 2); - let reassembled: String = chunks.concat(); - assert_eq!(reassembled, text); - } - - #[test] - fn split_text_empty() { - let chunks = split_text("", 100); - assert!(chunks.is_empty()); - } - - // --- Token cache tests --- - - #[test] - fn token_cache_rejects_invalid_json() { - let result = GoogleChatTokenCache::new("not json"); - assert!(result.is_err()); - } - - #[test] - fn token_cache_rejects_missing_fields() { - match GoogleChatTokenCache::new(r#"{"type": "service_account"}"#) { - Err(e) => assert!(e.contains("client_email"), "unexpected error: {e}"), - Ok(_) => panic!("expected error for missing client_email"), - } - } - - #[test] - fn token_cache_accepts_valid_sa_key() { - let key = r#"{ - "type": "service_account", - "client_email": "test@test.iam.gserviceaccount.com", - "private_key": "-----BEGIN RSA PRIVATE KEY-----\nMIIBogIBAAJBALvRE+oCMiEhtfO5ufaVc9wGPUMgPGxmVFiMPC/NMxmCSiMGNO9h\nCOyByeF78QHp4gOW/lgVU8MJkv33hVMbOr0CAwEAAQJAD2k/cFR5MIkw1PFcm98K\n9MqYKGpJCmGBjFY0ek0FHoC14d/hpAGaoWMjNaAyjU/IbGv1fj8C5MfFRal0fV/L\nAQIhAP0T6FPJMm3O4bM18kMHnOP2+Y5kxMpVxCCjkVNH7D09AiEAvXEQJYwR+PFs\njDDhEm4VPmk+lKJoQlopj8TN5gQV8DECIBcXbU+LPWx4H+qRElhCB1B5a9mYmpY\nV6LFPnvSfHqNAiEAiNj5+A6E7WJ50il+5NG5yn7gXh8vNxdCYIw5qx6C2bECIBmW\nVGVRhSmNsmDMJFsGIdKJsnEXpizIVHtfpXsS4j9X\n-----END RSA PRIVATE KEY-----\n" - }"#; - let result = GoogleChatTokenCache::new(key); - assert!(result.is_ok()); - } - - // --- Bot filtering logic test --- - - #[test] - fn bot_user_type_detected() { - let json = make_envelope("hello", None, "BOT", "DM", None); - let envelope: GoogleChatEnvelope = serde_json::from_str(&json).unwrap(); - let chat = envelope.chat.unwrap(); - let sender = chat - .message_payload - .as_ref() - .and_then(|p| p.message.as_ref()) - .and_then(|m| m.sender.as_ref()) - .or(chat.user.as_ref()); - let is_bot = sender.map(|s| s.user_type == "BOT").unwrap_or(false); - assert!(is_bot); - } - - // --- JWT verifier tests --- - - #[tokio::test] - async fn jwt_rejects_missing_bearer_prefix() { - let verifier = GoogleChatJwtVerifier::new("123456".into()); - let result = verifier.verify("NotBearer xyz").await; - assert!(result.is_err()); - assert!(result.unwrap_err().contains("Bearer")); - } - - #[tokio::test] - async fn jwt_rejects_invalid_token() { - let verifier = GoogleChatJwtVerifier::new("123456".into()); - let result = verifier.verify("Bearer not.a.valid.jwt").await; - assert!(result.is_err()); - } - - #[tokio::test] - async fn jwt_rejects_empty_bearer() { - let verifier = GoogleChatJwtVerifier::new("123456".into()); - let result = verifier.verify("Bearer ").await; - assert!(result.is_err()); - } - - #[test] - fn email_claim_accepts_chat_system_account() { - let claims = serde_json::json!({"email": "chat@system.gserviceaccount.com"}); - assert!(verify_email_claim(&claims).is_ok()); - } - - #[test] - fn email_claim_rejects_other_google_email() { - let claims = serde_json::json!({"email": "attacker@example.iam.gserviceaccount.com"}); - let err = verify_email_claim(&claims).unwrap_err(); - assert!(err.contains("email claim mismatch")); - } - - #[test] - fn email_claim_rejects_unrelated_gserviceaccount() { - let claims = serde_json::json!({"email": "my-sa@my-project.iam.gserviceaccount.com"}); - assert!(verify_email_claim(&claims).is_err()); - } - - #[test] - fn email_claim_rejects_missing_email() { - let claims = serde_json::json!({"sub": "123", "iss": "accounts.google.com"}); - let err = verify_email_claim(&claims).unwrap_err(); - assert!(err.contains("missing email")); - } - - #[test] - fn email_claim_rejects_non_string_email() { - let claims = serde_json::json!({"email": 12345}); - assert!(verify_email_claim(&claims).is_err()); - } - - #[test] - fn human_user_type_not_filtered() { - let json = make_envelope("hello", None, "HUMAN", "DM", None); - let envelope: GoogleChatEnvelope = serde_json::from_str(&json).unwrap(); - let chat = envelope.chat.unwrap(); - let sender = chat - .message_payload - .as_ref() - .and_then(|p| p.message.as_ref()) - .and_then(|m| m.sender.as_ref()) - .or(chat.user.as_ref()); - let is_bot = sender.map(|s| s.user_type == "BOT").unwrap_or(false); - assert!(!is_bot); - } - - // --- markdown_to_gchat tests --- - - #[test] - fn markdown_bold_double_asterisk() { - assert_eq!(markdown_to_gchat("hello **world**"), "hello *world*"); - } - - #[test] - fn markdown_bold_underscore() { - assert_eq!(markdown_to_gchat("hello __world__"), "hello *world*"); - } - - #[test] - fn markdown_link_conversion() { - assert_eq!( - markdown_to_gchat("see [docs](https://example.com) here"), - "see here" - ); - } - - #[test] - fn markdown_heading_to_bold() { - assert_eq!(markdown_to_gchat("# Title\ntext"), "*Title*\ntext"); - assert_eq!(markdown_to_gchat("## Sub\ntext"), "*Sub*\ntext"); - assert_eq!(markdown_to_gchat("### Deep\ntext"), "*Deep*\ntext"); - } - - #[test] - fn markdown_code_block_preserved() { - let input = "before\n```rust\nlet **x** = 1;\n```\nafter **bold**"; - let output = markdown_to_gchat(input); - assert!(output.contains("let **x** = 1;")); - assert!(output.contains("after *bold*")); - } - - #[test] - fn markdown_inline_code_preserved() { - assert_eq!( - markdown_to_gchat("use `**not bold**` here **bold**"), - "use `**not bold**` here *bold*" - ); - } - - #[test] - fn markdown_strikethrough() { - assert_eq!(markdown_to_gchat("~~deleted~~"), "~deleted~"); - assert_eq!( - markdown_to_gchat("keep ~~this~~ and ~~that~~"), - "keep ~this~ and ~that~" - ); - } - - #[test] - fn markdown_italic_asterisk() { - assert_eq!(markdown_to_gchat("*italic*"), "_italic_"); - assert_eq!( - markdown_to_gchat("plain *one* and *two*"), - "plain _one_ and _two_" - ); - } - - #[test] - fn markdown_italic_does_not_match_bold() { - assert_eq!(markdown_to_gchat("**bold**"), "*bold*"); - assert_eq!( - markdown_to_gchat("**bold** and *italic*"), - "*bold* and _italic_" - ); - } - - #[test] - fn markdown_italic_underscore_passes_through() { - // Google Chat italic is _text_, single underscore should pass through - assert_eq!(markdown_to_gchat("_italic_"), "_italic_"); - } - - #[test] - fn markdown_italic_no_match_when_unbalanced() { - // Lone asterisks (no closing) should pass through - assert_eq!(markdown_to_gchat("a * b"), "a * b"); - // Whitespace adjacent to asterisks should not match (avoid matching multiplication) - assert_eq!(markdown_to_gchat("2 * 3 * 4"), "2 * 3 * 4"); - } - - #[test] - fn markdown_empty_string() { - assert_eq!(markdown_to_gchat(""), ""); - } - - #[test] - fn markdown_no_conversion_needed() { - assert_eq!(markdown_to_gchat("plain text"), "plain text"); - } - - #[test] - fn markdown_multiple_links() { - assert_eq!( - markdown_to_gchat("[a](http://a.com) and [b](http://b.com)"), - " and " - ); - } - - #[test] - fn markdown_nested_bold_in_link_text() { - assert_eq!( - markdown_to_gchat("[**bold link**](http://x.com)"), - "" - ); - } - - #[test] - fn parse_send_message_response_name() { - let resp_json = r#"{"name": "spaces/SP1/messages/msg123", "text": "hello"}"#; - let parsed: serde_json::Value = serde_json::from_str(resp_json).unwrap(); - let name = parsed.get("name").and_then(|v| v.as_str()); - assert_eq!(name, Some("spaces/SP1/messages/msg123")); - } - - #[tokio::test] - async fn handle_reply_sends_gateway_response_success() { - use wiremock::{Mock, MockServer, ResponseTemplate}; - use wiremock::matchers::{method, path_regex}; - - let mock_server = MockServer::start().await; - Mock::given(method("POST")) - .and(path_regex("/spaces/.*/messages")) - .respond_with(ResponseTemplate::new(200).set_body_json( - serde_json::json!({"name": "spaces/TEST/messages/msg_abc"}), - )) - .mount(&mock_server) - .await; - - let (event_tx, mut event_rx) = tokio::sync::broadcast::channel::(16); - let mut adapter = GoogleChatAdapter::new(None, Some("fake-token".into()), None); - adapter.api_base = mock_server.uri(); - - let reply = GatewayReply { - schema: "openab.gateway.reply.v1".into(), - reply_to: "orig_msg".into(), - platform: "googlechat".into(), - channel: ReplyChannel { - id: "spaces/TEST".into(), - thread_id: None, - }, - content: Content { - content_type: "text".into(), - attachments: Vec::new(), - text: "hello".into(), - }, - command: None, - request_id: Some("req_123".into()), - quote_message_id: None, - }; - - adapter.handle_reply(&reply, &event_tx).await; - - let received = event_rx.try_recv(); - assert!(received.is_ok(), "expected GatewayResponse on event_tx"); - let resp: GatewayResponse = serde_json::from_str(&received.unwrap()).unwrap(); - assert_eq!(resp.request_id, "req_123"); - assert!(resp.success); - assert_eq!(resp.message_id, Some("spaces/TEST/messages/msg_abc".into())); - } - - #[tokio::test] - async fn handle_reply_sends_failure_response_on_api_error() { - use wiremock::{Mock, MockServer, ResponseTemplate}; - use wiremock::matchers::{method, path_regex}; - - let mock_server = MockServer::start().await; - Mock::given(method("POST")) - .and(path_regex("/spaces/.*/messages")) - .respond_with(ResponseTemplate::new(500)) - .mount(&mock_server) - .await; - - let (event_tx, mut event_rx) = tokio::sync::broadcast::channel::(16); - let mut adapter = GoogleChatAdapter::new(None, Some("fake-token".into()), None); - adapter.api_base = mock_server.uri(); - - let reply = GatewayReply { - schema: "openab.gateway.reply.v1".into(), - reply_to: "orig_msg".into(), - platform: "googlechat".into(), - channel: ReplyChannel { - id: "spaces/TEST".into(), - thread_id: None, - }, - content: Content { - content_type: "text".into(), - attachments: Vec::new(), - text: "hello".into(), - }, - command: None, - request_id: Some("req_fail".into()), - quote_message_id: None, - }; - - adapter.handle_reply(&reply, &event_tx).await; - - let received = event_rx.try_recv(); - assert!(received.is_ok(), "expected GatewayResponse on event_tx"); - let resp: GatewayResponse = serde_json::from_str(&received.unwrap()).unwrap(); - assert_eq!(resp.request_id, "req_fail"); - assert!(!resp.success); - assert!(resp.message_id.is_none()); - let err = resp.error.expect("error should be set on send failure"); - assert!(err.contains("500"), "error should include status code, got: {}", err); - } - - #[tokio::test] - async fn handle_reply_empty_message_short_circuits() { - use wiremock::{Mock, MockServer, ResponseTemplate}; - use wiremock::matchers::{method, path_regex}; - - let mock_server = MockServer::start().await; - // Mount a mock that would fail the test if called - Mock::given(method("POST")) - .and(path_regex("/spaces/.*/messages")) - .respond_with(ResponseTemplate::new(500)) - .expect(0) - .mount(&mock_server) - .await; - - let (event_tx, mut event_rx) = tokio::sync::broadcast::channel::(16); - let mut adapter = GoogleChatAdapter::new(None, Some("fake-token".into()), None); - adapter.api_base = mock_server.uri(); - - let reply = GatewayReply { - schema: "openab.gateway.reply.v1".into(), - reply_to: "orig_msg".into(), - platform: "googlechat".into(), - channel: ReplyChannel { - id: "spaces/TEST".into(), - thread_id: None, - }, - content: Content { - content_type: "text".into(), - attachments: Vec::new(), - text: "".into(), - }, - command: None, - request_id: Some("req_empty".into()), - quote_message_id: None, - }; - - adapter.handle_reply(&reply, &event_tx).await; - - let received = event_rx.try_recv(); - assert!(received.is_ok(), "expected failure GatewayResponse for empty message"); - let resp: GatewayResponse = serde_json::from_str(&received.unwrap()).unwrap(); - assert_eq!(resp.request_id, "req_empty"); - assert!(!resp.success); - assert_eq!(resp.error, Some("empty message".into())); - } - - #[tokio::test] - async fn handle_reply_multi_chunk_failure_includes_error() { - use wiremock::{Mock, MockServer, ResponseTemplate}; - use wiremock::matchers::{method, path_regex}; - - let mock_server = MockServer::start().await; - Mock::given(method("POST")) - .and(path_regex("/spaces/.*/messages")) - .respond_with(ResponseTemplate::new(500)) - .mount(&mock_server) - .await; - - let (event_tx, mut event_rx) = tokio::sync::broadcast::channel::(16); - let mut adapter = GoogleChatAdapter::new(None, Some("fake-token".into()), None); - adapter.api_base = mock_server.uri(); - - let long_text = "x".repeat(5000); - let reply = GatewayReply { - schema: "openab.gateway.reply.v1".into(), - reply_to: "orig_msg".into(), - platform: "googlechat".into(), - channel: ReplyChannel { - id: "spaces/TEST".into(), - thread_id: None, - }, - content: Content { - content_type: "text".into(), - attachments: Vec::new(), - text: long_text, - }, - command: None, - request_id: Some("req_multi_fail".into()), - quote_message_id: None, - }; - - adapter.handle_reply(&reply, &event_tx).await; - - let received = event_rx.try_recv(); - assert!(received.is_ok(), "expected GatewayResponse"); - let resp: GatewayResponse = serde_json::from_str(&received.unwrap()).unwrap(); - assert_eq!(resp.request_id, "req_multi_fail"); - assert!(!resp.success); - assert!(resp.message_id.is_none()); - let err = resp.error.expect("multi-chunk failure should set error"); - assert!(err.contains("500")); - } - - #[tokio::test] - async fn handle_reply_token_failure_sends_error_response() { - let (event_tx, mut event_rx) = tokio::sync::broadcast::channel::(16); - let adapter = GoogleChatAdapter::new(None, None, None); - - let reply = GatewayReply { - schema: "openab.gateway.reply.v1".into(), - reply_to: "orig_msg".into(), - platform: "googlechat".into(), - channel: ReplyChannel { - id: "spaces/TEST".into(), - thread_id: None, - }, - content: Content { - content_type: "text".into(), - attachments: Vec::new(), - text: "hello".into(), - }, - command: None, - request_id: Some("req_notoken".into()), - quote_message_id: None, - }; - - adapter.handle_reply(&reply, &event_tx).await; - - let received = event_rx.try_recv(); - assert!(received.is_ok(), "expected failure GatewayResponse"); - let resp: GatewayResponse = serde_json::from_str(&received.unwrap()).unwrap(); - assert_eq!(resp.request_id, "req_notoken"); - assert!(!resp.success); - assert_eq!(resp.error, Some("no credentials configured".into())); - } - - #[tokio::test] - async fn handle_reply_edit_message_does_not_send_response() { - use wiremock::{Mock, MockServer, ResponseTemplate}; - use wiremock::matchers::{method, path_regex}; - - let mock_server = MockServer::start().await; - Mock::given(method("PATCH")) - .and(path_regex("/spaces/.*/messages/.*")) - .respond_with(ResponseTemplate::new(200).set_body_json( - serde_json::json!({"name": "spaces/SP/messages/msg1"}), - )) - .mount(&mock_server) - .await; - - let (event_tx, mut event_rx) = tokio::sync::broadcast::channel::(16); - let mut adapter = GoogleChatAdapter::new(None, Some("fake-token".into()), None); - adapter.api_base = mock_server.uri(); - - let reply = GatewayReply { - schema: "openab.gateway.reply.v1".into(), - reply_to: "spaces/SP/messages/msg1".into(), - platform: "googlechat".into(), - channel: ReplyChannel { - id: "spaces/SP".into(), - thread_id: None, - }, - content: Content { - content_type: "text".into(), - attachments: Vec::new(), - text: "updated text".into(), - }, - command: Some("edit_message".into()), - request_id: None, - quote_message_id: None, - }; - - adapter.handle_reply(&reply, &event_tx).await; - - let received = event_rx.try_recv(); - assert!(received.is_err()); - } - - #[tokio::test] - async fn handle_reply_multi_chunk_sends_gateway_response() { - use wiremock::{Mock, MockServer, ResponseTemplate}; - use wiremock::matchers::{method, path_regex}; - - let mock_server = MockServer::start().await; - Mock::given(method("POST")) - .and(path_regex("/spaces/.*/messages")) - .respond_with(ResponseTemplate::new(200).set_body_json( - serde_json::json!({"name": "spaces/TEST/messages/first_chunk"}), - )) - .mount(&mock_server) - .await; - - let (event_tx, mut event_rx) = tokio::sync::broadcast::channel::(16); - let mut adapter = GoogleChatAdapter::new(None, Some("fake-token".into()), None); - adapter.api_base = mock_server.uri(); - - let long_text = "x".repeat(5000); - let reply = GatewayReply { - schema: "openab.gateway.reply.v1".into(), - reply_to: "orig_msg".into(), - platform: "googlechat".into(), - channel: ReplyChannel { - id: "spaces/TEST".into(), - thread_id: None, - }, - content: Content { - content_type: "text".into(), - attachments: Vec::new(), - text: long_text, - }, - command: None, - request_id: Some("req_multi".into()), - quote_message_id: None, - }; - - adapter.handle_reply(&reply, &event_tx).await; - - let received = event_rx.try_recv(); - assert!(received.is_ok(), "expected GatewayResponse for multi-chunk"); - let resp: GatewayResponse = serde_json::from_str(&received.unwrap()).unwrap(); - assert_eq!(resp.request_id, "req_multi"); - assert!(resp.success); - assert_eq!(resp.message_id, Some("spaces/TEST/messages/first_chunk".into())); - } - - #[tokio::test] - async fn handle_reply_multi_chunk_partial_failure_reports_failure() { - // Mixed success/failure: chunk 1 succeeds, subsequent chunks fail. - // Expect success=false (any chunk failure marks overall as failed), - // but message_id is still set so core has a reference. - use wiremock::{Mock, MockServer, ResponseTemplate}; - use wiremock::matchers::{method, path_regex}; - - let mock_server = MockServer::start().await; - // First request: 200 OK with message name - Mock::given(method("POST")) - .and(path_regex("/spaces/.*/messages")) - .respond_with(ResponseTemplate::new(200).set_body_json( - serde_json::json!({"name": "spaces/TEST/messages/first_chunk"}), - )) - .up_to_n_times(1) - .mount(&mock_server) - .await; - // Subsequent requests: 500 - Mock::given(method("POST")) - .and(path_regex("/spaces/.*/messages")) - .respond_with(ResponseTemplate::new(500)) - .mount(&mock_server) - .await; - - let (event_tx, mut event_rx) = tokio::sync::broadcast::channel::(16); - let mut adapter = GoogleChatAdapter::new(None, Some("fake-token".into()), None); - adapter.api_base = mock_server.uri(); - - let long_text = "x".repeat(5000); - let reply = GatewayReply { - schema: "openab.gateway.reply.v1".into(), - reply_to: "orig_msg".into(), - platform: "googlechat".into(), - channel: ReplyChannel { - id: "spaces/TEST".into(), - thread_id: None, - }, - content: Content { - content_type: "text".into(), - attachments: Vec::new(), - text: long_text, - }, - command: None, - request_id: Some("req_partial".into()), - quote_message_id: None, - }; - - adapter.handle_reply(&reply, &event_tx).await; - - let received = event_rx.try_recv(); - assert!(received.is_ok(), "expected GatewayResponse"); - let resp: GatewayResponse = serde_json::from_str(&received.unwrap()).unwrap(); - assert_eq!(resp.request_id, "req_partial"); - assert!(!resp.success, "partial failure must report success=false"); - assert_eq!(resp.message_id, Some("spaces/TEST/messages/first_chunk".into())); - let err = resp.error.expect("partial failure should set error"); - assert!(err.contains("500")); - } - - // --- Attachment parsing tests --- - - fn make_attachment( - source: &str, - content_type: &str, - content_name: &str, - resource_name: Option<&str>, - ) -> GoogleChatAttachment { - GoogleChatAttachment { - name: Some("spaces/SP/messages/MSG/attachments/ATT".into()), - content_name: Some(content_name.into()), - content_type: Some(content_type.into()), - source: Some(source.into()), - attachment_data_ref: resource_name.map(|rn| AttachmentDataRef { - resource_name: Some(rn.into()), - }), - drive_data_ref: None, - } - } - - #[test] - fn parse_attachments_image() { - let atts = vec![make_attachment( - "UPLOADED_CONTENT", - "image/png", - "photo.png", - Some("AATT_resource"), - )]; - let refs = parse_attachments(&atts); - assert_eq!(refs.len(), 1); - match &refs[0] { - GoogleChatMediaRef::Image { - resource_name, - content_name, - } => { - assert_eq!(resource_name, "AATT_resource"); - assert_eq!(content_name, "photo.png"); - } - other => panic!("expected Image, got {:?}", other), - } - } - - #[test] - fn parse_attachments_audio() { - let atts = vec![make_attachment( - "UPLOADED_CONTENT", - "audio/mp4", - "voice.m4a", - Some("AATT"), - )]; - let refs = parse_attachments(&atts); - assert!(matches!(refs[0], GoogleChatMediaRef::Audio { .. })); - } - - #[test] - fn parse_attachments_file() { - let atts = vec![make_attachment( - "UPLOADED_CONTENT", - "text/plain", - "notes.txt", - Some("AATT"), - )]; - let refs = parse_attachments(&atts); - assert!(matches!(refs[0], GoogleChatMediaRef::File { .. })); - } - - #[test] - fn parse_attachments_skips_drive() { - let atts = vec![GoogleChatAttachment { - name: Some("spaces/SP/messages/MSG/attachments/ATT".into()), - content_name: Some("doc".into()), - content_type: Some("application/vnd.google-apps.document".into()), - source: Some("DRIVE_FILE".into()), - attachment_data_ref: None, - drive_data_ref: Some(DriveDataRef { - drive_file_id: Some("drive_id_123".into()), - }), - }]; - assert_eq!(parse_attachments(&atts).len(), 0); - } - - #[test] - fn parse_attachments_skips_missing_resource_name() { - let atts = vec![make_attachment( - "UPLOADED_CONTENT", - "image/png", - "photo.png", - None, - )]; - assert_eq!(parse_attachments(&atts).len(), 0); - } - - #[test] - fn media_url_preserves_slashes_and_encodes_specials() { - let url = media_url("https://chat.googleapis.com/v1", "spaces/SP/messages/MSG/attachments/ATT"); - assert_eq!( - url, - "https://chat.googleapis.com/v1/media/spaces/SP/messages/MSG/attachments/ATT?alt=media" - ); - let url2 = media_url("https://chat.googleapis.com/v1", "AATT/some+resource=name"); - assert_eq!( - url2, - "https://chat.googleapis.com/v1/media/AATT/some%2Bresource%3Dname?alt=media" - ); - } - - #[tokio::test] - async fn download_googlechat_image_resizes_and_returns_attachment() { - use wiremock::{Mock, MockServer, ResponseTemplate}; - use wiremock::matchers::{method, path_regex}; - - // Generate a small valid PNG - let img = image::RgbImage::from_pixel(10, 10, image::Rgb([255, 0, 0])); - let mut buf = std::io::Cursor::new(Vec::new()); - image::DynamicImage::ImageRgb8(img) - .write_to(&mut buf, image::ImageFormat::Png) - .unwrap(); - let png_bytes = buf.into_inner(); - - let mock_server = MockServer::start().await; - Mock::given(method("GET")) - .and(path_regex("/media/.*")) - .respond_with( - ResponseTemplate::new(200) - .set_body_bytes(png_bytes) - .insert_header("content-type", "image/png"), - ) - .mount(&mock_server) - .await; - - let client = reqwest::Client::new(); - let result = download_googlechat_image( - &client, - "fake-token", - &mock_server.uri(), - "AATT_resource", - "photo.png", - ) - .await; - let att = result.expect("expected successful download"); - assert_eq!(att.attachment_type, "image"); - assert_eq!(att.filename, "photo.png"); - assert_eq!(att.mime_type, "image/jpeg"); // resized PNG → JPEG - assert!(att.path.is_some()); // stored to filesystem - assert!(att.size > 0); - } - - #[tokio::test] - async fn download_googlechat_file_rejects_non_text_extension() { - let client = reqwest::Client::new(); - let result = download_googlechat_file( - &client, - "fake-token", - "https://unused", // not called for non-text - "AATT", - "binary.exe", - TEXT_TOTAL_CAP, - ) - .await; - assert!(result.is_none(), "non-text extensions must be skipped"); - } - - #[tokio::test] - async fn download_googlechat_file_text_extension_succeeds() { - use wiremock::{Mock, MockServer, ResponseTemplate}; - use wiremock::matchers::{method, path_regex}; - - let mock_server = MockServer::start().await; - Mock::given(method("GET")) - .and(path_regex("/media/.*")) - .respond_with( - ResponseTemplate::new(200).set_body_bytes(b"hello world".to_vec()), - ) - .mount(&mock_server) - .await; - - let client = reqwest::Client::new(); - let result = download_googlechat_file( - &client, - "fake-token", - &mock_server.uri(), - "AATT", - "notes.txt", - TEXT_TOTAL_CAP, - ) - .await; - let att = result.expect("expected successful download"); - assert_eq!(att.attachment_type, "text_file"); - assert_eq!(att.filename, "notes.txt"); - assert_eq!(att.mime_type, "text/plain"); - } - - #[tokio::test] - async fn download_googlechat_audio_returns_attachment() { - use wiremock::{Mock, MockServer, ResponseTemplate}; - use wiremock::matchers::{method, path_regex}; - - let mock_server = MockServer::start().await; - let audio_bytes = vec![0u8; 1024]; - Mock::given(method("GET")) - .and(path_regex("/media/.*")) - .respond_with(ResponseTemplate::new(200).set_body_bytes(audio_bytes.clone())) - .mount(&mock_server) - .await; - - let client = reqwest::Client::new(); - let result = download_googlechat_audio( - &client, - "fake-token", - &mock_server.uri(), - "AATT", - "voice.m4a", - "audio/mp4", - ) - .await; - let att = result.expect("expected successful download"); - assert_eq!(att.attachment_type, "audio"); - assert_eq!(att.filename, "voice.m4a"); - assert_eq!(att.mime_type, "audio/mp4"); - assert_eq!(att.size, 1024); - } - - #[tokio::test] - async fn download_googlechat_image_rejects_oversized_content_length() { - use wiremock::{Mock, MockServer, ResponseTemplate}; - use wiremock::matchers::{method, path_regex}; - - let mock_server = MockServer::start().await; - Mock::given(method("GET")) - .and(path_regex("/media/.*")) - .respond_with( - ResponseTemplate::new(200) - .insert_header("content-length", "20000000") // 20 MB > 10 MB limit - .set_body_bytes(vec![0u8; 100]), - ) - .mount(&mock_server) - .await; - - let client = reqwest::Client::new(); - let result = download_googlechat_image( - &client, - "fake-token", - &mock_server.uri(), - "AATT", - "huge.png", - ) - .await; - assert!(result.is_none(), "oversized image must be rejected"); - } - - #[test] - fn parses_http_endpoint_url_top_level_envelope() { - let envelope: GoogleChatEnvelope = serde_json::from_value(serde_json::json!({ - "message": { - "name": "spaces/AAAA/messages/BBBB", - "text": "hello", - "attachment": [] - }, - "user": { - "name": "users/123", - "displayName": "Test User", - "type": "HUMAN" - }, - "space": { - "name": "spaces/AAAA", - "type": "DM" - } - })) - .unwrap(); - assert!(envelope.chat.is_none()); - assert!(envelope.message.is_some()); - assert_eq!(envelope.message.unwrap().name, "spaces/AAAA/messages/BBBB"); - assert!(envelope.user.is_some()); - assert_eq!(envelope.user.unwrap().name, "users/123"); - assert!(envelope.space.is_some()); - assert_eq!(envelope.space.unwrap().name, "spaces/AAAA"); - } -} diff --git a/gateway/src/adapters/line.rs b/gateway/src/adapters/line.rs deleted file mode 100644 index 1323d2605..000000000 --- a/gateway/src/adapters/line.rs +++ /dev/null @@ -1,780 +0,0 @@ -use crate::media::{resize_and_compress, IMAGE_MAX_DOWNLOAD}; -use crate::schema::*; -use crate::store; -use axum::extract::State; -use serde::Deserialize; -use std::sync::Arc; -use tracing::{error, info, warn}; - -// --- LINE types --- - -#[derive(Debug, Deserialize)] -pub struct LineWebhookBody { - events: Vec, -} - -#[derive(Debug, Deserialize)] -struct LineEvent { - #[serde(rename = "type")] - event_type: String, - source: Option, - message: Option, - #[serde(rename = "replyToken")] - reply_token: Option, -} - -#[derive(Debug, Deserialize)] -struct LineSource { - #[serde(rename = "type")] - source_type: String, - #[serde(rename = "userId")] - user_id: Option, - #[serde(rename = "groupId")] - group_id: Option, - #[serde(rename = "roomId")] - room_id: Option, -} - -#[derive(Debug, Deserialize)] -struct LineMessage { - id: String, - #[serde(rename = "type")] - message_type: String, - text: Option, - #[serde(rename = "contentProvider")] - content_provider: Option, - mention: Option, -} - -#[derive(Debug, Deserialize)] -struct LineMention { - mentionees: Vec, -} - -#[derive(Debug, Deserialize)] -struct LineMentionee { - #[serde(rename = "userId")] - user_id: Option, - #[serde(rename = "isSelf", default)] - is_self: bool, -} - -#[derive(Debug, Deserialize)] -struct LineContentProvider { - #[serde(rename = "type")] - provider_type: String, - #[serde(rename = "originalContentUrl")] - original_content_url: Option, -} - -/// Base URL for LINE Messaging API. Overridden in tests via the `api_base` parameter. -pub const LINE_API_BASE: &str = "https://api.line.me"; -/// Base URL for LINE binary content download API. -pub const LINE_DATA_API_BASE: &str = "https://api-data.line.me"; - -// --- Webhook handler --- - -pub async fn webhook( - State(state): State>, - headers: axum::http::HeaderMap, - body: axum::body::Bytes, -) -> axum::http::StatusCode { - // Validate X-Line-Signature - if let Some(ref channel_secret) = state.line_channel_secret { - use base64::Engine; - use hmac::{Hmac, Mac}; - use sha2::Sha256; - - let signature = headers - .get("x-line-signature") - .and_then(|v| v.to_str().ok()); - let Some(signature) = signature else { - warn!("LINE webhook rejected: missing X-Line-Signature"); - return axum::http::StatusCode::UNAUTHORIZED; - }; - - let mut mac = Hmac::::new_from_slice(channel_secret.as_bytes()).expect("HMAC key"); - mac.update(&body); - let expected = - base64::engine::general_purpose::STANDARD.encode(mac.finalize().into_bytes()); - if signature != expected { - warn!("LINE webhook rejected: invalid signature"); - return axum::http::StatusCode::UNAUTHORIZED; - } - } - - let webhook_body: LineWebhookBody = match serde_json::from_slice(&body) { - Ok(b) => b, - Err(e) => { - warn!("LINE webhook parse error: {e}"); - return axum::http::StatusCode::BAD_REQUEST; - } - }; - - let webhook_received_at = std::time::Instant::now(); - let background_state = state.clone(); - let permit = match background_state - .line_webhook_semaphore - .clone() - .acquire_owned() - .await - { - Ok(permit) => permit, - Err(_) => { - warn!("LINE webhook worker semaphore closed unexpectedly"); - return axum::http::StatusCode::SERVICE_UNAVAILABLE; - } - }; - tokio::spawn(async move { - let _permit = permit; - process_line_webhook_events(background_state, webhook_body, webhook_received_at).await; - }); - - axum::http::StatusCode::OK -} - -async fn process_line_webhook_events( - state: Arc, - webhook_body: LineWebhookBody, - webhook_received_at: std::time::Instant, -) { - // Acknowledge the webhook before image download/processing so LINE does not - // redeliver solely because gateway-side attachment work is slow. We keep one - // task per webhook payload so events from the same payload preserve order. - // - // Tradeoff: - // - Pros: lowers webhook latency and reduces redelivery pressure from LINE. - // - Cons: once 200 OK is returned, a later crash/task failure will not be - // retried by LINE. This PR intentionally keeps scope small and does not add - // background-task durability or duplicate suppression on top of early-ack. - // - Cons: an earlier image event from one webhook payload can also be emitted - // after a later text event from another payload if the image path is slower. - // - Guardrail: a shared semaphore bounds how many LINE payloads can enter the - // post-ack path concurrently. When saturated, new webhooks wait for capacity - // before spawning background work so bursts do not create unbounded backlog. - for event in webhook_body.events { - let Some(gateway_event) = build_gateway_event_from_line_event( - &event, - &state.client, - state.line_access_token.as_deref(), - LINE_DATA_API_BASE, - ) - .await - else { - continue; - }; - - // Cache before broadcasting the event. Once event_tx.send() fires, OAB - // may reply immediately; inserting afterward can silently force Push API. - // We still use webhook receipt time so TTL reflects true reply-token age. - if let Some(ref reply_token) = event.reply_token { - let mut cache = state - .reply_token_cache - .lock() - .unwrap_or_else(|e| e.into_inner()); - if cache.len() >= crate::REPLY_TOKEN_CACHE_MAX { - warn!( - size = cache.len(), - "reply token cache full, skipping insert" - ); - } else { - cache.insert( - gateway_event.event_id.clone(), - (reply_token.clone(), webhook_received_at), - ); - info!(event_id = %gateway_event.event_id, "cached LINE replyToken"); - } - } - - let json = serde_json::to_string(&gateway_event).unwrap(); - info!(channel = %gateway_event.channel.id, sender = %gateway_event.sender.id, "line → gateway"); - let _ = state.event_tx.send(json); - } -} - -fn sanitize_line_external_url_for_log(url: &str) -> String { - reqwest::Url::parse(url) - .ok() - .and_then(|parsed| parsed.host_str().map(str::to_owned)) - .unwrap_or_else(|| "invalid-or-missing-host".to_string()) -} - -async fn build_gateway_event_from_line_event( - event: &LineEvent, - client: &reqwest::Client, - line_access_token: Option<&str>, - data_api_base: &str, -) -> Option { - if event.event_type != "message" { - return None; - } - - let msg = event.message.as_ref()?; - if msg.message_type != "text" && msg.message_type != "image" { - return None; - } - - let text = msg.text.as_deref().unwrap_or(""); - let mut attachments = Vec::new(); - - if msg.message_type == "image" { - match msg - .content_provider - .as_ref() - .map(|provider| provider.provider_type.as_str()) - { - Some("external") => { - let original = msg - .content_provider - .as_ref() - .and_then(|provider| provider.original_content_url.as_deref()) - .unwrap_or("unknown"); - warn!( - message_id = %msg.id, - external_content_host = %sanitize_line_external_url_for_log(original), - "LINE external image content is not supported yet" - ); - } - _ => { - if let Some(access_token) = line_access_token { - if let Some(attachment) = - download_line_image(client, access_token, &msg.id, data_api_base).await - { - attachments.push(attachment); - } - } else { - warn!(message_id = %msg.id, "LINE image received but LINE_CHANNEL_ACCESS_TOKEN is not configured"); - } - } - } - } - - // Do not synthesize placeholder text for failed/unsupported image downloads. - // Core treats content.text as the user's prompt, so a fake marker would create - // a misleading turn instead of preserving the actual image content. - let event_text = text; - - if msg.message_type == "image" && event_text.trim().is_empty() && attachments.is_empty() { - info!( - message_id = %msg.id, - "LINE image event produced no attachment; skipping without synthesizing placeholder text" - ); - } - - if event_text.trim().is_empty() && attachments.is_empty() { - return None; - } - - let source = event.source.as_ref(); - let (channel_id, channel_type) = match source { - Some(s) if s.source_type == "group" => match s.group_id.as_deref() { - Some(id) if !id.is_empty() => (id.to_string(), "group".to_string()), - _ => { - warn!("LINE group event missing groupId, skipping"); - return None; - } - }, - Some(s) if s.source_type == "room" => match s.room_id.as_deref() { - Some(id) if !id.is_empty() => (id.to_string(), "room".to_string()), - _ => { - warn!("LINE room event missing roomId, skipping"); - return None; - } - }, - Some(s) => match s.user_id.as_deref() { - Some(id) if !id.is_empty() => (id.to_string(), "user".to_string()), - _ => { - warn!("LINE user event missing userId, skipping"); - return None; - } - }, - None => { - warn!("LINE event missing source, skipping"); - return None; - } - }; - let user_id = source - .and_then(|s| s.user_id.as_deref()) - .unwrap_or("unknown"); - - // Extract mentioned user IDs from the LINE webhook mention object. - // LINE populates this in group/room text messages when users are @-mentioned. - let mentionees = msg - .mention - .as_ref() - .map(|m| m.mentionees.as_slice()) - .unwrap_or_default(); - let mention_ids: Vec = mentionees - .iter() - .filter_map(|m| m.user_id.clone()) - .collect(); - - // @mention gating: in groups/rooms, only forward the event if the bot is mentioned. - // LINE sets isSelf=true on the mentionee that is the bot itself — no env var needed. - // 1:1 DMs always pass through. - let is_group = channel_type == "group" || channel_type == "room"; - if is_group && !mentionees.iter().any(|m| m.is_self) { - info!( - channel = %channel_id, - "line group message dropped (@mention gating: bot not mentioned)" - ); - return None; - } - - let mut gateway_event = GatewayEvent::new( - "line", - ChannelInfo { - id: channel_id, - channel_type, - thread_id: None, - }, - SenderInfo { - id: user_id.into(), - name: user_id.into(), - display_name: user_id.into(), - is_bot: false, - }, - event_text, - &msg.id, - mention_ids, - ); - gateway_event.content.attachments = attachments; - Some(gateway_event) -} - -pub async fn download_line_image( - client: &reqwest::Client, - access_token: &str, - message_id: &str, - api_base: &str, -) -> Option { - let mut resp = match client - .get(format!( - "{}/v2/bot/message/{}/content", - api_base, message_id - )) - .bearer_auth(access_token) - .send() - .await - { - Ok(resp) => resp, - Err(e) => { - warn!(message_id, error = %e, "LINE image download failed"); - return None; - } - }; - - if !resp.status().is_success() { - warn!(message_id, status = %resp.status(), "LINE image download failed"); - return None; - } - - if let Some(cl) = resp.headers().get(reqwest::header::CONTENT_LENGTH) { - if let Ok(size) = cl.to_str().unwrap_or("0").parse::() { - if size > IMAGE_MAX_DOWNLOAD { - warn!(message_id, size, "LINE image Content-Length exceeds limit"); - return None; - } - } - } - - let mut body = Vec::new(); - loop { - let chunk = match resp.chunk().await { - Ok(Some(chunk)) => chunk, - Ok(None) => break, - Err(e) => { - warn!(message_id, error = %e, "LINE image download failed while reading body"); - return None; - } - }; - body.extend_from_slice(&chunk); - if body.len() as u64 > IMAGE_MAX_DOWNLOAD { - warn!(message_id, size = body.len(), "LINE image exceeds limit"); - return None; - } - } - - let (compressed, mime) = - match tokio::task::spawn_blocking(move || resize_and_compress(&body)).await { - Ok(Ok(v)) => v, - Ok(Err(e)) => { - warn!(message_id, error = %e, "LINE image resize/compress failed"); - return None; - } - Err(e) => { - warn!(message_id, error = %e, "LINE image processing task failed"); - return None; - } - }; - let path = store::store_media(&compressed).await?; - let ext = if mime == "image/gif" { "gif" } else { "jpg" }; - Some(Attachment { - attachment_type: "image".into(), - filename: format!("line_{}.{}", message_id, ext), - mime_type: mime, - data: String::new(), - size: compressed.len() as u64, - path: Some(path), - }) -} - -// --- Reply handler (hybrid Reply/Push dispatch) --- - -/// Dispatch a reply to LINE using the hybrid Reply/Push strategy. -/// -/// Returns `true` if Reply API was used (or assumed used), `false` if Push API was used. -pub async fn dispatch_line_reply( - client: &reqwest::Client, - access_token: &str, - reply_cache: &crate::ReplyTokenCache, - reply: &GatewayReply, - api_base: &str, -) -> bool { - if matches!( - reply.command.as_deref(), - Some("add_reaction") | Some("remove_reaction") | Some("create_topic") - ) { - info!(command = ?reply.command.as_deref(), "line: ignoring unsupported command"); - return false; - } - - // Extract token from cache (drop lock before HTTP call) - let cached_token = { - let mut cache = reply_cache.lock().unwrap_or_else(|e| e.into_inner()); - cache - .remove(&reply.reply_to) - .and_then(|(token, cached_at)| { - if cached_at.elapsed().as_secs() < crate::REPLY_TOKEN_TTL_SECS { - Some(token) - } else { - info!("LINE replyToken expired, using Push API"); - None - } - }) - }; - - // Try Reply API first (free, no quota consumed) - let mut used_reply = false; - if let Some(reply_token) = cached_token { - info!(to = %reply.channel.id, "gateway → line (reply API)"); - let resp = client - .post(format!("{}/v2/bot/message/reply", api_base)) - .bearer_auth(access_token) - .json(&serde_json::json!({ - "replyToken": reply_token, - "messages": [{"type": "text", "text": reply.content.text}] - })) - .send() - .await; - match resp { - Ok(r) if r.status().is_success() => { - used_reply = true; - } - Ok(r) => { - let status = r.status(); - let body = r.text().await.unwrap_or_default(); - let body_lower = body.to_lowercase(); - let token_unusable = status.as_u16() == 400 - && ((body_lower.contains("invalid") && body_lower.contains("reply token")) - || body_lower.contains("expired")); - if token_unusable { - warn!(status = %status, body = %body, "LINE reply token unusable, falling back to Push"); - } else { - error!(status = %status, body = %body, "LINE Reply API error, NOT falling back to Push (possible duplicate risk)"); - used_reply = true; - } - } - Err(e) => { - error!(err = %e, "LINE Reply API network error, NOT falling back to Push (possible duplicate risk)"); - used_reply = true; - } - } - } - - // Fallback to Push API - if !used_reply { - info!(to = %reply.channel.id, "gateway → line (push API)"); - let _ = client - .post(format!("{}/v2/bot/message/push", api_base)) - .bearer_auth(access_token) - .json(&serde_json::json!({ - "to": reply.channel.id, - "messages": [{"type": "text", "text": reply.content.text}] - })) - .send() - .await - .map_err(|e| error!("line push error: {e}")); - } - - used_reply -} - -#[cfg(test)] -mod tests { - use super::*; - use axum::extract::State; - use std::collections::HashMap; - use std::sync::Arc; - use tokio::sync::{broadcast, Mutex, Semaphore}; - use wiremock::matchers::{header, method, path}; - use wiremock::{Mock, MockServer, ResponseTemplate}; - - #[tokio::test] - async fn download_line_image_resizes_and_returns_attachment() { - let server = MockServer::start().await; - let img = image::RgbImage::from_pixel(16, 16, image::Rgb([0, 128, 255])); - let mut buf = std::io::Cursor::new(Vec::new()); - img.write_to(&mut buf, image::ImageFormat::Png).unwrap(); - let bytes = buf.into_inner(); - - let _mock = Mock::given(method("GET")) - .and(path("/v2/bot/message/msg123/content")) - .and(header("authorization", "Bearer line_token")) - .respond_with( - ResponseTemplate::new(200) - .insert_header("content-type", "image/png") - .set_body_bytes(bytes), - ) - .mount_as_scoped(&server) - .await; - - let attachment = download_line_image( - &reqwest::Client::new(), - "line_token", - "msg123", - &server.uri(), - ) - .await - .expect("attachment should be downloaded"); - - assert_eq!(attachment.attachment_type, "image"); - assert!(attachment.filename.starts_with("line_msg123.")); - assert!(attachment.path.is_some()); - assert!(attachment.size > 0); - - let path = attachment.path.unwrap(); - let stored = tokio::fs::read(&path).await.unwrap(); - assert!(!stored.is_empty()); - let _ = tokio::fs::remove_file(path).await; - } - - #[tokio::test] - async fn build_gateway_event_from_line_image_attaches_downloaded_image() { - let server = MockServer::start().await; - let img = image::RgbImage::from_pixel(8, 8, image::Rgb([255, 0, 0])); - let mut buf = std::io::Cursor::new(Vec::new()); - img.write_to(&mut buf, image::ImageFormat::Png).unwrap(); - let bytes = buf.into_inner(); - - let _mock = Mock::given(method("GET")) - .and(path("/v2/bot/message/msg_image/content")) - .and(header("authorization", "Bearer line_token")) - .respond_with( - ResponseTemplate::new(200) - .insert_header("content-type", "image/png") - .set_body_bytes(bytes), - ) - .mount_as_scoped(&server) - .await; - - let event: LineEvent = serde_json::from_value(serde_json::json!({ - "type": "message", - "replyToken": "reply123", - "source": {"type": "user", "userId": "U123"}, - "message": { - "id": "msg_image", - "type": "image", - "contentProvider": {"type": "line"} - } - })) - .unwrap(); - - let gateway_event = build_gateway_event_from_line_event( - &event, - &reqwest::Client::new(), - Some("line_token"), - &server.uri(), - ) - .await - .expect("image event should produce a gateway event"); - - assert_eq!(gateway_event.platform, "line"); - assert_eq!(gateway_event.content.text, ""); - assert_eq!(gateway_event.content.attachments.len(), 1); - - let path = gateway_event.content.attachments[0] - .path - .clone() - .expect("path should be stored"); - let _ = tokio::fs::remove_file(path).await; - } - - #[tokio::test] - async fn download_line_image_rejects_oversized_content_length() { - let server = MockServer::start().await; - - let _mock = Mock::given(method("GET")) - .and(path("/v2/bot/message/msg_big/content")) - .and(header("authorization", "Bearer line_token")) - .respond_with( - ResponseTemplate::new(200) - .insert_header("content-type", "image/png") - .insert_header("content-length", (IMAGE_MAX_DOWNLOAD + 1).to_string()) - .set_body_bytes(vec![0u8; IMAGE_MAX_DOWNLOAD as usize + 1]), - ) - .mount_as_scoped(&server) - .await; - - let attachment = download_line_image( - &reqwest::Client::new(), - "line_token", - "msg_big", - &server.uri(), - ) - .await; - - assert!(attachment.is_none()); - } - - #[tokio::test] - async fn webhook_acknowledges_before_async_event_forwarding() { - let (event_tx, mut event_rx) = broadcast::channel::(8); - let state = Arc::new(crate::AppState { - telegram_bot_token: None, - telegram_secret_token: None, - telegram_rich_messages: false, - line_channel_secret: None, - line_access_token: None, - teams: None, - teams_service_urls: Mutex::new(HashMap::new()), - feishu: None, - google_chat: None, - wecom: None, - ws_token: None, - event_tx, - reply_token_cache: Arc::new(std::sync::Mutex::new(HashMap::new())), - line_webhook_semaphore: Arc::new(Semaphore::new(crate::LINE_WEBHOOK_CONCURRENCY_MAX)), - client: reqwest::Client::new(), - }); - - let body = axum::body::Bytes::from( - serde_json::json!({ - "events": [{ - "type": "message", - "replyToken": "reply123", - "source": {"type": "user", "userId": "U123"}, - "message": {"id": "msg123", "type": "text", "text": "hello"} - }] - }) - .to_string(), - ); - - let status = webhook(State(state.clone()), axum::http::HeaderMap::new(), body).await; - assert_eq!(status, axum::http::StatusCode::OK); - - let event_json = tokio::time::timeout(std::time::Duration::from_secs(1), event_rx.recv()) - .await - .expect("background task should forward an event") - .expect("broadcast should succeed"); - let event: GatewayEvent = serde_json::from_str(&event_json).expect("valid gateway event"); - - assert_eq!(event.message_id, "msg123"); - assert_eq!(event.content.text, "hello"); - - let cache = state - .reply_token_cache - .lock() - .unwrap_or_else(|e| e.into_inner()); - let (token, cached_at) = cache - .get(&event.event_id) - .expect("reply token should be cached"); - assert_eq!(token, "reply123"); - assert!(cached_at.elapsed() < std::time::Duration::from_secs(1)); - } - - // --- @mention gating tests --- - - fn make_group_text_event(text: &str, bot_mentioned: bool) -> LineEvent { - let mention = if bot_mentioned { - serde_json::json!({"mentionees": [{"userId": "Ubot123", "type": "user", "isSelf": true}]}) - } else { - serde_json::json!({"mentionees": [{"userId": "Uother", "type": "user", "isSelf": false}]}) - }; - serde_json::from_value(serde_json::json!({ - "type": "message", - "source": {"type": "group", "groupId": "C001", "userId": "U_sender"}, - "message": { - "id": "msg001", - "type": "text", - "text": text, - "mention": mention - } - })) - .unwrap() - } - - #[tokio::test] - async fn group_message_passes_when_bot_mentioned() { - let event = make_group_text_event("@Bot hello", true); - let result = build_gateway_event_from_line_event( - &event, - &reqwest::Client::new(), - None, - LINE_DATA_API_BASE, - ) - .await; - assert!(result.is_some()); - let gw = result.unwrap(); - assert_eq!(gw.mentions, vec!["Ubot123"]); - } - - #[tokio::test] - async fn group_message_dropped_when_bot_not_mentioned() { - let event = make_group_text_event("hey everyone", false); - let result = build_gateway_event_from_line_event( - &event, - &reqwest::Client::new(), - None, - LINE_DATA_API_BASE, - ) - .await; - assert!(result.is_none()); - } - - #[tokio::test] - async fn group_message_dropped_when_no_mention_at_all() { - let event: LineEvent = serde_json::from_value(serde_json::json!({ - "type": "message", - "source": {"type": "group", "groupId": "C001", "userId": "U_sender"}, - "message": {"id": "msg001", "type": "text", "text": "plain message no mention"} - })) - .unwrap(); - let result = build_gateway_event_from_line_event( - &event, - &reqwest::Client::new(), - None, - LINE_DATA_API_BASE, - ) - .await; - assert!(result.is_none()); - } - - #[tokio::test] - async fn dm_passes_even_without_mention() { - let event: LineEvent = serde_json::from_value(serde_json::json!({ - "type": "message", - "source": {"type": "user", "userId": "U_human"}, - "message": {"id": "msg002", "type": "text", "text": "hello bot"} - })) - .unwrap(); - let result = build_gateway_event_from_line_event( - &event, - &reqwest::Client::new(), - None, - LINE_DATA_API_BASE, - ) - .await; - assert!(result.is_some()); - } -} diff --git a/gateway/src/adapters/mod.rs b/gateway/src/adapters/mod.rs deleted file mode 100644 index 94a2a8a79..000000000 --- a/gateway/src/adapters/mod.rs +++ /dev/null @@ -1,6 +0,0 @@ -pub mod feishu; -pub mod googlechat; -pub mod line; -pub mod teams; -pub mod telegram; -pub mod wecom; diff --git a/gateway/src/adapters/teams.rs b/gateway/src/adapters/teams.rs deleted file mode 100644 index 09ac09df8..000000000 --- a/gateway/src/adapters/teams.rs +++ /dev/null @@ -1,877 +0,0 @@ -use crate::schema::*; -use axum::extract::State; -use axum::http::{HeaderMap, StatusCode}; -use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation}; -use serde::Deserialize; -use std::sync::Arc; -use tokio::sync::RwLock; -use tracing::{debug, error, info, warn}; - -// --- Bot Framework activity types --- - -#[allow(dead_code)] // Bot Framework schema fields — needed for future features -#[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct Activity { - #[serde(rename = "type")] - pub activity_type: String, - pub id: Option, - pub timestamp: Option, - pub service_url: Option, - pub channel_id: Option, - pub from: Option, - pub conversation: Option, - pub text: Option, - pub tenant: Option, - pub channel_data: Option, -} - -#[allow(dead_code)] -#[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ChannelAccount { - pub id: Option, - pub name: Option, - pub aad_object_id: Option, -} - -#[allow(dead_code)] -#[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ConversationAccount { - pub id: Option, - pub conversation_type: Option, - pub is_group: Option, - pub tenant_id: Option, -} - -#[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct TenantInfo { - pub id: Option, -} - -#[allow(dead_code)] -#[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ChannelData { - pub tenant: Option, -} - -impl Activity { - /// Resolve tenant id from any of the locations Teams may put it. - pub fn resolved_tenant_id(&self) -> Option<&str> { - self.tenant - .as_ref() - .and_then(|t| t.id.as_deref()) - .or_else(|| { - self.channel_data - .as_ref() - .and_then(|c| c.tenant.as_ref()) - .and_then(|t| t.id.as_deref()) - }) - .or_else(|| { - self.conversation - .as_ref() - .and_then(|c| c.tenant_id.as_deref()) - }) - } -} - -// --- OpenID configuration --- - -#[derive(Debug, Deserialize)] -struct OpenIdConfig { - jwks_uri: String, -} - -#[derive(Debug, Deserialize)] -struct JwksResponse { - keys: Vec, -} - -#[derive(Debug, Clone, Deserialize)] -struct JwkKey { - kid: Option, - n: String, - e: String, - kty: String, - #[serde(default)] - endorsements: Vec, -} - -// --- OAuth token --- - -#[derive(Debug, Deserialize)] -struct TokenResponse { - access_token: String, - expires_in: u64, -} - -struct CachedToken { - token: String, - expires_at: std::time::Instant, -} - -// --- Teams adapter config --- - -pub struct TeamsConfig { - pub app_id: String, - pub app_secret: String, - pub oauth_endpoint: String, - pub openid_metadata: String, - pub allowed_tenants: Vec, -} - -impl TeamsConfig { - pub fn from_env() -> Option { - let app_id = std::env::var("TEAMS_APP_ID").ok()?; - let app_secret = std::env::var("TEAMS_APP_SECRET").ok()?; - Some(Self { - app_id, - app_secret, - oauth_endpoint: std::env::var("TEAMS_OAUTH_ENDPOINT").unwrap_or_else(|_| { - "https://login.microsoftonline.com/botframework.com/oauth2/v2.0/token".into() - }), - openid_metadata: std::env::var("TEAMS_OPENID_METADATA").unwrap_or_else(|_| { - "https://login.botframework.com/v1/.well-known/openidconfiguration".into() - }), - allowed_tenants: std::env::var("TEAMS_ALLOWED_TENANTS") - .unwrap_or_default() - .split(',') - .map(|s| s.trim().to_string()) - .filter(|s| !s.is_empty()) - .collect(), - }) - } -} - -// --- Teams adapter state --- - -pub struct TeamsAdapter { - config: TeamsConfig, - client: reqwest::Client, - token_cache: RwLock>, - jwks_cache: RwLock, std::time::Instant)>>, -} - -const JWKS_CACHE_TTL: std::time::Duration = std::time::Duration::from_secs(3600); -const TOKEN_REFRESH_MARGIN: std::time::Duration = std::time::Duration::from_secs(300); - -impl TeamsAdapter { - pub fn new(config: TeamsConfig) -> Self { - Self { - config, - client: reqwest::Client::new(), - token_cache: RwLock::new(None), - jwks_cache: RwLock::new(None), - } - } - - /// Get a valid OAuth bearer token, refreshing if needed. - async fn get_token(&self) -> anyhow::Result { - // Check cache - { - let cache = self.token_cache.read().await; - if let Some(ref cached) = *cache { - if cached.expires_at > std::time::Instant::now() + TOKEN_REFRESH_MARGIN { - return Ok(cached.token.clone()); - } - } - } - - // Fetch new token - let resp: TokenResponse = self - .client - .post(&self.config.oauth_endpoint) - .form(&[ - ("grant_type", "client_credentials"), - ("client_id", &self.config.app_id), - ("client_secret", &self.config.app_secret), - ("scope", "https://api.botframework.com/.default"), - ]) - .send() - .await? - .json() - .await?; - - let token = resp.access_token.clone(); - *self.token_cache.write().await = Some(CachedToken { - token: resp.access_token, - expires_at: std::time::Instant::now() + std::time::Duration::from_secs(resp.expires_in), - }); - info!("teams OAuth token refreshed"); - Ok(token) - } - - /// Fetch and cache JWKS signing keys from Microsoft's OpenID metadata. - async fn get_jwks(&self) -> anyhow::Result> { - { - let cache = self.jwks_cache.read().await; - if let Some((ref keys, fetched_at)) = *cache { - if fetched_at.elapsed() < JWKS_CACHE_TTL { - return Ok(keys.clone()); - } - } - } - - let config: OpenIdConfig = self - .client - .get(&self.config.openid_metadata) - .send() - .await? - .json() - .await?; - - let jwks: JwksResponse = self - .client - .get(&config.jwks_uri) - .send() - .await? - .json() - .await?; - - let keys = jwks.keys; - *self.jwks_cache.write().await = Some((keys.clone(), std::time::Instant::now())); - info!(count = keys.len(), "teams JWKS keys refreshed"); - Ok(keys) - } - - /// Force-refresh JWKS keys, bypassing cache TTL. Called on cache miss (kid not found). - async fn refresh_jwks(&self) -> anyhow::Result> { - // Invalidate cache so get_jwks fetches fresh - *self.jwks_cache.write().await = None; - self.get_jwks().await - } - - /// Validate the JWT bearer token from an inbound Bot Framework request. - /// Checks: signature, issuer, audience, expiry, serviceUrl claim, and channel endorsements. - pub async fn validate_jwt(&self, auth_header: &str, activity: &Activity) -> anyhow::Result<()> { - let token = auth_header - .strip_prefix("Bearer ") - .ok_or_else(|| anyhow::anyhow!("missing Bearer prefix"))?; - - // Decode header to get kid - let header = jsonwebtoken::decode_header(token)?; - let kid = header - .kid - .ok_or_else(|| anyhow::anyhow!("no kid in JWT header"))?; - - let keys = self.get_jwks().await?; - let key = match keys.iter().find(|k| k.kid.as_deref() == Some(&kid)) { - Some(k) => k.clone(), - None => { - // Cache miss: Microsoft may have rotated keys. Force refresh and retry. - let refreshed = self.refresh_jwks().await?; - refreshed - .into_iter() - .find(|k| k.kid.as_deref() == Some(&kid)) - .ok_or_else(|| anyhow::anyhow!("no matching JWK for kid={kid} after refresh"))? - } - }; - - if key.kty != "RSA" { - anyhow::bail!("unsupported key type: {}", key.kty); - } - - // B2: Validate channel endorsements — key must endorse the activity's channelId - let channel_id = activity.channel_id.as_deref() - .ok_or_else(|| anyhow::anyhow!("activity missing channelId"))?; - if key.endorsements.is_empty() { - anyhow::bail!("JWK has no endorsements — cannot verify channelId={channel_id}"); - } - if !key.endorsements.iter().any(|e| e == channel_id) { - anyhow::bail!( - "JWK endorsements {:?} do not include channelId={channel_id}", - key.endorsements - ); - } - - let decoding_key = DecodingKey::from_rsa_components(&key.n, &key.e)?; - let mut validation = Validation::new(Algorithm::RS256); - validation.set_audience(&[&self.config.app_id]); - // Bot Framework tokens can use RS256 or RS384 - validation.algorithms = vec![Algorithm::RS256, Algorithm::RS384]; - // Bot Framework issuer per auth spec - validation.set_issuer(&["https://api.botframework.com"]); - validation.validate_aud = true; - validation.validate_exp = true; - validation.validate_nbf = false; - - let token_data = decode::(token, &decoding_key, &validation)?; - - // B1: Validate serviceUrl claim matches activity's serviceUrl - let activity_service_url = activity.service_url.as_deref() - .ok_or_else(|| anyhow::anyhow!("activity missing serviceUrl"))?; - let token_service_url = token_data.claims.get("serviceurl") - .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("JWT missing serviceurl claim"))?; - if token_service_url != activity_service_url { - anyhow::bail!( - "serviceUrl mismatch: token={token_service_url}, activity={activity_service_url}" - ); - } - - Ok(()) - } - - /// Check tenant allowlist. - fn check_tenant(&self, activity: &Activity) -> bool { - if self.config.allowed_tenants.is_empty() { - return true; - } - activity - .resolved_tenant_id() - .is_some_and(|tid| self.config.allowed_tenants.iter().any(|a| a == tid)) - } - - /// Send a reply via Bot Framework REST API. - pub async fn send_activity( - &self, - service_url: &str, - conversation_id: &str, - text: &str, - reply_to_id: Option<&str>, - ) -> anyhow::Result { - let token = self.get_token().await?; - let url = format!( - "{}v3/conversations/{}/activities", - ensure_trailing_slash(service_url), - conversation_id - ); - - let mut body = serde_json::json!({ - "type": "message", - "from": { "id": &self.config.app_id }, - "text": text, - "textFormat": "markdown", - }); - if let Some(id) = reply_to_id { - body["replyToId"] = serde_json::Value::String(id.to_string()); - } - - let resp = self - .client - .post(&url) - .bearer_auth(&token) - .json(&body) - .send() - .await?; - - if !resp.status().is_success() { - let status = resp.status(); - let body = resp.text().await.unwrap_or_default(); - anyhow::bail!("Bot Framework API error {status}: {body}"); - } - - let result: serde_json::Value = resp.json().await?; - Ok(result["id"].as_str().unwrap_or("").to_string()) - } - - /// Edit an existing activity (for streaming updates). - pub async fn update_activity( - &self, - service_url: &str, - conversation_id: &str, - activity_id: &str, - text: &str, - ) -> anyhow::Result<()> { - let token = self.get_token().await?; - let url = format!( - "{}v3/conversations/{}/activities/{}", - ensure_trailing_slash(service_url), - conversation_id, - activity_id - ); - - let body = serde_json::json!({ - "type": "message", - "from": { "id": &self.config.app_id }, - "text": text, - }); - - let resp = self - .client - .put(&url) - .bearer_auth(&token) - .json(&body) - .send() - .await?; - - if !resp.status().is_success() { - let status = resp.status(); - let body = resp.text().await.unwrap_or_default(); - anyhow::bail!("Bot Framework update error {status}: {body}"); - } - Ok(()) - } -} - -fn ensure_trailing_slash(url: &str) -> String { - if url.ends_with('/') { - url.to_string() - } else { - format!("{url}/") - } -} - -// --- Webhook handler --- - -/// Max webhook body size: 256 KB. Real Teams activities are a few KB; the -/// activity is parsed *before* JWT auth (Bot Framework requires serviceUrl / -/// channelId from the body to validate the token), so this caps the -/// unauthenticated parse attack surface. Mirrors the feishu adapter's limit. -const WEBHOOK_BODY_LIMIT: usize = 256 * 1024; - -pub async fn webhook( - State(state): State>, - headers: HeaderMap, - body: String, -) -> StatusCode { - let teams = match &state.teams { - Some(t) => t, - None => return StatusCode::NOT_FOUND, - }; - - // Defense-in-depth: bound the pre-auth body size (axum's default limit is 2 MB). - if body.len() > WEBHOOK_BODY_LIMIT { - warn!(size = body.len(), "teams webhook body too large"); - return StatusCode::PAYLOAD_TOO_LARGE; - } - - // Extract auth header early (before parsing activity) - let auth_header = match headers.get("authorization").and_then(|v| v.to_str().ok()) { - Some(h) => h.to_string(), - None => { - warn!("teams webhook: missing authorization header"); - return StatusCode::UNAUTHORIZED; - } - }; - - // Parse activity first (needed for JWT serviceUrl + endorsements validation). - // - // SECURITY NOTE (OX untrusted-deserialization finding — false positive): - // `Activity` is a strict, derive-only DTO (String / Option<_> / nested - // structs) with no custom Deserialize, no side-effectful Drop, and no enum - // variant dispatch. serde_json's data model cannot instantiate arbitrary - // types (unlike bincode/serde_yaml/rmp-serde), so object-injection / RCE - // does not apply. The recommended "strict DTO + validate after" pattern is - // already in place: JWT, activity-type, and tenant-allowlist checks below. - // DoS is bounded by serde_json's recursion limit (128) and the body cap above. - let activity: Activity = match serde_json::from_str(&body) { - Ok(a) => a, - Err(e) => { - warn!(error = %e, "teams: invalid activity JSON"); - return StatusCode::BAD_REQUEST; - } - }; - - // JWT validation (with activity context for serviceUrl + channelId checks) - if let Err(e) = teams.validate_jwt(&auth_header, &activity).await { - warn!(error = %e, "teams JWT validation failed"); - return StatusCode::UNAUTHORIZED; - } - - // Only handle message activities - if activity.activity_type != "message" { - debug!(activity_type = %activity.activity_type, "teams: ignoring non-message activity"); - return StatusCode::OK; - } - - // Tenant check - if !teams.check_tenant(&activity) { - let tid = activity.resolved_tenant_id().unwrap_or("unknown"); - warn!(tenant = tid, "teams: tenant not in allowlist"); - return StatusCode::FORBIDDEN; - } - - let text = match activity.text.as_deref() { - Some(t) if !t.trim().is_empty() => t.trim(), - _ => return StatusCode::OK, - }; - - let conversation_id = activity - .conversation - .as_ref() - .and_then(|c| c.id.as_deref()) - .unwrap_or(""); - let conversation_type = activity - .conversation - .as_ref() - .and_then(|c| c.conversation_type.as_deref()) - .unwrap_or("personal"); - let service_url = activity.service_url.as_deref().unwrap_or(""); - let sender_id = activity - .from - .as_ref() - .and_then(|f| f.id.as_deref()) - .unwrap_or(""); - let sender_name = activity - .from - .as_ref() - .and_then(|f| f.name.as_deref()) - .unwrap_or("Unknown"); - let activity_id = activity.id.as_deref().unwrap_or(""); - - // B3: Guard against empty service_url — replies will fail without it - if service_url.is_empty() { - warn!("teams: activity missing service_url, cannot route replies"); - return StatusCode::OK; - } - - let event = GatewayEvent::new( - "teams", - ChannelInfo { - id: conversation_id.to_string(), - channel_type: conversation_type.to_string(), - thread_id: None, // Teams conversations don't have sub-threads in the same way - }, - SenderInfo { - id: sender_id.to_string(), - name: sender_name.to_string(), - display_name: sender_name.to_string(), - is_bot: false, - }, - text, - activity_id, - vec![], // Teams @mentions parsing deferred to future PR - ); - - // Store service_url for reply routing - state.teams_service_urls.lock().await.insert( - conversation_id.to_string(), - (service_url.to_string(), std::time::Instant::now()), - ); - - let json = serde_json::to_string(&event).unwrap(); - let tenant_id = activity.resolved_tenant_id().unwrap_or(""); - info!( - conversation = conversation_id, - sender = sender_name, - tenant = tenant_id, - service_url = service_url, - "teams → gateway" - ); - let _ = state.event_tx.send(json); - - StatusCode::OK -} - -// --- Reply handler --- - -pub async fn handle_reply( - reply: &GatewayReply, - teams: &TeamsAdapter, - service_urls: &tokio::sync::Mutex< - std::collections::HashMap, - >, -) { - // Reactions are not supported on Teams — silently ignore - if reply.command.as_deref() == Some("add_reaction") - || reply.command.as_deref() == Some("remove_reaction") - { - return; - } - - let service_url = { - let mut urls = service_urls.lock().await; - match urls.get_mut(&reply.channel.id) { - Some((url, ts)) => { - // Refresh timestamp on reply to prevent TTL expiry during active conversations - *ts = std::time::Instant::now(); - url.clone() - } - None => { - error!(conversation = %reply.channel.id, "teams: no service_url for conversation"); - return; - } - } - }; - - let reply_to_id = if reply.reply_to.is_empty() { - None - } else { - Some(reply.reply_to.as_str()) - }; - - info!(conversation = %reply.channel.id, "gateway → teams"); - match teams - .send_activity( - &service_url, - &reply.channel.id, - &reply.content.text, - reply_to_id, - ) - .await - { - Ok(id) => debug!(activity_id = %id, "teams activity sent"), - Err(e) => error!(error = %e, "teams send error"), - } -} - -#[cfg(test)] -mod tests { - use super::*; - - // --- ensure_trailing_slash --- - - #[test] - fn trailing_slash_adds_when_missing() { - assert_eq!( - ensure_trailing_slash("https://example.com"), - "https://example.com/" - ); - } - - #[test] - fn trailing_slash_keeps_when_present() { - assert_eq!( - ensure_trailing_slash("https://example.com/"), - "https://example.com/" - ); - } - - #[test] - fn trailing_slash_empty_string() { - assert_eq!(ensure_trailing_slash(""), "/"); - } - - // --- check_tenant --- - - fn make_config(tenants: Vec<&str>) -> TeamsConfig { - TeamsConfig { - app_id: "test-app".into(), - app_secret: "test-secret".into(), - oauth_endpoint: "https://example.com/token".into(), - openid_metadata: "https://example.com/openid".into(), - allowed_tenants: tenants.into_iter().map(|s| s.to_string()).collect(), - } - } - - fn make_test_state() -> Arc { - let (event_tx, _rx) = tokio::sync::broadcast::channel(16); - - Arc::new(crate::AppState { - telegram_bot_token: None, - telegram_secret_token: None, - telegram_rich_messages: false, - line_channel_secret: None, - line_access_token: None, - teams: Some(TeamsAdapter::new(make_config(vec![]))), - teams_service_urls: tokio::sync::Mutex::new(std::collections::HashMap::new()), - feishu: None, - google_chat: None, - wecom: None, - ws_token: None, - event_tx, - reply_token_cache: Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())), - line_webhook_semaphore: Arc::new(tokio::sync::Semaphore::new(crate::LINE_WEBHOOK_CONCURRENCY_MAX)), - client: reqwest::Client::new(), - }) - } - - fn make_activity_with_tenant(tenant_id: Option<&str>) -> Activity { - Activity { - activity_type: "message".into(), - id: Some("act1".into()), - timestamp: None, - service_url: Some("https://smba.trafficmanager.net/".into()), - channel_id: Some("msteams".into()), - from: None, - conversation: None, - text: Some("hello".into()), - tenant: tenant_id.map(|id| TenantInfo { - id: Some(id.into()), - }), - channel_data: None, - } - } - - // --- webhook body limit --- - - #[tokio::test] - async fn webhook_rejects_oversized_body_before_auth() { - let status = webhook( - State(make_test_state()), - HeaderMap::new(), - "x".repeat(WEBHOOK_BODY_LIMIT + 1), - ) - .await; - - assert_eq!(status, StatusCode::PAYLOAD_TOO_LARGE); - } - - #[tokio::test] - async fn webhook_allows_body_at_limit_to_reach_auth() { - let status = webhook( - State(make_test_state()), - HeaderMap::new(), - "x".repeat(WEBHOOK_BODY_LIMIT), - ) - .await; - - assert_eq!(status, StatusCode::UNAUTHORIZED); - } - - #[test] - fn tenant_allowed_when_list_empty() { - let adapter = TeamsAdapter::new(make_config(vec![])); - let activity = make_activity_with_tenant(Some("any-tenant")); - assert!(adapter.check_tenant(&activity)); - } - - #[test] - fn tenant_allowed_when_in_list() { - let adapter = TeamsAdapter::new(make_config(vec!["tenant-a", "tenant-b"])); - let activity = make_activity_with_tenant(Some("tenant-b")); - assert!(adapter.check_tenant(&activity)); - } - - #[test] - fn tenant_rejected_when_not_in_list() { - let adapter = TeamsAdapter::new(make_config(vec!["tenant-a"])); - let activity = make_activity_with_tenant(Some("tenant-x")); - assert!(!adapter.check_tenant(&activity)); - } - - #[test] - fn tenant_rejected_when_no_tenant_info() { - let adapter = TeamsAdapter::new(make_config(vec!["tenant-a"])); - let activity = make_activity_with_tenant(None); - assert!(!adapter.check_tenant(&activity)); - } - - #[test] - fn tenant_allowed_when_no_tenant_and_empty_list() { - let adapter = TeamsAdapter::new(make_config(vec![])); - let activity = make_activity_with_tenant(None); - assert!(adapter.check_tenant(&activity)); - } - - // --- resolved_tenant_id --- - - #[test] - fn resolved_tenant_falls_back_to_channel_data() { - // Teams personal/channel webhooks put tenant in channelData, not top-level - let json = r#"{ - "type": "message", - "channelData": {"tenant": {"id": "from-channel-data"}} - }"#; - let activity: Activity = serde_json::from_str(json).unwrap(); - assert_eq!(activity.resolved_tenant_id(), Some("from-channel-data")); - } - - #[test] - fn resolved_tenant_prefers_top_level_over_channel_data() { - let json = r#"{ - "type": "message", - "tenant": {"id": "top-level"}, - "channelData": {"tenant": {"id": "from-channel-data"}} - }"#; - let activity: Activity = serde_json::from_str(json).unwrap(); - assert_eq!(activity.resolved_tenant_id(), Some("top-level")); - } - - #[test] - fn resolved_tenant_falls_back_to_conversation_tenant_id() { - let json = r#"{ - "type": "message", - "conversation": {"id": "c1", "tenantId": "from-conversation"} - }"#; - let activity: Activity = serde_json::from_str(json).unwrap(); - assert_eq!(activity.resolved_tenant_id(), Some("from-conversation")); - } - - #[test] - fn resolved_tenant_returns_none_when_absent() { - let json = r#"{"type": "message"}"#; - let activity: Activity = serde_json::from_str(json).unwrap(); - assert_eq!(activity.resolved_tenant_id(), None); - } - - // --- validate_jwt error paths --- - - #[tokio::test] - async fn jwt_rejects_missing_bearer_prefix() { - let adapter = TeamsAdapter::new(make_config(vec![])); - let activity = make_activity_with_tenant(Some("t1")); - let result = adapter.validate_jwt("NotBearer xyz", &activity).await; - assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("Bearer")); - } - - #[tokio::test] - async fn jwt_rejects_empty_bearer() { - let adapter = TeamsAdapter::new(make_config(vec![])); - let activity = make_activity_with_tenant(Some("t1")); - let result = adapter.validate_jwt("Bearer ", &activity).await; - assert!(result.is_err()); - } - - #[tokio::test] - async fn jwt_rejects_garbage_token() { - let adapter = TeamsAdapter::new(make_config(vec![])); - let activity = make_activity_with_tenant(Some("t1")); - let result = adapter.validate_jwt("Bearer not.a.valid.jwt", &activity).await; - assert!(result.is_err()); - } - - // --- Activity deserialization --- - - #[test] - fn deserialize_minimal_activity() { - let json = r#"{"type": "message"}"#; - let activity: Activity = serde_json::from_str(json).unwrap(); - assert_eq!(activity.activity_type, "message"); - assert!(activity.text.is_none()); - assert!(activity.from.is_none()); - } - - #[test] - fn deserialize_full_activity() { - let json = r#"{ - "type": "message", - "id": "act123", - "serviceUrl": "https://smba.trafficmanager.net/", - "channelId": "msteams", - "from": {"id": "user1", "name": "Alice", "aadObjectId": "aad-123"}, - "conversation": {"id": "conv1", "conversationType": "personal", "isGroup": false}, - "text": "hello bot", - "tenant": {"id": "tenant-abc"} - }"#; - let activity: Activity = serde_json::from_str(json).unwrap(); - assert_eq!(activity.activity_type, "message"); - assert_eq!(activity.text.as_deref(), Some("hello bot")); - assert_eq!( - activity.from.as_ref().unwrap().name.as_deref(), - Some("Alice") - ); - assert_eq!( - activity.tenant.as_ref().unwrap().id.as_deref(), - Some("tenant-abc") - ); - } - - #[test] - fn deserialize_non_message_activity() { - let json = r#"{"type": "conversationUpdate"}"#; - let activity: Activity = serde_json::from_str(json).unwrap(); - assert_eq!(activity.activity_type, "conversationUpdate"); - } - - #[test] - fn deserialize_invalid_json_fails() { - let result = serde_json::from_str::("not json"); - assert!(result.is_err()); - } - - // --- TeamsConfig::from_env --- - - #[test] - fn config_from_env_returns_none_without_vars() { - // Ensure the env vars are not set (they shouldn't be in test) - std::env::remove_var("TEAMS_APP_ID"); - std::env::remove_var("TEAMS_APP_SECRET"); - assert!(TeamsConfig::from_env().is_none()); - } -} diff --git a/gateway/src/adapters/telegram.rs b/gateway/src/adapters/telegram.rs deleted file mode 100644 index 60a98bd06..000000000 --- a/gateway/src/adapters/telegram.rs +++ /dev/null @@ -1,782 +0,0 @@ -use crate::media::{resize_and_compress, MediaKind, AUDIO_MAX_DOWNLOAD, FILE_MAX_DOWNLOAD, IMAGE_MAX_DOWNLOAD}; -use crate::schema::*; -use crate::store; -use axum::extract::State; -use axum::Json; -use serde::Deserialize; -use std::collections::HashMap; -use std::sync::Arc; -use tokio::sync::Mutex; -use tracing::{error, info, warn}; - -/// Base URL for Telegram Bot API. Extracted as constant for consistency -/// with LINE's `LINE_API_BASE` and to enable future mock testing. -pub const TELEGRAM_API_BASE: &str = "https://api.telegram.org"; - -// --- Telegram types --- - -#[derive(Debug, Deserialize)] -pub struct TelegramUpdate { - message: Option, -} - -#[derive(Debug, Deserialize)] -struct TelegramMessage { - message_id: i64, - message_thread_id: Option, - chat: TelegramChat, - from: Option, - text: Option, - caption: Option, - #[serde(default)] - entities: Vec, - #[serde(default)] - caption_entities: Vec, - #[serde(default)] - photo: Vec, - document: Option, - voice: Option, - audio: Option, -} - -#[derive(Debug, Deserialize)] -struct TelegramPhoto { - file_id: String, - width: u32, - height: u32, -} - -#[derive(Debug, Deserialize)] -struct TelegramDocument { - file_id: String, - file_name: Option, - mime_type: Option, -} - -#[derive(Debug, Deserialize)] -struct TelegramVoice { - file_id: String, - #[allow(dead_code)] // TODO: use for Content-Type hint - mime_type: Option, -} - -#[derive(Debug, Deserialize)] -struct TelegramAudio { - file_id: String, - #[allow(dead_code)] // TODO: use for filename - file_name: Option, - #[allow(dead_code)] // TODO: use for Content-Type hint - mime_type: Option, -} - -#[derive(Debug, Deserialize)] -struct TelegramEntity { - #[serde(rename = "type")] - entity_type: String, - offset: usize, - length: usize, -} - -#[derive(Debug, Deserialize)] -struct TelegramChat { - id: i64, - #[serde(rename = "type")] - chat_type: String, - #[allow(dead_code)] - is_forum: Option, -} - -#[derive(Debug, Deserialize)] -struct TelegramUser { - id: i64, - first_name: String, - last_name: Option, - username: Option, - is_bot: bool, -} - -// --- Webhook handler --- - -pub async fn webhook( - State(state): State>, - headers: axum::http::HeaderMap, - Json(update): Json, -) -> axum::http::StatusCode { - if let Some(ref expected) = state.telegram_secret_token { - let provided = headers - .get("x-telegram-bot-api-secret-token") - .and_then(|v| v.to_str().ok()); - if provided != Some(expected.as_str()) { - warn!("webhook rejected: invalid or missing secret_token"); - return axum::http::StatusCode::UNAUTHORIZED; - } - } - - let Some(msg) = update.message else { - return axum::http::StatusCode::OK; - }; - let is_photo = !msg.photo.is_empty(); - let is_document = msg.document.is_some(); - let is_voice = msg.voice.is_some(); - let is_audio = msg.audio.is_some(); - let text = msg.text.as_deref().or(msg.caption.as_deref()).unwrap_or(""); - - if text.trim().is_empty() && !is_photo && !is_document && !is_voice && !is_audio { - return axum::http::StatusCode::OK; - } - - let mut attachments = Vec::new(); - if is_photo || is_document || is_voice || is_audio { - if let Some(ref token) = state.telegram_bot_token { - let client = &state.client; - if is_photo { - if let Some(largest) = msg.photo.iter().max_by_key(|p| p.width * p.height) { - if let Some(att) = - download_telegram_media(client, token, &largest.file_id, MediaKind::Image).await - { - attachments.push(att); - } - } - } else if let Some(doc) = msg.document { - let file_name = doc.file_name.unwrap_or_else(|| "unknown.txt".to_string()); - let mime_type = doc.mime_type.unwrap_or_else(|| "text/plain".to_string()); - if let Some(att) = - download_telegram_document(client, token, &doc.file_id, &file_name, &mime_type).await - { - attachments.push(att); - } - } else if let Some(voice) = msg.voice { - if let Some(att) = download_telegram_media(client, token, &voice.file_id, MediaKind::Audio).await { - attachments.push(att); - } - } else if let Some(audio) = msg.audio { - if let Some(att) = download_telegram_media(client, token, &audio.file_id, MediaKind::Audio).await { - attachments.push(att); - } - } - } - } - - let from = msg.from.as_ref(); - let sender_name = from - .and_then(|u| u.username.as_deref()) - .unwrap_or("unknown"); - let display_name = from - .map(|u| { - let mut n = u.first_name.clone(); - if let Some(last) = &u.last_name { - n.push(' '); - n.push_str(last); - } - n - }) - .unwrap_or_else(|| "Unknown".into()); - - let mentions: Vec = msg - .entities - .iter() - .chain(msg.caption_entities.iter()) - .filter(|e| e.entity_type == "mention") - .filter_map(|e| { - text.get(e.offset..e.offset + e.length) - .map(|s| s.trim_start_matches('@').to_string()) - }) - .collect(); - - let mut event = GatewayEvent::new( - "telegram", - ChannelInfo { - id: msg.chat.id.to_string(), - channel_type: msg.chat.chat_type.clone(), - thread_id: msg.message_thread_id.map(|id| id.to_string()), - }, - SenderInfo { - id: from.map(|u| u.id.to_string()).unwrap_or_default(), - name: sender_name.into(), - display_name, - is_bot: from.map(|u| u.is_bot).unwrap_or(false), - }, - text, - &msg.message_id.to_string(), - mentions, - ); - event.content.attachments = attachments; - - // Guard: skip empty events (no text + no attachments) - if event.content.text.trim().is_empty() && event.content.attachments.is_empty() { - return axum::http::StatusCode::OK; - } - - let json = serde_json::to_string(&event).unwrap(); - info!(chat_id = %msg.chat.id, sender = %sender_name, "telegram → gateway"); - let _ = state.event_tx.send(json); - axum::http::StatusCode::OK -} - -/// Split text into chunks of at most `limit` characters, breaking at newlines when possible. -fn chunk_text(text: &str, limit: usize) -> Vec { - if text.chars().count() <= limit { - return vec![text.to_string()]; - } - let mut chunks = Vec::new(); - let mut current = String::new(); - for line in text.lines() { - if !current.is_empty() && current.chars().count() + line.chars().count() + 1 > limit { - chunks.push(std::mem::take(&mut current)); - } - if !current.is_empty() { - current.push('\n'); - } - if line.chars().count() > limit { - // Line itself exceeds limit — hard split - for ch in line.chars() { - current.push(ch); - if current.chars().count() >= limit { - chunks.push(std::mem::take(&mut current)); - } - } - } else { - current.push_str(line); - } - } - if !current.is_empty() { - chunks.push(current); - } - chunks -} - -fn is_markdown_parse_error(description: &str) -> bool { - let desc_lower = description.to_lowercase(); - desc_lower.contains("can't find end") - || desc_lower.contains("can't parse") - || desc_lower.contains("parse entities") -} - -/// Returns true if the content is complex enough to benefit from sendRichMessage. -/// -/// Design decisions: -/// - We classify at the adapter layer (not agent) so agents don't need prompt changes. -/// - Conservative: only route to rich when legacy sendMessage would visibly break. -/// - False positives are acceptable (rich renders simple text fine too), but we avoid -/// unnecessary API switches for plain prose to reduce risk surface. -/// - LaTeX and blockquotes are intentionally omitted for now (Phase 2). -fn is_complex_markdown(text: &str) -> bool { - // 🟡 Code blocks intentionally NOT routed to rich — sendMessage preserves - // syntax highlighting (language header + copy button) which RichBlockPreformatted lacks. - - // sendMessage hard limit is 4096 chars. Rich messages support 32768. - if text.chars().count() > 4096 { - return true; - } - text.lines().any(|line| { - let trimmed = line.trim_start(); - // ATX headings (h1-h6): sendMessage has zero heading support. - if trimmed.starts_with("# ") - || trimmed.starts_with("## ") - || trimmed.starts_with("### ") - || trimmed.starts_with("#### ") - || trimmed.starts_with("##### ") - || trimmed.starts_with("###### ") - { - return true; - } - // GFM table separator row detection. - if trimmed.starts_with('|') && trimmed.ends_with('|') { - let inner = &trimmed[1..trimmed.len() - 1]; - if inner.split('|').all(|cell| { - let c = cell.trim().trim_matches(':'); - !c.is_empty() && c.chars().all(|ch| ch == '-') - }) { - return true; - } - } - false - }) -} - -/// Send a rich message via Bot API 10.1 sendRichMessage. -/// -/// Design: we pass agent markdown directly via InputRichMessage.markdown. -/// Rich Markdown is GFM-compatible, so no conversion layer is needed. -/// The API handles rendering (tables, syntax highlighting, headings, etc.) -async fn send_rich_message( - client: &reqwest::Client, - bot_token: &str, - chat_id: &str, - thread_id: &Option, - text: &str, -) -> Result { - let url = format!("{TELEGRAM_API_BASE}/bot{bot_token}/sendRichMessage"); - let body = serde_json::json!({ - "chat_id": chat_id, - "message_thread_id": thread_id, - "rich_message": { "markdown": text }, - }); - let resp = client.post(&url).json(&body).send().await.map_err(|e| e.to_string())?; - let json: serde_json::Value = resp.json().await.map_err(|e| e.to_string())?; - if json["ok"].as_bool() == Some(true) { - Ok(json) - } else { - Err(json["description"].as_str().unwrap_or("unknown error").to_string()) - } -} - -/// Stream a partial rich message via sendRichMessageDraft. -/// -/// Design: ephemeral 30-second preview. Caller must follow up with -/// sendRichMessage to persist. Same draft_id = animated transition. -/// Wired but unused until gateway streaming infrastructure integrates. -#[allow(dead_code)] -async fn send_rich_message_draft( - client: &reqwest::Client, - bot_token: &str, - chat_id: &str, - thread_id: &Option, - draft_id: i64, - text: &str, -) -> Result<(), String> { - let url = format!("{TELEGRAM_API_BASE}/bot{bot_token}/sendRichMessageDraft"); - let body = serde_json::json!({ - "chat_id": chat_id, - "message_thread_id": thread_id, - "draft_id": draft_id, - "rich_message": if text.contains(", - reaction_state: &Arc>>>, - rich_messages: bool, -) { - // Handle create_topic command - if reply.command.as_deref() == Some("create_topic") { - let req_id = reply.request_id.clone().unwrap_or_default(); - info!(chat_id = %reply.channel.id, "creating forum topic"); - let url = format!("{TELEGRAM_API_BASE}/bot{bot_token}/createForumTopic"); - let resp = client - .post(&url) - .json(&serde_json::json!({"chat_id": reply.channel.id, "name": reply.content.text})) - .send() - .await; - let gw_resp = match resp { - Ok(r) => { - let body: serde_json::Value = r.json().await.unwrap_or_default(); - if body["ok"].as_bool() == Some(true) { - let tid = body["result"]["message_thread_id"] - .as_i64() - .map(|id| id.to_string()); - info!(thread_id = ?tid, "forum topic created"); - GatewayResponse { - schema: "openab.gateway.response.v1".into(), - request_id: req_id, - success: true, - thread_id: tid, - message_id: None, - error: None, - } - } else { - let err = body["description"] - .as_str() - .unwrap_or("unknown error") - .to_string(); - warn!(err = %err, "createForumTopic failed"); - GatewayResponse { - schema: "openab.gateway.response.v1".into(), - request_id: req_id, - success: false, - thread_id: None, - message_id: None, - error: Some(err), - } - } - } - Err(e) => GatewayResponse { - schema: "openab.gateway.response.v1".into(), - request_id: req_id, - success: false, - thread_id: None, - message_id: None, - error: Some(e.to_string()), - }, - }; - let json = serde_json::to_string(&gw_resp).unwrap(); - let _ = event_tx.send(json); - return; - } - - // Handle edit_message - if reply.command.as_deref() == Some("edit_message") { - if reply.reply_to == "draft" { - // Dummy "draft" ref from streaming without placeholder. - if rich_messages { - // Skip short updates — let thinking animation show until meaningful content arrives - if reply.content.text.len() < 30 { - return; - } - let text = if reply.content.text.len() > 32768 { - &reply.content.text[..reply.content.text.floor_char_boundary(32768)] - } else { - &reply.content.text - }; - // Combine channel + thread to avoid draft_id collision in forum topics - let chan: i64 = reply.channel.id.parse::().unwrap_or(1).abs(); - let tid: i64 = reply.channel.thread_id.as_deref().and_then(|t| t.parse::().ok()).unwrap_or(0).abs(); - let draft_id: i64 = (chan.wrapping_add(tid)) % 1_000_000 + 1; - let _ = send_rich_message_draft(client, bot_token, &reply.channel.id, &reply.channel.thread_id, draft_id, text).await; - } - // else: rich_messages=false with dummy ref — silently drop (no real msg to edit) - return; - } - // Real message_id — perform actual editMessageText (legacy streaming path) - let url = format!("{TELEGRAM_API_BASE}/bot{bot_token}/editMessageText"); - let _ = client - .post(&url) - .json(&serde_json::json!({ - "chat_id": reply.channel.id, - "message_id": reply.reply_to, - "text": &reply.content.text, - "parse_mode": "Markdown", - })) - .send() - .await; - return; - } - - // Handle add_reaction / remove_reaction - if reply.command.as_deref() == Some("add_reaction") - || reply.command.as_deref() == Some("remove_reaction") - { - // Send thinking draft on reaction changes — reflects agent state - if rich_messages && reply.command.as_deref() == Some("add_reaction") { - let thinking_text = match reply.content.text.as_str() { - "👀" => Some("Looking..."), - "🤔" => Some("Thinking..."), - "👨\u{200d}💻" => Some("Writing code..."), - "🔥" => Some("Working..."), - "⚡" => Some("Running tools..."), - _ => None, - }; - if let Some(text) = thinking_text { - let chan: i64 = reply.channel.id.parse::().unwrap_or(1).abs(); - let tid: i64 = reply.channel.thread_id.as_deref().and_then(|t| t.parse::().ok()).unwrap_or(0).abs(); - let draft_id: i64 = (chan.wrapping_add(tid)) % 1_000_000 + 1; - let _ = send_rich_message_draft( - client, bot_token, &reply.channel.id, &reply.channel.thread_id, draft_id, text, - ).await; - } - } - - let msg_key = format!("{}:{}", reply.channel.id, reply.reply_to); - let emoji = &reply.content.text; - let tg_emoji = match emoji.as_str() { - "🆗" => "👍", - other => other, - }; - let is_add = reply.command.as_deref() == Some("add_reaction"); - { - let mut reactions = reaction_state.lock().await; - let set = reactions.entry(msg_key.clone()).or_default(); - if is_add { - if !set.contains(&tg_emoji.to_string()) { - set.push(tg_emoji.to_string()); - } - } else { - set.retain(|e| e != tg_emoji); - } - } - let current: Vec = { - let reactions = reaction_state.lock().await; - reactions - .get(&msg_key) - .map(|v| { - v.iter() - .map(|e| serde_json::json!({"type": "emoji", "emoji": e})) - .collect() - }) - .unwrap_or_default() - }; - let url = format!("{TELEGRAM_API_BASE}/bot{bot_token}/setMessageReaction"); - let _ = client - .post(&url) - .json(&serde_json::json!({ - "chat_id": reply.channel.id, - "message_id": reply.reply_to, - "reaction": current, - })) - .send() - .await - .map_err(|e| error!("telegram reaction error: {e}")); - return; - } - - // Normal send_message - info!( - chat_id = %reply.channel.id, - thread_id = ?reply.channel.thread_id, - "gateway → telegram" - ); - - // --- Rich Message routing --- - // Design: try sendRichMessage first for complex content. On ANY failure - // (unsupported client, API version mismatch, network error), fall back to - // legacy sendMessage (chunked). This ensures zero-downtime rollout. - if rich_messages && is_complex_markdown(&reply.content.text) { - // Bot API limit: 32768 UTF-8 characters (not bytes). - let text = &reply.content.text; - let rich_text: String = if text.chars().count() > 32768 { - text.chars().take(32768).collect() - } else { - text.to_string() - }; - match send_rich_message(client, bot_token, &reply.channel.id, &reply.channel.thread_id, &rich_text).await { - Ok(_) => return, - Err(e) => warn!("sendRichMessage failed ({e}), falling back to sendMessage"), - } - } - - // Legacy sendMessage — chunk at 4096 chars to avoid rejection. - let chunks = chunk_text(&reply.content.text, 4096); - for chunk in &chunks { - let url = format!("{TELEGRAM_API_BASE}/bot{bot_token}/sendMessage"); - let resp = client - .post(&url) - .json(&serde_json::json!({ - "chat_id": reply.channel.id, - "text": chunk, - "message_thread_id": reply.channel.thread_id, - "parse_mode": "Markdown", - })) - .send() - .await; - - match resp { - Ok(r) => { - let body: serde_json::Value = r.json().await.unwrap_or_default(); - if body["ok"].as_bool() != Some(true) { - let desc = body["description"].as_str().unwrap_or("unknown error"); - if is_markdown_parse_error(desc) { - warn!("Markdown send failed: {desc}, retrying as plain text"); - match client - .post(&url) - .json(&serde_json::json!({ - "chat_id": reply.channel.id, - "text": chunk, - "message_thread_id": reply.channel.thread_id, - })) - .send() - .await - { - Ok(retry_r) => { - let retry_body: serde_json::Value = - retry_r.json().await.unwrap_or_default(); - if retry_body["ok"].as_bool() != Some(true) { - error!( - "telegram plain-text retry failed: {}", - retry_body["description"] - .as_str() - .unwrap_or("unknown error") - ); - } - } - Err(e) => error!("telegram plain-text send error: {e}"), - } - } else { - error!("telegram send failed: {desc}"); - } - } - } - Err(e) => error!("telegram send error: {e}"), - } - } -} - -/// Download media from Telegram via getFile → store to filesystem (colocate mode). -async fn download_telegram_media( - client: &reqwest::Client, - bot_token: &str, - file_id: &str, - kind: MediaKind, -) -> Option { - let get_file_url = format!("{TELEGRAM_API_BASE}/bot{}/getFile", bot_token); - let resp = client.get(&get_file_url).query(&[("file_id", file_id)]).send().await.ok()?; - let body: serde_json::Value = resp.json().await.ok()?; - let file_path = body["result"]["file_path"].as_str()?; - - let download_url = format!("{TELEGRAM_API_BASE}/file/bot{}/{}", bot_token, file_path); - let resp = client.get(&download_url).send().await.ok()?; - if !resp.status().is_success() { - return None; - } - - let max_size = match kind { - MediaKind::Image => IMAGE_MAX_DOWNLOAD, - MediaKind::Audio => AUDIO_MAX_DOWNLOAD, - }; - - if let Some(cl) = resp.headers().get(reqwest::header::CONTENT_LENGTH) { - if let Ok(size) = cl.to_str().unwrap_or("0").parse::() { - if size > max_size { - warn!(file_id, size, kind = ?kind, "Telegram media Content-Length exceeds limit"); - return None; - } - } - } - - let default_mime = match kind { - MediaKind::Image => "image/jpeg", - MediaKind::Audio => "audio/ogg", - }; - let content_type = resp - .headers() - .get(reqwest::header::CONTENT_TYPE) - .and_then(|h| h.to_str().ok()) - .unwrap_or(default_mime) - .to_string(); - - let bytes = resp.bytes().await.ok()?; - if bytes.len() as u64 > max_size { - warn!(file_id, size = bytes.len(), kind = ?kind, "Telegram media exceeds limit"); - return None; - } - - let (data_bytes, mime) = match kind { - MediaKind::Image => match resize_and_compress(&bytes) { - Ok((c, m)) => (c, m), - Err(e) => { - error!(err = %e, "Telegram image processing failed"); - return None; - } - }, - MediaKind::Audio => (bytes.to_vec(), content_type), - }; - - // Store to filesystem instead of base64 encoding - let path = store::store_media(&data_bytes).await?; - let att_type = match kind { - MediaKind::Image => "image", - MediaKind::Audio => "audio", - }; - info!(file_id, size = data_bytes.len(), kind = ?kind, "Telegram media stored"); - - Some(Attachment { - attachment_type: att_type.into(), - filename: format!("{}.{}", file_id, match kind { - MediaKind::Image => "jpg", - MediaKind::Audio => crate::media::audio_extension(&mime), - }), - mime_type: mime, - data: String::new(), // No base64 — using file path - size: data_bytes.len() as u64, - path: Some(path), - }) -} - -/// Download text document from Telegram → store to filesystem. -async fn download_telegram_document( - client: &reqwest::Client, - bot_token: &str, - file_id: &str, - file_name: &str, - mime_type: &str, -) -> Option { - if !crate::media::is_text_extension(file_name) { - tracing::debug!(file_name, "skipping non-text file attachment"); - return None; - } - - let get_file_url = format!("{TELEGRAM_API_BASE}/bot{}/getFile", bot_token); - let resp = client.get(&get_file_url).query(&[("file_id", file_id)]).send().await.ok()?; - let body: serde_json::Value = resp.json().await.ok()?; - let file_path = body["result"]["file_path"].as_str()?; - - let download_url = format!("{TELEGRAM_API_BASE}/file/bot{}/{}", bot_token, file_path); - let resp = client.get(&download_url).send().await.ok()?; - if !resp.status().is_success() { - return None; - } - - if let Some(cl) = resp.headers().get(reqwest::header::CONTENT_LENGTH) { - if let Ok(size) = cl.to_str().unwrap_or("0").parse::() { - if size > FILE_MAX_DOWNLOAD { - warn!(file_id, size, "Telegram document Content-Length exceeds limit"); - return None; - } - } - } - - let bytes = resp.bytes().await.ok()?; - if bytes.len() as u64 > FILE_MAX_DOWNLOAD { - warn!(file_id, size = bytes.len(), "Telegram document exceeds limit"); - return None; - } - - // Validate UTF-8 — reject binary files - if String::from_utf8(bytes.to_vec()).is_err() { - warn!(file_id, file_name, "Telegram document is not valid UTF-8, skipping"); - return None; - } - - let path = store::store_media(&bytes).await?; - info!(file_id, file_name, size = bytes.len(), "Telegram document stored"); - - Some(Attachment { - attachment_type: "text_file".into(), - filename: file_name.to_string(), - mime_type: mime_type.to_string(), - data: String::new(), - size: bytes.len() as u64, - path: Some(path), - }) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_is_markdown_parse_error() { - assert!(is_markdown_parse_error("Bad Request: can't find end of italic entity at byte offset 37")); - assert!(is_markdown_parse_error("Bad Request: can't parse entities: Can't find end of bold entity")); - assert!(is_markdown_parse_error("can't parse entities in message text")); - assert!(!is_markdown_parse_error("Unauthorized")); - assert!(!is_markdown_parse_error("Bad Request: chat not found")); - } - - #[test] - fn test_is_complex_markdown() { - // Tables - assert!(is_complex_markdown("| Col1 | Col2 |\n|---|---|\n| a | b |")); - assert!(is_complex_markdown("| Col1 | Col2 |\n| :--- | ---: |\n| a | b |")); - assert!(is_complex_markdown("| A | B |\n| :---: | :---: |\n| x | y |")); - // Code blocks — intentionally NOT complex (preserves syntax highlighting on legacy path) - assert!(!is_complex_markdown("```rust\nfn main() {}\n```")); - assert!(!is_complex_markdown("~~~\ncode\n~~~")); - // Headings - assert!(is_complex_markdown("# Heading\n\nSome text")); - assert!(is_complex_markdown("## Heading 2 at start")); - assert!(is_complex_markdown("### Heading 3 at start")); - assert!(is_complex_markdown("#### Heading 4")); - assert!(is_complex_markdown("text\n##### Heading 5")); - assert!(is_complex_markdown(" ## Indented heading")); - // Size - assert!(is_complex_markdown(&"x".repeat(4097))); - // Negatives - assert!(!is_complex_markdown("Hello world")); - assert!(!is_complex_markdown("*bold* and _italic_")); - assert!(!is_complex_markdown("#hashtag no space")); - assert!(!is_complex_markdown("| just | pipes |")); - } -} diff --git a/gateway/src/adapters/wecom.rs b/gateway/src/adapters/wecom.rs deleted file mode 100644 index e3e97ff17..000000000 --- a/gateway/src/adapters/wecom.rs +++ /dev/null @@ -1,1654 +0,0 @@ -use anyhow::Result; -use axum::extract::State; -use std::sync::Arc; -use tokio::sync::RwLock; -use tracing::{info, warn}; - -pub struct WecomConfig { - pub corp_id: String, - pub agent_id: String, - pub secret: String, - pub token: String, - pub encoding_aes_key: String, - pub webhook_path: String, - pub streaming_enabled: bool, - pub debounce_secs: u64, -} - -impl WecomConfig { - pub fn from_env() -> Option { - Self::from_reader(|k| std::env::var(k).ok()) - } - - /// Build config from an arbitrary string reader. Tests use this with a - /// HashMap so they don't mutate process-wide environment variables — - /// `env::set_var` races other tests under cargo's parallel runner. - fn from_reader Option>(read: F) -> Option { - let corp_id = read("WECOM_CORP_ID")?; - let secret = read("WECOM_SECRET")?; - let token = read("WECOM_TOKEN")?; - let encoding_aes_key = read("WECOM_ENCODING_AES_KEY")?; - let agent_id = read("WECOM_AGENT_ID")?; - if agent_id.parse::().is_err() { - warn!("WECOM_AGENT_ID must be a numeric value, got '{}'", agent_id); - return None; - } - let webhook_path = read("WECOM_WEBHOOK_PATH").unwrap_or_else(|| "/webhook/wecom".into()); - // Streaming opts-in: WeCom callback mode has no edit-message API, so - // streaming is implemented via thinking-placeholder + recall + resend, - // which causes a brief client flicker. Default off; set to true only if - // the UX tradeoff is acceptable. - let streaming_enabled = read("WECOM_STREAMING_ENABLED") - .map(|v| v == "true" || v == "1") - .unwrap_or(false); - let debounce_secs = read("WECOM_DEBOUNCE_SECS") - .and_then(|v| v.parse::().ok()) - .unwrap_or(3); - - if encoding_aes_key.len() != 43 { - warn!("WECOM_ENCODING_AES_KEY must be 43 characters, got {}", encoding_aes_key.len()); - return None; - } - - info!( - corp_id = %corp_id, - agent_id = %agent_id, - streaming_enabled, - debounce_secs, - "wecom adapter configured" - ); - Some(Self { - corp_id, - agent_id, - secret, - token, - encoding_aes_key, - webhook_path, - streaming_enabled, - debounce_secs, - }) - } -} - -fn decode_aes_key(encoding_aes_key: &str) -> anyhow::Result> { - use base64::engine::{DecodePaddingMode, GeneralPurpose, GeneralPurposeConfig}; - use base64::Engine; - // WeCom's EncodingAESKey is 43 base64 chars without trailing padding. - // Append "=" to make it a 44-char standard base64 string before decoding. - // Indifferent + allow_trailing_bits accommodate WeCom's non-standard - // encoding: the 43rd char's last 2 bits are not part of the output and - // must be ignored rather than rejected. - let padded = format!("{}=", encoding_aes_key); - let config = GeneralPurposeConfig::new() - .with_decode_padding_mode(DecodePaddingMode::Indifferent) - .with_decode_allow_trailing_bits(true); - let engine = GeneralPurpose::new(&base64::alphabet::STANDARD, config); - let key = engine - .decode(&padded) - .map_err(|e| anyhow::anyhow!("encoding_aes_key base64 decode failed: {e}"))?; - anyhow::ensure!( - key.len() == 32, - "encoding_aes_key must decode to 32 bytes, got {}", - key.len() - ); - Ok(key) -} - -fn compute_signature(token: &str, timestamp: &str, nonce: &str, encrypt: &str) -> String { - use sha1::Digest; - let mut parts = [token, timestamp, nonce, encrypt]; - parts.sort_unstable(); - let joined: String = parts.concat(); - let hash = sha1::Sha1::digest(joined.as_bytes()); - format!("{:x}", hash) -} - -fn verify_signature( - token: &str, - timestamp: &str, - nonce: &str, - encrypt: &str, - expected: &str, -) -> bool { - let computed = compute_signature(token, timestamp, nonce, encrypt); - tracing::debug!( - computed = %computed, - expected = %expected, - token_len = token.len(), - encrypt_len = encrypt.len(), - "signature comparison" - ); - subtle::ConstantTimeEq::ct_eq(computed.as_bytes(), expected.as_bytes()).into() -} - -fn decrypt_message( - encoding_aes_key: &str, - encrypted: &str, - expected_corp_id: &str, -) -> anyhow::Result { - use aes::cipher::{BlockDecryptMut, KeyIvInit}; - use base64::Engine; - - let key = decode_aes_key(encoding_aes_key)?; - let iv = &key[..16]; - - let cipher_bytes = base64::engine::general_purpose::STANDARD - .decode(encrypted) - .map_err(|e| anyhow::anyhow!("base64 decode failed: {e}"))?; - - if cipher_bytes.is_empty() || cipher_bytes.len() % 16 != 0 { - anyhow::bail!("ciphertext length {} not a multiple of 16", cipher_bytes.len()); - } - - type Aes256CbcDec = cbc::Decryptor; - let decryptor = Aes256CbcDec::new_from_slices(&key, iv) - .map_err(|e| anyhow::anyhow!("aes init failed: {e}"))?; - - let mut buf = cipher_bytes.to_vec(); - // WeCom uses PKCS7 with block_size=32, not 16. Decrypt without padding validation - // and strip padding manually. - let plaintext = decryptor - .decrypt_padded_mut::(&mut buf) - .map_err(|e| anyhow::anyhow!("aes decrypt failed: {e}"))?; - - // Strip WeCom PKCS7 padding (block_size=32): last byte indicates pad length (1-32) - let pad_byte = *plaintext.last().ok_or_else(|| anyhow::anyhow!("empty plaintext"))? as usize; - if pad_byte == 0 || pad_byte > 32 || pad_byte > plaintext.len() { - anyhow::bail!("invalid wecom padding value: {pad_byte}"); - } - let pad_start = plaintext.len() - pad_byte; - if !plaintext[pad_start..].iter().all(|&b| b as usize == pad_byte) { - anyhow::bail!("invalid PKCS#7 padding: not all padding bytes match"); - } - let plaintext = &plaintext[..pad_start]; - - // Plaintext structure: random(16) + msg_len(4, big-endian) + msg + corp_id - if plaintext.len() < 20 { - anyhow::bail!("decrypted payload too short"); - } - let msg_len = - u32::from_be_bytes([plaintext[16], plaintext[17], plaintext[18], plaintext[19]]) as usize; - if plaintext.len() < 20 + msg_len { - anyhow::bail!("msg_len exceeds payload size"); - } - let msg = &plaintext[20..20 + msg_len]; - let corp_id = &plaintext[20 + msg_len..]; - - let corp_id_str = - std::str::from_utf8(corp_id).map_err(|e| anyhow::anyhow!("corp_id not utf8: {e}"))?; - if corp_id_str != expected_corp_id { - anyhow::bail!("corp_id mismatch: expected {expected_corp_id}, got {corp_id_str}"); - } - - String::from_utf8(msg.to_vec()).map_err(|e| anyhow::anyhow!("message not utf8: {e}")) -} - -// --- Deduplication --- - -const DEDUPE_TTL_SECS: u64 = 30; -const DEDUPE_MAX_SIZE: usize = 10_000; - -struct DedupeCache { - entries: std::sync::Mutex>, -} - -impl DedupeCache { - fn new() -> Self { - Self { - entries: std::sync::Mutex::new(std::collections::HashMap::new()), - } - } - - fn check_and_insert(&self, msg_id: &str) -> bool { - let mut entries = self.entries.lock().unwrap_or_else(|e| e.into_inner()); - let now = std::time::Instant::now(); - - if entries.len() >= DEDUPE_MAX_SIZE { - entries.retain(|_, t| now.duration_since(*t).as_secs() < DEDUPE_TTL_SECS); - } - - if let Some(t) = entries.get(msg_id) { - if now.duration_since(*t).as_secs() < DEDUPE_TTL_SECS { - return false; - } - } - - entries.insert(msg_id.to_string(), now); - true - } -} - -// --- Token cache --- - -pub const WECOM_API_BASE: &str = "https://qyapi.weixin.qq.com"; -const TOKEN_REFRESH_MARGIN_SECS: u64 = 300; - -pub struct WecomTokenCache { - inner: RwLock>, - base_url: String, -} - -impl WecomTokenCache { - fn new() -> Self { - Self { - inner: RwLock::new(None), - base_url: WECOM_API_BASE.into(), - } - } - - #[cfg(test)] - fn with_base_url(base_url: String) -> Self { - Self { - inner: RwLock::new(None), - base_url, - } - } - - pub async fn get_token( - &self, - client: &reqwest::Client, - corp_id: &str, - secret: &str, - ) -> Result { - // Fast path: read lock - { - let guard = self.inner.read().await; - if let Some((ref token, created_at, expires_in)) = *guard { - let elapsed = created_at.elapsed().as_secs(); - if elapsed + TOKEN_REFRESH_MARGIN_SECS < expires_in { - return Ok(token.clone()); - } - } - } - - // Slow path: write lock + refresh - let mut guard = self.inner.write().await; - // Double-check after acquiring write lock - if let Some((ref token, created_at, expires_in)) = *guard { - let elapsed = created_at.elapsed().as_secs(); - if elapsed + TOKEN_REFRESH_MARGIN_SECS < expires_in { - return Ok(token.clone()); - } - } - - // WeCom's gettoken API requires `corpsecret` as a query parameter — the - // protocol mandates this, we can't move it to a header. Operators must - // configure their reverse proxy / load balancer to redact query strings - // on `/cgi-bin/gettoken` paths before logging access logs. We do not log - // this URL anywhere from the gateway side. - let url = format!( - "{}/cgi-bin/gettoken?corpid={}&corpsecret={}", - self.base_url, corp_id, secret - ); - let resp: serde_json::Value = client.get(&url).send().await?.json().await?; - - let errcode = resp["errcode"].as_i64().unwrap_or(-1); - if errcode != 0 { - anyhow::bail!( - "wecom gettoken failed: errcode={}, errmsg={}", - errcode, - resp["errmsg"] - ); - } - - let token = resp["access_token"] - .as_str() - .ok_or_else(|| anyhow::anyhow!("missing access_token in response"))? - .to_string(); - let expires_in = resp["expires_in"].as_u64().unwrap_or(7200); - - *guard = Some((token.clone(), std::time::Instant::now(), expires_in)); - Ok(token) - } - - pub async fn force_refresh( - &self, - client: &reqwest::Client, - corp_id: &str, - secret: &str, - ) -> Result { - let mut guard = self.inner.write().await; - *guard = None; - drop(guard); - self.get_token(client, corp_id, secret).await - } -} - -// --- Adapter --- - -struct PendingStream { - text_watch: tokio::sync::watch::Sender, -} - -type PendingMap = Arc>>; - -pub struct WecomAdapter { - pub config: WecomConfig, - pub token_cache: Arc, - client: reqwest::Client, - dedupe: DedupeCache, - pending_streams: PendingMap, -} - -impl WecomAdapter { - pub fn new(config: WecomConfig) -> Self { - Self { - token_cache: Arc::new(WecomTokenCache::new()), - client: reqwest::Client::new(), - dedupe: DedupeCache::new(), - pending_streams: Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())), - config, - } - } - - - pub async fn handle_reply( - &self, - reply: &crate::schema::GatewayReply, - event_tx: &tokio::sync::broadcast::Sender, - ) { - if let Some(cmd) = reply.command.as_deref() { - match cmd { - "add_reaction" | "remove_reaction" | "create_topic" => { - info!(command = cmd, "wecom: ignoring unsupported command"); - return; - } - "edit_message" => { - self.handle_edit_message(reply); - return; - } - _ => {} - } - } - - let text = &reply.content.text; - if text.is_empty() { - return; - } - - let to_user = reply - .channel - .id - .rsplit(':') - .next() - .unwrap_or(&reply.channel.id); - - let has_pending = { - let pending = self.pending_streams.lock().unwrap_or_else(|e| e.into_inner()); - pending.contains_key(&reply.channel.id) - }; - let is_streaming_placeholder = reply.request_id.is_some() && !has_pending; - if is_streaming_placeholder { - // Optionally send a thinking placeholder. With streaming disabled - // (default), buffer chunks silently and send the consolidated text - // when the debounce settles — no recall/flicker. - let placeholder_id = if self.config.streaming_enabled { - info!(to_user = to_user, "wecom: sending thinking placeholder"); - match self.send_text(to_user, "⏳...").await { - Ok(id) => Some(id), - Err(e) => { - warn!("wecom send thinking failed: {e}"); - return; - } - } - } else { - None - }; - - let (text_tx, text_rx) = tokio::sync::watch::channel(String::new()); - { - let mut pending = self.pending_streams.lock().unwrap_or_else(|e| e.into_inner()); - pending.insert(reply.channel.id.clone(), PendingStream { - text_watch: text_tx, - }); - } - let client = self.client.clone(); - let token_cache = self.token_cache.clone(); - let corp_id = self.config.corp_id.clone(); - let secret = self.config.secret.clone(); - let agent_id = self.config.agent_id.clone(); - let thinking_id = placeholder_id.clone(); - let flush_to_user = to_user.to_string(); - let channel_id_clone = reply.channel.id.clone(); - let pending_clone = self.pending_streams.clone(); - let debounce_secs = self.config.debounce_secs; - tokio::spawn(async move { - let mut rx = text_rx; - let debounce = std::time::Duration::from_secs(debounce_secs); - let mut last_text = String::new(); - let max_idle = std::time::Duration::from_secs(300); - let started = std::time::Instant::now(); - loop { - match tokio::time::timeout(debounce, rx.changed()).await { - Ok(Ok(())) => { - last_text = rx.borrow().clone(); - } - Ok(Err(_)) => break, - Err(_) => { - if !last_text.is_empty() { - break; - } - if started.elapsed() > max_idle { - warn!("wecom: debounce task timed out after 5 minutes"); - break; - } - } - } - } - // Acquire pending lock first, then capture any late writes - // that landed between the loop break and now. Holding the - // lock blocks handle_reply from sending more chunks for this - // channel, so this read is the last writeable moment. Then - // remove the entry, which drops text_tx and closes the channel. - { - let mut pending = pending_clone.lock().unwrap_or_else(|e| e.into_inner()); - let final_text = rx.borrow().clone(); - if !final_text.is_empty() { - last_text = final_text; - } - pending.remove(&channel_id_clone); - } - if last_text.is_empty() { - return; - } - flush_thinking( - &client, &token_cache, &corp_id, &secret, &agent_id, - thinking_id.as_deref(), &flush_to_user, &last_text, - ).await; - }); - - if let Some(ref req_id) = reply.request_id { - let resp = crate::schema::GatewayResponse { - schema: "openab.gateway.response.v1".into(), - request_id: req_id.clone(), - success: true, - thread_id: None, - message_id: placeholder_id, - error: None, - }; - if let Ok(json) = serde_json::to_string(&resp) { - let _ = event_tx.send(json); - } - } - return; - } - - if has_pending { - // Re-check under lock: the debounce task may have removed the entry - // between our earlier read of `has_pending` and now. If it did, - // fall through to the direct-send path so the chunk isn't lost. - let appended = { - let pending = self.pending_streams.lock().unwrap_or_else(|e| e.into_inner()); - if let Some(stream) = pending.get(&reply.channel.id) { - let current = stream.text_watch.borrow().clone(); - let combined = if current.is_empty() { - text.to_string() - } else { - format!("{}\n{}", current, text) - }; - let _ = stream.text_watch.send(combined); - true - } else { - false - } - }; - if appended { - if let Some(ref req_id) = reply.request_id { - let resp = crate::schema::GatewayResponse { - schema: "openab.gateway.response.v1".into(), - request_id: req_id.clone(), - success: true, - thread_id: None, - message_id: None, - error: None, - }; - if let Ok(json) = serde_json::to_string(&resp) { - let _ = event_tx.send(json); - } - } - return; - } - // Pending entry was already removed (debounce flushed) — fall - // through to direct-send below so this chunk still reaches the user. - } - - info!(to_user = to_user, "wecom: sending reply"); - let chunks = split_text_lines(text, 2048); - let mut msg_id = None; - - for chunk in &chunks { - match self.send_text(to_user, chunk).await { - Ok(id) => { - if msg_id.is_none() { - msg_id = Some(id); - } - } - Err(e) => warn!("wecom send failed: {e}"), - } - } - - if let Some(ref req_id) = reply.request_id { - let resp = crate::schema::GatewayResponse { - schema: "openab.gateway.response.v1".into(), - request_id: req_id.clone(), - success: msg_id.is_some(), - thread_id: None, - message_id: msg_id, - error: None, - }; - if let Ok(json) = serde_json::to_string(&resp) { - let _ = event_tx.send(json); - } - } - } - - fn handle_edit_message(&self, reply: &crate::schema::GatewayReply) { - let text = reply.content.text.trim(); - if text.is_empty() { - return; - } - let pending = self.pending_streams.lock().unwrap_or_else(|e| e.into_inner()); - if let Some(stream) = pending.get(&reply.channel.id) { - let _ = stream.text_watch.send(text.to_string()); - } - } - - - async fn send_text(&self, to_user: &str, text: &str) -> Result { - let agent_id: u64 = self.config.agent_id.parse().expect("agent_id validated at startup"); - let body = serde_json::json!({ - "touser": to_user, - "msgtype": "text", - "agentid": agent_id, - "text": { "content": text } - }); - - let resp = post_with_token_retry( - &self.client, - &self.token_cache, - &self.config.corp_id, - &self.config.secret, - "/cgi-bin/message/send", - &body, - ) - .await?; - Ok(resp["msgid"].as_str().unwrap_or("").to_string()) - } -} - -/// POST a JSON body to a WeCom API endpoint with automatic token refresh -/// on errcode 42001 (access_token expired). Used by both `send_text` and -/// the streaming flush path so a long-running stream can't lose its final -/// reply if the cached token expires mid-flight. -async fn post_with_token_retry( - client: &reqwest::Client, - token_cache: &WecomTokenCache, - corp_id: &str, - secret: &str, - api_path: &str, - body: &serde_json::Value, -) -> Result { - let token = token_cache.get_token(client, corp_id, secret).await?; - let url = format!("{}{}?access_token={}", token_cache.base_url, api_path, token); - let resp: serde_json::Value = client.post(&url).json(body).send().await?.json().await?; - let errcode = resp["errcode"].as_i64().unwrap_or(-1); - - if errcode == 42001 { - warn!(api_path, "wecom: access_token expired, refreshing and retrying"); - let new_token = token_cache.force_refresh(client, corp_id, secret).await?; - let retry_url = format!("{}{}?access_token={}", token_cache.base_url, api_path, new_token); - let retry_resp: serde_json::Value = - client.post(&retry_url).json(body).send().await?.json().await?; - let retry_code = retry_resp["errcode"].as_i64().unwrap_or(-1); - if retry_code != 0 { - anyhow::bail!( - "wecom {} retry failed: errcode={}, errmsg={}", - api_path, - retry_code, - retry_resp["errmsg"] - ); - } - Ok(retry_resp) - } else if errcode != 0 { - anyhow::bail!( - "wecom {} failed: errcode={}, errmsg={}", - api_path, - errcode, - resp["errmsg"] - ); - } else { - Ok(resp) - } -} - -// --- Handlers --- - -fn handle_verify_request( - token: &str, - encoding_aes_key: &str, - corp_id: &str, - msg_signature: &str, - timestamp: &str, - nonce: &str, - echostr: &str, -) -> anyhow::Result { - if !verify_signature(token, timestamp, nonce, echostr, msg_signature) { - anyhow::bail!("signature verification failed"); - } - decrypt_message(encoding_aes_key, echostr, corp_id) -} - -// --- XML parsing --- - -struct CallbackEnvelope { - to_user_name: String, - encrypt: String, -} - -struct WecomMessage { - from_user: String, - msg_type: String, - content: String, - msg_id: String, - pic_url: String, - media_id: String, - file_name: String, -} - -fn parse_envelope_xml(xml: &str) -> Result { - use quick_xml::events::Event; - use quick_xml::Reader; - - let mut reader = Reader::from_str(xml); - let mut to_user_name = String::new(); - let mut encrypt = String::new(); - let mut current_tag = String::new(); - - loop { - match reader.read_event() { - Ok(Event::Start(e)) => { - current_tag = String::from_utf8_lossy(e.name().as_ref()).to_string(); - } - Ok(Event::CData(e)) => { - let text = String::from_utf8_lossy(&e).to_string(); - match current_tag.as_str() { - "ToUserName" => to_user_name = text, - "Encrypt" => encrypt = text, - _ => {} - } - } - Ok(Event::Text(e)) => { - let text = e.unescape().unwrap_or_default().to_string(); - match current_tag.as_str() { - "ToUserName" if to_user_name.is_empty() => to_user_name = text, - "Encrypt" if encrypt.is_empty() => encrypt = text, - _ => {} - } - } - Ok(Event::End(_)) => { - current_tag.clear(); - } - Ok(Event::Eof) => break, - Err(e) => anyhow::bail!("xml parse error: {e}"), - _ => {} - } - } - - if encrypt.is_empty() { - anyhow::bail!("missing Encrypt field in callback XML"); - } - Ok(CallbackEnvelope { - to_user_name, - encrypt, - }) -} - -fn parse_message_xml(xml: &str) -> Result { - use quick_xml::events::Event; - use quick_xml::Reader; - - let mut reader = Reader::from_str(xml); - let mut from_user = String::new(); - let mut msg_type = String::new(); - let mut content = String::new(); - let mut msg_id = String::new(); - let mut pic_url = String::new(); - let mut media_id = String::new(); - let mut file_name = String::new(); - let mut current_tag = String::new(); - - loop { - match reader.read_event() { - Ok(Event::Start(e)) => { - current_tag = String::from_utf8_lossy(e.name().as_ref()).to_string(); - } - Ok(Event::CData(e)) => { - let text = String::from_utf8_lossy(&e).to_string(); - match current_tag.as_str() { - "FromUserName" => from_user = text, - "MsgType" => msg_type = text, - "Content" => content = text, - "MsgId" => msg_id = text, - "PicUrl" => pic_url = text, - "MediaId" => media_id = text, - "FileName" => file_name = text, - _ => {} - } - } - Ok(Event::Text(e)) => { - let text = e.unescape().unwrap_or_default().to_string(); - match current_tag.as_str() { - "FromUserName" if from_user.is_empty() => from_user = text, - "MsgType" if msg_type.is_empty() => msg_type = text, - "Content" if content.is_empty() => content = text, - "MsgId" if msg_id.is_empty() => msg_id = text, - "PicUrl" if pic_url.is_empty() => pic_url = text, - "MediaId" if media_id.is_empty() => media_id = text, - "FileName" if file_name.is_empty() => file_name = text, - _ => {} - } - } - Ok(Event::End(_)) => { - current_tag.clear(); - } - Ok(Event::Eof) => break, - Err(e) => anyhow::bail!("xml parse error: {e}"), - _ => {} - } - } - - Ok(WecomMessage { - from_user, - msg_type, - content, - msg_id, - pic_url, - media_id, - file_name, - }) -} - -#[allow(clippy::too_many_arguments)] -async fn flush_thinking( - client: &reqwest::Client, - token_cache: &WecomTokenCache, - corp_id: &str, - secret: &str, - agent_id: &str, - thinking_msg_id: Option<&str>, - to_user: &str, - text: &str, -) { - info!(?thinking_msg_id, text_len = text.len(), "wecom: flush_thinking starting"); - - // Recall thinking placeholder (only when streaming was enabled) - if let Some(id) = thinking_msg_id { - let body = serde_json::json!({ "msgid": id }); - match post_with_token_retry( - client, - token_cache, - corp_id, - secret, - "/cgi-bin/message/recall", - &body, - ) - .await - { - Ok(resp) => info!(body = %resp, "wecom: recall response"), - Err(e) => warn!(error = %e, "wecom: recall failed"), - } - } - - // Send final text. Each chunk goes through retry-on-token-expiry so a - // long stream that outlives the cached token still delivers its reply. - let aid = agent_id.parse::().unwrap_or(0); - let chunks = split_text_lines(text, 2048); - info!(chunk_count = chunks.len(), "wecom: sending final chunks"); - for (i, chunk) in chunks.iter().enumerate() { - let body = serde_json::json!({ - "touser": to_user, - "msgtype": "text", - "agentid": aid, - "text": { "content": chunk } - }); - match post_with_token_retry( - client, - token_cache, - corp_id, - secret, - "/cgi-bin/message/send", - &body, - ) - .await - { - Ok(val) => { - let msg_id = val["msgid"].as_str().unwrap_or(""); - info!(msg_id = %msg_id, chunk_idx = i, "wecom: sent final reply chunk"); - } - Err(e) => warn!(error = %e, chunk_idx = i, "wecom flush send failed"), - } - } -} - -/// Split `text` into chunks that each fit within `limit` bytes (WeCom's -/// `message/send` truncates server-side at 2048 bytes). Splits prefer -/// newline boundaries; lines that exceed the limit themselves are split at -/// UTF-8 char boundaries via `char_indices()` so multibyte characters are -/// never severed mid-codepoint. The `limit` and all `len()` comparisons in -/// this function are in **bytes**, matching WeCom's server-side check. -fn split_text_lines(text: &str, limit: usize) -> Vec { - if text.len() <= limit { - return vec![text.to_string()]; - } - let mut chunks = Vec::new(); - let mut current = String::new(); - for line in text.split('\n') { - if line.len() > limit { - if !current.is_empty() { - chunks.push(current); - current = String::new(); - } - // Split long line at char boundaries - let mut pos = 0; - for (i, ch) in line.char_indices() { - if i - pos + ch.len_utf8() > limit { - chunks.push(line[pos..i].to_string()); - pos = i; - } - } - if pos < line.len() { - current = line[pos..].to_string(); - } - continue; - } - let candidate_len = if current.is_empty() { - line.len() - } else { - current.len() + 1 + line.len() - }; - if candidate_len > limit && !current.is_empty() { - chunks.push(current); - current = String::new(); - } - if !current.is_empty() { - current.push('\n'); - } - current.push_str(line); - } - if !current.is_empty() { - chunks.push(current); - } - chunks -} - -pub async fn verify( - State(state): State>, - query: axum::extract::Query>, -) -> axum::response::Response { - use axum::response::IntoResponse; - - let wecom = match state.wecom.as_ref() { - Some(w) => w, - None => return axum::http::StatusCode::SERVICE_UNAVAILABLE.into_response(), - }; - - let msg_signature = query.get("msg_signature").map(|s| s.as_str()).unwrap_or(""); - let timestamp = query.get("timestamp").map(|s| s.as_str()).unwrap_or(""); - let nonce = query.get("nonce").map(|s| s.as_str()).unwrap_or(""); - let echostr = query.get("echostr").map(|s| s.as_str()).unwrap_or(""); - - info!( - msg_signature = %msg_signature, - timestamp = %timestamp, - nonce = %nonce, - echostr_len = echostr.len(), - "wecom verify request received" - ); - - match handle_verify_request( - &wecom.config.token, - &wecom.config.encoding_aes_key, - &wecom.config.corp_id, - msg_signature, - timestamp, - nonce, - echostr, - ) { - Ok(plaintext) => plaintext.into_response(), - Err(e) => { - warn!("wecom callback verification failed: {e}"); - axum::http::StatusCode::FORBIDDEN.into_response() - } - } -} - -pub async fn webhook( - State(state): State>, - query: axum::extract::Query>, - body: axum::body::Bytes, -) -> axum::response::Response { - use axum::response::IntoResponse; - - let wecom = match state.wecom.as_ref() { - Some(w) => w, - None => return axum::http::StatusCode::SERVICE_UNAVAILABLE.into_response(), - }; - - let msg_signature = query.get("msg_signature").map(|s| s.as_str()).unwrap_or(""); - let timestamp = query.get("timestamp").map(|s| s.as_str()).unwrap_or(""); - let nonce = query.get("nonce").map(|s| s.as_str()).unwrap_or(""); - - // Reject stale callbacks. WeCom retries within ~5s, our dedup window is - // 30s, so a 5-minute freshness check rejects replays without false- - // positives on legitimate retries. The signature itself doesn't bind a - // freshness expectation, so without this an attacker who captured a - // signed payload could replay it indefinitely. - if let Ok(ts) = timestamp.parse::() { - let now = chrono::Utc::now().timestamp(); - if (now - ts).abs() > 300 { - warn!(timestamp_age_secs = now - ts, "wecom webhook: rejecting stale callback"); - return axum::http::StatusCode::FORBIDDEN.into_response(); - } - } - - let body_str = match std::str::from_utf8(&body) { - Ok(s) => s, - Err(_) => return axum::http::StatusCode::BAD_REQUEST.into_response(), - }; - - let envelope = match parse_envelope_xml(body_str) { - Ok(e) => e, - Err(e) => { - warn!("wecom envelope parse error: {e}"); - return axum::http::StatusCode::BAD_REQUEST.into_response(); - } - }; - - // ToUserName in the outer envelope must match our configured Corp ID. - // The decrypt step also validates the inner Corp ID suffix; checking here - // first surfaces misrouted callbacks before we touch crypto. - if envelope.to_user_name != wecom.config.corp_id { - warn!( - envelope_to = %envelope.to_user_name, - expected = %wecom.config.corp_id, - "wecom webhook: envelope ToUserName mismatch" - ); - return axum::http::StatusCode::FORBIDDEN.into_response(); - } - - if !verify_signature( - &wecom.config.token, - timestamp, - nonce, - &envelope.encrypt, - msg_signature, - ) { - warn!("wecom webhook signature verification failed"); - return axum::http::StatusCode::FORBIDDEN.into_response(); - } - - info!(encrypt_len = envelope.encrypt.len(), "wecom: decrypting callback"); - let decrypted = match decrypt_message( - &wecom.config.encoding_aes_key, - &envelope.encrypt, - &wecom.config.corp_id, - ) { - Ok(d) => { - info!("wecom: decrypt ok"); - d - } - Err(e) => { - warn!(encrypt_len = envelope.encrypt.len(), "wecom decrypt failed: {e}"); - return "success".into_response(); - } - }; - - let msg = match parse_message_xml(&decrypted) { - Ok(m) => m, - Err(e) => { - warn!("wecom message parse error: {e}"); - return "success".into_response(); - } - }; - - info!( - msg_type = %msg.msg_type, - has_pic_url = !msg.pic_url.is_empty(), - msg_id = %msg.msg_id, - "wecom: parsed message" - ); - - if !matches!(msg.msg_type.as_str(), "text" | "image" | "file") { - return "success".into_response(); - } - - if !wecom.dedupe.check_and_insert(&msg.msg_id) { - return "success".into_response(); - } - - let text = match msg.msg_type.as_str() { - "text" => msg.content.clone(), - "image" => "Describe this image.".to_string(), - "file" => format!("User sent a file: {}", msg.file_name), - _ => String::new(), - }; - - let mut attachments = Vec::new(); - if msg.msg_type == "image" && !msg.pic_url.is_empty() { - match download_wecom_image(&wecom.client, &msg.pic_url).await { - Some(att) => attachments.push(att), - None => info!("wecom: image download failed, forwarding without attachment"), - } - } - if msg.msg_type == "file" && !msg.media_id.is_empty() { - match download_wecom_file( - &wecom.client, - &wecom.token_cache, - &wecom.config.corp_id, - &wecom.config.secret, - &msg.media_id, - &msg.file_name, - ) - .await - { - Some(att) => attachments.push(att), - None => info!("wecom: file download failed, forwarding without attachment"), - } - } - - if text.trim().is_empty() && attachments.is_empty() { - return "success".into_response(); - } - - let channel_id = format!("wecom:{}:{}", wecom.config.corp_id, msg.from_user); - let mut event = crate::schema::GatewayEvent::new( - "wecom", - crate::schema::ChannelInfo { - id: channel_id, - channel_type: "direct".into(), - thread_id: None, - }, - crate::schema::SenderInfo { - id: msg.from_user.clone(), - name: msg.from_user.clone(), - display_name: msg.from_user.clone(), - is_bot: false, - }, - &text, - &msg.msg_id, - vec![], - ); - event.content.attachments = attachments; - - let att_sizes: Vec = event.content.attachments.iter().map(|a| a.data.len()).collect(); - info!( - attachments = event.content.attachments.len(), - text_len = event.content.text.len(), - att_data_sizes = ?att_sizes, - att_mime = ?event.content.attachments.iter().map(|a| a.mime_type.as_str()).collect::>(), - "wecom: forwarding event to OAB" - ); - if let Ok(json) = serde_json::to_string(&event) { - info!( - json_len = json.len(), - has_attachments_in_json = json.contains("\"attachments\""), - "wecom: event JSON ready" - ); - let _ = state.event_tx.send(json); - } - - "success".into_response() -} - -const IMAGE_MAX_DOWNLOAD: u64 = 10 * 1024 * 1024; -const IMAGE_MAX_DIMENSION_PX: u32 = 1200; -const IMAGE_JPEG_QUALITY: u8 = 75; - -async fn download_wecom_image( - client: &reqwest::Client, - pic_url: &str, -) -> Option { - // Only fetch over HTTPS. WeCom's CDN serves images over HTTPS; rejecting - // non-HTTPS URLs prevents SSRF if the AES key is ever compromised and - // an attacker forges a callback with PicUrl pointing at an internal host. - if !pic_url.starts_with("https://") { - warn!(pic_url, "wecom: rejecting non-HTTPS pic_url"); - return None; - } - info!(pic_url, "wecom: downloading image"); - let resp = match client.get(pic_url).send().await { - Ok(r) => r, - Err(e) => { - warn!(error = %e, "wecom image download failed"); - return None; - } - }; - if !resp.status().is_success() { - warn!(status = %resp.status(), "wecom image download failed"); - return None; - } - if let Some(cl) = resp.headers().get(reqwest::header::CONTENT_LENGTH) { - if let Ok(size) = cl.to_str().unwrap_or("0").parse::() { - if size > IMAGE_MAX_DOWNLOAD { - warn!(size, "wecom image exceeds 10MB limit, skipping"); - return None; - } - } - } - let bytes = resp.bytes().await.ok()?; - if bytes.len() as u64 > IMAGE_MAX_DOWNLOAD { - warn!(size = bytes.len(), "wecom image exceeds 10MB limit"); - return None; - } - let (compressed, mime) = match resize_and_compress(&bytes) { - Ok(v) => v, - Err(e) => { - warn!(error = %e, "wecom: image resize/compress failed"); - return None; - } - }; - let path = crate::store::store_media(&compressed).await?; - let ext = if mime == "image/gif" { "gif" } else { "jpg" }; - Some(crate::schema::Attachment { - attachment_type: "image".into(), - filename: format!("wecom_{}.{}", chrono::Utc::now().timestamp(), ext), - mime_type: mime, - data: String::new(), - size: compressed.len() as u64, - path: Some(path), - }) -} - -const FILE_MAX_DOWNLOAD: u64 = 20 * 1024 * 1024; - -const TEXT_EXTENSIONS: &[&str] = &[ - "txt", "csv", "log", "md", "json", "jsonl", "yaml", "yml", "toml", "xml", "rs", "py", "js", - "ts", "jsx", "tsx", "go", "java", "c", "cpp", "h", "hpp", "rb", "sh", "bash", "zsh", "fish", - "ps1", "bat", "sql", "html", "css", "scss", "less", "ini", "cfg", "conf", "env", - "swift", "kt", "scala", "r", "pl", "lua", "graphql", "tsv", -]; - -const TEXT_FILENAMES: &[&str] = &[ - "dockerfile", "makefile", "justfile", "rakefile", "gemfile", - "procfile", "vagrantfile", ".gitignore", ".dockerignore", ".editorconfig", -]; - -fn is_text_file(filename: &str) -> bool { - let lower = filename.to_lowercase(); - if lower.contains('.') { - if let Some(ext) = lower.rsplit('.').next() { - if TEXT_EXTENSIONS.contains(&ext) { - return true; - } - } - } - TEXT_FILENAMES.contains(&lower.as_str()) -} - -/// GET /cgi-bin/media/get with token-expiry retry. The media API returns -/// JSON `{"errcode":42001,...}` instead of binary when the token is stale, -/// so we sniff Content-Type and retry once with a force-refreshed token. -async fn fetch_media_with_retry( - client: &reqwest::Client, - token_cache: &WecomTokenCache, - corp_id: &str, - secret: &str, - media_id: &str, -) -> Result { - let token = token_cache.get_token(client, corp_id, secret).await?; - let url = format!( - "{}/cgi-bin/media/get?access_token={}&media_id={}", - token_cache.base_url, token, media_id - ); - let resp = client.get(&url).send().await?; - let content_type = resp - .headers() - .get(reqwest::header::CONTENT_TYPE) - .and_then(|v| v.to_str().ok()) - .unwrap_or("") - .to_string(); - if !content_type.contains("json") { - return Ok(resp); - } - // JSON body means error path. Inspect for 42001 and retry once. - let body = resp.text().await.unwrap_or_default(); - let val: serde_json::Value = serde_json::from_str(&body).unwrap_or_default(); - let errcode = val["errcode"].as_i64().unwrap_or(-1); - if errcode == 42001 { - warn!("wecom media: access_token expired, refreshing and retrying"); - let new_token = token_cache.force_refresh(client, corp_id, secret).await?; - let retry_url = format!( - "{}/cgi-bin/media/get?access_token={}&media_id={}", - token_cache.base_url, new_token, media_id - ); - return Ok(client.get(&retry_url).send().await?); - } - anyhow::bail!("wecom media error: {body}") -} - -async fn download_wecom_file( - client: &reqwest::Client, - token_cache: &WecomTokenCache, - corp_id: &str, - secret: &str, - media_id: &str, - filename: &str, -) -> Option { - info!(filename, media_id, "wecom: downloading file"); - let resp = match fetch_media_with_retry(client, token_cache, corp_id, secret, media_id).await { - Ok(r) => r, - Err(e) => { - warn!(error = %e, "wecom file download failed"); - return None; - } - }; - if !resp.status().is_success() { - warn!(status = %resp.status(), "wecom file download failed"); - return None; - } - if let Some(cl) = resp.headers().get(reqwest::header::CONTENT_LENGTH) { - if let Ok(size) = cl.to_str().unwrap_or("0").parse::() { - if size > FILE_MAX_DOWNLOAD { - warn!(size, "wecom file exceeds 20MB limit, skipping"); - return None; - } - } - } - let bytes = resp.bytes().await.ok()?; - if bytes.len() as u64 > FILE_MAX_DOWNLOAD { - warn!(size = bytes.len(), "wecom file exceeds 20MB limit"); - return None; - } - - if !is_text_file(filename) { - info!(filename, "wecom: skipping non-text file"); - return None; - } - - let text_content = match String::from_utf8(bytes.to_vec()) { - Ok(s) => s, - Err(_) => { - info!(filename, "wecom: file is not valid UTF-8, skipping"); - return None; - } - }; - - let path = crate::store::store_media(text_content.as_bytes()).await?; - let size = text_content.len() as u64; - - Some(crate::schema::Attachment { - attachment_type: "text_file".into(), - filename: filename.to_string(), - mime_type: "text/plain".into(), - data: String::new(), - size, - path: Some(path), - }) -} - -fn resize_and_compress(raw: &[u8]) -> Result<(Vec, String), image::ImageError> { - use image::ImageReader; - use std::io::Cursor; - - let reader = ImageReader::new(Cursor::new(raw)).with_guessed_format()?; - let format = reader.format(); - if format == Some(image::ImageFormat::Gif) { - return Ok((raw.to_vec(), "image/gif".to_string())); - } - let img = reader.decode()?; - let (w, h) = (img.width(), img.height()); - let img = if w > IMAGE_MAX_DIMENSION_PX || h > IMAGE_MAX_DIMENSION_PX { - let max_side = std::cmp::max(w, h); - let ratio = f64::from(IMAGE_MAX_DIMENSION_PX) / f64::from(max_side); - let new_w = (f64::from(w) * ratio) as u32; - let new_h = (f64::from(h) * ratio) as u32; - img.resize(new_w, new_h, image::imageops::FilterType::Lanczos3) - } else { - img - }; - let mut buf = Cursor::new(Vec::new()); - let encoder = image::codecs::jpeg::JpegEncoder::new_with_quality(&mut buf, IMAGE_JPEG_QUALITY); - img.write_with_encoder(encoder)?; - Ok((buf.into_inner(), "image/jpeg".to_string())) -} - -#[cfg(test)] -mod tests { - use super::*; - - fn make_env(pairs: &[(&str, &str)]) -> impl Fn(&str) -> Option { - let map: std::collections::HashMap = pairs - .iter() - .map(|(k, v)| ((*k).to_string(), (*v).to_string())) - .collect(); - move |k: &str| map.get(k).cloned() - } - - #[test] - fn config_from_env_all_present() { - let env = make_env(&[ - ("WECOM_CORP_ID", "ww_test_corp"), - ("WECOM_SECRET", "test_secret"), - ("WECOM_TOKEN", "test_token"), - ("WECOM_ENCODING_AES_KEY", "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG"), - ("WECOM_AGENT_ID", "1000002"), - ]); - let config = WecomConfig::from_reader(env).unwrap(); - assert_eq!(config.corp_id, "ww_test_corp"); - assert_eq!(config.agent_id, "1000002"); - assert_eq!(config.webhook_path, "/webhook/wecom"); - assert!(!config.streaming_enabled, "streaming defaults off"); - assert_eq!(config.debounce_secs, 3); - } - - #[test] - fn config_from_env_missing_required() { - let env = make_env(&[]); - assert!(WecomConfig::from_reader(env).is_none()); - } - - fn encrypt_for_test(encoding_aes_key: &str, msg: &str, corp_id: &str) -> String { - use aes::cipher::{BlockEncryptMut, KeyIvInit}; - use base64::Engine; - - let key = decode_aes_key(encoding_aes_key).unwrap(); - let iv = &key[..16]; - - let msg_bytes = msg.as_bytes(); - let corp_id_bytes = corp_id.as_bytes(); - let msg_len = (msg_bytes.len() as u32).to_be_bytes(); - - let mut plaintext = Vec::new(); - plaintext.extend_from_slice(&[0u8; 16]); // random bytes (zeros for test) - plaintext.extend_from_slice(&msg_len); - plaintext.extend_from_slice(msg_bytes); - plaintext.extend_from_slice(corp_id_bytes); - - // WeCom uses PKCS7 padding with block_size=32 - let block_size = 32; - let pad_len = block_size - (plaintext.len() % block_size); - for _ in 0..pad_len { - plaintext.push(pad_len as u8); - } - - // Encrypt with NoPadding since we already padded manually - let total_len = plaintext.len(); - let mut buf = vec![0u8; total_len + 16]; // extra space just in case - buf[..total_len].copy_from_slice(&plaintext); - - type Aes256CbcEnc = cbc::Encryptor; - let encryptor = Aes256CbcEnc::new_from_slices(&key, iv).unwrap(); - let encrypted = encryptor - .encrypt_padded_mut::(&mut buf, total_len) - .unwrap(); - - base64::engine::general_purpose::STANDARD.encode(encrypted) - } - - #[test] - fn aes_key_decode() { - let key_str = "QUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUE"; - let key_bytes = decode_aes_key(key_str).unwrap(); - assert_eq!(key_bytes.len(), 32); - } - - #[test] - fn signature_verify() { - let token = "testtoken"; - let timestamp = "1409659813"; - let nonce = "1372623149"; - let encrypt = "msg_encrypt_content"; - - let sig = compute_signature(token, timestamp, nonce, encrypt); - assert!(verify_signature(token, timestamp, nonce, encrypt, &sig)); - assert!(!verify_signature( - token, - timestamp, - nonce, - encrypt, - "wrong_signature_value_here" - )); - } - - #[test] - fn decrypt_wecom_payload() { - let key_str = "QUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUE"; - let corp_id = "ww_test_corp"; - let msg = "hello world"; - - let encrypted = encrypt_for_test(key_str, msg, corp_id); - let decrypted = decrypt_message(key_str, &encrypted, corp_id).unwrap(); - assert_eq!(decrypted, msg); - } - - #[test] - fn verify_callback_echostr() { - let token = "testtoken"; - let encoding_aes_key = "QUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUE"; - let corp_id = "ww_test_corp"; - let echostr_plain = "success_echo_string"; - - let echostr_encrypted = encrypt_for_test(encoding_aes_key, echostr_plain, corp_id); - let sig = compute_signature(token, "1409659813", "nonce123", &echostr_encrypted); - - let result = handle_verify_request( - token, - encoding_aes_key, - corp_id, - &sig, - "1409659813", - "nonce123", - &echostr_encrypted, - ); - assert_eq!(result.unwrap(), echostr_plain); - } - - #[test] - fn parse_text_message_xml() { - let xml = r#"134883186012345678901234561000002"#; - - let msg = parse_message_xml(xml).unwrap(); - assert_eq!(msg.from_user, "user001"); - assert_eq!(msg.msg_type, "text"); - assert_eq!(msg.content, "hello bot"); - assert_eq!(msg.msg_id, "1234567890123456"); - } - - #[test] - fn parse_callback_envelope() { - let xml = r#""#; - - let envelope = parse_envelope_xml(xml).unwrap(); - assert_eq!(envelope.to_user_name, "ww_test_corp"); - assert_eq!(envelope.encrypt, "some_encrypted_base64"); - } - - #[test] - fn dedupe_rejects_duplicates() { - let cache = DedupeCache::new(); - assert!(cache.check_and_insert("msg_001")); - assert!(!cache.check_and_insert("msg_001")); - assert!(cache.check_and_insert("msg_002")); - } - - #[tokio::test] - async fn token_refresh_success() { - use wiremock::matchers::{method, query_param}; - use wiremock::{Mock, MockServer, ResponseTemplate}; - - let server = MockServer::start().await; - Mock::given(method("GET")) - .and(query_param("corpid", "ww_test_corp")) - .and(query_param("corpsecret", "test_secret")) - .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ - "errcode": 0, - "errmsg": "ok", - "access_token": "test_token_abc", - "expires_in": 7200 - }))) - .expect(1) - .mount(&server) - .await; - - let cache = WecomTokenCache::with_base_url(server.uri()); - let client = reqwest::Client::new(); - let token = cache.get_token(&client, "ww_test_corp", "test_secret").await.unwrap(); - assert_eq!(token, "test_token_abc"); - - // Second call uses cache (mock expects exactly 1 call) - let token2 = cache.get_token(&client, "ww_test_corp", "test_secret").await.unwrap(); - assert_eq!(token2, "test_token_abc"); - } - - #[test] - fn split_text_lines_multi() { - let text = "line1\nline2\nline3"; - let chunks = split_text_lines(text, 11); - assert_eq!(chunks.len(), 2); - assert_eq!(chunks[0], "line1\nline2"); - assert_eq!(chunks[1], "line3"); - } - - #[test] - fn split_text_lines_within_limit() { - let text = "short"; - let chunks = split_text_lines(text, 100); - assert_eq!(chunks, vec!["short"]); - } - - #[test] - fn split_text_lines_long_line() { - let text = "abcdefghij"; - let chunks = split_text_lines(text, 4); - assert_eq!(chunks, vec!["abcd", "efgh", "ij"]); - } - - #[test] - fn split_text_lines_long_line_utf8() { - let text = "你好世界測試"; // 18 bytes, 6 chars - let chunks = split_text_lines(text, 6); - assert_eq!(chunks, vec!["你好", "世界", "測試"]); - } - - #[test] - fn is_text_file_check() { - assert!(is_text_file("readme.md")); - assert!(is_text_file("config.json")); - assert!(is_text_file("data.csv")); - assert!(is_text_file("MAIN.PY")); - assert!(!is_text_file("photo.png")); - assert!(!is_text_file("archive.zip")); - assert!(!is_text_file("doc.pdf")); - } - - #[test] - fn parse_file_message() { - let xml = r#"134883186066661000002"#; - let msg = parse_message_xml(xml).unwrap(); - assert_eq!(msg.msg_type, "file"); - assert_eq!(msg.media_id, "media_abc123"); - assert_eq!(msg.file_name, "report.csv"); - } - - #[test] - fn full_webhook_decrypt_and_parse() { - let token = "testtoken"; - let encoding_aes_key = "QUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUE"; - let corp_id = "ww_test_corp"; - let timestamp = "1409659813"; - let nonce = "nonce123"; - - // Simulate the inner message - let inner_xml = "134883186099991000002"; - - // Encrypt it - let encrypted = encrypt_for_test(encoding_aes_key, inner_xml, corp_id); - - // Compute signature - let sig = compute_signature(token, timestamp, nonce, &encrypted); - - // Verify signature - assert!(verify_signature(token, timestamp, nonce, &encrypted, &sig)); - - // Decrypt - let decrypted = decrypt_message(encoding_aes_key, &encrypted, corp_id).unwrap(); - assert_eq!(decrypted, inner_xml); - - // Parse - let msg = parse_message_xml(&decrypted).unwrap(); - assert_eq!(msg.from_user, "user42"); - assert_eq!(msg.msg_type, "text"); - assert_eq!(msg.content, "ping"); - assert_eq!(msg.msg_id, "9999"); - } - - #[test] - fn parse_image_message() { - let encoding_aes_key = "QUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUE"; - let corp_id = "ww_test_corp"; - - let inner_xml = "134883186088881000002"; - - let encrypted = encrypt_for_test(encoding_aes_key, inner_xml, corp_id); - let decrypted = decrypt_message(encoding_aes_key, &encrypted, corp_id).unwrap(); - let msg = parse_message_xml(&decrypted).unwrap(); - assert_eq!(msg.msg_type, "image"); - assert_eq!(msg.pic_url, "http://example.com/pic.jpg"); - assert_eq!(msg.from_user, "user42"); - } - - #[test] - fn unsupported_msg_type_skipped() { - let xml = "134883186077771000002"; - let msg = parse_message_xml(xml).unwrap(); - assert_eq!(msg.msg_type, "voice"); - assert!(!matches!(msg.msg_type.as_str(), "text" | "image")); - } - - #[test] - fn verify_rejects_wrong_signature() { - let token = "testtoken"; - let encoding_aes_key = "QUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUE"; - let corp_id = "ww_test_corp"; - let echostr_plain = "test_echo"; - - let echostr_encrypted = encrypt_for_test(encoding_aes_key, echostr_plain, corp_id); - - let result = handle_verify_request( - token, - encoding_aes_key, - corp_id, - "completely_wrong_signature", - "1409659813", - "nonce123", - &echostr_encrypted, - ); - assert!(result.is_err()); - } - - #[test] - fn decrypt_with_large_padding_value() { - // Verifies decryption works when WeCom's 32-byte padding exceeds 16 - let encoding_aes_key = "QUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUE"; - let corp_id = "ww_test_corp"; - // Choose a message where (16 + 4 + msg_len + corp_id_len) % 32 < 16, - // producing a pad value > 16 which would fail with PKCS7/block_size=16. - // 16 + 4 + 1 + 12 = 33 → 33 % 32 = 1 → pad = 31 - let msg = "x"; - let encrypted = encrypt_for_test(encoding_aes_key, msg, corp_id); - let decrypted = decrypt_message(encoding_aes_key, &encrypted, corp_id).unwrap(); - assert_eq!(decrypted, msg); - } - - #[test] - fn decrypt_rejects_wrong_corp_id() { - let encoding_aes_key = "QUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUE"; - let corp_id = "ww_test_corp"; - let msg = "hello"; - - let encrypted = encrypt_for_test(encoding_aes_key, msg, corp_id); - let result = decrypt_message(encoding_aes_key, &encrypted, "ww_other_corp"); - assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("corp_id mismatch")); - } -} diff --git a/gateway/src/main.rs b/gateway/src/main.rs deleted file mode 100644 index b5cad2f6a..000000000 --- a/gateway/src/main.rs +++ /dev/null @@ -1,801 +0,0 @@ -mod adapters; -mod media; -mod schema; -pub mod store; - -use anyhow::Result; -use axum::{ - extract::State, - response::IntoResponse, - routing::{get, post}, - Router, -}; -use futures_util::{SinkExt, StreamExt}; -use schema::GatewayReply; -use std::collections::HashMap; -use std::sync::Arc; -use std::time::Instant; -use tokio::sync::{broadcast, Mutex, Semaphore}; -use tracing::{info, warn}; - -// --- Reply token cache for LINE hybrid Reply/Push dispatch --- - -/// Cache entry for LINE reply tokens: (replyToken, insertion_time). -/// Uses std::sync::Mutex — critical sections are short (insert/remove/retain) -/// and never held across .await, so async Mutex overhead is unnecessary. -pub type ReplyTokenCache = Arc>>; - -/// Maximum age (in seconds) before a cached reply token is considered expired. -/// LINE tokens are valid for ~1 minute; we use 50s as a conservative margin. -pub const REPLY_TOKEN_TTL_SECS: u64 = 50; - -/// Maximum number of cached reply tokens. Prevents unbounded memory growth -/// if webhooks arrive faster than OAB can reply (e.g. OAB offline, spam burst). -pub const REPLY_TOKEN_CACHE_MAX: usize = 10_000; - -/// Maximum number of post-ack LINE webhook payloads processed concurrently. -/// Keeps image download/decode work bounded during bursts without giving up the -/// fast 200 OK response path. -pub const LINE_WEBHOOK_CONCURRENCY_MAX: usize = 8; - -// --- App state (shared across all adapters) --- - -pub struct AppState { - /// Telegram bot token (None if Telegram disabled) - pub telegram_bot_token: Option, - /// Telegram webhook secret token for request validation - pub telegram_secret_token: Option, - /// Use sendRichMessage for complex content (Bot API 10.1+) - pub telegram_rich_messages: bool, - /// LINE channel secret for signature validation - pub line_channel_secret: Option, - /// LINE channel access token for reply API - pub line_access_token: Option, - /// Teams adapter (None if Teams disabled) - pub teams: Option, - /// service_url cache for Teams reply routing (conversation_id → (service_url, last_seen)) - pub teams_service_urls: Mutex>, - /// Feishu adapter (None if Feishu disabled) - pub feishu: Option, - /// Google Chat adapter (None if Google Chat disabled) - pub google_chat: Option, - pub wecom: Option, - /// WebSocket authentication token - pub ws_token: Option, - /// Broadcast channel: gateway → OAB (events from all platforms) - pub event_tx: broadcast::Sender, - /// Cache: event_id → (LINE replyToken, timestamp). - /// Global across all OAB WebSocket clients. LINE reply tokens are single-use: - /// the first client to `remove()` a token wins the free Reply API call; - /// other clients for the same event naturally fall back to Push API. - pub reply_token_cache: ReplyTokenCache, - /// Limits concurrent post-ack LINE webhook processing so image bursts do not - /// turn into unbounded download/decode work. - pub line_webhook_semaphore: Arc, - /// Shared HTTP client for media downloads and API calls - pub client: reqwest::Client, -} - -// --- WebSocket handler (OAB connects here) --- - -async fn ws_handler( - State(state): State>, - query: axum::extract::Query>, - ws: axum::extract::WebSocketUpgrade, -) -> axum::response::Response { - if let Some(ref expected) = state.ws_token { - let provided = query.get("token").map(|s| s.as_str()); - if provided != Some(expected.as_str()) { - warn!("WebSocket rejected: invalid or missing token"); - return axum::http::StatusCode::UNAUTHORIZED.into_response(); - } - } - ws.on_upgrade(move |socket| handle_oab_connection(state, socket)) -} - -async fn handle_oab_connection(state: Arc, socket: axum::extract::ws::WebSocket) { - use axum::extract::ws::Message; - - let (mut ws_tx, mut ws_rx) = socket.split(); - let mut event_rx = state.event_tx.subscribe(); - - info!("OAB client connected via WebSocket"); - - // Forward gateway events → OAB - let send_task = tokio::spawn(async move { - loop { - tokio::select! { - Ok(event_json) = event_rx.recv() => { - if ws_tx.send(Message::Text(event_json.into())).await.is_err() { - break; - } - } - } - } - }); - - // Receive OAB replies → route to correct platform - let state_for_recv = state.clone(); - // Track per-message reaction state (Telegram replaces all reactions atomically) - let reaction_state: Arc>>> = - Arc::new(Mutex::new(HashMap::new())); - let recv_task = tokio::spawn(async move { - let client = reqwest::Client::new(); - while let Some(Ok(msg)) = ws_rx.next().await { - if let Message::Text(text) = msg { - match serde_json::from_str::(&text) { - Ok(reply) => { - info!( - platform = %reply.platform, - channel = %reply.channel.id, - command = ?reply.command.as_deref(), - "OAB → gateway reply" - ); - match reply.platform.as_str() { - "telegram" => { - if let Some(ref token) = state_for_recv.telegram_bot_token { - adapters::telegram::handle_reply( - &reply, - token, - &client, - &state_for_recv.event_tx, - &reaction_state, - state_for_recv.telegram_rich_messages, - ) - .await; - } else { - warn!("reply for telegram but adapter not configured"); - } - } - "line" => { - if let Some(ref access_token) = state_for_recv.line_access_token { - adapters::line::dispatch_line_reply( - &client, - access_token, - &state_for_recv.reply_token_cache, - &reply, - adapters::line::LINE_API_BASE, - ) - .await; - } else { - warn!("reply for line but adapter not configured"); - } - } - "teams" => { - if let Some(ref teams) = state_for_recv.teams { - adapters::teams::handle_reply( - &reply, - teams, - &state_for_recv.teams_service_urls, - ) - .await; - } else { - warn!("reply for teams but adapter not configured"); - } - } - "feishu" => { - if let Some(ref feishu) = state_for_recv.feishu { - adapters::feishu::handle_reply( - &reply, - feishu, - &state_for_recv.event_tx, - ) - .await; - } else { - warn!("reply for feishu but adapter not configured"); - } - } - "googlechat" => { - if let Some(ref gc) = state_for_recv.google_chat { - gc.handle_reply(&reply, &state_for_recv.event_tx).await; - } else { - warn!("reply for googlechat but adapter not configured"); - } - } - "wecom" => { - if let Some(ref wecom) = state_for_recv.wecom { - wecom.handle_reply(&reply, &state_for_recv.event_tx).await; - } else { - warn!("reply for wecom but adapter not configured"); - } - } - other => warn!(platform = other, "unknown reply platform"), - } - } - Err(e) => warn!("invalid reply from OAB: {e}"), - } - } - } - }); - - tokio::select! { - _ = send_task => {}, - _ = recv_task => {}, - } - info!("OAB client disconnected"); -} - -async fn health() -> &'static str { - "ok" -} - -#[tokio::main] -async fn main() -> Result<()> { - tracing_subscriber::fmt() - .with_env_filter( - tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into()), - ) - .init(); - - let listen_addr = std::env::var("GATEWAY_LISTEN").unwrap_or_else(|_| "0.0.0.0:8080".into()); - let ws_token = std::env::var("GATEWAY_WS_TOKEN").ok(); - - if ws_token.is_none() { - warn!("GATEWAY_WS_TOKEN not set — WebSocket connections are NOT authenticated (insecure)"); - } - - let (event_tx, _) = broadcast::channel::(256); - let reply_token_cache: ReplyTokenCache = Arc::new(std::sync::Mutex::new(HashMap::new())); - - let mut app = Router::new() - .route("/ws", get(ws_handler)) - .route("/health", get(health)); - - // Telegram adapter - let telegram_bot_token = std::env::var("TELEGRAM_BOT_TOKEN").ok(); - let telegram_secret_token = std::env::var("TELEGRAM_SECRET_TOKEN").ok(); - let telegram_rich_messages = std::env::var("TELEGRAM_RICH_MESSAGES") - .map(|v| v != "0" && !v.eq_ignore_ascii_case("false")) - .unwrap_or(true); - if telegram_bot_token.is_some() { - let webhook_path = - std::env::var("TELEGRAM_WEBHOOK_PATH").unwrap_or_else(|_| "/webhook/telegram".into()); - if telegram_secret_token.is_none() { - warn!("TELEGRAM_SECRET_TOKEN not set — webhook requests are NOT validated (insecure)"); - } - info!(path = %webhook_path, "telegram adapter enabled"); - app = app.route(&webhook_path, post(adapters::telegram::webhook)); - } - - // LINE adapter — route is always mounted so inbound webhooks are accepted - // even without an access token (signature validation only needs LINE_CHANNEL_SECRET). - let line_channel_secret = std::env::var("LINE_CHANNEL_SECRET").ok(); - let line_access_token = std::env::var("LINE_CHANNEL_ACCESS_TOKEN").ok(); - info!("line adapter enabled"); - app = app.route("/webhook/line", post(adapters::line::webhook)); - - // Teams adapter - let teams = adapters::teams::TeamsConfig::from_env().map(|config| { - info!("teams adapter enabled"); - adapters::teams::TeamsAdapter::new(config) - }); - if teams.is_some() { - let webhook_path = - std::env::var("TEAMS_WEBHOOK_PATH").unwrap_or_else(|_| "/webhook/teams".into()); - info!(path = %webhook_path, "teams webhook registered"); - app = app.route(&webhook_path, post(adapters::teams::webhook)); - } - - // Feishu adapter - let feishu_config = adapters::feishu::FeishuConfig::from_env(); - let feishu_ws_mode = feishu_config - .as_ref() - .map(|c| c.connection_mode == adapters::feishu::ConnectionMode::Websocket) - .unwrap_or(false); - if let Some(ref config) = feishu_config { - match config.connection_mode { - adapters::feishu::ConnectionMode::Websocket => { - info!("feishu adapter enabled (websocket) — will connect after state init"); - } - adapters::feishu::ConnectionMode::Webhook => { - let path = config.webhook_path.clone(); - info!(path = %path, "feishu adapter enabled (webhook)"); - app = app.route(&path, post(adapters::feishu::webhook)); - } - } - } - let feishu = feishu_config.map(adapters::feishu::FeishuAdapter::new); - - // Resolve feishu bot identity early (needed for mention gating in both modes) - if let Some(ref f) = feishu { - f.resolve_bot_identity().await; - } - - // Google Chat adapter - let google_chat_enabled = std::env::var("GOOGLE_CHAT_ENABLED") - .map(|v| v == "true" || v == "1") - .unwrap_or(false); - let google_chat = if google_chat_enabled { - let token_cache = std::env::var("GOOGLE_CHAT_SA_KEY_JSON") - .ok() - .or_else(|| { - std::env::var("GOOGLE_CHAT_SA_KEY_FILE") - .ok() - .and_then(|path| std::fs::read_to_string(&path).ok()) - }) - .and_then(|json| { - adapters::googlechat::GoogleChatTokenCache::new(&json) - .map_err(|e| warn!("googlechat SA key error: {e}")) - .ok() - }); - let access_token = std::env::var("GOOGLE_CHAT_ACCESS_TOKEN").ok(); - let jwt_verifier = std::env::var("GOOGLE_CHAT_AUDIENCE").ok().map(|aud| { - info!("googlechat webhook JWT verification enabled (audience={aud})"); - adapters::googlechat::GoogleChatJwtVerifier::new(aud) - }); - - let webhook_path = std::env::var("GOOGLE_CHAT_WEBHOOK_PATH") - .unwrap_or_else(|_| "/webhook/googlechat".into()); - info!(path = %webhook_path, "googlechat adapter enabled"); - app = app.route(&webhook_path, post(adapters::googlechat::webhook)); - - if token_cache.is_some() { - info!("googlechat service account configured — token auto-refresh enabled"); - } else if access_token.is_some() { - warn!("googlechat using static access token — will expire in ~1 hour"); - } else { - warn!("GOOGLE_CHAT_ACCESS_TOKEN / GOOGLE_CHAT_SA_KEY_JSON not set — replies will be logged but not sent"); - } - if jwt_verifier.is_none() { - warn!( - "GOOGLE_CHAT_AUDIENCE not set — webhook requests are NOT authenticated (insecure)" - ); - } - - Some(adapters::googlechat::GoogleChatAdapter::new( - token_cache, - access_token, - jwt_verifier, - )) - } else { - None - }; - - // WeCom adapter - let wecom = adapters::wecom::WecomConfig::from_env().map(|config| { - let path = config.webhook_path.clone(); - info!(path = %path, "wecom adapter enabled"); - adapters::wecom::WecomAdapter::new(config) - }); - if let Some(ref w) = wecom { - app = app - .route( - &w.config.webhook_path, - axum::routing::get(adapters::wecom::verify), - ) - .route(&w.config.webhook_path, post(adapters::wecom::webhook)); - } - - if telegram_bot_token.is_none() - && line_access_token.is_none() - && teams.is_none() - && feishu.is_none() - && google_chat.is_none() - && wecom.is_none() - { - warn!("no adapters configured — set TELEGRAM_BOT_TOKEN, LINE_CHANNEL_ACCESS_TOKEN, TEAMS_APP_ID + TEAMS_APP_SECRET, FEISHU_APP_ID + FEISHU_APP_SECRET, GOOGLE_CHAT_ENABLED=true, and/or WECOM_CORP_ID + WECOM_SECRET + WECOM_TOKEN + WECOM_ENCODING_AES_KEY + WECOM_AGENT_ID"); - } - - let client = reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(30)) - .build() - .expect("HTTP client must build"); - - let state = Arc::new(AppState { - telegram_bot_token, - telegram_secret_token, - telegram_rich_messages, - line_channel_secret, - line_access_token, - teams, - teams_service_urls: Mutex::new(HashMap::new()), - feishu, - google_chat, - wecom, - ws_token, - event_tx, - reply_token_cache, - line_webhook_semaphore: Arc::new(Semaphore::new(LINE_WEBHOOK_CONCURRENCY_MAX)), - client, - }); - - // Background task: sweep expired reply tokens every REPLY_TOKEN_TTL_SECS - { - let cache_state = state.clone(); - tokio::spawn(async move { - loop { - tokio::time::sleep(std::time::Duration::from_secs(REPLY_TOKEN_TTL_SECS)).await; - let mut cache = cache_state - .reply_token_cache - .lock() - .unwrap_or_else(|e| e.into_inner()); - let before = cache.len(); - cache.retain(|_, (_, t)| t.elapsed().as_secs() < REPLY_TOKEN_TTL_SECS); - let after = cache.len(); - if before != after { - info!( - removed = before - after, - remaining = after, - "reply token cache sweep" - ); - } - } - }); - } - - // Periodic cleanup of stale Teams service_url entries (TTL: 4 hours) - { - let state_for_cleanup = state.clone(); - tokio::spawn(async move { - loop { - tokio::time::sleep(std::time::Duration::from_secs(300)).await; - let mut urls = state_for_cleanup.teams_service_urls.lock().await; - let before = urls.len(); - urls.retain(|_, (_, t)| t.elapsed().as_secs() < 4 * 3600); - let after = urls.len(); - if before != after { - info!( - removed = before - after, - remaining = after, - "teams service_url cache cleanup" - ); - } - } - }); - } - - let app = app.with_state(state.clone()); - - // Background task: evict expired media files (colocate store, TTL 2 min) - tokio::spawn(store::eviction_loop()); - - // Spawn feishu WebSocket long-connection if configured - // feishu_shutdown_tx must remain alive for the lifetime of main() — dropping - // it signals shutdown to the WS task via feishu_shutdown_rx. - let (feishu_shutdown_tx, feishu_shutdown_rx) = tokio::sync::watch::channel(false); - if feishu_ws_mode { - if let Some(ref feishu) = state.feishu { - match adapters::feishu::start_websocket( - feishu, - state.event_tx.clone(), - feishu_shutdown_rx, - ) - .await - { - Ok(_handle) => info!("feishu websocket task spawned"), - Err(e) => tracing::error!(err = %e, "feishu websocket startup failed"), - } - } - } - - info!(addr = %listen_addr, "gateway starting"); - let listener = tokio::net::TcpListener::bind(&listen_addr).await?; - axum::serve(listener, app).await?; - drop(feishu_shutdown_tx); - Ok(()) -} - -#[cfg(test)] -mod tests { - use super::*; - use std::time::Duration; - use wiremock::matchers::{body_json, header, method, path}; - use wiremock::{Mock, MockServer, ResponseTemplate}; - - fn make_reply(event_id: &str) -> schema::GatewayReply { - schema::GatewayReply { - schema: "openab.gateway.reply.v1".into(), - reply_to: event_id.into(), - platform: "line".into(), - channel: schema::ReplyChannel { - id: "U1234".into(), - thread_id: None, - }, - content: schema::Content { - content_type: "text".into(), - text: "hello".into(), - attachments: Vec::new(), - }, - command: None, - request_id: None, - quote_message_id: None, - } - } - - fn make_reply_with_command(event_id: &str, command: &str, text: &str) -> schema::GatewayReply { - schema::GatewayReply { - schema: "openab.gateway.reply.v1".into(), - reply_to: event_id.into(), - platform: "line".into(), - channel: schema::ReplyChannel { - id: "U1234".into(), - thread_id: None, - }, - content: schema::Content { - content_type: "text".into(), - text: text.into(), - attachments: Vec::new(), - }, - command: Some(command.into()), - request_id: None, - quote_message_id: None, - } - } - - fn make_cache() -> ReplyTokenCache { - Arc::new(std::sync::Mutex::new(HashMap::new())) - } - - /// Cache hit: uses Reply API with correct replyToken, bearer token, and message body. - /// Does NOT call Push API. - #[tokio::test] - async fn cache_hit_uses_reply_api() { - let server = MockServer::start().await; - let _reply = Mock::given(method("POST")) - .and(path("/v2/bot/message/reply")) - .and(header("authorization", "Bearer test_access_token")) - .and(body_json(serde_json::json!({ - "replyToken": "tok_abc", - "messages": [{"type": "text", "text": "hello"}] - }))) - .respond_with(ResponseTemplate::new(200).set_body_string("{}")) - .expect(1) - .mount_as_scoped(&server) - .await; - let _push = Mock::given(method("POST")) - .and(path("/v2/bot/message/push")) - .respond_with(ResponseTemplate::new(200)) - .expect(0) - .mount_as_scoped(&server) - .await; - - let cache = make_cache(); - cache - .lock() - .unwrap() - .insert("evt_1".into(), ("tok_abc".into(), Instant::now())); - - let client = reqwest::Client::new(); - let used = adapters::line::dispatch_line_reply( - &client, - "test_access_token", - &cache, - &make_reply("evt_1"), - &server.uri(), - ) - .await; - - assert!(used, "should report Reply API was used"); - } - - /// All unsupported LINE commands should be ignored without consuming the cached reply token. - #[tokio::test] - async fn line_ignores_unsupported_commands_without_touching_cache() { - let unsupported = &["add_reaction", "remove_reaction", "create_topic"]; - - for cmd in unsupported { - let server = MockServer::start().await; - let _reply = Mock::given(method("POST")) - .and(path("/v2/bot/message/reply")) - .respond_with(ResponseTemplate::new(200)) - .expect(0) - .mount_as_scoped(&server) - .await; - let _push = Mock::given(method("POST")) - .and(path("/v2/bot/message/push")) - .respond_with(ResponseTemplate::new(200)) - .expect(0) - .mount_as_scoped(&server) - .await; - - let cache = make_cache(); - cache - .lock() - .unwrap() - .insert("evt_unsup".into(), ("tok_unsup".into(), Instant::now())); - - let client = reqwest::Client::new(); - let used = adapters::line::dispatch_line_reply( - &client, - "test_access_token", - &cache, - &make_reply_with_command("evt_unsup", cmd, "payload"), - &server.uri(), - ) - .await; - - assert!(!used, "{cmd}: should not report reply usage"); - assert!( - cache.lock().unwrap().contains_key("evt_unsup"), - "{cmd}: should not consume the cached reply token" - ); - } - } - - /// Cache miss: falls back to Push API with correct "to", bearer token, and message body. - #[tokio::test] - async fn cache_miss_uses_push_api() { - let server = MockServer::start().await; - let _reply = Mock::given(method("POST")) - .and(path("/v2/bot/message/reply")) - .respond_with(ResponseTemplate::new(200)) - .expect(0) - .mount_as_scoped(&server) - .await; - let _push = Mock::given(method("POST")) - .and(path("/v2/bot/message/push")) - .and(header("authorization", "Bearer test_access_token")) - .and(body_json(serde_json::json!({ - "to": "U1234", - "messages": [{"type": "text", "text": "hello"}] - }))) - .respond_with(ResponseTemplate::new(200).set_body_string("{}")) - .expect(1) - .mount_as_scoped(&server) - .await; - - let cache = make_cache(); - - let client = reqwest::Client::new(); - let used = adapters::line::dispatch_line_reply( - &client, - "test_access_token", - &cache, - &make_reply("evt_miss"), - &server.uri(), - ) - .await; - - assert!(!used, "should report Push API was used (no reply token)"); - } - - /// Expired cached token: falls back to Push API. - #[tokio::test] - async fn expired_token_uses_push_api() { - let server = MockServer::start().await; - let _reply = Mock::given(method("POST")) - .and(path("/v2/bot/message/reply")) - .respond_with(ResponseTemplate::new(200)) - .expect(0) - .mount_as_scoped(&server) - .await; - let _push = Mock::given(method("POST")) - .and(path("/v2/bot/message/push")) - .and(header("authorization", "Bearer test_access_token")) - .and(body_json(serde_json::json!({ - "to": "U1234", - "messages": [{"type": "text", "text": "hello"}] - }))) - .respond_with(ResponseTemplate::new(200).set_body_string("{}")) - .expect(1) - .mount_as_scoped(&server) - .await; - - let cache = make_cache(); - let expired_time = Instant::now() - Duration::from_secs(REPLY_TOKEN_TTL_SECS + 10); - cache - .lock() - .unwrap() - .insert("evt_exp".into(), ("tok_old".into(), expired_time)); - - let client = reqwest::Client::new(); - let used = adapters::line::dispatch_line_reply( - &client, - "test_access_token", - &cache, - &make_reply("evt_exp"), - &server.uri(), - ) - .await; - - assert!(!used, "should report Push API was used (expired token)"); - } - - /// Reply API 400 with invalid/expired reply token: falls back to Push API. - #[tokio::test] - async fn reply_400_invalid_token_falls_back_to_push() { - let server = MockServer::start().await; - let _reply = Mock::given(method("POST")) - .and(path("/v2/bot/message/reply")) - .and(header("authorization", "Bearer test_access_token")) - .respond_with( - ResponseTemplate::new(400).set_body_string(r#"{"message":"Invalid reply token"}"#), - ) - .expect(1) - .mount_as_scoped(&server) - .await; - let _push = Mock::given(method("POST")) - .and(path("/v2/bot/message/push")) - .and(header("authorization", "Bearer test_access_token")) - .and(body_json(serde_json::json!({ - "to": "U1234", - "messages": [{"type": "text", "text": "hello"}] - }))) - .respond_with(ResponseTemplate::new(200).set_body_string("{}")) - .expect(1) - .mount_as_scoped(&server) - .await; - - let cache = make_cache(); - cache - .lock() - .unwrap() - .insert("evt_400".into(), ("tok_bad".into(), Instant::now())); - - let client = reqwest::Client::new(); - let used = adapters::line::dispatch_line_reply( - &client, - "test_access_token", - &cache, - &make_reply("evt_400"), - &server.uri(), - ) - .await; - - assert!(!used, "should fall back to Push on 400 invalid token"); - } - - /// Reply API 5xx: does NOT fall back to Push (duplicate risk). - #[tokio::test] - async fn reply_5xx_does_not_fallback() { - let server = MockServer::start().await; - let _reply = Mock::given(method("POST")) - .and(path("/v2/bot/message/reply")) - .and(header("authorization", "Bearer test_access_token")) - .respond_with(ResponseTemplate::new(500).set_body_string("Internal Server Error")) - .expect(1) - .mount_as_scoped(&server) - .await; - let _push = Mock::given(method("POST")) - .and(path("/v2/bot/message/push")) - .respond_with(ResponseTemplate::new(200)) - .expect(0) - .mount_as_scoped(&server) - .await; - - let cache = make_cache(); - cache - .lock() - .unwrap() - .insert("evt_5xx".into(), ("tok_5xx".into(), Instant::now())); - - let client = reqwest::Client::new(); - let used = adapters::line::dispatch_line_reply( - &client, - "test_access_token", - &cache, - &make_reply("evt_5xx"), - &server.uri(), - ) - .await; - - assert!(used, "should NOT fall back to Push on 5xx"); - } - - /// Reply API network/timeout error: does NOT fall back to Push (duplicate risk). - #[tokio::test] - async fn reply_network_error_does_not_fallback() { - let bad_base = "http://127.0.0.1:1"; - - let cache = make_cache(); - cache - .lock() - .unwrap() - .insert("evt_net".into(), ("tok_net".into(), Instant::now())); - - let client = reqwest::Client::builder() - .timeout(Duration::from_millis(100)) - .build() - .unwrap(); - let used = adapters::line::dispatch_line_reply( - &client, - "test_access_token", - &cache, - &make_reply("evt_net"), - bad_base, - ) - .await; - - assert!(used, "should NOT fall back to Push on network error"); - } -} diff --git a/gateway/src/media.rs b/gateway/src/media.rs deleted file mode 100644 index f6eb88565..000000000 --- a/gateway/src/media.rs +++ /dev/null @@ -1,123 +0,0 @@ -use image::ImageReader; -use std::io::Cursor; - -/// Media type for download functions — avoids stringly-typed branching. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum MediaKind { - Image, - Audio, -} - -pub const IMAGE_MAX_DIMENSION_PX: u32 = 1200; -pub const IMAGE_JPEG_QUALITY: u8 = 75; -pub const IMAGE_MAX_DOWNLOAD: u64 = 10 * 1024 * 1024; // 10 MB -pub const FILE_MAX_DOWNLOAD: u64 = 20 * 1024 * 1024; // 20 MB (same as store cap) -pub const AUDIO_MAX_DOWNLOAD: u64 = 20 * 1024 * 1024; // 20 MB -pub const GIF_MAX_SIZE: usize = 5 * 1024 * 1024; // 5 MB — prevents base64 bloat exceeding LLM payload limits - -/// Resize image so longest side <= 1200px, then encode as JPEG. -/// GIFs under 5MB are passed through unchanged to preserve animation. -pub fn resize_and_compress(raw: &[u8]) -> Result<(Vec, String), image::ImageError> { - let reader = ImageReader::new(Cursor::new(raw)).with_guessed_format()?; - let format = reader.format(); - if format == Some(image::ImageFormat::Gif) { - if raw.len() > GIF_MAX_SIZE { - return Err(image::ImageError::Limits( - image::error::LimitError::from_kind(image::error::LimitErrorKind::DimensionError), - )); - } - return Ok((raw.to_vec(), "image/gif".to_string())); - } - let img = reader.decode()?; - let (w, h) = (img.width(), img.height()); - let img = if w > IMAGE_MAX_DIMENSION_PX || h > IMAGE_MAX_DIMENSION_PX { - let max_side = std::cmp::max(w, h); - let ratio = f64::from(IMAGE_MAX_DIMENSION_PX) / f64::from(max_side); - let new_w = (f64::from(w) * ratio) as u32; - let new_h = (f64::from(h) * ratio) as u32; - img.resize(new_w, new_h, image::imageops::FilterType::Lanczos3) - } else { - img - }; - let mut buf = Cursor::new(Vec::new()); - let encoder = image::codecs::jpeg::JpegEncoder::new_with_quality(&mut buf, IMAGE_JPEG_QUALITY); - img.write_with_encoder(encoder)?; - Ok((buf.into_inner(), "image/jpeg".to_string())) -} - -/// Derive file extension from Content-Type for audio files. -pub fn audio_extension(content_type: &str) -> &'static str { - if content_type.contains("mpeg") || content_type.contains("mp3") { - "mp3" - } else if content_type.contains("m4a") || content_type.contains("mp4") { - "m4a" - } else { - "ogg" - } -} - -/// Check if a filename has a text-like extension suitable for reading as UTF-8. -pub fn is_text_extension(filename: &str) -> bool { - const TEXT_EXTS: &[&str] = &[ - "txt", "csv", "log", "md", "json", "jsonl", "yaml", "yml", "toml", "xml", "rs", "py", - "js", "ts", "jsx", "tsx", "go", "java", "c", "cpp", "h", "hpp", "rb", "sh", "bash", - "sql", "html", "css", "ini", "cfg", "conf", - ]; - let ext = filename.rsplit('.').next().unwrap_or("").to_lowercase(); - TEXT_EXTS.contains(&ext.as_str()) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn gif_under_limit_passes_through() { - let gif = b"GIF89a\x01\x00\x01\x00\x80\x00\x00\xff\xff\xff\x00\x00\x00!\xf9\x04\x00\x00\x00\x00\x00,\x00\x00\x00\x00\x01\x00\x01\x00\x00\x02\x02D\x01\x00;"; - let result = resize_and_compress(gif); - assert!(result.is_ok()); - let (data, mime) = result.unwrap(); - assert_eq!(mime, "image/gif"); - assert_eq!(data, gif); - } - - #[test] - fn gif_over_limit_returns_error() { - let mut data = b"GIF89a\x01\x00\x01\x00\x80\x00\x00\xff\xff\xff\x00\x00\x00".to_vec(); - data.resize(GIF_MAX_SIZE + 1, 0); - let result = resize_and_compress(&data); - assert!(result.is_err()); - } - - #[test] - fn small_jpeg_not_resized() { - let img = image::RgbImage::from_pixel(2, 2, image::Rgb([255, 0, 0])); - let mut buf = std::io::Cursor::new(Vec::new()); - img.write_to(&mut buf, image::ImageFormat::Jpeg).unwrap(); - let result = resize_and_compress(&buf.into_inner()); - assert!(result.is_ok()); - assert_eq!(result.unwrap().1, "image/jpeg"); - } - - #[test] - fn large_image_gets_resized() { - let img = image::RgbImage::from_pixel(2000, 2000, image::Rgb([0, 128, 255])); - let mut buf = std::io::Cursor::new(Vec::new()); - img.write_to(&mut buf, image::ImageFormat::Png).unwrap(); - let result = resize_and_compress(&buf.into_inner()); - assert!(result.is_ok()); - let (data, mime) = result.unwrap(); - assert_eq!(mime, "image/jpeg"); - let decoded = image::load_from_memory(&data).unwrap(); - assert!(decoded.width() <= IMAGE_MAX_DIMENSION_PX); - assert!(decoded.height() <= IMAGE_MAX_DIMENSION_PX); - } - - #[test] - fn text_extension_check() { - assert!(is_text_extension("main.rs")); - assert!(is_text_extension("data.csv")); - assert!(!is_text_extension("archive.zip")); - assert!(!is_text_extension("photo.jpg")); - } -} diff --git a/gateway/src/schema.rs b/gateway/src/schema.rs deleted file mode 100644 index 740d0fab8..000000000 --- a/gateway/src/schema.rs +++ /dev/null @@ -1,126 +0,0 @@ -use serde::{Deserialize, Serialize}; - -// --- Event schema (ADR openab.gateway.event.v1) --- - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct GatewayEvent { - pub schema: String, - pub event_id: String, - pub timestamp: String, - pub platform: String, - pub event_type: String, - pub channel: ChannelInfo, - pub sender: SenderInfo, - pub content: Content, - pub mentions: Vec, - pub message_id: String, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct ChannelInfo { - pub id: String, - #[serde(rename = "type")] - pub channel_type: String, - pub thread_id: Option, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct SenderInfo { - pub id: String, - pub name: String, - pub display_name: String, - pub is_bot: bool, -} - -#[derive(Clone, Debug, Default, Serialize, Deserialize)] -pub struct Content { - #[serde(rename = "type")] - pub content_type: String, - pub text: String, - #[serde(default, skip_serializing_if = "Vec::is_empty")] - pub attachments: Vec, -} - -#[derive(Clone, Debug, Default, Serialize, Deserialize)] -pub struct Attachment { - #[serde(rename = "type")] - pub attachment_type: String, // "image", "text_file", "audio" - pub filename: String, - pub mime_type: String, - /// Base64-encoded data (deprecated — use `path` for colocate mode). - /// Kept for backward compatibility; Core prefers `path` when present. - #[serde(default, skip_serializing_if = "String::is_empty")] - pub data: String, - pub size: u64, // size in bytes (after compression for images) - /// Local file path for colocate mode (gateway + core share filesystem). - /// When set, Core reads bytes directly from this path instead of decoding `data`. - /// Path format: ~/.openab/media/inbound/ (no extension, MIME in mime_type). - #[serde(default, skip_serializing_if = "Option::is_none")] - pub path: Option, -} - -// --- Reply schema (ADR openab.gateway.reply.v1) --- - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct GatewayReply { - pub schema: String, - pub reply_to: String, - pub platform: String, - pub channel: ReplyChannel, - pub content: Content, - #[serde(default)] - pub command: Option, - #[serde(default)] - pub request_id: Option, - /// When set, send this message as a reply/quote to the specified platform message ID. - /// Unlike `reply_to` (which identifies the triggering event for routing/dedup), - /// this field controls the visual reply/quote UI on the platform. - /// If quoting fails, the gateway MUST fall back to sending without quoting. - #[serde(default)] - pub quote_message_id: Option, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct ReplyChannel { - pub id: String, - pub thread_id: Option, -} - -/// Response from gateway back to OAB for commands (e.g. create_topic) -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct GatewayResponse { - pub schema: String, - pub request_id: String, - pub success: bool, - pub thread_id: Option, - pub message_id: Option, - pub error: Option, -} - -impl GatewayEvent { - pub fn new( - platform: &str, - channel: ChannelInfo, - sender: SenderInfo, - text: &str, - message_id: &str, - mentions: Vec, - ) -> Self { - Self { - schema: "openab.gateway.event.v1".into(), - event_id: format!("evt_{}", uuid::Uuid::new_v4()), - timestamp: chrono::Utc::now().to_rfc3339(), - platform: platform.into(), - event_type: "message".into(), - channel, - sender, - content: Content { - content_type: "text".into(), - text: text.into(), - attachments: Vec::new(), - }, - mentions, - message_id: message_id.into(), - } - } -} diff --git a/gateway/src/store.rs b/gateway/src/store.rs deleted file mode 100644 index b08e69903..000000000 --- a/gateway/src/store.rs +++ /dev/null @@ -1,132 +0,0 @@ -use std::path::{Path, PathBuf}; -use tokio::fs; -use tracing::{error, info}; -use uuid::Uuid; - -/// Inbound media directory under $HOME. -/// Pattern follows OpenClaw's `~/.openclaw/media/inbound/`. -/// -/// # Security Considerations -/// -/// - **Path traversal prevention**: Filenames are always server-generated UUIDs, -/// never user-supplied. No extension, no special characters — eliminates path -/// traversal attacks (e.g. `../../etc/passwd`). -/// -/// - **No auth token leakage**: Platform media URLs (Telegram getFile, LINE Content API) -/// contain bot tokens or require auth headers. By downloading in the gateway and -/// storing locally, tokens never reach Core or the agent. -/// -/// - **TTL auto-eviction**: Files are evicted after 2 minutes. Prevents disk exhaustion -/// from accumulated media and limits the window for any leaked file to be exploited. -/// -/// - **Colocate trust boundary**: This module assumes gateway and core share the same -/// filesystem (same pod / same $HOME). The file path is passed over the internal WS -/// connection — never exposed externally. If gateway and core are separated in the -/// future, switch to HTTP media proxy with internal-only binding. -/// -/// - **Size limits enforced before write**: Callers must validate file size against -/// IMAGE_MAX_DOWNLOAD / AUDIO_MAX_DOWNLOAD / FILE_MAX_DOWNLOAD before calling -/// `store_media()`. This module does NOT re-validate — it trusts the caller. -/// -/// - **No executable content**: Stored files are raw bytes (images, audio, text). -/// Core reads them as data only — never executed. The `mime_type` in the event -/// payload determines processing path, not the file content or name. -const MEDIA_INBOUND_DIR: &str = ".openab/media/inbound"; - -/// TTL for stored media files (2 minutes) -const TTL_SECS: u64 = 120; - -/// Get the inbound media directory path, creating it if needed. -pub async fn media_dir() -> PathBuf { - let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".into()); - let dir = Path::new(&home).join(MEDIA_INBOUND_DIR); - if !dir.exists() { - let _ = fs::create_dir_all(&dir).await; - } - dir -} - -/// Maximum file size accepted by store (defense-in-depth, callers should pre-check). -const MAX_STORE_SIZE: usize = 20 * 1024 * 1024; // 20 MB (matches AUDIO_MAX_DOWNLOAD) - -/// Store media bytes to disk, return the absolute file path. -/// Filename is UUID only (no extension) — MIME type is carried in the event payload. -/// Rejects files exceeding MAX_STORE_SIZE as a defense-in-depth measure. -pub async fn store_media(bytes: &[u8]) -> Option { - if bytes.len() > MAX_STORE_SIZE { - error!(size = bytes.len(), max = MAX_STORE_SIZE, "store_media rejected: exceeds size limit"); - return None; - } - let dir = media_dir().await; - let filename = Uuid::new_v4().to_string(); - let path = dir.join(&filename); - match fs::write(&path, bytes).await { - Ok(_) => { - info!(path = %path.display(), size = bytes.len(), "media stored"); - Some(path.to_string_lossy().into_owned()) - } - Err(e) => { - error!(error = %e, "failed to store media file"); - None - } - } -} - -/// Background task: evict files older than TTL_SECS. -pub async fn eviction_loop() { - let mut interval = tokio::time::interval(std::time::Duration::from_secs(30)); - loop { - interval.tick().await; - if let Err(e) = evict_expired().await { - error!(error = %e, "media eviction error"); - } - } -} - -async fn evict_expired() -> std::io::Result<()> { - let dir = media_dir().await; - if !dir.exists() { - return Ok(()); - } - let mut entries = fs::read_dir(&dir).await?; - let now = std::time::SystemTime::now(); - while let Some(entry) = entries.next_entry().await? { - if let Ok(meta) = entry.metadata().await { - if let Ok(modified) = meta.modified() { - if let Ok(age) = now.duration_since(modified) { - if age.as_secs() > TTL_SECS { - let path = entry.path(); - let _ = fs::remove_file(&path).await; - tracing::debug!(path = %path.display(), "evicted expired media"); - } - } - } - } - } - Ok(()) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn store_and_read_back() { - let data = b"hello media"; - let path = store_media(data).await.unwrap(); - let read_back = fs::read(&path).await.unwrap(); - assert_eq!(read_back, data); - // Cleanup - let _ = fs::remove_file(&path).await; - } - - #[tokio::test] - async fn filename_is_uuid_no_extension() { - let path = store_media(b"test").await.unwrap(); - let filename = Path::new(&path).file_name().unwrap().to_str().unwrap(); - // UUID v4 format: 8-4-4-4-12 hex chars - assert_eq!(filename.len(), 36); - assert!(!filename.contains('.')); - let _ = fs::remove_file(&path).await; - } -} diff --git a/src/acp/agentcore.rs b/src/acp/agentcore.rs deleted file mode 100644 index 86696d357..000000000 --- a/src/acp/agentcore.rs +++ /dev/null @@ -1,722 +0,0 @@ -//! AgentCore ACP bridge — stdin/stdout subprocess that bridges ACP JSON-RPC -//! to AgentCore's InvokeAgentRuntimeCommandShell WebSocket API. -//! -//! Invoked as: `openab --agentcore-bridge --runtime-arn ARN --region REGION` -//! -//! Opens a persistent PTY shell in the microVM, launches `kiro-cli acp -//! --trust-all-tools`, and forwards JSON-RPC bidirectionally. - -use anyhow::{anyhow, Result}; -use aws_credential_types::provider::ProvideCredentials; -use aws_sigv4::http_request::{sign, SignableBody, SignableRequest, SigningSettings}; -use aws_sigv4::sign::v4; -use futures_util::{SinkExt, StreamExt}; -use serde_json::{json, Value}; -use sha2::{Digest, Sha256}; -use std::collections::HashMap; -use std::sync::Arc; -use std::time::SystemTime; -use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; -use tokio::sync::Mutex; -use tokio_tungstenite::tungstenite::http; -use tokio_tungstenite::tungstenite::protocol::Message; -use tracing::info; - -const AGENT_CMD_PREFIX: &str = "stty -echo 2>/dev/null; mkdir -p /tmp/kiro-cli && cp -n /mnt/agent/.local/share/kiro-cli/data.sqlite3 /tmp/kiro-cli/ 2>/dev/null; export XDG_DATA_HOME=/tmp; exec "; - -/// WebSocket binary frame channel bytes (1-byte prefix protocol). -const CHANNEL_STDIN: u8 = 0x00; -const CHANNEL_STDOUT: u8 = 0x01; -const CHANNEL_STDERR: u8 = 0x02; - -/// Extract a complete JSON object from a line that may have PTY prefix noise. -/// Uses brace-counting to find matching `{}` pairs, robust against partial JSON -/// or embedded `{` in prompt text. -fn extract_json_object(line: &str) -> Option { - let bytes = line.as_bytes(); - let start = bytes.iter().position(|&b| b == b'{')?; - - let mut depth: i32 = 0; - let mut in_string = false; - let mut escape = false; - - for (i, &c) in bytes.iter().enumerate().skip(start) { - if escape { - escape = false; - continue; - } - if c == b'\\' && in_string { - escape = true; - continue; - } - if c == b'"' { - in_string = !in_string; - continue; - } - if in_string { - continue; - } - if c == b'{' { - depth += 1; - } else if c == b'}' { - depth -= 1; - if depth == 0 { - let candidate = &line[start..=i]; - // Validate it's actually valid JSON - if serde_json::from_str::(candidate).is_ok() { - return Some(candidate.to_string()); - } - // Not valid — try next `{` - return extract_json_object(&line[start + 1..]); - } - } - } - None -} - -/// Entry point for the agentcore bridge subprocess. -pub async fn run_bridge(runtime_arn: &str, region: &str, agent_command: &str) -> Result<()> { - let stdin = BufReader::new(tokio::io::stdin()); - let stdout = tokio::io::stdout(); - - let mut bridge = Bridge::new(runtime_arn, region, agent_command, stdin, stdout); - bridge.run().await -} - -struct Bridge { - runtime_arn: String, - region: String, - agent_command: String, - stdin: R, - stdout: W, - sessions: HashMap, - next_id: u64, -} - -struct ShellHandle { - /// Sender for writing to the WebSocket (stdin of shell) - ws_write: Arc>, - /// Buffered output from kiro-cli (stdout of shell via WebSocket) - line_rx: tokio::sync::mpsc::UnboundedReceiver, - /// Pump task handle - _pump: tokio::task::JoinHandle<()>, - /// Runtime session ID (for future reconnect support). - #[allow(dead_code)] - runtime_session_id: String, - /// kiro-cli's internal ACP session ID - kiro_session_id: String, -} - -type WsSink = futures_util::stream::SplitSink< - tokio_tungstenite::WebSocketStream>, - Message, ->; - -impl Bridge -where - R: AsyncBufReadExt + Unpin, - W: AsyncWriteExt + Unpin, -{ - fn new(runtime_arn: &str, region: &str, agent_command: &str, stdin: R, stdout: W) -> Self { - Self { - runtime_arn: runtime_arn.to_string(), - region: region.to_string(), - agent_command: agent_command.to_string(), - stdin, - stdout, - sessions: HashMap::new(), - next_id: 1000, - } - } - - fn alloc_id(&mut self) -> u64 { - self.next_id += 1; - self.next_id - } - - async fn write_msg(&mut self, msg: &Value) -> Result<()> { - let data = serde_json::to_string(msg)?; - self.stdout.write_all(data.as_bytes()).await?; - self.stdout.write_all(b"\n").await?; - self.stdout.flush().await?; - Ok(()) - } - - async fn write_response(&mut self, id: &Value, result: Value) -> Result<()> { - self.write_msg(&json!({"jsonrpc": "2.0", "id": id, "result": result})) - .await - } - - async fn write_error(&mut self, id: &Value, code: i32, message: &str) -> Result<()> { - self.write_msg( - &json!({"jsonrpc": "2.0", "id": id, "error": {"code": code, "message": message}}), - ) - .await - } - - async fn run(&mut self) -> Result<()> { - let mut line = String::new(); - loop { - line.clear(); - let n = self.stdin.read_line(&mut line).await?; - if n == 0 { - break; // EOF - } - let trimmed = line.trim(); - if trimmed.is_empty() { - continue; - } - let msg: Value = match serde_json::from_str(trimmed) { - Ok(v) => v, - Err(_) => continue, - }; - - let method = msg.get("method").and_then(|m| m.as_str()).unwrap_or(""); - let id = msg.get("id").cloned().unwrap_or(Value::Null); - let params = msg.get("params").cloned().unwrap_or(json!({})); - - // Skip messages without a method (e.g. stray responses) — same fix as Python F1 - if method.is_empty() { - continue; - } - - match method { - "initialize" => { - self.write_response( - &id, - json!({ - "protocolVersion": 1, - "agentInfo": {"name": "agentcore-shell-bridge", "version": "0.2.0"}, - "agentCapabilities": {"loadSession": true} - }), - ) - .await?; - } - "session/new" => { - let ts = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_millis(); - let acp_sid = format!("agentcore-{ts}"); - let runtime_sid = format!("oab-session-{ts:020}-{ts:013x}"); - - // Eagerly open shell + initialize the agent - match self.open_shell(&runtime_sid).await { - Ok(handle) => { - self.sessions.insert(acp_sid.clone(), handle); - self.write_response(&id, json!({"sessionId": acp_sid})) - .await?; - } - Err(e) => { - self.write_error(&id, -32000, &format!("shell init failed: {e}")) - .await?; - } - } - } - "session/load" => { - let acp_sid = params - .get("sessionId") - .and_then(|s| s.as_str()) - .unwrap_or("") - .to_string(); - self.write_response(&id, json!({"sessionId": acp_sid})) - .await?; - } - "session/prompt" => { - self.handle_prompt(&id, ¶ms).await?; - } - "session/cancel" | "cancel" => { - self.handle_cancel(¶ms).await; - } - "session/destroy" | "session/stop" => { - let acp_sid = params - .get("sessionId") - .and_then(|s| s.as_str()) - .unwrap_or("") - .to_string(); - self.sessions.remove(&acp_sid); - if id != Value::Null { - self.write_response(&id, json!({})).await?; - } - } - "session/request_permission" => { - if id != Value::Null { - self.write_response(&id, json!({"approved": true})).await?; - } - } - _ => { - if id != Value::Null { - self.write_error(&id, -32601, &format!("unknown method: {method}")) - .await?; - } - } - } - } - Ok(()) - } - - async fn handle_prompt(&mut self, id: &Value, params: &Value) -> Result<()> { - let acp_sid = params - .get("sessionId") - .and_then(|s| s.as_str()) - .unwrap_or("") - .to_string(); - - // Reconnect if session was lost (shell closed unexpectedly) - if !self.sessions.contains_key(&acp_sid) { - info!("session lost, reconnecting shell..."); - let ts = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_millis(); - let runtime_sid = format!("oab-reconnect-{ts:020}-{ts:013x}"); - match self.open_shell(&runtime_sid).await { - Ok(handle) => { self.sessions.insert(acp_sid.clone(), handle); } - Err(e) => { - self.write_error(id, -32000, &format!("reconnect failed: {e}")).await?; - return Ok(()); - } - } - } - - // Allocate ID before borrowing sessions - let kiro_id = self.alloc_id(); - let kiro_sid = self.sessions.get(&acp_sid) - .map(|s| s.kiro_session_id.clone()) - .unwrap_or_default(); - let mut fwd_params = params.clone(); - if let Some(obj) = fwd_params.as_object_mut() { - obj.insert("sessionId".to_string(), json!(kiro_sid)); - } - let kiro_msg = json!({ - "jsonrpc": "2.0", - "id": kiro_id, - "method": "session/prompt", - "params": fwd_params, - }); - let data = format!("{}\n", serde_json::to_string(&kiro_msg)?); - - // Send prompt to kiro-cli - { - let shell = self.sessions.get_mut(&acp_sid).unwrap(); - let mut w = shell.ws_write.lock().await; - let mut frame = Vec::with_capacity(1 + data.len()); - frame.push(CHANNEL_STDIN); - frame.extend_from_slice(data.as_bytes()); - let _ = w.send(Message::Binary(frame)).await; - } - - // Read responses/notifications from kiro-cli until we get the response for our id. - // We take line_rx out of the session to avoid holding &mut self across await points. - let mut line_rx = match self.sessions.get_mut(&acp_sid) { - Some(s) => std::mem::replace(&mut s.line_rx, tokio::sync::mpsc::unbounded_channel().1), - None => { - self.write_error(id, -32000, "session lost").await?; - return Ok(()); - } - }; - - let result = loop { - match tokio::time::timeout(std::time::Duration::from_secs(300), line_rx.recv()).await { - Ok(Some(line)) => { - let msg: Value = match serde_json::from_str(&line) { - Ok(v) => v, - Err(_) => continue, - }; - if msg.get("id").and_then(|i| i.as_u64()) == Some(kiro_id) { - if let Some(err) = msg.get("error") { - self.write_msg(&json!({"jsonrpc": "2.0", "id": id, "error": err})) - .await?; - } else { - let r = msg - .get("result") - .cloned() - .unwrap_or(json!({"type": "success"})); - self.write_response(id, r).await?; - } - break Some(line_rx); - } - if msg.get("method").is_some() { - self.write_msg(&msg).await?; - } - } - Ok(None) => { - self.write_error(id, -32000, "shell connection closed") - .await?; - self.sessions.remove(&acp_sid); - break None; - } - Err(_) => { - self.write_error(id, -32000, "timeout waiting for agent response") - .await?; - break Some(line_rx); - } - } - }; - - // Put line_rx back - if let Some(rx) = result { - if let Some(s) = self.sessions.get_mut(&acp_sid) { - s.line_rx = rx; - } - } - Ok(()) - } - - async fn handle_cancel(&mut self, params: &Value) { - let acp_sid = params - .get("sessionId") - .and_then(|s| s.as_str()) - .unwrap_or(""); - if let Some(shell) = self.sessions.get(acp_sid) { - let cancel_msg = json!({ - "jsonrpc": "2.0", - "method": "session/cancel", - "params": params, - }); - let data = format!( - "{}\n", - serde_json::to_string(&cancel_msg).unwrap_or_default() - ); - let mut frame = Vec::with_capacity(1 + data.len()); - frame.push(CHANNEL_STDIN); - frame.extend_from_slice(data.as_bytes()); - let mut w = shell.ws_write.lock().await; - let _ = w.send(Message::Binary(frame)).await; - } - } - - #[allow(dead_code)] - fn derive_runtime_session_id(&self, params: &Value) -> String { - // Try to extract from sender_context in prompt blocks - if let Some(blocks) = params.get("prompt").and_then(|p| p.as_array()) { - for block in blocks { - if let Some(text) = block.get("text").and_then(|t| t.as_str()) { - if let Some(start) = text.find("") { - if let Some(end) = text.find("") { - let ctx_str = &text[start + 16..end]; - if let Ok(ctx) = serde_json::from_str::(ctx_str.trim()) { - let platform = ctx - .get("channel") - .and_then(|c| c.as_str()) - .unwrap_or("unknown"); - let thread_id = ctx - .get("thread_id") - .or_else(|| ctx.get("channel_id")) - .and_then(|t| t.as_str()) - .unwrap_or(""); - let mut sid = format!("oab-{platform}-thread-{thread_id}"); - while sid.len() < 33 { - sid.push('0'); - } - return sid; - } - } - } - } - } - } - // Fallback - let mut sid = format!("oab-fallback-{}", uuid::Uuid::new_v4()); - while sid.len() < 33 { - sid.push('0'); - } - sid - } - - async fn open_shell(&self, session_id: &str) -> Result { - let (request, host) = build_signed_request(&self.runtime_arn, session_id, &self.region).await?; - - // Manual TLS connection — gives us full control, avoids connect_async host override - let tcp = tokio::net::TcpStream::connect(format!("{host}:443")) - .await - .map_err(|e| anyhow!("TCP connect to {host}:443 failed: {e}"))?; - - let connector = tokio_tungstenite::Connector::Rustls(std::sync::Arc::new( - rustls::ClientConfig::builder() - .with_root_certificates(rustls::RootCertStore { - roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(), - }) - .with_no_client_auth(), - )); - - let tls_stream = match connector { - tokio_tungstenite::Connector::Rustls(cfg) => { - let domain = rustls::pki_types::ServerName::try_from(host.as_str()) - .map_err(|e| anyhow!("bad DNS: {e}"))? - .to_owned(); - tokio_rustls::TlsConnector::from(cfg) - .connect(domain, tcp) - .await - .map_err(|e| anyhow!("TLS failed: {e}"))? - } - _ => unreachable!(), - }; - - // client_async performs the WebSocket upgrade using our exact request - let (ws_stream, _) = tokio_tungstenite::client_async(request, tls_stream) - .await - .map_err(|e| anyhow!("WebSocket upgrade failed: {e}"))?; - - info!(session_id, "AgentCore shell connected"); - - let (ws_write, mut ws_read) = ws_stream.split(); - let ws_write = Arc::new(Mutex::new(ws_write)); - - // Send agent launch command - let shell_cmd = format!("{}{}\n", AGENT_CMD_PREFIX, self.agent_command); - { - let mut frame = Vec::with_capacity(1 + shell_cmd.len()); - frame.push(CHANNEL_STDIN); - frame.extend_from_slice(shell_cmd.as_bytes()); - let mut w = ws_write.lock().await; - w.send(Message::Binary(frame)) - .await - .map_err(|e| anyhow!("failed to send launch cmd: {e}"))?; - } - - // Channel for forwarding parsed JSON-RPC lines - let (line_tx, mut line_rx) = tokio::sync::mpsc::unbounded_channel::(); - - // Spawn reader pump - let pump = tokio::spawn(async move { - let mut buf = String::new(); - while let Some(Ok(msg)) = ws_read.next().await { - match msg { - Message::Binary(data) => { - if data.len() < 2 { - continue; - } - if data[0] == CHANNEL_STDOUT { - // stdout - if let Ok(s) = std::str::from_utf8(&data[1..]) { - buf.push_str(s); - while let Some(nl) = buf.find('\n') { - let line = buf[..nl].to_string(); - buf = buf[nl + 1..].to_string(); - let trimmed = line.trim().to_string(); - if trimmed.is_empty() { - continue; - } - // Extract JSON object using brace-counting (handles PTY prefix noise) - if let Some(json_str) = extract_json_object(&trimmed) { - if line_tx.send(json_str).is_err() { - return; // receiver dropped — exit pump - } - } - } - } - } else if data[0] == CHANNEL_STDERR { - // stderr — log - if let Ok(s) = std::str::from_utf8(&data[1..]) { - let t = s.trim(); - if !t.is_empty() { - eprintln!("[agentcore] {t}"); - } - } - } - } - Message::Close(_) => break, - _ => {} - } - } - }); - - // Send ACP initialize to the agent (it will respond once booted) - - let init_msg = serde_json::json!({ - "jsonrpc": "2.0", - "id": 0, - "method": "initialize", - "params": { - "protocolVersion": "2024-11-05", - "capabilities": {}, - "clientInfo": {"name": "openab-agentcore-bridge", "version": env!("CARGO_PKG_VERSION")} - } - }); - let init_data = format!("{}\n", serde_json::to_string(&init_msg)?); - - // Send initialize and wait for response (retry if agent hasn't booted yet) - let mut initialized = false; - for attempt in 0..5 { - { - let mut w = ws_write.lock().await; - let mut frame = Vec::with_capacity(1 + init_data.len()); - frame.push(CHANNEL_STDIN); - frame.extend_from_slice(init_data.as_bytes()); - if let Err(e) = w.send(Message::Binary(frame)).await { - if attempt < 4 { - tokio::time::sleep(std::time::Duration::from_secs(2)).await; - continue; - } - return Err(anyhow!("failed to send initialize: {e}")); - } - } - // Wait up to 10s for response — skip notifications (lines without "id":0) - let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(10); - loop { - let remaining = deadline.saturating_duration_since(tokio::time::Instant::now()); - if remaining.is_zero() { - info!(attempt, "no initialize response, retrying..."); - break; - } - match tokio::time::timeout(remaining, line_rx.recv()).await { - Ok(Some(line)) => { - // Check if this is the initialize response (has "id":0 or "id": 0) - if let Ok(v) = serde_json::from_str::(&line) { - if v.get("id").and_then(|i| i.as_u64()) == Some(0) && v.get("result").is_some() { - info!(attempt, "agent initialized"); - initialized = true; - break; - } - } - // Skip notifications and other non-response lines - continue; - } - Ok(None) => return Err(anyhow!("agent closed before initialize response")), - Err(_) => { - info!(attempt, "no initialize response, retrying..."); - break; - } - } - } - if initialized { break; } - } - if !initialized { - return Err(anyhow!("agent failed to respond to initialize after 5 attempts")); - } - - // Send session/new to kiro-cli to create a session - let sess_msg = serde_json::json!({ - "jsonrpc": "2.0", - "id": 1, - "method": "session/new", - "params": {"cwd": "/home/agent", "mcpServers": []} - }); - let sess_data = format!("{}\n", serde_json::to_string(&sess_msg)?); - { - let mut w = ws_write.lock().await; - let mut frame = Vec::with_capacity(1 + sess_data.len()); - frame.push(CHANNEL_STDIN); - frame.extend_from_slice(sess_data.as_bytes()); - w.send(Message::Binary(frame)).await?; - } - - // Wait for session/new response — skip notifications (up to 120s) - let kiro_session_id = { - let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(120); - let mut sid = String::from("default"); - loop { - let remaining = deadline.saturating_duration_since(tokio::time::Instant::now()); - if remaining.is_zero() { - info!("session/new timed out, using default session"); - break; - } - match tokio::time::timeout(remaining, line_rx.recv()).await { - Ok(Some(line)) => { - if let Ok(v) = serde_json::from_str::(&line) { - if v.get("id").and_then(|i| i.as_u64()) == Some(1) { - sid = v.pointer("/result/sessionId") - .and_then(|s| s.as_str()) - .unwrap_or("default") - .to_string(); - info!(kiro_session_id = %sid, "agent session created"); - break; - } - } - // Skip notifications - continue; - } - Ok(None) => return Err(anyhow!("agent closed before session/new response")), - Err(_) => { - info!("session/new timed out, using default session"); - break; - } - } - } - sid - }; - - Ok(ShellHandle { - ws_write, - line_rx, - _pump: pump, - runtime_session_id: session_id.to_string(), - kiro_session_id, - }) - } -} - -/// Build a WebSocket upgrade request with SigV4 Authorization header. -async fn build_signed_request( - arn: &str, - session_id: &str, - region: &str, -) -> Result<(http::Request<()>, String)> { - let config = aws_config::defaults(aws_config::BehaviorVersion::latest()) - .region(aws_config::Region::new(region.to_string())) - .load() - .await; - - let creds = config - .credentials_provider() - .ok_or_else(|| anyhow!("No AWS credentials found"))? - .provide_credentials() - .await - .map_err(|e| anyhow!("Failed to get credentials: {e}"))?; - - let identity = creds.into(); - - let encoded_arn = urlencoding::encode(arn); - let host = format!("bedrock-agentcore.{region}.amazonaws.com"); - let path = format!("/runtimes/{encoded_arn}/ws/shells"); - - // Deterministic shell_id from session_id - let hash = Sha256::digest(session_id.as_bytes()); - let shell_id = format!("oab-{}", hex::encode(&hash[..8])); - - let query = format!("qualifier=DEFAULT&shellId={shell_id}"); - let uri = format!("https://{host}{path}?{query}"); - - // Header-based SigV4 auth - let mut settings = SigningSettings::default(); - settings.expires_in = None; - settings.uri_path_normalization_mode = - aws_sigv4::http_request::UriPathNormalizationMode::Enabled; - - let signing_params = v4::SigningParams::builder() - .identity(&identity) - .region(region) - .name("bedrock-agentcore") - .time(SystemTime::now()) - .settings(settings) - .build()?; - - let headers = [ - ("host", host.as_str()), - ("x-amzn-bedrock-agentcore-runtime-session-id", session_id), - ]; - let signable = SignableRequest::new("GET", &uri, headers.into_iter(), SignableBody::empty())?; - let (instructions, _sig) = sign(signable, &signing_params.into())?.into_parts(); - - let wss_uri = format!("wss://{host}{path}?{query}"); - - // Build request with auth headers + WebSocket headers - let mut builder = http::Request::builder() - .method("GET") - .uri(&wss_uri) - .header("host", &host) - .header("x-amzn-bedrock-agentcore-runtime-session-id", session_id) - .header("connection", "Upgrade") - .header("upgrade", "websocket") - .header("sec-websocket-version", "13") - .header("sec-websocket-key", tokio_tungstenite::tungstenite::handshake::client::generate_key()); - - // Add SigV4 auth headers (x-amz-date, authorization) - for (name, value) in instructions.headers() { - builder = builder.header(name, value); - } - - let request = builder.body(())?; - Ok((request, host)) -} diff --git a/src/acp/connection.rs b/src/acp/connection.rs deleted file mode 100644 index 8df3451f4..000000000 --- a/src/acp/connection.rs +++ /dev/null @@ -1,937 +0,0 @@ -use crate::acp::protocol::{ - parse_config_options, ConfigOption, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, -}; -use anyhow::{anyhow, Result}; -use serde_json::{json, Value}; -use std::collections::HashMap; -use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::Arc; -use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}; -use tokio::process::{Child, ChildStdin}; -use tokio::sync::{mpsc, oneshot, Mutex}; -use tokio::task::JoinHandle; -use tracing::{debug, error, info, trace}; - -/// Pick the most permissive selectable permission option from ACP options. -fn pick_best_option(options: &[Value]) -> Option { - let mut fallback: Option<&Value> = None; - - for kind in ["allow_always", "allow_once"] { - if let Some(option) = options - .iter() - .find(|option| option.get("kind").and_then(|k| k.as_str()) == Some(kind)) - { - return option - .get("optionId") - .and_then(|id| id.as_str()) - .map(str::to_owned); - } - } - - for option in options { - let kind = option.get("kind").and_then(|k| k.as_str()); - if kind == Some("reject_once") || kind == Some("reject_always") { - continue; - } - fallback = Some(option); - break; - } - - fallback - .and_then(|option| option.get("optionId")) - .and_then(|id| id.as_str()) - .map(str::to_owned) -} - -/// Build a spec-compliant permission response with backward-compatible fallback. -fn build_permission_response(params: Option<&Value>) -> Value { - match params - .and_then(|p| p.get("options")) - .and_then(|options| options.as_array()) - { - None => json!({ - "outcome": { - "outcome": "selected", - "optionId": "allow_always" - } - }), - Some(options) => { - if let Some(option_id) = pick_best_option(options) { - json!({ - "outcome": { - "outcome": "selected", - "optionId": option_id - } - }) - } else { - json!({ - "outcome": { - "outcome": "cancelled" - } - }) - } - } - } -} - -fn expand_env(val: &str) -> String { - if val.starts_with("${") && val.ends_with('}') { - let key = &val[2..val.len() - 1]; - std::env::var(key).unwrap_or_default() - } else { - val.to_string() - } -} -use tokio::time::Instant; - -/// A content block for the ACP prompt — either text or image. -#[derive(Debug, Clone)] -pub enum ContentBlock { - Text { text: String }, - Image { media_type: String, data: String }, -} - -impl ContentBlock { - pub fn to_json(&self) -> Value { - match self { - ContentBlock::Text { text } => json!({ - "type": "text", - "text": text - }), - ContentBlock::Image { media_type, data } => json!({ - "type": "image", - "data": data, - "mimeType": media_type - }), - } - } -} - -pub struct AcpConnection { - _proc: Child, - /// PID of the direct child, used as the process group ID for cleanup. - child_pgid: Option, - stdin: Arc>, - next_id: AtomicU64, - pending: Arc>>>, - notify_tx: Arc>>>, - pub acp_session_id: Option, - pub supports_load_session: bool, - pub config_options: Vec, - pub last_active: Instant, - pub session_reset: bool, - _reader_handle: JoinHandle<()>, - _stderr_handle: Option>, -} - -/// Build the final set of env vars for the agent subprocess. -/// `explicit` ([agent].env) takes precedence over `inherit` ([agent].inherit_env). -/// Returns (merged env map, list of keys that were inherited from the process). -fn build_agent_env( - explicit: &std::collections::HashMap, - inherit_keys: &[String], -) -> (std::collections::HashMap, Vec) { - let mut result: std::collections::HashMap = std::collections::HashMap::new(); - let mut inherited: Vec = Vec::new(); - - for (k, v) in explicit { - result.insert(k.clone(), expand_env(v)); - } - - for key in inherit_keys { - if !result.contains_key(key) { - if let Ok(v) = std::env::var(key) { - result.insert(key.clone(), v); - inherited.push(key.clone()); - } - } - } - - (result, inherited) -} - -/// Reader loop body: reads JSON-RPC messages from `reader`, auto-replies -/// `session/request_permission` via `writer`, resolves pending responses, -/// and forwards notifications + stale id-bearing messages to the active -/// subscriber. Extracted as a free generic function so unit tests can drive -/// it with `tokio::io::duplex()` halves instead of a real child process. -pub(crate) async fn run_reader_loop( - reader: R, - writer: Arc>, - pending: Arc>>>, - notify_tx: Arc>>>, -) where - R: AsyncRead + Unpin + Send + 'static, - W: AsyncWrite + Unpin + Send + 'static, -{ - let mut reader = BufReader::new(reader); - let mut line = String::new(); - loop { - line.clear(); - match reader.read_line(&mut line).await { - Ok(0) => break, // EOF - Ok(_) => {} - Err(e) => { - error!("reader error: {e}"); - break; - } - } - let msg: JsonRpcMessage = match serde_json::from_str(line.trim()) { - Ok(m) => m, - Err(_) => continue, - }; - debug!(line = line.trim(), "acp_recv"); - - // Auto-reply session/request_permission - if msg.method.as_deref() == Some("session/request_permission") { - if let Some(id) = msg.id { - let title = msg - .params - .as_ref() - .and_then(|p| p.get("toolCall")) - .and_then(|t| t.get("title")) - .and_then(|t| t.as_str()) - .unwrap_or("?"); - - let outcome = build_permission_response(msg.params.as_ref()); - info!(title, %outcome, "auto-respond permission"); - let reply = JsonRpcResponse::new(id, outcome); - if let Ok(data) = serde_json::to_string(&reply) { - let mut w = writer.lock().await; - let _ = w.write_all(format!("{data}\n").as_bytes()).await; - let _ = w.flush().await; - } - } - continue; - } - - // Response (has id) → resolve pending AND forward to subscriber - if let Some(id) = msg.id { - let mut map = pending.lock().await; - if let Some(tx) = map.remove(&id) { - // Forward to subscriber so they see the completion - let sub = notify_tx.lock().await; - if let Some(ntx) = sub.as_ref() { - // Clone the essential fields for the subscriber - let _ = ntx.send(JsonRpcMessage { - id: Some(id), - method: None, - result: msg.result.clone(), - error: msg.error.clone(), - params: None, - }); - } - let _ = tx.send(msg); - continue; - } - // Stale id (#732): pending was already abandoned. Falls through - // to subscriber forwarding; the adapter recv loop filters by - // request_id so it can't leak into the next prompt. - trace!(request_id = id, "stale id-bearing message after abandon"); - } - - // Notification → forward to subscriber - let sub = notify_tx.lock().await; - if let Some(tx) = sub.as_ref() { - let _ = tx.send(msg); - } - } - - // Connection closed — resolve all pending with error - let mut map = pending.lock().await; - for (_, tx) in map.drain() { - let _ = tx.send(JsonRpcMessage { - id: None, - method: None, - result: None, - error: Some(crate::acp::protocol::JsonRpcError { - code: -1, - message: "connection closed".into(), - data: None, - }), - params: None, - }); - } - // Close the notify channel so rx.recv() returns None - let mut sub = notify_tx.lock().await; - *sub = None; -} - -impl AcpConnection { - pub async fn spawn( - command: &str, - args: &[String], - working_dir: &str, - env: &std::collections::HashMap, - inherit_env: &[String], - ) -> Result { - info!(cmd = command, ?args, cwd = working_dir, "spawning agent"); - - let mut cmd = tokio::process::Command::new(command); - cmd.args(args) - .stdin(std::process::Stdio::piped()) - .stdout(std::process::Stdio::piped()) - .stderr(std::process::Stdio::piped()) - .current_dir(working_dir); - // Create a new process group so we can kill the entire tree. - // SAFETY: setpgid is async-signal-safe (POSIX.1-2008) and called - // before exec. Return value checked — failure means the child won't - // have its own process group, so kill(-pgid) would be unsafe. - #[cfg(unix)] - unsafe { - cmd.pre_exec(|| { - if libc::setpgid(0, 0) != 0 { - return Err(std::io::Error::last_os_error()); - } - Ok(()) - }); - } - #[cfg(windows)] - { - cmd.creation_flags(0x00000200); // CREATE_NEW_PROCESS_GROUP - } - // Clear inherited env to prevent credential leakage (e.g. DISCORD_BOT_TOKEN). - // Only [agent].env values + essential baseline vars are passed through. - cmd.env_clear(); - // Preserve the real HOME so agents can find OAuth/auth files (~/.codex, - // ~/.claude, ~/.config/gh, etc.). working_dir is already set via - // current_dir() above and is not necessarily the user's home directory. - cmd.env( - "HOME", - std::env::var("HOME").unwrap_or_else(|_| working_dir.into()), - ); - cmd.env( - "PATH", - std::env::var("PATH").unwrap_or_else(|_| "/usr/local/bin:/usr/bin:/bin".into()), - ); - #[cfg(unix)] - { - cmd.env( - "USER", - std::env::var("USER").unwrap_or_else(|_| "agent".into()), - ); - } - #[cfg(windows)] - { - // Windows requires SystemRoot for DLL loading and basic OS functionality. - // USERPROFILE is the Windows equivalent of HOME. - cmd.env( - "USERPROFILE", - std::env::var("USERPROFILE").unwrap_or_else(|_| working_dir.into()), - ); - cmd.env( - "USERNAME", - std::env::var("USERNAME").unwrap_or_else(|_| "agent".into()), - ); - if let Ok(v) = std::env::var("SystemRoot") { - cmd.env("SystemRoot", v); - } - if let Ok(v) = std::env::var("SystemDrive") { - cmd.env("SystemDrive", v); - } - } - for (k, v) in env { - cmd.env(k, expand_env(v)); - } - // Inherit selected env vars from the OAB process (e.g. vars injected - // via Kubernetes envFrom). Keys already in [agent].env are skipped — - // explicit values take precedence. - let (agent_env, inherited_keys) = build_agent_env(env, inherit_env); - for (k, v) in &agent_env { - cmd.env(k, v); - } - if !agent_env.is_empty() { - let explicit_keys: Vec<&String> = env.keys().collect(); - tracing::warn!( - ?explicit_keys, - ?inherited_keys, - "[agent].env/inherit_env is set -- these values are accessible to the agent and could be exfiltrated via prompt injection" - ); - } - let mut proc = cmd - .spawn() - .map_err(|e| anyhow!("failed to spawn {command}: {e}"))?; - let child_pgid = proc.id().and_then(|pid| i32::try_from(pid).ok()); - - let stdout = proc.stdout.take().ok_or_else(|| anyhow!("no stdout"))?; - let stdin = proc.stdin.take().ok_or_else(|| anyhow!("no stdin"))?; - let stdin = Arc::new(Mutex::new(stdin)); - - // Capture agent stderr and log it (ACP spec: agents MAY write to stderr - // for logging; clients MAY capture or ignore it). - let stderr_handle = if let Some(stderr) = proc.stderr.take() { - let cmd_name = command.to_string(); - Some(tokio::spawn(async move { - let mut reader = BufReader::new(stderr); - let mut line = String::new(); - loop { - line.clear(); - match reader.read_line(&mut line).await { - Ok(0) => break, - Ok(_) => { - let trimmed = line.trim(); - if !trimmed.is_empty() { - let sanitized: String = trimmed.chars() - .filter(|c| !c.is_control() || *c == '\t') - .collect(); - if !sanitized.is_empty() { - tracing::warn!(agent = %cmd_name, "{sanitized}"); - } - } - } - Err(_) => break, - } - } - })) - } else { - None - }; - - let pending: Arc>>> = - Arc::new(Mutex::new(HashMap::new())); - let notify_tx: Arc>>> = - Arc::new(Mutex::new(None)); - - let reader_handle = tokio::spawn(run_reader_loop( - stdout, - stdin.clone(), - pending.clone(), - notify_tx.clone(), - )); - - Ok(Self { - _proc: proc, - child_pgid, - stdin, - next_id: AtomicU64::new(1), - pending, - notify_tx, - acp_session_id: None, - supports_load_session: false, - config_options: Vec::new(), - last_active: Instant::now(), - session_reset: false, - _reader_handle: reader_handle, - _stderr_handle: stderr_handle, - }) - } - - fn next_id(&self) -> u64 { - self.next_id.fetch_add(1, Ordering::Relaxed) - } - - pub(crate) async fn send_raw(&self, data: &str) -> Result<()> { - debug!(data = data.trim(), "acp_send"); - let mut w = self.stdin.lock().await; - w.write_all(data.as_bytes()).await?; - w.write_all(b"\n").await?; - w.flush().await?; - Ok(()) - } - - async fn send_request(&self, method: &str, params: Option) -> Result { - let id = self.next_id(); - let req = JsonRpcRequest::new(id, method, params); - let data = serde_json::to_string(&req)?; - - let (tx, rx) = oneshot::channel(); - self.pending.lock().await.insert(id, tx); - - self.send_raw(&data).await?; - - let timeout_secs = if method == "session/new" { 120 } else { 30 }; - let resp = tokio::time::timeout(std::time::Duration::from_secs(timeout_secs), rx) - .await - .map_err(|_| anyhow!("timeout waiting for {method} response"))? - .map_err(|_| anyhow!("channel closed waiting for {method}"))?; - - if let Some(err) = &resp.error { - return Err(anyhow!("{err}")); - } - Ok(resp) - } - - pub async fn initialize(&mut self) -> Result<()> { - let resp = self - .send_request( - "initialize", - Some(json!({ - "protocolVersion": 1, - "clientCapabilities": {}, - "clientInfo": {"name": "openab", "version": "0.1.0"}, - })), - ) - .await?; - - let result = resp.result.as_ref(); - let agent_name = result - .and_then(|r| r.get("agentInfo")) - .and_then(|a| a.get("name")) - .and_then(|n| n.as_str()) - .unwrap_or("unknown"); - self.supports_load_session = result - .and_then(|r| r.get("agentCapabilities")) - .and_then(|c| c.get("loadSession")) - .and_then(|v| v.as_bool()) - .unwrap_or(false); - info!( - agent = agent_name, - load_session = self.supports_load_session, - "initialized" - ); - Ok(()) - } - - pub async fn session_new(&mut self, cwd: &str) -> Result { - let resp = self - .send_request("session/new", Some(json!({"cwd": cwd, "mcpServers": []}))) - .await?; - - let session_id = resp - .result - .as_ref() - .and_then(|r| r.get("sessionId")) - .and_then(|s| s.as_str()) - .ok_or_else(|| anyhow!("no sessionId in session/new response"))? - .to_string(); - - info!(session_id = %session_id, "session created"); - self.acp_session_id = Some(session_id.clone()); - if let Some(result) = resp.result.as_ref() { - self.config_options = parse_config_options(result); - if !self.config_options.is_empty() { - info!(count = self.config_options.len(), "parsed configOptions"); - } - } - Ok(session_id) - } - - /// Set a config option (e.g. model, mode) via ACP session/set_config_option. - /// Returns the updated list of all config options. - pub async fn set_config_option( - &mut self, - config_id: &str, - value: &str, - ) -> Result> { - let session_id = self - .acp_session_id - .as_ref() - .ok_or_else(|| anyhow!("no session"))? - .clone(); - - let resp = self - .send_request( - "session/set_config_option", - Some(json!({ - "sessionId": session_id, - "configId": config_id, - "value": value, - })), - ) - .await; - - match resp { - Ok(r) => { - if let Some(result) = r.result.as_ref() { - self.config_options = parse_config_options(result); - } - info!(config_id, value, "config option set"); - } - Err(_) => { - // Fall back: send as a slash command (e.g. "/model claude-sonnet-4") - let cmd = format!("/{config_id} {value}"); - info!( - cmd, - "set_config_option not supported, falling back to prompt" - ); - let _resp = self - .send_request( - "session/prompt", - Some(json!({ - "sessionId": session_id, - "prompt": [{"type": "text", "text": cmd}], - })), - ) - .await?; - for opt in &mut self.config_options { - if opt.id == config_id { - opt.current_value = value.to_string(); - } - } - } - } - - Ok(self.config_options.clone()) - } - - /// Send a prompt with content blocks (text and/or images) and return a receiver - /// for streaming notifications. The final message on the channel will have id set - /// (the prompt response). - pub async fn session_prompt( - &mut self, - content_blocks: Vec, - ) -> Result<(mpsc::UnboundedReceiver, u64)> { - self.last_active = Instant::now(); - - let session_id = self - .acp_session_id - .as_ref() - .ok_or_else(|| anyhow!("no session"))?; - - let (tx, rx) = mpsc::unbounded_channel(); - *self.notify_tx.lock().await = Some(tx); - - let id = self.next_id(); - - // Convert content blocks to JSON - let prompt_json: Vec = content_blocks.iter().map(|b| b.to_json()).collect(); - - let req = JsonRpcRequest::new( - id, - "session/prompt", - Some(json!({ - "sessionId": session_id, - "prompt": prompt_json, - })), - ); - let data = serde_json::to_string(&req)?; - - let (resp_tx, _resp_rx) = oneshot::channel(); - self.pending.lock().await.insert(id, resp_tx); - - self.send_raw(&data).await?; - Ok((rx, id)) - } - - /// Call after prompt streaming is done to clean up subscriber. - pub async fn prompt_done(&mut self) { - *self.notify_tx.lock().await = None; - self.last_active = Instant::now(); - } - - /// Drop the pending entry for `request_id` and best-effort send - /// `session/cancel` as a JSON-RPC notification (no id; per ACP spec the - /// agent does not reply). Errors are swallowed: the agent process may - /// already be dead, in which case the stdin write fails harmlessly. - /// See #732. - pub async fn abandon_request(&self, request_id: u64) { - self.pending.lock().await.remove(&request_id); - let Some(session_id) = self.acp_session_id.as_deref() else { - return; - }; - let req = json!({ - "jsonrpc": "2.0", - "method": "session/cancel", - "params": {"sessionId": session_id}, - }); - if let Ok(data) = serde_json::to_string(&req) { - let _ = self.send_raw(&data).await; - } - } - - /// Return a clone of the stdin handle for lock-free cancel. - pub fn cancel_handle(&self) -> Arc> { - Arc::clone(&self.stdin) - } - - pub fn alive(&self) -> bool { - !self._reader_handle.is_finished() - } - - /// Resume a previous session by ID. Returns Ok(()) if the agent accepted - /// the load, or an error if it failed (caller should fall back to session/new). - pub async fn session_load(&mut self, session_id: &str, cwd: &str) -> Result<()> { - let resp = self - .send_request( - "session/load", - Some(json!({"sessionId": session_id, "cwd": cwd, "mcpServers": []})), - ) - .await?; - // Accept any non-error response as success - if resp.error.is_some() { - return Err(anyhow!("session/load rejected")); - } - info!(session_id, "session loaded"); - self.acp_session_id = Some(session_id.to_string()); - if let Some(result) = resp.result.as_ref() { - self.config_options = parse_config_options(result); - } - Ok(()) - } - - /// Kill the entire process group: SIGTERM → SIGKILL. - /// Uses std::thread (not tokio::spawn) so SIGKILL fires even during - /// runtime shutdown or panic unwinding. - fn kill_process_group(&mut self) { - let pgid = match self.child_pgid { - Some(pid) if pid > 0 => pid, - _ => return, - }; - #[cfg(unix)] - { - // Stage 1: SIGTERM the process group - unsafe { - libc::kill(-pgid, libc::SIGTERM); - } - // Stage 2: SIGKILL after brief grace (std::thread survives runtime shutdown) - std::thread::spawn(move || { - std::thread::sleep(std::time::Duration::from_millis(1500)); - unsafe { - libc::kill(-pgid, libc::SIGKILL); - } - }); - } - #[cfg(not(unix))] - { - let _ = pgid; // suppress unused warning on Windows - } - } -} - -impl Drop for AcpConnection { - fn drop(&mut self) { - if let Some(handle) = self._stderr_handle.take() { - handle.abort(); - } - self.kill_process_group(); - } -} - -#[cfg(test)] -mod tests { - use super::{build_agent_env, build_permission_response, pick_best_option}; - use serde_json::json; - - #[test] - fn picks_allow_always_over_other_options() { - let options = vec![ - json!({"kind": "allow_once", "optionId": "once"}), - json!({"kind": "allow_always", "optionId": "always"}), - json!({"kind": "reject_once", "optionId": "reject"}), - ]; - - assert_eq!(pick_best_option(&options), Some("always".to_string())); - } - - #[test] - fn falls_back_to_first_unknown_non_reject_kind() { - let options = vec![ - json!({"kind": "reject_once", "optionId": "reject"}), - json!({"kind": "workspace_write", "optionId": "workspace-write"}), - ]; - - assert_eq!( - pick_best_option(&options), - Some("workspace-write".to_string()) - ); - } - - #[test] - fn selects_bypass_permissions_for_exit_plan_mode() { - let options = vec![ - json!({"optionId": "bypassPermissions", "kind": "allow_always"}), - json!({"optionId": "acceptEdits", "kind": "allow_always"}), - json!({"optionId": "default", "kind": "allow_once"}), - json!({"optionId": "plan", "kind": "reject_once"}), - ]; - - assert_eq!( - pick_best_option(&options), - Some("bypassPermissions".to_string()) - ); - } - - #[test] - fn returns_none_when_only_reject_options_exist() { - let options = vec![ - json!({"kind": "reject_once", "optionId": "reject-once"}), - json!({"kind": "reject_always", "optionId": "reject-always"}), - ]; - - assert_eq!(pick_best_option(&options), None); - } - - #[test] - fn builds_cancelled_outcome_when_no_selectable_option_exists() { - let response = build_permission_response(Some(&json!({ - "options": [ - {"kind": "reject_once", "optionId": "reject-once"} - ] - }))); - - assert_eq!(response, json!({"outcome": {"outcome": "cancelled"}})); - } - - #[test] - fn builds_cancelled_when_options_array_is_empty() { - let response = build_permission_response(Some(&json!({ - "options": [] - }))); - - assert_eq!(response, json!({"outcome": {"outcome": "cancelled"}})); - } - - #[test] - fn falls_back_to_allow_always_when_options_are_missing() { - let response = build_permission_response(Some(&json!({ - "toolCall": {"title": "legacy"} - }))); - - assert_eq!( - response, - json!({"outcome": {"outcome": "selected", "optionId": "allow_always"}}) - ); - } - - #[test] - fn falls_back_to_allow_always_when_params_is_none() { - let response = build_permission_response(None); - - assert_eq!( - response, - json!({"outcome": {"outcome": "selected", "optionId": "allow_always"}}) - ); - } - - #[test] - fn explicit_env_takes_precedence_over_inherit_env() { - let key = "OAB_TEST_PRECEDENCE"; - std::env::set_var(key, "from_process"); - let mut explicit = std::collections::HashMap::new(); - explicit.insert(key.to_string(), "from_config".to_string()); - let inherit = vec![key.to_string()]; - - let (result, inherited) = build_agent_env(&explicit, &inherit); - - assert_eq!(result.get(key).unwrap(), "from_config"); - assert!(!inherited.contains(&key.to_string())); - std::env::remove_var(key); - } - - #[test] - fn inherit_env_copies_from_process() { - let key = "OAB_TEST_INHERIT"; - std::env::set_var(key, "process_value"); - let explicit = std::collections::HashMap::new(); - let inherit = vec![key.to_string()]; - - let (result, inherited) = build_agent_env(&explicit, &inherit); - - assert_eq!(result.get(key).unwrap(), "process_value"); - assert!(inherited.contains(&key.to_string())); - std::env::remove_var(key); - } - - #[test] - fn inherit_env_skips_missing_vars() { - let explicit = std::collections::HashMap::new(); - let inherit = vec!["OAB_TEST_NONEXISTENT_VAR_12345".to_string()]; - - let (result, inherited) = build_agent_env(&explicit, &inherit); - - assert!(!result.contains_key("OAB_TEST_NONEXISTENT_VAR_12345")); - assert!(inherited.is_empty()); - } -} - -#[cfg(test)] -mod reader_loop_tests { - use super::*; - use std::collections::HashMap; - use std::sync::Arc; - use tokio::io::{duplex, AsyncWriteExt}; - use tokio::sync::{mpsc, oneshot, Mutex}; - - /// #732 stale-id path: when a response arrives for an id the broker has - /// already abandoned, the reader must (a) not crash, (b) leave `pending` - /// untouched, and (c) still forward the message to whoever is currently - /// subscribed — the adapter recv loop is responsible for filtering by - /// request_id so the stray response never leaks into the next prompt. - #[tokio::test] - async fn stale_id_response_is_forwarded_without_pending_entry() { - let (mut agent_stdout_writer, agent_stdout_reader) = duplex(8 * 1024); - let (agent_stdin_writer, _agent_stdin_reader) = duplex(8 * 1024); - - let pending: Arc>>> = - Arc::new(Mutex::new(HashMap::new())); - let notify_tx: Arc>>> = - Arc::new(Mutex::new(None)); - - let (sub_tx, mut sub_rx) = mpsc::unbounded_channel(); - *notify_tx.lock().await = Some(sub_tx); - - let writer = Arc::new(Mutex::new(agent_stdin_writer)); - let handle = tokio::spawn(run_reader_loop( - agent_stdout_reader, - writer, - pending.clone(), - notify_tx.clone(), - )); - - let stale = b"{\"jsonrpc\":\"2.0\",\"id\":42,\"result\":{\"stopReason\":\"ok\"}}\n"; - agent_stdout_writer.write_all(stale).await.unwrap(); - agent_stdout_writer.flush().await.unwrap(); - - let forwarded = tokio::time::timeout( - std::time::Duration::from_secs(2), - sub_rx.recv(), - ) - .await - .expect("subscriber should receive stale message before timeout") - .expect("subscriber channel should not be closed"); - assert_eq!(forwarded.id, Some(42)); - assert!(pending.lock().await.is_empty()); - - drop(agent_stdout_writer); - handle.await.unwrap(); - } - - /// Matched-id path: when a response's id is in `pending`, the loop must - /// resolve the oneshot AND forward a copy to the subscriber so the - /// adapter's recv loop sees the completion. Guards against regressions - /// that would suppress the forward branch while keeping resolve. - #[tokio::test] - async fn matched_id_response_resolves_pending_and_forwards() { - let (mut agent_stdout_writer, agent_stdout_reader) = duplex(8 * 1024); - let (agent_stdin_writer, _agent_stdin_reader) = duplex(8 * 1024); - - let pending: Arc>>> = - Arc::new(Mutex::new(HashMap::new())); - let notify_tx: Arc>>> = - Arc::new(Mutex::new(None)); - - let (resp_tx, resp_rx) = oneshot::channel(); - pending.lock().await.insert(7, resp_tx); - - let (sub_tx, mut sub_rx) = mpsc::unbounded_channel(); - *notify_tx.lock().await = Some(sub_tx); - - let writer = Arc::new(Mutex::new(agent_stdin_writer)); - let handle = tokio::spawn(run_reader_loop( - agent_stdout_reader, - writer, - pending.clone(), - notify_tx.clone(), - )); - - let payload = b"{\"jsonrpc\":\"2.0\",\"id\":7,\"result\":{\"stopReason\":\"end_turn\"}}\n"; - agent_stdout_writer.write_all(payload).await.unwrap(); - agent_stdout_writer.flush().await.unwrap(); - - let resolved = tokio::time::timeout(std::time::Duration::from_secs(2), resp_rx) - .await - .expect("oneshot should resolve") - .expect("oneshot should not be cancelled"); - assert_eq!(resolved.id, Some(7)); - - let forwarded = tokio::time::timeout(std::time::Duration::from_secs(2), sub_rx.recv()) - .await - .expect("subscriber should receive forwarded copy") - .expect("subscriber channel should not be closed"); - assert_eq!(forwarded.id, Some(7)); - assert!(pending.lock().await.is_empty()); - - drop(agent_stdout_writer); - handle.await.unwrap(); - } -} diff --git a/src/acp/mod.rs b/src/acp/mod.rs deleted file mode 100644 index b6a60eaaf..000000000 --- a/src/acp/mod.rs +++ /dev/null @@ -1,9 +0,0 @@ -#[cfg(feature = "agentcore")] -pub mod agentcore; -pub mod connection; -pub mod pool; -pub mod protocol; - -pub use connection::ContentBlock; -pub use pool::SessionPool; -pub use protocol::{classify_notification, AcpEvent}; diff --git a/src/acp/pool.rs b/src/acp/pool.rs deleted file mode 100644 index d97397169..000000000 --- a/src/acp/pool.rs +++ /dev/null @@ -1,622 +0,0 @@ -use crate::acp::connection::AcpConnection; -use crate::acp::protocol::ConfigOption; -use crate::config::AgentConfig; -use anyhow::{anyhow, Result}; -use std::collections::HashMap; -use std::path::{Path, PathBuf}; -use std::sync::Arc; -use tokio::sync::{Mutex, RwLock}; -use tokio::time::Instant; -use tracing::{info, warn}; - -/// Combined state protected by a single lock to prevent deadlocks. -/// Lock ordering: never await a per-connection mutex while holding `state`. -struct PoolState { - /// Active connections: thread_key → AcpConnection handle. - active: HashMap>>, - /// Lock-free cancel handles: thread_key → (stdin, session_id). - /// Stored separately so cancel can work without locking the connection. - cancel_handles: HashMap>, String)>, - /// Suspended sessions: thread_key → ACP sessionId. - /// Used at runtime to decide which thread can be resumed via `session/load` - /// because it no longer has a live in-memory connection. - suspended: HashMap, - /// Persisted resumable sessions: thread_key → ACP sessionId. - /// Includes both suspended sessions and active sessions so a process restart - /// can recover any live thread via `session/load`. - persisted: HashMap, - /// Serializes create/resume work per thread so rapid same-thread requests - /// cannot race each other into duplicate `session/load` attempts. - creating: HashMap>>, - /// Per-session working directory overrides (from control directives). - /// thread_key → canonical workspace path. - session_workdirs: HashMap, -} - -pub struct SessionPool { - state: RwLock, - config: AgentConfig, - max_sessions: usize, - mapping_path: PathBuf, - meta_path: PathBuf, -} - -type EvictionCandidate = (String, Arc>, Instant, Option); - -fn remove_if_same_handle( - map: &mut HashMap>>, - key: &str, - expected: &Arc>, -) -> Option>> { - let should_remove = map - .get(key) - .is_some_and(|current| Arc::ptr_eq(current, expected)); - if should_remove { - map.remove(key) - } else { - None - } -} - -fn get_or_insert_gate(map: &mut HashMap>>, key: &str) -> Arc> { - map.entry(key.to_string()) - .or_insert_with(|| Arc::new(Mutex::new(()))) - .clone() -} - -impl SessionPool { - pub fn new(config: AgentConfig, max_sessions: usize) -> Self { - let openab_dir = std::env::var("HOME") - .map(PathBuf::from) - .unwrap_or_else(|_| PathBuf::from("/tmp")) - .join(".openab"); - let _ = std::fs::create_dir_all(&openab_dir); - let mapping_path = openab_dir.join("thread_map.json"); - let meta_path = openab_dir.join("session_meta.json"); - let suspended = Self::load_mapping(&mapping_path); - let session_workdirs = Self::load_mapping(&meta_path); - Self { - state: RwLock::new(PoolState { - active: HashMap::new(), - cancel_handles: HashMap::new(), - persisted: suspended.clone(), - suspended, - creating: HashMap::new(), - session_workdirs, - }), - config, - max_sessions, - mapping_path, - meta_path, - } - } - - fn load_mapping(path: &Path) -> HashMap { - match std::fs::read_to_string(path) { - Ok(data) => serde_json::from_str(&data).unwrap_or_else(|e| { - warn!(path = %path.display(), error = %e, "corrupt mapping file, starting fresh"); - HashMap::new() - }), - Err(_) => HashMap::new(), - } - } - - fn save_mapping(&self, persisted: &HashMap) { - let data = match serde_json::to_string_pretty(persisted) { - Ok(d) => d, - Err(e) => { - warn!(error = %e, "failed to serialize thread mapping"); - return; - } - }; - let tmp = self.mapping_path.with_extension("json.tmp"); - if let Err(e) = - std::fs::write(&tmp, &data).and_then(|_| std::fs::rename(&tmp, &self.mapping_path)) - { - warn!(path = %self.mapping_path.display(), error = %e, "failed to persist thread mapping"); - } - } - - fn save_meta(&self, workdirs: &HashMap) { - let data = match serde_json::to_string_pretty(workdirs) { - Ok(d) => d, - Err(e) => { - warn!(error = %e, "failed to serialize session metadata"); - return; - } - }; - let tmp = self.meta_path.with_extension("json.tmp"); - if let Err(e) = - std::fs::write(&tmp, &data).and_then(|_| std::fs::rename(&tmp, &self.meta_path)) - { - warn!(path = %self.meta_path.display(), error = %e, "failed to persist session metadata"); - } - } - - /// Check if session state exists for this thread (active, suspended, or persisted). - #[allow(dead_code)] - pub async fn has_active_session(&self, thread_id: &str) -> bool { - let state = self.state.read().await; - // Any of these means the thread already has session state. - if state.suspended.contains_key(thread_id) || state.persisted.contains_key(thread_id) { - return true; - } - if let Some(conn) = state.active.get(thread_id) { - match conn.try_lock() { - Ok(c) => return c.alive(), - Err(_) => return true, // lock held = connection busy streaming = alive - } - } - false - } - - pub async fn get_or_create( - &self, - thread_id: &str, - working_dir_override: Option<&str>, - ) -> Result { - let create_gate = { - let mut state = self.state.write().await; - get_or_insert_gate(&mut state.creating, thread_id) - }; - let _create_guard = create_gate.lock().await; - - let (existing, saved_session_id) = { - let state = self.state.read().await; - ( - state.active.get(thread_id).cloned(), - state.suspended.get(thread_id).cloned(), - ) - }; - - let had_existing = existing.is_some(); - let mut saved_session_id = saved_session_id; - if let Some(conn) = existing.clone() { - let conn = conn.lock().await; - if conn.alive() { - return Ok(false); - } - if saved_session_id.is_none() { - saved_session_id = conn.acp_session_id.clone(); - } - } - - // Snapshot active handles so we can inspect them outside the state lock. - let snapshot: Vec<(String, Arc>)> = { - let state = self.state.read().await; - state - .active - .iter() - .map(|(k, v)| (k.clone(), Arc::clone(v))) - .collect() - }; - - let mut eviction_candidate: Option = None; - let mut skipped_locked_candidates = 0usize; - for (key, conn) in snapshot { - if key == thread_id { - continue; - } - let conn_handle = Arc::clone(&conn); - let Ok(conn) = conn.try_lock() else { - skipped_locked_candidates += 1; - continue; - }; - let candidate = ( - key, - conn_handle, - conn.last_active, - conn.acp_session_id.clone(), - ); - match &eviction_candidate { - Some((_, _, oldest_last_active, _)) if candidate.2 >= *oldest_last_active => {} - _ => eviction_candidate = Some(candidate), - } - } - - // Resolve effective working directory: stored per-session > explicit override > global config. - // Stored value has highest priority to enforce immutability (ADR §4.5). - let stored_workdir = { - let state = self.state.read().await; - state.session_workdirs.get(thread_id).cloned() - }; - - let effective_workdir = if let Some(stored) = stored_workdir { - stored - } else if let Some(wd) = working_dir_override { - wd.to_string() - } else { - self.config.working_dir.clone() - }; - - // Build the replacement connection outside the state lock so one stuck - // initialization does not block all unrelated sessions. - let mut new_conn = AcpConnection::spawn( - &self.config.command, - &self.config.args, - &effective_workdir, - &self.config.env, - &self.config.inherit_env, - ) - .await?; - - new_conn.initialize().await?; - - let mut resumed = false; - if let Some(ref sid) = saved_session_id { - if new_conn.supports_load_session { - match new_conn.session_load(sid, &effective_workdir).await { - Ok(()) => { - info!(thread_id, session_id = %sid, "session resumed via session/load"); - resumed = true; - } - Err(e) => { - warn!(thread_id, session_id = %sid, error = %e, "session/load failed, creating new session"); - } - } - } - } - - if !resumed { - new_conn.session_new(&effective_workdir).await?; - // Surface the reset banner both for restored sessions and for stale - // live entries that died before we could recover a resumable - // session id. In both cases the caller is continuing after an - // unexpected session loss. - if had_existing || saved_session_id.is_some() { - new_conn.session_reset = true; - } - } - - let cancel_handle = new_conn.cancel_handle(); - let cancel_session_id = new_conn.acp_session_id.clone().unwrap_or_default(); - let new_conn = Arc::new(Mutex::new(new_conn)); - - let mut state = self.state.write().await; - - // Another task may have created a healthy connection while we were - // initializing this one. - if let Some(existing) = state.active.get(thread_id).cloned() { - let Ok(existing) = existing.try_lock() else { - return Ok(false); - }; - if existing.alive() { - return Ok(false); - } - warn!(thread_id, "stale connection, rebuilding"); - drop(existing); - state.active.remove(thread_id); - state.cancel_handles.remove(thread_id); - } - - if state.active.len() >= self.max_sessions { - if let Some((key, expected_conn, _, sid)) = eviction_candidate { - if remove_if_same_handle(&mut state.active, &key, &expected_conn).is_some() { - state.cancel_handles.remove(&key); - info!(evicted = %key, "pool full, suspending oldest idle session"); - if let Some(sid) = sid { - state.persisted.insert(key.clone(), sid.clone()); - state.suspended.insert(key, sid); - } else { - state.persisted.remove(&key); - } - } else { - warn!(evicted = %key, "pool full but eviction candidate changed before removal"); - } - } else if skipped_locked_candidates > 0 { - warn!( - max_sessions = self.max_sessions, - skipped_locked_candidates, - "pool full but all other sessions were busy during eviction scan" - ); - } - } - - if state.active.len() >= self.max_sessions { - return Err(anyhow!("pool exhausted ({} sessions)", self.max_sessions)); - } - - if cancel_session_id.is_empty() { - state.persisted.remove(thread_id); - } else { - state - .persisted - .insert(thread_id.to_string(), cancel_session_id.clone()); - } - state.suspended.remove(thread_id); - state.active.insert(thread_id.to_string(), new_conn); - if !cancel_session_id.is_empty() { - state - .cancel_handles - .insert(thread_id.to_string(), (cancel_handle, cancel_session_id)); - } - self.save_mapping(&state.persisted); - - // Persist workspace override only after session spawn succeeded (口渡 F2). - if working_dir_override.is_some() { - state - .session_workdirs - .entry(thread_id.to_string()) - .or_insert_with(|| effective_workdir.clone()); - self.save_meta(&state.session_workdirs); - } - - // Return true only for genuinely new sessions — not resumed or reconnected ones. - // A session with prior state (saved_session_id or had_existing) is a resume, - // even if we had to spawn a new ACP process. ADR §2.2: directives are first-message-only. - let is_fresh = !had_existing && saved_session_id.is_none(); - Ok(is_fresh) - } - - /// Get mutable access to a connection. Caller must have called get_or_create first. - /// - /// Only the per-connection `Mutex` is held during `f`; the pool-level - /// `RwLock` is acquired briefly (read-only) to look up the `Arc` and then - /// released, so other connections can be used concurrently. - pub async fn with_connection(&self, thread_id: &str, f: F) -> Result - where - F: for<'a> FnOnce( - &'a mut AcpConnection, - ) -> std::pin::Pin< - Box> + Send + 'a>, - >, - { - let conn = { - let state = self.state.read().await; - state - .active - .get(thread_id) - .cloned() - .ok_or_else(|| anyhow!("no connection for thread {thread_id}"))? - }; - - let mut conn = conn.lock().await; - f(&mut conn).await - } - - /// Get cached configOptions for a session (e.g. available models). - pub async fn get_config_options(&self, thread_id: &str) -> Vec { - let state = self.state.read().await; - let conn = match state.active.get(thread_id) { - Some(c) => c.clone(), - None => return Vec::new(), - }; - drop(state); - let conn = conn.lock().await; - conn.config_options.clone() - } - - /// Set a config option (e.g. model) via ACP and return updated options. - pub async fn set_config_option( - &self, - thread_id: &str, - config_id: &str, - value: &str, - ) -> Result> { - let conn = { - let state = self.state.read().await; - state - .active - .get(thread_id) - .cloned() - .ok_or_else(|| anyhow!("no connection for thread {thread_id}"))? - }; - let mut conn = conn.lock().await; - conn.set_config_option(config_id, value).await - } - - /// Cancel the current in-flight operation for a session. - /// Uses pre-stored cancel handles to avoid locking the connection (which is held during streaming). - pub async fn cancel_session(&self, thread_id: &str) -> Result<()> { - let (stdin, session_id) = { - let state = self.state.read().await; - state - .cancel_handles - .get(thread_id) - .cloned() - .ok_or_else(|| anyhow!("no session for thread {thread_id}"))? - }; - let data = serde_json::to_string(&serde_json::json!({ - "jsonrpc": "2.0", - "method": "session/cancel", - "params": {"sessionId": session_id} - }))?; - tracing::info!(session_id, "sending session/cancel"); - use tokio::io::AsyncWriteExt; - let mut w = stdin.lock().await; - w.write_all(data.as_bytes()).await?; - w.write_all(b"\n").await?; - w.flush().await?; - Ok(()) - } - - /// Reset a session: cancel any in-flight operation, remove the active connection, - /// and clear all suspended state. The ACP process will be killed once the last - /// Arc reference is dropped (after streaming finishes). The next message will - /// trigger a fresh `get_or_create` with a new ACP session. - pub async fn reset_session(&self, thread_id: &str) -> Result<()> { - // Send session/cancel via the lock-free stdin handle first. - // This stops in-flight streaming even while with_connection() holds the - // connection mutex, so the old process finishes promptly. - if let Some((stdin, session_id)) = { - let state = self.state.read().await; - state.cancel_handles.get(thread_id).cloned() - } { - let data = serde_json::to_string(&serde_json::json!({ - "jsonrpc": "2.0", - "method": "session/cancel", - "params": {"sessionId": session_id} - }))?; - tracing::info!(session_id, "reset: sending session/cancel"); - use tokio::io::AsyncWriteExt; - let mut w = stdin.lock().await; - let _ = w.write_all(data.as_bytes()).await; - let _ = w.write_all(b"\n").await; - let _ = w.flush().await; - } - - let mut state = self.state.write().await; - let had_active = state.active.remove(thread_id).is_some(); - state.cancel_handles.remove(thread_id); - state.suspended.remove(thread_id); - state.persisted.remove(thread_id); - state.creating.remove(thread_id); - state.session_workdirs.remove(thread_id); - self.save_mapping(&state.persisted); - self.save_meta(&state.session_workdirs); - if had_active { - info!(thread_id, "session reset"); - Ok(()) - } else { - Err(anyhow!("no session for thread {thread_id}")) - } - } - - pub async fn cleanup_idle(&self, ttl_secs: u64) { - let cutoff = Instant::now() - std::time::Duration::from_secs(ttl_secs); - - let snapshot: Vec<(String, Arc>)> = { - let state = self.state.read().await; - state - .active - .iter() - .map(|(k, v)| (k.clone(), Arc::clone(v))) - .collect() - }; - - let mut stale = Vec::new(); - for (key, conn) in snapshot { - // Skip active sessions for this cleanup round instead of waiting on - // their per-connection mutex. A busy session is not idle. - let conn_handle = Arc::clone(&conn); - let Ok(conn) = conn.try_lock() else { - continue; - }; - if conn.last_active < cutoff || !conn.alive() { - stale.push((key, conn_handle, conn.acp_session_id.clone())); - } - } - - if stale.is_empty() { - return; - } - - let mut state = self.state.write().await; - for (key, expected_conn, sid) in stale { - if remove_if_same_handle(&mut state.active, &key, &expected_conn).is_some() { - info!(thread_id = %key, "cleaning up idle session"); - state.cancel_handles.remove(&key); - if let Some(sid) = sid { - state.persisted.insert(key.clone(), sid.clone()); - state.suspended.insert(key, sid); - } else { - state.persisted.remove(&key); - state.session_workdirs.remove(&key); - } - } - } - self.save_mapping(&state.persisted); - self.save_meta(&state.session_workdirs); - } - - pub async fn shutdown(&self) { - // Snapshot active handles, then drop state lock before awaiting - // per-connection mutexes (lock ordering: never hold state while - // awaiting a connection lock). - let snapshot: Vec<(String, Arc>)> = { - let state = self.state.read().await; - state - .active - .iter() - .map(|(k, v)| (k.clone(), Arc::clone(v))) - .collect() - }; - - let mut session_ids: Vec<(String, String)> = Vec::new(); - for (key, conn) in snapshot { - let conn = conn.lock().await; - if let Some(sid) = conn.acp_session_id.clone() { - session_ids.push((key, sid)); - } - } - - let mut state = self.state.write().await; - for (key, sid) in session_ids { - state.persisted.insert(key.clone(), sid.clone()); - state.suspended.insert(key, sid); - } - self.save_mapping(&state.persisted); - let count = state.active.len(); - state.active.clear(); - state.cancel_handles.clear(); - info!(count, "pool shutdown complete"); - } -} - -#[cfg(test)] -mod tests { - use super::{get_or_insert_gate, remove_if_same_handle}; - use std::collections::HashMap; - use std::sync::Arc; - use tokio::sync::Mutex; - - #[test] - fn remove_if_same_handle_removes_matching_entry() { - let expected = Arc::new(Mutex::new(1_u8)); - let mut map = HashMap::from([("thread".to_string(), Arc::clone(&expected))]); - - let removed = remove_if_same_handle(&mut map, "thread", &expected); - - assert!(removed.is_some()); - assert!(map.is_empty()); - } - - #[test] - fn remove_if_same_handle_keeps_replaced_entry() { - let stale = Arc::new(Mutex::new(1_u8)); - let fresh = Arc::new(Mutex::new(2_u8)); - let mut map = HashMap::from([("thread".to_string(), Arc::clone(&fresh))]); - - let removed = remove_if_same_handle(&mut map, "thread", &stale); - - assert!(removed.is_none()); - let current = map.get("thread").expect("entry should remain"); - assert!(Arc::ptr_eq(current, &fresh)); - } - - #[test] - fn get_or_insert_gate_reuses_gate_for_same_thread() { - let mut map = HashMap::new(); - - let first = get_or_insert_gate(&mut map, "thread"); - let second = get_or_insert_gate(&mut map, "thread"); - - assert!(Arc::ptr_eq(&first, &second)); - assert_eq!(map.len(), 1); - } - - #[test] - fn persisted_mapping_can_include_active_and_suspended_sessions() { - let persisted = HashMap::from([ - ("active-thread".to_string(), "session-active".to_string()), - ( - "suspended-thread".to_string(), - "session-suspended".to_string(), - ), - ]); - - let serialized = - serde_json::to_string_pretty(&persisted).expect("serialize persisted mapping"); - let roundtrip: HashMap = - serde_json::from_str(&serialized).expect("deserialize persisted mapping"); - - assert_eq!( - roundtrip.get("active-thread"), - Some(&"session-active".to_string()) - ); - assert_eq!( - roundtrip.get("suspended-thread"), - Some(&"session-suspended".to_string()) - ); - } -} diff --git a/src/acp/protocol.rs b/src/acp/protocol.rs deleted file mode 100644 index 099d98b71..000000000 --- a/src/acp/protocol.rs +++ /dev/null @@ -1,406 +0,0 @@ -use serde::{Deserialize, Serialize}; -use serde_json::Value; - -// --- Outgoing --- - -#[derive(Debug, Serialize)] -pub struct JsonRpcRequest { - pub jsonrpc: &'static str, - pub id: u64, - pub method: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub params: Option, -} - -impl JsonRpcRequest { - pub fn new(id: u64, method: impl Into, params: Option) -> Self { - Self { - jsonrpc: "2.0", - id, - method: method.into(), - params, - } - } -} - -#[derive(Debug, Serialize)] -pub struct JsonRpcResponse { - pub jsonrpc: &'static str, - pub id: u64, - pub result: Value, -} - -impl JsonRpcResponse { - pub fn new(id: u64, result: Value) -> Self { - Self { - jsonrpc: "2.0", - id, - result, - } - } -} - -// --- Incoming --- - -#[derive(Debug, Deserialize)] -pub struct JsonRpcMessage { - pub id: Option, - pub method: Option, - pub result: Option, - pub error: Option, - pub params: Option, -} - -#[derive(Debug, Clone, Deserialize)] -pub struct JsonRpcError { - pub code: i64, - pub message: String, - /// Optional structured data from the agent (JSON-RPC `error.data`). - /// Agents like codex-acp include `{"message": "...", "codex_error_info": "..."}`. - pub data: Option, -} - -impl JsonRpcError { - /// Extract a human-readable detail from `error.data.message` if present. - /// - /// The `"message"` key is a convention used by codex-acp and aligns with - /// common JSON-RPC practice, but is NOT mandated by the ACP spec. - /// Other agents may use `"detail"`, `"reason"`, etc. — extend here if needed. - pub fn data_message(&self) -> Option<&str> { - self.data - .as_ref() - .and_then(|d| d.get("message")) - .and_then(|m| m.as_str()) - } -} - -impl std::fmt::Display for JsonRpcError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "JSON-RPC error {}: {}", self.code, self.message)?; - if let Some(detail) = self.data_message() { - write!(f, " — {detail}")?; - } - Ok(()) - } -} - -// --- ACP configOptions (session-level configuration) --- - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ConfigOptionValue { - pub value: String, - pub name: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ConfigOption { - pub id: String, - pub name: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub category: Option, - #[serde(rename = "type")] - pub option_type: String, - pub current_value: String, - pub options: Vec, -} - -/// Extract configOptions from a JSON-RPC result value. -/// Supports standard `configOptions` and kiro-cli's `models`/`modes` fallback. -pub fn parse_config_options(result: &Value) -> Vec { - if let Some(opts) = result - .get("configOptions") - .and_then(|v| serde_json::from_value::>(v.clone()).ok()) - { - if !opts.is_empty() { - return opts; - } - } - - // Kiro-cli fallback: parse models/modes format - let mut options = Vec::new(); - - if let Some(models) = result.get("models") { - let current = models - .get("currentModelId") - .and_then(|v| v.as_str()) - .unwrap_or(""); - if let Some(available) = models.get("availableModels").and_then(|v| v.as_array()) { - let values: Vec = available - .iter() - .filter_map(|m| { - let id = m - .get("modelId") - .or_else(|| m.get("id")) - .and_then(|v| v.as_str())?; - let name = m.get("name").and_then(|v| v.as_str()).unwrap_or(id); - Some(ConfigOptionValue { - value: id.to_string(), - name: name.to_string(), - description: m - .get("description") - .and_then(|v| v.as_str()) - .map(String::from), - }) - }) - .collect(); - if !values.is_empty() { - options.push(ConfigOption { - id: "model".to_string(), - name: "Model".to_string(), - description: Some("AI model selection".to_string()), - category: Some("model".to_string()), - option_type: "enum".to_string(), - current_value: current.to_string(), - options: values, - }); - } - } - } - - if let Some(modes) = result.get("modes") { - let current = modes - .get("currentModeId") - .and_then(|v| v.as_str()) - .unwrap_or(""); - if let Some(available) = modes.get("availableModes").and_then(|v| v.as_array()) { - let values: Vec = available - .iter() - .filter_map(|m| { - let id = m.get("id").and_then(|v| v.as_str())?; - let name = m.get("name").and_then(|v| v.as_str()).unwrap_or(id); - Some(ConfigOptionValue { - value: id.to_string(), - name: name.to_string(), - description: m - .get("description") - .and_then(|v| v.as_str()) - .map(String::from), - }) - }) - .collect(); - if !values.is_empty() { - options.push(ConfigOption { - id: "agent".to_string(), - name: "Agent".to_string(), - description: Some("Agent mode selection".to_string()), - category: Some("agent".to_string()), - option_type: "enum".to_string(), - current_value: current.to_string(), - options: values, - }); - } - } - } - - options -} - -// --- ACP notification classification --- - -#[derive(Debug)] -pub enum AcpEvent { - Text(String), - Thinking, - ToolStart { - id: String, - title: String, - }, - ToolDone { - id: String, - title: String, - status: String, - }, - ConfigUpdate { - options: Vec, - }, - Status, -} - -pub fn classify_notification(msg: &JsonRpcMessage) -> Option { - let params = msg.params.as_ref()?; - let update = params.get("update")?; - let session_update = update.get("sessionUpdate")?.as_str()?; - - // toolCallId is the stable identity across tool_call → tool_call_update - // events for the same tool invocation. claude-agent-acp emits the first - // event before the input fields are streamed in (so the title falls back - // to "Terminal" / "Edit" / etc.) and refines them in a later - // tool_call_update; without the id we can't tell those events belong to - // the same call and end up rendering placeholder + refined as two - // separate lines. - let tool_id = update - .get("toolCallId") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - - match session_update { - "agent_message_chunk" => { - let text = update.get("content")?.get("text")?.as_str()?; - Some(AcpEvent::Text(text.to_string())) - } - "agent_thought_chunk" => Some(AcpEvent::Thinking), - "tool_call" => { - let title = update - .get("title") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - Some(AcpEvent::ToolStart { id: tool_id, title }) - } - "tool_call_update" => { - let title = update - .get("title") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - let status = update - .get("status") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - if status == "completed" || status == "failed" { - Some(AcpEvent::ToolDone { - id: tool_id, - title, - status, - }) - } else { - Some(AcpEvent::ToolStart { id: tool_id, title }) - } - } - "plan" => Some(AcpEvent::Status), - "config_option_update" => { - let options = parse_config_options(update); - Some(AcpEvent::ConfigUpdate { options }) - } - _ => None, - } -} - -#[cfg(test)] -mod tests { - use super::*; - use serde_json::json; - - #[test] - fn parse_standard_config_options() { - let result = json!({ - "configOptions": [{ - "id": "model", - "name": "Model", - "type": "enum", - "currentValue": "claude-sonnet-4", - "options": [ - {"value": "claude-sonnet-4", "name": "Sonnet 4"}, - {"value": "claude-opus-4", "name": "Opus 4"} - ] - }] - }); - let opts = parse_config_options(&result); - assert_eq!(opts.len(), 1); - assert_eq!(opts[0].id, "model"); - assert_eq!(opts[0].current_value, "claude-sonnet-4"); - assert_eq!(opts[0].options.len(), 2); - } - - #[test] - fn parse_kiro_models_fallback() { - let result = json!({ - "models": { - "currentModelId": "m1", - "availableModels": [ - {"modelId": "m1", "name": "Model One"}, - {"modelId": "m2", "name": "Model Two"} - ] - } - }); - let opts = parse_config_options(&result); - assert_eq!(opts.len(), 1); - assert_eq!(opts[0].id, "model"); - assert_eq!(opts[0].category.as_deref(), Some("model")); - assert_eq!(opts[0].current_value, "m1"); - assert_eq!(opts[0].options.len(), 2); - } - - #[test] - fn parse_kiro_modes_fallback() { - let result = json!({ - "modes": { - "currentModeId": "default", - "availableModes": [ - {"id": "default", "name": "Default"}, - {"id": "planner", "name": "Planner"} - ] - } - }); - let opts = parse_config_options(&result); - assert_eq!(opts.len(), 1); - assert_eq!(opts[0].id, "agent"); - assert_eq!(opts[0].category.as_deref(), Some("agent")); - assert_eq!(opts[0].current_value, "default"); - } - - #[test] - fn parse_kiro_models_and_modes() { - let result = json!({ - "models": { - "currentModelId": "m1", - "availableModels": [{"modelId": "m1", "name": "M1"}] - }, - "modes": { - "currentModeId": "default", - "availableModes": [{"id": "default", "name": "Default"}] - } - }); - let opts = parse_config_options(&result); - assert_eq!(opts.len(), 2); - assert_eq!(opts[0].id, "model"); - assert_eq!(opts[1].id, "agent"); - } - - #[test] - fn parse_standard_takes_precedence_over_kiro() { - let result = json!({ - "configOptions": [{ - "id": "model", - "name": "Model", - "type": "enum", - "currentValue": "standard", - "options": [{"value": "standard", "name": "Standard"}] - }], - "models": { - "currentModelId": "kiro", - "availableModels": [{"modelId": "kiro", "name": "Kiro"}] - } - }); - let opts = parse_config_options(&result); - assert_eq!(opts.len(), 1); - assert_eq!(opts[0].current_value, "standard"); - } - - #[test] - fn parse_empty_result() { - let opts = parse_config_options(&json!({})); - assert!(opts.is_empty()); - } - - #[test] - fn parse_empty_config_options_falls_through_to_kiro() { - let result = json!({ - "configOptions": [], - "models": { - "currentModelId": "m1", - "availableModels": [{"modelId": "m1", "name": "M1"}] - } - }); - let opts = parse_config_options(&result); - assert_eq!(opts.len(), 1); - assert_eq!(opts[0].id, "model"); - } -} diff --git a/src/adapter.rs b/src/adapter.rs deleted file mode 100644 index 8b77242b5..000000000 --- a/src/adapter.rs +++ /dev/null @@ -1,1659 +0,0 @@ -use anyhow::Result; -use async_trait::async_trait; -use serde::Serialize; -use std::sync::Arc; -use tracing::{error, warn}; - -use crate::acp::{classify_notification, AcpEvent, ContentBlock, SessionPool}; -use crate::config::{ReactionsConfig, ToolDisplay}; -use crate::error_display::{format_coded_error, format_user_error}; -use crate::format; -use crate::markdown::{self, TableMode}; -use crate::reactions::StatusReactionController; - -// --- Output directive parsing --- - -/// Parsed directives from agent output header block. -/// Consecutive `[[key:value]]` lines at the start of output are directives. -#[derive(Default, Debug)] -pub struct OutputDirectives { - /// Message ID to reply to (Discord: message_reference) - pub reply_to: Option, -} - -/// Parse `[[key:value]]` directives from the beginning of agent output. -/// Returns parsed directives and the remaining content (directives stripped). -pub fn parse_output_directives(content: &str) -> (OutputDirectives, String) { - let mut directives = OutputDirectives::default(); - let mut content_start = 0; - let mut trailing_content: Option<&str> = None; - - for line in content.lines() { - let trimmed = line.trim(); - // Try to match [[key:value]] at the start of the line (lenient: allows trailing content) - if let Some(after_open) = trimmed.strip_prefix("[[") { - if let Some(close_pos) = after_open.find("]]") { - let inner = &after_open[..close_pos]; - if let Some((key, value)) = inner.split_once(':') { - match key.trim() { - "reply_to" => { - let v = value.trim(); - // Validate: non-empty, reasonable length, no whitespace/control chars - if !v.is_empty() && v.len() <= 64 && v.chars().all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_') { - directives.reply_to = Some(v.to_string()); - } - } - _ => { - tracing::debug!(key = key.trim(), "unknown output directive ignored"); - } - } - // Check for trailing content after ]] - let remainder = after_open[close_pos + 2..].trim(); - if !remainder.is_empty() { - trailing_content = Some(remainder); - // Advance past this line - content_start += line.len(); - if content.as_bytes().get(content_start) == Some(&b'\r') { - content_start += 1; - } - if content.as_bytes().get(content_start) == Some(&b'\n') { - content_start += 1; - } - break; // Trailing content ends directive header - } - // Advance past this line + its line ending (handles both \n and \r\n) - content_start += line.len(); - if content.as_bytes().get(content_start) == Some(&b'\r') { - content_start += 1; - } - if content.as_bytes().get(content_start) == Some(&b'\n') { - content_start += 1; - } - } else { - // [[X]] without colon — not a directive, stop parsing - break; - } - } else { - // No closing ]] found — not a directive, stop parsing - break; - } - } else { - break; - } - } - - let remaining = if let Some(trailing) = trailing_content { - if content_start < content.len() { - format!("{}\n{}", trailing, &content[content_start..]) - } else { - trailing.to_string() - } - } else if content_start < content.len() { - content[content_start..].to_string() - } else { - String::new() - }; - (directives, remaining) -} - -// --- Platform-agnostic types --- - -/// Identifies a channel or thread across platforms. -/// -/// Used for **routing**: `channel_id` is the ID the adapter sends messages to. -/// For Discord threads, this is the thread's own channel ID (Discord API -/// requires it for `say`/`edit`). Use `parent_id` to find the parent channel. -/// -/// Compare with `SenderContext`, which is **metadata for the agent**: there -/// `channel_id` is the parent channel and `thread_id` is the thread, -/// matching Slack's model for cross-platform consistency. -#[derive(Clone, Debug)] -pub struct ChannelRef { - pub platform: String, - pub channel_id: String, - /// Thread within a channel (e.g. Slack thread_ts, Telegram topic_id). - /// For Discord, threads are separate channels so this is None. - pub thread_id: Option, - /// Parent channel if this is a thread-as-channel (Discord). - pub parent_id: Option, - /// Originating gateway event ID, propagated back in `GatewayReply.reply_to` - /// so the gateway can correlate replies with inbound events (e.g. LINE reply tokens). - /// Excluded from Hash/Eq — two ChannelRefs pointing to the same channel are - /// equal regardless of which event they originated from. - pub origin_event_id: Option, -} - -impl PartialEq for ChannelRef { - fn eq(&self, other: &Self) -> bool { - self.platform == other.platform - && self.channel_id == other.channel_id - && self.thread_id == other.thread_id - && self.parent_id == other.parent_id - } -} - -impl Eq for ChannelRef {} - -impl std::hash::Hash for ChannelRef { - fn hash(&self, state: &mut H) { - self.platform.hash(state); - self.channel_id.hash(state); - self.thread_id.hash(state); - self.parent_id.hash(state); - } -} - -/// Identifies a message across platforms. -#[derive(Clone, Debug)] -pub struct MessageRef { - pub channel: ChannelRef, - pub message_id: String, -} - -/// Bundles per-message parameters for `AdapterRouter::handle_message`. -/// -/// Introduced to reduce parameter count and make the signature extensible -/// (e.g. streaming policy, rate limit hints) without breaking call sites. -pub struct MessageContext { - pub thread_channel: ChannelRef, - pub sender_json: String, - pub prompt: String, - pub extra_blocks: Vec, - pub trigger_msg: MessageRef, - pub other_bot_present: bool, -} - -/// Sender identity injected into prompts for downstream agent context. -/// -/// This is **metadata for the agent** — `channel_id` always refers to the -/// logical parent channel, and `thread_id` identifies the thread (if any). -/// This convention is consistent across platforms (Slack, Discord, Telegram). -/// -/// Compare with `ChannelRef`, which is used for **routing**: there -/// `channel_id` is the ID the adapter sends messages to (for Discord -/// threads, that's the thread's own channel ID, not the parent). -#[derive(Clone, Debug, Serialize)] -pub struct SenderContext { - pub schema: String, - pub sender_id: String, - pub sender_name: String, - pub display_name: String, - pub channel: String, - pub channel_id: String, - /// Thread identifier, if the message is inside a thread. - /// Slack: thread_ts. Discord: thread channel ID (channel_id holds the parent). - #[serde(skip_serializing_if = "Option::is_none")] - pub thread_id: Option, - pub is_bot: bool, - /// Platform message creation time (ISO 8601 UTC), if available. - /// Discord/Slack: platform timestamp. Gateway: broker receive time (best-effort). - /// Additive optional field — schema version stays openab.sender.v1 (no consumer - /// breakage). If future additions require breaking changes, bump to v1.1+. - #[serde(skip_serializing_if = "Option::is_none")] - pub timestamp: Option, - /// Platform message ID. Agents can use this to reply to a specific message - /// via the `[[reply_to:]]` output directive. - #[serde(skip_serializing_if = "Option::is_none")] - pub message_id: Option, - /// The platform user ID of the receiving bot/agent. - /// Enables agents to identify themselves when multiple agents share the same backend. - #[serde(skip_serializing_if = "Option::is_none")] - pub receiver_id: Option, -} - -// --- ChatAdapter trait --- - -#[async_trait] -pub trait ChatAdapter: Send + Sync + 'static { - /// Platform name for logging and session key namespacing. - fn platform(&self) -> &'static str; - - /// Maximum message length (chars) for this platform; the router splits longer - /// replies into multiple messages at this bound. Platform-specific (e.g. 2000 - /// for Discord; Slack uses its Block Kit `markdown` block cap). - fn message_limit(&self) -> usize; - - /// Send a new message, returns a reference to the sent message. - async fn send_message(&self, channel: &ChannelRef, content: &str) -> Result; - - /// Create a thread from a trigger message, returns the thread channel ref. - async fn create_thread( - &self, - channel: &ChannelRef, - trigger_msg: &MessageRef, - title: &str, - ) -> Result; - - /// Add a reaction/emoji to a message. - async fn add_reaction(&self, msg: &MessageRef, emoji: &str) -> Result<()>; - - /// Remove a reaction/emoji from a message. - async fn remove_reaction(&self, msg: &MessageRef, emoji: &str) -> Result<()>; - - /// Edit an existing message in-place (for streaming updates). - /// Default: unsupported (send-once only). - async fn edit_message(&self, _msg: &MessageRef, _content: &str) -> Result<()> { - Err(anyhow::anyhow!("edit_message not supported")) - } - - /// Send a message as a reply to a specific message (Discord: message_reference). - /// Default: falls back to plain send_message (ignores reply_to). - async fn send_message_with_reply( - &self, - channel: &ChannelRef, - content: &str, - reply_to_message_id: &str, - ) -> Result { - let _ = reply_to_message_id; // unused in default impl - self.send_message(channel, content).await - } - - /// Rename the thread/channel title. Default: no-op (not all platforms support it). - async fn rename_thread(&self, _channel: &ChannelRef, _title: &str) -> Result<()> { - Ok(()) - } - - /// Delete a message. Used to remove streaming placeholders when reply_to is set. - /// Default: edits to zero-width space (fallback for platforms without delete support). - async fn delete_message(&self, msg: &MessageRef) -> Result<()> { - self.edit_message(msg, "\u{200b}").await - } - - /// Whether this adapter streams via a native streaming API (Slack - /// chat.startStream) rather than the post+edit loop. Default: false. - /// `other_bot_present` lets adapters fall back to send-once in multi-bot - /// threads (mirrors `use_streaming`'s #534 rule). - fn uses_native_streaming(&self, _other_bot_present: bool) -> bool { - false - } - - /// Begin a native stream. The returned MessageRef is the handle for - /// subsequent `stream_append` / `stream_finish`. - /// Default: delegate to send_message (only called when uses_native_streaming). - /// `recipient` is the per-turn `(user_id, team_id)` for platforms (Slack) that - /// need it for the native stream open; ignored by the default impl. - async fn stream_begin( - &self, - channel: &ChannelRef, - _recipient: Option<(String, String)>, - ) -> Result { - self.send_message(channel, "…").await - } - - /// Append an INCREMENTAL delta to a native stream. - /// Default: best-effort edit (only called when uses_native_streaming). - async fn stream_append(&self, msg: &MessageRef, delta: &str) -> Result<()> { - self.edit_message(msg, delta).await - } - - /// Finish a native stream and write the COMPLETE final content. - /// Default: delegate to edit_message. - async fn stream_finish(&self, msg: &MessageRef, final_content: &str) -> Result<()> { - self.edit_message(msg, final_content).await - } - - /// Whether this adapter uses a status API (e.g. assistant.threads.setStatus) - /// instead of emoji reactions for thinking/tool indicators. Independent of - /// `uses_native_streaming` — status can work without content streaming. - /// Default: false. - fn uses_assistant_status(&self) -> bool { - false - } - - /// Set an ephemeral status line (e.g. "Thinking…", "Using …"). - /// Empty string clears it. Default: no-op (platforms without a status API). - async fn set_status(&self, _channel: &ChannelRef, _status: &str) -> Result<()> { - Ok(()) - } - - /// Whether this platform renders Markdown tables natively. When `true`, the - /// router skips the `convert_tables` pre-pass (which rewrites tables into - /// code blocks / bullet lists for platforms that cannot render them) and - /// lets the platform render the raw Markdown table itself. - /// Default: `false` (keep converting). Overridden by Slack (Block Kit - /// `markdown` blocks / `markdown_text` stream chunks render tables natively). - fn renders_native_tables(&self) -> bool { - false - } - - /// Whether this adapter should use streaming edit (true) or send-once (false). - /// `other_bot_present` indicates if another bot has posted in the current thread. - /// Streaming should be disabled in multi-bot threads to avoid edit interference. - /// NOTE: Slight race window exists — the multibot cache is checked before - /// handle_message, so a bot arriving between the check and the response will - /// not be detected until the next message. This is acceptable: the first - /// response may stream, but subsequent ones will correctly use send-once. - fn use_streaming(&self, other_bot_present: bool) -> bool; - - /// Whether to send the "…" placeholder message before streaming starts. - /// Default: true. Platforms using drafts (e.g. Telegram Rich Messages) can - /// return false to suppress the placeholder. - fn show_streaming_placeholder(&self) -> bool { - true - } -} - -// --- AdapterRouter --- - -/// Shared logic for routing messages to ACP agents, managing sessions, -/// streaming edits, and controlling reactions. Platform-independent. -pub struct AdapterRouter { - pool: Arc, - reactions_config: ReactionsConfig, - table_mode: TableMode, - prompt_hard_timeout: std::time::Duration, - /// Polling cadence for the recv-loop liveness check (#732). - liveness_check_interval: std::time::Duration, - /// Workspace aliases from `[workspace.aliases]` config. - workspace_aliases: std::collections::HashMap, - /// Bot home directory (security boundary for workspace directives). - bot_home: std::path::PathBuf, -} - -impl AdapterRouter { - pub fn new( - pool: Arc, - reactions_config: ReactionsConfig, - table_mode: TableMode, - prompt_hard_timeout_secs: u64, - liveness_check_secs: u64, - workspace_aliases: std::collections::HashMap, - bot_home: std::path::PathBuf, - ) -> Self { - if liveness_check_secs >= prompt_hard_timeout_secs { - warn!( - liveness_check_secs, - prompt_hard_timeout_secs, - "pool.liveness_check_secs >= pool.prompt_hard_timeout_secs; \ - the hard ceiling will only fire after the next liveness tick \ - and may be effectively bypassed. Lower liveness_check_secs." - ); - } - Self { - pool, - reactions_config, - table_mode, - prompt_hard_timeout: std::time::Duration::from_secs(prompt_hard_timeout_secs), - liveness_check_interval: std::time::Duration::from_secs(liveness_check_secs), - workspace_aliases, - bot_home, - } - } - - /// Access the underlying session pool (e.g. for config option queries). - pub fn pool(&self) -> &Arc { - &self.pool - } - - /// Access the reactions config (used by dispatch.rs). - pub fn reactions_config(&self) -> &ReactionsConfig { - &self.reactions_config - } - - /// Workspace aliases for control directive resolution. - pub fn workspace_aliases_map(&self) -> std::collections::HashMap { - self.workspace_aliases.clone() - } - - /// Bot home path for workspace security boundary. - pub fn bot_home_path(&self) -> std::path::PathBuf { - self.bot_home.clone() - } - - /// Pack one arrival event into ContentBlocks. Per-arrival layout: - /// Text { "\n{json}\n" } <- delimiter - /// [Text blocks from extra_blocks (e.g. STT transcripts)] - /// Text { "{prompt}" } <- omitted if empty - /// [non-Text blocks from extra_blocks (e.g. Image)] - /// - /// The sender_context block stands alone so it can serve as a structural - /// delimiter between arrivals in batched dispatch — agents can scan for - /// `` openers to find arrival boundaries. Within an arrival, - /// transcript text precedes the typed prompt to match pre-batching adapter - /// behavior (voice content first), and images trail the prompt as before. - /// This is the single packing code path for both per-message and batched - /// dispatch (ADR §3.5). For a batch of N messages, call this N times and - /// concatenate. - pub fn pack_arrival_event( - sender_json: &str, - prompt: &str, - extra_blocks: Vec, - ) -> Vec { - let header = format!("\n{}\n", sender_json); - let (texts, others): (Vec<_>, Vec<_>) = extra_blocks - .into_iter() - .partition(|b| matches!(b, ContentBlock::Text { .. })); - let mut blocks = Vec::with_capacity(2 + texts.len() + others.len()); - blocks.push(ContentBlock::Text { text: header }); - blocks.extend(texts); - if !prompt.is_empty() { - blocks.push(ContentBlock::Text { - text: prompt.to_string(), - }); - } - blocks.extend(others); - blocks - } - - /// Handle an incoming user message. The adapter is responsible for - /// filtering, resolving the thread, and building the SenderContext. - /// This method handles sender context injection, session management, and streaming. - pub async fn handle_message( - &self, - adapter: &Arc, - ctx: MessageContext, - ) -> Result<()> { - tracing::debug!(platform = adapter.platform(), "processing message"); - - let content_blocks = - Self::pack_arrival_event(&ctx.sender_json, &ctx.prompt, ctx.extra_blocks); - - let thread_key = format!( - "{}:{}", - adapter.platform(), - ctx.thread_channel - .thread_id - .as_deref() - .unwrap_or(&ctx.thread_channel.channel_id) - ); - - if let Err(e) = self.pool.get_or_create(&thread_key, None).await { - let msg = format_user_error(&e.to_string()); - let _ = adapter - .send_message(&ctx.thread_channel, &format!("⚠️ {msg}")) - .await; - error!("pool error: {e}"); - return Err(e); - } - - // In assistant-status mode (e.g. Slack assistant_mode), status is conveyed - // via assistant.threads.setStatus, so the emoji-reaction lifecycle is skipped - // entirely — mirrors dispatch_batch so per-message and batched modes agree. - let assistant_status = adapter.uses_assistant_status(); - - let reactions = Arc::new(StatusReactionController::new( - self.reactions_config.enabled, - adapter.clone(), - ctx.trigger_msg.clone(), - self.reactions_config.emojis.clone(), - self.reactions_config.timing.clone(), - )); - if !assistant_status { - reactions.set_queued().await; - } - - let result = self - .stream_prompt( - adapter, - &thread_key, - content_blocks, - &ctx.thread_channel, - reactions.clone(), - ctx.other_bot_present, - ) - .await; - - if !assistant_status { - match &result { - Ok(()) => reactions.set_done().await, - Err(_) => reactions.set_error().await, - } - - let hold_ms = if result.is_ok() { - self.reactions_config.timing.done_hold_ms - } else { - self.reactions_config.timing.error_hold_ms - }; - if self.reactions_config.remove_after_reply { - let reactions = reactions; - tokio::spawn(async move { - tokio::time::sleep(std::time::Duration::from_millis(hold_ms)).await; - reactions.clear().await; - }); - } - } - - if let Err(ref e) = result { - let _ = adapter - .send_message(&ctx.thread_channel, &format!("⚠️ {e}")) - .await; - } - - result - } - - async fn stream_prompt( - &self, - adapter: &Arc, - thread_key: &str, - content_blocks: Vec, - thread_channel: &ChannelRef, - reactions: Arc, - other_bot_present: bool, - ) -> Result<()> { - self.stream_prompt_blocks( - adapter, - thread_key, - content_blocks, - thread_channel, - reactions, - other_bot_present, - // handle_message path (e.g. cron) is never Slack assistant-mode native - // streaming, so no per-turn recipient — degrades to post+edit if it were. - None, - ) - .await - } - - /// Drive one ACP turn with the given pre-packed ContentBlocks. - /// Called by both `handle_message` (per-message mode) and `dispatch::dispatch_batch` - /// (batched mode). - #[allow(clippy::too_many_arguments)] - pub async fn stream_prompt_blocks( - &self, - adapter: &Arc, - thread_key: &str, - content_blocks: Vec, - thread_channel: &ChannelRef, - reactions: Arc, - other_bot_present: bool, - recipient: Option<(String, String)>, - ) -> Result<()> { - let adapter = adapter.clone(); - let thread_channel = thread_channel.clone(); - let message_limit = adapter.message_limit(); - let streaming = adapter.use_streaming(other_bot_present); - let native = adapter.uses_native_streaming(other_bot_present); - let assistant_status = adapter.uses_assistant_status(); - // Platforms that render Markdown tables natively (e.g. Slack Block Kit - // `markdown` blocks / `markdown_text` stream chunks) skip the - // table→code/bullets pre-pass so the raw table renders natively. - let table_mode = if adapter.renders_native_tables() { - TableMode::Off - } else { - self.table_mode - }; - let tool_display = self.reactions_config.tool_display; - let prompt_hard_timeout = self.prompt_hard_timeout; - let liveness_check_interval = self.liveness_check_interval; - - self.pool - .with_connection(thread_key, |conn| { - let content_blocks = content_blocks.clone(); - Box::pin(async move { - let reset = conn.session_reset; - conn.session_reset = false; - - let (mut rx, request_id) = conn.session_prompt(content_blocks).await?; - if assistant_status { - let _ = adapter.set_status(&thread_channel, "Thinking…").await; - } else { - reactions.set_thinking().await; - } - - let mut text_buf = String::new(); - let mut tool_lines: Vec = Vec::new(); - - if reset { - text_buf.push_str("⚠️ _Session expired, starting fresh..._\n\n"); - } - - // Native streaming: defer stream_begin until first Text event - // so the thinking phase only shows set_status (no placeholder msg). - let mut native_msg: Option = None; - // Once stream_begin fails, stop retrying for this turn to avoid - // hammering the API on transient failures. - let mut stream_begin_failed = false; - // Native delta coalescing state (used only when `native`). - let mut native_pending = String::new(); - let mut native_last_flush = tokio::time::Instant::now(); - const NATIVE_FLUSH_MS: u128 = 400; - - // Streaming edit: send placeholder, spawn edit loop - let (buf_tx, placeholder_msg, edit_handle) = if streaming && !native { - let initial = if reset { - "⚠️ _Session expired, starting fresh..._\n\n…".to_string() - } else { - "…".to_string() - }; - let msg = if adapter.show_streaming_placeholder() { - adapter.send_message(&thread_channel, &initial).await? - } else { - // Dummy ref for edit loop — gateway uses drafts, doesn't need real msg_id - MessageRef { - message_id: "draft".to_string(), - channel: thread_channel.clone(), - } - }; - let (tx, rx) = tokio::sync::watch::channel(initial); - let edit_adapter = adapter.clone(); - let edit_msg = msg.clone(); - let limit = message_limit; - let mut buf_rx = rx; - let edit_handle = tokio::spawn(async move { - let mut last = String::new(); - // Track consecutive edit failures so we can abort cosmetic - // streaming when the platform stops accepting edits (e.g. - // Feishu's 20-edits-per-message hard cap, errcode 230072). - // Once aborted, the final delivery path still runs and the - // user sees the complete content at turn end. - let mut consecutive_failures: u32 = 0; - const MAX_CONSECUTIVE_FAILURES: u32 = 3; - loop { - tokio::time::sleep(std::time::Duration::from_millis(1500)).await; - if buf_rx.has_changed().unwrap_or(false) { - let content = buf_rx.borrow_and_update().clone(); - if content != last { - let display = if content.chars().count() > limit - 100 { - format!( - "…{}", - format::truncate_chars_tail(&content, limit - 100) - ) - } else { - content.clone() - }; - match edit_adapter - .edit_message(&edit_msg, &display) - .await - { - Ok(_) => { - consecutive_failures = 0; - last = content; - } - Err(e) => { - consecutive_failures += 1; - tracing::debug!( - message_id = %edit_msg.message_id, - platform = %edit_msg.channel.platform, - error = ?e, - consecutive_failures, - "mid-stream cosmetic edit failed" - ); - if consecutive_failures - >= MAX_CONSECUTIVE_FAILURES - { - tracing::warn!( - message_id = %edit_msg.message_id, - platform = %edit_msg.channel.platform, - consecutive_failures, - "mid-stream cosmetic edit aborted; \ - final content will be delivered at turn end" - ); - break; - } - } - } - } - } - if buf_rx.has_changed().is_err() { - break; - } - } - }); - (Some(tx), Some(msg), Some(edit_handle)) - } else { - (None, None, None) - }; - - // (#732) Liveness-aware recv loop. Filters stale id-bearing - // messages and abandons cleanly on dead agent / hard ceiling - // so late responses cannot leak into the next prompt. - let mut response_error: Option = None; - let prompt_start = tokio::time::Instant::now(); - loop { - let notification = tokio::select! { - msg = rx.recv() => match msg { - Some(n) => n, - // Reader saw EOF and already drained pending; nothing to abandon. - None => break, - }, - _ = tokio::time::sleep(liveness_check_interval) => { - if !conn.alive() { - response_error = Some("Agent process died".into()); - conn.abandon_request(request_id).await; - break; - } - if prompt_start.elapsed() > prompt_hard_timeout { - response_error = Some(format!( - "Agent exceeded hard timeout ({}s)", - prompt_hard_timeout.as_secs(), - )); - conn.abandon_request(request_id).await; - break; - } - continue; - } - }; - if let Some(notification_id) = notification.id { - if notification_id != request_id { - // Stale response from a previously-abandoned prompt. - // No automated test seam: this path only triggers when a - // real subprocess emits a late response after the broker - // already called abandon_request — covered by manual - // repro against a live agent (see #732 PR description). - continue; - } - if let Some(ref err) = notification.error { - response_error = Some(format_coded_error(err.code, &err.message, err.data_message())); - } - break; - } - - if let Some(event) = classify_notification(¬ification) { - match event { - AcpEvent::Text(t) => { - text_buf.push_str(&t); - if native { - // Lazy stream_begin: open the stream on first text. - if native_msg.is_none() && !stream_begin_failed { - match adapter.stream_begin(&thread_channel, recipient.clone()).await { - Ok(m) => { native_msg = Some(m); } - Err(e) => { - tracing::error!(error = ?e, "stream_begin failed on first text; will not retry this turn"); - stream_begin_failed = true; - } - } - } - if let Some(msg) = &native_msg { - native_pending.push_str(&t); - if native_last_flush.elapsed().as_millis() - >= NATIVE_FLUSH_MS - && !native_pending.is_empty() - { - let _ = adapter - .stream_append(msg, &native_pending) - .await; - native_pending.clear(); - native_last_flush = tokio::time::Instant::now(); - } - } - } else if let Some(tx) = &buf_tx { - let _ = tx.send(compose_display( - &tool_lines, - &text_buf, - true, - tool_display, - )); - } - } - AcpEvent::Thinking => { - if assistant_status { - let _ = adapter - .set_status(&thread_channel, "Thinking…") - .await; - } else { - reactions.set_thinking().await; - } - } - AcpEvent::ToolStart { id, title } if !title.is_empty() => { - // Live indicator: assistant status line vs emoji reaction. - if assistant_status { - let _ = adapter - .set_status( - &thread_channel, - &format!("Using {title}…"), - ) - .await; - } else { - reactions.set_tool(&title).await; - } - // Record the tool in BOTH modes so the finalized message keeps - // a tool summary (compose_display, gated by tool_display). In - // assistant_mode the status line is transient and cleared before - // the reply, so without this the message would retain no record - // of which tools ran. - let title = sanitize_title(&title); - if let Some(slot) = - tool_lines.iter_mut().find(|e| e.id == id) - { - slot.title = title; - slot.state = ToolState::Running; - } else { - tool_lines.push(ToolEntry { - id, - title, - state: ToolState::Running, - }); - } - // Post+edit live update (no-op under native streaming: buf_tx is None). - if let Some(tx) = &buf_tx { - let _ = tx.send(compose_display( - &tool_lines, - &text_buf, - true, - tool_display, - )); - } - } - AcpEvent::ToolDone { id, title, status } => { - // Live indicator: assistant status line vs emoji reaction. - if assistant_status { - let _ = adapter - .set_status(&thread_channel, "Thinking…") - .await; - } else { - reactions.set_thinking().await; - } - // Update the tool's state in BOTH modes (see ToolStart) so the - // finalized message's tool summary reflects completion/failure. - let new_state = if status == "completed" { - ToolState::Completed - } else { - ToolState::Failed - }; - if let Some(slot) = - tool_lines.iter_mut().find(|e| e.id == id) - { - if !title.is_empty() { - slot.title = sanitize_title(&title); - } - slot.state = new_state; - } else if !title.is_empty() { - tool_lines.push(ToolEntry { - id, - title: sanitize_title(&title), - state: new_state, - }); - } - if let Some(tx) = &buf_tx { - let _ = tx.send(compose_display( - &tool_lines, - &text_buf, - true, - tool_display, - )); - } - } - AcpEvent::ConfigUpdate { options } => { - conn.config_options = options; - } - _ => {} - } - } - } - - conn.prompt_done().await; - // Stop the cosmetic edit loop before the finalize write path - // issues its authoritative edit. Dropping buf_tx closes the watch - // channel so the loop breaks on its next check, but it may be - // mid-edit (a single edit can now block up to the gateway response - // timeout). Without an explicit abort+join, a cosmetic edit issued - // just before close could land *after* the finalize edit and - // overwrite it with stale, mid-stream content (#1122 review NEW-1). - // - // abort() cancels any cosmetic edit that has not yet been put on - // the wire and interrupts the inter-flush sleep immediately; the - // await confirms the task is gone before we proceed. This narrows - // the race to near zero — it does NOT fully eliminate it: a PUT - // already flushed microseconds before abort cannot be recalled, - // and if finalize's PUT travels a different pooled connection the - // server-side arrival order is not strictly guaranteed. That - // residual window is display-only (stale tail briefly shown) and - // far narrower than before this join existed. - drop(buf_tx); - if let Some(handle) = edit_handle { - handle.abort(); - let _ = handle.await; - } - - // Parse output directives from raw text_buf BEFORE compose_display. - // Directives are agent meta-layer, not content — must be stripped - // before tool lines are composed into the display output. - let (directives, stripped_text) = parse_output_directives(&text_buf); - let text_buf = stripped_text; - - // Build final content - let final_content = - compose_display(&tool_lines, &text_buf, false, tool_display); - let final_content = if final_content.is_empty() { - if let Some(err) = response_error { - format!("⚠️ {err}") - } else { - "_(no response)_".to_string() - } - } else if let Some(err) = response_error { - format!("⚠️ {err}\n\n{final_content}") - } else { - final_content - }; - - let final_content = markdown::convert_tables(&final_content, table_mode); - let chunks = format::split_message(&final_content, message_limit); - // Track delivery health across all final write paths. Any failure - // here means the user's view is incomplete; we propagate Err at the - // end of the closure so dispatch surfaces set_error (❌) instead of - // silently calling set_done (🆗) over a half-delivered turn. - let mut delivery_failed = false; - // Clear the assistant status line before delivering the final message. - if assistant_status { - let _ = adapter.set_status(&thread_channel, "").await; - } - if native { - if let Some(msg) = &native_msg { - if !native_pending.is_empty() { - if let Err(e) = - adapter.stream_append(msg, &native_pending).await - { - tracing::warn!(error = ?e, platform = %thread_channel.platform, message_id = %msg.message_id, "native finalize stream_append failed"); - delivery_failed = true; - } - } - // Finalize the streamed message with the first chunk (full-replace), - // then post any overflow chunks as new in-thread messages — mirrors - // the post+edit path so long replies aren't truncated at message_limit. - // NOTE: the reply_to directive is intentionally NOT honored in native - // streaming mode — the streamed message is the in-thread reply. - match chunks.first() { - Some(first) => { - if let Err(e) = adapter.stream_finish(msg, first).await { - tracing::warn!(error = ?e, platform = %thread_channel.platform, message_id = %msg.message_id, "native stream_finish failed"); - delivery_failed = true; - } - for chunk in chunks.iter().skip(1) { - if let Err(e) = - adapter.send_message(&thread_channel, chunk).await - { - tracing::warn!(error = ?e, platform = %thread_channel.platform, message_id = %msg.message_id, "native overflow chunk send failed"); - delivery_failed = true; - } - } - } - None => { - if let Err(e) = - adapter.stream_finish(msg, &final_content).await - { - tracing::warn!(error = ?e, platform = %thread_channel.platform, message_id = %msg.message_id, "native stream_finish (no chunks) failed"); - delivery_failed = true; - } - } - } - } else { - // native_msg is None — either no Text event ever arrived - // (tool-only or empty turn) so lazy stream_begin never - // fired, or stream_begin failed on the first Text event - // and we stopped retrying for this turn. In both cases no - // native stream was opened, so deliver the final content - // (which may be the "_(no response)_" sentinel, or the - // accumulated text_buf) as plain in-thread messages so - // the turn is never silently dropped. - for chunk in &chunks { - if let Err(e) = - adapter.send_message(&thread_channel, chunk).await - { - tracing::warn!(error = ?e, platform = %thread_channel.platform, "native fallback chunk send failed"); - delivery_failed = true; - } - } - } - } else if let Some(msg) = placeholder_msg { - if let Some(ref reply_id) = directives.reply_to { - // reply_to directive: send reply first, then delete placeholder. - // Only delete if send succeeds — preserves placeholder on failure. - let mut send_ok = false; - let mut first = true; - for chunk in &chunks { - if first { - match adapter.send_message_with_reply( - &thread_channel, - chunk, - reply_id, - ).await { - Ok(_) => { send_ok = true; } - Err(e) => { - tracing::warn!(error = ?e, platform = %thread_channel.platform, message_id = %msg.message_id, "reply_to send failed; preserving placeholder"); - delivery_failed = true; - } - } - } else if let Err(e) = - adapter.send_message(&thread_channel, chunk).await - { - tracing::warn!(error = ?e, platform = %thread_channel.platform, message_id = %msg.message_id, "reply_to overflow chunk send failed"); - delivery_failed = true; - } - first = false; - } - if send_ok { - if let Err(e) = adapter.delete_message(&msg).await { - tracing::warn!(error = ?e, platform = %thread_channel.platform, message_id = %msg.message_id, "delete placeholder failed; placeholder will remain visible"); - } - } - } else if adapter.platform() == "discord" - && contains_bot_mention(&final_content) - { - // Discord-specific: bot mention detected. Delete placeholder - // and send as new message so Discord emits MESSAGE_CREATE — - // otherwise the mentioned bot won't receive the gateway - // event since MESSAGE_UPDATE skips notifications (#1110). - let mut send_ok = false; - if let Some(first) = chunks.first() { - match adapter.send_message(&thread_channel, first).await { - Ok(_) => { - send_ok = true; - } - Err(e) => { - tracing::warn!(error = ?e, platform = %thread_channel.platform, message_id = %msg.message_id, "discord bot-mention first chunk send failed"); - delivery_failed = true; - } - } - } - for chunk in chunks.iter().skip(1) { - if let Err(e) = adapter.send_message(&thread_channel, chunk).await { - tracing::warn!(error = ?e, platform = %thread_channel.platform, message_id = %msg.message_id, "streaming overflow chunk send failed"); - delivery_failed = true; - } - } - if send_ok { - let _ = adapter.delete_message(&msg).await; - } - } else { - // Normal streaming: edit first chunk into placeholder, send rest. - // If placeholder is a dummy "draft" ref (no real message), send as - // new message instead — the gateway will persist via sendRichMessage. - if msg.message_id == "draft" { - for chunk in &chunks { - if let Err(e) = - adapter.send_message(&thread_channel, chunk).await - { - tracing::warn!(error = ?e, platform = %thread_channel.platform, message_id = %msg.message_id, "draft placeholder fallback chunk send failed"); - delivery_failed = true; - } - } - } else if let Some(first) = chunks.first() { - // If the placeholder edit fails (e.g. Feishu's - // 20-edits-per-message cap was hit during - // cosmetic streaming and the gateway reports - // edit_cap_reached), fall back to deleting the - // half-edited placeholder and sending the first - // chunk as a fresh message so the user sees the - // complete reply without overlap. If delete - // fails the placeholder simply remains — same - // UX as pre-recovery, not a hard failure. - if let Err(e) = adapter.edit_message(&msg, first).await { - tracing::warn!(error = ?e, platform = %thread_channel.platform, message_id = %msg.message_id, "final streaming edit failed; deleting placeholder and sending fresh"); - if let Err(de) = adapter.delete_message(&msg).await { - tracing::warn!(error = ?de, platform = %thread_channel.platform, message_id = %msg.message_id, "delete placeholder failed; user will see overlap"); - } - if let Err(e2) = - adapter.send_message(&thread_channel, first).await - { - tracing::error!(error = ?e2, platform = %thread_channel.platform, message_id = %msg.message_id, "fallback send_message also failed"); - delivery_failed = true; - } - } - for chunk in chunks.iter().skip(1) { - if let Err(e) = - adapter.send_message(&thread_channel, chunk).await - { - tracing::warn!(error = ?e, platform = %thread_channel.platform, message_id = %msg.message_id, "streaming overflow chunk send failed"); - delivery_failed = true; - } - } - } - } - } else { - // Send-once: all chunks as new messages - // First chunk uses reply_to directive if present - let mut first = true; - for chunk in &chunks { - if first { - if let Some(ref reply_id) = directives.reply_to { - if let Err(e) = adapter.send_message_with_reply( - &thread_channel, - chunk, - reply_id, - ).await { - tracing::warn!(error = ?e, platform = %thread_channel.platform, "send-once reply_to first chunk failed"); - delivery_failed = true; - } - } else if let Err(e) = - adapter.send_message(&thread_channel, chunk).await - { - tracing::warn!(error = ?e, platform = %thread_channel.platform, "send-once first chunk failed"); - delivery_failed = true; - } - } else if let Err(e) = - adapter.send_message(&thread_channel, chunk).await - { - tracing::warn!(error = ?e, platform = %thread_channel.platform, "send-once subsequent chunk failed"); - delivery_failed = true; - } - first = false; - } - } - - if delivery_failed { - Err(anyhow::anyhow!( - "streaming finalization had delivery failures; user view is incomplete" - )) - } else { - Ok(()) - } - }) - }) - .await - } -} - -/// Returns true if `content` contains a Discord user/bot mention (`<@123>`, `<@!123>`) -/// or a role mention (`<@&123>`). -/// Used to detect cross-bot mentions so the streaming path can switch from -/// edit (MESSAGE_UPDATE, no mention notification) to delete+send (MESSAGE_CREATE). -fn contains_bot_mention(content: &str) -> bool { - let mut i = 0; - let bytes = content.as_bytes(); - while i + 2 < bytes.len() { - if bytes[i] == b'<' && bytes[i + 1] == b'@' { - // Skip optional '!' (nickname mention) or '&' (role mention) - let start = if i + 2 < bytes.len() - && (bytes[i + 2] == b'!' || bytes[i + 2] == b'&') - { - i + 3 - } else { - i + 2 - }; - if start < bytes.len() && bytes[start].is_ascii_digit() { - if let Some(end) = content[start..].find('>') { - if content[start..start + end].chars().all(|c| c.is_ascii_digit()) { - return true; - } - } - } - i = start; - } else { - i += 1; - } - } - false -} - -/// Flatten a tool-call title into a single line safe for inline-code spans. -fn sanitize_title(title: &str) -> String { - title - .replace('\r', "") - .replace('\n', " ; ") - .replace('`', "'") -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum ToolState { - Running, - Completed, - Failed, -} - -#[derive(Debug, Clone)] -struct ToolEntry { - id: String, - title: String, - state: ToolState, -} - -impl ToolEntry { - fn render(&self) -> String { - let icon = match self.state { - ToolState::Running => "🔧", - ToolState::Completed => "✅", - ToolState::Failed => "❌", - }; - let suffix = if self.state == ToolState::Running { - "..." - } else { - "" - }; - format!("{icon} `{}`{}", self.title, suffix) - } -} - -/// Maximum number of finished tool entries to show individually -/// during streaming before collapsing into a summary line. -const TOOL_COLLAPSE_THRESHOLD: usize = 3; - -fn compose_display( - tool_lines: &[ToolEntry], - text: &str, - streaming: bool, - tool_display: ToolDisplay, -) -> String { - let mut out = String::new(); - if !tool_lines.is_empty() && tool_display != ToolDisplay::None { - let done = tool_lines - .iter() - .filter(|e| e.state == ToolState::Completed) - .count(); - let failed = tool_lines - .iter() - .filter(|e| e.state == ToolState::Failed) - .count(); - let running = tool_lines - .iter() - .filter(|e| e.state == ToolState::Running) - .count(); - let finished = done + failed; - - match tool_display { - ToolDisplay::Compact => { - // Always show count summary, never per-tool details - let mut parts = Vec::new(); - if done > 0 { - parts.push(format!("✅ {done}")); - } - if failed > 0 { - parts.push(format!("❌ {failed}")); - } - if running > 0 { - parts.push(format!("🔧 {running}")); - } - if !parts.is_empty() { - out.push_str(&format!("{} tool(s)\n", parts.join(" · "))); - } - } - ToolDisplay::Full => { - if streaming { - let running_entries: Vec<_> = tool_lines - .iter() - .filter(|e| e.state == ToolState::Running) - .collect(); - - if finished <= TOOL_COLLAPSE_THRESHOLD { - for entry in tool_lines.iter().filter(|e| e.state != ToolState::Running) { - out.push_str(&entry.render()); - out.push('\n'); - } - } else { - let mut parts = Vec::new(); - if done > 0 { - parts.push(format!("✅ {done}")); - } - if failed > 0 { - parts.push(format!("❌ {failed}")); - } - out.push_str(&format!("{} tool(s) completed\n", parts.join(" · "))); - } - - if running_entries.len() <= TOOL_COLLAPSE_THRESHOLD { - for entry in &running_entries { - out.push_str(&entry.render()); - out.push('\n'); - } - } else { - let hidden = running_entries.len() - TOOL_COLLAPSE_THRESHOLD; - out.push_str(&format!("🔧 {hidden} more running\n")); - for entry in running_entries.iter().skip(hidden) { - out.push_str(&entry.render()); - out.push('\n'); - } - } - } else { - for entry in tool_lines { - out.push_str(&entry.render()); - out.push('\n'); - } - } - } - ToolDisplay::None => {} // guarded above, but safe no-op - } - if !out.is_empty() { - out.push('\n'); - } - } - out.push_str(text.trim_end()); - out -} - -#[cfg(test)] -mod tests { - use super::*; - - /// Compile-time regression guard: use_streaming() is a required trait method - /// (no default). Any adapter that forgets to implement it will fail to compile. - /// This test documents the contract — see PR #503 / issue #502 for context. - #[test] - fn use_streaming_is_required_method() { - // If use_streaming() had a default impl, this test module would still - // compile even if an adapter forgot to override it. The real guard is - // the trait definition itself — this test exists as documentation and - // to catch if someone re-adds a default. - struct TestAdapter; - - #[async_trait] - impl ChatAdapter for TestAdapter { - fn platform(&self) -> &'static str { - "test" - } - fn message_limit(&self) -> usize { - 2000 - } - async fn send_message(&self, _: &ChannelRef, _: &str) -> Result { - unimplemented!() - } - async fn create_thread( - &self, - _: &ChannelRef, - _: &MessageRef, - _: &str, - ) -> Result { - unimplemented!() - } - async fn add_reaction(&self, _: &MessageRef, _: &str) -> Result<()> { - Ok(()) - } - async fn remove_reaction(&self, _: &MessageRef, _: &str) -> Result<()> { - Ok(()) - } - // use_streaming() MUST be declared — removing this line should fail compilation - fn use_streaming(&self, _other_bot_present: bool) -> bool { - false - } - } - - let adapter = TestAdapter; - // Verify the method is callable and returns the declared value - assert!(!adapter.use_streaming(false)); - // renders_native_tables defaults to false: platforms that don't override - // it keep the table→code/bullets conversion (e.g. Discord, Gateway). - assert!(!adapter.renders_native_tables()); - } - - #[test] - fn origin_event_id_excluded_from_eq() { - let a = ChannelRef { - platform: "line".into(), - channel_id: "U123".into(), - thread_id: None, - parent_id: None, - origin_event_id: Some("evt_aaa".into()), - }; - let b = ChannelRef { - platform: "line".into(), - channel_id: "U123".into(), - thread_id: None, - parent_id: None, - origin_event_id: Some("evt_bbb".into()), - }; - assert_eq!(a, b, "same channel with different event IDs must be equal"); - } - - #[test] - fn origin_event_id_excluded_from_hash() { - use std::collections::HashMap; - let a = ChannelRef { - platform: "line".into(), - channel_id: "U123".into(), - thread_id: None, - parent_id: None, - origin_event_id: Some("evt_aaa".into()), - }; - let b = ChannelRef { - platform: "line".into(), - channel_id: "U123".into(), - thread_id: None, - parent_id: None, - origin_event_id: Some("evt_bbb".into()), - }; - let mut map = HashMap::new(); - map.insert(a, "first"); - // b should hit the same bucket and overwrite - map.insert(b, "second"); - assert_eq!(map.len(), 1); - assert_eq!(map.values().next(), Some(&"second")); - } - - #[test] - fn origin_event_id_survives_clone() { - let ch = ChannelRef { - platform: "line".into(), - channel_id: "U123".into(), - thread_id: None, - parent_id: None, - origin_event_id: Some("evt_abc".into()), - }; - // Simulates create_thread propagation: clone preserves origin_event_id - let thread_ch = ChannelRef { - thread_id: Some("topic_1".into()), - origin_event_id: ch.origin_event_id.clone(), - ..ch.clone() - }; - assert_eq!(thread_ch.origin_event_id.as_deref(), Some("evt_abc")); - } - - fn tool(id: &str, title: &str, state: ToolState) -> ToolEntry { - ToolEntry { - id: id.into(), - title: title.into(), - state, - } - } - - #[test] - fn compose_display_full_shows_complete_title() { - let tools = vec![tool( - "1", - "curl -s https://example.com", - ToolState::Completed, - )]; - let out = compose_display(&tools, "done", false, ToolDisplay::Full); - assert!(out.contains("`curl -s https://example.com`")); - } - - #[test] - fn compose_display_compact_shows_count_summary() { - let tools = vec![ - tool("1", "curl -s https://example.com", ToolState::Completed), - tool("2", "grep -r pattern src/", ToolState::Completed), - tool("3", "cat /etc/hosts", ToolState::Failed), - ]; - let out = compose_display(&tools, "done", false, ToolDisplay::Compact); - assert!(out.contains("✅ 2"), "expected completed count: {out}"); - assert!(out.contains("❌ 1"), "expected failed count: {out}"); - assert!(out.contains("tool(s)"), "expected tool(s) label: {out}"); - // Must NOT contain individual tool names - assert!(!out.contains("curl"), "should not show tool names: {out}"); - assert!(!out.contains("grep"), "should not show tool names: {out}"); - } - - #[test] - fn compose_display_compact_shows_running_count() { - let tools = vec![ - tool("1", "curl", ToolState::Completed), - tool("2", "npm install", ToolState::Running), - ]; - let out = compose_display(&tools, "", true, ToolDisplay::Compact); - assert!(out.contains("✅ 1"), "expected completed count: {out}"); - assert!(out.contains("🔧 1"), "expected running count: {out}"); - } - - #[test] - fn compose_display_none_hides_tools() { - let tools = vec![tool( - "1", - "curl -s https://example.com", - ToolState::Completed, - )]; - let out = compose_display(&tools, "response text", false, ToolDisplay::None); - assert_eq!(out, "response text"); - } - - #[test] - fn contains_bot_mention_user() { - assert!(contains_bot_mention("hello <@1234567890> world")); - } - - #[test] - fn contains_bot_mention_nickname() { - assert!(contains_bot_mention("hey <@!9876543210>")); - } - - #[test] - fn contains_bot_mention_role() { - assert!(contains_bot_mention("calling <@&1496247626675257384>")); - } - - #[test] - fn contains_bot_mention_no_match() { - assert!(!contains_bot_mention("hello world")); - assert!(!contains_bot_mention("email user@example.com")); - assert!(!contains_bot_mention("<@not_a_number>")); - assert!(!contains_bot_mention("<#123456>")); // channel mention - } - - #[test] - fn contains_bot_mention_embedded() { - assert!(contains_bot_mention("請問 <@1501788608439386172> 1+1=?")); - } -} - -#[cfg(test)] -mod directive_tests { - use super::parse_output_directives; - - #[test] - fn parse_reply_to_directive() { - let input = "[[reply_to:1502606076451885136]]\nHello world"; - let (directives, content) = parse_output_directives(input); - assert_eq!(directives.reply_to, Some("1502606076451885136".to_string())); - assert_eq!(content, "Hello world"); - } - - #[test] - fn parse_no_directives() { - let input = "Just plain content\nwith multiple lines"; - let (directives, content) = parse_output_directives(input); - assert_eq!(directives.reply_to, None); - assert_eq!(content, input); - } - - #[test] - fn parse_multiple_directives() { - let input = "[[reply_to:123456]]\n[[unknown_key:value]]\nContent here"; - let (directives, content) = parse_output_directives(input); - assert_eq!(directives.reply_to, Some("123456".to_string())); - assert_eq!(content, "Content here"); - } - - #[test] - fn parse_invalid_reply_to_rejects_whitespace() { - let input = "[[reply_to:has spaces]]\nContent"; - let (directives, content) = parse_output_directives(input); - assert_eq!(directives.reply_to, None); - assert_eq!(content, "Content"); - } - - #[test] - fn parse_slack_ts_format_accepted() { - let input = "[[reply_to:1234567890.123456]]\nContent"; - let (directives, content) = parse_output_directives(input); - assert_eq!(directives.reply_to, Some("1234567890.123456".to_string())); - assert_eq!(content, "Content"); - } - - #[test] - fn parse_empty_reply_to() { - let input = "[[reply_to:]]\nContent"; - let (directives, content) = parse_output_directives(input); - assert_eq!(directives.reply_to, None); - assert_eq!(content, "Content"); - } - - #[test] - fn parse_crlf_line_endings() { - let input = "[[reply_to:999]]\r\nContent with CRLF"; - let (directives, content) = parse_output_directives(input); - assert_eq!(directives.reply_to, Some("999".to_string())); - assert_eq!(content, "Content with CRLF"); - } - - #[test] - fn parse_directive_only_no_content() { - let input = "[[reply_to:123]]"; - let (directives, content) = parse_output_directives(input); - assert_eq!(directives.reply_to, Some("123".to_string())); - assert_eq!(content, ""); - } - - #[test] - fn parse_non_directive_line_stops_parsing() { - let input = "Normal first line\n[[reply_to:123]]\nMore content"; - let (directives, content) = parse_output_directives(input); - assert_eq!(directives.reply_to, None); - assert_eq!(content, input); - } - - #[test] - fn parse_duplicate_reply_to_last_wins() { - let input = "[[reply_to:111]]\n[[reply_to:222]]\nContent"; - let (directives, content) = parse_output_directives(input); - // Last value wins - assert_eq!(directives.reply_to, Some("222".to_string())); - assert_eq!(content, "Content"); - } - - #[test] - fn parse_crlf_multiple_directives() { - let input = "[[reply_to:456]]\r\n[[unknown:x]]\r\nContent after CRLF"; - let (directives, content) = parse_output_directives(input); - assert_eq!(directives.reply_to, Some("456".to_string())); - assert_eq!(content, "Content after CRLF"); - } - - #[test] - fn parse_bracket_without_colon_preserved() { - // [[Note]] has no colon — not a directive, preserved as content - let input = "[[Summary]]\nThis is body text"; - let (directives, content) = parse_output_directives(input); - assert_eq!(directives.reply_to, None); - assert_eq!(content, input); - } - - #[test] - fn parse_reply_to_with_inline_content() { - // Agent puts content on same line as directive — should still parse - let input = "[[reply_to:1502724086474870926]] @BOT I'm on standby"; - let (directives, content) = parse_output_directives(input); - assert_eq!(directives.reply_to, Some("1502724086474870926".to_string())); - assert_eq!(content, "@BOT I'm on standby"); - } - - #[test] - fn parse_reply_to_inline_with_more_lines() { - let input = "[[reply_to:123]] First line\nSecond line\nThird line"; - let (directives, content) = parse_output_directives(input); - assert_eq!(directives.reply_to, Some("123".to_string())); - assert_eq!(content, "First line\nSecond line\nThird line"); - } - - #[test] - fn parse_reply_to_no_space_before_content() { - // No space between ]] and content - let input = "[[reply_to:1502724086474870926]]收到"; - let (directives, content) = parse_output_directives(input); - assert_eq!(directives.reply_to, Some("1502724086474870926".to_string())); - assert_eq!(content, "收到"); - } - - #[test] - fn parse_reply_to_inline_with_mention() { - // Real-world case: directive followed by Discord mention - let input = "[[reply_to:1502724086474870926]] <@1490365068863606784> 我 standby"; - let (directives, content) = parse_output_directives(input); - assert_eq!(directives.reply_to, Some("1502724086474870926".to_string())); - assert_eq!(content, "<@1490365068863606784> 我 standby"); - } - - #[test] - fn parse_reply_to_inline_only_spaces() { - // Trailing spaces only — no real content, should be empty - let input = "[[reply_to:123]] "; - let (directives, content) = parse_output_directives(input); - assert_eq!(directives.reply_to, Some("123".to_string())); - assert_eq!(content, ""); - } - - #[test] - fn parse_reply_to_with_brackets_in_content() { - // Content after ]] contains brackets — should not confuse parser - let input = "[[reply_to:456]] 看看 [[這個]] 怎麼樣"; - let (directives, content) = parse_output_directives(input); - assert_eq!(directives.reply_to, Some("456".to_string())); - assert_eq!(content, "看看 [[這個]] 怎麼樣"); - } -} diff --git a/src/bot_turns.rs b/src/bot_turns.rs deleted file mode 100644 index 130fa717b..000000000 --- a/src/bot_turns.rs +++ /dev/null @@ -1,368 +0,0 @@ -//! Per-thread bot turn tracking for runaway-loop prevention. -//! -//! Shared between Discord and Slack adapters so both platforms apply the same -//! soft/hard limit semantics. Both counters reset on a human message in the -//! thread. Runs before self-check so a bot's own messages count too — this -//! means `soft_limit=20` caps the *total* bot messages in a thread, not per-bot. - -use std::collections::HashMap; - -/// Absolute per-thread cap on consecutive bot turns without human intervention. -/// A human message resets both soft and hard counters to 0, allowing bots to -/// resume. This is *not* a lifetime total — it guards against runaway loops -/// between human resets. -pub const HARD_BOT_TURN_LIMIT: u32 = 1000; - -/// Stable prefix used in all bot turn limit warning messages. -/// Referenced by the dedup check in the Discord adapter — changing this -/// string requires updating the dedup check too. -pub const BOT_TURN_LIMIT_WARNING_PREFIX: &str = "⚠️ Bot turn limit reached"; - -#[derive(Debug, PartialEq, Eq)] -pub enum TurnResult { - /// Counter below limits — continue normally. - Ok, - /// Counter == soft_limit — warn once, then stop. - SoftLimit(u32), - /// Counter > soft_limit — silently stop (already warned). - Throttled, - /// Counter == HARD_BOT_TURN_LIMIT — warn once, then stop. - HardLimit, - /// Counter > HARD_BOT_TURN_LIMIT — silently stop (already warned). - Stopped, -} - -pub struct BotTurnTracker { - soft_limit: u32, - counts: HashMap, -} - -impl BotTurnTracker { - pub fn new(soft_limit: u32) -> Self { - Self { - soft_limit, - counts: HashMap::new(), - } - } - - pub fn on_bot_message(&mut self, thread_id: &str) -> TurnResult { - let (soft, hard) = self.counts.entry(thread_id.to_string()).or_insert((0, 0)); - *soft += 1; - *hard += 1; - if *hard > HARD_BOT_TURN_LIMIT { - TurnResult::Stopped - } else if *hard == HARD_BOT_TURN_LIMIT { - TurnResult::HardLimit - } else if *soft > self.soft_limit { - TurnResult::Throttled - } else if *soft == self.soft_limit { - TurnResult::SoftLimit(*soft) - } else { - TurnResult::Ok - } - } - - pub fn on_human_message(&mut self, thread_id: &str) { - if let Some((soft, hard)) = self.counts.get_mut(thread_id) { - *soft = 0; - *hard = 0; - } - } - - /// High-level decision for a bot message: increments the counter and - /// returns what the adapter should do. Collapses the warn-once semantics - /// and user-facing message formatting so Discord/Slack (and future adapters) - /// don't duplicate the match. - pub fn classify_bot_message(&mut self, thread_id: &str) -> TurnAction { - match self.on_bot_message(thread_id) { - TurnResult::Ok => TurnAction::Continue, - TurnResult::SoftLimit(n) => TurnAction::WarnAndStop { - severity: TurnSeverity::Soft, - turns: n, - user_message: format!( - "{} ({n}/{soft}). \ - A human must reply in this thread to continue bot-to-bot conversation.", - BOT_TURN_LIMIT_WARNING_PREFIX, - soft = self.soft_limit, - ), - }, - TurnResult::HardLimit => TurnAction::WarnAndStop { - severity: TurnSeverity::Hard, - turns: HARD_BOT_TURN_LIMIT, - user_message: format!( - "🛑 Hard bot turn limit reached ({HARD_BOT_TURN_LIMIT}). \ - A human must reply to continue." - ), - }, - TurnResult::Throttled | TurnResult::Stopped => TurnAction::SilentStop, - } - } -} - -/// Log severity hint for `TurnAction::WarnAndStop`. -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum TurnSeverity { - /// Soft limit — typically logged at `info!`. - Soft, - /// Hard absolute cap — typically logged at `warn!`. - Hard, -} - -/// High-level action for a bot message after calling -/// [`BotTurnTracker::classify_bot_message`]. -#[derive(Debug, PartialEq, Eq, Clone)] -pub enum TurnAction { - /// Safe to continue processing this bot message. - Continue, - /// Stop processing; if the message did not come from our own bot, the - /// caller should post `user_message` to the thread so humans see why - /// the bot went quiet. `turns` is the counter value at the warning - /// point — useful as a structured log field. - WarnAndStop { - severity: TurnSeverity, - turns: u32, - user_message: String, - }, - /// Stop processing silently — the warning was already sent on a previous - /// turn; further warnings would spam the thread. - SilentStop, -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn bot_turns_increment() { - let mut t = BotTurnTracker::new(5); - assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); - assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); - } - - #[test] - fn soft_limit_triggers() { - let mut t = BotTurnTracker::new(3); - assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); - assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); - assert_eq!(t.on_bot_message("t1"), TurnResult::SoftLimit(3)); - } - - #[test] - fn human_resets_both_counters() { - let mut t = BotTurnTracker::new(3); - assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); - assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); - t.on_human_message("t1"); - assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); - assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); - assert_eq!(t.on_bot_message("t1"), TurnResult::SoftLimit(3)); - } - - #[test] - fn hard_limit_triggers() { - let mut t = BotTurnTracker::new(HARD_BOT_TURN_LIMIT + 1); - for _ in 0..HARD_BOT_TURN_LIMIT - 1 { - assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); - } - assert_eq!(t.on_bot_message("t1"), TurnResult::HardLimit); - } - - #[test] - fn hard_limit_does_not_fire_at_legacy_100() { - let mut t = BotTurnTracker::new(HARD_BOT_TURN_LIMIT + 1); - for i in 1..=100 { - assert_eq!(t.on_bot_message("t1"), TurnResult::Ok, "turn {i}"); - } - } - - #[test] - fn hard_limit_resets_on_human() { - let mut t = BotTurnTracker::new(HARD_BOT_TURN_LIMIT + 1); - for _ in 0..HARD_BOT_TURN_LIMIT - 1 { - assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); - } - t.on_human_message("t1"); - assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); - } - - #[test] - fn hard_before_soft_when_equal() { - let mut t = BotTurnTracker::new(HARD_BOT_TURN_LIMIT); - for _ in 0..HARD_BOT_TURN_LIMIT - 1 { - assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); - } - assert_eq!(t.on_bot_message("t1"), TurnResult::HardLimit); - } - - #[test] - fn threads_are_independent() { - let mut t = BotTurnTracker::new(3); - assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); - assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); - assert_eq!(t.on_bot_message("t1"), TurnResult::SoftLimit(3)); - assert_eq!(t.on_bot_message("t2"), TurnResult::Ok); - } - - #[test] - fn human_on_unknown_thread_is_noop() { - let mut t = BotTurnTracker::new(5); - t.on_human_message("unknown"); - } - - #[test] - fn two_bot_pingpong_hits_soft_limit() { - let mut t = BotTurnTracker::new(20); - for i in 1..20 { - assert_eq!(t.on_bot_message("t1"), TurnResult::Ok, "turn {i}"); - } - assert_eq!(t.on_bot_message("t1"), TurnResult::SoftLimit(20)); - } - - #[test] - fn two_bot_pingpong_human_resets() { - let mut t = BotTurnTracker::new(20); - for _ in 0..15 { - assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); - } - t.on_human_message("t1"); - for _ in 0..15 { - assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); - } - for _ in 0..4 { - assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); - } - assert_eq!(t.on_bot_message("t1"), TurnResult::SoftLimit(20)); - } - - #[test] - fn soft_limit_warn_once_semantics() { - let mut t = BotTurnTracker::new(20); - for _ in 0..19 { - assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); - } - assert_eq!(t.on_bot_message("t1"), TurnResult::SoftLimit(20)); - assert_eq!(t.on_bot_message("t1"), TurnResult::Throttled); - assert_eq!(t.on_bot_message("t1"), TurnResult::Throttled); - } - - #[test] - fn hard_limit_warn_once_semantics() { - let mut t = BotTurnTracker::new(HARD_BOT_TURN_LIMIT + 1); - for _ in 0..HARD_BOT_TURN_LIMIT - 1 { - assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); - } - assert_eq!(t.on_bot_message("t1"), TurnResult::HardLimit); - assert_eq!(t.on_bot_message("t1"), TurnResult::Stopped); - } - - // System messages (thread created, pin, etc.) must not reset the counter. - // Filtering happens at the call site; this verifies the counter stays put - // when on_human_message is never called. Regression for openabdev/openab#497. - #[test] - fn system_message_does_not_reset_counter() { - let mut t = BotTurnTracker::new(3); - assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); - assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); - assert_eq!(t.on_bot_message("t1"), TurnResult::SoftLimit(3)); - } - - #[test] - fn classify_returns_continue_under_limits() { - let mut t = BotTurnTracker::new(5); - assert_eq!(t.classify_bot_message("t1"), TurnAction::Continue); - } - - #[test] - fn classify_returns_warn_and_stop_on_soft_limit() { - let mut t = BotTurnTracker::new(3); - let _ = t.classify_bot_message("t1"); - let _ = t.classify_bot_message("t1"); - assert_eq!( - t.classify_bot_message("t1"), - TurnAction::WarnAndStop { - severity: TurnSeverity::Soft, - turns: 3, - user_message: format!( - "{} (3/3). \ - A human must reply in this thread to continue bot-to-bot conversation.", - BOT_TURN_LIMIT_WARNING_PREFIX, - ), - }, - ); - } - - #[test] - fn classify_returns_silent_stop_past_soft_limit() { - let mut t = BotTurnTracker::new(2); - let _ = t.classify_bot_message("t1"); - let _ = t.classify_bot_message("t1"); - assert_eq!(t.classify_bot_message("t1"), TurnAction::SilentStop); - assert_eq!(t.classify_bot_message("t1"), TurnAction::SilentStop); - } - - #[test] - fn classify_returns_warn_and_stop_on_hard_limit() { - let mut t = BotTurnTracker::new(HARD_BOT_TURN_LIMIT + 1); - for _ in 0..HARD_BOT_TURN_LIMIT - 1 { - let _ = t.classify_bot_message("t1"); - } - assert_eq!( - t.classify_bot_message("t1"), - TurnAction::WarnAndStop { - severity: TurnSeverity::Hard, - turns: HARD_BOT_TURN_LIMIT, - user_message: format!( - "🛑 Hard bot turn limit reached ({HARD_BOT_TURN_LIMIT}). \ - A human must reply to continue." - ), - }, - ); - assert_eq!(t.classify_bot_message("t1"), TurnAction::SilentStop); - } - - #[test] - fn classify_is_per_thread_independent() { - let mut t = BotTurnTracker::new(2); - assert_eq!(t.classify_bot_message("t1"), TurnAction::Continue); - assert!(matches!( - t.classify_bot_message("t1"), - TurnAction::WarnAndStop { - severity: TurnSeverity::Soft, - .. - }, - )); - assert_eq!(t.classify_bot_message("t2"), TurnAction::Continue); - assert!(matches!( - t.classify_bot_message("t2"), - TurnAction::WarnAndStop { - severity: TurnSeverity::Soft, - .. - }, - )); - } - - // End-to-end: human message must fully reset classify behavior on the - // same thread, including unlocking new `Continue` responses. - #[test] - fn classify_resumes_after_human_message() { - let mut t = BotTurnTracker::new(2); - let _ = t.classify_bot_message("t1"); // Continue - assert!(matches!( - t.classify_bot_message("t1"), - TurnAction::WarnAndStop { .. }, - )); - // Without a human message, the next classify is silent. - assert_eq!(t.classify_bot_message("t1"), TurnAction::SilentStop); - // Human resets — classify starts at Continue again. - t.on_human_message("t1"); - assert_eq!(t.classify_bot_message("t1"), TurnAction::Continue); - assert!(matches!( - t.classify_bot_message("t1"), - TurnAction::WarnAndStop { - severity: TurnSeverity::Soft, - turns: 2, - .. - }, - )); - } -} diff --git a/src/config.rs b/src/config.rs deleted file mode 100644 index 991071164..000000000 --- a/src/config.rs +++ /dev/null @@ -1,1500 +0,0 @@ -use crate::markdown::TableMode; -use regex::Regex; -use serde::Deserialize; -use std::collections::HashMap; -use std::path::Path; - -/// Controls how incoming messages are dispatched to ACP turns. -/// -/// - `Message` (default): each message becomes its own ACP turn (v0.8.2-beta.1 behaviour). -/// - `Thread`: one buffer per thread; all senders in a thread share a single batch and -/// produce one ACP turn per turn boundary. -/// - `Lane`: one buffer per (thread, sender); each sender batches independently and gets -/// its own ACP turn — no silent-drop risk when multiple senders address the same thread. -#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] -pub enum MessageProcessingMode { - #[default] - Message, - Thread, - Lane, -} - -impl<'de> Deserialize<'de> for MessageProcessingMode { - fn deserialize>(deserializer: D) -> Result { - let s = String::deserialize(deserializer)?; - match s.to_lowercase().replace('-', "_").as_str() { - "per_message" => Ok(Self::Message), - "per_thread" => Ok(Self::Thread), - "per_lane" => Ok(Self::Lane), - other => Err(serde::de::Error::unknown_variant( - other, - &["per-message", "per-thread", "per-lane"], - )), - } - } -} - -/// Controls whether the bot processes messages from other Discord bots. -/// -/// Inspired by Hermes Agent's `DISCORD_ALLOW_BOTS` 3-value design: -/// - `Off` (default): ignore all bot messages (safe default, no behavior change) -/// - `Mentions`: only process bot messages that @mention this bot (natural loop breaker) -/// - `All`: process all bot messages (hard-capped at 1000 consecutive bot turns) -/// -/// The bot's own messages are always ignored regardless of this setting. -#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] -pub enum AllowBots { - #[default] - Off, - Mentions, - All, -} - -impl<'de> Deserialize<'de> for AllowBots { - fn deserialize>(deserializer: D) -> Result { - let s = String::deserialize(deserializer)?; - match s.to_lowercase().as_str() { - "off" | "none" | "false" => Ok(Self::Off), - "mentions" => Ok(Self::Mentions), - "all" | "true" => Ok(Self::All), - other => Err(serde::de::Error::unknown_variant( - other, - &["off", "mentions", "all"], - )), - } - } -} - -#[derive(Debug, Clone, Deserialize)] -pub struct AgentCoreConfig { - /// AgentCore Runtime ARN (required) - pub runtime_arn: String, - /// ACP agent command to run in the PTY shell (default: kiro-cli acp --trust-all-tools) - #[serde(default = "default_agentcore_shell_command")] - pub shell_command: String, - /// Cancel strategy: "noop" or "stop" (default: stop) - #[serde(default = "default_agentcore_cancel_strategy")] - #[allow(dead_code)] - pub cancel_strategy: AgentCoreCancelStrategy, -} - -fn default_agentcore_shell_command() -> String { - "kiro-cli acp --trust-all-tools".to_string() -} - -impl AgentCoreConfig { - /// Extract region from ARN: arn:aws:bedrock-agentcore:REGION:ACCOUNT:runtime/ID - pub fn region(&self) -> String { - let parts: Vec<&str> = self.runtime_arn.split(':').collect(); - if parts.len() >= 4 && !parts[3].is_empty() { - return parts[3].to_string(); - } - "us-east-1".into() // fallback (should never hit with valid ARN) - } -} - -#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] -pub enum AgentCoreCancelStrategy { - #[default] - Stop, - Noop, -} - -impl<'de> Deserialize<'de> for AgentCoreCancelStrategy { - fn deserialize>(deserializer: D) -> Result { - let s = String::deserialize(deserializer)?; - match s.to_lowercase().as_str() { - "stop" => Ok(Self::Stop), - "noop" => Ok(Self::Noop), - other => Err(serde::de::Error::unknown_variant(other, &["stop", "noop"])), - } - } -} - -impl std::fmt::Display for AgentCoreCancelStrategy { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Stop => write!(f, "stop"), - Self::Noop => write!(f, "noop"), - } - } -} - -fn default_agentcore_cancel_strategy() -> AgentCoreCancelStrategy { - AgentCoreCancelStrategy::Stop -} - -#[derive(Debug, Deserialize)] -pub struct Config { - pub discord: Option, - pub slack: Option, - pub gateway: Option, - pub agentcore: Option, - #[serde(default)] - pub agent: AgentConfig, - #[serde(default)] - pub pool: PoolConfig, - #[serde(default)] - pub reactions: ReactionsConfig, - #[serde(default)] - pub stt: SttConfig, - #[serde(default)] - pub markdown: MarkdownConfig, - #[serde(default)] - pub cron: CronConfig, - #[serde(default)] - pub hooks: HooksConfig, - #[serde(default)] - pub workspace: WorkspaceConfig, - #[serde(default)] - pub secrets: SecretsConfig, -} - -#[derive(Debug, Clone, Default, Deserialize)] -pub struct WorkspaceConfig { - /// Workspace aliases: `name = "~/path/to/project"` - /// Used with `[[ws:@alias]]` control directives. - #[serde(default)] - pub aliases: std::collections::HashMap, -} - -#[derive(Debug, Clone, Default, Deserialize)] -pub struct SecretsConfig { - /// AWS Secrets Manager configuration. - #[serde(default)] - pub aws: AwsSecretsConfig, - /// Exec provider configuration. - #[serde(default)] - pub exec: ExecSecretsConfig, - /// Secret references: key = "aws-sm://..." or "exec://..." - #[serde(default)] - pub refs: HashMap, -} - -#[derive(Debug, Clone, Default, Deserialize)] -pub struct AwsSecretsConfig { - /// Override AWS region (otherwise uses default credential chain). - pub region: Option, - /// Override endpoint URL (for LocalStack or VPC endpoints). - pub endpoint_url: Option, -} - -#[derive(Debug, Clone, Deserialize)] -pub struct ExecSecretsConfig { - /// Per-invocation timeout in seconds (default: 10). - #[serde(default = "default_exec_timeout")] - pub timeout_seconds: u64, -} - -impl Default for ExecSecretsConfig { - fn default() -> Self { - Self { timeout_seconds: 10 } - } -} - -fn default_exec_timeout() -> u64 { - 10 -} - -#[derive(Debug, Clone, Default, Deserialize)] -pub struct HooksConfig { - pub pre_boot: Option, - pub pre_shutdown: Option, -} - -/// Failure policy for a hook. -#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] -pub enum OnFailure { - #[default] - Abort, - Warn, -} - -impl<'de> Deserialize<'de> for OnFailure { - fn deserialize>(deserializer: D) -> Result { - let s = String::deserialize(deserializer)?; - match s.to_lowercase().as_str() { - "abort" => Ok(Self::Abort), - "warn" => Ok(Self::Warn), - other => Err(serde::de::Error::unknown_variant(other, &["abort", "warn"])), - } - } -} - -/// Configuration for a single hook. Exactly one of `script`, `inline`, or `url` must be set. -#[derive(Debug, Clone, Deserialize)] -pub struct HookConfig { - /// Absolute path to an executable script. - pub script: Option, - /// Inline script content (written to temp file and executed). - pub inline: Option, - /// Remote script URL (fetched and executed). - pub url: Option, - /// SHA-256 checksum of the remote script (required with `url`). - pub sha256: Option, - /// Max wall-clock seconds. Default: 60. - #[serde(default = "default_hook_timeout")] - pub timeout_seconds: u64, - /// Failure policy. Default: abort. - #[serde(default)] - pub on_failure: OnFailure, -} - -fn default_hook_timeout() -> u64 { - 60 -} - -#[derive(Debug, Clone, Default, Deserialize)] -pub struct CronConfig { - /// Enable usercron hot-reload (default: false). Must be explicitly set to true. - #[serde(default)] - pub usercron_enabled: bool, - /// Path to an external cronjob.toml for hot-reloadable user-managed schedules. - pub usercron_path: Option, - /// Baseline cronjob definitions: `[[cron.jobs]]` - #[serde(default)] - pub jobs: Vec, -} - -#[derive(Debug, Clone, Deserialize)] -pub struct SttConfig { - #[serde(default)] - pub enabled: bool, - #[serde(default)] - pub api_key: String, - #[serde(default = "default_stt_model")] - pub model: String, - #[serde(default = "default_stt_base_url")] - pub base_url: String, - /// Echo the transcribed text back to the thread (no mentions) before - /// dispatching the prompt to the agent. Lets users verify STT accuracy. - #[serde(default = "default_echo_transcript")] - pub echo_transcript: bool, -} - -impl Default for SttConfig { - fn default() -> Self { - Self { - enabled: false, - api_key: String::new(), - model: default_stt_model(), - base_url: default_stt_base_url(), - echo_transcript: default_echo_transcript(), - } - } -} - -fn default_stt_model() -> String { - "whisper-large-v3-turbo".into() -} -fn default_stt_base_url() -> String { - "https://api.groq.com/openai/v1".into() -} -fn default_echo_transcript() -> bool { - false -} - -#[derive(Debug, Deserialize)] -pub struct DiscordConfig { - pub bot_token: String, - /// Explicit flag: true = allow all channels, false = check allowed_channels list. - /// When not set, auto-detected: non-empty list → false, empty list → true. - pub allow_all_channels: Option, - /// Explicit flag: true = allow all users, false = check allowed_users list. - /// When not set, auto-detected: non-empty list → false, empty list → true. - pub allow_all_users: Option, - #[serde(default)] - pub allowed_channels: Vec, - #[serde(default)] - pub allowed_users: Vec, - #[serde(default)] - pub allow_bot_messages: AllowBots, - /// When non-empty, only bot messages from these IDs pass the bot gate. - /// Combines with `allow_bot_messages`: the mode check runs first, then - /// the allowlist filters further. Empty = allow any bot (mode permitting). - /// Only relevant when `allow_bot_messages` is `"mentions"` or `"all"`; - /// ignored when `"off"` since all bot messages are rejected before this check. - /// - /// **Admission override**: a trusted bot that explicitly @mentions this bot - /// bypasses the `allow_bot_messages` mode entirely (treated as human @mention). - /// This allows trusted bots to pull this bot into threads regardless of mode. - #[serde(default)] - pub trusted_bot_ids: Vec, - #[serde(default)] - pub allow_user_messages: AllowUsers, - /// Max consecutive bot turns (without human intervention) before throttling. - /// Human message resets the counter. Default: 100. - #[serde(default = "default_max_bot_turns")] - pub max_bot_turns: u32, - /// Role IDs that trigger the bot (same as direct @mention). - /// When a message mentions a role in this list, it is treated as a bot trigger. - /// Empty (default) = role mentions do not trigger the bot. - #[serde(default)] - pub allowed_role_ids: Vec, - /// Allow the bot to respond to Discord direct messages (DMs). - /// Default: false (opt-in). `allowed_users` still applies in DMs. - #[serde(default)] - pub allow_dm: bool, - /// Message dispatch mode. Default: per-message (v0.8.2-beta.1 behaviour). - #[serde(default)] - pub message_processing_mode: MessageProcessingMode, - /// Batched mode only: per-thread channel capacity. Default: 10. - #[serde(default = "default_max_buffered_messages")] - pub max_buffered_messages: usize, - /// Batched mode only: soft token cap for greedy drain. Default: 24000. - #[serde(default = "default_max_batch_tokens")] - pub max_batch_tokens: usize, -} - -fn default_max_bot_turns() -> u32 { - 100 -} -fn default_max_buffered_messages() -> usize { - 10 -} -fn default_max_batch_tokens() -> usize { - 24_000 -} - -/// Controls whether the bot responds to user messages in threads without @mention. -/// -/// - `Involved`: respond to thread messages only if the bot has participated -/// in the thread (posted at least one message, or the thread parent @mentions the bot). -/// Channel/MPDM messages always require @mention. DMs always process (implicit mention). -/// - `Mentions`: always require @mention, even in threads the bot is participating in. -/// - `MultibotMentions` (default): same as `Involved` in single-bot threads; falls back to -/// `Mentions` when other bots have also posted in the thread. -#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] -pub enum AllowUsers { - Involved, - Mentions, - #[default] - MultibotMentions, -} - -impl<'de> Deserialize<'de> for AllowUsers { - fn deserialize>(deserializer: D) -> Result { - let s = String::deserialize(deserializer)?; - match s.to_lowercase().replace('-', "_").as_str() { - "involved" => Ok(Self::Involved), - "mentions" => Ok(Self::Mentions), - "multibot_mentions" => Ok(Self::MultibotMentions), - other => Err(serde::de::Error::unknown_variant( - other, - &["involved", "mentions", "multibot-mentions"], - )), - } - } -} - -#[derive(Debug, Deserialize)] -pub struct SlackConfig { - pub bot_token: String, - pub app_token: String, - /// Explicit flag: true = allow all channels, false = check allowed_channels list. - /// When not set, auto-detected: non-empty list → false, empty list → true. - pub allow_all_channels: Option, - /// Explicit flag: true = allow all users, false = check allowed_users list. - /// When not set, auto-detected: non-empty list → false, empty list → true. - pub allow_all_users: Option, - #[serde(default)] - pub allowed_channels: Vec, - #[serde(default)] - pub allowed_users: Vec, - #[serde(default)] - pub allow_bot_messages: AllowBots, - /// Bot User IDs (U...) allowed to interact when allow_bot_messages is - /// "mentions" or "all". Find via Slack UI: click bot profile → Copy member ID. - /// Empty = allow any bot (mode permitting). - #[serde(default)] - pub trusted_bot_ids: Vec, - #[serde(default)] - pub allow_user_messages: AllowUsers, - /// Max consecutive bot turns (without human intervention) before throttling. - /// Human message resets the counter. Default: 100. - #[serde(default = "default_max_bot_turns")] - pub max_bot_turns: u32, - /// Message dispatch mode. Default: per-message. - #[serde(default)] - pub message_processing_mode: MessageProcessingMode, - /// Batched mode only: per-thread channel capacity. Default: 10. - #[serde(default = "default_max_buffered_messages")] - pub max_buffered_messages: usize, - /// Batched mode only: soft token cap for greedy drain. Default: 24000. - #[serde(default = "default_max_batch_tokens")] - pub max_batch_tokens: usize, - /// Slack "AI app / Assistant" mode: stream replies via chat.startStream + - /// assistant.threads.setStatus instead of post+edit + emoji reactions. - /// Requires the Slack app to be an AI app (assistant feature enabled) with - /// the `assistant:write` scope. Default: true — set to false for Slack apps - /// that are not AI apps (no `assistant:write`) to keep emoji-reaction status. - #[serde(default = "default_true")] - pub assistant_mode: bool, -} - -#[derive(Debug, Deserialize)] -pub struct GatewayConfig { - /// WebSocket URL of the custom gateway (e.g. ws://gateway:8080/ws) - pub url: String, - /// Platform name for session key namespacing (e.g. "telegram", "line") - #[serde(default = "default_gateway_platform")] - pub platform: String, - /// Shared token for WebSocket authentication (optional but recommended) - pub token: Option, - /// Bot username for @mention gating in groups (e.g. "my_bot") - pub bot_username: Option, - /// Explicit flag: true = allow all channels, false = check allowed_channels list. - /// When not set, auto-detected: non-empty list → false, empty list → true. - pub allow_all_channels: Option, - /// Explicit flag: true = allow all users, false = check allowed_users list. - /// When not set, auto-detected: non-empty list → false, empty list → true. - pub allow_all_users: Option, - #[serde(default)] - pub allowed_channels: Vec, - #[serde(default)] - pub allowed_users: Vec, - /// Enable streaming (typewriter) mode — requires gateway platform to support message editing. - #[serde(default)] - pub streaming: bool, - /// Show "…" placeholder at streaming start. Default: true. Set false for platforms using drafts. - #[serde(default = "default_true")] - pub streaming_placeholder: bool, - /// Message dispatch mode. Default: per-message. - #[serde(default)] - pub message_processing_mode: MessageProcessingMode, - /// Batched mode only: per-thread channel capacity. Default: 10. - #[serde(default = "default_max_buffered_messages")] - pub max_buffered_messages: usize, - /// Batched mode only: soft token cap for greedy drain. Default: 24000. - #[serde(default = "default_max_batch_tokens")] - pub max_batch_tokens: usize, -} - -fn default_gateway_platform() -> String { - "telegram".into() -} - -/// Raw intermediate struct for serde — uses `Option` to detect explicit fields. -#[derive(Debug, Deserialize)] -#[serde(default)] -struct AgentConfigRaw { - command: Option, - args: Option>, - working_dir: String, - env: HashMap, - inherit_env: Vec, -} - -impl Default for AgentConfigRaw { - fn default() -> Self { - Self { - command: None, - args: None, - working_dir: default_working_dir(), - env: HashMap::new(), - inherit_env: Vec::new(), - } - } -} - -#[derive(Debug)] -pub struct AgentConfig { - pub command: String, - pub args: Vec, - pub working_dir: String, - pub env: HashMap, - pub inherit_env: Vec, - /// Whether the command was explicitly set in config (vs defaulted from env/fallback). - pub command_explicit: bool, -} - -impl Default for AgentConfig { - fn default() -> Self { - Self { - command: default_agent_command(), - args: default_agent_args(), - working_dir: default_working_dir(), - env: HashMap::new(), - inherit_env: Vec::new(), - command_explicit: false, - } - } -} - -impl<'de> serde::Deserialize<'de> for AgentConfig { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - let raw = AgentConfigRaw::deserialize(deserializer)?; - let cmd_explicit = raw.command.is_some(); - let command = raw.command.unwrap_or_else(default_agent_command); - // If command was explicitly set but args was not, default args to [] - // to avoid leaking env-var args into a custom command. - let args = match (cmd_explicit, raw.args) { - (_, Some(args)) => args, // args explicitly set → use them - (true, None) => Vec::new(), // command set, args omitted → empty - (false, None) => default_agent_args(), // neither set → env var - }; - Ok(AgentConfig { - command, - args, - working_dir: raw.working_dir, - env: raw.env, - inherit_env: raw.inherit_env, - command_explicit: cmd_explicit, - }) - } -} - -#[derive(Debug, Deserialize)] -pub struct PoolConfig { - #[serde(default = "default_max_sessions")] - pub max_sessions: usize, - #[serde(default = "default_ttl_hours")] - pub session_ttl_hours: u64, - /// Hard ceiling for a single prompt (#732). Once exceeded, the broker - /// abandons the in-flight request, sends `session/cancel` to the agent, - /// and clears the pending entry so late responses cannot leak into the - /// next prompt's subscriber. - /// - /// Precision: checked every `liveness_check_secs`, so actual cutoff is - /// ±`liveness_check_secs` from this value. - #[serde(default = "default_prompt_hard_timeout_secs")] - pub prompt_hard_timeout_secs: u64, - /// Polling cadence (seconds) for the recv-loop liveness check (#732). - /// Lower = faster reaction to a dead agent / hard ceiling at the cost of - /// more wakeups while the agent is streaming normally. - #[serde(default = "default_liveness_check_secs")] - pub liveness_check_secs: u64, -} - -#[derive(Debug, Clone, Deserialize)] -pub struct CronJobConfig { - /// Stable ID for usercron jobs that need scheduler writeback. - pub id: Option, - /// Whether this cronjob is active (default: true) - #[serde(default = "default_true")] - pub enabled: bool, - /// Cron expression (5-field POSIX format) - pub schedule: String, - /// Target channel ID - pub channel: String, - /// Message to send to the agent - pub message: String, - /// Target platform (default: "discord") - #[serde(default = "default_cron_platform")] - pub platform: String, - /// Sender name for attribution (default: "openab-cron") - #[serde(default = "default_cron_sender")] - pub sender_name: String, - /// Optional thread ID (post to existing thread) - pub thread_id: Option, - /// Timezone (default: "UTC") - #[serde(default = "default_cron_timezone")] - pub timezone: String, - /// Usercron-only: command to run before firing. Exit 0 plus a matching - /// `disable_on_success_match` means the goal is complete and the scheduler - /// disables the job in the usercron file. - pub disable_on_success: Option, - /// Usercron-only: required output marker for `disable_on_success`. - pub disable_on_success_match: Option, - /// Usercron-only: timeout for `disable_on_success`. - #[serde(default = "default_disable_on_success_timeout_secs")] - pub disable_on_success_timeout_secs: u64, - /// Usercron-only: working directory for `disable_on_success`. - pub disable_on_success_working_dir: Option, -} - -fn default_cron_platform() -> String { - "discord".into() -} -fn default_cron_sender() -> String { - "openab-cron".into() -} -fn default_cron_timezone() -> String { - "UTC".into() -} -fn default_disable_on_success_timeout_secs() -> u64 { - 60 -} - -/// Controls how tool calls are rendered in chat messages. -/// -/// - `full`: show complete tool title including arguments (default, original behavior) -/// - `compact`: show only a count summary, e.g. `✅ 3 · 🔧 1 tool(s)` -/// - `none`: hide tool lines entirely, only show final response -#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] -pub enum ToolDisplay { - #[default] - Full, - Compact, - None, -} - -impl<'de> Deserialize<'de> for ToolDisplay { - fn deserialize>(deserializer: D) -> Result { - let s = String::deserialize(deserializer)?; - match s.to_lowercase().as_str() { - "full" => Ok(Self::Full), - "compact" => Ok(Self::Compact), - "none" | "off" | "hidden" => Ok(Self::None), - other => Err(serde::de::Error::unknown_variant( - other, - &["full", "compact", "none"], - )), - } - } -} - -#[derive(Debug, Clone, Deserialize)] -pub struct ReactionsConfig { - #[serde(default = "default_true")] - pub enabled: bool, - #[serde(default)] - pub remove_after_reply: bool, - #[serde(default)] - pub tool_display: ToolDisplay, - #[serde(default)] - pub emojis: ReactionEmojis, - #[serde(default)] - pub timing: ReactionTiming, -} - -#[derive(Debug, Clone, Deserialize)] -pub struct ReactionEmojis { - #[serde(default = "emoji_queued")] - pub queued: String, - #[serde(default = "emoji_thinking")] - pub thinking: String, - #[serde(default = "emoji_tool")] - pub tool: String, - #[serde(default = "emoji_coding")] - pub coding: String, - #[serde(default = "emoji_web")] - pub web: String, - #[serde(default = "emoji_done")] - pub done: String, - #[serde(default = "emoji_error")] - pub error: String, -} - -#[derive(Debug, Clone, Deserialize)] -pub struct ReactionTiming { - #[serde(default = "default_debounce_ms")] - pub debounce_ms: u64, - #[serde(default = "default_stall_soft_ms")] - pub stall_soft_ms: u64, - #[serde(default = "default_stall_hard_ms")] - pub stall_hard_ms: u64, - #[serde(default = "default_done_hold_ms")] - pub done_hold_ms: u64, - #[serde(default = "default_error_hold_ms")] - pub error_hold_ms: u64, -} - -// --- defaults --- - -fn default_working_dir() -> String { - std::env::var("HOME").unwrap_or_else(|_| "/tmp".into()) -} -fn default_agent_command() -> String { - if let Ok(val) = std::env::var("OPENAB_AGENT_COMMAND") { - if let Some(cmd) = val.split_whitespace().next() { - return cmd.to_string(); - } - } - "openab-agent".into() -} -fn default_agent_args() -> Vec { - if let Ok(val) = std::env::var("OPENAB_AGENT_COMMAND") { - let parts: Vec<&str> = val.split_whitespace().collect(); - if parts.len() > 1 { - return parts[1..].iter().map(|s| s.to_string()).collect(); - } - } - Vec::new() -} -fn default_max_sessions() -> usize { - 10 -} -fn default_ttl_hours() -> u64 { - 4 -} -pub(crate) fn default_prompt_hard_timeout_secs() -> u64 { - 30 * 60 -} -pub(crate) fn default_liveness_check_secs() -> u64 { - 30 -} -fn default_true() -> bool { - true -} - -fn emoji_queued() -> String { - "👀".into() -} -fn emoji_thinking() -> String { - "🤔".into() -} -fn emoji_tool() -> String { - "🔥".into() -} -fn emoji_coding() -> String { - "👨‍💻".into() -} -fn emoji_web() -> String { - "⚡".into() -} -fn emoji_done() -> String { - "🆗".into() -} -fn emoji_error() -> String { - "😱".into() -} - -fn default_debounce_ms() -> u64 { - 700 -} -fn default_stall_soft_ms() -> u64 { - 10_000 -} -fn default_stall_hard_ms() -> u64 { - 30_000 -} -fn default_done_hold_ms() -> u64 { - 1_500 -} -fn default_error_hold_ms() -> u64 { - 2_500 -} - -impl Default for PoolConfig { - fn default() -> Self { - Self { - max_sessions: default_max_sessions(), - session_ttl_hours: default_ttl_hours(), - prompt_hard_timeout_secs: default_prompt_hard_timeout_secs(), - liveness_check_secs: default_liveness_check_secs(), - } - } -} - -impl Default for ReactionsConfig { - fn default() -> Self { - Self { - enabled: true, - remove_after_reply: false, - tool_display: ToolDisplay::default(), - emojis: ReactionEmojis::default(), - timing: ReactionTiming::default(), - } - } -} - -impl Default for ReactionEmojis { - fn default() -> Self { - Self { - queued: emoji_queued(), - thinking: emoji_thinking(), - tool: emoji_tool(), - coding: emoji_coding(), - web: emoji_web(), - done: emoji_done(), - error: emoji_error(), - } - } -} - -impl Default for ReactionTiming { - fn default() -> Self { - Self { - debounce_ms: default_debounce_ms(), - stall_soft_ms: default_stall_soft_ms(), - stall_hard_ms: default_stall_hard_ms(), - done_hold_ms: default_done_hold_ms(), - error_hold_ms: default_error_hold_ms(), - } - } -} - -// --- markdown --- - -#[derive(Debug, Clone, Default, Deserialize)] -pub struct MarkdownConfig { - #[serde(default)] - pub tables: TableMode, -} - -// --- loading --- - -/// Resolve an allow_all flag: if explicitly set, use it; otherwise infer from the list. -/// Non-empty list → false (respect the list), empty list → true (allow all). -pub fn resolve_allow_all(flag: Option, list: &[String]) -> bool { - flag.unwrap_or(list.is_empty()) -} - -fn expand_env_vars(raw: &str) -> String { - let re = Regex::new(r"\$\{(\w+)\}").unwrap(); - re.replace_all(raw, |caps: ®ex::Captures| { - std::env::var(&caps[1]).unwrap_or_default() - }) - .into_owned() -} - -/// Load raw config text from a file path (env vars expanded but secrets NOT resolved). -pub fn load_config_raw(path: &Path) -> anyhow::Result { - let raw = std::fs::read_to_string(path) - .map_err(|e| anyhow::anyhow!("failed to read {}: {e}", path.display()))?; - Ok(expand_env_vars(&raw)) -} - -/// Load raw config text from a URL (env vars expanded but secrets NOT resolved). -pub async fn load_config_raw_from_url(url: &str) -> anyhow::Result { - let client = reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(10)) - .build()?; - let resp = client - .get(url) - .send() - .await - .map_err(|e| anyhow::anyhow!("failed to fetch remote config from {url}: {e}"))?; - let status = resp.status(); - if !status.is_success() { - anyhow::bail!("remote config request to {url} returned HTTP {status}"); - } - let bytes = resp - .bytes() - .await - .map_err(|e| anyhow::anyhow!("failed to read response body from {url}: {e}"))?; - const MAX_CONFIG_BYTES: usize = 1024 * 1024; - if bytes.len() > MAX_CONFIG_BYTES { - anyhow::bail!( - "remote config from {url} exceeds 1 MiB limit ({} bytes)", - bytes.len() - ); - } - let raw = String::from_utf8(bytes.to_vec()) - .map_err(|e| anyhow::anyhow!("remote config from {url} is not valid UTF-8: {e}"))?; - Ok(expand_env_vars(&raw)) -} - -/// Parse config from already-expanded text. -pub fn parse_config_str(expanded: &str, source: &str) -> anyhow::Result { - parse_config_inner(expanded, source) -} - -#[cfg(test)] -fn parse_config(raw: &str, source: &str) -> anyhow::Result { - let expanded = expand_env_vars(raw); - parse_config_inner(&expanded, source) -} - -#[cfg(test)] -fn load_config(path: &Path) -> anyhow::Result { - let raw = std::fs::read_to_string(path) - .map_err(|e| anyhow::anyhow!("failed to read {}: {e}", path.display()))?; - parse_config(&raw, path.display().to_string().as_str()) -} - -#[cfg(test)] -async fn load_config_from_url(url: &str) -> anyhow::Result { - let client = reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(10)) - .build()?; - let resp = client - .get(url) - .send() - .await - .map_err(|e| anyhow::anyhow!("failed to fetch remote config from {url}: {e}"))?; - let status = resp.status(); - if !status.is_success() { - anyhow::bail!("remote config request to {url} returned HTTP {status}"); - } - let bytes = resp - .bytes() - .await - .map_err(|e| anyhow::anyhow!("failed to read response body from {url}: {e}"))?; - let raw = String::from_utf8(bytes.to_vec()) - .map_err(|e| anyhow::anyhow!("remote config from {url} is not valid UTF-8: {e}"))?; - parse_config(&raw, url) -} - -fn parse_config_inner(expanded: &str, source: &str) -> anyhow::Result { - let mut config: Config = toml::from_str(expanded) - .map_err(|e| anyhow::anyhow!("failed to parse config from {source}: {e}"))?; - - // If [agentcore] is set and [agent] command was not explicitly provided, - // synthesize agent config to spawn the bundled agentcore-acp adapter. - if let Some(ref ac) = config.agentcore { - // Validate ARN format: arn:aws:bedrock-agentcore:REGION:ACCOUNT:runtime/ID - let parts: Vec<&str> = ac.runtime_arn.split(':').collect(); - anyhow::ensure!( - parts.len() >= 6 - && parts[0] == "arn" - && parts[2] == "bedrock-agentcore" - && !parts[3].is_empty() - && parts[5].starts_with("runtime/"), - "agentcore.runtime_arn is not a valid AgentCore Runtime ARN \ - (expected arn:aws:bedrock-agentcore:REGION:ACCOUNT:runtime/ID, got \"{}\")", - ac.runtime_arn - ); - - if !config.agent.command_explicit { - // Use native Rust bridge (agentcore feature) or fall back to Python adapter - #[cfg(feature = "agentcore")] - let (cmd, args) = { - let self_exe = std::env::current_exe() - .map(|p| p.to_string_lossy().to_string()) - .unwrap_or_else(|_| "openab".to_string()); - ( - self_exe, - vec![ - "agentcore-bridge".into(), - "--runtime-arn".into(), - ac.runtime_arn.clone(), - "--region".into(), - ac.region(), - "--command".into(), - ac.shell_command.clone(), - ], - ) - }; - #[cfg(not(feature = "agentcore"))] - let (cmd, args) = ( - "uv".to_string(), - vec![ - "run".into(), - "--script".into(), - "/opt/agentcore/acp/agentcore_acp.py".into(), - "--runtime-arn".into(), - ac.runtime_arn.clone(), - "--region".into(), - ac.region(), - "--cancel-strategy".into(), - ac.cancel_strategy.to_string(), - ], - ); - config.agent = AgentConfig { - command: cmd, - args, - working_dir: config.agent.working_dir.clone(), - env: config.agent.env.clone(), - inherit_env: config.agent.inherit_env.clone(), - command_explicit: true, // synthesized counts as explicit - }; - } - } - - // Validate max_buffered_messages > 0 (tokio::sync::mpsc::channel panics on 0) - // and max_batch_tokens > 0 (otherwise the consumer's token-cap check forces every - // batch to size 1 — functionally per-message via a confusing path). - if let Some(ref d) = config.discord { - anyhow::ensure!( - d.max_buffered_messages > 0, - "discord.max_buffered_messages must be > 0" - ); - anyhow::ensure!( - d.max_batch_tokens > 0, - "discord.max_batch_tokens must be > 0" - ); - } - if let Some(ref s) = config.slack { - anyhow::ensure!( - s.max_buffered_messages > 0, - "slack.max_buffered_messages must be > 0" - ); - anyhow::ensure!(s.max_batch_tokens > 0, "slack.max_batch_tokens must be > 0"); - } - if let Some(ref g) = config.gateway { - anyhow::ensure!( - g.max_buffered_messages > 0, - "gateway.max_buffered_messages must be > 0" - ); - anyhow::ensure!( - g.max_batch_tokens > 0, - "gateway.max_batch_tokens must be > 0" - ); - } - anyhow::ensure!( - config.pool.liveness_check_secs > 0, - "pool.liveness_check_secs must be > 0 (zero would spin the recv loop)" - ); - - Ok(config) -} - -#[cfg(test)] -mod tests { - use super::*; - use std::io::Write; - - const MINIMAL_TOML: &str = r#" -[discord] -bot_token = "test-token" - -[agent] -command = "echo" -"#; - - #[test] - fn parse_minimal_config() { - let cfg = parse_config(MINIMAL_TOML, "test").unwrap(); - assert_eq!(cfg.discord.unwrap().bot_token, "test-token"); - assert_eq!(cfg.agent.command, "echo"); - assert_eq!(cfg.pool.max_sessions, 10); - assert!(cfg.reactions.enabled); - } - - #[test] - fn expand_env_vars_replaces_known_var() { - std::env::set_var("AB_TEST_VAR", "hello"); - let result = expand_env_vars("token=${AB_TEST_VAR}"); - assert_eq!(result, "token=hello"); - std::env::remove_var("AB_TEST_VAR"); - } - - #[test] - fn expand_env_vars_unknown_becomes_empty() { - let result = expand_env_vars("token=${AB_NONEXISTENT_12345}"); - assert_eq!(result, "token="); - } - - #[test] - fn expand_env_vars_in_config() { - std::env::set_var("AB_TEST_TOKEN", "secret-bot-token"); - let toml = r#" -[discord] -bot_token = "${AB_TEST_TOKEN}" - -[agent] -command = "echo" -"#; - let cfg = parse_config(toml, "test").unwrap(); - assert_eq!(cfg.discord.unwrap().bot_token, "secret-bot-token"); - std::env::remove_var("AB_TEST_TOKEN"); - } - - #[test] - fn parse_invalid_toml_returns_error() { - let result = parse_config("not valid toml {{{}}", "test"); - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("failed to parse config from test")); - } - - #[test] - fn load_config_missing_file_returns_error() { - let result = load_config(Path::new("/tmp/agent-broker-nonexistent.toml")); - assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("failed to read")); - } - - #[test] - fn load_config_from_file() { - let mut tmp = tempfile::NamedTempFile::new().unwrap(); - write!(tmp, "{}", MINIMAL_TOML).unwrap(); - let cfg = load_config(tmp.path()).unwrap(); - assert_eq!(cfg.discord.unwrap().bot_token, "test-token"); - } - - #[tokio::test] - async fn load_config_from_url_invalid_host() { - let result = load_config_from_url("https://invalid.test.example/config.toml").await; - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("failed to fetch remote config")); - } - - #[test] - fn parse_gateway_config_defaults() { - let toml = r#" -[gateway] -url = "ws://gw:8080/ws" - -[agent] -command = "echo" -"#; - let cfg = parse_config(toml, "test").unwrap(); - let gw = cfg.gateway.unwrap(); - assert_eq!(gw.url, "ws://gw:8080/ws"); - assert_eq!(gw.platform, "telegram"); - assert!(gw.allowed_users.is_empty()); - assert!(gw.allowed_channels.is_empty()); - assert!(gw.allow_all_users.is_none()); - assert!(gw.allow_all_channels.is_none()); - // resolve_allow_all: empty lists → allow all - assert!(resolve_allow_all(gw.allow_all_users, &gw.allowed_users)); - assert!(resolve_allow_all( - gw.allow_all_channels, - &gw.allowed_channels - )); - } - - #[test] - fn parse_gateway_config_with_allowlists() { - let toml = r#" -[gateway] -url = "ws://gw:8080/ws" -platform = "line" -allowed_users = ["U1", "U2"] -allowed_channels = ["C1"] - -[agent] -command = "echo" -"#; - let cfg = parse_config(toml, "test").unwrap(); - let gw = cfg.gateway.unwrap(); - assert_eq!(gw.platform, "line"); - assert_eq!(gw.allowed_users, vec!["U1", "U2"]); - assert_eq!(gw.allowed_channels, vec!["C1"]); - // resolve_allow_all: non-empty lists → restricted - assert!(!resolve_allow_all(gw.allow_all_users, &gw.allowed_users)); - assert!(!resolve_allow_all( - gw.allow_all_channels, - &gw.allowed_channels - )); - } - - #[test] - fn tool_display_default_is_full() { - assert_eq!(ToolDisplay::default(), ToolDisplay::Full); - } - - #[test] - fn message_processing_mode_parses_per_message() { - let toml = r#" -[discord] -bot_token = "t" -message_processing_mode = "per-message" - -[agent] -command = "echo" -"#; - let cfg = parse_config(toml, "test").unwrap(); - assert_eq!( - cfg.discord.unwrap().message_processing_mode, - MessageProcessingMode::Message - ); - } - - #[test] - fn message_processing_mode_parses_per_thread() { - let toml = r#" -[discord] -bot_token = "t" -message_processing_mode = "per-thread" - -[agent] -command = "echo" -"#; - let cfg = parse_config(toml, "test").unwrap(); - assert_eq!( - cfg.discord.unwrap().message_processing_mode, - MessageProcessingMode::Thread - ); - } - - #[test] - fn message_processing_mode_parses_per_lane() { - let toml = r#" -[discord] -bot_token = "t" -message_processing_mode = "per-lane" - -[agent] -command = "echo" -"#; - let cfg = parse_config(toml, "test").unwrap(); - assert_eq!( - cfg.discord.unwrap().message_processing_mode, - MessageProcessingMode::Lane - ); - } - - // The legacy alias "batched" was removed: only per-message / per-thread / per-lane - // are accepted. Configs still using "batched" must migrate to an explicit value. - #[test] - fn message_processing_mode_batched_is_rejected() { - let toml = r#" -[discord] -bot_token = "t" -message_processing_mode = "batched" - -[agent] -command = "echo" -"#; - assert!(parse_config(toml, "test").is_err()); - } - - #[test] - fn message_processing_mode_default_is_per_message() { - let cfg = parse_config(MINIMAL_TOML, "test").unwrap(); - assert_eq!( - cfg.discord.unwrap().message_processing_mode, - MessageProcessingMode::Message - ); - } - - #[test] - fn message_processing_mode_unknown_value_errors() { - let toml = r#" -[discord] -bot_token = "t" -message_processing_mode = "bogus" - -[agent] -command = "echo" -"#; - assert!(parse_config(toml, "test").is_err()); - } - - #[test] - fn parse_gateway_config_explicit_allow_all_overrides_list() { - let toml = r#" -[gateway] -url = "ws://gw:8080/ws" -allow_all_users = true -allowed_users = ["U1"] - -[agent] -command = "echo" -"#; - let cfg = parse_config(toml, "test").unwrap(); - let gw = cfg.gateway.unwrap(); - // explicit flag overrides non-empty list - assert!(resolve_allow_all(gw.allow_all_users, &gw.allowed_users)); - } - - #[test] - fn stt_echo_transcript_defaults_to_false() { - let cfg = SttConfig::default(); - assert!( - !cfg.echo_transcript, - "echo_transcript should default to false" - ); - } - - #[test] - fn stt_echo_transcript_respects_explicit_false() { - let toml = r#" -[agent] -command = "echo" - -[stt] -enabled = true -api_key = "test" -echo_transcript = false -"#; - let cfg = parse_config(toml, "test").unwrap(); - assert!(cfg.stt.enabled); - assert!(!cfg.stt.echo_transcript); - } - - #[test] - fn parse_secrets_config() { - let toml = r#" -[discord] -bot_token = "${secrets.discord_token}" - -[agent] -command = "echo" - -[secrets.refs] -discord_token = "aws-sm://openab/prod#discord_bot_token" -github_pat = "exec:///home/agent/.local/bin/get-secret.sh vault/openab github_pat" - -[secrets.aws] -region = "ap-northeast-1" -endpoint_url = "http://localhost:4566" - -[secrets.exec] -timeout_seconds = 15 -"#; - let cfg = parse_config(toml, "test").unwrap(); - assert_eq!(cfg.secrets.refs.len(), 2); - assert_eq!( - cfg.secrets.refs.get("discord_token").unwrap(), - "aws-sm://openab/prod#discord_bot_token" - ); - assert_eq!( - cfg.secrets.refs.get("github_pat").unwrap(), - "exec:///home/agent/.local/bin/get-secret.sh vault/openab github_pat" - ); - assert_eq!(cfg.secrets.aws.region.as_deref(), Some("ap-northeast-1")); - assert_eq!( - cfg.secrets.aws.endpoint_url.as_deref(), - Some("http://localhost:4566") - ); - assert_eq!(cfg.secrets.exec.timeout_seconds, 15); - } - - #[test] - fn parse_secrets_config_defaults() { - let toml = r#" -[discord] -bot_token = "test" - -[agent] -command = "echo" -"#; - let cfg = parse_config(toml, "test").unwrap(); - assert!(cfg.secrets.refs.is_empty()); - assert!(cfg.secrets.aws.region.is_none()); - assert!(cfg.secrets.aws.endpoint_url.is_none()); - assert_eq!(cfg.secrets.exec.timeout_seconds, 10); - } - - #[test] - fn slack_assistant_mode_defaults_true_and_parses_false() { - let cfg: SlackConfig = toml::from_str("bot_token = \"x\"\napp_token = \"y\"\n").unwrap(); - assert!(cfg.assistant_mode, "assistant_mode must default to true"); - - let cfg2: SlackConfig = - toml::from_str("bot_token = \"x\"\napp_token = \"y\"\nassistant_mode = false\n") - .unwrap(); - assert!(!cfg2.assistant_mode); - } - - #[test] - fn agentcore_config_synthesizes_agent_command() { - let toml = r#" -[discord] -bot_token = "t" - -[agentcore] -runtime_arn = "arn:aws:bedrock-agentcore:us-east-1:123456789012:runtime/my-agent" -"#; - let cfg = parse_config(toml, "test").unwrap(); - #[cfg(feature = "agentcore")] - { - // With agentcore feature, spawns self with agentcore-bridge subcommand - assert!(cfg.agent.args.contains(&"agentcore-bridge".to_string())); - } - #[cfg(not(feature = "agentcore"))] - { - assert_eq!(cfg.agent.command, "uv"); - } - assert!(cfg.agent.args.contains(&"--runtime-arn".to_string())); - assert!(cfg - .agent - .args - .contains(&"arn:aws:bedrock-agentcore:us-east-1:123456789012:runtime/my-agent".to_string())); - } - - #[test] - fn agentcore_config_does_not_override_explicit_agent() { - let toml = r#" -[discord] -bot_token = "t" - -[agent] -command = "my-custom-agent" - -[agentcore] -runtime_arn = "arn:aws:bedrock-agentcore:us-east-1:123456789012:runtime/my-agent" -"#; - let cfg = parse_config(toml, "test").unwrap(); - assert_eq!(cfg.agent.command, "my-custom-agent"); - } - - #[test] - fn agentcore_config_defaults() { - let toml = r#" -[discord] -bot_token = "t" - -[agentcore] -runtime_arn = "arn:aws:bedrock-agentcore:us-east-1:123456789012:runtime/test" -"#; - let cfg = parse_config(toml, "test").unwrap(); - let ac = cfg.agentcore.unwrap(); - assert_eq!(ac.region(), "us-east-1"); - assert_eq!(ac.cancel_strategy, AgentCoreCancelStrategy::Stop); - } - - #[test] - fn agentcore_rejects_invalid_arn() { - let toml = r#" -[discord] -bot_token = "t" - -[agentcore] -runtime_arn = "not-a-valid-arn" -"#; - let err = parse_config(toml, "test").unwrap_err(); - assert!(err.to_string().contains("not a valid AgentCore Runtime ARN")); - } - - #[test] - fn agentcore_rejects_arn_wrong_service() { - let toml = r#" -[discord] -bot_token = "t" - -[agentcore] -runtime_arn = "arn:aws:s3:us-east-1:123456789012:bucket/my-bucket" -"#; - let err = parse_config(toml, "test").unwrap_err(); - assert!(err.to_string().contains("not a valid AgentCore Runtime ARN")); - } - - #[test] - fn agentcore_rejects_arn_missing_runtime_prefix() { - let toml = r#" -[discord] -bot_token = "t" - -[agentcore] -runtime_arn = "arn:aws:bedrock-agentcore:us-east-1:123456789012:agent/my-agent" -"#; - let err = parse_config(toml, "test").unwrap_err(); - assert!(err.to_string().contains("not a valid AgentCore Runtime ARN")); - } - - #[test] - fn agentcore_rejects_invalid_cancel_strategy() { - let toml = r#" -[discord] -bot_token = "t" - -[agentcore] -runtime_arn = "arn:aws:bedrock-agentcore:us-east-1:123456789012:runtime/test" -cancel_strategy = "stopp" -"#; - let err = parse_config(toml, "test").unwrap_err(); - assert!(err.to_string().contains("unknown variant")); - } - - #[test] - fn agentcore_extracts_region_from_arn() { - let toml = r#" -[discord] -bot_token = "t" - -[agentcore] -runtime_arn = "arn:aws:bedrock-agentcore:ap-northeast-1:123456789012:runtime/tokyo-agent" -"#; - let cfg = parse_config(toml, "test").unwrap(); - assert!(cfg.agent.args.contains(&"ap-northeast-1".to_string())); - } - - #[test] - fn agentcore_cancel_strategy_noop() { - let toml = r#" -[discord] -bot_token = "t" - -[agentcore] -runtime_arn = "arn:aws:bedrock-agentcore:us-east-1:123456789012:runtime/test" -cancel_strategy = "noop" -"#; - let cfg = parse_config(toml, "test").unwrap(); - let ac = cfg.agentcore.unwrap(); - assert_eq!(ac.cancel_strategy, AgentCoreCancelStrategy::Noop); - } -} diff --git a/src/cron.rs b/src/cron.rs deleted file mode 100644 index db5828b22..000000000 --- a/src/cron.rs +++ /dev/null @@ -1,1768 +0,0 @@ -use crate::adapter::{AdapterRouter, ChannelRef, ChatAdapter, SenderContext}; -use crate::config::CronJobConfig; -use crate::format; -use chrono::{Timelike, Utc}; -use chrono_tz::Tz; -use cron::Schedule; -use std::collections::{HashMap, HashSet}; -use std::path::{Path, PathBuf}; -use std::str::FromStr; -use std::sync::Arc; -use std::time::SystemTime; -use tokio::process::Command; -use tokio::sync::Mutex; -use toml_edit::{value, DocumentMut}; -use tracing::{debug, error, info, warn}; - -/// Parse a 5-field POSIX cron expression into a `Schedule`. -/// -/// The `cron` crate expects a 6-field expression (with seconds), so we prepend "0". -/// -/// POSIX numeric day-of-week values (0..=7, where 0 or 7 = Sunday) are translated -/// to the `cron` crate's 1-based form (1..=7, where 1 = Sunday) before being handed -/// to the underlying parser. Without this, numeric day-of-week values are off by one -/// — e.g. `1-5` (Mon-Fri in POSIX) would be evaluated as Sun-Thu. See the -/// [`translate_posix_dow_field`] doc comment for details. -/// -/// Name-based day-of-week tokens (`Mon`, `Sun`, `Mon-Fri`, ...) are passed through -/// unchanged — the `cron` crate's internal name-to-ordinal map is consistent. -pub fn parse_cron_expr(expr: &str) -> Result { - let translated = translate_posix_cron_expr(expr)?; - let six_field = format!("0 {}", translated); - Schedule::from_str(&six_field).map_err(|e| e.to_string()) -} - -/// Translate a 5-field POSIX cron expression so the day-of-week field uses the -/// numeric convention of the `cron` crate. -/// -/// Only the 5th field (day-of-week) is rewritten; the other four fields pass -/// through unchanged. -fn translate_posix_cron_expr(expr: &str) -> Result { - let fields: Vec<&str> = expr.split_whitespace().collect(); - if fields.len() != 5 { - return Err(format!( - "expected 5 whitespace-separated cron fields, got {}: {:?}", - fields.len(), - expr - )); - } - let translated_dow = translate_posix_dow_field(fields[4])?; - Ok(format!( - "{} {} {} {} {}", - fields[0], fields[1], fields[2], fields[3], translated_dow - )) -} - -/// Translate a POSIX day-of-week field to the `cron` crate's numeric form. -/// -/// # Background -/// -/// POSIX cron (and Linux crontab, Kubernetes CronJob, GitHub Actions) uses -/// `0..=7` where `0` or `7` = Sunday, `1` = Monday, ..., `6` = Saturday. -/// -/// The `cron` crate uses `1..=7` where `1` = Sunday, `2` = Monday, ..., `7` = Saturday -/// (it matches via chrono's `Weekday::number_from_sunday()`). Without translation, -/// every numeric day-of-week value fires one day early: -/// -/// | POSIX intent | Without translation (cron crate reads as) | -/// |---------------|-------------------------------------------| -/// | `0`, `7` (Sun) | out-of-range / Sat | -/// | `1` (Mon) | Sun | -/// | `5` (Fri) | Thu | -/// | `1-5` (Mon-Fri) | Sun-Thu | -/// -/// # Algorithm -/// -/// 1. If the field contains any ASCII letter (e.g. `Mon-Fri`), pass it through — -/// the cron crate's name-to-ordinal map is internally consistent. -/// 2. Otherwise, expand each comma-separated component into the set of POSIX -/// day values it represents. Ranges (`a-b`) and step values (`a/s`, `a-b/s`, -/// `*/s`) are expanded here. `7` is normalized to `0` (both = Sunday) to -/// avoid duplication. -/// 3. If the resulting set covers all 7 days, emit `*` for brevity. -/// 4. Otherwise, shift each value by `+1` (POSIX `{0..=6}` → cron crate -/// `{1..=7}`) and emit as a comma-separated list, compacting contiguous -/// runs into ranges for readability. -/// -/// # Mixed numeric and name notation -/// -/// Mixing numeric and name tokens in the same field (e.g. `1,Mon`) is not -/// supported and will return an error. Use either all numeric (POSIX) or all -/// name-based notation. -fn translate_posix_dow_field(field: &str) -> Result { - use std::collections::BTreeSet; - - // Name-based notation is internally consistent in the cron crate — pass through. - // But reject mixed numeric+name notation (e.g. "1,Mon") which would leave the - // numeric part untranslated and silently wrong. - let has_alpha = field.chars().any(|c| c.is_ascii_alphabetic()); - let has_digit = field.chars().any(|c| c.is_ascii_digit()); - if has_alpha && has_digit { - return Err(format!( - "mixed numeric and name notation is not supported in day-of-week field: {:?}", - field - )); - } - if has_alpha { - return Ok(field.to_string()); - } - - if field.is_empty() { - return Err("empty day-of-week field".to_string()); - } - - let mut days: BTreeSet = BTreeSet::new(); - - for part in field.split(',') { - if part.is_empty() { - return Err(format!("empty component in day-of-week field: {:?}", field)); - } - - // Split off optional step: `a/s`, `a-b/s`, `*/s`. - let (range_part, step) = match part.split_once('/') { - Some((r, s)) => { - let step_n: u32 = s - .parse() - .map_err(|_| format!("invalid step value in {:?}", part))?; - if step_n == 0 { - return Err(format!("step value cannot be zero in {:?}", part)); - } - (r, step_n) - } - None => (part, 1u32), - }; - - // Expand range_part to the list of POSIX day values it represents. - // Values may include 7 (Sunday alias for 0); normalization happens below. - let raw_values: Vec = if range_part == "*" { - (0..=6).collect() - } else if let Some((a, b)) = range_part.split_once('-') { - let a_n: u32 = a - .parse() - .map_err(|_| format!("invalid range start in {:?}", part))?; - let b_n: u32 = b - .parse() - .map_err(|_| format!("invalid range end in {:?}", part))?; - if a_n > 7 || b_n > 7 { - return Err(format!( - "day-of-week value out of range (0-7) in {:?}", - part - )); - } - if a_n > b_n { - return Err(format!("invalid range {:?}: start > end", part)); - } - (a_n..=b_n).collect() - } else { - let n: u32 = range_part - .parse() - .map_err(|_| format!("invalid number in {:?}", part))?; - if n > 7 { - return Err(format!("day-of-week value out of range (0-7): {}", n)); - } - if step > 1 { - // n/step means "from n through end-of-domain, stepping by step" - // Normalize 7 (Sunday alias) to 0 before expansion. - let start = if n == 7 { 0 } else { n }; - (start..=6).collect() - } else { - vec![n] - } - }; - - // Apply step filter, normalize 7 → 0, collect into the set. - for (i, &v) in raw_values.iter().enumerate() { - if (i as u32).is_multiple_of(step) { - let normalized = if v == 7 { 0 } else { v }; - days.insert(normalized); - } - } - } - - if days.is_empty() { - return Err(format!("empty day-of-week field: {:?}", field)); - } - - // All 7 days → emit `*` for brevity. - if days.len() == 7 { - return Ok("*".to_string()); - } - - // Shift POSIX {0..=6} → cron crate {1..=7} and emit, compacting contiguous runs. - let shifted: Vec = days.iter().map(|d| d + 1).collect(); - Ok(compact_ordinal_set(&shifted)) -} - -/// Compact a sorted list of ordinals into cron-style comma-list with ranges, -/// e.g. `[2,3,4,5,6]` → `"2-6"`, `[1,3,5]` → `"1,3,5"`, `[1,2,4,5]` → `"1-2,4-5"`. -fn compact_ordinal_set(sorted: &[u32]) -> String { - if sorted.is_empty() { - return String::new(); - } - let mut out: Vec = Vec::new(); - let mut start = sorted[0]; - let mut end = sorted[0]; - for &v in &sorted[1..] { - if v == end + 1 { - end = v; - } else { - out.push(render_run(start, end)); - start = v; - end = v; - } - } - out.push(render_run(start, end)); - out.join(",") -} - -fn render_run(start: u32, end: u32) -> String { - if start == end { - format!("{}", start) - } else { - format!("{}-{}", start, end) - } -} - -/// Check whether a cron schedule should fire right now. -/// Truncates the current time to the minute boundary and checks if the -/// schedule has an event at exactly that minute. -pub fn should_fire(schedule: &Schedule, tz: Tz) -> bool { - let now = Utc::now().with_timezone(&tz); - let minute_start = now.with_second(0).unwrap().with_nanosecond(0).unwrap(); - let query_from = minute_start - chrono::Duration::seconds(1); - schedule - .after(&query_from) - .next() - .map(|next| next == minute_start) - .unwrap_or(false) -} - -/// Known platforms that have adapter support. -const VALID_PLATFORMS: &[&str] = &["discord", "slack"]; - -/// Validate all cronjob configs (fail-fast on bad cron expressions or timezones). -pub fn validate_cronjobs( - cronjobs: &[CronJobConfig], - configured_platforms: &[&str], -) -> anyhow::Result<()> { - for (i, job) in cronjobs.iter().enumerate() { - if !job.enabled { - continue; - } - parse_cron_expr(&job.schedule).map_err(|e| { - anyhow::anyhow!( - "cronjobs[{i}]: invalid cron expression {:?}: {e}", - job.schedule - ) - })?; - job.timezone.parse::().map_err(|e| { - anyhow::anyhow!("cronjobs[{i}]: invalid timezone {:?}: {e}", job.timezone) - })?; - if !VALID_PLATFORMS.contains(&job.platform.as_str()) { - anyhow::bail!( - "cronjobs[{i}]: unknown platform {:?} (expected one of: {VALID_PLATFORMS:?})", - job.platform - ); - } - if !configured_platforms.contains(&job.platform.as_str()) { - anyhow::bail!( - "cronjobs[{i}]: platform {:?} is not configured — add [{}] to config.toml", - job.platform, - job.platform - ); - } - if job.disable_on_success.is_some() { - anyhow::bail!( - "cronjobs[{i}]: disable_on_success is only supported in usercron [[jobs]], not baseline [[cron.jobs]]" - ); - } - } - Ok(()) -} - -// --------------------------------------------------------------------------- -// Usercron hot-reload -// --------------------------------------------------------------------------- - -/// Wrapper for deserializing cronjob.toml which contains `[[jobs]]`. -#[derive(serde::Deserialize)] -struct UsercronFile { - #[serde(default)] - jobs: Vec, -} - -/// Load and validate cronjobs from an external TOML file. -/// Returns an empty vec if the file doesn't exist. -/// Logs and skips individual invalid entries rather than failing entirely. -pub fn load_usercron_file(path: &Path, configured_platforms: &[&str]) -> Vec { - let content = match std::fs::read_to_string(path) { - Ok(c) => c, - Err(e) if e.kind() == std::io::ErrorKind::NotFound => return vec![], - Err(e) => { - warn!(path = %path.display(), error = %e, "failed to read usercron file"); - return vec![]; - } - }; - let parsed: UsercronFile = match toml::from_str(&content) { - Ok(f) => f, - Err(e) => { - warn!(path = %path.display(), error = %e, "failed to parse usercron file, skipping all entries"); - return vec![]; - } - }; - // Validate each entry individually — keep valid ones, skip bad ones - parsed.jobs.into_iter().enumerate().filter(|(i, job)| { - if let Err(e) = parse_cron_expr(&job.schedule) { - warn!(index = i, schedule = %job.schedule, error = %e, "usercron: invalid cron expression, skipping"); - return false; - } - if job.timezone.parse::().is_err() { - warn!(index = i, timezone = %job.timezone, "usercron: invalid timezone, skipping"); - return false; - } - if !VALID_PLATFORMS.contains(&job.platform.as_str()) { - warn!(index = i, platform = %job.platform, "usercron: unknown platform, skipping"); - return false; - } - if !configured_platforms.contains(&job.platform.as_str()) { - warn!(index = i, platform = %job.platform, "usercron: platform not configured, skipping"); - return false; - } - if job.disable_on_success.as_deref().is_some_and(|s| !s.trim().is_empty()) { - if job.id.as_deref().is_none_or(|s| s.trim().is_empty()) { - warn!(index = i, "usercron: disable_on_success requires id, skipping"); - return false; - } - if job - .disable_on_success_match - .as_deref() - .is_none_or(|s| s.trim().is_empty()) - { - warn!(index = i, "usercron: disable_on_success requires disable_on_success_match, skipping"); - return false; - } - } - true - }).map(|(_, job)| job).collect() -} - -/// Get file mtime, returns None if file doesn't exist or metadata fails. -fn file_mtime(path: &Path) -> Option { - std::fs::metadata(path).ok().and_then(|m| m.modified().ok()) -} - -/// A parsed, ready-to-evaluate cron job. -struct ParsedJob { - schedule: Schedule, - tz: Tz, - config: CronJobConfig, - usercron_path: Option, -} - -/// Parse a list of CronJobConfig into ParsedJob, filtering out disabled/invalid entries. -fn parse_job_list( - configs: &[CronJobConfig], - source: &str, - usercron_path: Option<&Path>, -) -> Vec { - configs.iter().filter(|job| { - if !job.enabled { - info!(schedule = %job.schedule, channel = %job.channel, source, "cronjob disabled, skipping"); - } - job.enabled - }).filter_map(|job| { - let schedule = match parse_cron_expr(&job.schedule) { - Ok(s) => s, - Err(e) => { - error!(schedule = %job.schedule, error = %e, source, "invalid cron expression, skipping"); - return None; - } - }; - let tz: Tz = match job.timezone.parse() { - Ok(t) => t, - Err(e) => { - error!(timezone = %job.timezone, error = %e, source, "invalid timezone, skipping"); - return None; - } - }; - info!( - schedule = %job.schedule, timezone = %job.timezone, - channel = %job.channel, platform = %job.platform, - message = %job.message, source, - "cronjob registered" - ); - Some(ParsedJob { - schedule, - tz, - config: job.clone(), - usercron_path: usercron_path.map(Path::to_path_buf), - }) - }).collect() -} - -/// Run the internal cron scheduler. Evaluates cron expressions once per minute. -/// `usercron_path` enables hot-reload of an external cronjob.toml file. -pub async fn run_scheduler( - cronjobs: Vec, - usercron_path: Option, - configured_platforms: Vec, - router: Arc, - adapters: HashMap>, - mut shutdown_rx: tokio::sync::watch::Receiver, -) { - let platform_refs: Vec<&str> = configured_platforms.iter().map(|s| s.as_str()).collect(); - - // Parse baseline jobs from config.toml - let baseline_jobs = parse_job_list(&cronjobs, "config.toml", None); - - // Load initial usercron jobs - let mut usercron_jobs = if let Some(ref path) = usercron_path { - let configs = load_usercron_file(path, &platform_refs); - if !configs.is_empty() { - info!(count = configs.len(), path = %path.display(), "loaded usercron jobs"); - } - parse_job_list(&configs, "cronjob.toml", Some(path.as_path())) - } else { - vec![] - }; - let mut last_usercron_mtime: Option = usercron_path.as_deref().and_then(file_mtime); - - if baseline_jobs.is_empty() && usercron_jobs.is_empty() { - if usercron_path.is_some() { - info!( - "no cronjobs yet, but usercron_path is set — scheduler will watch for cronjob.toml" - ); - } else { - debug!("no cronjobs configured, scheduler not started"); - return; - } - } - - let total = baseline_jobs.len() + usercron_jobs.len(); - info!( - baseline = baseline_jobs.len(), - usercron = usercron_jobs.len(), - total, - "cron scheduler started" - ); - - let in_flight: Arc>> = Arc::new(Mutex::new(HashSet::new())); - // Serialize usercron read-modify-write updates so concurrent jobs do not - // overwrite each other's enabled/thread_id changes. - let usercron_write_lock: Arc> = Arc::new(Mutex::new(())); - - // Align to next minute boundary - let now = Utc::now(); - let secs_into_minute = now.timestamp() % 60; - let align_delay = if secs_into_minute == 0 { - 0 - } else { - 60 - secs_into_minute as u64 - }; - if align_delay > 0 { - debug!(align_secs = align_delay, "aligning to next minute boundary"); - tokio::time::sleep(std::time::Duration::from_secs(align_delay)).await; - } - let mut ticker = tokio::time::interval(std::time::Duration::from_secs(60)); - ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); - - let mut tasks: tokio::task::JoinSet<()> = tokio::task::JoinSet::new(); - - loop { - tokio::select! { - _ = ticker.tick() => { - // Hot-reload usercron file if mtime changed - if let Some(ref path) = usercron_path { - let current_mtime = file_mtime(path); - if current_mtime != last_usercron_mtime { - let configs = load_usercron_file(path, &platform_refs); - info!(count = configs.len(), path = %path.display(), "usercron file changed, reloading"); - // Keep in-flight indices across reload. A scheduler writeback - // (thread_id or enabled=false) changes mtime deterministically; - // clearing usercron indices here would allow the same job to - // overlap on the next tick while its previous run is still active. - usercron_jobs = - parse_job_list(&configs, "cronjob.toml", Some(path.as_path())); - last_usercron_mtime = current_mtime; - } - } - - // Evaluate all jobs: baseline first, then usercron - let all_jobs = baseline_jobs.iter().chain(usercron_jobs.iter()); - for (idx, job) in all_jobs.enumerate() { - if !should_fire(&job.schedule, job.tz) { - continue; - } - { - let running = in_flight.lock().await; - if running.contains(&idx) { - warn!(schedule = %job.config.schedule, channel = %job.config.channel, "skipping cronjob, previous execution still running"); - continue; - } - } - info!( - schedule = %job.config.schedule, - channel = %job.config.channel, - platform = %job.config.platform, - message = %job.config.message, - sender = %job.config.sender_name, - "🔔 cronjob fired" - ); - in_flight.lock().await.insert(idx); - - let config = job.config.clone(); - let usercron_path = job.usercron_path.clone(); - let router = router.clone(); - let adapters = adapters.clone(); - let in_flight = in_flight.clone(); - let usercron_write_lock = usercron_write_lock.clone(); - tasks.spawn(async move { - fire_cronjob( - idx, - &config, - usercron_path, - &router, - &adapters, - in_flight, - usercron_write_lock, - ) - .await; - }); - } - while tasks.try_join_next().is_some() {} - } - _ = shutdown_rx.changed() => { - if *shutdown_rx.borrow() { - info!("cron scheduler shutting down, waiting for in-flight tasks"); - let drain = async { while tasks.join_next().await.is_some() {} }; - let _ = tokio::time::timeout(std::time::Duration::from_secs(30), drain).await; - return; - } - } - } - } -} - -/// RAII guard that removes a job index from the in-flight set on drop. -struct InFlightGuard { - idx: usize, - set: Arc>>, -} - -impl Drop for InFlightGuard { - fn drop(&mut self) { - let idx = self.idx; - let set = self.set.clone(); - tokio::spawn(async move { - set.lock().await.remove(&idx); - }); - } -} - -async fn fire_cronjob( - idx: usize, - job: &CronJobConfig, - usercron_path: Option, - router: &Arc, - adapters: &HashMap>, - in_flight: Arc>>, - usercron_write_lock: Arc>, -) { - let _guard = InFlightGuard { - idx, - set: in_flight, - }; - - let adapter = match adapters.get(&job.platform) { - Some(a) => a.clone(), - None => { - error!(platform = %job.platform, "no adapter for platform, skipping cronjob"); - return; - } - }; - - if let Some(command) = non_empty_opt(job.disable_on_success.as_deref()) { - let marker = match non_empty_opt(job.disable_on_success_match.as_deref()) { - Some(marker) => marker, - None => { - warn!( - id = job.id.as_deref().unwrap_or(""), - "disable_on_success configured without disable_on_success_match, treating as not achieved" - ); - "" - } - }; - if !marker.is_empty() { - match check_disable_on_success(job, command, marker).await { - DisableOnSuccessResult::Achieved => { - let channel = ChannelRef { - platform: job.platform.clone(), - channel_id: job.channel.clone(), - thread_id: job.thread_id.clone(), - parent_id: None, - origin_event_id: None, - }; - if let Err(e) = adapter - .send_message( - &channel, - &format!( - "✅ Goal achieved: `{}` matched `{}`. Disabling cronjob.", - command, marker - ), - ) - .await - { - error!(channel = %job.channel, error = %e, "failed to send goal achieved message"); - } - - if let (Some(path), Some(id)) = - (usercron_path.as_deref(), non_empty_opt(job.id.as_deref())) - { - let _write_guard = usercron_write_lock.lock().await; - if let Err(e) = update_usercron_job(path, id, Some(false), None) { - error!(path = %path.display(), id, error = %e, "failed to disable completed usercron job"); - } - } else { - warn!("completed disable_on_success job has no usercron path or id, cannot write enabled=false"); - } - return; - } - DisableOnSuccessResult::NotAchieved(reason) => { - info!( - id = job.id.as_deref().unwrap_or(""), - reason, - "disable_on_success not achieved, firing cronjob normally" - ); - } - } - } - } - - let thread_channel = ChannelRef { - platform: job.platform.clone(), - channel_id: job.channel.clone(), - thread_id: job.thread_id.clone(), - parent_id: None, - origin_event_id: None, - }; - - let trigger_msg = match adapter - .send_message( - &thread_channel, - &format!("🕐 [{}]: {}", job.sender_name, job.message), - ) - .await - { - Ok(msg) => msg, - Err(e) => { - error!(channel = %job.channel, error = %e, "failed to send cron message"); - return; - } - }; - - let reply_channel = if job.thread_id.is_some() { - thread_channel.clone() - } else { - let thread_name = format::shorten_thread_name(&job.message); - match adapter - .create_thread(&thread_channel, &trigger_msg, &thread_name) - .await - { - Ok(ch) => { - if let (Some(path), Some(id), Some(thread_id)) = ( - usercron_path.as_deref(), - non_empty_opt(job.id.as_deref()), - ch.thread_id.as_deref().or(Some(ch.channel_id.as_str())), - ) { - let _write_guard = usercron_write_lock.lock().await; - if let Err(e) = update_usercron_job(path, id, None, Some(thread_id)) { - warn!(path = %path.display(), id, error = %e, "failed to persist usercron thread_id"); - } - } - ch - } - Err(e) => { - error!(channel = %job.channel, error = %e, "failed to create cron thread"); - let _ = adapter - .send_message( - &thread_channel, - &format!("⚠️ cronjob: failed to create thread: {e}"), - ) - .await; - return; - } - } - }; - - let sender = SenderContext { - schema: "openab.sender.v1".into(), - sender_id: "openab-cron".into(), - sender_name: job.sender_name.clone(), - display_name: job.sender_name.clone(), - channel: job.platform.clone(), - channel_id: reply_channel - .parent_id - .as_deref() - .unwrap_or(&reply_channel.channel_id) - .to_string(), - thread_id: reply_channel - .thread_id - .clone() - .or(Some(reply_channel.channel_id.clone())), - is_bot: true, - timestamp: Some(Utc::now().to_rfc3339()), - message_id: None, // cron jobs don't originate from a message - receiver_id: None, // cron jobs are self-triggered, no external receiver - }; - let sender_json = match serde_json::to_string(&sender) { - Ok(j) => j, - Err(e) => { - warn!(error = %e, "failed to serialize cron sender context, skipping"); - return; - } - }; - - if let Err(e) = router - .handle_message( - &adapter, - crate::adapter::MessageContext { - thread_channel: reply_channel.clone(), - sender_json, - prompt: job.message.clone(), - extra_blocks: vec![], - trigger_msg, - other_bot_present: false, - }, - ) - .await - { - error!("cron handle_message error: {e}"); - let _ = adapter - .send_message(&reply_channel, &format!("⚠️ cronjob error: {e}")) - .await; - } -} - -enum DisableOnSuccessResult { - Achieved, - NotAchieved(&'static str), -} - -fn non_empty_opt(value: Option<&str>) -> Option<&str> { - value.and_then(|s| { - let trimmed = s.trim(); - if trimmed.is_empty() { - None - } else { - Some(trimmed) - } - }) -} - -async fn check_disable_on_success( - job: &CronJobConfig, - command: &str, - marker: &str, -) -> DisableOnSuccessResult { - let timeout_secs = job.disable_on_success_timeout_secs.max(1); - let mut cmd = shell_command(command); - if let Some(dir) = non_empty_opt(job.disable_on_success_working_dir.as_deref()) { - cmd.current_dir(dir); - } - cmd.stdout(std::process::Stdio::piped()); - cmd.stderr(std::process::Stdio::piped()); - - let mut child = match cmd.spawn() { - Ok(child) => child, - Err(e) => { - warn!( - id = job.id.as_deref().unwrap_or(""), - command, - error = %e, - "disable_on_success command failed to start" - ); - return DisableOnSuccessResult::NotAchieved("command failed to start"); - } - }; - - // Take stdout/stderr handles and drain them concurrently to prevent pipe buffer deadlock. - let stdout_handle = child.stdout.take(); - let stderr_handle = child.stderr.take(); - - let stdout_task = tokio::spawn(async move { - let mut buf = Vec::new(); - if let Some(mut out) = stdout_handle { - let _ = tokio::io::AsyncReadExt::read_to_end(&mut out, &mut buf).await; - } - buf - }); - let stderr_task = tokio::spawn(async move { - let mut buf = Vec::new(); - if let Some(mut err) = stderr_handle { - let _ = tokio::io::AsyncReadExt::read_to_end(&mut err, &mut buf).await; - } - buf - }); - - let deadline = tokio::time::sleep(std::time::Duration::from_secs(timeout_secs)); - tokio::pin!(deadline); - - tokio::select! { - status = child.wait() => { - let status = match status { - Ok(s) => s, - Err(e) => { - warn!( - id = job.id.as_deref().unwrap_or(""), - command, - error = %e, - "disable_on_success command wait failed" - ); - stdout_task.abort(); - stderr_task.abort(); - return DisableOnSuccessResult::NotAchieved("command wait failed"); - } - }; - if !status.success() { - stdout_task.abort(); - stderr_task.abort(); - return DisableOnSuccessResult::NotAchieved("command exited non-zero"); - } - let stdout_buf = stdout_task.await.unwrap_or_default(); - let stderr_buf = stderr_task.await.unwrap_or_default(); - let stdout = String::from_utf8_lossy(&stdout_buf); - let stderr = String::from_utf8_lossy(&stderr_buf); - if stdout.contains(marker) || stderr.contains(marker) { - DisableOnSuccessResult::Achieved - } else { - DisableOnSuccessResult::NotAchieved("success marker not found") - } - } - _ = &mut deadline => { - // Timeout — kill the child to avoid orphan processes. - let _ = child.kill().await; - stdout_task.abort(); - stderr_task.abort(); - warn!( - id = job.id.as_deref().unwrap_or(""), - command, - timeout_secs, - "disable_on_success command timed out" - ); - DisableOnSuccessResult::NotAchieved("command timed out") - } - } -} - -fn shell_command(command: &str) -> Command { - #[cfg(windows)] - { - let mut child = Command::new("cmd"); - child.arg("/C").arg(command); - child - } - #[cfg(not(windows))] - { - let mut child = Command::new("sh"); - child.arg("-c").arg(command); - child - } -} - -fn update_usercron_job( - path: &Path, - id: &str, - enabled: Option, - thread_id: Option<&str>, -) -> anyhow::Result<()> { - let content = std::fs::read_to_string(path)?; - let mut doc = content.parse::()?; - let jobs = doc - .get_mut("jobs") - .and_then(|item| item.as_array_of_tables_mut()) - .ok_or_else(|| anyhow::anyhow!("usercron file has no [[jobs]] array"))?; - - let mut found = false; - for table in jobs.iter_mut() { - if table.get("id").and_then(|item| item.as_str()) != Some(id) { - continue; - } - if let Some(enabled) = enabled { - table["enabled"] = value(enabled); - } - if let Some(thread_id) = thread_id { - table["thread_id"] = value(thread_id); - } - found = true; - break; - } - - if !found { - anyhow::bail!("usercron job id {:?} not found", id); - } - - // Atomic write: write to temp file then rename to avoid corruption on crash. - let tmp = path.with_extension("toml.tmp"); - std::fs::write(&tmp, doc.to_string())?; - std::fs::rename(&tmp, path)?; - Ok(()) -} - -#[cfg(test)] -mod tests { - use super::*; - use chrono::{Datelike, Timelike}; - - // --- POSIX day-of-week translator --- - - #[test] - fn translate_dow_star_passes_through() { - assert_eq!(translate_posix_dow_field("*").unwrap(), "*"); - } - - #[test] - fn translate_dow_single_sunday_zero() { - assert_eq!(translate_posix_dow_field("0").unwrap(), "1"); - } - - #[test] - fn translate_dow_single_sunday_seven() { - assert_eq!(translate_posix_dow_field("7").unwrap(), "1"); - } - - #[test] - fn translate_dow_single_monday() { - assert_eq!(translate_posix_dow_field("1").unwrap(), "2"); - } - - #[test] - fn translate_dow_single_saturday() { - assert_eq!(translate_posix_dow_field("6").unwrap(), "7"); - } - - #[test] - fn translate_dow_weekday_range() { - // POSIX 1-5 (Mon-Fri) -> cron crate 2-6 - assert_eq!(translate_posix_dow_field("1-5").unwrap(), "2-6"); - } - - #[test] - fn translate_dow_all_days_zero_to_six() { - assert_eq!(translate_posix_dow_field("0-6").unwrap(), "*"); - } - - #[test] - fn translate_dow_all_days_zero_to_seven() { - // POSIX `0-7` is a quirky but valid "all days" expression. - assert_eq!(translate_posix_dow_field("0-7").unwrap(), "*"); - } - - #[test] - fn translate_dow_all_days_one_to_seven() { - // POSIX `1-7` covers Mon..Sun = all 7 days. - assert_eq!(translate_posix_dow_field("1-7").unwrap(), "*"); - } - - #[test] - fn translate_dow_range_three_to_five() { - // POSIX 3-5 (Wed-Fri) -> cron crate 4-6 - assert_eq!(translate_posix_dow_field("3-5").unwrap(), "4-6"); - } - - #[test] - fn translate_dow_list_dedupes_zero_and_seven() { - // Both 0 and 7 = Sunday; output is a single value. - assert_eq!(translate_posix_dow_field("0,7").unwrap(), "1"); - } - - #[test] - fn translate_dow_list_non_contiguous() { - // POSIX 1,3,5 (Mon,Wed,Fri) -> cron crate 2,4,6 - assert_eq!(translate_posix_dow_field("1,3,5").unwrap(), "2,4,6"); - } - - #[test] - fn translate_dow_list_compacts_contiguous_runs() { - // POSIX 1,2,4,5 -> cron crate 2,3,5,6 -> "2-3,5-6" - assert_eq!(translate_posix_dow_field("1,2,4,5").unwrap(), "2-3,5-6"); - } - - #[test] - fn translate_dow_step_from_star() { - // POSIX */2 = 0,2,4,6 = Sun,Tue,Thu,Sat -> cron crate 1,3,5,7 - assert_eq!(translate_posix_dow_field("*/2").unwrap(), "1,3,5,7"); - } - - #[test] - fn translate_dow_step_from_range() { - // POSIX 1-5/2 = 1,3,5 = Mon,Wed,Fri -> cron crate 2,4,6 - assert_eq!(translate_posix_dow_field("1-5/2").unwrap(), "2,4,6"); - } - - #[test] - fn translate_dow_names_pass_through() { - assert_eq!(translate_posix_dow_field("Mon-Fri").unwrap(), "Mon-Fri"); - assert_eq!( - translate_posix_dow_field("Mon,Wed,Fri").unwrap(), - "Mon,Wed,Fri" - ); - assert_eq!(translate_posix_dow_field("Sun").unwrap(), "Sun"); - } - - #[test] - fn translate_dow_step_from_singleton() { - // POSIX 1/2 = from Mon through Sat, step 2 = {1,3,5} = Mon,Wed,Fri -> cron crate 2,4,6 - assert_eq!(translate_posix_dow_field("1/2").unwrap(), "2,4,6"); - } - - #[test] - fn translate_dow_step_from_singleton_sunday() { - // POSIX 0/3 = from Sun through Sat, step 3 = {0,3,6} = Sun,Wed,Sat -> cron crate 1,4,7 - assert_eq!(translate_posix_dow_field("0/3").unwrap(), "1,4,7"); - } - - #[test] - fn translate_dow_step_from_singleton_seven() { - // POSIX 7/2 = Sunday alias, same as 0/2 = {0,2,4,6} = Sun,Tue,Thu,Sat -> cron crate 1,3,5,7 - assert_eq!(translate_posix_dow_field("7/2").unwrap(), "1,3,5,7"); - } - - #[test] - fn translate_dow_rejects_mixed_notation() { - assert!(translate_posix_dow_field("1,Mon").is_err()); - assert!(translate_posix_dow_field("Mon,1").is_err()); - assert!(translate_posix_dow_field("1-Fri").is_err()); - } - - #[test] - fn translate_dow_rejects_out_of_range() { - assert!(translate_posix_dow_field("8").is_err()); - assert!(translate_posix_dow_field("0-8").is_err()); - } - - #[test] - fn translate_dow_rejects_reversed_range() { - assert!(translate_posix_dow_field("5-3").is_err()); - } - - #[test] - fn translate_dow_rejects_empty() { - assert!(translate_posix_dow_field("").is_err()); - assert!(translate_posix_dow_field(",1").is_err()); - assert!(translate_posix_dow_field("1,").is_err()); - } - - #[test] - fn translate_dow_rejects_zero_step() { - assert!(translate_posix_dow_field("*/0").is_err()); - } - - // --- parse_cron_expr rejects wrong number of fields --- - - #[test] - fn parse_rejects_too_few_fields() { - assert!(parse_cron_expr("* * * *").is_err()); - } - - // --- POSIX-semantic Schedule behavior (regression for #784) --- - - #[test] - fn weekday_schedule_does_not_fire_on_sunday() { - use chrono::TimeZone; - // Regression for the reported bug: "0 7 * * 1-5" with timezone Asia/Taipei - // was firing on Sunday 2026-05-10 because the cron crate's `1-5` means - // Sun-Thu without translation. - let schedule = parse_cron_expr("0 7 * * 1-5").unwrap(); - let tz: Tz = "Asia/Taipei".parse().unwrap(); - let sunday = tz.with_ymd_and_hms(2026, 5, 10, 7, 0, 0).unwrap(); - let before = sunday - chrono::Duration::seconds(1); - let next = schedule.after(&before).next(); - assert_ne!( - next, - Some(sunday), - "POSIX 1-5 must not fire on Sunday (got next = {:?})", - next - ); - } - - #[test] - fn weekday_schedule_fires_on_monday() { - use chrono::TimeZone; - let schedule = parse_cron_expr("0 7 * * 1-5").unwrap(); - let tz: Tz = "Asia/Taipei".parse().unwrap(); - let monday = tz.with_ymd_and_hms(2026, 5, 11, 7, 0, 0).unwrap(); - let before = monday - chrono::Duration::seconds(1); - let next = schedule.after(&before).next(); - assert_eq!(next, Some(monday), "POSIX 1-5 must fire on Monday"); - } - - #[test] - fn weekday_schedule_fires_on_friday_not_saturday() { - use chrono::TimeZone; - let schedule = parse_cron_expr("0 7 * * 1-5").unwrap(); - let tz: Tz = "Asia/Taipei".parse().unwrap(); - // 2026-05-15 is Friday - let friday = tz.with_ymd_and_hms(2026, 5, 15, 7, 0, 0).unwrap(); - let before = friday - chrono::Duration::seconds(1); - let next = schedule.after(&before).next(); - assert_eq!(next, Some(friday), "POSIX 1-5 must fire on Friday"); - - // 2026-05-16 is Saturday - should not fire - let saturday = tz.with_ymd_and_hms(2026, 5, 16, 7, 0, 0).unwrap(); - let before = saturday - chrono::Duration::seconds(1); - let next = schedule.after(&before).next(); - assert_ne!(next, Some(saturday), "POSIX 1-5 must not fire on Saturday"); - } - - #[test] - fn sunday_schedule_fires_on_sunday_via_zero() { - use chrono::TimeZone; - let schedule = parse_cron_expr("0 7 * * 0").unwrap(); - let tz: Tz = "Asia/Taipei".parse().unwrap(); - let sunday = tz.with_ymd_and_hms(2026, 5, 10, 7, 0, 0).unwrap(); - let before = sunday - chrono::Duration::seconds(1); - let next = schedule.after(&before).next(); - assert_eq!(next, Some(sunday), "POSIX `0` must fire on Sunday"); - } - - #[test] - fn sunday_schedule_fires_on_sunday_via_seven() { - use chrono::TimeZone; - let schedule = parse_cron_expr("0 7 * * 7").unwrap(); - let tz: Tz = "Asia/Taipei".parse().unwrap(); - let sunday = tz.with_ymd_and_hms(2026, 5, 10, 7, 0, 0).unwrap(); - let before = sunday - chrono::Duration::seconds(1); - let next = schedule.after(&before).next(); - assert_eq!(next, Some(sunday), "POSIX `7` must also fire on Sunday"); - } - - #[test] - fn saturday_schedule_fires_on_saturday_via_six() { - use chrono::TimeZone; - let schedule = parse_cron_expr("0 7 * * 6").unwrap(); - let tz: Tz = "Asia/Taipei".parse().unwrap(); - // 2026-05-16 is Saturday - let saturday = tz.with_ymd_and_hms(2026, 5, 16, 7, 0, 0).unwrap(); - let before = saturday - chrono::Duration::seconds(1); - let next = schedule.after(&before).next(); - assert_eq!(next, Some(saturday), "POSIX `6` must fire on Saturday"); - } - - #[test] - fn name_based_weekday_still_works() { - use chrono::TimeZone; - // Name-based notation should be unaffected by the translation. - let schedule = parse_cron_expr("0 7 * * Mon-Fri").unwrap(); - let tz: Tz = "Asia/Taipei".parse().unwrap(); - let monday = tz.with_ymd_and_hms(2026, 5, 11, 7, 0, 0).unwrap(); - let before = monday - chrono::Duration::seconds(1); - let next = schedule.after(&before).next(); - assert_eq!(next, Some(monday)); - - let sunday = tz.with_ymd_and_hms(2026, 5, 10, 7, 0, 0).unwrap(); - let before = sunday - chrono::Duration::seconds(1); - let next = schedule.after(&before).next(); - assert_ne!(next, Some(sunday)); - } - - #[test] - fn parse_valid_cron_expression() { - let schedule = parse_cron_expr("0 9 * * 1-5").unwrap(); - let next = schedule.upcoming(chrono_tz::UTC).next(); - assert!(next.is_some()); - } - - #[test] - fn parse_every_minute_cron() { - let schedule = parse_cron_expr("* * * * *").unwrap(); - let next = schedule.upcoming(chrono_tz::UTC).next(); - assert!(next.is_some()); - } - - #[test] - fn parse_invalid_cron_expression() { - assert!(parse_cron_expr("not a cron").is_err()); - } - - #[test] - fn parse_invalid_cron_too_many_fields() { - assert!(parse_cron_expr("0 0 9 * * 1-5").is_err()); - } - - #[test] - fn valid_timezone_parses() { - assert!("Asia/Taipei".parse::().is_ok()); - } - - #[test] - fn invalid_timezone_fails() { - assert!("Mars/Olympus".parse::().is_err()); - } - - #[test] - fn utc_timezone_parses() { - assert!("UTC".parse::().is_ok()); - } - - #[test] - fn should_fire_every_minute_returns_true() { - let schedule = parse_cron_expr("* * * * *").unwrap(); - assert!(should_fire(&schedule, chrono_tz::UTC)); - } - - #[test] - fn should_fire_returns_false_for_distant_schedule() { - let schedule = parse_cron_expr("0 0 1 1 *").unwrap(); - let now = chrono::Utc::now(); - if now.month() != 1 || now.day() != 1 || now.hour() != 0 { - assert!(!should_fire(&schedule, chrono_tz::UTC)); - } - } - - #[test] - fn should_fire_respects_timezone() { - let schedule = parse_cron_expr("* * * * *").unwrap(); - let tz: Tz = "Asia/Taipei".parse().unwrap(); - assert!(should_fire(&schedule, tz)); - } - - #[test] - fn cronjob_config_defaults() { - let toml_str = r#" -[[jobs]] -schedule = "0 9 * * 1-5" -channel = "123" -message = "hello" -"#; - let cfg: UsercronFile = toml::from_str(toml_str).unwrap(); - let job = &cfg.jobs[0]; - assert_eq!(job.enabled, true); - assert_eq!(job.platform, "discord"); - assert_eq!(job.sender_name, "openab-cron"); - assert_eq!(job.timezone, "UTC"); - assert!(job.thread_id.is_none()); - assert!(job.id.is_none()); - assert!(job.disable_on_success.is_none()); - assert!(job.disable_on_success_match.is_none()); - assert_eq!(job.disable_on_success_timeout_secs, 60); - assert!(job.disable_on_success_working_dir.is_none()); - } - - #[test] - fn cronjob_config_disabled() { - let toml_str = r#" -[[jobs]] -enabled = false -schedule = "0 9 * * 1-5" -channel = "123" -message = "hello" -"#; - let cfg: UsercronFile = toml::from_str(toml_str).unwrap(); - assert_eq!(cfg.jobs[0].enabled, false); - } - - #[test] - fn cronjob_config_custom_values() { - let toml_str = r#" -[[jobs]] -schedule = "0 18 * * 1-5" -channel = "456" -message = "report" -platform = "slack" -sender_name = "DailyOps" -timezone = "Asia/Taipei" -thread_id = "789" -id = "daily-report" -disable_on_success = "npm test" -disable_on_success_match = "SUCCESS" -disable_on_success_timeout_secs = 30 -disable_on_success_working_dir = "/tmp/project" -"#; - let cfg: UsercronFile = toml::from_str(toml_str).unwrap(); - let job = &cfg.jobs[0]; - assert_eq!(job.platform, "slack"); - assert_eq!(job.sender_name, "DailyOps"); - assert_eq!(job.timezone, "Asia/Taipei"); - assert_eq!(job.thread_id.as_deref(), Some("789")); - assert_eq!(job.id.as_deref(), Some("daily-report")); - assert_eq!(job.disable_on_success.as_deref(), Some("npm test")); - assert_eq!(job.disable_on_success_match.as_deref(), Some("SUCCESS")); - assert_eq!(job.disable_on_success_timeout_secs, 30); - assert_eq!( - job.disable_on_success_working_dir.as_deref(), - Some("/tmp/project") - ); - } - - #[test] - fn load_usercron_nonexistent_returns_empty() { - let jobs = load_usercron_file(Path::new("/tmp/nonexistent-usercron.toml"), &["discord"]); - assert!(jobs.is_empty()); - } - - #[test] - fn load_usercron_valid_file() { - let dir = tempfile::tempdir().unwrap(); - let path = dir.path().join("cronjob.toml"); - std::fs::write( - &path, - r#" -[[jobs]] -schedule = "* * * * *" -channel = "123" -message = "ping" -"#, - ) - .unwrap(); - let jobs = load_usercron_file(&path, &["discord"]); - assert_eq!(jobs.len(), 1); - assert_eq!(jobs[0].message, "ping"); - } - - #[test] - fn load_usercron_invalid_toml_returns_empty() { - let dir = tempfile::tempdir().unwrap(); - let path = dir.path().join("cronjob.toml"); - std::fs::write(&path, "not valid toml {{{").unwrap(); - let jobs = load_usercron_file(&path, &["discord"]); - assert!(jobs.is_empty()); - } - - #[test] - fn load_usercron_skips_invalid_entries() { - let dir = tempfile::tempdir().unwrap(); - let path = dir.path().join("cronjob.toml"); - std::fs::write( - &path, - r#" -[[jobs]] -schedule = "* * * * *" -channel = "123" -message = "good" - -[[jobs]] -schedule = "bad cron" -channel = "456" -message = "bad" -"#, - ) - .unwrap(); - let jobs = load_usercron_file(&path, &["discord"]); - assert_eq!(jobs.len(), 1); - assert_eq!(jobs[0].message, "good"); - } - - #[test] - fn load_usercron_skips_unconfigured_platform() { - let dir = tempfile::tempdir().unwrap(); - let path = dir.path().join("cronjob.toml"); - std::fs::write( - &path, - r#" -[[jobs]] -schedule = "* * * * *" -channel = "123" -message = "discord job" - -[[jobs]] -schedule = "* * * * *" -channel = "456" -message = "slack job" -platform = "slack" -"#, - ) - .unwrap(); - // Only discord configured - let jobs = load_usercron_file(&path, &["discord"]); - assert_eq!(jobs.len(), 1); - assert_eq!(jobs[0].message, "discord job"); - } - - #[test] - fn load_usercron_skips_disable_on_success_without_id() { - let dir = tempfile::tempdir().unwrap(); - let path = dir.path().join("cronjob.toml"); - std::fs::write( - &path, - r#" -[[jobs]] -schedule = "* * * * *" -channel = "123" -message = "missing id" -disable_on_success = "echo SUCCESS" -disable_on_success_match = "SUCCESS" -"#, - ) - .unwrap(); - let jobs = load_usercron_file(&path, &["discord"]); - assert!(jobs.is_empty()); - } - - #[test] - fn load_usercron_skips_disable_on_success_without_match() { - let dir = tempfile::tempdir().unwrap(); - let path = dir.path().join("cronjob.toml"); - std::fs::write( - &path, - r#" -[[jobs]] -id = "goal" -schedule = "* * * * *" -channel = "123" -message = "missing marker" -disable_on_success = "echo SUCCESS" -"#, - ) - .unwrap(); - let jobs = load_usercron_file(&path, &["discord"]); - assert!(jobs.is_empty()); - } - - #[test] - fn validate_cronjobs_rejects_baseline_disable_on_success() { - let jobs = vec![CronJobConfig { - id: Some("baseline-goal".into()), - enabled: true, - schedule: "* * * * *".into(), - channel: "123".into(), - message: "hi".into(), - platform: "discord".into(), - sender_name: "test".into(), - thread_id: None, - timezone: "UTC".into(), - disable_on_success: Some("echo SUCCESS".into()), - disable_on_success_match: Some("SUCCESS".into()), - disable_on_success_timeout_secs: 60, - disable_on_success_working_dir: None, - }]; - let err = validate_cronjobs(&jobs, &["discord"]).unwrap_err(); - assert!(err.to_string().contains("only supported in usercron")); - } - - #[test] - fn update_usercron_job_sets_enabled_and_thread_id_by_id() { - let dir = tempfile::tempdir().unwrap(); - let path = dir.path().join("cronjob.toml"); - std::fs::write( - &path, - r#" -[[jobs]] -id = "goal-a" -enabled = true -schedule = "* * * * *" -channel = "123" -message = "a" - -[[jobs]] -id = "goal-b" -enabled = true -schedule = "* * * * *" -channel = "456" -message = "b" -"#, - ) - .unwrap(); - - update_usercron_job(&path, "goal-b", Some(false), Some("thread-456")).unwrap(); - - let updated = std::fs::read_to_string(&path).unwrap(); - let doc = updated.parse::().unwrap(); - let jobs = doc["jobs"].as_array_of_tables().unwrap(); - let job_a = jobs.iter().next().unwrap(); - let job_b = jobs.iter().nth(1).unwrap(); - assert_eq!(job_a["id"].as_str(), Some("goal-a")); - assert_eq!(job_a["enabled"].as_bool(), Some(true)); - assert!(job_a.get("thread_id").is_none()); - assert_eq!(job_b["id"].as_str(), Some("goal-b")); - assert_eq!(job_b["enabled"].as_bool(), Some(false)); - assert_eq!(job_b["thread_id"].as_str(), Some("thread-456")); - } - - #[test] - fn update_usercron_job_errors_for_missing_id() { - let dir = tempfile::tempdir().unwrap(); - let path = dir.path().join("cronjob.toml"); - std::fs::write( - &path, - r#" -[[jobs]] -id = "goal-a" -schedule = "* * * * *" -channel = "123" -message = "a" -"#, - ) - .unwrap(); - let err = update_usercron_job(&path, "missing", Some(false), None).unwrap_err(); - assert!(err.to_string().contains("not found")); - } - - #[tokio::test] - async fn disable_on_success_requires_exit_zero_and_marker() { - let mut job = test_cron_job(); - job.disable_on_success_timeout_secs = 5; - - assert!(matches!( - check_disable_on_success(&job, "printf SUCCESS", "SUCCESS").await, - DisableOnSuccessResult::Achieved - )); - assert!(matches!( - check_disable_on_success(&job, "printf DONE", "SUCCESS").await, - DisableOnSuccessResult::NotAchieved("success marker not found") - )); - assert!(matches!( - check_disable_on_success(&job, "printf SUCCESS; exit 1", "SUCCESS").await, - DisableOnSuccessResult::NotAchieved("command exited non-zero") - )); - } - - #[tokio::test] - async fn disable_on_success_kills_child_on_timeout() { - let mut job = test_cron_job(); - job.disable_on_success_timeout_secs = 1; - - let result = check_disable_on_success(&job, "sleep 999", "SUCCESS").await; - assert!(matches!( - result, - DisableOnSuccessResult::NotAchieved("command timed out") - )); - } - - fn test_cron_job() -> CronJobConfig { - CronJobConfig { - id: Some("goal".into()), - enabled: true, - schedule: "* * * * *".into(), - channel: "123".into(), - message: "hi".into(), - platform: "discord".into(), - sender_name: "test".into(), - thread_id: None, - timezone: "UTC".into(), - disable_on_success: Some("echo SUCCESS".into()), - disable_on_success_match: Some("SUCCESS".into()), - disable_on_success_timeout_secs: 60, - disable_on_success_working_dir: None, - } - } - - // --- validate_cronjobs tests --- - - #[test] - fn validate_cronjobs_valid_passes() { - let jobs = vec![CronJobConfig { - id: None, - enabled: true, - schedule: "0 9 * * 1-5".into(), - channel: "123".into(), - message: "hi".into(), - platform: "discord".into(), - sender_name: "test".into(), - thread_id: None, - timezone: "UTC".into(), - disable_on_success: None, - disable_on_success_match: None, - disable_on_success_timeout_secs: 60, - disable_on_success_working_dir: None, - }]; - assert!(validate_cronjobs(&jobs, &["discord"]).is_ok()); - } - - #[test] - fn validate_cronjobs_invalid_cron_fails() { - let jobs = vec![CronJobConfig { - id: None, - enabled: true, - schedule: "bad".into(), - channel: "123".into(), - message: "hi".into(), - platform: "discord".into(), - sender_name: "test".into(), - thread_id: None, - timezone: "UTC".into(), - disable_on_success: None, - disable_on_success_match: None, - disable_on_success_timeout_secs: 60, - disable_on_success_working_dir: None, - }]; - let err = validate_cronjobs(&jobs, &["discord"]).unwrap_err(); - assert!(err.to_string().contains("invalid cron expression")); - } - - #[test] - fn validate_cronjobs_invalid_timezone_fails() { - let jobs = vec![CronJobConfig { - id: None, - enabled: true, - schedule: "* * * * *".into(), - channel: "123".into(), - message: "hi".into(), - platform: "discord".into(), - sender_name: "test".into(), - thread_id: None, - timezone: "Mars/Olympus".into(), - disable_on_success: None, - disable_on_success_match: None, - disable_on_success_timeout_secs: 60, - disable_on_success_working_dir: None, - }]; - let err = validate_cronjobs(&jobs, &["discord"]).unwrap_err(); - assert!(err.to_string().contains("invalid timezone")); - } - - #[test] - fn validate_cronjobs_unknown_platform_fails() { - let jobs = vec![CronJobConfig { - id: None, - enabled: true, - schedule: "* * * * *".into(), - channel: "123".into(), - message: "hi".into(), - platform: "telegram".into(), - sender_name: "test".into(), - thread_id: None, - timezone: "UTC".into(), - disable_on_success: None, - disable_on_success_match: None, - disable_on_success_timeout_secs: 60, - disable_on_success_working_dir: None, - }]; - let err = validate_cronjobs(&jobs, &["discord"]).unwrap_err(); - assert!(err.to_string().contains("unknown platform")); - } - - #[test] - fn validate_cronjobs_unconfigured_platform_fails() { - let jobs = vec![CronJobConfig { - id: None, - enabled: true, - schedule: "* * * * *".into(), - channel: "123".into(), - message: "hi".into(), - platform: "slack".into(), - sender_name: "test".into(), - thread_id: None, - timezone: "UTC".into(), - disable_on_success: None, - disable_on_success_match: None, - disable_on_success_timeout_secs: 60, - disable_on_success_working_dir: None, - }]; - let err = validate_cronjobs(&jobs, &["discord"]).unwrap_err(); - assert!(err.to_string().contains("not configured")); - } - - #[test] - fn validate_cronjobs_disabled_with_invalid_cron_passes() { - let jobs = vec![CronJobConfig { - id: None, - enabled: false, - schedule: "bad".into(), - channel: "123".into(), - message: "hi".into(), - platform: "discord".into(), - sender_name: "test".into(), - thread_id: None, - timezone: "UTC".into(), - disable_on_success: None, - disable_on_success_match: None, - disable_on_success_timeout_secs: 60, - disable_on_success_working_dir: None, - }]; - assert!(validate_cronjobs(&jobs, &["discord"]).is_ok()); - } - - #[test] - fn validate_cronjobs_enabled_with_invalid_cron_still_fails() { - let jobs = vec![CronJobConfig { - id: None, - enabled: true, - schedule: "bad".into(), - channel: "123".into(), - message: "hi".into(), - platform: "discord".into(), - sender_name: "test".into(), - thread_id: None, - timezone: "UTC".into(), - disable_on_success: None, - disable_on_success_match: None, - disable_on_success_timeout_secs: 60, - disable_on_success_working_dir: None, - }]; - assert!(validate_cronjobs(&jobs, &["discord"]).is_err()); - } - - // --- file_mtime tests --- - - #[test] - fn file_mtime_detects_change() { - let dir = tempfile::tempdir().unwrap(); - let path = dir.path().join("test.toml"); - assert!(file_mtime(&path).is_none()); // doesn't exist yet - std::fs::write(&path, "v1").unwrap(); - let m1 = file_mtime(&path); - assert!(m1.is_some()); - // Sleep briefly to ensure mtime differs - std::thread::sleep(std::time::Duration::from_millis(50)); - std::fs::write(&path, "v2").unwrap(); - let m2 = file_mtime(&path); - assert!(m2.is_some()); - assert!(m2 != m1); - } - - // --- CronConfig TOML deserialization --- - - #[test] - fn cron_config_toml_parses() { - use crate::config::Config; - let toml_str = r#" -[agent] -command = "echo" - -[cron] -usercron_enabled = true -usercron_path = "cronjob.toml" - -[[cron.jobs]] -schedule = "0 9 * * 1-5" -channel = "123" -message = "hello" - -[[cron.jobs]] -schedule = "*/30 * * * *" -channel = "456" -message = "ping" -platform = "slack" -"#; - let cfg: Config = toml::from_str(toml_str).unwrap(); - assert!(cfg.cron.usercron_enabled); - assert_eq!(cfg.cron.usercron_path.as_deref(), Some("cronjob.toml")); - assert_eq!(cfg.cron.jobs.len(), 2); - assert_eq!(cfg.cron.jobs[0].message, "hello"); - assert_eq!(cfg.cron.jobs[1].platform, "slack"); - } - - #[test] - fn cron_config_defaults_when_omitted() { - use crate::config::Config; - let toml_str = r#" -[agent] -command = "echo" -"#; - let cfg: Config = toml::from_str(toml_str).unwrap(); - assert!(!cfg.cron.usercron_enabled); - assert!(cfg.cron.usercron_path.is_none()); - assert!(cfg.cron.jobs.is_empty()); - } - - // --- load_usercron empty file --- - - #[test] - fn load_usercron_empty_file_returns_empty() { - let dir = tempfile::tempdir().unwrap(); - let path = dir.path().join("cronjob.toml"); - std::fs::write(&path, "").unwrap(); - let jobs = load_usercron_file(&path, &["discord"]); - assert!(jobs.is_empty()); - } -} diff --git a/src/directives.rs b/src/directives.rs deleted file mode 100644 index c1d5c27f2..000000000 --- a/src/directives.rs +++ /dev/null @@ -1,314 +0,0 @@ -//! Control Directives parser (ADR: control-directives.md). -//! -//! Extracts leading `[[key:value]]` directives from the first message in a -//! session, strips them from the prompt, and returns structured metadata. - -use regex::Regex; -use std::collections::HashMap; -use std::path::{Path, PathBuf}; -use std::sync::LazyLock; -use tracing::warn; - -static DIRECTIVE_RE: LazyLock = - LazyLock::new(|| Regex::new(r"^\s*\[\[([a-z_]+):([^\]]*)\]\]").unwrap()); - -/// Parsed control directives from a session's first message. -#[derive(Debug, Clone, Default)] -pub struct SessionMetadata { - /// Resolved canonical workspace path (None = use default working_dir). - #[allow(dead_code)] - pub workspace: Option, - /// Thread title override (None = use generated title). - pub title: Option, - /// Raw directives map for forward-compatible unknown keys. - pub raw: HashMap, -} - -/// Result of parsing directives from a prompt. -pub struct ParseResult { - /// The prompt with leading directives stripped. - pub prompt: String, - /// Parsed session metadata. - pub metadata: SessionMetadata, -} - -/// Parse leading `[[key:value]]` directives from a prompt string. -/// -/// Directives must appear at the start of the message (after optional -/// whitespace). The first line/token that is not a directive stops parsing; -/// any `[[key:value]]` text after that point is preserved verbatim. -pub fn parse_directives(input: &str) -> ParseResult { - let mut raw: HashMap = HashMap::new(); - let mut remaining = input; - - loop { - remaining = remaining.trim_start_matches([' ', '\t']); - if remaining.starts_with('\n') || remaining.starts_with("\r\n") { - // A blank line after directives = end of header - let next = remaining.trim_start_matches(['\r', '\n']); - let next_trimmed = next.trim_start_matches([' ', '\t']); - if !next_trimmed.starts_with("[[") { - remaining = next; - break; - } - remaining = remaining.trim_start_matches(['\r', '\n']); - } - if let Some(caps) = DIRECTIVE_RE.captures(remaining) { - let full_match = caps.get(0).unwrap(); - let key = caps[1].to_string(); - let value = caps[2].to_string(); - // Last value wins for duplicate keys - raw.insert(key, value); - remaining = &remaining[full_match.end()..]; - } else { - break; - } - } - - let prompt = remaining.trim().to_string(); - let metadata = SessionMetadata { - workspace: None, // resolved later by resolve_workspace - title: raw.get("title").cloned(), - raw, - }; - - ParseResult { prompt, metadata } -} - -/// Resolve the `[[ws:...]]` directive value into a canonical path. -/// -/// Supports: -/// - Raw paths: `~/projects/foo` or `/home/bot/projects/foo` -/// - Aliases: `@alias_name` → looked up in `aliases` map -/// -/// Returns `Err` with a user-visible message on failure. -pub fn resolve_workspace( - raw_value: &str, - aliases: &HashMap, - bot_home: &Path, -) -> Result { - let path_str = if let Some(alias) = raw_value.strip_prefix('@') { - match aliases.get(alias) { - Some(resolved) => resolved.as_str(), - None => { - let available: Vec<&str> = aliases.keys().map(|s| s.as_str()).collect(); - return Err(format!( - "Unknown workspace alias `@{alias}`. Available: {}", - if available.is_empty() { - "(none configured)".to_string() - } else { - available.join(", ") - } - )); - } - } - } else { - raw_value - }; - - // Rule 1: reject relative paths - if !path_str.starts_with('~') && !path_str.starts_with('/') { - return Err(format!( - "Workspace path must be absolute (start with `~` or `/`): `{path_str}`" - )); - } - - // Rule 2: expand ~ - let expanded = if let Some(rest) = path_str.strip_prefix('~') { - let rest = rest.strip_prefix('/').unwrap_or(rest); - bot_home.join(rest) - } else { - PathBuf::from(path_str) - }; - - // Rule 3: canonicalize both paths - let canonical_home = bot_home.canonicalize().map_err(|e| { - warn!(path = %bot_home.display(), error = %e, "cannot canonicalize bot home"); - "Internal error: cannot resolve bot home directory".to_string() - })?; - - let canonical_target = expanded.canonicalize().map_err(|e| { - warn!(path = %expanded.display(), error = %e, "cannot canonicalize workspace path"); - format!( - "Workspace path does not exist: `{path_str}` (expanded to `{}`)", - expanded.display() - ) - })?; - - // Rule 4+5: verify within bot home subtree - if !canonical_target.starts_with(&canonical_home) { - return Err(format!( - "Workspace path is outside allowed directory: `{path_str}`" - )); - } - - // Rule 6: must be a directory (not a file) - if !canonical_target.is_dir() { - return Err(format!( - "Workspace path is not a directory: `{}`", - canonical_target.display() - )); - } - - Ok(canonical_target) -} - -#[cfg(test)] -mod tests { - use super::*; - use std::fs; - use tempfile::TempDir; - - #[test] - fn parse_basic_directives() { - let input = "[[ws:~/projects/foo]] [[title:Bug fix]]\ninvestigate the build failure"; - let result = parse_directives(input); - assert_eq!(result.prompt, "investigate the build failure"); - assert_eq!(result.metadata.raw.get("ws").unwrap(), "~/projects/foo"); - assert_eq!(result.metadata.title.as_deref(), Some("Bug fix")); - } - - #[test] - fn parse_directives_multiline_header() { - let input = "[[ws:@openab]]\n[[title:Review PR]]\nplease review this change"; - let result = parse_directives(input); - assert_eq!(result.prompt, "please review this change"); - assert_eq!(result.metadata.raw.get("ws").unwrap(), "@openab"); - assert_eq!(result.metadata.title.as_deref(), Some("Review PR")); - } - - #[test] - fn parse_preserves_body_directives() { - let input = "[[title:Test]]\nHere is some code with [[key:value]] in it"; - let result = parse_directives(input); - assert_eq!(result.prompt, "Here is some code with [[key:value]] in it"); - assert_eq!(result.metadata.title.as_deref(), Some("Test")); - assert!(!result.metadata.raw.contains_key("key")); - } - - #[test] - fn parse_no_directives() { - let input = "just a regular message"; - let result = parse_directives(input); - assert_eq!(result.prompt, "just a regular message"); - assert!(result.metadata.raw.is_empty()); - } - - #[test] - fn parse_duplicate_keys_last_wins() { - let input = "[[title:First]] [[title:Second]]\ndo stuff"; - let result = parse_directives(input); - assert_eq!(result.metadata.title.as_deref(), Some("Second")); - } - - #[test] - fn parse_empty_value() { - let input = "[[title:]]\ndo stuff"; - let result = parse_directives(input); - assert_eq!(result.metadata.title.as_deref(), Some("")); - } - - #[test] - fn parse_unknown_keys_ignored() { - let input = "[[foo:bar]] [[ws:~/x]]\ndo stuff"; - let result = parse_directives(input); - assert_eq!(result.metadata.raw.get("foo").unwrap(), "bar"); - assert_eq!(result.prompt, "do stuff"); - } - - #[test] - fn resolve_alias_success() { - let tmp = TempDir::new().unwrap(); - let projects = tmp.path().join("projects").join("openab"); - fs::create_dir_all(&projects).unwrap(); - - let mut aliases = HashMap::new(); - aliases.insert( - "openab".to_string(), - format!("{}/projects/openab", tmp.path().display()), - ); - - let result = resolve_workspace("@openab", &aliases, tmp.path()).unwrap(); - assert_eq!(result, projects.canonicalize().unwrap()); - } - - #[test] - fn resolve_alias_not_found() { - let tmp = TempDir::new().unwrap(); - let aliases = HashMap::new(); - let result = resolve_workspace("@nope", &aliases, tmp.path()); - assert!(result.is_err()); - assert!(result.unwrap_err().contains("Unknown workspace alias")); - } - - #[test] - fn resolve_relative_path_rejected() { - let tmp = TempDir::new().unwrap(); - let aliases = HashMap::new(); - let result = resolve_workspace("relative/path", &aliases, tmp.path()); - assert!(result.is_err()); - assert!(result.unwrap_err().contains("must be absolute")); - } - - #[test] - fn resolve_outside_home_rejected() { - let tmp = TempDir::new().unwrap(); - let aliases = HashMap::new(); - let result = resolve_workspace("/tmp", &aliases, tmp.path()); - assert!(result.is_err()); - assert!(result.unwrap_err().contains("outside allowed directory")); - } - - #[test] - fn resolve_tilde_expansion() { - let tmp = TempDir::new().unwrap(); - let projects = tmp.path().join("myapp"); - fs::create_dir_all(&projects).unwrap(); - - let aliases = HashMap::new(); - let result = resolve_workspace("~/myapp", &aliases, tmp.path()).unwrap(); - assert_eq!(result, projects.canonicalize().unwrap()); - } - - #[test] - fn resolve_nonexistent_path() { - let tmp = TempDir::new().unwrap(); - let aliases = HashMap::new(); - let result = resolve_workspace("~/does_not_exist", &aliases, tmp.path()); - assert!(result.is_err()); - assert!(result.unwrap_err().contains("does not exist")); - } - - #[test] - fn parse_directives_leading_spaces_on_newline() { - let input = "[[ws:@openab]]\n [[title:Fix CI]]\nhelp me debug"; - let result = parse_directives(input); - assert_eq!(result.prompt, "help me debug"); - assert_eq!(result.metadata.raw.get("ws").unwrap(), "@openab"); - assert_eq!(result.metadata.title.as_deref(), Some("Fix CI")); - } - - #[test] - fn resolve_file_path_rejected() { - let tmp = TempDir::new().unwrap(); - let file_path = tmp.path().join("Cargo.toml"); - fs::write(&file_path, "").unwrap(); - - let aliases = HashMap::new(); - let result = resolve_workspace(&format!("{}", file_path.display()), &aliases, tmp.path()); - assert!(result.is_err()); - assert!(result.unwrap_err().contains("not a directory")); - } - - #[test] - fn resolve_error_shows_expanded_path() { - let tmp = TempDir::new().unwrap(); - let aliases = HashMap::new(); - let result = resolve_workspace("~/no_such_dir", &aliases, tmp.path()); - assert!(result.is_err()); - let err = result.unwrap_err(); - // Error should contain both the original and expanded path - assert!(err.contains("~/no_such_dir")); - assert!(err.contains(&tmp.path().display().to_string())); - } -} diff --git a/src/discord.rs b/src/discord.rs deleted file mode 100644 index 12281afad..000000000 --- a/src/discord.rs +++ /dev/null @@ -1,3203 +0,0 @@ -use crate::acp::protocol::ConfigOption; -use crate::acp::ContentBlock; -use crate::adapter::{AdapterRouter, ChannelRef, ChatAdapter, MessageRef, SenderContext}; -use crate::bot_turns::{BotTurnTracker, TurnAction, TurnSeverity, BOT_TURN_LIMIT_WARNING_PREFIX}; -use crate::config::{AllowBots, AllowUsers, SttConfig}; -use crate::format; -use crate::media; -use crate::remind::{self, ReminderStore}; -use async_trait::async_trait; -use serenity::builder::{ - CreateActionRow, CreateAttachment, CreateButton, CreateCommand, CreateCommandOption, - CreateInteractionResponse, CreateInteractionResponseFollowup, CreateInteractionResponseMessage, - CreateSelectMenu, CreateSelectMenuKind, CreateSelectMenuOption, CreateThread, EditChannel, - EditMessage, GetMessages, -}; -use serenity::http::Http; -use serenity::model::application::ButtonStyle; -use serenity::model::application::{Command, CommandOptionType, ComponentInteractionDataKind, Interaction}; -use serenity::model::channel::{AutoArchiveDuration, Message, MessageType, ReactionType}; -use serenity::model::gateway::Ready; -use serenity::model::id::{ChannelId, MessageId, UserId}; -use serenity::prelude::*; -use std::collections::{HashMap, HashSet}; -use std::sync::LazyLock; -use std::sync::{Arc, OnceLock}; -use tracing::{debug, error, info, warn}; - -/// Hard cap on consecutive bot messages in a channel or thread. -/// Prevents runaway loops between multiple bots in "all" mode. -const MAX_CONSECUTIVE_BOT_TURNS: u32 = 1000; - -/// Maximum entries in the participation cache before eviction. -const PARTICIPATION_CACHE_MAX: usize = 1000; - -/// Discord StringSelectMenu hard limit on options. -const SELECT_MENU_PAGE_SIZE: usize = 25; - -/// Avoid unbounded Discord history exports from very large threads. -const THREAD_EXPORT_MESSAGE_LIMIT: usize = 5000; - -// --- DiscordAdapter: implements ChatAdapter for Discord via serenity --- - -pub struct DiscordAdapter { - http: Arc, -} - -impl DiscordAdapter { - pub fn new(http: Arc) -> Self { - Self { http } - } - - /// Resolve the effective Discord channel ID from a ChannelRef. - /// Discord threads are channels, so prefer thread_id when set. - fn resolve_channel(channel: &ChannelRef) -> &str { - channel.thread_id.as_deref().unwrap_or(&channel.channel_id) - } -} - -#[async_trait] -impl ChatAdapter for DiscordAdapter { - fn platform(&self) -> &'static str { - "discord" - } - - fn message_limit(&self) -> usize { - 2000 - } - - async fn send_message( - &self, - channel: &ChannelRef, - content: &str, - ) -> anyhow::Result { - let ch_id: u64 = Self::resolve_channel(channel).parse()?; - let msg = ChannelId::new(ch_id).say(&self.http, content).await?; - Ok(MessageRef { - channel: channel.clone(), - message_id: msg.id.to_string(), - }) - } - - async fn send_message_with_reply( - &self, - channel: &ChannelRef, - content: &str, - reply_to_message_id: &str, - ) -> anyhow::Result { - let ch_id: u64 = Self::resolve_channel(channel).parse()?; - let msg_id: u64 = reply_to_message_id.parse().unwrap_or(0); - if msg_id == 0 { - // Invalid message ID, fall back to plain send - return self.send_message(channel, content).await; - } - let builder = serenity::builder::CreateMessage::new() - .content(content) - .reference_message((ChannelId::new(ch_id), MessageId::new(msg_id))); - match ChannelId::new(ch_id) - .send_message(&self.http, builder) - .await - { - Ok(msg) => Ok(MessageRef { - channel: channel.clone(), - message_id: msg.id.to_string(), - }), - Err(e) => { - // Fallback to plain send if reply fails (e.g. unknown message, cross-channel) - tracing::warn!(error = ?e, reply_to = reply_to_message_id, "reply_to failed, falling back to plain send"); - self.send_message(channel, content).await - } - } - } - - async fn delete_message(&self, msg: &MessageRef) -> anyhow::Result<()> { - let ch_id: u64 = Self::resolve_channel(&msg.channel).parse()?; - let msg_id: u64 = msg.message_id.parse()?; - self.http - .delete_message(ChannelId::new(ch_id), MessageId::new(msg_id), None) - .await?; - Ok(()) - } - - async fn edit_message(&self, msg: &MessageRef, content: &str) -> anyhow::Result<()> { - let ch_id: u64 = Self::resolve_channel(&msg.channel).parse()?; - let msg_id: u64 = msg.message_id.parse()?; - ChannelId::new(ch_id) - .edit_message( - &self.http, - MessageId::new(msg_id), - EditMessage::new().content(content), - ) - .await?; - Ok(()) - } - - fn use_streaming(&self, other_bot_present: bool) -> bool { - !other_bot_present - } - - async fn create_thread( - &self, - channel: &ChannelRef, - trigger_msg: &MessageRef, - title: &str, - ) -> anyhow::Result { - let ch_id: u64 = channel.channel_id.parse()?; - let msg_id: u64 = trigger_msg.message_id.parse()?; - let thread = ChannelId::new(ch_id) - .create_thread_from_message( - &self.http, - MessageId::new(msg_id), - CreateThread::new(title).auto_archive_duration(AutoArchiveDuration::OneDay), - ) - .await?; - Ok(ChannelRef { - platform: "discord".into(), - channel_id: thread.id.to_string(), - thread_id: None, - parent_id: Some(channel.channel_id.clone()), - origin_event_id: None, - }) - } - - async fn add_reaction(&self, msg: &MessageRef, emoji: &str) -> anyhow::Result<()> { - let ch_id: u64 = Self::resolve_channel(&msg.channel).parse()?; - let msg_id: u64 = msg.message_id.parse()?; - self.http - .create_reaction( - ChannelId::new(ch_id), - MessageId::new(msg_id), - &ReactionType::Unicode(emoji.to_string()), - ) - .await?; - Ok(()) - } - - async fn remove_reaction(&self, msg: &MessageRef, emoji: &str) -> anyhow::Result<()> { - let ch_id: u64 = Self::resolve_channel(&msg.channel).parse()?; - let msg_id: u64 = msg.message_id.parse()?; - self.http - .delete_reaction_me( - ChannelId::new(ch_id), - MessageId::new(msg_id), - &ReactionType::Unicode(emoji.to_string()), - ) - .await?; - Ok(()) - } - - async fn rename_thread(&self, channel: &ChannelRef, title: &str) -> anyhow::Result<()> { - let ch_id: u64 = Self::resolve_channel(channel).parse()?; - // Truncate at char boundary to avoid panic on multi-byte chars (中文/Emoji). - let truncated: &str = if title.chars().count() > 100 { - let end = title.char_indices().nth(100).map(|(i, _)| i).unwrap_or(title.len()); - &title[..end] - } else { - title - }; - ChannelId::new(ch_id) - .edit(&self.http, EditChannel::new().name(truncated)) - .await?; - Ok(()) - } -} - -// --- Handler: serenity EventHandler that delegates to AdapterRouter --- - -pub struct Handler { - pub router: Arc, - pub allow_all_channels: bool, - pub allow_all_users: bool, - pub allowed_channels: HashSet, - pub allowed_users: HashSet, - pub stt_config: SttConfig, - pub adapter: OnceLock>, - pub allow_bot_messages: AllowBots, - pub trusted_bot_ids: HashSet, - pub allow_user_messages: AllowUsers, - /// Role IDs that trigger the bot (same as direct @mention). - pub allowed_role_ids: HashSet, - /// Positive-only cache: thread channel_id → cached_at for threads where bot has participated. - pub participated_threads: tokio::sync::Mutex>, - /// Positive-only cache: thread channel_id → cached_at for threads where other bots have posted. - /// Like participation, a thread becoming multi-bot is irreversible (bot messages don't disappear). - pub multibot_threads: tokio::sync::Mutex>, - /// Persistent disk cache for multibot thread detection (survives restarts). - pub multibot_cache: crate::multibot_cache::MultibotCache, - /// TTL for participation cache entries (from pool.session_ttl_hours). - pub session_ttl: std::time::Duration, - /// Configurable soft limit on bot turns per thread (reset by human message). - pub max_bot_turns: u32, - /// Per-thread bot turn tracker. Both counters reset on human msg. - pub bot_turns: tokio::sync::Mutex, - /// Allow the bot to respond to Discord DMs. - pub allow_dm: bool, - /// Per-thread dispatcher (Message mode uses cap=1 for FIFO; Thread/Lane use configured cap). - pub dispatcher: Arc, - /// Reminder store for /remind slash command. - pub reminder_store: ReminderStore, - /// Track scheduled reminder IDs to prevent duplicate scheduling on reconnect. - pub scheduled_ids: tokio::sync::Mutex>, -} - -impl Handler { - /// Check if the bot has participated in a Discord thread, and whether - /// other bots have also posted in it. - /// Returns `(involved, other_bot_present)`. - /// Fail-closed: returns `(false, false)` on API error. - /// Caches positive results only (both participation and multi-bot status are irreversible). - async fn bot_participated_in_thread( - &self, - http: &Http, - channel_id: ChannelId, - bot_id: UserId, - ) -> (bool, bool) { - let key = channel_id.to_string(); - - // Check positive caches - let cached_involved = { - let cache = self.participated_threads.lock().await; - cache - .get(&key) - .is_some_and(|ts| ts.elapsed() < self.session_ttl) - }; - let cached_multibot = { - let cache = self.multibot_threads.lock().await; - cache - .get(&key) - .is_some_and(|ts| ts.elapsed() < self.session_ttl) - } || self.multibot_cache.is_multibot(&key); - - // Both cached → skip fetch entirely - // With early detection from msg.author, multibot_threads is populated - // eagerly — no need to fetch just to check for other bots. - if cached_involved { - return (true, cached_multibot); - } - - // Fetch recent messages - let messages = match channel_id - .messages(http, serenity::builder::GetMessages::new().limit(200)) - .await - { - Ok(msgs) => msgs, - Err(e) => { - tracing::warn!( - channel_id = %channel_id, - error = %e, - "failed to fetch thread messages for participation check, rejecting (fail-closed)" - ); - return (false, false); - } - }; - - let involved = cached_involved || messages.iter().any(|m| m.author.id == bot_id); - // other_bot_present relies solely on early detection + disk cache; - // no longer scanned from fetched messages (200-msg window was unreliable). - let other_bot_present = cached_multibot; - - if involved && !cached_involved { - let mut cache = self.participated_threads.lock().await; - cache.insert(key.clone(), tokio::time::Instant::now()); - - // Evict if over capacity - if cache.len() > PARTICIPATION_CACHE_MAX { - cache.retain(|_, ts| ts.elapsed() < self.session_ttl); - if cache.len() > PARTICIPATION_CACHE_MAX { - let mut entries: Vec<_> = cache.iter().map(|(k, v)| (k.clone(), *v)).collect(); - entries.sort_by_key(|(_, ts)| *ts); - let evict_count = entries.len() / 2; - for (k, _) in entries.into_iter().take(evict_count) { - cache.remove(&k); - } - } - } - } - - (involved, other_bot_present) - } -} - -#[serenity::async_trait] -impl EventHandler for Handler { - async fn message(&self, ctx: Context, msg: Message) { - let bot_id = ctx.cache.current_user().id; - - // Early multibot detection: cache that another bot is present. - // Runs before self-check and bot gating so we always detect other bots. (#481) - if msg.author.bot && msg.author.id != bot_id { - let key = msg.channel_id.to_string(); - { - let mut cache = self.multibot_threads.lock().await; - cache.entry(key.clone()).or_insert_with(tokio::time::Instant::now); - } - // Persist to disk — multibot is irreversible - self.multibot_cache.mark_multibot(&key).await; - } - - // Bot turn counting: runs before self-check so ALL bot messages - // (including own) count toward the per-thread limit. This means - // soft_limit=20 = 20 total bot messages in the thread (~10 per bot - // in a two-bot ping-pong). (#483) - { - let thread_key = msg.channel_id.to_string(); - let mut tracker = self.bot_turns.lock().await; - if msg.author.bot { - match tracker.classify_bot_message(&thread_key) { - TurnAction::Continue => {} - TurnAction::SilentStop => return, - TurnAction::WarnAndStop { - severity, - turns, - user_message, - } => { - match severity { - TurnSeverity::Hard => tracing::warn!( - channel_id = %msg.channel_id, - turns, - "hard bot turn limit reached", - ), - TurnSeverity::Soft => tracing::info!( - channel_id = %msg.channel_id, - turns, - max = self.max_bot_turns, - "soft bot turn limit reached", - ), - } - // Only post the warning if this bot is allowed in the channel/thread. - // Bot turn counting intentionally runs before channel gating so ALL - // bot messages are counted, but the *warning message* must respect - // channel permissions — otherwise bots that never participated in a - // thread will spam it with warnings. - // - // Must match the full thread allowlist semantics: a thread is allowed - // if its own channel_id OR its parent_id is in allowed_channels. - let ch = msg.channel_id.get(); - let in_allowed_channel = self.allowed_channels.contains(&ch); - let mut allowed_here = self.allow_all_channels || in_allowed_channel; - if !allowed_here { - // Reuse detect_thread() for thread allowlist semantics. - // Only called on the WarnAndStop path (once per soft/hard - // limit hit), not on every bot message. - if let Ok(serenity::model::channel::Channel::Guild(gc)) = - msg.channel_id.to_channel(&ctx.http).await - { - let (in_thread, _) = detect_thread( - gc.thread_metadata.is_some(), - gc.parent_id.map(|id| id.get()), - gc.owner_id.map(|id| id.get()), - bot_id.get(), - &self.allowed_channels, - self.allow_all_channels, - in_allowed_channel, - ); - if in_thread { - allowed_here = true; - } - } - } - if msg.author.id != bot_id && allowed_here { - // Only warn if this bot actually participated in the - // thread — prevents uninvolved bots from spamming - // warnings in shared channels. (#727) - // Second value is `is_multibot`; not needed here. - let (participated, _) = self - .bot_participated_in_thread(&ctx.http, msg.channel_id, bot_id) - .await; - if participated { - // Dedup: skip if another bot already posted the same - // warning in this thread. Prevents N duplicate warnings - // when N bot processes each hit the soft limit. (#530) - let recent = msg - .channel_id - .messages( - &ctx.http, - serenity::builder::GetMessages::new().limit(10), - ) - .await - .unwrap_or_default(); - let pairs: Vec<(bool, &str)> = recent - .iter() - .map(|m| (m.author.bot, m.content.as_str())) - .collect(); - let already_warned = turn_limit_warning_present(&pairs); - if !already_warned { - let _ = msg.channel_id.say(&ctx.http, &user_message).await; - } - } - } - return; - } - } - } else if matches!(msg.kind, MessageType::Regular | MessageType::InlineReply) - && !msg.content.is_empty() - { - tracker.on_human_message(&thread_key); - } - } - - // Ignore own messages (after counting toward bot turns above) - if msg.author.id == bot_id { - return; - } - - let adapter = self - .adapter - .get_or_init(|| Arc::new(DiscordAdapter::new(ctx.http.clone()))) - .clone(); - - let channel_id = msg.channel_id.get(); - let in_allowed_channel = - self.allow_all_channels || self.allowed_channels.contains(&channel_id); - - let is_mentioned = msg.mentions_user_id(bot_id) - || msg.content.contains(&format!("<@{}>", bot_id)) - || (!self.allowed_role_ids.is_empty() - && msg - .mention_roles - .iter() - .any(|r| self.allowed_role_ids.contains(&r.get()))); - - // Bot message gating (from upstream #321) - if msg.author.bot { - // Trusted bot admission override: when a bot listed in `trusted_bot_ids` - // explicitly @mentions this bot, bypass the entire `allow_bot_messages` - // mode check. This treats the trusted bot's @mention identically to a - // human @mention — the bot becomes involved in the thread and the message - // is dispatched regardless of the `allow_bot_messages` setting. - // - // Rationale: `trusted_bot_ids` expresses admin-level trust. A trusted bot - // that @mentions this bot is performing a deliberate handoff/coordination - // action, equivalent to a human pulling the bot into a conversation. - // - // Safety: requires both (1) explicit @mention AND (2) sender in - // trusted_bot_ids. Messages from trusted bots without @mention still - // follow normal gating. Empty trusted_bot_ids (default) disables this - // entirely — no behavioral change for existing deployments. - let trusted_mention = is_mentioned - && !self.trusted_bot_ids.is_empty() - && self.trusted_bot_ids.contains(&msg.author.id.get()); - - if !trusted_mention { - match self.allow_bot_messages { - AllowBots::Off => return, - AllowBots::Mentions => { - if !is_mentioned { - return; - } - } - AllowBots::All => { - let cap = MAX_CONSECUTIVE_BOT_TURNS as usize; - let limit = std::cmp::min(MAX_CONSECUTIVE_BOT_TURNS, 100) as u8; - let history = ctx - .cache - .channel_messages(msg.channel_id) - .map(|msgs| { - let mut recent: Vec<_> = msgs - .iter() - .filter(|(mid, _)| **mid < msg.id) - .map(|(_, m)| m.clone()) - .collect(); - recent.sort_unstable_by_key(|m| std::cmp::Reverse(m.id)); - recent.truncate(cap); - recent - }) - .filter(|msgs| !msgs.is_empty()); - - let recent = if let Some(cached) = history { - cached - } else { - match msg - .channel_id - .messages( - &ctx.http, - serenity::builder::GetMessages::new() - .before(msg.id) - .limit(limit), - ) - .await - { - Ok(msgs) => msgs, - Err(e) => { - tracing::warn!(channel_id = %msg.channel_id, error = %e, "failed to fetch history for bot turn cap, rejecting (fail-closed)"); - return; - } - } - }; - - let consecutive_bot = recent - .iter() - .take_while(|m| m.author.bot && m.author.id != bot_id) - .count(); - if consecutive_bot >= cap { - tracing::warn!(channel_id = %msg.channel_id, cap, "bot turn cap reached, ignoring"); - return; - } - } - } - - if !self.trusted_bot_ids.is_empty() - && !self.trusted_bot_ids.contains(&msg.author.id.get()) - { - tracing::debug!(bot_id = %msg.author.id, "bot not in trusted_bot_ids, ignoring"); - return; - } - } - } - - // Thread detection: single to_channel() call for both allowed and - // non-allowed channels. Uses thread_metadata (not parent_id) to - // identify threads — see detect_thread() doc comments for rationale. - let (in_thread, bot_owns_thread, thread_parent_id, is_dm) = match msg - .channel_id - .to_channel(&ctx.http) - .await - { - Ok(serenity::model::channel::Channel::Guild(gc)) => { - let parent = gc.parent_id.map(|id| id.get().to_string()); - let result = detect_thread( - gc.thread_metadata.is_some(), - gc.parent_id.map(|id| id.get()), - gc.owner_id.map(|id| id.get()), - bot_id.get(), - &self.allowed_channels, - self.allow_all_channels, - in_allowed_channel, - ); - tracing::debug!( - channel_id = %msg.channel_id, - parent_id = ?gc.parent_id, - owner_id = ?gc.owner_id, - has_thread_metadata = gc.thread_metadata.is_some(), - in_thread = result.0, - bot_owns = ?result.1, - "thread check" - ); - ( - result.0, - result.1.unwrap_or(false), - if result.0 { parent } else { None }, - false, - ) - } - Ok(serenity::model::channel::Channel::Private(_)) => { - tracing::debug!(channel_id = %msg.channel_id, "DM channel"); - (false, false, None, true) - } - Ok(other) => { - tracing::debug!(channel_id = %msg.channel_id, kind = ?other, "not a guild thread"); - (false, false, None, false) - } - Err(e) => { - tracing::debug!(channel_id = %msg.channel_id, error = %e, "to_channel failed"); - (false, false, None, false) - } - }; - - // DM gating: allow_dm must be true, otherwise reject - if is_dm && !self.allow_dm { - tracing::debug!(channel_id = %msg.channel_id, "DM rejected (allow_dm=false)"); - return; - } - - if !is_dm && !in_allowed_channel && !in_thread { - return; - } - - // User message gating (mirrors Slack's AllowUsers logic). - // Mentions: always require @mention, even in bot's own threads. - // Involved (default): skip @mention if the bot owns the thread - // (Option A) OR has previously posted in it (Option B). - // MultibotMentions: same as Involved, but if other bots are also - // in the thread, require @mention to avoid all bots responding. - // DMs are treated as implicit @mention (mirrors Slack behavior). - if !is_mentioned && !is_dm { - match self.allow_user_messages { - AllowUsers::Mentions => return, - AllowUsers::Involved => { - if !in_thread { - return; - } - let (involved, _) = if bot_owns_thread { - (true, false) // other_bot_present not needed for Involved mode - } else { - self.bot_participated_in_thread(&ctx.http, msg.channel_id, bot_id) - .await - }; - if !involved { - tracing::debug!(channel_id = %msg.channel_id, "bot not involved in thread, ignoring"); - return; - } - } - AllowUsers::MultibotMentions => { - if !in_thread { - return; - } - let (involved, other_bot) = if bot_owns_thread { - // Still need to check for other bots - let (_, other) = self - .bot_participated_in_thread(&ctx.http, msg.channel_id, bot_id) - .await; - (true, other) - } else { - self.bot_participated_in_thread(&ctx.http, msg.channel_id, bot_id) - .await - }; - if !involved { - tracing::debug!(channel_id = %msg.channel_id, "bot not involved in thread, ignoring"); - return; - } - if other_bot { - tracing::debug!(channel_id = %msg.channel_id, "multi-bot thread, requiring @mention"); - return; - } - } - } - } - - if is_denied_user( - msg.author.bot, - self.allow_all_users, - &self.allowed_users, - msg.author.id.get(), - ) { - tracing::info!(user_id = %msg.author.id, "denied user, ignoring"); - let msg_ref = discord_msg_ref(&msg); - let _ = adapter.add_reaction(&msg_ref, "🚫").await; - return; - } - - let prompt = resolve_mentions(&msg.content, bot_id, &self.allowed_role_ids); - - // No text and no attachments → skip - if prompt.is_empty() && msg.attachments.is_empty() { - return; - } - - let display_name = msg - .member - .as_ref() - .and_then(|m| m.nick.as_ref()) - .or(msg.author.global_name.as_ref()) - .unwrap_or(&msg.author.name); - let sender = build_sender_context( - &msg.author.id.to_string(), - &msg.author.name, - display_name, - &msg.channel_id.to_string(), - thread_parent_id.as_deref(), - msg.author.bot, - &msg.timestamp.to_rfc3339().unwrap_or_default(), - &msg.id.to_string(), - &bot_id.to_string(), - ); - - // Build extra content blocks from attachments (audio -> STT, text -> inline, - // image -> encode, video -> URL for agent-side inspection). - let mut extra_blocks = Vec::new(); - let mut echo_entries: Vec = Vec::new(); - let mut failed_image_files: Vec = Vec::new(); - let mut text_file_bytes: u64 = 0; - let mut text_file_count: u32 = 0; - const TEXT_TOTAL_CAP: u64 = 1024 * 1024; // 1 MB total for all text file attachments - const TEXT_FILE_COUNT_CAP: u32 = 5; - - for attachment in &msg.attachments { - let mime = attachment.content_type.as_deref().unwrap_or(""); - if media::is_audio_mime(mime) { - if self.stt_config.enabled { - let mime_clean = mime.split(';').next().unwrap_or(mime).trim(); - match media::download_and_transcribe( - &attachment.url, - &attachment.filename, - mime_clean, - u64::from(attachment.size), - &self.stt_config, - None, - ) - .await - { - Some(transcript) => { - debug!(filename = %attachment.filename, chars = transcript.len(), "voice transcript injected"); - extra_blocks.insert( - 0, - ContentBlock::Text { - text: format!("[Voice message transcript]: {transcript}"), - }, - ); - echo_entries.push(crate::stt::EchoEntry::Success(transcript)); - } - None => { - warn!(filename = %attachment.filename, "STT failed for voice attachment"); - echo_entries.push(crate::stt::EchoEntry::Failed); - } - } - } else { - tracing::warn!(filename = %attachment.filename, "skipping audio attachment (STT disabled)"); - let msg_ref = discord_msg_ref(&msg); - let _ = adapter.add_reaction(&msg_ref, "🎤").await; - } - } else if media::is_text_file(&attachment.filename, attachment.content_type.as_deref()) - { - if text_file_count >= TEXT_FILE_COUNT_CAP { - tracing::warn!(filename = %attachment.filename, count = text_file_count, "text file count cap reached, skipping"); - continue; - } - // Pre-check with Discord-reported size (fast path, avoids unnecessary download). - // Running total uses actual downloaded bytes for accurate accounting. - if text_file_bytes + u64::from(attachment.size) > TEXT_TOTAL_CAP { - tracing::warn!(filename = %attachment.filename, total = text_file_bytes, "text attachments total exceeds 1MB cap, skipping remaining"); - continue; - } - if let Some((block, actual_bytes)) = media::download_and_read_text_file( - &attachment.url, - &attachment.filename, - u64::from(attachment.size), - None, - ) - .await - { - text_file_bytes += actual_bytes; - text_file_count += 1; - debug!(filename = %attachment.filename, "adding text file attachment"); - extra_blocks.push(block); - } - } else { - match media::download_and_encode_image( - &attachment.url, - attachment.content_type.as_deref(), - &attachment.filename, - u64::from(attachment.size), - None, - ) - .await - { - Ok(block) => { - debug!(url = %attachment.url, filename = %attachment.filename, "adding image attachment"); - extra_blocks.push(block); - } - Err(media::MediaFetchError::NotAnImage) => { - if media::is_video_file( - &attachment.filename, - attachment.content_type.as_deref(), - ) { - debug!(url = %attachment.url, filename = %attachment.filename, "adding video attachment link"); - extra_blocks.push(video_attachment_block( - &attachment.filename, - attachment.content_type.as_deref(), - u64::from(attachment.size), - &attachment.url, - )); - } - } - Err(e) => { - tracing::warn!( - url = %attachment.url, - filename = %attachment.filename, - error = %e, - "image attachment failed" - ); - failed_image_files.push(attachment.filename.clone()); - } - } - } - } - - tracing::debug!( - num_extra_blocks = extra_blocks.len(), - num_attachments = msg.attachments.len(), - in_thread, - "processing" - ); - - let thread_channel = if in_thread || is_dm { - // DMs use the DM channel directly (no threads in DMs). - ChannelRef { - platform: "discord".into(), - channel_id: msg.channel_id.get().to_string(), - thread_id: None, - parent_id: thread_parent_id.clone(), - origin_event_id: None, - } - } else { - match get_or_create_thread(&ctx, &adapter, &msg, &prompt).await { - Ok(ch) => ch, - Err(e) => { - error!("failed to create thread: {e}"); - return; - } - } - }; - - // Notify user if any images couldn't be processed. - if !failed_image_files.is_empty() { - let file_list = failed_image_files - .iter() - .map(|n| format!("`{}`", n.replace('`', "'"))) - .collect::>() - .join(", "); - let warn_msg = format!( - ":warning: I couldn't process the image(s) you shared ({}). \ - The files may be inaccessible or in an unsupported format (PNG/JPEG/GIF/WebP only).", - file_list - ); - if let Err(e) = adapter.send_message(&thread_channel, &warn_msg).await { - tracing::warn!(error = %e, "failed to send image warning to user"); - } - } - - let trigger_msg = discord_msg_ref(&msg); - - // Per-thread streaming: check if another bot is present in this thread - let other_bot_present_flag = { - let cache = self.multibot_threads.lock().await; - cache.contains_key(&msg.channel_id.to_string()) - } || self.multibot_cache.is_multibot(&msg.channel_id.to_string()); - - // Backfill thread_id: when OAB just created a new thread, the sender - // was built before the thread existed. Patch it so the agent sees - // thread_id on the very first turn. - let mut sender = sender; - if sender.thread_id.is_none() && thread_channel.parent_id.is_some() { - sender.thread_id = Some(thread_channel.channel_id.clone()); - } - - let dispatcher = self.dispatcher.clone(); - let stt_cfg = self.stt_config.clone(); - - tokio::spawn(async move { - // Best-effort echo before the agent reply so the user can verify STT. - crate::stt::post_echo( - &adapter, - &thread_channel, - &trigger_msg, - &echo_entries, - &stt_cfg, - ) - .await; - - let sender_id = sender.sender_id.clone(); - let sender_name = sender.sender_name.clone(); - let sender_json = serde_json::to_string(&sender).unwrap(); - let thread_key = dispatcher.key("discord", &thread_channel.channel_id, &sender_id); - let estimated_tokens = crate::dispatch::estimate_tokens(&prompt, &extra_blocks); - let buf_msg = crate::dispatch::BufferedMessage { - sender_json, - sender_name, - prompt, - extra_blocks, - trigger_msg, - arrived_at: std::time::Instant::now(), - estimated_tokens, - other_bot_present: other_bot_present_flag, - recipient: None, // Slack-only (assistant mode); N/A for Discord - }; - if let Err(e) = dispatcher - .submit(thread_key, thread_channel, adapter, buf_msg) - .await - { - error!("dispatcher submit error: {e}"); - } - }); - } - - async fn ready(&self, ctx: Context, ready: Ready) { - info!(user = %ready.user.name, "discord bot connected"); - - // Build the shared command list once. - let commands = vec![ - CreateCommand::new("models").description("Select the AI model for this session"), - CreateCommand::new("agents").description("Select the agent mode for this session"), - CreateCommand::new("cancel").description("Cancel the current operation"), - CreateCommand::new("cancel-all") - .description("Cancel current operation and drop all buffered messages"), - CreateCommand::new("reset").description("Reset the conversation session"), - CreateCommand::new("remind") - .description("Set a one-shot reminder to mention users/roles after a delay") - .add_option(CreateCommandOption::new( - CommandOptionType::String, - "targets", - "Users/roles to mention (e.g. @user1 @role1)", - ).required(true)) - .add_option(CreateCommandOption::new( - CommandOptionType::String, - "message", - "Reminder message", - ).required(true)) - .add_option(CreateCommandOption::new( - CommandOptionType::String, - "delay", - "Delay before firing (e.g. 30m, 2h, 1d)", - ).required(true)), - CreateCommand::new("export-thread") - .description("Download this thread as a text file") - .add_option(CreateCommandOption::new( - CommandOptionType::Integer, - "limit", - "Export only the most recent N messages (1–5000)", - )) - .add_option(CreateCommandOption::new( - CommandOptionType::String, - "since", - "Export messages after this message ID", - )) - .add_option(CreateCommandOption::new( - CommandOptionType::Integer, - "days", - "Export messages from the last N days (1–365)", - )) - .add_option(CreateCommandOption::new( - CommandOptionType::Boolean, - "all", - "Export all messages (up to 5000). Default is last 100.", - )), - ]; - - // Register global commands only. Registering the same commands per-guild - // makes Discord show duplicate slash commands in guild command pickers. - if let Err(e) = Command::set_global_commands(&ctx.http, commands.clone()).await { - tracing::warn!(error = %e, "failed to register global slash commands"); - } else { - info!("registered global slash commands"); - } - - // One-time migration cleanup: older versions registered the same - // slash commands per-guild, and Discord persists those server-side. - // Keep guild command sets empty so only global commands are shown. - for guild in &ready.guilds { - let guild_id = guild.id; - if let Err(e) = guild_id.set_commands(&ctx.http, Vec::new()).await { - tracing::warn!( - %guild_id, - error = %e, - "failed to clear stale guild slash commands" - ); - } - } - - // Re-schedule any pending reminders that survived a restart. - let pending = self.reminder_store.pending().await; - if !pending.is_empty() { - let mut scheduled = self.scheduled_ids.lock().await; - let mut count = 0; - for r in pending { - if scheduled.insert(r.id.clone()) { - remind::schedule_reminder(ctx.http.clone(), self.reminder_store.clone(), r); - count += 1; - } - } - if count > 0 { - info!(count, "re-scheduled pending reminders"); - } - } - } - - async fn interaction_create(&self, ctx: Context, interaction: Interaction) { - match interaction { - Interaction::Command(cmd) if cmd.data.name == "models" => { - self.handle_config_command(&ctx, &cmd, "model", "model") - .await; - } - Interaction::Command(cmd) if cmd.data.name == "agents" => { - self.handle_config_command(&ctx, &cmd, "agent", "agent") - .await; - } - Interaction::Command(cmd) if cmd.data.name == "cancel" => { - self.handle_cancel_command(&ctx, &cmd).await; - } - Interaction::Command(cmd) if cmd.data.name == "cancel-all" => { - self.handle_cancel_all_command(&ctx, &cmd).await; - } - Interaction::Command(cmd) if cmd.data.name == "reset" => { - self.handle_reset_command(&ctx, &cmd).await; - } - Interaction::Command(cmd) if cmd.data.name == "remind" => { - self.handle_remind_command(&ctx, &cmd).await; - } - Interaction::Command(cmd) if cmd.data.name == "export-thread" => { - self.handle_export_thread_command(&ctx, &cmd).await; - } - Interaction::Component(comp) if comp.data.custom_id.starts_with("acp_config_") => { - self.handle_config_select(&ctx, &comp).await; - } - Interaction::Component(comp) if comp.data.custom_id.starts_with("acp_pg:") => { - self.handle_pagination(&ctx, &comp).await; - } - _ => {} - } - } -} - -// --- Slash command & interaction handlers --- - -impl Handler { - /// Build a Discord select menu from ACP configOptions with the given category. - /// Paginates options in pages of 25 (Discord limit). The current selection is - /// always placed first so it appears on page 0. - fn build_config_select( - options: &[ConfigOption], - category: &str, - page: usize, - ) -> Option { - let opt = options - .iter() - .find(|o| o.category.as_deref() == Some(category))?; - - // Put current selection first so it always lands on page 0, - // then fill remaining slots in original order. - let sorted: Vec<_> = opt - .options - .iter() - .filter(|o| o.value == opt.current_value) - .chain(opt.options.iter().filter(|o| o.value != opt.current_value)) - .collect(); - - let menu_options: Vec = sorted - .iter() - .skip(page * SELECT_MENU_PAGE_SIZE) - .take(SELECT_MENU_PAGE_SIZE) - .map(|o| { - let mut item = CreateSelectMenuOption::new(&o.name, &o.value); - if let Some(desc) = &o.description { - item = item.description(desc); - } - if o.value == opt.current_value { - item = item.default_selection(true); - } - item - }) - .collect(); - - if menu_options.is_empty() { - return None; - } - - let current_name = opt - .options - .iter() - .find(|o| o.value == opt.current_value) - .map(|o| o.name.as_str()) - .unwrap_or(&opt.current_value); - let total_pages = sorted.len().div_ceil(SELECT_MENU_PAGE_SIZE); - let placeholder = if total_pages > 1 { - format!( - "Current: {} (page {}/{})", - current_name, - page + 1, - total_pages - ) - } else { - format!("Current: {}", current_name) - }; - - Some( - CreateSelectMenu::new( - format!("acp_config_{}", opt.id), - CreateSelectMenuKind::String { - options: menu_options, - }, - ) - .placeholder(placeholder), - ) - } - - /// Build ◀/▶ pagination buttons. Returns None when only one page exists. - fn build_pagination_buttons( - category: &str, - page: usize, - total_pages: usize, - ) -> Option { - if total_pages <= 1 { - return None; - } - let prev = CreateButton::new(format!("acp_pg:{}:{}", category, page.saturating_sub(1))) - .label("◀") - .style(ButtonStyle::Secondary) - .disabled(page == 0); - let next = CreateButton::new(format!("acp_pg:{}:{}", category, page + 1)) - .label("▶") - .style(ButtonStyle::Secondary) - .disabled(page + 1 >= total_pages); - let indicator = CreateButton::new("acp_pg_noop") - .label(format!("{}/{}", page + 1, total_pages)) - .style(ButtonStyle::Secondary) - .disabled(true); - Some(CreateActionRow::Buttons(vec![prev, indicator, next])) - } - - /// Build the full component rows (select menu + optional pagination) for a config category. - /// When `page` is `None`, auto-selects the page containing the current value. - fn build_config_components( - options: &[ConfigOption], - category: &str, - page: Option, - ) -> Option> { - let opt = options - .iter() - .find(|o| o.category.as_deref() == Some(category))?; - let total_pages = opt.options.len().div_ceil(SELECT_MENU_PAGE_SIZE); - let page = match page { - Some(p) => p.min(total_pages.saturating_sub(1)), - None => opt - .options - .iter() - .position(|o| o.value == opt.current_value) - .map(|i| i / SELECT_MENU_PAGE_SIZE) - .unwrap_or(0), - }; - - let select = Self::build_config_select(options, category, page)?; - let mut rows = vec![CreateActionRow::SelectMenu(select)]; - if let Some(buttons) = Self::build_pagination_buttons(category, page, total_pages) { - rows.push(buttons); - } - Some(rows) - } - - async fn handle_config_command( - &self, - ctx: &Context, - cmd: &serenity::model::application::CommandInteraction, - category: &str, - label: &str, - ) { - let thread_key = format!("discord:{}", cmd.channel_id.get()); - let config_options = self.router.pool().get_config_options(&thread_key).await; - - let response = match Self::build_config_components(&config_options, category, None) { - Some(rows) => CreateInteractionResponse::Message( - CreateInteractionResponseMessage::new() - .content(format!("🔧 Select a {label}:")) - .components(rows) - .ephemeral(true), - ), - None => CreateInteractionResponse::Message( - CreateInteractionResponseMessage::new() - .content(format!("⚠️ No {label} options available. Start a conversation first by @mentioning the bot.")) - .ephemeral(true), - ), - }; - - if let Err(e) = cmd.create_response(&ctx.http, response).await { - tracing::error!(error = %e, category, "failed to respond to slash command"); - } - } - - async fn handle_cancel_command( - &self, - ctx: &Context, - cmd: &serenity::model::application::CommandInteraction, - ) { - let thread_key = format!("discord:{}", cmd.channel_id.get()); - let result = self.router.pool().cancel_session(&thread_key).await; - - let msg = match result { - Ok(()) => "🛑 Cancel signal sent.".to_string(), - Err(e) => format!("⚠️ {e}"), - }; - - let response = CreateInteractionResponse::Message( - CreateInteractionResponseMessage::new() - .content(msg) - .ephemeral(true), - ); - if let Err(e) = cmd.create_response(&ctx.http, response).await { - tracing::error!(error = %e, "failed to respond to /cancel command"); - } - } - - async fn handle_cancel_all_command( - &self, - ctx: &Context, - cmd: &serenity::model::application::CommandInteraction, - ) { - // /cancel-all is the nuclear escape hatch: stop the in-flight turn AND clear - // every lane's buffer in this thread, so a human can intervene from a clean slate. - let session_key = format!("discord:{}", cmd.channel_id.get()); - let dropped = self - .dispatcher - .cancel_buffered_thread("discord", &cmd.channel_id.get().to_string()); - - let cancel_result = self.router.pool().cancel_session(&session_key).await; - - // Buffer count is approximate (sweep races with new arrivals) so we surface - // a binary "cleared / nothing" signal rather than a misleading exact number. - let msg = match (cancel_result, dropped) { - (Ok(()), 0) => "🛑 Cancel signal sent.".to_string(), - (Ok(()), _) => "🛑 Cancel signal sent. Buffered messages cleared.".to_string(), - (Err(_), 0) => { - "⚠️ Nothing to cancel — no active session and no buffered messages.".to_string() - } - (Err(_), _) => "🛑 Buffered messages cleared. No active session to cancel.".to_string(), - }; - - let response = CreateInteractionResponse::Message( - CreateInteractionResponseMessage::new() - .content(msg) - .ephemeral(true), - ); - if let Err(e) = cmd.create_response(&ctx.http, response).await { - tracing::error!(error = %e, "failed to respond to /cancel-all command"); - } - } - - async fn handle_reset_command( - &self, - ctx: &Context, - cmd: &serenity::model::application::CommandInteraction, - ) { - // /reset clears every lane's buffer in this thread and tears down the shared - // ACP session — the next message in the thread starts a fresh conversation. - let session_key = format!("discord:{}", cmd.channel_id.get()); - let dropped = self - .dispatcher - .cancel_buffered_thread("discord", &cmd.channel_id.get().to_string()); - - let result = self.router.pool().reset_session(&session_key).await; - - let msg = match result { - Ok(()) if dropped > 0 => { - format!("🔄 Session reset. Dropped {dropped} buffered message(s). Start a new conversation!") - } - Ok(()) => "🔄 Session reset. Start a new conversation!".to_string(), - Err(_) if dropped > 0 => { - format!("🔄 Dropped {dropped} buffered message(s). No active session to reset.") - } - Err(_) => { - "⚠️ No active session to reset. Start a conversation first by @mentioning the bot." - .to_string() - } - }; - - let response = CreateInteractionResponse::Message( - CreateInteractionResponseMessage::new() - .content(msg) - .ephemeral(true), - ); - if let Err(e) = cmd.create_response(&ctx.http, response).await { - tracing::error!(error = %e, "failed to respond to /reset command"); - } - } - - async fn handle_remind_command( - &self, - ctx: &Context, - cmd: &serenity::model::application::CommandInteraction, - ) { - // Only humans can use /remind - if cmd.user.bot { - let response = CreateInteractionResponse::Message( - CreateInteractionResponseMessage::new() - .content("⚠️ Only humans can set reminders.") - .ephemeral(true), - ); - let _ = cmd.create_response(&ctx.http, response).await; - return; - } - - // Extract options - let opts = &cmd.data.options; - let targets_raw = opts.iter() - .find(|o| o.name == "targets") - .and_then(|o| o.value.as_str()) - .unwrap_or(""); - let message = opts.iter() - .find(|o| o.name == "message") - .and_then(|o| o.value.as_str()) - .unwrap_or(""); - let delay_raw = opts.iter() - .find(|o| o.name == "delay") - .and_then(|o| o.value.as_str()) - .unwrap_or(""); - - if targets_raw.is_empty() || message.is_empty() || delay_raw.is_empty() { - let response = CreateInteractionResponse::Message( - CreateInteractionResponseMessage::new() - .content("⚠️ All fields (targets, message, delay) are required.") - .ephemeral(true), - ); - let _ = cmd.create_response(&ctx.http, response).await; - return; - } - - // Parse delay - let delay_secs = match remind::parse_delay(delay_raw) { - Ok(s) => s, - Err(e) => { - let response = CreateInteractionResponse::Message( - CreateInteractionResponseMessage::new() - .content(format!("⚠️ Invalid delay: {e}")) - .ephemeral(true), - ); - let _ = cmd.create_response(&ctx.http, response).await; - return; - } - }; - - if let Err(e) = remind::validate_message(message) { - let response = CreateInteractionResponse::Message( - CreateInteractionResponseMessage::new() - .content(format!("⚠️ {e}")) - .ephemeral(true), - ); - let _ = cmd.create_response(&ctx.http, response).await; - return; - } - - // Strip @everyone / @here to prevent unintended mass pings. - let message = remind::sanitize_message(message); - - // Extract mention strings from targets (keep raw — Discord renders them) - let targets: Vec = targets_raw - .split_whitespace() - .filter(|t| t.starts_with("<@") && t.ends_with('>')) - .map(|t| t.to_string()) - .collect(); - - if targets.is_empty() { - let response = CreateInteractionResponse::Message( - CreateInteractionResponseMessage::new() - .content("⚠️ No valid mentions found in targets. Use @user or @role.") - .ephemeral(true), - ); - let _ = cmd.create_response(&ctx.http, response).await; - return; - } - - if targets.len() > remind::MAX_TARGETS { - let response = CreateInteractionResponse::Message( - CreateInteractionResponseMessage::new() - .content(format!("⚠️ Too many targets (max {}). Use a @role instead.", remind::MAX_TARGETS)) - .ephemeral(true), - ); - let _ = cmd.create_response(&ctx.http, response).await; - return; - } - - // F4: Per-user rate limit (max 5 active reminders) - let user_id = cmd.user.id.get(); - let pending = self.reminder_store.pending().await; - let user_count = pending.iter().filter(|r| r.sender_id == user_id).count(); - if user_count >= 5 { - let response = CreateInteractionResponse::Message( - CreateInteractionResponseMessage::new() - .content("⚠️ You already have 5 active reminders. Wait for some to fire before adding more.") - .ephemeral(true), - ); - let _ = cmd.create_response(&ctx.http, response).await; - return; - } - - let fire_at = chrono::Utc::now() + chrono::Duration::seconds(delay_secs as i64); - let reminder = remind::Reminder { - id: uuid::Uuid::new_v4().to_string(), - channel_id: cmd.channel_id.get(), - sender_id: cmd.user.id.get(), - targets: targets.clone(), - message: message.clone(), - fire_at, - created_at: chrono::Utc::now(), - }; - - // Persist and schedule - self.reminder_store.add(reminder.clone()).await; - self.scheduled_ids.lock().await.insert(reminder.id.clone()); - remind::schedule_reminder(ctx.http.clone(), self.reminder_store.clone(), reminder); - - let delay_str = remind::format_delay(delay_secs); - let response = CreateInteractionResponse::Message( - CreateInteractionResponseMessage::new() - .content(format!( - "⏰ Reminder set! Will fire in **{delay_str}** and mention {}", - targets.join(" ") - )) - .ephemeral(true), - ); - if let Err(e) = cmd.create_response(&ctx.http, response).await { - tracing::error!(error = %e, "failed to respond to /remind command"); - } - } - - async fn handle_export_thread_command( - &self, - ctx: &Context, - cmd: &serenity::model::application::CommandInteraction, - ) { - if is_denied_user( - false, - self.allow_all_users, - &self.allowed_users, - cmd.user.id.get(), - ) { - let response = CreateInteractionResponse::Message( - CreateInteractionResponseMessage::new() - .content("🚫 You are not allowed to use this bot.") - .ephemeral(true), - ); - if let Err(e) = cmd.create_response(&ctx.http, response).await { - tracing::error!(error = %e, "failed to deny /export-thread command"); - } - return; - } - - let channel_id = cmd.channel_id; - let (export_allowed, export_name) = match channel_id.to_channel(&ctx.http).await { - Ok(serenity::model::channel::Channel::Guild(gc)) => { - let in_allowed_channel = - self.allow_all_channels || self.allowed_channels.contains(&channel_id.get()); - let (in_thread, _) = detect_thread( - gc.thread_metadata.is_some(), - gc.parent_id.map(|id| id.get()), - gc.owner_id.map(|id| id.get()), - ctx.cache.current_user().id.get(), - &self.allowed_channels, - self.allow_all_channels, - in_allowed_channel, - ); - (in_thread, gc.name.clone()) - } - Ok(serenity::model::channel::Channel::Private(_)) => { - (self.allow_dm, "dm".to_string()) - } - Ok(_) => (false, "channel".to_string()), - Err(e) => { - tracing::warn!(channel_id = %channel_id, error = %e, "failed to inspect channel for export"); - (false, "channel".to_string()) - } - }; - - if !export_allowed { - let response = CreateInteractionResponse::Message( - CreateInteractionResponseMessage::new() - .content("⚠️ Run this command inside an allowed Discord thread or DM.") - .ephemeral(true), - ); - if let Err(e) = cmd.create_response(&ctx.http, response).await { - tracing::error!(error = %e, "failed to respond to /export-thread rejection"); - } - return; - } - - // --- Parse and validate filter params (mutual exclusion) --- - let opts = &cmd.data.options; - let limit_opt = opts.iter().find(|o| o.name == "limit").and_then(|o| o.value.as_i64()); - let since_opt = opts.iter().find(|o| o.name == "since").and_then(|o| o.value.as_str()); - let days_opt = opts.iter().find(|o| o.name == "days").and_then(|o| o.value.as_i64()); - let all_opt = opts.iter().find(|o| o.name == "all").and_then(|o| o.value.as_bool()).unwrap_or(false); - - let filter_count = limit_opt.is_some() as u8 + since_opt.is_some() as u8 + days_opt.is_some() as u8 + all_opt as u8; - if filter_count > 1 { - let response = CreateInteractionResponse::Message( - CreateInteractionResponseMessage::new() - .content("⚠️ Please specify only one filter: `limit`, `since`, `days`, or `all`.") - .ephemeral(true), - ); - let _ = cmd.create_response(&ctx.http, response).await; - return; - } - - let filter = if all_opt { - ExportFilter::All - } else if let Some(n) = limit_opt { - if !(1..=5000).contains(&n) { - let response = CreateInteractionResponse::Message( - CreateInteractionResponseMessage::new() - .content("⚠️ `limit` must be between 1 and 5000.") - .ephemeral(true), - ); - let _ = cmd.create_response(&ctx.http, response).await; - return; - } - ExportFilter::Limit(n as usize) - } else if let Some(id_str) = since_opt { - match id_str.parse::() { - Ok(id) if id > 0 => ExportFilter::After(MessageId::new(id)), - _ => { - let response = CreateInteractionResponse::Message( - CreateInteractionResponseMessage::new() - .content("⚠️ `since` must be a valid message ID (right-click a message → Copy Message ID).") - .ephemeral(true), - ); - let _ = cmd.create_response(&ctx.http, response).await; - return; - } - } - } else if let Some(d) = days_opt { - if !(1..=365).contains(&d) { - let response = CreateInteractionResponse::Message( - CreateInteractionResponseMessage::new() - .content("⚠️ `days` must be between 1 and 365.") - .ephemeral(true), - ); - let _ = cmd.create_response(&ctx.http, response).await; - return; - } - let since_ts = chrono::Utc::now() - chrono::Duration::days(d); - let ts_ms = since_ts.timestamp_millis() as u64; - ExportFilter::After(timestamp_ms_to_snowflake(ts_ms)) - } else { - // Default: export last 100 messages (use limit:N or all:true for more) - ExportFilter::Limit(100) - }; - - let response = CreateInteractionResponse::Message( - CreateInteractionResponseMessage::new() - .content("Preparing thread export...") - .ephemeral(true), - ); - if let Err(e) = cmd.create_response(&ctx.http, response).await { - tracing::error!(error = %e, "failed to acknowledge /export-thread command"); - return; - } - - match export_channel_messages( - &ctx.http, - channel_id, - &export_name, - cmd.attachment_size_limit, - filter, - ) - .await - { - Ok(result) => { - let mut content = format!("Exported {} messages.", result.written); - if result.hit_cap { - content.push_str(&format!( - " Only the most recent {} messages were fetched — older messages were not included.", - result.fetched - )); - } - if result.byte_truncated { - content.push_str(&format!( - " Transcript truncated to fit Discord's attachment size limit ({} of {} fetched messages included).", - result.written, result.fetched - )); - } - let attachment = - CreateAttachment::bytes(result.transcript.into_bytes(), result.filename); - let followup = CreateInteractionResponseFollowup::new() - .content(content) - .add_file(attachment) - .ephemeral(true); - if let Err(e) = cmd.create_followup(&ctx.http, followup).await { - tracing::error!(error = %e, "failed to send /export-thread attachment"); - } - } - Err(e) => { - tracing::warn!(channel_id = %channel_id, error = %e, "failed to export thread"); - let followup = CreateInteractionResponseFollowup::new() - .content(format!("⚠️ Failed to export thread: {e}")) - .ephemeral(true); - if let Err(e) = cmd.create_followup(&ctx.http, followup).await { - tracing::error!(error = %e, "failed to send /export-thread error"); - } - } - } - } - - async fn handle_config_select( - &self, - ctx: &Context, - comp: &serenity::model::application::ComponentInteraction, - ) { - let config_id = comp - .data - .custom_id - .strip_prefix("acp_config_") - .unwrap_or("") - .to_string(); - - if config_id.is_empty() { - return; - } - - let selected_value = match &comp.data.kind { - ComponentInteractionDataKind::StringSelect { values } => match values.first() { - Some(v) => v.clone(), - None => return, - }, - _ => return, - }; - - let thread_key = format!("discord:{}", comp.channel_id.get()); - - let result = self - .router - .pool() - .set_config_option(&thread_key, &config_id, &selected_value) - .await; - - let response_msg = match result { - Ok(updated_options) => { - let display_name = updated_options - .iter() - .find(|o| o.id == config_id) - .and_then(|o| o.options.iter().find(|v| v.value == selected_value)) - .map(|v| v.name.as_str()) - .unwrap_or(&selected_value); - format!("✅ Switched to **{}**", display_name) - } - Err(e) => { - tracing::error!(error = %e, "failed to set config option"); - format!("❌ Failed to switch: {}", e) - } - }; - - let response = CreateInteractionResponse::UpdateMessage( - CreateInteractionResponseMessage::new() - .content(response_msg) - .components(vec![]), - ); - - if let Err(e) = comp.create_response(&ctx.http, response).await { - tracing::error!(error = %e, "failed to respond to config select"); - } - } - - async fn handle_pagination( - &self, - ctx: &Context, - comp: &serenity::model::application::ComponentInteraction, - ) { - // Parse custom_id format: acp_pg:{category}:{page} - let parts: Vec<&str> = comp.data.custom_id.splitn(3, ':').collect(); - let (category, page) = match parts.as_slice() { - [_, cat, pg] => match pg.parse::() { - Ok(p) => (*cat, p), - Err(_) => return, - }, - _ => return, - }; - - // Only allow known config categories. - if !matches!(category, "model" | "agent") { - return; - } - - let thread_key = format!("discord:{}", comp.channel_id.get()); - let config_options = self.router.pool().get_config_options(&thread_key).await; - - let response = match Self::build_config_components(&config_options, category, Some(page)) { - Some(rows) => CreateInteractionResponse::UpdateMessage( - CreateInteractionResponseMessage::new() - .content(format!("🔧 Select a {category}:")) - .components(rows), - ), - None => CreateInteractionResponse::UpdateMessage( - CreateInteractionResponseMessage::new() - .content(format!("⚠️ No {category} options available.")) - .components(vec![]), - ), - }; - - if let Err(e) = comp.create_response(&ctx.http, response).await { - tracing::error!(error = %e, category, "failed to respond to pagination"); - } - } -} - -// --- Discord-specific helpers --- - -fn discord_msg_ref(msg: &Message) -> MessageRef { - MessageRef { - channel: ChannelRef { - platform: "discord".into(), - channel_id: msg.channel_id.get().to_string(), - thread_id: None, - parent_id: None, - origin_event_id: None, - }, - message_id: msg.id.to_string(), - } -} - -struct ExportResult { - filename: String, - transcript: String, - /// Messages successfully pulled from Discord. - fetched: usize, - /// Messages that fit in the transcript (≤ `fetched`; differs when the - /// attachment-size limit truncates). - written: usize, - /// We stopped fetching because we hit the message cap and the thread still - /// has more messages we did not include. - hit_cap: bool, - /// Transcript was cut to keep the attachment under Discord's size limit. - byte_truncated: bool, -} - -/// Filter mode for export_channel_messages. -enum ExportFilter { - /// Fetch all messages (newest-first via `before`), capped at THREAD_EXPORT_MESSAGE_LIMIT. - All, - /// Fetch the most recent N messages (newest-first via `before`). - Limit(usize), - /// Fetch messages after a synthetic snowflake (newest-first via `before`, with boundary filtering). - After(MessageId), -} - -/// Discord epoch: 2015-01-01T00:00:00Z in milliseconds. -const DISCORD_EPOCH_MS: u64 = 1_420_070_400_000; - -/// Convert a UTC timestamp (in milliseconds since Unix epoch) to a synthetic -/// Discord snowflake suitable for use as an `after` cursor. -fn timestamp_ms_to_snowflake(timestamp_ms: u64) -> MessageId { - let discord_ms = timestamp_ms.saturating_sub(DISCORD_EPOCH_MS); - // Snowflake IDs use NonZeroU64 in serenity; ensure at least 1. - MessageId::new((discord_ms << 22).max(1)) -} - -async fn export_channel_messages( - http: &Http, - channel_id: ChannelId, - channel_name: &str, - attachment_size_limit: u32, - filter: ExportFilter, -) -> anyhow::Result { - let cap = match &filter { - ExportFilter::Limit(n) => *n, - _ => THREAD_EXPORT_MESSAGE_LIMIT, - }; - - let mut messages = Vec::new(); - let mut hit_cap = false; - - match &filter { - ExportFilter::All | ExportFilter::Limit(_) => { - // Fetch newest-first using `before` pagination, then reverse. - let mut before = None; - loop { - if messages.len() >= cap { - hit_cap = true; - break; - } - let remaining = cap - messages.len(); - let limit = remaining.min(100) as u8; - let mut request = GetMessages::new().limit(limit); - if let Some(before_id) = before { - request = request.before(before_id); - } - let batch = channel_id.messages(http, request).await?; - if batch.is_empty() { - break; - } - before = batch.last().map(|m| m.id); - let batch_len = batch.len(); - messages.extend(batch); - if batch_len < limit as usize { - break; - } - } - // Probe to confirm we actually left messages behind. - if hit_cap { - let probe = GetMessages::new().limit(1); - let probe = if let Some(before_id) = before { - probe.before(before_id) - } else { - probe - }; - if matches!(channel_id.messages(http, probe).await, Ok(b) if b.is_empty()) { - hit_cap = false; - } - } - messages.reverse(); - } - ExportFilter::After(after_id) => { - // Fetch newest-first using `before` pagination, stop when we hit - // messages at or before the filter boundary. This ensures that when - // the cap is reached, we keep the *newest* messages in the window. - let mut before = None; - loop { - if messages.len() >= cap { - hit_cap = true; - break; - } - let remaining = cap - messages.len(); - let limit = remaining.min(100) as u8; - let mut request = GetMessages::new().limit(limit); - if let Some(before_id) = before { - request = request.before(before_id); - } - let batch = channel_id.messages(http, request).await?; - if batch.is_empty() { - break; - } - before = batch.last().map(|m| m.id); - let batch_len = batch.len(); - // Filter out messages at or before the boundary. - let filtered: Vec<_> = batch.into_iter().filter(|m| m.id > *after_id).collect(); - let hit_boundary = filtered.len() < batch_len; - messages.extend(filtered); - if hit_boundary { - // We've reached the time boundary; no need to fetch older. - break; - } - if batch_len < limit as usize { - break; - } - } - // Probe only if we stopped due to cap (not boundary). - if hit_cap { - let probe = GetMessages::new().limit(1); - let probe = if let Some(before_id) = before { - probe.before(before_id) - } else { - probe - }; - if let Ok(batch) = channel_id.messages(http, probe).await { - // If the next message is beyond our filter boundary, - // we didn't actually leave relevant messages behind. - let has_more_in_window = batch.iter().any(|m| m.id > *after_id); - if !has_more_in_window { - hit_cap = false; - } - } - } - messages.reverse(); - } - } - - let filename = export_filename(channel_id, channel_name); - if attachment_size_limit < 2048 { - tracing::warn!(attachment_size_limit, "attachment_size_limit is very small; export will likely be truncated"); - } - let max_bytes = usize::try_from(attachment_size_limit) - .unwrap_or(8 * 1024 * 1024) - .saturating_sub(1024) - .max(1024); - let (transcript, written, byte_truncated) = - format_thread_export(channel_id, channel_name, &messages, max_bytes); - let fetched = messages.len(); - - Ok(ExportResult { - filename, - transcript, - fetched, - written, - hit_cap, - byte_truncated, - }) -} - -fn format_thread_export( - channel_id: ChannelId, - channel_name: &str, - messages: &[Message], - max_bytes: usize, -) -> (String, usize, bool) { - let header = format!( - "Discord thread export\nChannel: {channel_name} ({channel_id})\nMessages: {}\n\n", - messages.len() - ); - let entries: Vec = messages.iter().map(format_export_message).collect(); - assemble_export(&header, &entries, max_bytes) -} - -/// Build the transcript body from a pre-rendered header and a list of -/// already-formatted message entries, honouring `max_bytes`. -/// -/// Returns `(transcript, written, truncated)` where `written` is the number of -/// entries actually included. Split out from `format_thread_export` so the -/// truncation boundary logic can be unit-tested without constructing real -/// `serenity::model::channel::Message` values. -fn assemble_export(header: &str, entries: &[String], max_bytes: usize) -> (String, usize, bool) { - let mut out = String::from(header); - let mut written = 0; - let mut truncated = false; - - for entry in entries { - if out.len() + entry.len() > max_bytes { - truncated = true; - break; - } - out.push_str(entry); - written += 1; - } - - if truncated { - let note = "\n[Export truncated to fit Discord attachment size limit]\n"; - let room = max_bytes.saturating_sub(out.len()); - if room >= note.len() { - out.push_str(note); - } - } - - (out, written, truncated) -} - -fn format_export_message(msg: &Message) -> String { - let bot_marker = if msg.author.bot { " [bot]" } else { "" }; - let mut out = format!( - "[{}] {}{} ({})\n", - msg.timestamp, - msg.author.name, - bot_marker, - msg.author.id - ); - - if msg.content.is_empty() { - out.push_str("(no text)\n"); - } else { - out.push_str(&msg.content); - out.push('\n'); - } - - for attachment in &msg.attachments { - let mime = attachment.content_type.as_deref().unwrap_or("unknown"); - out.push_str(&format!( - "[attachment] {} ({} bytes, {}): {}\n", - attachment.filename, attachment.size, mime, attachment.url - )); - } - - out.push('\n'); - out -} - -fn export_filename(channel_id: ChannelId, channel_name: &str) -> String { - let safe_name = sanitize_filename_component(channel_name); - format!("discord-thread-{safe_name}-{channel_id}.txt") -} - -/// Reduce a free-form Discord channel/thread name to a safe ASCII filename -/// fragment. -/// -/// Non-ASCII characters are dropped silently — a purely-Chinese thread name -/// like "扈三娘的房間" yields a date-based fallback (e.g. `"20260512"`). -/// The caller appends the channel ID, which already guarantees uniqueness, -/// and an ASCII fragment plays nicer with downstream tools (mail attachments, -/// S3 keys, browser save-as dialogs). The 64-byte cap leaves room for the -/// `discord-thread-` prefix and the channel-ID suffix within typical -/// filesystem limits. -fn sanitize_filename_component(input: &str) -> String { - let mut safe = String::with_capacity(input.len()); - for ch in input.chars() { - if ch.is_ascii_alphanumeric() || matches!(ch, '-' | '_') { - safe.push(ch); - } else if ch.is_whitespace() || matches!(ch, '.' | '/') { - safe.push('-'); - } - } - let safe = safe.trim_matches('-'); - if safe.is_empty() { - // Use current date as a human-friendly fallback when the thread name - // is entirely non-ASCII. - chrono::Utc::now().format("%Y%m%d").to_string() - } else { - safe.chars().take(64).collect() - } -} - -async fn get_or_create_thread( - ctx: &Context, - adapter: &Arc, - msg: &Message, - prompt: &str, -) -> anyhow::Result { - let channel = msg.channel_id.to_channel(&ctx.http).await?; - if let serenity::model::channel::Channel::Guild(ref gc) = channel { - // Already in a thread — reuse it. Uses thread_metadata (see detect_thread()). - if gc.thread_metadata.is_some() { - return Ok(ChannelRef { - platform: "discord".into(), - channel_id: msg.channel_id.get().to_string(), - thread_id: None, - parent_id: None, - origin_event_id: None, - }); - } - } - - let thread_name = format::shorten_thread_name(prompt); - let parent = ChannelRef { - platform: "discord".into(), - channel_id: msg.channel_id.get().to_string(), - thread_id: None, - parent_id: None, - origin_event_id: None, - }; - let trigger_ref = discord_msg_ref(msg); - match adapter - .create_thread(&parent, &trigger_ref, &thread_name) - .await - { - Ok(ch) => Ok(ch), - Err(e) if is_thread_already_exists_error(&e) => { - // Another bot won the race from the same trigger message. Discord - // only allows one thread per message, so refetch the message and - // join the thread our sibling just created. - let refreshed = msg - .channel_id - .message(&ctx.http, msg.id) - .await - .map_err(|fe| { - anyhow::anyhow!("thread_already_exists (race), but refetch failed: {fe}") - })?; - let existing = refreshed.thread.ok_or_else(|| { - anyhow::anyhow!( - "thread_already_exists (race), but message has no thread after refetch" - ) - })?; - tracing::info!( - channel_id = %msg.channel_id, - thread_id = %existing.id, - "joining thread created by sibling bot from same trigger message" - ); - Ok(ChannelRef { - platform: "discord".into(), - channel_id: existing.id.to_string(), - thread_id: None, - parent_id: Some(msg.channel_id.get().to_string()), - origin_event_id: None, - }) - } - Err(e) => Err(e), - } -} - -/// Detect Discord's "A thread has already been created for this message" error -/// (JSON error code 160004). Triggered when two bots responding to the same -/// @-mention race to create a thread from the same trigger message. -/// -/// Uses string matching because serenity surfaces Discord API errors as -/// formatted strings — there is no structured error code we can match on. -/// Unit tests pin the expected patterns so serenity formatting changes are caught. -fn is_thread_already_exists_error(err: &anyhow::Error) -> bool { - let msg = err.to_string(); - msg.contains("160004") || msg.contains("already been created") -} - -static ROLE_MENTION_RE: LazyLock = - LazyLock::new(|| regex::Regex::new(r"<@&\d+>").unwrap()); - -fn resolve_mentions(content: &str, bot_id: UserId, allowed_role_ids: &HashSet) -> String { - // 1. Strip the bot's own trigger mention - let out = content - .replace(&format!("<@{}>", bot_id), "") - .replace(&format!("<@!{}>", bot_id), ""); - // 2. Strip allowed role mentions (they triggered the bot, not useful in prompt) - let out = if allowed_role_ids.is_empty() { - out - } else { - allowed_role_ids - .iter() - .fold(out, |s, id| s.replace(&format!("<@&{}>", id), "")) - }; - // 3. Other user mentions: keep <@UID> as-is so the LLM can mention back - // 4. Fallback: replace remaining role mentions only (user mentions are preserved) - let out = ROLE_MENTION_RE.replace_all(&out, "@(role)").to_string(); - out.trim().to_string() -} - -fn video_attachment_block( - filename: &str, - content_type: Option<&str>, - size: u64, - url: &str, -) -> ContentBlock { - ContentBlock::Text { - text: format!( - "[Video attachment]\nfilename: {}\ncontent_type: {}\nsize_bytes: {}\nurl: {}", - filename, - content_type.unwrap_or("unknown"), - size, - url - ), - } -} - -/// Build a `SenderContext` for Discord messages. -/// -/// Pure function extracted from `EventHandler::message` for testability. -/// When `thread_parent_id` is `Some`, the message is inside a thread: -/// - `channel_id` → parent channel (where the thread lives) -/// - `thread_id` → thread's own channel ID -/// -/// This mirrors Slack's model where `channel_id` is always the parent -/// channel and `thread_id` (thread_ts) identifies the thread. -/// -/// Note: `ChannelRef.channel_id` uses the *opposite* convention — it holds -/// the thread's channel ID for routing (Discord API sends to thread by its -/// channel ID). See `ChannelRef` doc comments for details. -#[allow(clippy::too_many_arguments)] -fn build_sender_context( - sender_id: &str, - sender_name: &str, - display_name: &str, - msg_channel_id: &str, - thread_parent_id: Option<&str>, - is_bot: bool, - timestamp: &str, - message_id: &str, - receiver_id: &str, -) -> SenderContext { - SenderContext { - schema: "openab.sender.v1".into(), - sender_id: sender_id.to_string(), - sender_name: sender_name.to_string(), - display_name: display_name.to_string(), - channel: "discord".into(), - channel_id: thread_parent_id.unwrap_or(msg_channel_id).to_string(), - thread_id: thread_parent_id.map(|_| msg_channel_id.to_string()), - is_bot, - timestamp: Some(timestamp.to_string()), - message_id: Some(message_id.to_string()), - receiver_id: Some(receiver_id.to_string()), - } -} - -/// Pure thread detection: determines whether a channel is a Discord thread -/// in an allowed parent, and whether the bot owns it. -/// -/// Returns `(in_allowed_thread, bot_owns)`: -/// - `in_allowed_thread`: true only if the channel IS a thread AND its parent -/// is permitted (via allowlist, `allow_all_channels`, or `in_allowed_channel`). -/// - `bot_owns`: `None` if the channel is not a thread (ownership is meaningless); -/// `Some(true/false)` if it IS a thread, indicating whether the bot owns it. -/// -/// Uses `thread_metadata.is_some()` — the canonical way to identify threads. -/// `parent_id` is NOT reliable for thread detection: category children also -/// have `parent_id` set. `parent_id` is only used here for the allowlist check. -/// -/// Discord API refs: -/// - Channel Object (parent_id / thread_metadata fields): -/// https://docs.discord.com/developers/resources/channel#channel-object -/// - Thread Metadata ("thread-specific fields not needed by other channels"): -/// https://docs.discord.com/developers/resources/channel#thread-metadata-object -fn detect_thread( - has_thread_metadata: bool, - parent_id: Option, - owner_id: Option, - bot_id: u64, - allowed_channels: &HashSet, - allow_all_channels: bool, - in_allowed_channel: bool, -) -> (bool, Option) { - if !has_thread_metadata { - return (false, None); - } - let in_allowed_thread = in_allowed_channel - || allow_all_channels - || parent_id.is_some_and(|pid| allowed_channels.contains(&pid)); - let bot_owns = owner_id.is_some_and(|oid| oid == bot_id); - (in_allowed_thread, Some(bot_owns)) -} - -/// Returns `true` if the author should be denied by the user allowlist. -/// Bot authors skip this check — they are gated by `allow_bot_messages` + `trusted_bot_ids`. -fn is_denied_user( - is_bot: bool, - allow_all_users: bool, - allowed_users: &HashSet, - user_id: u64, -) -> bool { - !is_bot && !allow_all_users && !allowed_users.contains(&user_id) -} - -/// Returns `true` if a bot message should bypass the `allow_bot_messages` mode check. -/// A trusted bot that @mentions this bot is treated the same as a human @mention — -/// it can pull the bot into a thread regardless of the `allow_bot_messages` setting. -#[cfg(test)] -fn is_trusted_bot_mention( - is_mentioned: bool, - trusted_bot_ids: &HashSet, - author_id: u64, -) -> bool { - is_mentioned && !trusted_bot_ids.is_empty() && trusted_bot_ids.contains(&author_id) -} - -/// Pure decision function: should a DM be processed? -/// Returns `true` if the DM should be processed (bot responds). -/// Mirrors the DM gating logic in EventHandler::message: -/// - `allow_dm` must be true -/// - `allowed_users` still applies (checked separately via `is_denied_user`) -/// - DMs bypass `allowed_channels` and `@mention` requirements -#[cfg(test)] -fn should_process_dm(allow_dm: bool) -> bool { - allow_dm -} - -/// Pure decision function: should thread creation be skipped? -/// Returns `true` when the message should reuse the current channel -/// directly (existing thread or DM), `false` when a new thread should -/// be created. Pins the invariant that DMs never call -/// `get_or_create_thread()` — Discord DM channels cannot create threads. -#[cfg(test)] -fn should_skip_thread_creation(in_thread: bool, is_dm: bool) -> bool { - in_thread || is_dm -} - -/// Pure decision function: should this message be processed or ignored? -/// Returns `true` if the message should be processed (bot responds). -/// Extracted from the EventHandler::message gating logic for testability. -#[cfg(test)] -fn should_process_user_message( - mode: AllowUsers, - is_mentioned: bool, - in_thread: bool, - involved: bool, - other_bot_present: bool, -) -> bool { - if is_mentioned { - return true; - } - match mode { - AllowUsers::Mentions => false, - AllowUsers::Involved => in_thread && involved, - AllowUsers::MultibotMentions => { - if !in_thread || !involved { - return false; - } - !other_bot_present - } - } -} - -/// Returns true if any bot message in `messages` contains a turn limit warning. -/// Used to dedup `WarnAndStop` across multiple bot processes sharing a thread. (#530) -/// Note: this is best-effort — a narrow race window exists where two bots fetch -/// simultaneously and both see no warning, resulting in a duplicate. For most -/// deployments this is acceptable; strict once-only semantics would require -/// shared state (e.g. gateway-owned emission or distributed lock). -/// -/// Accepts `(is_bot, content)` pairs so the logic can be unit-tested without -/// constructing `serenity::model::channel::Message` values (see existing test -/// boundary comment at `format_thread_export`). -fn turn_limit_warning_present(messages: &[(bool, &str)]) -> bool { - messages - .iter() - .any(|(is_bot, content)| *is_bot && content.contains(BOT_TURN_LIMIT_WARNING_PREFIX)) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::bot_turns::{TurnResult, HARD_BOT_TURN_LIMIT, BOT_TURN_LIMIT_WARNING_PREFIX}; - - // --- resolve_mentions tests --- - - /// Bot's own <@UID> mention is stripped from the prompt. - #[test] - fn resolve_mentions_strips_bot_mention() { - let bot_id = UserId::new(111); - let result = resolve_mentions("hello <@111> world", bot_id, &HashSet::new()); - assert_eq!(result, "hello world"); - } - - /// Bot's own legacy <@!UID> mention is also stripped. - #[test] - fn resolve_mentions_strips_bot_mention_legacy() { - let bot_id = UserId::new(111); - let result = resolve_mentions("hello <@!111> world", bot_id, &HashSet::new()); - assert_eq!(result, "hello world"); - } - - /// Other users' <@UID> mentions are preserved so the LLM can mention them back. - #[test] - fn resolve_mentions_preserves_other_user_mentions() { - let bot_id = UserId::new(111); - let result = resolve_mentions("<@111> say hi to <@222>", bot_id, &HashSet::new()); - assert_eq!(result, "say hi to <@222>"); - } - - /// Role mentions <@&UID> are replaced with @(role) placeholder. - #[test] - fn resolve_mentions_replaces_role_mentions() { - let bot_id = UserId::new(111); - let result = resolve_mentions("hello <@&999>", bot_id, &HashSet::new()); - assert_eq!(result, "hello @(role)"); - } - - /// Message containing only the bot mention results in empty string. - #[test] - fn resolve_mentions_empty_after_strip() { - let bot_id = UserId::new(111); - let result = resolve_mentions("<@111>", bot_id, &HashSet::new()); - assert_eq!(result, ""); - } - - /// Allowed role mentions are stripped from prompt (not replaced with @(role)). - #[test] - fn resolve_mentions_strips_allowed_role() { - let bot_id = UserId::new(111); - let roles: HashSet = [999].into_iter().collect(); - let result = resolve_mentions("hello <@&999> world", bot_id, &roles); - assert_eq!(result, "hello world"); - } - - /// Non-allowed role mentions are still replaced with @(role). - #[test] - fn resolve_mentions_keeps_other_roles_as_placeholder() { - let bot_id = UserId::new(111); - let roles: HashSet = [999].into_iter().collect(); - let result = resolve_mentions("<@&999> check <@&888>", bot_id, &roles); - assert_eq!(result, "check @(role)"); - } - - #[test] - fn video_attachment_block_includes_actionable_metadata() { - let block = video_attachment_block( - "demo.mp4", - Some("video/mp4"), - 12345, - "https://cdn.discordapp.com/attachments/demo.mp4", - ); - - let ContentBlock::Text { text } = block else { - panic!("video attachments must be forwarded as text metadata"); - }; - - assert!(text.contains("[Video attachment]")); - assert!(text.contains("filename: demo.mp4")); - assert!(text.contains("content_type: video/mp4")); - assert!(text.contains("size_bytes: 12345")); - assert!(text.contains("url: https://cdn.discordapp.com/attachments/demo.mp4")); - } - - // --- thread-race error detection --- - - /// Detects the Discord error code for "thread already exists" (160004). - #[test] - fn is_thread_already_exists_matches_code() { - let err = anyhow::Error::msg( - r#"HTTP error: {"code": 160004, "message": "A thread has already been created for this message."}"#, - ); - assert!(is_thread_already_exists_error(&err)); - } - - /// Detects the human-readable form of the error in case serenity renders - /// it without the numeric code. - #[test] - fn is_thread_already_exists_matches_message() { - let err = anyhow::anyhow!("A thread has already been created for this message."); - assert!(is_thread_already_exists_error(&err)); - } - - /// Unrelated errors do not match — we don't want the fallback path - /// swallowing real failures like permission denied. - #[test] - fn is_thread_already_exists_ignores_other_errors() { - let err = anyhow::anyhow!("Missing Permissions"); - assert!(!is_thread_already_exists_error(&err)); - let err = anyhow::anyhow!("rate limit exceeded"); - assert!(!is_thread_already_exists_error(&err)); - } - - // --- thread export helpers --- - - #[test] - fn sanitize_filename_component_keeps_safe_ascii() { - assert_eq!( - sanitize_filename_component("release notes_v2"), - "release-notes_v2" - ); - } - - #[test] - fn sanitize_filename_component_falls_back_for_empty_result() { - let result = sanitize_filename_component("///..."); - // Fallback is a YYYYMMDD date string - assert_eq!(result.len(), 8); - assert!(result.chars().all(|c| c.is_ascii_digit())); - } - - // --- assemble_export --- - // Split out from format_thread_export so we can test the truncation - // boundary without constructing serenity::model::channel::Message values. - - #[test] - fn assemble_export_empty_entries_returns_header_only() { - let (out, written, truncated) = assemble_export("HDR\n", &[], 1024); - assert_eq!(out, "HDR\n"); - assert_eq!(written, 0); - assert!(!truncated); - } - - #[test] - fn assemble_export_single_oversized_entry_writes_zero_and_marks_truncated() { - let entries = vec!["x".repeat(200)]; - let (out, written, truncated) = assemble_export("h\n", &entries, 50); - assert_eq!(written, 0); - assert!(truncated); - // Footer needs ~56 bytes; max_bytes 50 leaves ≤48 of room, so it is - // intentionally omitted (it can't be appended without exceeding the - // limit). The header is still present. - assert!(out.starts_with("h\n")); - assert!(!out.contains("xx")); - } - - #[test] - fn assemble_export_entry_at_exact_boundary_is_included() { - // header(2) + entry(3) == max_bytes(5); the strict-greater check - // keeps the entry in. - let (out, written, truncated) = assemble_export("h\n", &["abc".to_string()], 5); - assert_eq!(written, 1); - assert!(!truncated); - assert_eq!(out, "h\nabc"); - } - - #[test] - fn assemble_export_entry_one_byte_over_boundary_is_excluded() { - // header(2) + entry(4) == 6 > max_bytes(5); entry is dropped. - let (out, written, truncated) = assemble_export("h\n", &["abcd".to_string()], 5); - assert_eq!(written, 0); - assert!(truncated); - assert!(out.starts_with("h\n")); - assert!(!out.contains("abcd")); - } - - #[test] - fn assemble_export_appends_footer_when_room_remains() { - // First two short entries fit; the long third entry would overflow, - // and the remaining headroom is enough for the truncation footer. - let entries = vec!["a\n".to_string(), "b\n".to_string(), "c".repeat(500)]; - let (out, written, truncated) = assemble_export("h\n", &entries, 200); - assert_eq!(written, 2); - assert!(truncated); - assert!(out.contains("[Export truncated")); - } - - // --- snowflake conversion --- - - #[test] - fn timestamp_ms_to_snowflake_known_value() { - // 2026-05-10 00:00:00 UTC = 1778572800000 ms since Unix epoch - // Discord ms = 1778572800000 - 1420070400000 = 358502400000 - // Snowflake = 358502400000 << 22 = 1503238553600000000 (approx) - let ts_ms: u64 = 1_778_572_800_000; - let snowflake = timestamp_ms_to_snowflake(ts_ms); - // Verify round-trip: extract timestamp back from snowflake - let extracted_ms = (snowflake.get() >> 22) + DISCORD_EPOCH_MS; - assert_eq!(extracted_ms, ts_ms); - } - - #[test] - fn timestamp_ms_to_snowflake_at_discord_epoch_is_one() { - // At exactly the Discord epoch, discord_ms=0, shifted=0, clamped to 1 - let snowflake = timestamp_ms_to_snowflake(DISCORD_EPOCH_MS); - assert_eq!(snowflake.get(), 1); - } - - #[test] - fn timestamp_ms_to_snowflake_before_epoch_saturates() { - // Timestamp before Discord epoch should saturate to 1 - let snowflake = timestamp_ms_to_snowflake(1_000_000_000_000); - assert_eq!(snowflake.get(), 1); - } - - // --- ExportFilter cap logic --- - - #[test] - fn export_filter_default_cap_is_100() { - // Default (no params) uses Limit(100) - let filter = ExportFilter::Limit(100); - let cap = match &filter { - ExportFilter::Limit(n) => *n, - _ => THREAD_EXPORT_MESSAGE_LIMIT, - }; - assert_eq!(cap, 100); - } - - #[test] - fn export_filter_all_cap_is_5000() { - let filter = ExportFilter::All; - let cap = match &filter { - ExportFilter::Limit(n) => *n, - _ => THREAD_EXPORT_MESSAGE_LIMIT, - }; - assert_eq!(cap, THREAD_EXPORT_MESSAGE_LIMIT); - assert_eq!(cap, 5000); - } - - #[test] - fn export_filter_limit_uses_custom_cap() { - let filter = ExportFilter::Limit(250); - let cap = match &filter { - ExportFilter::Limit(n) => *n, - _ => THREAD_EXPORT_MESSAGE_LIMIT, - }; - assert_eq!(cap, 250); - } - - #[test] - fn export_filter_after_uses_global_cap() { - let filter = ExportFilter::After(MessageId::new(123456789)); - let cap = match &filter { - ExportFilter::Limit(n) => *n, - _ => THREAD_EXPORT_MESSAGE_LIMIT, - }; - assert_eq!(cap, THREAD_EXPORT_MESSAGE_LIMIT); - } - - // --- should_process_user_message tests (GIVEN/WHEN/THEN) --- - // Tests the multibot-mentions gating logic extracted from EventHandler::message. - // The bug in #481 was that other bots' messages were filtered by bot gating - // before multibot detection could run, so the bot never learned the thread - // was multi-bot and responded without @mention. - - /// GIVEN: multibot-mentions mode, single-bot thread, bot is involved - /// WHEN: human sends message without @mention - /// THEN: bot responds (natural conversation) - #[test] - fn multibot_mentions_single_bot_thread_no_mention() { - assert!(should_process_user_message( - AllowUsers::MultibotMentions, - false, // is_mentioned - true, // in_thread - true, // involved - false, // other_bot_present - )); - } - - /// GIVEN: multibot-mentions mode, multi-bot thread (other bot has posted) - /// WHEN: human sends message without @mention - /// THEN: bot does NOT respond (requires @mention in multi-bot thread) - /// This is the exact scenario from bug #481. - #[test] - fn multibot_mentions_multi_bot_thread_no_mention() { - assert!(!should_process_user_message( - AllowUsers::MultibotMentions, - false, // is_mentioned - true, // in_thread - true, // involved - true, // other_bot_present ← another bot posted - )); - } - - /// GIVEN: multibot-mentions mode, multi-bot thread - /// WHEN: human sends message WITH @mention - /// THEN: bot responds (explicit @mention always works) - #[test] - fn multibot_mentions_multi_bot_thread_with_mention() { - assert!(should_process_user_message( - AllowUsers::MultibotMentions, - true, // is_mentioned - true, // in_thread - true, // involved - true, // other_bot_present - )); - } - - /// GIVEN: multibot-mentions mode, not in a thread (main channel) - /// WHEN: human sends message without @mention - /// THEN: bot does NOT respond (main channel always requires @mention) - #[test] - fn multibot_mentions_main_channel_no_mention() { - assert!(!should_process_user_message( - AllowUsers::MultibotMentions, - false, // is_mentioned - false, // in_thread (main channel) - false, // involved - false, // other_bot_present - )); - } - - /// GIVEN: multibot-mentions mode, in thread but bot is NOT involved - /// WHEN: human sends message without @mention - /// THEN: bot does NOT respond (not participating in this thread) - #[test] - fn multibot_mentions_not_involved() { - assert!(!should_process_user_message( - AllowUsers::MultibotMentions, - false, // is_mentioned - true, // in_thread - false, // involved ← bot hasn't posted here - false, // other_bot_present - )); - } - - /// GIVEN: involved mode, multi-bot thread - /// WHEN: human sends message without @mention - /// THEN: bot responds (involved mode ignores multi-bot status) - #[test] - fn involved_mode_ignores_multibot() { - assert!(should_process_user_message( - AllowUsers::Involved, - false, // is_mentioned - true, // in_thread - true, // involved - true, // other_bot_present ← ignored in involved mode - )); - } - - /// GIVEN: mentions mode - /// WHEN: human sends message without @mention (even in own thread) - /// THEN: bot does NOT respond (always requires @mention) - #[test] - fn mentions_mode_always_requires_mention() { - assert!(!should_process_user_message( - AllowUsers::Mentions, - false, // is_mentioned - true, // in_thread - true, // involved - false, // other_bot_present - )); - } - - /// After soft limit fires once (n==20), subsequent bot messages still return - /// SoftLimit but with n>20. The caller warns only when n==max (exact hit), - /// preventing warning messages from ping-ponging between bots. - #[test] - fn soft_limit_warn_once_semantics() { - let mut t = BotTurnTracker::new(20); - for _ in 0..19 { - assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); - } - // n==20: exact hit — caller should send warning - assert_eq!(t.on_bot_message("t1"), TurnResult::SoftLimit(20)); - // n==21: past limit — caller should silently return (no warning) - assert_eq!(t.on_bot_message("t1"), TurnResult::Throttled); - // n==22: still past — still silent - assert_eq!(t.on_bot_message("t1"), TurnResult::Throttled); - } - - /// Hard limit also carries count for warn-once semantics. - #[test] - fn hard_limit_warn_once_semantics() { - let mut t = BotTurnTracker::new(HARD_BOT_TURN_LIMIT + 1); // soft > hard so hard fires first - for _ in 0..HARD_BOT_TURN_LIMIT - 1 { - assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); - } - // Exact hit — warn - assert_eq!(t.on_bot_message("t1"), TurnResult::HardLimit); - // Past — silent - assert_eq!(t.on_bot_message("t1"), TurnResult::Stopped); - } - - /// Regression test for #497: system messages (thread created, pin, etc.) - /// should NOT reset the bot turn counter. The filtering happens at the - /// call site (MessageType check); this verifies the counter stays put - /// when on_human_message is never called. - #[test] - fn system_message_does_not_reset_counter() { - let mut t = BotTurnTracker::new(3); - assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); - assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); - // No on_human_message (system message filtered out at call site) - assert_eq!(t.on_bot_message("t1"), TurnResult::SoftLimit(3)); - } - - // --- build_sender_context tests (regression for #581 → #584) --- - // PR #583 fixed SenderContext to use parent channel_id when in a thread. - // These tests verify the pure function extracted from EventHandler::message. - - /// In-thread message: channel_id = parent, thread_id = thread channel ID. - #[test] - fn build_sender_context_in_thread() { - let ctx = build_sender_context( - "user1", - "alice", - "Alice", - "thread_ch", - Some("parent_ch"), - false, - "2026-05-01T00:00:00Z", - "msg123", - "bot99", - ); - assert_eq!(ctx.channel_id, "parent_ch"); - assert_eq!(ctx.thread_id, Some("thread_ch".to_string())); - assert_eq!(ctx.channel, "discord"); - assert_eq!(ctx.sender_id, "user1"); - assert!(!ctx.is_bot); - assert_eq!(ctx.receiver_id, Some("bot99".to_string())); - } - - /// Non-thread message: channel_id = message channel, thread_id = None. - #[test] - fn build_sender_context_not_in_thread() { - let ctx = build_sender_context( - "user1", - "alice", - "Alice", - "main_ch", - None, - false, - "2026-05-01T00:00:00Z", - "msg456", - "bot99", - ); - assert_eq!(ctx.channel_id, "main_ch"); - assert_eq!(ctx.thread_id, None); - } - - /// Bot sender: is_bot flag propagated correctly. - #[test] - fn build_sender_context_bot_sender() { - let ctx = build_sender_context( - "bot1", - "mybot", - "MyBot", - "ch", - Some("parent"), - true, - "2026-05-01T00:00:00Z", - "msg789", - "bot99", - ); - assert!(ctx.is_bot); - assert_eq!(ctx.channel_id, "parent"); - assert_eq!(ctx.thread_id, Some("ch".to_string())); - } - - // --- detect_thread tests (regression for #506 → #518 → #519) --- - // PR #506 used parent_id.is_some() to detect threads, but category text - // channels also have parent_id (pointing to the category). This caused - // the bot to skip thread creation for normal channels inside categories. - // - // detect_thread() uses thread_metadata.is_some() — the canonical check - // per Discord API docs. Table-driven to cover all channel scenarios. - - const BOT: u64 = 1000; - const OTHER: u64 = 2000; - const PARENT_CH: u64 = 100; - const CATEGORY: u64 = 200; - - /// Helper: build an allowed_channels set from a slice. - fn allowed(ids: &[u64]) -> HashSet { - ids.iter().copied().collect() - } - - /// Table-driven: each row is a realistic Discord channel scenario. - #[test] - fn detect_thread_table() { - struct Case { - name: &'static str, - has_thread_metadata: bool, - parent_id: Option, - owner_id: Option, - bot_id: u64, - allowed_channels: HashSet, - allow_all: bool, - in_allowed: bool, - expect: (bool, Option), // (in_thread, bot_owns) - } - - let cases = vec![ - // --- Non-thread channels: thread_metadata = None --- - Case { - name: "text channel under category (regression #506)", - has_thread_metadata: false, - parent_id: Some(CATEGORY), // points to category, NOT a thread - owner_id: None, - bot_id: BOT, - allowed_channels: allowed(&[]), - allow_all: false, - in_allowed: true, - expect: (false, None), - }, - Case { - name: "top-level text channel (no category)", - has_thread_metadata: false, - parent_id: None, - owner_id: None, - bot_id: BOT, - allowed_channels: allowed(&[]), - allow_all: false, - in_allowed: true, - expect: (false, None), - }, - Case { - name: "voice channel under category", - has_thread_metadata: false, - parent_id: Some(CATEGORY), - owner_id: None, - bot_id: BOT, - allowed_channels: allowed(&[]), - allow_all: false, - in_allowed: false, - expect: (false, None), - }, - // --- Thread channels: thread_metadata = Some --- - Case { - name: "public thread, parent in allowlist, bot owns", - has_thread_metadata: true, - parent_id: Some(PARENT_CH), - owner_id: Some(BOT), - bot_id: BOT, - allowed_channels: allowed(&[PARENT_CH]), - allow_all: false, - in_allowed: false, - expect: (true, Some(true)), - }, - Case { - name: "public thread, parent in allowlist, other user owns", - has_thread_metadata: true, - parent_id: Some(PARENT_CH), - owner_id: Some(OTHER), - bot_id: BOT, - allowed_channels: allowed(&[PARENT_CH]), - allow_all: false, - in_allowed: false, - expect: (true, Some(false)), - }, - Case { - name: "thread, parent NOT in allowlist, not allow_all", - has_thread_metadata: true, - parent_id: Some(PARENT_CH), - owner_id: Some(BOT), - bot_id: BOT, - allowed_channels: allowed(&[]), - allow_all: false, - in_allowed: false, - expect: (false, Some(true)), - }, - Case { - name: "thread, allow_all_channels = true", - has_thread_metadata: true, - parent_id: Some(PARENT_CH), - owner_id: Some(OTHER), - bot_id: BOT, - allowed_channels: allowed(&[]), - allow_all: true, - in_allowed: false, - expect: (true, Some(false)), - }, - Case { - name: "thread, in_allowed_channel = true (parent is the allowed channel)", - has_thread_metadata: true, - parent_id: Some(PARENT_CH), - owner_id: None, - bot_id: BOT, - allowed_channels: allowed(&[]), - allow_all: false, - in_allowed: true, - expect: (true, Some(false)), - }, - // --- Defensive: partial data --- - Case { - name: "thread with parent_id = None (defensive, partial API data)", - has_thread_metadata: true, - parent_id: None, - owner_id: Some(BOT), - bot_id: BOT, - allowed_channels: allowed(&[PARENT_CH]), - allow_all: false, - in_allowed: false, - expect: (false, Some(true)), // can't verify parent → not allowed, but bot still owns - }, - ]; - - for c in &cases { - let result = detect_thread( - c.has_thread_metadata, - c.parent_id, - c.owner_id, - c.bot_id, - &c.allowed_channels, - c.allow_all, - c.in_allowed, - ); - assert_eq!(result, c.expect, "FAILED: {}", c.name); - } - } - - // --- WarnAndStop regression test (#633) --- - // The WarnAndStop path now delegates to detect_thread(). This test pins - // the exact scenario from #633: a category child channel whose category - // ID is in another bot's allowed_channels must NOT be treated as allowed. - #[test] - fn detect_thread_rejects_category_child_in_warn_and_stop() { - let category_id: u64 = 200; - let allowed = HashSet::from([category_id]); - // Category child: has parent_id (the category) but NO thread_metadata. - let (in_thread, _) = - detect_thread(false, Some(category_id), None, 1000, &allowed, false, false); - assert!( - !in_thread, - "category child must not match allowed_channels via parent_id" - ); - } - - // --- Per-thread streaming tests (#534) --- - // Streaming ON by default, OFF when another bot is detected in the thread. - - /// Single bot thread: streaming enabled. - #[test] - fn discord_streams_when_no_other_bot() { - let adapter = super::DiscordAdapter::new(Arc::new(super::Http::new(""))); - assert!(adapter.use_streaming(false)); - } - - /// Multi-bot thread: send-once to avoid edit interference. - #[test] - fn discord_no_stream_when_other_bot_present() { - let adapter = super::DiscordAdapter::new(Arc::new(super::Http::new(""))); - assert!(!adapter.use_streaming(true)); - } - - // --- resolve_channel tests --- - - #[test] - fn resolve_channel_uses_channel_id_when_no_thread() { - let ch = ChannelRef { - platform: "discord".into(), - channel_id: "111".into(), - thread_id: None, - parent_id: None, - origin_event_id: None, - }; - assert_eq!(DiscordAdapter::resolve_channel(&ch), "111"); - } - - #[test] - fn resolve_channel_prefers_thread_id_when_set() { - let ch = ChannelRef { - platform: "discord".into(), - channel_id: "111".into(), - thread_id: Some("222".into()), - parent_id: None, - origin_event_id: None, - }; - assert_eq!(DiscordAdapter::resolve_channel(&ch), "222"); - } - - // --- is_denied_user tests (regression for #604) --- - - /// Human not in allowlist → denied. - #[test] - fn denied_user_human_not_in_allowlist() { - let allowed = HashSet::from([100]); - assert!(is_denied_user(false, false, &allowed, 999)); - } - - /// Human in allowlist → allowed. - #[test] - fn denied_user_human_in_allowlist() { - let allowed = HashSet::from([100]); - assert!(!is_denied_user(false, false, &allowed, 100)); - } - - /// Bot not in allowlist → allowed (bots skip user gate). This is the #604 fix. - #[test] - fn denied_user_bot_skips_allowlist() { - let allowed = HashSet::from([100]); - assert!(!is_denied_user(true, false, &allowed, 999)); - } - - // --- Trusted bot mention bypass tests --- - // A trusted bot @mentioning this bot bypasses allow_bot_messages mode, - // treating the mention the same as a human @mention. - - /// GIVEN: trusted bot @mentions this bot - /// THEN: bypass is granted (treated as human mention) - #[test] - fn trusted_bot_mention_bypasses_gate() { - let trusted = HashSet::from([42]); - assert!(is_trusted_bot_mention(true, &trusted, 42)); - } - - /// GIVEN: untrusted bot @mentions this bot - /// THEN: no bypass (normal bot gating applies) - #[test] - fn untrusted_bot_mention_no_bypass() { - let trusted = HashSet::from([42]); - assert!(!is_trusted_bot_mention(true, &trusted, 99)); - } - - /// GIVEN: trusted bot sends message WITHOUT @mention - /// THEN: no bypass (must explicitly @mention) - #[test] - fn trusted_bot_no_mention_no_bypass() { - let trusted = HashSet::from([42]); - assert!(!is_trusted_bot_mention(false, &trusted, 42)); - } - - /// GIVEN: empty trusted_bot_ids (feature not configured) - /// THEN: no bypass regardless of mention - #[test] - fn empty_trusted_ids_no_bypass() { - let trusted: HashSet = HashSet::new(); - assert!(!is_trusted_bot_mention(true, &trusted, 42)); - } - - // --- Trusted bot admission integration tests --- - // These test the full bot gating decision path: allow_bot_messages mode + - // trusted_bot_ids + trusted mention bypass, mirroring the actual logic in - // EventHandler::message. - - /// Simulates the bot admission decision from EventHandler::message. - /// Returns `true` if the bot message would be processed (not dropped). - fn should_admit_bot_message( - allow_bot_messages: AllowBots, - is_mentioned: bool, - trusted_bot_ids: &HashSet, - author_id: u64, - ) -> bool { - let trusted_mention = is_mentioned - && !trusted_bot_ids.is_empty() - && trusted_bot_ids.contains(&author_id); - - if !trusted_mention { - match allow_bot_messages { - AllowBots::Off => return false, - AllowBots::Mentions => { - if !is_mentioned { - return false; - } - } - AllowBots::All => {} // would check consecutive cap, skip for unit test - } - - if !trusted_bot_ids.is_empty() && !trusted_bot_ids.contains(&author_id) { - return false; - } - } - true - } - - /// GIVEN: allow_bot_messages=Off, trusted bot @mentions this bot - /// THEN: admitted (trusted mention overrides Off mode) - #[test] - fn bot_admission_trusted_mention_overrides_off() { - let trusted = HashSet::from([42]); - assert!(should_admit_bot_message(AllowBots::Off, true, &trusted, 42)); - } - - /// GIVEN: allow_bot_messages=Off, untrusted bot @mentions this bot - /// THEN: rejected (Off mode blocks) - #[test] - fn bot_admission_untrusted_mention_blocked_by_off() { - let trusted = HashSet::from([42]); - assert!(!should_admit_bot_message(AllowBots::Off, true, &trusted, 99)); - } - - /// GIVEN: allow_bot_messages=Off, trusted bot without @mention - /// THEN: rejected (no mention = no bypass) - #[test] - fn bot_admission_trusted_no_mention_blocked_by_off() { - let trusted = HashSet::from([42]); - assert!(!should_admit_bot_message(AllowBots::Off, false, &trusted, 42)); - } - - /// GIVEN: allow_bot_messages=Off, empty trusted_bot_ids, bot @mentions - /// THEN: rejected (feature not configured) - #[test] - fn bot_admission_empty_trusted_ids_off_mode() { - let trusted: HashSet = HashSet::new(); - assert!(!should_admit_bot_message(AllowBots::Off, true, &trusted, 42)); - } - - /// GIVEN: allow_bot_messages=Mentions, trusted bot @mentions - /// THEN: admitted (would pass anyway, but bypass also works) - #[test] - fn bot_admission_mentions_mode_trusted_mention() { - let trusted = HashSet::from([42]); - assert!(should_admit_bot_message(AllowBots::Mentions, true, &trusted, 42)); - } - - /// GIVEN: allow_bot_messages=All, untrusted bot (not in trusted_bot_ids) - /// THEN: rejected by trusted_bot_ids filter - #[test] - fn bot_admission_all_mode_untrusted_bot_rejected() { - let trusted = HashSet::from([42]); - assert!(!should_admit_bot_message(AllowBots::All, false, &trusted, 99)); - } - - // --- DM gating tests (#656) --- - // DMs are gated by `allow_dm` config. When allowed, DMs bypass - // `allowed_channels` and treat the message as implicit @mention. - - /// GIVEN: allow_dm = false - /// WHEN: user sends a DM - /// THEN: DM is rejected - #[test] - fn dm_rejected_when_allow_dm_false() { - assert!(!should_process_dm(false)); - } - - /// GIVEN: allow_dm = true - /// WHEN: user sends a DM - /// THEN: DM is accepted - #[test] - fn dm_accepted_when_allow_dm_true() { - assert!(should_process_dm(true)); - } - - /// GIVEN: allow_dm = true, user NOT in allowed_users - /// WHEN: user sends a DM - /// THEN: user is denied (allowed_users still enforced in DMs) - #[test] - fn dm_denied_user_still_enforced() { - let allowed = HashSet::from([100]); - // DM passes allow_dm gate, but user gate still applies - assert!(should_process_dm(true)); - assert!(is_denied_user(false, false, &allowed, 999)); - } - - /// GIVEN: allow_dm = true, user in allowed_users - /// WHEN: user sends a DM - /// THEN: user is allowed - #[test] - fn dm_allowed_user_passes() { - let allowed = HashSet::from([100]); - assert!(should_process_dm(true)); - assert!(!is_denied_user(false, false, &allowed, 100)); - } - - /// DMs are treated as implicit @mention — should_process_user_message - /// is never called for DMs (the `!is_dm` guard skips it). - /// This test verifies the Involved mode would reject a non-thread, - /// non-mentioned message — confirming DMs MUST bypass this check. - #[test] - fn dm_must_bypass_user_message_gating() { - // Without the `!is_dm` bypass, a DM would be rejected by Involved mode - // because is_mentioned=false and in_thread=false. - assert!(!should_process_user_message( - AllowUsers::Involved, - false, // is_mentioned (DMs don't have @mention) - false, // in_thread (DMs are not threads) - false, // involved - false, // other_bot_present - )); - } - - // --- Thread creation skip tests (regression for #656 DM bug) --- - // Pins the invariant: DMs must never call get_or_create_thread(). - // Discord DM channels do not support thread creation. - - /// GIVEN: is_dm = true, not in a thread - /// THEN: skip thread creation (use DM channel directly) - #[test] - fn dm_skips_thread_creation() { - assert!(should_skip_thread_creation(false, true)); - } - - /// GIVEN: already in a thread, not a DM - /// THEN: skip thread creation (reuse existing thread) - #[test] - fn existing_thread_skips_thread_creation() { - assert!(should_skip_thread_creation(true, false)); - } - - /// GIVEN: not in a thread, not a DM (normal channel message) - /// THEN: do NOT skip — create a new thread - #[test] - fn normal_channel_creates_thread() { - assert!(!should_skip_thread_creation(false, false)); - } - - // --- WarnAndStop dedup tests (#530) --- - - #[test] - fn dedup_detects_existing_bot_warning() { - let msg = format!("{} (20/20). A human must reply.", BOT_TURN_LIMIT_WARNING_PREFIX); - assert!(turn_limit_warning_present(&[(true, &msg)])); - } - - #[test] - fn dedup_ignores_human_warning_text() { - let msg = format!("{} (20/20). A human must reply.", BOT_TURN_LIMIT_WARNING_PREFIX); - assert!(!turn_limit_warning_present(&[(false, &msg)])); - } - - #[test] - fn dedup_returns_false_when_no_warning() { - assert!(!turn_limit_warning_present(&[(true, "hello"), (false, "world")])); - } - - #[test] - fn dedup_returns_false_for_empty_messages() { - assert!(!turn_limit_warning_present(&[])); - } -} diff --git a/src/dispatch.rs b/src/dispatch.rs deleted file mode 100644 index 97d5f25e3..000000000 --- a/src/dispatch.rs +++ /dev/null @@ -1,1727 +0,0 @@ -//! Turn-boundary message batching dispatcher. -//! -//! See ADR: docs/adr/turn-boundary-batching.md for full design rationale. -//! -//! # Invariants -//! - I1: First message after idle has zero added latency. -//! - I2: At most one in-flight ACP turn per thread. -//! - I3: Broker structural fidelity — no merging, splitting, reordering, or -//! semantic transformation of arrival events. - -use std::collections::HashMap; -use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::{Arc, Mutex}; -use std::time::{Duration, Instant}; - -use anyhow::Result; -use async_trait::async_trait; -use tracing::{debug, error, info, info_span, warn}; - -use crate::acp::ContentBlock; -use crate::adapter::{AdapterRouter, ChannelRef, ChatAdapter, MessageRef}; -use crate::config::ReactionsConfig; -use crate::error_display::format_user_error; -use crate::reactions::StatusReactionController; - -// --------------------------------------------------------------------------- -// Public types -// --------------------------------------------------------------------------- - -/// One arrival event buffered for a future ACP turn. -pub struct BufferedMessage { - /// Serialised SenderContext JSON (already built by the platform adapter). - pub sender_json: String, - /// Author display name — denormalised from `sender_json` so observability - /// fields (per-event tracing in `dispatch_batch`) don't pay a JSON parse. - /// Per ADR §2.3 each arrival event carries its sender name. - pub sender_name: String, - /// User-visible prompt text (verbatim, never transformed). - pub prompt: String, - /// Attachment blocks (images, STT transcripts) in arrival order. - pub extra_blocks: Vec, - /// Anchor for reactions (👀 / ❌). - pub trigger_msg: MessageRef, - /// Broker receive time — used for `buffer_wait_ms` observability. - pub arrived_at: Instant, - /// Rough token estimate for `max_batch_tokens` cap. - pub estimated_tokens: usize, - /// Snapshot at submit time. Captured per-message so a batch reflects the - /// freshest known state; `dispatch_batch` reads `batch.last()`. - pub other_bot_present: bool, - /// Slack streaming recipient `(user_id, team_id)` for `chat.startStream`, - /// captured at message-arrival time (after allow-list) and bound to this turn - /// — no shared thread cache, so no cross-turn race. Populated for real-user - /// Slack turns regardless of `assistant_mode`; only *consumed* when assistant - /// mode's native streaming is active. `None` for non-Slack platforms and - /// bot-authored turns. - pub recipient: Option<(String, String)>, -} - -/// How `thread_key` is built for the dispatcher's per-thread map. -/// -/// - `Thread`: one mpsc per thread → all senders in a thread share one batch → one -/// ACP turn per batch (cheaper, but risks silent drop when the agent's single reply -/// forgets to address some senders). -/// - `Lane`: one mpsc per (thread, sender) → each sender batches independently and -/// gets a dedicated ACP turn. Sessions are still shared per-thread; turns serialise -/// through the shared session. -/// -/// Derived from `config::MessageProcessingMode` in `main.rs`. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum BatchGrouping { - Thread, - Lane, -} - -/// Error returned by `Dispatcher::submit`. -#[derive(Debug)] -pub enum DispatchError { - /// The per-thread consumer task has exited unexpectedly. - ConsumerDead, -} - -impl std::fmt::Display for DispatchError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::ConsumerDead => write!(f, "dispatch consumer exited unexpectedly"), - } - } -} - -impl std::error::Error for DispatchError {} - -// --------------------------------------------------------------------------- -// Internal types -// --------------------------------------------------------------------------- - -struct ThreadHandle { - tx: tokio::sync::mpsc::Sender, - consumer: tokio::task::JoinHandle<()>, - /// Race-safe eviction counter (§2.5). Plain u64 — all reads/writes under per_thread lock. - generation: u64, - channel_id: String, - adapter_kind: String, -} - -impl ThreadHandle { - /// Approximate number of messages still buffered in the mpsc — used for - /// shutdown / cancel logging. Not exact: tokio's mpsc has no sync `.len()`. - fn pending_count(&self) -> usize { - self.tx.max_capacity() - self.tx.capacity() - } -} - -// --------------------------------------------------------------------------- -// DispatchTarget — trait seam between Dispatcher and AdapterRouter -// --------------------------------------------------------------------------- - -/// Surface that `consumer_loop` / `dispatch_batch` need from the underlying -/// router. Extracted as a trait so the dispatcher can be unit-tested without -/// spinning up a real `SessionPool` (which forks ACP CLI subprocesses). -/// `AdapterRouter` is the production implementor; tests use a mock that -/// records calls. -#[async_trait] -pub trait DispatchTarget: Send + Sync + 'static { - fn reactions_config(&self) -> &ReactionsConfig; - - /// Workspace aliases from config (for `[[ws:@alias]]` resolution). - fn workspace_aliases(&self) -> std::collections::HashMap; - - /// Bot home directory (security boundary for workspace resolution). - fn bot_home(&self) -> std::path::PathBuf; - - /// Ensure the ACP session for `session_key` exists (idempotent). - /// Returns `true` if a new session was created, `false` if it already existed. - async fn ensure_session(&self, session_key: &str, working_dir: Option<&str>) -> Result; - - /// Destroy the session for `session_key` (used to rollback on directive failure). - async fn reset_session(&self, session_key: &str); - - /// Drive one ACP turn with the pre-packed `content_blocks`. - #[allow(clippy::too_many_arguments)] - async fn stream_prompt_blocks( - &self, - adapter: &Arc, - session_key: &str, - content_blocks: Vec, - thread_channel: &ChannelRef, - reactions: Arc, - other_bot_present: bool, - recipient: Option<(String, String)>, - ) -> Result<()>; -} - -#[async_trait] -impl DispatchTarget for AdapterRouter { - fn reactions_config(&self) -> &ReactionsConfig { - AdapterRouter::reactions_config(self) - } - - fn workspace_aliases(&self) -> std::collections::HashMap { - self.workspace_aliases_map() - } - - fn bot_home(&self) -> std::path::PathBuf { - self.bot_home_path() - } - - async fn ensure_session(&self, session_key: &str, working_dir: Option<&str>) -> Result { - self.pool().get_or_create(session_key, working_dir).await - } - - async fn reset_session(&self, session_key: &str) { - let _ = self.pool().reset_session(session_key).await; - } - - async fn stream_prompt_blocks( - &self, - adapter: &Arc, - session_key: &str, - content_blocks: Vec, - thread_channel: &ChannelRef, - reactions: Arc, - other_bot_present: bool, - recipient: Option<(String, String)>, - ) -> Result<()> { - AdapterRouter::stream_prompt_blocks( - self, - adapter, - session_key, - content_blocks, - thread_channel, - reactions, - other_bot_present, - recipient, - ) - .await - } -} - -// --------------------------------------------------------------------------- -// Dispatcher -// --------------------------------------------------------------------------- - -/// Default idle timeout for per-thread consumer tasks in batched modes (Thread / Lane). -/// When no message arrives within this window the consumer exits, allowing `per_thread` -/// map cleanup on the next `submit` (via `SendError` → `try_evict_locked`). Prevents -/// unbounded task/memory growth from one-shot thread keys (e.g. Slack non-thread messages). -/// -/// Batched modes need a longer window so a lane that's between trigger arrivals isn't -/// torn down and respawned on every message. -pub const DEFAULT_CONSUMER_IDLE_TIMEOUT: Duration = Duration::from_secs(300); - -/// Idle timeout for per-message mode (cap=1, no batching). Per-message dispatchers -/// don't benefit from holding consumers across message gaps — there is no batch -/// window to preserve — so a much shorter timeout reduces idle resource footprint -/// from one-shot thread keys (Little's Law: steady-state idle count = arrival rate -/// × idle window). -pub const PER_MESSAGE_CONSUMER_IDLE_TIMEOUT: Duration = Duration::from_secs(10); - -/// Resolve `(cap, grouping, idle_timeout)` for a given processing mode. -/// -/// Per-message mode forces cap=1 + Thread grouping + the short per-message idle -/// (one-shot threads shouldn't pin a consumer for 5 min); Thread / Lane modes -/// use the configured `max_buffered` and the default idle window. -pub fn dispatch_params( - mode: &crate::config::MessageProcessingMode, - max_buffered: usize, -) -> (usize, BatchGrouping, Duration) { - use crate::config::MessageProcessingMode; - match mode { - MessageProcessingMode::Message => { - (1, BatchGrouping::Thread, PER_MESSAGE_CONSUMER_IDLE_TIMEOUT) - } - MessageProcessingMode::Thread => ( - max_buffered, - BatchGrouping::Thread, - DEFAULT_CONSUMER_IDLE_TIMEOUT, - ), - MessageProcessingMode::Lane => ( - max_buffered, - BatchGrouping::Lane, - DEFAULT_CONSUMER_IDLE_TIMEOUT, - ), - } -} - -/// Per-thread message dispatcher for batched mode. -/// -/// Constructed once in `main.rs` and shared via `Arc`. Platform adapters call -/// `submit()` from their per-message `tokio::spawn`'d tasks. -pub struct Dispatcher { - /// std::sync::Mutex — critical section has no .await; tokio::Mutex buys nothing here. - per_thread: Mutex>, - /// Monotonic counter for `ThreadHandle.generation` (§2.5). Pre-fetched on - /// every `submit` and consumed only when a fresh handle is inserted; wasted - /// values are fine because generations need only be monotonic, not contiguous. - next_generation: AtomicU64, - target: Arc, - max_buffered_messages: usize, - max_batch_tokens: usize, - grouping: BatchGrouping, - idle_timeout: Duration, -} - -impl Dispatcher { - /// Construct a dispatcher with an explicit consumer idle timeout. Per-mode - /// callers in `main.rs` pass `PER_MESSAGE_CONSUMER_IDLE_TIMEOUT` for cap=1 - /// dispatchers and `DEFAULT_CONSUMER_IDLE_TIMEOUT` for batched modes. - pub fn with_idle_timeout( - target: Arc, - max_buffered_messages: usize, - max_batch_tokens: usize, - grouping: BatchGrouping, - idle_timeout: Duration, - ) -> Self { - Self { - per_thread: Mutex::new(HashMap::new()), - next_generation: AtomicU64::new(0), - target, - max_buffered_messages, - max_batch_tokens, - grouping, - idle_timeout, - } - } - - /// Build the dispatcher key for a (platform, thread, sender) tuple. - /// - /// In `Thread` mode the sender is ignored; in `Lane` mode the sender is appended - /// so each (thread, sender) pair gets its own mpsc and consumer. - /// - /// Note: this is the *dispatcher* key, not the *session pool* key. Session pool keys - /// are always `:` regardless of grouping (the ACP session is - /// shared per-thread by design). - pub fn key(&self, platform: &str, thread_id: &str, sender_id: &str) -> String { - match self.grouping { - BatchGrouping::Thread => format!("{platform}:{thread_id}"), - BatchGrouping::Lane => format!("{platform}:{thread_id}:{sender_id}"), - } - } - - /// Build the shared session pool key for a routed channel. - /// - /// Unlike dispatcher keys, session keys never include sender identity. - /// They track the logical conversation thread across all grouping modes. - fn session_key(thread_channel: &ChannelRef) -> String { - let logical_thread_id = thread_channel - .thread_id - .as_deref() - .unwrap_or(&thread_channel.channel_id); - format!("{}:{}", thread_channel.platform, logical_thread_id) - } - - /// Submit one arrival event for the given thread. - /// - /// - If the thread has no active consumer, one is spawned lazily. - /// - If the channel is full, this future parks until space is available - /// (backpressure — no data loss, no error). - /// - If the consumer has died (`SendError`), surfaces ❌ + ⚠️ and returns - /// `Err(DispatchError::ConsumerDead)` (§2.5). - /// - /// `adapter` is passed per-call (not stored on `Dispatcher`) because the - /// Discord adapter is constructed inside serenity's `ready` callback via - /// `OnceLock` — after the Dispatcher is built in `main.rs`. - pub async fn submit( - &self, - thread_key: String, - thread_channel: ChannelRef, - adapter: Arc, - msg: BufferedMessage, - ) -> Result<(), DispatchError> { - let cap = self.max_buffered_messages; - let target = Arc::clone(&self.target); - let max_tokens = self.max_batch_tokens; - let idle_timeout = self.idle_timeout; - - // Pre-fetch a generation in case we end up inserting a fresh handle. - // Wasted if the entry already exists; generations need only be monotonic. - let next_g = self.next_generation.fetch_add(1, Ordering::Relaxed); - - let (tx, my_generation) = { - // SAFETY: no .await while this guard is held — guard drops at end of block. - let mut map = self.per_thread.lock().unwrap(); - - // Proactive stale-entry cleanup: if the consumer has exited (idle - // timeout or unexpected), remove the entry so `or_insert_with` - // creates a fresh one. Prevents map leak from one-shot thread keys - // and avoids the first-message-after-idle being treated as an error. - if let Some(handle) = map.get(&thread_key) { - if handle.consumer.is_finished() { - map.remove(&thread_key); - } - } - - let entry = map.entry(thread_key.clone()).or_insert_with(|| { - let (tx, rx) = tokio::sync::mpsc::channel(cap); - let consumer = tokio::spawn(consumer_loop( - thread_key.clone(), - thread_channel.clone(), - rx, - Arc::clone(&target), - Arc::clone(&adapter), - cap, - max_tokens, - idle_timeout, - )); - ThreadHandle { - tx, - consumer, - generation: next_g, - channel_id: thread_channel.channel_id.clone(), - adapter_kind: adapter.platform().to_string(), - } - }); - (entry.tx.clone(), entry.generation) - }; - - if let Err(e) = tx.send(msg).await { - // Consumer has exited between our check and the send — race-safe - // eviction under lock (§2.5), then transparent retry once. - // - // Safe to re-acquire `per_thread` here: the first lock guard above - // was dropped before `tx.send().await`, so this acquisition cannot - // deadlock against the await point. The same property holds for the - // retry acquisition below. - { - // SAFETY: no .await while this guard is held. - let mut map = self.per_thread.lock().unwrap(); - Self::try_evict_locked(&mut map, &thread_key, my_generation); - } - let failed_msg = e.0; - - // Retry: spawn a fresh consumer and re-send. If this also fails, - // surface the error to the user. - let retry_g = self.next_generation.fetch_add(1, Ordering::Relaxed); - let (retry_tx, retry_gen) = { - // SAFETY: no .await while this guard is held — guard drops at end of block. - let mut map = self.per_thread.lock().unwrap(); - let entry = map.entry(thread_key.clone()).or_insert_with(|| { - let (tx, rx) = tokio::sync::mpsc::channel(cap); - let consumer = tokio::spawn(consumer_loop( - thread_key.clone(), - thread_channel.clone(), - rx, - Arc::clone(&target), - Arc::clone(&adapter), - cap, - max_tokens, - idle_timeout, - )); - ThreadHandle { - tx, - consumer, - generation: retry_g, - channel_id: thread_channel.channel_id.clone(), - adapter_kind: adapter.platform().to_string(), - } - }); - (entry.tx.clone(), entry.generation) - }; - - if let Err(e2) = retry_tx.send(failed_msg).await { - // Retry also failed — truly unexpected. Surface error. - { - // SAFETY: no .await while this guard is held. - let mut map = self.per_thread.lock().unwrap(); - Self::try_evict_locked(&mut map, &thread_key, retry_gen); - } - let failed_msg = e2.0; - let _ = adapter - .add_reaction( - &failed_msg.trigger_msg, - &self.target.reactions_config().emojis.error, - ) - .await; - let _ = adapter - .send_message( - &thread_channel, - &format!( - "⚠️ {}", - format_user_error("dispatch consumer exited unexpectedly") - ), - ) - .await; - return Err(DispatchError::ConsumerDead); - } - } - Ok(()) - } - - /// Drop all per-thread handles whose key belongs to `(platform, thread_id)`, - /// regardless of grouping, and abort each consumer (§2.5 / §4.4). Returns - /// the total number of buffered messages discarded across all lanes. - /// - /// Matches both Thread keys (`:`) and Lane keys - /// (`::`). Used by `/reset` and - /// `/cancel-all` to clear the entire thread, not just one lane. - /// - /// Disjoint from SendError recovery: removal happens *before* abort, so any - /// fresh `submit` after this returns lands on a lazily-constructed new handle - /// instead of observing `SendError`. - pub fn cancel_buffered_thread(&self, platform: &str, thread_id: &str) -> usize { - let prefix = format!("{platform}:{thread_id}"); - let lane_prefix = format!("{prefix}:"); - // SAFETY: no .await while this guard is held — function is sync. - let mut map = self.per_thread.lock().unwrap(); - let keys: Vec = map - .keys() - .filter(|k| k.as_str() == prefix || k.starts_with(&lane_prefix)) - .cloned() - .collect(); - let mut dropped = 0; - for k in keys { - if let Some(handle) = map.remove(&k) { - dropped += handle.pending_count(); - handle.consumer.abort(); - } - } - dropped - } - - /// §2.5 race-safe eviction. Caller must hold the `per_thread` mutex. - /// Removes the entry only if its generation matches `my_generation` — - /// protects against evicting a fresh handle that another `submit` lazily - /// inserted between this caller's failed `tx.send` and this call. - /// Returns true if the entry was removed. - fn try_evict_locked( - map: &mut HashMap, - thread_key: &str, - my_generation: u64, - ) -> bool { - if let Some(handle) = map.get(thread_key) { - if handle.generation == my_generation { - map.remove(thread_key); - return true; - } - } - false - } - - /// Remove map entries whose consumer task has finished (idle timeout or - /// unexpected exit). Called periodically from the cleanup task in main.rs - /// to prevent unbounded map growth from one-shot thread keys that never - /// receive a second `submit()`. Returns the number of entries swept. - pub fn sweep_stale(&self) -> usize { - // SAFETY: no .await while this guard is held — function is sync. - let mut map = self.per_thread.lock().unwrap(); - let before = map.len(); - map.retain(|_, handle| !handle.consumer.is_finished()); - before - map.len() - } - - /// Log buffered-message counts and drop all handles (called on SIGTERM). - pub fn shutdown(&self) { - // SAFETY: no .await while this guard is held — function is sync. - let mut map = self.per_thread.lock().unwrap(); - for (thread_id, handle) in map.iter() { - let pending = handle.pending_count(); - if pending > 0 { - warn!( - thread_id = %thread_id, - channel = %handle.channel_id, - adapter = %handle.adapter_kind, - buffered_lost = pending, - "shutdown dropped pending messages without dispatch", - ); - } - handle.consumer.abort(); - } - map.clear(); - } -} - -// --------------------------------------------------------------------------- -// consumer_loop -// --------------------------------------------------------------------------- - -#[allow(clippy::too_many_arguments)] -async fn consumer_loop( - thread_key: String, - thread_channel: ChannelRef, - mut rx: tokio::sync::mpsc::Receiver, - target: Arc, - adapter: Arc, - max_batch: usize, - max_tokens: usize, - idle_timeout: Duration, -) { - // `pending` holds a message that exceeded the token cap for the current batch; - // it becomes the first message of the next batch, preserving FIFO. - let mut pending: Option = None; - - loop { - // I1: block until at least one message arrives (zero latency for first message). - // Idle timeout: if no message arrives within `idle_timeout` the consumer - // exits, freeing the task and mpsc. The next `submit` for this thread_key - // will observe `SendError`, evict the stale entry, and lazily spawn a - // fresh consumer (§2.5 generation check prevents mis-eviction). - let first = match pending.take() { - Some(msg) => msg, - None => match tokio::time::timeout(idle_timeout, rx.recv()).await { - Ok(Some(msg)) => msg, - Ok(None) => { - // All senders dropped → shutdown() or cancel_buffered_thread(). - break; - } - Err(_elapsed) => { - debug!( - thread_key = %thread_key, - channel = %thread_channel.channel_id, - "consumer idle timeout, exiting" - ); - break; - } - }, - }; - - // Greedy drain up to max_batch messages or max_tokens. - let mut batch = vec![first]; - let mut cumulative_tokens = batch[0].estimated_tokens; - - while batch.len() < max_batch { - match rx.try_recv() { - Ok(more) => { - if cumulative_tokens + more.estimated_tokens > max_tokens { - // Token cap — save for next turn (FIFO preserved). - pending = Some(more); - break; - } - cumulative_tokens += more.estimated_tokens; - batch.push(more); - } - Err(_) => break, - } - } - - // §2.6: read the freshest snapshot in the batch (batch is non-empty). - let bot_present = batch.last().unwrap().other_bot_present; - - dispatch_batch( - &thread_key, - &thread_channel, - &target, - &adapter, - batch, - bot_present, - ) - .await; - } -} - -// --------------------------------------------------------------------------- -// dispatch_batch -// --------------------------------------------------------------------------- - -async fn dispatch_batch( - thread_key: &str, - thread_channel: &ChannelRef, - target: &Arc, - adapter: &Arc, - batch: Vec, - other_bot_present: bool, -) { - let dispatch_start = Instant::now(); - let batch_size = batch.len(); - let session_key = Dispatcher::session_key(thread_channel); - - // Apply 👀 reaction to every message in the batch before dispatch (§6.7). - // Skip when assistant status API is active — uses - // assistant.threads.setStatus instead of emoji reactions. - let assistant_status = adapter.uses_assistant_status(); - if !assistant_status { - let queued_emoji = &target.reactions_config().emojis.queued; - for msg in batch.iter() { - let _ = adapter.add_reaction(&msg.trigger_msg, queued_emoji).await; - } - } - - // Collect per-event observability data (before consuming the batch). - let tokens_per_event: Vec = batch.iter().map(|m| m.estimated_tokens).collect(); - let wait_ms: Vec = batch - .iter() - .map(|m| m.arrived_at.elapsed().as_millis()) - .collect(); - let senders: Vec = batch.iter().map(|m| m.sender_name.clone()).collect(); - - // Native-streaming recipient is bound to the turn (captured per-message). A - // batch attributes to the most recent sender; None for non-Slack/bot turns. - let recipient: Option<(String, String)> = batch.last().and_then(|m| m.recipient.clone()); - - // Anchor reactions on the last message in the batch (before consuming). - let trigger_msg = batch.last().unwrap().trigger_msg.clone(); - let dispatch_channel = ChannelRef { - // Reply correlation is event-scoped, but the dispatcher consumer is - // thread-scoped. Rebuild the per-dispatch channel from the stable - // thread route plus the freshest event ID so gateway replies (e.g. - // LINE reply-token lookup) target the current inbound event. - origin_event_id: trigger_msg.channel.origin_event_id.clone(), - ..thread_channel.clone() - }; - - // Pack all arrival events into one Vec (§3.3). - // Uses into_iter() to avoid deep-copying extra_blocks (may contain base64 image data). - let mut content_blocks: Vec = Vec::new(); - - // Parse control directives from the first message in the batch (ADR: control-directives). - // Directives are only processed on the session's first message (§2.2). - // - // Strategy: - // 1. Parse directives (cheap text extraction — no mutation, no I/O) - // 2. Attempt workspace resolution if [[ws:...]] present (may fail gracefully) - // 3. Call ensure_session with resolved workspace — returns created_now - // 4. Only strip prompt and apply title/workspace if created_now == true - // 5. If created_now == false, the [[...]] text is preserved verbatim - let mut batch = batch; - let parse_result = batch - .first() - .map(|first_msg| crate::directives::parse_directives(&first_msg.prompt)); - - // Tentatively resolve [[ws:...]] — if resolution fails and the session turns out to - // be new, we abort. If the session already existed, resolution failure is irrelevant. - let ws_resolved: Option> = parse_result.as_ref().and_then(|pr| { - pr.metadata.raw.get("ws").map(|ws_value| { - let aliases = target.workspace_aliases(); - let bot_home = target.bot_home(); - crate::directives::resolve_workspace(ws_value, &aliases, &bot_home) - .map(|p| p.display().to_string()) - }) - }); - - // Extract workspace path for ensure_session (None if no directive or resolution failed). - let workspace_override: Option = - ws_resolved.as_ref().and_then(|r| r.as_ref().ok().cloned()); - - // Ensure session exists. The create_gate mutex inside get_or_create serializes - // concurrent callers — only the winner gets created_now == true. - let created_now = match target - .ensure_session(&session_key, workspace_override.as_deref()) - .await - { - Ok(created) => created, - Err(e) => { - let user_msg = format_user_error(&e.to_string()); - let _ = adapter - .send_message(&dispatch_channel, &format!("⚠️ {user_msg}")) - .await; - error!("pool error in dispatch_batch: {e}"); - return; - } - }; - - // Only apply directives if this is genuinely the first message (fresh session). - if created_now { - if let Some(pr) = parse_result { - if !pr.metadata.raw.is_empty() { - // Apply [[title:...]] independently — works regardless of ws outcome. - let title_to_apply = pr.metadata.title.clone(); - - // If workspace resolution failed on a NEW session, rollback and abort. - // Reset FIRST to minimize TOCTOU window (擺渡 F1), then rename. - if let Some(Err(e)) = ws_resolved { - target.reset_session(&session_key).await; - // Apply title after reset so the thread is identifiable. - if let Some(ref title) = title_to_apply { - if !title.is_empty() { - let _ = adapter.rename_thread(&dispatch_channel, title).await; - } - } - let _ = adapter - .send_message(&dispatch_channel, &format!("⚠️ {e}")) - .await; - error!(session_key, error = %e, "workspace directive rejected"); - return; - } - - // Strip directives from the prompt - if let Some(first_msg) = batch.first_mut() { - first_msg.prompt = pr.prompt; - } - - // Apply title on success path. - if let Some(ref title) = title_to_apply { - if !title.is_empty() { - if let Err(e) = adapter.rename_thread(&dispatch_channel, title).await { - warn!(session_key, error = %e, "failed to apply title directive"); - } - } - } - } - } - } - - for msg in batch { - let mut event_blocks = - AdapterRouter::pack_arrival_event(&msg.sender_json, &msg.prompt, msg.extra_blocks); - content_blocks.append(&mut event_blocks); - } - let packed_block_count = content_blocks.len(); - - let reactions_config = target.reactions_config().clone(); - let reactions = Arc::new(StatusReactionController::new( - reactions_config.enabled, - adapter.clone(), - trigger_msg, - reactions_config.emojis.clone(), - reactions_config.timing.clone(), - )); - // 👀 already applied above; skip set_queued() to avoid double-reaction. - - let result = target - .stream_prompt_blocks( - adapter, - &session_key, - content_blocks, - &dispatch_channel, - reactions.clone(), - other_bot_present, - recipient, - ) - .await; - - // In assistant status mode, all status is conveyed via - // assistant.threads.setStatus — skip emoji reactions entirely. - if !assistant_status { - match &result { - Ok(()) => reactions.set_done().await, - Err(_) => reactions.set_error().await, - } - - let hold_ms = if result.is_ok() { - reactions_config.timing.done_hold_ms - } else { - reactions_config.timing.error_hold_ms - }; - if reactions_config.remove_after_reply { - let reactions = reactions; - tokio::spawn(async move { - tokio::time::sleep(std::time::Duration::from_millis(hold_ms)).await; - reactions.clear().await; - }); - } - } - - if let Err(ref e) = result { - let _ = adapter - .send_message(&dispatch_channel, &format!("⚠️ {e}")) - .await; - } - - let agent_dispatch_ms = dispatch_start.elapsed().as_millis(); - let span = info_span!( - "dispatch", - channel = %thread_channel.channel_id, - adapter = adapter.platform(), - ); - let _enter = span.enter(); - info!( - thread_key = %thread_key, - events_per_dispatch = batch_size, - packed_block_count = packed_block_count, - agent_dispatch_ms = agent_dispatch_ms, - tokens_per_event = ?tokens_per_event, - wait_ms = ?wait_ms, - senders = ?senders, - "batch dispatched", - ); -} - -// --------------------------------------------------------------------------- -// Token estimation -// --------------------------------------------------------------------------- - -/// Rough char-to-token ratio for English-ish text. Coarse on purpose — the goal -/// is a guard rail for `max_batch_tokens`, not an exact pre-flight. -const CHARS_PER_TOKEN_ESTIMATE: usize = 4; -/// Conservative per-image token budget. Larger than typical Claude image cost -/// so the cap trips before we hand the model an oversized batch. -const TOKENS_PER_IMAGE_ESTIMATE: usize = 512; - -/// Rough token estimate for a buffered message (used for `max_batch_tokens` cap). -/// Intentionally coarse — the goal is a guard rail, not an exact pre-flight. -pub fn estimate_tokens(prompt: &str, extra_blocks: &[ContentBlock]) -> usize { - let text_tokens = prompt.len() / CHARS_PER_TOKEN_ESTIMATE + 1; - let block_tokens: usize = extra_blocks - .iter() - .map(|b| match b { - ContentBlock::Text { text } => text.len() / CHARS_PER_TOKEN_ESTIMATE + 1, - ContentBlock::Image { .. } => TOKENS_PER_IMAGE_ESTIMATE, - }) - .sum(); - text_tokens + block_tokens -} - -// --------------------------------------------------------------------------- -// Tests -// --------------------------------------------------------------------------- - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn estimate_tokens_empty() { - assert!(estimate_tokens("", &[]) >= 1); - } - - #[test] - fn estimate_tokens_text() { - // 400 chars ≈ 100 tokens - let s = "a".repeat(400); - assert_eq!(estimate_tokens(&s, &[]), 101); - } - - #[test] - fn estimate_tokens_image_block() { - let blocks = vec![ContentBlock::Image { - media_type: "image/png".into(), - data: "base64data".into(), - }]; - assert_eq!(estimate_tokens("", &blocks), 1 + 512); - } - - #[test] - fn pack_arrival_event_single() { - let blocks = - AdapterRouter::pack_arrival_event(r#"{"schema":"openab.sender.v1"}"#, "hello", vec![]); - // sender_context delimiter + prompt = 2 blocks - assert_eq!(blocks.len(), 2); - if let ContentBlock::Text { text } = &blocks[0] { - assert!(text.contains("")); - assert!(text.contains("")); - // Header is delimiter only — prompt lives in its own block. - assert!(!text.contains("hello")); - } else { - panic!("expected Text delimiter block"); - } - if let ContentBlock::Text { text } = &blocks[1] { - assert_eq!(text, "hello"); - } else { - panic!("expected Text prompt block"); - } - } - - #[test] - fn pack_arrival_event_with_extra_blocks() { - let extra = vec![ - ContentBlock::Text { - text: "[Voice transcript]: hi".into(), - }, - ContentBlock::Image { - media_type: "image/png".into(), - data: "abc".into(), - }, - ]; - let blocks = AdapterRouter::pack_arrival_event("{}", "prompt", extra); - // delimiter + transcript + prompt + image = 4 blocks - assert_eq!(blocks.len(), 4); - assert!( - matches!(&blocks[0], ContentBlock::Text { text } if text.contains("")) - ); - assert!( - matches!(&blocks[1], ContentBlock::Text { text } if text.contains("Voice transcript")) - ); - assert!(matches!(&blocks[2], ContentBlock::Text { text } if text == "prompt")); - assert!(matches!(&blocks[3], ContentBlock::Image { .. })); - } - - #[test] - fn pack_arrival_event_batch_n2() { - // Two arrival events concatenated → 2 (header + prompt) pairs = 4 blocks. - let mut all: Vec = Vec::new(); - all.extend(AdapterRouter::pack_arrival_event( - r#"{"ts":"T1"}"#, - "msg1", - vec![], - )); - all.extend(AdapterRouter::pack_arrival_event( - r#"{"ts":"T2"}"#, - "msg2", - vec![], - )); - assert_eq!(all.len(), 4); - if let ContentBlock::Text { text } = &all[0] { - assert!(text.contains(r#""ts":"T1""#)); - assert!(!text.contains("msg1")); - } - if let ContentBlock::Text { text } = &all[1] { - assert_eq!(text, "msg1"); - } - if let ContentBlock::Text { text } = &all[2] { - assert!(text.contains(r#""ts":"T2""#)); - assert!(!text.contains("msg2")); - } - if let ContentBlock::Text { text } = &all[3] { - assert_eq!(text, "msg2"); - } - } - - // ADR §3.6 Scenario B — text in one message, image in the next, same author. - // Broker preserves structural truth: image stays in M2 alone, both messages - // carry the same sender_id so the agent can semantically link them. - #[test] - fn pack_arrival_event_scenario_b_image_in_separate_message() { - let mut all: Vec = Vec::new(); - // M1 (alice): "see this image" - all.extend(AdapterRouter::pack_arrival_event( - r#"{"sender_id":"A","ts":"T1"}"#, - "see this image", - vec![], - )); - // M2 (alice): image, no text - all.extend(AdapterRouter::pack_arrival_event( - r#"{"sender_id":"A","ts":"T2"}"#, - "", - vec![ContentBlock::Image { - media_type: "image/png".into(), - data: "imgB".into(), - }], - )); - // header(M1) + prompt(M1) + header(M2) + image(M2) = 4 blocks - // (M2 has empty prompt, so its prompt block is omitted) - assert_eq!(all.len(), 4); - if let ContentBlock::Text { text } = &all[0] { - assert!(text.contains(r#""sender_id":"A""#)); - assert!(text.contains(r#""ts":"T1""#)); - } else { - panic!("expected Text delimiter for M1"); - } - if let ContentBlock::Text { text } = &all[1] { - assert_eq!(text, "see this image"); - } else { - panic!("expected Text prompt for M1"); - } - if let ContentBlock::Text { text } = &all[2] { - assert!(text.contains(r#""ts":"T2""#)); - } else { - panic!("expected Text delimiter for M2"); - } - // M2's image follows immediately after its delimiter (no prompt block). - assert!(matches!(&all[3], ContentBlock::Image { .. })); - } - - // ADR §3.6 Scenario C — fragmented multi-author batch. - // Repeated sender_id is preserved across non-adjacent messages; bob's interjection - // is kept as-is (no silent drop, no temporal reorder). - #[test] - fn pack_arrival_event_scenario_c_multi_author_interleaved() { - let mut all: Vec = Vec::new(); - all.extend(AdapterRouter::pack_arrival_event( - r#"{"sender_id":"A","ts":"T1"}"#, - "see this image", - vec![], - )); - all.extend(AdapterRouter::pack_arrival_event( - r#"{"sender_id":"B","ts":"T2"}"#, - "what?", - vec![], - )); - all.extend(AdapterRouter::pack_arrival_event( - r#"{"sender_id":"A","ts":"T3"}"#, - "", - vec![ContentBlock::Image { - media_type: "image/png".into(), - data: "imgC".into(), - }], - )); - // M1: header + prompt = 2 blocks - // M2: header + prompt = 2 blocks - // M3: header + image = 2 blocks (empty prompt → no prompt block) - // total = 6 - assert_eq!(all.len(), 6); - let h1 = match &all[0] { - ContentBlock::Text { text } => text, - _ => panic!("expected Text delimiter for M1"), - }; - let p1 = match &all[1] { - ContentBlock::Text { text } => text, - _ => panic!("expected Text prompt for M1"), - }; - let h2 = match &all[2] { - ContentBlock::Text { text } => text, - _ => panic!("expected Text delimiter for M2"), - }; - let p2 = match &all[3] { - ContentBlock::Text { text } => text, - _ => panic!("expected Text prompt for M2"), - }; - let h3 = match &all[4] { - ContentBlock::Text { text } => text, - _ => panic!("expected Text delimiter for M3"), - }; - assert!(h1.contains(r#""sender_id":"A""#) && h1.contains(r#""ts":"T1""#)); - assert_eq!(p1, "see this image"); - assert!(h2.contains(r#""sender_id":"B""#) && h2.contains(r#""ts":"T2""#)); - assert_eq!(p2, "what?"); - assert!(h3.contains(r#""sender_id":"A""#) && h3.contains(r#""ts":"T3""#)); - // M3's image attached to M3 only. - assert!(matches!(&all[5], ContentBlock::Image { .. })); - } - - // ADR §3.6 Scenario D — voice-only message in a batch. - // Within each arrival, transcript Text blocks precede the prompt block so the - // agent sees voice content before any typed text. The sender_context delimiter - // still opens each arrival. - #[test] - fn pack_arrival_event_scenario_d_voice_only() { - let mut all: Vec = Vec::new(); - all.extend(AdapterRouter::pack_arrival_event( - r#"{"sender_id":"A","ts":"T1"}"#, - "look at this", - vec![ContentBlock::Image { - media_type: "image/png".into(), - data: "scr".into(), - }], - )); - all.extend(AdapterRouter::pack_arrival_event( - r#"{"sender_id":"A","ts":"T2"}"#, - "", - vec![ContentBlock::Text { - text: "[Voice message transcript]: hey can we sync about the deploy".into(), - }], - )); - all.extend(AdapterRouter::pack_arrival_event( - r#"{"sender_id":"B","ts":"T3"}"#, - "what?", - vec![], - )); - // M1: header + prompt + image = 3 - // M2: header + transcript = 2 (empty prompt → no prompt block) - // M3: header + prompt = 2 - // total = 7 - assert_eq!(all.len(), 7); - if let ContentBlock::Text { text } = &all[0] { - assert!(text.contains(r#""ts":"T1""#)); - assert!(!text.contains("look at this")); - } - if let ContentBlock::Text { text } = &all[1] { - assert_eq!(text, "look at this"); - } - assert!(matches!(&all[2], ContentBlock::Image { .. })); - if let ContentBlock::Text { text } = &all[3] { - assert!(text.contains(r#""ts":"T2""#)); - } - // Transcript precedes prompt (and prompt is omitted here because empty). - if let ContentBlock::Text { text } = &all[4] { - assert!(text.contains("Voice message transcript")); - assert!(text.contains("sync about the deploy")); - } else { - panic!("expected transcript Text block after M2 delimiter"); - } - if let ContentBlock::Text { text } = &all[5] { - assert!(text.contains(r#""sender_id":"B""#)); - } - if let ContentBlock::Text { text } = &all[6] { - assert_eq!(text, "what?"); - } - } - - // Token-cap math: a single message that already exceeds max_batch_tokens still - // dispatches alone (the consumer_loop logic admits the first message before - // checking the cap). Verifies estimate_tokens scales with input length. - #[test] - fn estimate_tokens_oversized_single_message() { - // ~24k token text (96000 chars / 4 chars-per-token). - let big = "x".repeat(96_000); - let est = estimate_tokens(&big, &[]); - assert!(est > 24_000, "expected >24k tokens, got {est}"); - } - - // Cumulative token math: two messages whose sum exceeds max_batch_tokens. - // The consumer_loop reads first, then peeks at the next; if cumulative tokens - // > cap, the second is held over to the next batch (FIFO preserved). - #[test] - fn estimate_tokens_cumulative_exceeds_cap() { - let max_tokens = 24_000_usize; - let m1 = estimate_tokens(&"a".repeat(80_000), &[]); - let m2 = estimate_tokens(&"b".repeat(50_000), &[]); - assert!(m1 < max_tokens); - assert!(m1 + m2 > max_tokens, "{m1} + {m2} should exceed cap"); - } - - // ADR §2.5 race-safe eviction. The full SendError path requires a real - // AdapterRouter (concrete struct, not a trait — no easy mock seam), so we - // unit-test the eviction predicate in isolation. End-to-end consumer-death - // recovery is exercised by the manual staging smoke documented in the ADR. - fn dummy_handle(generation: u64) -> ThreadHandle { - let (tx, _rx) = tokio::sync::mpsc::channel::(1); - let consumer = tokio::spawn(async {}); - ThreadHandle { - tx, - consumer, - generation, - channel_id: "C".into(), - adapter_kind: "discord".into(), - } - } - - #[tokio::test] - async fn try_evict_locked_removes_when_generation_matches() { - let mut map: HashMap = HashMap::new(); - map.insert("t".into(), dummy_handle(7)); - assert!(Dispatcher::try_evict_locked(&mut map, "t", 7)); - assert!(map.is_empty()); - } - - // The bug §2.5 prevents: a stale producer (my_gen=7) observing SendError - // must not remove a freshly inserted handle (gen=8) created by another - // submit between the failed send and the eviction attempt. - #[tokio::test] - async fn try_evict_locked_keeps_when_generation_differs() { - let mut map: HashMap = HashMap::new(); - map.insert("t".into(), dummy_handle(8)); - assert!(!Dispatcher::try_evict_locked(&mut map, "t", 7)); - assert_eq!(map.len(), 1); - assert_eq!(map.get("t").unwrap().generation, 8); - } - - #[tokio::test] - async fn try_evict_locked_returns_false_when_absent() { - let mut map: HashMap = HashMap::new(); - assert!(!Dispatcher::try_evict_locked(&mut map, "missing", 0)); - } - - // BatchGrouping → thread_key shape. - fn make_dispatcher(grouping: BatchGrouping) -> Dispatcher { - // The router is wrapped in Arc but never used by `key()` itself; we use - // a dummy AdapterRouter built via the same path main.rs would use. - // For a pure-keying test we'd ideally not need it, but the constructor demands one. - // Construct a minimal router via the public test helpers in adapter.rs if available; - // otherwise we fall back to building one with a dummy SessionPool. - use crate::acp::SessionPool; - let agent_cfg = crate::config::AgentConfig { - command: "/bin/true".into(), - args: vec![], - working_dir: "/tmp".into(), - env: std::collections::HashMap::new(), - inherit_env: vec![], - command_explicit: true, - }; - let pool = Arc::new(SessionPool::new(agent_cfg, 1)); - let router = Arc::new(AdapterRouter::new( - pool, - crate::config::ReactionsConfig::default(), - crate::markdown::TableMode::Off, - crate::config::default_prompt_hard_timeout_secs(), - crate::config::default_liveness_check_secs(), - std::collections::HashMap::new(), - std::path::PathBuf::from("/tmp"), - )); - Dispatcher::with_idle_timeout(router, 10, 24_000, grouping, DEFAULT_CONSUMER_IDLE_TIMEOUT) - } - - #[tokio::test] - async fn key_per_thread_ignores_sender() { - let d = make_dispatcher(BatchGrouping::Thread); - assert_eq!(d.key("discord", "T1", "userA"), "discord:T1"); - assert_eq!(d.key("discord", "T1", "userB"), "discord:T1"); - } - - #[tokio::test] - async fn key_per_lane_includes_sender() { - let d = make_dispatcher(BatchGrouping::Lane); - assert_eq!(d.key("discord", "T1", "userA"), "discord:T1:userA"); - assert_eq!(d.key("discord", "T1", "userB"), "discord:T1:userB"); - // Different threads remain distinct. - assert_eq!(d.key("slack", "T2", "userA"), "slack:T2:userA"); - } - - fn insert_dummy_handle(d: &Dispatcher, key: &str) { - let (tx, _rx) = tokio::sync::mpsc::channel::(10); - let consumer = tokio::spawn(async {}); - let handle = ThreadHandle { - tx, - consumer, - generation: 0, - channel_id: "c".into(), - adapter_kind: "discord".into(), - }; - d.per_thread.lock().unwrap().insert(key.to_string(), handle); - } - - #[tokio::test] - async fn cancel_buffered_thread_drops_per_thread_key() { - let d = make_dispatcher(BatchGrouping::Thread); - insert_dummy_handle(&d, "discord:T1"); - insert_dummy_handle(&d, "discord:T2"); // different thread, must survive - assert_eq!(d.cancel_buffered_thread("discord", "T1"), 0); // no buffered msgs - let map = d.per_thread.lock().unwrap(); - assert!(!map.contains_key("discord:T1")); - assert!(map.contains_key("discord:T2")); - } - - #[tokio::test] - async fn cancel_buffered_thread_drops_all_lanes() { - let d = make_dispatcher(BatchGrouping::Lane); - insert_dummy_handle(&d, "discord:T1:userA"); - insert_dummy_handle(&d, "discord:T1:userB"); - insert_dummy_handle(&d, "discord:T2:userA"); // different thread - insert_dummy_handle(&d, "slack:T1:userA"); // different platform - d.cancel_buffered_thread("discord", "T1"); - let map = d.per_thread.lock().unwrap(); - assert!(!map.contains_key("discord:T1:userA")); - assert!(!map.contains_key("discord:T1:userB")); - assert!(map.contains_key("discord:T2:userA")); - assert!(map.contains_key("slack:T1:userA")); - } - - #[tokio::test] - async fn cancel_buffered_thread_does_not_match_thread_id_prefix() { - // T1 must not match T10 / T11 (substring trap). - let d = make_dispatcher(BatchGrouping::Lane); - insert_dummy_handle(&d, "discord:T1:userA"); - insert_dummy_handle(&d, "discord:T10:userA"); - d.cancel_buffered_thread("discord", "T1"); - let map = d.per_thread.lock().unwrap(); - assert!(!map.contains_key("discord:T1:userA")); - assert!(map.contains_key("discord:T10:userA")); - } - - // Long-running consumer that parks until aborted — used by sweep_stale / - // shutdown tests to exercise the "still alive" path. - fn alive_consumer_handle() -> ThreadHandle { - let (tx, _rx) = tokio::sync::mpsc::channel::(10); - let consumer = tokio::spawn(async { - std::future::pending::<()>().await; - }); - ThreadHandle { - tx, - consumer, - generation: 0, - channel_id: "c".into(), - adapter_kind: "discord".into(), - } - } - - #[tokio::test] - async fn sweep_stale_removes_finished_consumers() { - let d = make_dispatcher(BatchGrouping::Thread); - insert_dummy_handle(&d, "discord:T1"); - insert_dummy_handle(&d, "discord:T2"); - // Yield so the empty-body spawned tasks actually run to completion - // before is_finished() is checked. - tokio::time::sleep(Duration::from_millis(10)).await; - let swept = d.sweep_stale(); - assert_eq!(swept, 2); - assert!(d.per_thread.lock().unwrap().is_empty()); - } - - #[tokio::test] - async fn sweep_stale_keeps_running_consumers() { - let d = make_dispatcher(BatchGrouping::Thread); - let abort = { - let h = alive_consumer_handle(); - let a = h.consumer.abort_handle(); - d.per_thread.lock().unwrap().insert("alive".into(), h); - a - }; - let swept = d.sweep_stale(); - assert_eq!(swept, 0); - assert!(d.per_thread.lock().unwrap().contains_key("alive")); - // Cleanup so the parked task doesn't linger across tests. - abort.abort(); - } - - #[tokio::test] - async fn shutdown_clears_all_handles() { - let d = make_dispatcher(BatchGrouping::Thread); - insert_dummy_handle(&d, "k1"); - insert_dummy_handle(&d, "k2"); - insert_dummy_handle(&d, "k3"); - d.shutdown(); - assert!(d.per_thread.lock().unwrap().is_empty()); - } - - #[tokio::test] - async fn shutdown_aborts_running_consumers() { - let d = make_dispatcher(BatchGrouping::Thread); - let abort = { - let h = alive_consumer_handle(); - let a = h.consumer.abort_handle(); - d.per_thread.lock().unwrap().insert("k".into(), h); - a - }; - d.shutdown(); - // Give the runtime a tick to process abort + map drop. - tokio::time::sleep(Duration::from_millis(10)).await; - assert!(abort.is_finished()); - } - - // ----------------------------------------------------------------------- - // consumer_loop / dispatch_batch integration tests (NIT 2) - // - // These drive `consumer_loop` directly with a pre-populated mpsc, using - // `MockDispatchTarget` to record the calls that would otherwise hit a - // real `AdapterRouter` (and through it, ACP CLI subprocesses). This - // gives deterministic coverage of the orchestration paths the existing - // unit tests don't reach: greedy drain, token-cap overflow, idle timeout. - // ----------------------------------------------------------------------- - - /// One recorded `stream_prompt_blocks` invocation. - #[derive(Clone)] - struct RecordedDispatch { - block_count: usize, - other_bot_present: bool, - dispatch_channel: ChannelRef, - } - - /// Mock `DispatchTarget` — records calls; never touches a real session pool. - struct MockDispatchTarget { - reactions: ReactionsConfig, - calls: Mutex>, - /// If set, `ensure_session` returns this error once. - ensure_err: Mutex>, - /// If set, `stream_prompt_blocks` returns this error once. - stream_err: Mutex>, - } - - impl MockDispatchTarget { - fn new() -> Self { - Self { - reactions: ReactionsConfig::default(), - calls: Mutex::new(Vec::new()), - ensure_err: Mutex::new(None), - stream_err: Mutex::new(None), - } - } - - fn calls(&self) -> Vec { - self.calls.lock().unwrap().clone() - } - } - - #[async_trait] - impl DispatchTarget for MockDispatchTarget { - fn reactions_config(&self) -> &ReactionsConfig { - &self.reactions - } - - fn workspace_aliases(&self) -> std::collections::HashMap { - std::collections::HashMap::new() - } - - fn bot_home(&self) -> std::path::PathBuf { - std::path::PathBuf::from("/tmp") - } - - async fn ensure_session( - &self, - _session_key: &str, - _working_dir: Option<&str>, - ) -> Result { - if let Some(msg) = self.ensure_err.lock().unwrap().take() { - return Err(anyhow::anyhow!(msg)); - } - Ok(true) - } - - async fn reset_session(&self, _session_key: &str) {} - - async fn stream_prompt_blocks( - &self, - _adapter: &Arc, - _session_key: &str, - content_blocks: Vec, - thread_channel: &ChannelRef, - _reactions: Arc, - other_bot_present: bool, - _recipient: Option<(String, String)>, - ) -> Result<()> { - self.calls.lock().unwrap().push(RecordedDispatch { - block_count: content_blocks.len(), - other_bot_present, - dispatch_channel: thread_channel.clone(), - }); - if let Some(msg) = self.stream_err.lock().unwrap().take() { - return Err(anyhow::anyhow!(msg)); - } - Ok(()) - } - } - - /// Mock `ChatAdapter` — every method is a no-op success. The dispatch loop - /// invokes `add_reaction` (queued 👀), `platform`, and on the error path - /// `send_message`; nothing else needs real behavior here. - struct MockChatAdapter; - - #[async_trait] - impl ChatAdapter for MockChatAdapter { - fn platform(&self) -> &'static str { - "mock" - } - fn message_limit(&self) -> usize { - 2000 - } - - async fn send_message(&self, channel: &ChannelRef, _content: &str) -> Result { - Ok(MessageRef { - channel: channel.clone(), - message_id: "mock-msg".into(), - }) - } - - async fn create_thread( - &self, - channel: &ChannelRef, - _trigger_msg: &MessageRef, - _title: &str, - ) -> Result { - Ok(channel.clone()) - } - - async fn add_reaction(&self, _msg: &MessageRef, _emoji: &str) -> Result<()> { - Ok(()) - } - async fn remove_reaction(&self, _msg: &MessageRef, _emoji: &str) -> Result<()> { - Ok(()) - } - fn use_streaming(&self, _other_bot_present: bool) -> bool { - false - } - } - - fn make_channel(thread: &str) -> ChannelRef { - ChannelRef { - platform: "mock".into(), - channel_id: thread.into(), - thread_id: Some(thread.into()), - parent_id: None, - origin_event_id: None, - } - } - - fn make_msg(prompt: &str, tokens: usize) -> BufferedMessage { - BufferedMessage { - sender_json: r#"{"schema":"openab.sender.v1","sender_id":"u","sender_name":"u"}"# - .into(), - sender_name: "u".into(), - prompt: prompt.into(), - extra_blocks: vec![], - trigger_msg: MessageRef { - channel: make_channel("T"), - message_id: format!("m-{prompt}"), - }, - arrived_at: Instant::now(), - estimated_tokens: tokens, - other_bot_present: false, - recipient: None, - } - } - - /// Pre-load `msgs` into a fresh mpsc, drop the sender, and run - /// `consumer_loop` to completion. Returns the recorded dispatches. - async fn run_consumer_with_messages( - msgs: Vec, - max_batch: usize, - max_tokens: usize, - ) -> Vec { - let mock = Arc::new(MockDispatchTarget::new()); - let target: Arc = mock.clone(); - let adapter: Arc = Arc::new(MockChatAdapter); - let (tx, rx) = tokio::sync::mpsc::channel::(msgs.len().max(1)); - for m in msgs { - tx.send(m).await.unwrap(); - } - drop(tx); - - consumer_loop( - "mock:T".into(), - make_channel("T"), - rx, - target, - adapter, - max_batch, - max_tokens, - Duration::from_secs(60), - ) - .await; - - mock.calls() - } - - #[tokio::test] - async fn consumer_dispatches_single_message_as_one_batch() { - let calls = run_consumer_with_messages(vec![make_msg("hi", 10)], 10, 24_000).await; - assert_eq!(calls.len(), 1); - // pack_arrival_event with no extra_blocks → delimiter + prompt = 2 blocks. - assert_eq!(calls[0].block_count, 2); - assert!(!calls[0].other_bot_present); - } - - #[tokio::test] - async fn consumer_greedy_drain_combines_queued_messages_into_one_batch() { - // 3 messages already in the queue when the consumer wakes → greedy - // drain pulls all 3, packs them into one batch, dispatches once. - let calls = run_consumer_with_messages( - vec![make_msg("a", 50), make_msg("b", 50), make_msg("c", 50)], - 10, - 24_000, - ) - .await; - assert_eq!(calls.len(), 1, "expected a single batched dispatch"); - // 3 arrivals × (delimiter + prompt) = 6 blocks. - assert_eq!(calls[0].block_count, 6); - } - - #[tokio::test] - async fn consumer_token_cap_splits_batch_preserving_fifo() { - // max_tokens=100, two 80-token messages → cumulative 160 > 100, so - // msg2 becomes `pending` and is dispatched in the next batch. - let calls = - run_consumer_with_messages(vec![make_msg("a", 80), make_msg("b", 80)], 10, 100).await; - assert_eq!(calls.len(), 2, "token cap should split into two batches"); - // Each batch holds one arrival → delimiter + prompt = 2 blocks. - assert_eq!(calls[0].block_count, 2); - assert_eq!(calls[1].block_count, 2); - } - - #[tokio::test] - async fn consumer_dispatch_uses_last_event_origin_event_id_for_merged_batch() { - let mut first = make_msg("a", 80); - first.trigger_msg.channel.origin_event_id = Some("evt-first".into()); - let mut second = make_msg("b", 80); - second.trigger_msg.channel.origin_event_id = Some("evt-second".into()); - - let calls = run_consumer_with_messages(vec![first, second], 10, 200).await; - assert_eq!(calls.len(), 1); - assert_eq!( - calls[0].dispatch_channel.origin_event_id.as_deref(), - Some("evt-second") - ); - } - - #[tokio::test] - async fn consumer_dispatch_preserves_thread_route_while_refreshing_origin_event_id() { - let mock = Arc::new(MockDispatchTarget::new()); - let target: Arc = mock.clone(); - let adapter: Arc = Arc::new(MockChatAdapter); - let (tx, rx) = tokio::sync::mpsc::channel::(1); - - let mut msg = make_msg("hi", 10); - msg.trigger_msg.channel = ChannelRef { - platform: "mock".into(), - channel_id: "parent-channel".into(), - thread_id: None, - parent_id: None, - origin_event_id: Some("evt-fresh".into()), - }; - tx.send(msg).await.unwrap(); - drop(tx); - - consumer_loop( - "mock:topic-42".into(), - ChannelRef { - platform: "mock".into(), - channel_id: "topic-42".into(), - thread_id: Some("topic-42".into()), - parent_id: Some("parent-channel".into()), - origin_event_id: Some("evt-stale".into()), - }, - rx, - target, - adapter, - 10, - 24_000, - Duration::from_secs(60), - ) - .await; - - let calls = mock.calls(); - assert_eq!(calls.len(), 1); - assert_eq!(calls[0].dispatch_channel.channel_id, "topic-42"); - assert_eq!( - calls[0].dispatch_channel.thread_id.as_deref(), - Some("topic-42") - ); - assert_eq!( - calls[0].dispatch_channel.parent_id.as_deref(), - Some("parent-channel") - ); - assert_eq!( - calls[0].dispatch_channel.origin_event_id.as_deref(), - Some("evt-fresh") - ); - } - - #[tokio::test] - async fn consumer_exits_after_idle_timeout_with_no_messages() { - // No messages ever arrive; consumer should exit once `idle_timeout` - // elapses. Keep `tx` alive so the exit path is the timeout, not the - // "all senders dropped" branch. - let mock = Arc::new(MockDispatchTarget::new()); - let target: Arc = mock.clone(); - let adapter: Arc = Arc::new(MockChatAdapter); - let (tx, rx) = tokio::sync::mpsc::channel::(1); - let consumer = tokio::spawn(consumer_loop( - "mock:T".into(), - make_channel("T"), - rx, - target, - adapter, - 10, - 24_000, - Duration::from_millis(50), - )); - // Wait enough for the timeout branch + a tick for the task to finish. - tokio::time::sleep(Duration::from_millis(150)).await; - assert!( - consumer.is_finished(), - "consumer should exit after idle timeout" - ); - // No dispatches should have been recorded. - assert!(mock.calls().is_empty()); - drop(tx); - } - - #[tokio::test] - async fn submit_evicts_dead_handle_and_retries_with_fresh_consumer() { - // §2.5: if `tx.send()` returns `SendError` (consumer's rx dropped - // mid-flight), `submit` evicts the stale entry under lock and spawns - // a fresh consumer. Manufacture this state by inserting a handle - // whose consumer is still parked but whose rx has been dropped. - let mock = Arc::new(MockDispatchTarget::new()); - let target: Arc = mock.clone(); - let d = Dispatcher::with_idle_timeout( - target, - 10, - 24_000, - BatchGrouping::Thread, - DEFAULT_CONSUMER_IDLE_TIMEOUT, - ); - let adapter: Arc = Arc::new(MockChatAdapter); - - let key = "mock:T".to_string(); - let parked = { - let (tx, rx) = tokio::sync::mpsc::channel::(10); - drop(rx); // closes the channel → next tx.send() yields SendError - let consumer = tokio::spawn(std::future::pending::<()>()); - let abort = consumer.abort_handle(); - let handle = ThreadHandle { - tx, - consumer, - generation: 999, - channel_id: "T".into(), - adapter_kind: "mock".into(), - }; - d.per_thread.lock().unwrap().insert(key.clone(), handle); - abort - }; - - d.submit(key, make_channel("T"), adapter, make_msg("hello", 10)) - .await - .expect("retry should spawn a fresh consumer"); - // Give the freshly spawned consumer time to drain + dispatch. - tokio::time::sleep(Duration::from_millis(50)).await; - - let calls = mock.calls(); - assert_eq!( - calls.len(), - 1, - "fresh consumer should have dispatched the retry" - ); - // pack_arrival_event with no extra_blocks → delimiter + prompt = 2 blocks. - assert_eq!(calls[0].block_count, 2); - - parked.abort(); - } -} diff --git a/src/error_display.rs b/src/error_display.rs deleted file mode 100644 index c8826dcbf..000000000 --- a/src/error_display.rs +++ /dev/null @@ -1,323 +0,0 @@ -/// Format any error for user display in Discord. -/// -/// Handles two error categories: -/// - **Coded errors** (code != 0): JSON-RPC or HTTP status codes from upstream agent. -/// - **Startup/connection errors** (code == 0): Errors from pool.rs or connection.rs -/// where only the message string is available. -/// -/// Provider-agnostic: no provider-specific strings, message text passed through verbatim. -pub fn format_user_error(message: &str) -> String { - let msg_lower = message.to_lowercase(); - - // Startup / connection errors (code == 0 from anyhow) - if msg_lower.contains("timeout waiting for") { - // Use msg_lower for extraction to stay case-insistent with the match above. - // msg_lower and message are the same length, so byte offsets are valid. - if let Some(start) = msg_lower.find("timeout waiting for ") { - let rest = &message[start + "timeout waiting for ".len()..]; - let method = rest.split_whitespace().next().unwrap_or("request"); - return format!( - "**Request Timeout**\nTimeout waiting for {}, please try again.", - method - ); - } - return "**Request Timeout**\nTimeout waiting for a response, please try again." - .to_string(); - } - if msg_lower.contains("connection closed") || msg_lower.contains("channel closed") { - return "**Connection Lost**\nThe connection to the agent was lost, please try again." - .to_string(); - } - if msg_lower.contains("failed to spawn") || msg_lower.contains("no such file") { - return "**Agent Not Found**\nCould not start the agent — please check your configuration." - .to_string(); - } - if msg_lower.contains("pool exhausted") { - return "**Service Busy**\nAll agent sessions are in use, please try again shortly." - .to_string(); - } - if msg_lower.contains("invalid api key") || msg_lower.contains("unauthorized") { - return "**Unauthorized**\nPlease check your API key configuration.".to_string(); - } - - // Unknown error — pass through as-is - if message.is_empty() { - "**Error**\nAn unknown error occurred.".to_string() - } else { - format!("**Error**\n{}", message) - } -} - -/// Format coded error from ACP agent for display in Discord. -/// Used for response errors that have a JSON-RPC or HTTP status code. -/// `data_message` is the optional detail extracted from `error.data.message`. -/// Public for reuse by other adapters (e.g. Slack). -pub fn format_coded_error(code: i64, message: &str, data_message: Option<&str>) -> String { - let prefix = match code { - 400 => "**Bad Request**", - 401 => "**Unauthorized**", - 403 => "**Forbidden**", - 404 => "**Not Found**", - 408 => "**Request Timeout**", - 429 => "**Rate Limited**", - 500 => "**Internal Server Error**", - 502 => "**Bad Gateway**", - 503 => "**Service Unavailable**", - 504 => "**Gateway Timeout**", - -32600 => "**Invalid Request**", - -32601 => "**Method Not Found**", - -32602 => "**Invalid Params**", - -32603 => "**Internal Error**", - -32099..=-32000 => "**Server Error**", - _ => "**Error**", - }; - let mut out = if message.is_empty() { - format!("{} (code: {})", prefix, code) - } else { - format!("{} (code: {})\n{}", prefix, code, message) - }; - let detail = data_message.filter(|s| !s.trim().is_empty()); - if let Some(detail) = detail { - if !message.contains(detail) { - out.push_str("\n> "); - out.push_str(detail); - } - } else if code == -32603 { - out.push_str( - "\n\n_The agent did not return any error details. \ - Please check the agent's own logs for more information._", - ); - } - out -} - -#[cfg(test)] -mod tests { - use super::*; - - // ─── format_user_error tests ───────────────────────────────────────────── - - #[test] - fn format_user_error_timeout() { - let result = format_user_error("timeout waiting for session/new response"); - assert!(result.contains("Request Timeout")); - assert!(result.contains("session/new")); - } - - #[test] - fn format_user_error_connection_closed() { - let result = format_user_error("connection closed"); - assert!(result.contains("Connection Lost")); - } - - #[test] - fn format_user_error_channel_closed() { - let result = format_user_error("channel closed"); - assert!(result.contains("Connection Lost")); - } - - #[test] - fn format_user_error_failed_to_spawn() { - let result = format_user_error("failed to spawn /some/path: No such file"); - assert!(result.contains("Agent Not Found")); - assert!(result.contains("the agent")); // generic, no provider name - } - - #[test] - fn format_user_error_no_such_file() { - let result = format_user_error("binary /usr/bin/nonexistent: no such file"); - assert!(result.contains("Agent Not Found")); - } - - #[test] - fn format_user_error_pool_exhausted() { - let result = format_user_error("pool exhausted (5 sessions)"); - assert!(result.contains("Service Busy")); - } - - #[test] - fn format_user_error_invalid_api_key() { - let result = format_user_error("invalid api key"); - assert!(result.contains("Unauthorized")); - } - - #[test] - fn format_user_error_unauthorized() { - let result = format_user_error("unauthorized: token rejected"); - assert!(result.contains("Unauthorized")); - } - - #[test] - fn format_user_error_unknown() { - let result = format_user_error("something went wrong"); - assert!(result.contains("Error")); - assert!(result.contains("something went wrong")); - } - - #[test] - fn format_user_error_empty() { - let result = format_user_error(""); - assert!(result.contains("Error")); - assert!(result.contains("unknown")); - } - - #[test] - fn format_user_error_case_insensitive() { - assert!(format_user_error("TIMEOUT WAITING FOR foo").contains("Timeout")); - assert!(format_user_error("CONNECTION CLOSED").contains("Connection")); - assert!(format_user_error("POOL EXHAUSTED").contains("Busy")); - } - - #[test] - fn format_user_error_mixed_case_timeout() { - // Case-insensitive matching should still extract method correctly - let result = format_user_error("Timeout Waiting For custom/method"); - assert!(result.contains("Request Timeout")); - assert!(result.contains("custom/method")); - } - - // ─── format_coded_error tests ─────────────────────────────────────────── - - #[test] - fn format_coded_error_401() { - let result = format_coded_error(401, "invalid token", None); - assert!(result.contains("Unauthorized")); - assert!(result.contains("401")); - assert!(result.contains("invalid token")); - } - - #[test] - fn format_coded_error_429() { - let result = format_coded_error(429, "", None); - assert!(result.contains("Rate Limited")); - assert!(result.contains("429")); - assert!(!result.contains("\n")); // no message, no newline - } - - #[test] - fn format_coded_error_503() { - let result = format_coded_error(503, "service unavailable", None); - assert!(result.contains("Service Unavailable")); - assert!(result.contains("503")); - assert!(result.contains("service unavailable")); - } - - #[test] - fn format_coded_error_json_rpc() { - let result = format_coded_error(-32602, "missing required parameter", None); - assert!(result.contains("Invalid Params")); - assert!(result.contains("-32602")); - } - - #[test] - fn format_coded_error_server_error_range() { - let result = format_coded_error(-32050, "internal failure", None); - assert!(result.contains("Server Error")); - assert!(result.contains("-32050")); - } - - #[test] - fn format_coded_error_connection_error() { - let result = format_coded_error(-32000, "connection refused", None); - assert!(result.contains("Server Error")); // -32000 falls in -32099..=-32000 range - assert!(result.contains("-32000")); - } - - #[test] - fn format_coded_error_unknown_code() { - let result = format_coded_error(999, "something happened", None); - assert!(result.contains("Error")); - assert!(result.contains("999")); - assert!(result.contains("something happened")); - } - - #[test] - fn format_coded_error_with_data_message() { - let result = format_coded_error(-32603, "Internal error", Some("model not supported")); - assert!(result.contains("Internal Error")); - assert!(result.contains("model not supported")); - } - - #[test] - fn format_coded_error_data_message_not_duplicated() { - // If data_message is already in message, don't repeat it - let result = format_coded_error(-32603, "model not supported", Some("model not supported")); - assert_eq!(result.matches("model not supported").count(), 1); - } - - #[test] - fn format_coded_error_32603_no_detail_shows_fallback() { - let result = format_coded_error(-32603, "Internal error", None); - assert!(result.contains("Internal Error")); - assert!(result.contains("did not return any error details")); - assert!(result.contains("agent's own logs")); - } - - #[test] - fn format_coded_error_32603_with_detail_no_fallback() { - let result = format_coded_error(-32603, "Internal error", Some("model not found")); - assert!(result.contains("model not found")); - assert!(!result.contains("did not return any error details")); - } - - #[test] - fn format_coded_error_32603_empty_detail_shows_fallback() { - let result = format_coded_error(-32603, "Internal error", Some("")); - assert!(result.contains("did not return any error details")); - } - - #[test] - fn format_coded_error_other_code_no_detail_no_fallback() { - // Fallback only applies to -32603 - let result = format_coded_error(-32602, "bad params", None); - assert!(!result.contains("did not return any error details")); - } - - #[test] - fn format_coded_error_32603_empty_message_still_shows_fallback() { - // Even when message is empty, fallback should appear - let result = format_coded_error(-32603, "", None); - assert!(result.contains("Internal Error")); - assert!(result.contains("did not return any error details")); - } - - #[test] - fn format_coded_error_32603_whitespace_detail_shows_fallback() { - // Whitespace-only detail should be treated as empty - let result = format_coded_error(-32603, "Internal error", Some(" ")); - assert!(result.contains("Internal Error")); - assert!(result.contains("did not return any error details")); - } - - #[test] - fn format_coded_error_500_no_detail_no_fallback() { - // HTTP 500 without detail should NOT get the ACP-specific hint - let result = format_coded_error(500, "server error", None); - assert!(result.contains("Internal Server Error")); - assert!(!result.contains("did not return any error details")); - } - - #[test] - fn format_coded_error_32603_fallback_does_not_duplicate_with_detail() { - // When detail is present, no fallback appears — mutually exclusive - let result = format_coded_error(-32603, "Internal error", Some("rate limit exceeded")); - assert!(result.contains("rate limit exceeded")); - assert!(!result.contains("did not return any error details")); - assert!(!result.contains("agent's own logs")); - } - - #[test] - fn format_coded_error_server_error_range_no_fallback() { - // Other JSON-RPC server error codes should NOT get the hint - let result = format_coded_error(-32099, "custom error", None); - assert!(!result.contains("did not return any error details")); - } - - #[test] - fn format_coded_error_32603_fallback_message_is_italic() { - // Verify Discord markdown italic formatting - let result = format_coded_error(-32603, "Internal error", None); - assert!(result.contains("_The agent did not return")); - assert!(result.ends_with("_")); - } -} diff --git a/src/format.rs b/src/format.rs deleted file mode 100644 index d39410f15..000000000 --- a/src/format.rs +++ /dev/null @@ -1,327 +0,0 @@ -/// Split text into chunks at line boundaries, each <= limit Unicode characters (UTF-8 safe). -/// Discord's message limit counts Unicode characters, not bytes. -/// -/// Fenced code blocks (``` ... ```) are handled specially: if a split falls inside a -/// code block, the current chunk is closed with ``` and the next chunk is reopened with -/// the original opener (preserving language tag), so each chunk renders correctly. -/// -/// Invariant: every returned chunk satisfies `chunk.chars().count() <= limit`. -pub fn split_message(text: &str, limit: usize) -> Vec { - if text.chars().count() <= limit { - return vec![text.to_string()]; - } - - let mut chunks = Vec::new(); - let mut current = String::new(); - let mut current_len: usize = 0; - // When inside a fenced code block, holds the full opener line (e.g. "```rust"). - let mut fence_opener: Option = None; - - // Cost of appending "\n```" to close a fence before emitting a chunk. - const CLOSE_COST: usize = 4; // '\n' + '`' + '`' + '`' - - for line in text.split('\n') { - let line_chars = line.chars().count(); - let is_fence_line = line.starts_with("```"); - - // Determine overhead that must be reserved when inside a fence. - let close_reserve = if fence_opener.is_some() && !is_fence_line { - CLOSE_COST - } else { - 0 - }; - - // Check whether appending this line (+ newline separator + close reserve) overflows. - if !current.is_empty() && current_len + 1 + line_chars + close_reserve > limit { - // Emit current chunk, closing fence if needed. - if let Some(ref opener) = fence_opener { - if !is_fence_line { - current.push_str("\n```"); - } - chunks.push(std::mem::take(&mut current)); - // Reopen fence in next chunk with full opener (preserves language tag). - current.push_str(opener); - current_len = opener.chars().count(); - - if is_fence_line { - // The closing fence marker itself triggers the split. - fence_opener = None; - current.push('\n'); - current_len += 1; - current.push_str(line); - current_len += line_chars; - continue; - } else if current_len + 1 + line_chars + CLOSE_COST <= limit { - // Line fits in the reopened chunk (with room for \n + line + close marker). - current.push('\n'); - current_len += 1; - current.push_str(line); - current_len += line_chars; - continue; - } - // Otherwise: line doesn't fit even in a fresh reopened chunk. - // Fall through to the normal line-processing logic below, - // which will hit the hard-split path if line_chars > limit, - // or the normal append path otherwise. - } else { - chunks.push(std::mem::take(&mut current)); - current_len = 0; - } - } - - // Newline separator between lines within a chunk. - if !current.is_empty() { - current.push('\n'); - current_len += 1; - } - - // Track fence state. - if is_fence_line { - if fence_opener.is_some() { - fence_opener = None; - } else { - fence_opener = Some(line.to_string()); - } - } - - // Hard-split: single line exceeds available space. - // This triggers when the line itself is longer than limit, OR when the - // line doesn't fit in the current chunk even after accounting for fence - // close overhead (e.g. after a reopen where opener already consumed space). - let effective_avail = if fence_opener.is_some() { - limit.saturating_sub(current_len + CLOSE_COST) - } else { - limit.saturating_sub(current_len) - }; - if line_chars > effective_avail { - let overhead = if let Some(ref opener) = fence_opener { - // opener + '\n' at start, '\n```' at end - opener.chars().count() + 1 + CLOSE_COST - } else { - 0 - }; - // If limit can't even fit overhead, fall back to unfenced hard-split. - let capacity = limit.saturating_sub(overhead); - if let Some(opener) = fence_opener.as_ref().filter(|_| capacity > 0) { - // Fenced hard-split: each mid chunk = opener\n + chars + \n``` - let opener_len = opener.chars().count(); - let mut chars = line.chars().peekable(); - - // Fill remaining space in current chunk first. - let avail_first = if current_len > 0 { - limit.saturating_sub(current_len + CLOSE_COST) - } else { - capacity - }; - for _ in 0..avail_first { - if let Some(ch) = chars.next() { - current.push(ch); - current_len += 1; - } else { - break; - } - } - - while chars.peek().is_some() { - // Close current fenced chunk. - current.push_str("\n```"); - chunks.push(std::mem::take(&mut current)); - // Reopen. - current.push_str(opener); - current.push('\n'); - current_len = opener_len + 1; - for _ in 0..capacity { - if let Some(ch) = chars.next() { - current.push(ch); - current_len += 1; - } else { - break; - } - } - } - } else { - // Plain hard-split (no fence or limit too small for fence wrapping). - for ch in line.chars() { - if current_len >= limit { - chunks.push(std::mem::take(&mut current)); - current_len = 0; - } - current.push(ch); - current_len += 1; - } - } - } else { - current.push_str(line); - current_len += line_chars; - } - } - - if !current.is_empty() { - // Close any trailing open fence. - if fence_opener.is_some() { - current.push_str("\n```"); - } - chunks.push(current); - } - chunks -} - -/// Shorten a prompt into a thread title: collapse GitHub URLs and cap at 40 chars. -pub fn shorten_thread_name(prompt: &str) -> String { - use std::sync::LazyLock; - static GH_RE: LazyLock = LazyLock::new(|| { - regex::Regex::new(r"https?://github\.com/([^/]+/[^/]+)/(issues|pull)/(\d+)").unwrap() - }); - // Strip @(role) and @(user) placeholders left by resolve_mentions() - let cleaned = prompt.replace("@(role)", "").replace("@(user)", ""); - let shortened = GH_RE.replace_all(cleaned.trim(), "$1#$3"); - let name: String = shortened.chars().take(40).collect(); - if name.len() < shortened.len() { - format!("{name}...") - } else { - name - } -} - -/// Truncate a string to at most `limit` Unicode characters, keeping the tail -/// (most recent output) for better streaming UX. -pub fn truncate_chars_tail(s: &str, limit: usize) -> String { - let total = s.chars().count(); - if total <= limit { - return s.to_string(); - } - s.chars().skip(total - limit).collect() -} - -#[cfg(test)] -mod tests { - use super::*; - - /// Helper: assert every chunk respects the limit. - fn assert_length_invariant(chunks: &[String], limit: usize) { - for (i, chunk) in chunks.iter().enumerate() { - let len = chunk.chars().count(); - assert!( - len <= limit, - "chunk {i} has {len} chars, exceeds limit {limit}:\n{chunk}" - ); - } - } - - #[test] - fn no_split_under_limit() { - let text = "hello\nworld"; - let chunks = split_message(text, 100); - assert_eq!(chunks.len(), 1); - assert_eq!(chunks[0], text); - } - - #[test] - fn plain_text_split_respects_limit() { - let text = "aaaa\nbbbb\ncccc\ndddd"; - let chunks = split_message(text, 10); - assert_length_invariant(&chunks, 10); - assert!(chunks.len() > 1); - } - - #[test] - fn fenced_split_preserves_language_tag() { - // ```rust\n + 1990 chars of content + \n``` — should split - let content_line = "x".repeat(1990); - let text = format!("```rust\n{content_line}\nanother line here\n```"); - let chunks = split_message(&text, 2000); - assert_length_invariant(&chunks, 2000); - // First chunk should start with ```rust - assert!(chunks[0].starts_with("```rust")); - // If split happened, second chunk should reopen with ```rust - if chunks.len() > 1 { - assert!( - chunks[1].starts_with("```rust"), - "second chunk should reopen with language tag: {}", - &chunks[1][..chunks[1].len().min(20)] - ); - } - } - - #[test] - fn fenced_split_close_overhead_budgeted() { - // Construct a fenced block where content + close marker would overflow - // without proper budgeting. - // limit=50, opener="```" (3), close="\n```" (4) - // Available for content per chunk: 50 - 3 - 1 - 4 = 42 (with opener+newline+close) - let line1 = "a".repeat(40); - let line2 = "b".repeat(40); - let text = format!("```\n{line1}\n{line2}\n```"); - let chunks = split_message(&text, 50); - assert_length_invariant(&chunks, 50); - } - - #[test] - fn reopen_path_no_overflow() { - // Regression: limit=2000, fenced block with a 1996-char line. - // Old code would produce 2004-char chunk due to reopen + extra \n. - let content = "x".repeat(1990); - let text = format!("```rust\n{content}\nshort\n```"); - let chunks = split_message(&text, 2000); - assert_length_invariant(&chunks, 2000); - } - - #[test] - fn hard_split_fenced_respects_limit() { - // A single very long line inside a fence. - let long_line = "x".repeat(100); - let text = format!("```\n{long_line}\n```"); - let chunks = split_message(&text, 20); - assert_length_invariant(&chunks, 20); - // All content should be present - let total_x: usize = chunks - .iter() - .map(|c| c.chars().filter(|&ch| ch == 'x').count()) - .sum(); - assert_eq!(total_x, 100); - } - - #[test] - fn hard_split_plain_respects_limit() { - let long_line = "y".repeat(50); - let text = format!("before\n{long_line}\nafter"); - let chunks = split_message(&text, 10); - assert_length_invariant(&chunks, 10); - } - - #[test] - fn closing_fence_triggers_split() { - // The closing ``` itself pushes over the limit. - let content = "a".repeat(44); - // "```\n" + 44 chars + "\n```" = 3 + 1 + 44 + 1 + 3 = 52 - let text = format!("```\n{content}\n```"); - let chunks = split_message(&text, 50); - assert_length_invariant(&chunks, 50); - } - - #[test] - fn multi_fence_blocks() { - let text = "text\n```python\ncode1\ncode2\n```\nmore text\n```js\ncode3\n```"; - let chunks = split_message(text, 25); - assert_length_invariant(&chunks, 25); - } - - #[test] - fn fence_balance_across_chunks() { - // Every chunk should have balanced fences (even number of ``` lines). - let content = (0..20) - .map(|i| format!("line {i}")) - .collect::>() - .join("\n"); - let text = format!("```\n{content}\n```"); - let chunks = split_message(&text, 30); - assert_length_invariant(&chunks, 30); - for (i, chunk) in chunks.iter().enumerate() { - let fence_count = chunk.lines().filter(|l| l.starts_with("```")).count(); - assert!( - fence_count % 2 == 0, - "chunk {i} has unbalanced fences ({fence_count}):\n{chunk}" - ); - } - } -} diff --git a/src/gateway.rs b/src/gateway.rs deleted file mode 100644 index d73819ee4..000000000 --- a/src/gateway.rs +++ /dev/null @@ -1,1054 +0,0 @@ -use crate::acp::ContentBlock; -use crate::adapter::{AdapterRouter, ChannelRef, ChatAdapter, MessageRef, SenderContext}; -use anyhow::Result; -use async_trait::async_trait; -use futures_util::{SinkExt, StreamExt}; -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; -use std::sync::Arc; -use tokio::sync::Mutex; -use tokio_tungstenite::tungstenite::Message; -use tracing::{error, info, warn}; - -/// Timeout for waiting on gateway reply acknowledgement. -const GATEWAY_REPLY_TIMEOUT_SECS: u64 = 5; - -/// Platforms whose gateway adapter emits a `GatewayResponse` for `edit_message` -/// so core can observe edit success or failure (used to gate the per-edit -/// response-wait below). -/// -/// Today only Feishu does, because it is the only adapter with a known -/// per-message edit cap (errcode 230072) that requires core-side recovery, and -/// the only one wired to ack edits. -/// -/// NOTE: this gates the `edit_message` response-wait only. `delete_message` is -/// unconditionally fire-and-forget (the recovery path sends fresh content -/// regardless of the delete outcome), so it does not consult this list. -/// -/// TECH DEBT: this is platform-identity standing in for a *capability*. The -/// right model is a capability handshake at gateway-connect time ("does this -/// adapter acknowledge edits?") rather than a hardcoded platform name. We -/// accept the hardcode now because there is no handshake protocol yet; when one -/// lands, replace this allowlist with a negotiated capability flag. Any new -/// adapter that wires request/response for edits MUST be added here, or its -/// edit failures stay invisible to core (silent failure mode). -const EDIT_RESPONSE_PLATFORMS: &[&str] = &["feishu"]; - -/// Whether `platform` acknowledges `edit_message` with a `GatewayResponse`. -/// See `EDIT_RESPONSE_PLATFORMS`. -fn platform_acks_writes(platform: &str) -> bool { - EDIT_RESPONSE_PLATFORMS.contains(&platform) -} - -// --- Gateway event/reply schemas (mirrors gateway service) --- - -#[derive(Clone, Debug, Deserialize)] -struct GatewayEvent { - #[allow(dead_code)] - schema: String, - event_id: String, - #[allow(dead_code)] - timestamp: String, - platform: String, - channel: GwChannel, - sender: GwSender, - content: GwContent, - #[serde(default)] - #[allow(dead_code)] - mentions: Vec, - message_id: String, -} - -#[derive(Clone, Debug, Deserialize)] -struct GwChannel { - id: String, - #[serde(rename = "type")] - channel_type: String, - thread_id: Option, -} - -#[derive(Clone, Debug, Deserialize)] -struct GwSender { - id: String, - name: String, - display_name: String, - is_bot: bool, -} - -#[derive(Clone, Debug, Deserialize)] -struct GwContent { - #[allow(dead_code)] - #[serde(rename = "type")] - content_type: String, - text: String, - #[serde(default)] - attachments: Vec, -} - -#[derive(Clone, Debug, Deserialize)] -struct GwAttachment { - #[serde(rename = "type")] - attachment_type: String, - filename: String, - mime_type: String, - #[serde(default)] - data: String, - #[allow(dead_code)] - size: u64, - /// Colocate mode: local file path (preferred over base64 `data` when present) - #[serde(default)] - path: Option, -} - -#[derive(Serialize)] -struct GatewayReply { - schema: String, - reply_to: String, - platform: String, - channel: ReplyChannel, - content: ReplyContent, - #[serde(skip_serializing_if = "Option::is_none")] - command: Option, - #[serde(skip_serializing_if = "Option::is_none")] - request_id: Option, - /// When set, the gateway should send this message as a reply/quote to the specified message ID. - /// Unlike `reply_to` (routing/dedup identifier for the triggering event), this field controls - /// the visual reply/quote UI on the platform. Falls back to plain send on failure. - #[serde(skip_serializing_if = "Option::is_none")] - quote_message_id: Option, -} - -#[derive(Serialize)] -struct ReplyChannel { - id: String, - #[serde(skip_serializing_if = "Option::is_none")] - thread_id: Option, -} - -#[derive(Serialize)] -struct ReplyContent { - #[serde(rename = "type")] - content_type: String, - text: String, -} - -#[derive(Clone, Debug, Deserialize)] -struct GatewayResponse { - #[allow(dead_code)] - schema: String, - request_id: String, - success: bool, - thread_id: Option, - message_id: Option, - error: Option, -} - -// --- GatewayAdapter: ChatAdapter over WebSocket --- - -type PendingRequests = Arc>>>; -type SharedWsTx = Arc< - Mutex< - futures_util::stream::SplitSink< - tokio_tungstenite::WebSocketStream< - tokio_tungstenite::MaybeTlsStream, - >, - Message, - >, - >, ->; - -pub struct GatewayAdapter { - ws_tx: SharedWsTx, - pending: PendingRequests, - platform_name: &'static str, - streaming: bool, - streaming_placeholder: bool, -} - -impl GatewayAdapter { - fn new( - ws_tx: SharedWsTx, - pending: PendingRequests, - platform_name: &'static str, - streaming: bool, - streaming_placeholder: bool, - ) -> Self { - Self { - ws_tx, - pending, - platform_name, - streaming, - streaming_placeholder, - } - } - - /// Internal helper for send_message / send_message_with_reply. - async fn send_gateway_reply( - &self, - channel: &ChannelRef, - content: &str, - quote_message_id: Option<&str>, - ) -> Result { - let req_id = if self.streaming { - Some(format!("req_{}", uuid::Uuid::new_v4())) - } else { - None - }; - let pending_rx = if let Some(ref id) = req_id { - let (tx, rx) = tokio::sync::oneshot::channel(); - self.pending.lock().await.insert(id.clone(), tx); - Some(rx) - } else { - None - }; - let reply = GatewayReply { - schema: "openab.gateway.reply.v1".into(), - reply_to: channel.origin_event_id.clone().unwrap_or_default(), - platform: channel.platform.clone(), - channel: ReplyChannel { - id: channel.channel_id.clone(), - thread_id: channel.thread_id.clone(), - }, - content: ReplyContent { - content_type: "text".into(), - text: content.into(), - }, - command: None, - request_id: req_id.clone(), - quote_message_id: quote_message_id.map(|s| s.to_string()), - }; - let json = serde_json::to_string(&reply)?; - if let Err(e) = self.ws_tx.lock().await.send(Message::Text(json)).await { - if let Some(ref id) = req_id { - self.pending.lock().await.remove(id); - } - return Err(e.into()); - } - let msg_id = if let (Some(rx), Some(ref id)) = (pending_rx, &req_id) { - match tokio::time::timeout(std::time::Duration::from_secs(GATEWAY_REPLY_TIMEOUT_SECS), rx).await { - Ok(Ok(resp)) if resp.success => resp.message_id.unwrap_or_else(|| "gw_sent".into()), - Ok(Ok(resp)) => { - // Gateway explicitly reported failure (success=false). Surface - // as Err so dispatch sets ❌ instead of 🆗 over an incomplete - // delivery. Examples: Feishu edit cap reached after append-new - // fallback also failed; chunked send delivered N/M chunks. - let err_msg = resp.error.clone() - .unwrap_or_else(|| "gateway reported failure".to_string()); - tracing::warn!(request_id = %id, error = %err_msg, "gateway replied with failure"); - return Err(anyhow::anyhow!("gateway reported failure: {err_msg}")); - } - Ok(Err(_)) => { - // Channel closed (gateway shutting down or pending dropped). - // Maintain legacy behavior — adapters that don't implement - // GatewayResponse for all reply types (LINE, Teams) rely on - // this for non-failure outcomes. - tracing::warn!(request_id = %id, "gateway response channel closed"); - "gw_sent".into() - } - Err(_) => { - // Timeout. Many adapters (LINE, Teams) intentionally do not - // emit GatewayResponse for replies, so timeout is the expected - // path for them. Maintain legacy behavior to avoid breaking - // platforms that have not yet wired request/response feedback. - tracing::warn!(request_id = %id, "gateway reply timed out"); - self.pending.lock().await.remove(id); - "gw_sent".into() - } - } - } else { - "gw_sent".into() - }; - Ok(MessageRef { - channel: channel.clone(), - message_id: msg_id, - }) - } -} - -/// Send a fire-and-forget reply via the shared WebSocket (no request-response). -/// Used for slash command responses where we don't need message_id back. -async fn send_fire_and_forget( - ws_tx: &SharedWsTx, - channel: &ChannelRef, - content: &str, -) -> Result<()> { - let reply = GatewayReply { - schema: "openab.gateway.reply.v1".into(), - reply_to: channel.origin_event_id.clone().unwrap_or_default(), - platform: channel.platform.clone(), - channel: ReplyChannel { - id: channel.channel_id.clone(), - thread_id: channel.thread_id.clone(), - }, - content: ReplyContent { - content_type: "text".into(), - text: content.into(), - }, - command: None, - request_id: None, - quote_message_id: None, - }; - let json = serde_json::to_string(&reply)?; - ws_tx.lock().await.send(Message::Text(json)).await?; - Ok(()) -} - -/// Handle `/models` or `/agents` text commands for gateway platforms. -/// Returns the response message, or None if the command was not recognized. -/// -/// Supported syntax: -/// /model list — numbered list of available models -/// /model set — switch by exact name or number -/// /models — alias of /model list -/// /agent list — numbered list of available agents -/// /agent set — switch by exact name or number -/// /agents — alias of /agent list -async fn handle_config_command( - trimmed: &str, - router: &AdapterRouter, - thread_key: &str, -) -> Option { - // Parse command: /model or /models (alias) - let (category, label, action, arg) = if trimmed == "/models" { - ("model", "model", "list", "") - } else if trimmed == "/agents" { - ("agent", "agent", "list", "") - } else if trimmed.starts_with("/model ") { - let rest = trimmed.strip_prefix("/model ").unwrap().trim(); - let (action, arg) = rest.split_once(' ').unwrap_or((rest, "")); - ("model", "model", action, arg.trim()) - } else if trimmed.starts_with("/agent ") { - let rest = trimmed.strip_prefix("/agent ").unwrap().trim(); - let (action, arg) = rest.split_once(' ').unwrap_or((rest, "")); - ("agent", "agent", action, arg.trim()) - } else if trimmed == "/model" { - ("model", "model", "list", "") - } else if trimmed == "/agent" { - ("agent", "agent", "list", "") - } else { - return None; - }; - - // Support both "agent" and "mode" categories (kiro-cli vs cursor-agent) - let categories: &[&str] = if category == "agent" { - &["agent", "mode"] - } else { - &[category] - }; - - let options = router.pool().get_config_options(thread_key).await; - let filtered: Vec<_> = options - .iter() - .filter(|o| { - o.category - .as_deref() - .is_some_and(|c| categories.contains(&c)) - }) - .collect(); - - if filtered.is_empty() { - return Some(format!( - "⚠️ No {label} options available. Start a conversation first." - )); - } - - // Collect all values with index for numbered list / set-by-number - let mut all_values: Vec<(String, String, String, bool)> = Vec::new(); // (config_id, value, name, is_current) - for opt in &filtered { - for v in &opt.options { - all_values.push(( - opt.id.clone(), - v.value.clone(), - v.name.clone(), - v.value == opt.current_value, - )); - } - } - - match action { - "list" => { - let mut lines = vec![format!("🔧 Available {label}s:")]; - for (i, (_, _, name, is_current)) in all_values.iter().enumerate() { - let marker = if *is_current { " ✅" } else { "" }; - lines.push(format!(" {}. {}{}", i + 1, name, marker)); - } - lines.push(format!("\nUsage: /{label} set ")); - Some(lines.join("\n")) - } - "set" => { - if arg.is_empty() { - return Some(format!("Usage: /{label} set ")); - } - // Try number first - if let Ok(num) = arg.parse::() { - if num >= 1 && num <= all_values.len() { - let (ref config_id, ref value, ref name, _) = all_values[num - 1]; - return match router - .pool() - .set_config_option(thread_key, config_id, value) - .await - { - Ok(_) => Some(format!("✅ Switched to **{name}**")), - Err(e) => Some(format!("❌ Failed to switch: {e}")), - }; - } else { - return Some(format!("⚠️ Invalid number. Use 1–{}.", all_values.len())); - } - } - // Exact match on value or name - let arg_lower = arg.to_lowercase(); - for (config_id, value, name, _) in &all_values { - if value.to_lowercase() == arg_lower || name.to_lowercase() == arg_lower { - return match router - .pool() - .set_config_option(thread_key, config_id, value) - .await - { - Ok(_) => Some(format!("✅ Switched to **{name}**")), - Err(e) => Some(format!("❌ Failed to switch: {e}")), - }; - } - } - Some(format!( - "⚠️ No {label} matching \"{arg}\". Use /{label} list to see options." - )) - } - _ => Some(format!( - "Unknown action \"{action}\". Usage: /{label} list | /{label} set " - )), - } -} - -#[async_trait] -impl ChatAdapter for GatewayAdapter { - fn platform(&self) -> &'static str { - self.platform_name - } - - fn message_limit(&self) -> usize { - 4096 // Telegram limit - } - - async fn send_message(&self, channel: &ChannelRef, content: &str) -> Result { - self.send_gateway_reply(channel, content, None).await - } - - async fn send_message_with_reply( - &self, - channel: &ChannelRef, - content: &str, - reply_to_message_id: &str, - ) -> Result { - self.send_gateway_reply(channel, content, Some(reply_to_message_id)).await - } - - async fn create_thread( - &self, - channel: &ChannelRef, - _trigger_msg: &MessageRef, - title: &str, - ) -> Result { - // Send create_topic command to gateway - let req_id = format!("req_{}", uuid::Uuid::new_v4()); - let (tx, rx) = tokio::sync::oneshot::channel(); - self.pending.lock().await.insert(req_id.clone(), tx); - - let reply = GatewayReply { - schema: "openab.gateway.reply.v1".into(), - reply_to: String::new(), - platform: channel.platform.clone(), - channel: ReplyChannel { - id: channel.channel_id.clone(), - thread_id: None, - }, - content: ReplyContent { - content_type: "text".into(), - text: title.into(), - }, - command: Some("create_topic".into()), - request_id: Some(req_id.clone()), - quote_message_id: None, - }; - let json = serde_json::to_string(&reply)?; - self.ws_tx.lock().await.send(Message::Text(json)).await?; - - // Wait for response (5s timeout) - match tokio::time::timeout(std::time::Duration::from_secs(5), rx).await { - Ok(Ok(resp)) if resp.success => Ok(ChannelRef { - platform: channel.platform.clone(), - channel_id: channel.channel_id.clone(), - thread_id: resp.thread_id, - parent_id: None, - origin_event_id: channel.origin_event_id.clone(), - }), - Ok(Ok(resp)) => { - warn!(err = ?resp.error, "create_topic failed, falling back to same channel"); - Ok(channel.clone()) - } - _ => { - warn!("create_topic timeout, falling back to same channel"); - self.pending.lock().await.remove(&req_id); - Ok(channel.clone()) - } - } - } - - async fn add_reaction(&self, msg: &MessageRef, emoji: &str) -> Result<()> { - let reply = GatewayReply { - schema: "openab.gateway.reply.v1".into(), - reply_to: msg.message_id.clone(), - platform: msg.channel.platform.clone(), - channel: ReplyChannel { - id: msg.channel.channel_id.clone(), - thread_id: msg.channel.thread_id.clone(), - }, - content: ReplyContent { - content_type: "text".into(), - text: emoji.into(), - }, - command: Some("add_reaction".into()), - quote_message_id: None, - request_id: None, - }; - let json = serde_json::to_string(&reply)?; - self.ws_tx.lock().await.send(Message::Text(json)).await?; - Ok(()) - } - - async fn remove_reaction(&self, msg: &MessageRef, emoji: &str) -> Result<()> { - let reply = GatewayReply { - schema: "openab.gateway.reply.v1".into(), - reply_to: msg.message_id.clone(), - platform: msg.channel.platform.clone(), - channel: ReplyChannel { - id: msg.channel.channel_id.clone(), - thread_id: msg.channel.thread_id.clone(), - }, - content: ReplyContent { - content_type: "text".into(), - text: emoji.into(), - }, - command: Some("remove_reaction".into()), - quote_message_id: None, - request_id: None, - }; - let json = serde_json::to_string(&reply)?; - self.ws_tx.lock().await.send(Message::Text(json)).await?; - Ok(()) - } - - async fn edit_message(&self, msg: &MessageRef, content: &str) -> Result<()> { - // Use a short request/response cycle so we can react to platform-level - // edit failures (e.g. Feishu's 20-edits-per-message cap, errcode 230072). - // Without this, edit_message was fire-and-forget and core never saw cap - // signals — cosmetic streaming would keep flushing forever and the final - // edit fallback to send_message could not trigger. - // - // Scope intentionally limited to platforms that ack writes (see - // EDIT_RESPONSE_PLATFORMS). Other adapters (LINE, Teams, Slack, Discord, - // …) keep the original fire-and-forget path so cosmetic streaming on - // those platforms does not pay a response-wait penalty per flush. - const EDIT_RESPONSE_TIMEOUT_MS: u64 = 800; - let needs_response = self.streaming && platform_acks_writes(&msg.channel.platform); - - let req_id = if needs_response { - Some(format!("req_{}", uuid::Uuid::new_v4())) - } else { - None - }; - let pending_rx = if let Some(ref id) = req_id { - let (tx, rx) = tokio::sync::oneshot::channel(); - self.pending.lock().await.insert(id.clone(), tx); - Some(rx) - } else { - None - }; - let reply = GatewayReply { - schema: "openab.gateway.reply.v1".into(), - reply_to: msg.message_id.clone(), - platform: msg.channel.platform.clone(), - channel: ReplyChannel { - id: msg.channel.channel_id.clone(), - thread_id: msg.channel.thread_id.clone(), - }, - content: ReplyContent { - content_type: "text".into(), - text: content.into(), - }, - command: Some("edit_message".into()), - quote_message_id: None, - request_id: req_id.clone(), - }; - let json = serde_json::to_string(&reply)?; - if let Err(e) = self.ws_tx.lock().await.send(Message::Text(json)).await { - if let Some(ref id) = req_id { - self.pending.lock().await.remove(id); - } - return Err(e.into()); - } - if let (Some(rx), Some(ref id)) = (pending_rx, &req_id) { - match tokio::time::timeout( - std::time::Duration::from_millis(EDIT_RESPONSE_TIMEOUT_MS), - rx, - ).await { - Ok(Ok(resp)) if resp.success => Ok(()), - Ok(Ok(resp)) => { - let err_msg = resp.error.clone() - .unwrap_or_else(|| "gateway reported edit failure".to_string()); - tracing::warn!(request_id = %id, error = %err_msg, "edit_message gateway replied failure"); - Err(anyhow::anyhow!("edit failure: {err_msg}")) - } - Ok(Err(_)) => { - tracing::debug!(request_id = %id, "edit_message gateway response channel closed"); - Ok(()) - } - Err(_) => { - // Timeout — feishu didn't respond within the window - // (probably a slow API). Treat as success to avoid - // false-positive ❌; the cap-reached path already short- - // circuits much faster (gateway returns immediately). - self.pending.lock().await.remove(id); - Ok(()) - } - } - } else { - // Non-feishu (or non-streaming): fire-and-forget, no added latency. - Ok(()) - } - } - - /// Override default delete_message (which falls back to edit-to-zero-width) - /// so platforms with native delete APIs (e.g. Feishu DELETE /im/v1/messages/{id}) - /// can perform real deletions. Critical for the streaming-edit-cap recovery - /// path: when Feishu's 20-edits-per-message cap is hit and we send full - /// content as a fresh message, we need to remove the half-edited placeholder - /// to avoid duplicated content. The default zero-width-edit fallback would - /// itself fail on a cap-reached message, leaving the placeholder visible. - /// - /// Fire-and-forget: gateway adapters that don't implement delete will simply - /// ignore the command. Failure is non-fatal — if delete fails, the user sees - /// the placeholder remain (same behavior as before this override). We do not - /// wait on a response here: the recovery path sends fresh content regardless - /// of whether the delete landed, so a response would only buy an extra log - /// line at the cost of a per-finalize wait. - async fn delete_message(&self, msg: &MessageRef) -> Result<()> { - let reply = GatewayReply { - schema: "openab.gateway.reply.v1".into(), - reply_to: msg.message_id.clone(), - platform: msg.channel.platform.clone(), - channel: ReplyChannel { - id: msg.channel.channel_id.clone(), - thread_id: msg.channel.thread_id.clone(), - }, - content: ReplyContent { - content_type: "text".into(), - text: String::new(), - }, - command: Some("delete_message".into()), - quote_message_id: None, - request_id: None, - }; - let json = serde_json::to_string(&reply)?; - self.ws_tx.lock().await.send(Message::Text(json)).await?; - Ok(()) - } - - fn use_streaming(&self, _other_bot_present: bool) -> bool { - self.streaming - } - - fn show_streaming_placeholder(&self) -> bool { - self.streaming_placeholder - } -} - -// --- Run the gateway adapter (connects to gateway WS, routes events to AdapterRouter) --- - -/// Resolved gateway configuration passed to the adapter at startup. -pub struct GatewayParams { - pub url: String, - pub platform: String, - pub token: Option, - pub bot_username: Option, - pub allow_all_channels: bool, - pub allowed_channels: Vec, - pub allow_all_users: bool, - pub allowed_users: Vec, - pub streaming: bool, - pub streaming_placeholder: bool, - pub stt: crate::config::SttConfig, -} - -pub async fn run_gateway_adapter( - params: GatewayParams, - mut shutdown_rx: tokio::sync::watch::Receiver, - dispatcher: Arc, - router: Arc, -) -> Result<()> { - let platform: &'static str = Box::leak(params.platform.into_boxed_str()); - - // Append auth token as query param if configured - let gateway_url = params.url; - let bot_username = params.bot_username; - let allow_all_channels = params.allow_all_channels; - let allowed_channels = params.allowed_channels; - let allow_all_users = params.allow_all_users; - let allowed_users = params.allowed_users; - let streaming = params.streaming; - let streaming_placeholder = params.streaming_placeholder; - let stt_config = params.stt; - - let connect_url = match ¶ms.token { - Some(token) => { - let sep = if gateway_url.contains('?') { "&" } else { "?" }; - format!("{gateway_url}{sep}token={token}") - } - None => { - warn!("gateway.token not set — WebSocket connection is NOT authenticated"); - gateway_url.clone() - } - }; - let mut backoff_secs = 1u64; - const MAX_BACKOFF: u64 = 30; - - loop { - // Check shutdown before connecting - if *shutdown_rx.borrow() { - info!("gateway adapter shutting down"); - return Ok(()); - } - - info!(url = %gateway_url, "connecting to custom gateway"); - - let ws_stream = match tokio_tungstenite::connect_async(&connect_url).await { - Ok((stream, _)) => { - backoff_secs = 1; // reset on success - info!("connected to gateway"); - stream - } - Err(e) => { - error!(err = %e, backoff = backoff_secs, "gateway connection failed, retrying"); - tokio::select! { - _ = tokio::time::sleep(std::time::Duration::from_secs(backoff_secs)) => {} - _ = shutdown_rx.changed() => { return Ok(()); } - } - backoff_secs = (backoff_secs * 2).min(MAX_BACKOFF); - continue; - } - }; - - let (ws_tx, mut ws_rx) = ws_stream.split(); - let ws_tx: SharedWsTx = Arc::new(Mutex::new(ws_tx)); - let pending: PendingRequests = Arc::new(Mutex::new(HashMap::new())); - let adapter: Arc = Arc::new(GatewayAdapter::new( - ws_tx.clone(), - pending.clone(), - platform, - streaming, - streaming_placeholder, - )); - let slash_ws_tx = ws_tx.clone(); // for fire-and-forget slash command responses - let mut tasks: tokio::task::JoinSet<()> = tokio::task::JoinSet::new(); - - loop { - tokio::select! { - msg = ws_rx.next() => { - match msg { - Some(Ok(Message::Text(text))) => { - let text_str: &str = &text; - - // Check if it's a response to a pending command - if let Ok(resp) = serde_json::from_str::(text_str) { - if resp.schema == "openab.gateway.response.v1" { - if let Some(tx) = pending.lock().await.remove(&resp.request_id) { - let _ = tx.send(resp); - } - continue; - } - } - - match serde_json::from_str::(text_str) { - Ok(event) => { - // TODO: gateway adapters (feishu) do their own bot filtering - // via AllowBots + trusted_bot_ids, but Telegram does not. - // When Feishu lifts the bot-to-bot delivery restriction, - // this guard needs to become adapter-aware (e.g. a field on - // GatewayEvent indicating the adapter already filtered bots). - if event.sender.is_bot { - continue; - } - - // Channel allowlist gate - if !allow_all_channels && !allowed_channels.contains(&event.channel.id) { - info!(channel = %event.channel.id, "gateway: channel not in allowed_channels, skipping"); - continue; - } - - // User allowlist gate - if !allow_all_users && !allowed_users.contains(&event.sender.id) { - info!(sender = %event.sender.id, "gateway: user not in allowed_users, skipping"); - continue; - } - - // @mention gating: in groups, only respond if bot is mentioned - // DMs (private) and thread replies always pass through - let is_group = event.channel.channel_type == "group" - || event.channel.channel_type == "supergroup"; - let in_thread = event.channel.thread_id.is_some(); - if is_group && !in_thread { - if let Some(ref bot_name) = bot_username { - let mentioned = event.mentions.iter().any(|m| m == bot_name); - if !mentioned { - continue; // skip non-mentioned group messages - } - } - } - - info!( - platform = %event.platform, - sender = %event.sender.name, - channel = %event.channel.id, - "gateway event received" - ); - - let channel = ChannelRef { - platform: event.platform.clone(), - channel_id: event.channel.id.clone(), - thread_id: event.channel.thread_id.clone(), - parent_id: None, - origin_event_id: Some(event.event_id.clone()), - }; - - let sender_ctx = SenderContext { - schema: "openab.sender.v1".into(), - sender_id: event.sender.id.clone(), - sender_name: event.sender.name.clone(), - display_name: event.sender.display_name.clone(), - channel: event.channel.channel_type.clone(), - channel_id: event.channel.id.clone(), - thread_id: event.channel.thread_id.clone(), - is_bot: event.sender.is_bot, - // Gateway: use event timestamp if available, else broker receive time - timestamp: Some(if event.timestamp.is_empty() { - crate::timestamp::now_iso8601() - } else { - event.timestamp.clone() - }), - message_id: if event.message_id.is_empty() { None } else { Some(event.message_id.clone()) }, - receiver_id: None, // gateway does not yet resolve receiver identity - }; - let sender_json = serde_json::to_string(&sender_ctx) - .unwrap_or_default(); - - let trigger_msg = MessageRef { - channel: channel.clone(), - message_id: event.message_id.clone(), - }; - - let adapter = adapter.clone(); - let prompt = event.content.text.clone(); - let sender_name = event.sender.name.clone(); - let sender_id = event.sender.id.clone(); - let dispatcher = dispatcher.clone(); - - // Convert gateway attachments to ContentBlocks - let mut extra_blocks = Vec::new(); - for att in &event.content.attachments { - // Read bytes: prefer file path (colocate), fallback to base64 - let bytes_result = if let Some(ref path) = att.path { - tokio::fs::read(path).await.map_err(|e| e.to_string()) - } else if !att.data.is_empty() { - use base64::Engine; - base64::engine::general_purpose::STANDARD - .decode(&att.data) - .map_err(|e| e.to_string()) - } else { - Err("no path or data".into()) - }; - - match att.attachment_type.as_str() { - "image" => { - match bytes_result { - Ok(bytes) => { - use base64::Engine; - let b64 = base64::engine::general_purpose::STANDARD.encode(&bytes); - extra_blocks.push(ContentBlock::Image { - media_type: att.mime_type.clone(), - data: b64, - }); - } - Err(e) => { - tracing::warn!(filename = %att.filename, error = %e, "gateway image read failed"); - } - } - } - "text_file" => { - if let Ok(bytes) = bytes_result { - let text = String::from_utf8_lossy(&bytes); - extra_blocks.push(ContentBlock::Text { - text: format!("```{}\n{}\n```", att.filename, text), - }); - } - } - "audio" if stt_config.enabled => { - match bytes_result { - Ok(bytes) => { - match crate::stt::transcribe( - &crate::media::HTTP_CLIENT, - &stt_config, - bytes, - att.filename.clone(), - &att.mime_type, - ).await { - Some(transcript) => { - extra_blocks.push(ContentBlock::Text { - text: format!("[Voice message transcript]: {transcript}"), - }); - } - None => { - tracing::warn!(filename = %att.filename, "gateway audio STT failed"); - extra_blocks.push(ContentBlock::Text { - text: format!( - "[Voice message — transcription failed for {}]", - att.filename - ), - }); - } - } - } - Err(e) => { - tracing::warn!(filename = %att.filename, error = %e, "gateway audio read failed"); - extra_blocks.push(ContentBlock::Text { - text: format!( - "[Voice message — read failed for {}]", - att.filename - ), - }); - } - } - } - "audio" => { - tracing::debug!(filename = %att.filename, "audio attachment skipped — STT not enabled"); - } - _ => {} - } - } - - // Slash command interception for gateway platforms - // (Feishu/LINE/Telegram don't have native slash commands) - // Use fire-and-forget send — slash command responses don't - // need message_id for streaming edits. - let trimmed = prompt.trim(); - if trimmed == "/reset" { - let thread_id_str = event.channel.thread_id.as_deref().unwrap_or(&event.channel.id); - let thread_key = format!("{}:{}", event.platform, thread_id_str); - let dropped = dispatcher.cancel_buffered_thread(event.platform.as_str(), thread_id_str); - let msg = match (router.pool().reset_session(&thread_key).await, dropped) { - (Ok(()), 0) => "🔄 Session reset. Start a new conversation!".to_string(), - (Ok(()), n) => format!("🔄 Session reset. Dropped {n} buffered message(s). Start a new conversation!"), - (Err(_), 0) => "⚠️ No active session to reset.".to_string(), - (Err(_), n) => format!("🔄 Dropped {n} buffered message(s). No active session to reset."), - }; - let _ = send_fire_and_forget(&slash_ws_tx, &channel, &msg).await; - continue; - } - if trimmed == "/cancel" { - let thread_key = format!("{}:{}", event.platform, event.channel.thread_id.as_deref().unwrap_or(&event.channel.id)); - let msg = match router.pool().cancel_session(&thread_key).await { - Ok(()) => "🛑 Cancel signal sent.".to_string(), - Err(e) => format!("⚠️ {e}"), - }; - let _ = send_fire_and_forget(&slash_ws_tx, &channel, &msg).await; - continue; - } - { - let thread_key = format!("{}:{}", event.platform, event.channel.thread_id.as_deref().unwrap_or(&event.channel.id)); - if let Some(msg) = handle_config_command(trimmed, &router, &thread_key).await { - let _ = send_fire_and_forget(&slash_ws_tx, &channel, &msg).await; - continue; - } - } - - tasks.spawn(async move { - // If supergroup with no thread_id, create a forum topic - let thread_channel = if event.channel.channel_type == "supergroup" - && channel.thread_id.is_none() - { - let title = crate::format::shorten_thread_name(&prompt); - match adapter.create_thread(&channel, &trigger_msg, &title).await { - Ok(tc) => tc, - Err(e) => { - warn!("create_thread failed, using channel: {e}"); - channel.clone() - } - } - } else { - channel.clone() - }; - - let thread_id = thread_channel - .thread_id - .as_deref() - .unwrap_or(&thread_channel.channel_id); - let thread_key = dispatcher.key( - &thread_channel.platform, - thread_id, - &sender_id, - ); - let estimated_tokens = - crate::dispatch::estimate_tokens(&prompt, &extra_blocks); - let buf_msg = crate::dispatch::BufferedMessage { - sender_json, - sender_name, - prompt, - extra_blocks, - trigger_msg, - arrived_at: std::time::Instant::now(), - estimated_tokens, - // TODO: implement gateway multibot detection - other_bot_present: false, - recipient: None, // Slack-only (assistant mode); N/A for gateway - }; - if let Err(e) = dispatcher - .submit(thread_key, thread_channel, adapter, buf_msg) - .await - { - error!("gateway dispatcher submit error: {e}"); - } - }); - } - Err(e) => warn!("invalid gateway event: {e}"), - } - } - Some(Ok(Message::Close(_))) | None => { - warn!("gateway WebSocket closed, will reconnect"); - break; - } - Some(Err(e)) => { - error!("gateway WebSocket error: {e}, will reconnect"); - break; - } - _ => {} - } - } - _ = shutdown_rx.changed() => { - if *shutdown_rx.borrow() { - info!("gateway adapter shutting down, waiting for {} in-flight tasks", tasks.len()); - while tasks.join_next().await.is_some() {} - return Ok(()); - } - } - } - } // inner loop — break here means reconnect - - // Drain in-flight tasks before reconnecting - while tasks.join_next().await.is_some() {} - - warn!(backoff = backoff_secs, "reconnecting to gateway"); - tokio::select! { - _ = tokio::time::sleep(std::time::Duration::from_secs(backoff_secs)) => {} - _ = shutdown_rx.changed() => { return Ok(()); } - } - backoff_secs = (backoff_secs * 2).min(MAX_BACKOFF); - } // outer reconnect loop -} diff --git a/src/hooks.rs b/src/hooks.rs deleted file mode 100644 index 164ce39d0..000000000 --- a/src/hooks.rs +++ /dev/null @@ -1,425 +0,0 @@ -use crate::config::{HookConfig, OnFailure}; -use sha2::{Digest, Sha256}; -use std::io::Write; -use std::path::PathBuf; -use tokio::process::Command; -use tracing::{error, info, warn}; - -/// Maximum size for a remote hook script (1 MiB). -const MAX_SCRIPT_SIZE: usize = 1024 * 1024; - -/// Run a hook. Returns Ok(()) if the hook succeeds or is not configured. -/// Returns Err only if on_failure=abort and the hook fails. -pub async fn run_hook(name: &str, hook: &HookConfig) -> anyhow::Result<()> { - info!(hook = name, "running hook"); - - let resolved = match resolve_script(name, hook).await { - Ok(r) => r, - Err(e) => return handle_failure(name, hook.on_failure, e), - }; - - let result = execute(&resolved.path, hook.timeout_seconds).await; - - // Clean up temp files - if resolved.temp { - let _ = std::fs::remove_file(&resolved.path); - } - - match result { - Ok(()) => { - info!(hook = name, "hook completed successfully"); - Ok(()) - } - Err(e) => handle_failure(name, hook.on_failure, e), - } -} - -/// Validate hook config at parse time. -pub fn validate_hook(name: &str, hook: &HookConfig) -> anyhow::Result<()> { - let sources = [ - hook.script.is_some(), - hook.inline.is_some(), - hook.url.is_some(), - ]; - let count = sources.iter().filter(|&&b| b).count(); - if count == 0 { - anyhow::bail!("hooks.{name}: exactly one of script, inline, or url must be set"); - } - if count > 1 { - anyhow::bail!( - "hooks.{name}: only one of script, inline, or url may be set (found {count})" - ); - } - if hook.url.is_some() && hook.sha256.is_none() { - anyhow::bail!("hooks.{name}: sha256 is required when using url"); - } - if let Some(ref path) = hook.script { - if !PathBuf::from(path).is_absolute() { - anyhow::bail!("hooks.{name}: script path must be absolute, got: {path}"); - } - } - Ok(()) -} - -struct ResolvedScript { - path: PathBuf, - temp: bool, -} - -async fn resolve_script(name: &str, hook: &HookConfig) -> anyhow::Result { - if let Some(ref path) = hook.script { - let p = PathBuf::from(path); - if !p.exists() { - anyhow::bail!("hooks.{name}: script not found: {path}"); - } - return Ok(ResolvedScript { - path: p, - temp: false, - }); - } - - if let Some(ref content) = hook.inline { - let path = write_temp_script(name, content)?; - return Ok(ResolvedScript { path, temp: true }); - } - - if let Some(ref url) = hook.url { - let expected_hash = hook.sha256.as_deref().unwrap(); - let content = fetch_and_verify(url, expected_hash).await?; - let path = write_temp_script(name, &content)?; - return Ok(ResolvedScript { path, temp: true }); - } - - anyhow::bail!("hooks.{name}: no script source configured"); -} - -fn write_temp_script(name: &str, content: &str) -> anyhow::Result { - #[cfg(unix)] - let suffix = ".sh"; - #[cfg(windows)] - let suffix = ".cmd"; - - let prefix = format!("openab-hook-{name}-"); - let mut builder = tempfile::Builder::new(); - builder.prefix(prefix.as_str()).suffix(suffix); - - #[cfg(unix)] - { - use std::os::unix::fs::PermissionsExt; - builder.permissions(std::fs::Permissions::from_mode(0o700)); - } - - let mut f = builder.tempfile()?; - f.write_all(content.as_bytes())?; - let path = f.into_temp_path().keep().map_err(|e| { - anyhow::anyhow!("failed to persist temp script: {}", e.error) - })?; - Ok(path) -} - -async fn fetch_and_verify(url: &str, expected_hex: &str) -> anyhow::Result { - info!(url = url, "fetching hook script from URL"); - let client = reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(30)) - .build()?; - let resp = client.get(url).send().await?; - if !resp.status().is_success() { - anyhow::bail!("hook url returned HTTP {}", resp.status()); - } - let content_length = resp.content_length().unwrap_or(0) as usize; - if content_length > MAX_SCRIPT_SIZE { - anyhow::bail!( - "hook script too large: {content_length} bytes (max {MAX_SCRIPT_SIZE})" - ); - } - let body = resp.bytes().await?; - if body.len() > MAX_SCRIPT_SIZE { - anyhow::bail!( - "hook script too large: {} bytes (max {MAX_SCRIPT_SIZE})", - body.len() - ); - } - - let mut hasher = Sha256::new(); - hasher.update(&body); - let actual_hex = format!("{:x}", hasher.finalize()); - - if actual_hex != expected_hex.to_lowercase() { - anyhow::bail!("hook sha256 mismatch: expected {expected_hex}, got {actual_hex}"); - } - - Ok(String::from_utf8(body.to_vec())?) -} - -async fn execute(path: &PathBuf, timeout_secs: u64) -> anyhow::Result<()> { - let mut cmd = Command::new(path); - cmd.env_clear(); - - // Baseline env (same as agent subprocess) - if let Ok(v) = std::env::var("HOME") { - cmd.env("HOME", &v); - } - if let Ok(v) = std::env::var("PATH") { - cmd.env("PATH", &v); - } - #[cfg(unix)] - if let Ok(v) = std::env::var("USER") { - cmd.env("USER", &v); - } - #[cfg(windows)] - { - if let Ok(v) = std::env::var("USERPROFILE") { - cmd.env("USERPROFILE", &v); - } - if let Ok(v) = std::env::var("USERNAME") { - cmd.env("USERNAME", &v); - } - if let Ok(v) = std::env::var("SystemRoot") { - cmd.env("SystemRoot", &v); - } - if let Ok(v) = std::env::var("SystemDrive") { - cmd.env("SystemDrive", &v); - } - } - - // Pass through cloud credential env vars for IAM-based auth (IRSA, Workload Identity, ECS task role) - for (key, val) in std::env::vars() { - let pass = key.starts_with("AWS_") - || key.starts_with("AMAZON_") - || key.starts_with("ECS_CONTAINER_METADATA_URI") - || key.starts_with("GOOGLE_") - || key.starts_with("GCLOUD_") - || key.starts_with("CLOUDSDK_") - || key.starts_with("AZURE_") - || key == "BOOTSTRAP_URI" - || key == "BOOTSTRAP_BASE_URI" - || key == "BOOTSTRAP_PERSONAL_URI" - || key == "STATE_BUCKET" - || key == "TASK_FAMILY" - || key == "OPENAB_AGENT_NAME" - || key == "OPENAB_BACKEND_AGENT"; - if pass { - cmd.env(&key, &val); - } - } - - let mut child = cmd.spawn()?; - - if timeout_secs == 0 { - let status = child.wait().await?; - if !status.success() { - anyhow::bail!("hook exited with {status}"); - } - return Ok(()); - } - - let timeout = std::time::Duration::from_secs(timeout_secs); - match tokio::time::timeout(timeout, child.wait()).await { - Ok(Ok(status)) => { - if !status.success() { - anyhow::bail!("hook exited with {status}"); - } - Ok(()) - } - Ok(Err(e)) => anyhow::bail!("hook process error: {e}"), - Err(_) => { - let _ = child.kill().await; - anyhow::bail!("hook timed out after {timeout_secs}s"); - } - } -} - -fn handle_failure(name: &str, policy: OnFailure, err: anyhow::Error) -> anyhow::Result<()> { - match policy { - OnFailure::Abort => { - error!(hook = name, error = %err, "hook failed (on_failure=abort)"); - Err(err) - } - OnFailure::Warn => { - warn!(hook = name, error = %err, "hook failed (on_failure=warn), continuing"); - Ok(()) - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::config::{HookConfig, OnFailure}; - - fn hook_with_script(path: &str) -> HookConfig { - HookConfig { - script: Some(path.into()), - inline: None, - url: None, - sha256: None, - timeout_seconds: 60, - on_failure: OnFailure::Abort, - } - } - - fn hook_with_inline(content: &str) -> HookConfig { - HookConfig { - script: None, - inline: Some(content.into()), - url: None, - sha256: None, - timeout_seconds: 60, - on_failure: OnFailure::Abort, - } - } - - #[test] - fn validate_rejects_no_source() { - let hook = HookConfig { - script: None, - inline: None, - url: None, - sha256: None, - timeout_seconds: 60, - on_failure: OnFailure::Abort, - }; - assert!(validate_hook("test", &hook).is_err()); - } - - #[test] - fn validate_rejects_multiple_sources() { - let hook = HookConfig { - script: Some("/bin/true".into()), - inline: Some("echo hi".into()), - url: None, - sha256: None, - timeout_seconds: 60, - on_failure: OnFailure::Abort, - }; - assert!(validate_hook("test", &hook).is_err()); - } - - #[test] - fn validate_rejects_url_without_sha256() { - let hook = HookConfig { - script: None, - inline: None, - url: Some("https://example.com/script.sh".into()), - sha256: None, - timeout_seconds: 60, - on_failure: OnFailure::Abort, - }; - assert!(validate_hook("test", &hook).is_err()); - } - - #[test] - fn validate_rejects_relative_script_path() { - let hook = hook_with_script("relative/path.sh"); - assert!(validate_hook("test", &hook).is_err()); - } - - #[test] - fn validate_accepts_absolute_script_path() { - let hook = hook_with_script("/usr/local/bin/bootstrap.sh"); - assert!(validate_hook("test", &hook).is_ok()); - } - - #[test] - fn validate_accepts_inline() { - let hook = hook_with_inline("#!/bin/sh\necho hello"); - assert!(validate_hook("test", &hook).is_ok()); - } - - #[test] - fn validate_accepts_url_with_sha256() { - let hook = HookConfig { - script: None, - inline: None, - url: Some("https://example.com/script.sh".into()), - sha256: Some("abc123".into()), - timeout_seconds: 60, - on_failure: OnFailure::Abort, - }; - assert!(validate_hook("test", &hook).is_ok()); - } - - #[tokio::test] - async fn run_inline_script_success() { - let hook = hook_with_inline("#!/bin/sh\nexit 0"); - let result = run_hook("test", &hook).await; - assert!(result.is_ok()); - } - - #[tokio::test] - async fn run_inline_script_failure_abort() { - let hook = hook_with_inline("#!/bin/sh\nexit 1"); - let result = run_hook("test", &hook).await; - assert!(result.is_err()); - } - - #[tokio::test] - async fn run_inline_script_failure_warn() { - let hook = HookConfig { - script: None, - inline: Some("#!/bin/sh\nexit 1".into()), - url: None, - sha256: None, - timeout_seconds: 60, - on_failure: OnFailure::Warn, - }; - let result = run_hook("test", &hook).await; - assert!(result.is_ok()); // warn mode continues - } - - #[tokio::test] - async fn run_inline_script_timeout() { - let hook = HookConfig { - script: None, - inline: Some("#!/bin/sh\nsleep 10".into()), - url: None, - sha256: None, - timeout_seconds: 1, - on_failure: OnFailure::Abort, - }; - let result = run_hook("test", &hook).await; - assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("timed out")); - } - - #[tokio::test] - async fn run_script_file_success() { - let dir = std::env::temp_dir(); - let path = dir.join("openab-test-hook-success.sh"); - std::fs::write(&path, "#!/bin/sh\nexit 0").unwrap(); - #[cfg(unix)] - { - use std::os::unix::fs::PermissionsExt; - std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o700)).unwrap(); - } - let hook = hook_with_script(path.to_str().unwrap()); - let result = run_hook("test", &hook).await; - let _ = std::fs::remove_file(&path); - assert!(result.is_ok()); - } - - #[tokio::test] - async fn run_script_file_not_found() { - let hook = hook_with_script("/tmp/openab-nonexistent-hook-12345.sh"); - let result = run_hook("test", &hook).await; - assert!(result.is_err()); - } - - #[test] - fn config_parses_hooks() { - let toml_str = "[agent]\ncommand = \"echo\"\n\n[hooks.pre_boot]\ninline = \"echo hello\"\ntimeout_seconds = 30\non_failure = \"warn\"\n"; - let cfg: crate::config::Config = toml::from_str(toml_str).unwrap(); - let hook = cfg.hooks.pre_boot.unwrap(); - assert_eq!(hook.inline.unwrap(), "echo hello"); - assert_eq!(hook.timeout_seconds, 30); - assert_eq!(hook.on_failure, OnFailure::Warn); - } - - #[test] - fn config_parses_no_hooks() { - let toml_str = "[agent]\ncommand = \"echo\"\n"; - let cfg: crate::config::Config = toml::from_str(toml_str).unwrap(); - assert!(cfg.hooks.pre_boot.is_none()); - assert!(cfg.hooks.pre_shutdown.is_none()); - } -} diff --git a/src/main.rs b/src/main.rs index 600028368..7902a5316 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,39 +1,32 @@ -mod acp; -mod adapter; -mod bot_turns; -mod config; -mod cron; -mod directives; -mod discord; -mod dispatch; -mod error_display; -mod format; -mod gateway; -mod hooks; -mod markdown; -mod media; -mod multibot_cache; -mod reactions; -mod remind; -mod secrets; -mod setup; -mod slack; -mod stt; -mod timestamp; - -use adapter::AdapterRouter; +use openab_core::acp; +use openab_core::adapter::{self, AdapterRouter}; +use openab_core::bot_turns; +use openab_core::config; +use openab_core::cron; +#[cfg(feature = "discord")] +use openab_core::discord; +use openab_core::dispatch; +use openab_core::gateway; +use openab_core::hooks; +use openab_core::multibot_cache; +use openab_core::remind; +use openab_core::secrets; +use openab_core::setup; +#[cfg(feature = "slack")] +use openab_core::slack; +use openab_core::stt; + use clap::Parser; +#[cfg(feature = "discord")] use serenity::gateway::GatewayError; +#[cfg(feature = "discord")] use serenity::prelude::*; use std::collections::HashSet; use std::path::PathBuf; use std::sync::{Arc, Mutex}; use tracing::{error, info, warn}; -/// Wait for SIGINT (ctrl_c) or, on unix, SIGTERM. SIGTERM is what Kubernetes -/// sends during pod termination, so handling it lets us run the full cleanup -/// path (shard manager, ACP pool drain) instead of getting SIGKILL'd after the -/// grace period. +/// Wait for SIGINT (ctrl_c) or, on unix, SIGTERM. async fn shutdown_signal() { #[cfg(unix)] { @@ -205,11 +198,6 @@ async fn main() -> anyhow::Result<()> { // Shutdown signal for Slack adapter let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false); - // Dispatcher handles tracked here so SIGTERM cleanup can call shutdown() on each (ADR §6.8). - // Also shared with the cleanup task for periodic stale-entry sweeping. - // Arc>> because: outer Arc shared with cleanup task + shutdown, - // Mutex guards startup-time pushes, inner Arc shared with each adapter. - // All pushes happen at startup; runtime access is read-only (lock is uncontended). let dispatchers: Arc>>> = Arc::new(Mutex::new(Vec::new())); // Spawn cleanup task @@ -219,22 +207,25 @@ async fn main() -> anyhow::Result<()> { loop { tokio::time::sleep(std::time::Duration::from_secs(60)).await; cleanup_pool.cleanup_idle(ttl_secs).await; - // Sweep stale per-thread dispatcher entries (idle-exited consumers). for d in cleanup_dispatchers.lock().unwrap().iter() { d.sweep_stale(); } } }); - // Pre-build shared adapters for cron scheduler (avoids duplicate Http clients / rate-limit buckets) + // Pre-build shared adapters for cron scheduler + #[cfg(feature = "discord")] let shared_discord_adapter: Option> = cfg.discord.as_ref().map(|dc| { let http = Arc::new(serenity::http::Http::new(&dc.bot_token)); Arc::new(discord::DiscordAdapter::new(http)) as Arc }); + #[cfg(not(feature = "discord"))] + let shared_discord_adapter: Option> = None; + let session_ttl_dur = std::time::Duration::from_secs(ttl_secs); - // Initialize multibot cache (persists to $HOME/.openab/cache/threads.json) + // Initialize multibot cache let multibot_cache_path = std::env::var("HOME") .map(std::path::PathBuf::from) .unwrap_or_default() @@ -243,6 +234,7 @@ async fn main() -> anyhow::Result<()> { .join("threads.json"); let multibot_cache = multibot_cache::MultibotCache::load(multibot_cache_path); + #[cfg(feature = "slack")] let shared_slack_adapter: Option> = cfg.slack.as_ref().map(|s| { Arc::new(slack::SlackAdapter::new( s.bot_token.clone(), @@ -252,8 +244,10 @@ async fn main() -> anyhow::Result<()> { multibot_cache.clone(), )) }); + #[cfg(not(feature = "slack"))] + let shared_slack_adapter: Option> = None; - // Validate cronjob config at startup (fail-fast on bad cron expressions or timezones) + // Validate cronjob config at startup let mut configured_platforms: Vec<&str> = Vec::new(); if cfg.discord.is_some() { configured_platforms.push("discord"); @@ -264,6 +258,7 @@ async fn main() -> anyhow::Result<()> { cron::validate_cronjobs(&cfg.cron.jobs, &configured_platforms)?; // Spawn Slack adapter (background task) + #[cfg(feature = "slack")] let slack_handle = if let Some(slack_cfg) = cfg.slack { let allow_all_channels = config::resolve_allow_all(slack_cfg.allow_all_channels, &slack_cfg.allowed_channels); @@ -288,9 +283,6 @@ async fn main() -> anyhow::Result<()> { let adapter = shared_slack_adapter .clone() .expect("shared_slack_adapter must exist when slack config is present"); - // Dispatcher is the sole serialization path for all modes. Message = cap 1 - // (each message dispatches alone, FIFO). Thread / Lane = configured cap; - // grouping decides whether senders share a buffer or get their own lane. let (slack_cap, slack_grouping, slack_idle) = dispatch::dispatch_params( &slack_cfg.message_processing_mode, slack_cfg.max_buffered_messages, @@ -327,6 +319,8 @@ async fn main() -> anyhow::Result<()> { } else { None }; + #[cfg(not(feature = "slack"))] + let slack_handle: Option> = None; // Spawn Gateway adapter (background task) let gateway_handle = if let Some(gw_cfg) = cfg.gateway { @@ -376,14 +370,13 @@ async fn main() -> anyhow::Result<()> { None }; - // Spawn cron scheduler (background task) — reuses shared adapters + // Spawn cron scheduler (background task) let usercron_path = if cfg.cron.usercron_enabled { cfg.cron.usercron_path.as_ref().map(|p| { let path = std::path::PathBuf::from(p); if path.is_absolute() { path } else { - // Relative paths resolve from $HOME/.openab/ (e.g. "cronjob.toml" → "$HOME/.openab/cronjob.toml") std::env::var("HOME") .map(std::path::PathBuf::from) .unwrap_or_default() @@ -404,6 +397,7 @@ async fn main() -> anyhow::Result<()> { if let Some(ref a) = shared_discord_adapter { cron_adapters.insert("discord".into(), a.clone()); } + #[cfg(feature = "slack")] if let Some(ref a) = shared_slack_adapter { cron_adapters.insert("slack".into(), a.clone() as Arc); } @@ -426,6 +420,7 @@ async fn main() -> anyhow::Result<()> { }; // Run Discord adapter (foreground, blocking) or wait for ctrl_c + #[cfg(feature = "discord")] if let Some(discord_cfg) = cfg.discord { let allow_all_channels = config::resolve_allow_all( discord_cfg.allow_all_channels, @@ -469,7 +464,7 @@ async fn main() -> anyhow::Result<()> { )); dispatchers.lock().unwrap().push(discord_dispatcher.clone()); - // Initialize reminder store (persists to $HOME/.openab/reminders.json) + // Initialize reminder store let reminder_path = std::env::var("HOME") .map(std::path::PathBuf::from) .unwrap_or_default() @@ -512,7 +507,6 @@ async fn main() -> anyhow::Result<()> { .event_handler(handler) .await?; - // Graceful Discord shutdown on ctrl_c let shard_manager = client.shard_manager.clone(); tokio::spawn(async move { shutdown_signal().await; @@ -541,7 +535,12 @@ async fn main() -> anyhow::Result<()> { Ok(_) => {} } } else { - // No Discord — wait for SIGINT or SIGTERM + info!("running without discord, press ctrl+c to stop"); + shutdown_signal().await; + info!("shutdown signal received"); + } + #[cfg(not(feature = "discord"))] + { info!("running without discord, press ctrl+c to stop"); shutdown_signal().await; info!("shutdown signal received"); @@ -549,7 +548,6 @@ async fn main() -> anyhow::Result<()> { // Cleanup cleanup_handle.abort(); - // Signal Slack adapter to shut down gracefully let _ = shutdown_tx.send(true); if let Some(handle) = slack_handle { let _ = tokio::time::timeout(std::time::Duration::from_secs(5), handle).await; @@ -558,16 +556,13 @@ async fn main() -> anyhow::Result<()> { let _ = tokio::time::timeout(std::time::Duration::from_secs(5), handle).await; } if let Some(handle) = cron_handle { - // cron.rs drains in-flight tasks for up to 30s, so wait slightly longer let _ = tokio::time::timeout(std::time::Duration::from_secs(35), handle).await; } - // Drain per-thread dispatchers and log buffered_lost counts before pool shutdown (ADR §6.8). for d in dispatchers.lock().unwrap().iter() { d.shutdown(); } let shutdown_pool = pool; shutdown_pool.shutdown().await; - // Run pre_shutdown hook after pool shutdown to guarantee no active sessions are writing. if let Some(ref hook) = shutdown_hook { if let Err(e) = hooks::run_hook("pre_shutdown", hook).await { error!(error = %e, "pre_shutdown hook failed"); @@ -604,7 +599,7 @@ mod tests { #[test] fn cli_no_args_defaults_to_run() { let cli = Cli::try_parse_from(["openab"]).unwrap(); - assert!(cli.command.is_none()); // None → unwrap_or(Run { config: None }) + assert!(cli.command.is_none()); } #[test] diff --git a/src/markdown.rs b/src/markdown.rs deleted file mode 100644 index 32398cc25..000000000 --- a/src/markdown.rs +++ /dev/null @@ -1,349 +0,0 @@ -use pulldown_cmark::{Event, Options, Parser, Tag, TagEnd}; -use serde::Deserialize; -use std::fmt; -use unicode_width::UnicodeWidthStr; - -/// How to render markdown tables for a given channel. -#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Deserialize)] -#[serde(rename_all = "lowercase")] -pub enum TableMode { - /// Wrap the table in a fenced code block (default). - #[default] - Code, - /// Convert each row into bullet points. - Bullets, - /// Pass through unchanged. - Off, -} - -impl fmt::Display for TableMode { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Code => write!(f, "code"), - Self::Bullets => write!(f, "bullets"), - Self::Off => write!(f, "off"), - } - } -} - -// ── IR types ──────────────────────────────────────────────────────── - -/// A parsed table: header row + data rows, each cell is plain text. -struct Table { - headers: Vec, - rows: Vec>, -} - -/// Segment of the document — either verbatim text or a parsed table. -enum Segment { - Text(String), - Table(Table), -} - -// ── Public API ────────────────────────────────────────────────────── - -/// Parse markdown, detect tables via pulldown-cmark, and render them -/// according to `mode`. Non-table content passes through unchanged. -pub fn convert_tables(markdown: &str, mode: TableMode) -> String { - if mode == TableMode::Off || markdown.is_empty() { - return markdown.to_string(); - } - - let segments = parse_segments(markdown); - - let mut out = String::with_capacity(markdown.len()); - for seg in segments { - match seg { - Segment::Text(t) => out.push_str(&t), - Segment::Table(table) => match mode { - TableMode::Code => render_table_code(&table, &mut out), - TableMode::Bullets => render_table_bullets(&table, &mut out), - TableMode::Off => unreachable!(), - }, - } - } - out -} - -// ── Parser ────────────────────────────────────────────────────────── - -/// Walk the markdown source with pulldown-cmark and split it into -/// text segments and parsed Table segments. -fn parse_segments(markdown: &str) -> Vec { - let mut opts = Options::empty(); - opts.insert(Options::ENABLE_TABLES); - - let mut segments: Vec = Vec::new(); - let mut in_table = false; - let mut in_head = false; - let mut headers: Vec = Vec::new(); - let mut rows: Vec> = Vec::new(); - let mut current_row: Vec = Vec::new(); - let mut cell_buf = String::new(); - let mut last_table_end: usize = 0; - - // We need byte offsets to grab non-table text verbatim. - let parser_with_offsets = Parser::new_ext(markdown, opts).into_offset_iter(); - - for (event, range) in parser_with_offsets { - match event { - Event::Start(Tag::Table(_)) => { - // Flush text before this table - let before = &markdown[last_table_end..range.start]; - if !before.is_empty() { - push_text(&mut segments, before); - } - in_table = true; - headers.clear(); - rows.clear(); - } - Event::End(TagEnd::Table) => { - let table = Table { - headers: std::mem::take(&mut headers), - rows: std::mem::take(&mut rows), - }; - segments.push(Segment::Table(table)); - in_table = false; - last_table_end = range.end; - } - Event::Start(Tag::TableHead) => { - in_head = true; - current_row.clear(); - } - Event::End(TagEnd::TableHead) => { - headers = std::mem::take(&mut current_row); - in_head = false; - } - Event::Start(Tag::TableRow) => { - current_row.clear(); - } - Event::End(TagEnd::TableRow) if !in_head => { - rows.push(std::mem::take(&mut current_row)); - } - Event::Start(Tag::TableCell) => { - cell_buf.clear(); - } - Event::End(TagEnd::TableCell) => { - current_row.push(cell_buf.trim().to_string()); - cell_buf.clear(); - } - Event::Text(t) if in_table => { - cell_buf.push_str(&t); - } - Event::Code(t) if in_table => { - cell_buf.push('`'); - cell_buf.push_str(&t); - cell_buf.push('`'); - } - // Inline markup inside cells: collect text, ignore tags - Event::SoftBreak if in_table => { - cell_buf.push(' '); - } - Event::HardBreak if in_table => { - cell_buf.push(' '); - } - // Start/End of inline tags (bold, italic, link, etc.) — skip the - // tag markers but keep processing their child text events above. - Event::Start(Tag::Emphasis) - | Event::Start(Tag::Strong) - | Event::Start(Tag::Strikethrough) - | Event::Start(Tag::Link { .. }) - | Event::End(TagEnd::Emphasis) - | Event::End(TagEnd::Strong) - | Event::End(TagEnd::Strikethrough) - | Event::End(TagEnd::Link) - if in_table => {} - _ => {} - } - } - - // Remaining text after last table - if last_table_end < markdown.len() { - let tail = &markdown[last_table_end..]; - if !tail.is_empty() { - push_text(&mut segments, tail); - } - } - - segments -} - -fn push_text(segments: &mut Vec, text: &str) { - if let Some(Segment::Text(ref mut prev)) = segments.last_mut() { - prev.push_str(text); - } else { - segments.push(Segment::Text(text.to_string())); - } -} - -// ── Renderers ─────────────────────────────────────────────────────── - -/// Render table as a fenced code block with aligned columns. -fn render_table_code(table: &Table, out: &mut String) { - let col_count = table - .headers - .len() - .max(table.rows.iter().map(|r| r.len()).max().unwrap_or(0)); - if col_count == 0 { - return; - } - - // Strip backticks from cells — inside a code fence they render as literals. - let strip = |s: &str| s.replace('`', ""); - let headers: Vec = table.headers.iter().map(|h| strip(h)).collect(); - let rows: Vec> = table - .rows - .iter() - .map(|r| r.iter().map(|c| strip(c)).collect()) - .collect(); - - // Compute column widths (using display width for CJK/emoji) - let mut widths = vec![0usize; col_count]; - for (i, h) in headers.iter().enumerate() { - widths[i] = widths[i].max(UnicodeWidthStr::width(h.as_str())); - } - for row in &rows { - for (i, cell) in row.iter().enumerate() { - if i < col_count { - widths[i] = widths[i].max(UnicodeWidthStr::width(cell.as_str())); - } - } - } - // Minimum width 3 for the divider - for w in &mut widths { - *w = (*w).max(3); - } - - out.push_str("```\n"); - - // Header row - write_row(out, &headers, &widths, col_count); - // Divider - out.push('|'); - for w in &widths { - out.push(' '); - for _ in 0..*w { - out.push('-'); - } - out.push_str(" |"); - } - out.push('\n'); - // Data rows - for row in &rows { - write_row(out, row, &widths, col_count); - } - - out.push_str("```\n"); -} - -fn write_row(out: &mut String, cells: &[String], widths: &[usize], col_count: usize) { - out.push('|'); - for (i, w) in widths.iter().enumerate().take(col_count) { - out.push(' '); - let cell = cells.get(i).map(|s| s.as_str()).unwrap_or(""); - out.push_str(cell); - let display_width = UnicodeWidthStr::width(cell); - let pad = w.saturating_sub(display_width); - for _ in 0..pad { - out.push(' '); - } - out.push_str(" |"); - } - out.push('\n'); -} - -/// Render table as bullet points: `• header: value` per cell. -fn render_table_bullets(table: &Table, out: &mut String) { - for (row_idx, row) in table.rows.iter().enumerate() { - for (i, cell) in row.iter().enumerate() { - if cell.is_empty() { - continue; - } - out.push_str("• "); - if let Some(h) = table.headers.get(i) { - if !h.is_empty() { - out.push_str(h); - out.push_str(": "); - } - } - out.push_str(cell); - out.push('\n'); - } - // Blank line between rows, but not after the last one - if row_idx + 1 < table.rows.len() { - out.push('\n'); - } - } -} - -// ── Tests ─────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - - const TABLE_MD: &str = "\ -Some text before. - -| Name | Age | -|-------|-----| -| Alice | 30 | -| Bob | 25 | - -Some text after. -"; - - #[test] - fn off_mode_passes_through() { - let result = convert_tables(TABLE_MD, TableMode::Off); - assert_eq!(result, TABLE_MD); - } - - #[test] - fn code_mode_wraps_in_codeblock() { - let result = convert_tables(TABLE_MD, TableMode::Code); - assert!(result.contains("```\n")); - assert!(result.contains("| Alice")); - assert!(result.contains("Some text before.")); - assert!(result.contains("Some text after.")); - } - - #[test] - fn bullets_mode_converts_to_bullets() { - let result = convert_tables(TABLE_MD, TableMode::Bullets); - assert!(result.contains("• Name: Alice")); - assert!(result.contains("• Age: 30")); - assert!(!result.contains("```")); - } - - #[test] - fn no_table_passes_through() { - let plain = "Hello world\nNo tables here."; - let result = convert_tables(plain, TableMode::Code); - assert_eq!(result, plain); - } - - #[test] - fn code_mode_strips_backticks_from_code_cells() { - let md = "| col |\n|-----|\n| `value` |\n"; - let result = convert_tables(md, TableMode::Code); - // The table is inside a ``` block — backtick wrapping must be stripped. - assert!(result.contains("value"), "cell content should be present"); - // Only the fence markers themselves should contain backticks. - let inner = result.trim_start_matches("```\n").trim_end_matches("```\n"); - assert!( - !inner.contains('`'), - "no backticks should appear inside the code fence: {result:?}" - ); - } - - #[test] - fn bullets_mode_keeps_backticks_in_code_cells() { - let md = "| col |\n|-----|\n| `value` |\n"; - let result = convert_tables(md, TableMode::Bullets); - assert!( - result.contains("`value`"), - "backticks should be kept in bullets mode" - ); - } -} diff --git a/src/media.rs b/src/media.rs deleted file mode 100644 index 33ea59010..000000000 --- a/src/media.rs +++ /dev/null @@ -1,846 +0,0 @@ -use crate::acp::ContentBlock; -use crate::config::SttConfig; -use base64::engine::general_purpose::STANDARD as BASE64; -use base64::Engine; -use image::codecs::gif::GifDecoder; -use image::{AnimationDecoder, ImageReader}; -use std::io::Cursor; -use std::sync::LazyLock; -use tracing::{debug, error, warn}; - -/// Reusable HTTP client for downloading attachments (shared across adapters). -pub static HTTP_CLIENT: LazyLock = LazyLock::new(|| { - reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(30)) - .build() - .expect("static HTTP client must build") -}); - -/// Maximum dimension (width or height) for resized images. -const IMAGE_MAX_DIMENSION_PX: u32 = 1200; - -/// JPEG quality for compressed output. -const IMAGE_JPEG_QUALITY: u8 = 75; - -/// Error variants for `download_and_encode_image`. -#[derive(Debug)] -pub enum MediaFetchError { - /// URL empty or MIME/filename doesn't indicate an image; skip silently. - NotAnImage, - /// HTTP response Content-Type is not a supported image format. - UnsupportedResponseType { actual: Option }, - /// Response body magic bytes don't match a supported image format. - InvalidImageBody { magic_prefix_hex: String }, - /// File exceeds the configured size limit. - SizeExceeded { actual: u64, limit: u64 }, - /// Network-level error (send or body-read). - Network(reqwest::Error), - /// Server returned a non-success HTTP status. - HttpStatus(reqwest::StatusCode), - /// Body was a valid image but post-processing (resize/compress) failed. - /// Unlike `InvalidImageBody`, the bytes decoded successfully — this is an - /// unexpected processing error, not a content validation failure. Both the - /// Slack and Discord adapters surface this as a user-facing warning alongside - /// other image-validation failures. - ProcessingFailed(image::ImageError), -} - -impl std::fmt::Display for MediaFetchError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::NotAnImage => write!(f, "not an image attachment"), - Self::UnsupportedResponseType { actual } => write!( - f, - "server returned unexpected content type (actual: {})", - actual.as_deref().unwrap_or("none"), - ), - Self::InvalidImageBody { magic_prefix_hex } => write!( - f, - "response body is not a valid image (first 8 bytes: {magic_prefix_hex})" - ), - Self::SizeExceeded { actual, limit } => { - write!(f, "file size {actual} exceeds limit {limit}") - } - Self::Network(e) => write!(f, "network error: {e}"), - Self::HttpStatus(s) => write!(f, "HTTP {s}"), - Self::ProcessingFailed(e) => write!(f, "image processing failed: {e}"), - } - } -} - -/// Strip MIME parameters and trim whitespace. `"image/png; charset=binary"` → `"image/png"`. -pub(crate) fn strip_mime_params(mime: &str) -> &str { - mime.split(';').next().unwrap_or(mime).trim() -} - -/// Format the first 8 bytes of a buffer as lowercase hex (no separator). -fn hex_prefix(body: &[u8]) -> String { - body.iter() - .take(8) - .map(|b| format!("{b:02x}")) - .collect::>() - .concat() -} - -/// Validate the HTTP response Content-Type and body magic bytes. -/// -/// If Content-Type is present and explicitly text-typed (e.g. `text/html` from -/// Slack's auth redirect when `files:read` scope is missing), rejects immediately. -/// Generic types such as `application/octet-stream` and absent headers pass through -/// to the magic-byte check, which is the authoritative gate for image validity. -/// -/// Content-Type is filtered with a block-list (`text/*`) rather than an allow-list -/// (`image/*`) because CDNs commonly serve any file type as `application/octet-stream`; -/// rejecting that header would silently break real downloads. The magic-byte check -/// examines the actual bytes regardless of what the server claims. -fn validate_image_response( - content_type: Option<&str>, - body: &[u8], -) -> Result<(), MediaFetchError> { - // Reject explicitly-text responses early (e.g. Slack HTML login page at HTTP 200). - // application/octet-stream and other generic types pass through to magic-byte check. - if let Some(ct) = content_type { - let base = strip_mime_params(ct).to_lowercase(); - if base.starts_with("text/") { - return Err(MediaFetchError::UnsupportedResponseType { actual: Some(base) }); - } - } - - let reader = match ImageReader::new(Cursor::new(body)).with_guessed_format() { - Ok(r) => r, - Err(e) => { - error!(error = %e, "image format detection I/O error"); - return Err(MediaFetchError::InvalidImageBody { - magic_prefix_hex: hex_prefix(body), - }); - } - }; - - match reader.format() { - Some(image::ImageFormat::Png | image::ImageFormat::Jpeg | image::ImageFormat::WebP) => { - Ok(()) - } - Some(image::ImageFormat::Gif) => { - validate_gif_body(body).map_err(|e| { - warn!(error = %e, "GIF validation failed"); - MediaFetchError::InvalidImageBody { - magic_prefix_hex: hex_prefix(body), - } - })?; - Ok(()) - } - _ => Err(MediaFetchError::InvalidImageBody { - magic_prefix_hex: hex_prefix(body), - }), - } -} - -/// Validate a GIF body by attempting to decode exactly one frame. -/// -/// Decoding only the first frame is intentional: the GIF header and colour tables -/// must be valid before the first frame can be decoded, so this catches truncated -/// or corrupt payloads without the CPU/memory cost of decoding a large animated GIF -/// in full. -/// -/// Creates its own `Cursor` over `raw`; the caller can independently re-read the -/// same slice for resizing. -fn validate_gif_body(raw: &[u8]) -> image::ImageResult<()> { - let decoder = GifDecoder::new(Cursor::new(raw))?; - let mut frames = decoder.into_frames(); - frames.next().ok_or_else(|| { - image::ImageError::Decoding(image::error::DecodingError::new( - image::error::ImageFormatHint::Exact(image::ImageFormat::Gif), - "GIF has no frames", - )) - })??; - Ok(()) -} - -/// Download an image from a URL, resize/compress it, and return as a ContentBlock. -/// -/// Returns `Err(MediaFetchError::NotAnImage)` when the URL or MIME hint don't -/// indicate an image — callers should skip silently. Returns -/// `Err(MediaFetchError::SizeExceeded)` when the declared `size` exceeds the limit -/// before any request is made, or when the downloaded body exceeds the limit. Returns -/// other `Err` variants (`Network`, `HttpStatus`, `UnsupportedResponseType`, -/// `InvalidImageBody`) after a request attempt — callers should surface these to the user. Returns -/// `Err(MediaFetchError::ProcessingFailed)` when the body is a valid image but -/// resize/compression fails — callers should warn the user and skip. -/// -/// Pass `auth_token` for platforms that require authentication (e.g. Slack private files). -pub async fn download_and_encode_image( - url: &str, - mime_hint: Option<&str>, - filename: &str, - size: u64, - auth_token: Option<&str>, -) -> Result { - const MAX_SIZE: u64 = 10 * 1024 * 1024; // 10 MB - - if url.is_empty() { - return Err(MediaFetchError::NotAnImage); - } - - let mime = mime_hint.or_else(|| { - filename - .rsplit('.') - .next() - .and_then(|ext| match ext.to_lowercase().as_str() { - "png" => Some("image/png"), - "jpg" | "jpeg" => Some("image/jpeg"), - "gif" => Some("image/gif"), - "webp" => Some("image/webp"), - _ => None, - }) - }); - - let Some(mime) = mime else { - debug!(filename, "skipping non-image attachment"); - return Err(MediaFetchError::NotAnImage); - }; - let mime = mime.split(';').next().unwrap_or(mime).trim(); - if !mime.starts_with("image/") { - debug!(filename, mime, "skipping non-image attachment"); - return Err(MediaFetchError::NotAnImage); - } - - if size > MAX_SIZE { - error!(filename, size, "image exceeds 10MB limit"); - return Err(MediaFetchError::SizeExceeded { - actual: size, - limit: MAX_SIZE, - }); - } - - let mut req = HTTP_CLIENT.get(url); - if let Some(token) = auth_token { - req = req.header("Authorization", format!("Bearer {token}")); - } - - let response = match req.send().await { - Ok(resp) => resp, - Err(e) => { - error!(url, error = %e, "download failed"); - return Err(MediaFetchError::Network(e)); - } - }; - if !response.status().is_success() { - error!(url, status = %response.status(), "HTTP error downloading image"); - return Err(MediaFetchError::HttpStatus(response.status())); - } - - // Capture Content-Type BEFORE .bytes() consumes the response. - let content_type = response - .headers() - .get(reqwest::header::CONTENT_TYPE) - .and_then(|v| v.to_str().ok()) - .map(str::to_string); - - let bytes = match response.bytes().await { - Ok(b) => b, - Err(e) => { - error!(url, error = %e, "read failed"); - return Err(MediaFetchError::Network(e)); - } - }; - - if bytes.len() as u64 > MAX_SIZE { - error!( - filename, - size = bytes.len(), - "downloaded image exceeds limit" - ); - return Err(MediaFetchError::SizeExceeded { - actual: bytes.len() as u64, - limit: MAX_SIZE, - }); - } - - // Guard against HTTP 200 responses that are error pages (e.g. Slack auth redirect - // when files:read scope is missing), and against corrupted or mislabeled bodies. - if let Err(e) = validate_image_response(content_type.as_deref(), &bytes) { - error!( - filename, - mime_hint = mime, - content_type = content_type.as_deref().unwrap_or("none"), - magic = hex_prefix(&bytes), - error = %e, - "image validation failed — body is not a supported image" - ); - return Err(e); - } - - let (output_bytes, output_mime) = match resize_and_compress(&bytes) { - Ok(result) => result, - Err(e) => { - error!( - filename, - error = %e, - size = bytes.len(), - "resize failed after successful validation" - ); - return Err(MediaFetchError::ProcessingFailed(e)); - } - }; - - debug!( - filename, - original_size = bytes.len(), - compressed_size = output_bytes.len(), - "image processed" - ); - - let encoded = BASE64.encode(&output_bytes); - Ok(ContentBlock::Image { - media_type: output_mime, - data: encoded, - }) -} - -/// Download an audio file and transcribe it via the configured STT provider. -/// Pass `auth_token` for platforms that require authentication. -pub async fn download_and_transcribe( - url: &str, - filename: &str, - mime_type: &str, - size: u64, - stt_config: &SttConfig, - auth_token: Option<&str>, -) -> Option { - const MAX_SIZE: u64 = 25 * 1024 * 1024; // 25 MB (Whisper API limit) - - if size > MAX_SIZE { - error!(filename, size, "audio exceeds 25MB limit"); - return None; - } - - let mut req = HTTP_CLIENT.get(url); - if let Some(token) = auth_token { - req = req.header("Authorization", format!("Bearer {token}")); - } - - let resp = match req.send().await { - Ok(r) => r, - Err(e) => { - error!(url, error = %e, "audio download request failed"); - return None; - } - }; - if !resp.status().is_success() { - error!(url, status = %resp.status(), "audio download failed"); - return None; - } - let bytes = match resp.bytes().await { - Ok(b) => b.to_vec(), - Err(e) => { - error!(url, error = %e, "audio body read failed"); - return None; - } - }; - - if bytes.len() as u64 > MAX_SIZE { - error!(filename, size = bytes.len(), "downloaded audio exceeds 25MB limit"); - return None; - } - - crate::stt::transcribe( - &HTTP_CLIENT, - stt_config, - bytes, - filename.to_string(), - mime_type, - ) - .await -} - -/// Resize image so longest side <= IMAGE_MAX_DIMENSION_PX, then encode as JPEG. -/// GIFs are passed through unchanged to preserve animation. -pub fn resize_and_compress(raw: &[u8]) -> Result<(Vec, String), image::ImageError> { - let reader = ImageReader::new(Cursor::new(raw)).with_guessed_format()?; - - let format = reader.format(); - - if format == Some(image::ImageFormat::Gif) { - return Ok((raw.to_vec(), "image/gif".to_string())); - } - - let img = reader.decode()?; - let (w, h) = (img.width(), img.height()); - - let img = if w > IMAGE_MAX_DIMENSION_PX || h > IMAGE_MAX_DIMENSION_PX { - let max_side = std::cmp::max(w, h); - let ratio = f64::from(IMAGE_MAX_DIMENSION_PX) / f64::from(max_side); - let new_w = (f64::from(w) * ratio) as u32; - let new_h = (f64::from(h) * ratio) as u32; - img.resize(new_w, new_h, image::imageops::FilterType::Lanczos3) - } else { - img - }; - - let mut buf = Cursor::new(Vec::new()); - let encoder = image::codecs::jpeg::JpegEncoder::new_with_quality(&mut buf, IMAGE_JPEG_QUALITY); - img.write_with_encoder(encoder)?; - - Ok((buf.into_inner(), "image/jpeg".to_string())) -} - -/// Check if a MIME type is audio. -pub fn is_audio_mime(mime: &str) -> bool { - mime.starts_with("audio/") -} - -/// Check if an attachment is a video file. -pub fn is_video_file(filename: &str, content_type: Option<&str>) -> bool { - let mime = content_type.unwrap_or(""); - let mime_base = mime.split(';').next().unwrap_or(mime).trim(); - if mime_base.starts_with("video/") { - return true; - } - - filename - .rsplit('.') - .next() - .map(|ext| { - matches!( - ext.to_lowercase().as_str(), - "mp4" | "mov" | "m4v" | "webm" | "mkv" | "avi" - ) - }) - .unwrap_or(false) -} - -/// Extensions recognised as text-based files that can be inlined into the prompt. -const TEXT_EXTENSIONS: &[&str] = &[ - "txt", "csv", "log", "md", "json", "jsonl", "yaml", "yml", "toml", "xml", "rs", "py", "js", - "ts", "jsx", "tsx", "go", "java", "c", "cpp", "h", "hpp", "rb", "sh", "bash", "zsh", "fish", - "ps1", "bat", "sql", "html", "css", "scss", "less", "ini", "cfg", "conf", "env", -]; - -/// Exact filenames (no extension) recognised as text files. -const TEXT_FILENAMES: &[&str] = &[ - "dockerfile", - "makefile", - "justfile", - "rakefile", - "gemfile", - "procfile", - "vagrantfile", - ".gitignore", - ".dockerignore", - ".editorconfig", -]; - -/// MIME types recognised as text-based (beyond `text/*`). -const TEXT_MIME_TYPES: &[&str] = &[ - "application/json", - "application/xml", - "application/javascript", - "application/x-yaml", - "application/x-sh", - "application/toml", - "application/x-toml", -]; - -/// Check if a file is text-based and can be inlined into the prompt. -pub fn is_text_file(filename: &str, content_type: Option<&str>) -> bool { - let mime = content_type.unwrap_or(""); - let mime_base = mime.split(';').next().unwrap_or(mime).trim(); - if mime_base.starts_with("text/") || TEXT_MIME_TYPES.contains(&mime_base) { - return true; - } - // Check extension - if filename.contains('.') { - if let Some(ext) = filename.rsplit('.').next() { - if TEXT_EXTENSIONS.contains(&ext.to_lowercase().as_str()) { - return true; - } - } - } - // Check exact filename (Dockerfile, Makefile, etc.) - TEXT_FILENAMES.contains(&filename.to_lowercase().as_str()) -} - -/// Download a text-based file and return it as a ContentBlock::Text. -/// Files larger than 512 KB are skipped to avoid bloating the prompt. -/// -/// Pass `auth_token` for platforms that require authentication (e.g. Slack private files). -/// -/// Note: the caller already guards total size via a total cap; the per-file -/// MAX_SIZE check here is intentional defense-in-depth so this function remains -/// self-contained and safe when called from other contexts. -pub async fn download_and_read_text_file( - url: &str, - filename: &str, - size: u64, - auth_token: Option<&str>, -) -> Option<(ContentBlock, u64)> { - const MAX_SIZE: u64 = 512 * 1024; // 512 KB - - if size > MAX_SIZE { - tracing::warn!(filename, size, "text file exceeds 512KB limit, skipping"); - return None; - } - - let mut req = HTTP_CLIENT.get(url); - if let Some(token) = auth_token { - req = req.header("Authorization", format!("Bearer {token}")); - } - - let resp = match req.send().await { - Ok(r) => r, - Err(e) => { - tracing::warn!(url, error = %e, "text file download failed"); - return None; - } - }; - if !resp.status().is_success() { - tracing::warn!(url, status = %resp.status(), "text file download failed"); - return None; - } - let bytes = match resp.bytes().await { - Ok(b) => b, - Err(e) => { - tracing::warn!(url, error = %e, "text file body read failed"); - return None; - } - }; - let actual_size = bytes.len() as u64; - - // Defense-in-depth: verify actual download size - if actual_size > MAX_SIZE { - tracing::warn!( - filename, - size = actual_size, - "downloaded text file exceeds 512KB limit, skipping" - ); - return None; - } - - // from_utf8_lossy returns Cow::Borrowed for valid UTF-8 (zero-copy) - let text = String::from_utf8_lossy(&bytes).into_owned(); - - // Dynamic fence: keep adding backticks until the fence doesn't appear in content - let mut fence = "```".to_string(); - while text.contains(fence.as_str()) { - fence.push('`'); - } - - debug!(filename, bytes = text.len(), "text file inlined"); - Some(( - ContentBlock::Text { - text: format!("[File: {filename}]\n{fence}\n{text}\n{fence}"), - }, - actual_size, - )) -} - -#[cfg(test)] -mod tests { - use super::*; - - fn make_png(width: u32, height: u32) -> Vec { - let img = image::RgbImage::new(width, height); - let mut buf = Cursor::new(Vec::new()); - img.write_to(&mut buf, image::ImageFormat::Png).unwrap(); - buf.into_inner() - } - - fn make_jpeg(width: u32, height: u32) -> Vec { - let img = image::RgbImage::new(width, height); - let mut buf = Cursor::new(Vec::new()); - img.write_to(&mut buf, image::ImageFormat::Jpeg).unwrap(); - buf.into_inner() - } - - fn make_gif() -> Vec { - vec![ - 0x47, 0x49, 0x46, 0x38, 0x39, 0x61, 0x01, 0x00, 0x01, 0x00, 0x80, 0x00, 0x00, 0x00, - 0x00, 0x00, 0xff, 0xff, 0xff, 0x2C, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, - 0x00, 0x02, 0x02, 0x44, 0x01, 0x00, 0x3B, - ] - } - - #[test] - fn large_image_resized_to_max_dimension() { - let png = make_png(3000, 2000); - let (compressed, mime) = resize_and_compress(&png).unwrap(); - - assert_eq!(mime, "image/jpeg"); - let result = image::load_from_memory(&compressed).unwrap(); - assert!(result.width() <= IMAGE_MAX_DIMENSION_PX); - assert!(result.height() <= IMAGE_MAX_DIMENSION_PX); - } - - #[test] - fn small_image_keeps_original_dimensions() { - let png = make_png(800, 600); - let (compressed, mime) = resize_and_compress(&png).unwrap(); - - assert_eq!(mime, "image/jpeg"); - let result = image::load_from_memory(&compressed).unwrap(); - assert_eq!(result.width(), 800); - assert_eq!(result.height(), 600); - } - - #[test] - fn landscape_image_respects_aspect_ratio() { - let png = make_png(4000, 2000); - let (compressed, _) = resize_and_compress(&png).unwrap(); - - let result = image::load_from_memory(&compressed).unwrap(); - assert_eq!(result.width(), 1200); - assert_eq!(result.height(), 600); - } - - #[test] - fn portrait_image_respects_aspect_ratio() { - let png = make_png(2000, 4000); - let (compressed, _) = resize_and_compress(&png).unwrap(); - - let result = image::load_from_memory(&compressed).unwrap(); - assert_eq!(result.width(), 600); - assert_eq!(result.height(), 1200); - } - - #[test] - fn compressed_output_is_smaller_than_original() { - let png = make_png(3000, 2000); - let (compressed, _) = resize_and_compress(&png).unwrap(); - - assert!( - compressed.len() < png.len(), - "compressed {} should be < original {}", - compressed.len(), - png.len() - ); - } - - #[test] - fn gif_passes_through_unchanged() { - let gif = make_gif(); - let (output, mime) = resize_and_compress(&gif).unwrap(); - - assert_eq!(mime, "image/gif"); - assert_eq!(output, gif); - } - - #[test] - fn invalid_data_returns_error() { - let garbage = vec![0x00, 0x01, 0x02, 0x03]; - assert!(resize_and_compress(&garbage).is_err()); - } - - #[test] - fn video_file_detects_mime_and_common_extensions() { - assert!(is_video_file("clip.bin", Some("video/mp4"))); - assert!(is_video_file("clip.mp4", None)); - assert!(is_video_file("clip.MOV", None)); - assert!(!is_video_file("notes.txt", Some("text/plain"))); - } - - // --- validate_image_response tests --- - - #[test] - fn validate_accepts_png_with_matching_content_type() { - let png = make_png(1, 1); - assert!(validate_image_response(Some("image/png"), &png).is_ok()); - } - - #[test] - fn validate_accepts_jpeg_with_matching_content_type() { - let jpeg = make_jpeg(1, 1); - assert!(validate_image_response(Some("image/jpeg"), &jpeg).is_ok()); - } - - #[test] - fn validate_accepts_gif_with_matching_content_type() { - let gif = make_gif(); - assert!(validate_image_response(Some("image/gif"), &gif).is_ok()); - } - - #[test] - fn validate_rejects_corrupt_gif_body() { - let corrupt_gif = b"GIF89a\x01\x00\x01\x00\x00\x00\x00"; - let result = validate_image_response(Some("image/gif"), corrupt_gif); - assert!(matches!( - result, - Err(MediaFetchError::InvalidImageBody { .. }) - )); - } - - #[test] - fn validate_accepts_missing_content_type_with_valid_png() { - // When Content-Type header is absent, fall back to magic-byte detection. - let png = make_png(1, 1); - assert!(validate_image_response(None, &png).is_ok()); - } - - #[test] - fn validate_content_type_strips_params() { - // "image/png; charset=binary" is a real header value — must be accepted. - let png = make_png(1, 1); - assert!(validate_image_response(Some("image/png; charset=binary"), &png).is_ok()); - } - - /// Exact reproduction of issue #776: Slack serves the workspace login HTML - /// page at HTTP 200 when the bot token lacks the `files:read` scope. - /// The Slack file metadata says `mimetype: image/png`; the response body - /// magic bytes are `Slack login"; - let result = validate_image_response(Some("image/png"), html_body); - match result { - Err(MediaFetchError::InvalidImageBody { magic_prefix_hex }) => { - assert_eq!(magic_prefix_hex, "3c21444f43545950"); - } - other => panic!("expected InvalidImageBody, got {other:?}"), - } - } - - #[test] - fn validate_rejects_text_html_content_type() { - // Even if the body were a valid image, a text/html Content-Type must be rejected. - let png = make_png(1, 1); - let result = validate_image_response(Some("text/html; charset=utf-8"), &png); - assert!(matches!( - result, - Err(MediaFetchError::UnsupportedResponseType { .. }) - )); - } - - #[test] - fn validate_rejects_mixed_case_text_content_type() { - // Mixed-case Content-Type must be normalised before rejection. - let png = make_png(1, 1); - let result = validate_image_response(Some("Text/HTML; Charset=utf-8"), &png); - assert!(matches!( - result, - Err(MediaFetchError::UnsupportedResponseType { .. }) - )); - } - - /// Regression test for the application/octet-stream fix: CDNs and generic - /// file download endpoints commonly serve any file with this Content-Type. - /// The old allow-list incorrectly rejected it before magic-byte check. - #[test] - fn validate_accepts_octet_stream_with_valid_png() { - let png = make_png(1, 1); - assert!( - validate_image_response(Some("application/octet-stream"), &png).is_ok(), - "application/octet-stream must pass through to magic-byte check" - ); - } - - /// application/json body is rejected by magic bytes, not by Content-Type. - #[test] - fn validate_rejects_json_body_by_magic_bytes() { - let json_body = b"{\"error\":\"invalid_auth\",\"ok\":false}"; - let result = validate_image_response(Some("application/json"), json_body); - assert!(matches!( - result, - Err(MediaFetchError::InvalidImageBody { .. }) - )); - } - - /// Missing Content-Type with invalid body: CDN stripping the header should - /// still be caught by magic-byte detection. - #[test] - fn validate_rejects_html_body_with_missing_content_type() { - let html_body = b"error page"; - let result = validate_image_response(None, html_body); - assert!(matches!( - result, - Err(MediaFetchError::InvalidImageBody { .. }) - )); - } - - #[test] - fn validate_rejects_empty_body() { - let result = validate_image_response(Some("image/png"), &[]); - assert!(matches!( - result, - Err(MediaFetchError::InvalidImageBody { .. }) - )); - } - - #[test] - fn validate_rejects_truncated_png_header() { - // PNG magic is 8 bytes; 4 bytes is not enough to identify the format. - let truncated = [0x89u8, 0x50, 0x4e, 0x47]; - let result = validate_image_response(Some("image/png"), &truncated); - assert!(matches!( - result, - Err(MediaFetchError::InvalidImageBody { .. }) - )); - } - - #[test] - fn truncated_png_body_must_not_produce_content_block() { - // Valid PNG magic bytes (8 bytes) + partial IHDR -- body is too short to decode. - // Previously: the <=1MB fallback in download_and_encode_image forwarded raw bytes - // after resize_and_compress failed, reproducing the #776 poisoning class. - // After removing the fallback, resize_and_compress failure must propagate as Err. - let truncated: &[u8] = &[ - 0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a, // PNG magic - 0x00, 0x00, 0x00, 0x0d, 0x49, 0x48, 0x44, 0x52, // partial IHDR - ]; - assert!( - validate_image_response(Some("image/png"), truncated).is_ok(), - "magic-byte check still passes for truncated body" - ); - assert!( - resize_and_compress(truncated).is_err(), - "truncated PNG must fail at decode -- no raw-byte fallback allowed" - ); - } - - #[test] - fn media_fetch_error_display_renders() { - let _ = MediaFetchError::NotAnImage.to_string(); - let _ = MediaFetchError::UnsupportedResponseType { - actual: Some("text/html".into()), - } - .to_string(); - let s = MediaFetchError::UnsupportedResponseType { actual: None }.to_string(); - assert!(s.contains("none"), "None branch should render as 'none'"); - let _ = MediaFetchError::InvalidImageBody { - magic_prefix_hex: "3c21444f43545950".into(), - } - .to_string(); - let _ = MediaFetchError::SizeExceeded { - actual: 11_000_000, - limit: 10_000_000, - } - .to_string(); - let _ = MediaFetchError::HttpStatus(reqwest::StatusCode::UNAUTHORIZED).to_string(); - let _ = MediaFetchError::ProcessingFailed(image::ImageError::Unsupported( - image::error::UnsupportedError::from_format_and_kind( - image::error::ImageFormatHint::Unknown, - image::error::UnsupportedErrorKind::Color(image::ExtendedColorType::Rgba16), - ), - )) - .to_string(); - } - - #[test] - fn validate_accepts_webp_by_magic_bytes() { - let img = image::RgbImage::new(1, 1); - let mut buf = std::io::Cursor::new(Vec::new()); - img.write_to(&mut buf, image::ImageFormat::WebP).unwrap(); - let webp_body = buf.into_inner(); - assert!(validate_image_response(Some("image/webp"), &webp_body).is_ok()); - } - - #[test] - fn hex_prefix_formats_first_8_bytes() { - let bytes = b""; - assert_eq!(hex_prefix(bytes), "3c21444f43545950"); - } - - #[test] - fn hex_prefix_handles_short_buffer() { - let bytes = [0xffu8, 0xd8]; - assert_eq!(hex_prefix(&bytes), "ffd8"); - } -} diff --git a/src/multibot_cache.rs b/src/multibot_cache.rs deleted file mode 100644 index a9bd8c82e..000000000 --- a/src/multibot_cache.rs +++ /dev/null @@ -1,85 +0,0 @@ -//! Persistent disk cache for multibot thread detection. -//! -//! Once a thread is identified as multi-bot (irreversible), it is stored in -//! `~/.openab/cache/threads.json` so the detection survives restarts and -//! in-memory TTL expiry. - -use chrono::{DateTime, Utc}; -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; -use std::path::PathBuf; -use std::sync::{Arc, Mutex}; -use tracing::{info, warn}; - -#[derive(Serialize, Deserialize, Clone)] -struct Entry { - detected_at: DateTime, -} - -/// Shared multibot thread cache with file persistence. -#[derive(Clone)] -pub struct MultibotCache { - threads: Arc>>, - path: PathBuf, -} - -impl MultibotCache { - /// Load or create the cache from `~/.openab/cache/threads.json`. - pub fn load(path: PathBuf) -> Self { - let threads = match std::fs::read_to_string(&path) { - Ok(data) => serde_json::from_str(&data).unwrap_or_else(|e| { - warn!(error = %e, "failed to parse threads.json, starting empty"); - HashMap::new() - }), - Err(_) => HashMap::new(), - }; - info!(count = threads.len(), path = %path.display(), "loaded multibot cache"); - Self { - threads: Arc::new(Mutex::new(threads)), - path, - } - } - - /// Check if a thread is known to be multi-bot. - pub fn is_multibot(&self, thread_id: &str) -> bool { - self.threads.lock().unwrap().contains_key(thread_id) - } - - /// Mark a thread as multi-bot and persist to disk (non-blocking). - pub async fn mark_multibot(&self, thread_id: &str) { - let snapshot = { - let mut threads = self.threads.lock().unwrap(); - if threads.contains_key(thread_id) { - return; - } - threads.insert( - thread_id.to_string(), - Entry { - detected_at: Utc::now(), - }, - ); - threads.clone() - }; - let path = self.path.clone(); - tokio::task::spawn_blocking(move || persist(&path, &snapshot)).await.ok(); - } -} - -fn persist(path: &PathBuf, threads: &HashMap) { - if let Some(parent) = path.parent() { - if let Err(e) = std::fs::create_dir_all(parent) { - warn!(error = %e, "failed to create cache directory"); - return; - } - } - match serde_json::to_string_pretty(threads) { - Ok(data) => { - if let Err(e) = std::fs::write(path, data) { - warn!(error = %e, "failed to persist threads.json"); - } - } - Err(e) => { - warn!(error = %e, "failed to serialize multibot cache"); - } - } -} diff --git a/src/reactions.rs b/src/reactions.rs deleted file mode 100644 index 6e68f90b6..000000000 --- a/src/reactions.rs +++ /dev/null @@ -1,276 +0,0 @@ -use crate::adapter::{ChatAdapter, MessageRef}; -use crate::config::{ReactionEmojis, ReactionTiming}; -use std::sync::Arc; -use tokio::sync::Mutex; -use tokio::time::Duration; - -const CODING_TOKENS: &[&str] = &["exec", "process", "read", "write", "edit", "bash", "shell"]; -const WEB_TOKENS: &[&str] = &[ - "web_search", - "web_fetch", - "web-search", - "web-fetch", - "browser", -]; - -fn classify_tool<'a>(name: &str, emojis: &'a ReactionEmojis) -> &'a str { - let n = name.to_lowercase(); - if WEB_TOKENS.iter().any(|t| n.contains(t)) { - &emojis.web - } else if CODING_TOKENS.iter().any(|t| n.contains(t)) { - &emojis.coding - } else { - &emojis.tool - } -} - -struct Inner { - adapter: Arc, - message: MessageRef, - emojis: ReactionEmojis, - timing: ReactionTiming, - current: String, - finished: bool, - debounce_handle: Option>, - stall_soft_handle: Option>, - stall_hard_handle: Option>, -} - -pub struct StatusReactionController { - inner: Arc>, - enabled: bool, -} - -impl StatusReactionController { - pub fn new( - enabled: bool, - adapter: Arc, - message: MessageRef, - emojis: ReactionEmojis, - timing: ReactionTiming, - ) -> Self { - Self { - inner: Arc::new(Mutex::new(Inner { - adapter, - message, - emojis, - timing, - current: String::new(), - finished: false, - debounce_handle: None, - stall_soft_handle: None, - stall_hard_handle: None, - })), - enabled, - } - } - - pub async fn set_queued(&self) { - if !self.enabled { - return; - } - let emoji = { self.inner.lock().await.emojis.queued.clone() }; - self.apply_immediate(&emoji).await; - } - - pub async fn set_thinking(&self) { - if !self.enabled { - return; - } - let emoji = { self.inner.lock().await.emojis.thinking.clone() }; - self.schedule_debounced(&emoji).await; - } - - pub async fn set_tool(&self, tool_name: &str) { - if !self.enabled { - return; - } - let emoji = { - let inner = self.inner.lock().await; - classify_tool(tool_name, &inner.emojis).to_string() - }; - self.schedule_debounced(&emoji).await; - } - - pub async fn set_done(&self) { - if !self.enabled { - return; - } - let emoji = { self.inner.lock().await.emojis.done.clone() }; - self.finish(&emoji).await; - // Add a random mood face - let faces = ["😊", "😎", "🫡", "🤓", "😏", "✌️", "💪", "🦾"]; - let face = faces[rand::random::() % faces.len()]; - let inner = self.inner.lock().await; - let _ = inner.adapter.add_reaction(&inner.message, face).await; - } - - pub async fn set_error(&self) { - if !self.enabled { - return; - } - let emoji = { self.inner.lock().await.emojis.error.clone() }; - self.finish(&emoji).await; - } - - pub async fn clear(&self) { - if !self.enabled { - return; - } - let mut inner = self.inner.lock().await; - cancel_timers(&mut inner); - let current = inner.current.clone(); - if !current.is_empty() { - let _ = inner - .adapter - .remove_reaction(&inner.message, ¤t) - .await; - inner.current.clear(); - } - } - - async fn apply_immediate(&self, emoji: &str) { - let mut inner = self.inner.lock().await; - if inner.finished || emoji == inner.current { - return; - } - cancel_debounce(&mut inner); - let old = inner.current.clone(); - inner.current = emoji.to_string(); - let adapter = inner.adapter.clone(); - let msg = inner.message.clone(); - let new = emoji.to_string(); - drop(inner); - - let _ = adapter.add_reaction(&msg, &new).await; - if !old.is_empty() && old != new { - let _ = adapter.remove_reaction(&msg, &old).await; - } - self.reset_stall_timers().await; - } - - async fn schedule_debounced(&self, emoji: &str) { - let mut inner = self.inner.lock().await; - if inner.finished || emoji == inner.current { - self.reset_stall_timers_inner(&mut inner); - return; - } - cancel_debounce(&mut inner); - - let emoji = emoji.to_string(); - let ctrl = self.inner.clone(); - let debounce_ms = inner.timing.debounce_ms; - inner.debounce_handle = Some(tokio::spawn(async move { - tokio::time::sleep(Duration::from_millis(debounce_ms)).await; - let mut inner = ctrl.lock().await; - if inner.finished { - return; - } - let old = inner.current.clone(); - inner.current = emoji.clone(); - let adapter = inner.adapter.clone(); - let msg = inner.message.clone(); - drop(inner); - - let _ = adapter.add_reaction(&msg, &emoji).await; - if !old.is_empty() && old != emoji { - let _ = adapter.remove_reaction(&msg, &old).await; - } - })); - self.reset_stall_timers_inner(&mut inner); - } - - async fn finish(&self, emoji: &str) { - let mut inner = self.inner.lock().await; - if inner.finished { - return; - } - inner.finished = true; - cancel_timers(&mut inner); - - let old = inner.current.clone(); - inner.current = emoji.to_string(); - let adapter = inner.adapter.clone(); - let msg = inner.message.clone(); - let new = emoji.to_string(); - drop(inner); - - let _ = adapter.add_reaction(&msg, &new).await; - if !old.is_empty() && old != new { - let _ = adapter.remove_reaction(&msg, &old).await; - } - } - - async fn reset_stall_timers(&self) { - let mut inner = self.inner.lock().await; - self.reset_stall_timers_inner(&mut inner); - } - - fn reset_stall_timers_inner(&self, inner: &mut Inner) { - if let Some(h) = inner.stall_soft_handle.take() { - h.abort(); - } - if let Some(h) = inner.stall_hard_handle.take() { - h.abort(); - } - - let soft_ms = inner.timing.stall_soft_ms; - let hard_ms = inner.timing.stall_hard_ms; - let ctrl = self.inner.clone(); - - inner.stall_soft_handle = Some(tokio::spawn({ - let ctrl = ctrl.clone(); - async move { - tokio::time::sleep(Duration::from_millis(soft_ms)).await; - let mut inner = ctrl.lock().await; - if inner.finished { - return; - } - let old = inner.current.clone(); - inner.current = "🥱".to_string(); - let adapter = inner.adapter.clone(); - let msg = inner.message.clone(); - drop(inner); - let _ = adapter.add_reaction(&msg, "🥱").await; - if !old.is_empty() && old != "🥱" { - let _ = adapter.remove_reaction(&msg, &old).await; - } - } - })); - - inner.stall_hard_handle = Some(tokio::spawn(async move { - tokio::time::sleep(Duration::from_millis(hard_ms)).await; - let mut inner = ctrl.lock().await; - if inner.finished { - return; - } - let old = inner.current.clone(); - inner.current = "😨".to_string(); - let adapter = inner.adapter.clone(); - let msg = inner.message.clone(); - drop(inner); - let _ = adapter.add_reaction(&msg, "😨").await; - if !old.is_empty() && old != "😨" { - let _ = adapter.remove_reaction(&msg, &old).await; - } - })); - } -} - -fn cancel_debounce(inner: &mut Inner) { - if let Some(h) = inner.debounce_handle.take() { - h.abort(); - } -} - -fn cancel_timers(inner: &mut Inner) { - if let Some(h) = inner.debounce_handle.take() { - h.abort(); - } - if let Some(h) = inner.stall_soft_handle.take() { - h.abort(); - } - if let Some(h) = inner.stall_hard_handle.take() { - h.abort(); - } -} diff --git a/src/remind.rs b/src/remind.rs deleted file mode 100644 index 9472c53d8..000000000 --- a/src/remind.rs +++ /dev/null @@ -1,399 +0,0 @@ -//! One-shot `/remind` slash command — schedules a delayed mention in a Discord channel. -//! -//! Persistence: reminders are stored in `reminders.json` and reloaded on startup. - -use chrono::{DateTime, Utc}; -use serde::{Deserialize, Serialize}; -use serenity::http::Http; -use serenity::model::id::ChannelId; -use std::path::PathBuf; -use std::sync::Arc; -use tokio::sync::Mutex; -use tracing::{error, info, warn}; - -/// A single pending reminder. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Reminder { - pub id: String, - pub channel_id: u64, - pub sender_id: u64, - /// Raw mention strings (e.g. "<@123>", "<@&456>") - pub targets: Vec, - pub message: String, - pub fire_at: DateTime, - pub created_at: DateTime, -} - -/// Shared reminder store with file persistence. -#[derive(Clone)] -pub struct ReminderStore { - reminders: Arc>>, - path: PathBuf, -} - -impl ReminderStore { - /// Load or create the reminder store from the given path. - pub fn load(path: PathBuf) -> Self { - let reminders = match std::fs::read_to_string(&path) { - Ok(data) => serde_json::from_str(&data).unwrap_or_else(|e| { - warn!(error = %e, "failed to parse reminders.json, starting empty"); - Vec::new() - }), - Err(_) => Vec::new(), - }; - info!(count = reminders.len(), path = %path.display(), "loaded reminders"); - Self { - reminders: Arc::new(Mutex::new(reminders)), - path, - } - } - - /// Add a reminder and persist to disk. - pub async fn add(&self, reminder: Reminder) { - let snapshot = { - let mut reminders = self.reminders.lock().await; - reminders.push(reminder); - reminders.clone() - }; - self.persist(&snapshot); - } - - /// Remove a reminder by ID and persist. - pub async fn remove(&self, id: &str) { - let snapshot = { - let mut reminders = self.reminders.lock().await; - reminders.retain(|r| r.id != id); - reminders.clone() - }; - self.persist(&snapshot); - } - - /// Get all pending reminders (for startup re-scheduling). - pub async fn pending(&self) -> Vec { - self.reminders.lock().await.clone() - } - - fn persist(&self, reminders: &[Reminder]) { - match serde_json::to_string_pretty(reminders) { - Ok(data) => { - if let Some(parent) = self.path.parent() { - if let Err(e) = std::fs::create_dir_all(parent) { - error!(error = %e, "failed to create reminders directory"); - return; - } - } - if let Err(e) = std::fs::write(&self.path, data) { - error!(error = %e, "failed to persist reminders.json"); - } - } - Err(e) => { - error!(error = %e, "failed to serialize reminders, skipping persist"); - } - } - } -} - -/// Maximum allowed message length for reminders. -pub const MAX_MESSAGE_LEN: usize = 1800; - -/// Maximum number of mention targets per reminder. -pub const MAX_TARGETS: usize = 10; - -/// Sanitize reminder message: neutralize @everyone/@here. -pub fn sanitize_message(msg: &str) -> String { - msg.replace("@everyone", "@\u{200b}everyone") - .replace("@here", "@\u{200b}here") -} - -/// Validate reminder message length. -pub fn validate_message(msg: &str) -> Result<(), String> { - if msg.len() > MAX_MESSAGE_LEN { - Err(format!("message too long (max {MAX_MESSAGE_LEN} characters)")) - } else { - Ok(()) - } -} - -/// Parse a human delay string like "30m", "2h", "7d" into seconds. -/// Supports combinations: "1h30m", "2d12h". -/// Range: 1m (60s) to 30d (2_592_000s). -pub fn parse_delay(input: &str) -> Result { - let s = input.trim().to_lowercase(); - if s.is_empty() { - return Err("empty delay".into()); - } - - let mut total_secs: u64 = 0; - let mut num_buf = String::new(); - - for ch in s.chars() { - if ch.is_ascii_digit() { - num_buf.push(ch); - } else { - let n: u64 = num_buf.parse().map_err(|_| format!("invalid number in delay: {input}"))?; - num_buf.clear(); - let multiplier = match ch { - 'm' => 60, - 'h' => 3600, - 'd' => 86400, - _ => return Err(format!("unknown unit '{ch}' in delay (use m/h/d)")), - }; - total_secs += n * multiplier; - } - } - - // Handle bare number (default to minutes) - if !num_buf.is_empty() { - let n: u64 = num_buf.parse().map_err(|_| format!("invalid number in delay: {input}"))?; - total_secs += n * 60; // default unit = minutes - } - - if total_secs < 60 { - return Err("minimum delay is 1m".into()); - } - if total_secs > 2_592_000 { - return Err("maximum delay is 30d".into()); - } - - Ok(total_secs) -} - -/// Format seconds into a human-readable string like "2h 30m". -pub fn format_delay(secs: u64) -> String { - let d = secs / 86400; - let h = (secs % 86400) / 3600; - let m = (secs % 3600) / 60; - let mut parts = Vec::new(); - if d > 0 { parts.push(format!("{d}d")); } - if h > 0 { parts.push(format!("{h}h")); } - if m > 0 { parts.push(format!("{m}m")); } - if parts.is_empty() { "< 1m".into() } else { parts.join(" ") } -} - -/// Spawn a tokio task that fires the reminder after the delay. -pub fn schedule_reminder( - http: Arc, - store: ReminderStore, - reminder: Reminder, -) { - let now = Utc::now(); - let delay = if reminder.fire_at > now { - (reminder.fire_at - now).to_std().unwrap_or_default() - } else { - std::time::Duration::ZERO - }; - - let id = reminder.id.clone(); - tokio::spawn(async move { - tokio::time::sleep(delay).await; - - let targets_str = reminder.targets.join(" "); - let content = format!( - "⏰ **Reminder** from <@{}>:\n\"{}\"\ncc {}", - reminder.sender_id, reminder.message, targets_str - ); - - let channel = ChannelId::new(reminder.channel_id); - match channel.say(&http, &content).await { - Ok(_) => { - info!(id = %id, channel = reminder.channel_id, "reminder fired"); - store.remove(&id).await; - } - Err(e) => { - error!(error = %e, id = %id, "failed to send reminder — keeping for retry on next restart"); - } - } - }); -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn parse_delay_minutes() { - assert_eq!(parse_delay("5m").unwrap(), 300); - assert_eq!(parse_delay("1m").unwrap(), 60); - } - - #[test] - fn parse_delay_hours() { - assert_eq!(parse_delay("2h").unwrap(), 7200); - } - - #[test] - fn parse_delay_days() { - assert_eq!(parse_delay("1d").unwrap(), 86400); - assert_eq!(parse_delay("30d").unwrap(), 2_592_000); - } - - #[test] - fn parse_delay_combined() { - assert_eq!(parse_delay("1h30m").unwrap(), 5400); - assert_eq!(parse_delay("1d12h").unwrap(), 129_600); - } - - #[test] - fn parse_delay_bare_number_defaults_to_minutes() { - assert_eq!(parse_delay("10").unwrap(), 600); - } - - #[test] - fn parse_delay_too_short() { - assert!(parse_delay("0m").is_err()); - assert!(parse_delay("0h").is_err()); - } - - #[test] - fn parse_delay_too_long() { - assert!(parse_delay("31d").is_err()); - } - - #[test] - fn format_delay_basic() { - assert_eq!(format_delay(3600), "1h"); - assert_eq!(format_delay(5400), "1h 30m"); - assert_eq!(format_delay(90000), "1d 1h"); - } - - #[test] - fn parse_delay_empty() { - assert!(parse_delay("").is_err()); - assert!(parse_delay(" ").is_err()); - } - - #[test] - fn parse_delay_invalid_unit() { - assert!(parse_delay("2x").is_err()); - assert!(parse_delay("abc").is_err()); - assert!(parse_delay("5s").is_err()); - } - - #[test] - fn parse_delay_case_insensitive() { - assert_eq!(parse_delay("2H").unwrap(), 7200); - assert_eq!(parse_delay("1D30M").unwrap(), 88200); - } - - #[test] - fn parse_delay_whitespace_trimmed() { - assert_eq!(parse_delay(" 5m ").unwrap(), 300); - } - - #[test] - fn parse_delay_bare_number_boundary() { - assert_eq!(parse_delay("1").unwrap(), 60); // 1 min - assert_eq!(parse_delay("30").unwrap(), 1800); // 30 min - } - - #[test] - fn parse_delay_exact_boundaries() { - // Exactly 1m (minimum) - assert_eq!(parse_delay("1m").unwrap(), 60); - // Exactly 30d (maximum) - assert_eq!(parse_delay("30d").unwrap(), 2_592_000); - // Just over 30d - assert!(parse_delay("30d1m").is_err()); - } - - #[test] - fn format_delay_zero() { - assert_eq!(format_delay(0), "< 1m"); - } - - #[test] - fn format_delay_pure_units() { - assert_eq!(format_delay(86400), "1d"); - assert_eq!(format_delay(120), "2m"); - assert_eq!(format_delay(7200), "2h"); - } - - #[tokio::test] - async fn reminder_store_add_remove() { - let dir = std::env::temp_dir().join(format!("remind_test_{}", std::process::id())); - std::fs::create_dir_all(&dir).unwrap(); - let path = dir.join("reminders.json"); - - let store = ReminderStore::load(path.clone()); - assert_eq!(store.pending().await.len(), 0); - - let r = Reminder { - id: "test-1".into(), - channel_id: 123, - sender_id: 456, - targets: vec!["<@789>".into()], - message: "hello".into(), - fire_at: Utc::now() + chrono::Duration::hours(1), - created_at: Utc::now(), - }; - - store.add(r).await; - assert_eq!(store.pending().await.len(), 1); - - store.remove("test-1").await; - assert_eq!(store.pending().await.len(), 0); - - // Verify persistence - let store2 = ReminderStore::load(path.clone()); - assert_eq!(store2.pending().await.len(), 0); - - std::fs::remove_dir_all(&dir).ok(); - } - - #[tokio::test] - async fn reminder_store_persists_across_reload() { - let dir = std::env::temp_dir().join(format!("remind_test2_{}", std::process::id())); - std::fs::create_dir_all(&dir).unwrap(); - let path = dir.join("reminders.json"); - - let store = ReminderStore::load(path.clone()); - let r = Reminder { - id: "persist-1".into(), - channel_id: 100, - sender_id: 200, - targets: vec!["<@300>".into()], - message: "persist test".into(), - fire_at: Utc::now() + chrono::Duration::hours(2), - created_at: Utc::now(), - }; - store.add(r).await; - - // Reload from disk - let store2 = ReminderStore::load(path.clone()); - let pending = store2.pending().await; - assert_eq!(pending.len(), 1); - assert_eq!(pending[0].id, "persist-1"); - assert_eq!(pending[0].message, "persist test"); - - std::fs::remove_dir_all(&dir).ok(); - } - - #[test] - fn sanitize_message_strips_everyone_here() { - assert_eq!(sanitize_message("hello @everyone"), "hello @\u{200b}everyone"); - assert_eq!(sanitize_message("hey @here check"), "hey @\u{200b}here check"); - assert_eq!(sanitize_message("@everyone @here"), "@\u{200b}everyone @\u{200b}here"); - } - - #[test] - fn sanitize_message_no_change() { - assert_eq!(sanitize_message("normal message"), "normal message"); - assert_eq!(sanitize_message("<@123> hello"), "<@123> hello"); - } - - #[test] - fn validate_message_ok() { - assert!(validate_message("short message").is_ok()); - assert!(validate_message(&"a".repeat(1800)).is_ok()); - } - - #[test] - fn validate_message_too_long() { - assert!(validate_message(&"a".repeat(1801)).is_err()); - } - - #[test] - fn max_targets_constant() { - assert_eq!(MAX_TARGETS, 10); - } -} diff --git a/src/secrets.rs b/src/secrets.rs deleted file mode 100644 index e6a7967fd..000000000 --- a/src/secrets.rs +++ /dev/null @@ -1,479 +0,0 @@ -use std::collections::HashMap; -use tracing::{error, info}; - -use crate::config::SecretsConfig; - -/// Resolved secrets: mapping from key name to plaintext value. -pub type ResolvedSecrets = HashMap; - -/// Resolve all secret references in the [secrets] config table. -/// Returns a map of key → resolved value. -pub async fn resolve(cfg: &SecretsConfig) -> anyhow::Result { - let mut resolved = HashMap::new(); - - // Build AWS client once if any refs use aws-sm:// - #[cfg(feature = "secrets-aws")] - let aws_client = if cfg.refs.values().any(|v| v.starts_with("aws-sm://")) { - Some(build_aws_client(cfg).await) - } else { - None - }; - - for (key, uri) in &cfg.refs { - let value = if uri.starts_with("aws-sm://") { - #[cfg(feature = "secrets-aws")] - { - let client = aws_client.as_ref().ok_or_else(|| { - anyhow::anyhow!("secret '{key}': AWS client not initialized") - })?; - resolve_aws_sm(key, uri, client).await? - } - #[cfg(not(feature = "secrets-aws"))] - { - anyhow::bail!( - "secret '{key}' uses aws-sm:// but the 'secrets-aws' feature is not enabled" - ); - } - } else if uri.starts_with("exec://") { - resolve_exec(key, uri, cfg).await? - } else { - anyhow::bail!( - "secret '{key}': unrecognized URI scheme in '{uri}' (expected aws-sm:// or exec://)" - ); - }; - resolved.insert(key.clone(), value); - } - - if !resolved.is_empty() { - info!(count = resolved.len(), "secrets resolved"); - } - Ok(resolved) -} - -// -- AWS Secrets Manager provider -- - -#[cfg(feature = "secrets-aws")] -async fn build_aws_client(cfg: &SecretsConfig) -> aws_sdk_secretsmanager::Client { - let mut config_loader = aws_config::defaults(aws_config::BehaviorVersion::latest()); - if let Some(ref region) = cfg.aws.region { - config_loader = config_loader.region(aws_config::Region::new(region.clone())); - } - if let Some(ref endpoint) = cfg.aws.endpoint_url { - config_loader = config_loader.endpoint_url(endpoint); - } - let sdk_config = config_loader.load().await; - aws_sdk_secretsmanager::Client::new(&sdk_config) -} - -#[cfg(feature = "secrets-aws")] -async fn resolve_aws_sm( - key: &str, - uri: &str, - client: &aws_sdk_secretsmanager::Client, -) -> anyhow::Result { - let (secret_id, json_key) = parse_aws_sm_uri(uri) - .ok_or_else(|| anyhow::anyhow!("secret '{key}': invalid aws-sm:// URI '{uri}' — expected aws-sm://#"))?; - - let resp = client - .get_secret_value() - .secret_id(&secret_id) - .send() - .await - .map_err(|e| { - error!(secret = key, secret_id = %secret_id, "AWS Secrets Manager error"); - anyhow::anyhow!("secret '{key}': failed to fetch '{secret_id}' from AWS Secrets Manager: {e}") - })?; - - let secret_string = resp - .secret_string() - .ok_or_else(|| anyhow::anyhow!("secret '{key}': '{secret_id}' has no string value (binary secrets not supported)"))?; - - // Parse as JSON and extract the key - let json: serde_json::Value = serde_json::from_str(secret_string) - .map_err(|e| anyhow::anyhow!("secret '{key}': '{secret_id}' is not valid JSON: {e}"))?; - - let value = json - .get(&json_key) - .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("secret '{key}': JSON key '{json_key}' not found in '{secret_id}'"))?; - - Ok(value.to_owned()) -} - -/// Parse `aws-sm://secret-id#json-key` into (secret_id, json_key). -#[cfg(feature = "secrets-aws")] -fn parse_aws_sm_uri(uri: &str) -> Option<(String, String)> { - let rest = uri.strip_prefix("aws-sm://")?; - let (secret_id, json_key) = rest.rsplit_once('#')?; - if secret_id.is_empty() || json_key.is_empty() { - return None; - } - Some((secret_id.to_owned(), json_key.to_owned())) -} - -// -- Exec provider -- -// Note: script path is delimited by the first space. Paths containing spaces are not supported. - -async fn resolve_exec(key: &str, uri: &str, cfg: &SecretsConfig) -> anyhow::Result { - let rest = uri.strip_prefix("exec://").unwrap(); - let mut parts_iter = rest.splitn(3, ' '); - let script = parts_iter.next().ok_or_else(|| { - anyhow::anyhow!("secret '{key}': exec:// URI missing script path") - })?; - if script.is_empty() { - anyhow::bail!("secret '{key}': exec:// URI has empty script path"); - } - - let mut cmd = tokio::process::Command::new(script); - cmd.kill_on_drop(true); - - // Sanitized environment (same as pre_boot hooks — no unrelated tokens leak) - cmd.env_clear(); - if let Ok(v) = std::env::var("HOME") { - cmd.env("HOME", &v); - } - if let Ok(v) = std::env::var("PATH") { - cmd.env("PATH", &v); - } - #[cfg(unix)] - if let Ok(v) = std::env::var("USER") { - cmd.env("USER", &v); - } - // Pass through cloud credential env vars for IAM-based auth - for (key, val) in std::env::vars() { - let pass = key.starts_with("AWS_") - || key.starts_with("AMAZON_") - || key.starts_with("ECS_CONTAINER_METADATA_URI") - || key.starts_with("GOOGLE_") - || key.starts_with("GCLOUD_") - || key.starts_with("CLOUDSDK_") - || key.starts_with("AZURE_"); - if pass { - cmd.env(&key, &val); - } - } - - // Pass remaining parts as arguments (key, attribute) - for arg in parts_iter { - if !arg.is_empty() { - cmd.arg(arg); - } - } - - let timeout = std::time::Duration::from_secs(cfg.exec.timeout_seconds); - let output = tokio::time::timeout(timeout, cmd.output()) - .await - .map_err(|_| { - anyhow::anyhow!( - "secret '{key}': exec script '{script}' timed out after {}s", - cfg.exec.timeout_seconds - ) - })? - .map_err(|e| { - if e.kind() == std::io::ErrorKind::NotFound { - anyhow::anyhow!( - "secret '{key}': exec script '{script}' not found — did [hooks.pre_boot] run successfully?" - ) - } else { - anyhow::anyhow!("secret '{key}': failed to execute '{script}': {e}") - } - })?; - - if !output.status.success() { - let stderr = String::from_utf8_lossy(&output.stderr); - error!(secret = key, script, %stderr, "exec provider failed"); - anyhow::bail!( - "secret '{key}': exec script '{script}' exited with {}", - output.status - ); - } - - let value = String::from_utf8(output.stdout) - .map_err(|e| anyhow::anyhow!("secret '{key}': exec output is not valid UTF-8: {e}"))?; - Ok(value.trim_end_matches('\n').to_owned()) -} - -/// Substitute `${secrets.}` references in the raw config text with resolved values. -/// Uses single-pass replacement to avoid double-substitution if a secret value -/// itself contains `${secrets.*}` patterns. -/// Values are escaped for use within TOML double-quoted strings. -pub fn substitute(raw: &str, secrets: &ResolvedSecrets) -> String { - let re = regex::Regex::new(r"\$\{secrets\.([^}]+)\}").unwrap(); - re.replace_all(raw, |caps: ®ex::Captures| { - let key = &caps[1]; - secrets - .get(key) - .map(|v| escape_toml_value(v)) - .unwrap_or_else(|| caps[0].to_owned()) - }) - .into_owned() -} - -/// Escape a string value so it is safe inside a TOML double-quoted string. -fn escape_toml_value(s: &str) -> String { - let mut out = String::with_capacity(s.len()); - for ch in s.chars() { - match ch { - '\\' => out.push_str("\\\\"), - '"' => out.push_str("\\\""), - '\n' => out.push_str("\\n"), - '\r' => out.push_str("\\r"), - '\t' => out.push_str("\\t"), - _ => out.push(ch), - } - } - out -} - -#[cfg(test)] -mod tests { - use super::*; - - #[cfg(feature = "secrets-aws")] - #[test] - fn parse_aws_sm_uri_valid() { - let (id, key) = parse_aws_sm_uri("aws-sm://openab/prod#discord_bot_token").unwrap(); - assert_eq!(id, "openab/prod"); - assert_eq!(key, "discord_bot_token"); - } - - #[cfg(feature = "secrets-aws")] - #[test] - fn parse_aws_sm_uri_with_arn() { - let uri = "aws-sm://arn:aws:secretsmanager:us-east-1:123456789:secret:my-secret-abc123#api_key"; - let (id, key) = parse_aws_sm_uri(uri).unwrap(); - assert_eq!(id, "arn:aws:secretsmanager:us-east-1:123456789:secret:my-secret-abc123"); - assert_eq!(key, "api_key"); - } - - #[cfg(feature = "secrets-aws")] - #[test] - fn parse_aws_sm_uri_missing_key() { - assert!(parse_aws_sm_uri("aws-sm://openab/prod").is_none()); - assert!(parse_aws_sm_uri("aws-sm://openab/prod#").is_none()); - assert!(parse_aws_sm_uri("aws-sm://#key").is_none()); - } - - #[test] - fn substitute_replaces_secrets() { - let mut secrets = HashMap::new(); - secrets.insert("token".to_owned(), "my-secret-value".to_owned()); - let input = r#"bot_token = "${secrets.token}""#; - let output = substitute(input, &secrets); - assert_eq!(output, r#"bot_token = "my-secret-value""#); - } - - #[test] - fn substitute_escapes_special_chars() { - let mut secrets = HashMap::new(); - secrets.insert("key".to_owned(), "has\"quotes\\and\nnewlines".to_owned()); - let input = r#"value = "${secrets.key}""#; - let output = substitute(input, &secrets); - assert_eq!(output, r#"value = "has\"quotes\\and\nnewlines""#); - } - - #[test] - fn substitute_no_match_unchanged() { - let secrets = HashMap::new(); - let input = r#"bot_token = "${DISCORD_BOT_TOKEN}""#; - let output = substitute(input, &secrets); - assert_eq!(output, input); - } - - #[test] - fn substitute_unknown_key_left_intact() { - let mut secrets = HashMap::new(); - secrets.insert("known".to_owned(), "val".to_owned()); - let input = r#"a = "${secrets.known}" b = "${secrets.unknown}""#; - let output = substitute(input, &secrets); - assert_eq!(output, r#"a = "val" b = "${secrets.unknown}""#); - } - - #[test] - fn substitute_no_double_replacement() { - let mut secrets = HashMap::new(); - // Secret value itself contains a ${secrets.*} pattern - secrets.insert("a".to_owned(), "${secrets.b}".to_owned()); - secrets.insert("b".to_owned(), "should-not-appear".to_owned()); - let input = r#"val = "${secrets.a}""#; - let output = substitute(input, &secrets); - // The literal ${secrets.b} should be escaped, not re-substituted - assert!(!output.contains("should-not-appear")); - } - - #[test] - fn substitute_multiple_refs_same_line() { - let mut secrets = HashMap::new(); - secrets.insert("user".to_owned(), "admin".to_owned()); - secrets.insert("pass".to_owned(), "s3cret".to_owned()); - let input = r#"dsn = "postgres://${secrets.user}:${secrets.pass}@localhost""#; - let output = substitute(input, &secrets); - assert_eq!(output, r#"dsn = "postgres://admin:s3cret@localhost""#); - } - - #[test] - fn escape_toml_value_basic() { - assert_eq!(escape_toml_value("hello"), "hello"); - assert_eq!(escape_toml_value(r#"a"b"#), r#"a\"b"#); - assert_eq!(escape_toml_value("a\\b"), "a\\\\b"); - assert_eq!(escape_toml_value("line1\nline2"), "line1\\nline2"); - assert_eq!(escape_toml_value("tab\there"), "tab\\there"); - assert_eq!(escape_toml_value("cr\rhere"), "cr\\rhere"); - } - - #[test] - fn escape_toml_value_combined() { - let input = "key=\"val\"\nnext"; - let output = escape_toml_value(input); - assert_eq!(output, "key=\\\"val\\\"\\nnext"); - } - - #[tokio::test] - async fn resolve_exec_success() { - use crate::config::{AwsSecretsConfig, ExecSecretsConfig, SecretsConfig}; - let cfg = SecretsConfig { - aws: AwsSecretsConfig::default(), - exec: ExecSecretsConfig { timeout_seconds: 5 }, - refs: HashMap::new(), - }; - // Use echo as a script - let result = resolve_exec("test", "exec:///bin/echo hello world", &cfg).await; - assert_eq!(result.unwrap(), "hello world"); - } - - #[tokio::test] - async fn resolve_exec_not_found() { - use crate::config::{AwsSecretsConfig, ExecSecretsConfig, SecretsConfig}; - let cfg = SecretsConfig { - aws: AwsSecretsConfig::default(), - exec: ExecSecretsConfig { timeout_seconds: 5 }, - refs: HashMap::new(), - }; - let result = - resolve_exec("test", "exec:///nonexistent/script arg1 arg2", &cfg).await; - let err = result.unwrap_err().to_string(); - assert!(err.contains("not found")); - assert!(err.contains("pre_boot")); - } - - #[tokio::test] - async fn resolve_exec_nonzero_exit() { - use crate::config::{AwsSecretsConfig, ExecSecretsConfig, SecretsConfig}; - let cfg = SecretsConfig { - aws: AwsSecretsConfig::default(), - exec: ExecSecretsConfig { timeout_seconds: 5 }, - refs: HashMap::new(), - }; - let result = resolve_exec("test", "exec:///bin/false", &cfg).await; - assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("exited with")); - } - - #[tokio::test] - async fn resolve_exec_strips_trailing_newline() { - use crate::config::{AwsSecretsConfig, ExecSecretsConfig, SecretsConfig}; - let cfg = SecretsConfig { - aws: AwsSecretsConfig::default(), - exec: ExecSecretsConfig { timeout_seconds: 5 }, - refs: HashMap::new(), - }; - // printf adds no newline, echo does — test both - let result = resolve_exec("test", "exec:///bin/echo secret_value", &cfg).await; - assert_eq!(result.unwrap(), "secret_value"); - } - - #[tokio::test] - async fn resolve_empty_refs_returns_empty() { - use crate::config::{AwsSecretsConfig, ExecSecretsConfig, SecretsConfig}; - let cfg = SecretsConfig { - aws: AwsSecretsConfig::default(), - exec: ExecSecretsConfig::default(), - refs: HashMap::new(), - }; - let result = resolve(&cfg).await.unwrap(); - assert!(result.is_empty()); - } - - #[tokio::test] - async fn resolve_unknown_scheme_fails() { - use crate::config::{AwsSecretsConfig, ExecSecretsConfig, SecretsConfig}; - let mut refs = HashMap::new(); - refs.insert("bad".to_owned(), "ftp://something".to_owned()); - let cfg = SecretsConfig { - aws: AwsSecretsConfig::default(), - exec: ExecSecretsConfig::default(), - refs, - }; - let result = resolve(&cfg).await; - assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("unrecognized URI scheme")); - } - - #[cfg(feature = "secrets-aws")] - #[test] - fn parse_aws_sm_uri_hash_in_secret_name() { - // rsplit_once('#') should split on the LAST # - let uri = "aws-sm://my#secret#api_key"; - let (id, key) = parse_aws_sm_uri(uri).unwrap(); - assert_eq!(id, "my#secret"); - assert_eq!(key, "api_key"); - } - - #[tokio::test] - async fn resolve_exec_sanitized_env() { - use crate::config::{AwsSecretsConfig, ExecSecretsConfig, SecretsConfig}; - // Set a dummy env var that should NOT be visible to the exec script - std::env::set_var("OPENAB_TEST_LEAKED_SECRET", "should_not_leak"); - let cfg = SecretsConfig { - aws: AwsSecretsConfig::default(), - exec: ExecSecretsConfig { timeout_seconds: 5 }, - refs: HashMap::new(), - }; - // /usr/bin/env prints all env vars; grep for our dummy var - let result = resolve_exec("test", "exec:///usr/bin/env", &cfg).await.unwrap(); - assert!( - !result.contains("OPENAB_TEST_LEAKED_SECRET"), - "exec script should not see unrelated env vars" - ); - // HOME and PATH should still be present - assert!(result.contains("HOME="), "HOME should be in sanitized env"); - assert!(result.contains("PATH="), "PATH should be in sanitized env"); - std::env::remove_var("OPENAB_TEST_LEAKED_SECRET"); - } - - #[tokio::test] - async fn resolve_exec_timeout() { - use crate::config::{AwsSecretsConfig, ExecSecretsConfig, SecretsConfig}; - let cfg = SecretsConfig { - aws: AwsSecretsConfig::default(), - exec: ExecSecretsConfig { timeout_seconds: 1 }, - refs: HashMap::new(), - }; - // sleep 10 will be killed after 1s timeout - let result = resolve_exec("test", "exec:///bin/sleep 10", &cfg).await; - let err = result.unwrap_err().to_string(); - assert!(err.contains("timed out"), "expected timeout error, got: {err}"); - } - - #[test] - fn substitute_and_reparse_integration() { - let mut secrets = HashMap::new(); - secrets.insert("token".to_owned(), "xoxb-secret-value".to_owned()); - secrets.insert("key".to_owned(), "sk-with\"special\\chars".to_owned()); - - let raw = r#" -[discord] -bot_token = "${secrets.token}" - -[agent] -command = "echo" -args = ["--key", "${secrets.key}"] -"#; - let substituted = substitute(raw, &secrets); - // Verify the substituted text is valid TOML that parses correctly - let cfg: crate::config::Config = toml::from_str(&substituted) - .expect("substituted config should be valid TOML"); - assert_eq!(cfg.discord.unwrap().bot_token, "xoxb-secret-value"); - assert_eq!(cfg.agent.args[1], "sk-with\"special\\chars"); - } -} diff --git a/src/setup/config.rs b/src/setup/config.rs deleted file mode 100644 index ee1483650..000000000 --- a/src/setup/config.rs +++ /dev/null @@ -1,157 +0,0 @@ -//! Config generation and TOML serialization for the setup wizard. - -/// Mask bot token in config output for preview -pub fn mask_bot_token(config: &str) -> String { - config - .lines() - .map(|line| { - if line.trim_start().starts_with("bot_token") { - "bot_token = \"***\"".to_string() - } else { - line.to_string() - } - }) - .collect::>() - .join("\n") -} - -#[derive(serde::Serialize)] -pub(crate) struct ConfigToml { - discord: DiscordConfigToml, - agent: AgentConfigToml, - pool: PoolConfigToml, - reactions: ReactionsConfigToml, -} - -#[derive(serde::Serialize)] -struct DiscordConfigToml { - bot_token: String, - allowed_channels: Vec, -} - -#[derive(serde::Serialize)] -struct AgentConfigToml { - command: String, - args: Vec, - working_dir: String, -} - -#[derive(serde::Serialize)] -struct PoolConfigToml { - max_sessions: usize, - session_ttl_hours: u64, -} - -#[derive(serde::Serialize)] -struct ReactionsConfigToml { - enabled: bool, - remove_after_reply: bool, - emojis: EmojisToml, - timing: TimingToml, -} - -#[derive(serde::Serialize)] -struct EmojisToml { - queued: String, - thinking: String, - tool: String, - coding: String, - web: String, - done: String, - error: String, -} - -#[derive(serde::Serialize)] -struct TimingToml { - debounce_ms: u64, - stall_soft_ms: u64, - stall_hard_ms: u64, - done_hold_ms: u64, - error_hold_ms: u64, -} - -pub fn generate_config( - bot_token: &str, - agent_command: &str, - channel_ids: Vec, - working_dir: &str, - max_sessions: usize, - session_ttl_hours: u64, -) -> String { - let config = ConfigToml { - discord: DiscordConfigToml { - bot_token: bot_token.to_string(), - allowed_channels: channel_ids, - }, - agent: { - let (command, args): (&str, Vec) = match agent_command { - "kiro" => ("kiro-cli", vec!["acp".into(), "--trust-all-tools".into()]), - "claude" => ("claude-agent-acp", vec![]), - "codex" => ("codex-acp", vec![]), - "gemini" => ("gemini", vec!["--acp".into()]), - other => (other, vec![]), - }; - AgentConfigToml { - command: command.to_string(), - args, - working_dir: working_dir.to_string(), - } - }, - pool: PoolConfigToml { - max_sessions, - session_ttl_hours, - }, - reactions: ReactionsConfigToml { - enabled: true, - remove_after_reply: false, - emojis: EmojisToml { - queued: "👀".into(), - thinking: "🤔".into(), - tool: "🔥".into(), - coding: "👨💻".into(), - web: "⚡".into(), - done: "🆗".into(), - error: "😱".into(), - }, - timing: TimingToml { - debounce_ms: 700, - stall_soft_ms: 10_000, - stall_hard_ms: 30_000, - done_hold_ms: 1_500, - error_hold_ms: 2_500, - }, - }, - }; - toml::to_string_pretty(&config).expect("TOML serialization failed") -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn generate_config_contains_sections() { - let config = generate_config( - "my_token", - "claude", - vec!["123".to_string()], - "/home/agent", - 10, - 24, - ); - assert!(config.contains("[discord]")); - assert!(config.contains("[agent]")); - assert!(config.contains("[pool]")); - assert!(config.contains("[reactions]")); - assert!(config.contains("[reactions.emojis]")); - assert!(config.contains("[reactions.timing]")); - } - - #[test] - fn generate_config_kiro_working_dir() { - let config = generate_config("tok", "kiro", vec!["ch".to_string()], "/home/agent", 10, 24); - assert!(config.contains(r#"working_dir = "/home/agent""#)); - assert!(config.contains("acp")); - assert!(config.contains("--trust-all-tools")); - } -} diff --git a/src/setup/mod.rs b/src/setup/mod.rs deleted file mode 100644 index 96034f0ab..000000000 --- a/src/setup/mod.rs +++ /dev/null @@ -1,12 +0,0 @@ -//! OpenAB interactive setup wizard. -//! -//! Modules: -//! - `validate` — input validation (bot token, channel ID, agent command) -//! - `config` — TOML config generation and serialization -//! - `wizard` — interactive TUI, Discord API client, and wizard entry point - -mod config; -mod validate; -mod wizard; - -pub use wizard::run_setup; diff --git a/src/setup/validate.rs b/src/setup/validate.rs deleted file mode 100644 index b09401559..000000000 --- a/src/setup/validate.rs +++ /dev/null @@ -1,78 +0,0 @@ -//! Input validation functions for the setup wizard. - -/// Validate bot token format using allowlist (a-zA-Z0-9-./_) -pub fn validate_bot_token(token: &str) -> anyhow::Result<()> { - if token.is_empty() { - anyhow::bail!("Token cannot be empty"); - } - if !token.chars().all(|c| { - c.is_ascii_alphanumeric() - || c == '-' - || c == '.' - || c == '_' - || c == '/' - || c == '*' - || c == '=' - }) { - anyhow::bail!( - "Token must only contain ASCII letters, numbers, dash, period, underscore, slash, or equals" - ); - } - Ok(()) -} - -/// Validate agent command -#[cfg(test)] -pub fn validate_agent_command(cmd: &str) -> anyhow::Result<()> { - let valid = ["kiro", "claude", "codex", "gemini"]; - if !valid.contains(&cmd) { - anyhow::bail!("Agent must be one of: {}", valid.join(", ")); - } - Ok(()) -} - -/// Validate channel ID is numeric -pub fn validate_channel_id(id: &str) -> anyhow::Result<()> { - if id.is_empty() { - anyhow::bail!("Channel ID cannot be empty"); - } - if !id.chars().all(|c| c.is_ascii_digit()) { - anyhow::bail!("Channel ID must be numeric only"); - } - Ok(()) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn validate_bot_token_ok() { - assert!(validate_bot_token("simple_token").is_ok()); - assert!(validate_bot_token("token.with-dashes_123").is_ok()); - assert!(validate_bot_token("***/efgh").is_ok()); - } - - #[test] - fn validate_bot_token_reject_invalid() { - assert!(validate_bot_token("").is_err()); - assert!(validate_bot_token("token\nnewline").is_err()); - assert!(validate_bot_token("token\ttab").is_err()); - assert!(validate_bot_token("token with space").is_err()); - } - - #[test] - fn validate_agent_command_known_and_unknown() { - for agent in &["kiro", "claude", "codex", "gemini"] { - assert!(validate_agent_command(agent).is_ok()); - } - assert!(validate_agent_command("invalid").is_err()); - } - - #[test] - fn validate_channel_id_accepts_numeric_rejects_invalid() { - assert!(validate_channel_id("1492329565824094370").is_ok()); - assert!(validate_channel_id("").is_err()); - assert!(validate_channel_id("abc123").is_err()); - } -} diff --git a/src/setup/wizard.rs b/src/setup/wizard.rs deleted file mode 100644 index f5a789609..000000000 --- a/src/setup/wizard.rs +++ /dev/null @@ -1,667 +0,0 @@ -//! Interactive setup wizard TUI and Discord API client. - -use std::io::{self, IsTerminal, Write}; -use std::path::{Path, PathBuf}; - -use crate::setup::config::{generate_config, mask_bot_token}; -use crate::setup::validate::{validate_bot_token, validate_channel_id}; - -// --------------------------------------------------------------------------- -// Color codes (ANSI) -// --------------------------------------------------------------------------- - -const C: Colors = Colors { - reset: "\x1b[0m", - bold: "\x1b[1m", - cyan: "\x1b[36m", - green: "\x1b[32m", - red: "\x1b[31m", - yellow: "\x1b[33m", - magenta: "\x1b[35m", -}; - -struct Colors { - reset: &'static str, - bold: &'static str, - cyan: &'static str, - green: &'static str, - red: &'static str, - yellow: &'static str, - magenta: &'static str, -} - -const BORDER: char = '═'; - -macro_rules! cprintln { - ($color:expr, $fmt:expr) => {{ - println!("{}{}{}", $color, $fmt, C.reset); - }}; - ($color:expr, $fmt:expr, $($arg:tt)*) => {{ - println!("{}{}{}", $color, format!($fmt, $($arg)*), C.reset); - }}; -} - -// --------------------------------------------------------------------------- -// Input helpers -// --------------------------------------------------------------------------- - -fn is_interactive() -> bool { - std::io::stdin().is_terminal() && std::io::stdout().is_terminal() -} - -fn prompt(prompt_text: &str) -> String { - print!("{}{}: {}", C.yellow, prompt_text, C.reset); - io::stdout().flush().ok(); - let mut input = String::new(); - io::stdin().read_line(&mut input).ok(); - input.trim().to_string() -} - -fn prompt_default(prompt_text: &str, default: &str) -> String { - print!("{}{} [{}]: {}", C.yellow, prompt_text, default, C.reset); - io::stdout().flush().ok(); - let mut input = String::new(); - io::stdin().read_line(&mut input).ok(); - let input = input.trim(); - if input.is_empty() { - default.to_string() - } else { - input.to_string() - } -} - -fn prompt_password(prompt_text: &str) -> String { - print!("{}{}: ", C.yellow, prompt_text); - io::stdout().flush().ok(); - rpassword::read_password().unwrap_or_default() -} - -fn prompt_yes_no(prompt_text: &str, default: bool) -> bool { - let default_str = if default { "Y/n" } else { "y/N" }; - loop { - print!("{}{} [{}]: ", C.yellow, prompt_text, default_str,); - io::stdout().flush().ok(); - let mut input = String::new(); - io::stdin().read_line(&mut input).ok(); - let input = input.trim().to_lowercase(); - if input.is_empty() { - return default; - } - match input.as_str() { - "y" | "yes" => return true, - "n" | "no" => return false, - _ => cprintln!(C.red, "Please enter 'y' or 'n'"), - } - } -} - -fn prompt_choice(prompt_text: &str, choices: &[&str]) -> usize { - println!(); - cprintln!(C.cyan, "{}", prompt_text); - for (i, choice) in choices.iter().enumerate() { - println!(" {}. {}", i + 1, choice); - } - print!("{}Select [1-{}]: {}", C.yellow, choices.len(), C.reset); - io::stdout().flush().ok(); - loop { - let mut input = String::new(); - io::stdin().read_line(&mut input).ok(); - match input.trim().parse::() { - Ok(n) if n >= 1 && n <= choices.len() => return n - 1, - _ => { - print!("{}Select [1-{}]: {}", C.yellow, choices.len(), C.reset); - io::stdout().flush().ok(); - } - } - } -} - -fn prompt_checklist(prompt_text: &str, items: &[&str]) -> Vec { - println!(); - cprintln!(C.cyan, "{}", prompt_text); - for (i, item) in items.iter().enumerate() { - println!(" [{}] {}", i + 1, item); - } - println!(); - print!( - "{}Enter numbers separated by commas (e.g. 1,3,5) or press Enter for all: {}", - C.yellow, C.reset - ); - io::stdout().flush().ok(); - let mut input = String::new(); - io::stdin().read_line(&mut input).ok(); - let input = input.trim(); - if input.is_empty() { - return (0..items.len()).collect(); - } - input - .split(',') - .filter_map(|s| s.trim().parse::().ok()) - .filter(|n| *n >= 1 && *n <= items.len()) - .map(|n| n - 1) - .collect() -} - -// --------------------------------------------------------------------------- -// Box drawing helpers -// --------------------------------------------------------------------------- - -fn print_box(lines: &[&str]) { - let width = lines - .iter() - .map(|l| unicode_width::UnicodeWidthStr::width(&**l)) - .max() - .unwrap_or(60); - let width = width.clamp(60, 76); - println!(); - cprintln!( - C.cyan, - "{}", - "╔".to_string() + &BORDER.to_string().repeat(width + 2) + "╗" - ); - for line in lines { - let padded = format!(" {: Self { - Self { - token: token.to_string(), - http: reqwest::blocking::Client::builder() - .timeout(std::time::Duration::from_secs(10)) - .build() - .expect("static HTTP client must build"), - } - } - - /// Verify token by fetching bot info - fn verify_token(&self) -> anyhow::Result<(String, String)> { - let resp = self - .http - .get("https://discord.com/api/v10/users/@me") - .header("Authorization", format!("Bot {}", self.token)) - .header("User-Agent", "OpenAB setup wizard") - .send()?; - if !resp.status().is_success() { - anyhow::bail!("Token verification failed: HTTP {}", resp.status()); - } - #[derive(serde::Deserialize)] - struct MeResponse { - id: String, - username: String, - } - let me: MeResponse = resp.json()?; - Ok((me.id, me.username)) - } - - /// Fetch guilds the bot is in - fn fetch_guilds(&self) -> anyhow::Result> { - let resp = self - .http - .get("https://discord.com/api/v10/users/@me/guilds") - .header("Authorization", format!("Bot {}", self.token)) - .header("User-Agent", "OpenAB setup wizard") - .send()?; - if !resp.status().is_success() { - anyhow::bail!("Failed to fetch guilds: HTTP {}", resp.status()); - } - #[derive(serde::Deserialize)] - struct Guild { - id: String, - name: String, - } - let guilds: Vec = resp.json()?; - Ok(guilds.into_iter().map(|g| (g.id, g.name)).collect()) - } - - /// Fetch channels in a guild - fn fetch_channels(&self, guild_id: &str) -> anyhow::Result> { - let url = format!("https://discord.com/api/v10/guilds/{}/channels", guild_id); - let resp = self - .http - .get(&url) - .header("Authorization", format!("Bot {}", self.token)) - .header("User-Agent", "OpenAB setup wizard") - .send()?; - if !resp.status().is_success() { - anyhow::bail!("Failed to fetch channels: HTTP {}", resp.status()); - } - #[derive(serde::Deserialize)] - struct Channel { - id: String, - #[serde(rename = "type")] - kind: u8, - name: String, - } - let channels: Vec = resp.json()?; - // type 0 = text channel - Ok(channels - .into_iter() - .filter(|c| c.kind == 0) - .map(|c| (c.id, c.name, guild_id.to_string())) - .collect()) - } -} - -// --------------------------------------------------------------------------- -// Section 1: Discord Bot Setup Guide -// --------------------------------------------------------------------------- - -fn section_discord_guide() { - print_box(&[ - "Discord Bot Setup Guide", - "", - "1. Go to: https://discord.com/developers/applications", - "2. Click 'New Application' -> name it (e.g. OpenAB)", - "3. Bot -> Reset Token -> COPY the token", - "", - "4. Enable Privileged Gateway Intents:", - " - Message Content Intent", - " - Guild Members Intent", - "", - "5. OAuth2 -> URL Generator:", - " - SCOPES: bot", - " - BOT PERMISSIONS:", - " Send Messages | Embed Links | Attach Files", - " Read Message History | Add Reactions", - " Use Slash Commands", - "", - "6. Visit the generated URL -> add bot to your server", - ]); -} - -// --------------------------------------------------------------------------- -// Section 2: Channel Selection -// --------------------------------------------------------------------------- - -fn section_channels(client: &DiscordClient) -> anyhow::Result> { - println!(); - cprintln!(C.bold, "--- Step 2: Allowed Channels ---"); - println!(); - - print!(" Fetching servers... "); - io::stdout().flush().ok(); - let guilds = client.fetch_guilds()?; - cprintln!(C.green, "OK Found {} server(s)", guilds.len()); - println!(); - - if guilds.is_empty() { - cprintln!(C.yellow, " No servers found. Enter channel IDs manually."); - let input = prompt(" Channel ID(s), comma-separated"); - let ids: Vec = input - .split(',') - .map(|s| s.trim().to_string()) - .filter(|s| !s.is_empty()) - .collect(); - for id in &ids { - validate_channel_id(id)?; - } - return Ok(ids); - } - - let guild_names: Vec<&str> = guilds.iter().map(|(_, n)| n.as_str()).collect(); - let guild_idx = prompt_choice(" Select server:", &guild_names); - let (guild_id, guild_name) = &guilds[guild_idx]; - - print!(" Fetching channels in '{}'... ", guild_name); - io::stdout().flush().ok(); - let channels = client.fetch_channels(guild_id)?; - cprintln!(C.green, "OK Found {} channel(s)", channels.len()); - println!(); - - if channels.is_empty() { - cprintln!( - C.yellow, - " No text channels found. Enter channel IDs manually." - ); - let input = prompt(" Channel ID(s), comma-separated"); - let ids: Vec = input - .split(',') - .map(|s| s.trim().to_string()) - .filter(|s| !s.is_empty()) - .collect(); - for id in &ids { - validate_channel_id(id)?; - } - return Ok(ids); - } - - let channel_names: Vec = channels.iter().map(|(_, n, _)| format!("#{}", n)).collect(); - let channel_names_refs: Vec<&str> = channel_names.iter().map(|s| s.as_str()).collect(); - - let selected = prompt_checklist(" Select channels (by number):", &channel_names_refs); - let selected_ids: Vec = selected.iter().map(|&i| channels[i].0.clone()).collect(); - - println!(); - cprintln!(C.green, " Selected {} channel(s)", selected_ids.len()); - for id in &selected_ids { - if let Some((_, name, _)) = channels.iter().find(|(cid, _, _)| cid == id) { - println!(" * #{}", name); - } else { - println!(" * {}", id); - } - } - println!(); - - Ok(selected_ids) -} - -// --------------------------------------------------------------------------- -// Section 3: Agent Configuration -// --------------------------------------------------------------------------- - -fn section_agent() -> (String, String, bool) { - println!(); - cprintln!(C.bold, "--- Step 3: Agent Configuration ---"); - println!(); - - print_box(&[ - "Agent Installation Guide", - "", - "claude: npm install -g @anthropic-ai/claude-code", - "kiro: npm install -g @koryhutchison/kiro-cli", - "codex: npm install -g openai-codex (requires OpenAI API key)", - "gemini: npm install -g @google/gemini-cli", - "", - "Make sure the agent is in your PATH before continuing.", - ]); - println!(); - - let choices = ["claude", "kiro", "codex", "gemini"]; - let idx = prompt_choice(" Select agent:", &choices); - let agent = choices[idx]; - - let deploy_choices = ["Local (current directory)", "Docker / k8s"]; - let deploy_idx = prompt_choice(" Deployment target:", &deploy_choices); - let is_local = deploy_idx == 0; - let default_dir = match (is_local, agent) { - (true, _) => ".", - (false, "kiro") => "/home/agent", - (false, _) => "/home/node", - }; - - let working_dir = prompt_default(" Working directory", default_dir); - - cprintln!(C.green, " Agent: {} | Working dir: {}", agent, working_dir); - println!(); - - (agent.to_string(), working_dir, is_local) -} - -// --------------------------------------------------------------------------- -// Section 4: Pool Settings -// --------------------------------------------------------------------------- - -fn section_pool() -> (usize, u64) { - println!(); - cprintln!(C.bold, "--- Step 4: Session Pool ---"); - println!(); - - let max_sessions: usize = prompt_default(" Max sessions", "10").parse().unwrap_or(10); - let ttl_hours: u64 = prompt_default(" Session TTL (hours)", "24") - .parse() - .unwrap_or(24); - - cprintln!( - C.green, - " Max sessions: {} | TTL: {}h", - max_sessions, - ttl_hours - ); - println!(); - - (max_sessions, ttl_hours) -} - -// --------------------------------------------------------------------------- -// Preview & Save -// --------------------------------------------------------------------------- - -fn section_preview_and_save(config_content: &str, output_path: &PathBuf) -> anyhow::Result<()> { - println!(); - cprintln!(C.bold, "--- Preview ---"); - println!(); - println!("{}", mask_bot_token(config_content)); - println!(); - - if output_path.exists() && !prompt_yes_no(" File exists. Overwrite?", false) { - println!(" Saving cancelled."); - return Ok(()); - } - - std::fs::write(output_path, config_content)?; - cprintln!(C.green, "OK config.toml saved to {}", output_path.display()); - println!(); - - Ok(()) -} - -// --------------------------------------------------------------------------- -// Non-interactive guidance -// --------------------------------------------------------------------------- - -fn print_noninteractive_guide() { - print_box(&[ - "Non-Interactive Mode", - "", - "The interactive wizard requires a terminal.", - "Create config.toml manually, then run:", - "", - " openab run config.toml", - "", - "Config format reference:", - " [discord]", - " bot_token = \"YOUR_BOT_TOKEN\"", - " allowed_channels = [\"CHANNEL_ID\"]", - "", - " [agent]", - " command = \"kiro-cli\"", - " args = [\"acp\", \"--trust-all-tools\"]", - " working_dir = \"/home/agent\"", - "", - " [pool]", - " max_sessions = 10", - " session_ttl_hours = 24", - "", - " [reactions]", - " enabled = true", - " remove_after_reply = false", - " ...", - ]); -} - -// --------------------------------------------------------------------------- -// Next steps printer -// --------------------------------------------------------------------------- - -fn print_next_steps(agent: &str, output_path: &Path, is_local: bool) { - println!(); - cprintln!(C.bold, "--- Next Steps ---"); - println!(); - - if is_local { - match agent { - "kiro" => { - cprintln!( - C.cyan, - " 1. Install kiro-cli (see https://kiro.dev for installer)" - ); - cprintln!(C.cyan, " 2. Authenticate:"); - println!(" kiro-cli login --use-device-flow"); - } - "claude" => { - cprintln!(C.cyan, " 1. Install Claude Code + ACP adapter:"); - println!(" npm install -g @anthropic-ai/claude-code @agentclientprotocol/claude-agent-acp"); - cprintln!(C.cyan, " 2. Authenticate:"); - println!(" claude auth login"); - } - "codex" => { - cprintln!(C.cyan, " 1. Install Codex CLI + ACP adapter:"); - println!(" npm install -g @openai/codex @zed-industries/codex-acp"); - cprintln!(C.cyan, " 2. Authenticate:"); - println!(" codex login --device-auth"); - } - "gemini" => { - cprintln!(C.cyan, " 1. Install Gemini CLI:"); - println!(" npm install -g @google/gemini-cli"); - cprintln!( - C.cyan, - " 2. Authenticate via Google OAuth, or set GEMINI_API_KEY in config.toml" - ); - } - _ => {} - } - - println!(); - cprintln!(C.green, " 3. Run the bot:"); - println!(" cargo run -- run {}", output_path.display()); - } else { - cprintln!( - C.cyan, - " Docker image already bundles the agent CLI and ACP adapter." - ); - println!(); - cprintln!(C.cyan, " 1. Deploy with Helm (or your preferred method):"); - println!(" helm install openab openab/openab \\"); - println!( - " --set agents.{}.discord.botToken=\"$BOT_TOKEN\"", - agent - ); - println!(); - cprintln!( - C.cyan, - " 2. Authenticate inside the pod (first time only):" - ); - match agent { - "kiro" => println!( - " kubectl exec -it deployment/openab-kiro -- kiro-cli login --use-device-flow" - ), - "claude" => { - println!(" kubectl exec -it deployment/openab-claude -- claude auth login") - } - "codex" => println!( - " kubectl exec -it deployment/openab-codex -- codex login --device-auth" - ), - "gemini" => { - println!(" Set GEMINI_API_KEY via secret, or exec into the pod for OAuth") - } - _ => {} - } - println!(); - cprintln!(C.green, " See README for full Helm options."); - } - println!(); -} - -// --------------------------------------------------------------------------- -// Main wizard entry point -// --------------------------------------------------------------------------- - -pub fn run_setup(output_path: Option) -> anyhow::Result<()> { - if !is_interactive() { - print_noninteractive_guide(); - return Ok(()); - } - - println!(); - cprintln!( - C.magenta, - "============================================================" - ); - cprintln!( - C.magenta, - " OpenAB Interactive Setup Wizard " - ); - cprintln!( - C.magenta, - "============================================================" - ); - - // Step 1: Discord Guide + Token - section_discord_guide(); - println!(); - let bot_token = prompt_password(" Bot Token (or press Enter to skip)"); - if bot_token.is_empty() { - cprintln!(C.yellow, " Skipped. Set bot_token manually in config.toml"); - println!(); - cprintln!( - C.green, - " Setup complete! Edit config.toml to add your bot token." - ); - return Ok(()); - } - validate_bot_token(&bot_token)?; - - let client = DiscordClient::new(&bot_token); - print!(" Verifying token with Discord API... "); - io::stdout().flush().ok(); - let (_bot_id, bot_username) = client.verify_token()?; - cprintln!(C.green, "OK Logged in as {}", bot_username); - - // Step 2: Channels - let channel_ids = match section_channels(&client) { - Ok(ids) if !ids.is_empty() => ids, - Ok(_) => { - cprintln!(C.yellow, " No channels selected."); - vec![] - } - Err(e) => { - cprintln!(C.yellow, " Channel fetch failed: {}. Enter manually.", e); - let input = prompt(" Channel ID(s), comma-separated"); - let ids: Vec = input - .split(',') - .map(|s| s.trim().to_string()) - .filter(|s| !s.is_empty()) - .collect(); - for id in &ids { - validate_channel_id(id).map_err(|e| anyhow::anyhow!("{}", e))?; - } - ids - } - }; - - // Step 3: Agent - let (agent, working_dir, is_local) = section_agent(); - - // Step 4: Pool - let (max_sessions, ttl_hours) = section_pool(); - - // Generate - let config_content = generate_config( - &bot_token, - &agent, - channel_ids, - &working_dir, - max_sessions, - ttl_hours, - ); - - // Output - let output_path = output_path.unwrap_or_else(|| PathBuf::from("config.toml")); - section_preview_and_save(&config_content, &output_path)?; - - print_next_steps(&agent, &output_path, is_local); - - Ok(()) -} diff --git a/src/slack.rs b/src/slack.rs deleted file mode 100644 index 94fdcf0a1..000000000 --- a/src/slack.rs +++ /dev/null @@ -1,2329 +0,0 @@ -use crate::acp::ContentBlock; -use crate::adapter::{ChannelRef, ChatAdapter, MessageRef, SenderContext}; -use crate::bot_turns::{BotTurnTracker, TurnAction, TurnSeverity}; -use crate::config::{AllowBots, AllowUsers, SttConfig}; -use crate::media; -use anyhow::{anyhow, Result}; -use async_trait::async_trait; -use futures_util::{SinkExt, StreamExt}; -use std::collections::{HashMap, HashSet}; -use std::sync::{Arc, LazyLock}; -use tokio::sync::watch; -use tokio_tungstenite::tungstenite; -use tracing::{debug, error, info, warn}; - -const SLACK_API: &str = "https://slack.com/api"; - -/// Map Unicode emoji to Slack short names for reactions API. -/// Only covers the default `[reactions.emojis]` set. Custom emoji configured -/// outside this map will fall back to `grey_question`. -fn unicode_to_slack_emoji(unicode: &str) -> &str { - match unicode { - "👀" => "eyes", - "🤔" => "thinking_face", - "🔥" => "fire", - "👨\u{200d}💻" => "technologist", - "⚡" => "zap", - "🆗" => "ok", - "😱" => "scream", - "🚫" => "no_entry_sign", - "😊" => "blush", - "😎" => "sunglasses", - "🫡" => "saluting_face", - "🤓" => "nerd_face", - "😏" => "smirk", - "✌\u{fe0f}" => "v", - "💪" => "muscle", - "🦾" => "mechanical_arm", - "🥱" => "yawning_face", - "😨" => "fearful", - "✅" => "white_check_mark", - "❌" => "x", - "🔧" => "wrench", - "🎤" => "microphone", - _ => "grey_question", - } -} - -// --- SlackAdapter: implements ChatAdapter for Slack --- - -/// TTL for cached user display names (5 minutes). -const USER_CACHE_TTL: std::time::Duration = std::time::Duration::from_secs(300); - -/// Maximum entries in the participation cache before eviction. -const PARTICIPATION_CACHE_MAX: usize = 1000; - -/// Maximum entries in the streams map before eviction (safety net for -/// aborted turns that begin a stream but never reach stream_finish). -const STREAM_CACHE_MAX: usize = 1024; - -#[derive(Default)] -struct StreamEntry { - active: bool, - degraded_buf: String, -} - -pub struct SlackAdapter { - client: reqwest::Client, - bot_token: String, - bot_user_id: tokio::sync::OnceCell, - user_cache: tokio::sync::Mutex>, - /// Cache: Bot ID (B...) → Bot User ID (U...) for trusted_bot_ids matching. - bot_id_cache: tokio::sync::Mutex>, - /// Positive-only cache: thread_ts → cached_at for threads where bot has participated. - participated_threads: tokio::sync::Mutex>, - /// Positive-only cache: thread_ts → cached_at for threads where other bots have posted. - /// Like participation, a thread becoming multi-bot is irreversible (bot messages don't disappear). - multibot_threads: tokio::sync::Mutex>, - /// Persistent disk cache for multibot thread detection (survives restarts). - multibot_cache: crate::multibot_cache::MultibotCache, - /// TTL for participation cache entries (matches session_ttl_hours from config). - session_ttl: std::time::Duration, - /// Assistant mode: stream via chat.startStream + assistant.threads.setStatus. - assistant_mode: bool, - /// streaming message ts → state. active=false = degraded (post+edit fallback). - /// Lifecycle: stream_begin inserts, stream_finish removes; insert_stream - /// bounds the map (STREAM_CACHE_MAX) as a safety net against aborted turns. - streams: tokio::sync::Mutex>, -} - -impl SlackAdapter { - pub fn new( - bot_token: String, - session_ttl: std::time::Duration, - _allow_bot_messages: AllowBots, - assistant_mode: bool, - multibot_cache: crate::multibot_cache::MultibotCache, - ) -> Self { - Self { - // Bound every Slack Web API call; an unbounded inline gating call in the - // read loop could otherwise stall the Socket Mode idle-timeout watchdog. - client: reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(30)) - .build() - .unwrap_or_else(|_| reqwest::Client::new()), - bot_token, - bot_user_id: tokio::sync::OnceCell::new(), - user_cache: tokio::sync::Mutex::new(HashMap::new()), - bot_id_cache: tokio::sync::Mutex::new(HashMap::new()), - participated_threads: tokio::sync::Mutex::new(HashMap::new()), - multibot_threads: tokio::sync::Mutex::new(HashMap::new()), - multibot_cache, - session_ttl, - assistant_mode, - streams: tokio::sync::Mutex::new(HashMap::new()), - } - } - - /// Returns the bot token for use in API calls outside the adapter. - pub fn bot_token(&self) -> &str { - &self.bot_token - } - - /// Eagerly record that another bot has posted in a thread. Called from the - /// event loop when a bot message arrives, so multibot detection doesn't - /// depend on fetching thread history. Idempotent. - async fn note_other_bot_in_thread(&self, thread_ts: &str) { - { - let mut cache = self.multibot_threads.lock().await; - cache - .entry(thread_ts.to_string()) - .or_insert_with(tokio::time::Instant::now); - enforce_cache_bounds(&mut cache, self.session_ttl); - } - // Persist to disk — multibot is irreversible - self.multibot_cache.mark_multibot(thread_ts).await; - } - - - /// Insert a stream entry, bounding the map so aborted turns (begin without a - /// matching finish) can't leak unboundedly. Normal lifecycle: stream_begin - /// inserts, stream_finish removes. - async fn insert_stream(&self, ts: String, entry: StreamEntry) { - let mut map = self.streams.lock().await; - if map.len() >= STREAM_CACHE_MAX { - // Only evict inactive (degraded/stale) streams to avoid cutting off - // active streams mid-turn. If no inactive entries exist, fall through - // and allow the map to grow slightly beyond the soft cap. - let evict: Vec = map - .iter() - .filter(|(_, e)| !e.active) - .map(|(k, _)| k.clone()) - .collect(); - for k in evict { - map.remove(&k); - } - } - map.insert(ts, entry); - } - - /// Accumulate a delta into a degraded stream's buffer and return the new - /// cumulative text. Returns None if no (degraded) stream entry exists for - /// `ts` — never resurrects a removed/absent stream. No network I/O. - async fn accumulate_degraded(&self, ts: &str, delta: &str) -> Option { - let mut map = self.streams.lock().await; - let entry = map.get_mut(ts)?; - entry.degraded_buf.push_str(delta); - Some(entry.degraded_buf.clone()) - } - - /// Get the bot's own Slack user ID (cached after first call). - async fn get_bot_user_id(&self) -> Option<&str> { - self.bot_user_id - .get_or_try_init(|| async { - let resp = self - .api_post("auth.test", serde_json::json!({})) - .await - .map_err(|e| anyhow!("auth.test failed: {e}"))?; - resp["user_id"] - .as_str() - .map(|s| s.to_string()) - .ok_or_else(|| anyhow!("no user_id in auth.test response")) - }) - .await - .inspect_err(|e| warn!(error = %e, "bot user ID unavailable; mention detection may suppress bot messages under Mentions mode")) - .ok() - .map(|s| s.as_str()) - } - - async fn api_post(&self, method: &str, body: serde_json::Value) -> Result { - let resp = self - .client - .post(format!("{SLACK_API}/{method}")) - .header("Authorization", format!("Bearer {}", self.bot_token)) - .header("Content-Type", "application/json; charset=utf-8") - .json(&body) - .send() - .await?; - - let json: serde_json::Value = resp.json().await?; - if json["ok"].as_bool() != Some(true) { - let err = json["error"].as_str().unwrap_or("unknown error"); - return Err(anyhow!("Slack API {method}: {err}")); - } - Ok(json) - } - - /// Call a Slack API method using GET with query parameters. - /// Required for read methods like conversations.replies that don't accept JSON body. - async fn api_get(&self, method: &str, params: &[(&str, &str)]) -> Result { - let resp = self - .client - .get(format!("{SLACK_API}/{method}")) - .header("Authorization", format!("Bearer {}", self.bot_token)) - .query(params) - .send() - .await?; - - let json: serde_json::Value = resp.json().await?; - if json["ok"].as_bool() != Some(true) { - let err = json["error"].as_str().unwrap_or("unknown error"); - return Err(anyhow!("Slack API {method}: {err}")); - } - Ok(json) - } - - /// Resolve a Slack user ID to display name via users.info API. - /// Results are cached for 5 minutes to avoid hitting Slack rate limits. - async fn resolve_user_name(&self, user_id: &str) -> Option { - // Check cache first - { - let cache = self.user_cache.lock().await; - if let Some((name, ts)) = cache.get(user_id) { - if ts.elapsed() < USER_CACHE_TTL { - return Some(name.clone()); - } - } - } - - let resp = self - .api_post("users.info", serde_json::json!({ "user": user_id })) - .await - .ok()?; - let user = resp.get("user")?; - let profile = user.get("profile")?; - let display = profile - .get("display_name") - .and_then(|v| v.as_str()) - .filter(|s| !s.is_empty()); - let real = profile - .get("real_name") - .and_then(|v| v.as_str()) - .filter(|s| !s.is_empty()); - let name = user.get("name").and_then(|v| v.as_str()); - let resolved = display.or(real).or(name)?.to_string(); - - // Cache the result - self.user_cache.lock().await.insert( - user_id.to_string(), - (resolved.clone(), tokio::time::Instant::now()), - ); - - Some(resolved) - } - - /// Resolve a Bot ID (B...) to Bot User ID (U...) via bots.info API. - /// Cached permanently (bot IDs don't change). - async fn resolve_bot_user_id(&self, bot_id: &str) -> Option { - if bot_id.is_empty() { - return None; - } - - { - let cache = self.bot_id_cache.lock().await; - if let Some(user_id) = cache.get(bot_id) { - return Some(user_id.clone()); - } - } - - let resp = self - .api_post("bots.info", serde_json::json!({ "bot": bot_id })) - .await - .inspect_err(|e| { - warn!( - bot_id, - error = %e, - "failed to resolve Slack bot ID via bots.info" - ) - }) - .ok()?; - let user_id = resp.get("bot")?.get("user_id")?.as_str()?.to_string(); - - self.bot_id_cache - .lock() - .await - .insert(bot_id.to_string(), user_id.clone()); - - Some(user_id) - } - - async fn trusted_bot_ids_contains( - &self, - trusted_bot_ids: &HashSet, - event_bot_id: &str, - ) -> bool { - if trusted_bot_ids.is_empty() { - return true; - } - if bot_id_matches_trusted(trusted_bot_ids, event_bot_id, None) { - return true; - } - let resolved = self.resolve_bot_user_id(event_bot_id).await; - bot_id_matches_trusted(trusted_bot_ids, event_bot_id, resolved.as_deref()) - } - - /// Check whether the bot has participated in a Slack thread and whether - /// other bots have also posted in it. - /// Returns `(involved, other_bot_present)`. - /// Involved = parent message @mentions the bot OR any message in thread is from the bot. - /// Fail-closed: returns `(false, false)` on API error (consistent with Discord's approach). - /// Caches positive results only — both states are irreversible. - async fn bot_participated_in_thread(&self, channel: &str, thread_ts: &str) -> (bool, bool) { - let cached_involved = { - let cache = self.participated_threads.lock().await; - cache - .get(thread_ts) - .is_some_and(|ts| ts.elapsed() < self.session_ttl) - }; - let cached_multibot = { - let cache = self.multibot_threads.lock().await; - cache - .get(thread_ts) - .is_some_and(|ts| ts.elapsed() < self.session_ttl) - } || self.multibot_cache.is_multibot(thread_ts); - - // Eager multibot detection from message events populates the cache - // before this runs. When already involved and cached, skip the fetch. - if cached_involved { - return (true, cached_multibot); - } - - let bot_id = match self.get_bot_user_id().await { - Some(id) => id, - None => { - warn!("cannot resolve bot user ID, rejecting (fail-closed)"); - return (false, false); - } - }; - - let resp = self - .api_get( - "conversations.replies", - &[ - ("channel", channel), - ("ts", thread_ts), - ("limit", "200"), - ("inclusive", "true"), - ], - ) - .await; - - let json = match resp { - Ok(json) => json, - Err(e) => { - warn!(channel, thread_ts, error = %e, "failed to fetch thread replies, rejecting (fail-closed)"); - return (false, false); - } - }; - let Some(messages) = json["messages"].as_array() else { - return (false, false); - }; - - let parent_mentions_bot = messages - .first() - .and_then(|m| m["text"].as_str()) - .is_some_and(|text| text_mentions_uid(text, bot_id)); - - let bot_posted = messages.iter().any(|m| m["user"].as_str() == Some(bot_id)); - - let involved = parent_mentions_bot || bot_posted; - // other_bot_present relies solely on early detection + disk cache; - // no longer scanned from fetched messages (200-msg window was unreliable). - let other_bot_present = cached_multibot; - - if involved { - self.cache_participation(thread_ts).await; - } - - (involved, other_bot_present) - } - - /// Insert a positive participation entry, enforcing cache bounds. - async fn cache_participation(&self, thread_ts: &str) { - let mut cache = self.participated_threads.lock().await; - cache.insert(thread_ts.to_string(), tokio::time::Instant::now()); - enforce_cache_bounds(&mut cache, self.session_ttl); - } -} - -/// Shared eviction policy for positive-only caches. -/// First drops expired entries; if still over, drops the oldest half. -fn enforce_cache_bounds( - cache: &mut HashMap, - ttl: std::time::Duration, -) { - if cache.len() <= PARTICIPATION_CACHE_MAX { - return; - } - cache.retain(|_, ts| ts.elapsed() < ttl); - if cache.len() > PARTICIPATION_CACHE_MAX { - let mut entries: Vec<_> = cache.iter().map(|(k, v)| (k.clone(), *v)).collect(); - entries.sort_by_key(|(_, ts)| *ts); - let evict_count = entries.len() / 2; - for (key, _) in entries.into_iter().take(evict_count) { - cache.remove(&key); - } - } -} - -#[async_trait] -impl ChatAdapter for SlackAdapter { - fn platform(&self) -> &'static str { - "slack" - } - - fn message_limit(&self) -> usize { - // Match the Block Kit `markdown` block cap (12k) minus headroom. Messages - // are sent as markdown blocks, so the old 4000 mrkdwn-era limit would - // split long replies (and Markdown tables) across messages needlessly — - // a mid-table split renders as raw pipes. 11_900 keeps typical tables in - // one block and cuts message-spam on long replies. - MARKDOWN_BLOCK_LIMIT - } - - async fn send_message(&self, channel: &ChannelRef, content: &str) -> Result { - let thread_ts = channel.thread_id.as_deref(); - let body = build_post_message_body(&channel.channel_id, thread_ts, content); - let resp = match self.api_post("chat.postMessage", body).await { - Ok(r) => r, - // Graceful degradation: if the `blocks` payload is rejected (workspace - // lacks the markdown block, or content exceeds the cumulative block - // cap), retry text-only so the message still lands (mrkdwn fallback) - // instead of failing outright. - Err(e) if is_block_payload_rejected(&e) => { - warn!(error = %e, "markdown block rejected; retrying chat.postMessage text-only"); - let fallback = build_post_message_text_only(&channel.channel_id, thread_ts, content); - self.api_post("chat.postMessage", fallback).await? - } - Err(e) => return Err(e), - }; - let ts = resp["ts"] - .as_str() - .ok_or_else(|| anyhow!("no ts in chat.postMessage response"))?; - Ok(MessageRef { - channel: ChannelRef { - platform: "slack".into(), - channel_id: channel.channel_id.clone(), - thread_id: channel.thread_id.clone(), - parent_id: None, - origin_event_id: None, - }, - message_id: ts.to_string(), - }) - } - - async fn create_thread( - &self, - channel: &ChannelRef, - trigger_msg: &MessageRef, - _title: &str, - ) -> Result { - // Slack threads are implicit — posting with thread_ts creates/continues a thread. - Ok(ChannelRef { - platform: "slack".into(), - channel_id: channel.channel_id.clone(), - thread_id: Some(trigger_msg.message_id.clone()), - parent_id: None, - origin_event_id: None, - }) - } - - async fn add_reaction(&self, msg: &MessageRef, emoji: &str) -> Result<()> { - let name = unicode_to_slack_emoji(emoji); - match self - .api_post( - "reactions.add", - serde_json::json!({ - "channel": msg.channel.channel_id, - "timestamp": msg.message_id, - "name": name, - }), - ) - .await - { - Ok(_) => Ok(()), - Err(e) if e.to_string().contains("already_reacted") => Ok(()), - Err(e) => Err(e), - } - } - - async fn remove_reaction(&self, msg: &MessageRef, emoji: &str) -> Result<()> { - let name = unicode_to_slack_emoji(emoji); - match self - .api_post( - "reactions.remove", - serde_json::json!({ - "channel": msg.channel.channel_id, - "timestamp": msg.message_id, - "name": name, - }), - ) - .await - { - Ok(_) => Ok(()), - Err(e) if e.to_string().contains("no_reaction") => Ok(()), - Err(e) => Err(e), - } - } - - async fn edit_message(&self, msg: &MessageRef, content: &str) -> Result<()> { - let body = build_update_body(&msg.channel.channel_id, &msg.message_id, content); - match self.api_post("chat.update", body).await { - Ok(_) => Ok(()), - // See send_message: degrade to text-only if the blocks payload is rejected. - Err(e) if is_block_payload_rejected(&e) => { - warn!(error = %e, "markdown block rejected; retrying chat.update text-only"); - let fallback = - build_update_text_only(&msg.channel.channel_id, &msg.message_id, content); - self.api_post("chat.update", fallback).await?; - Ok(()) - } - Err(e) => Err(e), - } - } - - fn use_streaming(&self, other_bot_present: bool) -> bool { - !other_bot_present - } - - fn renders_native_tables(&self) -> bool { - true - } - - fn uses_assistant_status(&self) -> bool { - self.assistant_mode - } - - fn uses_native_streaming(&self, other_bot_present: bool) -> bool { - let native = self.assistant_mode && !other_bot_present; - debug!( - assistant_mode = self.assistant_mode, - other_bot_present, - native, - "slack assistant_mode decision (per turn)" - ); - native - } - - async fn stream_begin( - &self, - channel: &ChannelRef, - recipient: Option<(String, String)>, - ) -> Result { - let thread_ts = channel.thread_id.clone().unwrap_or_default(); - // recipient is bound to this turn (captured at message arrival, carried on - // BufferedMessage) — no shared thread cache, so no cross-turn race. - let make_ref = |ts: String| MessageRef { - channel: ChannelRef { - platform: "slack".into(), - channel_id: channel.channel_id.clone(), - thread_id: channel.thread_id.clone(), - parent_id: None, - origin_event_id: None, - }, - message_id: ts, - }; - - if let Some((user_id, team_id)) = recipient { - let body = build_start_stream_body(&channel.channel_id, &thread_ts, &user_id, &team_id); - match self.api_post("chat.startStream", body).await { - Ok(resp) => { - if let Some(ts) = resp["ts"].as_str() { - self.insert_stream( - ts.to_string(), - StreamEntry { active: true, degraded_buf: String::new() }, - ) - .await; - return Ok(make_ref(ts.to_string())); - } - error!("chat.startStream ok but no ts; falling back to post+edit"); - } - Err(e) => { - error!(error = %e, "chat.startStream failed; falling back to post+edit for this turn"); - } - } - } else { - // Expected for bot-authored turns (no recipient bound) and non-user - // triggers, so warn! rather than error! to avoid on-call noise. - warn!(thread_ts, "no recipient for turn; falling back to post+edit"); - } - - // Degraded fallback: plain placeholder via send_message; mark inactive. - let msg = self.send_message(channel, "…").await?; - self.insert_stream( - msg.message_id.clone(), - StreamEntry { active: false, degraded_buf: String::new() }, - ) - .await; - Ok(msg) - } - - async fn stream_append(&self, msg: &MessageRef, delta: &str) -> Result<()> { - let ts = &msg.message_id; - let active = { - let map = self.streams.lock().await; - map.get(ts).map(|e| e.active).unwrap_or(false) - }; - if active { - let body = build_append_stream_body(&msg.channel.channel_id, ts, delta); - if let Err(e) = self.api_post("chat.appendStream", body).await { - warn!(error = %e, "chat.appendStream failed (cosmetic; final replace will correct)"); - } - } else if let Some(cumulative) = self.accumulate_degraded(ts, delta).await { - let _ = self.edit_message(msg, &cumulative).await; // cosmetic mid-stream - } - Ok(()) - } - - async fn stream_finish(&self, msg: &MessageRef, final_content: &str) -> Result<()> { - let ts = &msg.message_id; - let active = { - let map = self.streams.lock().await; - map.get(ts).map(|e| e.active).unwrap_or(false) - }; - if active { - // Close the native stream WITHOUT re-sending content. The reply was - // already streamed live via chat.appendStream; stopStream's - // `markdown_text` *appends* (it does not replace), so passing the full - // content here duplicates the whole reply (#1055). Close only, then - // replace with the finalized content via chat.update below. - let close = serde_json::json!({ "channel": msg.channel.channel_id, "ts": ts }); - if let Err(e) = self.api_post("chat.stopStream", close).await { - warn!(error = %e, "chat.stopStream(close) failed; continuing to final replace"); - } - } - // Replace with the finalized content (Block Kit markdown). For the active - // path this overwrites the streamed preview with a single clean copy - // (rich rendering + native tables); for the degraded path it is the final - // post+edit update. chat.update replaces, so there is no duplication. - if let Err(e) = self.edit_message(msg, final_content).await { - if active { - // The native stream already delivered the reply (chat.appendStream), - // and stopStream left it in place. Do NOT postMessage a fallback - // here — that would post a duplicate copy. Keep the streamed - // content as the final message. - warn!(error = %e, "final chat.update failed; keeping streamed content (no duplicate post)"); - } else { - // Degraded path: no streamed content exists (post+edit placeholder), - // so post the final as a new message to avoid losing the reply. - warn!(error = %e, "final chat.update failed; trying postMessage"); - if let Err(e2) = self.send_message(&msg.channel, final_content).await { - error!(error = %e2, "final postMessage also failed; reply may be incomplete"); - } - } - } - self.streams.lock().await.remove(ts); - Ok(()) - } - - async fn set_status(&self, channel: &ChannelRef, status: &str) -> Result<()> { - let thread_ts = channel.thread_id.clone().unwrap_or_default(); - let body = build_set_status_body(&channel.channel_id, &thread_ts, status); - if let Err(e) = self.api_post("assistant.threads.setStatus", body).await { - warn!(error = %e, status, "assistant.threads.setStatus failed (cosmetic)"); - } - Ok(()) - } -} - -// --- Socket Mode event loop --- - -/// Hard cap on consecutive bot messages in a thread. Prevents runaway loops. -const MAX_CONSECUTIVE_BOT_TURNS: usize = 1000; - -/// Socket Mode keepalive. Slack's inbound WebSocket can go half-open (e.g. a NAT -/// idle-timeout silently drops inbound frames with no Close/FIN), which leaves -/// `read.next()` blocked forever, so the reconnect loop never fires and the bot -/// goes deaf while still showing as connected. We proactively ping and force a -/// reconnect when no inbound frame (including Slack's own pings) has arrived -/// within the idle window. Reconnect backoff mirrors the gateway adapter. -const PING_INTERVAL_SECS: u64 = 30; -const IDLE_TIMEOUT_SECS: u64 = 75; -const MAX_BACKOFF_SECS: u64 = 30; - -/// Next reconnect delay: double, capped. Reset to 1 on a successful connect. -fn next_backoff(cur: u64) -> u64 { - (cur * 2).min(MAX_BACKOFF_SECS) -} - -/// The socket is considered dead (half-open) when no inbound frame has arrived -/// within `timeout`; Slack sends periodic pings, so silence past the window -/// means the inbound path is gone. -fn socket_idle(since_last_inbound: std::time::Duration, timeout: std::time::Duration) -> bool { - since_last_inbound >= timeout -} - -/// Run the Slack adapter using Socket Mode (persistent WebSocket, no public URL needed). -/// Reconnects automatically on disconnect. -#[allow(clippy::too_many_arguments)] -pub async fn run_slack_adapter( - adapter: Arc, - app_token: String, - allow_all_channels: bool, - allow_all_users: bool, - allowed_channels: HashSet, - allowed_users: HashSet, - allow_bot_messages: AllowBots, - trusted_bot_ids: HashSet, - allow_user_messages: AllowUsers, - max_bot_turns: u32, - stt_config: SttConfig, - mut shutdown_rx: watch::Receiver, - dispatcher: Arc, -) -> Result<()> { - let bot_token = adapter.bot_token().to_string(); - let bot_turns = Arc::new(tokio::sync::Mutex::new(BotTurnTracker::new(max_bot_turns))); - // Warm the bot-user-id cache once so the per-message path never does the - // cold-cache `auth.test` inline in the read loop. - let _ = adapter.get_bot_user_id().await; - let mut backoff_secs = 1u64; - - loop { - // Check for shutdown before (re)connecting - if *shutdown_rx.borrow() { - info!("Slack adapter shutting down"); - return Ok(()); - } - - let ws_url = match get_socket_mode_url(&app_token).await { - Ok(url) => url, - Err(e) => { - error!(err = %e, backoff = backoff_secs, "failed to get Socket Mode URL, retrying"); - tokio::select! { - _ = tokio::time::sleep(std::time::Duration::from_secs(backoff_secs)) => {} - _ = shutdown_rx.changed() => { return Ok(()); } - } - backoff_secs = next_backoff(backoff_secs); - continue; - } - }; - info!(url = %ws_url, "connecting to Slack Socket Mode"); - - match tokio_tungstenite::connect_async(&ws_url).await { - Ok((ws_stream, _)) => { - info!("Slack Socket Mode connected"); - backoff_secs = 1; // reset on success - let (mut write, mut read) = ws_stream.split(); - let mut ping_interval = - tokio::time::interval(std::time::Duration::from_secs(PING_INTERVAL_SECS)); - ping_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); - let mut last_inbound = std::time::Instant::now(); - - loop { - tokio::select! { - msg_result = read.next() => { - last_inbound = std::time::Instant::now(); - let Some(msg_result) = msg_result else { break }; - match msg_result { - Ok(tungstenite::Message::Text(text)) => { - let envelope: serde_json::Value = - match serde_json::from_str(&text) { - Ok(v) => v, - Err(_) => continue, - }; - - // Acknowledge the envelope immediately - if let Some(envelope_id) = envelope["envelope_id"].as_str() { - let ack = serde_json::json!({"envelope_id": envelope_id}); - let _ = write - .send(tungstenite::Message::Text(ack.to_string())) - .await; - } - - // Slash commands and interactive block_actions aren't - // handled on Slack: slash commands are blocked by Slack - // in thread composers, and the channel-level delivery - // lacks the thread_ts needed to route to a session. - // Ack only; ignore payload. - match envelope["type"].as_str() { - Some("slash_commands") | Some("interactive") => { - debug!( - envelope_type = envelope["type"].as_str().unwrap_or(""), - "ignoring Slack envelope type (not supported on this adapter)" - ); - continue; - } - _ => {} - } - - // Route events - if envelope["type"].as_str() == Some("events_api") { - let event = &envelope["payload"]["event"]; - let event_type = event["type"].as_str().unwrap_or(""); - match event_type { - "app_mention" => { - // Apply bot gating for app_mention events (same rules as message events) - let is_bot = event["bot_id"].is_string() - || event["subtype"].as_str() == Some("bot_message"); - if is_bot { - match allow_bot_messages { - AllowBots::Off => { continue; } - AllowBots::Mentions | AllowBots::All => { - if !trusted_bot_ids.is_empty() { - let event_bot_id = event["bot_id"].as_str().unwrap_or(""); - let is_trusted = adapter - .trusted_bot_ids_contains(&trusted_bot_ids, event_bot_id) - .await; - if !is_trusted { - debug!(event_bot_id, "bot not in trusted_bot_ids, ignoring app_mention"); - continue; - } - } - } - } - } - let event = event.clone(); - let adapter = adapter.clone(); - let bot_token = bot_token.clone(); - let allowed_channels = allowed_channels.clone(); - let allowed_users = allowed_users.clone(); - let stt_config = stt_config.clone(); - let dispatcher = dispatcher.clone(); - let team_id = envelope["payload"]["team_id"] - .as_str() - .unwrap_or("") - .to_string(); - tokio::spawn(async move { - handle_message( - &event, - &team_id, - &adapter, - &bot_token, - allow_all_channels, - allow_all_users, - &allowed_channels, - &allowed_users, - &stt_config, - &dispatcher, - ) - .await; - }); - } - "message" => { - let channel_id = event["channel"].as_str().unwrap_or(""); - let has_thread = event["thread_ts"].is_string(); - let is_bot = event["bot_id"].is_string() - || event["subtype"].as_str() == Some("bot_message"); - let subtype = event["subtype"].as_str().unwrap_or(""); - let msg_text = event["text"].as_str().unwrap_or(""); - let bot_uid_opt = adapter.get_bot_user_id().await.map(|s| s.to_string()); - let mentions_bot = bot_uid_opt - .as_ref() - .is_some_and(|bot_uid| text_mentions_uid(msg_text, bot_uid)); - let is_dm = channel_id.starts_with('D'); - let event_user_id = event["user"].as_str(); - let is_own_bot_msg = is_bot - && bot_uid_opt.as_deref().is_some() - && event_user_id == bot_uid_opt.as_deref(); - - debug!( - channel_id, - has_thread, - is_bot, - is_dm, - subtype, - mentions_bot, - text = msg_text, - "message event received" - ); - - // Skip non-message subtypes - let skip_subtype = matches!(subtype, - "message_changed" | "message_deleted" | - "channel_join" | "channel_leave" | - "channel_topic" | "channel_purpose" - ); - if skip_subtype { continue; } - - // --- Eager multibot detection --- - // Runs before self-check and bot gating so we always detect - // other bots even when allow_bot_messages=Off filters them out. - // Matches Discord #481 ordering. - if is_bot && !is_own_bot_msg { - if let Some(thread_ts) = event["thread_ts"].as_str() { - adapter.note_other_bot_in_thread(thread_ts).await; - } - } - - // --- Bot turn tracking --- - // Runs before self-check so ALL bot messages (including own) - // count toward the per-thread limit. Matches Discord #483. - // Keyed on thread_ts when in a thread, else channel:ts. - // Non-thread messages get a unique key per message, so the - // counter never accumulates — intentional, because bot-to-bot - // loops only happen inside threads. - let turn_key = if let Some(thread_ts) = event["thread_ts"].as_str() { - thread_ts.to_string() - } else { - format!("{}:{}", channel_id, event["ts"].as_str().unwrap_or("")) - }; - // Classify under the lock (order-sensitive, kept in the read - // loop), but run any warning send AFTER releasing it; holding - // the tracker mutex across `chat.postMessage` would stall turn - // tracking for every thread, not just this one. - let turn_action = { - let mut tracker = bot_turns.lock().await; - if is_bot { - tracker.classify_bot_message(&turn_key) - } else { - if is_plain_user_message(subtype, msg_text) { - tracker.on_human_message(&turn_key); - } - TurnAction::Continue - } - }; - match turn_action { - TurnAction::Continue => {} - TurnAction::SilentStop => continue, - TurnAction::WarnAndStop { severity, turns, user_message } => { - match severity { - TurnSeverity::Hard => warn!(channel_id, turns, "hard bot turn limit reached"), - TurnSeverity::Soft => info!(channel_id, turns, max = max_bot_turns, "soft bot turn limit reached"), - } - let channel_allowed = allow_all_channels - || allowed_channels.contains(channel_id); - if !is_own_bot_msg && channel_allowed { - let warn_channel = ChannelRef { - platform: "slack".into(), - channel_id: channel_id.to_string(), - thread_id: event["thread_ts"].as_str().map(|s| s.to_string()), - parent_id: None, - origin_event_id: None, - }; - let adapter = adapter.clone(); - tokio::spawn(async move { - if let Err(e) = adapter.send_message(&warn_channel, &user_message).await { - warn!(error = %e, "failed to send bot turn limit warning"); - } - }); - } - continue; - } - } - - // Ignore own bot messages (after counting toward turns) - if is_own_bot_msg { continue; } - - // Skip messages that @mention the bot — app_mention handles those - // (except in DMs where app_mention doesn't fire) - if mentions_bot && !is_dm { continue; } - - // --- Bot message gating --- - if is_bot { - let event_bot_id = event["bot_id"].as_str().unwrap_or(""); - match allow_bot_messages { - AllowBots::Off => { continue; } - AllowBots::Mentions => { - if !mentions_bot { continue; } - } - AllowBots::All => { - // Loop protection: count consecutive bot msgs (fail-closed) - if let Some(thread_ts) = event["thread_ts"].as_str() { - let cap = MAX_CONSECUTIVE_BOT_TURNS; - let limit_str = std::cmp::min(cap + 1, 1000).to_string(); - match adapter.api_get( - "conversations.replies", - &[ - ("channel", channel_id), - ("ts", thread_ts), - ("limit", &limit_str), - ("inclusive", "true"), - ], - ).await { - Ok(resp) => { - if let Some(msgs) = resp["messages"].as_array() { - let consecutive = msgs.iter().rev() - .take_while(|m| { - m["bot_id"].is_string() - || m["subtype"].as_str() == Some("bot_message") - }) - .count(); - if consecutive >= cap { - warn!(channel_id, cap, "bot turn cap reached, ignoring"); - continue; - } - } - } - Err(e) => { - warn!(channel_id, thread_ts, error = %e, "failed to fetch thread for bot loop check, rejecting (fail-closed)"); - continue; - } - } - } - } - } - // Check trusted_bot_ids - if !trusted_bot_ids.is_empty() { - let is_trusted = adapter - .trusted_bot_ids_contains(&trusted_bot_ids, event_bot_id) - .await; - if !is_trusted { - debug!(event_bot_id, "bot not in trusted_bot_ids, ignoring"); - continue; - } - } - // Bot messages must be in a thread (no top-level bot processing) - if !has_thread { continue; } - } - - // --- User message gating --- - if !is_bot { - if is_dm { - // DM: implicit mention — always process - } else { - match allow_user_messages { - AllowUsers::Mentions => { - if !mentions_bot { continue; } - } - AllowUsers::Involved => { - if !has_thread { - continue; - } - let thread_ts = event["thread_ts"].as_str().unwrap_or(""); - let (involved, _) = adapter - .bot_participated_in_thread(channel_id, thread_ts) - .await; - if !involved { - debug!(channel_id, thread_ts, "bot not involved in thread, ignoring"); - continue; - } - } - AllowUsers::MultibotMentions => { - if !has_thread { - continue; - } - let thread_ts = event["thread_ts"].as_str().unwrap_or(""); - let (involved, other_bot) = adapter - .bot_participated_in_thread(channel_id, thread_ts) - .await; - if !involved { - debug!(channel_id, thread_ts, "bot not involved in thread, ignoring"); - continue; - } - // In multi-bot threads, require @mention — mirrors - // Discord's `should_process_user_message`. In practice - // mention-bearing message events are already deduped - // earlier (app_mention handles the @-path), so this - // branch rarely sees `mentions_bot == true`, but keep - // the explicit check so the logic is self-consistent - // and survives changes to the earlier dedup. - if other_bot && !mentions_bot { - debug!(channel_id, thread_ts, "multi-bot thread without @mention, ignoring"); - continue; - } - } - } - } - } - - // Dispatch to handle_message (per-thread serialization comes - // from Dispatcher consumer task in batched mode and from - // pool.with_connection in per-message mode). - let team_id = envelope["payload"]["team_id"] - .as_str() - .unwrap_or("") - .to_string(); - let event = event.clone(); - let adapter = adapter.clone(); - let bot_token = bot_token.clone(); - let allowed_channels = allowed_channels.clone(); - let allowed_users = allowed_users.clone(); - let stt_config = stt_config.clone(); - let dispatcher = dispatcher.clone(); - tokio::spawn(async move { - handle_message( - &event, - &team_id, - &adapter, - &bot_token, - allow_all_channels, - allow_all_users, - &allowed_channels, - &allowed_users, - &stt_config, - &dispatcher, - ) - .await; - }); - } - _ => {} - } - } - } - Ok(tungstenite::Message::Ping(data)) => { - let _ = write.send(tungstenite::Message::Pong(data)).await; - } - Ok(tungstenite::Message::Close(_)) => { - warn!("Slack Socket Mode connection closed by server"); - break; - } - Err(e) => { - error!("Socket Mode read error: {e}"); - break; - } - _ => {} - } - } - _ = ping_interval.tick() => { - if socket_idle( - last_inbound.elapsed(), - std::time::Duration::from_secs(IDLE_TIMEOUT_SECS), - ) { - warn!( - idle_secs = last_inbound.elapsed().as_secs(), - "Slack Socket Mode idle past timeout (likely half-open), forcing reconnect" - ); - break; - } - if let Err(e) = write.send(tungstenite::Message::Ping(Vec::new())).await { - warn!(error = %e, "Slack Socket Mode ping failed, reconnecting"); - break; - } - } - _ = shutdown_rx.changed() => { - info!("Slack adapter received shutdown signal"); - let _ = write.send(tungstenite::Message::Close(None)).await; - return Ok(()); - } - } - } - } - Err(e) => { - error!(err = %e, backoff = backoff_secs, "failed to connect to Slack Socket Mode, retrying"); - } - } - - warn!(backoff = backoff_secs, "reconnecting to Slack Socket Mode"); - tokio::select! { - _ = tokio::time::sleep(std::time::Duration::from_secs(backoff_secs)) => {} - _ = shutdown_rx.changed() => { return Ok(()); } - } - backoff_secs = next_backoff(backoff_secs); - } -} - -/// Call apps.connections.open to get a WebSocket URL for Socket Mode. -async fn get_socket_mode_url(app_token: &str) -> Result { - let client = reqwest::Client::new(); - let resp = client - .post(format!("{SLACK_API}/apps.connections.open")) - .header("Authorization", format!("Bearer {app_token}")) - .header("Content-Type", "application/x-www-form-urlencoded") - .send() - .await?; - let json: serde_json::Value = resp.json().await?; - if json["ok"].as_bool() != Some(true) { - let err = json["error"].as_str().unwrap_or("unknown"); - return Err(anyhow!("apps.connections.open: {err}")); - } - json["url"] - .as_str() - .map(|s| s.to_string()) - .ok_or_else(|| anyhow!("no url in apps.connections.open response")) -} - -#[allow(clippy::too_many_arguments)] -async fn handle_message( - event: &serde_json::Value, - team_id: &str, - adapter: &Arc, - bot_token: &str, - allow_all_channels: bool, - allow_all_users: bool, - allowed_channels: &HashSet, - allowed_users: &HashSet, - stt_config: &SttConfig, - dispatcher: &Arc, -) { - let channel_id = match event["channel"].as_str() { - Some(ch) => ch.to_string(), - None => return, - }; - // Bot messages may lack "user" field — fall back to "bot_id" as sender identifier - let user_id = match event["user"].as_str().or_else(|| event["bot_id"].as_str()) { - Some(u) => u.to_string(), - None => return, - }; - let is_bot_msg = - event["bot_id"].is_string() || event["subtype"].as_str() == Some("bot_message"); - let text = match event["text"].as_str() { - Some(t) => t.to_string(), - None => return, - }; - let ts = match event["ts"].as_str() { - Some(ts) => ts.to_string(), - None => return, - }; - let thread_ts = event["thread_ts"].as_str().map(|s| s.to_string()); - - // Check allowed channels - if !allow_all_channels && !allowed_channels.contains(&channel_id) { - return; - } - - // Check allowed users — skip for bot messages (they go through trusted_bot_ids instead) - if !is_bot_msg && !allow_all_users && !allowed_users.contains(&user_id) { - tracing::info!(user_id, "denied Slack user, ignoring"); - let msg_ref = MessageRef { - channel: ChannelRef { - platform: "slack".into(), - channel_id: channel_id.clone(), - thread_id: thread_ts.clone(), - parent_id: None, - origin_event_id: None, - }, - message_id: ts.clone(), - }; - let _ = adapter.add_reaction(&msg_ref, "🚫").await; - return; - } - - // Capture the native-streaming recipient for THIS turn, now that the sender has - // passed the channel + user allow-list checks above (so denied/unauthorized - // senders are never recorded). It rides on the per-turn BufferedMessage to - // stream_begin — no shared thread cache, no cross-turn race. Real users only: - // bot IDs (B...) are rejected by chat.startStream's recipient_user_id, and an - // empty team_id would silently degrade, so we surface that. - let stream_recipient = if is_bot_msg { - None - } else { - if team_id.is_empty() { - warn!("empty team_id; chat.startStream will degrade to post+edit"); - } - Some((user_id.clone(), team_id.to_string())) - }; - - // Resolve mentions: strip only this bot's own trigger mention so the LLM - // can still @-mention other users in its reply. - let bot_id = adapter.get_bot_user_id().await; - let prompt = resolve_slack_mentions(&text, bot_id); - - // Process file attachments (images, audio) - let files = event["files"].as_array(); - let has_files = files.is_some_and(|f| !f.is_empty()); - - if prompt.is_empty() && !has_files { - return; - } - - // Caps mirror Discord's text-file attachment flow (PR #291) so both - // adapters apply the same limits: 5 files or 1 MB of text per message. - const TEXT_TOTAL_CAP: u64 = 1024 * 1024; - const TEXT_FILE_COUNT_CAP: u32 = 5; - - let mut extra_blocks = Vec::new(); - let mut echo_entries: Vec = Vec::new(); - let mut text_file_bytes: u64 = 0; - let mut text_file_count: u32 = 0; - let mut failed_image_files: Vec = Vec::new(); - - if let Some(files) = files { - for file in files { - let mimetype_raw = file["mimetype"].as_str().unwrap_or(""); - let mimetype = strip_mime_params(mimetype_raw); - let filename = file["name"].as_str().unwrap_or("file"); - let size = file["size"].as_u64().unwrap_or(0); - // Slack private files require Bearer token to download - let url = slack_file_download_url(file); - - if url.is_empty() { - continue; - } - - if media::is_audio_mime(mimetype) { - if stt_config.enabled { - match media::download_and_transcribe( - url, - filename, - mimetype, - size, - stt_config, - Some(bot_token), - ) - .await - { - Some(transcript) => { - debug!( - filename, - chars = transcript.len(), - "voice transcript injected" - ); - extra_blocks.insert( - 0, - ContentBlock::Text { - text: format!("[Voice message transcript]: {transcript}"), - }, - ); - echo_entries.push(crate::stt::EchoEntry::Success(transcript)); - } - None => { - warn!(filename, "STT failed for voice attachment"); - echo_entries.push(crate::stt::EchoEntry::Failed); - } - } - } else { - debug!(filename, "skipping audio attachment (STT disabled)"); - let msg_ref = MessageRef { - channel: ChannelRef { - platform: "slack".into(), - channel_id: channel_id.clone(), - thread_id: thread_ts.clone(), - parent_id: None, - origin_event_id: None, - }, - message_id: ts.clone(), - }; - let _ = adapter.add_reaction(&msg_ref, "🎤").await; - } - } else if media::is_text_file(filename, Some(mimetype)) { - if text_file_count >= TEXT_FILE_COUNT_CAP { - debug!( - filename, - count = text_file_count, - "text file count cap reached, skipping" - ); - continue; - } - // Pre-check with Slack-reported size as a fast path when the - // field is populated. Slack can report `size == 0` for - // externally-backed files, so this is advisory only — the - // authoritative cap check happens after download using - // `actual_bytes`. - if size > 0 && text_file_bytes + size > TEXT_TOTAL_CAP { - debug!( - filename, - total = text_file_bytes, - "text attachments total exceeds 1MB cap, skipping remaining" - ); - continue; - } - if let Some((block, actual_bytes)) = - media::download_and_read_text_file(url, filename, size, Some(bot_token)).await - { - if text_file_bytes + actual_bytes > TEXT_TOTAL_CAP { - debug!( - filename, - running = text_file_bytes, - actual = actual_bytes, - "text attachments total exceeds 1MB cap after download, dropping file", - ); - continue; - } - text_file_bytes += actual_bytes; - text_file_count += 1; - debug!(filename, "adding text file attachment"); - extra_blocks.push(block); - } - } else { - match media::download_and_encode_image( - url, - Some(mimetype), - filename, - size, - Some(bot_token), - ) - .await - { - Ok(block) => { - debug!(filename, "adding image attachment"); - extra_blocks.push(block); - } - Err(media::MediaFetchError::NotAnImage) => {} - Err(media::MediaFetchError::SizeExceeded { actual, limit }) => { - warn!(filename, actual, limit, "image exceeds size limit"); - failed_image_files.push(filename.to_string()); - } - Err( - media::MediaFetchError::UnsupportedResponseType { .. } - | media::MediaFetchError::InvalidImageBody { .. }, - ) => { - warn!( - filename, - "image validation failed; server may have returned non-image content" - ); - failed_image_files.push(filename.to_string()); - } - Err(media::MediaFetchError::ProcessingFailed(ref e)) => { - warn!(filename, error = %e, "image post-processing failed"); - failed_image_files.push(filename.to_string()); - } - Err(media::MediaFetchError::HttpStatus(status)) - if status.is_client_error() => - { - warn!(filename, %status, "image download denied"); - failed_image_files.push(filename.to_string()); - } - Err(e) => { - warn!(filename, error = %e, "image download failed"); - failed_image_files.push(filename.to_string()); - } - } - } - } - } - - // Notify user if any images couldn't be processed. - if !failed_image_files.is_empty() { - let warn_channel = ChannelRef { - platform: "slack".into(), - channel_id: channel_id.clone(), - thread_id: thread_ts.clone().or_else(|| Some(ts.clone())), - parent_id: None, - origin_event_id: None, - }; - let file_list = failed_image_files - .iter() - .map(|n| sanitize_slack_filename(n)) - .collect::>() - .join("`, `"); - let msg = format!( - ":warning: I couldn't process the file(s) you shared (`{file_list}`). \ - This can happen when the bot lacks the `files:read` OAuth scope, \ - the file format isn't supported (PNG/JPEG/GIF/WebP only), \ - or the file is too large." - ); - if let Err(e) = adapter.send_message(&warn_channel, &msg).await { - warn!(error = %e, "failed to send image validation warning to user"); - } - } - - // Resolve Slack display name (best-effort, fallback to user_id) - let display_name = adapter - .resolve_user_name(&user_id) - .await - .unwrap_or_else(|| user_id.clone()); - - let sender = SenderContext { - schema: "openab.sender.v1".into(), - sender_id: user_id.clone(), - sender_name: display_name.clone(), - display_name, - channel: "slack".into(), - channel_id: channel_id.clone(), - thread_id: thread_ts.clone(), - is_bot: is_bot_msg, - timestamp: Some(crate::timestamp::slack_ts_to_iso8601(&ts)), - message_id: Some(ts.clone()), - receiver_id: bot_id.map(|id| id.to_string()), - }; - - let trigger_msg = MessageRef { - channel: ChannelRef { - platform: "slack".into(), - channel_id: channel_id.clone(), - thread_id: thread_ts.clone(), - parent_id: None, - origin_event_id: None, - }, - message_id: ts.clone(), - }; - - // Determine thread: if already in a thread, continue it; otherwise start a new thread - let thread_channel = ChannelRef { - platform: "slack".into(), - channel_id: channel_id.clone(), - thread_id: Some(thread_ts.unwrap_or(ts)), - parent_id: None, - origin_event_id: None, - }; - - // Serialize sender context with Slack-native key names so agents calling - // the Slack API directly see "thread_ts" rather than the generic "thread_id". - let sender_json = { - let mut v = serde_json::to_value(&sender).unwrap(); - if let Some(obj) = v.as_object_mut() { - if let Some(tid) = obj.remove("thread_id") { - obj.insert("thread_ts".to_string(), tid); - } - } - v.to_string() - }; - - let adapter_dyn: Arc = adapter.clone(); - let other_bot_present = { - let cache = adapter.multibot_threads.lock().await; - thread_channel.thread_id.as_deref().is_some_and(|ts| { - cache - .get(ts) - .is_some_and(|inst| inst.elapsed() < adapter.session_ttl) - }) - } || thread_channel - .thread_id - .as_deref() - .is_some_and(|ts| adapter.multibot_cache.is_multibot(ts)); - - // Best-effort echo before the agent reply so the user can verify STT. - crate::stt::post_echo( - &adapter_dyn, - &thread_channel, - &trigger_msg, - &echo_entries, - stt_config, - ) - .await; - - let thread_id = thread_channel - .thread_id - .as_deref() - .unwrap_or(&thread_channel.channel_id); - let thread_key = dispatcher.key("slack", thread_id, &sender.sender_id); - let estimated_tokens = crate::dispatch::estimate_tokens(&prompt, &extra_blocks); - let buf_msg = crate::dispatch::BufferedMessage { - sender_json, - sender_name: sender.sender_name.clone(), - prompt, - extra_blocks, - trigger_msg, - arrived_at: std::time::Instant::now(), - estimated_tokens, - other_bot_present, - recipient: stream_recipient, - }; - if let Err(e) = dispatcher - .submit(thread_key, thread_channel, adapter_dyn, buf_msg) - .await - { - error!("Slack dispatcher submit error: {e}"); - } -} - -/// Strip all occurrences of the bot's own `<@BOT_UID>` or `<@BOT_UID|handle>` mention. -/// Other users' mentions stay intact so the LLM can @-mention them back. -/// If the bot UID isn't known, fall back to returning the text trimmed — -/// safer than stripping all mentions and losing user addressability. -fn resolve_slack_mentions(text: &str, bot_id: Option<&str>) -> String { - let Some(id) = bot_id else { - return text.trim().to_string(); - }; - let prefix = format!("<@{id}"); - let mut out = String::with_capacity(text.len()); - let mut s = text; - while let Some(pos) = s.find(&prefix) { - let after = &s[pos + prefix.len()..]; - match after.as_bytes().first() { - Some(b'>') => { - out.push_str(&s[..pos]); - s = &after[1..]; - } - Some(b'|') => { - if let Some(close) = after.find('>') { - out.push_str(&s[..pos]); - s = &after[close + 1..]; - } else { - out.push_str(&s[..pos + prefix.len()]); - s = after; - } - } - _ => { - out.push_str(&s[..pos + prefix.len()]); - s = after; - } - } - } - out.push_str(s); - out.trim().to_string() -} - -/// Pick the best download URL for a Slack file object. `url_private_download` -/// streams the raw bytes; `url_private` is the fallback for older file shapes. -/// Returns `""` when neither is present (caller should skip the file). -fn slack_file_download_url(file: &serde_json::Value) -> &str { - file["url_private_download"] - .as_str() - .or_else(|| file["url_private"].as_str()) - .unwrap_or("") -} - -/// Strip MIME parameters so type-detection helpers see the bare media type. -/// Delegates to media::strip_mime_params (single source of truth). -/// Needed because Slack occasionally sends `text/plain; charset=utf-8` and -/// `media::is_text_file` expects the bare form. -fn strip_mime_params(mimetype: &str) -> &str { - media::strip_mime_params(mimetype) -} - -/// Sanitize a filename for safe embedding in a Slack mrkdwn message. -/// -/// Ampersands (`&`), backticks (`` ` ``), and angle brackets (`<`, `>`) are escaped. -/// `&` is encoded as `&` first because Slack decodes HTML entities before parsing -/// mrkdwn — a filename like `<@here>` would otherwise round-trip back to -/// `<@here>` and trigger a mention ping. Backticks and angle brackets are Slack -/// mrkdwn delimiters; without escaping, `` or `` `<@U123>` `` would render -/// as mentions or @-here pings. -pub(crate) fn sanitize_slack_filename(s: &str) -> String { - s.replace('&', "&").replace('`', "'").replace('<', "(").replace('>', ")") -} - -/// Returns `true` if `text` contains a Slack user mention for `uid`. -/// -/// Accepts both `<@U...>` (bare) and `<@U...|handle>` (labelled) wire forms. -/// Slack (and bots addressing peers) can emit the labelled form; `<@UID>` is -/// not a substring of `<@UID|handle>`, so a bare `contains("<@UID>")` silently -/// misses it. -fn text_mentions_uid(text: &str, uid: &str) -> bool { - let prefix = format!("<@{uid}"); - text.match_indices(&prefix) - .any(|(i, _)| matches!(text.as_bytes().get(i + prefix.len()), Some(b'>') | Some(b'|'))) -} - -fn bot_id_matches_trusted( - trusted_bot_ids: &HashSet, - event_bot_id: &str, - resolved_user_id: Option<&str>, -) -> bool { - if event_bot_id.is_empty() { - return false; - } - - trusted_bot_ids.contains(event_bot_id) - || resolved_user_id.is_some_and(|uid| trusted_bot_ids.contains(uid)) -} - -/// True only when a Slack non-bot event represents a real user message -/// that should reset the bot-turn counter. -/// -/// Many Slack subtypes (pinned_item, channel_name, channel_archive, -/// group_join / group_leave / group_topic / group_purpose, reminder_add, -/// tombstone, …) carry a `user` field so the event loop sees -/// `is_bot == false`, but they represent administrative/system actions, -/// not conversation. Resetting the counter on them would let runaway -/// bot-to-bot loops re-arm whenever any pin / rename / archive happens. -/// -/// Mirrors Discord's `MessageType::Regular | InlineReply` + non-empty -/// content gate in `src/discord.rs`. Regression parity for -/// openabdev/openab#497. -fn is_plain_user_message(subtype: &str, text: &str) -> bool { - if text.is_empty() { - return false; - } - matches!( - subtype, - "" | "me_message" | "thread_broadcast" | "file_share", - ) -} - -/// Slack caps a single Block Kit `markdown` block at 12,000 characters; we use -/// 11,900 to keep ~100 chars of headroom. Doubles as the Slack `message_limit` -/// so the router splits long replies into separate messages at the same bound -/// (one markdown block per message stays under the API cap). -const MARKDOWN_BLOCK_LIMIT: usize = 11_900; - -/// True if a Slack API error indicates the `blocks` payload was rejected, so the -/// caller should retry text-only: -/// - `invalid_blocks` — workspace can't render the Block Kit `markdown` block -/// (malformed/unsupported payload). -/// - `msg_blocks_too_long` — content exceeds Slack's cumulative ~12k cap across -/// all `markdown` blocks in one message. Reachable by direct `send_message` -/// callers that bypass the router's `message_limit` pre-split (e.g. STT echo). -/// -/// `invalid_arguments` is deliberately excluded — it's a Slack catch-all (bad -/// channel, missing/invalid `ts`, malformed `thread_ts`, …) and would trigger a -/// pointless text-only retry that fails identically. -/// -/// Matches the Slack error *code* exactly (the trailing token of `api_post`'s -/// `"Slack API : "` message), not a substring of the message — -/// so a future code like `invalid_blocks_field` does not falsely match. -fn is_block_payload_rejected(e: &anyhow::Error) -> bool { - let s = e.to_string(); - let code = s.rsplit(": ").next().unwrap_or(s.as_str()).trim(); - code == "invalid_blocks" || code == "msg_blocks_too_long" -} - -/// Build Block Kit `markdown` blocks from raw Markdown. Slack renders these -/// natively — real headings, lists, tables, blockquotes, and language-tagged -/// code fences — unlike the legacy `text` mrkdwn field, which flattens headings -/// to bold and cannot render tables. Long content is split at the block limit, -/// reusing `format::split_message` so code-fence balance is preserved. -/// -/// Follow-up (non-blocking): `split_message` is not table-aware — a single -/// Markdown table exceeding `MARKDOWN_BLOCK_LIMIT` (11,900 chars) splits at line -/// boundaries, so continuation blocks lack the header/separator rows and render -/// as raw pipes. The 4000→11,900 bump makes this rare; a future improvement is -/// to re-emit the table header at the top of each continuation chunk. -fn build_markdown_blocks(content: &str) -> Vec { - let chunks = if content.len() <= MARKDOWN_BLOCK_LIMIT { - vec![content.to_string()] - } else { - crate::format::split_message(content, MARKDOWN_BLOCK_LIMIT) - }; - chunks - .into_iter() - .map(|chunk| serde_json::json!({ "type": "markdown", "text": chunk })) - .collect() -} - -/// Body for `chat.postMessage`: Block Kit `markdown` blocks (rich rendering) -/// plus a `text` fallback used for notifications and accessibility. -fn build_post_message_body( - channel_id: &str, - thread_ts: Option<&str>, - content: &str, -) -> serde_json::Value { - let mut body = serde_json::json!({ - "channel": channel_id, - "blocks": build_markdown_blocks(content), - "text": markdown_to_mrkdwn(content), - }); - if let Some(ts) = thread_ts { - body["thread_ts"] = serde_json::Value::String(ts.to_string()); - } - body -} - -/// Body for `chat.update`: same Block Kit `markdown` blocks + `text` fallback. -fn build_update_body(channel_id: &str, ts: &str, content: &str) -> serde_json::Value { - serde_json::json!({ - "channel": channel_id, - "ts": ts, - "blocks": build_markdown_blocks(content), - "text": markdown_to_mrkdwn(content), - }) -} - -/// Text-only `chat.postMessage` body (no `blocks`) — degradation path when a -/// workspace rejects the Block Kit `markdown` block. -fn build_post_message_text_only( - channel_id: &str, - thread_ts: Option<&str>, - content: &str, -) -> serde_json::Value { - let mut body = serde_json::json!({ - "channel": channel_id, - "text": markdown_to_mrkdwn(content), - }); - if let Some(ts) = thread_ts { - body["thread_ts"] = serde_json::Value::String(ts.to_string()); - } - body -} - -/// Text-only `chat.update` body (no `blocks`) — see `build_post_message_text_only`. -fn build_update_text_only(channel_id: &str, ts: &str, content: &str) -> serde_json::Value { - serde_json::json!({ - "channel": channel_id, - "ts": ts, - "text": markdown_to_mrkdwn(content), - }) -} - -/// Convert Markdown (as output by Claude Code) to Slack mrkdwn format. -/// Used for the `text` fallback field that accompanies Block Kit blocks -/// (shown in notification previews and to assistive tech). -fn markdown_to_mrkdwn(text: &str) -> String { - static BOLD_RE: LazyLock = - LazyLock::new(|| regex::Regex::new(r"\*\*(.+?)\*\*").unwrap()); - static ITALIC_RE: LazyLock = - LazyLock::new(|| regex::Regex::new(r"\*([^*]+?)\*").unwrap()); - static LINK_RE: LazyLock = - LazyLock::new(|| regex::Regex::new(r"\[([^\]]+)\]\(([^)]+)\)").unwrap()); - static HEADING_RE: LazyLock = - LazyLock::new(|| regex::Regex::new(r"(?m)^#{1,6}\s+(.+)$").unwrap()); - static CODE_BLOCK_LANG_RE: LazyLock = - LazyLock::new(|| regex::Regex::new(r"```\w+\n").unwrap()); - - // Order: bold first (** → placeholder), then italic (* → _), then restore bold - let text = BOLD_RE.replace_all(text, "\x01$1\x02"); // **bold** → \x01bold\x02 - let text = ITALIC_RE.replace_all(&text, "_${1}_"); // *italic* → _italic_ - // Restore bold: \x01bold\x02 → *bold* - let text = text.replace(['\x01', '\x02'], "*"); - let text = LINK_RE.replace_all(&text, "<$2|$1>"); // [text](url) → - let text = HEADING_RE.replace_all(&text, "*$1*"); // # heading → *heading* - let text = CODE_BLOCK_LANG_RE.replace_all(&text, "```\n"); // ```rust → ``` - text.into_owned() -} - -fn build_start_stream_body(channel: &str, thread_ts: &str, user_id: &str, team_id: &str) -> serde_json::Value { - serde_json::json!({ - "channel": channel, - "thread_ts": thread_ts, - "recipient_user_id": user_id, - "recipient_team_id": team_id, - }) -} - -fn build_append_stream_body(channel: &str, ts: &str, delta: &str) -> serde_json::Value { - serde_json::json!({ - "channel": channel, - "ts": ts, - "markdown_text": delta, - }) -} - -fn build_set_status_body(channel_id: &str, thread_ts: &str, status: &str) -> serde_json::Value { - serde_json::json!({ - "channel_id": channel_id, - "thread_ts": thread_ts, - "status": status, - }) -} - -#[cfg(test)] -mod tests { - use super::*; - - // --- builder tests --- - - #[test] - fn build_start_stream_body_has_recipient() { - let b = build_start_stream_body("C1", "1700.1", "U2", "T3"); - assert_eq!(b["channel"], "C1"); - assert_eq!(b["thread_ts"], "1700.1"); - assert_eq!(b["recipient_user_id"], "U2"); - assert_eq!(b["recipient_team_id"], "T3"); - } - - #[test] - fn build_append_stream_body_is_markdown_text_chunk() { - let b = build_append_stream_body("C1", "1700.9", "hello"); - assert_eq!(b["channel"], "C1"); - assert_eq!(b["ts"], "1700.9"); - assert_eq!(b["markdown_text"], "hello"); - } - - #[test] - fn build_set_status_body_shape() { - let b = build_set_status_body("C1", "1700.1", "Thinking\u{2026}"); - assert_eq!(b["channel_id"], "C1"); - assert_eq!(b["thread_ts"], "1700.1"); - assert_eq!(b["status"], "Thinking\u{2026}"); - } - - #[tokio::test] - async fn degraded_stream_append_accumulates() { - let adapter = SlackAdapter::new("xoxb-test".into(), std::time::Duration::from_secs(60), AllowBots::Off, true, crate::multibot_cache::MultibotCache::load("/dev/null".into())); - adapter.streams.lock().await.insert( - "TS".into(), - StreamEntry { active: false, degraded_buf: String::new() }, - ); - assert_eq!(adapter.accumulate_degraded("TS", "a").await.as_deref(), Some("a")); - assert_eq!(adapter.accumulate_degraded("TS", "b").await.as_deref(), Some("ab")); - // missing stream is not resurrected: - assert_eq!(adapter.accumulate_degraded("MISSING", "x").await, None); - } - use crate::adapter::ChatAdapter; - - /// Bot's own `<@UID>` trigger mention is stripped. - #[test] - fn resolve_mentions_strips_bot_mention() { - let out = resolve_slack_mentions("<@U1BOT> hello", Some("U1BOT")); - assert_eq!(out, "hello"); - } - - /// Other users' mentions are preserved so the LLM can address them back — - /// this is the core fix: the old `strip_slack_mention` wiped all `<@...>`. - #[test] - fn resolve_mentions_preserves_other_user_mentions() { - let out = resolve_slack_mentions("<@U1BOT> say hi to <@U2ALICE>", Some("U1BOT")); - assert_eq!(out, "say hi to <@U2ALICE>"); - } - - /// Multiple occurrences of the bot mention all get stripped. - #[test] - fn resolve_mentions_strips_repeated_bot_mentions() { - let out = resolve_slack_mentions("<@U1BOT> ping <@U1BOT>", Some("U1BOT")); - assert_eq!(out, "ping"); - } - - /// When the bot UID is unknown, fall back to preserving the text - /// (safer than stripping all user mentions). - #[test] - fn resolve_mentions_unknown_bot_preserves_all() { - let out = resolve_slack_mentions("<@U1BOT> hi <@U2ALICE>", None); - assert_eq!(out, "<@U1BOT> hi <@U2ALICE>"); - } - - /// Labelled form of another user's mention (`<@UID|handle>`) is preserved. - #[test] - fn resolve_mentions_preserves_labelled_other_user_mention() { - let out = resolve_slack_mentions("<@U1BOT> say hi to <@U2ALICE|alice>", Some("U1BOT")); - assert_eq!(out, "say hi to <@U2ALICE|alice>"); - } - - /// Labelled form `<@UID|handle>` is stripped the same as bare form. - #[test] - fn resolve_mentions_strips_labelled_bot_mention() { - let out = resolve_slack_mentions("<@U1BOT|my-bot> hello", Some("U1BOT")); - assert_eq!(out, "hello"); - } - - /// Labelled form mid-sentence is stripped and surrounding text preserved. - #[test] - fn resolve_mentions_strips_labelled_mid_sentence() { - let out = resolve_slack_mentions("please ask <@U1BOT|handle> to run", Some("U1BOT")); - assert_eq!(out, "please ask to run"); - } - - /// Mixed bare and labelled forms of the same UID in one string are both stripped. - #[test] - fn resolve_mentions_strips_mixed_bare_and_labelled() { - let out = resolve_slack_mentions("<@U1BOT> and <@U1BOT|handle> run", Some("U1BOT")); - assert_eq!(out, "and run"); - } - - /// Malformed unclosed `<@UID|label` (no closing `>`) is preserved verbatim. - #[test] - fn resolve_mentions_malformed_unclosed_label_preserved() { - let out = resolve_slack_mentions("ask <@U1BOT|nolabel to run", Some("U1BOT")); - assert!(out.contains("<@U1BOT")); - } - - #[test] - fn resolve_mentions_preserves_longer_uid_prefix() { - let out = resolve_slack_mentions("<@U1BOTX> hello", Some("U1BOT")); - assert_eq!(out, "<@U1BOTX> hello"); - } - - // --- text_mentions_uid tests --- - - #[test] - fn mentions_uid_bare_form() { - assert!(text_mentions_uid("<@U123BOT> hello", "U123BOT")); - } - - #[test] - fn mentions_uid_labelled_form() { - assert!(text_mentions_uid("<@U123BOT|my-bot> hello", "U123BOT")); - } - - #[test] - fn mentions_uid_labelled_form_mid_sentence() { - assert!(text_mentions_uid("please ask <@U123BOT|handle> to run", "U123BOT")); - } - - #[test] - fn mentions_uid_no_match() { - assert!(!text_mentions_uid("hello world", "U123BOT")); - } - - #[test] - fn mentions_uid_no_false_positive_on_uid_prefix() { - assert!(!text_mentions_uid("<@U123BOT> hello", "U123")); - } - - #[test] - fn mentions_uid_second_mention_matches() { - assert!(text_mentions_uid("<@U999OTHER> and <@U123BOT>", "U123BOT")); - } - - #[test] - fn mentions_uid_empty_label_form() { - assert!(text_mentions_uid("<@U123BOT|> hello", "U123BOT")); - } - - #[test] - fn mentions_uid_truncated_no_closing_delimiter() { - assert!(!text_mentions_uid("<@U123BOT", "U123BOT")); - } - - // --- is_plain_user_message tests (regression for openabdev/openab#497 parity) --- - - /// Empty message text never counts as a user message (regardless of subtype). - #[test] - fn empty_text_is_not_plain_user_message() { - assert!(!is_plain_user_message("", "")); - assert!(!is_plain_user_message("me_message", "")); - } - - /// No subtype + non-empty text = plain user message (the common case). - #[test] - fn no_subtype_nonempty_text_is_plain_user_message() { - assert!(is_plain_user_message("", "hello")); - } - - /// Whitelisted subtypes with non-empty text are user messages. - #[test] - fn whitelisted_subtypes_are_plain_user_messages() { - assert!(is_plain_user_message("me_message", "waves")); - assert!(is_plain_user_message("thread_broadcast", "see channel")); - assert!(is_plain_user_message("file_share", "caption")); - } - - /// System-ish subtypes (even from real users) are NOT user messages — - /// resetting the counter on them would let bot-to-bot loops re-arm. - #[test] - fn system_subtypes_are_not_plain_user_messages() { - for subtype in [ - "pinned_item", - "unpinned_item", - "channel_name", - "channel_archive", - "channel_unarchive", - "group_join", - "group_leave", - "group_topic", - "group_purpose", - "reminder_add", - "tombstone", - ] { - assert!( - !is_plain_user_message(subtype, "some text"), - "subtype {subtype} must not count as a user message", - ); - } - } - - // --- slack_file_download_url tests --- - - /// Prefers url_private_download when both fields are present — - /// that endpoint always streams raw bytes even for browser-previewed types. - #[test] - fn slack_file_url_prefers_download_variant() { - let file = serde_json::json!({ - "url_private_download": "https://files.slack.com/.../download/log.txt", - "url_private": "https://files.slack.com/.../preview/log.txt", - }); - assert_eq!( - slack_file_download_url(&file), - "https://files.slack.com/.../download/log.txt", - ); - } - - /// Falls back to url_private when url_private_download is absent. - #[test] - fn slack_file_url_falls_back_to_private() { - let file = serde_json::json!({ - "url_private": "https://files.slack.com/.../log.txt", - }); - assert_eq!( - slack_file_download_url(&file), - "https://files.slack.com/.../log.txt", - ); - } - - /// Externally-backed files with no private URL return empty — caller skips. - #[test] - fn slack_file_url_empty_for_external_only() { - let file = serde_json::json!({ - "external_type": "gdrive", - "permalink": "https://docs.google.com/...", - }); - assert_eq!(slack_file_download_url(&file), ""); - } - - // --- sanitize_slack_filename tests --- - - #[test] - fn sanitize_leaves_normal_filename_unchanged() { - assert_eq!(sanitize_slack_filename("photo.png"), "photo.png"); - assert_eq!(sanitize_slack_filename("my file (1).jpg"), "my file (1).jpg"); - } - - #[test] - fn sanitize_replaces_backtick() { - assert_eq!(sanitize_slack_filename("file`name.png"), "file'name.png"); - } - - #[test] - fn sanitize_replaces_angle_brackets() { - // Angle brackets are Slack mrkdwn delimiters; they must not pass through. - assert_eq!(sanitize_slack_filename("<@U123>"), "(@U123)"); - assert_eq!(sanitize_slack_filename(""), "(!here)"); - } - - #[test] - fn sanitize_combined_injection_attempt() { - // A filename constructed to inject a Slack @here ping. - assert_eq!( - sanitize_slack_filename("``"), - "'(!here)'" - ); - } - - #[test] - fn sanitize_escapes_ampersand_before_angle_brackets() { - // Slack mrkdwn decodes HTML entities before markup parsing. - // "<@here>" would round-trip back to "<@here>" and trigger a mention - // ping if & is not escaped. The & must be escaped first so downstream - // Slack entity decoding cannot reconstruct a mrkdwn delimiter. - assert_eq!(sanitize_slack_filename("<@here>"), "&lt;@here&gt;"); - assert_eq!(sanitize_slack_filename("file&name.png"), "file&name.png"); - } - - // --- strip_mime_params tests --- - - /// MIME with charset parameter strips to bare media type. - #[test] - fn strip_mime_params_removes_charset() { - assert_eq!(strip_mime_params("text/plain; charset=utf-8"), "text/plain"); - } - - /// Bare MIME is unchanged. - #[test] - fn strip_mime_params_bare_unchanged() { - assert_eq!(strip_mime_params("image/png"), "image/png"); - } - - /// Empty input is unchanged. - #[test] - fn strip_mime_params_empty() { - assert_eq!(strip_mime_params(""), ""); - } - - /// Surrounding whitespace is trimmed. - #[test] - fn strip_mime_params_trims_whitespace() { - assert_eq!(strip_mime_params(" text/plain "), "text/plain"); - } - - // --- bot_id_matches_trusted tests --- - - #[test] - fn trusted_bot_ids_accepts_raw_slack_bot_id() { - let trusted = HashSet::from(["B123BOT".to_string()]); - assert!(bot_id_matches_trusted(&trusted, "B123BOT", None)); - } - - #[test] - fn trusted_bot_ids_accepts_resolved_bot_user_id() { - let trusted = HashSet::from(["U123BOT".to_string()]); - assert!(bot_id_matches_trusted( - &trusted, - "B123BOT", - Some("U123BOT") - )); - } - - #[test] - fn trusted_bot_ids_rejects_unknown_bot_when_resolution_fails() { - let trusted = HashSet::from(["U123BOT".to_string()]); - assert!(!bot_id_matches_trusted(&trusted, "B999BOT", None)); - } - - #[test] - fn trusted_bot_ids_rejects_empty_event_bot_id() { - let trusted = HashSet::from(["".to_string()]); - assert!(!bot_id_matches_trusted(&trusted, "", None)); - } - - /// Per-thread streaming: ON by default, OFF when another bot is present (#534). - #[test] - fn streaming_per_thread() { - let ttl = std::time::Duration::from_secs(300); - let adapter = SlackAdapter::new("xoxb-test".into(), ttl, AllowBots::Mentions, false, crate::multibot_cache::MultibotCache::load("/dev/null".into())); - - assert!( - adapter.use_streaming(false), - "should stream when no other bot" - ); - assert!( - !adapter.use_streaming(true), - "should NOT stream when other bot present" - ); - } - - #[tokio::test] - async fn assistant_mode_gates_status_and_native_streaming() { - let ttl = std::time::Duration::from_secs(60); - // assistant_mode=true → status API on; native streaming on (no other bot), - // off when another bot is present; post+edit streaming on regardless. - let adapter = SlackAdapter::new("xoxb-test".into(), ttl, AllowBots::Off, true, crate::multibot_cache::MultibotCache::load("/dev/null".into())); - assert!(adapter.uses_assistant_status(), "assistant_mode enables status API"); - assert!(adapter.use_streaming(false), "post+edit streaming on when no other bot"); - assert!(adapter.uses_native_streaming(false), "native streaming on when no other bot"); - assert!(!adapter.uses_native_streaming(true), "other bot present disables native"); - // assistant_mode=false → no status API, no native streaming; post+edit still streams. - let adapter2 = SlackAdapter::new("xoxb-test".into(), ttl, AllowBots::Off, false, crate::multibot_cache::MultibotCache::load("/dev/null".into())); - assert!(!adapter2.uses_assistant_status()); - assert!(adapter2.use_streaming(false), "post+edit streaming independent of assistant_mode"); - assert!(!adapter2.uses_native_streaming(false), "native streaming requires assistant_mode"); - } - - /// chat.postMessage body carries Block Kit `markdown` blocks with the raw - /// Markdown preserved (NOT downgraded), plus a `text` fallback and thread_ts. - #[test] - fn post_message_body_uses_raw_markdown_blocks() { - let b = build_post_message_body("C1", Some("1700.1"), "## Heading\n- item"); - assert_eq!(b["channel"], "C1"); - assert_eq!(b["thread_ts"], "1700.1"); - assert_eq!(b["blocks"][0]["type"], "markdown"); - // Raw markdown preserved — heading is NOT flattened to `*Heading*`. - assert_eq!(b["blocks"][0]["text"], "## Heading\n- item"); - assert!(b["text"].is_string(), "text fallback present for a11y/notifs"); - } - - /// thread_ts is omitted (top-level post) when the channel has no thread. - #[test] - fn post_message_body_omits_thread_ts_when_none() { - let b = build_post_message_body("C1", None, "hi"); - assert!(b.get("thread_ts").is_none()); - } - - /// chat.update body also uses Block Kit `markdown` blocks with raw markdown. - #[test] - fn update_body_uses_raw_markdown_blocks() { - let b = build_update_body("C1", "1700.9", "**bold**"); - assert_eq!(b["channel"], "C1"); - assert_eq!(b["ts"], "1700.9"); - assert_eq!(b["blocks"][0]["type"], "markdown"); - assert_eq!(b["blocks"][0]["text"], "**bold**"); - } - - /// Content over the per-block cap (11,900) splits into multiple markdown - /// blocks, each within the limit. Assert on char count — `split_message` - /// enforces `chars().count() <= limit`, not byte length. - #[test] - fn long_content_splits_into_multiple_markdown_blocks() { - let big = "lorem ipsum dolor\n".repeat(1000); // > MARKDOWN_BLOCK_LIMIT - assert!(big.chars().count() > MARKDOWN_BLOCK_LIMIT); - let blocks = build_markdown_blocks(&big); - assert!(blocks.len() >= 2, "should split into multiple blocks"); - for blk in &blocks { - assert_eq!(blk["type"], "markdown"); - assert!(blk["text"].as_str().unwrap().chars().count() <= MARKDOWN_BLOCK_LIMIT); - } - } - - /// Regression for the long-table split: a Markdown table that overflows the - /// old 4000 limit but fits the new 11,900 message_limit must stay in a single - /// chunk, so it isn't split mid-table into raw pipe text. - #[test] - fn typical_long_table_stays_in_one_chunk() { - let ttl = std::time::Duration::from_secs(300); - let adapter = SlackAdapter::new("xoxb-test".into(), ttl, AllowBots::Mentions, true, crate::multibot_cache::MultibotCache::load("/dev/null".into())); - let limit = adapter.message_limit(); - assert_eq!(limit, MARKDOWN_BLOCK_LIMIT); - let mut table = String::from("| col a | col b | col c |\n|---|---|---|\n"); - for i in 0..150 { - table.push_str(&format!("| row {i} aaaa | bbbb {i} | cccc {i} |\n")); - } - assert!(table.chars().count() > 4000, "table must exceed old limit"); - assert!(table.chars().count() < limit, "but fit the new one"); - assert_eq!( - crate::format::split_message(&table, limit).len(), - 1, - "table within message_limit must not be split mid-table" - ); - } - - /// Text-only fallback bodies carry `text` and no `blocks` — used when a - /// workspace rejects the Block Kit markdown block. - #[test] - fn text_only_fallback_bodies_have_no_blocks() { - let post = build_post_message_text_only("C1", Some("1700.1"), "## H\n- x"); - assert!(post.get("blocks").is_none()); - assert!(post["text"].is_string()); - assert_eq!(post["thread_ts"], "1700.1"); - let upd = build_update_text_only("C1", "1700.9", "**b**"); - assert!(upd.get("blocks").is_none()); - assert!(upd["text"].is_string()); - } - - /// Error classifier matches `invalid_blocks` (malformed/unsupported blocks) - /// and `msg_blocks_too_long` (over the cumulative block cap) → degrade to - /// text. `invalid_arguments` is a Slack catch-all and must NOT trigger a - /// pointless text-only retry; unrelated errors are ignored too. - #[test] - fn detects_block_payload_rejected_errors() { - assert!(is_block_payload_rejected(&anyhow!( - "Slack API chat.postMessage: invalid_blocks" - ))); - assert!( - is_block_payload_rejected(&anyhow!("Slack API chat.postMessage: msg_blocks_too_long")), - "oversize block payload should degrade to text-only" - ); - assert!( - !is_block_payload_rejected(&anyhow!("Slack API chat.update: invalid_arguments")), - "invalid_arguments is a catch-all, not a block-rejection signal" - ); - assert!(!is_block_payload_rejected(&anyhow!( - "Slack API chat.postMessage: channel_not_found" - ))); - // Exact error-code match, not substring: a future code that merely - // contains `invalid_blocks` must NOT trigger a text-only retry. - assert!( - !is_block_payload_rejected(&anyhow!("Slack API chat.postMessage: invalid_blocks_field")), - "must match the error code exactly, not as a substring" - ); - } - - /// Slack opts into native table rendering (Block Kit markdown / markdown_text - /// stream chunks), so the router skips the table→code-block conversion. - #[test] - fn slack_renders_native_tables() { - let ttl = std::time::Duration::from_secs(300); - let adapter = SlackAdapter::new("xoxb-test".into(), ttl, AllowBots::Mentions, true, crate::multibot_cache::MultibotCache::load("/dev/null".into())); - assert!(adapter.renders_native_tables()); - } -} - -#[cfg(test)] -mod socket_keepalive_tests { - use super::{next_backoff, socket_idle, IDLE_TIMEOUT_SECS, MAX_BACKOFF_SECS}; - use std::time::Duration; - - /// Backoff doubles and caps, matching the gateway adapter (1,2,4,8,16,30,30…). - #[test] - fn backoff_doubles_then_caps() { - let mut b = 1u64; - let seq: Vec = (0..8) - .map(|_| { - let cur = b; - b = next_backoff(b); - cur - }) - .collect(); - assert_eq!(seq, vec![1, 2, 4, 8, 16, MAX_BACKOFF_SECS, MAX_BACKOFF_SECS, MAX_BACKOFF_SECS]); - assert_eq!(next_backoff(MAX_BACKOFF_SECS), MAX_BACKOFF_SECS); - } - - /// A half-open socket (no inbound past the window) is detected; an active one - /// (recent inbound, e.g. a Slack ping) is not. This is the deaf-socket guard. - #[test] - fn idle_detects_half_open_at_boundary() { - let timeout = Duration::from_secs(IDLE_TIMEOUT_SECS); - assert!(!socket_idle(Duration::from_secs(0), timeout)); - assert!(!socket_idle(Duration::from_secs(IDLE_TIMEOUT_SECS - 1), timeout)); - assert!(socket_idle(Duration::from_secs(IDLE_TIMEOUT_SECS), timeout)); - assert!(socket_idle(Duration::from_secs(IDLE_TIMEOUT_SECS + 10), timeout)); - } -} diff --git a/src/stt.rs b/src/stt.rs deleted file mode 100644 index d266e6117..000000000 --- a/src/stt.rs +++ /dev/null @@ -1,354 +0,0 @@ -use crate::adapter::{ChannelRef, ChatAdapter, MessageRef}; -use crate::config::SttConfig; -use reqwest::multipart; -use std::sync::Arc; -use tracing::{debug, error, warn}; - -/// Outcome of attempting STT on a single audio attachment. -/// Used by adapters to feed `post_echo`. -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum EchoEntry { - Success(String), - Failed, -} - -/// Render a list of echo entries as a single multi-line quoted block. -/// Returns `None` for empty input so callers can short-circuit. -/// -/// Each entry produces one `> 🎤 …` line. Internal newlines inside a -/// transcript are flattened to spaces so each entry occupies exactly one -/// visual line — Discord and Slack both stop applying `>` at the next `\n`. -pub fn format_echo_message(entries: &[EchoEntry]) -> Option { - if entries.is_empty() { - return None; - } - let mut lines = Vec::with_capacity(entries.len()); - for e in entries { - match e { - EchoEntry::Success(text) => { - let flat = text.replace(['\n', '\r'], " "); - lines.push(format!("> 🎤 {flat}")); - } - EchoEntry::Failed => { - lines.push("> 🎤 (transcription failed)".to_string()); - } - } - } - Some(lines.join("\n")) -} - -/// Post a transcript echo to the thread and add a ⚠️ reaction for any failed -/// entries. No-op when the config disables echoing or when `entries` is empty. -/// -/// Errors from the adapter (send/reaction) are logged and swallowed — the -/// echo is best-effort and must never block the agent reply. -pub async fn post_echo( - adapter: &Arc, - thread: &ChannelRef, - trigger: &MessageRef, - entries: &[EchoEntry], - cfg: &SttConfig, -) { - if !cfg.echo_transcript { - return; - } - let Some(body) = format_echo_message(entries) else { - return; - }; - if let Err(e) = adapter.send_message(thread, &body).await { - warn!(error = %e, platform = adapter.platform(), "failed to send STT echo message"); - } - for entry in entries { - if matches!(entry, EchoEntry::Failed) { - if let Err(e) = adapter.add_reaction(trigger, "⚠️").await { - warn!(error = %e, platform = adapter.platform(), "failed to add STT failure reaction"); - } - // Add only one reaction even with multiple failures — emoji reactions - // are unique per (user, emoji, message), so additional calls are no-ops. - break; - } - } -} - -/// Transcribe audio bytes via an OpenAI-compatible `/audio/transcriptions` endpoint. -pub async fn transcribe( - client: &reqwest::Client, - cfg: &SttConfig, - audio_bytes: Vec, - filename: String, - mime_type: &str, -) -> Option { - let url = format!( - "{}/audio/transcriptions", - cfg.base_url.trim_end_matches('/') - ); - - let file_part = multipart::Part::bytes(audio_bytes) - .file_name(filename) - .mime_str(mime_type) - .ok()?; - - let form = multipart::Form::new() - .part("file", file_part) - .text("model", cfg.model.clone()) - .text("response_format", "json"); - - let resp = match client - .post(&url) - .bearer_auth(&cfg.api_key) - .multipart(form) - .send() - .await - { - Ok(r) => r, - Err(e) => { - error!(error = %e, "STT request failed"); - return None; - } - }; - - if !resp.status().is_success() { - let status = resp.status(); - let body = resp.text().await.unwrap_or_default(); - error!(status = %status, body = %body, "STT API error"); - return None; - } - - let json: serde_json::Value = match resp.json().await { - Ok(v) => v, - Err(e) => { - error!(error = %e, "STT response parse failed"); - return None; - } - }; - - let text = json.get("text")?.as_str()?.trim().to_string(); - if text.is_empty() { - return None; - } - - debug!(chars = text.len(), "STT transcription complete"); - Some(text) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn format_single_success_entry() { - let entries = vec![EchoEntry::Success("hello world".into())]; - let out = format_echo_message(&entries).expect("non-empty input → Some"); - assert_eq!(out, "> 🎤 hello world"); - } - - #[test] - fn format_single_failure_entry() { - let entries = vec![EchoEntry::Failed]; - let out = format_echo_message(&entries).expect("non-empty input → Some"); - assert_eq!(out, "> 🎤 (transcription failed)"); - } - - #[test] - fn format_multiple_mixed_entries() { - let entries = vec![ - EchoEntry::Success("first".into()), - EchoEntry::Failed, - EchoEntry::Success("third".into()), - ]; - let out = format_echo_message(&entries).expect("non-empty input → Some"); - assert_eq!(out, "> 🎤 first\n> 🎤 (transcription failed)\n> 🎤 third"); - } - - #[test] - fn format_empty_entries_returns_none() { - let entries: Vec = vec![]; - assert!(format_echo_message(&entries).is_none()); - } - - #[test] - fn format_strips_internal_newlines_in_transcript() { - // Multi-line transcripts must collapse to a single quoted line so the - // ">" prefix still applies to every visual line. - let entries = vec![EchoEntry::Success("line one\nline two".into())]; - let out = format_echo_message(&entries).expect("non-empty input → Some"); - assert_eq!(out, "> 🎤 line one line two"); - } - - use crate::adapter::{ChannelRef, ChatAdapter, MessageRef}; - use anyhow::Result; - use async_trait::async_trait; - use std::sync::{Arc, Mutex}; - - #[derive(Default)] - struct MockAdapter { - sent_messages: Mutex>, - reactions: Mutex>, - } - - #[async_trait] - impl ChatAdapter for MockAdapter { - fn platform(&self) -> &'static str { - "mock" - } - fn message_limit(&self) -> usize { - 4000 - } - async fn send_message(&self, channel: &ChannelRef, content: &str) -> Result { - self.sent_messages - .lock() - .unwrap() - .push((channel.clone(), content.to_string())); - Ok(MessageRef { - channel: channel.clone(), - message_id: "mock-msg".into(), - }) - } - async fn create_thread( - &self, - channel: &ChannelRef, - _trigger: &MessageRef, - _title: &str, - ) -> Result { - Ok(channel.clone()) - } - async fn add_reaction(&self, msg: &MessageRef, emoji: &str) -> Result<()> { - self.reactions - .lock() - .unwrap() - .push((msg.clone(), emoji.to_string())); - Ok(()) - } - async fn remove_reaction(&self, _msg: &MessageRef, _emoji: &str) -> Result<()> { - Ok(()) - } - fn use_streaming(&self, _other_bot_present: bool) -> bool { - false - } - } - - fn test_channel() -> ChannelRef { - ChannelRef { - platform: "mock".into(), - channel_id: "C1".into(), - thread_id: Some("T1".into()), - parent_id: None, - origin_event_id: None, - } - } - - fn test_trigger() -> MessageRef { - MessageRef { - channel: test_channel(), - message_id: "M1".into(), - } - } - - fn cfg(echo: bool) -> SttConfig { - SttConfig { - echo_transcript: echo, - ..SttConfig::default() - } - } - - #[tokio::test] - async fn post_echo_success_sends_one_message_no_reactions() { - let mock = Arc::new(MockAdapter::default()); - let adapter: Arc = mock.clone(); - let entries = vec![EchoEntry::Success("hello".into())]; - post_echo( - &adapter, - &test_channel(), - &test_trigger(), - &entries, - &cfg(true), - ) - .await; - - assert_eq!(mock.sent_messages.lock().unwrap().len(), 1); - assert_eq!(mock.sent_messages.lock().unwrap()[0].1, "> 🎤 hello"); - assert!(mock.reactions.lock().unwrap().is_empty()); - } - - #[tokio::test] - async fn post_echo_failure_adds_warning_reaction() { - let mock = Arc::new(MockAdapter::default()); - let adapter: Arc = mock.clone(); - let entries = vec![EchoEntry::Failed]; - post_echo( - &adapter, - &test_channel(), - &test_trigger(), - &entries, - &cfg(true), - ) - .await; - - assert_eq!(mock.sent_messages.lock().unwrap().len(), 1); - assert_eq!( - mock.sent_messages.lock().unwrap()[0].1, - "> 🎤 (transcription failed)" - ); - let reactions = mock.reactions.lock().unwrap(); - assert_eq!(reactions.len(), 1); - assert_eq!(reactions[0].1, "⚠️"); - } - - #[tokio::test] - async fn post_echo_mixed_one_message_one_reaction() { - let mock = Arc::new(MockAdapter::default()); - let adapter: Arc = mock.clone(); - let entries = vec![EchoEntry::Success("ok".into()), EchoEntry::Failed]; - post_echo( - &adapter, - &test_channel(), - &test_trigger(), - &entries, - &cfg(true), - ) - .await; - - assert_eq!(mock.sent_messages.lock().unwrap().len(), 1); - assert_eq!( - mock.sent_messages.lock().unwrap()[0].1, - "> 🎤 ok\n> 🎤 (transcription failed)" - ); - assert_eq!(mock.reactions.lock().unwrap().len(), 1); - } - - #[tokio::test] - async fn post_echo_disabled_is_noop() { - let mock = Arc::new(MockAdapter::default()); - let adapter: Arc = mock.clone(); - let entries = vec![EchoEntry::Success("hi".into()), EchoEntry::Failed]; - post_echo( - &adapter, - &test_channel(), - &test_trigger(), - &entries, - &cfg(false), - ) - .await; - - assert!(mock.sent_messages.lock().unwrap().is_empty()); - assert!(mock.reactions.lock().unwrap().is_empty()); - } - - #[tokio::test] - async fn post_echo_empty_entries_is_noop() { - let mock = Arc::new(MockAdapter::default()); - let adapter: Arc = mock.clone(); - let entries: Vec = vec![]; - post_echo( - &adapter, - &test_channel(), - &test_trigger(), - &entries, - &cfg(true), - ) - .await; - - assert!(mock.sent_messages.lock().unwrap().is_empty()); - assert!(mock.reactions.lock().unwrap().is_empty()); - } -} diff --git a/src/timestamp.rs b/src/timestamp.rs deleted file mode 100644 index aa7adce46..000000000 --- a/src/timestamp.rs +++ /dev/null @@ -1,114 +0,0 @@ -//! ISO 8601 UTC timestamp helpers — no external crate dependency. -//! -//! Centralizes the Gregorian date math used by Slack (`.` ts strings) -//! and Gateway (`SystemTime::now()`) so both adapters share one implementation. - -use std::time::{SystemTime, UNIX_EPOCH}; - -/// Convert days since the Unix epoch (1970-01-01) to a Gregorian (year, month, day). -/// Algorithm from . -fn days_to_ymd(days: u64) -> (u64, u64, u64) { - let z = days + 719468; - let era = z / 146097; - let doe = z % 146097; - let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365; - let y = yoe + era * 400; - let doy = doe - (365 * yoe + yoe / 4 - yoe / 100); - let mp = (5 * doy + 2) / 153; - let d = doy - (153 * mp + 2) / 5 + 1; - let m = if mp < 10 { mp + 3 } else { mp - 9 }; - let y = if m <= 2 { y + 1 } else { y }; - (y, m, d) -} - -/// Format a Unix timestamp (seconds + millis) as ISO 8601 UTC with millisecond precision. -fn unix_to_iso8601(secs: u64, ms: u64) -> String { - let days = secs / 86400; - let time_secs = secs % 86400; - let h = time_secs / 3600; - let m = (time_secs % 3600) / 60; - let s = time_secs % 60; - let (year, month, day) = days_to_ymd(days); - format!("{year:04}-{month:02}-{day:02}T{h:02}:{m:02}:{s:02}.{ms:03}Z") -} - -/// Convert a Slack `ts` string (".") to ISO 8601 UTC. -/// Best-effort; falls back to epoch on parse failure. -/// -/// Parses as `f64` so the fractional part carries decimal semantics directly — -/// ".12" maps to 120 ms, not 12 ms — without any string-padding gymnastics. -pub fn slack_ts_to_iso8601(ts: &str) -> String { - let total = ts.parse::().unwrap_or(0.0); - let secs = total.trunc() as u64; - let ms = (total.fract() * 1000.0).round() as u64; - unix_to_iso8601(secs, ms) -} - -/// Current wall-clock instant as ISO 8601 UTC with millisecond precision. -pub fn now_iso8601() -> String { - let dur = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap_or_default(); - unix_to_iso8601(dur.as_secs(), (dur.subsec_millis()) as u64) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn slack_ts_epoch_zero() { - assert_eq!(slack_ts_to_iso8601("0.000000"), "1970-01-01T00:00:00.000Z"); - } - - #[test] - fn slack_ts_keeps_milliseconds() { - // 1714204397 = 2024-04-27T07:53:17 UTC; .123456 → .123 ms - assert_eq!( - slack_ts_to_iso8601("1714204397.123456"), - "2024-04-27T07:53:17.123Z" - ); - } - - #[test] - fn slack_ts_missing_fraction_uses_zero() { - assert_eq!( - slack_ts_to_iso8601("1714204397"), - "2024-04-27T07:53:17.000Z" - ); - } - - #[test] - fn slack_ts_two_digit_fraction_is_120ms_not_12ms() { - // ".12" carries decimal semantics: 0.12 s = 120 ms. - assert_eq!( - slack_ts_to_iso8601("1714204397.12"), - "2024-04-27T07:53:17.120Z" - ); - } - - #[test] - fn slack_ts_one_digit_fraction_is_100ms_not_1ms() { - // ".1" carries decimal semantics: 0.1 s = 100 ms. - assert_eq!( - slack_ts_to_iso8601("1714204397.1"), - "2024-04-27T07:53:17.100Z" - ); - } - - #[test] - fn slack_ts_unparseable_falls_back_to_epoch() { - assert_eq!(slack_ts_to_iso8601("not-a-ts"), "1970-01-01T00:00:00.000Z"); - } - - #[test] - fn now_iso8601_has_expected_shape() { - let s = now_iso8601(); - // YYYY-MM-DDTHH:MM:SS.mmmZ = 24 chars - assert_eq!(s.len(), 24); - assert!(s.ends_with('Z')); - assert_eq!(&s[4..5], "-"); - assert_eq!(&s[10..11], "T"); - assert_eq!(&s[19..20], "."); - } -} From d25cbe6fbaba94f79654f41cb0e564d8e121941e Mon Sep 17 00:00:00 2001 From: chaodu-agent Date: Thu, 18 Jun 2026 20:19:22 +0000 Subject: [PATCH 03/20] fix: address remaining review findings - Remove old gateway/ directory (eliminates duplicate package name conflict) Standalone gateway binary will be re-added as thin wrapper in follow-up PR - CI path filters: add crates/** to ci.yml and docker-smoke-test.yml triggers - Tighten openab-gateway lib.rs: media module is pub(crate) - Note: Cargo.lock needs regeneration (no cargo in this env) --- .github/workflows/ci.yml | 3 +- .github/workflows/docker-smoke-test.yml | 1 + crates/openab-gateway/src/lib.rs | 2 +- gateway/Cargo.lock | 2849 ----------------------- gateway/Cargo.toml | 33 - gateway/Dockerfile | 14 - gateway/README.md | 173 -- 7 files changed, 4 insertions(+), 3071 deletions(-) delete mode 100644 gateway/Cargo.lock delete mode 100644 gateway/Cargo.toml delete mode 100644 gateway/Dockerfile delete mode 100644 gateway/README.md diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4a2e000e3..fbd719ceb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -4,6 +4,7 @@ on: pull_request: paths: - "src/**" + - "crates/**" - "gateway/**" - "operator/**" - "Cargo.toml" @@ -29,7 +30,7 @@ jobs: BASE=${{ github.event.pull_request.base.sha }} HEAD=${{ github.event.pull_request.head.sha }} CHANGED=$(git diff --name-only "$BASE" "$HEAD") - echo "core=$(echo "$CHANGED" | grep -qE '^(src/|Cargo\.(toml|lock))' && echo true || echo false)" >> "$GITHUB_OUTPUT" + echo "core=$(echo "$CHANGED" | grep -qE '^(src/|crates/|Cargo\.(toml|lock))' && echo true || echo false)" >> "$GITHUB_OUTPUT" echo "gateway=$(echo "$CHANGED" | grep -q '^gateway/' && echo true || echo false)" >> "$GITHUB_OUTPUT" echo "operator=$(echo "$CHANGED" | grep -q '^operator/' && echo true || echo false)" >> "$GITHUB_OUTPUT" diff --git a/.github/workflows/docker-smoke-test.yml b/.github/workflows/docker-smoke-test.yml index 4f3e084d8..3c3d09d2f 100644 --- a/.github/workflows/docker-smoke-test.yml +++ b/.github/workflows/docker-smoke-test.yml @@ -5,6 +5,7 @@ on: paths: - 'Dockerfile*' - 'src/**' + - 'crates/**' - 'Cargo.*' jobs: diff --git a/crates/openab-gateway/src/lib.rs b/crates/openab-gateway/src/lib.rs index ad11f9db1..d81de34d0 100644 --- a/crates/openab-gateway/src/lib.rs +++ b/crates/openab-gateway/src/lib.rs @@ -1,4 +1,4 @@ pub mod adapters; -pub mod media; +pub(crate) mod media; pub mod schema; pub mod store; diff --git a/gateway/Cargo.lock b/gateway/Cargo.lock deleted file mode 100644 index c1567f997..000000000 --- a/gateway/Cargo.lock +++ /dev/null @@ -1,2849 +0,0 @@ -# This file is automatically @generated by Cargo. -# It is not intended for manual editing. -version = 4 - -[[package]] -name = "adler2" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" - -[[package]] -name = "aes" -version = "0.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" -dependencies = [ - "cfg-if", - "cipher", - "cpufeatures", -] - -[[package]] -name = "aho-corasick" -version = "1.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" -dependencies = [ - "memchr", -] - -[[package]] -name = "android_system_properties" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" -dependencies = [ - "libc", -] - -[[package]] -name = "anyhow" -version = "1.0.102" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" - -[[package]] -name = "assert-json-diff" -version = "2.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12" -dependencies = [ - "serde", - "serde_json", -] - -[[package]] -name = "atomic-waker" -version = "1.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" - -[[package]] -name = "autocfg" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" - -[[package]] -name = "axum" -version = "0.8.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31b698c5f9a010f6573133b09e0de5408834d0c82f8d7475a89fc1867a71cd90" -dependencies = [ - "axum-core", - "base64", - "bytes", - "form_urlencoded", - "futures-util", - "http", - "http-body", - "http-body-util", - "hyper", - "hyper-util", - "itoa", - "matchit", - "memchr", - "mime", - "percent-encoding", - "pin-project-lite", - "serde_core", - "serde_json", - "serde_path_to_error", - "serde_urlencoded", - "sha1", - "sync_wrapper", - "tokio", - "tokio-tungstenite 0.29.0", - "tower", - "tower-layer", - "tower-service", - "tracing", -] - -[[package]] -name = "axum-core" -version = "0.5.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1" -dependencies = [ - "bytes", - "futures-core", - "http", - "http-body", - "http-body-util", - "mime", - "pin-project-lite", - "sync_wrapper", - "tower-layer", - "tower-service", - "tracing", -] - -[[package]] -name = "base64" -version = "0.22.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" - -[[package]] -name = "bitflags" -version = "2.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4512299f36f043ab09a583e57bceb5a5aab7a73db1805848e8fef3c9e8c78b3" - -[[package]] -name = "block-buffer" -version = "0.10.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" -dependencies = [ - "generic-array", -] - -[[package]] -name = "block-padding" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8894febbff9f758034a5b8e12d87918f56dfc64a8e1fe757d65e29041538d93" -dependencies = [ - "generic-array", -] - -[[package]] -name = "bumpalo" -version = "3.20.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" - -[[package]] -name = "bytemuck" -version = "1.25.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8efb64bd706a16a1bdde310ae86b351e4d21550d98d056f22f8a7f7a2183fec" - -[[package]] -name = "byteorder" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" - -[[package]] -name = "byteorder-lite" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" - -[[package]] -name = "bytes" -version = "1.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" - -[[package]] -name = "cbc" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26b52a9543ae338f279b96b0b9fed9c8093744685043739079ce85cd58f289a6" -dependencies = [ - "cipher", -] - -[[package]] -name = "cc" -version = "1.2.60" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43c5703da9466b66a946814e1adf53ea2c90f10063b86290cc9eb67ce3478a20" -dependencies = [ - "find-msvc-tools", - "shlex", -] - -[[package]] -name = "cfg-if" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" - -[[package]] -name = "cfg_aliases" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" - -[[package]] -name = "chrono" -version = "0.4.44" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0" -dependencies = [ - "iana-time-zone", - "js-sys", - "num-traits", - "serde", - "wasm-bindgen", - "windows-link", -] - -[[package]] -name = "cipher" -version = "0.4.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" -dependencies = [ - "crypto-common", - "inout", -] - -[[package]] -name = "color_quant" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" - -[[package]] -name = "core-foundation-sys" -version = "0.8.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" - -[[package]] -name = "cpufeatures" -version = "0.2.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" -dependencies = [ - "libc", -] - -[[package]] -name = "crc32fast" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" -dependencies = [ - "cfg-if", -] - -[[package]] -name = "crypto-common" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" -dependencies = [ - "generic-array", - "typenum", -] - -[[package]] -name = "data-encoding" -version = "2.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7a1e2f27636f116493b8b860f5546edb47c8d8f8ea73e1d2a20be88e28d1fea" - -[[package]] -name = "deadpool" -version = "0.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0be2b1d1d6ec8d846f05e137292d0b89133caf95ef33695424c09568bdd39b1b" -dependencies = [ - "deadpool-runtime", - "lazy_static", - "num_cpus", - "tokio", -] - -[[package]] -name = "deadpool-runtime" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b" - -[[package]] -name = "deranged" -version = "0.5.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7cd812cc2bc1d69d4764bd80df88b4317eaef9e773c75226407d9bc0876b211c" -dependencies = [ - "powerfmt", -] - -[[package]] -name = "digest" -version = "0.10.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" -dependencies = [ - "block-buffer", - "crypto-common", - "subtle", -] - -[[package]] -name = "displaydoc" -version = "0.2.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "either" -version = "1.15.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" - -[[package]] -name = "equivalent" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" - -[[package]] -name = "errno" -version = "0.3.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" -dependencies = [ - "libc", - "windows-sys 0.61.2", -] - -[[package]] -name = "fdeflate" -version = "0.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e6853b52649d4ac5c0bd02320cddc5ba956bdb407c4b75a2c6b75bf51500f8c" -dependencies = [ - "simd-adler32", -] - -[[package]] -name = "find-msvc-tools" -version = "0.1.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" - -[[package]] -name = "flate2" -version = "1.1.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "843fba2746e448b37e26a819579957415c8cef339bf08564fe8b7ddbd959573c" -dependencies = [ - "crc32fast", - "miniz_oxide", -] - -[[package]] -name = "fnv" -version = "1.0.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" - -[[package]] -name = "foldhash" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" - -[[package]] -name = "form_urlencoded" -version = "1.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf" -dependencies = [ - "percent-encoding", -] - -[[package]] -name = "futures" -version = "0.3.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b147ee9d1f6d097cef9ce628cd2ee62288d963e16fb287bd9286455b241382d" -dependencies = [ - "futures-channel", - "futures-core", - "futures-executor", - "futures-io", - "futures-sink", - "futures-task", - "futures-util", -] - -[[package]] -name = "futures-channel" -version = "0.3.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d" -dependencies = [ - "futures-core", - "futures-sink", -] - -[[package]] -name = "futures-core" -version = "0.3.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" - -[[package]] -name = "futures-executor" -version = "0.3.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf29c38818342a3b26b5b923639e7b1f4a61fc5e76102d4b1981c6dc7a7579d" -dependencies = [ - "futures-core", - "futures-task", - "futures-util", -] - -[[package]] -name = "futures-io" -version = "0.3.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718" - -[[package]] -name = "futures-macro" -version = "0.3.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "futures-sink" -version = "0.3.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893" - -[[package]] -name = "futures-task" -version = "0.3.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" - -[[package]] -name = "futures-util" -version = "0.3.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" -dependencies = [ - "futures-channel", - "futures-core", - "futures-io", - "futures-macro", - "futures-sink", - "futures-task", - "memchr", - "pin-project-lite", - "slab", -] - -[[package]] -name = "generic-array" -version = "0.14.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" -dependencies = [ - "typenum", - "version_check", -] - -[[package]] -name = "getrandom" -version = "0.2.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" -dependencies = [ - "cfg-if", - "js-sys", - "libc", - "wasi", - "wasm-bindgen", -] - -[[package]] -name = "getrandom" -version = "0.3.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" -dependencies = [ - "cfg-if", - "js-sys", - "libc", - "r-efi 5.3.0", - "wasip2", - "wasm-bindgen", -] - -[[package]] -name = "getrandom" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0de51e6874e94e7bf76d726fc5d13ba782deca734ff60d5bb2fb2607c7406555" -dependencies = [ - "cfg-if", - "libc", - "r-efi 6.0.0", - "wasip2", - "wasip3", -] - -[[package]] -name = "gif" -version = "0.14.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee8cfcc411d9adbbaba82fb72661cc1bcca13e8bba98b364e62b2dba8f960159" -dependencies = [ - "color_quant", - "weezl", -] - -[[package]] -name = "h2" -version = "0.4.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f44da3a8150a6703ed5d34e164b875fd14c2cdab9af1252a9a1020bde2bdc54" -dependencies = [ - "atomic-waker", - "bytes", - "fnv", - "futures-core", - "futures-sink", - "http", - "indexmap", - "slab", - "tokio", - "tokio-util", - "tracing", -] - -[[package]] -name = "hashbrown" -version = "0.15.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" -dependencies = [ - "foldhash", -] - -[[package]] -name = "hashbrown" -version = "0.17.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f467dd6dccf739c208452f8014c75c18bb8301b050ad1cfb27153803edb0f51" - -[[package]] -name = "heck" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" - -[[package]] -name = "hermit-abi" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" - -[[package]] -name = "hmac" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" -dependencies = [ - "digest", -] - -[[package]] -name = "http" -version = "1.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" -dependencies = [ - "bytes", - "itoa", -] - -[[package]] -name = "http-body" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" -dependencies = [ - "bytes", - "http", -] - -[[package]] -name = "http-body-util" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" -dependencies = [ - "bytes", - "futures-core", - "http", - "http-body", - "pin-project-lite", -] - -[[package]] -name = "httparse" -version = "1.10.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" - -[[package]] -name = "httpdate" -version = "1.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" - -[[package]] -name = "hyper" -version = "1.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6299f016b246a94207e63da54dbe807655bf9e00044f73ded42c3ac5305fbcca" -dependencies = [ - "atomic-waker", - "bytes", - "futures-channel", - "futures-core", - "h2", - "http", - "http-body", - "httparse", - "httpdate", - "itoa", - "pin-project-lite", - "smallvec", - "tokio", - "want", -] - -[[package]] -name = "hyper-rustls" -version = "0.27.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33ca68d021ef39cf6463ab54c1d0f5daf03377b70561305bb89a8f83aab66e0f" -dependencies = [ - "http", - "hyper", - "hyper-util", - "rustls 0.23.39", - "tokio", - "tokio-rustls 0.26.4", - "tower-service", - "webpki-roots 1.0.7", -] - -[[package]] -name = "hyper-util" -version = "0.1.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0" -dependencies = [ - "base64", - "bytes", - "futures-channel", - "futures-util", - "http", - "http-body", - "hyper", - "ipnet", - "libc", - "percent-encoding", - "pin-project-lite", - "socket2", - "tokio", - "tower-service", - "tracing", -] - -[[package]] -name = "iana-time-zone" -version = "0.1.65" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e31bc9ad994ba00e440a8aa5c9ef0ec67d5cb5e5cb0cc7f8b744a35b389cc470" -dependencies = [ - "android_system_properties", - "core-foundation-sys", - "iana-time-zone-haiku", - "js-sys", - "log", - "wasm-bindgen", - "windows-core", -] - -[[package]] -name = "iana-time-zone-haiku" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" -dependencies = [ - "cc", -] - -[[package]] -name = "icu_collections" -version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2984d1cd16c883d7935b9e07e44071dca8d917fd52ecc02c04d5fa0b5a3f191c" -dependencies = [ - "displaydoc", - "potential_utf", - "utf8_iter", - "yoke", - "zerofrom", - "zerovec", -] - -[[package]] -name = "icu_locale_core" -version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92219b62b3e2b4d88ac5119f8904c10f8f61bf7e95b640d25ba3075e6cac2c29" -dependencies = [ - "displaydoc", - "litemap", - "tinystr", - "writeable", - "zerovec", -] - -[[package]] -name = "icu_normalizer" -version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c56e5ee99d6e3d33bd91c5d85458b6005a22140021cc324cea84dd0e72cff3b4" -dependencies = [ - "icu_collections", - "icu_normalizer_data", - "icu_properties", - "icu_provider", - "smallvec", - "zerovec", -] - -[[package]] -name = "icu_normalizer_data" -version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da3be0ae77ea334f4da67c12f149704f19f81d1adf7c51cf482943e84a2bad38" - -[[package]] -name = "icu_properties" -version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bee3b67d0ea5c2cca5003417989af8996f8604e34fb9ddf96208a033901e70de" -dependencies = [ - "icu_collections", - "icu_locale_core", - "icu_properties_data", - "icu_provider", - "zerotrie", - "zerovec", -] - -[[package]] -name = "icu_properties_data" -version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e2bbb201e0c04f7b4b3e14382af113e17ba4f63e2c9d2ee626b720cbce54a14" - -[[package]] -name = "icu_provider" -version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "139c4cf31c8b5f33d7e199446eff9c1e02decfc2f0eec2c8d71f65befa45b421" -dependencies = [ - "displaydoc", - "icu_locale_core", - "writeable", - "yoke", - "zerofrom", - "zerotrie", - "zerovec", -] - -[[package]] -name = "id-arena" -version = "2.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" - -[[package]] -name = "idna" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" -dependencies = [ - "idna_adapter", - "smallvec", - "utf8_iter", -] - -[[package]] -name = "idna_adapter" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344" -dependencies = [ - "icu_normalizer", - "icu_properties", -] - -[[package]] -name = "image" -version = "0.25.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85ab80394333c02fe689eaf900ab500fbd0c2213da414687ebf995a65d5a6104" -dependencies = [ - "bytemuck", - "byteorder-lite", - "color_quant", - "gif", - "image-webp", - "moxcms", - "num-traits", - "png", - "zune-core", - "zune-jpeg", -] - -[[package]] -name = "image-webp" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "525e9ff3e1a4be2fbea1fdf0e98686a6d98b4d8f937e1bf7402245af1909e8c3" -dependencies = [ - "byteorder-lite", - "quick-error", -] - -[[package]] -name = "indexmap" -version = "2.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d466e9454f08e4a911e14806c24e16fba1b4c121d1ea474396f396069cf949d9" -dependencies = [ - "equivalent", - "hashbrown 0.17.0", - "serde", - "serde_core", -] - -[[package]] -name = "inout" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01" -dependencies = [ - "block-padding", - "generic-array", -] - -[[package]] -name = "ipnet" -version = "2.12.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2" - -[[package]] -name = "iri-string" -version = "0.7.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25e659a4bb38e810ebc252e53b5814ff908a8c58c2a9ce2fae1bbec24cbf4e20" -dependencies = [ - "memchr", - "serde", -] - -[[package]] -name = "itertools" -version = "0.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" -dependencies = [ - "either", -] - -[[package]] -name = "itoa" -version = "1.0.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" - -[[package]] -name = "js-sys" -version = "0.3.95" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2964e92d1d9dc3364cae4d718d93f227e3abb088e747d92e0395bfdedf1c12ca" -dependencies = [ - "cfg-if", - "futures-util", - "once_cell", - "wasm-bindgen", -] - -[[package]] -name = "jsonwebtoken" -version = "9.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a87cc7a48537badeae96744432de36f4be2b4a34a05a5ef32e9dd8a1c169dde" -dependencies = [ - "base64", - "js-sys", - "pem", - "ring", - "serde", - "serde_json", - "simple_asn1", -] - -[[package]] -name = "lazy_static" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" - -[[package]] -name = "leb128fmt" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" - -[[package]] -name = "libc" -version = "0.2.186" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66" - -[[package]] -name = "litemap" -version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92daf443525c4cce67b150400bc2316076100ce0b3686209eb8cf3c31612e6f0" - -[[package]] -name = "lock_api" -version = "0.4.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" -dependencies = [ - "scopeguard", -] - -[[package]] -name = "log" -version = "0.4.29" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" - -[[package]] -name = "lru-slab" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" - -[[package]] -name = "matchers" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" -dependencies = [ - "regex-automata", -] - -[[package]] -name = "matchit" -version = "0.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" - -[[package]] -name = "memchr" -version = "2.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" - -[[package]] -name = "mime" -version = "0.3.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" - -[[package]] -name = "miniz_oxide" -version = "0.8.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" -dependencies = [ - "adler2", - "simd-adler32", -] - -[[package]] -name = "mio" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50b7e5b27aa02a74bac8c3f23f448f8d87ff11f92d3aac1a6ed369ee08cc56c1" -dependencies = [ - "libc", - "wasi", - "windows-sys 0.61.2", -] - -[[package]] -name = "moxcms" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb85c154ba489f01b25c0d36ae69a87e4a1c73a72631fc6c0eb6dde34a73e44b" -dependencies = [ - "num-traits", - "pxfm", -] - -[[package]] -name = "nu-ansi-term" -version = "0.50.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" -dependencies = [ - "windows-sys 0.61.2", -] - -[[package]] -name = "num-bigint" -version = "0.4.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" -dependencies = [ - "num-integer", - "num-traits", -] - -[[package]] -name = "num-conv" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6673768db2d862beb9b39a78fdcb1a69439615d5794a1be50caa9bc92c81967" - -[[package]] -name = "num-integer" -version = "0.1.46" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" -dependencies = [ - "num-traits", -] - -[[package]] -name = "num-traits" -version = "0.2.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" -dependencies = [ - "autocfg", -] - -[[package]] -name = "num_cpus" -version = "1.17.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b" -dependencies = [ - "hermit-abi", - "libc", -] - -[[package]] -name = "once_cell" -version = "1.21.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" - -[[package]] -name = "openab-gateway" -version = "0.5.4" -dependencies = [ - "aes", - "anyhow", - "axum", - "base64", - "cbc", - "chrono", - "futures-util", - "hmac", - "image", - "jsonwebtoken", - "prost", - "quick-xml", - "reqwest", - "serde", - "serde_json", - "sha1", - "sha2", - "subtle", - "tokio", - "tokio-tungstenite 0.21.0", - "tracing", - "tracing-subscriber", - "urlencoding", - "uuid", - "wiremock", -] - -[[package]] -name = "parking_lot" -version = "0.12.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" -dependencies = [ - "lock_api", - "parking_lot_core", -] - -[[package]] -name = "parking_lot_core" -version = "0.9.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" -dependencies = [ - "cfg-if", - "libc", - "redox_syscall", - "smallvec", - "windows-link", -] - -[[package]] -name = "pem" -version = "3.0.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d30c53c26bc5b31a98cd02d20f25a7c8567146caf63ed593a9d87b2775291be" -dependencies = [ - "base64", - "serde_core", -] - -[[package]] -name = "percent-encoding" -version = "2.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" - -[[package]] -name = "pin-project-lite" -version = "0.2.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" - -[[package]] -name = "png" -version = "0.18.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60769b8b31b2a9f263dae2776c37b1b28ae246943cf719eb6946a1db05128a61" -dependencies = [ - "bitflags", - "crc32fast", - "fdeflate", - "flate2", - "miniz_oxide", -] - -[[package]] -name = "potential_utf" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0103b1cef7ec0cf76490e969665504990193874ea05c85ff9bab8b911d0a0564" -dependencies = [ - "zerovec", -] - -[[package]] -name = "powerfmt" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" - -[[package]] -name = "ppv-lite86" -version = "0.2.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" -dependencies = [ - "zerocopy", -] - -[[package]] -name = "prettyplease" -version = "0.2.37" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" -dependencies = [ - "proc-macro2", - "syn", -] - -[[package]] -name = "proc-macro2" -version = "1.0.106" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" -dependencies = [ - "unicode-ident", -] - -[[package]] -name = "prost" -version = "0.13.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5" -dependencies = [ - "bytes", - "prost-derive", -] - -[[package]] -name = "prost-derive" -version = "0.13.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" -dependencies = [ - "anyhow", - "itertools", - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "pxfm" -version = "0.1.29" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0c5ccf5294c6ccd63a74f1565028353830a9c2f5eb0c682c355c471726a6e3f" - -[[package]] -name = "quick-error" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" - -[[package]] -name = "quick-xml" -version = "0.37.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "331e97a1af0bf59823e6eadffe373d7b27f485be8748f71471c662c1f269b7fb" -dependencies = [ - "memchr", -] - -[[package]] -name = "quinn" -version = "0.11.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20" -dependencies = [ - "bytes", - "cfg_aliases", - "pin-project-lite", - "quinn-proto", - "quinn-udp", - "rustc-hash", - "rustls 0.23.39", - "socket2", - "thiserror 2.0.18", - "tokio", - "tracing", - "web-time", -] - -[[package]] -name = "quinn-proto" -version = "0.11.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "434b42fec591c96ef50e21e886936e66d3cc3f737104fdb9b737c40ffb94c098" -dependencies = [ - "bytes", - "getrandom 0.3.4", - "lru-slab", - "rand 0.9.4", - "ring", - "rustc-hash", - "rustls 0.23.39", - "rustls-pki-types", - "slab", - "thiserror 2.0.18", - "tinyvec", - "tracing", - "web-time", -] - -[[package]] -name = "quinn-udp" -version = "0.5.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd" -dependencies = [ - "cfg_aliases", - "libc", - "once_cell", - "socket2", - "tracing", - "windows-sys 0.60.2", -] - -[[package]] -name = "quote" -version = "1.0.45" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" -dependencies = [ - "proc-macro2", -] - -[[package]] -name = "r-efi" -version = "5.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" - -[[package]] -name = "r-efi" -version = "6.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" - -[[package]] -name = "rand" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ca0ecfa931c29007047d1bc58e623ab12e5590e8c7cc53200d5202b69266d8a" -dependencies = [ - "libc", - "rand_chacha 0.3.1", - "rand_core 0.6.4", -] - -[[package]] -name = "rand" -version = "0.9.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44c5af06bb1b7d3216d91932aed5265164bf384dc89cd6ba05cf59a35f5f76ea" -dependencies = [ - "rand_chacha 0.9.0", - "rand_core 0.9.5", -] - -[[package]] -name = "rand_chacha" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" -dependencies = [ - "ppv-lite86", - "rand_core 0.6.4", -] - -[[package]] -name = "rand_chacha" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" -dependencies = [ - "ppv-lite86", - "rand_core 0.9.5", -] - -[[package]] -name = "rand_core" -version = "0.6.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" -dependencies = [ - "getrandom 0.2.17", -] - -[[package]] -name = "rand_core" -version = "0.9.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" -dependencies = [ - "getrandom 0.3.4", -] - -[[package]] -name = "redox_syscall" -version = "0.5.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" -dependencies = [ - "bitflags", -] - -[[package]] -name = "regex" -version = "1.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" -dependencies = [ - "aho-corasick", - "memchr", - "regex-automata", - "regex-syntax", -] - -[[package]] -name = "regex-automata" -version = "0.4.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" -dependencies = [ - "aho-corasick", - "memchr", - "regex-syntax", -] - -[[package]] -name = "regex-syntax" -version = "0.8.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" - -[[package]] -name = "reqwest" -version = "0.12.28" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147" -dependencies = [ - "base64", - "bytes", - "futures-core", - "http", - "http-body", - "http-body-util", - "hyper", - "hyper-rustls", - "hyper-util", - "js-sys", - "log", - "percent-encoding", - "pin-project-lite", - "quinn", - "rustls 0.23.39", - "rustls-pki-types", - "serde", - "serde_json", - "serde_urlencoded", - "sync_wrapper", - "tokio", - "tokio-rustls 0.26.4", - "tower", - "tower-http", - "tower-service", - "url", - "wasm-bindgen", - "wasm-bindgen-futures", - "web-sys", - "webpki-roots 1.0.7", -] - -[[package]] -name = "ring" -version = "0.17.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" -dependencies = [ - "cc", - "cfg-if", - "getrandom 0.2.17", - "libc", - "untrusted", - "windows-sys 0.52.0", -] - -[[package]] -name = "rustc-hash" -version = "2.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94300abf3f1ae2e2b8ffb7b58043de3d399c73fa6f4b73826402a5c457614dbe" - -[[package]] -name = "rustls" -version = "0.22.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432" -dependencies = [ - "log", - "ring", - "rustls-pki-types", - "rustls-webpki 0.102.8", - "subtle", - "zeroize", -] - -[[package]] -name = "rustls" -version = "0.23.39" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c2c118cb077cca2822033836dfb1b975355dfb784b5e8da48f7b6c5db74e60e" -dependencies = [ - "once_cell", - "ring", - "rustls-pki-types", - "rustls-webpki 0.103.13", - "subtle", - "zeroize", -] - -[[package]] -name = "rustls-pki-types" -version = "1.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd" -dependencies = [ - "web-time", - "zeroize", -] - -[[package]] -name = "rustls-webpki" -version = "0.102.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9" -dependencies = [ - "ring", - "rustls-pki-types", - "untrusted", -] - -[[package]] -name = "rustls-webpki" -version = "0.103.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61c429a8649f110dddef65e2a5ad240f747e85f7758a6bccc7e5777bd33f756e" -dependencies = [ - "ring", - "rustls-pki-types", - "untrusted", -] - -[[package]] -name = "rustversion" -version = "1.0.22" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" - -[[package]] -name = "ryu" -version = "1.0.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" - -[[package]] -name = "scopeguard" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" - -[[package]] -name = "semver" -version = "1.0.28" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a7852d02fc848982e0c167ef163aaff9cd91dc640ba85e263cb1ce46fae51cd" - -[[package]] -name = "serde" -version = "1.0.228" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" -dependencies = [ - "serde_core", - "serde_derive", -] - -[[package]] -name = "serde_core" -version = "1.0.228" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" -dependencies = [ - "serde_derive", -] - -[[package]] -name = "serde_derive" -version = "1.0.228" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "serde_json" -version = "1.0.149" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" -dependencies = [ - "itoa", - "memchr", - "serde", - "serde_core", - "zmij", -] - -[[package]] -name = "serde_path_to_error" -version = "0.1.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457" -dependencies = [ - "itoa", - "serde", - "serde_core", -] - -[[package]] -name = "serde_urlencoded" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" -dependencies = [ - "form_urlencoded", - "itoa", - "ryu", - "serde", -] - -[[package]] -name = "sha1" -version = "0.10.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" -dependencies = [ - "cfg-if", - "cpufeatures", - "digest", -] - -[[package]] -name = "sha2" -version = "0.10.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" -dependencies = [ - "cfg-if", - "cpufeatures", - "digest", -] - -[[package]] -name = "sharded-slab" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" -dependencies = [ - "lazy_static", -] - -[[package]] -name = "shlex" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" - -[[package]] -name = "signal-hook-registry" -version = "1.4.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4db69cba1110affc0e9f7bcd48bbf87b3f4fc7c61fc9155afd4c469eb3d6c1b" -dependencies = [ - "errno", - "libc", -] - -[[package]] -name = "simd-adler32" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "703d5c7ef118737c72f1af64ad2f6f8c5e1921f818cdcb97b8fe6fc69bf66214" - -[[package]] -name = "simple_asn1" -version = "0.6.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d585997b0ac10be3c5ee635f1bab02d512760d14b7c468801ac8a01d9ae5f1d" -dependencies = [ - "num-bigint", - "num-traits", - "thiserror 2.0.18", - "time", -] - -[[package]] -name = "slab" -version = "0.4.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" - -[[package]] -name = "smallvec" -version = "1.15.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" - -[[package]] -name = "socket2" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" -dependencies = [ - "libc", - "windows-sys 0.61.2", -] - -[[package]] -name = "stable_deref_trait" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" - -[[package]] -name = "subtle" -version = "2.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" - -[[package]] -name = "syn" -version = "2.0.117" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - -[[package]] -name = "sync_wrapper" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" -dependencies = [ - "futures-core", -] - -[[package]] -name = "synstructure" -version = "0.13.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "thiserror" -version = "1.0.69" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" -dependencies = [ - "thiserror-impl 1.0.69", -] - -[[package]] -name = "thiserror" -version = "2.0.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" -dependencies = [ - "thiserror-impl 2.0.18", -] - -[[package]] -name = "thiserror-impl" -version = "1.0.69" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "thiserror-impl" -version = "2.0.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "thread_local" -version = "1.1.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" -dependencies = [ - "cfg-if", -] - -[[package]] -name = "time" -version = "0.3.47" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "743bd48c283afc0388f9b8827b976905fb217ad9e647fae3a379a9283c4def2c" -dependencies = [ - "deranged", - "itoa", - "num-conv", - "powerfmt", - "serde_core", - "time-core", - "time-macros", -] - -[[package]] -name = "time-core" -version = "0.1.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7694e1cfe791f8d31026952abf09c69ca6f6fa4e1a1229e18988f06a04a12dca" - -[[package]] -name = "time-macros" -version = "0.2.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e70e4c5a0e0a8a4823ad65dfe1a6930e4f4d756dcd9dd7939022b5e8c501215" -dependencies = [ - "num-conv", - "time-core", -] - -[[package]] -name = "tinystr" -version = "0.8.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8323304221c2a851516f22236c5722a72eaa19749016521d6dff0824447d96d" -dependencies = [ - "displaydoc", - "zerovec", -] - -[[package]] -name = "tinyvec" -version = "1.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e61e67053d25a4e82c844e8424039d9745781b3fc4f32b8d55ed50f5f667ef3" -dependencies = [ - "tinyvec_macros", -] - -[[package]] -name = "tinyvec_macros" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" - -[[package]] -name = "tokio" -version = "1.52.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b67dee974fe86fd92cc45b7a95fdd2f99a36a6d7b0d431a231178d3d670bbcc6" -dependencies = [ - "bytes", - "libc", - "mio", - "parking_lot", - "pin-project-lite", - "signal-hook-registry", - "socket2", - "tokio-macros", - "windows-sys 0.61.2", -] - -[[package]] -name = "tokio-macros" -version = "2.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "385a6cb71ab9ab790c5fe8d67f1645e6c450a7ce006a33de03daa956cf70a496" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "tokio-rustls" -version = "0.25.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "775e0c0f0adb3a2f22a00c4745d728b479985fc15ee7ca6a2608388c5569860f" -dependencies = [ - "rustls 0.22.4", - "rustls-pki-types", - "tokio", -] - -[[package]] -name = "tokio-rustls" -version = "0.26.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" -dependencies = [ - "rustls 0.23.39", - "tokio", -] - -[[package]] -name = "tokio-tungstenite" -version = "0.21.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38" -dependencies = [ - "futures-util", - "log", - "rustls 0.22.4", - "rustls-pki-types", - "tokio", - "tokio-rustls 0.25.0", - "tungstenite 0.21.0", - "webpki-roots 0.26.11", -] - -[[package]] -name = "tokio-tungstenite" -version = "0.29.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f72a05e828585856dacd553fba484c242c46e391fb0e58917c942ee9202915c" -dependencies = [ - "futures-util", - "log", - "tokio", - "tungstenite 0.29.0", -] - -[[package]] -name = "tokio-util" -version = "0.7.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" -dependencies = [ - "bytes", - "futures-core", - "futures-sink", - "pin-project-lite", - "tokio", -] - -[[package]] -name = "tower" -version = "0.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4" -dependencies = [ - "futures-core", - "futures-util", - "pin-project-lite", - "sync_wrapper", - "tokio", - "tower-layer", - "tower-service", - "tracing", -] - -[[package]] -name = "tower-http" -version = "0.6.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" -dependencies = [ - "bitflags", - "bytes", - "futures-util", - "http", - "http-body", - "iri-string", - "pin-project-lite", - "tower", - "tower-layer", - "tower-service", -] - -[[package]] -name = "tower-layer" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" - -[[package]] -name = "tower-service" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" - -[[package]] -name = "tracing" -version = "0.1.44" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" -dependencies = [ - "log", - "pin-project-lite", - "tracing-attributes", - "tracing-core", -] - -[[package]] -name = "tracing-attributes" -version = "0.1.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "tracing-core" -version = "0.1.36" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" -dependencies = [ - "once_cell", - "valuable", -] - -[[package]] -name = "tracing-log" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" -dependencies = [ - "log", - "once_cell", - "tracing-core", -] - -[[package]] -name = "tracing-subscriber" -version = "0.3.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb7f578e5945fb242538965c2d0b04418d38ec25c79d160cd279bf0731c8d319" -dependencies = [ - "matchers", - "nu-ansi-term", - "once_cell", - "regex-automata", - "sharded-slab", - "smallvec", - "thread_local", - "tracing", - "tracing-core", - "tracing-log", -] - -[[package]] -name = "try-lock" -version = "0.2.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" - -[[package]] -name = "tungstenite" -version = "0.21.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1" -dependencies = [ - "byteorder", - "bytes", - "data-encoding", - "http", - "httparse", - "log", - "rand 0.8.6", - "rustls 0.22.4", - "rustls-pki-types", - "sha1", - "thiserror 1.0.69", - "url", - "utf-8", -] - -[[package]] -name = "tungstenite" -version = "0.29.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c01152af293afb9c7c2a57e4b559c5620b421f6d133261c60dd2d0cdb38e6b8" -dependencies = [ - "bytes", - "data-encoding", - "http", - "httparse", - "log", - "rand 0.9.4", - "sha1", - "thiserror 2.0.18", -] - -[[package]] -name = "typenum" -version = "1.20.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40ce102ab67701b8526c123c1bab5cbe42d7040ccfd0f64af1a385808d2f43de" - -[[package]] -name = "unicode-ident" -version = "1.0.24" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" - -[[package]] -name = "unicode-xid" -version = "0.2.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" - -[[package]] -name = "untrusted" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" - -[[package]] -name = "url" -version = "2.5.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff67a8a4397373c3ef660812acab3268222035010ab8680ec4215f38ba3d0eed" -dependencies = [ - "form_urlencoded", - "idna", - "percent-encoding", - "serde", -] - -[[package]] -name = "urlencoding" -version = "2.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" - -[[package]] -name = "utf-8" -version = "0.7.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" - -[[package]] -name = "utf8_iter" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" - -[[package]] -name = "uuid" -version = "1.23.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddd74a9687298c6858e9b88ec8935ec45d22e8fd5e6394fa1bd4e99a87789c76" -dependencies = [ - "getrandom 0.4.2", - "js-sys", - "wasm-bindgen", -] - -[[package]] -name = "valuable" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" - -[[package]] -name = "version_check" -version = "0.9.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" - -[[package]] -name = "want" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" -dependencies = [ - "try-lock", -] - -[[package]] -name = "wasi" -version = "0.11.1+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" - -[[package]] -name = "wasip2" -version = "1.0.3+wasi-0.2.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20064672db26d7cdc89c7798c48a0fdfac8213434a1186e5ef29fd560ae223d6" -dependencies = [ - "wit-bindgen 0.57.1", -] - -[[package]] -name = "wasip3" -version = "0.4.0+wasi-0.3.0-rc-2026-01-06" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" -dependencies = [ - "wit-bindgen 0.51.0", -] - -[[package]] -name = "wasm-bindgen" -version = "0.2.118" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bf938a0bacb0469e83c1e148908bd7d5a6010354cf4fb73279b7447422e3a89" -dependencies = [ - "cfg-if", - "once_cell", - "rustversion", - "wasm-bindgen-macro", - "wasm-bindgen-shared", -] - -[[package]] -name = "wasm-bindgen-futures" -version = "0.4.68" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f371d383f2fb139252e0bfac3b81b265689bf45b6874af544ffa4c975ac1ebf8" -dependencies = [ - "js-sys", - "wasm-bindgen", -] - -[[package]] -name = "wasm-bindgen-macro" -version = "0.2.118" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eeff24f84126c0ec2db7a449f0c2ec963c6a49efe0698c4242929da037ca28ed" -dependencies = [ - "quote", - "wasm-bindgen-macro-support", -] - -[[package]] -name = "wasm-bindgen-macro-support" -version = "0.2.118" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d08065faf983b2b80a79fd87d8254c409281cf7de75fc4b773019824196c904" -dependencies = [ - "bumpalo", - "proc-macro2", - "quote", - "syn", - "wasm-bindgen-shared", -] - -[[package]] -name = "wasm-bindgen-shared" -version = "0.2.118" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fd04d9e306f1907bd13c6361b5c6bfc7b3b3c095ed3f8a9246390f8dbdee129" -dependencies = [ - "unicode-ident", -] - -[[package]] -name = "wasm-encoder" -version = "0.244.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319" -dependencies = [ - "leb128fmt", - "wasmparser", -] - -[[package]] -name = "wasm-metadata" -version = "0.244.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" -dependencies = [ - "anyhow", - "indexmap", - "wasm-encoder", - "wasmparser", -] - -[[package]] -name = "wasmparser" -version = "0.244.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" -dependencies = [ - "bitflags", - "hashbrown 0.15.5", - "indexmap", - "semver", -] - -[[package]] -name = "web-sys" -version = "0.3.95" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f2dfbb17949fa2088e5d39408c48368947b86f7834484e87b73de55bc14d97d" -dependencies = [ - "js-sys", - "wasm-bindgen", -] - -[[package]] -name = "web-time" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" -dependencies = [ - "js-sys", - "wasm-bindgen", -] - -[[package]] -name = "webpki-roots" -version = "0.26.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "521bc38abb08001b01866da9f51eb7c5d647a19260e00054a8c7fd5f9e57f7a9" -dependencies = [ - "webpki-roots 1.0.7", -] - -[[package]] -name = "webpki-roots" -version = "1.0.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52f5ee44c96cf55f1b349600768e3ece3a8f26010c05265ab73f945bb1a2eb9d" -dependencies = [ - "rustls-pki-types", -] - -[[package]] -name = "weezl" -version = "0.1.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a28ac98ddc8b9274cb41bb4d9d4d5c425b6020c50c46f25559911905610b4a88" - -[[package]] -name = "windows-core" -version = "0.62.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" -dependencies = [ - "windows-implement", - "windows-interface", - "windows-link", - "windows-result", - "windows-strings", -] - -[[package]] -name = "windows-implement" -version = "0.60.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "windows-interface" -version = "0.59.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "windows-link" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" - -[[package]] -name = "windows-result" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" -dependencies = [ - "windows-link", -] - -[[package]] -name = "windows-strings" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" -dependencies = [ - "windows-link", -] - -[[package]] -name = "windows-sys" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" -dependencies = [ - "windows-targets 0.52.6", -] - -[[package]] -name = "windows-sys" -version = "0.60.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" -dependencies = [ - "windows-targets 0.53.5", -] - -[[package]] -name = "windows-sys" -version = "0.61.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" -dependencies = [ - "windows-link", -] - -[[package]] -name = "windows-targets" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" -dependencies = [ - "windows_aarch64_gnullvm 0.52.6", - "windows_aarch64_msvc 0.52.6", - "windows_i686_gnu 0.52.6", - "windows_i686_gnullvm 0.52.6", - "windows_i686_msvc 0.52.6", - "windows_x86_64_gnu 0.52.6", - "windows_x86_64_gnullvm 0.52.6", - "windows_x86_64_msvc 0.52.6", -] - -[[package]] -name = "windows-targets" -version = "0.53.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" -dependencies = [ - "windows-link", - "windows_aarch64_gnullvm 0.53.1", - "windows_aarch64_msvc 0.53.1", - "windows_i686_gnu 0.53.1", - "windows_i686_gnullvm 0.53.1", - "windows_i686_msvc 0.53.1", - "windows_x86_64_gnu 0.53.1", - "windows_x86_64_gnullvm 0.53.1", - "windows_x86_64_msvc 0.53.1", -] - -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" - -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" - -[[package]] -name = "windows_aarch64_msvc" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" - -[[package]] -name = "windows_aarch64_msvc" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" - -[[package]] -name = "windows_i686_gnu" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" - -[[package]] -name = "windows_i686_gnu" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" - -[[package]] -name = "windows_i686_gnullvm" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" - -[[package]] -name = "windows_i686_gnullvm" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" - -[[package]] -name = "windows_i686_msvc" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" - -[[package]] -name = "windows_i686_msvc" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" - -[[package]] -name = "windows_x86_64_gnu" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" - -[[package]] -name = "windows_x86_64_gnu" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" - -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" - -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" - -[[package]] -name = "windows_x86_64_msvc" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" - -[[package]] -name = "windows_x86_64_msvc" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" - -[[package]] -name = "wiremock" -version = "0.6.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08db1edfb05d9b3c1542e521aea074442088292f00b5f28e435c714a98f85031" -dependencies = [ - "assert-json-diff", - "base64", - "deadpool", - "futures", - "http", - "http-body-util", - "hyper", - "hyper-util", - "log", - "once_cell", - "regex", - "serde", - "serde_json", - "tokio", - "url", -] - -[[package]] -name = "wit-bindgen" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" -dependencies = [ - "wit-bindgen-rust-macro", -] - -[[package]] -name = "wit-bindgen" -version = "0.57.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ebf944e87a7c253233ad6766e082e3cd714b5d03812acc24c318f549614536e" - -[[package]] -name = "wit-bindgen-core" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" -dependencies = [ - "anyhow", - "heck", - "wit-parser", -] - -[[package]] -name = "wit-bindgen-rust" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" -dependencies = [ - "anyhow", - "heck", - "indexmap", - "prettyplease", - "syn", - "wasm-metadata", - "wit-bindgen-core", - "wit-component", -] - -[[package]] -name = "wit-bindgen-rust-macro" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a" -dependencies = [ - "anyhow", - "prettyplease", - "proc-macro2", - "quote", - "syn", - "wit-bindgen-core", - "wit-bindgen-rust", -] - -[[package]] -name = "wit-component" -version = "0.244.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" -dependencies = [ - "anyhow", - "bitflags", - "indexmap", - "log", - "serde", - "serde_derive", - "serde_json", - "wasm-encoder", - "wasm-metadata", - "wasmparser", - "wit-parser", -] - -[[package]] -name = "wit-parser" -version = "0.244.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" -dependencies = [ - "anyhow", - "id-arena", - "indexmap", - "log", - "semver", - "serde", - "serde_derive", - "serde_json", - "unicode-xid", - "wasmparser", -] - -[[package]] -name = "writeable" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ffae5123b2d3fc086436f8834ae3ab053a283cfac8fe0a0b8eaae044768a4c4" - -[[package]] -name = "yoke" -version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "abe8c5fda708d9ca3df187cae8bfb9ceda00dd96231bed36e445a1a48e66f9ca" -dependencies = [ - "stable_deref_trait", - "yoke-derive", - "zerofrom", -] - -[[package]] -name = "yoke-derive" -version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de844c262c8848816172cef550288e7dc6c7b7814b4ee56b3e1553f275f1858e" -dependencies = [ - "proc-macro2", - "quote", - "syn", - "synstructure", -] - -[[package]] -name = "zerocopy" -version = "0.8.48" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eed437bf9d6692032087e337407a86f04cd8d6a16a37199ed57949d415bd68e9" -dependencies = [ - "zerocopy-derive", -] - -[[package]] -name = "zerocopy-derive" -version = "0.8.48" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70e3cd084b1788766f53af483dd21f93881ff30d7320490ec3ef7526d203bad4" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "zerofrom" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69faa1f2a1ea75661980b013019ed6687ed0e83d069bc1114e2cc74c6c04c4df" -dependencies = [ - "zerofrom-derive", -] - -[[package]] -name = "zerofrom-derive" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11532158c46691caf0f2593ea8358fed6bbf68a0315e80aae9bd41fbade684a1" -dependencies = [ - "proc-macro2", - "quote", - "syn", - "synstructure", -] - -[[package]] -name = "zeroize" -version = "1.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" - -[[package]] -name = "zerotrie" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f9152d31db0792fa83f70fb2f83148effb5c1f5b8c7686c3459e361d9bc20bf" -dependencies = [ - "displaydoc", - "yoke", - "zerofrom", -] - -[[package]] -name = "zerovec" -version = "0.11.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90f911cbc359ab6af17377d242225f4d75119aec87ea711a880987b18cd7b239" -dependencies = [ - "yoke", - "zerofrom", - "zerovec-derive", -] - -[[package]] -name = "zerovec-derive" -version = "0.11.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "625dc425cab0dca6dc3c3319506e6593dcb08a9f387ea3b284dbd52a92c40555" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "zmij" -version = "1.0.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" - -[[package]] -name = "zune-core" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb8a0807f7c01457d0379ba880ba6322660448ddebc890ce29bb64da71fb40f9" - -[[package]] -name = "zune-jpeg" -version = "0.5.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "27bc9d5b815bc103f142aa054f561d9187d191692ec7c2d1e2b4737f8dbd7296" -dependencies = [ - "zune-core", -] diff --git a/gateway/Cargo.toml b/gateway/Cargo.toml deleted file mode 100644 index e07dffceb..000000000 --- a/gateway/Cargo.toml +++ /dev/null @@ -1,33 +0,0 @@ -[package] -name = "openab-gateway" -version = "0.5.4" -edition = "2021" - -[dependencies] -tokio = { version = "1", features = ["full"] } -axum = { version = "0.8", features = ["ws"] } -tokio-tungstenite = { version = "0.21", features = ["rustls-tls-webpki-roots"] } -futures-util = "0.3" -serde = { version = "1", features = ["derive"] } -serde_json = "1" -reqwest = { version = "0.12", default-features = false, features = ["rustls-tls", "json"] } -tracing = "0.1" -tracing-subscriber = { version = "0.3", features = ["env-filter"] } -anyhow = "1" -uuid = { version = "1", features = ["v4"] } -chrono = { version = "0.4", features = ["serde"] } -hmac = "0.12" -sha2 = "0.10" -base64 = "0.22" -jsonwebtoken = "9" -aes = "0.8" -cbc = "0.1" -prost = "0.13" -subtle = "2" -sha1 = "0.10" -quick-xml = "0.37" -image = { version = "0.25", default-features = false, features = ["jpeg", "png", "gif", "webp"] } -urlencoding = "2" - -[dev-dependencies] -wiremock = "0.6" diff --git a/gateway/Dockerfile b/gateway/Dockerfile deleted file mode 100644 index 8ee8172d0..000000000 --- a/gateway/Dockerfile +++ /dev/null @@ -1,14 +0,0 @@ -# --- Build stage --- -FROM rust:1-bookworm AS builder -WORKDIR /build -COPY gateway/Cargo.toml gateway/Cargo.lock ./ -RUN mkdir src && echo 'fn main() {}' > src/main.rs && cargo build --release && rm -rf src -COPY gateway/src/ src/ -RUN touch src/main.rs && cargo build --release - -# --- Runtime stage --- -FROM debian:bookworm-slim -RUN apt-get update && apt-get install -y --no-install-recommends ca-certificates && rm -rf /var/lib/apt/lists/* -COPY --from=builder /build/target/release/openab-gateway /usr/local/bin/openab-gateway -EXPOSE 8080 -ENTRYPOINT ["openab-gateway"] diff --git a/gateway/README.md b/gateway/README.md deleted file mode 100644 index 79c492a28..000000000 --- a/gateway/README.md +++ /dev/null @@ -1,173 +0,0 @@ -# OpenAB Custom Gateway - -A standalone service that bridges webhook-based platforms and custom event sources to OAB via WebSocket. OAB connects outbound to the gateway — no inbound ports or TLS required on OAB. - -``` - External (HTTPS) Internal (cluster) - ──────────────── ────────────────── - -Telegram ──POST──▶┌─────────────────────┐ -LINE ──POST──▶│ │ -GitHub ──POST──▶│ Custom Gateway │◀──WebSocket── OAB Pod -CI/CD ──POST──▶│ :8080 │ (OAB connects out) -curl/cron ──POST──▶│ │ - └─────────────────────┘ - -Discord ◀──WebSocket── OAB Pod (unchanged, direct) -Slack ◀──WebSocket── OAB Pod (unchanged, direct) -``` - -The gateway normalizes all inbound events to a unified schema (`openab.gateway.event.v1`), forwards them to OAB over WebSocket, and routes OAB replies back to the originating platform API. - -For architecture details, see [ADR: Custom Gateway](../docs/adr/custom-gateway.md). - -> **Design note:** The gateway is intentionally NOT included in the OAB container image. It is a separate service with its own build, deployment, and scaling lifecycle. This follows the ADR principle that OAB remains outbound-only and platform-agnostic — all inbound webhook handling and platform credentials live in the gateway. - ---- - -## Quick Start - -```bash -cargo build --release -export TELEGRAM_BOT_TOKEN="your-bot-token" -./target/release/openab-gateway -``` - -### OAB Config - -```toml -[gateway] -url = "ws://gateway:8080/ws" -``` - -### Environment Variables - -| Variable | Default | Description | -|---|---|---| -| `TELEGRAM_BOT_TOKEN` | (required) | Telegram Bot API token | -| `GATEWAY_LISTEN` | `0.0.0.0:8080` | Listen address | -| `TELEGRAM_WEBHOOK_PATH` | `/webhook/telegram` | Webhook endpoint path | -| `LINE_CHANNEL_SECRET` | (optional) | LINE channel secret for webhook HMAC signature verification | -| `LINE_CHANNEL_ACCESS_TOKEN` | (optional) | LINE channel access token for Reply/Push API | -| `FEISHU_APP_ID` | (optional) | Feishu/Lark App ID — enables feishu adapter | -| `FEISHU_APP_SECRET` | (optional) | Feishu/Lark App Secret | -| `FEISHU_DOMAIN` | `feishu` | `feishu` (China) or `lark` (international) | -| `FEISHU_CONNECTION_MODE` | `websocket` | `websocket` (recommended) or `webhook` | -| `FEISHU_WEBHOOK_PATH` | `/webhook/feishu` | Webhook endpoint path | -| `FEISHU_VERIFICATION_TOKEN` | (optional) | Webhook verification token | -| `FEISHU_ENCRYPT_KEY` | (optional) | Webhook encrypt key for AES-256-CBC | -| `FEISHU_ALLOWED_GROUPS` | (optional) | Comma-separated chat_id allowlist | -| `FEISHU_ALLOWED_USERS` | (optional) | Comma-separated open_id allowlist | -| `FEISHU_REQUIRE_MENTION` | `true` | Require @mention in groups | -| `FEISHU_DEDUPE_TTL_SECS` | `300` | Event deduplication cache TTL (seconds) | -| `FEISHU_MESSAGE_LIMIT` | `4000` | Max message length before auto-splitting (bytes) | -| `GOOGLE_CHAT_ENABLED` | `false` | Set to `true` or `1` to enable the Google Chat adapter | -| `GOOGLE_CHAT_AUDIENCE` | (optional) | JWT audience for webhook verification (full webhook URL, e.g. `https://your-domain.com/webhook/googlechat`) | -| `GOOGLE_CHAT_SA_KEY_JSON` | (optional) | Service account key JSON string (enables token auto-refresh) | -| `GOOGLE_CHAT_SA_KEY_FILE` | (optional) | Path to service account key JSON file (alternative to `SA_KEY_JSON`) | -| `GOOGLE_CHAT_ACCESS_TOKEN` | (optional) | Static OAuth2 access token (fallback, expires in 1 hour) | -| `GOOGLE_CHAT_WEBHOOK_PATH` | `/webhook/googlechat` | Webhook endpoint path | -| `WECOM_CORP_ID` | (required*) | WeCom Corp ID — enables wecom adapter | -| `WECOM_AGENT_ID` | (required*) | WeCom App Agent ID | -| `WECOM_SECRET` | (required*) | WeCom App Secret | -| `WECOM_TOKEN` | (required*) | Callback verification Token | -| `WECOM_ENCODING_AES_KEY` | (required*) | Callback EncodingAESKey (43 chars) | -| `WECOM_WEBHOOK_PATH` | `/webhook/wecom` | Webhook endpoint path | -| `WECOM_STREAMING_ENABLED` | `false` | Enable thinking-placeholder + recall streaming (causes brief client flicker) | -| `WECOM_DEBOUNCE_SECS` | `3` | Debounce quiet-period seconds before flushing buffered streamed text | - -### Endpoints - -| Path | Description | -|---|---| -| `POST /webhook/telegram` | Telegram webhook receiver | -| `POST /webhook/line` | LINE webhook receiver | -| `POST /webhook/feishu` | Feishu webhook receiver (when `FEISHU_CONNECTION_MODE=webhook`) | -| `POST /webhook/googlechat` | Google Chat webhook receiver | -| `GET /webhook/wecom` | WeCom callback URL verification | -| `POST /webhook/wecom` | WeCom message callback receiver | -| `GET /ws` | WebSocket server (OAB connects here) | -| `GET /health` | Health check | - ---- - -## Platform Setup - -### Telegram - -1. Create a bot via [@BotFather](https://t.me/BotFather) and get the token. - -2. Start the gateway: - ```bash - export TELEGRAM_BOT_TOKEN="your-token" - ./target/release/openab-gateway - ``` - -3. Expose the gateway over HTTPS (Telegram requires it). Easiest option — Cloudflare Tunnel: - ```bash - cloudflared tunnel --url http://localhost:8080 - ``` - -4. Set the webhook: - ```bash - curl "https://api.telegram.org/bot${TELEGRAM_BOT_TOKEN}/setWebhook?url=https://your-host/webhook/telegram" - ``` - -5. For supergroup forum topics (thread isolation like Discord), give the bot **Manage Topics** permission in the group settings. - -### LINE - -See [docs/line.md](../docs/line.md) for the full setup guide. - -### Feishu/Lark - -See [docs/feishu.md](../docs/feishu.md) for the full setup guide. - -### Google Chat - -See [docs/google-chat.md](../docs/google-chat.md) for the full setup guide. - -### WeCom (企业微信) - -See [docs/wecom.md](../docs/wecom.md) for the full setup guide. - -### Other Platforms - -GitHub webhooks, CI/CD events, monitoring alerts — any HTTP event source can be added as a gateway adapter. See the ADR for the adapter interface. - ---- - -## Custom Event Source - -Any HTTP client can drive an OAB agent session by posting to the webhook endpoint. This turns OAB into an event-driven agent platform — no chat app required. - -### Example: trigger an agent from a cron job - -```bash -curl -X POST http://gateway:8080/webhook/telegram \ - -H "Content-Type: application/json" \ - -d '{ - "message": { - "message_id": 1, - "chat": {"id": 12345, "type": "private"}, - "from": {"id": 99, "first_name": "CronJob", "username": "scheduler", "is_bot": false}, - "text": "run daily security scan on staging" - } - }' -``` - -### Example: generic event (future `/webhook/custom` endpoint) - -Once a generic webhook adapter is added, any JSON payload can trigger an agent: - -```bash -curl -X POST http://gateway:8080/webhook/custom \ - -H "Content-Type: application/json" \ - -d '{ - "channel": "ops-alerts", - "sender": "cloudwatch", - "text": "CPU > 90% on prod-api-3 for 5 minutes, investigate and suggest fix" - }' -``` - -The agent response is delivered back through the gateway to whatever reply mechanism the adapter defines — Telegram message, GitHub comment, Slack DM, PagerDuty note, or simply logged. From 928815de1e8c7d36d355eb1682116ac7dce1ce1b Mon Sep 17 00:00:00 2001 From: chaodu-agent Date: Thu, 18 Jun 2026 20:21:18 +0000 Subject: [PATCH 04/20] fix: add unified mode hook in main.rs, update gateway CI workflows MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add #[cfg(any(feature = "telegram", ...))] block in main.rs that will host the embedded axum webhook server in unified mode (TODO: full wiring in follow-up PR per ADR Phase 1) - build-gateway.yml: use root Dockerfile with BUILD_MODE=unified instead of deleted gateway/Dockerfile - gateway-release-pr.yml: version bump path → crates/openab-gateway/ --- .github/workflows/build-gateway.yml | 4 +++- .github/workflows/gateway-release-pr.yml | 2 +- src/main.rs | 20 ++++++++++++++++++++ 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build-gateway.yml b/.github/workflows/build-gateway.yml index 516e68c3c..d5dd62ee7 100644 --- a/.github/workflows/build-gateway.yml +++ b/.github/workflows/build-gateway.yml @@ -58,7 +58,9 @@ jobs: uses: docker/build-push-action@v6 with: context: . - file: gateway/Dockerfile + file: Dockerfile + build-args: | + BUILD_MODE=unified platforms: ${{ matrix.platform.os }} outputs: type=image,name=${{ env.IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=${{ inputs.dry_run != true }} cache-from: type=gha,scope=gateway-${{ matrix.platform.os }} diff --git a/.github/workflows/gateway-release-pr.yml b/.github/workflows/gateway-release-pr.yml index 34dbaa791..13412becf 100644 --- a/.github/workflows/gateway-release-pr.yml +++ b/.github/workflows/gateway-release-pr.yml @@ -29,7 +29,7 @@ jobs: - name: Update gateway version run: | VERSION="${{ inputs.version }}" - sed -i "s/^version = .*/version = \"${VERSION}\"/" gateway/Cargo.toml + sed -i "s/^version = .*/version = \"${VERSION}\"/" crates/openab-gateway/Cargo.toml echo "::notice::Gateway release version: ${VERSION}" - name: Create release PR diff --git a/src/main.rs b/src/main.rs index 7902a5316..f28a1f3bd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -371,6 +371,26 @@ async fn main() -> anyhow::Result<()> { }; // Spawn cron scheduler (background task) + // Spawn embedded webhook server when gateway adapters are compiled in (unified mode). + // In unified mode, platform webhooks hit this axum server directly → Dispatcher.submit(), + // bypassing the WebSocket hop of the two-process model. + #[cfg(any( + feature = "telegram", + feature = "line", + feature = "feishu", + feature = "googlechat", + feature = "wecom", + feature = "teams", + ))] + let _unified_handle = { + // TODO(Phase 1): Wire each compiled-in adapter's webhook handler to axum routes + // and call Dispatcher.submit() directly instead of going through WS gateway. + // For now, the feature compiles the gateway crate (making the code available) + // but the full runtime integration (axum server, route registration, direct dispatch) + // will be completed in a follow-up PR. + None::> + }; + let usercron_path = if cfg.cron.usercron_enabled { cfg.cron.usercron_path.as_ref().map(|p| { let path = std::path::PathBuf::from(p); From 3194a88c0247b92ff9a7842dc63e2bd80873d76d Mon Sep 17 00:00:00 2001 From: chaodu-agent Date: Thu, 18 Jun 2026 20:23:10 +0000 Subject: [PATCH 05/20] fix: update Cargo.lock + clean up CI gateway job - Regenerate Cargo.lock for workspace (adds openab-core, openab-gateway) - Remove stale gateway CI job (gateway/ no longer exists) - Update check job to use --workspace (covers all crates) - Remove gateway from path filters and change detection --- .github/workflows/ci.yml | 32 +- Cargo.lock | 922 +++++++++++++++++++++++---------------- 2 files changed, 560 insertions(+), 394 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fbd719ceb..9b7a7d337 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -5,7 +5,6 @@ on: paths: - "src/**" - "crates/**" - - "gateway/**" - "operator/**" - "Cargo.toml" - "Cargo.lock" @@ -19,7 +18,6 @@ jobs: runs-on: ubuntu-latest outputs: core: ${{ steps.filter.outputs.core }} - gateway: ${{ steps.filter.outputs.gateway }} operator: ${{ steps.filter.outputs.operator }} steps: - uses: actions/checkout@v6 @@ -31,7 +29,6 @@ jobs: HEAD=${{ github.event.pull_request.head.sha }} CHANGED=$(git diff --name-only "$BASE" "$HEAD") echo "core=$(echo "$CHANGED" | grep -qE '^(src/|crates/|Cargo\.(toml|lock))' && echo true || echo false)" >> "$GITHUB_OUTPUT" - echo "gateway=$(echo "$CHANGED" | grep -q '^gateway/' && echo true || echo false)" >> "$GITHUB_OUTPUT" echo "operator=$(echo "$CHANGED" | grep -q '^operator/' && echo true || echo false)" >> "$GITHUB_OUTPUT" check: @@ -45,33 +42,14 @@ jobs: components: clippy - uses: Swatinem/rust-cache@v2 - name: cargo check - run: cargo check + run: cargo check --workspace - name: cargo clippy - run: cargo clippy -- -D warnings + run: cargo clippy --workspace -- -D warnings - name: cargo test - run: cargo test + run: cargo test --workspace - gateway: - needs: changes - if: needs.changes.outputs.gateway == 'true' - runs-on: ubuntu-latest - defaults: - run: - working-directory: gateway - steps: - - uses: actions/checkout@v6 - - uses: dtolnay/rust-toolchain@stable - with: - components: clippy - - uses: Swatinem/rust-cache@v2 - with: - workspaces: gateway - - name: cargo check - run: cargo check - - name: cargo clippy - run: cargo clippy -- -D warnings - - name: cargo test - run: cargo test + # gateway tests are now covered by `cargo test --workspace` in the check job above + # (openab-gateway is a workspace member in crates/openab-gateway/) operator: needs: changes diff --git a/Cargo.lock b/Cargo.lock index 9001c6e9c..626ada54e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8,6 +8,17 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" +[[package]] +name = "aes" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures 0.2.17", +] + [[package]] name = "aho-corasick" version = "1.1.4" @@ -100,6 +111,16 @@ dependencies = [ "serde", ] +[[package]] +name = "assert-json-diff" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "async-trait" version = "0.1.89" @@ -119,9 +140,9 @@ checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" [[package]] name = "autocfg" -version = "1.5.0" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +checksum = "f2032f911046de80f0a198e0901378627c33f59ea0ac00e363d481118bd70a53" [[package]] name = "aws-config" @@ -145,7 +166,7 @@ dependencies = [ "bytes", "fastrand", "hex", - "http 1.4.0", + "http 1.4.2", "sha1", "time", "tokio", @@ -190,9 +211,9 @@ dependencies = [ [[package]] name = "aws-runtime" -version = "1.7.4" +version = "1.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77ed8e8c52d2dc2390ad9f15647fe663f71e9780b4262c190fbb823a32721566" +checksum = "6c9b9de216a988dd54b754a82a7660cfe14cee4f6782ae4524470972fa0ccb39" dependencies = [ "aws-credential-types", "aws-sigv4", @@ -205,7 +226,7 @@ dependencies = [ "bytes", "bytes-utils", "fastrand", - "http 1.4.0", + "http 1.4.2", "http-body 1.0.1", "percent-encoding", "pin-project-lite", @@ -215,9 +236,9 @@ dependencies = [ [[package]] name = "aws-sdk-secretsmanager" -version = "1.107.0" +version = "1.108.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "63da8ec2dca98a68d8bcba971abae5f06e2c9c0017f43097d1ff92cff96adc54" +checksum = "cd24cfd47bda71881c399a7cc4850d5a61727246163d5cd1a7c0b48ef19983cb" dependencies = [ "arc-swap", "aws-credential-types", @@ -233,16 +254,16 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", - "http 1.4.0", + "http 1.4.2", "regex-lite", "tracing", ] [[package]] name = "aws-sdk-sso" -version = "1.101.0" +version = "1.102.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b647baea49ff551960b904f905681e9b4765a6c4ea08631e89dc52d8bd3f5896" +checksum = "8c82b3ac19f1431854f7ace3a7531674633e286bfdde21976893bfee36fd493b" dependencies = [ "arc-swap", "aws-credential-types", @@ -258,16 +279,16 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", - "http 1.4.0", + "http 1.4.2", "regex-lite", "tracing", ] [[package]] name = "aws-sdk-ssooidc" -version = "1.103.0" +version = "1.104.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ae401c65ff288aa7873117fe535cd32b7b1bb0bc43751d28901a1d5f20636b9" +checksum = "321000d2b4c5519ee573f73167f612efd7329322d9b26969ad1979f0427f1913" dependencies = [ "arc-swap", "aws-credential-types", @@ -283,16 +304,16 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", - "http 1.4.0", + "http 1.4.2", "regex-lite", "tracing", ] [[package]] name = "aws-sdk-sts" -version = "1.106.0" +version = "1.107.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c80de7bb7d03e9ca8c9fd7b489f20f3948d3f3be91a7953591347d238115408" +checksum = "3d0d328ba962af23ecfa3c9f23b98d3d35e325fa218d7f13d17a6bf522f8a560" dependencies = [ "arc-swap", "aws-credential-types", @@ -309,7 +330,7 @@ dependencies = [ "aws-types", "fastrand", "http 0.2.12", - "http 1.4.0", + "http 1.4.2", "regex-lite", "tracing", ] @@ -327,9 +348,9 @@ dependencies = [ "bytes", "form_urlencoded", "hex", - "hmac", + "hmac 0.13.0", "http 0.2.12", - "http 1.4.0", + "http 1.4.2", "percent-encoding", "sha2 0.11.0", "time", @@ -359,7 +380,7 @@ dependencies = [ "bytes-utils", "futures-core", "futures-util", - "http 1.4.0", + "http 1.4.2", "http-body 1.0.1", "http-body-util", "percent-encoding", @@ -378,18 +399,18 @@ dependencies = [ "aws-smithy-runtime-api", "aws-smithy-types", "h2 0.3.27", - "h2 0.4.14", + "h2 0.4.15", "http 0.2.12", - "http 1.4.0", + "http 1.4.2", "http-body 0.4.6", "hyper 0.14.32", - "hyper 1.9.0", + "hyper 1.10.1", "hyper-rustls 0.24.2", "hyper-rustls 0.27.9", "hyper-util", "pin-project-lite", "rustls 0.21.12", - "rustls 0.23.38", + "rustls 0.23.40", "rustls-native-certs", "rustls-pki-types", "tokio", @@ -444,7 +465,7 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", - "http 1.4.0", + "http 1.4.2", "http-body 0.4.6", "http-body 1.0.1", "http-body-util", @@ -465,7 +486,7 @@ dependencies = [ "aws-smithy-types", "bytes", "http 0.2.12", - "http 1.4.0", + "http 1.4.2", "pin-project-lite", "tokio", "tracing", @@ -491,21 +512,21 @@ checksum = "7442cb268338f0eb8278140a107c046756aa01093d8ef5e99628d34ae09c94f5" dependencies = [ "aws-smithy-runtime-api", "aws-smithy-types", - "http 1.4.0", + "http 1.4.2", ] [[package]] name = "aws-smithy-types" -version = "1.4.9" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53f93074121a1be41317b9aa607143ae17900631f7f59a99f2b905d519d6783b" +checksum = "32b42fcf341259d85ca10fac9a2f6448a8ec691c6955a18e45bc3b71a85fab85" dependencies = [ "base64-simd", "bytes", "bytes-utils", "futures-core", "http 0.2.12", - "http 1.4.0", + "http 1.4.2", "http-body 0.4.6", "http-body 1.0.1", "http-body-util", @@ -544,6 +565,61 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31b698c5f9a010f6573133b09e0de5408834d0c82f8d7475a89fc1867a71cd90" +dependencies = [ + "axum-core", + "base64", + "bytes", + "form_urlencoded", + "futures-util", + "http 1.4.2", + "http-body 1.0.1", + "http-body-util", + "hyper 1.10.1", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "serde_core", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sha1", + "sync_wrapper", + "tokio", + "tokio-tungstenite 0.29.0", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1" +dependencies = [ + "bytes", + "futures-core", + "http 1.4.2", + "http-body 1.0.1", + "http-body-util", + "mime", + "pin-project-lite", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "base64" version = "0.22.1" @@ -562,9 +638,9 @@ dependencies = [ [[package]] name = "bitflags" -version = "2.11.1" +version = "2.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4512299f36f043ab09a583e57bceb5a5aab7a73db1805848e8fef3c9e8c78b3" +checksum = "b4388bee8683e3d04af747c73422af53102d2bd24d9eadb6cbc100baef4b43f8" [[package]] name = "block-buffer" @@ -577,18 +653,27 @@ dependencies = [ [[package]] name = "block-buffer" -version = "0.12.0" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cdd35008169921d80bc60d3d0ab416eecb028c4cd653352907921d95084790be" +checksum = "d2f6c7dbe95a6ed67ad9f18e57daf93a2f034c524b99fd2b76d18fdfeb6660aa" dependencies = [ "hybrid-array", ] +[[package]] +name = "block-padding" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8894febbff9f758034a5b8e12d87918f56dfc64a8e1fe757d65e29041538d93" +dependencies = [ + "generic-array", +] + [[package]] name = "bumpalo" -version = "3.20.2" +version = "3.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" +checksum = "72f5acc6cb2ba439de613abc23857ec3d78374d8ed5ac84e9d11336e87da8649" [[package]] name = "bytemuck" @@ -610,9 +695,9 @@ checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" [[package]] name = "bytes" -version = "1.11.1" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" +checksum = "8ae3f5d315924270530207e2a68396c3cc547f6dca3fbdca317cfb1a51edb593" [[package]] name = "bytes-utils" @@ -624,11 +709,20 @@ dependencies = [ "either", ] +[[package]] +name = "cbc" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26b52a9543ae338f279b96b0b9fed9c8093744685043739079ce85cd58f289a6" +dependencies = [ + "cipher", +] + [[package]] name = "cc" -version = "1.2.60" +version = "1.2.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43c5703da9466b66a946814e1adf53ea2c90f10063b86290cc9eb67ce3478a20" +checksum = "dad887fd958be91b5098c0248def011f4523ab786cd411be668777e55063501f" dependencies = [ "find-msvc-tools", "jobserver", @@ -650,9 +744,9 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] name = "chrono" -version = "0.4.44" +version = "0.4.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0" +checksum = "1aa79e62e7697b8e29b513a68abacf485adcd1fe8284a4316c5ae868e6633327" dependencies = [ "iana-time-zone", "js-sys", @@ -672,11 +766,21 @@ dependencies = [ "phf 0.12.1", ] +[[package]] +name = "cipher" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +dependencies = [ + "crypto-common 0.1.7", + "inout", +] + [[package]] name = "clap" -version = "4.6.0" +version = "4.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b193af5b67834b676abd72466a96c1024e6a6ad978a1f484bd90b85c94041351" +checksum = "1ddb117e43bbf7dacf0a4190fef4d345b9bad68dfc649cb349e7d17d28428e51" dependencies = [ "clap_builder", "clap_derive", @@ -696,9 +800,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.6.0" +version = "4.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1110bd8a634a1ab8cb04345d8d878267d57c3cf1b38d91b71af6686408bbca6a" +checksum = "f2ce8604710f6733aa641a2b3731eaa1e8b3d9973d5e3565da11800813f997a9" dependencies = [ "heck", "proc-macro2", @@ -844,9 +948,27 @@ dependencies = [ [[package]] name = "data-encoding" -version = "2.10.0" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4ae5f15dda3c708c0ade84bfee31ccab44a3da4f88015ed22f63732abe300c8" + +[[package]] +name = "deadpool" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0be2b1d1d6ec8d846f05e137292d0b89133caf95ef33695424c09568bdd39b1b" +dependencies = [ + "deadpool-runtime", + "lazy_static", + "num_cpus", + "tokio", +] + +[[package]] +name = "deadpool-runtime" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7a1e2f27636f116493b8b860f5546edb47c8d8f8ea73e1d2a20be88e28d1fea" +checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b" [[package]] name = "deranged" @@ -854,7 +976,6 @@ version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7cd812cc2bc1d69d4764bd80df88b4317eaef9e773c75226407d9bc0876b211c" dependencies = [ - "powerfmt", "serde_core", ] @@ -866,6 +987,7 @@ checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer 0.10.4", "crypto-common 0.1.7", + "subtle", ] [[package]] @@ -874,7 +996,7 @@ version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f1dd6dbb5841937940781866fa1281a1ff7bd3bf827091440879f9994983d5c2" dependencies = [ - "block-buffer 0.12.0", + "block-buffer 0.12.1", "const-oid", "crypto-common 0.2.2", "ctutils", @@ -882,9 +1004,9 @@ dependencies = [ [[package]] name = "displaydoc" -version = "0.2.5" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +checksum = "1ac70aa55017e108007fbaf5aa0f54b021c98f92ff8af59d42eda9da96e3dd4f" dependencies = [ "proc-macro2", "quote", @@ -956,12 +1078,6 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" -[[package]] -name = "foldhash" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" - [[package]] name = "form_urlencoded" version = "1.2.2" @@ -985,6 +1101,7 @@ checksum = "8b147ee9d1f6d097cef9ce628cd2ee62288d963e16fb287bd9286455b241382d" dependencies = [ "futures-channel", "futures-core", + "futures-executor", "futures-io", "futures-sink", "futures-task", @@ -1007,6 +1124,17 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" +[[package]] +name = "futures-executor" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf29c38818342a3b26b5b923639e7b1f4a61fc5e76102d4b1981c6dc7a7579d" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + [[package]] name = "futures-io" version = "0.3.32" @@ -1092,15 +1220,13 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.4.2" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0de51e6874e94e7bf76d726fc5d13ba782deca734ff60d5bb2fb2607c7406555" +checksum = "300e883d756b2e4ec94e02791f39b04b522276138852cfc41d9fb7e904106099" dependencies = [ "cfg-if", "libc", "r-efi 6.0.0", - "wasip2", - "wasip3", ] [[package]] @@ -1134,16 +1260,16 @@ dependencies = [ [[package]] name = "h2" -version = "0.4.14" +version = "0.4.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "171fefbc92fe4a4de27e0698d6a5b392d6a0e333506bc49133760b3bcf948733" +checksum = "6cb093c84e8bd9b188d4c4a8cb6579fc016968d14c99882163cd3ff402a4f155" dependencies = [ "atomic-waker", "bytes", "fnv", "futures-core", "futures-sink", - "http 1.4.0", + "http 1.4.2", "indexmap", "slab", "tokio", @@ -1159,18 +1285,9 @@ checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" [[package]] name = "hashbrown" -version = "0.15.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" -dependencies = [ - "foldhash", -] - -[[package]] -name = "hashbrown" -version = "0.17.0" +version = "0.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f467dd6dccf739c208452f8014c75c18bb8301b050ad1cfb27153803edb0f51" +checksum = "ed5909b6e89a2db4456e54cd5f673791d7eca6732202bbf2a9cc504fe2f9b84a" [[package]] name = "heck" @@ -1178,12 +1295,27 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hermit-abi" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" + [[package]] name = "hex" version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest 0.10.7", +] + [[package]] name = "hmac" version = "0.13.0" @@ -1206,9 +1338,9 @@ dependencies = [ [[package]] name = "http" -version = "1.4.0" +version = "1.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" +checksum = "6970f50e31d6fc17d3fa27329444bfa74e196cf62e95052a3f6fee181dba6425" dependencies = [ "bytes", "itoa", @@ -1232,7 +1364,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", - "http 1.4.0", + "http 1.4.2", ] [[package]] @@ -1243,7 +1375,7 @@ checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" dependencies = [ "bytes", "futures-core", - "http 1.4.0", + "http 1.4.2", "http-body 1.0.1", "pin-project-lite", ] @@ -1262,9 +1394,9 @@ checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" [[package]] name = "hybrid-array" -version = "0.4.10" +version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3944cf8cf766b40e2a1a333ee5e9b563f854d5fa49d6a8ca2764e97c6eddb214" +checksum = "9155a582abd142abc056962c29e3ce5ff2ad5469f4246b537ed42c5deba857da" dependencies = [ "typenum", ] @@ -1295,18 +1427,19 @@ dependencies = [ [[package]] name = "hyper" -version = "1.9.0" +version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6299f016b246a94207e63da54dbe807655bf9e00044f73ded42c3ac5305fbcca" +checksum = "55281c53a1894c864990125767da440a4e630446785086f52523b20033b74498" dependencies = [ "atomic-waker", "bytes", "futures-channel", "futures-core", - "h2 0.4.14", - "http 1.4.0", + "h2 0.4.15", + "http 1.4.2", "http-body 1.0.1", "httparse", + "httpdate", "itoa", "pin-project-lite", "smallvec", @@ -1335,15 +1468,15 @@ version = "0.27.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33ca68d021ef39cf6463ab54c1d0f5daf03377b70561305bb89a8f83aab66e0f" dependencies = [ - "http 1.4.0", - "hyper 1.9.0", + "http 1.4.2", + "hyper 1.10.1", "hyper-util", - "rustls 0.23.38", + "rustls 0.23.40", "rustls-native-certs", "tokio", "tokio-rustls 0.26.4", "tower-service", - "webpki-roots 1.0.6", + "webpki-roots 1.0.8", ] [[package]] @@ -1356,14 +1489,14 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "http 1.4.0", + "http 1.4.2", "http-body 1.0.1", - "hyper 1.9.0", + "hyper 1.10.1", "ipnet", "libc", "percent-encoding", "pin-project-lite", - "socket2 0.6.3", + "socket2 0.6.4", "tokio", "tower-service", "tracing", @@ -1475,12 +1608,6 @@ dependencies = [ "zerovec", ] -[[package]] -name = "id-arena" -version = "2.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" - [[package]] name = "idna" version = "1.1.0" @@ -1494,9 +1621,9 @@ dependencies = [ [[package]] name = "idna_adapter" -version = "1.2.1" +version = "1.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344" +checksum = "cb68373c0d6620ef8105e855e7745e18b0d00d3bdb07fb532e434244cdb9a714" dependencies = [ "icu_normalizer", "icu_properties", @@ -1537,26 +1664,24 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d466e9454f08e4a911e14806c24e16fba1b4c121d1ea474396f396069cf949d9" dependencies = [ "equivalent", - "hashbrown 0.17.0", - "serde", - "serde_core", + "hashbrown 0.17.1", ] [[package]] -name = "ipnet" -version = "2.12.0" +name = "inout" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2" +checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01" +dependencies = [ + "block-padding", + "generic-array", +] [[package]] -name = "iri-string" -version = "0.7.12" +name = "ipnet" +version = "2.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25e659a4bb38e810ebc252e53b5814ff908a8c58c2a9ce2fae1bbec24cbf4e20" -dependencies = [ - "memchr", - "serde", -] +checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2" [[package]] name = "is_terminal_polyfill" @@ -1564,6 +1689,15 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" +[[package]] +name = "itertools" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.18" @@ -1582,33 +1716,41 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.95" +version = "0.3.102" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2964e92d1d9dc3364cae4d718d93f227e3abb088e747d92e0395bfdedf1c12ca" +checksum = "03d04c30968dffe80775bd4d7fb676131cd04a1fb46d2686dbffbaec2d9dfd31" dependencies = [ "cfg-if", "futures-util", - "once_cell", "wasm-bindgen", ] [[package]] -name = "lazy_static" -version = "1.5.0" +name = "jsonwebtoken" +version = "9.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +checksum = "5a87cc7a48537badeae96744432de36f4be2b4a34a05a5ef32e9dd8a1c169dde" +dependencies = [ + "base64", + "js-sys", + "pem", + "ring", + "serde", + "serde_json", + "simple_asn1", +] [[package]] -name = "leb128fmt" -version = "0.1.0" +name = "lazy_static" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.185" +version = "0.2.186" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52ff2c0fe9bc6cb6b14a0592c2ff4fa9ceb83eea9db979b0487cd054946a2b8f" +checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66" [[package]] name = "linux-raw-sys" @@ -1633,9 +1775,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.29" +version = "0.4.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" +checksum = "953f07c43838f8e6f9758cab68bf5bed85465e7587ebe0b823f1bcd81978ad3a" [[package]] name = "lru-slab" @@ -1652,11 +1794,17 @@ dependencies = [ "regex-automata", ] +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + [[package]] name = "memchr" -version = "2.8.0" +version = "2.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" +checksum = "88904434abc2901f197fe8cc55f0445e7ded921dba5911dad2e2b39b48e663c4" [[package]] name = "mime" @@ -1686,9 +1834,9 @@ dependencies = [ [[package]] name = "mio" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50b7e5b27aa02a74bac8c3f23f448f8d87ff11f92d3aac1a6ed369ee08cc56c1" +checksum = "02bd0af71c67b473010cbbc60715ee815645a4dc942899111f494b4b737d6fda" dependencies = [ "libc", "wasi", @@ -1714,11 +1862,21 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + [[package]] name = "num-conv" -version = "0.2.1" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6673768db2d862beb9b39a78fdcb1a69439615d5794a1be50caa9bc92c81967" +checksum = "521739c6d2bac4aa25192232afe6841231376b2b26d4d9fae5ecf8ca5772e441" [[package]] name = "num-integer" @@ -1738,6 +1896,16 @@ dependencies = [ "autocfg", ] +[[package]] +name = "num_cpus" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "once_cell" version = "1.21.4" @@ -1753,6 +1921,20 @@ checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" [[package]] name = "openab" version = "0.8.5" +dependencies = [ + "anyhow", + "clap", + "openab-core", + "openab-gateway", + "serenity", + "tokio", + "tracing", + "tracing-subscriber", +] + +[[package]] +name = "openab-core" +version = "0.8.5" dependencies = [ "anyhow", "async-trait", @@ -1767,11 +1949,11 @@ dependencies = [ "cron", "futures-util", "hex", - "http 1.4.0", + "http 1.4.2", "image", "libc", "pulldown-cmark", - "rand 0.8.5", + "rand 0.8.6", "regex", "reqwest", "rpassword", @@ -1783,7 +1965,7 @@ dependencies = [ "tempfile", "tokio", "tokio-rustls 0.25.0", - "tokio-tungstenite", + "tokio-tungstenite 0.21.0", "toml", "toml_edit", "tracing", @@ -1794,6 +1976,37 @@ dependencies = [ "webpki-roots 0.26.11", ] +[[package]] +name = "openab-gateway" +version = "0.5.4" +dependencies = [ + "aes", + "anyhow", + "axum", + "base64", + "cbc", + "chrono", + "futures-util", + "hmac 0.12.1", + "image", + "jsonwebtoken", + "prost", + "quick-xml", + "reqwest", + "serde", + "serde_json", + "sha1", + "sha2 0.10.9", + "subtle", + "tokio", + "tokio-tungstenite 0.21.0", + "tracing", + "tracing-subscriber", + "urlencoding", + "uuid", + "wiremock", +] + [[package]] name = "openssl-probe" version = "0.2.1" @@ -1829,6 +2042,16 @@ dependencies = [ "windows-link", ] +[[package]] +name = "pem" +version = "3.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d30c53c26bc5b31a98cd02d20f25a7c8567146caf63ed593a9d87b2775291be" +dependencies = [ + "base64", + "serde_core", +] + [[package]] name = "percent-encoding" version = "2.3.2" @@ -1861,7 +2084,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d" dependencies = [ "phf_shared 0.11.3", - "rand 0.8.5", + "rand 0.8.6", ] [[package]] @@ -1945,29 +2168,42 @@ dependencies = [ ] [[package]] -name = "prettyplease" -version = "0.2.37" +name = "proc-macro2" +version = "1.0.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" dependencies = [ - "proc-macro2", - "syn", + "unicode-ident", ] [[package]] -name = "proc-macro2" -version = "1.0.106" +name = "prost" +version = "0.13.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +checksum = "2796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5" dependencies = [ - "unicode-ident", + "bytes", + "prost-derive", +] + +[[package]] +name = "prost-derive" +version = "0.13.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" +dependencies = [ + "anyhow", + "itertools", + "proc-macro2", + "quote", + "syn", ] [[package]] name = "pulldown-cmark" -version = "0.13.3" +version = "0.13.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c3a14896dfa883796f1cb410461aef38810ea05f2b2c33c5aded3649095fdad" +checksum = "e9f068eba8e7071c5f9511831b44f32c740d5adf574e990f946ddb53db2f314e" dependencies = [ "bitflags", "memchr", @@ -1976,9 +2212,9 @@ dependencies = [ [[package]] name = "pxfm" -version = "0.1.28" +version = "0.1.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5a041e753da8b807c9255f28de81879c78c876392ff2469cde94799b2896b9d" +checksum = "e0c5ccf5294c6ccd63a74f1565028353830a9c2f5eb0c682c355c471726a6e3f" [[package]] name = "quick-error" @@ -1986,6 +2222,15 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" +[[package]] +name = "quick-xml" +version = "0.37.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "331e97a1af0bf59823e6eadffe373d7b27f485be8748f71471c662c1f269b7fb" +dependencies = [ + "memchr", +] + [[package]] name = "quinn" version = "0.11.9" @@ -1998,8 +2243,8 @@ dependencies = [ "quinn-proto", "quinn-udp", "rustc-hash", - "rustls 0.23.38", - "socket2 0.6.3", + "rustls 0.23.40", + "socket2 0.6.4", "thiserror 2.0.18", "tokio", "tracing", @@ -2018,7 +2263,7 @@ dependencies = [ "rand 0.9.4", "ring", "rustc-hash", - "rustls 0.23.38", + "rustls 0.23.40", "rustls-pki-types", "slab", "thiserror 2.0.18", @@ -2036,7 +2281,7 @@ dependencies = [ "cfg_aliases", "libc", "once_cell", - "socket2 0.6.3", + "socket2 0.6.4", "tracing", "windows-sys 0.60.2", ] @@ -2064,9 +2309,9 @@ checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" [[package]] name = "rand" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +checksum = "5ca0ecfa931c29007047d1bc58e623ab12e5590e8c7cc53200d5202b69266d8a" dependencies = [ "libc", "rand_chacha 0.3.1", @@ -2132,9 +2377,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.12.3" +version = "1.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" +checksum = "f1292b7759ae1cb9ec195452d1390a074f0cd8541ab7a5a8c31cd6db45d4a6ba" dependencies = [ "aho-corasick", "memchr", @@ -2161,9 +2406,9 @@ checksum = "cab834c73d247e67f4fae452806d17d3c7501756d98c8808d7c9c7aa7d18f973" [[package]] name = "regex-syntax" -version = "0.8.10" +version = "0.8.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" +checksum = "d6f6ff9a378485b298a5286656da665ba74413d36db0979633275d2e708145d4" [[package]] name = "reqwest" @@ -2176,10 +2421,10 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "http 1.4.0", + "http 1.4.2", "http-body 1.0.1", "http-body-util", - "hyper 1.9.0", + "hyper 1.10.1", "hyper-rustls 0.27.9", "hyper-util", "js-sys", @@ -2188,7 +2433,7 @@ dependencies = [ "percent-encoding", "pin-project-lite", "quinn", - "rustls 0.23.38", + "rustls 0.23.40", "rustls-pki-types", "serde", "serde_json", @@ -2205,7 +2450,7 @@ dependencies = [ "wasm-bindgen-futures", "wasm-streams", "web-sys", - "webpki-roots 1.0.6", + "webpki-roots 1.0.8", ] [[package]] @@ -2224,20 +2469,20 @@ dependencies = [ [[package]] name = "rpassword" -version = "7.4.0" +version = "7.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "66d4c8b64f049c6721ec8ccec37ddfc3d641c4a7fca57e8f2a89de509c73df39" +checksum = "2da316a15f47e3d053de9cb2c439650bd8fa4aaeb9365f2e5f27f492ff73c196" dependencies = [ "libc", "rtoolbox", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] name = "rtoolbox" -version = "0.0.4" +version = "0.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "327b72899159dfae8060c51a1f6aebe955245bcd9cc4997eed0f623caea022e4" +checksum = "50a0e551c1e27e1731aba276dbeaeac73f53c7cd34d1bda485d02bd1e0f36844" dependencies = [ "libc", "windows-sys 0.59.0", @@ -2299,15 +2544,15 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.38" +version = "0.23.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69f9466fb2c14ea04357e91413efb882e2a6d4a406e625449bc0a5d360d53a21" +checksum = "ef86cd5876211988985292b91c96a8f2d298df24e75989a43a3c73f2d4d8168b" dependencies = [ "aws-lc-rs", "once_cell", "ring", "rustls-pki-types", - "rustls-webpki 0.103.12", + "rustls-webpki 0.103.13", "subtle", "zeroize", ] @@ -2326,9 +2571,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.14.0" +version = "1.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd" +checksum = "30a7197ae7eb376e574fe940d068c30fe0462554a3ddbe4eca7838e049c937a9" dependencies = [ "web-time", "zeroize", @@ -2357,9 +2602,9 @@ dependencies = [ [[package]] name = "rustls-webpki" -version = "0.103.12" +version = "0.103.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8279bb85272c9f10811ae6a6c547ff594d6a7f3c6c6b02ee9726d1d0dcfcdd06" +checksum = "61c429a8649f110dddef65e2a5ad240f747e85f7758a6bccc7e5777bd33f756e" dependencies = [ "aws-lc-rs", "ring", @@ -2484,9 +2729,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.149" +version = "1.0.150" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +checksum = "e8014e44b4736ed0538adeecded0fce2a272f22dc9578a7eb6b2d9993c74cfb9" dependencies = [ "itoa", "memchr", @@ -2495,6 +2740,17 @@ dependencies = [ "zmij", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457" +dependencies = [ + "itoa", + "serde", + "serde_core", +] + [[package]] name = "serde_spanned" version = "0.6.9" @@ -2541,7 +2797,7 @@ dependencies = [ "serde_json", "time", "tokio", - "tokio-tungstenite", + "tokio-tungstenite 0.21.0", "tracing", "typemap_rev", "url", @@ -2591,9 +2847,9 @@ dependencies = [ [[package]] name = "shlex" -version = "1.3.0" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +checksum = "f8fadd59c855ef2080decdef8ff161eb6661b86933c9d82e5ba29dc602a55aba" [[package]] name = "signal-hook-registry" @@ -2611,11 +2867,23 @@ version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "703d5c7ef118737c72f1af64ad2f6f8c5e1921f818cdcb97b8fe6fc69bf66214" +[[package]] +name = "simple_asn1" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d585997b0ac10be3c5ee635f1bab02d512760d14b7c468801ac8a01d9ae5f1d" +dependencies = [ + "num-bigint", + "num-traits", + "thiserror 2.0.18", + "time", +] + [[package]] name = "siphasher" -version = "1.0.2" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2aa850e253778c88a04c3d7323b043aeda9d3e30d5971937c1855769763678e" +checksum = "8ee5873ec9cce0195efcb7a4e9507a04cd49aec9c83d0389df45b1ef7ba2e649" [[package]] name = "slab" @@ -2625,9 +2893,9 @@ checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" [[package]] name = "smallvec" -version = "1.15.1" +version = "1.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" +checksum = "8ed6a63f02c8539c91a8685a86f4099661ba3da017932f6ebbea6de3f0fa7c90" [[package]] name = "socket2" @@ -2641,9 +2909,9 @@ dependencies = [ [[package]] name = "socket2" -version = "0.6.3" +version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" +checksum = "52d1cfed4120b4d927bf7c0f86d2087a4a7d6027c906d9f9d525a80573b9be51" dependencies = [ "libc", "windows-sys 0.61.2", @@ -2669,9 +2937,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.117" +version = "2.0.118" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +checksum = "1b9ae57f904213ebb649ce6895b8a66c66f0203b9319718f69a5612a065b1422" dependencies = [ "proc-macro2", "quote", @@ -2705,7 +2973,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32497e9a4c7b38532efcdebeef879707aa9f794296a4f0244f6f69e9bc8574bd" dependencies = [ "fastrand", - "getrandom 0.4.2", + "getrandom 0.4.3", "once_cell", "rustix", "windows-sys 0.61.2", @@ -2762,12 +3030,11 @@ dependencies = [ [[package]] name = "time" -version = "0.3.47" +version = "0.3.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "743bd48c283afc0388f9b8827b976905fb217ad9e647fae3a379a9283c4def2c" +checksum = "711a53c2d47bbd818258c498c8dbfe186a2526c631495cfe7e078567f86b8469" dependencies = [ "deranged", - "itoa", "num-conv", "powerfmt", "serde_core", @@ -2777,15 +3044,15 @@ dependencies = [ [[package]] name = "time-core" -version = "0.1.8" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7694e1cfe791f8d31026952abf09c69ca6f6fa4e1a1229e18988f06a04a12dca" +checksum = "9e1c906769ad99c88eaa54e728060edef082f8e358ff32030cb7c7d315e81109" [[package]] name = "time-macros" -version = "0.2.27" +version = "0.2.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e70e4c5a0e0a8a4823ad65dfe1a6930e4f4d756dcd9dd7939022b5e8c501215" +checksum = "71c652a3727a9cbb9a02f707f530b618ce00d0ccd762009c8c23bd191df3c17d" dependencies = [ "num-conv", "time-core", @@ -2818,9 +3085,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.52.0" +version = "1.52.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a91135f59b1cbf38c91e73cf3386fca9bb77915c45ce2771460c9d92f0f3d776" +checksum = "8fc7f01b389ac15039e4dc9531aa973a135d7a4135281b12d7c1bc79fd57fffe" dependencies = [ "bytes", "libc", @@ -2828,7 +3095,7 @@ dependencies = [ "parking_lot", "pin-project-lite", "signal-hook-registry", - "socket2 0.6.3", + "socket2 0.6.4", "tokio-macros", "windows-sys 0.61.2", ] @@ -2871,7 +3138,7 @@ version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" dependencies = [ - "rustls 0.23.38", + "rustls 0.23.40", "tokio", ] @@ -2887,10 +3154,22 @@ dependencies = [ "rustls-pki-types", "tokio", "tokio-rustls 0.25.0", - "tungstenite", + "tungstenite 0.21.0", "webpki-roots 0.26.11", ] +[[package]] +name = "tokio-tungstenite" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f72a05e828585856dacd553fba484c242c46e391fb0e58917c942ee9202915c" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite 0.29.0", +] + [[package]] name = "tokio-util" version = "0.7.18" @@ -2958,24 +3237,25 @@ dependencies = [ "tokio", "tower-layer", "tower-service", + "tracing", ] [[package]] name = "tower-http" -version = "0.6.8" +version = "0.6.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" +checksum = "4cfcf7e2740e6fc6d4d688b4ef00650406bb94adf4731e43c096c3a19fe40840" dependencies = [ "bitflags", "bytes", "futures-util", - "http 1.4.0", + "http 1.4.2", "http-body 1.0.1", - "iri-string", "pin-project-lite", "tower", "tower-layer", "tower-service", + "url", ] [[package]] @@ -3080,10 +3360,10 @@ dependencies = [ "byteorder", "bytes", "data-encoding", - "http 1.4.0", + "http 1.4.2", "httparse", "log", - "rand 0.8.5", + "rand 0.8.6", "rustls 0.22.4", "rustls-pki-types", "sha1", @@ -3092,6 +3372,22 @@ dependencies = [ "utf-8", ] +[[package]] +name = "tungstenite" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c01152af293afb9c7c2a57e4b559c5620b421f6d133261c60dd2d0cdb38e6b8" +dependencies = [ + "bytes", + "data-encoding", + "http 1.4.2", + "httparse", + "log", + "rand 0.9.4", + "sha1", + "thiserror 2.0.18", +] + [[package]] name = "typemap_rev" version = "0.3.0" @@ -3100,9 +3396,9 @@ checksum = "74b08b0c1257381af16a5c3605254d529d3e7e109f3c62befc5d168968192998" [[package]] name = "typenum" -version = "1.19.0" +version = "1.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" +checksum = "b6f5e870be6c3b371b77fe0ee0bafb859fa4964b4404c27de1d380043c4dda20" [[package]] name = "unicase" @@ -3122,12 +3418,6 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" -[[package]] -name = "unicode-xid" -version = "0.2.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" - [[package]] name = "untrusted" version = "0.9.0" @@ -3173,11 +3463,11 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.23.0" +version = "1.23.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ac8b6f42ead25368cf5b098aeb3dc8a1a2c05a3eee8a9a1a68c640edbfc79d9" +checksum = "144d6b123cef80b301b8f72a9e2ca4370ddec21950d0a103dd22c437006d2db7" dependencies = [ - "getrandom 0.4.2", + "getrandom 0.4.3", "js-sys", "wasm-bindgen", ] @@ -3217,27 +3507,18 @@ checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" [[package]] name = "wasip2" -version = "1.0.2+wasi-0.2.9" +version = "1.0.4+wasi-0.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" -dependencies = [ - "wit-bindgen", -] - -[[package]] -name = "wasip3" -version = "0.4.0+wasi-0.3.0-rc-2026-01-06" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" +checksum = "b67efb37e106e55ce722a510d6b5f9c17f083e5fc79afc2badeb12cc313d9487" dependencies = [ "wit-bindgen", ] [[package]] name = "wasm-bindgen" -version = "0.2.118" +version = "0.2.125" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bf938a0bacb0469e83c1e148908bd7d5a6010354cf4fb73279b7447422e3a89" +checksum = "8ddb3f79143bced6de84270411622a2699cee572fc0875aeaf1e7867cf9fca1a" dependencies = [ "cfg-if", "once_cell", @@ -3248,9 +3529,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.68" +version = "0.4.75" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f371d383f2fb139252e0bfac3b81b265689bf45b6874af544ffa4c975ac1ebf8" +checksum = "503b14d284f2c8dac03b819967e155ea753f573586193b2b2c95990cb5d69280" dependencies = [ "js-sys", "wasm-bindgen", @@ -3258,9 +3539,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.118" +version = "0.2.125" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eeff24f84126c0ec2db7a449f0c2ec963c6a49efe0698c4242929da037ca28ed" +checksum = "4e21a184b13fb19e157296e2c46056aec9092264fab83e4ba59e68c61b323c3d" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -3268,9 +3549,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.118" +version = "0.2.125" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d08065faf983b2b80a79fd87d8254c409281cf7de75fc4b773019824196c904" +checksum = "fecefd9c35bd935a20fc3fc344b5f29138961e4f47fb03297d88f2587afb5ebd" dependencies = [ "bumpalo", "proc-macro2", @@ -3281,35 +3562,13 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.118" +version = "0.2.125" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fd04d9e306f1907bd13c6361b5c6bfc7b3b3c095ed3f8a9246390f8dbdee129" +checksum = "23939e44bb9a5d7576fa2b563dc2e136628f1224e88a8deed09e04858b77871f" dependencies = [ "unicode-ident", ] -[[package]] -name = "wasm-encoder" -version = "0.244.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319" -dependencies = [ - "leb128fmt", - "wasmparser", -] - -[[package]] -name = "wasm-metadata" -version = "0.244.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" -dependencies = [ - "anyhow", - "indexmap", - "wasm-encoder", - "wasmparser", -] - [[package]] name = "wasm-streams" version = "0.4.2" @@ -3323,23 +3582,11 @@ dependencies = [ "web-sys", ] -[[package]] -name = "wasmparser" -version = "0.244.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" -dependencies = [ - "bitflags", - "hashbrown 0.15.5", - "indexmap", - "semver", -] - [[package]] name = "web-sys" -version = "0.3.95" +version = "0.3.102" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f2dfbb17949fa2088e5d39408c48368947b86f7834484e87b73de55bc14d97d" +checksum = "a6430a72df5eb332242960fe84b3002a241163998241eb596d4f739b9757061d" dependencies = [ "js-sys", "wasm-bindgen", @@ -3361,14 +3608,14 @@ version = "0.26.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "521bc38abb08001b01866da9f51eb7c5d647a19260e00054a8c7fd5f9e57f7a9" dependencies = [ - "webpki-roots 1.0.6", + "webpki-roots 1.0.8", ] [[package]] name = "webpki-roots" -version = "1.0.6" +version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22cfaf3c063993ff62e73cb4311efde4db1efb31ab78a3e5c457939ad5cc0bed" +checksum = "bf85cb06032201fa7c6f829d7db5a7e5aa45bcc0655327713065f6f0576731bf" dependencies = [ "rustls-pki-types", ] @@ -3613,92 +3860,33 @@ dependencies = [ ] [[package]] -name = "wit-bindgen" -version = "0.51.0" +name = "wiremock" +version = "0.6.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" +checksum = "08db1edfb05d9b3c1542e521aea074442088292f00b5f28e435c714a98f85031" dependencies = [ - "wit-bindgen-rust-macro", -] - -[[package]] -name = "wit-bindgen-core" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" -dependencies = [ - "anyhow", - "heck", - "wit-parser", -] - -[[package]] -name = "wit-bindgen-rust" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" -dependencies = [ - "anyhow", - "heck", - "indexmap", - "prettyplease", - "syn", - "wasm-metadata", - "wit-bindgen-core", - "wit-component", -] - -[[package]] -name = "wit-bindgen-rust-macro" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a" -dependencies = [ - "anyhow", - "prettyplease", - "proc-macro2", - "quote", - "syn", - "wit-bindgen-core", - "wit-bindgen-rust", -] - -[[package]] -name = "wit-component" -version = "0.244.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" -dependencies = [ - "anyhow", - "bitflags", - "indexmap", + "assert-json-diff", + "base64", + "deadpool", + "futures", + "http 1.4.2", + "http-body-util", + "hyper 1.10.1", + "hyper-util", "log", + "once_cell", + "regex", "serde", - "serde_derive", "serde_json", - "wasm-encoder", - "wasm-metadata", - "wasmparser", - "wit-parser", + "tokio", + "url", ] [[package]] -name = "wit-parser" -version = "0.244.0" +name = "wit-bindgen" +version = "0.57.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" -dependencies = [ - "anyhow", - "id-arena", - "indexmap", - "log", - "semver", - "serde", - "serde_derive", - "serde_json", - "unicode-xid", - "wasmparser", -] +checksum = "1ebf944e87a7c253233ad6766e082e3cd714b5d03812acc24c318f549614536e" [[package]] name = "writeable" @@ -3714,9 +3902,9 @@ checksum = "66fee0b777b0f5ac1c69bb06d361268faafa61cd4682ae064a171c16c433e9e4" [[package]] name = "yoke" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "abe8c5fda708d9ca3df187cae8bfb9ceda00dd96231bed36e445a1a48e66f9ca" +checksum = "709fe23a0424b6a435d82152b1bd3fdfb0833487d5fa90d05d42762a9891fef5" dependencies = [ "stable_deref_trait", "yoke-derive", @@ -3737,18 +3925,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.48" +version = "0.8.52" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eed437bf9d6692032087e337407a86f04cd8d6a16a37199ed57949d415bd68e9" +checksum = "ce1022995ff5ff5d841ad7d994facc23098cd40152f2c1d11cd607c6f530653f" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.48" +version = "0.8.52" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70e3cd084b1788766f53af483dd21f93881ff30d7320490ec3ef7526d203bad4" +checksum = "1ae7f38b72ec2a254e2b87ef277cf2cd4fb97cbebf944faa6f33354da0867930" dependencies = [ "proc-macro2", "quote", @@ -3757,9 +3945,9 @@ dependencies = [ [[package]] name = "zerofrom" -version = "0.1.7" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69faa1f2a1ea75661980b013019ed6687ed0e83d069bc1114e2cc74c6c04c4df" +checksum = "0ec05a11813ea801ff6d75110ad09cd0824ddba17dfe17128ea0d5f68e6c5272" dependencies = [ "zerofrom-derive", ] @@ -3778,9 +3966,9 @@ dependencies = [ [[package]] name = "zeroize" -version = "1.8.2" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" +checksum = "e13c156562582aa81c60cb29407084cdb54c4164760106ab78e6c5b0858cf64e" [[package]] name = "zerotrie" From 1641b6f13682cb5d97be075de3fb65c662a463cc Mon Sep 17 00:00:00 2001 From: chaodu-agent Date: Thu, 18 Jun 2026 20:25:18 +0000 Subject: [PATCH 06/20] fix: add COPY crates/ to all Dockerfiles for workspace build All variant Dockerfiles need crates/ copied before the build step, since src/main.rs now imports from openab-core (a workspace member). --- Dockerfile.agentcore | 1 + Dockerfile.antigravity | 1 + Dockerfile.claude | 1 + Dockerfile.codex | 1 + Dockerfile.copilot | 1 + Dockerfile.cursor | 1 + Dockerfile.gemini | 1 + Dockerfile.grok | 1 + Dockerfile.hermes | 1 + Dockerfile.mimocode | 1 + Dockerfile.native | 1 + Dockerfile.opencode | 1 + Dockerfile.pi | 1 + 13 files changed, 13 insertions(+) diff --git a/Dockerfile.agentcore b/Dockerfile.agentcore index 815d0ce88..89060f805 100644 --- a/Dockerfile.agentcore +++ b/Dockerfile.agentcore @@ -6,6 +6,7 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ +COPY crates/ crates/ RUN mkdir src && echo 'fn main() {}' > src/main.rs && cargo build --release --features agentcore && rm -rf src COPY src/ src/ RUN touch src/main.rs && cargo build --release --features agentcore diff --git a/Dockerfile.antigravity b/Dockerfile.antigravity index bc16c1e7c..112da9abf 100644 --- a/Dockerfile.antigravity +++ b/Dockerfile.antigravity @@ -2,6 +2,7 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ +COPY crates/ crates/ RUN mkdir src && echo 'fn main() {}' > src/main.rs && cargo build --release && rm -rf src COPY src/ src/ RUN touch src/main.rs && cargo build --release diff --git a/Dockerfile.claude b/Dockerfile.claude index c6657252d..cb646076d 100644 --- a/Dockerfile.claude +++ b/Dockerfile.claude @@ -2,6 +2,7 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ +COPY crates/ crates/ RUN mkdir src && echo 'fn main() {}' > src/main.rs && cargo build --release && rm -rf src COPY src/ src/ RUN touch src/main.rs && cargo build --release diff --git a/Dockerfile.codex b/Dockerfile.codex index f6e2f2af6..49ee7fae0 100644 --- a/Dockerfile.codex +++ b/Dockerfile.codex @@ -2,6 +2,7 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ +COPY crates/ crates/ RUN mkdir src && echo 'fn main() {}' > src/main.rs && cargo build --release && rm -rf src COPY src/ src/ RUN touch src/main.rs && cargo build --release diff --git a/Dockerfile.copilot b/Dockerfile.copilot index 23e903898..67ceaf9f7 100644 --- a/Dockerfile.copilot +++ b/Dockerfile.copilot @@ -2,6 +2,7 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ +COPY crates/ crates/ RUN mkdir src && echo 'fn main() {}' > src/main.rs && cargo build --release && rm -rf src COPY src/ src/ RUN touch src/main.rs && cargo build --release diff --git a/Dockerfile.cursor b/Dockerfile.cursor index 2c408b3b6..339d73f22 100644 --- a/Dockerfile.cursor +++ b/Dockerfile.cursor @@ -2,6 +2,7 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ +COPY crates/ crates/ RUN mkdir src && echo 'fn main() {}' > src/main.rs && cargo build --release && rm -rf src COPY src/ src/ RUN touch src/main.rs && cargo build --release diff --git a/Dockerfile.gemini b/Dockerfile.gemini index c7221608e..6387ca1e2 100644 --- a/Dockerfile.gemini +++ b/Dockerfile.gemini @@ -2,6 +2,7 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ +COPY crates/ crates/ RUN mkdir src && echo 'fn main() {}' > src/main.rs && cargo build --release && rm -rf src COPY src/ src/ RUN touch src/main.rs && cargo build --release diff --git a/Dockerfile.grok b/Dockerfile.grok index bdd8a6c1a..bd8dc48d4 100644 --- a/Dockerfile.grok +++ b/Dockerfile.grok @@ -2,6 +2,7 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ +COPY crates/ crates/ RUN mkdir src && echo 'fn main() {}' > src/main.rs && cargo build --release && rm -rf src COPY src/ src/ RUN touch src/main.rs && cargo build --release diff --git a/Dockerfile.hermes b/Dockerfile.hermes index 66408f902..18cf277a1 100644 --- a/Dockerfile.hermes +++ b/Dockerfile.hermes @@ -2,6 +2,7 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ +COPY crates/ crates/ RUN mkdir src && echo 'fn main() {}' > src/main.rs && cargo build --release && rm -rf src COPY src/ src/ RUN touch src/main.rs && cargo build --release diff --git a/Dockerfile.mimocode b/Dockerfile.mimocode index 9aadc1de0..d637a51ff 100644 --- a/Dockerfile.mimocode +++ b/Dockerfile.mimocode @@ -2,6 +2,7 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ +COPY crates/ crates/ RUN mkdir src && echo 'fn main() {}' > src/main.rs && cargo build --release && rm -rf src COPY src/ src/ RUN touch src/main.rs && cargo build --release diff --git a/Dockerfile.native b/Dockerfile.native index 058b91c76..1d1d6c562 100644 --- a/Dockerfile.native +++ b/Dockerfile.native @@ -2,6 +2,7 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ +COPY crates/ crates/ COPY openab-agent/ openab-agent/ RUN mkdir src && echo 'fn main() {}' > src/main.rs && cargo build --release && rm -rf src COPY src/ src/ diff --git a/Dockerfile.opencode b/Dockerfile.opencode index 69d8a8909..244b5e3f1 100644 --- a/Dockerfile.opencode +++ b/Dockerfile.opencode @@ -2,6 +2,7 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ +COPY crates/ crates/ RUN mkdir src && echo 'fn main() {}' > src/main.rs && cargo build --release && rm -rf src COPY src/ src/ RUN touch src/main.rs && cargo build --release diff --git a/Dockerfile.pi b/Dockerfile.pi index f87fc7e8b..697833c8c 100644 --- a/Dockerfile.pi +++ b/Dockerfile.pi @@ -2,6 +2,7 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ +COPY crates/ crates/ RUN mkdir src && echo 'fn main() {}' > src/main.rs && cargo build --release && rm -rf src COPY src/ src/ RUN touch src/main.rs && cargo build --release From cfd459fcff00562e55ddfbee819b174da70f8fd3 Mon Sep 17 00:00:00 2001 From: chaodu-agent Date: Thu, 18 Jun 2026 20:27:59 +0000 Subject: [PATCH 07/20] fix: optimize Dockerfile layer caching for workspace Use targeted COPY of crate Cargo.toml manifests + dummy lib.rs for the dep-fetch layer. Full crate source is only COPYed after, so the dep cache layer only invalidates when Cargo.toml/Cargo.lock change. --- Dockerfile | 9 ++++++++- Dockerfile.antigravity | 9 ++++++++- Dockerfile.claude | 9 ++++++++- Dockerfile.codex | 9 ++++++++- Dockerfile.copilot | 9 ++++++++- Dockerfile.cursor | 9 ++++++++- Dockerfile.gemini | 9 ++++++++- Dockerfile.grok | 9 ++++++++- Dockerfile.hermes | 9 ++++++++- Dockerfile.mimocode | 9 ++++++++- Dockerfile.opencode | 9 ++++++++- Dockerfile.pi | 9 ++++++++- 12 files changed, 96 insertions(+), 12 deletions(-) diff --git a/Dockerfile b/Dockerfile index f3222aa72..7ffd48d5e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -8,8 +8,15 @@ ARG FEATURES WORKDIR /build COPY Cargo.toml Cargo.lock ./ +COPY crates/openab-core/Cargo.toml crates/openab-core/Cargo.toml +COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml +RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ + && echo 'fn main() {}' > src/main.rs \ + && echo '' > crates/openab-core/src/lib.rs \ + && echo '' > crates/openab-gateway/src/lib.rs \ + && cargo build --release \ + && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ -RUN mkdir src && echo 'fn main() {}' > src/main.rs && cargo build --release && rm -rf src COPY src/ src/ RUN touch src/main.rs && \ if [ "$BUILD_MODE" = "unified" ]; then \ diff --git a/Dockerfile.antigravity b/Dockerfile.antigravity index 112da9abf..4ed8341bc 100644 --- a/Dockerfile.antigravity +++ b/Dockerfile.antigravity @@ -2,8 +2,15 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ +COPY crates/openab-core/Cargo.toml crates/openab-core/Cargo.toml +COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml +RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ + && echo 'fn main() {}' > src/main.rs \ + && echo '' > crates/openab-core/src/lib.rs \ + && echo '' > crates/openab-gateway/src/lib.rs \ + && cargo build --release \ + && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ -RUN mkdir src && echo 'fn main() {}' > src/main.rs && cargo build --release && rm -rf src COPY src/ src/ RUN touch src/main.rs && cargo build --release diff --git a/Dockerfile.claude b/Dockerfile.claude index cb646076d..710171bb2 100644 --- a/Dockerfile.claude +++ b/Dockerfile.claude @@ -2,8 +2,15 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ +COPY crates/openab-core/Cargo.toml crates/openab-core/Cargo.toml +COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml +RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ + && echo 'fn main() {}' > src/main.rs \ + && echo '' > crates/openab-core/src/lib.rs \ + && echo '' > crates/openab-gateway/src/lib.rs \ + && cargo build --release \ + && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ -RUN mkdir src && echo 'fn main() {}' > src/main.rs && cargo build --release && rm -rf src COPY src/ src/ RUN touch src/main.rs && cargo build --release diff --git a/Dockerfile.codex b/Dockerfile.codex index 49ee7fae0..1bcf9ab82 100644 --- a/Dockerfile.codex +++ b/Dockerfile.codex @@ -2,8 +2,15 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ +COPY crates/openab-core/Cargo.toml crates/openab-core/Cargo.toml +COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml +RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ + && echo 'fn main() {}' > src/main.rs \ + && echo '' > crates/openab-core/src/lib.rs \ + && echo '' > crates/openab-gateway/src/lib.rs \ + && cargo build --release \ + && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ -RUN mkdir src && echo 'fn main() {}' > src/main.rs && cargo build --release && rm -rf src COPY src/ src/ RUN touch src/main.rs && cargo build --release diff --git a/Dockerfile.copilot b/Dockerfile.copilot index 67ceaf9f7..3fc76809c 100644 --- a/Dockerfile.copilot +++ b/Dockerfile.copilot @@ -2,8 +2,15 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ +COPY crates/openab-core/Cargo.toml crates/openab-core/Cargo.toml +COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml +RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ + && echo 'fn main() {}' > src/main.rs \ + && echo '' > crates/openab-core/src/lib.rs \ + && echo '' > crates/openab-gateway/src/lib.rs \ + && cargo build --release \ + && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ -RUN mkdir src && echo 'fn main() {}' > src/main.rs && cargo build --release && rm -rf src COPY src/ src/ RUN touch src/main.rs && cargo build --release diff --git a/Dockerfile.cursor b/Dockerfile.cursor index 339d73f22..a65bcacda 100644 --- a/Dockerfile.cursor +++ b/Dockerfile.cursor @@ -2,8 +2,15 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ +COPY crates/openab-core/Cargo.toml crates/openab-core/Cargo.toml +COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml +RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ + && echo 'fn main() {}' > src/main.rs \ + && echo '' > crates/openab-core/src/lib.rs \ + && echo '' > crates/openab-gateway/src/lib.rs \ + && cargo build --release \ + && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ -RUN mkdir src && echo 'fn main() {}' > src/main.rs && cargo build --release && rm -rf src COPY src/ src/ RUN touch src/main.rs && cargo build --release diff --git a/Dockerfile.gemini b/Dockerfile.gemini index 6387ca1e2..6a4585588 100644 --- a/Dockerfile.gemini +++ b/Dockerfile.gemini @@ -2,8 +2,15 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ +COPY crates/openab-core/Cargo.toml crates/openab-core/Cargo.toml +COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml +RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ + && echo 'fn main() {}' > src/main.rs \ + && echo '' > crates/openab-core/src/lib.rs \ + && echo '' > crates/openab-gateway/src/lib.rs \ + && cargo build --release \ + && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ -RUN mkdir src && echo 'fn main() {}' > src/main.rs && cargo build --release && rm -rf src COPY src/ src/ RUN touch src/main.rs && cargo build --release diff --git a/Dockerfile.grok b/Dockerfile.grok index bd8dc48d4..53ee03c3a 100644 --- a/Dockerfile.grok +++ b/Dockerfile.grok @@ -2,8 +2,15 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ +COPY crates/openab-core/Cargo.toml crates/openab-core/Cargo.toml +COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml +RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ + && echo 'fn main() {}' > src/main.rs \ + && echo '' > crates/openab-core/src/lib.rs \ + && echo '' > crates/openab-gateway/src/lib.rs \ + && cargo build --release \ + && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ -RUN mkdir src && echo 'fn main() {}' > src/main.rs && cargo build --release && rm -rf src COPY src/ src/ RUN touch src/main.rs && cargo build --release diff --git a/Dockerfile.hermes b/Dockerfile.hermes index 18cf277a1..9629ccd29 100644 --- a/Dockerfile.hermes +++ b/Dockerfile.hermes @@ -2,8 +2,15 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ +COPY crates/openab-core/Cargo.toml crates/openab-core/Cargo.toml +COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml +RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ + && echo 'fn main() {}' > src/main.rs \ + && echo '' > crates/openab-core/src/lib.rs \ + && echo '' > crates/openab-gateway/src/lib.rs \ + && cargo build --release \ + && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ -RUN mkdir src && echo 'fn main() {}' > src/main.rs && cargo build --release && rm -rf src COPY src/ src/ RUN touch src/main.rs && cargo build --release diff --git a/Dockerfile.mimocode b/Dockerfile.mimocode index d637a51ff..7e5daae5d 100644 --- a/Dockerfile.mimocode +++ b/Dockerfile.mimocode @@ -2,8 +2,15 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ +COPY crates/openab-core/Cargo.toml crates/openab-core/Cargo.toml +COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml +RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ + && echo 'fn main() {}' > src/main.rs \ + && echo '' > crates/openab-core/src/lib.rs \ + && echo '' > crates/openab-gateway/src/lib.rs \ + && cargo build --release \ + && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ -RUN mkdir src && echo 'fn main() {}' > src/main.rs && cargo build --release && rm -rf src COPY src/ src/ RUN touch src/main.rs && cargo build --release diff --git a/Dockerfile.opencode b/Dockerfile.opencode index 244b5e3f1..994e44b59 100644 --- a/Dockerfile.opencode +++ b/Dockerfile.opencode @@ -2,8 +2,15 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ +COPY crates/openab-core/Cargo.toml crates/openab-core/Cargo.toml +COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml +RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ + && echo 'fn main() {}' > src/main.rs \ + && echo '' > crates/openab-core/src/lib.rs \ + && echo '' > crates/openab-gateway/src/lib.rs \ + && cargo build --release \ + && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ -RUN mkdir src && echo 'fn main() {}' > src/main.rs && cargo build --release && rm -rf src COPY src/ src/ RUN touch src/main.rs && cargo build --release diff --git a/Dockerfile.pi b/Dockerfile.pi index 697833c8c..d089fbae2 100644 --- a/Dockerfile.pi +++ b/Dockerfile.pi @@ -2,8 +2,15 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ +COPY crates/openab-core/Cargo.toml crates/openab-core/Cargo.toml +COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml +RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ + && echo 'fn main() {}' > src/main.rs \ + && echo '' > crates/openab-core/src/lib.rs \ + && echo '' > crates/openab-gateway/src/lib.rs \ + && cargo build --release \ + && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ -RUN mkdir src && echo 'fn main() {}' > src/main.rs && cargo build --release && rm -rf src COPY src/ src/ RUN touch src/main.rs && cargo build --release From b6df01c5ff265a119e57d48f028a6bd1099e632a Mon Sep 17 00:00:00 2001 From: chaodu-agent Date: Thu, 18 Jun 2026 20:28:32 +0000 Subject: [PATCH 08/20] fix: apply layer cache optimization to agentcore and native Dockerfiles --- Dockerfile.agentcore | 9 ++++++++- Dockerfile.native | 11 +++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/Dockerfile.agentcore b/Dockerfile.agentcore index 89060f805..805538997 100644 --- a/Dockerfile.agentcore +++ b/Dockerfile.agentcore @@ -6,8 +6,15 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ +COPY crates/openab-core/Cargo.toml crates/openab-core/Cargo.toml +COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml +RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ + && echo 'fn main() {}' > src/main.rs \ + && echo '' > crates/openab-core/src/lib.rs \ + && echo '' > crates/openab-gateway/src/lib.rs \ + && cargo build --release --features agentcore \ + && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ -RUN mkdir src && echo 'fn main() {}' > src/main.rs && cargo build --release --features agentcore && rm -rf src COPY src/ src/ RUN touch src/main.rs && cargo build --release --features agentcore diff --git a/Dockerfile.native b/Dockerfile.native index 1d1d6c562..782302867 100644 --- a/Dockerfile.native +++ b/Dockerfile.native @@ -2,9 +2,16 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ -COPY crates/ crates/ +COPY crates/openab-core/Cargo.toml crates/openab-core/Cargo.toml +COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml COPY openab-agent/ openab-agent/ -RUN mkdir src && echo 'fn main() {}' > src/main.rs && cargo build --release && rm -rf src +RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ + && echo 'fn main() {}' > src/main.rs \ + && echo '' > crates/openab-core/src/lib.rs \ + && echo '' > crates/openab-gateway/src/lib.rs \ + && cargo build --release \ + && rm -rf src crates/openab-core/src crates/openab-gateway/src +COPY crates/ crates/ COPY src/ src/ RUN touch src/main.rs && cargo build --release RUN cd openab-agent && cargo build --release From 504b5ea79223bb5ab98275ce94dee1acbe69171d Mon Sep 17 00:00:00 2001 From: chaodu-agent Date: Thu, 18 Jun 2026 20:30:46 +0000 Subject: [PATCH 09/20] fix: add AppState and related types to gateway lib.rs These were previously in gateway/src/main.rs and accessed by adapters via crate::AppState. Now that gateway is a library crate, they need to be in lib.rs for the adapters to compile. --- crates/openab-gateway/src/lib.rs | 43 ++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/crates/openab-gateway/src/lib.rs b/crates/openab-gateway/src/lib.rs index d81de34d0..c67f9fc85 100644 --- a/crates/openab-gateway/src/lib.rs +++ b/crates/openab-gateway/src/lib.rs @@ -2,3 +2,46 @@ pub mod adapters; pub(crate) mod media; pub mod schema; pub mod store; + +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Instant; +use tokio::sync::{broadcast, Mutex, Semaphore}; + +// --- Reply token cache for LINE hybrid Reply/Push dispatch --- + +/// Cache entry for LINE reply tokens: (replyToken, insertion_time). +pub type ReplyTokenCache = Arc>>; + +/// Maximum age (in seconds) before a cached reply token is considered expired. +pub const REPLY_TOKEN_TTL_SECS: u64 = 50; + +/// Maximum number of cached reply tokens. +pub const REPLY_TOKEN_CACHE_MAX: usize = 10_000; + +/// Maximum number of post-ack LINE webhook payloads processed concurrently. +pub const LINE_WEBHOOK_CONCURRENCY_MAX: usize = 8; + +// --- App state (shared across all adapters) --- + +pub struct AppState { + pub telegram_bot_token: Option, + pub telegram_secret_token: Option, + pub telegram_rich_messages: bool, + pub line_channel_secret: Option, + pub line_access_token: Option, + #[cfg(feature = "teams")] + pub teams: Option, + pub teams_service_urls: Mutex>, + #[cfg(feature = "feishu")] + pub feishu: Option, + #[cfg(feature = "googlechat")] + pub google_chat: Option, + #[cfg(feature = "wecom")] + pub wecom: Option, + pub ws_token: Option, + pub event_tx: broadcast::Sender, + pub reply_token_cache: ReplyTokenCache, + pub line_webhook_semaphore: Arc, + pub client: reqwest::Client, +} From 198e19c365489a35e704aa039e8bcb62ffe4cb42 Mon Sep 17 00:00:00 2001 From: chaodufashi Date: Thu, 18 Jun 2026 20:53:47 +0000 Subject: [PATCH 10/20] fix: resolve CI clippy failure + restore gateway standalone binary - Remove unused `use openab_core::stt` import (fixes clippy -D warnings) - Restore gateway standalone binary via [[bin]] in openab-gateway crate - Add Dockerfile.gateway for proper gateway image build - Fix build-gateway.yml to use Dockerfile.gateway instead of BUILD_MODE=unified - Add warn!() log when unified features compiled but runtime not wired --- .github/workflows/build-gateway.yml | 4 +- Dockerfile.gateway | 23 ++ crates/openab-gateway/Cargo.toml | 4 + crates/openab-gateway/src/main.rs | 410 ++++++++++++++++++++++++++++ src/main.rs | 3 +- 5 files changed, 439 insertions(+), 5 deletions(-) create mode 100644 Dockerfile.gateway create mode 100644 crates/openab-gateway/src/main.rs diff --git a/.github/workflows/build-gateway.yml b/.github/workflows/build-gateway.yml index d5dd62ee7..9c46ffa50 100644 --- a/.github/workflows/build-gateway.yml +++ b/.github/workflows/build-gateway.yml @@ -58,9 +58,7 @@ jobs: uses: docker/build-push-action@v6 with: context: . - file: Dockerfile - build-args: | - BUILD_MODE=unified + file: Dockerfile.gateway platforms: ${{ matrix.platform.os }} outputs: type=image,name=${{ env.IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=${{ inputs.dry_run != true }} cache-from: type=gha,scope=gateway-${{ matrix.platform.os }} diff --git a/Dockerfile.gateway b/Dockerfile.gateway new file mode 100644 index 000000000..28372f86a --- /dev/null +++ b/Dockerfile.gateway @@ -0,0 +1,23 @@ +# --- Build stage --- +FROM rust:1-bookworm AS builder +WORKDIR /build +COPY Cargo.toml Cargo.lock ./ +COPY crates/openab-core/Cargo.toml crates/openab-core/Cargo.toml +COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml +RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ + && echo 'fn main() {}' > src/main.rs \ + && echo '' > crates/openab-core/src/lib.rs \ + && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ + && echo '' > crates/openab-gateway/src/lib.rs \ + && cargo build --release -p openab-gateway \ + && rm -rf src crates/openab-core/src crates/openab-gateway/src +COPY crates/ crates/ +COPY src/ src/ +RUN touch crates/openab-gateway/src/main.rs && cargo build --release -p openab-gateway + +# --- Runtime stage --- +FROM debian:bookworm-slim +RUN apt-get update && apt-get install -y --no-install-recommends ca-certificates && rm -rf /var/lib/apt/lists/* +COPY --from=builder /build/target/release/openab-gateway /usr/local/bin/openab-gateway +EXPOSE 8080 +ENTRYPOINT ["openab-gateway"] diff --git a/crates/openab-gateway/Cargo.toml b/crates/openab-gateway/Cargo.toml index 3236ffcd1..de8921166 100644 --- a/crates/openab-gateway/Cargo.toml +++ b/crates/openab-gateway/Cargo.toml @@ -4,6 +4,10 @@ version = "0.5.4" edition = "2021" license = "MIT" +[[bin]] +name = "openab-gateway" +path = "src/main.rs" + [dependencies] tokio = { version = "1", features = ["full"] } axum = { version = "0.8", features = ["ws"] } diff --git a/crates/openab-gateway/src/main.rs b/crates/openab-gateway/src/main.rs new file mode 100644 index 000000000..77e2f18fd --- /dev/null +++ b/crates/openab-gateway/src/main.rs @@ -0,0 +1,410 @@ +use anyhow::Result; +use axum::{ + extract::State, + response::IntoResponse, + routing::{get, post}, + Router, +}; +use futures_util::{SinkExt, StreamExt}; +use openab_gateway::schema::GatewayReply; +use openab_gateway::{ + adapters, AppState, ReplyTokenCache, LINE_WEBHOOK_CONCURRENCY_MAX, REPLY_TOKEN_TTL_SECS, +}; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Instant; +use tokio::sync::{broadcast, Mutex, Semaphore}; +use tracing::{info, warn}; + +// --- WebSocket handler (OAB connects here) --- + +async fn ws_handler( + State(state): State>, + query: axum::extract::Query>, + ws: axum::extract::WebSocketUpgrade, +) -> axum::response::Response { + if let Some(ref expected) = state.ws_token { + let provided = query.get("token").map(|s| s.as_str()); + if provided != Some(expected.as_str()) { + warn!("WebSocket rejected: invalid or missing token"); + return axum::http::StatusCode::UNAUTHORIZED.into_response(); + } + } + ws.on_upgrade(move |socket| handle_oab_connection(state, socket)) +} + +async fn handle_oab_connection(state: Arc, socket: axum::extract::ws::WebSocket) { + use axum::extract::ws::Message; + + let (mut ws_tx, mut ws_rx) = socket.split(); + let mut event_rx = state.event_tx.subscribe(); + + info!("OAB client connected via WebSocket"); + + let send_task = tokio::spawn(async move { + loop { + tokio::select! { + Ok(event_json) = event_rx.recv() => { + if ws_tx.send(Message::Text(event_json.into())).await.is_err() { + break; + } + } + } + } + }); + + let state_for_recv = state.clone(); + let reaction_state: Arc>>> = + Arc::new(Mutex::new(HashMap::new())); + let recv_task = tokio::spawn(async move { + let client = reqwest::Client::new(); + while let Some(Ok(msg)) = ws_rx.next().await { + if let Message::Text(text) = msg { + match serde_json::from_str::(&text) { + Ok(reply) => { + info!( + platform = %reply.platform, + channel = %reply.channel.id, + command = ?reply.command.as_deref(), + "OAB → gateway reply" + ); + match reply.platform.as_str() { + #[cfg(feature = "telegram")] + "telegram" => { + if let Some(ref token) = state_for_recv.telegram_bot_token { + adapters::telegram::handle_reply( + &reply, + token, + &client, + &state_for_recv.event_tx, + &reaction_state, + state_for_recv.telegram_rich_messages, + ) + .await; + } else { + warn!("reply for telegram but adapter not configured"); + } + } + #[cfg(feature = "line")] + "line" => { + if let Some(ref access_token) = state_for_recv.line_access_token { + adapters::line::dispatch_line_reply( + &client, + access_token, + &state_for_recv.reply_token_cache, + &reply, + adapters::line::LINE_API_BASE, + ) + .await; + } else { + warn!("reply for line but adapter not configured"); + } + } + #[cfg(feature = "teams")] + "teams" => { + if let Some(ref teams) = state_for_recv.teams { + adapters::teams::handle_reply( + &reply, + teams, + &state_for_recv.teams_service_urls, + ) + .await; + } else { + warn!("reply for teams but adapter not configured"); + } + } + #[cfg(feature = "feishu")] + "feishu" => { + if let Some(ref feishu) = state_for_recv.feishu { + adapters::feishu::handle_reply( + &reply, + feishu, + &state_for_recv.event_tx, + ) + .await; + } else { + warn!("reply for feishu but adapter not configured"); + } + } + #[cfg(feature = "googlechat")] + "googlechat" => { + if let Some(ref gc) = state_for_recv.google_chat { + gc.handle_reply(&reply, &state_for_recv.event_tx).await; + } else { + warn!("reply for googlechat but adapter not configured"); + } + } + #[cfg(feature = "wecom")] + "wecom" => { + if let Some(ref wecom) = state_for_recv.wecom { + wecom.handle_reply(&reply, &state_for_recv.event_tx).await; + } else { + warn!("reply for wecom but adapter not configured"); + } + } + other => warn!(platform = other, "unknown reply platform"), + } + } + Err(e) => warn!("invalid reply from OAB: {e}"), + } + } + } + }); + + tokio::select! { + _ = send_task => {}, + _ = recv_task => {}, + } + info!("OAB client disconnected"); +} + +async fn health() -> &'static str { + "ok" +} + +#[tokio::main] +async fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into()), + ) + .init(); + + let listen_addr = std::env::var("GATEWAY_LISTEN").unwrap_or_else(|_| "0.0.0.0:8080".into()); + let ws_token = std::env::var("GATEWAY_WS_TOKEN").ok(); + + if ws_token.is_none() { + warn!("GATEWAY_WS_TOKEN not set — WebSocket connections are NOT authenticated (insecure)"); + } + + let (event_tx, _) = broadcast::channel::(256); + let reply_token_cache: ReplyTokenCache = Arc::new(std::sync::Mutex::new(HashMap::new())); + + let mut app = Router::new() + .route("/ws", get(ws_handler)) + .route("/health", get(health)); + + // Telegram adapter + let telegram_bot_token = std::env::var("TELEGRAM_BOT_TOKEN").ok(); + let telegram_secret_token = std::env::var("TELEGRAM_SECRET_TOKEN").ok(); + let telegram_rich_messages = std::env::var("TELEGRAM_RICH_MESSAGES") + .map(|v| v != "0" && !v.eq_ignore_ascii_case("false")) + .unwrap_or(true); + #[cfg(feature = "telegram")] + if telegram_bot_token.is_some() { + let webhook_path = + std::env::var("TELEGRAM_WEBHOOK_PATH").unwrap_or_else(|_| "/webhook/telegram".into()); + if telegram_secret_token.is_none() { + warn!("TELEGRAM_SECRET_TOKEN not set — webhook requests are NOT validated (insecure)"); + } + info!(path = %webhook_path, "telegram adapter enabled"); + app = app.route(&webhook_path, post(adapters::telegram::webhook)); + } + + // LINE adapter + let line_channel_secret = std::env::var("LINE_CHANNEL_SECRET").ok(); + let line_access_token = std::env::var("LINE_CHANNEL_ACCESS_TOKEN").ok(); + #[cfg(feature = "line")] + { + info!("line adapter enabled"); + app = app.route("/webhook/line", post(adapters::line::webhook)); + } + + // Teams adapter + #[cfg(feature = "teams")] + let teams = adapters::teams::TeamsConfig::from_env().map(|config| { + let webhook_path = + std::env::var("TEAMS_WEBHOOK_PATH").unwrap_or_else(|_| "/webhook/teams".into()); + info!(path = %webhook_path, "teams adapter enabled"); + adapters::teams::TeamsAdapter::new(config) + }); + #[cfg(not(feature = "teams"))] + let teams: Option<()> = None; + + #[cfg(feature = "teams")] + if teams.is_some() { + let webhook_path = + std::env::var("TEAMS_WEBHOOK_PATH").unwrap_or_else(|_| "/webhook/teams".into()); + app = app.route(&webhook_path, post(adapters::teams::webhook)); + } + + // Feishu adapter + #[cfg(feature = "feishu")] + let feishu_config = adapters::feishu::FeishuConfig::from_env(); + #[cfg(feature = "feishu")] + let feishu_ws_mode = feishu_config + .as_ref() + .map(|c| c.connection_mode == adapters::feishu::ConnectionMode::Websocket) + .unwrap_or(false); + #[cfg(feature = "feishu")] + if let Some(ref config) = feishu_config { + match config.connection_mode { + adapters::feishu::ConnectionMode::Websocket => { + info!("feishu adapter enabled (websocket) — will connect after state init"); + } + adapters::feishu::ConnectionMode::Webhook => { + let path = config.webhook_path.clone(); + info!(path = %path, "feishu adapter enabled (webhook)"); + app = app.route(&path, post(adapters::feishu::webhook)); + } + } + } + #[cfg(feature = "feishu")] + let feishu = feishu_config.map(adapters::feishu::FeishuAdapter::new); + #[cfg(feature = "feishu")] + if let Some(ref f) = feishu { + f.resolve_bot_identity().await; + } + + // Google Chat adapter + #[cfg(feature = "googlechat")] + let google_chat = { + let enabled = std::env::var("GOOGLE_CHAT_ENABLED") + .map(|v| v == "true" || v == "1") + .unwrap_or(false); + if enabled { + let token_cache = std::env::var("GOOGLE_CHAT_SA_KEY_JSON") + .ok() + .or_else(|| { + std::env::var("GOOGLE_CHAT_SA_KEY_FILE") + .ok() + .and_then(|path| std::fs::read_to_string(&path).ok()) + }) + .and_then(|json| { + adapters::googlechat::GoogleChatTokenCache::new(&json) + .map_err(|e| warn!("googlechat SA key error: {e}")) + .ok() + }); + let access_token = std::env::var("GOOGLE_CHAT_ACCESS_TOKEN").ok(); + let jwt_verifier = std::env::var("GOOGLE_CHAT_AUDIENCE").ok().map(|aud| { + info!("googlechat webhook JWT verification enabled (audience={aud})"); + adapters::googlechat::GoogleChatJwtVerifier::new(aud) + }); + let webhook_path = std::env::var("GOOGLE_CHAT_WEBHOOK_PATH") + .unwrap_or_else(|_| "/webhook/googlechat".into()); + info!(path = %webhook_path, "googlechat adapter enabled"); + app = app.route(&webhook_path, post(adapters::googlechat::webhook)); + Some(adapters::googlechat::GoogleChatAdapter::new( + token_cache, + access_token, + jwt_verifier, + )) + } else { + None + } + }; + + // WeCom adapter + #[cfg(feature = "wecom")] + let wecom = adapters::wecom::WecomConfig::from_env().map(|config| { + let path = config.webhook_path.clone(); + info!(path = %path, "wecom adapter enabled"); + adapters::wecom::WecomAdapter::new(config) + }); + #[cfg(feature = "wecom")] + if let Some(ref w) = wecom { + app = app + .route( + &w.config.webhook_path, + axum::routing::get(adapters::wecom::verify), + ) + .route(&w.config.webhook_path, post(adapters::wecom::webhook)); + } + + let client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(30)) + .build() + .expect("HTTP client must build"); + + let state = Arc::new(AppState { + telegram_bot_token, + telegram_secret_token, + telegram_rich_messages, + line_channel_secret, + line_access_token, + #[cfg(feature = "teams")] + teams, + teams_service_urls: Mutex::new(HashMap::new()), + #[cfg(feature = "feishu")] + feishu, + #[cfg(feature = "googlechat")] + google_chat, + #[cfg(feature = "wecom")] + wecom, + ws_token, + event_tx, + reply_token_cache, + line_webhook_semaphore: Arc::new(Semaphore::new(LINE_WEBHOOK_CONCURRENCY_MAX)), + client, + }); + + // Background: sweep expired reply tokens + { + let cache_state = state.clone(); + tokio::spawn(async move { + loop { + tokio::time::sleep(std::time::Duration::from_secs(REPLY_TOKEN_TTL_SECS)).await; + let mut cache = cache_state + .reply_token_cache + .lock() + .unwrap_or_else(|e| e.into_inner()); + let before = cache.len(); + cache.retain(|_, (_, t)| t.elapsed().as_secs() < REPLY_TOKEN_TTL_SECS); + let after = cache.len(); + if before != after { + info!( + removed = before - after, + remaining = after, + "reply token cache sweep" + ); + } + } + }); + } + + // Background: cleanup stale Teams service_url entries (TTL: 4 hours) + { + let state_for_cleanup = state.clone(); + tokio::spawn(async move { + loop { + tokio::time::sleep(std::time::Duration::from_secs(300)).await; + let mut urls = state_for_cleanup.teams_service_urls.lock().await; + let before = urls.len(); + urls.retain(|_, (_, t)| t.elapsed().as_secs() < 4 * 3600); + let after = urls.len(); + if before != after { + info!( + removed = before - after, + remaining = after, + "teams service_url cache cleanup" + ); + } + } + }); + } + + let app = app.with_state(state.clone()); + + // Background: evict expired media files + tokio::spawn(openab_gateway::store::eviction_loop()); + + // Spawn feishu WebSocket long-connection if configured + #[cfg(feature = "feishu")] + let _feishu_shutdown_tx = { + let (tx, rx) = tokio::sync::watch::channel(false); + if feishu_ws_mode { + if let Some(ref feishu) = state.feishu { + match adapters::feishu::start_websocket(feishu, state.event_tx.clone(), rx).await { + Ok(_handle) => info!("feishu websocket task spawned"), + Err(e) => tracing::error!(err = %e, "feishu websocket startup failed"), + } + } + } + tx + }; + + info!(addr = %listen_addr, "gateway starting"); + let listener = tokio::net::TcpListener::bind(&listen_addr).await?; + axum::serve(listener, app).await?; + Ok(()) +} diff --git a/src/main.rs b/src/main.rs index f28a1f3bd..2e66c4206 100644 --- a/src/main.rs +++ b/src/main.rs @@ -14,8 +14,6 @@ use openab_core::secrets; use openab_core::setup; #[cfg(feature = "slack")] use openab_core::slack; -use openab_core::stt; - use clap::Parser; #[cfg(feature = "discord")] use serenity::gateway::GatewayError; @@ -388,6 +386,7 @@ async fn main() -> anyhow::Result<()> { // For now, the feature compiles the gateway crate (making the code available) // but the full runtime integration (axum server, route registration, direct dispatch) // will be completed in a follow-up PR. + warn!("unified gateway features compiled in but runtime integration not yet wired — gateway adapters are NOT active in this binary"); None::> }; From 140d826ce173c1ec499e561f3137c6df3c047f50 Mon Sep 17 00:00:00 2001 From: chaodufashi Date: Thu, 18 Jun 2026 21:02:15 +0000 Subject: [PATCH 11/20] fix: remove unused Instant import in gateway binary --- crates/openab-gateway/src/main.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/crates/openab-gateway/src/main.rs b/crates/openab-gateway/src/main.rs index 77e2f18fd..f0dc8b3c7 100644 --- a/crates/openab-gateway/src/main.rs +++ b/crates/openab-gateway/src/main.rs @@ -12,7 +12,6 @@ use openab_gateway::{ }; use std::collections::HashMap; use std::sync::Arc; -use std::time::Instant; use tokio::sync::{broadcast, Mutex, Semaphore}; use tracing::{info, warn}; From bf8b92713645d47e2a3c8c774120dcf766e9fe0e Mon Sep 17 00:00:00 2001 From: chaodufashi Date: Thu, 18 Jun 2026 21:10:41 +0000 Subject: [PATCH 12/20] fix: add dummy gateway main.rs to Dockerfile dep-cache layers The [[bin]] entry in openab-gateway requires src/main.rs to exist during the dependency pre-fetch layer, otherwise cargo build fails. --- Dockerfile | 1 + Dockerfile.agentcore | 1 + Dockerfile.antigravity | 1 + Dockerfile.claude | 1 + Dockerfile.codex | 1 + Dockerfile.copilot | 1 + Dockerfile.cursor | 1 + Dockerfile.gateway | 1 + Dockerfile.gemini | 1 + Dockerfile.grok | 1 + Dockerfile.hermes | 1 + Dockerfile.mimocode | 1 + Dockerfile.native | 1 + Dockerfile.opencode | 1 + Dockerfile.pi | 1 + 15 files changed, 15 insertions(+) diff --git a/Dockerfile b/Dockerfile index 7ffd48d5e..d5f59532f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -14,6 +14,7 @@ RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ && echo 'fn main() {}' > src/main.rs \ && echo '' > crates/openab-core/src/lib.rs \ && echo '' > crates/openab-gateway/src/lib.rs \ + && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ && cargo build --release \ && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ diff --git a/Dockerfile.agentcore b/Dockerfile.agentcore index 805538997..10de5f572 100644 --- a/Dockerfile.agentcore +++ b/Dockerfile.agentcore @@ -12,6 +12,7 @@ RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ && echo 'fn main() {}' > src/main.rs \ && echo '' > crates/openab-core/src/lib.rs \ && echo '' > crates/openab-gateway/src/lib.rs \ + && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ && cargo build --release --features agentcore \ && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ diff --git a/Dockerfile.antigravity b/Dockerfile.antigravity index 4ed8341bc..d484aa033 100644 --- a/Dockerfile.antigravity +++ b/Dockerfile.antigravity @@ -8,6 +8,7 @@ RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ && echo 'fn main() {}' > src/main.rs \ && echo '' > crates/openab-core/src/lib.rs \ && echo '' > crates/openab-gateway/src/lib.rs \ + && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ && cargo build --release \ && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ diff --git a/Dockerfile.claude b/Dockerfile.claude index 710171bb2..d90814d84 100644 --- a/Dockerfile.claude +++ b/Dockerfile.claude @@ -8,6 +8,7 @@ RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ && echo 'fn main() {}' > src/main.rs \ && echo '' > crates/openab-core/src/lib.rs \ && echo '' > crates/openab-gateway/src/lib.rs \ + && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ && cargo build --release \ && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ diff --git a/Dockerfile.codex b/Dockerfile.codex index 1bcf9ab82..f876229b1 100644 --- a/Dockerfile.codex +++ b/Dockerfile.codex @@ -8,6 +8,7 @@ RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ && echo 'fn main() {}' > src/main.rs \ && echo '' > crates/openab-core/src/lib.rs \ && echo '' > crates/openab-gateway/src/lib.rs \ + && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ && cargo build --release \ && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ diff --git a/Dockerfile.copilot b/Dockerfile.copilot index 3fc76809c..0a3c1f3d8 100644 --- a/Dockerfile.copilot +++ b/Dockerfile.copilot @@ -8,6 +8,7 @@ RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ && echo 'fn main() {}' > src/main.rs \ && echo '' > crates/openab-core/src/lib.rs \ && echo '' > crates/openab-gateway/src/lib.rs \ + && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ && cargo build --release \ && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ diff --git a/Dockerfile.cursor b/Dockerfile.cursor index a65bcacda..48f77c427 100644 --- a/Dockerfile.cursor +++ b/Dockerfile.cursor @@ -8,6 +8,7 @@ RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ && echo 'fn main() {}' > src/main.rs \ && echo '' > crates/openab-core/src/lib.rs \ && echo '' > crates/openab-gateway/src/lib.rs \ + && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ && cargo build --release \ && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ diff --git a/Dockerfile.gateway b/Dockerfile.gateway index 28372f86a..e5a464791 100644 --- a/Dockerfile.gateway +++ b/Dockerfile.gateway @@ -9,6 +9,7 @@ RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ && echo '' > crates/openab-core/src/lib.rs \ && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ && echo '' > crates/openab-gateway/src/lib.rs \ + && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ && cargo build --release -p openab-gateway \ && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ diff --git a/Dockerfile.gemini b/Dockerfile.gemini index 6a4585588..6d09a8162 100644 --- a/Dockerfile.gemini +++ b/Dockerfile.gemini @@ -8,6 +8,7 @@ RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ && echo 'fn main() {}' > src/main.rs \ && echo '' > crates/openab-core/src/lib.rs \ && echo '' > crates/openab-gateway/src/lib.rs \ + && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ && cargo build --release \ && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ diff --git a/Dockerfile.grok b/Dockerfile.grok index 53ee03c3a..b74fb2874 100644 --- a/Dockerfile.grok +++ b/Dockerfile.grok @@ -8,6 +8,7 @@ RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ && echo 'fn main() {}' > src/main.rs \ && echo '' > crates/openab-core/src/lib.rs \ && echo '' > crates/openab-gateway/src/lib.rs \ + && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ && cargo build --release \ && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ diff --git a/Dockerfile.hermes b/Dockerfile.hermes index 9629ccd29..8de874019 100644 --- a/Dockerfile.hermes +++ b/Dockerfile.hermes @@ -8,6 +8,7 @@ RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ && echo 'fn main() {}' > src/main.rs \ && echo '' > crates/openab-core/src/lib.rs \ && echo '' > crates/openab-gateway/src/lib.rs \ + && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ && cargo build --release \ && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ diff --git a/Dockerfile.mimocode b/Dockerfile.mimocode index 7e5daae5d..9ef899004 100644 --- a/Dockerfile.mimocode +++ b/Dockerfile.mimocode @@ -8,6 +8,7 @@ RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ && echo 'fn main() {}' > src/main.rs \ && echo '' > crates/openab-core/src/lib.rs \ && echo '' > crates/openab-gateway/src/lib.rs \ + && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ && cargo build --release \ && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ diff --git a/Dockerfile.native b/Dockerfile.native index 782302867..28b283375 100644 --- a/Dockerfile.native +++ b/Dockerfile.native @@ -9,6 +9,7 @@ RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ && echo 'fn main() {}' > src/main.rs \ && echo '' > crates/openab-core/src/lib.rs \ && echo '' > crates/openab-gateway/src/lib.rs \ + && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ && cargo build --release \ && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ diff --git a/Dockerfile.opencode b/Dockerfile.opencode index 994e44b59..6e16958bd 100644 --- a/Dockerfile.opencode +++ b/Dockerfile.opencode @@ -8,6 +8,7 @@ RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ && echo 'fn main() {}' > src/main.rs \ && echo '' > crates/openab-core/src/lib.rs \ && echo '' > crates/openab-gateway/src/lib.rs \ + && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ && cargo build --release \ && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ diff --git a/Dockerfile.pi b/Dockerfile.pi index d089fbae2..662c2b335 100644 --- a/Dockerfile.pi +++ b/Dockerfile.pi @@ -8,6 +8,7 @@ RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ && echo 'fn main() {}' > src/main.rs \ && echo '' > crates/openab-core/src/lib.rs \ && echo '' > crates/openab-gateway/src/lib.rs \ + && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ && cargo build --release \ && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ From 168b46eb1615cf4c70c06e637fde44d7afa28ebb Mon Sep 17 00:00:00 2001 From: chaodufashi Date: Thu, 18 Jun 2026 21:29:30 +0000 Subject: [PATCH 13/20] fix: touch all workspace member sources in Docker final build step Cargo's incremental build needs source files to be newer than cached artifacts. The dep-cache layer compiled dummy empty sources, so we must touch the real sources after COPY to trigger recompilation. --- Dockerfile | 2 +- Dockerfile.agentcore | 2 +- Dockerfile.antigravity | 2 +- Dockerfile.claude | 2 +- Dockerfile.codex | 2 +- Dockerfile.copilot | 2 +- Dockerfile.cursor | 2 +- Dockerfile.gateway | 2 +- Dockerfile.gemini | 2 +- Dockerfile.grok | 2 +- Dockerfile.hermes | 2 +- Dockerfile.mimocode | 2 +- Dockerfile.native | 2 +- Dockerfile.opencode | 2 +- Dockerfile.pi | 2 +- 15 files changed, 15 insertions(+), 15 deletions(-) diff --git a/Dockerfile b/Dockerfile index d5f59532f..d688f2961 100644 --- a/Dockerfile +++ b/Dockerfile @@ -19,7 +19,7 @@ RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs && \ +RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && \ if [ "$BUILD_MODE" = "unified" ]; then \ cargo build --release --features unified; \ elif [ -n "$FEATURES" ]; then \ diff --git a/Dockerfile.agentcore b/Dockerfile.agentcore index 10de5f572..8b1c80727 100644 --- a/Dockerfile.agentcore +++ b/Dockerfile.agentcore @@ -17,7 +17,7 @@ RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs && cargo build --release --features agentcore +RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release --features agentcore # --- Runtime stage --- FROM debian:bookworm-slim diff --git a/Dockerfile.antigravity b/Dockerfile.antigravity index d484aa033..7be41c40e 100644 --- a/Dockerfile.antigravity +++ b/Dockerfile.antigravity @@ -13,7 +13,7 @@ RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs && cargo build --release +RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release # --- Build agy-acp adapter --- FROM rust:1-bookworm AS adapter-builder diff --git a/Dockerfile.claude b/Dockerfile.claude index d90814d84..40e9dfa44 100644 --- a/Dockerfile.claude +++ b/Dockerfile.claude @@ -13,7 +13,7 @@ RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs && cargo build --release +RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release # --- Runtime stage --- FROM node:22-bookworm-slim diff --git a/Dockerfile.codex b/Dockerfile.codex index f876229b1..90d03e7e3 100644 --- a/Dockerfile.codex +++ b/Dockerfile.codex @@ -13,7 +13,7 @@ RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs && cargo build --release +RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release # --- Runtime stage --- FROM node:22-bookworm-slim diff --git a/Dockerfile.copilot b/Dockerfile.copilot index 0a3c1f3d8..a5d958138 100644 --- a/Dockerfile.copilot +++ b/Dockerfile.copilot @@ -13,7 +13,7 @@ RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs && cargo build --release +RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release # --- Runtime stage --- FROM node:22-bookworm-slim diff --git a/Dockerfile.cursor b/Dockerfile.cursor index 48f77c427..ebd810582 100644 --- a/Dockerfile.cursor +++ b/Dockerfile.cursor @@ -13,7 +13,7 @@ RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs && cargo build --release +RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release # --- Runtime stage --- FROM debian:bookworm-slim diff --git a/Dockerfile.gateway b/Dockerfile.gateway index e5a464791..e55a82163 100644 --- a/Dockerfile.gateway +++ b/Dockerfile.gateway @@ -14,7 +14,7 @@ RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch crates/openab-gateway/src/main.rs && cargo build --release -p openab-gateway +RUN touch crates/openab-gateway/src/main.rs crates/openab-gateway/src/lib.rs && cargo build --release -p openab-gateway # --- Runtime stage --- FROM debian:bookworm-slim diff --git a/Dockerfile.gemini b/Dockerfile.gemini index 6d09a8162..67d30e736 100644 --- a/Dockerfile.gemini +++ b/Dockerfile.gemini @@ -13,7 +13,7 @@ RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs && cargo build --release +RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release # --- Runtime stage --- FROM node:22-bookworm-slim diff --git a/Dockerfile.grok b/Dockerfile.grok index b74fb2874..945d06f7f 100644 --- a/Dockerfile.grok +++ b/Dockerfile.grok @@ -13,7 +13,7 @@ RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs && cargo build --release +RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release # --- Runtime stage --- FROM debian:bookworm-slim diff --git a/Dockerfile.hermes b/Dockerfile.hermes index 8de874019..00dfcb461 100644 --- a/Dockerfile.hermes +++ b/Dockerfile.hermes @@ -13,7 +13,7 @@ RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs && cargo build --release +RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release # --- Runtime stage --- FROM python:3.12-slim-bookworm diff --git a/Dockerfile.mimocode b/Dockerfile.mimocode index 9ef899004..d02a59120 100644 --- a/Dockerfile.mimocode +++ b/Dockerfile.mimocode @@ -13,7 +13,7 @@ RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs && cargo build --release +RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release # --- Runtime stage --- # MiMo-Code (https://github.com/XiaomiMiMo/MiMo-Code) is a fork of OpenCode diff --git a/Dockerfile.native b/Dockerfile.native index 28b283375..4b69794c5 100644 --- a/Dockerfile.native +++ b/Dockerfile.native @@ -14,7 +14,7 @@ RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs && cargo build --release +RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release RUN cd openab-agent && cargo build --release # --- Runtime stage --- diff --git a/Dockerfile.opencode b/Dockerfile.opencode index 6e16958bd..e52a605ee 100644 --- a/Dockerfile.opencode +++ b/Dockerfile.opencode @@ -13,7 +13,7 @@ RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs && cargo build --release +RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release # --- Runtime stage --- # node:22-bookworm-slim mirrors the base image used by Dockerfile.claude, diff --git a/Dockerfile.pi b/Dockerfile.pi index 662c2b335..206ec38e1 100644 --- a/Dockerfile.pi +++ b/Dockerfile.pi @@ -13,7 +13,7 @@ RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ && rm -rf src crates/openab-core/src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs && cargo build --release +RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release # --- Runtime stage --- FROM node:22-bookworm-slim From 263d2e8e5a53e7ce2d5f9f073e366b9d1973a35a Mon Sep 17 00:00:00 2001 From: chaodu-agent Date: Thu, 18 Jun 2026 22:55:30 +0000 Subject: [PATCH 14/20] =?UTF-8?q?refactor:=20simplify=20=E2=80=94=20keep?= =?UTF-8?q?=20core=20in=20src/,=20only=20extract=20gateway=20to=20crate?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous approach of extracting core into a separate library crate (openab-core) required making all internal items pub across 30+ files. This is a Phase 2 concern. Simplified approach (matches ADR intent): - src/ keeps all core modules as the binary's own code (unchanged from main) - crates/openab-gateway/ is the only workspace member (library crate) - Root Cargo.toml adds workspace + gateway feature flags on top of original deps - --features unified/telegram/etc. pulls in openab-gateway This means: - Default build: identical to main (zero behavior change) - cargo build --features unified: adds gateway adapters - cargo build --no-default-features --features telegram: minimal binary Removed crates/openab-core/ entirely. --- Cargo.lock | 15 +-- Cargo.toml | 58 ++++++--- Dockerfile | 6 +- Dockerfile.antigravity | 6 +- Dockerfile.claude | 6 +- Dockerfile.codex | 6 +- Dockerfile.copilot | 6 +- Dockerfile.cursor | 6 +- Dockerfile.gemini | 6 +- Dockerfile.grok | 6 +- Dockerfile.hermes | 6 +- Dockerfile.mimocode | 6 +- Dockerfile.opencode | 6 +- Dockerfile.pi | 6 +- crates/openab-core/Cargo.toml | 54 --------- crates/openab-core/src/lib.rs | 25 ---- .../openab-core/src => src}/acp/agentcore.rs | 0 .../openab-core/src => src}/acp/connection.rs | 0 {crates/openab-core/src => src}/acp/mod.rs | 0 {crates/openab-core/src => src}/acp/pool.rs | 0 .../openab-core/src => src}/acp/protocol.rs | 0 {crates/openab-core/src => src}/adapter.rs | 0 {crates/openab-core/src => src}/bot_turns.rs | 0 {crates/openab-core/src => src}/config.rs | 0 {crates/openab-core/src => src}/cron.rs | 0 {crates/openab-core/src => src}/directives.rs | 0 {crates/openab-core/src => src}/discord.rs | 0 {crates/openab-core/src => src}/dispatch.rs | 0 .../openab-core/src => src}/error_display.rs | 0 {crates/openab-core/src => src}/format.rs | 0 {crates/openab-core/src => src}/gateway.rs | 0 {crates/openab-core/src => src}/hooks.rs | 0 src/main.rs | 114 ++++++++---------- {crates/openab-core/src => src}/markdown.rs | 0 {crates/openab-core/src => src}/media.rs | 0 .../openab-core/src => src}/multibot_cache.rs | 0 {crates/openab-core/src => src}/reactions.rs | 0 {crates/openab-core/src => src}/remind.rs | 0 {crates/openab-core/src => src}/secrets.rs | 0 .../openab-core/src => src}/setup/config.rs | 0 {crates/openab-core/src => src}/setup/mod.rs | 0 .../openab-core/src => src}/setup/validate.rs | 0 .../openab-core/src => src}/setup/wizard.rs | 0 {crates/openab-core/src => src}/slack.rs | 0 {crates/openab-core/src => src}/stt.rs | 0 {crates/openab-core/src => src}/timestamp.rs | 0 46 files changed, 118 insertions(+), 220 deletions(-) delete mode 100644 crates/openab-core/Cargo.toml delete mode 100644 crates/openab-core/src/lib.rs rename {crates/openab-core/src => src}/acp/agentcore.rs (100%) rename {crates/openab-core/src => src}/acp/connection.rs (100%) rename {crates/openab-core/src => src}/acp/mod.rs (100%) rename {crates/openab-core/src => src}/acp/pool.rs (100%) rename {crates/openab-core/src => src}/acp/protocol.rs (100%) rename {crates/openab-core/src => src}/adapter.rs (100%) rename {crates/openab-core/src => src}/bot_turns.rs (100%) rename {crates/openab-core/src => src}/config.rs (100%) rename {crates/openab-core/src => src}/cron.rs (100%) rename {crates/openab-core/src => src}/directives.rs (100%) rename {crates/openab-core/src => src}/discord.rs (100%) rename {crates/openab-core/src => src}/dispatch.rs (100%) rename {crates/openab-core/src => src}/error_display.rs (100%) rename {crates/openab-core/src => src}/format.rs (100%) rename {crates/openab-core/src => src}/gateway.rs (100%) rename {crates/openab-core/src => src}/hooks.rs (100%) rename {crates/openab-core/src => src}/markdown.rs (100%) rename {crates/openab-core/src => src}/media.rs (100%) rename {crates/openab-core/src => src}/multibot_cache.rs (100%) rename {crates/openab-core/src => src}/reactions.rs (100%) rename {crates/openab-core/src => src}/remind.rs (100%) rename {crates/openab-core/src => src}/secrets.rs (100%) rename {crates/openab-core/src => src}/setup/config.rs (100%) rename {crates/openab-core/src => src}/setup/mod.rs (100%) rename {crates/openab-core/src => src}/setup/validate.rs (100%) rename {crates/openab-core/src => src}/setup/wizard.rs (100%) rename {crates/openab-core/src => src}/slack.rs (100%) rename {crates/openab-core/src => src}/stt.rs (100%) rename {crates/openab-core/src => src}/timestamp.rs (100%) diff --git a/Cargo.lock b/Cargo.lock index 626ada54e..fd8de76cd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1921,20 +1921,6 @@ checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" [[package]] name = "openab" version = "0.8.5" -dependencies = [ - "anyhow", - "clap", - "openab-core", - "openab-gateway", - "serenity", - "tokio", - "tracing", - "tracing-subscriber", -] - -[[package]] -name = "openab-core" -version = "0.8.5" dependencies = [ "anyhow", "async-trait", @@ -1952,6 +1938,7 @@ dependencies = [ "http 1.4.2", "image", "libc", + "openab-gateway", "pulldown-cmark", "rand 0.8.6", "regex", diff --git a/Cargo.toml b/Cargo.toml index abf9fecc0..f631630d2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["crates/openab-core", "crates/openab-gateway"] +members = ["crates/openab-gateway"] [package] name = "openab" @@ -8,30 +8,55 @@ edition = "2021" license = "MIT" [dependencies] -openab-core = { path = "crates/openab-core", default-features = false } -openab-gateway = { path = "crates/openab-gateway", default-features = false, optional = true } tokio = { version = "1", features = ["full"] } +serde = { version = "1", features = ["derive"] } +serde_json = "1" +toml = "0.8" +toml_edit = "0.22" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] } -clap = { version = "4", features = ["derive"] } -anyhow = "1" serenity = { version = "0.12", default-features = false, features = ["client", "gateway", "model", "rustls_backend", "cache"] } +uuid = { version = "1", features = ["v4"] } +regex = "1" +anyhow = "1" +async-trait = "0.1" +futures-util = "0.3" +rand = "0.8" +clap = { version = "4", features = ["derive"] } +rpassword = "7" +reqwest = { version = "0.12", default-features = false, features = ["rustls-tls", "multipart", "json", "blocking"] } +base64 = "0.22" +image = { version = "0.25", default-features = false, features = ["jpeg", "png", "gif", "webp"] } +unicode-width = "0.2" +pulldown-cmark = { version = "0.13", default-features = false } +tokio-tungstenite = { version = "0.21", features = ["rustls-tls-webpki-roots"] } +rustls = { version = "0.22", optional = true } +tokio-rustls = { version = "0.25", optional = true } +webpki-roots = { version = "0.26", optional = true } +cron = "0.16.0" +chrono = { version = "0.4.44", features = ["serde"] } +chrono-tz = "0.10.4" +sha2 = "0.10" +tempfile = "3.27.0" +aws-sdk-secretsmanager = { version = "1", optional = true } +aws-config = { version = "1", optional = true } +aws-sigv4 = { version = "1", optional = true } +aws-credential-types = { version = "1", optional = true } +urlencoding = { version = "2", optional = true } +hex = { version = "0.4", optional = true } +http = { version = "1", optional = true } + +# Gateway crate (opt-in for unified binary) +openab-gateway = { path = "crates/openab-gateway", default-features = false, optional = true } [features] -# Default: core only (Discord + Slack). Gateway ships as separate binary. -default = ["discord", "slack", "secrets-aws", "agentcore"] +default = ["secrets-aws", "agentcore"] +secrets-aws = ["dep:aws-sdk-secretsmanager", "dep:aws-config"] +agentcore = ["dep:aws-config", "dep:aws-sigv4", "dep:aws-credential-types", "dep:urlencoding", "dep:hex", "dep:http", "dep:rustls", "dep:tokio-rustls", "dep:webpki-roots"] # Opt-in: compile all gateway adapters into a single unified binary unified = ["telegram", "line", "feishu", "googlechat", "wecom", "teams"] -# Core adapters -discord = ["openab-core/discord"] -slack = ["openab-core/slack"] - -# Core optional features -secrets-aws = ["openab-core/secrets-aws"] -agentcore = ["openab-core/agentcore"] - # Gateway adapters (each pulls in the gateway crate) telegram = ["dep:openab-gateway", "openab-gateway/telegram"] line = ["dep:openab-gateway", "openab-gateway/line"] @@ -39,3 +64,6 @@ feishu = ["dep:openab-gateway", "openab-gateway/feishu"] googlechat = ["dep:openab-gateway", "openab-gateway/googlechat"] wecom = ["dep:openab-gateway", "openab-gateway/wecom"] teams = ["dep:openab-gateway", "openab-gateway/teams"] + +[target.'cfg(unix)'.dependencies] +libc = "0.2" diff --git a/Dockerfile b/Dockerfile index d688f2961..ff831bac8 100644 --- a/Dockerfile +++ b/Dockerfile @@ -8,15 +8,13 @@ ARG FEATURES WORKDIR /build COPY Cargo.toml Cargo.lock ./ -COPY crates/openab-core/Cargo.toml crates/openab-core/Cargo.toml COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml -RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ +RUN mkdir -p src crates/openab-gateway/src \ && echo 'fn main() {}' > src/main.rs \ - && echo '' > crates/openab-core/src/lib.rs \ && echo '' > crates/openab-gateway/src/lib.rs \ && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ && cargo build --release \ - && rm -rf src crates/openab-core/src crates/openab-gateway/src + && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && \ diff --git a/Dockerfile.antigravity b/Dockerfile.antigravity index 7be41c40e..aa388f4a9 100644 --- a/Dockerfile.antigravity +++ b/Dockerfile.antigravity @@ -2,15 +2,13 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ -COPY crates/openab-core/Cargo.toml crates/openab-core/Cargo.toml COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml -RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ +RUN mkdir -p src crates/openab-gateway/src \ && echo 'fn main() {}' > src/main.rs \ - && echo '' > crates/openab-core/src/lib.rs \ && echo '' > crates/openab-gateway/src/lib.rs \ && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ && cargo build --release \ - && rm -rf src crates/openab-core/src crates/openab-gateway/src + && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release diff --git a/Dockerfile.claude b/Dockerfile.claude index 40e9dfa44..b55cd7fd0 100644 --- a/Dockerfile.claude +++ b/Dockerfile.claude @@ -2,15 +2,13 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ -COPY crates/openab-core/Cargo.toml crates/openab-core/Cargo.toml COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml -RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ +RUN mkdir -p src crates/openab-gateway/src \ && echo 'fn main() {}' > src/main.rs \ - && echo '' > crates/openab-core/src/lib.rs \ && echo '' > crates/openab-gateway/src/lib.rs \ && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ && cargo build --release \ - && rm -rf src crates/openab-core/src crates/openab-gateway/src + && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release diff --git a/Dockerfile.codex b/Dockerfile.codex index 90d03e7e3..079878859 100644 --- a/Dockerfile.codex +++ b/Dockerfile.codex @@ -2,15 +2,13 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ -COPY crates/openab-core/Cargo.toml crates/openab-core/Cargo.toml COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml -RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ +RUN mkdir -p src crates/openab-gateway/src \ && echo 'fn main() {}' > src/main.rs \ - && echo '' > crates/openab-core/src/lib.rs \ && echo '' > crates/openab-gateway/src/lib.rs \ && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ && cargo build --release \ - && rm -rf src crates/openab-core/src crates/openab-gateway/src + && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release diff --git a/Dockerfile.copilot b/Dockerfile.copilot index a5d958138..110886c28 100644 --- a/Dockerfile.copilot +++ b/Dockerfile.copilot @@ -2,15 +2,13 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ -COPY crates/openab-core/Cargo.toml crates/openab-core/Cargo.toml COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml -RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ +RUN mkdir -p src crates/openab-gateway/src \ && echo 'fn main() {}' > src/main.rs \ - && echo '' > crates/openab-core/src/lib.rs \ && echo '' > crates/openab-gateway/src/lib.rs \ && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ && cargo build --release \ - && rm -rf src crates/openab-core/src crates/openab-gateway/src + && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release diff --git a/Dockerfile.cursor b/Dockerfile.cursor index ebd810582..1a920cfce 100644 --- a/Dockerfile.cursor +++ b/Dockerfile.cursor @@ -2,15 +2,13 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ -COPY crates/openab-core/Cargo.toml crates/openab-core/Cargo.toml COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml -RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ +RUN mkdir -p src crates/openab-gateway/src \ && echo 'fn main() {}' > src/main.rs \ - && echo '' > crates/openab-core/src/lib.rs \ && echo '' > crates/openab-gateway/src/lib.rs \ && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ && cargo build --release \ - && rm -rf src crates/openab-core/src crates/openab-gateway/src + && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release diff --git a/Dockerfile.gemini b/Dockerfile.gemini index 67d30e736..2f76e8148 100644 --- a/Dockerfile.gemini +++ b/Dockerfile.gemini @@ -2,15 +2,13 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ -COPY crates/openab-core/Cargo.toml crates/openab-core/Cargo.toml COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml -RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ +RUN mkdir -p src crates/openab-gateway/src \ && echo 'fn main() {}' > src/main.rs \ - && echo '' > crates/openab-core/src/lib.rs \ && echo '' > crates/openab-gateway/src/lib.rs \ && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ && cargo build --release \ - && rm -rf src crates/openab-core/src crates/openab-gateway/src + && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release diff --git a/Dockerfile.grok b/Dockerfile.grok index 945d06f7f..3daf1d5d5 100644 --- a/Dockerfile.grok +++ b/Dockerfile.grok @@ -2,15 +2,13 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ -COPY crates/openab-core/Cargo.toml crates/openab-core/Cargo.toml COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml -RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ +RUN mkdir -p src crates/openab-gateway/src \ && echo 'fn main() {}' > src/main.rs \ - && echo '' > crates/openab-core/src/lib.rs \ && echo '' > crates/openab-gateway/src/lib.rs \ && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ && cargo build --release \ - && rm -rf src crates/openab-core/src crates/openab-gateway/src + && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release diff --git a/Dockerfile.hermes b/Dockerfile.hermes index 00dfcb461..30c4da474 100644 --- a/Dockerfile.hermes +++ b/Dockerfile.hermes @@ -2,15 +2,13 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ -COPY crates/openab-core/Cargo.toml crates/openab-core/Cargo.toml COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml -RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ +RUN mkdir -p src crates/openab-gateway/src \ && echo 'fn main() {}' > src/main.rs \ - && echo '' > crates/openab-core/src/lib.rs \ && echo '' > crates/openab-gateway/src/lib.rs \ && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ && cargo build --release \ - && rm -rf src crates/openab-core/src crates/openab-gateway/src + && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release diff --git a/Dockerfile.mimocode b/Dockerfile.mimocode index d02a59120..67b94048e 100644 --- a/Dockerfile.mimocode +++ b/Dockerfile.mimocode @@ -2,15 +2,13 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ -COPY crates/openab-core/Cargo.toml crates/openab-core/Cargo.toml COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml -RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ +RUN mkdir -p src crates/openab-gateway/src \ && echo 'fn main() {}' > src/main.rs \ - && echo '' > crates/openab-core/src/lib.rs \ && echo '' > crates/openab-gateway/src/lib.rs \ && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ && cargo build --release \ - && rm -rf src crates/openab-core/src crates/openab-gateway/src + && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release diff --git a/Dockerfile.opencode b/Dockerfile.opencode index e52a605ee..fb38a5bae 100644 --- a/Dockerfile.opencode +++ b/Dockerfile.opencode @@ -2,15 +2,13 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ -COPY crates/openab-core/Cargo.toml crates/openab-core/Cargo.toml COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml -RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ +RUN mkdir -p src crates/openab-gateway/src \ && echo 'fn main() {}' > src/main.rs \ - && echo '' > crates/openab-core/src/lib.rs \ && echo '' > crates/openab-gateway/src/lib.rs \ && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ && cargo build --release \ - && rm -rf src crates/openab-core/src crates/openab-gateway/src + && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release diff --git a/Dockerfile.pi b/Dockerfile.pi index 206ec38e1..43414429c 100644 --- a/Dockerfile.pi +++ b/Dockerfile.pi @@ -2,15 +2,13 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ -COPY crates/openab-core/Cargo.toml crates/openab-core/Cargo.toml COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml -RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ +RUN mkdir -p src crates/openab-gateway/src \ && echo 'fn main() {}' > src/main.rs \ - && echo '' > crates/openab-core/src/lib.rs \ && echo '' > crates/openab-gateway/src/lib.rs \ && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ && cargo build --release \ - && rm -rf src crates/openab-core/src crates/openab-gateway/src + && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release diff --git a/crates/openab-core/Cargo.toml b/crates/openab-core/Cargo.toml deleted file mode 100644 index 5136e69a3..000000000 --- a/crates/openab-core/Cargo.toml +++ /dev/null @@ -1,54 +0,0 @@ -[package] -name = "openab-core" -version = "0.8.5" -edition = "2021" -license = "MIT" - -[dependencies] -tokio = { version = "1", features = ["full"] } -serde = { version = "1", features = ["derive"] } -serde_json = "1" -toml = "0.8" -toml_edit = "0.22" -tracing = "0.1" -tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] } -serenity = { version = "0.12", default-features = false, features = ["client", "gateway", "model", "rustls_backend", "cache"] } -uuid = { version = "1", features = ["v4"] } -regex = "1" -anyhow = "1" -async-trait = "0.1" -futures-util = "0.3" -rand = "0.8" -clap = { version = "4", features = ["derive"] } -rpassword = "7" -reqwest = { version = "0.12", default-features = false, features = ["rustls-tls", "multipart", "json", "blocking"] } -base64 = "0.22" -image = { version = "0.25", default-features = false, features = ["jpeg", "png", "gif", "webp"] } -unicode-width = "0.2" -pulldown-cmark = { version = "0.13", default-features = false } -tokio-tungstenite = { version = "0.21", features = ["rustls-tls-webpki-roots"] } -rustls = { version = "0.22", optional = true } -tokio-rustls = { version = "0.25", optional = true } -webpki-roots = { version = "0.26", optional = true } -cron = "0.16.0" -chrono = { version = "0.4.44", features = ["serde"] } -chrono-tz = "0.10.4" -sha2 = "0.10" -tempfile = "3.27.0" -aws-sdk-secretsmanager = { version = "1", optional = true } -aws-config = { version = "1", optional = true } -aws-sigv4 = { version = "1", optional = true } -aws-credential-types = { version = "1", optional = true } -urlencoding = { version = "2", optional = true } -hex = { version = "0.4", optional = true } -http = { version = "1", optional = true } - -[target.'cfg(unix)'.dependencies] -libc = "0.2" - -[features] -default = ["discord", "slack", "secrets-aws", "agentcore"] -discord = [] -slack = [] -secrets-aws = ["dep:aws-sdk-secretsmanager", "dep:aws-config"] -agentcore = ["dep:aws-config", "dep:aws-sigv4", "dep:aws-credential-types", "dep:urlencoding", "dep:hex", "dep:http", "dep:rustls", "dep:tokio-rustls", "dep:webpki-roots"] diff --git a/crates/openab-core/src/lib.rs b/crates/openab-core/src/lib.rs deleted file mode 100644 index f61540657..000000000 --- a/crates/openab-core/src/lib.rs +++ /dev/null @@ -1,25 +0,0 @@ -pub mod acp; -pub mod adapter; -pub mod bot_turns; -pub mod config; -pub mod cron; -pub mod directives; -pub mod dispatch; -pub mod error_display; -pub mod format; -pub mod gateway; -pub mod hooks; -pub mod markdown; -pub mod media; -pub mod multibot_cache; -pub mod reactions; -pub mod remind; -pub mod secrets; -pub mod setup; -pub mod stt; -pub mod timestamp; - -#[cfg(feature = "discord")] -pub mod discord; -#[cfg(feature = "slack")] -pub mod slack; diff --git a/crates/openab-core/src/acp/agentcore.rs b/src/acp/agentcore.rs similarity index 100% rename from crates/openab-core/src/acp/agentcore.rs rename to src/acp/agentcore.rs diff --git a/crates/openab-core/src/acp/connection.rs b/src/acp/connection.rs similarity index 100% rename from crates/openab-core/src/acp/connection.rs rename to src/acp/connection.rs diff --git a/crates/openab-core/src/acp/mod.rs b/src/acp/mod.rs similarity index 100% rename from crates/openab-core/src/acp/mod.rs rename to src/acp/mod.rs diff --git a/crates/openab-core/src/acp/pool.rs b/src/acp/pool.rs similarity index 100% rename from crates/openab-core/src/acp/pool.rs rename to src/acp/pool.rs diff --git a/crates/openab-core/src/acp/protocol.rs b/src/acp/protocol.rs similarity index 100% rename from crates/openab-core/src/acp/protocol.rs rename to src/acp/protocol.rs diff --git a/crates/openab-core/src/adapter.rs b/src/adapter.rs similarity index 100% rename from crates/openab-core/src/adapter.rs rename to src/adapter.rs diff --git a/crates/openab-core/src/bot_turns.rs b/src/bot_turns.rs similarity index 100% rename from crates/openab-core/src/bot_turns.rs rename to src/bot_turns.rs diff --git a/crates/openab-core/src/config.rs b/src/config.rs similarity index 100% rename from crates/openab-core/src/config.rs rename to src/config.rs diff --git a/crates/openab-core/src/cron.rs b/src/cron.rs similarity index 100% rename from crates/openab-core/src/cron.rs rename to src/cron.rs diff --git a/crates/openab-core/src/directives.rs b/src/directives.rs similarity index 100% rename from crates/openab-core/src/directives.rs rename to src/directives.rs diff --git a/crates/openab-core/src/discord.rs b/src/discord.rs similarity index 100% rename from crates/openab-core/src/discord.rs rename to src/discord.rs diff --git a/crates/openab-core/src/dispatch.rs b/src/dispatch.rs similarity index 100% rename from crates/openab-core/src/dispatch.rs rename to src/dispatch.rs diff --git a/crates/openab-core/src/error_display.rs b/src/error_display.rs similarity index 100% rename from crates/openab-core/src/error_display.rs rename to src/error_display.rs diff --git a/crates/openab-core/src/format.rs b/src/format.rs similarity index 100% rename from crates/openab-core/src/format.rs rename to src/format.rs diff --git a/crates/openab-core/src/gateway.rs b/src/gateway.rs similarity index 100% rename from crates/openab-core/src/gateway.rs rename to src/gateway.rs diff --git a/crates/openab-core/src/hooks.rs b/src/hooks.rs similarity index 100% rename from crates/openab-core/src/hooks.rs rename to src/hooks.rs diff --git a/src/main.rs b/src/main.rs index 2e66c4206..600028368 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,30 +1,39 @@ -use openab_core::acp; -use openab_core::adapter::{self, AdapterRouter}; -use openab_core::bot_turns; -use openab_core::config; -use openab_core::cron; -#[cfg(feature = "discord")] -use openab_core::discord; -use openab_core::dispatch; -use openab_core::gateway; -use openab_core::hooks; -use openab_core::multibot_cache; -use openab_core::remind; -use openab_core::secrets; -use openab_core::setup; -#[cfg(feature = "slack")] -use openab_core::slack; +mod acp; +mod adapter; +mod bot_turns; +mod config; +mod cron; +mod directives; +mod discord; +mod dispatch; +mod error_display; +mod format; +mod gateway; +mod hooks; +mod markdown; +mod media; +mod multibot_cache; +mod reactions; +mod remind; +mod secrets; +mod setup; +mod slack; +mod stt; +mod timestamp; + +use adapter::AdapterRouter; use clap::Parser; -#[cfg(feature = "discord")] use serenity::gateway::GatewayError; -#[cfg(feature = "discord")] use serenity::prelude::*; use std::collections::HashSet; use std::path::PathBuf; use std::sync::{Arc, Mutex}; use tracing::{error, info, warn}; -/// Wait for SIGINT (ctrl_c) or, on unix, SIGTERM. +/// Wait for SIGINT (ctrl_c) or, on unix, SIGTERM. SIGTERM is what Kubernetes +/// sends during pod termination, so handling it lets us run the full cleanup +/// path (shard manager, ACP pool drain) instead of getting SIGKILL'd after the +/// grace period. async fn shutdown_signal() { #[cfg(unix)] { @@ -196,6 +205,11 @@ async fn main() -> anyhow::Result<()> { // Shutdown signal for Slack adapter let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false); + // Dispatcher handles tracked here so SIGTERM cleanup can call shutdown() on each (ADR §6.8). + // Also shared with the cleanup task for periodic stale-entry sweeping. + // Arc>> because: outer Arc shared with cleanup task + shutdown, + // Mutex guards startup-time pushes, inner Arc shared with each adapter. + // All pushes happen at startup; runtime access is read-only (lock is uncontended). let dispatchers: Arc>>> = Arc::new(Mutex::new(Vec::new())); // Spawn cleanup task @@ -205,25 +219,22 @@ async fn main() -> anyhow::Result<()> { loop { tokio::time::sleep(std::time::Duration::from_secs(60)).await; cleanup_pool.cleanup_idle(ttl_secs).await; + // Sweep stale per-thread dispatcher entries (idle-exited consumers). for d in cleanup_dispatchers.lock().unwrap().iter() { d.sweep_stale(); } } }); - // Pre-build shared adapters for cron scheduler - #[cfg(feature = "discord")] + // Pre-build shared adapters for cron scheduler (avoids duplicate Http clients / rate-limit buckets) let shared_discord_adapter: Option> = cfg.discord.as_ref().map(|dc| { let http = Arc::new(serenity::http::Http::new(&dc.bot_token)); Arc::new(discord::DiscordAdapter::new(http)) as Arc }); - #[cfg(not(feature = "discord"))] - let shared_discord_adapter: Option> = None; - let session_ttl_dur = std::time::Duration::from_secs(ttl_secs); - // Initialize multibot cache + // Initialize multibot cache (persists to $HOME/.openab/cache/threads.json) let multibot_cache_path = std::env::var("HOME") .map(std::path::PathBuf::from) .unwrap_or_default() @@ -232,7 +243,6 @@ async fn main() -> anyhow::Result<()> { .join("threads.json"); let multibot_cache = multibot_cache::MultibotCache::load(multibot_cache_path); - #[cfg(feature = "slack")] let shared_slack_adapter: Option> = cfg.slack.as_ref().map(|s| { Arc::new(slack::SlackAdapter::new( s.bot_token.clone(), @@ -242,10 +252,8 @@ async fn main() -> anyhow::Result<()> { multibot_cache.clone(), )) }); - #[cfg(not(feature = "slack"))] - let shared_slack_adapter: Option> = None; - // Validate cronjob config at startup + // Validate cronjob config at startup (fail-fast on bad cron expressions or timezones) let mut configured_platforms: Vec<&str> = Vec::new(); if cfg.discord.is_some() { configured_platforms.push("discord"); @@ -256,7 +264,6 @@ async fn main() -> anyhow::Result<()> { cron::validate_cronjobs(&cfg.cron.jobs, &configured_platforms)?; // Spawn Slack adapter (background task) - #[cfg(feature = "slack")] let slack_handle = if let Some(slack_cfg) = cfg.slack { let allow_all_channels = config::resolve_allow_all(slack_cfg.allow_all_channels, &slack_cfg.allowed_channels); @@ -281,6 +288,9 @@ async fn main() -> anyhow::Result<()> { let adapter = shared_slack_adapter .clone() .expect("shared_slack_adapter must exist when slack config is present"); + // Dispatcher is the sole serialization path for all modes. Message = cap 1 + // (each message dispatches alone, FIFO). Thread / Lane = configured cap; + // grouping decides whether senders share a buffer or get their own lane. let (slack_cap, slack_grouping, slack_idle) = dispatch::dispatch_params( &slack_cfg.message_processing_mode, slack_cfg.max_buffered_messages, @@ -317,8 +327,6 @@ async fn main() -> anyhow::Result<()> { } else { None }; - #[cfg(not(feature = "slack"))] - let slack_handle: Option> = None; // Spawn Gateway adapter (background task) let gateway_handle = if let Some(gw_cfg) = cfg.gateway { @@ -368,34 +376,14 @@ async fn main() -> anyhow::Result<()> { None }; - // Spawn cron scheduler (background task) - // Spawn embedded webhook server when gateway adapters are compiled in (unified mode). - // In unified mode, platform webhooks hit this axum server directly → Dispatcher.submit(), - // bypassing the WebSocket hop of the two-process model. - #[cfg(any( - feature = "telegram", - feature = "line", - feature = "feishu", - feature = "googlechat", - feature = "wecom", - feature = "teams", - ))] - let _unified_handle = { - // TODO(Phase 1): Wire each compiled-in adapter's webhook handler to axum routes - // and call Dispatcher.submit() directly instead of going through WS gateway. - // For now, the feature compiles the gateway crate (making the code available) - // but the full runtime integration (axum server, route registration, direct dispatch) - // will be completed in a follow-up PR. - warn!("unified gateway features compiled in but runtime integration not yet wired — gateway adapters are NOT active in this binary"); - None::> - }; - + // Spawn cron scheduler (background task) — reuses shared adapters let usercron_path = if cfg.cron.usercron_enabled { cfg.cron.usercron_path.as_ref().map(|p| { let path = std::path::PathBuf::from(p); if path.is_absolute() { path } else { + // Relative paths resolve from $HOME/.openab/ (e.g. "cronjob.toml" → "$HOME/.openab/cronjob.toml") std::env::var("HOME") .map(std::path::PathBuf::from) .unwrap_or_default() @@ -416,7 +404,6 @@ async fn main() -> anyhow::Result<()> { if let Some(ref a) = shared_discord_adapter { cron_adapters.insert("discord".into(), a.clone()); } - #[cfg(feature = "slack")] if let Some(ref a) = shared_slack_adapter { cron_adapters.insert("slack".into(), a.clone() as Arc); } @@ -439,7 +426,6 @@ async fn main() -> anyhow::Result<()> { }; // Run Discord adapter (foreground, blocking) or wait for ctrl_c - #[cfg(feature = "discord")] if let Some(discord_cfg) = cfg.discord { let allow_all_channels = config::resolve_allow_all( discord_cfg.allow_all_channels, @@ -483,7 +469,7 @@ async fn main() -> anyhow::Result<()> { )); dispatchers.lock().unwrap().push(discord_dispatcher.clone()); - // Initialize reminder store + // Initialize reminder store (persists to $HOME/.openab/reminders.json) let reminder_path = std::env::var("HOME") .map(std::path::PathBuf::from) .unwrap_or_default() @@ -526,6 +512,7 @@ async fn main() -> anyhow::Result<()> { .event_handler(handler) .await?; + // Graceful Discord shutdown on ctrl_c let shard_manager = client.shard_manager.clone(); tokio::spawn(async move { shutdown_signal().await; @@ -554,12 +541,7 @@ async fn main() -> anyhow::Result<()> { Ok(_) => {} } } else { - info!("running without discord, press ctrl+c to stop"); - shutdown_signal().await; - info!("shutdown signal received"); - } - #[cfg(not(feature = "discord"))] - { + // No Discord — wait for SIGINT or SIGTERM info!("running without discord, press ctrl+c to stop"); shutdown_signal().await; info!("shutdown signal received"); @@ -567,6 +549,7 @@ async fn main() -> anyhow::Result<()> { // Cleanup cleanup_handle.abort(); + // Signal Slack adapter to shut down gracefully let _ = shutdown_tx.send(true); if let Some(handle) = slack_handle { let _ = tokio::time::timeout(std::time::Duration::from_secs(5), handle).await; @@ -575,13 +558,16 @@ async fn main() -> anyhow::Result<()> { let _ = tokio::time::timeout(std::time::Duration::from_secs(5), handle).await; } if let Some(handle) = cron_handle { + // cron.rs drains in-flight tasks for up to 30s, so wait slightly longer let _ = tokio::time::timeout(std::time::Duration::from_secs(35), handle).await; } + // Drain per-thread dispatchers and log buffered_lost counts before pool shutdown (ADR §6.8). for d in dispatchers.lock().unwrap().iter() { d.shutdown(); } let shutdown_pool = pool; shutdown_pool.shutdown().await; + // Run pre_shutdown hook after pool shutdown to guarantee no active sessions are writing. if let Some(ref hook) = shutdown_hook { if let Err(e) = hooks::run_hook("pre_shutdown", hook).await { error!(error = %e, "pre_shutdown hook failed"); @@ -618,7 +604,7 @@ mod tests { #[test] fn cli_no_args_defaults_to_run() { let cli = Cli::try_parse_from(["openab"]).unwrap(); - assert!(cli.command.is_none()); + assert!(cli.command.is_none()); // None → unwrap_or(Run { config: None }) } #[test] diff --git a/crates/openab-core/src/markdown.rs b/src/markdown.rs similarity index 100% rename from crates/openab-core/src/markdown.rs rename to src/markdown.rs diff --git a/crates/openab-core/src/media.rs b/src/media.rs similarity index 100% rename from crates/openab-core/src/media.rs rename to src/media.rs diff --git a/crates/openab-core/src/multibot_cache.rs b/src/multibot_cache.rs similarity index 100% rename from crates/openab-core/src/multibot_cache.rs rename to src/multibot_cache.rs diff --git a/crates/openab-core/src/reactions.rs b/src/reactions.rs similarity index 100% rename from crates/openab-core/src/reactions.rs rename to src/reactions.rs diff --git a/crates/openab-core/src/remind.rs b/src/remind.rs similarity index 100% rename from crates/openab-core/src/remind.rs rename to src/remind.rs diff --git a/crates/openab-core/src/secrets.rs b/src/secrets.rs similarity index 100% rename from crates/openab-core/src/secrets.rs rename to src/secrets.rs diff --git a/crates/openab-core/src/setup/config.rs b/src/setup/config.rs similarity index 100% rename from crates/openab-core/src/setup/config.rs rename to src/setup/config.rs diff --git a/crates/openab-core/src/setup/mod.rs b/src/setup/mod.rs similarity index 100% rename from crates/openab-core/src/setup/mod.rs rename to src/setup/mod.rs diff --git a/crates/openab-core/src/setup/validate.rs b/src/setup/validate.rs similarity index 100% rename from crates/openab-core/src/setup/validate.rs rename to src/setup/validate.rs diff --git a/crates/openab-core/src/setup/wizard.rs b/src/setup/wizard.rs similarity index 100% rename from crates/openab-core/src/setup/wizard.rs rename to src/setup/wizard.rs diff --git a/crates/openab-core/src/slack.rs b/src/slack.rs similarity index 100% rename from crates/openab-core/src/slack.rs rename to src/slack.rs diff --git a/crates/openab-core/src/stt.rs b/src/stt.rs similarity index 100% rename from crates/openab-core/src/stt.rs rename to src/stt.rs diff --git a/crates/openab-core/src/timestamp.rs b/src/timestamp.rs similarity index 100% rename from crates/openab-core/src/timestamp.rs rename to src/timestamp.rs From b651b13a98c77401fb78282624e8cfd1303e65ed Mon Sep 17 00:00:00 2001 From: chaodu-agent Date: Thu, 18 Jun 2026 23:09:24 +0000 Subject: [PATCH 15/20] feat: add discord/slack feature gates (Phase 2) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Discord and Slack are now behind feature flags: - default = ["discord", "slack", "secrets-aws", "agentcore"] - serenity is optional, only pulled in by discord feature This enables: cargo build --no-default-features --features agentcore,telegram → OAB + Telegram only (no Discord, no Slack, no serenity dep) cargo build --features telegram → Discord + Slack + Telegram (default + telegram) cargo build --no-default-features --features discord,agentcore,telegram → Discord + Telegram (no Slack) --- Cargo.toml | 6 ++++-- src/main.rs | 22 ++++++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f631630d2..f8009ca35 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,7 @@ toml = "0.8" toml_edit = "0.22" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] } -serenity = { version = "0.12", default-features = false, features = ["client", "gateway", "model", "rustls_backend", "cache"] } +serenity = { version = "0.12", default-features = false, features = ["client", "gateway", "model", "rustls_backend", "cache"], optional = true } uuid = { version = "1", features = ["v4"] } regex = "1" anyhow = "1" @@ -50,7 +50,9 @@ http = { version = "1", optional = true } openab-gateway = { path = "crates/openab-gateway", default-features = false, optional = true } [features] -default = ["secrets-aws", "agentcore"] +default = ["discord", "slack", "secrets-aws", "agentcore"] +discord = ["dep:serenity"] +slack = [] secrets-aws = ["dep:aws-sdk-secretsmanager", "dep:aws-config"] agentcore = ["dep:aws-config", "dep:aws-sigv4", "dep:aws-credential-types", "dep:urlencoding", "dep:hex", "dep:http", "dep:rustls", "dep:tokio-rustls", "dep:webpki-roots"] diff --git a/src/main.rs b/src/main.rs index 600028368..3bc00bad4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,6 +4,7 @@ mod bot_turns; mod config; mod cron; mod directives; +#[cfg(feature = "discord")] mod discord; mod dispatch; mod error_display; @@ -17,13 +18,16 @@ mod reactions; mod remind; mod secrets; mod setup; +#[cfg(feature = "slack")] mod slack; mod stt; mod timestamp; use adapter::AdapterRouter; use clap::Parser; +#[cfg(feature = "discord")] use serenity::gateway::GatewayError; +#[cfg(feature = "discord")] use serenity::prelude::*; use std::collections::HashSet; use std::path::PathBuf; @@ -227,11 +231,14 @@ async fn main() -> anyhow::Result<()> { }); // Pre-build shared adapters for cron scheduler (avoids duplicate Http clients / rate-limit buckets) + #[cfg(feature = "discord")] let shared_discord_adapter: Option> = cfg.discord.as_ref().map(|dc| { let http = Arc::new(serenity::http::Http::new(&dc.bot_token)); Arc::new(discord::DiscordAdapter::new(http)) as Arc }); + #[cfg(not(feature = "discord"))] + let shared_discord_adapter: Option> = None; let session_ttl_dur = std::time::Duration::from_secs(ttl_secs); // Initialize multibot cache (persists to $HOME/.openab/cache/threads.json) @@ -243,6 +250,7 @@ async fn main() -> anyhow::Result<()> { .join("threads.json"); let multibot_cache = multibot_cache::MultibotCache::load(multibot_cache_path); + #[cfg(feature = "slack")] let shared_slack_adapter: Option> = cfg.slack.as_ref().map(|s| { Arc::new(slack::SlackAdapter::new( s.bot_token.clone(), @@ -252,6 +260,8 @@ async fn main() -> anyhow::Result<()> { multibot_cache.clone(), )) }); + #[cfg(not(feature = "slack"))] + let shared_slack_adapter: Option> = None; // Validate cronjob config at startup (fail-fast on bad cron expressions or timezones) let mut configured_platforms: Vec<&str> = Vec::new(); @@ -264,6 +274,7 @@ async fn main() -> anyhow::Result<()> { cron::validate_cronjobs(&cfg.cron.jobs, &configured_platforms)?; // Spawn Slack adapter (background task) + #[cfg(feature = "slack")] let slack_handle = if let Some(slack_cfg) = cfg.slack { let allow_all_channels = config::resolve_allow_all(slack_cfg.allow_all_channels, &slack_cfg.allowed_channels); @@ -327,6 +338,8 @@ async fn main() -> anyhow::Result<()> { } else { None }; + #[cfg(not(feature = "slack"))] + let slack_handle: Option> = None; // Spawn Gateway adapter (background task) let gateway_handle = if let Some(gw_cfg) = cfg.gateway { @@ -401,9 +414,11 @@ async fn main() -> anyhow::Result<()> { let cron_router = router.clone(); let mut cron_adapters: std::collections::HashMap> = std::collections::HashMap::new(); + #[cfg(feature = "discord")] if let Some(ref a) = shared_discord_adapter { cron_adapters.insert("discord".into(), a.clone()); } + #[cfg(feature = "slack")] if let Some(ref a) = shared_slack_adapter { cron_adapters.insert("slack".into(), a.clone() as Arc); } @@ -426,6 +441,7 @@ async fn main() -> anyhow::Result<()> { }; // Run Discord adapter (foreground, blocking) or wait for ctrl_c + #[cfg(feature = "discord")] if let Some(discord_cfg) = cfg.discord { let allow_all_channels = config::resolve_allow_all( discord_cfg.allow_all_channels, @@ -546,6 +562,12 @@ async fn main() -> anyhow::Result<()> { shutdown_signal().await; info!("shutdown signal received"); } + #[cfg(not(feature = "discord"))] + { + info!("running without discord, press ctrl+c to stop"); + shutdown_signal().await; + info!("shutdown signal received"); + } // Cleanup cleanup_handle.abort(); From f2ea6cf68e96f0142af168ea5fd857728ba1f9b7 Mon Sep 17 00:00:00 2001 From: chaodufashi Date: Fri, 19 Jun 2026 00:19:40 +0000 Subject: [PATCH 16/20] =?UTF-8?q?fix:=20address=20review=20findings=20?= =?UTF-8?q?=E2=80=94=20fix=20Docker=20builds,=20add=20CI=20coverage?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Remove stale crates/openab-core references from all Dockerfiles (fixes Docker smoke test failures after workspace simplification) - Add `cargo clippy --features unified` to CI for feature matrix coverage - Fix stale gateway/README.md link in build-gateway.yml - Document per-adapter feature intended usage in Cargo.toml --- .github/workflows/build-gateway.yml | 2 +- .github/workflows/ci.yml | 2 ++ Cargo.toml | 6 +++++- Dockerfile | 2 +- Dockerfile.agentcore | 9 ++++----- Dockerfile.antigravity | 2 +- Dockerfile.claude | 2 +- Dockerfile.codex | 2 +- Dockerfile.copilot | 2 +- Dockerfile.cursor | 2 +- Dockerfile.gateway | 7 +++---- Dockerfile.gemini | 2 +- Dockerfile.grok | 2 +- Dockerfile.hermes | 2 +- Dockerfile.mimocode | 2 +- Dockerfile.native | 9 ++++----- Dockerfile.opencode | 2 +- Dockerfile.pi | 2 +- 18 files changed, 31 insertions(+), 28 deletions(-) diff --git a/.github/workflows/build-gateway.yml b/.github/workflows/build-gateway.yml index 9c46ffa50..49f230abf 100644 --- a/.github/workflows/build-gateway.yml +++ b/.github/workflows/build-gateway.yml @@ -160,7 +160,7 @@ jobs: ### Links - - [Gateway README](https://github.com/openabdev/openab/blob/main/gateway/README.md) + - [Gateway Crate](https://github.com/openabdev/openab/tree/main/crates/openab-gateway) - [ADR: Custom Gateway](https://github.com/openabdev/openab/blob/main/docs/adr/custom-gateway.md) EOF sed -i 's/^ //' /tmp/release-notes.md diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9b7a7d337..fd5a9e13e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -45,6 +45,8 @@ jobs: run: cargo check --workspace - name: cargo clippy run: cargo clippy --workspace -- -D warnings + - name: cargo clippy (unified) + run: cargo clippy --workspace --features unified -- -D warnings - name: cargo test run: cargo test --workspace diff --git a/Cargo.toml b/Cargo.toml index f8009ca35..5bdbdd60f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -59,7 +59,11 @@ agentcore = ["dep:aws-config", "dep:aws-sigv4", "dep:aws-credential-types", "dep # Opt-in: compile all gateway adapters into a single unified binary unified = ["telegram", "line", "feishu", "googlechat", "wecom", "teams"] -# Gateway adapters (each pulls in the gateway crate) +# Gateway adapters (each pulls in the gateway crate). +# These are meant to be used WITH default features (e.g. `--features telegram`) +# or via the `unified` shortcut. Using `--no-default-features --features telegram` +# gives you ONLY the gateway adapter without Discord/Slack — intentional for +# single-adapter deployments but not the common case. telegram = ["dep:openab-gateway", "openab-gateway/telegram"] line = ["dep:openab-gateway", "openab-gateway/line"] feishu = ["dep:openab-gateway", "openab-gateway/feishu"] diff --git a/Dockerfile b/Dockerfile index ff831bac8..cee571bb7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -17,7 +17,7 @@ RUN mkdir -p src crates/openab-gateway/src \ && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && \ +RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && \ if [ "$BUILD_MODE" = "unified" ]; then \ cargo build --release --features unified; \ elif [ -n "$FEATURES" ]; then \ diff --git a/Dockerfile.agentcore b/Dockerfile.agentcore index 8b1c80727..f5cecad37 100644 --- a/Dockerfile.agentcore +++ b/Dockerfile.agentcore @@ -6,18 +6,17 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ -COPY crates/openab-core/Cargo.toml crates/openab-core/Cargo.toml COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml -RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ +RUN mkdir -p src crates/openab-gateway/src \ && echo 'fn main() {}' > src/main.rs \ - && echo '' > crates/openab-core/src/lib.rs \ + && echo '' > \ && echo '' > crates/openab-gateway/src/lib.rs \ && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ && cargo build --release --features agentcore \ - && rm -rf src crates/openab-core/src crates/openab-gateway/src + && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release --features agentcore +RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release --features agentcore # --- Runtime stage --- FROM debian:bookworm-slim diff --git a/Dockerfile.antigravity b/Dockerfile.antigravity index aa388f4a9..89d0d11e4 100644 --- a/Dockerfile.antigravity +++ b/Dockerfile.antigravity @@ -11,7 +11,7 @@ RUN mkdir -p src crates/openab-gateway/src \ && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release +RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release # --- Build agy-acp adapter --- FROM rust:1-bookworm AS adapter-builder diff --git a/Dockerfile.claude b/Dockerfile.claude index b55cd7fd0..1010987aa 100644 --- a/Dockerfile.claude +++ b/Dockerfile.claude @@ -11,7 +11,7 @@ RUN mkdir -p src crates/openab-gateway/src \ && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release +RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release # --- Runtime stage --- FROM node:22-bookworm-slim diff --git a/Dockerfile.codex b/Dockerfile.codex index 079878859..821e55fb4 100644 --- a/Dockerfile.codex +++ b/Dockerfile.codex @@ -11,7 +11,7 @@ RUN mkdir -p src crates/openab-gateway/src \ && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release +RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release # --- Runtime stage --- FROM node:22-bookworm-slim diff --git a/Dockerfile.copilot b/Dockerfile.copilot index 110886c28..166a759a2 100644 --- a/Dockerfile.copilot +++ b/Dockerfile.copilot @@ -11,7 +11,7 @@ RUN mkdir -p src crates/openab-gateway/src \ && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release +RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release # --- Runtime stage --- FROM node:22-bookworm-slim diff --git a/Dockerfile.cursor b/Dockerfile.cursor index 1a920cfce..4aea2b2a0 100644 --- a/Dockerfile.cursor +++ b/Dockerfile.cursor @@ -11,7 +11,7 @@ RUN mkdir -p src crates/openab-gateway/src \ && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release +RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release # --- Runtime stage --- FROM debian:bookworm-slim diff --git a/Dockerfile.gateway b/Dockerfile.gateway index e55a82163..bd5891327 100644 --- a/Dockerfile.gateway +++ b/Dockerfile.gateway @@ -2,16 +2,15 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ -COPY crates/openab-core/Cargo.toml crates/openab-core/Cargo.toml COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml -RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ +RUN mkdir -p src crates/openab-gateway/src \ && echo 'fn main() {}' > src/main.rs \ - && echo '' > crates/openab-core/src/lib.rs \ + && echo '' > \ && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ && echo '' > crates/openab-gateway/src/lib.rs \ && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ && cargo build --release -p openab-gateway \ - && rm -rf src crates/openab-core/src crates/openab-gateway/src + && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ RUN touch crates/openab-gateway/src/main.rs crates/openab-gateway/src/lib.rs && cargo build --release -p openab-gateway diff --git a/Dockerfile.gemini b/Dockerfile.gemini index 2f76e8148..ec9778cdd 100644 --- a/Dockerfile.gemini +++ b/Dockerfile.gemini @@ -11,7 +11,7 @@ RUN mkdir -p src crates/openab-gateway/src \ && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release +RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release # --- Runtime stage --- FROM node:22-bookworm-slim diff --git a/Dockerfile.grok b/Dockerfile.grok index 3daf1d5d5..511fcaf34 100644 --- a/Dockerfile.grok +++ b/Dockerfile.grok @@ -11,7 +11,7 @@ RUN mkdir -p src crates/openab-gateway/src \ && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release +RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release # --- Runtime stage --- FROM debian:bookworm-slim diff --git a/Dockerfile.hermes b/Dockerfile.hermes index 30c4da474..ae3025e65 100644 --- a/Dockerfile.hermes +++ b/Dockerfile.hermes @@ -11,7 +11,7 @@ RUN mkdir -p src crates/openab-gateway/src \ && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release +RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release # --- Runtime stage --- FROM python:3.12-slim-bookworm diff --git a/Dockerfile.mimocode b/Dockerfile.mimocode index 67b94048e..79a21e388 100644 --- a/Dockerfile.mimocode +++ b/Dockerfile.mimocode @@ -11,7 +11,7 @@ RUN mkdir -p src crates/openab-gateway/src \ && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release +RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release # --- Runtime stage --- # MiMo-Code (https://github.com/XiaomiMiMo/MiMo-Code) is a fork of OpenCode diff --git a/Dockerfile.native b/Dockerfile.native index 4b69794c5..809d14379 100644 --- a/Dockerfile.native +++ b/Dockerfile.native @@ -2,19 +2,18 @@ FROM rust:1-bookworm AS builder WORKDIR /build COPY Cargo.toml Cargo.lock ./ -COPY crates/openab-core/Cargo.toml crates/openab-core/Cargo.toml COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml COPY openab-agent/ openab-agent/ -RUN mkdir -p src crates/openab-core/src crates/openab-gateway/src \ +RUN mkdir -p src crates/openab-gateway/src \ && echo 'fn main() {}' > src/main.rs \ - && echo '' > crates/openab-core/src/lib.rs \ + && echo '' > \ && echo '' > crates/openab-gateway/src/lib.rs \ && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ && cargo build --release \ - && rm -rf src crates/openab-core/src crates/openab-gateway/src + && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release +RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release RUN cd openab-agent && cargo build --release # --- Runtime stage --- diff --git a/Dockerfile.opencode b/Dockerfile.opencode index fb38a5bae..05f3cf8c2 100644 --- a/Dockerfile.opencode +++ b/Dockerfile.opencode @@ -11,7 +11,7 @@ RUN mkdir -p src crates/openab-gateway/src \ && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release +RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release # --- Runtime stage --- # node:22-bookworm-slim mirrors the base image used by Dockerfile.claude, diff --git a/Dockerfile.pi b/Dockerfile.pi index 43414429c..697fce4b4 100644 --- a/Dockerfile.pi +++ b/Dockerfile.pi @@ -11,7 +11,7 @@ RUN mkdir -p src crates/openab-gateway/src \ && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs crates/openab-core/src/lib.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release +RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release # --- Runtime stage --- FROM node:22-bookworm-slim From b4a102ad3d1aa5d24f3d1370d722e4f7a5548fc0 Mon Sep 17 00:00:00 2001 From: chaodufashi Date: Fri, 19 Jun 2026 01:55:50 +0000 Subject: [PATCH 17/20] fix: remove broken 'echo > \' lines in Dockerfiles MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The dependency-cache RUN steps had stray 'echo '' > \' lines with no target path, which would fail docker build. Remove the broken lines — the subsequent echo already writes the empty lib.rs stub correctly. --- Dockerfile.agentcore | 1 - Dockerfile.gateway | 2 -- Dockerfile.native | 1 - 3 files changed, 4 deletions(-) diff --git a/Dockerfile.agentcore b/Dockerfile.agentcore index f5cecad37..4ff90ce2c 100644 --- a/Dockerfile.agentcore +++ b/Dockerfile.agentcore @@ -9,7 +9,6 @@ COPY Cargo.toml Cargo.lock ./ COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml RUN mkdir -p src crates/openab-gateway/src \ && echo 'fn main() {}' > src/main.rs \ - && echo '' > \ && echo '' > crates/openab-gateway/src/lib.rs \ && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ && cargo build --release --features agentcore \ diff --git a/Dockerfile.gateway b/Dockerfile.gateway index bd5891327..015bf2780 100644 --- a/Dockerfile.gateway +++ b/Dockerfile.gateway @@ -5,8 +5,6 @@ COPY Cargo.toml Cargo.lock ./ COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml RUN mkdir -p src crates/openab-gateway/src \ && echo 'fn main() {}' > src/main.rs \ - && echo '' > \ - && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ && echo '' > crates/openab-gateway/src/lib.rs \ && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ && cargo build --release -p openab-gateway \ diff --git a/Dockerfile.native b/Dockerfile.native index 809d14379..a1d5e02c9 100644 --- a/Dockerfile.native +++ b/Dockerfile.native @@ -6,7 +6,6 @@ COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml COPY openab-agent/ openab-agent/ RUN mkdir -p src crates/openab-gateway/src \ && echo 'fn main() {}' > src/main.rs \ - && echo '' > \ && echo '' > crates/openab-gateway/src/lib.rs \ && echo 'fn main() {}' > crates/openab-gateway/src/main.rs \ && cargo build --release \ From 673ab3e543538b76f708c1f632a144e22e243830 Mon Sep 17 00:00:00 2001 From: chaodufashi Date: Fri, 19 Jun 2026 02:04:39 +0000 Subject: [PATCH 18/20] fix: gate AppState fields with #[cfg] and add gateway README - Add #[cfg(feature = "telegram")] to telegram_* fields and env reads - Add #[cfg(feature = "line")] to line_* fields, reply_token_cache, line_webhook_semaphore, and the token sweep background task - Add #[cfg(feature = "teams")] to teams_service_urls and its cleanup task - Remove unused cfg(not) teams variable - Add crates/openab-gateway/README.md (replaces deleted gateway/README.md) --- crates/openab-gateway/README.md | 142 ++++++++++++++++++++++++++++++ crates/openab-gateway/src/lib.rs | 8 ++ crates/openab-gateway/src/main.rs | 18 +++- 3 files changed, 166 insertions(+), 2 deletions(-) create mode 100644 crates/openab-gateway/README.md diff --git a/crates/openab-gateway/README.md b/crates/openab-gateway/README.md new file mode 100644 index 000000000..8fd4465ea --- /dev/null +++ b/crates/openab-gateway/README.md @@ -0,0 +1,142 @@ +# OpenAB Gateway + +A standalone service that bridges webhook-based platforms and custom event sources to OAB via WebSocket. OAB connects outbound to the gateway — no inbound ports or TLS required on OAB. + +``` + External (HTTPS) Internal (cluster) + ──────────────── ────────────────── + +Telegram ──POST──▶┌─────────────────────┐ +LINE ──POST──▶│ │ +Feishu ──POST──▶│ OpenAB Gateway │◀──WebSocket── OAB Pod +Google ──POST──▶│ :8080 │ (OAB connects out) +WeCom ──POST──▶│ │ +Teams ──POST──▶│ │ + └─────────────────────┘ + +Discord ◀──WebSocket── OAB Pod (unchanged, direct) +Slack ◀──WebSocket── OAB Pod (unchanged, direct) +``` + +The gateway normalizes all inbound events to a unified schema (`openab.gateway.event.v1`), forwards them to OAB over WebSocket, and routes OAB replies back to the originating platform API. + +For architecture details, see [ADR: Custom Gateway](../../docs/adr/custom-gateway.md). + +--- + +## Build Modes + +This crate can run as a **standalone binary** or be compiled into the **unified OAB binary**. + +### Standalone + +```bash +cargo build --release -p openab-gateway +``` + +### Unified (all adapters in one binary) + +```bash +cargo build --release --features unified +``` + +### Single adapter only + +```bash +cargo build --release --no-default-features --features telegram +``` + +### Docker + +```bash +# Standalone gateway +docker build -f Dockerfile.gateway -t openab-gateway . + +# Unified binary +docker build --build-arg BUILD_MODE=unified -t openab:unified . + +# Custom adapter selection +docker build --build-arg FEATURES=telegram,line -t openab:custom . +``` + +--- + +## Environment Variables + +| Variable | Default | Description | +|---|---|---| +| `GATEWAY_LISTEN` | `0.0.0.0:8080` | Listen address | +| `GATEWAY_WS_TOKEN` | (optional) | Token for WebSocket authentication | +| `TELEGRAM_BOT_TOKEN` | (optional) | Telegram Bot API token | +| `TELEGRAM_SECRET_TOKEN` | (optional) | Webhook secret for request validation | +| `TELEGRAM_WEBHOOK_PATH` | `/webhook/telegram` | Webhook endpoint path | +| `TELEGRAM_RICH_MESSAGES` | `true` | Enable Markdown formatting | +| `LINE_CHANNEL_SECRET` | (optional) | LINE channel secret for HMAC verification | +| `LINE_CHANNEL_ACCESS_TOKEN` | (optional) | LINE channel access token | +| `FEISHU_APP_ID` | (optional) | Feishu/Lark App ID | +| `FEISHU_APP_SECRET` | (optional) | Feishu/Lark App Secret | +| `FEISHU_DOMAIN` | `feishu` | `feishu` (China) or `lark` (international) | +| `FEISHU_CONNECTION_MODE` | `websocket` | `websocket` or `webhook` | +| `FEISHU_WEBHOOK_PATH` | `/webhook/feishu` | Webhook endpoint path | +| `GOOGLE_CHAT_ENABLED` | `false` | Set to `true` to enable Google Chat | +| `GOOGLE_CHAT_AUDIENCE` | (optional) | JWT audience for webhook verification | +| `GOOGLE_CHAT_SA_KEY_JSON` | (optional) | Service account key JSON | +| `GOOGLE_CHAT_WEBHOOK_PATH` | `/webhook/googlechat` | Webhook endpoint path | +| `WECOM_CORP_ID` | (required*) | WeCom Corp ID | +| `WECOM_AGENT_ID` | (required*) | WeCom App Agent ID | +| `WECOM_SECRET` | (required*) | WeCom App Secret | +| `WECOM_TOKEN` | (required*) | Callback verification Token | +| `WECOM_ENCODING_AES_KEY` | (required*) | Callback EncodingAESKey | +| `WECOM_WEBHOOK_PATH` | `/webhook/wecom` | Webhook endpoint path | +| `TEAMS_APP_ID` | (optional) | Microsoft Teams App ID | +| `TEAMS_APP_PASSWORD` | (optional) | Microsoft Teams App Password | +| `TEAMS_WEBHOOK_PATH` | `/webhook/teams` | Webhook endpoint path | + +--- + +## Endpoints + +| Path | Description | +|---|---| +| `GET /ws` | WebSocket server (OAB connects here) | +| `GET /health` | Health check | +| `POST /webhook/telegram` | Telegram webhook receiver | +| `POST /webhook/line` | LINE webhook receiver | +| `POST /webhook/feishu` | Feishu webhook receiver | +| `POST /webhook/googlechat` | Google Chat webhook receiver | +| `GET /webhook/wecom` | WeCom callback URL verification | +| `POST /webhook/wecom` | WeCom message callback receiver | +| `POST /webhook/teams` | Microsoft Teams webhook receiver | + +--- + +## OAB Config + +```toml +[gateway] +url = "ws://gateway:8080/ws" +``` + +--- + +## Platform Setup + +- [Telegram](../../docs/telegram.md) +- [LINE](../../docs/line.md) +- [Feishu/Lark](../../docs/feishu.md) +- [Google Chat](../../docs/google-chat.md) +- [WeCom](../../docs/wecom.md) + +--- + +## Feature Flags + +| Flag | Effect | +|------|--------| +| `default` | All adapters enabled | +| `telegram` | Telegram adapter only | +| `line` | LINE adapter only | +| `feishu` | Feishu/Lark adapter only | +| `googlechat` | Google Chat adapter only | +| `wecom` | WeCom adapter only | +| `teams` | Microsoft Teams adapter only | diff --git a/crates/openab-gateway/src/lib.rs b/crates/openab-gateway/src/lib.rs index c67f9fc85..91d2cda53 100644 --- a/crates/openab-gateway/src/lib.rs +++ b/crates/openab-gateway/src/lib.rs @@ -25,13 +25,19 @@ pub const LINE_WEBHOOK_CONCURRENCY_MAX: usize = 8; // --- App state (shared across all adapters) --- pub struct AppState { + #[cfg(feature = "telegram")] pub telegram_bot_token: Option, + #[cfg(feature = "telegram")] pub telegram_secret_token: Option, + #[cfg(feature = "telegram")] pub telegram_rich_messages: bool, + #[cfg(feature = "line")] pub line_channel_secret: Option, + #[cfg(feature = "line")] pub line_access_token: Option, #[cfg(feature = "teams")] pub teams: Option, + #[cfg(feature = "teams")] pub teams_service_urls: Mutex>, #[cfg(feature = "feishu")] pub feishu: Option, @@ -41,7 +47,9 @@ pub struct AppState { pub wecom: Option, pub ws_token: Option, pub event_tx: broadcast::Sender, + #[cfg(feature = "line")] pub reply_token_cache: ReplyTokenCache, + #[cfg(feature = "line")] pub line_webhook_semaphore: Arc, pub client: reqwest::Client, } diff --git a/crates/openab-gateway/src/main.rs b/crates/openab-gateway/src/main.rs index f0dc8b3c7..94e6554e1 100644 --- a/crates/openab-gateway/src/main.rs +++ b/crates/openab-gateway/src/main.rs @@ -177,6 +177,7 @@ async fn main() -> Result<()> { } let (event_tx, _) = broadcast::channel::(256); + #[cfg(feature = "line")] let reply_token_cache: ReplyTokenCache = Arc::new(std::sync::Mutex::new(HashMap::new())); let mut app = Router::new() @@ -184,8 +185,11 @@ async fn main() -> Result<()> { .route("/health", get(health)); // Telegram adapter + #[cfg(feature = "telegram")] let telegram_bot_token = std::env::var("TELEGRAM_BOT_TOKEN").ok(); + #[cfg(feature = "telegram")] let telegram_secret_token = std::env::var("TELEGRAM_SECRET_TOKEN").ok(); + #[cfg(feature = "telegram")] let telegram_rich_messages = std::env::var("TELEGRAM_RICH_MESSAGES") .map(|v| v != "0" && !v.eq_ignore_ascii_case("false")) .unwrap_or(true); @@ -201,7 +205,9 @@ async fn main() -> Result<()> { } // LINE adapter + #[cfg(feature = "line")] let line_channel_secret = std::env::var("LINE_CHANNEL_SECRET").ok(); + #[cfg(feature = "line")] let line_access_token = std::env::var("LINE_CHANNEL_ACCESS_TOKEN").ok(); #[cfg(feature = "line")] { @@ -217,8 +223,6 @@ async fn main() -> Result<()> { info!(path = %webhook_path, "teams adapter enabled"); adapters::teams::TeamsAdapter::new(config) }); - #[cfg(not(feature = "teams"))] - let teams: Option<()> = None; #[cfg(feature = "teams")] if teams.is_some() { @@ -316,13 +320,19 @@ async fn main() -> Result<()> { .expect("HTTP client must build"); let state = Arc::new(AppState { + #[cfg(feature = "telegram")] telegram_bot_token, + #[cfg(feature = "telegram")] telegram_secret_token, + #[cfg(feature = "telegram")] telegram_rich_messages, + #[cfg(feature = "line")] line_channel_secret, + #[cfg(feature = "line")] line_access_token, #[cfg(feature = "teams")] teams, + #[cfg(feature = "teams")] teams_service_urls: Mutex::new(HashMap::new()), #[cfg(feature = "feishu")] feishu, @@ -332,12 +342,15 @@ async fn main() -> Result<()> { wecom, ws_token, event_tx, + #[cfg(feature = "line")] reply_token_cache, + #[cfg(feature = "line")] line_webhook_semaphore: Arc::new(Semaphore::new(LINE_WEBHOOK_CONCURRENCY_MAX)), client, }); // Background: sweep expired reply tokens + #[cfg(feature = "line")] { let cache_state = state.clone(); tokio::spawn(async move { @@ -362,6 +375,7 @@ async fn main() -> Result<()> { } // Background: cleanup stale Teams service_url entries (TTL: 4 hours) + #[cfg(feature = "teams")] { let state_for_cleanup = state.clone(); tokio::spawn(async move { From 88c907c6a4d9cbdb055b4ee464b0622f9cb7ed2a Mon Sep 17 00:00:00 2001 From: chaodufashi Date: Fri, 19 Jun 2026 02:28:29 +0000 Subject: [PATCH 19/20] feat: add FEATURES build-arg to all Dockerfiles All Dockerfile.* now accept an optional FEATURES build-arg to enable additional adapters at build time: docker build -f Dockerfile.claude --build-arg FEATURES=telegram -t openab-claude:tg . For the standard images (claude, codex, etc.) FEATURES is additive to defaults. For Dockerfile.gateway, FEATURES replaces the default (all adapters). For Dockerfile.agentcore, FEATURES is combined with the required agentcore feature. --- Dockerfile.agentcore | 9 ++++++++- Dockerfile.antigravity | 9 ++++++++- Dockerfile.claude | 9 ++++++++- Dockerfile.codex | 9 ++++++++- Dockerfile.copilot | 9 ++++++++- Dockerfile.cursor | 9 ++++++++- Dockerfile.gateway | 9 ++++++++- Dockerfile.gemini | 9 ++++++++- Dockerfile.grok | 9 ++++++++- Dockerfile.hermes | 9 ++++++++- Dockerfile.mimocode | 9 ++++++++- Dockerfile.native | 9 ++++++++- Dockerfile.opencode | 9 ++++++++- Dockerfile.pi | 9 ++++++++- 14 files changed, 112 insertions(+), 14 deletions(-) diff --git a/Dockerfile.agentcore b/Dockerfile.agentcore index 4ff90ce2c..8d4f265f6 100644 --- a/Dockerfile.agentcore +++ b/Dockerfile.agentcore @@ -3,7 +3,9 @@ # Result: ~20MB image — single Rust binary, no Python/pip needed. # --- Build stage --- +ARG FEATURES="" FROM rust:1-bookworm AS builder +ARG FEATURES WORKDIR /build COPY Cargo.toml Cargo.lock ./ COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml @@ -15,7 +17,12 @@ RUN mkdir -p src crates/openab-gateway/src \ && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release --features agentcore +RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && \ + if [ -n "$FEATURES" ]; then \ + cargo build --release --features "agentcore,$FEATURES"; \ + else \ + cargo build --release --features agentcore; \ + fi # --- Runtime stage --- FROM debian:bookworm-slim diff --git a/Dockerfile.antigravity b/Dockerfile.antigravity index 89d0d11e4..75e00bf74 100644 --- a/Dockerfile.antigravity +++ b/Dockerfile.antigravity @@ -1,5 +1,7 @@ # --- Build openab --- +ARG FEATURES="" FROM rust:1-bookworm AS builder +ARG FEATURES WORKDIR /build COPY Cargo.toml Cargo.lock ./ COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml @@ -11,7 +13,12 @@ RUN mkdir -p src crates/openab-gateway/src \ && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release +RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && \ + if [ -n "$FEATURES" ]; then \ + cargo build --release --features "$FEATURES"; \ + else \ + cargo build --release; \ + fi # --- Build agy-acp adapter --- FROM rust:1-bookworm AS adapter-builder diff --git a/Dockerfile.claude b/Dockerfile.claude index 1010987aa..52393d064 100644 --- a/Dockerfile.claude +++ b/Dockerfile.claude @@ -1,5 +1,7 @@ # --- Build stage --- +ARG FEATURES="" FROM rust:1-bookworm AS builder +ARG FEATURES WORKDIR /build COPY Cargo.toml Cargo.lock ./ COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml @@ -11,7 +13,12 @@ RUN mkdir -p src crates/openab-gateway/src \ && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release +RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && \ + if [ -n "$FEATURES" ]; then \ + cargo build --release --features "$FEATURES"; \ + else \ + cargo build --release; \ + fi # --- Runtime stage --- FROM node:22-bookworm-slim diff --git a/Dockerfile.codex b/Dockerfile.codex index 821e55fb4..9e516b047 100644 --- a/Dockerfile.codex +++ b/Dockerfile.codex @@ -1,5 +1,7 @@ # --- Build stage --- +ARG FEATURES="" FROM rust:1-bookworm AS builder +ARG FEATURES WORKDIR /build COPY Cargo.toml Cargo.lock ./ COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml @@ -11,7 +13,12 @@ RUN mkdir -p src crates/openab-gateway/src \ && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release +RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && \ + if [ -n "$FEATURES" ]; then \ + cargo build --release --features "$FEATURES"; \ + else \ + cargo build --release; \ + fi # --- Runtime stage --- FROM node:22-bookworm-slim diff --git a/Dockerfile.copilot b/Dockerfile.copilot index 166a759a2..36e53d3d5 100644 --- a/Dockerfile.copilot +++ b/Dockerfile.copilot @@ -1,5 +1,7 @@ # --- Build stage --- +ARG FEATURES="" FROM rust:1-bookworm AS builder +ARG FEATURES WORKDIR /build COPY Cargo.toml Cargo.lock ./ COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml @@ -11,7 +13,12 @@ RUN mkdir -p src crates/openab-gateway/src \ && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release +RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && \ + if [ -n "$FEATURES" ]; then \ + cargo build --release --features "$FEATURES"; \ + else \ + cargo build --release; \ + fi # --- Runtime stage --- FROM node:22-bookworm-slim diff --git a/Dockerfile.cursor b/Dockerfile.cursor index 4aea2b2a0..d22463590 100644 --- a/Dockerfile.cursor +++ b/Dockerfile.cursor @@ -1,5 +1,7 @@ # --- Build stage --- +ARG FEATURES="" FROM rust:1-bookworm AS builder +ARG FEATURES WORKDIR /build COPY Cargo.toml Cargo.lock ./ COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml @@ -11,7 +13,12 @@ RUN mkdir -p src crates/openab-gateway/src \ && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release +RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && \ + if [ -n "$FEATURES" ]; then \ + cargo build --release --features "$FEATURES"; \ + else \ + cargo build --release; \ + fi # --- Runtime stage --- FROM debian:bookworm-slim diff --git a/Dockerfile.gateway b/Dockerfile.gateway index 015bf2780..b7015e253 100644 --- a/Dockerfile.gateway +++ b/Dockerfile.gateway @@ -1,5 +1,7 @@ # --- Build stage --- +ARG FEATURES="" FROM rust:1-bookworm AS builder +ARG FEATURES WORKDIR /build COPY Cargo.toml Cargo.lock ./ COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml @@ -11,7 +13,12 @@ RUN mkdir -p src crates/openab-gateway/src \ && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch crates/openab-gateway/src/main.rs crates/openab-gateway/src/lib.rs && cargo build --release -p openab-gateway +RUN touch crates/openab-gateway/src/main.rs crates/openab-gateway/src/lib.rs && \ + if [ -n "$FEATURES" ]; then \ + cargo build --release -p openab-gateway --no-default-features --features "$FEATURES"; \ + else \ + cargo build --release -p openab-gateway; \ + fi # --- Runtime stage --- FROM debian:bookworm-slim diff --git a/Dockerfile.gemini b/Dockerfile.gemini index ec9778cdd..76b265ee4 100644 --- a/Dockerfile.gemini +++ b/Dockerfile.gemini @@ -1,5 +1,7 @@ # --- Build stage --- +ARG FEATURES="" FROM rust:1-bookworm AS builder +ARG FEATURES WORKDIR /build COPY Cargo.toml Cargo.lock ./ COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml @@ -11,7 +13,12 @@ RUN mkdir -p src crates/openab-gateway/src \ && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release +RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && \ + if [ -n "$FEATURES" ]; then \ + cargo build --release --features "$FEATURES"; \ + else \ + cargo build --release; \ + fi # --- Runtime stage --- FROM node:22-bookworm-slim diff --git a/Dockerfile.grok b/Dockerfile.grok index 511fcaf34..18d72c7f1 100644 --- a/Dockerfile.grok +++ b/Dockerfile.grok @@ -1,5 +1,7 @@ # --- Build stage --- +ARG FEATURES="" FROM rust:1-bookworm AS builder +ARG FEATURES WORKDIR /build COPY Cargo.toml Cargo.lock ./ COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml @@ -11,7 +13,12 @@ RUN mkdir -p src crates/openab-gateway/src \ && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release +RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && \ + if [ -n "$FEATURES" ]; then \ + cargo build --release --features "$FEATURES"; \ + else \ + cargo build --release; \ + fi # --- Runtime stage --- FROM debian:bookworm-slim diff --git a/Dockerfile.hermes b/Dockerfile.hermes index ae3025e65..d664e833a 100644 --- a/Dockerfile.hermes +++ b/Dockerfile.hermes @@ -1,5 +1,7 @@ # --- Build stage --- +ARG FEATURES="" FROM rust:1-bookworm AS builder +ARG FEATURES WORKDIR /build COPY Cargo.toml Cargo.lock ./ COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml @@ -11,7 +13,12 @@ RUN mkdir -p src crates/openab-gateway/src \ && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release +RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && \ + if [ -n "$FEATURES" ]; then \ + cargo build --release --features "$FEATURES"; \ + else \ + cargo build --release; \ + fi # --- Runtime stage --- FROM python:3.12-slim-bookworm diff --git a/Dockerfile.mimocode b/Dockerfile.mimocode index 79a21e388..7e8f28468 100644 --- a/Dockerfile.mimocode +++ b/Dockerfile.mimocode @@ -1,5 +1,7 @@ # --- Build stage --- +ARG FEATURES="" FROM rust:1-bookworm AS builder +ARG FEATURES WORKDIR /build COPY Cargo.toml Cargo.lock ./ COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml @@ -11,7 +13,12 @@ RUN mkdir -p src crates/openab-gateway/src \ && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release +RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && \ + if [ -n "$FEATURES" ]; then \ + cargo build --release --features "$FEATURES"; \ + else \ + cargo build --release; \ + fi # --- Runtime stage --- # MiMo-Code (https://github.com/XiaomiMiMo/MiMo-Code) is a fork of OpenCode diff --git a/Dockerfile.native b/Dockerfile.native index a1d5e02c9..2671ca584 100644 --- a/Dockerfile.native +++ b/Dockerfile.native @@ -1,5 +1,7 @@ # --- Build stage --- +ARG FEATURES="" FROM rust:1-bookworm AS builder +ARG FEATURES WORKDIR /build COPY Cargo.toml Cargo.lock ./ COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml @@ -12,7 +14,12 @@ RUN mkdir -p src crates/openab-gateway/src \ && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release +RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && \ + if [ -n "$FEATURES" ]; then \ + cargo build --release --features "$FEATURES"; \ + else \ + cargo build --release; \ + fi RUN cd openab-agent && cargo build --release # --- Runtime stage --- diff --git a/Dockerfile.opencode b/Dockerfile.opencode index 05f3cf8c2..8dd243f01 100644 --- a/Dockerfile.opencode +++ b/Dockerfile.opencode @@ -1,5 +1,7 @@ # --- Build stage --- +ARG FEATURES="" FROM rust:1-bookworm AS builder +ARG FEATURES WORKDIR /build COPY Cargo.toml Cargo.lock ./ COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml @@ -11,7 +13,12 @@ RUN mkdir -p src crates/openab-gateway/src \ && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release +RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && \ + if [ -n "$FEATURES" ]; then \ + cargo build --release --features "$FEATURES"; \ + else \ + cargo build --release; \ + fi # --- Runtime stage --- # node:22-bookworm-slim mirrors the base image used by Dockerfile.claude, diff --git a/Dockerfile.pi b/Dockerfile.pi index 697fce4b4..7f0ec87f5 100644 --- a/Dockerfile.pi +++ b/Dockerfile.pi @@ -1,5 +1,7 @@ # --- Build stage --- +ARG FEATURES="" FROM rust:1-bookworm AS builder +ARG FEATURES WORKDIR /build COPY Cargo.toml Cargo.lock ./ COPY crates/openab-gateway/Cargo.toml crates/openab-gateway/Cargo.toml @@ -11,7 +13,12 @@ RUN mkdir -p src crates/openab-gateway/src \ && rm -rf src crates/openab-gateway/src COPY crates/ crates/ COPY src/ src/ -RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && cargo build --release +RUN touch src/main.rs crates/openab-gateway/src/lib.rs crates/openab-gateway/src/main.rs && \ + if [ -n "$FEATURES" ]; then \ + cargo build --release --features "$FEATURES"; \ + else \ + cargo build --release; \ + fi # --- Runtime stage --- FROM node:22-bookworm-slim From be765da61e5c77d95bf0495272cd6d2f4b3563de Mon Sep 17 00:00:00 2001 From: chaodufashi Date: Fri, 19 Jun 2026 02:36:34 +0000 Subject: [PATCH 20/20] ci: add unified- variants to docker smoke-test matrix Adds 12 unified build variants (one per agent Dockerfile) to validate that all agent images work correctly with --features unified. Total matrix: 25 variants (13 default + 12 unified). --- .github/workflows/docker-smoke-test.yml | 40 ++++++++++++++++--------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/.github/workflows/docker-smoke-test.yml b/.github/workflows/docker-smoke-test.yml index 3c3d09d2f..df340d276 100644 --- a/.github/workflows/docker-smoke-test.yml +++ b/.github/workflows/docker-smoke-test.yml @@ -14,25 +14,37 @@ jobs: fail-fast: false matrix: variant: - - { dockerfile: Dockerfile, suffix: "", agent: "kiro-cli", agent_args: "acp --trust-all-tools" } - - { dockerfile: Dockerfile.claude, suffix: "-claude", agent: "claude-agent-acp", agent_args: "" } - - { dockerfile: Dockerfile.codex, suffix: "-codex", agent: "codex-acp", agent_args: "" } - - { dockerfile: Dockerfile.gemini, suffix: "-gemini", agent: "gemini", agent_args: "--acp" } - - { dockerfile: Dockerfile.copilot, suffix: "-copilot", agent: "copilot", agent_args: "--acp" } - - { dockerfile: Dockerfile.opencode, suffix: "-opencode", agent: "opencode", agent_args: "acp" } - - { dockerfile: Dockerfile.cursor, suffix: "-cursor", agent: "cursor-agent", agent_args: "acp" } - - { dockerfile: Dockerfile.mimocode, suffix: "-mimocode", agent: "mimo", agent_args: "acp" } - - { dockerfile: Dockerfile.hermes, suffix: "-hermes", agent: "hermes-acp", agent_args: "" } - - { dockerfile: Dockerfile.grok, suffix: "-grok", agent: "grok", agent_args: "agent stdio" } - - { dockerfile: Dockerfile.antigravity, suffix: "-antigravity", agent: "agy-acp", agent_args: "" } - - { dockerfile: Dockerfile.pi, suffix: "-pi", agent: "pi-acp", agent_args: "" } - - { dockerfile: openshell/Dockerfile, suffix: "-native-sandbox", agent: "openab-agent", agent_args: "" } + - { dockerfile: Dockerfile, suffix: "", agent: "kiro-cli", agent_args: "acp --trust-all-tools", build_args: "" } + - { dockerfile: Dockerfile, suffix: "-unified", agent: "kiro-cli", agent_args: "acp --trust-all-tools", build_args: "--build-arg BUILD_MODE=unified" } + - { dockerfile: Dockerfile.claude, suffix: "-claude", agent: "claude-agent-acp", agent_args: "", build_args: "" } + - { dockerfile: Dockerfile.claude, suffix: "-unified-claude", agent: "claude-agent-acp", agent_args: "", build_args: "--build-arg FEATURES=unified" } + - { dockerfile: Dockerfile.codex, suffix: "-codex", agent: "codex-acp", agent_args: "", build_args: "" } + - { dockerfile: Dockerfile.codex, suffix: "-unified-codex", agent: "codex-acp", agent_args: "", build_args: "--build-arg FEATURES=unified" } + - { dockerfile: Dockerfile.gemini, suffix: "-gemini", agent: "gemini", agent_args: "--acp", build_args: "" } + - { dockerfile: Dockerfile.gemini, suffix: "-unified-gemini", agent: "gemini", agent_args: "--acp", build_args: "--build-arg FEATURES=unified" } + - { dockerfile: Dockerfile.copilot, suffix: "-copilot", agent: "copilot", agent_args: "--acp", build_args: "" } + - { dockerfile: Dockerfile.copilot, suffix: "-unified-copilot", agent: "copilot", agent_args: "--acp", build_args: "--build-arg FEATURES=unified" } + - { dockerfile: Dockerfile.opencode, suffix: "-opencode", agent: "opencode", agent_args: "acp", build_args: "" } + - { dockerfile: Dockerfile.opencode, suffix: "-unified-opencode", agent: "opencode", agent_args: "acp", build_args: "--build-arg FEATURES=unified" } + - { dockerfile: Dockerfile.cursor, suffix: "-cursor", agent: "cursor-agent", agent_args: "acp", build_args: "" } + - { dockerfile: Dockerfile.cursor, suffix: "-unified-cursor", agent: "cursor-agent", agent_args: "acp", build_args: "--build-arg FEATURES=unified" } + - { dockerfile: Dockerfile.mimocode, suffix: "-mimocode", agent: "mimo", agent_args: "acp", build_args: "" } + - { dockerfile: Dockerfile.mimocode, suffix: "-unified-mimocode", agent: "mimo", agent_args: "acp", build_args: "--build-arg FEATURES=unified" } + - { dockerfile: Dockerfile.hermes, suffix: "-hermes", agent: "hermes-acp", agent_args: "", build_args: "" } + - { dockerfile: Dockerfile.hermes, suffix: "-unified-hermes", agent: "hermes-acp", agent_args: "", build_args: "--build-arg FEATURES=unified" } + - { dockerfile: Dockerfile.grok, suffix: "-grok", agent: "grok", agent_args: "agent stdio", build_args: "" } + - { dockerfile: Dockerfile.grok, suffix: "-unified-grok", agent: "grok", agent_args: "agent stdio", build_args: "--build-arg FEATURES=unified" } + - { dockerfile: Dockerfile.antigravity, suffix: "-antigravity", agent: "agy-acp", agent_args: "", build_args: "" } + - { dockerfile: Dockerfile.antigravity, suffix: "-unified-antigravity", agent: "agy-acp", agent_args: "", build_args: "--build-arg FEATURES=unified" } + - { dockerfile: Dockerfile.pi, suffix: "-pi", agent: "pi-acp", agent_args: "", build_args: "" } + - { dockerfile: Dockerfile.pi, suffix: "-unified-pi", agent: "pi-acp", agent_args: "", build_args: "--build-arg FEATURES=unified" } + - { dockerfile: openshell/Dockerfile, suffix: "-native-sandbox", agent: "openab-agent", agent_args: "", build_args: "" } runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 - name: Build image - run: docker build -t openab-test${{ matrix.variant.suffix }} -f ${{ matrix.variant.dockerfile }} . + run: docker build -t openab-test${{ matrix.variant.suffix }} -f ${{ matrix.variant.dockerfile }} ${{ matrix.variant.build_args }} . - name: Verify openab CMD does not crash run: |