diff --git a/cmd/gateway/main.go b/cmd/gateway/main.go index 94d0fef..2bc134f 100644 --- a/cmd/gateway/main.go +++ b/cmd/gateway/main.go @@ -155,6 +155,11 @@ func main() { // Register admin endpoints if enabled if cfg.Admin.Enabled { + // Check if frontend dist exists + if _, err := os.Stat("internal/admin/dist"); os.IsNotExist(err) { + log.Fatalf("admin UI enabled but frontend dist not found") + } + buildInfo := admin.BuildInfo{ Version: "dev", BuildTime: time.Now().Format(time.RFC3339), @@ -348,23 +353,39 @@ func initConversationStore(cfg config.ConversationConfig, logger *slog.Logger) ( return conversation.NewMemoryStore(ttl), "memory", nil } } + type responseWriter struct { http.ResponseWriter statusCode int bytesWritten int + wroteHeader bool } func (rw *responseWriter) WriteHeader(code int) { + if rw.wroteHeader { + return + } + rw.wroteHeader = true rw.statusCode = code rw.ResponseWriter.WriteHeader(code) } func (rw *responseWriter) Write(b []byte) (int, error) { + if !rw.wroteHeader { + rw.wroteHeader = true + rw.statusCode = http.StatusOK + } n, err := rw.ResponseWriter.Write(b) rw.bytesWritten += n return n, err } +func (rw *responseWriter) Flush() { + if flusher, ok := rw.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } +} + func loggingMiddleware(next http.Handler, logger *slog.Logger) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() diff --git a/cmd/gateway/main_test.go b/cmd/gateway/main_test.go new file mode 100644 index 0000000..c08cf50 --- /dev/null +++ b/cmd/gateway/main_test.go @@ -0,0 +1,57 @@ +package main + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +var _ http.Flusher = (*responseWriter)(nil) + +type countingFlusherRecorder struct { + *httptest.ResponseRecorder + flushCount int +} + +func newCountingFlusherRecorder() *countingFlusherRecorder { + return &countingFlusherRecorder{ResponseRecorder: httptest.NewRecorder()} +} + +func (r *countingFlusherRecorder) Flush() { + r.flushCount++ +} + +func TestResponseWriterWriteHeaderOnlyOnce(t *testing.T) { + rec := httptest.NewRecorder() + rw := &responseWriter{ResponseWriter: rec, statusCode: http.StatusOK} + + rw.WriteHeader(http.StatusCreated) + rw.WriteHeader(http.StatusInternalServerError) + + assert.Equal(t, http.StatusCreated, rec.Code) + assert.Equal(t, http.StatusCreated, rw.statusCode) +} + +func TestResponseWriterWriteSetsImplicitStatus(t *testing.T) { + rec := httptest.NewRecorder() + rw := &responseWriter{ResponseWriter: rec, statusCode: http.StatusOK} + + n, err := rw.Write([]byte("ok")) + + assert.NoError(t, err) + assert.Equal(t, 2, n) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, http.StatusOK, rw.statusCode) + assert.Equal(t, 2, rw.bytesWritten) +} + +func TestResponseWriterFlushDelegates(t *testing.T) { + rec := newCountingFlusherRecorder() + rw := &responseWriter{ResponseWriter: rec, statusCode: http.StatusOK} + + rw.Flush() + + assert.Equal(t, 1, rec.flushCount) +} diff --git a/frontend/admin/vite.config.ts b/frontend/admin/vite.config.ts index 4c37cb7..c5182bd 100644 --- a/frontend/admin/vite.config.ts +++ b/frontend/admin/vite.config.ts @@ -6,6 +6,7 @@ export default defineConfig({ base: '/admin/', server: { port: 5173, + allowedHosts: ['.coder.ia-innovacion.work', 'localhost'], proxy: { '/admin/api': { target: 'http://localhost:8080', diff --git a/internal/config/config.go b/internal/config/config.go index d32c46e..89d6334 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -172,9 +172,32 @@ func Load(path string) (*Config, error) { func (cfg *Config) validate() error { for _, m := range cfg.Models { - if _, ok := cfg.Providers[m.Provider]; !ok { + providerEntry, ok := cfg.Providers[m.Provider] + if !ok { return fmt.Errorf("model %q references unknown provider %q", m.Name, m.Provider) } + + switch providerEntry.Type { + case "openai", "anthropic", "google", "azureopenai", "azureanthropic": + if providerEntry.APIKey == "" { + return fmt.Errorf("model %q references provider %q (%s) without api_key", m.Name, m.Provider, providerEntry.Type) + } + } + + switch providerEntry.Type { + case "azureopenai", "azureanthropic": + if providerEntry.Endpoint == "" { + return fmt.Errorf("model %q references provider %q (%s) without endpoint", m.Name, m.Provider, providerEntry.Type) + } + case "vertexai": + if providerEntry.Project == "" || providerEntry.Location == "" { + return fmt.Errorf("model %q references provider %q (vertexai) without project/location", m.Name, m.Provider) + } + case "openai", "anthropic", "google": + // No additional required fields. + default: + return fmt.Errorf("model %q references provider %q with unknown type %q", m.Name, m.Provider, providerEntry.Type) + } } return nil } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 867b4b2..2615f29 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -103,7 +103,7 @@ server: address: ":8080" providers: azure: - type: azure_openai + type: azureopenai api_key: azure-key endpoint: https://my-resource.openai.azure.com api_version: "2024-02-15-preview" @@ -113,7 +113,7 @@ models: provider_model_id: gpt-4-deployment `, validate: func(t *testing.T, cfg *Config) { - assert.Equal(t, "azure_openai", cfg.Providers["azure"].Type) + assert.Equal(t, "azureopenai", cfg.Providers["azure"].Type) assert.Equal(t, "azure-key", cfg.Providers["azure"].APIKey) assert.Equal(t, "https://my-resource.openai.azure.com", cfg.Providers["azure"].Endpoint) assert.Equal(t, "2024-02-15-preview", cfg.Providers["azure"].APIVersion) @@ -126,7 +126,7 @@ server: address: ":8080" providers: vertex: - type: vertex_ai + type: vertexai project: my-gcp-project location: us-central1 models: @@ -135,7 +135,7 @@ models: provider_model_id: gemini-1.5-pro `, validate: func(t *testing.T, cfg *Config) { - assert.Equal(t, "vertex_ai", cfg.Providers["vertex"].Type) + assert.Equal(t, "vertexai", cfg.Providers["vertex"].Type) assert.Equal(t, "my-gcp-project", cfg.Providers["vertex"].Project) assert.Equal(t, "us-central1", cfg.Providers["vertex"].Location) }, @@ -208,6 +208,20 @@ models: configYAML: `invalid: yaml: content: [unclosed`, expectError: true, }, + { + name: "model references provider without required API key", + configYAML: ` +server: + address: ":8080" +providers: + openai: + type: openai +models: + - name: gpt-4 + provider: openai +`, + expectError: true, + }, { name: "multiple models same provider", configYAML: ` @@ -283,7 +297,7 @@ func TestConfigValidate(t *testing.T) { name: "valid config", config: Config{ Providers: map[string]ProviderEntry{ - "openai": {Type: "openai"}, + "openai": {Type: "openai", APIKey: "test-key"}, }, Models: []ModelEntry{ {Name: "gpt-4", Provider: "openai"}, @@ -295,7 +309,7 @@ func TestConfigValidate(t *testing.T) { name: "model references unknown provider", config: Config{ Providers: map[string]ProviderEntry{ - "openai": {Type: "openai"}, + "openai": {Type: "openai", APIKey: "test-key"}, }, Models: []ModelEntry{ {Name: "gpt-4", Provider: "unknown"}, @@ -303,6 +317,18 @@ func TestConfigValidate(t *testing.T) { }, expectError: true, }, + { + name: "model references provider without api key", + config: Config{ + Providers: map[string]ProviderEntry{ + "openai": {Type: "openai"}, + }, + Models: []ModelEntry{ + {Name: "gpt-4", Provider: "openai"}, + }, + }, + expectError: true, + }, { name: "no models", config: Config{ @@ -317,8 +343,8 @@ func TestConfigValidate(t *testing.T) { name: "multiple models multiple providers", config: Config{ Providers: map[string]ProviderEntry{ - "openai": {Type: "openai"}, - "anthropic": {Type: "anthropic"}, + "openai": {Type: "openai", APIKey: "test-key"}, + "anthropic": {Type: "anthropic", APIKey: "ant-key"}, }, Models: []ModelEntry{ {Name: "gpt-4", Provider: "openai"}, diff --git a/internal/observability/metrics_middleware.go b/internal/observability/metrics_middleware.go index 8537935..fdb98f4 100644 --- a/internal/observability/metrics_middleware.go +++ b/internal/observability/metrics_middleware.go @@ -48,15 +48,30 @@ type metricsResponseWriter struct { http.ResponseWriter statusCode int bytesWritten int + wroteHeader bool } func (w *metricsResponseWriter) WriteHeader(statusCode int) { + if w.wroteHeader { + return + } + w.wroteHeader = true w.statusCode = statusCode w.ResponseWriter.WriteHeader(statusCode) } func (w *metricsResponseWriter) Write(b []byte) (int, error) { + if !w.wroteHeader { + w.wroteHeader = true + w.statusCode = http.StatusOK + } n, err := w.ResponseWriter.Write(b) w.bytesWritten += n return n, err } + +func (w *metricsResponseWriter) Flush() { + if flusher, ok := w.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } +} diff --git a/internal/observability/middleware_response_writer_test.go b/internal/observability/middleware_response_writer_test.go new file mode 100644 index 0000000..14d0cb3 --- /dev/null +++ b/internal/observability/middleware_response_writer_test.go @@ -0,0 +1,65 @@ +package observability + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +var _ http.Flusher = (*metricsResponseWriter)(nil) +var _ http.Flusher = (*statusResponseWriter)(nil) + +type testFlusherRecorder struct { + *httptest.ResponseRecorder + flushCount int +} + +func newTestFlusherRecorder() *testFlusherRecorder { + return &testFlusherRecorder{ResponseRecorder: httptest.NewRecorder()} +} + +func (r *testFlusherRecorder) Flush() { + r.flushCount++ +} + +func TestMetricsResponseWriterWriteHeaderOnlyOnce(t *testing.T) { + rec := httptest.NewRecorder() + rw := &metricsResponseWriter{ResponseWriter: rec, statusCode: http.StatusOK} + + rw.WriteHeader(http.StatusAccepted) + rw.WriteHeader(http.StatusInternalServerError) + + assert.Equal(t, http.StatusAccepted, rec.Code) + assert.Equal(t, http.StatusAccepted, rw.statusCode) +} + +func TestMetricsResponseWriterFlushDelegates(t *testing.T) { + rec := newTestFlusherRecorder() + rw := &metricsResponseWriter{ResponseWriter: rec, statusCode: http.StatusOK} + + rw.Flush() + + assert.Equal(t, 1, rec.flushCount) +} + +func TestStatusResponseWriterWriteHeaderOnlyOnce(t *testing.T) { + rec := httptest.NewRecorder() + rw := &statusResponseWriter{ResponseWriter: rec, statusCode: http.StatusOK} + + rw.WriteHeader(http.StatusNoContent) + rw.WriteHeader(http.StatusInternalServerError) + + assert.Equal(t, http.StatusNoContent, rec.Code) + assert.Equal(t, http.StatusNoContent, rw.statusCode) +} + +func TestStatusResponseWriterFlushDelegates(t *testing.T) { + rec := newTestFlusherRecorder() + rw := &statusResponseWriter{ResponseWriter: rec, statusCode: http.StatusOK} + + rw.Flush() + + assert.Equal(t, 1, rec.flushCount) +} diff --git a/internal/observability/tracing_middleware.go b/internal/observability/tracing_middleware.go index c1b426e..9feae16 100644 --- a/internal/observability/tracing_middleware.go +++ b/internal/observability/tracing_middleware.go @@ -72,14 +72,29 @@ func TracingMiddleware(next http.Handler, tp *sdktrace.TracerProvider) http.Hand // statusResponseWriter wraps http.ResponseWriter to capture the status code. type statusResponseWriter struct { http.ResponseWriter - statusCode int + statusCode int + wroteHeader bool } func (w *statusResponseWriter) WriteHeader(statusCode int) { + if w.wroteHeader { + return + } + w.wroteHeader = true w.statusCode = statusCode w.ResponseWriter.WriteHeader(statusCode) } func (w *statusResponseWriter) Write(b []byte) (int, error) { + if !w.wroteHeader { + w.wroteHeader = true + w.statusCode = http.StatusOK + } return w.ResponseWriter.Write(b) } + +func (w *statusResponseWriter) Flush() { + if flusher, ok := w.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } +} diff --git a/internal/providers/providers.go b/internal/providers/providers.go index bd807bc..639fcda 100644 --- a/internal/providers/providers.go +++ b/internal/providers/providers.go @@ -136,6 +136,9 @@ func (r *Registry) Get(name string) (Provider, bool) { func (r *Registry) Models() []struct{ Provider, Model string } { var out []struct{ Provider, Model string } for _, m := range r.modelList { + if _, ok := r.providers[m.Provider]; !ok { + continue + } out = append(out, struct{ Provider, Model string }{Provider: m.Provider, Model: m.Name}) } return out @@ -156,7 +159,9 @@ func (r *Registry) Default(model string) (Provider, error) { if p, ok := r.providers[providerName]; ok { return p, nil } + return nil, fmt.Errorf("model %q is mapped to provider %q, but that provider is not available", model, providerName) } + return nil, fmt.Errorf("model %q not configured", model) } for _, p := range r.providers { diff --git a/internal/providers/providers_test.go b/internal/providers/providers_test.go index 49b8595..367b6f0 100644 --- a/internal/providers/providers_test.go +++ b/internal/providers/providers_test.go @@ -475,7 +475,7 @@ func TestRegistry_Default(t *testing.T) { }, }, { - name: "returns first provider for unknown model", + name: "returns error for unknown model", setupReg: func() *Registry { reg, _ := NewRegistry( map[string]config.ProviderEntry{ @@ -490,11 +490,34 @@ func TestRegistry_Default(t *testing.T) { ) return reg }, - modelName: "unknown-model", - validate: func(t *testing.T, p Provider) { - assert.NotNil(t, p) - // Should return first available provider + modelName: "unknown-model", + expectError: true, + errorMsg: "not configured", + }, + { + name: "returns error for model whose provider is unavailable", + setupReg: func() *Registry { + reg, _ := NewRegistry( + map[string]config.ProviderEntry{ + "openai": { + Type: "openai", + APIKey: "", // unavailable provider + }, + "google": { + Type: "google", + APIKey: "test-key", + }, + }, + []config.ModelEntry{ + {Name: "gpt-4", Provider: "openai"}, + {Name: "gemini-pro", Provider: "google"}, + }, + ) + return reg }, + modelName: "gpt-4", + expectError: true, + errorMsg: "not available", }, { name: "returns first provider for empty model name", @@ -542,6 +565,31 @@ func TestRegistry_Default(t *testing.T) { } } +func TestRegistry_Models_FiltersUnavailableProviders(t *testing.T) { + reg, err := NewRegistry( + map[string]config.ProviderEntry{ + "openai": { + Type: "openai", + APIKey: "", // unavailable provider + }, + "google": { + Type: "google", + APIKey: "test-key", + }, + }, + []config.ModelEntry{ + {Name: "gpt-4", Provider: "openai"}, + {Name: "gemini-pro", Provider: "google"}, + }, + ) + require.NoError(t, err) + + models := reg.Models() + require.Len(t, models, 1) + assert.Equal(t, "gemini-pro", models[0].Model) + assert.Equal(t, "google", models[0].Provider) +} + func TestBuildProvider(t *testing.T) { tests := []struct { name string diff --git a/internal/server/server.go b/internal/server/server.go index 0dcb490..5190944 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -239,17 +239,17 @@ func (s *GatewayServer) handleSyncResponse(w http.ResponseWriter, r *http.Reques } func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.Request, provider providers.Provider, providerMsgs []api.Message, resolvedReq *api.ResponseRequest, origReq *api.ResponseRequest, storeMsgs []api.Message) { - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.WriteHeader(http.StatusOK) - flusher, ok := w.(http.Flusher) if !ok { http.Error(w, "streaming not supported", http.StatusInternalServerError) return } + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.WriteHeader(http.StatusOK) + responseID := generateID("resp_") itemID := generateID("msg_") seq := 0 diff --git a/internal/server/streaming_writer_test.go b/internal/server/streaming_writer_test.go new file mode 100644 index 0000000..95dc3b2 --- /dev/null +++ b/internal/server/streaming_writer_test.go @@ -0,0 +1,53 @@ +package server + +import ( + "io" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +type nonFlusherRecorder struct { + recorder *httptest.ResponseRecorder + writeHeaderCalls int +} + +func newNonFlusherRecorder() *nonFlusherRecorder { + return &nonFlusherRecorder{recorder: httptest.NewRecorder()} +} + +func (w *nonFlusherRecorder) Header() http.Header { + return w.recorder.Header() +} + +func (w *nonFlusherRecorder) Write(b []byte) (int, error) { + return w.recorder.Write(b) +} + +func (w *nonFlusherRecorder) WriteHeader(statusCode int) { + w.writeHeaderCalls++ + w.recorder.WriteHeader(statusCode) +} + +func (w *nonFlusherRecorder) StatusCode() int { + return w.recorder.Code +} + +func (w *nonFlusherRecorder) BodyString() string { + return w.recorder.Body.String() +} + +func TestHandleStreamingResponseWithoutFlusherWritesSingleErrorHeader(t *testing.T) { + s := New(nil, nil, slog.New(slog.NewTextHandler(io.Discard, nil))) + req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + w := newNonFlusherRecorder() + + s.handleStreamingResponse(w, req, nil, nil, nil, nil, nil) + + assert.Equal(t, 1, w.writeHeaderCalls) + assert.Equal(t, http.StatusInternalServerError, w.StatusCode()) + assert.Contains(t, w.BodyString(), "streaming not supported") +} diff --git a/scripts/__pycache__/chat.cpython-312.pyc b/scripts/__pycache__/chat.cpython-312.pyc new file mode 100644 index 0000000..066ed18 Binary files /dev/null and b/scripts/__pycache__/chat.cpython-312.pyc differ diff --git a/scripts/chat.py b/scripts/chat.py index 545faeb..83cf362 100755 --- a/scripts/chat.py +++ b/scripts/chat.py @@ -135,6 +135,41 @@ class ChatClient: return self._stream_response(model) else: return self._sync_response(model) + + @staticmethod + def _get_attr(obj: Any, key: str, default: Any = None) -> Any: + """Access object attributes safely for both SDK objects and dicts.""" + if obj is None: + return default + if isinstance(obj, dict): + return obj.get(key, default) + return getattr(obj, key, default) + + def _extract_stream_error(self, event: Any) -> str: + """Extract error message from a response.failed event.""" + response = self._get_attr(event, "response") + error = self._get_attr(response, "error") + message = self._get_attr(error, "message") + if message: + return str(message) + return "streaming request failed" + + def _extract_completed_text(self, event: Any) -> str: + """Extract assistant output text from a response.completed event.""" + response = self._get_attr(event, "response") + output_items = self._get_attr(response, "output", []) or [] + + text_parts = [] + for item in output_items: + if self._get_attr(item, "type") != "message": + continue + for part in self._get_attr(item, "content", []) or []: + if self._get_attr(part, "type") == "output_text": + text = self._get_attr(part, "text", "") + if text: + text_parts.append(str(text)) + + return "".join(text_parts) def _sync_response(self, model: str) -> str: """Non-streaming response with tool support.""" @@ -225,6 +260,7 @@ class ChatClient: while iteration < max_iterations: iteration += 1 assistant_text = "" + stream_error = None tool_calls = {} # Dict to track tool calls by item_id tool_calls_list = [] # Final list of completed tool calls assistant_content = [] @@ -244,6 +280,15 @@ class ChatClient: if event.type == "response.output_text.delta": assistant_text += event.delta live.update(Markdown(assistant_text)) + elif event.type == "response.completed": + # Some providers may emit final text only in response.completed. + if not assistant_text: + completed_text = self._extract_completed_text(event) + if completed_text: + assistant_text = completed_text + live.update(Markdown(assistant_text)) + elif event.type == "response.failed": + stream_error = self._extract_stream_error(event) elif event.type == "response.output_item.added": if hasattr(event, 'item') and event.item.type == "function_call": # Start tracking a new tool call @@ -270,6 +315,10 @@ class ChatClient: except json.JSONDecodeError: self.console.print(f"[red]Error parsing tool arguments JSON[/red]") + if stream_error: + self.console.print(f"[bold red]Error:[/bold red] {stream_error}") + return "" + # Build assistant content if assistant_text: assistant_content.append({"type": "output_text", "text": assistant_text})