Fix context background and silent JWT

This commit is contained in:
2026-03-05 06:55:44 +00:00
parent 214e63b0c5
commit ae2e1b7a80
11 changed files with 99 additions and 92 deletions

View File

@@ -110,7 +110,7 @@ func main() {
Issuer: cfg.Auth.Issuer, Issuer: cfg.Auth.Issuer,
Audience: cfg.Auth.Audience, Audience: cfg.Auth.Audience,
} }
authMiddleware, err := auth.New(authConfig) authMiddleware, err := auth.New(authConfig, logger)
if err != nil { if err != nil {
logger.Error("failed to initialize auth", slog.String("error", err.Error())) logger.Error("failed to initialize auth", slog.String("error", err.Error()))
os.Exit(1) os.Exit(1)

View File

@@ -6,6 +6,7 @@ import (
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"log/slog"
"math/big" "math/big"
"net/http" "net/http"
"strings" "strings"
@@ -28,12 +29,13 @@ type Middleware struct {
keys map[string]*rsa.PublicKey keys map[string]*rsa.PublicKey
mu sync.RWMutex mu sync.RWMutex
client *http.Client client *http.Client
logger *slog.Logger
} }
// New creates an authentication middleware. // New creates an authentication middleware.
func New(cfg Config) (*Middleware, error) { func New(cfg Config, logger *slog.Logger) (*Middleware, error) {
if !cfg.Enabled { if !cfg.Enabled {
return &Middleware{cfg: cfg}, nil return &Middleware{cfg: cfg, logger: logger}, nil
} }
if cfg.Issuer == "" { if cfg.Issuer == "" {
@@ -44,6 +46,7 @@ func New(cfg Config) (*Middleware, error) {
cfg: cfg, cfg: cfg,
keys: make(map[string]*rsa.PublicKey), keys: make(map[string]*rsa.PublicKey),
client: &http.Client{Timeout: 10 * time.Second}, client: &http.Client{Timeout: 10 * time.Second},
logger: logger,
} }
// Fetch JWKS on startup // Fetch JWKS on startup
@@ -255,6 +258,15 @@ func (m *Middleware) periodicRefresh() {
defer ticker.Stop() defer ticker.Stop()
for range ticker.C { for range ticker.C {
_ = m.refreshJWKS() if err := m.refreshJWKS(); err != nil {
m.logger.Error("failed to refresh JWKS",
slog.String("issuer", m.cfg.Issuer),
slog.String("error", err.Error()),
)
} else {
m.logger.Debug("successfully refreshed JWKS",
slog.String("issuer", m.cfg.Issuer),
)
}
} }
} }

View File

@@ -7,6 +7,7 @@ import (
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"log/slog"
"math/big" "math/big"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@@ -213,7 +214,7 @@ func TestNew(t *testing.T) {
} }
} }
m, err := New(tt.config) m, err := New(tt.config, slog.Default())
if tt.expectError { if tt.expectError {
assert.Error(t, err) assert.Error(t, err)
@@ -239,7 +240,7 @@ func TestMiddleware_Handler(t *testing.T) {
Issuer: server.server.URL, Issuer: server.server.URL,
Audience: testAudience, Audience: testAudience,
} }
m, err := New(cfg) m, err := New(cfg, slog.Default())
require.NoError(t, err) require.NoError(t, err)
// Create a test handler that echoes back claims // Create a test handler that echoes back claims
@@ -415,7 +416,7 @@ func TestMiddleware_Handler_DisabledAuth(t *testing.T) {
cfg := Config{ cfg := Config{
Enabled: false, Enabled: false,
} }
m, err := New(cfg) m, err := New(cfg, slog.Default())
require.NoError(t, err) require.NoError(t, err)
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -442,7 +443,7 @@ func TestValidateToken(t *testing.T) {
Issuer: server.server.URL, Issuer: server.server.URL,
Audience: testAudience, Audience: testAudience,
} }
m, err := New(cfg) m, err := New(cfg, slog.Default())
require.NoError(t, err) require.NoError(t, err)
tests := []struct { tests := []struct {
@@ -665,7 +666,7 @@ func TestValidateToken_NoAudienceConfigured(t *testing.T) {
Issuer: server.server.URL, Issuer: server.server.URL,
Audience: "", // No audience required Audience: "", // No audience required
} }
m, err := New(cfg) m, err := New(cfg, slog.Default())
require.NoError(t, err) require.NoError(t, err)
// Token without audience should be valid // Token without audience should be valid
@@ -897,7 +898,7 @@ func TestRefreshJWKS_Concurrency(t *testing.T) {
Issuer: server.server.URL, Issuer: server.server.URL,
Audience: testAudience, Audience: testAudience,
} }
m, err := New(cfg) m, err := New(cfg, slog.Default())
require.NoError(t, err) require.NoError(t, err)
// Trigger concurrent refreshes // Trigger concurrent refreshes
@@ -982,7 +983,7 @@ func TestMiddleware_IssuerWithTrailingSlash(t *testing.T) {
Issuer: server.server.URL + "/", // Trailing slash Issuer: server.server.URL + "/", // Trailing slash
Audience: testAudience, Audience: testAudience,
} }
m, err := New(cfg) m, err := New(cfg, slog.Default())
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, m) require.NotNil(t, m)
assert.Len(t, m.keys, 1) assert.Len(t, m.keys, 1)

View File

@@ -1,6 +1,7 @@
package conversation package conversation
import ( import (
"context"
"sync" "sync"
"time" "time"
@@ -9,10 +10,10 @@ import (
// Store defines the interface for conversation storage backends. // Store defines the interface for conversation storage backends.
type Store interface { type Store interface {
Get(id string) (*Conversation, error) Get(ctx context.Context, id string) (*Conversation, error)
Create(id string, model string, messages []api.Message) (*Conversation, error) Create(ctx context.Context, id string, model string, messages []api.Message) (*Conversation, error)
Append(id string, messages ...api.Message) (*Conversation, error) Append(ctx context.Context, id string, messages ...api.Message) (*Conversation, error)
Delete(id string) error Delete(ctx context.Context, id string) error
Size() int Size() int
Close() error Close() error
} }
@@ -51,7 +52,7 @@ func NewMemoryStore(ttl time.Duration) *MemoryStore {
} }
// Get retrieves a conversation by ID. Returns a deep copy to prevent data races. // Get retrieves a conversation by ID. Returns a deep copy to prevent data races.
func (s *MemoryStore) Get(id string) (*Conversation, error) { func (s *MemoryStore) Get(ctx context.Context, id string) (*Conversation, error) {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
@@ -74,7 +75,7 @@ func (s *MemoryStore) Get(id string) (*Conversation, error) {
} }
// Create creates a new conversation with the given messages. // Create creates a new conversation with the given messages.
func (s *MemoryStore) Create(id string, model string, messages []api.Message) (*Conversation, error) { func (s *MemoryStore) Create(ctx context.Context, id string, model string, messages []api.Message) (*Conversation, error) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@@ -105,7 +106,7 @@ func (s *MemoryStore) Create(id string, model string, messages []api.Message) (*
} }
// Append adds new messages to an existing conversation. // Append adds new messages to an existing conversation.
func (s *MemoryStore) Append(id string, messages ...api.Message) (*Conversation, error) { func (s *MemoryStore) Append(ctx context.Context, id string, messages ...api.Message) (*Conversation, error) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@@ -131,7 +132,7 @@ func (s *MemoryStore) Append(id string, messages ...api.Message) (*Conversation,
} }
// Delete removes a conversation from the store. // Delete removes a conversation from the store.
func (s *MemoryStore) Delete(id string) error { func (s *MemoryStore) Delete(ctx context.Context, id string) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()

View File

@@ -1,6 +1,7 @@
package conversation package conversation
import ( import (
"context"
"testing" "testing"
"time" "time"
@@ -21,7 +22,7 @@ func TestMemoryStore_CreateAndGet(t *testing.T) {
}, },
} }
conv, err := store.Create("test-id", "gpt-4", messages) conv, err := store.Create(context.Background(),"test-id", "gpt-4", messages)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, conv) require.NotNil(t, conv)
assert.Equal(t, "test-id", conv.ID) assert.Equal(t, "test-id", conv.ID)
@@ -29,7 +30,7 @@ func TestMemoryStore_CreateAndGet(t *testing.T) {
assert.Len(t, conv.Messages, 1) assert.Len(t, conv.Messages, 1)
assert.Equal(t, "Hello", conv.Messages[0].Content[0].Text) assert.Equal(t, "Hello", conv.Messages[0].Content[0].Text)
retrieved, err := store.Get("test-id") retrieved, err := store.Get(context.Background(),"test-id")
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, retrieved) require.NotNil(t, retrieved)
assert.Equal(t, conv.ID, retrieved.ID) assert.Equal(t, conv.ID, retrieved.ID)
@@ -40,7 +41,7 @@ func TestMemoryStore_CreateAndGet(t *testing.T) {
func TestMemoryStore_GetNonExistent(t *testing.T) { func TestMemoryStore_GetNonExistent(t *testing.T) {
store := NewMemoryStore(1 * time.Hour) store := NewMemoryStore(1 * time.Hour)
conv, err := store.Get("nonexistent") conv, err := store.Get(context.Background(),"nonexistent")
require.NoError(t, err) require.NoError(t, err)
assert.Nil(t, conv, "should return nil for nonexistent conversation") assert.Nil(t, conv, "should return nil for nonexistent conversation")
} }
@@ -57,7 +58,7 @@ func TestMemoryStore_Append(t *testing.T) {
}, },
} }
_, err := store.Create("test-id", "gpt-4", initialMessages) _, err := store.Create(context.Background(),"test-id", "gpt-4", initialMessages)
require.NoError(t, err) require.NoError(t, err)
newMessages := []api.Message{ newMessages := []api.Message{
@@ -75,7 +76,7 @@ func TestMemoryStore_Append(t *testing.T) {
}, },
} }
conv, err := store.Append("test-id", newMessages...) conv, err := store.Append(context.Background(),"test-id", newMessages...)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, conv) require.NotNil(t, conv)
assert.Len(t, conv.Messages, 3, "should have all messages") assert.Len(t, conv.Messages, 3, "should have all messages")
@@ -94,7 +95,7 @@ func TestMemoryStore_AppendNonExistent(t *testing.T) {
}, },
} }
conv, err := store.Append("nonexistent", newMessage) conv, err := store.Append(context.Background(),"nonexistent", newMessage)
require.NoError(t, err) require.NoError(t, err)
assert.Nil(t, conv, "should return nil when appending to nonexistent conversation") assert.Nil(t, conv, "should return nil when appending to nonexistent conversation")
} }
@@ -111,20 +112,20 @@ func TestMemoryStore_Delete(t *testing.T) {
}, },
} }
_, err := store.Create("test-id", "gpt-4", messages) _, err := store.Create(context.Background(),"test-id", "gpt-4", messages)
require.NoError(t, err) require.NoError(t, err)
// Verify it exists // Verify it exists
conv, err := store.Get("test-id") conv, err := store.Get(context.Background(),"test-id")
require.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, conv) assert.NotNil(t, conv)
// Delete it // Delete it
err = store.Delete("test-id") err = store.Delete(context.Background(),"test-id")
require.NoError(t, err) require.NoError(t, err)
// Verify it's gone // Verify it's gone
conv, err = store.Get("test-id") conv, err = store.Get(context.Background(),"test-id")
require.NoError(t, err) require.NoError(t, err)
assert.Nil(t, conv, "conversation should be deleted") assert.Nil(t, conv, "conversation should be deleted")
} }
@@ -138,15 +139,15 @@ func TestMemoryStore_Size(t *testing.T) {
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}}, {Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
} }
_, err := store.Create("conv-1", "gpt-4", messages) _, err := store.Create(context.Background(),"conv-1", "gpt-4", messages)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 1, store.Size()) assert.Equal(t, 1, store.Size())
_, err = store.Create("conv-2", "gpt-4", messages) _, err = store.Create(context.Background(),"conv-2", "gpt-4", messages)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 2, store.Size()) assert.Equal(t, 2, store.Size())
err = store.Delete("conv-1") err = store.Delete(context.Background(),"conv-1")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 1, store.Size()) assert.Equal(t, 1, store.Size())
} }
@@ -159,14 +160,14 @@ func TestMemoryStore_ConcurrentAccess(t *testing.T) {
} }
// Create initial conversation // Create initial conversation
_, err := store.Create("test-id", "gpt-4", messages) _, err := store.Create(context.Background(),"test-id", "gpt-4", messages)
require.NoError(t, err) require.NoError(t, err)
// Simulate concurrent reads and writes // Simulate concurrent reads and writes
done := make(chan bool, 10) done := make(chan bool, 10)
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
go func() { go func() {
_, _ = store.Get("test-id") _, _ = store.Get(context.Background(),"test-id")
done <- true done <- true
}() }()
} }
@@ -176,7 +177,7 @@ func TestMemoryStore_ConcurrentAccess(t *testing.T) {
Role: "assistant", Role: "assistant",
Content: []api.ContentBlock{{Type: "output_text", Text: "Response"}}, Content: []api.ContentBlock{{Type: "output_text", Text: "Response"}},
} }
_, _ = store.Append("test-id", newMsg) _, _ = store.Append(context.Background(),"test-id", newMsg)
done <- true done <- true
}() }()
} }
@@ -187,7 +188,7 @@ func TestMemoryStore_ConcurrentAccess(t *testing.T) {
} }
// Verify final state // Verify final state
conv, err := store.Get("test-id") conv, err := store.Get(context.Background(),"test-id")
require.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, conv) assert.NotNil(t, conv)
assert.GreaterOrEqual(t, len(conv.Messages), 1) assert.GreaterOrEqual(t, len(conv.Messages), 1)
@@ -205,11 +206,11 @@ func TestMemoryStore_DeepCopy(t *testing.T) {
}, },
} }
_, err := store.Create("test-id", "gpt-4", messages) _, err := store.Create(context.Background(),"test-id", "gpt-4", messages)
require.NoError(t, err) require.NoError(t, err)
// Get conversation // Get conversation
conv1, err := store.Get("test-id") conv1, err := store.Get(context.Background(),"test-id")
require.NoError(t, err) require.NoError(t, err)
// Note: Current implementation copies the Messages slice but not the Content blocks // Note: Current implementation copies the Messages slice but not the Content blocks
@@ -225,7 +226,7 @@ func TestMemoryStore_DeepCopy(t *testing.T) {
assert.Equal(t, originalLen+1, len(conv1.Messages), "can modify returned message slice") assert.Equal(t, originalLen+1, len(conv1.Messages), "can modify returned message slice")
// Verify original is unchanged // Verify original is unchanged
conv2, err := store.Get("test-id") conv2, err := store.Get(context.Background(),"test-id")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, originalLen, len(conv2.Messages), "original conversation unaffected by slice modification") assert.Equal(t, originalLen, len(conv2.Messages), "original conversation unaffected by slice modification")
} }
@@ -238,11 +239,11 @@ func TestMemoryStore_TTLCleanup(t *testing.T) {
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}}, {Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
} }
_, err := store.Create("test-id", "gpt-4", messages) _, err := store.Create(context.Background(),"test-id", "gpt-4", messages)
require.NoError(t, err) require.NoError(t, err)
// Verify it exists // Verify it exists
conv, err := store.Get("test-id") conv, err := store.Get(context.Background(),"test-id")
require.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, conv) assert.NotNil(t, conv)
assert.Equal(t, 1, store.Size()) assert.Equal(t, 1, store.Size())
@@ -265,12 +266,12 @@ func TestMemoryStore_NoTTL(t *testing.T) {
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}}, {Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
} }
_, err := store.Create("test-id", "gpt-4", messages) _, err := store.Create(context.Background(),"test-id", "gpt-4", messages)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 1, store.Size()) assert.Equal(t, 1, store.Size())
// Without TTL, conversation should persist indefinitely // Without TTL, conversation should persist indefinitely
conv, err := store.Get("test-id") conv, err := store.Get(context.Background(),"test-id")
require.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, conv) assert.NotNil(t, conv)
} }
@@ -282,7 +283,7 @@ func TestMemoryStore_UpdatedAtTracking(t *testing.T) {
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}}, {Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
} }
conv, err := store.Create("test-id", "gpt-4", messages) conv, err := store.Create(context.Background(),"test-id", "gpt-4", messages)
require.NoError(t, err) require.NoError(t, err)
createdAt := conv.CreatedAt createdAt := conv.CreatedAt
updatedAt := conv.UpdatedAt updatedAt := conv.UpdatedAt
@@ -296,7 +297,7 @@ func TestMemoryStore_UpdatedAtTracking(t *testing.T) {
Role: "assistant", Role: "assistant",
Content: []api.ContentBlock{{Type: "output_text", Text: "Response"}}, Content: []api.ContentBlock{{Type: "output_text", Text: "Response"}},
} }
conv, err = store.Append("test-id", newMsg) conv, err = store.Append(context.Background(),"test-id", newMsg)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, createdAt, conv.CreatedAt, "created time should not change") assert.Equal(t, createdAt, conv.CreatedAt, "created time should not change")
@@ -313,7 +314,7 @@ func TestMemoryStore_MultipleConversations(t *testing.T) {
messages := []api.Message{ messages := []api.Message{
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello " + id}}}, {Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello " + id}}},
} }
_, err := store.Create(id, model, messages) _, err := store.Create(context.Background(),id, model, messages)
require.NoError(t, err) require.NoError(t, err)
} }
@@ -322,7 +323,7 @@ func TestMemoryStore_MultipleConversations(t *testing.T) {
// Verify each conversation is independent // Verify each conversation is independent
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
id := "conv-" + string(rune('0'+i)) id := "conv-" + string(rune('0'+i))
conv, err := store.Get(id) conv, err := store.Get(context.Background(),id)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, conv) require.NotNil(t, conv)
assert.Equal(t, id, conv.ID) assert.Equal(t, id, conv.ID)

View File

@@ -13,7 +13,6 @@ import (
type RedisStore struct { type RedisStore struct {
client *redis.Client client *redis.Client
ttl time.Duration ttl time.Duration
ctx context.Context
} }
// NewRedisStore creates a Redis-backed conversation store. // NewRedisStore creates a Redis-backed conversation store.
@@ -21,7 +20,6 @@ func NewRedisStore(client *redis.Client, ttl time.Duration) *RedisStore {
return &RedisStore{ return &RedisStore{
client: client, client: client,
ttl: ttl, ttl: ttl,
ctx: context.Background(),
} }
} }
@@ -31,8 +29,8 @@ func (s *RedisStore) key(id string) string {
} }
// Get retrieves a conversation by ID from Redis. // Get retrieves a conversation by ID from Redis.
func (s *RedisStore) Get(id string) (*Conversation, error) { func (s *RedisStore) Get(ctx context.Context, id string) (*Conversation, error) {
data, err := s.client.Get(s.ctx, s.key(id)).Bytes() data, err := s.client.Get(ctx, s.key(id)).Bytes()
if err == redis.Nil { if err == redis.Nil {
return nil, nil return nil, nil
} }
@@ -49,7 +47,7 @@ func (s *RedisStore) Get(id string) (*Conversation, error) {
} }
// Create creates a new conversation with the given messages. // Create creates a new conversation with the given messages.
func (s *RedisStore) Create(id string, model string, messages []api.Message) (*Conversation, error) { func (s *RedisStore) Create(ctx context.Context, id string, model string, messages []api.Message) (*Conversation, error) {
now := time.Now() now := time.Now()
conv := &Conversation{ conv := &Conversation{
ID: id, ID: id,
@@ -64,7 +62,7 @@ func (s *RedisStore) Create(id string, model string, messages []api.Message) (*C
return nil, err return nil, err
} }
if err := s.client.Set(s.ctx, s.key(id), data, s.ttl).Err(); err != nil { if err := s.client.Set(ctx, s.key(id), data, s.ttl).Err(); err != nil {
return nil, err return nil, err
} }
@@ -72,8 +70,8 @@ func (s *RedisStore) Create(id string, model string, messages []api.Message) (*C
} }
// Append adds new messages to an existing conversation. // Append adds new messages to an existing conversation.
func (s *RedisStore) Append(id string, messages ...api.Message) (*Conversation, error) { func (s *RedisStore) Append(ctx context.Context, id string, messages ...api.Message) (*Conversation, error) {
conv, err := s.Get(id) conv, err := s.Get(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -89,7 +87,7 @@ func (s *RedisStore) Append(id string, messages ...api.Message) (*Conversation,
return nil, err return nil, err
} }
if err := s.client.Set(s.ctx, s.key(id), data, s.ttl).Err(); err != nil { if err := s.client.Set(ctx, s.key(id), data, s.ttl).Err(); err != nil {
return nil, err return nil, err
} }
@@ -97,17 +95,18 @@ func (s *RedisStore) Append(id string, messages ...api.Message) (*Conversation,
} }
// Delete removes a conversation from Redis. // Delete removes a conversation from Redis.
func (s *RedisStore) Delete(id string) error { func (s *RedisStore) Delete(ctx context.Context, id string) error {
return s.client.Del(s.ctx, s.key(id)).Err() return s.client.Del(ctx, s.key(id)).Err()
} }
// Size returns the number of active conversations in Redis. // Size returns the number of active conversations in Redis.
func (s *RedisStore) Size() int { func (s *RedisStore) Size() int {
var count int var count int
var cursor uint64 var cursor uint64
ctx := context.Background()
for { for {
keys, nextCursor, err := s.client.Scan(s.ctx, cursor, "conv:*", 100).Result() keys, nextCursor, err := s.client.Scan(ctx, cursor, "conv:*", 100).Result()
if err != nil { if err != nil {
return 0 return 0
} }

View File

@@ -1,6 +1,7 @@
package conversation package conversation
import ( import (
"context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"time" "time"
@@ -71,8 +72,8 @@ func NewSQLStore(db *sql.DB, driver string, ttl time.Duration) (*SQLStore, error
return s, nil return s, nil
} }
func (s *SQLStore) Get(id string) (*Conversation, error) { func (s *SQLStore) Get(ctx context.Context, id string) (*Conversation, error) {
row := s.db.QueryRow(s.dialect.getByID, id) row := s.db.QueryRowContext(ctx, s.dialect.getByID, id)
var conv Conversation var conv Conversation
var msgJSON string var msgJSON string
@@ -91,14 +92,14 @@ func (s *SQLStore) Get(id string) (*Conversation, error) {
return &conv, nil return &conv, nil
} }
func (s *SQLStore) Create(id string, model string, messages []api.Message) (*Conversation, error) { func (s *SQLStore) Create(ctx context.Context, id string, model string, messages []api.Message) (*Conversation, error) {
now := time.Now() now := time.Now()
msgJSON, err := json.Marshal(messages) msgJSON, err := json.Marshal(messages)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if _, err := s.db.Exec(s.dialect.upsert, id, model, string(msgJSON), now, now); err != nil { if _, err := s.db.ExecContext(ctx, s.dialect.upsert, id, model, string(msgJSON), now, now); err != nil {
return nil, err return nil, err
} }
@@ -111,8 +112,8 @@ func (s *SQLStore) Create(id string, model string, messages []api.Message) (*Con
}, nil }, nil
} }
func (s *SQLStore) Append(id string, messages ...api.Message) (*Conversation, error) { func (s *SQLStore) Append(ctx context.Context, id string, messages ...api.Message) (*Conversation, error) {
conv, err := s.Get(id) conv, err := s.Get(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -128,15 +129,15 @@ func (s *SQLStore) Append(id string, messages ...api.Message) (*Conversation, er
return nil, err return nil, err
} }
if _, err := s.db.Exec(s.dialect.update, string(msgJSON), conv.UpdatedAt, id); err != nil { if _, err := s.db.ExecContext(ctx, s.dialect.update, string(msgJSON), conv.UpdatedAt, id); err != nil {
return nil, err return nil, err
} }
return conv, nil return conv, nil
} }
func (s *SQLStore) Delete(id string) error { func (s *SQLStore) Delete(ctx context.Context, id string) error {
_, err := s.db.Exec(s.dialect.deleteByID, id) _, err := s.db.ExecContext(ctx, s.dialect.deleteByID, id)
return err return err
} }

View File

@@ -42,9 +42,7 @@ func NewInstrumentedStore(s conversation.Store, backend string, registry *promet
} }
// Get wraps the store's Get method with metrics and tracing. // Get wraps the store's Get method with metrics and tracing.
func (s *InstrumentedStore) Get(id string) (*conversation.Conversation, error) { func (s *InstrumentedStore) Get(ctx context.Context, id string) (*conversation.Conversation, error) {
ctx := context.Background()
// Start span if tracing is enabled // Start span if tracing is enabled
if s.tracer != nil { if s.tracer != nil {
var span trace.Span var span trace.Span
@@ -61,7 +59,7 @@ func (s *InstrumentedStore) Get(id string) (*conversation.Conversation, error) {
start := time.Now() start := time.Now()
// Call underlying store // Call underlying store
conv, err := s.base.Get(id) conv, err := s.base.Get(ctx, id)
// Record metrics // Record metrics
duration := time.Since(start).Seconds() duration := time.Since(start).Seconds()
@@ -95,9 +93,7 @@ func (s *InstrumentedStore) Get(id string) (*conversation.Conversation, error) {
} }
// Create wraps the store's Create method with metrics and tracing. // Create wraps the store's Create method with metrics and tracing.
func (s *InstrumentedStore) Create(id string, model string, messages []api.Message) (*conversation.Conversation, error) { func (s *InstrumentedStore) Create(ctx context.Context, id string, model string, messages []api.Message) (*conversation.Conversation, error) {
ctx := context.Background()
// Start span if tracing is enabled // Start span if tracing is enabled
if s.tracer != nil { if s.tracer != nil {
var span trace.Span var span trace.Span
@@ -116,7 +112,7 @@ func (s *InstrumentedStore) Create(id string, model string, messages []api.Messa
start := time.Now() start := time.Now()
// Call underlying store // Call underlying store
conv, err := s.base.Create(id, model, messages) conv, err := s.base.Create(ctx, id, model, messages)
// Record metrics // Record metrics
duration := time.Since(start).Seconds() duration := time.Since(start).Seconds()
@@ -146,9 +142,7 @@ func (s *InstrumentedStore) Create(id string, model string, messages []api.Messa
} }
// Append wraps the store's Append method with metrics and tracing. // Append wraps the store's Append method with metrics and tracing.
func (s *InstrumentedStore) Append(id string, messages ...api.Message) (*conversation.Conversation, error) { func (s *InstrumentedStore) Append(ctx context.Context, id string, messages ...api.Message) (*conversation.Conversation, error) {
ctx := context.Background()
// Start span if tracing is enabled // Start span if tracing is enabled
if s.tracer != nil { if s.tracer != nil {
var span trace.Span var span trace.Span
@@ -166,7 +160,7 @@ func (s *InstrumentedStore) Append(id string, messages ...api.Message) (*convers
start := time.Now() start := time.Now()
// Call underlying store // Call underlying store
conv, err := s.base.Append(id, messages...) conv, err := s.base.Append(ctx, id, messages...)
// Record metrics // Record metrics
duration := time.Since(start).Seconds() duration := time.Since(start).Seconds()
@@ -199,9 +193,7 @@ func (s *InstrumentedStore) Append(id string, messages ...api.Message) (*convers
} }
// Delete wraps the store's Delete method with metrics and tracing. // Delete wraps the store's Delete method with metrics and tracing.
func (s *InstrumentedStore) Delete(id string) error { func (s *InstrumentedStore) Delete(ctx context.Context, id string) error {
ctx := context.Background()
// Start span if tracing is enabled // Start span if tracing is enabled
if s.tracer != nil { if s.tracer != nil {
var span trace.Span var span trace.Span
@@ -218,7 +210,7 @@ func (s *InstrumentedStore) Delete(id string) error {
start := time.Now() start := time.Now()
// Call underlying store // Call underlying store
err := s.base.Delete(id) err := s.base.Delete(ctx, id)
// Record metrics // Record metrics
duration := time.Since(start).Seconds() duration := time.Since(start).Seconds()

View File

@@ -51,7 +51,7 @@ func (s *GatewayServer) handleReady(w http.ResponseWriter, r *http.Request) {
// Test conversation store by attempting a simple operation // Test conversation store by attempting a simple operation
testID := "health_check_test" testID := "health_check_test"
_, err := s.convs.Get(testID) _, err := s.convs.Get(ctx, testID)
if err != nil { if err != nil {
checks["conversation_store"] = "unhealthy: " + err.Error() checks["conversation_store"] = "unhealthy: " + err.Error()
allHealthy = false allHealthy = false

View File

@@ -156,7 +156,7 @@ func newMockConversationStore() *mockConversationStore {
} }
} }
func (m *mockConversationStore) Get(id string) (*conversation.Conversation, error) { func (m *mockConversationStore) Get(ctx context.Context, id string) (*conversation.Conversation, error) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
@@ -170,7 +170,7 @@ func (m *mockConversationStore) Get(id string) (*conversation.Conversation, erro
return conv, nil return conv, nil
} }
func (m *mockConversationStore) Create(id string, model string, messages []api.Message) (*conversation.Conversation, error) { func (m *mockConversationStore) Create(ctx context.Context, id string, model string, messages []api.Message) (*conversation.Conversation, error) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
@@ -187,7 +187,7 @@ func (m *mockConversationStore) Create(id string, model string, messages []api.M
return conv, nil return conv, nil
} }
func (m *mockConversationStore) Append(id string, messages ...api.Message) (*conversation.Conversation, error) { func (m *mockConversationStore) Append(ctx context.Context, id string, messages ...api.Message) (*conversation.Conversation, error) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
@@ -203,7 +203,7 @@ func (m *mockConversationStore) Append(id string, messages ...api.Message) (*con
return conv, nil return conv, nil
} }
func (m *mockConversationStore) Delete(id string) error { func (m *mockConversationStore) Delete(ctx context.Context, id string) error {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()

View File

@@ -107,7 +107,7 @@ func (s *GatewayServer) handleResponses(w http.ResponseWriter, r *http.Request)
// Build full message history from previous conversation // Build full message history from previous conversation
var historyMsgs []api.Message var historyMsgs []api.Message
if req.PreviousResponseID != nil && *req.PreviousResponseID != "" { if req.PreviousResponseID != nil && *req.PreviousResponseID != "" {
conv, err := s.convs.Get(*req.PreviousResponseID) conv, err := s.convs.Get(r.Context(), *req.PreviousResponseID)
if err != nil { if err != nil {
s.logger.ErrorContext(r.Context(), "failed to retrieve conversation", s.logger.ErrorContext(r.Context(), "failed to retrieve conversation",
logger.LogAttrsWithTrace(r.Context(), logger.LogAttrsWithTrace(r.Context(),
@@ -186,7 +186,7 @@ func (s *GatewayServer) handleSyncResponse(w http.ResponseWriter, r *http.Reques
ToolCalls: result.ToolCalls, ToolCalls: result.ToolCalls,
} }
allMsgs := append(storeMsgs, assistantMsg) allMsgs := append(storeMsgs, assistantMsg)
if _, err := s.convs.Create(responseID, result.Model, allMsgs); err != nil { if _, err := s.convs.Create(r.Context(), responseID, result.Model, allMsgs); err != nil {
s.logger.ErrorContext(r.Context(), "failed to store conversation", s.logger.ErrorContext(r.Context(), "failed to store conversation",
logger.LogAttrsWithTrace(r.Context(), logger.LogAttrsWithTrace(r.Context(),
slog.String("request_id", logger.FromContext(r.Context())), slog.String("request_id", logger.FromContext(r.Context())),
@@ -543,7 +543,7 @@ loop:
ToolCalls: toolCalls, ToolCalls: toolCalls,
} }
allMsgs := append(storeMsgs, assistantMsg) allMsgs := append(storeMsgs, assistantMsg)
if _, err := s.convs.Create(responseID, model, allMsgs); err != nil { if _, err := s.convs.Create(r.Context(), responseID, model, allMsgs); err != nil {
s.logger.ErrorContext(r.Context(), "failed to store conversation", s.logger.ErrorContext(r.Context(), "failed to store conversation",
slog.String("request_id", logger.FromContext(r.Context())), slog.String("request_id", logger.FromContext(r.Context())),
slog.String("response_id", responseID), slog.String("response_id", responseID),