From 89c7e3ac85d09bebc2a4c0f46cc2a01675fe5260 Mon Sep 17 00:00:00 2001 From: Anibal Angulo Date: Fri, 6 Mar 2026 21:31:51 +0000 Subject: [PATCH] Add fail-fast on init for missing provider credentials --- cmd/gateway/main.go | 21 ++++++ cmd/gateway/main_test.go | 57 +++++++++++++++ frontend/admin/vite.config.ts | 1 + internal/config/config.go | 25 ++++++- internal/config/config_test.go | 42 ++++++++--- internal/observability/metrics_middleware.go | 15 ++++ .../middleware_response_writer_test.go | 65 ++++++++++++++++++ internal/observability/tracing_middleware.go | 17 ++++- internal/providers/providers.go | 5 ++ internal/providers/providers_test.go | 58 ++++++++++++++-- internal/server/server.go | 10 +-- internal/server/streaming_writer_test.go | 53 ++++++++++++++ scripts/__pycache__/chat.cpython-312.pyc | Bin 0 -> 20974 bytes scripts/chat.py | 49 +++++++++++++ 14 files changed, 398 insertions(+), 20 deletions(-) create mode 100644 cmd/gateway/main_test.go create mode 100644 internal/observability/middleware_response_writer_test.go create mode 100644 internal/server/streaming_writer_test.go create mode 100644 scripts/__pycache__/chat.cpython-312.pyc diff --git a/cmd/gateway/main.go b/cmd/gateway/main.go index 94d0fef..2bc134f 100644 --- a/cmd/gateway/main.go +++ b/cmd/gateway/main.go @@ -155,6 +155,11 @@ func main() { // Register admin endpoints if enabled if cfg.Admin.Enabled { + // Check if frontend dist exists + if _, err := os.Stat("internal/admin/dist"); os.IsNotExist(err) { + log.Fatalf("admin UI enabled but frontend dist not found") + } + buildInfo := admin.BuildInfo{ Version: "dev", BuildTime: time.Now().Format(time.RFC3339), @@ -348,23 +353,39 @@ func initConversationStore(cfg config.ConversationConfig, logger *slog.Logger) ( return conversation.NewMemoryStore(ttl), "memory", nil } } + type responseWriter struct { http.ResponseWriter statusCode int bytesWritten int + wroteHeader bool } func (rw *responseWriter) WriteHeader(code int) { + if rw.wroteHeader { + return + } + rw.wroteHeader = true rw.statusCode = code rw.ResponseWriter.WriteHeader(code) } func (rw *responseWriter) Write(b []byte) (int, error) { + if !rw.wroteHeader { + rw.wroteHeader = true + rw.statusCode = http.StatusOK + } n, err := rw.ResponseWriter.Write(b) rw.bytesWritten += n return n, err } +func (rw *responseWriter) Flush() { + if flusher, ok := rw.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } +} + func loggingMiddleware(next http.Handler, logger *slog.Logger) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() diff --git a/cmd/gateway/main_test.go b/cmd/gateway/main_test.go new file mode 100644 index 0000000..c08cf50 --- /dev/null +++ b/cmd/gateway/main_test.go @@ -0,0 +1,57 @@ +package main + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +var _ http.Flusher = (*responseWriter)(nil) + +type countingFlusherRecorder struct { + *httptest.ResponseRecorder + flushCount int +} + +func newCountingFlusherRecorder() *countingFlusherRecorder { + return &countingFlusherRecorder{ResponseRecorder: httptest.NewRecorder()} +} + +func (r *countingFlusherRecorder) Flush() { + r.flushCount++ +} + +func TestResponseWriterWriteHeaderOnlyOnce(t *testing.T) { + rec := httptest.NewRecorder() + rw := &responseWriter{ResponseWriter: rec, statusCode: http.StatusOK} + + rw.WriteHeader(http.StatusCreated) + rw.WriteHeader(http.StatusInternalServerError) + + assert.Equal(t, http.StatusCreated, rec.Code) + assert.Equal(t, http.StatusCreated, rw.statusCode) +} + +func TestResponseWriterWriteSetsImplicitStatus(t *testing.T) { + rec := httptest.NewRecorder() + rw := &responseWriter{ResponseWriter: rec, statusCode: http.StatusOK} + + n, err := rw.Write([]byte("ok")) + + assert.NoError(t, err) + assert.Equal(t, 2, n) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, http.StatusOK, rw.statusCode) + assert.Equal(t, 2, rw.bytesWritten) +} + +func TestResponseWriterFlushDelegates(t *testing.T) { + rec := newCountingFlusherRecorder() + rw := &responseWriter{ResponseWriter: rec, statusCode: http.StatusOK} + + rw.Flush() + + assert.Equal(t, 1, rec.flushCount) +} diff --git a/frontend/admin/vite.config.ts b/frontend/admin/vite.config.ts index 4c37cb7..c5182bd 100644 --- a/frontend/admin/vite.config.ts +++ b/frontend/admin/vite.config.ts @@ -6,6 +6,7 @@ export default defineConfig({ base: '/admin/', server: { port: 5173, + allowedHosts: ['.coder.ia-innovacion.work', 'localhost'], proxy: { '/admin/api': { target: 'http://localhost:8080', diff --git a/internal/config/config.go b/internal/config/config.go index d32c46e..89d6334 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -172,9 +172,32 @@ func Load(path string) (*Config, error) { func (cfg *Config) validate() error { for _, m := range cfg.Models { - if _, ok := cfg.Providers[m.Provider]; !ok { + providerEntry, ok := cfg.Providers[m.Provider] + if !ok { return fmt.Errorf("model %q references unknown provider %q", m.Name, m.Provider) } + + switch providerEntry.Type { + case "openai", "anthropic", "google", "azureopenai", "azureanthropic": + if providerEntry.APIKey == "" { + return fmt.Errorf("model %q references provider %q (%s) without api_key", m.Name, m.Provider, providerEntry.Type) + } + } + + switch providerEntry.Type { + case "azureopenai", "azureanthropic": + if providerEntry.Endpoint == "" { + return fmt.Errorf("model %q references provider %q (%s) without endpoint", m.Name, m.Provider, providerEntry.Type) + } + case "vertexai": + if providerEntry.Project == "" || providerEntry.Location == "" { + return fmt.Errorf("model %q references provider %q (vertexai) without project/location", m.Name, m.Provider) + } + case "openai", "anthropic", "google": + // No additional required fields. + default: + return fmt.Errorf("model %q references provider %q with unknown type %q", m.Name, m.Provider, providerEntry.Type) + } } return nil } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 867b4b2..2615f29 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -103,7 +103,7 @@ server: address: ":8080" providers: azure: - type: azure_openai + type: azureopenai api_key: azure-key endpoint: https://my-resource.openai.azure.com api_version: "2024-02-15-preview" @@ -113,7 +113,7 @@ models: provider_model_id: gpt-4-deployment `, validate: func(t *testing.T, cfg *Config) { - assert.Equal(t, "azure_openai", cfg.Providers["azure"].Type) + assert.Equal(t, "azureopenai", cfg.Providers["azure"].Type) assert.Equal(t, "azure-key", cfg.Providers["azure"].APIKey) assert.Equal(t, "https://my-resource.openai.azure.com", cfg.Providers["azure"].Endpoint) assert.Equal(t, "2024-02-15-preview", cfg.Providers["azure"].APIVersion) @@ -126,7 +126,7 @@ server: address: ":8080" providers: vertex: - type: vertex_ai + type: vertexai project: my-gcp-project location: us-central1 models: @@ -135,7 +135,7 @@ models: provider_model_id: gemini-1.5-pro `, validate: func(t *testing.T, cfg *Config) { - assert.Equal(t, "vertex_ai", cfg.Providers["vertex"].Type) + assert.Equal(t, "vertexai", cfg.Providers["vertex"].Type) assert.Equal(t, "my-gcp-project", cfg.Providers["vertex"].Project) assert.Equal(t, "us-central1", cfg.Providers["vertex"].Location) }, @@ -208,6 +208,20 @@ models: configYAML: `invalid: yaml: content: [unclosed`, expectError: true, }, + { + name: "model references provider without required API key", + configYAML: ` +server: + address: ":8080" +providers: + openai: + type: openai +models: + - name: gpt-4 + provider: openai +`, + expectError: true, + }, { name: "multiple models same provider", configYAML: ` @@ -283,7 +297,7 @@ func TestConfigValidate(t *testing.T) { name: "valid config", config: Config{ Providers: map[string]ProviderEntry{ - "openai": {Type: "openai"}, + "openai": {Type: "openai", APIKey: "test-key"}, }, Models: []ModelEntry{ {Name: "gpt-4", Provider: "openai"}, @@ -295,7 +309,7 @@ func TestConfigValidate(t *testing.T) { name: "model references unknown provider", config: Config{ Providers: map[string]ProviderEntry{ - "openai": {Type: "openai"}, + "openai": {Type: "openai", APIKey: "test-key"}, }, Models: []ModelEntry{ {Name: "gpt-4", Provider: "unknown"}, @@ -303,6 +317,18 @@ func TestConfigValidate(t *testing.T) { }, expectError: true, }, + { + name: "model references provider without api key", + config: Config{ + Providers: map[string]ProviderEntry{ + "openai": {Type: "openai"}, + }, + Models: []ModelEntry{ + {Name: "gpt-4", Provider: "openai"}, + }, + }, + expectError: true, + }, { name: "no models", config: Config{ @@ -317,8 +343,8 @@ func TestConfigValidate(t *testing.T) { name: "multiple models multiple providers", config: Config{ Providers: map[string]ProviderEntry{ - "openai": {Type: "openai"}, - "anthropic": {Type: "anthropic"}, + "openai": {Type: "openai", APIKey: "test-key"}, + "anthropic": {Type: "anthropic", APIKey: "ant-key"}, }, Models: []ModelEntry{ {Name: "gpt-4", Provider: "openai"}, diff --git a/internal/observability/metrics_middleware.go b/internal/observability/metrics_middleware.go index 8537935..fdb98f4 100644 --- a/internal/observability/metrics_middleware.go +++ b/internal/observability/metrics_middleware.go @@ -48,15 +48,30 @@ type metricsResponseWriter struct { http.ResponseWriter statusCode int bytesWritten int + wroteHeader bool } func (w *metricsResponseWriter) WriteHeader(statusCode int) { + if w.wroteHeader { + return + } + w.wroteHeader = true w.statusCode = statusCode w.ResponseWriter.WriteHeader(statusCode) } func (w *metricsResponseWriter) Write(b []byte) (int, error) { + if !w.wroteHeader { + w.wroteHeader = true + w.statusCode = http.StatusOK + } n, err := w.ResponseWriter.Write(b) w.bytesWritten += n return n, err } + +func (w *metricsResponseWriter) Flush() { + if flusher, ok := w.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } +} diff --git a/internal/observability/middleware_response_writer_test.go b/internal/observability/middleware_response_writer_test.go new file mode 100644 index 0000000..14d0cb3 --- /dev/null +++ b/internal/observability/middleware_response_writer_test.go @@ -0,0 +1,65 @@ +package observability + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +var _ http.Flusher = (*metricsResponseWriter)(nil) +var _ http.Flusher = (*statusResponseWriter)(nil) + +type testFlusherRecorder struct { + *httptest.ResponseRecorder + flushCount int +} + +func newTestFlusherRecorder() *testFlusherRecorder { + return &testFlusherRecorder{ResponseRecorder: httptest.NewRecorder()} +} + +func (r *testFlusherRecorder) Flush() { + r.flushCount++ +} + +func TestMetricsResponseWriterWriteHeaderOnlyOnce(t *testing.T) { + rec := httptest.NewRecorder() + rw := &metricsResponseWriter{ResponseWriter: rec, statusCode: http.StatusOK} + + rw.WriteHeader(http.StatusAccepted) + rw.WriteHeader(http.StatusInternalServerError) + + assert.Equal(t, http.StatusAccepted, rec.Code) + assert.Equal(t, http.StatusAccepted, rw.statusCode) +} + +func TestMetricsResponseWriterFlushDelegates(t *testing.T) { + rec := newTestFlusherRecorder() + rw := &metricsResponseWriter{ResponseWriter: rec, statusCode: http.StatusOK} + + rw.Flush() + + assert.Equal(t, 1, rec.flushCount) +} + +func TestStatusResponseWriterWriteHeaderOnlyOnce(t *testing.T) { + rec := httptest.NewRecorder() + rw := &statusResponseWriter{ResponseWriter: rec, statusCode: http.StatusOK} + + rw.WriteHeader(http.StatusNoContent) + rw.WriteHeader(http.StatusInternalServerError) + + assert.Equal(t, http.StatusNoContent, rec.Code) + assert.Equal(t, http.StatusNoContent, rw.statusCode) +} + +func TestStatusResponseWriterFlushDelegates(t *testing.T) { + rec := newTestFlusherRecorder() + rw := &statusResponseWriter{ResponseWriter: rec, statusCode: http.StatusOK} + + rw.Flush() + + assert.Equal(t, 1, rec.flushCount) +} diff --git a/internal/observability/tracing_middleware.go b/internal/observability/tracing_middleware.go index c1b426e..9feae16 100644 --- a/internal/observability/tracing_middleware.go +++ b/internal/observability/tracing_middleware.go @@ -72,14 +72,29 @@ func TracingMiddleware(next http.Handler, tp *sdktrace.TracerProvider) http.Hand // statusResponseWriter wraps http.ResponseWriter to capture the status code. type statusResponseWriter struct { http.ResponseWriter - statusCode int + statusCode int + wroteHeader bool } func (w *statusResponseWriter) WriteHeader(statusCode int) { + if w.wroteHeader { + return + } + w.wroteHeader = true w.statusCode = statusCode w.ResponseWriter.WriteHeader(statusCode) } func (w *statusResponseWriter) Write(b []byte) (int, error) { + if !w.wroteHeader { + w.wroteHeader = true + w.statusCode = http.StatusOK + } return w.ResponseWriter.Write(b) } + +func (w *statusResponseWriter) Flush() { + if flusher, ok := w.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } +} diff --git a/internal/providers/providers.go b/internal/providers/providers.go index bd807bc..639fcda 100644 --- a/internal/providers/providers.go +++ b/internal/providers/providers.go @@ -136,6 +136,9 @@ func (r *Registry) Get(name string) (Provider, bool) { func (r *Registry) Models() []struct{ Provider, Model string } { var out []struct{ Provider, Model string } for _, m := range r.modelList { + if _, ok := r.providers[m.Provider]; !ok { + continue + } out = append(out, struct{ Provider, Model string }{Provider: m.Provider, Model: m.Name}) } return out @@ -156,7 +159,9 @@ func (r *Registry) Default(model string) (Provider, error) { if p, ok := r.providers[providerName]; ok { return p, nil } + return nil, fmt.Errorf("model %q is mapped to provider %q, but that provider is not available", model, providerName) } + return nil, fmt.Errorf("model %q not configured", model) } for _, p := range r.providers { diff --git a/internal/providers/providers_test.go b/internal/providers/providers_test.go index 49b8595..367b6f0 100644 --- a/internal/providers/providers_test.go +++ b/internal/providers/providers_test.go @@ -475,7 +475,7 @@ func TestRegistry_Default(t *testing.T) { }, }, { - name: "returns first provider for unknown model", + name: "returns error for unknown model", setupReg: func() *Registry { reg, _ := NewRegistry( map[string]config.ProviderEntry{ @@ -490,11 +490,34 @@ func TestRegistry_Default(t *testing.T) { ) return reg }, - modelName: "unknown-model", - validate: func(t *testing.T, p Provider) { - assert.NotNil(t, p) - // Should return first available provider + modelName: "unknown-model", + expectError: true, + errorMsg: "not configured", + }, + { + name: "returns error for model whose provider is unavailable", + setupReg: func() *Registry { + reg, _ := NewRegistry( + map[string]config.ProviderEntry{ + "openai": { + Type: "openai", + APIKey: "", // unavailable provider + }, + "google": { + Type: "google", + APIKey: "test-key", + }, + }, + []config.ModelEntry{ + {Name: "gpt-4", Provider: "openai"}, + {Name: "gemini-pro", Provider: "google"}, + }, + ) + return reg }, + modelName: "gpt-4", + expectError: true, + errorMsg: "not available", }, { name: "returns first provider for empty model name", @@ -542,6 +565,31 @@ func TestRegistry_Default(t *testing.T) { } } +func TestRegistry_Models_FiltersUnavailableProviders(t *testing.T) { + reg, err := NewRegistry( + map[string]config.ProviderEntry{ + "openai": { + Type: "openai", + APIKey: "", // unavailable provider + }, + "google": { + Type: "google", + APIKey: "test-key", + }, + }, + []config.ModelEntry{ + {Name: "gpt-4", Provider: "openai"}, + {Name: "gemini-pro", Provider: "google"}, + }, + ) + require.NoError(t, err) + + models := reg.Models() + require.Len(t, models, 1) + assert.Equal(t, "gemini-pro", models[0].Model) + assert.Equal(t, "google", models[0].Provider) +} + func TestBuildProvider(t *testing.T) { tests := []struct { name string diff --git a/internal/server/server.go b/internal/server/server.go index 0dcb490..5190944 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -239,17 +239,17 @@ func (s *GatewayServer) handleSyncResponse(w http.ResponseWriter, r *http.Reques } func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.Request, provider providers.Provider, providerMsgs []api.Message, resolvedReq *api.ResponseRequest, origReq *api.ResponseRequest, storeMsgs []api.Message) { - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.WriteHeader(http.StatusOK) - flusher, ok := w.(http.Flusher) if !ok { http.Error(w, "streaming not supported", http.StatusInternalServerError) return } + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.WriteHeader(http.StatusOK) + responseID := generateID("resp_") itemID := generateID("msg_") seq := 0 diff --git a/internal/server/streaming_writer_test.go b/internal/server/streaming_writer_test.go new file mode 100644 index 0000000..95dc3b2 --- /dev/null +++ b/internal/server/streaming_writer_test.go @@ -0,0 +1,53 @@ +package server + +import ( + "io" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +type nonFlusherRecorder struct { + recorder *httptest.ResponseRecorder + writeHeaderCalls int +} + +func newNonFlusherRecorder() *nonFlusherRecorder { + return &nonFlusherRecorder{recorder: httptest.NewRecorder()} +} + +func (w *nonFlusherRecorder) Header() http.Header { + return w.recorder.Header() +} + +func (w *nonFlusherRecorder) Write(b []byte) (int, error) { + return w.recorder.Write(b) +} + +func (w *nonFlusherRecorder) WriteHeader(statusCode int) { + w.writeHeaderCalls++ + w.recorder.WriteHeader(statusCode) +} + +func (w *nonFlusherRecorder) StatusCode() int { + return w.recorder.Code +} + +func (w *nonFlusherRecorder) BodyString() string { + return w.recorder.Body.String() +} + +func TestHandleStreamingResponseWithoutFlusherWritesSingleErrorHeader(t *testing.T) { + s := New(nil, nil, slog.New(slog.NewTextHandler(io.Discard, nil))) + req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + w := newNonFlusherRecorder() + + s.handleStreamingResponse(w, req, nil, nil, nil, nil, nil) + + assert.Equal(t, 1, w.writeHeaderCalls) + assert.Equal(t, http.StatusInternalServerError, w.StatusCode()) + assert.Contains(t, w.BodyString(), "streaming not supported") +} diff --git a/scripts/__pycache__/chat.cpython-312.pyc b/scripts/__pycache__/chat.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..066ed18f78cad3ca2e37f54c65f99fa54e6ab4b9 GIT binary patch literal 20974 zcmch9ZEzdOmDu3>`-=n#kOavQzabHzzQ4aHQlBf4+!f_rQKyAL%!mXD0`vfs#4)JV zo%d|YtILwMt3<`x5_Q?DrLxK`x9pVZN~KQTIA_^O&VWD}@K~G3=dE)2M-p-+$1ZK9 zF7NeV0FVOh^(B>&*wfSfy8Cta>(}qSevSWRGU+LJ{>Sn6gS%d)s9)iS^yspYr#&)? zxtxok6|9rhT##45ATV{T^MabGe^c34DW{H7Ov9TL)A+tz9G|`_ zris;EkTcDY?_yfsr248EH-t58Ewg^Og`zmEoMLO<&km=I8kdjD-5dB{>U&vkBpCLG z=)j;qMhC+&mb>5&u=IrpM~D2eSTMkbM!b6cnW%q&-L0qL86A%eM#7|kcXYf+67BJf zaUpsz78~8&))tBc{Gq`}G`4%^#+@6BgO5ZQHbf7wfJ)G_*}KtmA>@w^df3aO#o=R- zVKz)Rvd`MF^x zaw!aPRgXW+h9IKu;UXiWF-+?9_lHFMWD!bTq=nqO_l3f|%AyN5}eOoS#(KN{@_%V!_c+FhB~l z(*7_*Gr^0&D1dqeb)GzLR5;8qF>s7KIBRS4Q_`Efywr5Kj4?;T`|B16X_ zG$4j3yA^^mHa^Mi`?`1;#iIXuyYa&2Kr%#^@L{n@y!L708_S)6r2j5WEm%8Is0F!YmCmf2OxX z(2}B%z`xRqm0%V&NVl-w0dFha3_#8BY6l!(4%)idOTM!aZWz-Jk0IJ^C!vp#aI@P> zcd{4!FkMj`Nhs%)Mx+`>M9{?|k&qAT7Ic1YU~B{sj|yszjg4{PKz4NlAVq25PXT{V zpA$|P2OdY(^;>b^hsYr&v+IjEcaitjZRY+AAqyqni)RT~^ zDM(e3)OJYKlGM7mlF^ZvBL+B#&rl(b*)&Ga7=p6*voNPA(z1cnR{2Lu8sqYBQ||yj z%$NVeW#X#199v*ax0LxZFg|92>#4YG2zs*mL7pvf&k1!5N0&T9S}B%)OI&k8^G7hz z3_&kfKT_$qnz1n@w-k9C|8Q+EiluRC^;o#={HDW~*}zx~)+>-yBt#a!NHB!ZIW`L9 zM*D$z&>{oyauz^}tHhuR0~ZD*7?fgQ$Dj-YoJ>K*BJ1O=omqxL_m9&&%SHIoGvVPd zu)RF1+D!v<0cyn-SJ%G|E3Cr`6+rfU=jp-7CE!Lh*$(LFcr?b2h^w(hT5fc1DbwLd zjJ_}y3XRjjkjG9kupndM@)qk?0x-*)CG2oe05P?o9Et+>RWV~DqftQttbxM? zBq*Q{f-V@1AXD?l+*+<4QEk9r)fmCXDfWSDgIE-412;f?p095kjEt~t0ob{?wtZ>=oJf_vlrX;Z ziKX;y{u)2~g$4J6b%|F_Co6kXmNN-L=sd)N(mrc zN)IwC{C+wQ!3_%7SB}|AARrFr*khdX@WMi-}wG3IYmjQkyWTz**=m4ooAxTlm@ zAk>n>VSs`!rPNj|$I_&6fn2qe%c$Th$2`X=x4KJ^wegC!i<@|LIN}**LAv6$1Gu2^ zj|P44-Rsr@%RtUFN-{XCOi%~N*nnulMp4p?a=6$84af>8eF-L9IZ+=QM%ju%Q6$QS zE{GNE0LUn!56hCjt=sr;`02)0}#16D-PjS^Npuj5YzmkUXbk%*=vB2(!laZVP;sbppm+& z#EuL4V0d&4WH0t|?5-&KfzXH_S*cFR;cVix#$YLcIJXZF5Y+xrkiQwh>I1rff$~Z9 z3ua$*JRHbsqeq;21<2ik0!UoYFrv`MH9~vESN9;~;H+PT^>~`fivn>fmJ_L@)yb8t zF{d(^)hsmBk(<*va1F@G(#~nhE#nWg2LiCNXi+twL9ye4{UAg{>8Sq#3)?K}d;Jkm zqfU3el!XEbg6I{2Skw#qydKnxa5Ux*2S5d2f`J(4g+2%xpa-lDRVW1Cj?}AjF!ngnbOM1uEz%03U6&)G|-vS@*6v*cg$>=)uc*W z62_Ltmh!3SOvkJvWpO98ZgCD&>_wof6_PU&2M)#4uR}l zrVEmAZn!5DYrwz;l&eP>S8|3}b{Z%uULc>2E8Qx7%i+sJr=y9)MN5n90%(%7KflH< z_=CXnXwow;>O$h$BC}whC!}^ zC^8IOK-m>k=7S}R+NQJ~MUjv$oVa2T--2qQxHdtm0f}P&^#0k#xyn@O)=6!~YQJHg zHqTt1bEK@BCRLe|@*Cc1?`+weE>*I5QuA{|$>Y-UsY~D7n=sa9EaG&om`5aUGLM&G z{{C<0ae$JRhqxIIP!kIHQ^pmY)Ok>oCSYSLN&~9={4&a6*m^kG^@=rNRWG&!^&f-! z)oYYd7t~L5wCLLPs~HWWmHIKEiE9A1Ev_uKX%vVC8fhbt-ZJX?2_VV!-ujWc2iE>J5NdBn0ZYM8Jgfz(zreh{ge=VeS+JfQR+ZQ+prj)Q3 zk!95|r-r>~eqXBeV5X|}_Wqmu=NjhK09>_as#9=OWt??sr#tC%->aWv-giH8Ze7+< z6^CR?8mesFQVC@(dHd3}OH+KxQkT%y5#G6Cv7wAj7TZB!4PbbcO#t%)rl*L@hhfK= zl5@jxV0F9+NEulnfUkavks+gu%RmZMy^~Xew^1Oq$e;#f)2kjveM=5Yk{Zg%n|T0E z7efh=JfQq*_CidKQWRpo>KQVwN`TKH^NVRrZ4j9ByRv^Fn@}-Y{3o$;83oKi4SyQ= z)4rh^vc6av7WV&y(a^a4RP(0&6%0tYKu0J{hxkP3Wi?S{BGNgyw}+50tN2 zl?0y>h7*Rkx)_E~s_w;7klIir)c^*d9902L74Nh8Pps3DH!8tcEGwa2uC_R}7V80{ zW2|uvV_SilI$9Rja$pP^!VP%U6W2TJgWfGg>}cxOm*R?UJt_D6y+jAs1mMv;>J8S5>MmgQ}XM><2CivCRl5i zKtnzYyh%n)$*%v?WyM+QlB}MJp@E1zLv%J$@M@qqwTuc<6Na-`YWNZ*mUGj*yDJj* z6iR8L3`buA`@d+$i;6aX?{AQS5lh@}P)_B`&hv2E)TB<+boix7Uld zDXWX5%^nYd4hgf%bM+G_G?I z%p&|7gW{?-711dOKt@vGmU zNv(pcg*5zJ;})R=`w9eoSwD%?b&k^kbV7u zK@c>+ypMgNv5;tTjEtVYOSvn!{ZJgWaXjH6swZ~c0r-qNIXCiYSSkz!=iLpbAC&OjI08%Zbz7U7c{k!LS5H4n)c9SOy3zIiwO4AvVl; zFbi#gf*M2>P=DPPQKTVLY98@l_T~Eq21}`*f_nH8oE1bxK}t_%4|^IFY`K~8Wfg3W z#YQe*Bc^PUuOBS1!y>vyQW!dn2Nc=E2&{^O?S_h4ARtOXJ|Myp*8>@RMWJA{LP`Hy zEc`EFC4)Q!wi0v3UN)(FVs^}UQ|1;hY0lWdwt4)Ct?Y(t+Lg97ByA0e=DiDN7LMNE znrJ-s$aefmbsbn$ZlAw-ep#icESXY+)y3{imo`2uZG7CcVd}`tj_K|vRXb+3-rjR_ zPr7PjvT9?xYDW_Pw=Xn7Fs1#SjVd`ITPmk2U6Us=bX~$&odIl4JVBGsEsH2h?V8>- zt9w}Le&Y6g;JW9UF=eV+(p8>hm1n*>Rkc6kYD&92Ntb7CB<($z^d4MvwP#$Ov}(`RSfQe~T#l!~(E#|T`)70jR92g|Hzw_knVP!Wp_`!$-SD37jxN*a1>-|=IT#((#pRNnI6QYrv!rYc&Nw^P=2sAZX^O=)XG(%O)5HD+ARnc9s|Wm5@MX)9ee zQ}&vlZ`?fgJK3kb`v?i>F?#m$nqln4&_Q`7})5e;lv1YdRyQTAe->FS(Y)`bm zlBhksXza~YRZpJ$Rdw@}4%;zK8`IXhq_uAL+t3lsaQ|rLFZzYkkJG zIaAv*XTSMIrltYVZY=pNq771OP0NxU5j=J6RN6=e24Qh=xCeyJBI2-u6Y}CPqnuDOstFZr1UZqh(7;p7G|R~D zsU=}^8w!zYiNv7s-BN&{_{W7}J)`@7A{4iPP^>3Hu_12g1kwLI>x(Rz0ztXq&f`4o?fF2EwzzjZ~8CQRPc%)bmowh)UV$@|LW7N4~ zj7oB2B!`kcP_ilCD=8LNU!GC56$TGpW6fFdBzV)@lIu(XrW z$JTDFz%)ObnD;kO`cIorEOmd4Qboo-*Jt5RnE@P!2{R~kITd4Qy)=sXH;CV*2q~>k z{KioMwF~8%RZrY3!GSm}zh0-LbYKL)f70B6!lojMRaOBgte}_x&y>eics8D|cO0#) zc2JOuw9h!eMBpq`YvUSFf!1M9*C41s)WUGsb*^8 zTJZK-p{z#baV1gqK5zf&7#=r~C!VwDfoisHm9l3cJw;{D47~C4)?O#|8fvUBpoFbb zoB21+d?6)3iG~%*9v;t;N2)VdB8$z0HEt<7YmERY&ViL_DjcsrpzQ6ETFbxVmOPad zD0_NbSUeRX9&+TRz0B7)dFpP{*f!ow*^X|;`)rF1Qt;%`K)^A=69YZ=U$=# zzfd&{fy)`_HDPcvz)U=lA*zj}EpkJEFwqvdVT|=d&;@ECQFKZ;_5uj^gFf)1@`-*| zy>Lu}`UQuo9XUMXMj*ylN?MP%zyj+9r@WY-hibB?eL1@)Z^@;J2MsS)fHqez_?@u~ z(LcB_Gy}HWymCpDPYr^BV#hb-E6n=M5pC6%^NJFY;ieF8EXr5pN)#A=z2J4l@~!8< zSL{5|z|f}z&aDQ-VawGxo9velsMvmAF1~K<_%+(C~qj=ONj(lDRCCHNzM9r1^a}2sL_zMiag29^@Alo9U zDfgEUbKAIsm_X>}s~Gzl1cEuQZKADRFpB04A2}=Ht|5peublcqa8|&59aEJ^)Z8Ri z)QJI_@C9>@wnXFqUt!L53~+r$Z4l(u&}>Z9M!O-4uUesvt~^BkC#>MVD*lK*nyIdx z>;hZ*Sk5C z!ket{rYp83@&ADZH3U=YPpa#u4rd%LbjwLJwm)>VXPixQjUQ~dyCLb^He<#^0K-kg zyT|5UPc`g)lDXv)ksH9#w6c zXTLY};n4RYA4VS3rnbKZ+F^6+dqa1I(oH*)O*=C!o)7eQ^&eR7S~8xt%!Z90?7qAE zgMD}Rfl|7wX6D4~*n9jP{=JDi6DilWB_mZ?D`MXLu%i2!nsU`cEAF3FZCxn)yXx;( zKUkmI)eE{dl!6jT*S7iLl2LcwQf-M$-ef#>-&D+16|^k^T}hc zr}p_~kIxM!8+K{=8@4BQvFU~Z(BmQF`R~B{38p7$JM=t?JIsfNTs>8a|6ne~u&%uVUO(K6;eRY| z-vO^5Sz#%zkjd3E__`F#h8$t=&z^pVTTK{?jIJk5iMdV?5!IVcbdL zRjQsw4Ir9-P?`ZILp?>s;RYNSfooJQ7o5Q4z@$Q~pQFJL2mB8hGVo}v&t?-&Z{S;qIb=@-65!r0V5=A0&9Z$u`QonqJ zBpX69UG?ORK5#un0fIS>AbC&-hg0UFr5IUU0V;>QQ0+MjzKs)#D~fSAp~m+mnIwd} zP5kB~Y%DNHBxz!dr~MbfzXsh3Y2vIG6+Ia6&Op|k*UMKEdGd(phXa267f2N%_-^{S zS`;L^cx!vEu2XoUCEEMkNJPmIaA^$3{C5>_MiY#I2b>`a?$EwL)(?WcU=-irY%>m7 zs_si+NK^*K{o%WCorI+FT2di>yi-sOa4Z{k>qQSqwTJ}@RneHBJA667k}DaaG>60> z=s^mH1A)-kNEo~@AS;p5tP9$2lT9h)0`1xjAt54Y)C)LmN zl&K`)T)&`vuz%6mclB^asY@&ENu_=2)!8EnrF~JkA)nMYdx<2q{mSM5-ytpW-kV}? zgr~!^?0bWE1|MnLVAm{h!Y-&Xe(ltspPFiXq;&rNnS(NxeM-pz1FZ!;{u0h@^ZxnM z3+hDqfko}X$6E7eODf2I9!2Bh^~ysP)JGM%4w>Sky0Sxi6(8+YK|BCixhJ=-;a*tr zbxq}6zC>S5@Uulu0MCSMjq3>FI>s#o9`uMP=EF(Wu_nKk@+$;ZY8%uJzJ$eA7NeJx zQid2y!I3f2{B z0Z}h+$*+!HOzGwR6|~L03jwdg6uJwxC+;pL{%Z&XLudYq5qA%hQTO4@7=R60yj9fY z){vFJA!;0wo~Si>G4>$_Ux5JBUP9a7g*bSmV)Hm9K++;T+eCUM8Ci1vKIDHD$@vOI za^(EjXiGTV^RWlsMdR5*sy#QS7D+Tuwdd!WNz&$DTgxVu8H45Ro!53w9Z4CgX84rB zoiUQ*jG2a+i#J>6lym2kEqfjr_b#a@TREH%70|Khp95&*9@K}h15fX3S5jdy6Av6N zw4^4D2VL0}edv?-}y6x<;$y0_8o+idHou28Cg6 z_mjnPJ4Y_9U`GV1HWro>n^1tU6UUQ0aYeEHtC-KuP_B$r7R;g9sI=ZD zly4}bO|U9)F|B%-5(x@CB65cY#k^-2TOJbnQnRWe_~cB*(krP=2JCTDTOsj;UknnFgj242YTutqu&&c`uov)>TcyC(|s| zKs>pCtFC^eIR-A;0!K)jR_8+S-hOl*jUh+@fnF-+K6e zL#d6Vlp(fD>W=F#F>M z=Wt}2khD(aPer#yX^Oag&6z z^d{b0&q0?!!GS!WOKb&=rZ|4q)rU%-pD%;X0h-bk6&!Cp7h`#?1KRH_Xn)(9?RS9o zbVP#8zhR6!A?H{=MT)@@Sk5^c{nf;luYXy>U`4-c)@=9i>VD^Hnw4sjG$`oH38{4$ zOFPI3#mDkUeh$u3*MAs0Sx`!WxQ1 zyy2*L>yke6%1Jo+=!Fu3a*z#;@*0naC@B1TQF#K*4cy%5mh9~uawS{5te(>W-B^fR z%HvDh@mB@JUm9pR1Qr;Ua|1^I4ly!#JaD^BG5)^nWWnr|61E>MJ0MCZ-F)fiXD2 zpuxsesLPVGBL1*g6{?u$bHh!pchJYBpMc2LPGIyMB7#s~EV7$EM^w=Byz^X}7}Jvh zhKzAG6pCCr&zGNT%SQF5rK2s#CQ`hh0K5d*dER)g?Yy3-0+8iL)e#e*t&tHp=!)X6 zO%OcKw~=Y2_o9(%zlhg)+C!fP?@3VQN#-ipXddb$D8>cK2!W2~!NrNngZD-T2EZ1O zo4-6HQ4U@J#8k=CFb|8`_$v5Sc#zzD0bm z7{2l)a5YFJVA#O)05lPd(wHu2US&gY_#!G^uz>ys7Z8{&I516!CJA0WFdFl0jqnEX z?ZKbv5af|fd}ScyA7fb07SA>hT(%CgG0*0Wo40P-xOp>QQII$24-Svzv+Tuw>4P_C zQ#X0HdEmz%+!{d%?QP|C7?L)*j}SMwD~gQ@DpISU7K`!vY;h>cSLsCpCtkB>BV%Ic zMJHMgB@+0Z4lQ_CM&U;^Vqlq7wc((`3rBGb1#F3uk47R)|2VsjUb;Q~R*t9vl5m($ z#KF%|Y<7@g(vGO{748C_<5KgnIIMtB9~IM}99hAkkeLp*YAd zN@`?0AQS0GP#z=PG>bE@&sWc@#l+vBoq%s7W-{@YVsZmOI^@Z&Ts%_cE3@oFT)u>* zh--K^&1)r&!fUbvz}qBniI32NCVLMUe$+)|w+tXcI>wu_(7mYQ(~wNM$#2Le10Q?A z_D&NJmdQ0;F7_k)YX`74kR9!&xxdCvIf+xG6Xe7pAaTV)!H0K|vRe zWN!j;1_E>evQenyP}hT>mx&Ieof58e<7S9=uF$U?l5lyUgY+dh6h`wP_k9fB#DLgm z$U(u&_(@K%QS%4WC~?mQizP51l!gS`OYC@m#LqEgZ{o&AV}kZ@_YpFhZnK~!0|FLP z(pzHY!JY}$+^(LrQ*z(IWRvs@5`;1Zv$)QrtL1`;B>ViKAY9Uidt`X68E%uq{Wd`t z7-6zk$`lZVVSyihA>hOpJBmgI(MT$VqH#$%;tz(oUqkM<(4O?F675NkO=W3QL(uBx3?zh}B* zO4D1D^wu=JD@pH4(R)%=dlOpcvYs-!BtvTXpDk+7Kh_%F)?U+2*{&O&XbfND|ANO; zYyD0AHPaJk%bYFg+>my9lTPoXDN|8>yXIz1qOAF9SI$cNt;QSfY4?&`rlaxHT6;%3 z!!J>?^_7{*=9yZExg9^R+cm_y?zJHb1P{JiTKd`fdvvqb5>Alta4GwC+MiIK&bDXV(6&Q_{^M%h%FrS+7pB2!kEF_&aY zyct_-rlciP-n4AdmZ`6vf-_2kWl7#6)0I4~5KpZe4y7tOz;V2^I$c_yEUjNGZG<~| z&5$FwWeG<;n26dc2+hWN|(1J%Ue?A>!%b?O3M?K zt;y2XjJ@i{sp(URx|dBz zh&*^LQTK&M_E)iDY&zrKk!aj8)s=(V_JjQ&1{1IP6WjU|{X>bm;YapR7K=>Vwsc!Z zvaKVrp=0XEqP-RO4=*a_Lii*Wzi15 z!jK)U?nm|>82B<589X?r-?GrSa4}H_H6Q&Y&WXWvNS3iW&AhnIa&7Tz%Eg| z^Q8yv54I%AyBD=RIl_P2@O1+@La=}KETo|9b#P*B>dWJ>ao+a5>JO{uEejJ5UjFIX zAD?~jQljRywDoM#diJx&r7i?7KzZ-N%L^TenuBR;d(zte8Ju!M?q}dKVOnNaRu8JGH~t|1~|(3 zqisRE7d8PnP2bOb3K_uS;qDBCFeNhCvRx}vKDARa>n|w7FDU)TRP}$O=wDOrk16_N zs`6v1?qkaJF;xR0_{rF+;3`DM@#QCW7m#^%qu@H6sXqojMl~gnTiuwc-v&2g%1hvC z2!3oQAIoLtD)Xl`RK>>S`VxnJS?5rxmo}5S>N1X#P}dV{WyaNzslNp9)g?;`Rh4;J zrA1H^fNp8l$*-!GdS&Zm9kLlX1jqMW=X~jWVBUGROZ;AlJ|*eDJ0`oLkX0>H7%m;s tT4fb87oSq_`n1Uaf$OAPR Any: + """Access object attributes safely for both SDK objects and dicts.""" + if obj is None: + return default + if isinstance(obj, dict): + return obj.get(key, default) + return getattr(obj, key, default) + + def _extract_stream_error(self, event: Any) -> str: + """Extract error message from a response.failed event.""" + response = self._get_attr(event, "response") + error = self._get_attr(response, "error") + message = self._get_attr(error, "message") + if message: + return str(message) + return "streaming request failed" + + def _extract_completed_text(self, event: Any) -> str: + """Extract assistant output text from a response.completed event.""" + response = self._get_attr(event, "response") + output_items = self._get_attr(response, "output", []) or [] + + text_parts = [] + for item in output_items: + if self._get_attr(item, "type") != "message": + continue + for part in self._get_attr(item, "content", []) or []: + if self._get_attr(part, "type") == "output_text": + text = self._get_attr(part, "text", "") + if text: + text_parts.append(str(text)) + + return "".join(text_parts) def _sync_response(self, model: str) -> str: """Non-streaming response with tool support.""" @@ -225,6 +260,7 @@ class ChatClient: while iteration < max_iterations: iteration += 1 assistant_text = "" + stream_error = None tool_calls = {} # Dict to track tool calls by item_id tool_calls_list = [] # Final list of completed tool calls assistant_content = [] @@ -244,6 +280,15 @@ class ChatClient: if event.type == "response.output_text.delta": assistant_text += event.delta live.update(Markdown(assistant_text)) + elif event.type == "response.completed": + # Some providers may emit final text only in response.completed. + if not assistant_text: + completed_text = self._extract_completed_text(event) + if completed_text: + assistant_text = completed_text + live.update(Markdown(assistant_text)) + elif event.type == "response.failed": + stream_error = self._extract_stream_error(event) elif event.type == "response.output_item.added": if hasattr(event, 'item') and event.item.type == "function_call": # Start tracking a new tool call @@ -270,6 +315,10 @@ class ChatClient: except json.JSONDecodeError: self.console.print(f"[red]Error parsing tool arguments JSON[/red]") + if stream_error: + self.console.print(f"[bold red]Error:[/bold red] {stream_error}") + return "" + # Build assistant content if assistant_text: assistant_content.append({"type": "output_text", "text": assistant_text})