Files
latticelm/internal/auth/auth_test.go

1009 lines
25 KiB
Go

package auth
import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"fmt"
"log/slog"
"math/big"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Test fixtures
var (
testPrivateKey *rsa.PrivateKey
testPublicKey *rsa.PublicKey
testKID = "test-key-id-1"
testIssuer = "https://test-issuer.example.com"
testAudience = "test-client-id"
)
func init() {
// Generate test RSA key pair
var err error
testPrivateKey, err = rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
panic(fmt.Sprintf("failed to generate test key: %v", err))
}
testPublicKey = &testPrivateKey.PublicKey
}
// mockJWKSServer provides a mock OIDC/JWKS server for testing
type mockJWKSServer struct {
server *httptest.Server
jwksResponse []byte
oidcResponse []byte
mu sync.Mutex
requestCount int
failNext bool
}
func newMockJWKSServer(publicKey *rsa.PublicKey, kid string) *mockJWKSServer {
m := &mockJWKSServer{}
// Encode public key components for JWKS
nBytes := publicKey.N.Bytes()
eBytes := big.NewInt(int64(publicKey.E)).Bytes()
n := base64.RawURLEncoding.EncodeToString(nBytes)
e := base64.RawURLEncoding.EncodeToString(eBytes)
jwks := map[string]interface{}{
"keys": []map[string]interface{}{
{
"kid": kid,
"kty": "RSA",
"use": "sig",
"n": n,
"e": e,
},
},
}
m.jwksResponse, _ = json.Marshal(jwks)
mux := http.NewServeMux()
// OIDC discovery endpoint
mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
m.mu.Lock()
m.requestCount++
failNext := m.failNext
if m.failNext {
m.failNext = false
}
m.mu.Unlock()
if failNext {
http.Error(w, "service unavailable", http.StatusServiceUnavailable)
return
}
oidcConfig := map[string]string{
"jwks_uri": m.server.URL + "/jwks",
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(oidcConfig)
})
// JWKS endpoint
mux.HandleFunc("/jwks", func(w http.ResponseWriter, r *http.Request) {
m.mu.Lock()
m.requestCount++
failNext := m.failNext
if m.failNext {
m.failNext = false
}
m.mu.Unlock()
if failNext {
http.Error(w, "service unavailable", http.StatusServiceUnavailable)
return
}
w.Header().Set("Content-Type", "application/json")
w.Write(m.jwksResponse)
})
m.server = httptest.NewServer(mux)
return m
}
func (m *mockJWKSServer) close() {
m.server.Close()
}
func (m *mockJWKSServer) getRequestCount() int {
m.mu.Lock()
defer m.mu.Unlock()
return m.requestCount
}
func (m *mockJWKSServer) setFailNext() {
m.mu.Lock()
defer m.mu.Unlock()
m.failNext = true
}
func (m *mockJWKSServer) updateJWKS(newResponse []byte) {
m.mu.Lock()
defer m.mu.Unlock()
m.jwksResponse = newResponse
}
// generateTestJWT creates a signed JWT with the given claims
func generateTestJWT(privateKey *rsa.PrivateKey, claims jwt.MapClaims, kid string) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header["kid"] = kid
return token.SignedString(privateKey)
}
func TestNew(t *testing.T) {
tests := []struct {
name string
config Config
setupServer func() *mockJWKSServer
expectError bool
validate func(t *testing.T, m *Middleware)
}{
{
name: "disabled auth returns empty middleware",
config: Config{
Enabled: false,
},
expectError: false,
validate: func(t *testing.T, m *Middleware) {
assert.False(t, m.cfg.Enabled)
assert.Nil(t, m.keys)
assert.Nil(t, m.client)
},
},
{
name: "enabled without issuer returns error",
config: Config{
Enabled: true,
Issuer: "",
},
expectError: true,
},
{
name: "enabled with valid config fetches JWKS",
setupServer: func() *mockJWKSServer {
return newMockJWKSServer(testPublicKey, testKID)
},
expectError: false,
validate: func(t *testing.T, m *Middleware) {
assert.True(t, m.cfg.Enabled)
assert.NotNil(t, m.keys)
assert.NotNil(t, m.client)
assert.Len(t, m.keys, 1)
assert.Contains(t, m.keys, testKID)
},
},
{
name: "JWKS fetch failure returns error",
setupServer: func() *mockJWKSServer {
server := newMockJWKSServer(testPublicKey, testKID)
server.setFailNext()
return server
},
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var server *mockJWKSServer
if tt.setupServer != nil {
server = tt.setupServer()
defer server.close()
tt.config = Config{
Enabled: true,
Issuer: server.server.URL,
Audience: testAudience,
}
}
m, err := New(tt.config, slog.Default())
if tt.expectError {
assert.Error(t, err)
return
}
require.NoError(t, err)
require.NotNil(t, m)
if tt.validate != nil {
tt.validate(t, m)
}
})
}
}
func TestMiddleware_Handler(t *testing.T) {
server := newMockJWKSServer(testPublicKey, testKID)
defer server.close()
cfg := Config{
Enabled: true,
Issuer: server.server.URL,
Audience: testAudience,
}
m, err := New(cfg, slog.Default())
require.NoError(t, err)
// Create a test handler that echoes back claims
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
claims, ok := GetClaims(r.Context())
if ok {
w.WriteHeader(http.StatusOK)
w.Write([]byte(fmt.Sprintf("sub:%s", claims["sub"])))
} else {
w.WriteHeader(http.StatusOK)
w.Write([]byte("no-claims"))
}
})
handler := m.Handler(testHandler)
tests := []struct {
name string
setupRequest func() *http.Request
expectStatus int
expectBody string
validateClaims bool
}{
{
name: "missing authorization header",
setupRequest: func() *http.Request {
return httptest.NewRequest(http.MethodGet, "/test", nil)
},
expectStatus: http.StatusUnauthorized,
expectBody: "missing authorization header",
},
{
name: "malformed authorization header - no bearer",
setupRequest: func() *http.Request {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set("Authorization", "invalid-token")
return req
},
expectStatus: http.StatusUnauthorized,
expectBody: "invalid authorization header format",
},
{
name: "malformed authorization header - wrong scheme",
setupRequest: func() *http.Request {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set("Authorization", "Basic dGVzdDp0ZXN0")
return req
},
expectStatus: http.StatusUnauthorized,
expectBody: "invalid authorization header format",
},
{
name: "valid token with correct claims",
setupRequest: func() *http.Request {
claims := jwt.MapClaims{
"sub": "user123",
"iss": server.server.URL,
"aud": testAudience,
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}
token, err := generateTestJWT(testPrivateKey, claims, testKID)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set("Authorization", "Bearer "+token)
return req
},
expectStatus: http.StatusOK,
expectBody: "sub:user123",
validateClaims: true,
},
{
name: "expired token",
setupRequest: func() *http.Request {
claims := jwt.MapClaims{
"sub": "user123",
"iss": server.server.URL,
"aud": testAudience,
"exp": time.Now().Add(-time.Hour).Unix(),
"iat": time.Now().Add(-2 * time.Hour).Unix(),
}
token, err := generateTestJWT(testPrivateKey, claims, testKID)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set("Authorization", "Bearer "+token)
return req
},
expectStatus: http.StatusUnauthorized,
expectBody: "invalid token",
},
{
name: "token with wrong issuer",
setupRequest: func() *http.Request {
claims := jwt.MapClaims{
"sub": "user123",
"iss": "https://wrong-issuer.example.com",
"aud": testAudience,
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}
token, err := generateTestJWT(testPrivateKey, claims, testKID)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set("Authorization", "Bearer "+token)
return req
},
expectStatus: http.StatusUnauthorized,
expectBody: "invalid token",
},
{
name: "token with wrong audience",
setupRequest: func() *http.Request {
claims := jwt.MapClaims{
"sub": "user123",
"iss": server.server.URL,
"aud": "wrong-audience",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}
token, err := generateTestJWT(testPrivateKey, claims, testKID)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set("Authorization", "Bearer "+token)
return req
},
expectStatus: http.StatusUnauthorized,
expectBody: "invalid token",
},
{
name: "token with missing kid",
setupRequest: func() *http.Request {
claims := jwt.MapClaims{
"sub": "user123",
"iss": server.server.URL,
"aud": testAudience,
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
// Don't set kid header
tokenString, err := token.SignedString(testPrivateKey)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set("Authorization", "Bearer "+tokenString)
return req
},
expectStatus: http.StatusUnauthorized,
expectBody: "invalid token",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := tt.setupRequest()
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, tt.expectStatus, rec.Code)
if tt.expectBody != "" {
assert.Contains(t, rec.Body.String(), tt.expectBody)
}
})
}
}
func TestMiddleware_Handler_DisabledAuth(t *testing.T) {
cfg := Config{
Enabled: false,
}
m, err := New(cfg, slog.Default())
require.NoError(t, err)
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("success"))
})
handler := m.Handler(testHandler)
req := httptest.NewRequest(http.MethodGet, "/test", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "success", rec.Body.String())
}
func TestValidateToken(t *testing.T) {
server := newMockJWKSServer(testPublicKey, testKID)
defer server.close()
cfg := Config{
Enabled: true,
Issuer: server.server.URL,
Audience: testAudience,
}
m, err := New(cfg, slog.Default())
require.NoError(t, err)
tests := []struct {
name string
setupToken func() string
expectError bool
validate func(t *testing.T, claims jwt.MapClaims)
}{
{
name: "valid token with all required claims",
setupToken: func() string {
claims := jwt.MapClaims{
"sub": "user123",
"email": "user@example.com",
"iss": server.server.URL,
"aud": testAudience,
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}
token, err := generateTestJWT(testPrivateKey, claims, testKID)
require.NoError(t, err)
return token
},
expectError: false,
validate: func(t *testing.T, claims jwt.MapClaims) {
assert.Equal(t, "user123", claims["sub"])
assert.Equal(t, "user@example.com", claims["email"])
},
},
{
name: "token with audience as array",
setupToken: func() string {
claims := jwt.MapClaims{
"sub": "user123",
"iss": server.server.URL,
"aud": []interface{}{testAudience, "other-audience"},
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}
token, err := generateTestJWT(testPrivateKey, claims, testKID)
require.NoError(t, err)
return token
},
expectError: false,
},
{
name: "token with audience array not matching",
setupToken: func() string {
claims := jwt.MapClaims{
"sub": "user123",
"iss": server.server.URL,
"aud": []interface{}{"wrong-audience", "other-audience"},
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}
token, err := generateTestJWT(testPrivateKey, claims, testKID)
require.NoError(t, err)
return token
},
expectError: true,
},
{
name: "token with invalid audience format",
setupToken: func() string {
claims := jwt.MapClaims{
"sub": "user123",
"iss": server.server.URL,
"aud": 12345, // Invalid type
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}
token, err := generateTestJWT(testPrivateKey, claims, testKID)
require.NoError(t, err)
return token
},
expectError: true,
},
{
name: "token signed with wrong key",
setupToken: func() string {
wrongKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
claims := jwt.MapClaims{
"sub": "user123",
"iss": server.server.URL,
"aud": testAudience,
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}
token, err := generateTestJWT(wrongKey, claims, testKID)
require.NoError(t, err)
return token
},
expectError: true,
},
{
name: "token with unknown kid triggers JWKS refresh",
setupToken: func() string {
// Create a new key pair
newKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
newKID := "new-key-id"
// Update the JWKS to include the new key
nBytes := newKey.PublicKey.N.Bytes()
eBytes := big.NewInt(int64(newKey.PublicKey.E)).Bytes()
n := base64.RawURLEncoding.EncodeToString(nBytes)
e := base64.RawURLEncoding.EncodeToString(eBytes)
jwks := map[string]interface{}{
"keys": []map[string]interface{}{
{
"kid": testKID,
"kty": "RSA",
"use": "sig",
"n": base64.RawURLEncoding.EncodeToString(testPublicKey.N.Bytes()),
"e": base64.RawURLEncoding.EncodeToString(big.NewInt(int64(testPublicKey.E)).Bytes()),
},
{
"kid": newKID,
"kty": "RSA",
"use": "sig",
"n": n,
"e": e,
},
},
}
jwksResponse, _ := json.Marshal(jwks)
server.updateJWKS(jwksResponse)
claims := jwt.MapClaims{
"sub": "user123",
"iss": server.server.URL,
"aud": testAudience,
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}
token, err := generateTestJWT(newKey, claims, newKID)
require.NoError(t, err)
return token
},
expectError: false,
validate: func(t *testing.T, claims jwt.MapClaims) {
assert.Equal(t, "user123", claims["sub"])
},
},
{
name: "token with completely unknown kid after refresh",
setupToken: func() string {
unknownKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
claims := jwt.MapClaims{
"sub": "user123",
"iss": server.server.URL,
"aud": testAudience,
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}
token, err := generateTestJWT(unknownKey, claims, "completely-unknown-kid")
require.NoError(t, err)
return token
},
expectError: true,
},
{
name: "malformed token",
setupToken: func() string {
return "not.a.valid.jwt.token"
},
expectError: true,
},
{
name: "token with non-RSA signing method",
setupToken: func() string {
claims := jwt.MapClaims{
"sub": "user123",
"iss": server.server.URL,
"aud": testAudience,
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
token.Header["kid"] = testKID
tokenString, err := token.SignedString([]byte("secret"))
require.NoError(t, err)
return tokenString
},
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
token := tt.setupToken()
claims, err := m.validateToken(token)
if tt.expectError {
assert.Error(t, err)
return
}
require.NoError(t, err)
require.NotNil(t, claims)
if tt.validate != nil {
tt.validate(t, claims)
}
})
}
}
func TestValidateToken_NoAudienceConfigured(t *testing.T) {
server := newMockJWKSServer(testPublicKey, testKID)
defer server.close()
cfg := Config{
Enabled: true,
Issuer: server.server.URL,
Audience: "", // No audience required
}
m, err := New(cfg, slog.Default())
require.NoError(t, err)
// Token without audience should be valid
claims := jwt.MapClaims{
"sub": "user123",
"iss": server.server.URL,
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}
token, err := generateTestJWT(testPrivateKey, claims, testKID)
require.NoError(t, err)
validatedClaims, err := m.validateToken(token)
require.NoError(t, err)
assert.Equal(t, "user123", validatedClaims["sub"])
}
func TestRefreshJWKS(t *testing.T) {
tests := []struct {
name string
setupServer func() *mockJWKSServer
expectError bool
validate func(t *testing.T, m *Middleware)
}{
{
name: "successful JWKS fetch and parse",
setupServer: func() *mockJWKSServer {
return newMockJWKSServer(testPublicKey, testKID)
},
expectError: false,
validate: func(t *testing.T, m *Middleware) {
assert.Len(t, m.keys, 1)
assert.Contains(t, m.keys, testKID)
},
},
{
name: "OIDC discovery failure",
setupServer: func() *mockJWKSServer {
server := newMockJWKSServer(testPublicKey, testKID)
server.setFailNext()
return server
},
expectError: true,
},
{
name: "JWKS with multiple keys",
setupServer: func() *mockJWKSServer {
server := newMockJWKSServer(testPublicKey, testKID)
// Add another key
key2, _ := rsa.GenerateKey(rand.Reader, 2048)
kid2 := "test-key-id-2"
nBytes := key2.PublicKey.N.Bytes()
eBytes := big.NewInt(int64(key2.PublicKey.E)).Bytes()
n := base64.RawURLEncoding.EncodeToString(nBytes)
e := base64.RawURLEncoding.EncodeToString(eBytes)
jwks := map[string]interface{}{
"keys": []map[string]interface{}{
{
"kid": testKID,
"kty": "RSA",
"use": "sig",
"n": base64.RawURLEncoding.EncodeToString(testPublicKey.N.Bytes()),
"e": base64.RawURLEncoding.EncodeToString(big.NewInt(int64(testPublicKey.E)).Bytes()),
},
{
"kid": kid2,
"kty": "RSA",
"use": "sig",
"n": n,
"e": e,
},
},
}
jwksResponse, _ := json.Marshal(jwks)
server.updateJWKS(jwksResponse)
return server
},
expectError: false,
validate: func(t *testing.T, m *Middleware) {
assert.Len(t, m.keys, 2)
assert.Contains(t, m.keys, testKID)
assert.Contains(t, m.keys, "test-key-id-2")
},
},
{
name: "JWKS with non-RSA keys skipped",
setupServer: func() *mockJWKSServer {
server := newMockJWKSServer(testPublicKey, testKID)
jwks := map[string]interface{}{
"keys": []map[string]interface{}{
{
"kid": testKID,
"kty": "RSA",
"use": "sig",
"n": base64.RawURLEncoding.EncodeToString(testPublicKey.N.Bytes()),
"e": base64.RawURLEncoding.EncodeToString(big.NewInt(int64(testPublicKey.E)).Bytes()),
},
{
"kid": "ec-key",
"kty": "EC", // Non-RSA key
"use": "sig",
"crv": "P-256",
},
},
}
jwksResponse, _ := json.Marshal(jwks)
server.updateJWKS(jwksResponse)
return server
},
expectError: false,
validate: func(t *testing.T, m *Middleware) {
// Only RSA key should be loaded
assert.Len(t, m.keys, 1)
assert.Contains(t, m.keys, testKID)
},
},
{
name: "JWKS with wrong use field skipped",
setupServer: func() *mockJWKSServer {
server := newMockJWKSServer(testPublicKey, testKID)
jwks := map[string]interface{}{
"keys": []map[string]interface{}{
{
"kid": testKID,
"kty": "RSA",
"use": "sig",
"n": base64.RawURLEncoding.EncodeToString(testPublicKey.N.Bytes()),
"e": base64.RawURLEncoding.EncodeToString(big.NewInt(int64(testPublicKey.E)).Bytes()),
},
{
"kid": "enc-key",
"kty": "RSA",
"use": "enc", // Wrong use
"n": base64.RawURLEncoding.EncodeToString(testPublicKey.N.Bytes()),
"e": base64.RawURLEncoding.EncodeToString(big.NewInt(int64(testPublicKey.E)).Bytes()),
},
},
}
jwksResponse, _ := json.Marshal(jwks)
server.updateJWKS(jwksResponse)
return server
},
expectError: false,
validate: func(t *testing.T, m *Middleware) {
// Only key with use=sig should be loaded
assert.Len(t, m.keys, 1)
assert.Contains(t, m.keys, testKID)
},
},
{
name: "JWKS with invalid base64 encoding skipped",
setupServer: func() *mockJWKSServer {
server := newMockJWKSServer(testPublicKey, testKID)
jwks := map[string]interface{}{
"keys": []map[string]interface{}{
{
"kid": testKID,
"kty": "RSA",
"use": "sig",
"n": base64.RawURLEncoding.EncodeToString(testPublicKey.N.Bytes()),
"e": base64.RawURLEncoding.EncodeToString(big.NewInt(int64(testPublicKey.E)).Bytes()),
},
{
"kid": "bad-key",
"kty": "RSA",
"use": "sig",
"n": "!!!invalid-base64!!!",
"e": "AQAB",
},
},
}
jwksResponse, _ := json.Marshal(jwks)
server.updateJWKS(jwksResponse)
return server
},
expectError: false,
validate: func(t *testing.T, m *Middleware) {
// Only valid key should be loaded
assert.Len(t, m.keys, 1)
assert.Contains(t, m.keys, testKID)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := tt.setupServer()
defer server.close()
cfg := Config{
Enabled: true,
Issuer: server.server.URL,
Audience: testAudience,
}
m := &Middleware{
cfg: cfg,
keys: make(map[string]*rsa.PublicKey),
client: &http.Client{Timeout: 10 * time.Second},
}
err := m.refreshJWKS()
if tt.expectError {
assert.Error(t, err)
return
}
require.NoError(t, err)
if tt.validate != nil {
tt.validate(t, m)
}
})
}
}
func TestRefreshJWKS_Concurrency(t *testing.T) {
server := newMockJWKSServer(testPublicKey, testKID)
defer server.close()
cfg := Config{
Enabled: true,
Issuer: server.server.URL,
Audience: testAudience,
}
m, err := New(cfg, slog.Default())
require.NoError(t, err)
// Trigger concurrent refreshes
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_ = m.refreshJWKS()
}()
}
wg.Wait()
// Verify keys are still valid
m.mu.RLock()
defer m.mu.RUnlock()
assert.Len(t, m.keys, 1)
assert.Contains(t, m.keys, testKID)
}
func TestGetClaims(t *testing.T) {
tests := []struct {
name string
setupContext func() context.Context
expectFound bool
validateSubject string
}{
{
name: "context with claims",
setupContext: func() context.Context {
claims := jwt.MapClaims{
"sub": "user123",
"email": "user@example.com",
}
return context.WithValue(context.Background(), claimsKey, claims)
},
expectFound: true,
validateSubject: "user123",
},
{
name: "context without claims",
setupContext: func() context.Context {
return context.Background()
},
expectFound: false,
},
{
name: "context with wrong type",
setupContext: func() context.Context {
return context.WithValue(context.Background(), claimsKey, "not-claims")
},
expectFound: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := tt.setupContext()
claims, ok := GetClaims(ctx)
if tt.expectFound {
assert.True(t, ok)
assert.NotNil(t, claims)
if tt.validateSubject != "" {
assert.Equal(t, tt.validateSubject, claims["sub"])
}
} else {
assert.False(t, ok)
}
})
}
}
func TestMiddleware_IssuerWithTrailingSlash(t *testing.T) {
server := newMockJWKSServer(testPublicKey, testKID)
defer server.close()
// Test that issuer with trailing slash works
cfg := Config{
Enabled: true,
Issuer: server.server.URL + "/", // Trailing slash
Audience: testAudience,
}
m, err := New(cfg, slog.Default())
require.NoError(t, err)
require.NotNil(t, m)
assert.Len(t, m.keys, 1)
// Validate that token with issuer without trailing slash still works
claims := jwt.MapClaims{
"sub": "user123",
"iss": strings.TrimSuffix(server.server.URL, "/"),
"aud": testAudience,
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}
token, err := generateTestJWT(testPrivateKey, claims, testKID)
require.NoError(t, err)
// Update middleware to use issuer without trailing slash for comparison
m.cfg.Issuer = strings.TrimSuffix(m.cfg.Issuer, "/")
validatedClaims, err := m.validateToken(token)
require.NoError(t, err)
assert.Equal(t, "user123", validatedClaims["sub"])
}