From 27e68f8e4c12cdd6a33600ede79782de91e38f16 Mon Sep 17 00:00:00 2001 From: Anibal Angulo Date: Sun, 1 Mar 2026 19:21:58 +0000 Subject: [PATCH] Update config structure --- README.md | 10 ++- cmd/gateway/main.go | 2 +- config.example-with-auth.yaml | 19 ----- config.example.yaml | 29 +++++-- internal/api/types.go | 12 +++ internal/config/config.go | 102 ++++++++++-------------- internal/providers/openai/openai.go | 1 - internal/providers/providers.go | 119 +++++++++++++++++++++------- internal/server/server.go | 28 +++++++ scripts/chat.py | 31 ++++---- 10 files changed, 220 insertions(+), 133 deletions(-) delete mode 100644 config.example-with-auth.yaml diff --git a/README.md b/README.md index d2f5113..bc87426 100644 --- a/README.md +++ b/README.md @@ -212,20 +212,24 @@ The gateway supports Azure OpenAI with the same interface as standard OpenAI: ```yaml providers: azureopenai: + type: "azureopenai" api_key: "${AZURE_OPENAI_API_KEY}" endpoint: "https://your-resource.openai.azure.com" - deployment_id: "your-deployment-name" + +models: + - name: "gpt-4o" + provider: "azureopenai" + provider_model_id: "my-gpt4o-deployment" # optional: defaults to name ``` ```bash export AZURE_OPENAI_API_KEY="..." export AZURE_OPENAI_ENDPOINT="https://your-resource.openai.azure.com" -export AZURE_OPENAI_DEPLOYMENT_ID="gpt-4o" ./gateway ``` -The gateway prefers Azure OpenAI for `gpt-*` models if configured. See **[AZURE_OPENAI.md](./AZURE_OPENAI.md)** for complete setup guide. +The `provider_model_id` field lets you map a friendly model name to the actual provider identifier (e.g., an Azure deployment name). If omitted, the model `name` is used directly. See **[AZURE_OPENAI.md](./AZURE_OPENAI.md)** for complete setup guide. ## Authentication diff --git a/cmd/gateway/main.go b/cmd/gateway/main.go index 97d5260..3f37124 100644 --- a/cmd/gateway/main.go +++ b/cmd/gateway/main.go @@ -24,7 +24,7 @@ func main() { log.Fatalf("load config: %v", err) } - registry, err := providers.NewRegistry(cfg.Providers) + registry, err := providers.NewRegistry(cfg.Providers, cfg.Models) if err != nil { log.Fatalf("init providers: %v", err) } diff --git a/config.example-with-auth.yaml b/config.example-with-auth.yaml deleted file mode 100644 index fd79d59..0000000 --- a/config.example-with-auth.yaml +++ /dev/null @@ -1,19 +0,0 @@ -# Example configuration with Google OAuth2 authentication - -auth: - enabled: true - issuer: "https://accounts.google.com" - audience: "YOUR-CLIENT-ID.apps.googleusercontent.com" - -providers: - openai: - api_key: "${OPENAI_API_KEY}" - model: "gpt-4o-mini" - - anthropic: - api_key: "${ANTHROPIC_API_KEY}" - model: "claude-3-5-sonnet-20241022" - - google: - api_key: "${GOOGLE_API_KEY}" - model: "gemini-2.0-flash-exp" diff --git a/config.example.yaml b/config.example.yaml index f16ae20..5a12e1e 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -3,19 +3,38 @@ server: providers: google: + type: "google" api_key: "YOUR_GOOGLE_API_KEY" - model: "gemini-1.5-flash" endpoint: "https://generativelanguage.googleapis.com" anthropic: + type: "anthropic" api_key: "YOUR_ANTHROPIC_API_KEY" - model: "claude-3-5-sonnet" endpoint: "https://api.anthropic.com" openai: + type: "openai" api_key: "YOUR_OPENAI_API_KEY" - model: "gpt-4o-mini" endpoint: "https://api.openai.com" - # Azure-hosted Anthropic (Microsoft Foundry) - optional, overrides anthropic if set + # Azure OpenAI - optional + # azureopenai: + # type: "azureopenai" + # api_key: "YOUR_AZURE_OPENAI_API_KEY" + # endpoint: "https://your-resource.openai.azure.com" + # api_version: "2024-12-01-preview" + # Azure-hosted Anthropic (Microsoft Foundry) - optional # azureanthropic: + # type: "azureanthropic" # api_key: "YOUR_AZURE_ANTHROPIC_API_KEY" # endpoint: "https://your-resource.services.ai.azure.com/anthropic" - # model: "claude-sonnet-4-5-20250514" + +models: + - name: "gemini-1.5-flash" + provider: "google" + - name: "claude-3-5-sonnet" + provider: "anthropic" + - name: "gpt-4o-mini" + provider: "openai" + # - name: "gpt-4o" + # provider: "azureopenai" + # provider_model_id: "my-gpt4o-deployment" # optional: defaults to name + # - name: "claude-sonnet-4-5-20250514" + # provider: "azureanthropic" diff --git a/internal/api/types.go b/internal/api/types.go index 5b7c94b..a96ddd6 100644 --- a/internal/api/types.go +++ b/internal/api/types.go @@ -64,6 +64,18 @@ type StreamDelta struct { Content []ContentBlock `json:"content,omitempty"` } +// ModelInfo describes a single model available through the gateway. +type ModelInfo struct { + ID string `json:"id"` + Provider string `json:"provider"` +} + +// ModelsResponse is returned by the GET /v1/models endpoint. +type ModelsResponse struct { + Object string `json:"object"` + Data []ModelInfo `json:"data"` +} + // Validate performs basic structural validation. func (r *ResponseRequest) Validate() error { if r == nil { diff --git a/internal/config/config.go b/internal/config/config.go index 470f6eb..2d30632 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -9,9 +9,10 @@ import ( // Config describes the full gateway configuration file. type Config struct { - Server ServerConfig `yaml:"server"` - Providers ProvidersConfig `yaml:"providers"` - Auth AuthConfig `yaml:"auth"` + Server ServerConfig `yaml:"server"` + Providers map[string]ProviderEntry `yaml:"providers"` + Models []ModelEntry `yaml:"models"` + Auth AuthConfig `yaml:"auth"` } // AuthConfig holds OIDC authentication settings. @@ -26,89 +27,68 @@ type ServerConfig struct { Address string `yaml:"address"` } -// ProvidersConfig wraps supported provider settings. -type ProvidersConfig struct { - Google ProviderConfig `yaml:"google"` - Anthropic ProviderConfig `yaml:"anthropic"` - OpenAI ProviderConfig `yaml:"openai"` - AzureOpenAI AzureOpenAIConfig `yaml:"azureopenai"` - AzureAnthropic AzureAnthropicConfig `yaml:"azureanthropic"` +// ProviderEntry defines a named provider instance in the config file. +type ProviderEntry struct { + Type string `yaml:"type"` + APIKey string `yaml:"api_key"` + Endpoint string `yaml:"endpoint"` + APIVersion string `yaml:"api_version"` } -// AzureAnthropicConfig contains Azure-specific settings for Anthropic (Microsoft Foundry). -type AzureAnthropicConfig struct { - APIKey string `yaml:"api_key"` - Endpoint string `yaml:"endpoint"` - Model string `yaml:"model"` +// ModelEntry maps a model name to a provider entry. +type ModelEntry struct { + Name string `yaml:"name"` + Provider string `yaml:"provider"` + ProviderModelID string `yaml:"provider_model_id"` } -// ProviderConfig contains shared provider configuration fields. +// ProviderConfig contains shared provider configuration fields used internally by providers. type ProviderConfig struct { APIKey string `yaml:"api_key"` Model string `yaml:"model"` Endpoint string `yaml:"endpoint"` } -// AzureOpenAIConfig contains Azure-specific settings. +// AzureOpenAIConfig contains Azure-specific settings used internally by the OpenAI provider. type AzureOpenAIConfig struct { - APIKey string `yaml:"api_key"` - Endpoint string `yaml:"endpoint"` - DeploymentID string `yaml:"deployment_id"` - APIVersion string `yaml:"api_version"` + APIKey string `yaml:"api_key"` + Endpoint string `yaml:"endpoint"` + APIVersion string `yaml:"api_version"` } -// Load reads and parses a YAML configuration file and applies env overrides. +// AzureAnthropicConfig contains Azure-specific settings for Anthropic used internally. +type AzureAnthropicConfig struct { + APIKey string `yaml:"api_key"` + Endpoint string `yaml:"endpoint"` + Model string `yaml:"model"` +} + +// Load reads and parses a YAML configuration file, expanding ${VAR} env references. func Load(path string) (*Config, error) { data, err := os.ReadFile(path) if err != nil { return nil, fmt.Errorf("read config: %w", err) } + expanded := os.Expand(string(data), os.Getenv) + var cfg Config - if err := yaml.Unmarshal(data, &cfg); err != nil { + if err := yaml.Unmarshal([]byte(expanded), &cfg); err != nil { return nil, fmt.Errorf("parse config: %w", err) } - cfg.applyEnvOverrides() + if err := cfg.validate(); err != nil { + return nil, err + } + return &cfg, nil } -func (cfg *Config) applyEnvOverrides() { - overrideAPIKey(&cfg.Providers.Google, "GOOGLE_API_KEY") - overrideAPIKey(&cfg.Providers.Anthropic, "ANTHROPIC_API_KEY") - overrideAPIKey(&cfg.Providers.OpenAI, "OPENAI_API_KEY") - - // Azure OpenAI overrides - if v := os.Getenv("AZURE_OPENAI_API_KEY"); v != "" { - cfg.Providers.AzureOpenAI.APIKey = v - } - if v := os.Getenv("AZURE_OPENAI_ENDPOINT"); v != "" { - cfg.Providers.AzureOpenAI.Endpoint = v - } - if v := os.Getenv("AZURE_OPENAI_DEPLOYMENT_ID"); v != "" { - cfg.Providers.AzureOpenAI.DeploymentID = v - } - if v := os.Getenv("AZURE_OPENAI_API_VERSION"); v != "" { - cfg.Providers.AzureOpenAI.APIVersion = v - } - - // Azure Anthropic (Microsoft Foundry) overrides - if v := os.Getenv("AZURE_ANTHROPIC_API_KEY"); v != "" { - cfg.Providers.AzureAnthropic.APIKey = v - } - if v := os.Getenv("AZURE_ANTHROPIC_ENDPOINT"); v != "" { - cfg.Providers.AzureAnthropic.Endpoint = v - } - if v := os.Getenv("AZURE_ANTHROPIC_MODEL"); v != "" { - cfg.Providers.AzureAnthropic.Model = v - } -} - -func overrideAPIKey(cfg *ProviderConfig, envKey string) { - if cfg == nil { - return - } - if v := os.Getenv(envKey); v != "" { - cfg.APIKey = v +func (cfg *Config) validate() error { + for _, m := range cfg.Models { + if _, ok := cfg.Providers[m.Provider]; !ok { + return fmt.Errorf("model %q references unknown provider %q", m.Name, m.Provider) + } } + return nil } diff --git a/internal/providers/openai/openai.go b/internal/providers/openai/openai.go index ac55a01..ac9bc03 100644 --- a/internal/providers/openai/openai.go +++ b/internal/providers/openai/openai.go @@ -55,7 +55,6 @@ func NewAzure(azureCfg config.AzureOpenAIConfig) *Provider { return &Provider{ cfg: config.ProviderConfig{ APIKey: azureCfg.APIKey, - Model: azureCfg.DeploymentID, }, client: client, azure: true, diff --git a/internal/providers/providers.go b/internal/providers/providers.go index 3cf2a68..f32c79f 100644 --- a/internal/providers/providers.go +++ b/internal/providers/providers.go @@ -3,7 +3,6 @@ package providers import ( "context" "fmt" - "strings" "github.com/yourusername/go-llm-gateway/internal/api" "github.com/yourusername/go-llm-gateway/internal/config" @@ -19,27 +18,38 @@ type Provider interface { GenerateStream(ctx context.Context, req *api.ResponseRequest) (<-chan *api.StreamChunk, <-chan error) } -// Registry keeps track of registered providers by key (e.g. "openai"). +// Registry keeps track of registered providers and model-to-provider mappings. type Registry struct { - providers map[string]Provider + providers map[string]Provider + models map[string]string // model name -> provider entry name + providerModelIDs map[string]string // model name -> provider model ID + modelList []config.ModelEntry } // NewRegistry constructs provider implementations from configuration. -func NewRegistry(cfg config.ProvidersConfig) (*Registry, error) { - reg := &Registry{providers: make(map[string]Provider)} +func NewRegistry(entries map[string]config.ProviderEntry, models []config.ModelEntry) (*Registry, error) { + reg := &Registry{ + providers: make(map[string]Provider), + models: make(map[string]string), + providerModelIDs: make(map[string]string), + modelList: models, + } - if cfg.Google.APIKey != "" { - reg.providers[googleprovider.Name] = googleprovider.New(cfg.Google) + for name, entry := range entries { + p, err := buildProvider(entry) + if err != nil { + return nil, fmt.Errorf("provider %q: %w", name, err) + } + if p != nil { + reg.providers[name] = p + } } - if cfg.AzureAnthropic.APIKey != "" && cfg.AzureAnthropic.Endpoint != "" { - reg.providers[anthropicprovider.Name] = anthropicprovider.NewAzure(cfg.AzureAnthropic) - } else if cfg.Anthropic.APIKey != "" { - reg.providers[anthropicprovider.Name] = anthropicprovider.New(cfg.Anthropic) - } - if cfg.AzureOpenAI.APIKey != "" && cfg.AzureOpenAI.Endpoint != "" { - reg.providers[openaiprovider.Name] = openaiprovider.NewAzure(cfg.AzureOpenAI) - } else if cfg.OpenAI.APIKey != "" { - reg.providers[openaiprovider.Name] = openaiprovider.New(cfg.OpenAI) + + for _, m := range models { + reg.models[m.Name] = m.Provider + if m.ProviderModelID != "" { + reg.providerModelIDs[m.Name] = m.ProviderModelID + } } if len(reg.providers) == 0 { @@ -49,26 +59,77 @@ func NewRegistry(cfg config.ProvidersConfig) (*Registry, error) { return reg, nil } -// Get returns provider by key. +func buildProvider(entry config.ProviderEntry) (Provider, error) { + if entry.APIKey == "" { + return nil, nil + } + + switch entry.Type { + case "openai": + return openaiprovider.New(config.ProviderConfig{ + APIKey: entry.APIKey, + Endpoint: entry.Endpoint, + }), nil + case "azureopenai": + if entry.Endpoint == "" { + return nil, fmt.Errorf("endpoint is required for azureopenai") + } + return openaiprovider.NewAzure(config.AzureOpenAIConfig{ + APIKey: entry.APIKey, + Endpoint: entry.Endpoint, + APIVersion: entry.APIVersion, + }), nil + case "anthropic": + return anthropicprovider.New(config.ProviderConfig{ + APIKey: entry.APIKey, + Endpoint: entry.Endpoint, + }), nil + case "azureanthropic": + if entry.Endpoint == "" { + return nil, fmt.Errorf("endpoint is required for azureanthropic") + } + return anthropicprovider.NewAzure(config.AzureAnthropicConfig{ + APIKey: entry.APIKey, + Endpoint: entry.Endpoint, + }), nil + case "google": + return googleprovider.New(config.ProviderConfig{ + APIKey: entry.APIKey, + Endpoint: entry.Endpoint, + }), nil + default: + return nil, fmt.Errorf("unknown provider type %q", entry.Type) + } +} + +// Get returns provider by entry name. func (r *Registry) Get(name string) (Provider, bool) { p, ok := r.providers[name] return p, ok } -// Default returns provider based on inferred name. +// Models returns the list of configured models and their provider entry names. +func (r *Registry) Models() []struct{ Provider, Model string } { + var out []struct{ Provider, Model string } + for _, m := range r.modelList { + out = append(out, struct{ Provider, Model string }{Provider: m.Provider, Model: m.Name}) + } + return out +} + +// ResolveModelID returns the provider_model_id for a model, falling back to the model name itself. +func (r *Registry) ResolveModelID(model string) string { + if id, ok := r.providerModelIDs[model]; ok { + return id + } + return model +} + +// Default returns the provider for the given model name. func (r *Registry) Default(model string) (Provider, error) { if model != "" { - switch { - case strings.HasPrefix(model, "gpt") || strings.HasPrefix(model, "o1") || strings.HasPrefix(model, "o3"): - if p, ok := r.providers[openaiprovider.Name]; ok { - return p, nil - } - case strings.HasPrefix(model, "claude"): - if p, ok := r.providers[anthropicprovider.Name]; ok { - return p, nil - } - case strings.HasPrefix(model, "gemini"): - if p, ok := r.providers[googleprovider.Name]; ok { + if providerName, ok := r.models[model]; ok { + if p, ok := r.providers[providerName]; ok { return p, nil } } diff --git a/internal/server/server.go b/internal/server/server.go index fce079f..aada7d0 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -30,6 +30,31 @@ func New(registry *providers.Registry, convs *conversation.Store, logger *log.Lo // RegisterRoutes wires the HTTP handlers onto the provided mux. func (s *GatewayServer) RegisterRoutes(mux *http.ServeMux) { mux.HandleFunc("/v1/responses", s.handleResponses) + mux.HandleFunc("/v1/models", s.handleModels) +} + +func (s *GatewayServer) handleModels(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + models := s.registry.Models() + var data []api.ModelInfo + for _, m := range models { + data = append(data, api.ModelInfo{ + ID: m.Model, + Provider: m.Provider, + }) + } + + resp := api.ModelsResponse{ + Object: "list", + Data: data, + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(resp) } func (s *GatewayServer) handleResponses(w http.ResponseWriter, r *http.Request) { @@ -66,6 +91,9 @@ func (s *GatewayServer) handleResponses(w http.ResponseWriter, r *http.Request) return } + // Resolve provider_model_id (e.g., Azure deployment name) before sending to provider + fullReq.Model = s.registry.ResolveModelID(req.Model) + // Handle streaming vs non-streaming if req.Stream { s.handleStreamingResponse(w, r, provider, &fullReq, &req) diff --git a/scripts/chat.py b/scripts/chat.py index c647e19..a1d5245 100755 --- a/scripts/chat.py +++ b/scripts/chat.py @@ -139,22 +139,25 @@ class ChatClient: self.messages = [] -def print_models_table(): - """Print available models table.""" +def print_models_table(base_url: str, headers: dict): + """Fetch and print available models from the gateway.""" + console = Console() + try: + resp = httpx.get(f"{base_url}/v1/models", headers=headers, timeout=10) + resp.raise_for_status() + data = resp.json().get("data", []) + except Exception as e: + console.print(f"[red]Failed to fetch models: {e}[/red]") + return + table = Table(title="Available Models", show_header=True, header_style="bold magenta") table.add_column("Provider", style="cyan") table.add_column("Model ID", style="green") - table.add_column("Alias", style="yellow") - - table.add_row("OpenAI", "gpt-4o", "gpt4") - table.add_row("OpenAI", "gpt-4o-mini", "gpt4-mini") - table.add_row("OpenAI", "o1", "o1") - table.add_row("Anthropic", "claude-3-5-sonnet-20241022", "claude") - table.add_row("Anthropic", "claude-3-5-haiku-20241022", "haiku") - table.add_row("Google", "gemini-2.0-flash-exp", "gemini") - table.add_row("Google", "gemini-1.5-pro", "gemini-pro") - - Console().print(table) + + for model in data: + table.add_row(model.get("provider", ""), model.get("id", "")) + + console.print(table) def main(): @@ -227,7 +230,7 @@ def main(): )) elif cmd == "/models": - print_models_table() + print_models_table(args.url, client._headers()) elif cmd == "/model": if len(cmd_parts) < 2: