Add fail-fast on init for missing provider credentials

This commit is contained in:
2026-03-06 21:31:51 +00:00
parent 610b6c3367
commit 89c7e3ac85
14 changed files with 398 additions and 20 deletions

View File

@@ -155,6 +155,11 @@ func main() {
// Register admin endpoints if enabled
if cfg.Admin.Enabled {
// Check if frontend dist exists
if _, err := os.Stat("internal/admin/dist"); os.IsNotExist(err) {
log.Fatalf("admin UI enabled but frontend dist not found")
}
buildInfo := admin.BuildInfo{
Version: "dev",
BuildTime: time.Now().Format(time.RFC3339),
@@ -348,23 +353,39 @@ func initConversationStore(cfg config.ConversationConfig, logger *slog.Logger) (
return conversation.NewMemoryStore(ttl), "memory", nil
}
}
type responseWriter struct {
http.ResponseWriter
statusCode int
bytesWritten int
wroteHeader bool
}
func (rw *responseWriter) WriteHeader(code int) {
if rw.wroteHeader {
return
}
rw.wroteHeader = true
rw.statusCode = code
rw.ResponseWriter.WriteHeader(code)
}
func (rw *responseWriter) Write(b []byte) (int, error) {
if !rw.wroteHeader {
rw.wroteHeader = true
rw.statusCode = http.StatusOK
}
n, err := rw.ResponseWriter.Write(b)
rw.bytesWritten += n
return n, err
}
func (rw *responseWriter) Flush() {
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
func loggingMiddleware(next http.Handler, logger *slog.Logger) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()

57
cmd/gateway/main_test.go Normal file
View File

@@ -0,0 +1,57 @@
package main
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
var _ http.Flusher = (*responseWriter)(nil)
type countingFlusherRecorder struct {
*httptest.ResponseRecorder
flushCount int
}
func newCountingFlusherRecorder() *countingFlusherRecorder {
return &countingFlusherRecorder{ResponseRecorder: httptest.NewRecorder()}
}
func (r *countingFlusherRecorder) Flush() {
r.flushCount++
}
func TestResponseWriterWriteHeaderOnlyOnce(t *testing.T) {
rec := httptest.NewRecorder()
rw := &responseWriter{ResponseWriter: rec, statusCode: http.StatusOK}
rw.WriteHeader(http.StatusCreated)
rw.WriteHeader(http.StatusInternalServerError)
assert.Equal(t, http.StatusCreated, rec.Code)
assert.Equal(t, http.StatusCreated, rw.statusCode)
}
func TestResponseWriterWriteSetsImplicitStatus(t *testing.T) {
rec := httptest.NewRecorder()
rw := &responseWriter{ResponseWriter: rec, statusCode: http.StatusOK}
n, err := rw.Write([]byte("ok"))
assert.NoError(t, err)
assert.Equal(t, 2, n)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, http.StatusOK, rw.statusCode)
assert.Equal(t, 2, rw.bytesWritten)
}
func TestResponseWriterFlushDelegates(t *testing.T) {
rec := newCountingFlusherRecorder()
rw := &responseWriter{ResponseWriter: rec, statusCode: http.StatusOK}
rw.Flush()
assert.Equal(t, 1, rec.flushCount)
}

View File

@@ -6,6 +6,7 @@ export default defineConfig({
base: '/admin/',
server: {
port: 5173,
allowedHosts: ['.coder.ia-innovacion.work', 'localhost'],
proxy: {
'/admin/api': {
target: 'http://localhost:8080',

View File

@@ -172,9 +172,32 @@ func Load(path string) (*Config, error) {
func (cfg *Config) validate() error {
for _, m := range cfg.Models {
if _, ok := cfg.Providers[m.Provider]; !ok {
providerEntry, ok := cfg.Providers[m.Provider]
if !ok {
return fmt.Errorf("model %q references unknown provider %q", m.Name, m.Provider)
}
switch providerEntry.Type {
case "openai", "anthropic", "google", "azureopenai", "azureanthropic":
if providerEntry.APIKey == "" {
return fmt.Errorf("model %q references provider %q (%s) without api_key", m.Name, m.Provider, providerEntry.Type)
}
}
switch providerEntry.Type {
case "azureopenai", "azureanthropic":
if providerEntry.Endpoint == "" {
return fmt.Errorf("model %q references provider %q (%s) without endpoint", m.Name, m.Provider, providerEntry.Type)
}
case "vertexai":
if providerEntry.Project == "" || providerEntry.Location == "" {
return fmt.Errorf("model %q references provider %q (vertexai) without project/location", m.Name, m.Provider)
}
case "openai", "anthropic", "google":
// No additional required fields.
default:
return fmt.Errorf("model %q references provider %q with unknown type %q", m.Name, m.Provider, providerEntry.Type)
}
}
return nil
}

View File

@@ -103,7 +103,7 @@ server:
address: ":8080"
providers:
azure:
type: azure_openai
type: azureopenai
api_key: azure-key
endpoint: https://my-resource.openai.azure.com
api_version: "2024-02-15-preview"
@@ -113,7 +113,7 @@ models:
provider_model_id: gpt-4-deployment
`,
validate: func(t *testing.T, cfg *Config) {
assert.Equal(t, "azure_openai", cfg.Providers["azure"].Type)
assert.Equal(t, "azureopenai", cfg.Providers["azure"].Type)
assert.Equal(t, "azure-key", cfg.Providers["azure"].APIKey)
assert.Equal(t, "https://my-resource.openai.azure.com", cfg.Providers["azure"].Endpoint)
assert.Equal(t, "2024-02-15-preview", cfg.Providers["azure"].APIVersion)
@@ -126,7 +126,7 @@ server:
address: ":8080"
providers:
vertex:
type: vertex_ai
type: vertexai
project: my-gcp-project
location: us-central1
models:
@@ -135,7 +135,7 @@ models:
provider_model_id: gemini-1.5-pro
`,
validate: func(t *testing.T, cfg *Config) {
assert.Equal(t, "vertex_ai", cfg.Providers["vertex"].Type)
assert.Equal(t, "vertexai", cfg.Providers["vertex"].Type)
assert.Equal(t, "my-gcp-project", cfg.Providers["vertex"].Project)
assert.Equal(t, "us-central1", cfg.Providers["vertex"].Location)
},
@@ -208,6 +208,20 @@ models:
configYAML: `invalid: yaml: content: [unclosed`,
expectError: true,
},
{
name: "model references provider without required API key",
configYAML: `
server:
address: ":8080"
providers:
openai:
type: openai
models:
- name: gpt-4
provider: openai
`,
expectError: true,
},
{
name: "multiple models same provider",
configYAML: `
@@ -283,7 +297,7 @@ func TestConfigValidate(t *testing.T) {
name: "valid config",
config: Config{
Providers: map[string]ProviderEntry{
"openai": {Type: "openai"},
"openai": {Type: "openai", APIKey: "test-key"},
},
Models: []ModelEntry{
{Name: "gpt-4", Provider: "openai"},
@@ -295,7 +309,7 @@ func TestConfigValidate(t *testing.T) {
name: "model references unknown provider",
config: Config{
Providers: map[string]ProviderEntry{
"openai": {Type: "openai"},
"openai": {Type: "openai", APIKey: "test-key"},
},
Models: []ModelEntry{
{Name: "gpt-4", Provider: "unknown"},
@@ -303,6 +317,18 @@ func TestConfigValidate(t *testing.T) {
},
expectError: true,
},
{
name: "model references provider without api key",
config: Config{
Providers: map[string]ProviderEntry{
"openai": {Type: "openai"},
},
Models: []ModelEntry{
{Name: "gpt-4", Provider: "openai"},
},
},
expectError: true,
},
{
name: "no models",
config: Config{
@@ -317,8 +343,8 @@ func TestConfigValidate(t *testing.T) {
name: "multiple models multiple providers",
config: Config{
Providers: map[string]ProviderEntry{
"openai": {Type: "openai"},
"anthropic": {Type: "anthropic"},
"openai": {Type: "openai", APIKey: "test-key"},
"anthropic": {Type: "anthropic", APIKey: "ant-key"},
},
Models: []ModelEntry{
{Name: "gpt-4", Provider: "openai"},

View File

@@ -48,15 +48,30 @@ type metricsResponseWriter struct {
http.ResponseWriter
statusCode int
bytesWritten int
wroteHeader bool
}
func (w *metricsResponseWriter) WriteHeader(statusCode int) {
if w.wroteHeader {
return
}
w.wroteHeader = true
w.statusCode = statusCode
w.ResponseWriter.WriteHeader(statusCode)
}
func (w *metricsResponseWriter) Write(b []byte) (int, error) {
if !w.wroteHeader {
w.wroteHeader = true
w.statusCode = http.StatusOK
}
n, err := w.ResponseWriter.Write(b)
w.bytesWritten += n
return n, err
}
func (w *metricsResponseWriter) Flush() {
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}

View File

@@ -0,0 +1,65 @@
package observability
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
var _ http.Flusher = (*metricsResponseWriter)(nil)
var _ http.Flusher = (*statusResponseWriter)(nil)
type testFlusherRecorder struct {
*httptest.ResponseRecorder
flushCount int
}
func newTestFlusherRecorder() *testFlusherRecorder {
return &testFlusherRecorder{ResponseRecorder: httptest.NewRecorder()}
}
func (r *testFlusherRecorder) Flush() {
r.flushCount++
}
func TestMetricsResponseWriterWriteHeaderOnlyOnce(t *testing.T) {
rec := httptest.NewRecorder()
rw := &metricsResponseWriter{ResponseWriter: rec, statusCode: http.StatusOK}
rw.WriteHeader(http.StatusAccepted)
rw.WriteHeader(http.StatusInternalServerError)
assert.Equal(t, http.StatusAccepted, rec.Code)
assert.Equal(t, http.StatusAccepted, rw.statusCode)
}
func TestMetricsResponseWriterFlushDelegates(t *testing.T) {
rec := newTestFlusherRecorder()
rw := &metricsResponseWriter{ResponseWriter: rec, statusCode: http.StatusOK}
rw.Flush()
assert.Equal(t, 1, rec.flushCount)
}
func TestStatusResponseWriterWriteHeaderOnlyOnce(t *testing.T) {
rec := httptest.NewRecorder()
rw := &statusResponseWriter{ResponseWriter: rec, statusCode: http.StatusOK}
rw.WriteHeader(http.StatusNoContent)
rw.WriteHeader(http.StatusInternalServerError)
assert.Equal(t, http.StatusNoContent, rec.Code)
assert.Equal(t, http.StatusNoContent, rw.statusCode)
}
func TestStatusResponseWriterFlushDelegates(t *testing.T) {
rec := newTestFlusherRecorder()
rw := &statusResponseWriter{ResponseWriter: rec, statusCode: http.StatusOK}
rw.Flush()
assert.Equal(t, 1, rec.flushCount)
}

View File

@@ -72,14 +72,29 @@ func TracingMiddleware(next http.Handler, tp *sdktrace.TracerProvider) http.Hand
// statusResponseWriter wraps http.ResponseWriter to capture the status code.
type statusResponseWriter struct {
http.ResponseWriter
statusCode int
statusCode int
wroteHeader bool
}
func (w *statusResponseWriter) WriteHeader(statusCode int) {
if w.wroteHeader {
return
}
w.wroteHeader = true
w.statusCode = statusCode
w.ResponseWriter.WriteHeader(statusCode)
}
func (w *statusResponseWriter) Write(b []byte) (int, error) {
if !w.wroteHeader {
w.wroteHeader = true
w.statusCode = http.StatusOK
}
return w.ResponseWriter.Write(b)
}
func (w *statusResponseWriter) Flush() {
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}

View File

@@ -136,6 +136,9 @@ func (r *Registry) Get(name string) (Provider, bool) {
func (r *Registry) Models() []struct{ Provider, Model string } {
var out []struct{ Provider, Model string }
for _, m := range r.modelList {
if _, ok := r.providers[m.Provider]; !ok {
continue
}
out = append(out, struct{ Provider, Model string }{Provider: m.Provider, Model: m.Name})
}
return out
@@ -156,7 +159,9 @@ func (r *Registry) Default(model string) (Provider, error) {
if p, ok := r.providers[providerName]; ok {
return p, nil
}
return nil, fmt.Errorf("model %q is mapped to provider %q, but that provider is not available", model, providerName)
}
return nil, fmt.Errorf("model %q not configured", model)
}
for _, p := range r.providers {

View File

@@ -475,7 +475,7 @@ func TestRegistry_Default(t *testing.T) {
},
},
{
name: "returns first provider for unknown model",
name: "returns error for unknown model",
setupReg: func() *Registry {
reg, _ := NewRegistry(
map[string]config.ProviderEntry{
@@ -490,11 +490,34 @@ func TestRegistry_Default(t *testing.T) {
)
return reg
},
modelName: "unknown-model",
validate: func(t *testing.T, p Provider) {
assert.NotNil(t, p)
// Should return first available provider
modelName: "unknown-model",
expectError: true,
errorMsg: "not configured",
},
{
name: "returns error for model whose provider is unavailable",
setupReg: func() *Registry {
reg, _ := NewRegistry(
map[string]config.ProviderEntry{
"openai": {
Type: "openai",
APIKey: "", // unavailable provider
},
"google": {
Type: "google",
APIKey: "test-key",
},
},
[]config.ModelEntry{
{Name: "gpt-4", Provider: "openai"},
{Name: "gemini-pro", Provider: "google"},
},
)
return reg
},
modelName: "gpt-4",
expectError: true,
errorMsg: "not available",
},
{
name: "returns first provider for empty model name",
@@ -542,6 +565,31 @@ func TestRegistry_Default(t *testing.T) {
}
}
func TestRegistry_Models_FiltersUnavailableProviders(t *testing.T) {
reg, err := NewRegistry(
map[string]config.ProviderEntry{
"openai": {
Type: "openai",
APIKey: "", // unavailable provider
},
"google": {
Type: "google",
APIKey: "test-key",
},
},
[]config.ModelEntry{
{Name: "gpt-4", Provider: "openai"},
{Name: "gemini-pro", Provider: "google"},
},
)
require.NoError(t, err)
models := reg.Models()
require.Len(t, models, 1)
assert.Equal(t, "gemini-pro", models[0].Model)
assert.Equal(t, "google", models[0].Provider)
}
func TestBuildProvider(t *testing.T) {
tests := []struct {
name string

View File

@@ -239,17 +239,17 @@ func (s *GatewayServer) handleSyncResponse(w http.ResponseWriter, r *http.Reques
}
func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.Request, provider providers.Provider, providerMsgs []api.Message, resolvedReq *api.ResponseRequest, origReq *api.ResponseRequest, storeMsgs []api.Message) {
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "streaming not supported", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.WriteHeader(http.StatusOK)
responseID := generateID("resp_")
itemID := generateID("msg_")
seq := 0

View File

@@ -0,0 +1,53 @@
package server
import (
"io"
"log/slog"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
type nonFlusherRecorder struct {
recorder *httptest.ResponseRecorder
writeHeaderCalls int
}
func newNonFlusherRecorder() *nonFlusherRecorder {
return &nonFlusherRecorder{recorder: httptest.NewRecorder()}
}
func (w *nonFlusherRecorder) Header() http.Header {
return w.recorder.Header()
}
func (w *nonFlusherRecorder) Write(b []byte) (int, error) {
return w.recorder.Write(b)
}
func (w *nonFlusherRecorder) WriteHeader(statusCode int) {
w.writeHeaderCalls++
w.recorder.WriteHeader(statusCode)
}
func (w *nonFlusherRecorder) StatusCode() int {
return w.recorder.Code
}
func (w *nonFlusherRecorder) BodyString() string {
return w.recorder.Body.String()
}
func TestHandleStreamingResponseWithoutFlusherWritesSingleErrorHeader(t *testing.T) {
s := New(nil, nil, slog.New(slog.NewTextHandler(io.Discard, nil)))
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
w := newNonFlusherRecorder()
s.handleStreamingResponse(w, req, nil, nil, nil, nil, nil)
assert.Equal(t, 1, w.writeHeaderCalls)
assert.Equal(t, http.StatusInternalServerError, w.StatusCode())
assert.Contains(t, w.BodyString(), "streaming not supported")
}

Binary file not shown.

View File

@@ -135,6 +135,41 @@ class ChatClient:
return self._stream_response(model)
else:
return self._sync_response(model)
@staticmethod
def _get_attr(obj: Any, key: str, default: Any = None) -> Any:
"""Access object attributes safely for both SDK objects and dicts."""
if obj is None:
return default
if isinstance(obj, dict):
return obj.get(key, default)
return getattr(obj, key, default)
def _extract_stream_error(self, event: Any) -> str:
"""Extract error message from a response.failed event."""
response = self._get_attr(event, "response")
error = self._get_attr(response, "error")
message = self._get_attr(error, "message")
if message:
return str(message)
return "streaming request failed"
def _extract_completed_text(self, event: Any) -> str:
"""Extract assistant output text from a response.completed event."""
response = self._get_attr(event, "response")
output_items = self._get_attr(response, "output", []) or []
text_parts = []
for item in output_items:
if self._get_attr(item, "type") != "message":
continue
for part in self._get_attr(item, "content", []) or []:
if self._get_attr(part, "type") == "output_text":
text = self._get_attr(part, "text", "")
if text:
text_parts.append(str(text))
return "".join(text_parts)
def _sync_response(self, model: str) -> str:
"""Non-streaming response with tool support."""
@@ -225,6 +260,7 @@ class ChatClient:
while iteration < max_iterations:
iteration += 1
assistant_text = ""
stream_error = None
tool_calls = {} # Dict to track tool calls by item_id
tool_calls_list = [] # Final list of completed tool calls
assistant_content = []
@@ -244,6 +280,15 @@ class ChatClient:
if event.type == "response.output_text.delta":
assistant_text += event.delta
live.update(Markdown(assistant_text))
elif event.type == "response.completed":
# Some providers may emit final text only in response.completed.
if not assistant_text:
completed_text = self._extract_completed_text(event)
if completed_text:
assistant_text = completed_text
live.update(Markdown(assistant_text))
elif event.type == "response.failed":
stream_error = self._extract_stream_error(event)
elif event.type == "response.output_item.added":
if hasattr(event, 'item') and event.item.type == "function_call":
# Start tracking a new tool call
@@ -270,6 +315,10 @@ class ChatClient:
except json.JSONDecodeError:
self.console.print(f"[red]Error parsing tool arguments JSON[/red]")
if stream_error:
self.console.print(f"[bold red]Error:[/bold red] {stream_error}")
return ""
# Build assistant content
if assistant_text:
assistant_content.append({"type": "output_text", "text": assistant_text})