Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 43 additions & 12 deletions internal/ratelimit/ratelimit.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,44 @@ package ratelimit
import (
"net/http"
"sync"
"time"

"github.com/Ribbit-Network/api/internal/auth"
"golang.org/x/time/rate"
)

// Limiter holds per-key token-bucket limiters.
type entry struct {
limiter *rate.Limiter
lastSeen time.Time
}

// Limiter holds per-key token-bucket limiters. Idle entries are evicted
// lazily — a sweep runs inline with get() at most once per ttl, so there is
// no background goroutine to manage.
type Limiter struct {
mu sync.Mutex
entries map[string]*rate.Limiter
r rate.Limit
b int
mu sync.Mutex
entries map[string]*entry
r rate.Limit
b int
ttl time.Duration
lastSweep time.Time
now func() time.Time
}

// New creates a Limiter allowing r tokens per second with a burst of b.
func New(r rate.Limit, b int) *Limiter {
// New creates a Limiter allowing r tokens per second with a burst of b. Entries
// untouched for ttl are evicted; ttl must be positive. Choose ttl >= b/r so an
// evicted bucket would already have refilled — otherwise eviction effectively
// grants a fresh burst.
func New(r rate.Limit, b int, ttl time.Duration) *Limiter {
if ttl <= 0 {
panic("ratelimit: ttl must be > 0")
}
return &Limiter{
entries: make(map[string]*rate.Limiter),
entries: make(map[string]*entry),
r: r,
b: b,
ttl: ttl,
now: time.Now,
}
}

Expand All @@ -43,10 +62,22 @@ func (l *Limiter) Middleware(next http.Handler) http.Handler {
func (l *Limiter) get(key string) *rate.Limiter {
l.mu.Lock()
defer l.mu.Unlock()
lim, ok := l.entries[key]

now := l.now()
if now.Sub(l.lastSweep) >= l.ttl {
for k, e := range l.entries {
if now.Sub(e.lastSeen) >= l.ttl {
delete(l.entries, k)
Comment thread
keenanjohnson marked this conversation as resolved.
}
}
l.lastSweep = now
}

e, ok := l.entries[key]
if !ok {
lim = rate.NewLimiter(l.r, l.b)
l.entries[key] = lim
e = &entry{limiter: rate.NewLimiter(l.r, l.b)}
l.entries[key] = e
}
return lim
e.lastSeen = now
return e.limiter
}
54 changes: 49 additions & 5 deletions internal/ratelimit/ratelimit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/Ribbit-Network/api/internal/auth"
"github.com/stretchr/testify/require"
Expand All @@ -23,6 +24,12 @@ func okHandler() http.Handler {
})
}

// newLim builds a Limiter with a ttl far longer than any test runs, so the
// rate-limit tests below aren't affected by eviction.
func newLim(r rate.Limit, b int) *Limiter {
return New(r, b, time.Hour)
}

func limited(l *Limiter) http.Handler {
return auth.Require(allowAll{})(l.Middleware(okHandler()))
}
Expand All @@ -40,7 +47,7 @@ func reqWithBearer(key string) *http.Request {
}

func TestRateLimit_AllowsWithinBurst(t *testing.T) {
h := limited(New(rate.Limit(1), 5))
h := limited(newLim(rate.Limit(1), 5))

for i := 0; i < 5; i++ {
rec := httptest.NewRecorder()
Expand All @@ -50,7 +57,7 @@ func TestRateLimit_AllowsWithinBurst(t *testing.T) {
}

func TestRateLimit_BlocksAfterBurst(t *testing.T) {
h := limited(New(rate.Limit(1), 3))
h := limited(newLim(rate.Limit(1), 3))

for i := 0; i < 3; i++ {
rec := httptest.NewRecorder()
Expand All @@ -67,7 +74,7 @@ func TestRateLimit_BlocksAfterBurst(t *testing.T) {
// "Authorization: Bearer ..." instead of X-API-Key — i.e. the auth+ratelimit
// chain works end-to-end regardless of which header the client uses.
func TestRateLimit_BlocksAfterBurst_Bearer(t *testing.T) {
h := limited(New(rate.Limit(1), 2))
h := limited(newLim(rate.Limit(1), 2))

for i := 0; i < 2; i++ {
rec := httptest.NewRecorder()
Expand All @@ -81,7 +88,7 @@ func TestRateLimit_BlocksAfterBurst_Bearer(t *testing.T) {
}

func TestRateLimit_IndependentPerKey(t *testing.T) {
h := limited(New(rate.Limit(1), 1))
h := limited(newLim(rate.Limit(1), 1))

for _, key := range []string{"key-a", "key-b", "key-c"} {
rec := httptest.NewRecorder()
Expand All @@ -93,9 +100,46 @@ func TestRateLimit_IndependentPerKey(t *testing.T) {
// If something mounts Limiter.Middleware without auth in front, a request with
// no context key should pass through rather than gate on an empty string.
func TestRateLimit_NoKeyInContext_PassesThrough(t *testing.T) {
h := New(rate.Limit(1), 1).Middleware(okHandler())
h := newLim(rate.Limit(1), 1).Middleware(okHandler())

rec := httptest.NewRecorder()
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/data", nil))
require.Equal(t, http.StatusOK, rec.Code)
}

// Idle entries should be evicted once ttl elapses; active ones should remain.
// Clock is injected so the test doesn't sleep.
func TestRateLimit_EvictsIdleEntries(t *testing.T) {
clock := time.Unix(1_000_000, 0)
l := New(rate.Limit(1), 1, time.Minute)
l.now = func() time.Time { return clock }

l.get("idle-a")
l.get("idle-b")
require.Len(t, l.entries, 2)

// Jump past the ttl and touch a new key — triggers the lazy sweep.
clock = clock.Add(2 * time.Minute)
l.get("fresh")

require.Len(t, l.entries, 1)
_, ok := l.entries["fresh"]
require.True(t, ok, "freshly-touched key should remain")
}

func TestRateLimit_KeepsActiveEntries(t *testing.T) {
clock := time.Unix(1_000_000, 0)
l := New(rate.Limit(1), 1, time.Minute)
l.now = func() time.Time { return clock }

// Touch the same key every 30s for several minutes — well past one ttl
// of cumulative time, but never idle for a full ttl.
for i := 0; i < 10; i++ {
l.get("active")
clock = clock.Add(30 * time.Second)
}

require.Len(t, l.entries, 1)
_, ok := l.entries["active"]
require.True(t, ok)
}
4 changes: 2 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ func runServer() {
}

requireKey := auth.Require(store)
// 60 requests/minute per key with a burst of 30.
limiter := ratelimit.New(rate.Every(time.Second), 60)
// 1 request/sec per key with a burst of 60; lazily evict keys after about 10-20 minutes of idleness.
limiter := ratelimit.New(rate.Every(time.Second), 60, 10*time.Minute)

mux := http.NewServeMux()
mux.HandleFunc("/", handleRoot)
Expand Down
Loading