Fix context background and silent JWT
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -28,12 +29,13 @@ type Middleware struct {
|
||||
keys map[string]*rsa.PublicKey
|
||||
mu sync.RWMutex
|
||||
client *http.Client
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// New creates an authentication middleware.
|
||||
func New(cfg Config) (*Middleware, error) {
|
||||
func New(cfg Config, logger *slog.Logger) (*Middleware, error) {
|
||||
if !cfg.Enabled {
|
||||
return &Middleware{cfg: cfg}, nil
|
||||
return &Middleware{cfg: cfg, logger: logger}, nil
|
||||
}
|
||||
|
||||
if cfg.Issuer == "" {
|
||||
@@ -44,6 +46,7 @@ func New(cfg Config) (*Middleware, error) {
|
||||
cfg: cfg,
|
||||
keys: make(map[string]*rsa.PublicKey),
|
||||
client: &http.Client{Timeout: 10 * time.Second},
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
// Fetch JWKS on startup
|
||||
@@ -255,6 +258,15 @@ func (m *Middleware) periodicRefresh() {
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
_ = m.refreshJWKS()
|
||||
if err := m.refreshJWKS(); err != nil {
|
||||
m.logger.Error("failed to refresh JWKS",
|
||||
slog.String("issuer", m.cfg.Issuer),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
} else {
|
||||
m.logger.Debug("successfully refreshed JWKS",
|
||||
slog.String("issuer", m.cfg.Issuer),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -213,7 +214,7 @@ func TestNew(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
m, err := New(tt.config)
|
||||
m, err := New(tt.config, slog.Default())
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
@@ -239,7 +240,7 @@ func TestMiddleware_Handler(t *testing.T) {
|
||||
Issuer: server.server.URL,
|
||||
Audience: testAudience,
|
||||
}
|
||||
m, err := New(cfg)
|
||||
m, err := New(cfg, slog.Default())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a test handler that echoes back claims
|
||||
@@ -415,7 +416,7 @@ func TestMiddleware_Handler_DisabledAuth(t *testing.T) {
|
||||
cfg := Config{
|
||||
Enabled: false,
|
||||
}
|
||||
m, err := New(cfg)
|
||||
m, err := New(cfg, slog.Default())
|
||||
require.NoError(t, err)
|
||||
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -442,7 +443,7 @@ func TestValidateToken(t *testing.T) {
|
||||
Issuer: server.server.URL,
|
||||
Audience: testAudience,
|
||||
}
|
||||
m, err := New(cfg)
|
||||
m, err := New(cfg, slog.Default())
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
@@ -665,7 +666,7 @@ func TestValidateToken_NoAudienceConfigured(t *testing.T) {
|
||||
Issuer: server.server.URL,
|
||||
Audience: "", // No audience required
|
||||
}
|
||||
m, err := New(cfg)
|
||||
m, err := New(cfg, slog.Default())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Token without audience should be valid
|
||||
@@ -897,7 +898,7 @@ func TestRefreshJWKS_Concurrency(t *testing.T) {
|
||||
Issuer: server.server.URL,
|
||||
Audience: testAudience,
|
||||
}
|
||||
m, err := New(cfg)
|
||||
m, err := New(cfg, slog.Default())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Trigger concurrent refreshes
|
||||
@@ -982,7 +983,7 @@ func TestMiddleware_IssuerWithTrailingSlash(t *testing.T) {
|
||||
Issuer: server.server.URL + "/", // Trailing slash
|
||||
Audience: testAudience,
|
||||
}
|
||||
m, err := New(cfg)
|
||||
m, err := New(cfg, slog.Default())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, m)
|
||||
assert.Len(t, m.keys, 1)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -9,10 +10,10 @@ import (
|
||||
|
||||
// Store defines the interface for conversation storage backends.
|
||||
type Store interface {
|
||||
Get(id string) (*Conversation, error)
|
||||
Create(id string, model string, messages []api.Message) (*Conversation, error)
|
||||
Append(id string, messages ...api.Message) (*Conversation, error)
|
||||
Delete(id string) error
|
||||
Get(ctx context.Context, id string) (*Conversation, error)
|
||||
Create(ctx context.Context, id string, model string, messages []api.Message) (*Conversation, error)
|
||||
Append(ctx context.Context, id string, messages ...api.Message) (*Conversation, error)
|
||||
Delete(ctx context.Context, id string) error
|
||||
Size() int
|
||||
Close() error
|
||||
}
|
||||
@@ -51,7 +52,7 @@ func NewMemoryStore(ttl time.Duration) *MemoryStore {
|
||||
}
|
||||
|
||||
// Get retrieves a conversation by ID. Returns a deep copy to prevent data races.
|
||||
func (s *MemoryStore) Get(id string) (*Conversation, error) {
|
||||
func (s *MemoryStore) Get(ctx context.Context, id string) (*Conversation, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
@@ -74,7 +75,7 @@ func (s *MemoryStore) Get(id string) (*Conversation, error) {
|
||||
}
|
||||
|
||||
// Create creates a new conversation with the given messages.
|
||||
func (s *MemoryStore) Create(id string, model string, messages []api.Message) (*Conversation, error) {
|
||||
func (s *MemoryStore) Create(ctx context.Context, id string, model string, messages []api.Message) (*Conversation, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
@@ -105,7 +106,7 @@ func (s *MemoryStore) Create(id string, model string, messages []api.Message) (*
|
||||
}
|
||||
|
||||
// Append adds new messages to an existing conversation.
|
||||
func (s *MemoryStore) Append(id string, messages ...api.Message) (*Conversation, error) {
|
||||
func (s *MemoryStore) Append(ctx context.Context, id string, messages ...api.Message) (*Conversation, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
@@ -131,7 +132,7 @@ func (s *MemoryStore) Append(id string, messages ...api.Message) (*Conversation,
|
||||
}
|
||||
|
||||
// Delete removes a conversation from the store.
|
||||
func (s *MemoryStore) Delete(id string) error {
|
||||
func (s *MemoryStore) Delete(ctx context.Context, id string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -21,7 +22,7 @@ func TestMemoryStore_CreateAndGet(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
conv, err := store.Create("test-id", "gpt-4", messages)
|
||||
conv, err := store.Create(context.Background(),"test-id", "gpt-4", messages)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conv)
|
||||
assert.Equal(t, "test-id", conv.ID)
|
||||
@@ -29,7 +30,7 @@ func TestMemoryStore_CreateAndGet(t *testing.T) {
|
||||
assert.Len(t, conv.Messages, 1)
|
||||
assert.Equal(t, "Hello", conv.Messages[0].Content[0].Text)
|
||||
|
||||
retrieved, err := store.Get("test-id")
|
||||
retrieved, err := store.Get(context.Background(),"test-id")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, retrieved)
|
||||
assert.Equal(t, conv.ID, retrieved.ID)
|
||||
@@ -40,7 +41,7 @@ func TestMemoryStore_CreateAndGet(t *testing.T) {
|
||||
func TestMemoryStore_GetNonExistent(t *testing.T) {
|
||||
store := NewMemoryStore(1 * time.Hour)
|
||||
|
||||
conv, err := store.Get("nonexistent")
|
||||
conv, err := store.Get(context.Background(),"nonexistent")
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, conv, "should return nil for nonexistent conversation")
|
||||
}
|
||||
@@ -57,7 +58,7 @@ func TestMemoryStore_Append(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
_, err := store.Create("test-id", "gpt-4", initialMessages)
|
||||
_, err := store.Create(context.Background(),"test-id", "gpt-4", initialMessages)
|
||||
require.NoError(t, err)
|
||||
|
||||
newMessages := []api.Message{
|
||||
@@ -75,7 +76,7 @@ func TestMemoryStore_Append(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
conv, err := store.Append("test-id", newMessages...)
|
||||
conv, err := store.Append(context.Background(),"test-id", newMessages...)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conv)
|
||||
assert.Len(t, conv.Messages, 3, "should have all messages")
|
||||
@@ -94,7 +95,7 @@ func TestMemoryStore_AppendNonExistent(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
conv, err := store.Append("nonexistent", newMessage)
|
||||
conv, err := store.Append(context.Background(),"nonexistent", newMessage)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, conv, "should return nil when appending to nonexistent conversation")
|
||||
}
|
||||
@@ -111,20 +112,20 @@ func TestMemoryStore_Delete(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
_, err := store.Create("test-id", "gpt-4", messages)
|
||||
_, err := store.Create(context.Background(),"test-id", "gpt-4", messages)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify it exists
|
||||
conv, err := store.Get("test-id")
|
||||
conv, err := store.Get(context.Background(),"test-id")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, conv)
|
||||
|
||||
// Delete it
|
||||
err = store.Delete("test-id")
|
||||
err = store.Delete(context.Background(),"test-id")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify it's gone
|
||||
conv, err = store.Get("test-id")
|
||||
conv, err = store.Get(context.Background(),"test-id")
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, conv, "conversation should be deleted")
|
||||
}
|
||||
@@ -138,15 +139,15 @@ func TestMemoryStore_Size(t *testing.T) {
|
||||
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
|
||||
}
|
||||
|
||||
_, err := store.Create("conv-1", "gpt-4", messages)
|
||||
_, err := store.Create(context.Background(),"conv-1", "gpt-4", messages)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, store.Size())
|
||||
|
||||
_, err = store.Create("conv-2", "gpt-4", messages)
|
||||
_, err = store.Create(context.Background(),"conv-2", "gpt-4", messages)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, store.Size())
|
||||
|
||||
err = store.Delete("conv-1")
|
||||
err = store.Delete(context.Background(),"conv-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, store.Size())
|
||||
}
|
||||
@@ -159,14 +160,14 @@ func TestMemoryStore_ConcurrentAccess(t *testing.T) {
|
||||
}
|
||||
|
||||
// Create initial conversation
|
||||
_, err := store.Create("test-id", "gpt-4", messages)
|
||||
_, err := store.Create(context.Background(),"test-id", "gpt-4", messages)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate concurrent reads and writes
|
||||
done := make(chan bool, 10)
|
||||
for i := 0; i < 5; i++ {
|
||||
go func() {
|
||||
_, _ = store.Get("test-id")
|
||||
_, _ = store.Get(context.Background(),"test-id")
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
@@ -176,7 +177,7 @@ func TestMemoryStore_ConcurrentAccess(t *testing.T) {
|
||||
Role: "assistant",
|
||||
Content: []api.ContentBlock{{Type: "output_text", Text: "Response"}},
|
||||
}
|
||||
_, _ = store.Append("test-id", newMsg)
|
||||
_, _ = store.Append(context.Background(),"test-id", newMsg)
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
@@ -187,7 +188,7 @@ func TestMemoryStore_ConcurrentAccess(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify final state
|
||||
conv, err := store.Get("test-id")
|
||||
conv, err := store.Get(context.Background(),"test-id")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, conv)
|
||||
assert.GreaterOrEqual(t, len(conv.Messages), 1)
|
||||
@@ -205,11 +206,11 @@ func TestMemoryStore_DeepCopy(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
_, err := store.Create("test-id", "gpt-4", messages)
|
||||
_, err := store.Create(context.Background(),"test-id", "gpt-4", messages)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get conversation
|
||||
conv1, err := store.Get("test-id")
|
||||
conv1, err := store.Get(context.Background(),"test-id")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Note: Current implementation copies the Messages slice but not the Content blocks
|
||||
@@ -225,7 +226,7 @@ func TestMemoryStore_DeepCopy(t *testing.T) {
|
||||
assert.Equal(t, originalLen+1, len(conv1.Messages), "can modify returned message slice")
|
||||
|
||||
// Verify original is unchanged
|
||||
conv2, err := store.Get("test-id")
|
||||
conv2, err := store.Get(context.Background(),"test-id")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, originalLen, len(conv2.Messages), "original conversation unaffected by slice modification")
|
||||
}
|
||||
@@ -238,11 +239,11 @@ func TestMemoryStore_TTLCleanup(t *testing.T) {
|
||||
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
|
||||
}
|
||||
|
||||
_, err := store.Create("test-id", "gpt-4", messages)
|
||||
_, err := store.Create(context.Background(),"test-id", "gpt-4", messages)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify it exists
|
||||
conv, err := store.Get("test-id")
|
||||
conv, err := store.Get(context.Background(),"test-id")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, conv)
|
||||
assert.Equal(t, 1, store.Size())
|
||||
@@ -265,12 +266,12 @@ func TestMemoryStore_NoTTL(t *testing.T) {
|
||||
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
|
||||
}
|
||||
|
||||
_, err := store.Create("test-id", "gpt-4", messages)
|
||||
_, err := store.Create(context.Background(),"test-id", "gpt-4", messages)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, store.Size())
|
||||
|
||||
// Without TTL, conversation should persist indefinitely
|
||||
conv, err := store.Get("test-id")
|
||||
conv, err := store.Get(context.Background(),"test-id")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, conv)
|
||||
}
|
||||
@@ -282,7 +283,7 @@ func TestMemoryStore_UpdatedAtTracking(t *testing.T) {
|
||||
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
|
||||
}
|
||||
|
||||
conv, err := store.Create("test-id", "gpt-4", messages)
|
||||
conv, err := store.Create(context.Background(),"test-id", "gpt-4", messages)
|
||||
require.NoError(t, err)
|
||||
createdAt := conv.CreatedAt
|
||||
updatedAt := conv.UpdatedAt
|
||||
@@ -296,7 +297,7 @@ func TestMemoryStore_UpdatedAtTracking(t *testing.T) {
|
||||
Role: "assistant",
|
||||
Content: []api.ContentBlock{{Type: "output_text", Text: "Response"}},
|
||||
}
|
||||
conv, err = store.Append("test-id", newMsg)
|
||||
conv, err = store.Append(context.Background(),"test-id", newMsg)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, createdAt, conv.CreatedAt, "created time should not change")
|
||||
@@ -313,7 +314,7 @@ func TestMemoryStore_MultipleConversations(t *testing.T) {
|
||||
messages := []api.Message{
|
||||
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello " + id}}},
|
||||
}
|
||||
_, err := store.Create(id, model, messages)
|
||||
_, err := store.Create(context.Background(),id, model, messages)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -322,7 +323,7 @@ func TestMemoryStore_MultipleConversations(t *testing.T) {
|
||||
// Verify each conversation is independent
|
||||
for i := 0; i < 10; i++ {
|
||||
id := "conv-" + string(rune('0'+i))
|
||||
conv, err := store.Get(id)
|
||||
conv, err := store.Get(context.Background(),id)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conv)
|
||||
assert.Equal(t, id, conv.ID)
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
type RedisStore struct {
|
||||
client *redis.Client
|
||||
ttl time.Duration
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewRedisStore creates a Redis-backed conversation store.
|
||||
@@ -21,7 +20,6 @@ func NewRedisStore(client *redis.Client, ttl time.Duration) *RedisStore {
|
||||
return &RedisStore{
|
||||
client: client,
|
||||
ttl: ttl,
|
||||
ctx: context.Background(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,8 +29,8 @@ func (s *RedisStore) key(id string) string {
|
||||
}
|
||||
|
||||
// Get retrieves a conversation by ID from Redis.
|
||||
func (s *RedisStore) Get(id string) (*Conversation, error) {
|
||||
data, err := s.client.Get(s.ctx, s.key(id)).Bytes()
|
||||
func (s *RedisStore) Get(ctx context.Context, id string) (*Conversation, error) {
|
||||
data, err := s.client.Get(ctx, s.key(id)).Bytes()
|
||||
if err == redis.Nil {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -49,7 +47,7 @@ func (s *RedisStore) Get(id string) (*Conversation, error) {
|
||||
}
|
||||
|
||||
// Create creates a new conversation with the given messages.
|
||||
func (s *RedisStore) Create(id string, model string, messages []api.Message) (*Conversation, error) {
|
||||
func (s *RedisStore) Create(ctx context.Context, id string, model string, messages []api.Message) (*Conversation, error) {
|
||||
now := time.Now()
|
||||
conv := &Conversation{
|
||||
ID: id,
|
||||
@@ -64,7 +62,7 @@ func (s *RedisStore) Create(id string, model string, messages []api.Message) (*C
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.client.Set(s.ctx, s.key(id), data, s.ttl).Err(); err != nil {
|
||||
if err := s.client.Set(ctx, s.key(id), data, s.ttl).Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -72,8 +70,8 @@ func (s *RedisStore) Create(id string, model string, messages []api.Message) (*C
|
||||
}
|
||||
|
||||
// Append adds new messages to an existing conversation.
|
||||
func (s *RedisStore) Append(id string, messages ...api.Message) (*Conversation, error) {
|
||||
conv, err := s.Get(id)
|
||||
func (s *RedisStore) Append(ctx context.Context, id string, messages ...api.Message) (*Conversation, error) {
|
||||
conv, err := s.Get(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -89,7 +87,7 @@ func (s *RedisStore) Append(id string, messages ...api.Message) (*Conversation,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.client.Set(s.ctx, s.key(id), data, s.ttl).Err(); err != nil {
|
||||
if err := s.client.Set(ctx, s.key(id), data, s.ttl).Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -97,17 +95,18 @@ func (s *RedisStore) Append(id string, messages ...api.Message) (*Conversation,
|
||||
}
|
||||
|
||||
// Delete removes a conversation from Redis.
|
||||
func (s *RedisStore) Delete(id string) error {
|
||||
return s.client.Del(s.ctx, s.key(id)).Err()
|
||||
func (s *RedisStore) Delete(ctx context.Context, id string) error {
|
||||
return s.client.Del(ctx, s.key(id)).Err()
|
||||
}
|
||||
|
||||
// Size returns the number of active conversations in Redis.
|
||||
func (s *RedisStore) Size() int {
|
||||
var count int
|
||||
var cursor uint64
|
||||
ctx := context.Background()
|
||||
|
||||
for {
|
||||
keys, nextCursor, err := s.client.Scan(s.ctx, cursor, "conv:*", 100).Result()
|
||||
keys, nextCursor, err := s.client.Scan(ctx, cursor, "conv:*", 100).Result()
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
@@ -71,8 +72,8 @@ func NewSQLStore(db *sql.DB, driver string, ttl time.Duration) (*SQLStore, error
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *SQLStore) Get(id string) (*Conversation, error) {
|
||||
row := s.db.QueryRow(s.dialect.getByID, id)
|
||||
func (s *SQLStore) Get(ctx context.Context, id string) (*Conversation, error) {
|
||||
row := s.db.QueryRowContext(ctx, s.dialect.getByID, id)
|
||||
|
||||
var conv Conversation
|
||||
var msgJSON string
|
||||
@@ -91,14 +92,14 @@ func (s *SQLStore) Get(id string) (*Conversation, error) {
|
||||
return &conv, nil
|
||||
}
|
||||
|
||||
func (s *SQLStore) Create(id string, model string, messages []api.Message) (*Conversation, error) {
|
||||
func (s *SQLStore) Create(ctx context.Context, id string, model string, messages []api.Message) (*Conversation, error) {
|
||||
now := time.Now()
|
||||
msgJSON, err := json.Marshal(messages)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := s.db.Exec(s.dialect.upsert, id, model, string(msgJSON), now, now); err != nil {
|
||||
if _, err := s.db.ExecContext(ctx, s.dialect.upsert, id, model, string(msgJSON), now, now); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -111,8 +112,8 @@ func (s *SQLStore) Create(id string, model string, messages []api.Message) (*Con
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SQLStore) Append(id string, messages ...api.Message) (*Conversation, error) {
|
||||
conv, err := s.Get(id)
|
||||
func (s *SQLStore) Append(ctx context.Context, id string, messages ...api.Message) (*Conversation, error) {
|
||||
conv, err := s.Get(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -128,15 +129,15 @@ func (s *SQLStore) Append(id string, messages ...api.Message) (*Conversation, er
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := s.db.Exec(s.dialect.update, string(msgJSON), conv.UpdatedAt, id); err != nil {
|
||||
if _, err := s.db.ExecContext(ctx, s.dialect.update, string(msgJSON), conv.UpdatedAt, id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conv, nil
|
||||
}
|
||||
|
||||
func (s *SQLStore) Delete(id string) error {
|
||||
_, err := s.db.Exec(s.dialect.deleteByID, id)
|
||||
func (s *SQLStore) Delete(ctx context.Context, id string) error {
|
||||
_, err := s.db.ExecContext(ctx, s.dialect.deleteByID, id)
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -42,9 +42,7 @@ func NewInstrumentedStore(s conversation.Store, backend string, registry *promet
|
||||
}
|
||||
|
||||
// Get wraps the store's Get method with metrics and tracing.
|
||||
func (s *InstrumentedStore) Get(id string) (*conversation.Conversation, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
func (s *InstrumentedStore) Get(ctx context.Context, id string) (*conversation.Conversation, error) {
|
||||
// Start span if tracing is enabled
|
||||
if s.tracer != nil {
|
||||
var span trace.Span
|
||||
@@ -61,7 +59,7 @@ func (s *InstrumentedStore) Get(id string) (*conversation.Conversation, error) {
|
||||
start := time.Now()
|
||||
|
||||
// Call underlying store
|
||||
conv, err := s.base.Get(id)
|
||||
conv, err := s.base.Get(ctx, id)
|
||||
|
||||
// Record metrics
|
||||
duration := time.Since(start).Seconds()
|
||||
@@ -95,9 +93,7 @@ func (s *InstrumentedStore) Get(id string) (*conversation.Conversation, error) {
|
||||
}
|
||||
|
||||
// Create wraps the store's Create method with metrics and tracing.
|
||||
func (s *InstrumentedStore) Create(id string, model string, messages []api.Message) (*conversation.Conversation, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
func (s *InstrumentedStore) Create(ctx context.Context, id string, model string, messages []api.Message) (*conversation.Conversation, error) {
|
||||
// Start span if tracing is enabled
|
||||
if s.tracer != nil {
|
||||
var span trace.Span
|
||||
@@ -116,7 +112,7 @@ func (s *InstrumentedStore) Create(id string, model string, messages []api.Messa
|
||||
start := time.Now()
|
||||
|
||||
// Call underlying store
|
||||
conv, err := s.base.Create(id, model, messages)
|
||||
conv, err := s.base.Create(ctx, id, model, messages)
|
||||
|
||||
// Record metrics
|
||||
duration := time.Since(start).Seconds()
|
||||
@@ -146,9 +142,7 @@ func (s *InstrumentedStore) Create(id string, model string, messages []api.Messa
|
||||
}
|
||||
|
||||
// Append wraps the store's Append method with metrics and tracing.
|
||||
func (s *InstrumentedStore) Append(id string, messages ...api.Message) (*conversation.Conversation, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
func (s *InstrumentedStore) Append(ctx context.Context, id string, messages ...api.Message) (*conversation.Conversation, error) {
|
||||
// Start span if tracing is enabled
|
||||
if s.tracer != nil {
|
||||
var span trace.Span
|
||||
@@ -166,7 +160,7 @@ func (s *InstrumentedStore) Append(id string, messages ...api.Message) (*convers
|
||||
start := time.Now()
|
||||
|
||||
// Call underlying store
|
||||
conv, err := s.base.Append(id, messages...)
|
||||
conv, err := s.base.Append(ctx, id, messages...)
|
||||
|
||||
// Record metrics
|
||||
duration := time.Since(start).Seconds()
|
||||
@@ -199,9 +193,7 @@ func (s *InstrumentedStore) Append(id string, messages ...api.Message) (*convers
|
||||
}
|
||||
|
||||
// Delete wraps the store's Delete method with metrics and tracing.
|
||||
func (s *InstrumentedStore) Delete(id string) error {
|
||||
ctx := context.Background()
|
||||
|
||||
func (s *InstrumentedStore) Delete(ctx context.Context, id string) error {
|
||||
// Start span if tracing is enabled
|
||||
if s.tracer != nil {
|
||||
var span trace.Span
|
||||
@@ -218,7 +210,7 @@ func (s *InstrumentedStore) Delete(id string) error {
|
||||
start := time.Now()
|
||||
|
||||
// Call underlying store
|
||||
err := s.base.Delete(id)
|
||||
err := s.base.Delete(ctx, id)
|
||||
|
||||
// Record metrics
|
||||
duration := time.Since(start).Seconds()
|
||||
|
||||
@@ -51,7 +51,7 @@ func (s *GatewayServer) handleReady(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Test conversation store by attempting a simple operation
|
||||
testID := "health_check_test"
|
||||
_, err := s.convs.Get(testID)
|
||||
_, err := s.convs.Get(ctx, testID)
|
||||
if err != nil {
|
||||
checks["conversation_store"] = "unhealthy: " + err.Error()
|
||||
allHealthy = false
|
||||
|
||||
@@ -156,7 +156,7 @@ func newMockConversationStore() *mockConversationStore {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockConversationStore) Get(id string) (*conversation.Conversation, error) {
|
||||
func (m *mockConversationStore) Get(ctx context.Context, id string) (*conversation.Conversation, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
@@ -170,7 +170,7 @@ func (m *mockConversationStore) Get(id string) (*conversation.Conversation, erro
|
||||
return conv, nil
|
||||
}
|
||||
|
||||
func (m *mockConversationStore) Create(id string, model string, messages []api.Message) (*conversation.Conversation, error) {
|
||||
func (m *mockConversationStore) Create(ctx context.Context, id string, model string, messages []api.Message) (*conversation.Conversation, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
@@ -187,7 +187,7 @@ func (m *mockConversationStore) Create(id string, model string, messages []api.M
|
||||
return conv, nil
|
||||
}
|
||||
|
||||
func (m *mockConversationStore) Append(id string, messages ...api.Message) (*conversation.Conversation, error) {
|
||||
func (m *mockConversationStore) Append(ctx context.Context, id string, messages ...api.Message) (*conversation.Conversation, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
@@ -203,7 +203,7 @@ func (m *mockConversationStore) Append(id string, messages ...api.Message) (*con
|
||||
return conv, nil
|
||||
}
|
||||
|
||||
func (m *mockConversationStore) Delete(id string) error {
|
||||
func (m *mockConversationStore) Delete(ctx context.Context, id string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
|
||||
@@ -107,7 +107,7 @@ func (s *GatewayServer) handleResponses(w http.ResponseWriter, r *http.Request)
|
||||
// Build full message history from previous conversation
|
||||
var historyMsgs []api.Message
|
||||
if req.PreviousResponseID != nil && *req.PreviousResponseID != "" {
|
||||
conv, err := s.convs.Get(*req.PreviousResponseID)
|
||||
conv, err := s.convs.Get(r.Context(), *req.PreviousResponseID)
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(r.Context(), "failed to retrieve conversation",
|
||||
logger.LogAttrsWithTrace(r.Context(),
|
||||
@@ -186,7 +186,7 @@ func (s *GatewayServer) handleSyncResponse(w http.ResponseWriter, r *http.Reques
|
||||
ToolCalls: result.ToolCalls,
|
||||
}
|
||||
allMsgs := append(storeMsgs, assistantMsg)
|
||||
if _, err := s.convs.Create(responseID, result.Model, allMsgs); err != nil {
|
||||
if _, err := s.convs.Create(r.Context(), responseID, result.Model, allMsgs); err != nil {
|
||||
s.logger.ErrorContext(r.Context(), "failed to store conversation",
|
||||
logger.LogAttrsWithTrace(r.Context(),
|
||||
slog.String("request_id", logger.FromContext(r.Context())),
|
||||
@@ -543,7 +543,7 @@ loop:
|
||||
ToolCalls: toolCalls,
|
||||
}
|
||||
allMsgs := append(storeMsgs, assistantMsg)
|
||||
if _, err := s.convs.Create(responseID, model, allMsgs); err != nil {
|
||||
if _, err := s.convs.Create(r.Context(), responseID, model, allMsgs); err != nil {
|
||||
s.logger.ErrorContext(r.Context(), "failed to store conversation",
|
||||
slog.String("request_id", logger.FromContext(r.Context())),
|
||||
slog.String("response_id", responseID),
|
||||
|
||||
Reference in New Issue
Block a user