From 24e068661d76792718f162922abf63d287a486a8 Mon Sep 17 00:00:00 2001 From: MasterPtato Date: Tue, 23 Jun 2026 17:28:30 -0700 Subject: [PATCH] [SLOP(claude-opus-4-8)] feat(util): add rate limiter primitive and ingress throttles for actor create and gateway websocket --- Cargo.lock | 1 + .../errors/actor.creation_rate_limit.json | 5 + engine/packages/config/src/config/pegboard.rs | 30 ++ engine/packages/gasoline/src/ctx/message.rs | 2 +- engine/packages/gasoline/src/error.rs | 2 +- .../packages/guard-core/src/proxy_service.rs | 16 +- engine/packages/guard-core/src/utils.rs | 44 +- engine/packages/pegboard-gateway2/src/lib.rs | 1 + .../pegboard-gateway2/src/shared_state.rs | 2 +- .../src/ws_to_tunnel_task.rs | 16 + engine/packages/pegboard/Cargo.toml | 1 + engine/packages/pegboard/src/errors.rs | 6 + .../packages/pegboard/src/ops/actor/create.rs | 34 ++ .../pegboard/src/workflows/actor/runtime.rs | 4 +- .../pegboard/src/workflows/actor2/runtime.rs | 2 +- .../workflows/runner_pool_metadata_poller.rs | 2 +- .../pegboard/src/workflows/serverless/conn.rs | 4 +- .../src/driver/postgres/mod.rs | 2 +- engine/packages/universalpubsub/src/pubsub.rs | 2 +- engine/packages/util/src/backoff.rs | 109 ---- engine/packages/util/src/lib.rs | 2 +- engine/packages/util/src/throttle.rs | 487 ++++++++++++++++++ 22 files changed, 606 insertions(+), 168 deletions(-) create mode 100644 engine/artifacts/errors/actor.creation_rate_limit.json delete mode 100644 engine/packages/util/src/backoff.rs create mode 100644 engine/packages/util/src/throttle.rs diff --git a/Cargo.lock b/Cargo.lock index 3df4e1323d..bea893b3b6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3928,6 +3928,7 @@ dependencies = [ "futures-util", "gasoline", "lazy_static", + "moka", "namespace", "nix 0.30.1", "portpicker", diff --git a/engine/artifacts/errors/actor.creation_rate_limit.json b/engine/artifacts/errors/actor.creation_rate_limit.json new file mode 100644 index 0000000000..ff4ee74181 --- /dev/null +++ b/engine/artifacts/errors/actor.creation_rate_limit.json @@ -0,0 +1,5 @@ +{ + "code": "creation_rate_limit", + "group": "actor", + "message": "Too many actors created at once. Try again later." +} \ No newline at end of file diff --git a/engine/packages/config/src/config/pegboard.rs b/engine/packages/config/src/config/pegboard.rs index 616af06a36..f05a98319f 100644 --- a/engine/packages/config/src/config/pegboard.rs +++ b/engine/packages/config/src/config/pegboard.rs @@ -117,6 +117,12 @@ pub struct Pegboard { pub gateway_hws_max_pending_size: Option, /// Max HTTP request body size in bytes for requests to actors. pub gateway_http_max_request_body_size: Option, + /// Max burst of inbound WebSocket messages on a single connection before throttling. + pub gateway_websocket_rate_limit_requests: Option, + /// Time to regain one inbound WebSocket message token on a single connection. + /// + /// Unit is in milliseconds. + pub gateway_websocket_rate_limit_drip_rate_ms: Option, // === Envoy Settings === /// How long to wait before considering an envoy lost and evicting all of its actors. @@ -162,6 +168,14 @@ pub struct Pegboard { /// /// Unit is in bytes. Default: 1,048,576 (1 MiB). pub preload_max_total_bytes: Option, + + // === Rate Limiting === + /// Max burst of actor creations per namespace before throttling. + pub actor_create_rate_limit_requests: Option, + /// Time to regain one actor creation token per namespace. + /// + /// Unit is in milliseconds. + pub actor_create_rate_limit_drip_rate_ms: Option, } impl Pegboard { @@ -369,6 +383,22 @@ impl Pegboard { self.serverless_drain_grace_period.unwrap_or(10_000) } + pub fn gateway_websocket_rate_limit_requests(&self) -> u64 { + self.gateway_websocket_rate_limit_requests.unwrap_or(2_000) + } + + pub fn gateway_websocket_rate_limit_drip_rate_ms(&self) -> u64 { + self.gateway_websocket_rate_limit_drip_rate_ms.unwrap_or(10) + } + + pub fn actor_create_rate_limit_requests(&self) -> u64 { + self.actor_create_rate_limit_requests.unwrap_or(500) + } + + pub fn actor_create_rate_limit_drip_rate_ms(&self) -> u64 { + self.actor_create_rate_limit_drip_rate_ms.unwrap_or(10) + } + pub fn preload_max_total_bytes(&self) -> u64 { self.preload_max_total_bytes.unwrap_or(1_048_576) } diff --git a/engine/packages/gasoline/src/ctx/message.rs b/engine/packages/gasoline/src/ctx/message.rs index a6edb2fe32..4c7fcf5e50 100644 --- a/engine/packages/gasoline/src/ctx/message.rs +++ b/engine/packages/gasoline/src/ctx/message.rs @@ -139,7 +139,7 @@ impl MessageCtx { M: Message, { // Infinite backoff since we want to wait until the service reboots. - let mut backoff = rivet_util::backoff::Backoff::default_infinite(); + let mut backoff = rivet_util::throttle::Backoff::default_infinite(); loop { // Ignore for infinite backoff backoff.tick().await; diff --git a/engine/packages/gasoline/src/error.rs b/engine/packages/gasoline/src/error.rs index 2fa7c686b9..1df67e17ad 100644 --- a/engine/packages/gasoline/src/error.rs +++ b/engine/packages/gasoline/src/error.rs @@ -193,7 +193,7 @@ impl WorkflowError { | WorkflowError::ActivityTimeout(_, error_count) | WorkflowError::OperationTimeout(_, error_count) => { // NOTE: Max retry is handled in `WorkflowCtx::activity` - let mut backoff = rivet_util::backoff::Backoff::new_at( + let mut backoff = rivet_util::throttle::Backoff::new_at( 8, None, RETRY_TIMEOUT_MS, diff --git a/engine/packages/guard-core/src/proxy_service.rs b/engine/packages/guard-core/src/proxy_service.rs index f54bfe8f64..f37287c2ac 100644 --- a/engine/packages/guard-core/src/proxy_service.rs +++ b/engine/packages/guard-core/src/proxy_service.rs @@ -32,7 +32,7 @@ use crate::RouteTarget; use crate::request_context::RequestContext; use crate::response_body::ResponseBody; use crate::route::{CacheKeyFn, ResolveRouteOutput, RouteCache, RoutingFn, RoutingOutput}; -use crate::utils::{InFlightCounter, RateLimiter}; +use crate::utils::InFlightCounter; use crate::{ WebSocketHandle, custom_serve::HibernationResult, errors, metrics, task_group::TaskGroup, utils, }; @@ -56,7 +56,7 @@ pub struct ProxyState { >, route_cache: RouteCache, // We use moka::Cache instead of scc::HashMap because it automatically handles TTL and capacity - rate_limiters: Cache>>, + rate_limiters: Cache>>, in_flight_counters: Cache>>, in_flight_requests: Cache, @@ -98,11 +98,11 @@ impl ProxyState { route_cache: RouteCache::new(route_cache_ttl), rate_limiters: Cache::builder() .max_capacity(10_000) - .time_to_live(PROXY_STATE_CACHE_TTL) + .time_to_idle(PROXY_STATE_CACHE_TTL) .build(), in_flight_counters: Cache::builder() .max_capacity(10_000) - .time_to_live(PROXY_STATE_CACHE_TTL) + .time_to_idle(PROXY_STATE_CACHE_TTL) .build(), in_flight_requests: Cache::builder().max_capacity(10_000_000).build(), tasks: TaskGroup::new(), @@ -217,9 +217,11 @@ impl ProxyState { if let Some(existing_limiter) = self.rate_limiters.get(&req_ctx.client_ip).await { existing_limiter } else { - let new_limiter = Arc::new(Mutex::new(RateLimiter::new( - req_ctx.rate_limit.requests, - req_ctx.rate_limit.period, + let new_limiter = Arc::new(Mutex::new(rivet_util::throttle::RateLimiter::new( + rivet_util::throttle::RateLimitMethod::FixedWindow { + requests: req_ctx.rate_limit.requests, + period: Duration::from_secs(req_ctx.rate_limit.period), + }, ))); self.rate_limiters .insert(req_ctx.client_ip, new_limiter.clone()) diff --git a/engine/packages/guard-core/src/utils.rs b/engine/packages/guard-core/src/utils.rs index 5503611dcc..7c5b3af382 100644 --- a/engine/packages/guard-core/src/utils.rs +++ b/engine/packages/guard-core/src/utils.rs @@ -7,7 +7,7 @@ use hyper::header::HeaderName; use rivet_api_builder::{ErrorResponse, RawErrorResponse}; use rivet_error::{INTERNAL_ERROR, RivetError}; use rivet_util::Id; -use std::time::{Duration, Instant}; +use std::time::Duration; use tokio_tungstenite::tungstenite::protocol::{CloseFrame, frame::coding::CloseCode}; use url::Url; @@ -19,7 +19,7 @@ const X_RIVET_TARGET: HeaderName = HeaderName::from_static("x-rivet-target"); const X_RIVET_ACTOR: HeaderName = HeaderName::from_static("x-rivet-actor"); const X_RIVET_TOKEN: HeaderName = HeaderName::from_static("x-rivet-token"); -// In-flight requests counter +// In-flight requests counter (semaphore) pub(crate) struct InFlightCounter { count: usize, max: usize, @@ -44,43 +44,6 @@ impl InFlightCounter { } } -// Rate limiter -pub(crate) struct RateLimiter { - requests_remaining: u64, - reset_time: Instant, - requests_limit: u64, - period: Duration, -} - -impl RateLimiter { - pub(crate) fn new(requests: u64, period_seconds: u64) -> Self { - Self { - requests_remaining: requests, - reset_time: Instant::now() + Duration::from_secs(period_seconds), - requests_limit: requests, - period: Duration::from_secs(period_seconds), - } - } - - pub(crate) fn try_acquire(&mut self) -> bool { - let now = Instant::now(); - - // Check if we need to reset the counter - if now >= self.reset_time { - self.requests_remaining = self.requests_limit; - self.reset_time = now + self.period; - } - - // Try to consume a request - if self.requests_remaining > 0 { - self.requests_remaining -= 1; - true - } else { - false - } - } -} - // Calculate backoff duration for a given retry attempt pub(crate) fn calculate_backoff(attempt: u32, initial_interval: u64) -> Duration { Duration::from_millis(initial_interval * 2u64.pow(attempt - 1)) @@ -177,7 +140,6 @@ pub(crate) fn err_into_response(err: anyhow::Error) -> Result StatusCode::BAD_GATEWAY, ("guard", "request_timeout") => StatusCode::GATEWAY_TIMEOUT, ("guard", "retry_attempts_exceeded") => StatusCode::BAD_GATEWAY, - ("actor", "not_found") => StatusCode::NOT_FOUND, ("guard", "service_unavailable") => StatusCode::SERVICE_UNAVAILABLE, ("guard", "actor_stopped_while_waiting") => StatusCode::SERVICE_UNAVAILABLE, ("guard", "tunnel_request_aborted") => StatusCode::SERVICE_UNAVAILABLE, @@ -188,6 +150,8 @@ pub(crate) fn err_into_response(err: anyhow::Error) -> Result StatusCode::NOT_FOUND, ("guard", "invalid_request_body") => StatusCode::PAYLOAD_TOO_LARGE, ("guard", "invalid_response_body") => StatusCode::BAD_GATEWAY, + ("actor", "creation_rate_limit") => StatusCode::TOO_MANY_REQUESTS, + ("actor", "not_found") => StatusCode::NOT_FOUND, _ => StatusCode::BAD_REQUEST, }; diff --git a/engine/packages/pegboard-gateway2/src/lib.rs b/engine/packages/pegboard-gateway2/src/lib.rs index 4bd33f15cd..e370270334 100644 --- a/engine/packages/pegboard-gateway2/src/lib.rs +++ b/engine/packages/pegboard-gateway2/src/lib.rs @@ -674,6 +674,7 @@ impl PegboardGateway2 { ); let ws_to_tunnel = tokio::spawn( ws_to_tunnel_task::task( + ctx.clone(), in_flight_req.clone(), ws_rx, ingress_bytes.clone(), diff --git a/engine/packages/pegboard-gateway2/src/shared_state.rs b/engine/packages/pegboard-gateway2/src/shared_state.rs index ff870d330d..67ad84a7b8 100644 --- a/engine/packages/pegboard-gateway2/src/shared_state.rs +++ b/engine/packages/pegboard-gateway2/src/shared_state.rs @@ -661,7 +661,7 @@ impl InFlightRequestHandle { // Cap retries so a permanently-gone receiver fails fast instead of pinning the // request forever. Worst-case backoff total is ~19s, which stays under the default // tunnel ping timeout (30s) so the ping path can take over if the receiver is truly lost. - let mut backoff = rivet_util::backoff::Backoff::new(6, Some(8), 100, 5); + let mut backoff = rivet_util::throttle::Backoff::new(6, Some(8), 100, 5); let first_attempt_at = Instant::now(); let mut attempt = 0; loop { diff --git a/engine/packages/pegboard-gateway2/src/ws_to_tunnel_task.rs b/engine/packages/pegboard-gateway2/src/ws_to_tunnel_task.rs index 1ebcff0e94..66ac4836d3 100644 --- a/engine/packages/pegboard-gateway2/src/ws_to_tunnel_task.rs +++ b/engine/packages/pegboard-gateway2/src/ws_to_tunnel_task.rs @@ -1,11 +1,13 @@ use anyhow::Result; use futures_util::TryStreamExt; +use gas::prelude::*; use rivet_envoy_protocol as protocol; use rivet_guard_core::websocket_handle::WebSocketReceiver; use std::sync::{ Arc, atomic::{AtomicU64, Ordering}, }; +use std::time::Duration; use tokio::sync::{Mutex, watch}; use tokio_tungstenite::tungstenite::Message; @@ -14,6 +16,7 @@ use crate::shared_state::{InFlightRequestHandle, display_id}; #[tracing::instrument(name = "ws_to_tunnel_task", skip_all)] pub async fn task( + ctx: StandaloneCtx, in_flight_req: InFlightRequestHandle, ws_rx: Arc>, ingress_bytes: Arc, @@ -21,7 +24,20 @@ pub async fn task( ) -> Result { let mut ws_rx = ws_rx.lock().await; + // Leaky bucket rate limit on consuming ws messages + let pegboard_config = ctx.config().pegboard(); + let mut rate_limit = rivet_util::throttle::RateLimiter::new( + rivet_util::throttle::RateLimitMethod::LeakyBucket { + requests: pegboard_config.gateway_websocket_rate_limit_requests(), + drip_rate: Duration::from_millis( + pegboard_config.gateway_websocket_rate_limit_drip_rate_ms(), + ), + }, + ); + loop { + rate_limit.acquire().await; + tokio::select! { res = ws_rx.try_next() => { if let Some(msg) = res? { diff --git a/engine/packages/pegboard/Cargo.toml b/engine/packages/pegboard/Cargo.toml index 3d878d6b1d..6e9701246f 100644 --- a/engine/packages/pegboard/Cargo.toml +++ b/engine/packages/pegboard/Cargo.toml @@ -17,6 +17,7 @@ foundationdb-tuple.workspace = true futures-util.workspace = true gas.workspace = true lazy_static.workspace = true +moka.workspace = true namespace.workspace = true nix.workspace = true rand.workspace = true diff --git a/engine/packages/pegboard/src/errors.rs b/engine/packages/pegboard/src/errors.rs index 13e21b55cb..45fb31fbd8 100644 --- a/engine/packages/pegboard/src/errors.rs +++ b/engine/packages/pegboard/src/errors.rs @@ -13,6 +13,12 @@ pub enum Actor { #[error("namespace_not_found", "The namespace does not exist.")] NamespaceNotFound, + #[error( + "creation_rate_limit", + "Too many actors created at once. Try again later." + )] + CreationRateLimit, + #[error( "input_too_large", "Actor input too large.", diff --git a/engine/packages/pegboard/src/ops/actor/create.rs b/engine/packages/pegboard/src/ops/actor/create.rs index c21e878e80..a0cb51bca0 100644 --- a/engine/packages/pegboard/src/ops/actor/create.rs +++ b/engine/packages/pegboard/src/ops/actor/create.rs @@ -1,7 +1,15 @@ use anyhow::{Context, Result}; use gas::prelude::*; +use moka::future::Cache; use rivet_api_util::{Method, request_remote_datacenter}; use rivet_types::actors::{Actor, CrashPolicy}; +use std::sync::{Arc, OnceLock}; +use std::time::Duration; +use tokio::sync::Mutex; + +const RATE_LIMITER_CACHE_TTL: Duration = Duration::from_secs(60 * 60); +static RATE_LIMITERS: OnceLock>>> = + OnceLock::new(); #[derive(Debug)] pub struct Input { @@ -29,6 +37,32 @@ pub struct Output { #[operation] pub async fn pegboard_actor_create(ctx: &OperationCtx, input: &Input) -> Result { + let rate_limiter = RATE_LIMITERS + .get_or_init(|| { + Cache::builder() + .max_capacity(10_000) + .time_to_idle(RATE_LIMITER_CACHE_TTL) + .build() + }) + .entry(input.namespace_id) + .or_insert_with(async { + let pegboard_config = ctx.config().pegboard(); + Arc::new(Mutex::new(rivet_util::throttle::RateLimiter::new( + rivet_util::throttle::RateLimitMethod::LeakyBucket { + requests: pegboard_config.actor_create_rate_limit_requests(), + drip_rate: Duration::from_millis( + pegboard_config.actor_create_rate_limit_drip_rate_ms(), + ), + }, + ))) + }) + .await; + + // Limit actor creation per namespace id + if !rate_limiter.value().lock().await.try_acquire() { + return Err(crate::errors::Actor::CreationRateLimit.build()); + } + // Set up subscriptions before dispatching workflow let ( mut create_sub, diff --git a/engine/packages/pegboard/src/workflows/actor/runtime.rs b/engine/packages/pegboard/src/workflows/actor/runtime.rs index 9965b59d5d..505f0afa81 100644 --- a/engine/packages/pegboard/src/workflows/actor/runtime.rs +++ b/engine/packages/pegboard/src/workflows/actor/runtime.rs @@ -1307,8 +1307,8 @@ fn reschedule_backoff( retry_count: usize, base_retry_timeout: usize, max_exponent: usize, -) -> util::backoff::Backoff { - util::backoff::Backoff::new_at(max_exponent, None, base_retry_timeout, 500, retry_count) +) -> util::throttle::Backoff { + util::throttle::Backoff::new_at(max_exponent, None, base_retry_timeout, 500, retry_count) } #[derive(Debug, Serialize, Deserialize)] diff --git a/engine/packages/pegboard/src/workflows/actor2/runtime.rs b/engine/packages/pegboard/src/workflows/actor2/runtime.rs index e8e1d063e1..3eb8dd9a04 100644 --- a/engine/packages/pegboard/src/workflows/actor2/runtime.rs +++ b/engine/packages/pegboard/src/workflows/actor2/runtime.rs @@ -830,7 +830,7 @@ async fn compare_retry( if reset { state.reschedule_ts = None; } else { - let backoff = util::backoff::Backoff::new_at( + let backoff = util::throttle::Backoff::new_at( ctx.config().pegboard().reschedule_backoff_max_exponent(), None, ctx.config().pegboard().base_retry_timeout(), diff --git a/engine/packages/pegboard/src/workflows/runner_pool_metadata_poller.rs b/engine/packages/pegboard/src/workflows/runner_pool_metadata_poller.rs index bcb600f784..47f5e6e262 100644 --- a/engine/packages/pegboard/src/workflows/runner_pool_metadata_poller.rs +++ b/engine/packages/pegboard/src/workflows/runner_pool_metadata_poller.rs @@ -164,7 +164,7 @@ async fn poll_metadata(ctx: &ActivityCtx, input: &PollMetadataInput) -> Result

util::backoff::Backoff { - util::backoff::Backoff::new_at(max_exponent, None, base_retry_timeout, 500, retry_count) +) -> util::throttle::Backoff { + util::throttle::Backoff::new_at(max_exponent, None, base_retry_timeout, 500, retry_count) } /// Report an error to the error tracker workflow. diff --git a/engine/packages/universalpubsub/src/driver/postgres/mod.rs b/engine/packages/universalpubsub/src/driver/postgres/mod.rs index 3ef8f02beb..52bc219b2a 100644 --- a/engine/packages/universalpubsub/src/driver/postgres/mod.rs +++ b/engine/packages/universalpubsub/src/driver/postgres/mod.rs @@ -5,7 +5,7 @@ use base64::engine::general_purpose::STANDARD_NO_PAD as BASE64; use deadpool_postgres::{Config, ManagerConfig, Pool, PoolConfig, RecyclingMethod, Runtime}; use futures_util::future::poll_fn; use rivet_postgres_util::build_tls_config; -use rivet_util::backoff::Backoff; +use rivet_util::throttle::Backoff; use scc::HashMap; use std::hash::{DefaultHasher, Hash, Hasher}; use std::path::PathBuf; diff --git a/engine/packages/universalpubsub/src/pubsub.rs b/engine/packages/universalpubsub/src/pubsub.rs index 1640307e91..d1bc921af0 100644 --- a/engine/packages/universalpubsub/src/pubsub.rs +++ b/engine/packages/universalpubsub/src/pubsub.rs @@ -8,7 +8,7 @@ use scc::HashMap; use tokio::sync::broadcast; use uuid::Uuid; -use rivet_util::backoff::Backoff; +use rivet_util::throttle::Backoff; use crate::chunking::{ChunkTracker, FastPath, encode_chunk, split_payload_into_chunks}; use crate::driver::{PubSubDriverHandle, PublishOpts, SubscriberDriverHandle}; diff --git a/engine/packages/util/src/backoff.rs b/engine/packages/util/src/backoff.rs deleted file mode 100644 index 183f25042e..0000000000 --- a/engine/packages/util/src/backoff.rs +++ /dev/null @@ -1,109 +0,0 @@ -use rand::Rng; -use tokio::time::{Duration, Instant}; - -pub struct Backoff { - /// Maximum exponent for the backoff. - max_exponent: usize, - - /// Maximum amount of retries. - max_retries: Option, - - /// Base wait time in ms. - wait: usize, - - /// Maximum randomness. - randomness: usize, - - /// Iteration of the backoff. - i: usize, - - /// Timestamp to sleep until in ms. - sleep_until: Instant, -} - -impl Backoff { - pub fn new( - max_exponent: usize, - max_retries: Option, - wait: usize, - randomness: usize, - ) -> Backoff { - Backoff { - max_exponent, - max_retries, - wait, - randomness, - i: 0, - sleep_until: Instant::now(), - } - } - - pub fn new_at( - max_exponent: usize, - max_retries: Option, - wait: usize, - randomness: usize, - i: usize, - ) -> Backoff { - Backoff { - max_exponent, - max_retries, - wait, - randomness, - i, - sleep_until: Instant::now(), - } - } - - pub fn tick_index(&self) -> usize { - self.i - } - - /// Waits for the next backoff tick. - /// - /// Returns false if the index is greater than `max_retries`. - pub async fn tick(&mut self) -> bool { - if self.max_retries.map_or(false, |x| self.i > x) { - return false; - } - - tokio::time::sleep_until(self.sleep_until).await; - - let next_wait = self.current_duration() + rand::thread_rng().gen_range(0..self.randomness); - self.sleep_until += Duration::from_millis(next_wait as u64); - - self.i += 1; - - true - } - - /// Returns the instant of the next backoff tick. Does not wait. - /// - /// Returns None if the index is greater than `max_retries`. - pub fn step(&mut self) -> Option { - if self.max_retries.map_or(false, |x| self.i > x) { - return None; - } - - let next_wait = self.current_duration() + rand::thread_rng().gen_range(0..self.randomness); - self.sleep_until += Duration::from_millis(next_wait as u64); - - self.i += 1; - - Some(self.sleep_until) - } - - pub fn current_duration(&self) -> usize { - self.wait * 2usize.pow(self.i.min(self.max_exponent) as u32) - } - - pub fn default_infinite() -> Backoff { - Backoff::new(8, None, 1_000, 1_000) - } -} - -impl Default for Backoff { - fn default() -> Backoff { - Backoff::new(5, Some(16), 1_000, 1_000) - } -} diff --git a/engine/packages/util/src/lib.rs b/engine/packages/util/src/lib.rs index 9b77b4b487..0c088b9df8 100644 --- a/engine/packages/util/src/lib.rs +++ b/engine/packages/util/src/lib.rs @@ -4,7 +4,6 @@ pub use id::Id; pub use rivet_util_id as id; pub mod async_counter; -pub mod backoff; pub mod billing; pub mod build_meta; pub mod check; @@ -19,6 +18,7 @@ pub mod req; pub mod serde; pub mod size; pub mod sort; +pub mod throttle; pub mod timestamp; pub mod url; diff --git a/engine/packages/util/src/throttle.rs b/engine/packages/util/src/throttle.rs new file mode 100644 index 0000000000..38295f99f6 --- /dev/null +++ b/engine/packages/util/src/throttle.rs @@ -0,0 +1,487 @@ +use rand::Rng; +use tokio::time::{Duration, Instant}; + +pub struct Backoff { + /// Maximum exponent for the backoff. + max_exponent: usize, + + /// Maximum amount of retries. + max_retries: Option, + + /// Base wait time in ms. + wait: usize, + + /// Maximum randomness. + randomness: usize, + + /// Iteration of the backoff. + i: usize, + + /// Timestamp to sleep until in ms. + sleep_until: Instant, +} + +impl Backoff { + pub fn new( + max_exponent: usize, + max_retries: Option, + wait: usize, + randomness: usize, + ) -> Backoff { + Backoff { + max_exponent, + max_retries, + wait, + randomness, + i: 0, + sleep_until: Instant::now(), + } + } + + pub fn new_at( + max_exponent: usize, + max_retries: Option, + wait: usize, + randomness: usize, + i: usize, + ) -> Backoff { + Backoff { + max_exponent, + max_retries, + wait, + randomness, + i, + sleep_until: Instant::now(), + } + } + + pub fn tick_index(&self) -> usize { + self.i + } + + /// Waits for the next backoff tick. + /// + /// Returns false if the index is greater than `max_retries`. + pub async fn tick(&mut self) -> bool { + if self.max_retries.map_or(false, |x| self.i > x) { + return false; + } + + tokio::time::sleep_until(self.sleep_until).await; + + let next_wait = self.current_duration() + rand::thread_rng().gen_range(0..self.randomness); + self.sleep_until += Duration::from_millis(next_wait as u64); + + self.i += 1; + + true + } + + /// Returns the instant of the next backoff tick. Does not wait. + /// + /// Returns None if the index is greater than `max_retries`. + pub fn step(&mut self) -> Option { + if self.max_retries.map_or(false, |x| self.i > x) { + return None; + } + + let next_wait = self.current_duration() + rand::thread_rng().gen_range(0..self.randomness); + self.sleep_until += Duration::from_millis(next_wait as u64); + + self.i += 1; + + Some(self.sleep_until) + } + + pub fn current_duration(&self) -> usize { + self.wait * 2usize.pow(self.i.min(self.max_exponent) as u32) + } + + pub fn default_infinite() -> Backoff { + Backoff::new(8, None, 1_000, 1_000) + } +} + +impl Default for Backoff { + fn default() -> Backoff { + Backoff::new(5, Some(16), 1_000, 1_000) + } +} + +pub enum RateLimitMethod { + FixedWindow { + requests: u64, + period: Duration, + }, + LeakyBucket { + requests: u64, + /// How quickly to regain requests. 1 / drip_rate + drip_rate: Duration, + }, +} + +enum RateLimitState { + FixedWindow { + requests_remaining: u64, + requests_limit: u64, + reset_time: Instant, + period: Duration, + }, + LeakyBucket { + requests_remaining: u64, + requests_limit: u64, + last_acquire: Instant, + drip_rate: Duration, + accum_drip: f32, + }, +} + +pub struct RateLimiter { + state: RateLimitState, +} + +impl RateLimiter { + pub fn new(method: RateLimitMethod) -> Self { + Self { + state: match method { + RateLimitMethod::FixedWindow { requests, period } => RateLimitState::FixedWindow { + requests_remaining: requests, + requests_limit: requests, + reset_time: Instant::now() + period, + period, + }, + RateLimitMethod::LeakyBucket { + requests, + drip_rate, + } => RateLimitState::LeakyBucket { + requests_remaining: requests, + requests_limit: requests, + last_acquire: Instant::now(), + drip_rate: drip_rate, + accum_drip: 0.0, + }, + }, + } + } + + pub fn try_acquire(&mut self) -> bool { + match &mut self.state { + RateLimitState::FixedWindow { + requests_remaining, + requests_limit, + reset_time, + period, + } => { + let now = Instant::now(); + // Check if we need to reset the counter + if now >= *reset_time { + *requests_remaining = *requests_limit; + *reset_time = now + *period; + } + + // Try to consume a request + if *requests_remaining > 0 { + *requests_remaining -= 1; + true + } else { + false + } + } + RateLimitState::LeakyBucket { + requests_remaining, + requests_limit, + last_acquire, + drip_rate, + accum_drip, + } => { + let now = Instant::now(); + let dt = now - *last_acquire; + *last_acquire = now; + + // Drip bucket + if requests_remaining < requests_limit { + *accum_drip += dt.div_duration_f32(*drip_rate); + + *requests_remaining += + (*accum_drip as u64).min(*requests_limit - *requests_remaining); + + if *accum_drip >= 1.0 { + *accum_drip = accum_drip.fract(); + } + } + + if *requests_remaining > 0 { + *requests_remaining -= 1; + true + } else { + false + } + } + } + } + + pub async fn acquire(&mut self) { + match &mut self.state { + RateLimitState::FixedWindow { + requests_remaining, + requests_limit, + reset_time, + period, + } => { + let now = Instant::now(); + // Check if we need to reset the counter + if now >= *reset_time { + *requests_remaining = *requests_limit; + *reset_time = now + *period; + } + + // Try to consume a request + if *requests_remaining > 0 { + *requests_remaining -= 1; + } else { + tokio::time::sleep(*period).await; + + *requests_remaining = *requests_limit; + *reset_time = Instant::now() + *period; + } + } + RateLimitState::LeakyBucket { + requests_remaining, + requests_limit, + last_acquire, + drip_rate, + accum_drip, + } => { + let now = Instant::now(); + let dt = now - *last_acquire; + *last_acquire = now; + + // Drip bucket + if requests_remaining < requests_limit { + *accum_drip += dt.div_duration_f32(*drip_rate); + + *requests_remaining += + (*accum_drip as u64).min(*requests_limit - *requests_remaining); + + if *accum_drip >= 1.0 { + *accum_drip = accum_drip.fract(); + } + } + + if *requests_remaining > 0 { + *requests_remaining -= 1; + } else { + let deficit = 1.0 - *accum_drip; + tokio::time::sleep(drip_rate.mul_f32(deficit)).await; + + *last_acquire = Instant::now(); + *accum_drip = 0.0; + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::{RateLimitMethod, RateLimiter}; + use tokio::time::{Duration, Instant}; + + // MARK: FixedWindow / try_acquire + + #[tokio::test(start_paused = true)] + async fn fixed_window_allows_full_burst_then_blocks() { + let mut rl = RateLimiter::new(RateLimitMethod::FixedWindow { + requests: 3, + period: Duration::from_millis(100), + }); + + assert!(rl.try_acquire()); + assert!(rl.try_acquire()); + assert!(rl.try_acquire()); + // Limit reached within the window. + assert!(!rl.try_acquire()); + } + + #[tokio::test(start_paused = true)] + async fn fixed_window_does_not_refill_before_period() { + let mut rl = RateLimiter::new(RateLimitMethod::FixedWindow { + requests: 2, + period: Duration::from_millis(100), + }); + + assert!(rl.try_acquire()); + assert!(rl.try_acquire()); + assert!(!rl.try_acquire()); + + // Just shy of a full period: still no refill. The window is + // all-or-nothing, it does not drip partial credit. + tokio::time::advance(Duration::from_millis(99)).await; + assert!(!rl.try_acquire()); + } + + #[tokio::test(start_paused = true)] + async fn fixed_window_resets_to_full_after_period() { + let mut rl = RateLimiter::new(RateLimitMethod::FixedWindow { + requests: 2, + period: Duration::from_millis(100), + }); + + assert!(rl.try_acquire()); + assert!(rl.try_acquire()); + assert!(!rl.try_acquire()); + + // After a full period the window resets to its full allowance. + tokio::time::advance(Duration::from_millis(100)).await; + assert!(rl.try_acquire()); + assert!(rl.try_acquire()); + assert!(!rl.try_acquire()); + } + + // MARK: LeakyBucket / try_acquire + + #[tokio::test(start_paused = true)] + async fn leaky_bucket_allows_full_burst_then_blocks() { + let mut rl = RateLimiter::new(RateLimitMethod::LeakyBucket { + requests: 3, + drip_rate: Duration::from_millis(10), + }); + + assert!(rl.try_acquire()); + assert!(rl.try_acquire()); + assert!(rl.try_acquire()); + assert!(!rl.try_acquire()); + } + + #[tokio::test(start_paused = true)] + async fn leaky_bucket_drips_exactly_one_token_per_rate() { + let mut rl = RateLimiter::new(RateLimitMethod::LeakyBucket { + requests: 3, + drip_rate: Duration::from_millis(10), + }); + + // Drain the bucket. + for _ in 0..3 { + assert!(rl.try_acquire()); + } + assert!(!rl.try_acquire()); + + // Exactly one drip period yields exactly one token, no more. + tokio::time::advance(Duration::from_millis(10)).await; + assert!(rl.try_acquire()); + assert!(!rl.try_acquire()); + } + + #[tokio::test(start_paused = true)] + async fn leaky_bucket_refill_is_capped_at_capacity() { + let mut rl = RateLimiter::new(RateLimitMethod::LeakyBucket { + requests: 3, + drip_rate: Duration::from_millis(10), + }); + + for _ in 0..3 { + assert!(rl.try_acquire()); + } + assert!(!rl.try_acquire()); + + // Idle far longer than it takes to refill the whole bucket. Credit must + // not accumulate past capacity, so only `requests` tokens are available. + tokio::time::advance(Duration::from_millis(1_000)).await; + assert!(rl.try_acquire()); + assert!(rl.try_acquire()); + assert!(rl.try_acquire()); + assert!(!rl.try_acquire()); + } + + #[tokio::test(start_paused = true)] + async fn leaky_bucket_accumulates_fractional_drip_across_calls() { + let mut rl = RateLimiter::new(RateLimitMethod::LeakyBucket { + requests: 1, + drip_rate: Duration::from_millis(10), + }); + + // Consume the only token. + assert!(rl.try_acquire()); + assert!(!rl.try_acquire()); + + // Half a drip period: less than one whole token, still blocked. + tokio::time::advance(Duration::from_millis(5)).await; + assert!(!rl.try_acquire()); + + // Another half period: the fractional credit from the previous interval + // must carry over and complete one whole token. + tokio::time::advance(Duration::from_millis(5)).await; + assert!(rl.try_acquire()); + assert!(!rl.try_acquire()); + } + + // MARK: acquire (blocking) + + #[tokio::test(start_paused = true)] + async fn acquire_returns_immediately_while_tokens_remain() { + let mut rl = RateLimiter::new(RateLimitMethod::LeakyBucket { + requests: 3, + drip_rate: Duration::from_millis(10), + }); + + let start = Instant::now(); + rl.acquire().await; + rl.acquire().await; + rl.acquire().await; + // Burst is served without waiting. + assert_eq!(start.elapsed(), Duration::ZERO); + } + + #[tokio::test(start_paused = true)] + async fn acquire_blocks_until_a_token_is_available() { + let mut rl = RateLimiter::new(RateLimitMethod::LeakyBucket { + requests: 1, + drip_rate: Duration::from_millis(10), + }); + + // Drain the single token. + rl.acquire().await; + + // The next acquire must wait one full drip period for a token. + let start = Instant::now(); + rl.acquire().await; + assert!(start.elapsed() >= Duration::from_millis(10)); + } + + #[tokio::test(start_paused = true)] + async fn acquire_sustains_the_drip_rate_without_doubling() { + let mut rl = RateLimiter::new(RateLimitMethod::LeakyBucket { + requests: 1, + drip_rate: Duration::from_millis(10), + }); + + // Drain the initial burst token so every subsequent acquire starts empty. + rl.acquire().await; + + let start = Instant::now(); + // Five acquires, each starting from an empty bucket, must each cost one + // drip period, so the total is at least 5 * drip_rate. A limiter that + // admits the post-sleep request without debiting a token finishes in + // ~3 periods, effectively doubling the sustained rate. + for _ in 0..5 { + rl.acquire().await; + } + assert!(start.elapsed() >= Duration::from_millis(50)); + } + + #[tokio::test(start_paused = true)] + async fn fixed_window_acquire_blocks_until_window_resets() { + let mut rl = RateLimiter::new(RateLimitMethod::FixedWindow { + requests: 2, + period: Duration::from_millis(100), + }); + + rl.acquire().await; + rl.acquire().await; + + // The window is exhausted, so the next acquire must wait for the reset. + let start = Instant::now(); + rl.acquire().await; + assert!(start.elapsed() >= Duration::from_millis(100)); + } +}