Add fail-fast on init for missing provider credentials
This commit is contained in:
@@ -155,6 +155,11 @@ func main() {
|
|||||||
|
|
||||||
// Register admin endpoints if enabled
|
// Register admin endpoints if enabled
|
||||||
if cfg.Admin.Enabled {
|
if cfg.Admin.Enabled {
|
||||||
|
// Check if frontend dist exists
|
||||||
|
if _, err := os.Stat("internal/admin/dist"); os.IsNotExist(err) {
|
||||||
|
log.Fatalf("admin UI enabled but frontend dist not found")
|
||||||
|
}
|
||||||
|
|
||||||
buildInfo := admin.BuildInfo{
|
buildInfo := admin.BuildInfo{
|
||||||
Version: "dev",
|
Version: "dev",
|
||||||
BuildTime: time.Now().Format(time.RFC3339),
|
BuildTime: time.Now().Format(time.RFC3339),
|
||||||
@@ -348,23 +353,39 @@ func initConversationStore(cfg config.ConversationConfig, logger *slog.Logger) (
|
|||||||
return conversation.NewMemoryStore(ttl), "memory", nil
|
return conversation.NewMemoryStore(ttl), "memory", nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type responseWriter struct {
|
type responseWriter struct {
|
||||||
http.ResponseWriter
|
http.ResponseWriter
|
||||||
statusCode int
|
statusCode int
|
||||||
bytesWritten int
|
bytesWritten int
|
||||||
|
wroteHeader bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rw *responseWriter) WriteHeader(code int) {
|
func (rw *responseWriter) WriteHeader(code int) {
|
||||||
|
if rw.wroteHeader {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rw.wroteHeader = true
|
||||||
rw.statusCode = code
|
rw.statusCode = code
|
||||||
rw.ResponseWriter.WriteHeader(code)
|
rw.ResponseWriter.WriteHeader(code)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rw *responseWriter) Write(b []byte) (int, error) {
|
func (rw *responseWriter) Write(b []byte) (int, error) {
|
||||||
|
if !rw.wroteHeader {
|
||||||
|
rw.wroteHeader = true
|
||||||
|
rw.statusCode = http.StatusOK
|
||||||
|
}
|
||||||
n, err := rw.ResponseWriter.Write(b)
|
n, err := rw.ResponseWriter.Write(b)
|
||||||
rw.bytesWritten += n
|
rw.bytesWritten += n
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rw *responseWriter) Flush() {
|
||||||
|
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func loggingMiddleware(next http.Handler, logger *slog.Logger) http.Handler {
|
func loggingMiddleware(next http.Handler, logger *slog.Logger) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
|||||||
57
cmd/gateway/main_test.go
Normal file
57
cmd/gateway/main_test.go
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ http.Flusher = (*responseWriter)(nil)
|
||||||
|
|
||||||
|
type countingFlusherRecorder struct {
|
||||||
|
*httptest.ResponseRecorder
|
||||||
|
flushCount int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCountingFlusherRecorder() *countingFlusherRecorder {
|
||||||
|
return &countingFlusherRecorder{ResponseRecorder: httptest.NewRecorder()}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *countingFlusherRecorder) Flush() {
|
||||||
|
r.flushCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponseWriterWriteHeaderOnlyOnce(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
rw := &responseWriter{ResponseWriter: rec, statusCode: http.StatusOK}
|
||||||
|
|
||||||
|
rw.WriteHeader(http.StatusCreated)
|
||||||
|
rw.WriteHeader(http.StatusInternalServerError)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusCreated, rec.Code)
|
||||||
|
assert.Equal(t, http.StatusCreated, rw.statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponseWriterWriteSetsImplicitStatus(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
rw := &responseWriter{ResponseWriter: rec, statusCode: http.StatusOK}
|
||||||
|
|
||||||
|
n, err := rw.Write([]byte("ok"))
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 2, n)
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
assert.Equal(t, http.StatusOK, rw.statusCode)
|
||||||
|
assert.Equal(t, 2, rw.bytesWritten)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponseWriterFlushDelegates(t *testing.T) {
|
||||||
|
rec := newCountingFlusherRecorder()
|
||||||
|
rw := &responseWriter{ResponseWriter: rec, statusCode: http.StatusOK}
|
||||||
|
|
||||||
|
rw.Flush()
|
||||||
|
|
||||||
|
assert.Equal(t, 1, rec.flushCount)
|
||||||
|
}
|
||||||
@@ -6,6 +6,7 @@ export default defineConfig({
|
|||||||
base: '/admin/',
|
base: '/admin/',
|
||||||
server: {
|
server: {
|
||||||
port: 5173,
|
port: 5173,
|
||||||
|
allowedHosts: ['.coder.ia-innovacion.work', 'localhost'],
|
||||||
proxy: {
|
proxy: {
|
||||||
'/admin/api': {
|
'/admin/api': {
|
||||||
target: 'http://localhost:8080',
|
target: 'http://localhost:8080',
|
||||||
|
|||||||
@@ -172,9 +172,32 @@ func Load(path string) (*Config, error) {
|
|||||||
|
|
||||||
func (cfg *Config) validate() error {
|
func (cfg *Config) validate() error {
|
||||||
for _, m := range cfg.Models {
|
for _, m := range cfg.Models {
|
||||||
if _, ok := cfg.Providers[m.Provider]; !ok {
|
providerEntry, ok := cfg.Providers[m.Provider]
|
||||||
|
if !ok {
|
||||||
return fmt.Errorf("model %q references unknown provider %q", m.Name, m.Provider)
|
return fmt.Errorf("model %q references unknown provider %q", m.Name, m.Provider)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
switch providerEntry.Type {
|
||||||
|
case "openai", "anthropic", "google", "azureopenai", "azureanthropic":
|
||||||
|
if providerEntry.APIKey == "" {
|
||||||
|
return fmt.Errorf("model %q references provider %q (%s) without api_key", m.Name, m.Provider, providerEntry.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch providerEntry.Type {
|
||||||
|
case "azureopenai", "azureanthropic":
|
||||||
|
if providerEntry.Endpoint == "" {
|
||||||
|
return fmt.Errorf("model %q references provider %q (%s) without endpoint", m.Name, m.Provider, providerEntry.Type)
|
||||||
|
}
|
||||||
|
case "vertexai":
|
||||||
|
if providerEntry.Project == "" || providerEntry.Location == "" {
|
||||||
|
return fmt.Errorf("model %q references provider %q (vertexai) without project/location", m.Name, m.Provider)
|
||||||
|
}
|
||||||
|
case "openai", "anthropic", "google":
|
||||||
|
// No additional required fields.
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("model %q references provider %q with unknown type %q", m.Name, m.Provider, providerEntry.Type)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -103,7 +103,7 @@ server:
|
|||||||
address: ":8080"
|
address: ":8080"
|
||||||
providers:
|
providers:
|
||||||
azure:
|
azure:
|
||||||
type: azure_openai
|
type: azureopenai
|
||||||
api_key: azure-key
|
api_key: azure-key
|
||||||
endpoint: https://my-resource.openai.azure.com
|
endpoint: https://my-resource.openai.azure.com
|
||||||
api_version: "2024-02-15-preview"
|
api_version: "2024-02-15-preview"
|
||||||
@@ -113,7 +113,7 @@ models:
|
|||||||
provider_model_id: gpt-4-deployment
|
provider_model_id: gpt-4-deployment
|
||||||
`,
|
`,
|
||||||
validate: func(t *testing.T, cfg *Config) {
|
validate: func(t *testing.T, cfg *Config) {
|
||||||
assert.Equal(t, "azure_openai", cfg.Providers["azure"].Type)
|
assert.Equal(t, "azureopenai", cfg.Providers["azure"].Type)
|
||||||
assert.Equal(t, "azure-key", cfg.Providers["azure"].APIKey)
|
assert.Equal(t, "azure-key", cfg.Providers["azure"].APIKey)
|
||||||
assert.Equal(t, "https://my-resource.openai.azure.com", cfg.Providers["azure"].Endpoint)
|
assert.Equal(t, "https://my-resource.openai.azure.com", cfg.Providers["azure"].Endpoint)
|
||||||
assert.Equal(t, "2024-02-15-preview", cfg.Providers["azure"].APIVersion)
|
assert.Equal(t, "2024-02-15-preview", cfg.Providers["azure"].APIVersion)
|
||||||
@@ -126,7 +126,7 @@ server:
|
|||||||
address: ":8080"
|
address: ":8080"
|
||||||
providers:
|
providers:
|
||||||
vertex:
|
vertex:
|
||||||
type: vertex_ai
|
type: vertexai
|
||||||
project: my-gcp-project
|
project: my-gcp-project
|
||||||
location: us-central1
|
location: us-central1
|
||||||
models:
|
models:
|
||||||
@@ -135,7 +135,7 @@ models:
|
|||||||
provider_model_id: gemini-1.5-pro
|
provider_model_id: gemini-1.5-pro
|
||||||
`,
|
`,
|
||||||
validate: func(t *testing.T, cfg *Config) {
|
validate: func(t *testing.T, cfg *Config) {
|
||||||
assert.Equal(t, "vertex_ai", cfg.Providers["vertex"].Type)
|
assert.Equal(t, "vertexai", cfg.Providers["vertex"].Type)
|
||||||
assert.Equal(t, "my-gcp-project", cfg.Providers["vertex"].Project)
|
assert.Equal(t, "my-gcp-project", cfg.Providers["vertex"].Project)
|
||||||
assert.Equal(t, "us-central1", cfg.Providers["vertex"].Location)
|
assert.Equal(t, "us-central1", cfg.Providers["vertex"].Location)
|
||||||
},
|
},
|
||||||
@@ -208,6 +208,20 @@ models:
|
|||||||
configYAML: `invalid: yaml: content: [unclosed`,
|
configYAML: `invalid: yaml: content: [unclosed`,
|
||||||
expectError: true,
|
expectError: true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "model references provider without required API key",
|
||||||
|
configYAML: `
|
||||||
|
server:
|
||||||
|
address: ":8080"
|
||||||
|
providers:
|
||||||
|
openai:
|
||||||
|
type: openai
|
||||||
|
models:
|
||||||
|
- name: gpt-4
|
||||||
|
provider: openai
|
||||||
|
`,
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "multiple models same provider",
|
name: "multiple models same provider",
|
||||||
configYAML: `
|
configYAML: `
|
||||||
@@ -283,7 +297,7 @@ func TestConfigValidate(t *testing.T) {
|
|||||||
name: "valid config",
|
name: "valid config",
|
||||||
config: Config{
|
config: Config{
|
||||||
Providers: map[string]ProviderEntry{
|
Providers: map[string]ProviderEntry{
|
||||||
"openai": {Type: "openai"},
|
"openai": {Type: "openai", APIKey: "test-key"},
|
||||||
},
|
},
|
||||||
Models: []ModelEntry{
|
Models: []ModelEntry{
|
||||||
{Name: "gpt-4", Provider: "openai"},
|
{Name: "gpt-4", Provider: "openai"},
|
||||||
@@ -295,7 +309,7 @@ func TestConfigValidate(t *testing.T) {
|
|||||||
name: "model references unknown provider",
|
name: "model references unknown provider",
|
||||||
config: Config{
|
config: Config{
|
||||||
Providers: map[string]ProviderEntry{
|
Providers: map[string]ProviderEntry{
|
||||||
"openai": {Type: "openai"},
|
"openai": {Type: "openai", APIKey: "test-key"},
|
||||||
},
|
},
|
||||||
Models: []ModelEntry{
|
Models: []ModelEntry{
|
||||||
{Name: "gpt-4", Provider: "unknown"},
|
{Name: "gpt-4", Provider: "unknown"},
|
||||||
@@ -303,6 +317,18 @@ func TestConfigValidate(t *testing.T) {
|
|||||||
},
|
},
|
||||||
expectError: true,
|
expectError: true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "model references provider without api key",
|
||||||
|
config: Config{
|
||||||
|
Providers: map[string]ProviderEntry{
|
||||||
|
"openai": {Type: "openai"},
|
||||||
|
},
|
||||||
|
Models: []ModelEntry{
|
||||||
|
{Name: "gpt-4", Provider: "openai"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "no models",
|
name: "no models",
|
||||||
config: Config{
|
config: Config{
|
||||||
@@ -317,8 +343,8 @@ func TestConfigValidate(t *testing.T) {
|
|||||||
name: "multiple models multiple providers",
|
name: "multiple models multiple providers",
|
||||||
config: Config{
|
config: Config{
|
||||||
Providers: map[string]ProviderEntry{
|
Providers: map[string]ProviderEntry{
|
||||||
"openai": {Type: "openai"},
|
"openai": {Type: "openai", APIKey: "test-key"},
|
||||||
"anthropic": {Type: "anthropic"},
|
"anthropic": {Type: "anthropic", APIKey: "ant-key"},
|
||||||
},
|
},
|
||||||
Models: []ModelEntry{
|
Models: []ModelEntry{
|
||||||
{Name: "gpt-4", Provider: "openai"},
|
{Name: "gpt-4", Provider: "openai"},
|
||||||
|
|||||||
@@ -48,15 +48,30 @@ type metricsResponseWriter struct {
|
|||||||
http.ResponseWriter
|
http.ResponseWriter
|
||||||
statusCode int
|
statusCode int
|
||||||
bytesWritten int
|
bytesWritten int
|
||||||
|
wroteHeader bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *metricsResponseWriter) WriteHeader(statusCode int) {
|
func (w *metricsResponseWriter) WriteHeader(statusCode int) {
|
||||||
|
if w.wroteHeader {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.wroteHeader = true
|
||||||
w.statusCode = statusCode
|
w.statusCode = statusCode
|
||||||
w.ResponseWriter.WriteHeader(statusCode)
|
w.ResponseWriter.WriteHeader(statusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *metricsResponseWriter) Write(b []byte) (int, error) {
|
func (w *metricsResponseWriter) Write(b []byte) (int, error) {
|
||||||
|
if !w.wroteHeader {
|
||||||
|
w.wroteHeader = true
|
||||||
|
w.statusCode = http.StatusOK
|
||||||
|
}
|
||||||
n, err := w.ResponseWriter.Write(b)
|
n, err := w.ResponseWriter.Write(b)
|
||||||
w.bytesWritten += n
|
w.bytesWritten += n
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *metricsResponseWriter) Flush() {
|
||||||
|
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
65
internal/observability/middleware_response_writer_test.go
Normal file
65
internal/observability/middleware_response_writer_test.go
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
package observability
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ http.Flusher = (*metricsResponseWriter)(nil)
|
||||||
|
var _ http.Flusher = (*statusResponseWriter)(nil)
|
||||||
|
|
||||||
|
type testFlusherRecorder struct {
|
||||||
|
*httptest.ResponseRecorder
|
||||||
|
flushCount int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestFlusherRecorder() *testFlusherRecorder {
|
||||||
|
return &testFlusherRecorder{ResponseRecorder: httptest.NewRecorder()}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *testFlusherRecorder) Flush() {
|
||||||
|
r.flushCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsResponseWriterWriteHeaderOnlyOnce(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
rw := &metricsResponseWriter{ResponseWriter: rec, statusCode: http.StatusOK}
|
||||||
|
|
||||||
|
rw.WriteHeader(http.StatusAccepted)
|
||||||
|
rw.WriteHeader(http.StatusInternalServerError)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusAccepted, rec.Code)
|
||||||
|
assert.Equal(t, http.StatusAccepted, rw.statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsResponseWriterFlushDelegates(t *testing.T) {
|
||||||
|
rec := newTestFlusherRecorder()
|
||||||
|
rw := &metricsResponseWriter{ResponseWriter: rec, statusCode: http.StatusOK}
|
||||||
|
|
||||||
|
rw.Flush()
|
||||||
|
|
||||||
|
assert.Equal(t, 1, rec.flushCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStatusResponseWriterWriteHeaderOnlyOnce(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
rw := &statusResponseWriter{ResponseWriter: rec, statusCode: http.StatusOK}
|
||||||
|
|
||||||
|
rw.WriteHeader(http.StatusNoContent)
|
||||||
|
rw.WriteHeader(http.StatusInternalServerError)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusNoContent, rec.Code)
|
||||||
|
assert.Equal(t, http.StatusNoContent, rw.statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStatusResponseWriterFlushDelegates(t *testing.T) {
|
||||||
|
rec := newTestFlusherRecorder()
|
||||||
|
rw := &statusResponseWriter{ResponseWriter: rec, statusCode: http.StatusOK}
|
||||||
|
|
||||||
|
rw.Flush()
|
||||||
|
|
||||||
|
assert.Equal(t, 1, rec.flushCount)
|
||||||
|
}
|
||||||
@@ -72,14 +72,29 @@ func TracingMiddleware(next http.Handler, tp *sdktrace.TracerProvider) http.Hand
|
|||||||
// statusResponseWriter wraps http.ResponseWriter to capture the status code.
|
// statusResponseWriter wraps http.ResponseWriter to capture the status code.
|
||||||
type statusResponseWriter struct {
|
type statusResponseWriter struct {
|
||||||
http.ResponseWriter
|
http.ResponseWriter
|
||||||
statusCode int
|
statusCode int
|
||||||
|
wroteHeader bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *statusResponseWriter) WriteHeader(statusCode int) {
|
func (w *statusResponseWriter) WriteHeader(statusCode int) {
|
||||||
|
if w.wroteHeader {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.wroteHeader = true
|
||||||
w.statusCode = statusCode
|
w.statusCode = statusCode
|
||||||
w.ResponseWriter.WriteHeader(statusCode)
|
w.ResponseWriter.WriteHeader(statusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *statusResponseWriter) Write(b []byte) (int, error) {
|
func (w *statusResponseWriter) Write(b []byte) (int, error) {
|
||||||
|
if !w.wroteHeader {
|
||||||
|
w.wroteHeader = true
|
||||||
|
w.statusCode = http.StatusOK
|
||||||
|
}
|
||||||
return w.ResponseWriter.Write(b)
|
return w.ResponseWriter.Write(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *statusResponseWriter) Flush() {
|
||||||
|
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -136,6 +136,9 @@ func (r *Registry) Get(name string) (Provider, bool) {
|
|||||||
func (r *Registry) Models() []struct{ Provider, Model string } {
|
func (r *Registry) Models() []struct{ Provider, Model string } {
|
||||||
var out []struct{ Provider, Model string }
|
var out []struct{ Provider, Model string }
|
||||||
for _, m := range r.modelList {
|
for _, m := range r.modelList {
|
||||||
|
if _, ok := r.providers[m.Provider]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
out = append(out, struct{ Provider, Model string }{Provider: m.Provider, Model: m.Name})
|
out = append(out, struct{ Provider, Model string }{Provider: m.Provider, Model: m.Name})
|
||||||
}
|
}
|
||||||
return out
|
return out
|
||||||
@@ -156,7 +159,9 @@ func (r *Registry) Default(model string) (Provider, error) {
|
|||||||
if p, ok := r.providers[providerName]; ok {
|
if p, ok := r.providers[providerName]; ok {
|
||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
return nil, fmt.Errorf("model %q is mapped to provider %q, but that provider is not available", model, providerName)
|
||||||
}
|
}
|
||||||
|
return nil, fmt.Errorf("model %q not configured", model)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, p := range r.providers {
|
for _, p := range r.providers {
|
||||||
|
|||||||
@@ -475,7 +475,7 @@ func TestRegistry_Default(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "returns first provider for unknown model",
|
name: "returns error for unknown model",
|
||||||
setupReg: func() *Registry {
|
setupReg: func() *Registry {
|
||||||
reg, _ := NewRegistry(
|
reg, _ := NewRegistry(
|
||||||
map[string]config.ProviderEntry{
|
map[string]config.ProviderEntry{
|
||||||
@@ -490,11 +490,34 @@ func TestRegistry_Default(t *testing.T) {
|
|||||||
)
|
)
|
||||||
return reg
|
return reg
|
||||||
},
|
},
|
||||||
modelName: "unknown-model",
|
modelName: "unknown-model",
|
||||||
validate: func(t *testing.T, p Provider) {
|
expectError: true,
|
||||||
assert.NotNil(t, p)
|
errorMsg: "not configured",
|
||||||
// Should return first available provider
|
},
|
||||||
|
{
|
||||||
|
name: "returns error for model whose provider is unavailable",
|
||||||
|
setupReg: func() *Registry {
|
||||||
|
reg, _ := NewRegistry(
|
||||||
|
map[string]config.ProviderEntry{
|
||||||
|
"openai": {
|
||||||
|
Type: "openai",
|
||||||
|
APIKey: "", // unavailable provider
|
||||||
|
},
|
||||||
|
"google": {
|
||||||
|
Type: "google",
|
||||||
|
APIKey: "test-key",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
[]config.ModelEntry{
|
||||||
|
{Name: "gpt-4", Provider: "openai"},
|
||||||
|
{Name: "gemini-pro", Provider: "google"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return reg
|
||||||
},
|
},
|
||||||
|
modelName: "gpt-4",
|
||||||
|
expectError: true,
|
||||||
|
errorMsg: "not available",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "returns first provider for empty model name",
|
name: "returns first provider for empty model name",
|
||||||
@@ -542,6 +565,31 @@ func TestRegistry_Default(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRegistry_Models_FiltersUnavailableProviders(t *testing.T) {
|
||||||
|
reg, err := NewRegistry(
|
||||||
|
map[string]config.ProviderEntry{
|
||||||
|
"openai": {
|
||||||
|
Type: "openai",
|
||||||
|
APIKey: "", // unavailable provider
|
||||||
|
},
|
||||||
|
"google": {
|
||||||
|
Type: "google",
|
||||||
|
APIKey: "test-key",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
[]config.ModelEntry{
|
||||||
|
{Name: "gpt-4", Provider: "openai"},
|
||||||
|
{Name: "gemini-pro", Provider: "google"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
models := reg.Models()
|
||||||
|
require.Len(t, models, 1)
|
||||||
|
assert.Equal(t, "gemini-pro", models[0].Model)
|
||||||
|
assert.Equal(t, "google", models[0].Provider)
|
||||||
|
}
|
||||||
|
|
||||||
func TestBuildProvider(t *testing.T) {
|
func TestBuildProvider(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -239,17 +239,17 @@ func (s *GatewayServer) handleSyncResponse(w http.ResponseWriter, r *http.Reques
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.Request, provider providers.Provider, providerMsgs []api.Message, resolvedReq *api.ResponseRequest, origReq *api.ResponseRequest, storeMsgs []api.Message) {
|
func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.Request, provider providers.Provider, providerMsgs []api.Message, resolvedReq *api.ResponseRequest, origReq *api.ResponseRequest, storeMsgs []api.Message) {
|
||||||
w.Header().Set("Content-Type", "text/event-stream")
|
|
||||||
w.Header().Set("Cache-Control", "no-cache")
|
|
||||||
w.Header().Set("Connection", "keep-alive")
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
|
|
||||||
flusher, ok := w.(http.Flusher)
|
flusher, ok := w.(http.Flusher)
|
||||||
if !ok {
|
if !ok {
|
||||||
http.Error(w, "streaming not supported", http.StatusInternalServerError)
|
http.Error(w, "streaming not supported", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
w.Header().Set("Cache-Control", "no-cache")
|
||||||
|
w.Header().Set("Connection", "keep-alive")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
responseID := generateID("resp_")
|
responseID := generateID("resp_")
|
||||||
itemID := generateID("msg_")
|
itemID := generateID("msg_")
|
||||||
seq := 0
|
seq := 0
|
||||||
|
|||||||
53
internal/server/streaming_writer_test.go
Normal file
53
internal/server/streaming_writer_test.go
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
type nonFlusherRecorder struct {
|
||||||
|
recorder *httptest.ResponseRecorder
|
||||||
|
writeHeaderCalls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newNonFlusherRecorder() *nonFlusherRecorder {
|
||||||
|
return &nonFlusherRecorder{recorder: httptest.NewRecorder()}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *nonFlusherRecorder) Header() http.Header {
|
||||||
|
return w.recorder.Header()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *nonFlusherRecorder) Write(b []byte) (int, error) {
|
||||||
|
return w.recorder.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *nonFlusherRecorder) WriteHeader(statusCode int) {
|
||||||
|
w.writeHeaderCalls++
|
||||||
|
w.recorder.WriteHeader(statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *nonFlusherRecorder) StatusCode() int {
|
||||||
|
return w.recorder.Code
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *nonFlusherRecorder) BodyString() string {
|
||||||
|
return w.recorder.Body.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleStreamingResponseWithoutFlusherWritesSingleErrorHeader(t *testing.T) {
|
||||||
|
s := New(nil, nil, slog.New(slog.NewTextHandler(io.Discard, nil)))
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
|
w := newNonFlusherRecorder()
|
||||||
|
|
||||||
|
s.handleStreamingResponse(w, req, nil, nil, nil, nil, nil)
|
||||||
|
|
||||||
|
assert.Equal(t, 1, w.writeHeaderCalls)
|
||||||
|
assert.Equal(t, http.StatusInternalServerError, w.StatusCode())
|
||||||
|
assert.Contains(t, w.BodyString(), "streaming not supported")
|
||||||
|
}
|
||||||
BIN
scripts/__pycache__/chat.cpython-312.pyc
Normal file
BIN
scripts/__pycache__/chat.cpython-312.pyc
Normal file
Binary file not shown.
@@ -136,6 +136,41 @@ class ChatClient:
|
|||||||
else:
|
else:
|
||||||
return self._sync_response(model)
|
return self._sync_response(model)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_attr(obj: Any, key: str, default: Any = None) -> Any:
|
||||||
|
"""Access object attributes safely for both SDK objects and dicts."""
|
||||||
|
if obj is None:
|
||||||
|
return default
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return obj.get(key, default)
|
||||||
|
return getattr(obj, key, default)
|
||||||
|
|
||||||
|
def _extract_stream_error(self, event: Any) -> str:
|
||||||
|
"""Extract error message from a response.failed event."""
|
||||||
|
response = self._get_attr(event, "response")
|
||||||
|
error = self._get_attr(response, "error")
|
||||||
|
message = self._get_attr(error, "message")
|
||||||
|
if message:
|
||||||
|
return str(message)
|
||||||
|
return "streaming request failed"
|
||||||
|
|
||||||
|
def _extract_completed_text(self, event: Any) -> str:
|
||||||
|
"""Extract assistant output text from a response.completed event."""
|
||||||
|
response = self._get_attr(event, "response")
|
||||||
|
output_items = self._get_attr(response, "output", []) or []
|
||||||
|
|
||||||
|
text_parts = []
|
||||||
|
for item in output_items:
|
||||||
|
if self._get_attr(item, "type") != "message":
|
||||||
|
continue
|
||||||
|
for part in self._get_attr(item, "content", []) or []:
|
||||||
|
if self._get_attr(part, "type") == "output_text":
|
||||||
|
text = self._get_attr(part, "text", "")
|
||||||
|
if text:
|
||||||
|
text_parts.append(str(text))
|
||||||
|
|
||||||
|
return "".join(text_parts)
|
||||||
|
|
||||||
def _sync_response(self, model: str) -> str:
|
def _sync_response(self, model: str) -> str:
|
||||||
"""Non-streaming response with tool support."""
|
"""Non-streaming response with tool support."""
|
||||||
max_iterations = 10 # Prevent infinite loops
|
max_iterations = 10 # Prevent infinite loops
|
||||||
@@ -225,6 +260,7 @@ class ChatClient:
|
|||||||
while iteration < max_iterations:
|
while iteration < max_iterations:
|
||||||
iteration += 1
|
iteration += 1
|
||||||
assistant_text = ""
|
assistant_text = ""
|
||||||
|
stream_error = None
|
||||||
tool_calls = {} # Dict to track tool calls by item_id
|
tool_calls = {} # Dict to track tool calls by item_id
|
||||||
tool_calls_list = [] # Final list of completed tool calls
|
tool_calls_list = [] # Final list of completed tool calls
|
||||||
assistant_content = []
|
assistant_content = []
|
||||||
@@ -244,6 +280,15 @@ class ChatClient:
|
|||||||
if event.type == "response.output_text.delta":
|
if event.type == "response.output_text.delta":
|
||||||
assistant_text += event.delta
|
assistant_text += event.delta
|
||||||
live.update(Markdown(assistant_text))
|
live.update(Markdown(assistant_text))
|
||||||
|
elif event.type == "response.completed":
|
||||||
|
# Some providers may emit final text only in response.completed.
|
||||||
|
if not assistant_text:
|
||||||
|
completed_text = self._extract_completed_text(event)
|
||||||
|
if completed_text:
|
||||||
|
assistant_text = completed_text
|
||||||
|
live.update(Markdown(assistant_text))
|
||||||
|
elif event.type == "response.failed":
|
||||||
|
stream_error = self._extract_stream_error(event)
|
||||||
elif event.type == "response.output_item.added":
|
elif event.type == "response.output_item.added":
|
||||||
if hasattr(event, 'item') and event.item.type == "function_call":
|
if hasattr(event, 'item') and event.item.type == "function_call":
|
||||||
# Start tracking a new tool call
|
# Start tracking a new tool call
|
||||||
@@ -270,6 +315,10 @@ class ChatClient:
|
|||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
self.console.print(f"[red]Error parsing tool arguments JSON[/red]")
|
self.console.print(f"[red]Error parsing tool arguments JSON[/red]")
|
||||||
|
|
||||||
|
if stream_error:
|
||||||
|
self.console.print(f"[bold red]Error:[/bold red] {stream_error}")
|
||||||
|
return ""
|
||||||
|
|
||||||
# Build assistant content
|
# Build assistant content
|
||||||
if assistant_text:
|
if assistant_text:
|
||||||
assistant_content.append({"type": "output_text", "text": assistant_text})
|
assistant_content.append({"type": "output_text", "text": assistant_text})
|
||||||
|
|||||||
Reference in New Issue
Block a user