From 4439567ccd5d276b35930e88fd842f0c49605ee8 Mon Sep 17 00:00:00 2001 From: Anibal Angulo Date: Sat, 28 Feb 2026 21:15:15 +0000 Subject: [PATCH] Add OAuth --- README.md | 53 ++++++- cmd/gateway/main.go | 23 ++- config.example-with-auth.yaml | 19 +++ go.mod | 1 + go.sum | 2 + internal/auth/auth.go | 260 ++++++++++++++++++++++++++++++++++ internal/config/config.go | 8 ++ 7 files changed, 361 insertions(+), 5 deletions(-) create mode 100644 config.example-with-auth.yaml create mode 100644 internal/auth/auth.go diff --git a/README.md b/README.md index 252a04b..38676ab 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,8 @@ Go LLM Gateway (unified API) ✅ **Provider auto-selection** (gpt→OpenAI, claude→Anthropic, gemini→Google) ✅ **Configuration system** (YAML with env var support) ✅ **Streaming support** (Server-Sent Events for all providers) +✅ **OAuth2/OIDC authentication** (Google, Auth0, any OIDC provider) +✅ **Terminal chat client** (Python with Rich UI, PEP 723) ## Quick Start @@ -168,9 +170,52 @@ For full specification details, see: **https://www.openresponses.org** - `internal/server`: HTTP handlers that expose `/v1/responses`. - `internal/providers`: Provider abstractions plus provider-specific scaffolding in `google`, `anthropic`, and `openai` subpackages. +## Chat Client + +Interactive terminal chat interface with beautiful Rich UI: + +```bash +# Basic usage +uv run chat.py + +# With authentication +uv run chat.py --token "$(gcloud auth print-identity-token)" + +# Switch models on the fly +You> /model claude +You> /models # List all available models +``` + +See **[CHAT_CLIENT.md](./CHAT_CLIENT.md)** for full documentation. + +## Authentication + +The gateway supports OAuth2/OIDC authentication. See **[AUTH.md](./AUTH.md)** for setup instructions. + +**Quick example with Google OAuth:** + +```yaml +auth: + enabled: true + issuer: "https://accounts.google.com" + audience: "YOUR-CLIENT-ID.apps.googleusercontent.com" +``` + +```bash +# Get token +TOKEN=$(gcloud auth print-identity-token) + +# Make authenticated request +curl -X POST http://localhost:8080/v1/responses \ + -H "Authorization: Bearer $TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"model": "gemini-2.0-flash-exp", ...}' +``` + ## Next Steps -- Implement the actual SDK calls inside each provider using the official Go clients. -- Support streaming responses and tool invocation per the broader Open Responses spec. -- Add structured logging, tracing, and request-level metrics. -- Expand configuration to support routing policies (cost, latency, failover, etc.). +- ✅ ~~Implement streaming responses~~ +- ✅ ~~Add OAuth2/OIDC authentication~~ +- ⬜ Add structured logging, tracing, and request-level metrics +- ⬜ Support tool/function calling +- ⬜ Expand configuration to support routing policies (cost, latency, failover) diff --git a/cmd/gateway/main.go b/cmd/gateway/main.go index 929a57e..d404696 100644 --- a/cmd/gateway/main.go +++ b/cmd/gateway/main.go @@ -7,6 +7,7 @@ import ( "os" "time" + "github.com/yourusername/go-llm-gateway/internal/auth" "github.com/yourusername/go-llm-gateway/internal/config" "github.com/yourusername/go-llm-gateway/internal/providers" "github.com/yourusername/go-llm-gateway/internal/server" @@ -29,6 +30,23 @@ func main() { logger := log.New(os.Stdout, "gateway ", log.LstdFlags|log.Lshortfile) + // Initialize authentication middleware + authConfig := auth.Config{ + Enabled: cfg.Auth.Enabled, + Issuer: cfg.Auth.Issuer, + Audience: cfg.Auth.Audience, + } + authMiddleware, err := auth.New(authConfig) + if err != nil { + log.Fatalf("init auth: %v", err) + } + + if cfg.Auth.Enabled { + logger.Printf("Authentication enabled (issuer: %s)", cfg.Auth.Issuer) + } else { + logger.Printf("Authentication disabled - WARNING: API is publicly accessible") + } + gatewayServer := server.New(registry, logger) mux := http.NewServeMux() gatewayServer.RegisterRoutes(mux) @@ -38,9 +56,12 @@ func main() { addr = ":8080" } + // Build handler chain: logging -> auth -> routes + handler := loggingMiddleware(authMiddleware.Handler(mux), logger) + srv := &http.Server{ Addr: addr, - Handler: loggingMiddleware(mux, logger), + Handler: handler, ReadTimeout: 15 * time.Second, WriteTimeout: 60 * time.Second, IdleTimeout: 120 * time.Second, diff --git a/config.example-with-auth.yaml b/config.example-with-auth.yaml new file mode 100644 index 0000000..fd79d59 --- /dev/null +++ b/config.example-with-auth.yaml @@ -0,0 +1,19 @@ +# Example configuration with Google OAuth2 authentication + +auth: + enabled: true + issuer: "https://accounts.google.com" + audience: "YOUR-CLIENT-ID.apps.googleusercontent.com" + +providers: + openai: + api_key: "${OPENAI_API_KEY}" + model: "gpt-4o-mini" + + anthropic: + api_key: "${ANTHROPIC_API_KEY}" + model: "claude-3-5-sonnet-20241022" + + google: + api_key: "${GOOGLE_API_KEY}" + model: "gemini-2.0-flash-exp" diff --git a/go.mod b/go.mod index dce794c..160dbd4 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( cloud.google.com/go v0.116.0 // indirect cloud.google.com/go/auth v0.9.3 // indirect cloud.google.com/go/compute/metadata v0.5.0 // indirect + github.com/golang-jwt/jwt/v5 v5.3.1 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/google/go-cmp v0.6.0 // indirect github.com/google/s2a-go v0.1.8 // indirect diff --git a/go.sum b/go.sum index ac057ae..18f27ce 100644 --- a/go.sum +++ b/go.sum @@ -20,6 +20,8 @@ github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymF github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= +github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= diff --git a/internal/auth/auth.go b/internal/auth/auth.go new file mode 100644 index 0000000..0aa9d52 --- /dev/null +++ b/internal/auth/auth.go @@ -0,0 +1,260 @@ +package auth + +import ( + "context" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "fmt" + "math/big" + "net/http" + "strings" + "sync" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +// Config holds OIDC authentication configuration. +type Config struct { + Enabled bool `yaml:"enabled"` + Issuer string `yaml:"issuer"` // e.g., "https://accounts.google.com" + Audience string `yaml:"audience"` // e.g., your client ID +} + +// Middleware provides JWT validation middleware. +type Middleware struct { + cfg Config + keys map[string]*rsa.PublicKey + mu sync.RWMutex + client *http.Client +} + +// New creates an authentication middleware. +func New(cfg Config) (*Middleware, error) { + if !cfg.Enabled { + return &Middleware{cfg: cfg}, nil + } + + if cfg.Issuer == "" { + return nil, fmt.Errorf("auth enabled but issuer not configured") + } + + m := &Middleware{ + cfg: cfg, + keys: make(map[string]*rsa.PublicKey), + client: &http.Client{Timeout: 10 * time.Second}, + } + + // Fetch JWKS on startup + if err := m.refreshJWKS(); err != nil { + return nil, fmt.Errorf("failed to fetch JWKS: %w", err) + } + + // Refresh JWKS periodically + go m.periodicRefresh() + + return m, nil +} + +// Handler wraps an HTTP handler with authentication. +func (m *Middleware) Handler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !m.cfg.Enabled { + next.ServeHTTP(w, r) + return + } + + // Extract token from Authorization header + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + http.Error(w, "missing authorization header", http.StatusUnauthorized) + return + } + + parts := strings.SplitN(authHeader, " ", 2) + if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" { + http.Error(w, "invalid authorization header format", http.StatusUnauthorized) + return + } + + tokenString := parts[1] + + // Validate token + claims, err := m.validateToken(tokenString) + if err != nil { + http.Error(w, fmt.Sprintf("invalid token: %v", err), http.StatusUnauthorized) + return + } + + // Add claims to context + ctx := context.WithValue(r.Context(), claimsKey, claims) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +type contextKey string + +const claimsKey contextKey = "jwt_claims" + +// GetClaims extracts JWT claims from request context. +func GetClaims(ctx context.Context) (jwt.MapClaims, bool) { + claims, ok := ctx.Value(claimsKey).(jwt.MapClaims) + return claims, ok +} + +func (m *Middleware) validateToken(tokenString string) (jwt.MapClaims, error) { + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + // Verify signing method + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + + // Get key ID from token header + kid, ok := token.Header["kid"].(string) + if !ok { + return nil, fmt.Errorf("missing kid in token header") + } + + // Get public key + m.mu.RLock() + key, exists := m.keys[kid] + m.mu.RUnlock() + + if !exists { + // Try refreshing JWKS + if err := m.refreshJWKS(); err != nil { + return nil, fmt.Errorf("failed to refresh JWKS: %w", err) + } + + m.mu.RLock() + key, exists = m.keys[kid] + m.mu.RUnlock() + + if !exists { + return nil, fmt.Errorf("unknown key ID: %s", kid) + } + } + + return key, nil + }) + + if err != nil { + return nil, err + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok || !token.Valid { + return nil, fmt.Errorf("invalid token claims") + } + + // Validate issuer + if iss, ok := claims["iss"].(string); !ok || iss != m.cfg.Issuer { + return nil, fmt.Errorf("invalid issuer: %s", iss) + } + + // Validate audience if configured + if m.cfg.Audience != "" { + aud, ok := claims["aud"].(string) + if !ok { + // aud might be an array + audArray, ok := claims["aud"].([]interface{}) + if !ok { + return nil, fmt.Errorf("invalid audience format") + } + found := false + for _, a := range audArray { + if audStr, ok := a.(string); ok && audStr == m.cfg.Audience { + found = true + break + } + } + if !found { + return nil, fmt.Errorf("audience not matched") + } + } else if aud != m.cfg.Audience { + return nil, fmt.Errorf("invalid audience: %s", aud) + } + } + + return claims, nil +} + +func (m *Middleware) refreshJWKS() error { + jwksURL := strings.TrimSuffix(m.cfg.Issuer, "/") + "/.well-known/openid-configuration" + + // Fetch OIDC discovery document + resp, err := m.client.Get(jwksURL) + if err != nil { + return err + } + defer resp.Body.Close() + + var oidcConfig struct { + JwksURI string `json:"jwks_uri"` + } + if err := json.NewDecoder(resp.Body).Decode(&oidcConfig); err != nil { + return err + } + + // Fetch JWKS + resp, err = m.client.Get(oidcConfig.JwksURI) + if err != nil { + return err + } + defer resp.Body.Close() + + var jwks struct { + Keys []struct { + Kid string `json:"kid"` + Kty string `json:"kty"` + Use string `json:"use"` + N string `json:"n"` + E string `json:"e"` + } `json:"keys"` + } + + if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil { + return err + } + + // Parse keys + newKeys := make(map[string]*rsa.PublicKey) + for _, key := range jwks.Keys { + if key.Kty != "RSA" || key.Use != "sig" { + continue + } + + nBytes, err := base64.RawURLEncoding.DecodeString(key.N) + if err != nil { + continue + } + + eBytes, err := base64.RawURLEncoding.DecodeString(key.E) + if err != nil { + continue + } + + pubKey := &rsa.PublicKey{ + N: new(big.Int).SetBytes(nBytes), + E: int(new(big.Int).SetBytes(eBytes).Int64()), + } + + newKeys[key.Kid] = pubKey + } + + m.mu.Lock() + m.keys = newKeys + m.mu.Unlock() + + return nil +} + +func (m *Middleware) periodicRefresh() { + ticker := time.NewTicker(1 * time.Hour) + defer ticker.Stop() + + for range ticker.C { + _ = m.refreshJWKS() + } +} diff --git a/internal/config/config.go b/internal/config/config.go index 9205cf8..c6c6ef0 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -11,6 +11,14 @@ import ( type Config struct { Server ServerConfig `yaml:"server"` Providers ProvidersConfig `yaml:"providers"` + Auth AuthConfig `yaml:"auth"` +} + +// AuthConfig holds OIDC authentication settings. +type AuthConfig struct { + Enabled bool `yaml:"enabled"` + Issuer string `yaml:"issuer"` + Audience string `yaml:"audience"` } // ServerConfig controls HTTP server values.