Add OAuth
This commit is contained in:
53
README.md
53
README.md
@@ -52,6 +52,8 @@ Go LLM Gateway (unified API)
|
|||||||
✅ **Provider auto-selection** (gpt→OpenAI, claude→Anthropic, gemini→Google)
|
✅ **Provider auto-selection** (gpt→OpenAI, claude→Anthropic, gemini→Google)
|
||||||
✅ **Configuration system** (YAML with env var support)
|
✅ **Configuration system** (YAML with env var support)
|
||||||
✅ **Streaming support** (Server-Sent Events for all providers)
|
✅ **Streaming support** (Server-Sent Events for all providers)
|
||||||
|
✅ **OAuth2/OIDC authentication** (Google, Auth0, any OIDC provider)
|
||||||
|
✅ **Terminal chat client** (Python with Rich UI, PEP 723)
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
@@ -168,9 +170,52 @@ For full specification details, see: **https://www.openresponses.org**
|
|||||||
- `internal/server`: HTTP handlers that expose `/v1/responses`.
|
- `internal/server`: HTTP handlers that expose `/v1/responses`.
|
||||||
- `internal/providers`: Provider abstractions plus provider-specific scaffolding in `google`, `anthropic`, and `openai` subpackages.
|
- `internal/providers`: Provider abstractions plus provider-specific scaffolding in `google`, `anthropic`, and `openai` subpackages.
|
||||||
|
|
||||||
|
## Chat Client
|
||||||
|
|
||||||
|
Interactive terminal chat interface with beautiful Rich UI:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Basic usage
|
||||||
|
uv run chat.py
|
||||||
|
|
||||||
|
# With authentication
|
||||||
|
uv run chat.py --token "$(gcloud auth print-identity-token)"
|
||||||
|
|
||||||
|
# Switch models on the fly
|
||||||
|
You> /model claude
|
||||||
|
You> /models # List all available models
|
||||||
|
```
|
||||||
|
|
||||||
|
See **[CHAT_CLIENT.md](./CHAT_CLIENT.md)** for full documentation.
|
||||||
|
|
||||||
|
## Authentication
|
||||||
|
|
||||||
|
The gateway supports OAuth2/OIDC authentication. See **[AUTH.md](./AUTH.md)** for setup instructions.
|
||||||
|
|
||||||
|
**Quick example with Google OAuth:**
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
auth:
|
||||||
|
enabled: true
|
||||||
|
issuer: "https://accounts.google.com"
|
||||||
|
audience: "YOUR-CLIENT-ID.apps.googleusercontent.com"
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Get token
|
||||||
|
TOKEN=$(gcloud auth print-identity-token)
|
||||||
|
|
||||||
|
# Make authenticated request
|
||||||
|
curl -X POST http://localhost:8080/v1/responses \
|
||||||
|
-H "Authorization: Bearer $TOKEN" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"model": "gemini-2.0-flash-exp", ...}'
|
||||||
|
```
|
||||||
|
|
||||||
## Next Steps
|
## Next Steps
|
||||||
|
|
||||||
- Implement the actual SDK calls inside each provider using the official Go clients.
|
- ✅ ~~Implement streaming responses~~
|
||||||
- Support streaming responses and tool invocation per the broader Open Responses spec.
|
- ✅ ~~Add OAuth2/OIDC authentication~~
|
||||||
- Add structured logging, tracing, and request-level metrics.
|
- ⬜ Add structured logging, tracing, and request-level metrics
|
||||||
- Expand configuration to support routing policies (cost, latency, failover, etc.).
|
- ⬜ Support tool/function calling
|
||||||
|
- ⬜ Expand configuration to support routing policies (cost, latency, failover)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/yourusername/go-llm-gateway/internal/auth"
|
||||||
"github.com/yourusername/go-llm-gateway/internal/config"
|
"github.com/yourusername/go-llm-gateway/internal/config"
|
||||||
"github.com/yourusername/go-llm-gateway/internal/providers"
|
"github.com/yourusername/go-llm-gateway/internal/providers"
|
||||||
"github.com/yourusername/go-llm-gateway/internal/server"
|
"github.com/yourusername/go-llm-gateway/internal/server"
|
||||||
@@ -29,6 +30,23 @@ func main() {
|
|||||||
|
|
||||||
logger := log.New(os.Stdout, "gateway ", log.LstdFlags|log.Lshortfile)
|
logger := log.New(os.Stdout, "gateway ", log.LstdFlags|log.Lshortfile)
|
||||||
|
|
||||||
|
// Initialize authentication middleware
|
||||||
|
authConfig := auth.Config{
|
||||||
|
Enabled: cfg.Auth.Enabled,
|
||||||
|
Issuer: cfg.Auth.Issuer,
|
||||||
|
Audience: cfg.Auth.Audience,
|
||||||
|
}
|
||||||
|
authMiddleware, err := auth.New(authConfig)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("init auth: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Auth.Enabled {
|
||||||
|
logger.Printf("Authentication enabled (issuer: %s)", cfg.Auth.Issuer)
|
||||||
|
} else {
|
||||||
|
logger.Printf("Authentication disabled - WARNING: API is publicly accessible")
|
||||||
|
}
|
||||||
|
|
||||||
gatewayServer := server.New(registry, logger)
|
gatewayServer := server.New(registry, logger)
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
gatewayServer.RegisterRoutes(mux)
|
gatewayServer.RegisterRoutes(mux)
|
||||||
@@ -38,9 +56,12 @@ func main() {
|
|||||||
addr = ":8080"
|
addr = ":8080"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Build handler chain: logging -> auth -> routes
|
||||||
|
handler := loggingMiddleware(authMiddleware.Handler(mux), logger)
|
||||||
|
|
||||||
srv := &http.Server{
|
srv := &http.Server{
|
||||||
Addr: addr,
|
Addr: addr,
|
||||||
Handler: loggingMiddleware(mux, logger),
|
Handler: handler,
|
||||||
ReadTimeout: 15 * time.Second,
|
ReadTimeout: 15 * time.Second,
|
||||||
WriteTimeout: 60 * time.Second,
|
WriteTimeout: 60 * time.Second,
|
||||||
IdleTimeout: 120 * time.Second,
|
IdleTimeout: 120 * time.Second,
|
||||||
|
|||||||
19
config.example-with-auth.yaml
Normal file
19
config.example-with-auth.yaml
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
# Example configuration with Google OAuth2 authentication
|
||||||
|
|
||||||
|
auth:
|
||||||
|
enabled: true
|
||||||
|
issuer: "https://accounts.google.com"
|
||||||
|
audience: "YOUR-CLIENT-ID.apps.googleusercontent.com"
|
||||||
|
|
||||||
|
providers:
|
||||||
|
openai:
|
||||||
|
api_key: "${OPENAI_API_KEY}"
|
||||||
|
model: "gpt-4o-mini"
|
||||||
|
|
||||||
|
anthropic:
|
||||||
|
api_key: "${ANTHROPIC_API_KEY}"
|
||||||
|
model: "claude-3-5-sonnet-20241022"
|
||||||
|
|
||||||
|
google:
|
||||||
|
api_key: "${GOOGLE_API_KEY}"
|
||||||
|
model: "gemini-2.0-flash-exp"
|
||||||
1
go.mod
1
go.mod
@@ -14,6 +14,7 @@ require (
|
|||||||
cloud.google.com/go v0.116.0 // indirect
|
cloud.google.com/go v0.116.0 // indirect
|
||||||
cloud.google.com/go/auth v0.9.3 // indirect
|
cloud.google.com/go/auth v0.9.3 // indirect
|
||||||
cloud.google.com/go/compute/metadata v0.5.0 // indirect
|
cloud.google.com/go/compute/metadata v0.5.0 // indirect
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.3.1 // indirect
|
||||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
||||||
github.com/google/go-cmp v0.6.0 // indirect
|
github.com/google/go-cmp v0.6.0 // indirect
|
||||||
github.com/google/s2a-go v0.1.8 // indirect
|
github.com/google/s2a-go v0.1.8 // indirect
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -20,6 +20,8 @@ github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymF
|
|||||||
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||||
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
|
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
|
||||||
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
|
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||||
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
|
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
|
||||||
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=
|
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=
|
||||||
|
|||||||
260
internal/auth/auth.go
Normal file
260
internal/auth/auth.go
Normal file
@@ -0,0 +1,260 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rsa"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config holds OIDC authentication configuration.
|
||||||
|
type Config struct {
|
||||||
|
Enabled bool `yaml:"enabled"`
|
||||||
|
Issuer string `yaml:"issuer"` // e.g., "https://accounts.google.com"
|
||||||
|
Audience string `yaml:"audience"` // e.g., your client ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Middleware provides JWT validation middleware.
|
||||||
|
type Middleware struct {
|
||||||
|
cfg Config
|
||||||
|
keys map[string]*rsa.PublicKey
|
||||||
|
mu sync.RWMutex
|
||||||
|
client *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates an authentication middleware.
|
||||||
|
func New(cfg Config) (*Middleware, error) {
|
||||||
|
if !cfg.Enabled {
|
||||||
|
return &Middleware{cfg: cfg}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Issuer == "" {
|
||||||
|
return nil, fmt.Errorf("auth enabled but issuer not configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
m := &Middleware{
|
||||||
|
cfg: cfg,
|
||||||
|
keys: make(map[string]*rsa.PublicKey),
|
||||||
|
client: &http.Client{Timeout: 10 * time.Second},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fetch JWKS on startup
|
||||||
|
if err := m.refreshJWKS(); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refresh JWKS periodically
|
||||||
|
go m.periodicRefresh()
|
||||||
|
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler wraps an HTTP handler with authentication.
|
||||||
|
func (m *Middleware) Handler(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if !m.cfg.Enabled {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract token from Authorization header
|
||||||
|
authHeader := r.Header.Get("Authorization")
|
||||||
|
if authHeader == "" {
|
||||||
|
http.Error(w, "missing authorization header", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.SplitN(authHeader, " ", 2)
|
||||||
|
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
|
||||||
|
http.Error(w, "invalid authorization header format", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenString := parts[1]
|
||||||
|
|
||||||
|
// Validate token
|
||||||
|
claims, err := m.validateToken(tokenString)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("invalid token: %v", err), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add claims to context
|
||||||
|
ctx := context.WithValue(r.Context(), claimsKey, claims)
|
||||||
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type contextKey string
|
||||||
|
|
||||||
|
const claimsKey contextKey = "jwt_claims"
|
||||||
|
|
||||||
|
// GetClaims extracts JWT claims from request context.
|
||||||
|
func GetClaims(ctx context.Context) (jwt.MapClaims, bool) {
|
||||||
|
claims, ok := ctx.Value(claimsKey).(jwt.MapClaims)
|
||||||
|
return claims, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Middleware) validateToken(tokenString string) (jwt.MapClaims, error) {
|
||||||
|
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||||
|
// Verify signing method
|
||||||
|
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||||
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get key ID from token header
|
||||||
|
kid, ok := token.Header["kid"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("missing kid in token header")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get public key
|
||||||
|
m.mu.RLock()
|
||||||
|
key, exists := m.keys[kid]
|
||||||
|
m.mu.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
// Try refreshing JWKS
|
||||||
|
if err := m.refreshJWKS(); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to refresh JWKS: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.RLock()
|
||||||
|
key, exists = m.keys[kid]
|
||||||
|
m.mu.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
return nil, fmt.Errorf("unknown key ID: %s", kid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return key, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
claims, ok := token.Claims.(jwt.MapClaims)
|
||||||
|
if !ok || !token.Valid {
|
||||||
|
return nil, fmt.Errorf("invalid token claims")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate issuer
|
||||||
|
if iss, ok := claims["iss"].(string); !ok || iss != m.cfg.Issuer {
|
||||||
|
return nil, fmt.Errorf("invalid issuer: %s", iss)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate audience if configured
|
||||||
|
if m.cfg.Audience != "" {
|
||||||
|
aud, ok := claims["aud"].(string)
|
||||||
|
if !ok {
|
||||||
|
// aud might be an array
|
||||||
|
audArray, ok := claims["aud"].([]interface{})
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid audience format")
|
||||||
|
}
|
||||||
|
found := false
|
||||||
|
for _, a := range audArray {
|
||||||
|
if audStr, ok := a.(string); ok && audStr == m.cfg.Audience {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
return nil, fmt.Errorf("audience not matched")
|
||||||
|
}
|
||||||
|
} else if aud != m.cfg.Audience {
|
||||||
|
return nil, fmt.Errorf("invalid audience: %s", aud)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return claims, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Middleware) refreshJWKS() error {
|
||||||
|
jwksURL := strings.TrimSuffix(m.cfg.Issuer, "/") + "/.well-known/openid-configuration"
|
||||||
|
|
||||||
|
// Fetch OIDC discovery document
|
||||||
|
resp, err := m.client.Get(jwksURL)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
var oidcConfig struct {
|
||||||
|
JwksURI string `json:"jwks_uri"`
|
||||||
|
}
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&oidcConfig); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fetch JWKS
|
||||||
|
resp, err = m.client.Get(oidcConfig.JwksURI)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
var jwks struct {
|
||||||
|
Keys []struct {
|
||||||
|
Kid string `json:"kid"`
|
||||||
|
Kty string `json:"kty"`
|
||||||
|
Use string `json:"use"`
|
||||||
|
N string `json:"n"`
|
||||||
|
E string `json:"e"`
|
||||||
|
} `json:"keys"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse keys
|
||||||
|
newKeys := make(map[string]*rsa.PublicKey)
|
||||||
|
for _, key := range jwks.Keys {
|
||||||
|
if key.Kty != "RSA" || key.Use != "sig" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
nBytes, err := base64.RawURLEncoding.DecodeString(key.N)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
eBytes, err := base64.RawURLEncoding.DecodeString(key.E)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
pubKey := &rsa.PublicKey{
|
||||||
|
N: new(big.Int).SetBytes(nBytes),
|
||||||
|
E: int(new(big.Int).SetBytes(eBytes).Int64()),
|
||||||
|
}
|
||||||
|
|
||||||
|
newKeys[key.Kid] = pubKey
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
m.keys = newKeys
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Middleware) periodicRefresh() {
|
||||||
|
ticker := time.NewTicker(1 * time.Hour)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for range ticker.C {
|
||||||
|
_ = m.refreshJWKS()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -11,6 +11,14 @@ import (
|
|||||||
type Config struct {
|
type Config struct {
|
||||||
Server ServerConfig `yaml:"server"`
|
Server ServerConfig `yaml:"server"`
|
||||||
Providers ProvidersConfig `yaml:"providers"`
|
Providers ProvidersConfig `yaml:"providers"`
|
||||||
|
Auth AuthConfig `yaml:"auth"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthConfig holds OIDC authentication settings.
|
||||||
|
type AuthConfig struct {
|
||||||
|
Enabled bool `yaml:"enabled"`
|
||||||
|
Issuer string `yaml:"issuer"`
|
||||||
|
Audience string `yaml:"audience"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServerConfig controls HTTP server values.
|
// ServerConfig controls HTTP server values.
|
||||||
|
|||||||
Reference in New Issue
Block a user