Add fail-fast on init for missing provider credentials
This commit is contained in:
@@ -172,9 +172,32 @@ func Load(path string) (*Config, error) {
|
||||
|
||||
func (cfg *Config) validate() error {
|
||||
for _, m := range cfg.Models {
|
||||
if _, ok := cfg.Providers[m.Provider]; !ok {
|
||||
providerEntry, ok := cfg.Providers[m.Provider]
|
||||
if !ok {
|
||||
return fmt.Errorf("model %q references unknown provider %q", m.Name, m.Provider)
|
||||
}
|
||||
|
||||
switch providerEntry.Type {
|
||||
case "openai", "anthropic", "google", "azureopenai", "azureanthropic":
|
||||
if providerEntry.APIKey == "" {
|
||||
return fmt.Errorf("model %q references provider %q (%s) without api_key", m.Name, m.Provider, providerEntry.Type)
|
||||
}
|
||||
}
|
||||
|
||||
switch providerEntry.Type {
|
||||
case "azureopenai", "azureanthropic":
|
||||
if providerEntry.Endpoint == "" {
|
||||
return fmt.Errorf("model %q references provider %q (%s) without endpoint", m.Name, m.Provider, providerEntry.Type)
|
||||
}
|
||||
case "vertexai":
|
||||
if providerEntry.Project == "" || providerEntry.Location == "" {
|
||||
return fmt.Errorf("model %q references provider %q (vertexai) without project/location", m.Name, m.Provider)
|
||||
}
|
||||
case "openai", "anthropic", "google":
|
||||
// No additional required fields.
|
||||
default:
|
||||
return fmt.Errorf("model %q references provider %q with unknown type %q", m.Name, m.Provider, providerEntry.Type)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -103,7 +103,7 @@ server:
|
||||
address: ":8080"
|
||||
providers:
|
||||
azure:
|
||||
type: azure_openai
|
||||
type: azureopenai
|
||||
api_key: azure-key
|
||||
endpoint: https://my-resource.openai.azure.com
|
||||
api_version: "2024-02-15-preview"
|
||||
@@ -113,7 +113,7 @@ models:
|
||||
provider_model_id: gpt-4-deployment
|
||||
`,
|
||||
validate: func(t *testing.T, cfg *Config) {
|
||||
assert.Equal(t, "azure_openai", cfg.Providers["azure"].Type)
|
||||
assert.Equal(t, "azureopenai", cfg.Providers["azure"].Type)
|
||||
assert.Equal(t, "azure-key", cfg.Providers["azure"].APIKey)
|
||||
assert.Equal(t, "https://my-resource.openai.azure.com", cfg.Providers["azure"].Endpoint)
|
||||
assert.Equal(t, "2024-02-15-preview", cfg.Providers["azure"].APIVersion)
|
||||
@@ -126,7 +126,7 @@ server:
|
||||
address: ":8080"
|
||||
providers:
|
||||
vertex:
|
||||
type: vertex_ai
|
||||
type: vertexai
|
||||
project: my-gcp-project
|
||||
location: us-central1
|
||||
models:
|
||||
@@ -135,7 +135,7 @@ models:
|
||||
provider_model_id: gemini-1.5-pro
|
||||
`,
|
||||
validate: func(t *testing.T, cfg *Config) {
|
||||
assert.Equal(t, "vertex_ai", cfg.Providers["vertex"].Type)
|
||||
assert.Equal(t, "vertexai", cfg.Providers["vertex"].Type)
|
||||
assert.Equal(t, "my-gcp-project", cfg.Providers["vertex"].Project)
|
||||
assert.Equal(t, "us-central1", cfg.Providers["vertex"].Location)
|
||||
},
|
||||
@@ -208,6 +208,20 @@ models:
|
||||
configYAML: `invalid: yaml: content: [unclosed`,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "model references provider without required API key",
|
||||
configYAML: `
|
||||
server:
|
||||
address: ":8080"
|
||||
providers:
|
||||
openai:
|
||||
type: openai
|
||||
models:
|
||||
- name: gpt-4
|
||||
provider: openai
|
||||
`,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "multiple models same provider",
|
||||
configYAML: `
|
||||
@@ -283,7 +297,7 @@ func TestConfigValidate(t *testing.T) {
|
||||
name: "valid config",
|
||||
config: Config{
|
||||
Providers: map[string]ProviderEntry{
|
||||
"openai": {Type: "openai"},
|
||||
"openai": {Type: "openai", APIKey: "test-key"},
|
||||
},
|
||||
Models: []ModelEntry{
|
||||
{Name: "gpt-4", Provider: "openai"},
|
||||
@@ -295,7 +309,7 @@ func TestConfigValidate(t *testing.T) {
|
||||
name: "model references unknown provider",
|
||||
config: Config{
|
||||
Providers: map[string]ProviderEntry{
|
||||
"openai": {Type: "openai"},
|
||||
"openai": {Type: "openai", APIKey: "test-key"},
|
||||
},
|
||||
Models: []ModelEntry{
|
||||
{Name: "gpt-4", Provider: "unknown"},
|
||||
@@ -303,6 +317,18 @@ func TestConfigValidate(t *testing.T) {
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "model references provider without api key",
|
||||
config: Config{
|
||||
Providers: map[string]ProviderEntry{
|
||||
"openai": {Type: "openai"},
|
||||
},
|
||||
Models: []ModelEntry{
|
||||
{Name: "gpt-4", Provider: "openai"},
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "no models",
|
||||
config: Config{
|
||||
@@ -317,8 +343,8 @@ func TestConfigValidate(t *testing.T) {
|
||||
name: "multiple models multiple providers",
|
||||
config: Config{
|
||||
Providers: map[string]ProviderEntry{
|
||||
"openai": {Type: "openai"},
|
||||
"anthropic": {Type: "anthropic"},
|
||||
"openai": {Type: "openai", APIKey: "test-key"},
|
||||
"anthropic": {Type: "anthropic", APIKey: "ant-key"},
|
||||
},
|
||||
Models: []ModelEntry{
|
||||
{Name: "gpt-4", Provider: "openai"},
|
||||
|
||||
@@ -48,15 +48,30 @@ type metricsResponseWriter struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
bytesWritten int
|
||||
wroteHeader bool
|
||||
}
|
||||
|
||||
func (w *metricsResponseWriter) WriteHeader(statusCode int) {
|
||||
if w.wroteHeader {
|
||||
return
|
||||
}
|
||||
w.wroteHeader = true
|
||||
w.statusCode = statusCode
|
||||
w.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func (w *metricsResponseWriter) Write(b []byte) (int, error) {
|
||||
if !w.wroteHeader {
|
||||
w.wroteHeader = true
|
||||
w.statusCode = http.StatusOK
|
||||
}
|
||||
n, err := w.ResponseWriter.Write(b)
|
||||
w.bytesWritten += n
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (w *metricsResponseWriter) Flush() {
|
||||
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
65
internal/observability/middleware_response_writer_test.go
Normal file
65
internal/observability/middleware_response_writer_test.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package observability
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var _ http.Flusher = (*metricsResponseWriter)(nil)
|
||||
var _ http.Flusher = (*statusResponseWriter)(nil)
|
||||
|
||||
type testFlusherRecorder struct {
|
||||
*httptest.ResponseRecorder
|
||||
flushCount int
|
||||
}
|
||||
|
||||
func newTestFlusherRecorder() *testFlusherRecorder {
|
||||
return &testFlusherRecorder{ResponseRecorder: httptest.NewRecorder()}
|
||||
}
|
||||
|
||||
func (r *testFlusherRecorder) Flush() {
|
||||
r.flushCount++
|
||||
}
|
||||
|
||||
func TestMetricsResponseWriterWriteHeaderOnlyOnce(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
rw := &metricsResponseWriter{ResponseWriter: rec, statusCode: http.StatusOK}
|
||||
|
||||
rw.WriteHeader(http.StatusAccepted)
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
|
||||
assert.Equal(t, http.StatusAccepted, rec.Code)
|
||||
assert.Equal(t, http.StatusAccepted, rw.statusCode)
|
||||
}
|
||||
|
||||
func TestMetricsResponseWriterFlushDelegates(t *testing.T) {
|
||||
rec := newTestFlusherRecorder()
|
||||
rw := &metricsResponseWriter{ResponseWriter: rec, statusCode: http.StatusOK}
|
||||
|
||||
rw.Flush()
|
||||
|
||||
assert.Equal(t, 1, rec.flushCount)
|
||||
}
|
||||
|
||||
func TestStatusResponseWriterWriteHeaderOnlyOnce(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
rw := &statusResponseWriter{ResponseWriter: rec, statusCode: http.StatusOK}
|
||||
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
|
||||
assert.Equal(t, http.StatusNoContent, rec.Code)
|
||||
assert.Equal(t, http.StatusNoContent, rw.statusCode)
|
||||
}
|
||||
|
||||
func TestStatusResponseWriterFlushDelegates(t *testing.T) {
|
||||
rec := newTestFlusherRecorder()
|
||||
rw := &statusResponseWriter{ResponseWriter: rec, statusCode: http.StatusOK}
|
||||
|
||||
rw.Flush()
|
||||
|
||||
assert.Equal(t, 1, rec.flushCount)
|
||||
}
|
||||
@@ -72,14 +72,29 @@ func TracingMiddleware(next http.Handler, tp *sdktrace.TracerProvider) http.Hand
|
||||
// statusResponseWriter wraps http.ResponseWriter to capture the status code.
|
||||
type statusResponseWriter struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
statusCode int
|
||||
wroteHeader bool
|
||||
}
|
||||
|
||||
func (w *statusResponseWriter) WriteHeader(statusCode int) {
|
||||
if w.wroteHeader {
|
||||
return
|
||||
}
|
||||
w.wroteHeader = true
|
||||
w.statusCode = statusCode
|
||||
w.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func (w *statusResponseWriter) Write(b []byte) (int, error) {
|
||||
if !w.wroteHeader {
|
||||
w.wroteHeader = true
|
||||
w.statusCode = http.StatusOK
|
||||
}
|
||||
return w.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
func (w *statusResponseWriter) Flush() {
|
||||
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -136,6 +136,9 @@ func (r *Registry) Get(name string) (Provider, bool) {
|
||||
func (r *Registry) Models() []struct{ Provider, Model string } {
|
||||
var out []struct{ Provider, Model string }
|
||||
for _, m := range r.modelList {
|
||||
if _, ok := r.providers[m.Provider]; !ok {
|
||||
continue
|
||||
}
|
||||
out = append(out, struct{ Provider, Model string }{Provider: m.Provider, Model: m.Name})
|
||||
}
|
||||
return out
|
||||
@@ -156,7 +159,9 @@ func (r *Registry) Default(model string) (Provider, error) {
|
||||
if p, ok := r.providers[providerName]; ok {
|
||||
return p, nil
|
||||
}
|
||||
return nil, fmt.Errorf("model %q is mapped to provider %q, but that provider is not available", model, providerName)
|
||||
}
|
||||
return nil, fmt.Errorf("model %q not configured", model)
|
||||
}
|
||||
|
||||
for _, p := range r.providers {
|
||||
|
||||
@@ -475,7 +475,7 @@ func TestRegistry_Default(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "returns first provider for unknown model",
|
||||
name: "returns error for unknown model",
|
||||
setupReg: func() *Registry {
|
||||
reg, _ := NewRegistry(
|
||||
map[string]config.ProviderEntry{
|
||||
@@ -490,11 +490,34 @@ func TestRegistry_Default(t *testing.T) {
|
||||
)
|
||||
return reg
|
||||
},
|
||||
modelName: "unknown-model",
|
||||
validate: func(t *testing.T, p Provider) {
|
||||
assert.NotNil(t, p)
|
||||
// Should return first available provider
|
||||
modelName: "unknown-model",
|
||||
expectError: true,
|
||||
errorMsg: "not configured",
|
||||
},
|
||||
{
|
||||
name: "returns error for model whose provider is unavailable",
|
||||
setupReg: func() *Registry {
|
||||
reg, _ := NewRegistry(
|
||||
map[string]config.ProviderEntry{
|
||||
"openai": {
|
||||
Type: "openai",
|
||||
APIKey: "", // unavailable provider
|
||||
},
|
||||
"google": {
|
||||
Type: "google",
|
||||
APIKey: "test-key",
|
||||
},
|
||||
},
|
||||
[]config.ModelEntry{
|
||||
{Name: "gpt-4", Provider: "openai"},
|
||||
{Name: "gemini-pro", Provider: "google"},
|
||||
},
|
||||
)
|
||||
return reg
|
||||
},
|
||||
modelName: "gpt-4",
|
||||
expectError: true,
|
||||
errorMsg: "not available",
|
||||
},
|
||||
{
|
||||
name: "returns first provider for empty model name",
|
||||
@@ -542,6 +565,31 @@ func TestRegistry_Default(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Models_FiltersUnavailableProviders(t *testing.T) {
|
||||
reg, err := NewRegistry(
|
||||
map[string]config.ProviderEntry{
|
||||
"openai": {
|
||||
Type: "openai",
|
||||
APIKey: "", // unavailable provider
|
||||
},
|
||||
"google": {
|
||||
Type: "google",
|
||||
APIKey: "test-key",
|
||||
},
|
||||
},
|
||||
[]config.ModelEntry{
|
||||
{Name: "gpt-4", Provider: "openai"},
|
||||
{Name: "gemini-pro", Provider: "google"},
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
models := reg.Models()
|
||||
require.Len(t, models, 1)
|
||||
assert.Equal(t, "gemini-pro", models[0].Model)
|
||||
assert.Equal(t, "google", models[0].Provider)
|
||||
}
|
||||
|
||||
func TestBuildProvider(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -239,17 +239,17 @@ func (s *GatewayServer) handleSyncResponse(w http.ResponseWriter, r *http.Reques
|
||||
}
|
||||
|
||||
func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.Request, provider providers.Provider, providerMsgs []api.Message, resolvedReq *api.ResponseRequest, origReq *api.ResponseRequest, storeMsgs []api.Message) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "streaming not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
responseID := generateID("resp_")
|
||||
itemID := generateID("msg_")
|
||||
seq := 0
|
||||
|
||||
53
internal/server/streaming_writer_test.go
Normal file
53
internal/server/streaming_writer_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type nonFlusherRecorder struct {
|
||||
recorder *httptest.ResponseRecorder
|
||||
writeHeaderCalls int
|
||||
}
|
||||
|
||||
func newNonFlusherRecorder() *nonFlusherRecorder {
|
||||
return &nonFlusherRecorder{recorder: httptest.NewRecorder()}
|
||||
}
|
||||
|
||||
func (w *nonFlusherRecorder) Header() http.Header {
|
||||
return w.recorder.Header()
|
||||
}
|
||||
|
||||
func (w *nonFlusherRecorder) Write(b []byte) (int, error) {
|
||||
return w.recorder.Write(b)
|
||||
}
|
||||
|
||||
func (w *nonFlusherRecorder) WriteHeader(statusCode int) {
|
||||
w.writeHeaderCalls++
|
||||
w.recorder.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func (w *nonFlusherRecorder) StatusCode() int {
|
||||
return w.recorder.Code
|
||||
}
|
||||
|
||||
func (w *nonFlusherRecorder) BodyString() string {
|
||||
return w.recorder.Body.String()
|
||||
}
|
||||
|
||||
func TestHandleStreamingResponseWithoutFlusherWritesSingleErrorHeader(t *testing.T) {
|
||||
s := New(nil, nil, slog.New(slog.NewTextHandler(io.Discard, nil)))
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
w := newNonFlusherRecorder()
|
||||
|
||||
s.handleStreamingResponse(w, req, nil, nil, nil, nil, nil)
|
||||
|
||||
assert.Equal(t, 1, w.writeHeaderCalls)
|
||||
assert.Equal(t, http.StatusInternalServerError, w.StatusCode())
|
||||
assert.Contains(t, w.BodyString(), "streaming not supported")
|
||||
}
|
||||
Reference in New Issue
Block a user