Add fail-fast on init for missing provider credentials

This commit is contained in:
2026-03-06 21:31:51 +00:00
parent 610b6c3367
commit 89c7e3ac85
14 changed files with 398 additions and 20 deletions

View File

@@ -172,9 +172,32 @@ func Load(path string) (*Config, error) {
func (cfg *Config) validate() error {
for _, m := range cfg.Models {
if _, ok := cfg.Providers[m.Provider]; !ok {
providerEntry, ok := cfg.Providers[m.Provider]
if !ok {
return fmt.Errorf("model %q references unknown provider %q", m.Name, m.Provider)
}
switch providerEntry.Type {
case "openai", "anthropic", "google", "azureopenai", "azureanthropic":
if providerEntry.APIKey == "" {
return fmt.Errorf("model %q references provider %q (%s) without api_key", m.Name, m.Provider, providerEntry.Type)
}
}
switch providerEntry.Type {
case "azureopenai", "azureanthropic":
if providerEntry.Endpoint == "" {
return fmt.Errorf("model %q references provider %q (%s) without endpoint", m.Name, m.Provider, providerEntry.Type)
}
case "vertexai":
if providerEntry.Project == "" || providerEntry.Location == "" {
return fmt.Errorf("model %q references provider %q (vertexai) without project/location", m.Name, m.Provider)
}
case "openai", "anthropic", "google":
// No additional required fields.
default:
return fmt.Errorf("model %q references provider %q with unknown type %q", m.Name, m.Provider, providerEntry.Type)
}
}
return nil
}

View File

@@ -103,7 +103,7 @@ server:
address: ":8080"
providers:
azure:
type: azure_openai
type: azureopenai
api_key: azure-key
endpoint: https://my-resource.openai.azure.com
api_version: "2024-02-15-preview"
@@ -113,7 +113,7 @@ models:
provider_model_id: gpt-4-deployment
`,
validate: func(t *testing.T, cfg *Config) {
assert.Equal(t, "azure_openai", cfg.Providers["azure"].Type)
assert.Equal(t, "azureopenai", cfg.Providers["azure"].Type)
assert.Equal(t, "azure-key", cfg.Providers["azure"].APIKey)
assert.Equal(t, "https://my-resource.openai.azure.com", cfg.Providers["azure"].Endpoint)
assert.Equal(t, "2024-02-15-preview", cfg.Providers["azure"].APIVersion)
@@ -126,7 +126,7 @@ server:
address: ":8080"
providers:
vertex:
type: vertex_ai
type: vertexai
project: my-gcp-project
location: us-central1
models:
@@ -135,7 +135,7 @@ models:
provider_model_id: gemini-1.5-pro
`,
validate: func(t *testing.T, cfg *Config) {
assert.Equal(t, "vertex_ai", cfg.Providers["vertex"].Type)
assert.Equal(t, "vertexai", cfg.Providers["vertex"].Type)
assert.Equal(t, "my-gcp-project", cfg.Providers["vertex"].Project)
assert.Equal(t, "us-central1", cfg.Providers["vertex"].Location)
},
@@ -208,6 +208,20 @@ models:
configYAML: `invalid: yaml: content: [unclosed`,
expectError: true,
},
{
name: "model references provider without required API key",
configYAML: `
server:
address: ":8080"
providers:
openai:
type: openai
models:
- name: gpt-4
provider: openai
`,
expectError: true,
},
{
name: "multiple models same provider",
configYAML: `
@@ -283,7 +297,7 @@ func TestConfigValidate(t *testing.T) {
name: "valid config",
config: Config{
Providers: map[string]ProviderEntry{
"openai": {Type: "openai"},
"openai": {Type: "openai", APIKey: "test-key"},
},
Models: []ModelEntry{
{Name: "gpt-4", Provider: "openai"},
@@ -295,7 +309,7 @@ func TestConfigValidate(t *testing.T) {
name: "model references unknown provider",
config: Config{
Providers: map[string]ProviderEntry{
"openai": {Type: "openai"},
"openai": {Type: "openai", APIKey: "test-key"},
},
Models: []ModelEntry{
{Name: "gpt-4", Provider: "unknown"},
@@ -303,6 +317,18 @@ func TestConfigValidate(t *testing.T) {
},
expectError: true,
},
{
name: "model references provider without api key",
config: Config{
Providers: map[string]ProviderEntry{
"openai": {Type: "openai"},
},
Models: []ModelEntry{
{Name: "gpt-4", Provider: "openai"},
},
},
expectError: true,
},
{
name: "no models",
config: Config{
@@ -317,8 +343,8 @@ func TestConfigValidate(t *testing.T) {
name: "multiple models multiple providers",
config: Config{
Providers: map[string]ProviderEntry{
"openai": {Type: "openai"},
"anthropic": {Type: "anthropic"},
"openai": {Type: "openai", APIKey: "test-key"},
"anthropic": {Type: "anthropic", APIKey: "ant-key"},
},
Models: []ModelEntry{
{Name: "gpt-4", Provider: "openai"},

View File

@@ -48,15 +48,30 @@ type metricsResponseWriter struct {
http.ResponseWriter
statusCode int
bytesWritten int
wroteHeader bool
}
func (w *metricsResponseWriter) WriteHeader(statusCode int) {
if w.wroteHeader {
return
}
w.wroteHeader = true
w.statusCode = statusCode
w.ResponseWriter.WriteHeader(statusCode)
}
func (w *metricsResponseWriter) Write(b []byte) (int, error) {
if !w.wroteHeader {
w.wroteHeader = true
w.statusCode = http.StatusOK
}
n, err := w.ResponseWriter.Write(b)
w.bytesWritten += n
return n, err
}
func (w *metricsResponseWriter) Flush() {
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}

View File

@@ -0,0 +1,65 @@
package observability
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
var _ http.Flusher = (*metricsResponseWriter)(nil)
var _ http.Flusher = (*statusResponseWriter)(nil)
type testFlusherRecorder struct {
*httptest.ResponseRecorder
flushCount int
}
func newTestFlusherRecorder() *testFlusherRecorder {
return &testFlusherRecorder{ResponseRecorder: httptest.NewRecorder()}
}
func (r *testFlusherRecorder) Flush() {
r.flushCount++
}
func TestMetricsResponseWriterWriteHeaderOnlyOnce(t *testing.T) {
rec := httptest.NewRecorder()
rw := &metricsResponseWriter{ResponseWriter: rec, statusCode: http.StatusOK}
rw.WriteHeader(http.StatusAccepted)
rw.WriteHeader(http.StatusInternalServerError)
assert.Equal(t, http.StatusAccepted, rec.Code)
assert.Equal(t, http.StatusAccepted, rw.statusCode)
}
func TestMetricsResponseWriterFlushDelegates(t *testing.T) {
rec := newTestFlusherRecorder()
rw := &metricsResponseWriter{ResponseWriter: rec, statusCode: http.StatusOK}
rw.Flush()
assert.Equal(t, 1, rec.flushCount)
}
func TestStatusResponseWriterWriteHeaderOnlyOnce(t *testing.T) {
rec := httptest.NewRecorder()
rw := &statusResponseWriter{ResponseWriter: rec, statusCode: http.StatusOK}
rw.WriteHeader(http.StatusNoContent)
rw.WriteHeader(http.StatusInternalServerError)
assert.Equal(t, http.StatusNoContent, rec.Code)
assert.Equal(t, http.StatusNoContent, rw.statusCode)
}
func TestStatusResponseWriterFlushDelegates(t *testing.T) {
rec := newTestFlusherRecorder()
rw := &statusResponseWriter{ResponseWriter: rec, statusCode: http.StatusOK}
rw.Flush()
assert.Equal(t, 1, rec.flushCount)
}

View File

@@ -72,14 +72,29 @@ func TracingMiddleware(next http.Handler, tp *sdktrace.TracerProvider) http.Hand
// statusResponseWriter wraps http.ResponseWriter to capture the status code.
type statusResponseWriter struct {
http.ResponseWriter
statusCode int
statusCode int
wroteHeader bool
}
func (w *statusResponseWriter) WriteHeader(statusCode int) {
if w.wroteHeader {
return
}
w.wroteHeader = true
w.statusCode = statusCode
w.ResponseWriter.WriteHeader(statusCode)
}
func (w *statusResponseWriter) Write(b []byte) (int, error) {
if !w.wroteHeader {
w.wroteHeader = true
w.statusCode = http.StatusOK
}
return w.ResponseWriter.Write(b)
}
func (w *statusResponseWriter) Flush() {
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}

View File

@@ -136,6 +136,9 @@ func (r *Registry) Get(name string) (Provider, bool) {
func (r *Registry) Models() []struct{ Provider, Model string } {
var out []struct{ Provider, Model string }
for _, m := range r.modelList {
if _, ok := r.providers[m.Provider]; !ok {
continue
}
out = append(out, struct{ Provider, Model string }{Provider: m.Provider, Model: m.Name})
}
return out
@@ -156,7 +159,9 @@ func (r *Registry) Default(model string) (Provider, error) {
if p, ok := r.providers[providerName]; ok {
return p, nil
}
return nil, fmt.Errorf("model %q is mapped to provider %q, but that provider is not available", model, providerName)
}
return nil, fmt.Errorf("model %q not configured", model)
}
for _, p := range r.providers {

View File

@@ -475,7 +475,7 @@ func TestRegistry_Default(t *testing.T) {
},
},
{
name: "returns first provider for unknown model",
name: "returns error for unknown model",
setupReg: func() *Registry {
reg, _ := NewRegistry(
map[string]config.ProviderEntry{
@@ -490,11 +490,34 @@ func TestRegistry_Default(t *testing.T) {
)
return reg
},
modelName: "unknown-model",
validate: func(t *testing.T, p Provider) {
assert.NotNil(t, p)
// Should return first available provider
modelName: "unknown-model",
expectError: true,
errorMsg: "not configured",
},
{
name: "returns error for model whose provider is unavailable",
setupReg: func() *Registry {
reg, _ := NewRegistry(
map[string]config.ProviderEntry{
"openai": {
Type: "openai",
APIKey: "", // unavailable provider
},
"google": {
Type: "google",
APIKey: "test-key",
},
},
[]config.ModelEntry{
{Name: "gpt-4", Provider: "openai"},
{Name: "gemini-pro", Provider: "google"},
},
)
return reg
},
modelName: "gpt-4",
expectError: true,
errorMsg: "not available",
},
{
name: "returns first provider for empty model name",
@@ -542,6 +565,31 @@ func TestRegistry_Default(t *testing.T) {
}
}
func TestRegistry_Models_FiltersUnavailableProviders(t *testing.T) {
reg, err := NewRegistry(
map[string]config.ProviderEntry{
"openai": {
Type: "openai",
APIKey: "", // unavailable provider
},
"google": {
Type: "google",
APIKey: "test-key",
},
},
[]config.ModelEntry{
{Name: "gpt-4", Provider: "openai"},
{Name: "gemini-pro", Provider: "google"},
},
)
require.NoError(t, err)
models := reg.Models()
require.Len(t, models, 1)
assert.Equal(t, "gemini-pro", models[0].Model)
assert.Equal(t, "google", models[0].Provider)
}
func TestBuildProvider(t *testing.T) {
tests := []struct {
name string

View File

@@ -239,17 +239,17 @@ func (s *GatewayServer) handleSyncResponse(w http.ResponseWriter, r *http.Reques
}
func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.Request, provider providers.Provider, providerMsgs []api.Message, resolvedReq *api.ResponseRequest, origReq *api.ResponseRequest, storeMsgs []api.Message) {
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "streaming not supported", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.WriteHeader(http.StatusOK)
responseID := generateID("resp_")
itemID := generateID("msg_")
seq := 0

View File

@@ -0,0 +1,53 @@
package server
import (
"io"
"log/slog"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
type nonFlusherRecorder struct {
recorder *httptest.ResponseRecorder
writeHeaderCalls int
}
func newNonFlusherRecorder() *nonFlusherRecorder {
return &nonFlusherRecorder{recorder: httptest.NewRecorder()}
}
func (w *nonFlusherRecorder) Header() http.Header {
return w.recorder.Header()
}
func (w *nonFlusherRecorder) Write(b []byte) (int, error) {
return w.recorder.Write(b)
}
func (w *nonFlusherRecorder) WriteHeader(statusCode int) {
w.writeHeaderCalls++
w.recorder.WriteHeader(statusCode)
}
func (w *nonFlusherRecorder) StatusCode() int {
return w.recorder.Code
}
func (w *nonFlusherRecorder) BodyString() string {
return w.recorder.Body.String()
}
func TestHandleStreamingResponseWithoutFlusherWritesSingleErrorHeader(t *testing.T) {
s := New(nil, nil, slog.New(slog.NewTextHandler(io.Discard, nil)))
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
w := newNonFlusherRecorder()
s.handleStreamingResponse(w, req, nil, nil, nil, nil, nil)
assert.Equal(t, 1, w.writeHeaderCalls)
assert.Equal(t, http.StatusInternalServerError, w.StatusCode())
assert.Contains(t, w.BodyString(), "streaming not supported")
}