Add rate limiting
This commit is contained in:
135
internal/ratelimit/ratelimit.go
Normal file
135
internal/ratelimit/ratelimit.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// Config defines rate limiting configuration.
|
||||
type Config struct {
|
||||
// RequestsPerSecond is the number of requests allowed per second per IP.
|
||||
RequestsPerSecond float64
|
||||
// Burst is the maximum burst size allowed.
|
||||
Burst int
|
||||
// Enabled controls whether rate limiting is active.
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
// Middleware provides per-IP rate limiting using token bucket algorithm.
|
||||
type Middleware struct {
|
||||
limiters map[string]*rate.Limiter
|
||||
mu sync.RWMutex
|
||||
config Config
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// New creates a new rate limiting middleware.
|
||||
func New(config Config, logger *slog.Logger) *Middleware {
|
||||
m := &Middleware{
|
||||
limiters: make(map[string]*rate.Limiter),
|
||||
config: config,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
// Start cleanup goroutine to remove old limiters
|
||||
if config.Enabled {
|
||||
go m.cleanupLimiters()
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// Handler wraps an http.Handler with rate limiting.
|
||||
func (m *Middleware) Handler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !m.config.Enabled {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract client IP (handle X-Forwarded-For for proxies)
|
||||
ip := m.getClientIP(r)
|
||||
|
||||
limiter := m.getLimiter(ip)
|
||||
if !limiter.Allow() {
|
||||
m.logger.Warn("rate limit exceeded",
|
||||
slog.String("ip", ip),
|
||||
slog.String("path", r.URL.Path),
|
||||
)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Retry-After", "1")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
w.Write([]byte(`{"error":"rate limit exceeded","message":"too many requests"}`))
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// getLimiter returns the rate limiter for a given IP, creating one if needed.
|
||||
func (m *Middleware) getLimiter(ip string) *rate.Limiter {
|
||||
m.mu.RLock()
|
||||
limiter, exists := m.limiters[ip]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if exists {
|
||||
return limiter
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Double-check after acquiring write lock
|
||||
limiter, exists = m.limiters[ip]
|
||||
if exists {
|
||||
return limiter
|
||||
}
|
||||
|
||||
limiter = rate.NewLimiter(rate.Limit(m.config.RequestsPerSecond), m.config.Burst)
|
||||
m.limiters[ip] = limiter
|
||||
return limiter
|
||||
}
|
||||
|
||||
// getClientIP extracts the client IP from the request.
|
||||
func (m *Middleware) getClientIP(r *http.Request) string {
|
||||
// Check X-Forwarded-For header (for proxies/load balancers)
|
||||
xff := r.Header.Get("X-Forwarded-For")
|
||||
if xff != "" {
|
||||
// X-Forwarded-For can be a comma-separated list, use the first IP
|
||||
for idx := 0; idx < len(xff); idx++ {
|
||||
if xff[idx] == ',' {
|
||||
return xff[:idx]
|
||||
}
|
||||
}
|
||||
return xff
|
||||
}
|
||||
|
||||
// Check X-Real-IP header
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
return xri
|
||||
}
|
||||
|
||||
// Fall back to RemoteAddr
|
||||
return r.RemoteAddr
|
||||
}
|
||||
|
||||
// cleanupLimiters periodically removes unused limiters to prevent memory leaks.
|
||||
func (m *Middleware) cleanupLimiters() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
m.mu.Lock()
|
||||
// Clear all limiters periodically
|
||||
// In production, you might want a more sophisticated LRU cache
|
||||
m.limiters = make(map[string]*rate.Limiter)
|
||||
m.mu.Unlock()
|
||||
|
||||
m.logger.Debug("cleaned up rate limiters")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user