diff --git a/cmd/gateway/main.go b/cmd/gateway/main.go index 4f53e31..259183c 100644 --- a/cmd/gateway/main.go +++ b/cmd/gateway/main.go @@ -110,7 +110,7 @@ func main() { Issuer: cfg.Auth.Issuer, Audience: cfg.Auth.Audience, } - authMiddleware, err := auth.New(authConfig) + authMiddleware, err := auth.New(authConfig, logger) if err != nil { logger.Error("failed to initialize auth", slog.String("error", err.Error())) os.Exit(1) diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 0aa9d52..b36b768 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -6,6 +6,7 @@ import ( "encoding/base64" "encoding/json" "fmt" + "log/slog" "math/big" "net/http" "strings" @@ -28,12 +29,13 @@ type Middleware struct { keys map[string]*rsa.PublicKey mu sync.RWMutex client *http.Client + logger *slog.Logger } // New creates an authentication middleware. -func New(cfg Config) (*Middleware, error) { +func New(cfg Config, logger *slog.Logger) (*Middleware, error) { if !cfg.Enabled { - return &Middleware{cfg: cfg}, nil + return &Middleware{cfg: cfg, logger: logger}, nil } if cfg.Issuer == "" { @@ -44,6 +46,7 @@ func New(cfg Config) (*Middleware, error) { cfg: cfg, keys: make(map[string]*rsa.PublicKey), client: &http.Client{Timeout: 10 * time.Second}, + logger: logger, } // Fetch JWKS on startup @@ -255,6 +258,15 @@ func (m *Middleware) periodicRefresh() { defer ticker.Stop() for range ticker.C { - _ = m.refreshJWKS() + if err := m.refreshJWKS(); err != nil { + m.logger.Error("failed to refresh JWKS", + slog.String("issuer", m.cfg.Issuer), + slog.String("error", err.Error()), + ) + } else { + m.logger.Debug("successfully refreshed JWKS", + slog.String("issuer", m.cfg.Issuer), + ) + } } } diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go index 3622d63..bf3b14a 100644 --- a/internal/auth/auth_test.go +++ b/internal/auth/auth_test.go @@ -7,6 +7,7 @@ import ( "encoding/base64" "encoding/json" "fmt" + "log/slog" "math/big" "net/http" "net/http/httptest" @@ -213,7 +214,7 @@ func TestNew(t *testing.T) { } } - m, err := New(tt.config) + m, err := New(tt.config, slog.Default()) if tt.expectError { assert.Error(t, err) @@ -239,7 +240,7 @@ func TestMiddleware_Handler(t *testing.T) { Issuer: server.server.URL, Audience: testAudience, } - m, err := New(cfg) + m, err := New(cfg, slog.Default()) require.NoError(t, err) // Create a test handler that echoes back claims @@ -415,7 +416,7 @@ func TestMiddleware_Handler_DisabledAuth(t *testing.T) { cfg := Config{ Enabled: false, } - m, err := New(cfg) + m, err := New(cfg, slog.Default()) require.NoError(t, err) testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -442,7 +443,7 @@ func TestValidateToken(t *testing.T) { Issuer: server.server.URL, Audience: testAudience, } - m, err := New(cfg) + m, err := New(cfg, slog.Default()) require.NoError(t, err) tests := []struct { @@ -665,7 +666,7 @@ func TestValidateToken_NoAudienceConfigured(t *testing.T) { Issuer: server.server.URL, Audience: "", // No audience required } - m, err := New(cfg) + m, err := New(cfg, slog.Default()) require.NoError(t, err) // Token without audience should be valid @@ -897,7 +898,7 @@ func TestRefreshJWKS_Concurrency(t *testing.T) { Issuer: server.server.URL, Audience: testAudience, } - m, err := New(cfg) + m, err := New(cfg, slog.Default()) require.NoError(t, err) // Trigger concurrent refreshes @@ -982,7 +983,7 @@ func TestMiddleware_IssuerWithTrailingSlash(t *testing.T) { Issuer: server.server.URL + "/", // Trailing slash Audience: testAudience, } - m, err := New(cfg) + m, err := New(cfg, slog.Default()) require.NoError(t, err) require.NotNil(t, m) assert.Len(t, m.keys, 1) diff --git a/internal/conversation/conversation.go b/internal/conversation/conversation.go index b00b193..9a1beb4 100644 --- a/internal/conversation/conversation.go +++ b/internal/conversation/conversation.go @@ -1,6 +1,7 @@ package conversation import ( + "context" "sync" "time" @@ -9,10 +10,10 @@ import ( // Store defines the interface for conversation storage backends. type Store interface { - Get(id string) (*Conversation, error) - Create(id string, model string, messages []api.Message) (*Conversation, error) - Append(id string, messages ...api.Message) (*Conversation, error) - Delete(id string) error + Get(ctx context.Context, id string) (*Conversation, error) + Create(ctx context.Context, id string, model string, messages []api.Message) (*Conversation, error) + Append(ctx context.Context, id string, messages ...api.Message) (*Conversation, error) + Delete(ctx context.Context, id string) error Size() int Close() error } @@ -51,7 +52,7 @@ func NewMemoryStore(ttl time.Duration) *MemoryStore { } // Get retrieves a conversation by ID. Returns a deep copy to prevent data races. -func (s *MemoryStore) Get(id string) (*Conversation, error) { +func (s *MemoryStore) Get(ctx context.Context, id string) (*Conversation, error) { s.mu.RLock() defer s.mu.RUnlock() @@ -74,7 +75,7 @@ func (s *MemoryStore) Get(id string) (*Conversation, error) { } // Create creates a new conversation with the given messages. -func (s *MemoryStore) Create(id string, model string, messages []api.Message) (*Conversation, error) { +func (s *MemoryStore) Create(ctx context.Context, id string, model string, messages []api.Message) (*Conversation, error) { s.mu.Lock() defer s.mu.Unlock() @@ -105,7 +106,7 @@ func (s *MemoryStore) Create(id string, model string, messages []api.Message) (* } // Append adds new messages to an existing conversation. -func (s *MemoryStore) Append(id string, messages ...api.Message) (*Conversation, error) { +func (s *MemoryStore) Append(ctx context.Context, id string, messages ...api.Message) (*Conversation, error) { s.mu.Lock() defer s.mu.Unlock() @@ -131,7 +132,7 @@ func (s *MemoryStore) Append(id string, messages ...api.Message) (*Conversation, } // Delete removes a conversation from the store. -func (s *MemoryStore) Delete(id string) error { +func (s *MemoryStore) Delete(ctx context.Context, id string) error { s.mu.Lock() defer s.mu.Unlock() diff --git a/internal/conversation/conversation_test.go b/internal/conversation/conversation_test.go index 6dc747d..b217973 100644 --- a/internal/conversation/conversation_test.go +++ b/internal/conversation/conversation_test.go @@ -1,6 +1,7 @@ package conversation import ( + "context" "testing" "time" @@ -21,7 +22,7 @@ func TestMemoryStore_CreateAndGet(t *testing.T) { }, } - conv, err := store.Create("test-id", "gpt-4", messages) + conv, err := store.Create(context.Background(),"test-id", "gpt-4", messages) require.NoError(t, err) require.NotNil(t, conv) assert.Equal(t, "test-id", conv.ID) @@ -29,7 +30,7 @@ func TestMemoryStore_CreateAndGet(t *testing.T) { assert.Len(t, conv.Messages, 1) assert.Equal(t, "Hello", conv.Messages[0].Content[0].Text) - retrieved, err := store.Get("test-id") + retrieved, err := store.Get(context.Background(),"test-id") require.NoError(t, err) require.NotNil(t, retrieved) assert.Equal(t, conv.ID, retrieved.ID) @@ -40,7 +41,7 @@ func TestMemoryStore_CreateAndGet(t *testing.T) { func TestMemoryStore_GetNonExistent(t *testing.T) { store := NewMemoryStore(1 * time.Hour) - conv, err := store.Get("nonexistent") + conv, err := store.Get(context.Background(),"nonexistent") require.NoError(t, err) assert.Nil(t, conv, "should return nil for nonexistent conversation") } @@ -57,7 +58,7 @@ func TestMemoryStore_Append(t *testing.T) { }, } - _, err := store.Create("test-id", "gpt-4", initialMessages) + _, err := store.Create(context.Background(),"test-id", "gpt-4", initialMessages) require.NoError(t, err) newMessages := []api.Message{ @@ -75,7 +76,7 @@ func TestMemoryStore_Append(t *testing.T) { }, } - conv, err := store.Append("test-id", newMessages...) + conv, err := store.Append(context.Background(),"test-id", newMessages...) require.NoError(t, err) require.NotNil(t, conv) assert.Len(t, conv.Messages, 3, "should have all messages") @@ -94,7 +95,7 @@ func TestMemoryStore_AppendNonExistent(t *testing.T) { }, } - conv, err := store.Append("nonexistent", newMessage) + conv, err := store.Append(context.Background(),"nonexistent", newMessage) require.NoError(t, err) assert.Nil(t, conv, "should return nil when appending to nonexistent conversation") } @@ -111,20 +112,20 @@ func TestMemoryStore_Delete(t *testing.T) { }, } - _, err := store.Create("test-id", "gpt-4", messages) + _, err := store.Create(context.Background(),"test-id", "gpt-4", messages) require.NoError(t, err) // Verify it exists - conv, err := store.Get("test-id") + conv, err := store.Get(context.Background(),"test-id") require.NoError(t, err) assert.NotNil(t, conv) // Delete it - err = store.Delete("test-id") + err = store.Delete(context.Background(),"test-id") require.NoError(t, err) // Verify it's gone - conv, err = store.Get("test-id") + conv, err = store.Get(context.Background(),"test-id") require.NoError(t, err) assert.Nil(t, conv, "conversation should be deleted") } @@ -138,15 +139,15 @@ func TestMemoryStore_Size(t *testing.T) { {Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}}, } - _, err := store.Create("conv-1", "gpt-4", messages) + _, err := store.Create(context.Background(),"conv-1", "gpt-4", messages) require.NoError(t, err) assert.Equal(t, 1, store.Size()) - _, err = store.Create("conv-2", "gpt-4", messages) + _, err = store.Create(context.Background(),"conv-2", "gpt-4", messages) require.NoError(t, err) assert.Equal(t, 2, store.Size()) - err = store.Delete("conv-1") + err = store.Delete(context.Background(),"conv-1") require.NoError(t, err) assert.Equal(t, 1, store.Size()) } @@ -159,14 +160,14 @@ func TestMemoryStore_ConcurrentAccess(t *testing.T) { } // Create initial conversation - _, err := store.Create("test-id", "gpt-4", messages) + _, err := store.Create(context.Background(),"test-id", "gpt-4", messages) require.NoError(t, err) // Simulate concurrent reads and writes done := make(chan bool, 10) for i := 0; i < 5; i++ { go func() { - _, _ = store.Get("test-id") + _, _ = store.Get(context.Background(),"test-id") done <- true }() } @@ -176,7 +177,7 @@ func TestMemoryStore_ConcurrentAccess(t *testing.T) { Role: "assistant", Content: []api.ContentBlock{{Type: "output_text", Text: "Response"}}, } - _, _ = store.Append("test-id", newMsg) + _, _ = store.Append(context.Background(),"test-id", newMsg) done <- true }() } @@ -187,7 +188,7 @@ func TestMemoryStore_ConcurrentAccess(t *testing.T) { } // Verify final state - conv, err := store.Get("test-id") + conv, err := store.Get(context.Background(),"test-id") require.NoError(t, err) assert.NotNil(t, conv) assert.GreaterOrEqual(t, len(conv.Messages), 1) @@ -205,11 +206,11 @@ func TestMemoryStore_DeepCopy(t *testing.T) { }, } - _, err := store.Create("test-id", "gpt-4", messages) + _, err := store.Create(context.Background(),"test-id", "gpt-4", messages) require.NoError(t, err) // Get conversation - conv1, err := store.Get("test-id") + conv1, err := store.Get(context.Background(),"test-id") require.NoError(t, err) // Note: Current implementation copies the Messages slice but not the Content blocks @@ -225,7 +226,7 @@ func TestMemoryStore_DeepCopy(t *testing.T) { assert.Equal(t, originalLen+1, len(conv1.Messages), "can modify returned message slice") // Verify original is unchanged - conv2, err := store.Get("test-id") + conv2, err := store.Get(context.Background(),"test-id") require.NoError(t, err) assert.Equal(t, originalLen, len(conv2.Messages), "original conversation unaffected by slice modification") } @@ -238,11 +239,11 @@ func TestMemoryStore_TTLCleanup(t *testing.T) { {Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}}, } - _, err := store.Create("test-id", "gpt-4", messages) + _, err := store.Create(context.Background(),"test-id", "gpt-4", messages) require.NoError(t, err) // Verify it exists - conv, err := store.Get("test-id") + conv, err := store.Get(context.Background(),"test-id") require.NoError(t, err) assert.NotNil(t, conv) assert.Equal(t, 1, store.Size()) @@ -265,12 +266,12 @@ func TestMemoryStore_NoTTL(t *testing.T) { {Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}}, } - _, err := store.Create("test-id", "gpt-4", messages) + _, err := store.Create(context.Background(),"test-id", "gpt-4", messages) require.NoError(t, err) assert.Equal(t, 1, store.Size()) // Without TTL, conversation should persist indefinitely - conv, err := store.Get("test-id") + conv, err := store.Get(context.Background(),"test-id") require.NoError(t, err) assert.NotNil(t, conv) } @@ -282,7 +283,7 @@ func TestMemoryStore_UpdatedAtTracking(t *testing.T) { {Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}}, } - conv, err := store.Create("test-id", "gpt-4", messages) + conv, err := store.Create(context.Background(),"test-id", "gpt-4", messages) require.NoError(t, err) createdAt := conv.CreatedAt updatedAt := conv.UpdatedAt @@ -296,7 +297,7 @@ func TestMemoryStore_UpdatedAtTracking(t *testing.T) { Role: "assistant", Content: []api.ContentBlock{{Type: "output_text", Text: "Response"}}, } - conv, err = store.Append("test-id", newMsg) + conv, err = store.Append(context.Background(),"test-id", newMsg) require.NoError(t, err) assert.Equal(t, createdAt, conv.CreatedAt, "created time should not change") @@ -313,7 +314,7 @@ func TestMemoryStore_MultipleConversations(t *testing.T) { messages := []api.Message{ {Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello " + id}}}, } - _, err := store.Create(id, model, messages) + _, err := store.Create(context.Background(),id, model, messages) require.NoError(t, err) } @@ -322,7 +323,7 @@ func TestMemoryStore_MultipleConversations(t *testing.T) { // Verify each conversation is independent for i := 0; i < 10; i++ { id := "conv-" + string(rune('0'+i)) - conv, err := store.Get(id) + conv, err := store.Get(context.Background(),id) require.NoError(t, err) require.NotNil(t, conv) assert.Equal(t, id, conv.ID) diff --git a/internal/conversation/redis_store.go b/internal/conversation/redis_store.go index 146a32d..5428bba 100644 --- a/internal/conversation/redis_store.go +++ b/internal/conversation/redis_store.go @@ -13,7 +13,6 @@ import ( type RedisStore struct { client *redis.Client ttl time.Duration - ctx context.Context } // NewRedisStore creates a Redis-backed conversation store. @@ -21,7 +20,6 @@ func NewRedisStore(client *redis.Client, ttl time.Duration) *RedisStore { return &RedisStore{ client: client, ttl: ttl, - ctx: context.Background(), } } @@ -31,8 +29,8 @@ func (s *RedisStore) key(id string) string { } // Get retrieves a conversation by ID from Redis. -func (s *RedisStore) Get(id string) (*Conversation, error) { - data, err := s.client.Get(s.ctx, s.key(id)).Bytes() +func (s *RedisStore) Get(ctx context.Context, id string) (*Conversation, error) { + data, err := s.client.Get(ctx, s.key(id)).Bytes() if err == redis.Nil { return nil, nil } @@ -49,7 +47,7 @@ func (s *RedisStore) Get(id string) (*Conversation, error) { } // Create creates a new conversation with the given messages. -func (s *RedisStore) Create(id string, model string, messages []api.Message) (*Conversation, error) { +func (s *RedisStore) Create(ctx context.Context, id string, model string, messages []api.Message) (*Conversation, error) { now := time.Now() conv := &Conversation{ ID: id, @@ -64,7 +62,7 @@ func (s *RedisStore) Create(id string, model string, messages []api.Message) (*C return nil, err } - if err := s.client.Set(s.ctx, s.key(id), data, s.ttl).Err(); err != nil { + if err := s.client.Set(ctx, s.key(id), data, s.ttl).Err(); err != nil { return nil, err } @@ -72,8 +70,8 @@ func (s *RedisStore) Create(id string, model string, messages []api.Message) (*C } // Append adds new messages to an existing conversation. -func (s *RedisStore) Append(id string, messages ...api.Message) (*Conversation, error) { - conv, err := s.Get(id) +func (s *RedisStore) Append(ctx context.Context, id string, messages ...api.Message) (*Conversation, error) { + conv, err := s.Get(ctx, id) if err != nil { return nil, err } @@ -89,7 +87,7 @@ func (s *RedisStore) Append(id string, messages ...api.Message) (*Conversation, return nil, err } - if err := s.client.Set(s.ctx, s.key(id), data, s.ttl).Err(); err != nil { + if err := s.client.Set(ctx, s.key(id), data, s.ttl).Err(); err != nil { return nil, err } @@ -97,17 +95,18 @@ func (s *RedisStore) Append(id string, messages ...api.Message) (*Conversation, } // Delete removes a conversation from Redis. -func (s *RedisStore) Delete(id string) error { - return s.client.Del(s.ctx, s.key(id)).Err() +func (s *RedisStore) Delete(ctx context.Context, id string) error { + return s.client.Del(ctx, s.key(id)).Err() } // Size returns the number of active conversations in Redis. func (s *RedisStore) Size() int { var count int var cursor uint64 + ctx := context.Background() for { - keys, nextCursor, err := s.client.Scan(s.ctx, cursor, "conv:*", 100).Result() + keys, nextCursor, err := s.client.Scan(ctx, cursor, "conv:*", 100).Result() if err != nil { return 0 } diff --git a/internal/conversation/sql_store.go b/internal/conversation/sql_store.go index bcfd503..14ccd4f 100644 --- a/internal/conversation/sql_store.go +++ b/internal/conversation/sql_store.go @@ -1,6 +1,7 @@ package conversation import ( + "context" "database/sql" "encoding/json" "time" @@ -71,8 +72,8 @@ func NewSQLStore(db *sql.DB, driver string, ttl time.Duration) (*SQLStore, error return s, nil } -func (s *SQLStore) Get(id string) (*Conversation, error) { - row := s.db.QueryRow(s.dialect.getByID, id) +func (s *SQLStore) Get(ctx context.Context, id string) (*Conversation, error) { + row := s.db.QueryRowContext(ctx, s.dialect.getByID, id) var conv Conversation var msgJSON string @@ -91,14 +92,14 @@ func (s *SQLStore) Get(id string) (*Conversation, error) { return &conv, nil } -func (s *SQLStore) Create(id string, model string, messages []api.Message) (*Conversation, error) { +func (s *SQLStore) Create(ctx context.Context, id string, model string, messages []api.Message) (*Conversation, error) { now := time.Now() msgJSON, err := json.Marshal(messages) if err != nil { return nil, err } - if _, err := s.db.Exec(s.dialect.upsert, id, model, string(msgJSON), now, now); err != nil { + if _, err := s.db.ExecContext(ctx, s.dialect.upsert, id, model, string(msgJSON), now, now); err != nil { return nil, err } @@ -111,8 +112,8 @@ func (s *SQLStore) Create(id string, model string, messages []api.Message) (*Con }, nil } -func (s *SQLStore) Append(id string, messages ...api.Message) (*Conversation, error) { - conv, err := s.Get(id) +func (s *SQLStore) Append(ctx context.Context, id string, messages ...api.Message) (*Conversation, error) { + conv, err := s.Get(ctx, id) if err != nil { return nil, err } @@ -128,15 +129,15 @@ func (s *SQLStore) Append(id string, messages ...api.Message) (*Conversation, er return nil, err } - if _, err := s.db.Exec(s.dialect.update, string(msgJSON), conv.UpdatedAt, id); err != nil { + if _, err := s.db.ExecContext(ctx, s.dialect.update, string(msgJSON), conv.UpdatedAt, id); err != nil { return nil, err } return conv, nil } -func (s *SQLStore) Delete(id string) error { - _, err := s.db.Exec(s.dialect.deleteByID, id) +func (s *SQLStore) Delete(ctx context.Context, id string) error { + _, err := s.db.ExecContext(ctx, s.dialect.deleteByID, id) return err } diff --git a/internal/observability/store_wrapper.go b/internal/observability/store_wrapper.go index 52d8216..2064041 100644 --- a/internal/observability/store_wrapper.go +++ b/internal/observability/store_wrapper.go @@ -42,9 +42,7 @@ func NewInstrumentedStore(s conversation.Store, backend string, registry *promet } // Get wraps the store's Get method with metrics and tracing. -func (s *InstrumentedStore) Get(id string) (*conversation.Conversation, error) { - ctx := context.Background() - +func (s *InstrumentedStore) Get(ctx context.Context, id string) (*conversation.Conversation, error) { // Start span if tracing is enabled if s.tracer != nil { var span trace.Span @@ -61,7 +59,7 @@ func (s *InstrumentedStore) Get(id string) (*conversation.Conversation, error) { start := time.Now() // Call underlying store - conv, err := s.base.Get(id) + conv, err := s.base.Get(ctx, id) // Record metrics duration := time.Since(start).Seconds() @@ -95,9 +93,7 @@ func (s *InstrumentedStore) Get(id string) (*conversation.Conversation, error) { } // Create wraps the store's Create method with metrics and tracing. -func (s *InstrumentedStore) Create(id string, model string, messages []api.Message) (*conversation.Conversation, error) { - ctx := context.Background() - +func (s *InstrumentedStore) Create(ctx context.Context, id string, model string, messages []api.Message) (*conversation.Conversation, error) { // Start span if tracing is enabled if s.tracer != nil { var span trace.Span @@ -116,7 +112,7 @@ func (s *InstrumentedStore) Create(id string, model string, messages []api.Messa start := time.Now() // Call underlying store - conv, err := s.base.Create(id, model, messages) + conv, err := s.base.Create(ctx, id, model, messages) // Record metrics duration := time.Since(start).Seconds() @@ -146,9 +142,7 @@ func (s *InstrumentedStore) Create(id string, model string, messages []api.Messa } // Append wraps the store's Append method with metrics and tracing. -func (s *InstrumentedStore) Append(id string, messages ...api.Message) (*conversation.Conversation, error) { - ctx := context.Background() - +func (s *InstrumentedStore) Append(ctx context.Context, id string, messages ...api.Message) (*conversation.Conversation, error) { // Start span if tracing is enabled if s.tracer != nil { var span trace.Span @@ -166,7 +160,7 @@ func (s *InstrumentedStore) Append(id string, messages ...api.Message) (*convers start := time.Now() // Call underlying store - conv, err := s.base.Append(id, messages...) + conv, err := s.base.Append(ctx, id, messages...) // Record metrics duration := time.Since(start).Seconds() @@ -199,9 +193,7 @@ func (s *InstrumentedStore) Append(id string, messages ...api.Message) (*convers } // Delete wraps the store's Delete method with metrics and tracing. -func (s *InstrumentedStore) Delete(id string) error { - ctx := context.Background() - +func (s *InstrumentedStore) Delete(ctx context.Context, id string) error { // Start span if tracing is enabled if s.tracer != nil { var span trace.Span @@ -218,7 +210,7 @@ func (s *InstrumentedStore) Delete(id string) error { start := time.Now() // Call underlying store - err := s.base.Delete(id) + err := s.base.Delete(ctx, id) // Record metrics duration := time.Since(start).Seconds() diff --git a/internal/server/health.go b/internal/server/health.go index b95ebaf..4765a18 100644 --- a/internal/server/health.go +++ b/internal/server/health.go @@ -51,7 +51,7 @@ func (s *GatewayServer) handleReady(w http.ResponseWriter, r *http.Request) { // Test conversation store by attempting a simple operation testID := "health_check_test" - _, err := s.convs.Get(testID) + _, err := s.convs.Get(ctx, testID) if err != nil { checks["conversation_store"] = "unhealthy: " + err.Error() allHealthy = false diff --git a/internal/server/mocks_test.go b/internal/server/mocks_test.go index cbc8ccd..bfdc3cd 100644 --- a/internal/server/mocks_test.go +++ b/internal/server/mocks_test.go @@ -156,7 +156,7 @@ func newMockConversationStore() *mockConversationStore { } } -func (m *mockConversationStore) Get(id string) (*conversation.Conversation, error) { +func (m *mockConversationStore) Get(ctx context.Context, id string) (*conversation.Conversation, error) { m.mu.Lock() defer m.mu.Unlock() @@ -170,7 +170,7 @@ func (m *mockConversationStore) Get(id string) (*conversation.Conversation, erro return conv, nil } -func (m *mockConversationStore) Create(id string, model string, messages []api.Message) (*conversation.Conversation, error) { +func (m *mockConversationStore) Create(ctx context.Context, id string, model string, messages []api.Message) (*conversation.Conversation, error) { m.mu.Lock() defer m.mu.Unlock() @@ -187,7 +187,7 @@ func (m *mockConversationStore) Create(id string, model string, messages []api.M return conv, nil } -func (m *mockConversationStore) Append(id string, messages ...api.Message) (*conversation.Conversation, error) { +func (m *mockConversationStore) Append(ctx context.Context, id string, messages ...api.Message) (*conversation.Conversation, error) { m.mu.Lock() defer m.mu.Unlock() @@ -203,7 +203,7 @@ func (m *mockConversationStore) Append(id string, messages ...api.Message) (*con return conv, nil } -func (m *mockConversationStore) Delete(id string) error { +func (m *mockConversationStore) Delete(ctx context.Context, id string) error { m.mu.Lock() defer m.mu.Unlock() diff --git a/internal/server/server.go b/internal/server/server.go index f0b2e7d..9125b3b 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -107,7 +107,7 @@ func (s *GatewayServer) handleResponses(w http.ResponseWriter, r *http.Request) // Build full message history from previous conversation var historyMsgs []api.Message if req.PreviousResponseID != nil && *req.PreviousResponseID != "" { - conv, err := s.convs.Get(*req.PreviousResponseID) + conv, err := s.convs.Get(r.Context(), *req.PreviousResponseID) if err != nil { s.logger.ErrorContext(r.Context(), "failed to retrieve conversation", logger.LogAttrsWithTrace(r.Context(), @@ -186,7 +186,7 @@ func (s *GatewayServer) handleSyncResponse(w http.ResponseWriter, r *http.Reques ToolCalls: result.ToolCalls, } allMsgs := append(storeMsgs, assistantMsg) - if _, err := s.convs.Create(responseID, result.Model, allMsgs); err != nil { + if _, err := s.convs.Create(r.Context(), responseID, result.Model, allMsgs); err != nil { s.logger.ErrorContext(r.Context(), "failed to store conversation", logger.LogAttrsWithTrace(r.Context(), slog.String("request_id", logger.FromContext(r.Context())), @@ -543,7 +543,7 @@ loop: ToolCalls: toolCalls, } allMsgs := append(storeMsgs, assistantMsg) - if _, err := s.convs.Create(responseID, model, allMsgs); err != nil { + if _, err := s.convs.Create(r.Context(), responseID, model, allMsgs); err != nil { s.logger.ErrorContext(r.Context(), "failed to store conversation", slog.String("request_id", logger.FromContext(r.Context())), slog.String("response_id", responseID),