Update config structure
This commit is contained in:
10
README.md
10
README.md
@@ -212,20 +212,24 @@ The gateway supports Azure OpenAI with the same interface as standard OpenAI:
|
||||
```yaml
|
||||
providers:
|
||||
azureopenai:
|
||||
type: "azureopenai"
|
||||
api_key: "${AZURE_OPENAI_API_KEY}"
|
||||
endpoint: "https://your-resource.openai.azure.com"
|
||||
deployment_id: "your-deployment-name"
|
||||
|
||||
models:
|
||||
- name: "gpt-4o"
|
||||
provider: "azureopenai"
|
||||
provider_model_id: "my-gpt4o-deployment" # optional: defaults to name
|
||||
```
|
||||
|
||||
```bash
|
||||
export AZURE_OPENAI_API_KEY="..."
|
||||
export AZURE_OPENAI_ENDPOINT="https://your-resource.openai.azure.com"
|
||||
export AZURE_OPENAI_DEPLOYMENT_ID="gpt-4o"
|
||||
|
||||
./gateway
|
||||
```
|
||||
|
||||
The gateway prefers Azure OpenAI for `gpt-*` models if configured. See **[AZURE_OPENAI.md](./AZURE_OPENAI.md)** for complete setup guide.
|
||||
The `provider_model_id` field lets you map a friendly model name to the actual provider identifier (e.g., an Azure deployment name). If omitted, the model `name` is used directly. See **[AZURE_OPENAI.md](./AZURE_OPENAI.md)** for complete setup guide.
|
||||
|
||||
## Authentication
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ func main() {
|
||||
log.Fatalf("load config: %v", err)
|
||||
}
|
||||
|
||||
registry, err := providers.NewRegistry(cfg.Providers)
|
||||
registry, err := providers.NewRegistry(cfg.Providers, cfg.Models)
|
||||
if err != nil {
|
||||
log.Fatalf("init providers: %v", err)
|
||||
}
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
# Example configuration with Google OAuth2 authentication
|
||||
|
||||
auth:
|
||||
enabled: true
|
||||
issuer: "https://accounts.google.com"
|
||||
audience: "YOUR-CLIENT-ID.apps.googleusercontent.com"
|
||||
|
||||
providers:
|
||||
openai:
|
||||
api_key: "${OPENAI_API_KEY}"
|
||||
model: "gpt-4o-mini"
|
||||
|
||||
anthropic:
|
||||
api_key: "${ANTHROPIC_API_KEY}"
|
||||
model: "claude-3-5-sonnet-20241022"
|
||||
|
||||
google:
|
||||
api_key: "${GOOGLE_API_KEY}"
|
||||
model: "gemini-2.0-flash-exp"
|
||||
@@ -3,19 +3,38 @@ server:
|
||||
|
||||
providers:
|
||||
google:
|
||||
type: "google"
|
||||
api_key: "YOUR_GOOGLE_API_KEY"
|
||||
model: "gemini-1.5-flash"
|
||||
endpoint: "https://generativelanguage.googleapis.com"
|
||||
anthropic:
|
||||
type: "anthropic"
|
||||
api_key: "YOUR_ANTHROPIC_API_KEY"
|
||||
model: "claude-3-5-sonnet"
|
||||
endpoint: "https://api.anthropic.com"
|
||||
openai:
|
||||
type: "openai"
|
||||
api_key: "YOUR_OPENAI_API_KEY"
|
||||
model: "gpt-4o-mini"
|
||||
endpoint: "https://api.openai.com"
|
||||
# Azure-hosted Anthropic (Microsoft Foundry) - optional, overrides anthropic if set
|
||||
# Azure OpenAI - optional
|
||||
# azureopenai:
|
||||
# type: "azureopenai"
|
||||
# api_key: "YOUR_AZURE_OPENAI_API_KEY"
|
||||
# endpoint: "https://your-resource.openai.azure.com"
|
||||
# api_version: "2024-12-01-preview"
|
||||
# Azure-hosted Anthropic (Microsoft Foundry) - optional
|
||||
# azureanthropic:
|
||||
# type: "azureanthropic"
|
||||
# api_key: "YOUR_AZURE_ANTHROPIC_API_KEY"
|
||||
# endpoint: "https://your-resource.services.ai.azure.com/anthropic"
|
||||
# model: "claude-sonnet-4-5-20250514"
|
||||
|
||||
models:
|
||||
- name: "gemini-1.5-flash"
|
||||
provider: "google"
|
||||
- name: "claude-3-5-sonnet"
|
||||
provider: "anthropic"
|
||||
- name: "gpt-4o-mini"
|
||||
provider: "openai"
|
||||
# - name: "gpt-4o"
|
||||
# provider: "azureopenai"
|
||||
# provider_model_id: "my-gpt4o-deployment" # optional: defaults to name
|
||||
# - name: "claude-sonnet-4-5-20250514"
|
||||
# provider: "azureanthropic"
|
||||
|
||||
@@ -64,6 +64,18 @@ type StreamDelta struct {
|
||||
Content []ContentBlock `json:"content,omitempty"`
|
||||
}
|
||||
|
||||
// ModelInfo describes a single model available through the gateway.
|
||||
type ModelInfo struct {
|
||||
ID string `json:"id"`
|
||||
Provider string `json:"provider"`
|
||||
}
|
||||
|
||||
// ModelsResponse is returned by the GET /v1/models endpoint.
|
||||
type ModelsResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []ModelInfo `json:"data"`
|
||||
}
|
||||
|
||||
// Validate performs basic structural validation.
|
||||
func (r *ResponseRequest) Validate() error {
|
||||
if r == nil {
|
||||
|
||||
@@ -9,9 +9,10 @@ import (
|
||||
|
||||
// Config describes the full gateway configuration file.
|
||||
type Config struct {
|
||||
Server ServerConfig `yaml:"server"`
|
||||
Providers ProvidersConfig `yaml:"providers"`
|
||||
Auth AuthConfig `yaml:"auth"`
|
||||
Server ServerConfig `yaml:"server"`
|
||||
Providers map[string]ProviderEntry `yaml:"providers"`
|
||||
Models []ModelEntry `yaml:"models"`
|
||||
Auth AuthConfig `yaml:"auth"`
|
||||
}
|
||||
|
||||
// AuthConfig holds OIDC authentication settings.
|
||||
@@ -26,89 +27,68 @@ type ServerConfig struct {
|
||||
Address string `yaml:"address"`
|
||||
}
|
||||
|
||||
// ProvidersConfig wraps supported provider settings.
|
||||
type ProvidersConfig struct {
|
||||
Google ProviderConfig `yaml:"google"`
|
||||
Anthropic ProviderConfig `yaml:"anthropic"`
|
||||
OpenAI ProviderConfig `yaml:"openai"`
|
||||
AzureOpenAI AzureOpenAIConfig `yaml:"azureopenai"`
|
||||
AzureAnthropic AzureAnthropicConfig `yaml:"azureanthropic"`
|
||||
// ProviderEntry defines a named provider instance in the config file.
|
||||
type ProviderEntry struct {
|
||||
Type string `yaml:"type"`
|
||||
APIKey string `yaml:"api_key"`
|
||||
Endpoint string `yaml:"endpoint"`
|
||||
APIVersion string `yaml:"api_version"`
|
||||
}
|
||||
|
||||
// AzureAnthropicConfig contains Azure-specific settings for Anthropic (Microsoft Foundry).
|
||||
type AzureAnthropicConfig struct {
|
||||
APIKey string `yaml:"api_key"`
|
||||
Endpoint string `yaml:"endpoint"`
|
||||
Model string `yaml:"model"`
|
||||
// ModelEntry maps a model name to a provider entry.
|
||||
type ModelEntry struct {
|
||||
Name string `yaml:"name"`
|
||||
Provider string `yaml:"provider"`
|
||||
ProviderModelID string `yaml:"provider_model_id"`
|
||||
}
|
||||
|
||||
// ProviderConfig contains shared provider configuration fields.
|
||||
// ProviderConfig contains shared provider configuration fields used internally by providers.
|
||||
type ProviderConfig struct {
|
||||
APIKey string `yaml:"api_key"`
|
||||
Model string `yaml:"model"`
|
||||
Endpoint string `yaml:"endpoint"`
|
||||
}
|
||||
|
||||
// AzureOpenAIConfig contains Azure-specific settings.
|
||||
// AzureOpenAIConfig contains Azure-specific settings used internally by the OpenAI provider.
|
||||
type AzureOpenAIConfig struct {
|
||||
APIKey string `yaml:"api_key"`
|
||||
Endpoint string `yaml:"endpoint"`
|
||||
DeploymentID string `yaml:"deployment_id"`
|
||||
APIVersion string `yaml:"api_version"`
|
||||
APIKey string `yaml:"api_key"`
|
||||
Endpoint string `yaml:"endpoint"`
|
||||
APIVersion string `yaml:"api_version"`
|
||||
}
|
||||
|
||||
// Load reads and parses a YAML configuration file and applies env overrides.
|
||||
// AzureAnthropicConfig contains Azure-specific settings for Anthropic used internally.
|
||||
type AzureAnthropicConfig struct {
|
||||
APIKey string `yaml:"api_key"`
|
||||
Endpoint string `yaml:"endpoint"`
|
||||
Model string `yaml:"model"`
|
||||
}
|
||||
|
||||
// Load reads and parses a YAML configuration file, expanding ${VAR} env references.
|
||||
func Load(path string) (*Config, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read config: %w", err)
|
||||
}
|
||||
|
||||
expanded := os.Expand(string(data), os.Getenv)
|
||||
|
||||
var cfg Config
|
||||
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||||
if err := yaml.Unmarshal([]byte(expanded), &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
|
||||
cfg.applyEnvOverrides()
|
||||
if err := cfg.validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func (cfg *Config) applyEnvOverrides() {
|
||||
overrideAPIKey(&cfg.Providers.Google, "GOOGLE_API_KEY")
|
||||
overrideAPIKey(&cfg.Providers.Anthropic, "ANTHROPIC_API_KEY")
|
||||
overrideAPIKey(&cfg.Providers.OpenAI, "OPENAI_API_KEY")
|
||||
|
||||
// Azure OpenAI overrides
|
||||
if v := os.Getenv("AZURE_OPENAI_API_KEY"); v != "" {
|
||||
cfg.Providers.AzureOpenAI.APIKey = v
|
||||
}
|
||||
if v := os.Getenv("AZURE_OPENAI_ENDPOINT"); v != "" {
|
||||
cfg.Providers.AzureOpenAI.Endpoint = v
|
||||
}
|
||||
if v := os.Getenv("AZURE_OPENAI_DEPLOYMENT_ID"); v != "" {
|
||||
cfg.Providers.AzureOpenAI.DeploymentID = v
|
||||
}
|
||||
if v := os.Getenv("AZURE_OPENAI_API_VERSION"); v != "" {
|
||||
cfg.Providers.AzureOpenAI.APIVersion = v
|
||||
}
|
||||
|
||||
// Azure Anthropic (Microsoft Foundry) overrides
|
||||
if v := os.Getenv("AZURE_ANTHROPIC_API_KEY"); v != "" {
|
||||
cfg.Providers.AzureAnthropic.APIKey = v
|
||||
}
|
||||
if v := os.Getenv("AZURE_ANTHROPIC_ENDPOINT"); v != "" {
|
||||
cfg.Providers.AzureAnthropic.Endpoint = v
|
||||
}
|
||||
if v := os.Getenv("AZURE_ANTHROPIC_MODEL"); v != "" {
|
||||
cfg.Providers.AzureAnthropic.Model = v
|
||||
}
|
||||
}
|
||||
|
||||
func overrideAPIKey(cfg *ProviderConfig, envKey string) {
|
||||
if cfg == nil {
|
||||
return
|
||||
}
|
||||
if v := os.Getenv(envKey); v != "" {
|
||||
cfg.APIKey = v
|
||||
func (cfg *Config) validate() error {
|
||||
for _, m := range cfg.Models {
|
||||
if _, ok := cfg.Providers[m.Provider]; !ok {
|
||||
return fmt.Errorf("model %q references unknown provider %q", m.Name, m.Provider)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -55,7 +55,6 @@ func NewAzure(azureCfg config.AzureOpenAIConfig) *Provider {
|
||||
return &Provider{
|
||||
cfg: config.ProviderConfig{
|
||||
APIKey: azureCfg.APIKey,
|
||||
Model: azureCfg.DeploymentID,
|
||||
},
|
||||
client: client,
|
||||
azure: true,
|
||||
|
||||
@@ -3,7 +3,6 @@ package providers
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/yourusername/go-llm-gateway/internal/api"
|
||||
"github.com/yourusername/go-llm-gateway/internal/config"
|
||||
@@ -19,27 +18,38 @@ type Provider interface {
|
||||
GenerateStream(ctx context.Context, req *api.ResponseRequest) (<-chan *api.StreamChunk, <-chan error)
|
||||
}
|
||||
|
||||
// Registry keeps track of registered providers by key (e.g. "openai").
|
||||
// Registry keeps track of registered providers and model-to-provider mappings.
|
||||
type Registry struct {
|
||||
providers map[string]Provider
|
||||
providers map[string]Provider
|
||||
models map[string]string // model name -> provider entry name
|
||||
providerModelIDs map[string]string // model name -> provider model ID
|
||||
modelList []config.ModelEntry
|
||||
}
|
||||
|
||||
// NewRegistry constructs provider implementations from configuration.
|
||||
func NewRegistry(cfg config.ProvidersConfig) (*Registry, error) {
|
||||
reg := &Registry{providers: make(map[string]Provider)}
|
||||
func NewRegistry(entries map[string]config.ProviderEntry, models []config.ModelEntry) (*Registry, error) {
|
||||
reg := &Registry{
|
||||
providers: make(map[string]Provider),
|
||||
models: make(map[string]string),
|
||||
providerModelIDs: make(map[string]string),
|
||||
modelList: models,
|
||||
}
|
||||
|
||||
if cfg.Google.APIKey != "" {
|
||||
reg.providers[googleprovider.Name] = googleprovider.New(cfg.Google)
|
||||
for name, entry := range entries {
|
||||
p, err := buildProvider(entry)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("provider %q: %w", name, err)
|
||||
}
|
||||
if p != nil {
|
||||
reg.providers[name] = p
|
||||
}
|
||||
}
|
||||
if cfg.AzureAnthropic.APIKey != "" && cfg.AzureAnthropic.Endpoint != "" {
|
||||
reg.providers[anthropicprovider.Name] = anthropicprovider.NewAzure(cfg.AzureAnthropic)
|
||||
} else if cfg.Anthropic.APIKey != "" {
|
||||
reg.providers[anthropicprovider.Name] = anthropicprovider.New(cfg.Anthropic)
|
||||
}
|
||||
if cfg.AzureOpenAI.APIKey != "" && cfg.AzureOpenAI.Endpoint != "" {
|
||||
reg.providers[openaiprovider.Name] = openaiprovider.NewAzure(cfg.AzureOpenAI)
|
||||
} else if cfg.OpenAI.APIKey != "" {
|
||||
reg.providers[openaiprovider.Name] = openaiprovider.New(cfg.OpenAI)
|
||||
|
||||
for _, m := range models {
|
||||
reg.models[m.Name] = m.Provider
|
||||
if m.ProviderModelID != "" {
|
||||
reg.providerModelIDs[m.Name] = m.ProviderModelID
|
||||
}
|
||||
}
|
||||
|
||||
if len(reg.providers) == 0 {
|
||||
@@ -49,26 +59,77 @@ func NewRegistry(cfg config.ProvidersConfig) (*Registry, error) {
|
||||
return reg, nil
|
||||
}
|
||||
|
||||
// Get returns provider by key.
|
||||
func buildProvider(entry config.ProviderEntry) (Provider, error) {
|
||||
if entry.APIKey == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
switch entry.Type {
|
||||
case "openai":
|
||||
return openaiprovider.New(config.ProviderConfig{
|
||||
APIKey: entry.APIKey,
|
||||
Endpoint: entry.Endpoint,
|
||||
}), nil
|
||||
case "azureopenai":
|
||||
if entry.Endpoint == "" {
|
||||
return nil, fmt.Errorf("endpoint is required for azureopenai")
|
||||
}
|
||||
return openaiprovider.NewAzure(config.AzureOpenAIConfig{
|
||||
APIKey: entry.APIKey,
|
||||
Endpoint: entry.Endpoint,
|
||||
APIVersion: entry.APIVersion,
|
||||
}), nil
|
||||
case "anthropic":
|
||||
return anthropicprovider.New(config.ProviderConfig{
|
||||
APIKey: entry.APIKey,
|
||||
Endpoint: entry.Endpoint,
|
||||
}), nil
|
||||
case "azureanthropic":
|
||||
if entry.Endpoint == "" {
|
||||
return nil, fmt.Errorf("endpoint is required for azureanthropic")
|
||||
}
|
||||
return anthropicprovider.NewAzure(config.AzureAnthropicConfig{
|
||||
APIKey: entry.APIKey,
|
||||
Endpoint: entry.Endpoint,
|
||||
}), nil
|
||||
case "google":
|
||||
return googleprovider.New(config.ProviderConfig{
|
||||
APIKey: entry.APIKey,
|
||||
Endpoint: entry.Endpoint,
|
||||
}), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown provider type %q", entry.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns provider by entry name.
|
||||
func (r *Registry) Get(name string) (Provider, bool) {
|
||||
p, ok := r.providers[name]
|
||||
return p, ok
|
||||
}
|
||||
|
||||
// Default returns provider based on inferred name.
|
||||
// Models returns the list of configured models and their provider entry names.
|
||||
func (r *Registry) Models() []struct{ Provider, Model string } {
|
||||
var out []struct{ Provider, Model string }
|
||||
for _, m := range r.modelList {
|
||||
out = append(out, struct{ Provider, Model string }{Provider: m.Provider, Model: m.Name})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// ResolveModelID returns the provider_model_id for a model, falling back to the model name itself.
|
||||
func (r *Registry) ResolveModelID(model string) string {
|
||||
if id, ok := r.providerModelIDs[model]; ok {
|
||||
return id
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
// Default returns the provider for the given model name.
|
||||
func (r *Registry) Default(model string) (Provider, error) {
|
||||
if model != "" {
|
||||
switch {
|
||||
case strings.HasPrefix(model, "gpt") || strings.HasPrefix(model, "o1") || strings.HasPrefix(model, "o3"):
|
||||
if p, ok := r.providers[openaiprovider.Name]; ok {
|
||||
return p, nil
|
||||
}
|
||||
case strings.HasPrefix(model, "claude"):
|
||||
if p, ok := r.providers[anthropicprovider.Name]; ok {
|
||||
return p, nil
|
||||
}
|
||||
case strings.HasPrefix(model, "gemini"):
|
||||
if p, ok := r.providers[googleprovider.Name]; ok {
|
||||
if providerName, ok := r.models[model]; ok {
|
||||
if p, ok := r.providers[providerName]; ok {
|
||||
return p, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,6 +30,31 @@ func New(registry *providers.Registry, convs *conversation.Store, logger *log.Lo
|
||||
// RegisterRoutes wires the HTTP handlers onto the provided mux.
|
||||
func (s *GatewayServer) RegisterRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("/v1/responses", s.handleResponses)
|
||||
mux.HandleFunc("/v1/models", s.handleModels)
|
||||
}
|
||||
|
||||
func (s *GatewayServer) handleModels(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
models := s.registry.Models()
|
||||
var data []api.ModelInfo
|
||||
for _, m := range models {
|
||||
data = append(data, api.ModelInfo{
|
||||
ID: m.Model,
|
||||
Provider: m.Provider,
|
||||
})
|
||||
}
|
||||
|
||||
resp := api.ModelsResponse{
|
||||
Object: "list",
|
||||
Data: data,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
func (s *GatewayServer) handleResponses(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -66,6 +91,9 @@ func (s *GatewayServer) handleResponses(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
// Resolve provider_model_id (e.g., Azure deployment name) before sending to provider
|
||||
fullReq.Model = s.registry.ResolveModelID(req.Model)
|
||||
|
||||
// Handle streaming vs non-streaming
|
||||
if req.Stream {
|
||||
s.handleStreamingResponse(w, r, provider, &fullReq, &req)
|
||||
|
||||
@@ -139,22 +139,25 @@ class ChatClient:
|
||||
self.messages = []
|
||||
|
||||
|
||||
def print_models_table():
|
||||
"""Print available models table."""
|
||||
def print_models_table(base_url: str, headers: dict):
|
||||
"""Fetch and print available models from the gateway."""
|
||||
console = Console()
|
||||
try:
|
||||
resp = httpx.get(f"{base_url}/v1/models", headers=headers, timeout=10)
|
||||
resp.raise_for_status()
|
||||
data = resp.json().get("data", [])
|
||||
except Exception as e:
|
||||
console.print(f"[red]Failed to fetch models: {e}[/red]")
|
||||
return
|
||||
|
||||
table = Table(title="Available Models", show_header=True, header_style="bold magenta")
|
||||
table.add_column("Provider", style="cyan")
|
||||
table.add_column("Model ID", style="green")
|
||||
table.add_column("Alias", style="yellow")
|
||||
|
||||
table.add_row("OpenAI", "gpt-4o", "gpt4")
|
||||
table.add_row("OpenAI", "gpt-4o-mini", "gpt4-mini")
|
||||
table.add_row("OpenAI", "o1", "o1")
|
||||
table.add_row("Anthropic", "claude-3-5-sonnet-20241022", "claude")
|
||||
table.add_row("Anthropic", "claude-3-5-haiku-20241022", "haiku")
|
||||
table.add_row("Google", "gemini-2.0-flash-exp", "gemini")
|
||||
table.add_row("Google", "gemini-1.5-pro", "gemini-pro")
|
||||
for model in data:
|
||||
table.add_row(model.get("provider", ""), model.get("id", ""))
|
||||
|
||||
Console().print(table)
|
||||
console.print(table)
|
||||
|
||||
|
||||
def main():
|
||||
@@ -227,7 +230,7 @@ def main():
|
||||
))
|
||||
|
||||
elif cmd == "/models":
|
||||
print_models_table()
|
||||
print_models_table(args.url, client._headers())
|
||||
|
||||
elif cmd == "/model":
|
||||
if len(cmd_parts) < 2:
|
||||
|
||||
Reference in New Issue
Block a user