From d782204c683cce7ca8031703b81b2ce9c7759fdd Mon Sep 17 00:00:00 2001 From: Anibal Angulo Date: Thu, 5 Mar 2026 07:21:04 +0000 Subject: [PATCH] Add circuit breaker --- cmd/gateway/main.go | 14 ++- go.mod | 1 + go.sum | 2 + internal/observability/metrics.go | 39 +++++++ internal/providers/circuitbreaker.go | 145 +++++++++++++++++++++++++++ internal/providers/providers.go | 17 +++- internal/server/server.go | 28 +++++- 7 files changed, 241 insertions(+), 5 deletions(-) create mode 100644 internal/providers/circuitbreaker.go diff --git a/cmd/gateway/main.go b/cmd/gateway/main.go index 259183c..247c656 100644 --- a/cmd/gateway/main.go +++ b/cmd/gateway/main.go @@ -91,7 +91,19 @@ func main() { logger.Info("metrics initialized", slog.String("path", metricsPath)) } - baseRegistry, err := providers.NewRegistry(cfg.Providers, cfg.Models) + // Create provider registry with circuit breaker support + var baseRegistry *providers.Registry + if cfg.Observability.Enabled && cfg.Observability.Metrics.Enabled { + // Pass observability callback for circuit breaker state changes + baseRegistry, err = providers.NewRegistryWithCircuitBreaker( + cfg.Providers, + cfg.Models, + observability.RecordCircuitBreakerStateChange, + ) + } else { + // No observability, use default registry + baseRegistry, err = providers.NewRegistry(cfg.Providers, cfg.Models) + } if err != nil { logger.Error("failed to initialize providers", slog.String("error", err.Error())) os.Exit(1) diff --git a/go.mod b/go.mod index b7088d0..a12df8f 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/openai/openai-go/v3 v3.2.0 github.com/prometheus/client_golang v1.19.0 github.com/redis/go-redis/v9 v9.18.0 + github.com/sony/gobreaker v1.0.0 github.com/stretchr/testify v1.11.1 go.opentelemetry.io/otel v1.29.0 go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.29.0 diff --git a/go.sum b/go.sum index 659bb8c..fa9a1fc 100644 --- a/go.sum +++ b/go.sum @@ -121,6 +121,8 @@ github.com/redis/go-redis/v9 v9.18.0 h1:pMkxYPkEbMPwRdenAzUNyFNrDgHx9U+DrBabWNfS github.com/redis/go-redis/v9 v9.18.0/go.mod h1:k3ufPphLU5YXwNTUcCRXGxUoF1fqxnhFQmscfkCoDA0= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/sony/gobreaker v1.0.0 h1:feX5fGGXSl3dYd4aHZItw+FpHLvvoaqkawKjVNiFMNQ= +github.com/sony/gobreaker v1.0.0/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= diff --git a/internal/observability/metrics.go b/internal/observability/metrics.go index 1c33c8e..82b4879 100644 --- a/internal/observability/metrics.go +++ b/internal/observability/metrics.go @@ -118,6 +118,23 @@ var ( }, []string{"backend"}, ) + + // Circuit Breaker Metrics + circuitBreakerState = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "circuit_breaker_state", + Help: "Circuit breaker state (0=closed, 1=open, 2=half-open)", + }, + []string{"provider"}, + ) + + circuitBreakerStateTransitions = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "circuit_breaker_state_transitions_total", + Help: "Total number of circuit breaker state transitions", + }, + []string{"provider", "from", "to"}, + ) ) // InitMetrics registers all metrics with a new Prometheus registry. @@ -143,5 +160,27 @@ func InitMetrics() *prometheus.Registry { registry.MustRegister(conversationOperationDuration) registry.MustRegister(conversationActiveCount) + // Register circuit breaker metrics + registry.MustRegister(circuitBreakerState) + registry.MustRegister(circuitBreakerStateTransitions) + return registry } + +// RecordCircuitBreakerStateChange records a circuit breaker state transition. +func RecordCircuitBreakerStateChange(provider, from, to string) { + // Record the transition + circuitBreakerStateTransitions.WithLabelValues(provider, from, to).Inc() + + // Update the current state gauge + var stateValue float64 + switch to { + case "closed": + stateValue = 0 + case "open": + stateValue = 1 + case "half-open": + stateValue = 2 + } + circuitBreakerState.WithLabelValues(provider).Set(stateValue) +} diff --git a/internal/providers/circuitbreaker.go b/internal/providers/circuitbreaker.go new file mode 100644 index 0000000..1112509 --- /dev/null +++ b/internal/providers/circuitbreaker.go @@ -0,0 +1,145 @@ +package providers + +import ( + "context" + "fmt" + "time" + + "github.com/sony/gobreaker" + + "github.com/ajac-zero/latticelm/internal/api" +) + +// CircuitBreakerProvider wraps a Provider with circuit breaker functionality. +type CircuitBreakerProvider struct { + provider Provider + cb *gobreaker.CircuitBreaker +} + +// CircuitBreakerConfig holds configuration for the circuit breaker. +type CircuitBreakerConfig struct { + // MaxRequests is the maximum number of requests allowed to pass through + // when the circuit breaker is half-open. Default: 3 + MaxRequests uint32 + + // Interval is the cyclic period of the closed state for the circuit breaker + // to clear the internal Counts. Default: 30s + Interval time.Duration + + // Timeout is the period of the open state, after which the state becomes half-open. + // Default: 60s + Timeout time.Duration + + // MinRequests is the minimum number of requests needed before evaluating failure ratio. + // Default: 5 + MinRequests uint32 + + // FailureRatio is the ratio of failures that will trip the circuit breaker. + // Default: 0.5 (50%) + FailureRatio float64 + + // OnStateChange is an optional callback invoked when circuit breaker state changes. + // Parameters: provider name, from state, to state + OnStateChange func(provider, from, to string) +} + +// DefaultCircuitBreakerConfig returns a sensible default configuration. +func DefaultCircuitBreakerConfig() CircuitBreakerConfig { + return CircuitBreakerConfig{ + MaxRequests: 3, + Interval: 30 * time.Second, + Timeout: 60 * time.Second, + MinRequests: 5, + FailureRatio: 0.5, + } +} + +// NewCircuitBreakerProvider wraps a provider with circuit breaker functionality. +func NewCircuitBreakerProvider(provider Provider, cfg CircuitBreakerConfig) *CircuitBreakerProvider { + providerName := provider.Name() + + settings := gobreaker.Settings{ + Name: fmt.Sprintf("%s-circuit-breaker", providerName), + MaxRequests: cfg.MaxRequests, + Interval: cfg.Interval, + Timeout: cfg.Timeout, + ReadyToTrip: func(counts gobreaker.Counts) bool { + // Only trip if we have enough requests to be statistically meaningful + if counts.Requests < cfg.MinRequests { + return false + } + failureRatio := float64(counts.TotalFailures) / float64(counts.Requests) + return failureRatio >= cfg.FailureRatio + }, + OnStateChange: func(name string, from gobreaker.State, to gobreaker.State) { + // Call the callback if provided + if cfg.OnStateChange != nil { + cfg.OnStateChange(providerName, from.String(), to.String()) + } + }, + } + + return &CircuitBreakerProvider{ + provider: provider, + cb: gobreaker.NewCircuitBreaker(settings), + } +} + +// Name returns the underlying provider name. +func (p *CircuitBreakerProvider) Name() string { + return p.provider.Name() +} + +// Generate wraps the provider's Generate method with circuit breaker protection. +func (p *CircuitBreakerProvider) Generate(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (*api.ProviderResult, error) { + result, err := p.cb.Execute(func() (interface{}, error) { + return p.provider.Generate(ctx, messages, req) + }) + + if err != nil { + return nil, err + } + + return result.(*api.ProviderResult), nil +} + +// GenerateStream wraps the provider's GenerateStream method with circuit breaker protection. +func (p *CircuitBreakerProvider) GenerateStream(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (<-chan *api.ProviderStreamDelta, <-chan error) { + // For streaming, we check the circuit breaker state before initiating the stream + // If the circuit is open, we return an error immediately + state := p.cb.State() + if state == gobreaker.StateOpen { + errChan := make(chan error, 1) + deltaChan := make(chan *api.ProviderStreamDelta) + errChan <- gobreaker.ErrOpenState + close(deltaChan) + close(errChan) + return deltaChan, errChan + } + + // If circuit is closed or half-open, attempt the stream + deltaChan, errChan := p.provider.GenerateStream(ctx, messages, req) + + // Wrap the error channel to report successes/failures to circuit breaker + wrappedErrChan := make(chan error, 1) + + go func() { + defer close(wrappedErrChan) + + // Wait for the error channel to signal completion + if err := <-errChan; err != nil { + // Record failure in circuit breaker + p.cb.Execute(func() (interface{}, error) { + return nil, err + }) + wrappedErrChan <- err + } else { + // Record success in circuit breaker + p.cb.Execute(func() (interface{}, error) { + return nil, nil + }) + } + }() + + return deltaChan, wrappedErrChan +} diff --git a/internal/providers/providers.go b/internal/providers/providers.go index 245fdfc..bd807bc 100644 --- a/internal/providers/providers.go +++ b/internal/providers/providers.go @@ -28,6 +28,16 @@ type Registry struct { // NewRegistry constructs provider implementations from configuration. func NewRegistry(entries map[string]config.ProviderEntry, models []config.ModelEntry) (*Registry, error) { + return NewRegistryWithCircuitBreaker(entries, models, nil) +} + +// NewRegistryWithCircuitBreaker constructs provider implementations with circuit breaker support. +// The onStateChange callback is invoked when circuit breaker state changes. +func NewRegistryWithCircuitBreaker( + entries map[string]config.ProviderEntry, + models []config.ModelEntry, + onStateChange func(provider, from, to string), +) (*Registry, error) { reg := &Registry{ providers: make(map[string]Provider), models: make(map[string]string), @@ -35,13 +45,18 @@ func NewRegistry(entries map[string]config.ProviderEntry, models []config.ModelE modelList: models, } + // Use default circuit breaker configuration + cbConfig := DefaultCircuitBreakerConfig() + cbConfig.OnStateChange = onStateChange + for name, entry := range entries { p, err := buildProvider(entry) if err != nil { return nil, fmt.Errorf("provider %q: %w", name, err) } if p != nil { - reg.providers[name] = p + // Wrap provider with circuit breaker + reg.providers[name] = NewCircuitBreakerProvider(p, cbConfig) } } diff --git a/internal/server/server.go b/internal/server/server.go index 9125b3b..0dcb490 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -2,6 +2,7 @@ package server import ( "encoding/json" + "errors" "fmt" "log/slog" "net/http" @@ -9,6 +10,7 @@ import ( "time" "github.com/google/uuid" + "github.com/sony/gobreaker" "github.com/ajac-zero/latticelm/internal/api" "github.com/ajac-zero/latticelm/internal/conversation" @@ -40,6 +42,11 @@ func New(registry ProviderRegistry, convs conversation.Store, logger *slog.Logge } } +// isCircuitBreakerError checks if the error is from a circuit breaker. +func isCircuitBreakerError(err error) bool { + return errors.Is(err, gobreaker.ErrOpenState) || errors.Is(err, gobreaker.ErrTooManyRequests) +} + // RegisterRoutes wires the HTTP handlers onto the provided mux. func (s *GatewayServer) RegisterRoutes(mux *http.ServeMux) { mux.HandleFunc("/v1/responses", s.handleResponses) @@ -173,7 +180,13 @@ func (s *GatewayServer) handleSyncResponse(w http.ResponseWriter, r *http.Reques slog.String("error", err.Error()), )..., ) - http.Error(w, "provider error", http.StatusBadGateway) + + // Check if error is from circuit breaker + if isCircuitBreakerError(err) { + http.Error(w, "service temporarily unavailable - circuit breaker open", http.StatusServiceUnavailable) + } else { + http.Error(w, "provider error", http.StatusBadGateway) + } return } @@ -409,6 +422,15 @@ loop: slog.String("error", streamErr.Error()), )..., ) + + // Determine error type based on circuit breaker state + errorType := "server_error" + errorMessage := streamErr.Error() + if isCircuitBreakerError(streamErr) { + errorType = "circuit_breaker_open" + errorMessage = "service temporarily unavailable - circuit breaker open" + } + failedResp := s.buildResponse(origReq, &api.ProviderResult{ Model: origReq.Model, }, provider.Name(), responseID) @@ -416,8 +438,8 @@ loop: failedResp.CompletedAt = nil failedResp.Output = []api.OutputItem{} failedResp.Error = &api.ResponseError{ - Type: "server_error", - Message: streamErr.Error(), + Type: errorType, + Message: errorMessage, } s.sendSSE(w, flusher, &seq, "response.failed", &api.StreamEvent{ Type: "response.failed",