Add circuit breaker
This commit is contained in:
@@ -2,6 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
@@ -9,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sony/gobreaker"
|
||||
|
||||
"github.com/ajac-zero/latticelm/internal/api"
|
||||
"github.com/ajac-zero/latticelm/internal/conversation"
|
||||
@@ -40,6 +42,11 @@ func New(registry ProviderRegistry, convs conversation.Store, logger *slog.Logge
|
||||
}
|
||||
}
|
||||
|
||||
// isCircuitBreakerError checks if the error is from a circuit breaker.
|
||||
func isCircuitBreakerError(err error) bool {
|
||||
return errors.Is(err, gobreaker.ErrOpenState) || errors.Is(err, gobreaker.ErrTooManyRequests)
|
||||
}
|
||||
|
||||
// RegisterRoutes wires the HTTP handlers onto the provided mux.
|
||||
func (s *GatewayServer) RegisterRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("/v1/responses", s.handleResponses)
|
||||
@@ -173,7 +180,13 @@ func (s *GatewayServer) handleSyncResponse(w http.ResponseWriter, r *http.Reques
|
||||
slog.String("error", err.Error()),
|
||||
)...,
|
||||
)
|
||||
http.Error(w, "provider error", http.StatusBadGateway)
|
||||
|
||||
// Check if error is from circuit breaker
|
||||
if isCircuitBreakerError(err) {
|
||||
http.Error(w, "service temporarily unavailable - circuit breaker open", http.StatusServiceUnavailable)
|
||||
} else {
|
||||
http.Error(w, "provider error", http.StatusBadGateway)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -409,6 +422,15 @@ loop:
|
||||
slog.String("error", streamErr.Error()),
|
||||
)...,
|
||||
)
|
||||
|
||||
// Determine error type based on circuit breaker state
|
||||
errorType := "server_error"
|
||||
errorMessage := streamErr.Error()
|
||||
if isCircuitBreakerError(streamErr) {
|
||||
errorType = "circuit_breaker_open"
|
||||
errorMessage = "service temporarily unavailable - circuit breaker open"
|
||||
}
|
||||
|
||||
failedResp := s.buildResponse(origReq, &api.ProviderResult{
|
||||
Model: origReq.Model,
|
||||
}, provider.Name(), responseID)
|
||||
@@ -416,8 +438,8 @@ loop:
|
||||
failedResp.CompletedAt = nil
|
||||
failedResp.Output = []api.OutputItem{}
|
||||
failedResp.Error = &api.ResponseError{
|
||||
Type: "server_error",
|
||||
Message: streamErr.Error(),
|
||||
Type: errorType,
|
||||
Message: errorMessage,
|
||||
}
|
||||
s.sendSSE(w, flusher, &seq, "response.failed", &api.StreamEvent{
|
||||
Type: "response.failed",
|
||||
|
||||
Reference in New Issue
Block a user