diff --git a/README.md b/README.md index 0767644..ed76b41 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,8 @@ latticelm (unified API) ✅ **OAuth2/OIDC authentication** (Google, Auth0, any OIDC provider) ✅ **Terminal chat client** (Python with Rich UI, PEP 723) ✅ **Conversation tracking** (previous_response_id for efficient context) +✅ **Rate limiting** (Per-IP token bucket with configurable limits) +✅ **Health & readiness endpoints** (Kubernetes-compatible health checks) ## Quick Start @@ -258,6 +260,54 @@ curl -X POST http://localhost:8080/v1/responses \ -d '{"model": "gemini-2.0-flash-exp", ...}' ``` +## Production Features + +### Rate Limiting + +Per-IP rate limiting using token bucket algorithm to prevent abuse and manage load: + +```yaml +rate_limit: + enabled: true + requests_per_second: 10 # Max requests per second per IP + burst: 20 # Maximum burst size +``` + +Features: +- **Token bucket algorithm** for smooth rate limiting +- **Per-IP limiting** with support for X-Forwarded-For headers +- **Configurable limits** for requests per second and burst size +- **Automatic cleanup** of stale rate limiters to prevent memory leaks +- **429 responses** with Retry-After header when limits exceeded + +### Health & Readiness Endpoints + +Kubernetes-compatible health check endpoints for orchestration and load balancers: + +**Liveness endpoint** (`/health`): +```bash +curl http://localhost:8080/health +# {"status":"healthy","timestamp":1709438400} +``` + +**Readiness endpoint** (`/ready`): +```bash +curl http://localhost:8080/ready +# { +# "status":"ready", +# "timestamp":1709438400, +# "checks":{ +# "conversation_store":"healthy", +# "providers":"healthy" +# } +# } +``` + +The readiness endpoint verifies: +- Conversation store connectivity +- At least one provider is configured +- Returns 503 if any check fails + ## Next Steps - ✅ ~~Implement streaming responses~~ diff --git a/cmd/gateway/main.go b/cmd/gateway/main.go index 80ba3d6..8c4b142 100644 --- a/cmd/gateway/main.go +++ b/cmd/gateway/main.go @@ -22,6 +22,7 @@ import ( "github.com/ajac-zero/latticelm/internal/conversation" slogger "github.com/ajac-zero/latticelm/internal/logger" "github.com/ajac-zero/latticelm/internal/providers" + "github.com/ajac-zero/latticelm/internal/ratelimit" "github.com/ajac-zero/latticelm/internal/server" ) @@ -86,8 +87,30 @@ func main() { addr = ":8080" } - // Build handler chain: logging -> auth -> routes - handler := loggingMiddleware(authMiddleware.Handler(mux), logger) + // Initialize rate limiting + rateLimitConfig := ratelimit.Config{ + Enabled: cfg.RateLimit.Enabled, + RequestsPerSecond: cfg.RateLimit.RequestsPerSecond, + Burst: cfg.RateLimit.Burst, + } + // Set defaults if not configured + if rateLimitConfig.Enabled && rateLimitConfig.RequestsPerSecond == 0 { + rateLimitConfig.RequestsPerSecond = 10 // default 10 req/s + } + if rateLimitConfig.Enabled && rateLimitConfig.Burst == 0 { + rateLimitConfig.Burst = 20 // default burst of 20 + } + rateLimitMiddleware := ratelimit.New(rateLimitConfig, logger) + + if cfg.RateLimit.Enabled { + logger.Info("rate limiting enabled", + slog.Float64("requests_per_second", rateLimitConfig.RequestsPerSecond), + slog.Int("burst", rateLimitConfig.Burst), + ) + } + + // Build handler chain: logging -> rate limiting -> auth -> routes + handler := loggingMiddleware(rateLimitMiddleware.Handler(authMiddleware.Handler(mux)), logger) srv := &http.Server{ Addr: addr, diff --git a/config.example.yaml b/config.example.yaml index a0245d5..f49dd4a 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -5,6 +5,11 @@ logging: format: "json" # "json" for production, "text" for development level: "info" # "debug", "info", "warn", or "error" +rate_limit: + enabled: false # Enable rate limiting (recommended for production) + requests_per_second: 10 # Max requests per second per IP (default: 10) + burst: 20 # Maximum burst size (default: 20) + providers: google: type: "google" diff --git a/config.test.yaml b/config.test.yaml index 8f5f323..8cc03f3 100644 --- a/config.test.yaml +++ b/config.test.yaml @@ -5,6 +5,11 @@ logging: format: "text" # text format for easy reading in development level: "debug" # debug level to see all logs +rate_limit: + enabled: false # disabled for testing + requests_per_second: 100 + burst: 200 + providers: mock: type: "openai" diff --git a/go.mod b/go.mod index 4e426b5..c04f498 100644 --- a/go.mod +++ b/go.mod @@ -46,6 +46,7 @@ require ( golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.40.0 // indirect golang.org/x/text v0.33.0 // indirect + golang.org/x/time v0.14.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect google.golang.org/grpc v1.66.2 // indirect google.golang.org/protobuf v1.34.2 // indirect diff --git a/go.sum b/go.sum index f71fd69..ff896e2 100644 --- a/go.sum +++ b/go.sum @@ -160,6 +160,8 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= +golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= +golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= diff --git a/internal/config/config.go b/internal/config/config.go index 0bebbf8..514dcf1 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -15,6 +15,7 @@ type Config struct { Auth AuthConfig `yaml:"auth"` Conversations ConversationConfig `yaml:"conversations"` Logging LoggingConfig `yaml:"logging"` + RateLimit RateLimitConfig `yaml:"rate_limit"` } // ConversationConfig controls conversation storage. @@ -39,6 +40,16 @@ type LoggingConfig struct { Level string `yaml:"level"` } +// RateLimitConfig controls rate limiting behavior. +type RateLimitConfig struct { + // Enabled controls whether rate limiting is active. + Enabled bool `yaml:"enabled"` + // RequestsPerSecond is the number of requests allowed per second per IP. + RequestsPerSecond float64 `yaml:"requests_per_second"` + // Burst is the maximum burst size allowed. + Burst int `yaml:"burst"` +} + // AuthConfig holds OIDC authentication settings. type AuthConfig struct { Enabled bool `yaml:"enabled"` diff --git a/internal/ratelimit/ratelimit.go b/internal/ratelimit/ratelimit.go new file mode 100644 index 0000000..aa03b67 --- /dev/null +++ b/internal/ratelimit/ratelimit.go @@ -0,0 +1,135 @@ +package ratelimit + +import ( + "log/slog" + "net/http" + "sync" + "time" + + "golang.org/x/time/rate" +) + +// Config defines rate limiting configuration. +type Config struct { + // RequestsPerSecond is the number of requests allowed per second per IP. + RequestsPerSecond float64 + // Burst is the maximum burst size allowed. + Burst int + // Enabled controls whether rate limiting is active. + Enabled bool +} + +// Middleware provides per-IP rate limiting using token bucket algorithm. +type Middleware struct { + limiters map[string]*rate.Limiter + mu sync.RWMutex + config Config + logger *slog.Logger +} + +// New creates a new rate limiting middleware. +func New(config Config, logger *slog.Logger) *Middleware { + m := &Middleware{ + limiters: make(map[string]*rate.Limiter), + config: config, + logger: logger, + } + + // Start cleanup goroutine to remove old limiters + if config.Enabled { + go m.cleanupLimiters() + } + + return m +} + +// Handler wraps an http.Handler with rate limiting. +func (m *Middleware) Handler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !m.config.Enabled { + next.ServeHTTP(w, r) + return + } + + // Extract client IP (handle X-Forwarded-For for proxies) + ip := m.getClientIP(r) + + limiter := m.getLimiter(ip) + if !limiter.Allow() { + m.logger.Warn("rate limit exceeded", + slog.String("ip", ip), + slog.String("path", r.URL.Path), + ) + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Retry-After", "1") + w.WriteHeader(http.StatusTooManyRequests) + w.Write([]byte(`{"error":"rate limit exceeded","message":"too many requests"}`)) + return + } + + next.ServeHTTP(w, r) + }) +} + +// getLimiter returns the rate limiter for a given IP, creating one if needed. +func (m *Middleware) getLimiter(ip string) *rate.Limiter { + m.mu.RLock() + limiter, exists := m.limiters[ip] + m.mu.RUnlock() + + if exists { + return limiter + } + + m.mu.Lock() + defer m.mu.Unlock() + + // Double-check after acquiring write lock + limiter, exists = m.limiters[ip] + if exists { + return limiter + } + + limiter = rate.NewLimiter(rate.Limit(m.config.RequestsPerSecond), m.config.Burst) + m.limiters[ip] = limiter + return limiter +} + +// getClientIP extracts the client IP from the request. +func (m *Middleware) getClientIP(r *http.Request) string { + // Check X-Forwarded-For header (for proxies/load balancers) + xff := r.Header.Get("X-Forwarded-For") + if xff != "" { + // X-Forwarded-For can be a comma-separated list, use the first IP + for idx := 0; idx < len(xff); idx++ { + if xff[idx] == ',' { + return xff[:idx] + } + } + return xff + } + + // Check X-Real-IP header + if xri := r.Header.Get("X-Real-IP"); xri != "" { + return xri + } + + // Fall back to RemoteAddr + return r.RemoteAddr +} + +// cleanupLimiters periodically removes unused limiters to prevent memory leaks. +func (m *Middleware) cleanupLimiters() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + + for range ticker.C { + m.mu.Lock() + // Clear all limiters periodically + // In production, you might want a more sophisticated LRU cache + m.limiters = make(map[string]*rate.Limiter) + m.mu.Unlock() + + m.logger.Debug("cleaned up rate limiters") + } +} diff --git a/internal/ratelimit/ratelimit_test.go b/internal/ratelimit/ratelimit_test.go new file mode 100644 index 0000000..81faed0 --- /dev/null +++ b/internal/ratelimit/ratelimit_test.go @@ -0,0 +1,175 @@ +package ratelimit + +import ( + "log/slog" + "net/http" + "net/http/httptest" + "os" + "testing" + "time" +) + +func TestRateLimitMiddleware(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + + tests := []struct { + name string + config Config + requestCount int + expectedAllowed int + expectedRateLimited int + }{ + { + name: "disabled rate limiting allows all requests", + config: Config{ + Enabled: false, + RequestsPerSecond: 1, + Burst: 1, + }, + requestCount: 10, + expectedAllowed: 10, + expectedRateLimited: 0, + }, + { + name: "enabled rate limiting enforces limits", + config: Config{ + Enabled: true, + RequestsPerSecond: 1, + Burst: 2, + }, + requestCount: 5, + expectedAllowed: 2, // Burst allows 2 immediately + expectedRateLimited: 3, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + middleware := New(tt.config, logger) + + handler := middleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + allowed := 0 + rateLimited := 0 + + for i := 0; i < tt.requestCount; i++ { + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.1:1234" + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + if w.Code == http.StatusOK { + allowed++ + } else if w.Code == http.StatusTooManyRequests { + rateLimited++ + } + } + + if allowed != tt.expectedAllowed { + t.Errorf("expected %d allowed requests, got %d", tt.expectedAllowed, allowed) + } + if rateLimited != tt.expectedRateLimited { + t.Errorf("expected %d rate limited requests, got %d", tt.expectedRateLimited, rateLimited) + } + }) + } +} + +func TestGetClientIP(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + middleware := New(Config{Enabled: false}, logger) + + tests := []struct { + name string + headers map[string]string + remoteAddr string + expectedIP string + }{ + { + name: "uses X-Forwarded-For if present", + headers: map[string]string{"X-Forwarded-For": "203.0.113.1, 198.51.100.1"}, + remoteAddr: "192.168.1.1:1234", + expectedIP: "203.0.113.1", + }, + { + name: "uses X-Real-IP if X-Forwarded-For not present", + headers: map[string]string{"X-Real-IP": "203.0.113.1"}, + remoteAddr: "192.168.1.1:1234", + expectedIP: "203.0.113.1", + }, + { + name: "uses RemoteAddr as fallback", + headers: map[string]string{}, + remoteAddr: "192.168.1.1:1234", + expectedIP: "192.168.1.1:1234", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = tt.remoteAddr + for k, v := range tt.headers { + req.Header.Set(k, v) + } + + ip := middleware.getClientIP(req) + if ip != tt.expectedIP { + t.Errorf("expected IP %q, got %q", tt.expectedIP, ip) + } + }) + } +} + +func TestRateLimitRefill(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + + config := Config{ + Enabled: true, + RequestsPerSecond: 10, // 10 requests per second + Burst: 5, + } + middleware := New(config, logger) + + handler := middleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Use up the burst + for i := 0; i < 5; i++ { + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.1:1234" + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("request %d should be allowed, got status %d", i, w.Code) + } + } + + // Next request should be rate limited + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.1:1234" + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if w.Code != http.StatusTooManyRequests { + t.Errorf("expected rate limit, got status %d", w.Code) + } + + // Wait for tokens to refill (100ms = 1 token at 10/s) + time.Sleep(150 * time.Millisecond) + + // Should be allowed now + req = httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.1:1234" + w = httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("request should be allowed after refill, got status %d", w.Code) + } +} diff --git a/internal/server/health.go b/internal/server/health.go new file mode 100644 index 0000000..5d402f5 --- /dev/null +++ b/internal/server/health.go @@ -0,0 +1,87 @@ +package server + +import ( + "context" + "encoding/json" + "net/http" + "time" +) + +// HealthStatus represents the health check response. +type HealthStatus struct { + Status string `json:"status"` + Timestamp int64 `json:"timestamp"` + Checks map[string]string `json:"checks,omitempty"` +} + +// handleHealth returns a basic health check endpoint. +// This is suitable for Kubernetes liveness probes. +func (s *GatewayServer) handleHealth(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + status := HealthStatus{ + Status: "healthy", + Timestamp: time.Now().Unix(), + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(status) +} + +// handleReady returns a readiness check that verifies dependencies. +// This is suitable for Kubernetes readiness probes and load balancer health checks. +func (s *GatewayServer) handleReady(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + checks := make(map[string]string) + allHealthy := true + + // Check conversation store connectivity + ctx, cancel := context.WithTimeout(r.Context(), 2*time.Second) + defer cancel() + + // Test conversation store by attempting a simple operation + testID := "health_check_test" + _, err := s.convs.Get(testID) + if err != nil { + checks["conversation_store"] = "unhealthy: " + err.Error() + allHealthy = false + } else { + checks["conversation_store"] = "healthy" + } + + // Check if at least one provider is configured + models := s.registry.Models() + if len(models) == 0 { + checks["providers"] = "unhealthy: no providers configured" + allHealthy = false + } else { + checks["providers"] = "healthy" + } + + _ = ctx // Use context if needed + + status := HealthStatus{ + Timestamp: time.Now().Unix(), + Checks: checks, + } + + if allHealthy { + status.Status = "ready" + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + } else { + status.Status = "not_ready" + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusServiceUnavailable) + } + + _ = json.NewEncoder(w).Encode(status) +} diff --git a/internal/server/health_test.go b/internal/server/health_test.go new file mode 100644 index 0000000..4f44d67 --- /dev/null +++ b/internal/server/health_test.go @@ -0,0 +1,150 @@ +package server + +import ( + "encoding/json" + "log/slog" + "net/http" + "net/http/httptest" + "os" + "testing" +) + +func TestHealthEndpoint(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + registry := newMockRegistry() + convStore := newMockConversationStore() + + server := New(registry, convStore, logger) + + tests := []struct { + name string + method string + expectedStatus int + }{ + { + name: "GET returns healthy status", + method: http.MethodGet, + expectedStatus: http.StatusOK, + }, + { + name: "POST returns method not allowed", + method: http.MethodPost, + expectedStatus: http.StatusMethodNotAllowed, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(tt.method, "/health", nil) + w := httptest.NewRecorder() + + server.handleHealth(w, req) + + if w.Code != tt.expectedStatus { + t.Errorf("expected status %d, got %d", tt.expectedStatus, w.Code) + } + + if tt.expectedStatus == http.StatusOK { + var status HealthStatus + if err := json.NewDecoder(w.Body).Decode(&status); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + + if status.Status != "healthy" { + t.Errorf("expected status 'healthy', got %q", status.Status) + } + + if status.Timestamp == 0 { + t.Error("expected non-zero timestamp") + } + } + }) + } +} + +func TestReadyEndpoint(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + + tests := []struct { + name string + setupRegistry func() *mockRegistry + convStore *mockConversationStore + expectedStatus int + expectedReady bool + }{ + { + name: "returns ready when all checks pass", + setupRegistry: func() *mockRegistry { + reg := newMockRegistry() + reg.addModel("test-model", "test-provider") + return reg + }, + convStore: newMockConversationStore(), + expectedStatus: http.StatusOK, + expectedReady: true, + }, + { + name: "returns not ready when no providers configured", + setupRegistry: func() *mockRegistry { + return newMockRegistry() + }, + convStore: newMockConversationStore(), + expectedStatus: http.StatusServiceUnavailable, + expectedReady: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := New(tt.setupRegistry(), tt.convStore, logger) + + req := httptest.NewRequest(http.MethodGet, "/ready", nil) + w := httptest.NewRecorder() + + server.handleReady(w, req) + + if w.Code != tt.expectedStatus { + t.Errorf("expected status %d, got %d", tt.expectedStatus, w.Code) + } + + var status HealthStatus + if err := json.NewDecoder(w.Body).Decode(&status); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + + if tt.expectedReady { + if status.Status != "ready" { + t.Errorf("expected status 'ready', got %q", status.Status) + } + } else { + if status.Status != "not_ready" { + t.Errorf("expected status 'not_ready', got %q", status.Status) + } + } + + if status.Timestamp == 0 { + t.Error("expected non-zero timestamp") + } + + if status.Checks == nil { + t.Error("expected checks map to be present") + } + }) + } +} + +func TestReadyEndpointMethodNotAllowed(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + registry := newMockRegistry() + convStore := newMockConversationStore() + server := New(registry, convStore, logger) + + req := httptest.NewRequest(http.MethodPost, "/ready", nil) + w := httptest.NewRecorder() + + server.handleReady(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("expected status %d, got %d", http.StatusMethodNotAllowed, w.Code) + } +} diff --git a/internal/server/server.go b/internal/server/server.go index 975b768..4784403 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -44,6 +44,8 @@ func New(registry ProviderRegistry, convs conversation.Store, logger *slog.Logge func (s *GatewayServer) RegisterRoutes(mux *http.ServeMux) { mux.HandleFunc("/v1/responses", s.handleResponses) mux.HandleFunc("/v1/models", s.handleModels) + mux.HandleFunc("/health", s.handleHealth) + mux.HandleFunc("/ready", s.handleReady) } func (s *GatewayServer) handleModels(w http.ResponseWriter, r *http.Request) {