Add CI and production grade improvements #3
@@ -91,7 +91,19 @@ func main() {
|
|||||||
logger.Info("metrics initialized", slog.String("path", metricsPath))
|
logger.Info("metrics initialized", slog.String("path", metricsPath))
|
||||||
}
|
}
|
||||||
|
|
||||||
baseRegistry, err := providers.NewRegistry(cfg.Providers, cfg.Models)
|
// Create provider registry with circuit breaker support
|
||||||
|
var baseRegistry *providers.Registry
|
||||||
|
if cfg.Observability.Enabled && cfg.Observability.Metrics.Enabled {
|
||||||
|
// Pass observability callback for circuit breaker state changes
|
||||||
|
baseRegistry, err = providers.NewRegistryWithCircuitBreaker(
|
||||||
|
cfg.Providers,
|
||||||
|
cfg.Models,
|
||||||
|
observability.RecordCircuitBreakerStateChange,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
// No observability, use default registry
|
||||||
|
baseRegistry, err = providers.NewRegistry(cfg.Providers, cfg.Models)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("failed to initialize providers", slog.String("error", err.Error()))
|
logger.Error("failed to initialize providers", slog.String("error", err.Error()))
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
|
|||||||
1
go.mod
1
go.mod
@@ -12,6 +12,7 @@ require (
|
|||||||
github.com/openai/openai-go/v3 v3.2.0
|
github.com/openai/openai-go/v3 v3.2.0
|
||||||
github.com/prometheus/client_golang v1.19.0
|
github.com/prometheus/client_golang v1.19.0
|
||||||
github.com/redis/go-redis/v9 v9.18.0
|
github.com/redis/go-redis/v9 v9.18.0
|
||||||
|
github.com/sony/gobreaker v1.0.0
|
||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.11.1
|
||||||
go.opentelemetry.io/otel v1.29.0
|
go.opentelemetry.io/otel v1.29.0
|
||||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.29.0
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.29.0
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -121,6 +121,8 @@ github.com/redis/go-redis/v9 v9.18.0 h1:pMkxYPkEbMPwRdenAzUNyFNrDgHx9U+DrBabWNfS
|
|||||||
github.com/redis/go-redis/v9 v9.18.0/go.mod h1:k3ufPphLU5YXwNTUcCRXGxUoF1fqxnhFQmscfkCoDA0=
|
github.com/redis/go-redis/v9 v9.18.0/go.mod h1:k3ufPphLU5YXwNTUcCRXGxUoF1fqxnhFQmscfkCoDA0=
|
||||||
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
|
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
|
||||||
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
|
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
|
||||||
|
github.com/sony/gobreaker v1.0.0 h1:feX5fGGXSl3dYd4aHZItw+FpHLvvoaqkawKjVNiFMNQ=
|
||||||
|
github.com/sony/gobreaker v1.0.0/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY=
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||||
|
|||||||
@@ -118,6 +118,23 @@ var (
|
|||||||
},
|
},
|
||||||
[]string{"backend"},
|
[]string{"backend"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Circuit Breaker Metrics
|
||||||
|
circuitBreakerState = prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "circuit_breaker_state",
|
||||||
|
Help: "Circuit breaker state (0=closed, 1=open, 2=half-open)",
|
||||||
|
},
|
||||||
|
[]string{"provider"},
|
||||||
|
)
|
||||||
|
|
||||||
|
circuitBreakerStateTransitions = prometheus.NewCounterVec(
|
||||||
|
prometheus.CounterOpts{
|
||||||
|
Name: "circuit_breaker_state_transitions_total",
|
||||||
|
Help: "Total number of circuit breaker state transitions",
|
||||||
|
},
|
||||||
|
[]string{"provider", "from", "to"},
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
// InitMetrics registers all metrics with a new Prometheus registry.
|
// InitMetrics registers all metrics with a new Prometheus registry.
|
||||||
@@ -143,5 +160,27 @@ func InitMetrics() *prometheus.Registry {
|
|||||||
registry.MustRegister(conversationOperationDuration)
|
registry.MustRegister(conversationOperationDuration)
|
||||||
registry.MustRegister(conversationActiveCount)
|
registry.MustRegister(conversationActiveCount)
|
||||||
|
|
||||||
|
// Register circuit breaker metrics
|
||||||
|
registry.MustRegister(circuitBreakerState)
|
||||||
|
registry.MustRegister(circuitBreakerStateTransitions)
|
||||||
|
|
||||||
return registry
|
return registry
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RecordCircuitBreakerStateChange records a circuit breaker state transition.
|
||||||
|
func RecordCircuitBreakerStateChange(provider, from, to string) {
|
||||||
|
// Record the transition
|
||||||
|
circuitBreakerStateTransitions.WithLabelValues(provider, from, to).Inc()
|
||||||
|
|
||||||
|
// Update the current state gauge
|
||||||
|
var stateValue float64
|
||||||
|
switch to {
|
||||||
|
case "closed":
|
||||||
|
stateValue = 0
|
||||||
|
case "open":
|
||||||
|
stateValue = 1
|
||||||
|
case "half-open":
|
||||||
|
stateValue = 2
|
||||||
|
}
|
||||||
|
circuitBreakerState.WithLabelValues(provider).Set(stateValue)
|
||||||
|
}
|
||||||
|
|||||||
145
internal/providers/circuitbreaker.go
Normal file
145
internal/providers/circuitbreaker.go
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sony/gobreaker"
|
||||||
|
|
||||||
|
"github.com/ajac-zero/latticelm/internal/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CircuitBreakerProvider wraps a Provider with circuit breaker functionality.
|
||||||
|
type CircuitBreakerProvider struct {
|
||||||
|
provider Provider
|
||||||
|
cb *gobreaker.CircuitBreaker
|
||||||
|
}
|
||||||
|
|
||||||
|
// CircuitBreakerConfig holds configuration for the circuit breaker.
|
||||||
|
type CircuitBreakerConfig struct {
|
||||||
|
// MaxRequests is the maximum number of requests allowed to pass through
|
||||||
|
// when the circuit breaker is half-open. Default: 3
|
||||||
|
MaxRequests uint32
|
||||||
|
|
||||||
|
// Interval is the cyclic period of the closed state for the circuit breaker
|
||||||
|
// to clear the internal Counts. Default: 30s
|
||||||
|
Interval time.Duration
|
||||||
|
|
||||||
|
// Timeout is the period of the open state, after which the state becomes half-open.
|
||||||
|
// Default: 60s
|
||||||
|
Timeout time.Duration
|
||||||
|
|
||||||
|
// MinRequests is the minimum number of requests needed before evaluating failure ratio.
|
||||||
|
// Default: 5
|
||||||
|
MinRequests uint32
|
||||||
|
|
||||||
|
// FailureRatio is the ratio of failures that will trip the circuit breaker.
|
||||||
|
// Default: 0.5 (50%)
|
||||||
|
FailureRatio float64
|
||||||
|
|
||||||
|
// OnStateChange is an optional callback invoked when circuit breaker state changes.
|
||||||
|
// Parameters: provider name, from state, to state
|
||||||
|
OnStateChange func(provider, from, to string)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultCircuitBreakerConfig returns a sensible default configuration.
|
||||||
|
func DefaultCircuitBreakerConfig() CircuitBreakerConfig {
|
||||||
|
return CircuitBreakerConfig{
|
||||||
|
MaxRequests: 3,
|
||||||
|
Interval: 30 * time.Second,
|
||||||
|
Timeout: 60 * time.Second,
|
||||||
|
MinRequests: 5,
|
||||||
|
FailureRatio: 0.5,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCircuitBreakerProvider wraps a provider with circuit breaker functionality.
|
||||||
|
func NewCircuitBreakerProvider(provider Provider, cfg CircuitBreakerConfig) *CircuitBreakerProvider {
|
||||||
|
providerName := provider.Name()
|
||||||
|
|
||||||
|
settings := gobreaker.Settings{
|
||||||
|
Name: fmt.Sprintf("%s-circuit-breaker", providerName),
|
||||||
|
MaxRequests: cfg.MaxRequests,
|
||||||
|
Interval: cfg.Interval,
|
||||||
|
Timeout: cfg.Timeout,
|
||||||
|
ReadyToTrip: func(counts gobreaker.Counts) bool {
|
||||||
|
// Only trip if we have enough requests to be statistically meaningful
|
||||||
|
if counts.Requests < cfg.MinRequests {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
failureRatio := float64(counts.TotalFailures) / float64(counts.Requests)
|
||||||
|
return failureRatio >= cfg.FailureRatio
|
||||||
|
},
|
||||||
|
OnStateChange: func(name string, from gobreaker.State, to gobreaker.State) {
|
||||||
|
// Call the callback if provided
|
||||||
|
if cfg.OnStateChange != nil {
|
||||||
|
cfg.OnStateChange(providerName, from.String(), to.String())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return &CircuitBreakerProvider{
|
||||||
|
provider: provider,
|
||||||
|
cb: gobreaker.NewCircuitBreaker(settings),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name returns the underlying provider name.
|
||||||
|
func (p *CircuitBreakerProvider) Name() string {
|
||||||
|
return p.provider.Name()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate wraps the provider's Generate method with circuit breaker protection.
|
||||||
|
func (p *CircuitBreakerProvider) Generate(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (*api.ProviderResult, error) {
|
||||||
|
result, err := p.cb.Execute(func() (interface{}, error) {
|
||||||
|
return p.provider.Generate(ctx, messages, req)
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return result.(*api.ProviderResult), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateStream wraps the provider's GenerateStream method with circuit breaker protection.
|
||||||
|
func (p *CircuitBreakerProvider) GenerateStream(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (<-chan *api.ProviderStreamDelta, <-chan error) {
|
||||||
|
// For streaming, we check the circuit breaker state before initiating the stream
|
||||||
|
// If the circuit is open, we return an error immediately
|
||||||
|
state := p.cb.State()
|
||||||
|
if state == gobreaker.StateOpen {
|
||||||
|
errChan := make(chan error, 1)
|
||||||
|
deltaChan := make(chan *api.ProviderStreamDelta)
|
||||||
|
errChan <- gobreaker.ErrOpenState
|
||||||
|
close(deltaChan)
|
||||||
|
close(errChan)
|
||||||
|
return deltaChan, errChan
|
||||||
|
}
|
||||||
|
|
||||||
|
// If circuit is closed or half-open, attempt the stream
|
||||||
|
deltaChan, errChan := p.provider.GenerateStream(ctx, messages, req)
|
||||||
|
|
||||||
|
// Wrap the error channel to report successes/failures to circuit breaker
|
||||||
|
wrappedErrChan := make(chan error, 1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer close(wrappedErrChan)
|
||||||
|
|
||||||
|
// Wait for the error channel to signal completion
|
||||||
|
if err := <-errChan; err != nil {
|
||||||
|
// Record failure in circuit breaker
|
||||||
|
p.cb.Execute(func() (interface{}, error) {
|
||||||
|
return nil, err
|
||||||
|
})
|
||||||
|
wrappedErrChan <- err
|
||||||
|
} else {
|
||||||
|
// Record success in circuit breaker
|
||||||
|
p.cb.Execute(func() (interface{}, error) {
|
||||||
|
return nil, nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return deltaChan, wrappedErrChan
|
||||||
|
}
|
||||||
@@ -28,6 +28,16 @@ type Registry struct {
|
|||||||
|
|
||||||
// NewRegistry constructs provider implementations from configuration.
|
// NewRegistry constructs provider implementations from configuration.
|
||||||
func NewRegistry(entries map[string]config.ProviderEntry, models []config.ModelEntry) (*Registry, error) {
|
func NewRegistry(entries map[string]config.ProviderEntry, models []config.ModelEntry) (*Registry, error) {
|
||||||
|
return NewRegistryWithCircuitBreaker(entries, models, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRegistryWithCircuitBreaker constructs provider implementations with circuit breaker support.
|
||||||
|
// The onStateChange callback is invoked when circuit breaker state changes.
|
||||||
|
func NewRegistryWithCircuitBreaker(
|
||||||
|
entries map[string]config.ProviderEntry,
|
||||||
|
models []config.ModelEntry,
|
||||||
|
onStateChange func(provider, from, to string),
|
||||||
|
) (*Registry, error) {
|
||||||
reg := &Registry{
|
reg := &Registry{
|
||||||
providers: make(map[string]Provider),
|
providers: make(map[string]Provider),
|
||||||
models: make(map[string]string),
|
models: make(map[string]string),
|
||||||
@@ -35,13 +45,18 @@ func NewRegistry(entries map[string]config.ProviderEntry, models []config.ModelE
|
|||||||
modelList: models,
|
modelList: models,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Use default circuit breaker configuration
|
||||||
|
cbConfig := DefaultCircuitBreakerConfig()
|
||||||
|
cbConfig.OnStateChange = onStateChange
|
||||||
|
|
||||||
for name, entry := range entries {
|
for name, entry := range entries {
|
||||||
p, err := buildProvider(entry)
|
p, err := buildProvider(entry)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("provider %q: %w", name, err)
|
return nil, fmt.Errorf("provider %q: %w", name, err)
|
||||||
}
|
}
|
||||||
if p != nil {
|
if p != nil {
|
||||||
reg.providers[name] = p
|
// Wrap provider with circuit breaker
|
||||||
|
reg.providers[name] = NewCircuitBreakerProvider(p, cbConfig)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -9,6 +10,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/sony/gobreaker"
|
||||||
|
|
||||||
"github.com/ajac-zero/latticelm/internal/api"
|
"github.com/ajac-zero/latticelm/internal/api"
|
||||||
"github.com/ajac-zero/latticelm/internal/conversation"
|
"github.com/ajac-zero/latticelm/internal/conversation"
|
||||||
@@ -40,6 +42,11 @@ func New(registry ProviderRegistry, convs conversation.Store, logger *slog.Logge
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isCircuitBreakerError checks if the error is from a circuit breaker.
|
||||||
|
func isCircuitBreakerError(err error) bool {
|
||||||
|
return errors.Is(err, gobreaker.ErrOpenState) || errors.Is(err, gobreaker.ErrTooManyRequests)
|
||||||
|
}
|
||||||
|
|
||||||
// RegisterRoutes wires the HTTP handlers onto the provided mux.
|
// RegisterRoutes wires the HTTP handlers onto the provided mux.
|
||||||
func (s *GatewayServer) RegisterRoutes(mux *http.ServeMux) {
|
func (s *GatewayServer) RegisterRoutes(mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/v1/responses", s.handleResponses)
|
mux.HandleFunc("/v1/responses", s.handleResponses)
|
||||||
@@ -173,7 +180,13 @@ func (s *GatewayServer) handleSyncResponse(w http.ResponseWriter, r *http.Reques
|
|||||||
slog.String("error", err.Error()),
|
slog.String("error", err.Error()),
|
||||||
)...,
|
)...,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Check if error is from circuit breaker
|
||||||
|
if isCircuitBreakerError(err) {
|
||||||
|
http.Error(w, "service temporarily unavailable - circuit breaker open", http.StatusServiceUnavailable)
|
||||||
|
} else {
|
||||||
http.Error(w, "provider error", http.StatusBadGateway)
|
http.Error(w, "provider error", http.StatusBadGateway)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -409,6 +422,15 @@ loop:
|
|||||||
slog.String("error", streamErr.Error()),
|
slog.String("error", streamErr.Error()),
|
||||||
)...,
|
)...,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Determine error type based on circuit breaker state
|
||||||
|
errorType := "server_error"
|
||||||
|
errorMessage := streamErr.Error()
|
||||||
|
if isCircuitBreakerError(streamErr) {
|
||||||
|
errorType = "circuit_breaker_open"
|
||||||
|
errorMessage = "service temporarily unavailable - circuit breaker open"
|
||||||
|
}
|
||||||
|
|
||||||
failedResp := s.buildResponse(origReq, &api.ProviderResult{
|
failedResp := s.buildResponse(origReq, &api.ProviderResult{
|
||||||
Model: origReq.Model,
|
Model: origReq.Model,
|
||||||
}, provider.Name(), responseID)
|
}, provider.Name(), responseID)
|
||||||
@@ -416,8 +438,8 @@ loop:
|
|||||||
failedResp.CompletedAt = nil
|
failedResp.CompletedAt = nil
|
||||||
failedResp.Output = []api.OutputItem{}
|
failedResp.Output = []api.OutputItem{}
|
||||||
failedResp.Error = &api.ResponseError{
|
failedResp.Error = &api.ResponseError{
|
||||||
Type: "server_error",
|
Type: errorType,
|
||||||
Message: streamErr.Error(),
|
Message: errorMessage,
|
||||||
}
|
}
|
||||||
s.sendSSE(w, flusher, &seq, "response.failed", &api.StreamEvent{
|
s.sendSSE(w, flusher, &seq, "response.failed", &api.StreamEvent{
|
||||||
Type: "response.failed",
|
Type: "response.failed",
|
||||||
|
|||||||
Reference in New Issue
Block a user