diff --git a/internal/ratelimit/ratelimit.go b/internal/ratelimit/ratelimit.go index cefc467..5298fff 100644 --- a/internal/ratelimit/ratelimit.go +++ b/internal/ratelimit/ratelimit.go @@ -4,25 +4,44 @@ package ratelimit import ( "net/http" "sync" + "time" "github.com/Ribbit-Network/api/internal/auth" "golang.org/x/time/rate" ) -// Limiter holds per-key token-bucket limiters. +type entry struct { + limiter *rate.Limiter + lastSeen time.Time +} + +// Limiter holds per-key token-bucket limiters. Idle entries are evicted +// lazily — a sweep runs inline with get() at most once per ttl, so there is +// no background goroutine to manage. type Limiter struct { - mu sync.Mutex - entries map[string]*rate.Limiter - r rate.Limit - b int + mu sync.Mutex + entries map[string]*entry + r rate.Limit + b int + ttl time.Duration + lastSweep time.Time + now func() time.Time } -// New creates a Limiter allowing r tokens per second with a burst of b. -func New(r rate.Limit, b int) *Limiter { +// New creates a Limiter allowing r tokens per second with a burst of b. Entries +// untouched for ttl are evicted; ttl must be positive. Choose ttl >= b/r so an +// evicted bucket would already have refilled — otherwise eviction effectively +// grants a fresh burst. +func New(r rate.Limit, b int, ttl time.Duration) *Limiter { + if ttl <= 0 { + panic("ratelimit: ttl must be > 0") + } return &Limiter{ - entries: make(map[string]*rate.Limiter), + entries: make(map[string]*entry), r: r, b: b, + ttl: ttl, + now: time.Now, } } @@ -43,10 +62,22 @@ func (l *Limiter) Middleware(next http.Handler) http.Handler { func (l *Limiter) get(key string) *rate.Limiter { l.mu.Lock() defer l.mu.Unlock() - lim, ok := l.entries[key] + + now := l.now() + if now.Sub(l.lastSweep) >= l.ttl { + for k, e := range l.entries { + if now.Sub(e.lastSeen) >= l.ttl { + delete(l.entries, k) + } + } + l.lastSweep = now + } + + e, ok := l.entries[key] if !ok { - lim = rate.NewLimiter(l.r, l.b) - l.entries[key] = lim + e = &entry{limiter: rate.NewLimiter(l.r, l.b)} + l.entries[key] = e } - return lim + e.lastSeen = now + return e.limiter } diff --git a/internal/ratelimit/ratelimit_test.go b/internal/ratelimit/ratelimit_test.go index f73e6ac..936b2a9 100644 --- a/internal/ratelimit/ratelimit_test.go +++ b/internal/ratelimit/ratelimit_test.go @@ -4,6 +4,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/Ribbit-Network/api/internal/auth" "github.com/stretchr/testify/require" @@ -23,6 +24,12 @@ func okHandler() http.Handler { }) } +// newLim builds a Limiter with a ttl far longer than any test runs, so the +// rate-limit tests below aren't affected by eviction. +func newLim(r rate.Limit, b int) *Limiter { + return New(r, b, time.Hour) +} + func limited(l *Limiter) http.Handler { return auth.Require(allowAll{})(l.Middleware(okHandler())) } @@ -40,7 +47,7 @@ func reqWithBearer(key string) *http.Request { } func TestRateLimit_AllowsWithinBurst(t *testing.T) { - h := limited(New(rate.Limit(1), 5)) + h := limited(newLim(rate.Limit(1), 5)) for i := 0; i < 5; i++ { rec := httptest.NewRecorder() @@ -50,7 +57,7 @@ func TestRateLimit_AllowsWithinBurst(t *testing.T) { } func TestRateLimit_BlocksAfterBurst(t *testing.T) { - h := limited(New(rate.Limit(1), 3)) + h := limited(newLim(rate.Limit(1), 3)) for i := 0; i < 3; i++ { rec := httptest.NewRecorder() @@ -67,7 +74,7 @@ func TestRateLimit_BlocksAfterBurst(t *testing.T) { // "Authorization: Bearer ..." instead of X-API-Key — i.e. the auth+ratelimit // chain works end-to-end regardless of which header the client uses. func TestRateLimit_BlocksAfterBurst_Bearer(t *testing.T) { - h := limited(New(rate.Limit(1), 2)) + h := limited(newLim(rate.Limit(1), 2)) for i := 0; i < 2; i++ { rec := httptest.NewRecorder() @@ -81,7 +88,7 @@ func TestRateLimit_BlocksAfterBurst_Bearer(t *testing.T) { } func TestRateLimit_IndependentPerKey(t *testing.T) { - h := limited(New(rate.Limit(1), 1)) + h := limited(newLim(rate.Limit(1), 1)) for _, key := range []string{"key-a", "key-b", "key-c"} { rec := httptest.NewRecorder() @@ -93,9 +100,46 @@ func TestRateLimit_IndependentPerKey(t *testing.T) { // If something mounts Limiter.Middleware without auth in front, a request with // no context key should pass through rather than gate on an empty string. func TestRateLimit_NoKeyInContext_PassesThrough(t *testing.T) { - h := New(rate.Limit(1), 1).Middleware(okHandler()) + h := newLim(rate.Limit(1), 1).Middleware(okHandler()) rec := httptest.NewRecorder() h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/data", nil)) require.Equal(t, http.StatusOK, rec.Code) } + +// Idle entries should be evicted once ttl elapses; active ones should remain. +// Clock is injected so the test doesn't sleep. +func TestRateLimit_EvictsIdleEntries(t *testing.T) { + clock := time.Unix(1_000_000, 0) + l := New(rate.Limit(1), 1, time.Minute) + l.now = func() time.Time { return clock } + + l.get("idle-a") + l.get("idle-b") + require.Len(t, l.entries, 2) + + // Jump past the ttl and touch a new key — triggers the lazy sweep. + clock = clock.Add(2 * time.Minute) + l.get("fresh") + + require.Len(t, l.entries, 1) + _, ok := l.entries["fresh"] + require.True(t, ok, "freshly-touched key should remain") +} + +func TestRateLimit_KeepsActiveEntries(t *testing.T) { + clock := time.Unix(1_000_000, 0) + l := New(rate.Limit(1), 1, time.Minute) + l.now = func() time.Time { return clock } + + // Touch the same key every 30s for several minutes — well past one ttl + // of cumulative time, but never idle for a full ttl. + for i := 0; i < 10; i++ { + l.get("active") + clock = clock.Add(30 * time.Second) + } + + require.Len(t, l.entries, 1) + _, ok := l.entries["active"] + require.True(t, ok) +} diff --git a/main.go b/main.go index 4d2f03f..f1a9a3b 100644 --- a/main.go +++ b/main.go @@ -39,8 +39,8 @@ func runServer() { } requireKey := auth.Require(store) - // 60 requests/minute per key with a burst of 30. - limiter := ratelimit.New(rate.Every(time.Second), 60) + // 1 request/sec per key with a burst of 60; lazily evict keys after about 10-20 minutes of idleness. + limiter := ratelimit.New(rate.Every(time.Second), 60, 10*time.Minute) mux := http.NewServeMux() mux.HandleFunc("/", handleRoot)