Fix context background and silent JWT

This commit is contained in:
2026-03-05 06:55:44 +00:00
parent 214e63b0c5
commit ae2e1b7a80
11 changed files with 99 additions and 92 deletions

View File

@@ -1,6 +1,7 @@
package conversation
import (
"context"
"sync"
"time"
@@ -9,10 +10,10 @@ import (
// Store defines the interface for conversation storage backends.
type Store interface {
Get(id string) (*Conversation, error)
Create(id string, model string, messages []api.Message) (*Conversation, error)
Append(id string, messages ...api.Message) (*Conversation, error)
Delete(id string) error
Get(ctx context.Context, id string) (*Conversation, error)
Create(ctx context.Context, id string, model string, messages []api.Message) (*Conversation, error)
Append(ctx context.Context, id string, messages ...api.Message) (*Conversation, error)
Delete(ctx context.Context, id string) error
Size() int
Close() error
}
@@ -51,7 +52,7 @@ func NewMemoryStore(ttl time.Duration) *MemoryStore {
}
// Get retrieves a conversation by ID. Returns a deep copy to prevent data races.
func (s *MemoryStore) Get(id string) (*Conversation, error) {
func (s *MemoryStore) Get(ctx context.Context, id string) (*Conversation, error) {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -74,7 +75,7 @@ func (s *MemoryStore) Get(id string) (*Conversation, error) {
}
// Create creates a new conversation with the given messages.
func (s *MemoryStore) Create(id string, model string, messages []api.Message) (*Conversation, error) {
func (s *MemoryStore) Create(ctx context.Context, id string, model string, messages []api.Message) (*Conversation, error) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -105,7 +106,7 @@ func (s *MemoryStore) Create(id string, model string, messages []api.Message) (*
}
// Append adds new messages to an existing conversation.
func (s *MemoryStore) Append(id string, messages ...api.Message) (*Conversation, error) {
func (s *MemoryStore) Append(ctx context.Context, id string, messages ...api.Message) (*Conversation, error) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -131,7 +132,7 @@ func (s *MemoryStore) Append(id string, messages ...api.Message) (*Conversation,
}
// Delete removes a conversation from the store.
func (s *MemoryStore) Delete(id string) error {
func (s *MemoryStore) Delete(ctx context.Context, id string) error {
s.mu.Lock()
defer s.mu.Unlock()

View File

@@ -1,6 +1,7 @@
package conversation
import (
"context"
"testing"
"time"
@@ -21,7 +22,7 @@ func TestMemoryStore_CreateAndGet(t *testing.T) {
},
}
conv, err := store.Create("test-id", "gpt-4", messages)
conv, err := store.Create(context.Background(),"test-id", "gpt-4", messages)
require.NoError(t, err)
require.NotNil(t, conv)
assert.Equal(t, "test-id", conv.ID)
@@ -29,7 +30,7 @@ func TestMemoryStore_CreateAndGet(t *testing.T) {
assert.Len(t, conv.Messages, 1)
assert.Equal(t, "Hello", conv.Messages[0].Content[0].Text)
retrieved, err := store.Get("test-id")
retrieved, err := store.Get(context.Background(),"test-id")
require.NoError(t, err)
require.NotNil(t, retrieved)
assert.Equal(t, conv.ID, retrieved.ID)
@@ -40,7 +41,7 @@ func TestMemoryStore_CreateAndGet(t *testing.T) {
func TestMemoryStore_GetNonExistent(t *testing.T) {
store := NewMemoryStore(1 * time.Hour)
conv, err := store.Get("nonexistent")
conv, err := store.Get(context.Background(),"nonexistent")
require.NoError(t, err)
assert.Nil(t, conv, "should return nil for nonexistent conversation")
}
@@ -57,7 +58,7 @@ func TestMemoryStore_Append(t *testing.T) {
},
}
_, err := store.Create("test-id", "gpt-4", initialMessages)
_, err := store.Create(context.Background(),"test-id", "gpt-4", initialMessages)
require.NoError(t, err)
newMessages := []api.Message{
@@ -75,7 +76,7 @@ func TestMemoryStore_Append(t *testing.T) {
},
}
conv, err := store.Append("test-id", newMessages...)
conv, err := store.Append(context.Background(),"test-id", newMessages...)
require.NoError(t, err)
require.NotNil(t, conv)
assert.Len(t, conv.Messages, 3, "should have all messages")
@@ -94,7 +95,7 @@ func TestMemoryStore_AppendNonExistent(t *testing.T) {
},
}
conv, err := store.Append("nonexistent", newMessage)
conv, err := store.Append(context.Background(),"nonexistent", newMessage)
require.NoError(t, err)
assert.Nil(t, conv, "should return nil when appending to nonexistent conversation")
}
@@ -111,20 +112,20 @@ func TestMemoryStore_Delete(t *testing.T) {
},
}
_, err := store.Create("test-id", "gpt-4", messages)
_, err := store.Create(context.Background(),"test-id", "gpt-4", messages)
require.NoError(t, err)
// Verify it exists
conv, err := store.Get("test-id")
conv, err := store.Get(context.Background(),"test-id")
require.NoError(t, err)
assert.NotNil(t, conv)
// Delete it
err = store.Delete("test-id")
err = store.Delete(context.Background(),"test-id")
require.NoError(t, err)
// Verify it's gone
conv, err = store.Get("test-id")
conv, err = store.Get(context.Background(),"test-id")
require.NoError(t, err)
assert.Nil(t, conv, "conversation should be deleted")
}
@@ -138,15 +139,15 @@ func TestMemoryStore_Size(t *testing.T) {
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
}
_, err := store.Create("conv-1", "gpt-4", messages)
_, err := store.Create(context.Background(),"conv-1", "gpt-4", messages)
require.NoError(t, err)
assert.Equal(t, 1, store.Size())
_, err = store.Create("conv-2", "gpt-4", messages)
_, err = store.Create(context.Background(),"conv-2", "gpt-4", messages)
require.NoError(t, err)
assert.Equal(t, 2, store.Size())
err = store.Delete("conv-1")
err = store.Delete(context.Background(),"conv-1")
require.NoError(t, err)
assert.Equal(t, 1, store.Size())
}
@@ -159,14 +160,14 @@ func TestMemoryStore_ConcurrentAccess(t *testing.T) {
}
// Create initial conversation
_, err := store.Create("test-id", "gpt-4", messages)
_, err := store.Create(context.Background(),"test-id", "gpt-4", messages)
require.NoError(t, err)
// Simulate concurrent reads and writes
done := make(chan bool, 10)
for i := 0; i < 5; i++ {
go func() {
_, _ = store.Get("test-id")
_, _ = store.Get(context.Background(),"test-id")
done <- true
}()
}
@@ -176,7 +177,7 @@ func TestMemoryStore_ConcurrentAccess(t *testing.T) {
Role: "assistant",
Content: []api.ContentBlock{{Type: "output_text", Text: "Response"}},
}
_, _ = store.Append("test-id", newMsg)
_, _ = store.Append(context.Background(),"test-id", newMsg)
done <- true
}()
}
@@ -187,7 +188,7 @@ func TestMemoryStore_ConcurrentAccess(t *testing.T) {
}
// Verify final state
conv, err := store.Get("test-id")
conv, err := store.Get(context.Background(),"test-id")
require.NoError(t, err)
assert.NotNil(t, conv)
assert.GreaterOrEqual(t, len(conv.Messages), 1)
@@ -205,11 +206,11 @@ func TestMemoryStore_DeepCopy(t *testing.T) {
},
}
_, err := store.Create("test-id", "gpt-4", messages)
_, err := store.Create(context.Background(),"test-id", "gpt-4", messages)
require.NoError(t, err)
// Get conversation
conv1, err := store.Get("test-id")
conv1, err := store.Get(context.Background(),"test-id")
require.NoError(t, err)
// Note: Current implementation copies the Messages slice but not the Content blocks
@@ -225,7 +226,7 @@ func TestMemoryStore_DeepCopy(t *testing.T) {
assert.Equal(t, originalLen+1, len(conv1.Messages), "can modify returned message slice")
// Verify original is unchanged
conv2, err := store.Get("test-id")
conv2, err := store.Get(context.Background(),"test-id")
require.NoError(t, err)
assert.Equal(t, originalLen, len(conv2.Messages), "original conversation unaffected by slice modification")
}
@@ -238,11 +239,11 @@ func TestMemoryStore_TTLCleanup(t *testing.T) {
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
}
_, err := store.Create("test-id", "gpt-4", messages)
_, err := store.Create(context.Background(),"test-id", "gpt-4", messages)
require.NoError(t, err)
// Verify it exists
conv, err := store.Get("test-id")
conv, err := store.Get(context.Background(),"test-id")
require.NoError(t, err)
assert.NotNil(t, conv)
assert.Equal(t, 1, store.Size())
@@ -265,12 +266,12 @@ func TestMemoryStore_NoTTL(t *testing.T) {
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
}
_, err := store.Create("test-id", "gpt-4", messages)
_, err := store.Create(context.Background(),"test-id", "gpt-4", messages)
require.NoError(t, err)
assert.Equal(t, 1, store.Size())
// Without TTL, conversation should persist indefinitely
conv, err := store.Get("test-id")
conv, err := store.Get(context.Background(),"test-id")
require.NoError(t, err)
assert.NotNil(t, conv)
}
@@ -282,7 +283,7 @@ func TestMemoryStore_UpdatedAtTracking(t *testing.T) {
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
}
conv, err := store.Create("test-id", "gpt-4", messages)
conv, err := store.Create(context.Background(),"test-id", "gpt-4", messages)
require.NoError(t, err)
createdAt := conv.CreatedAt
updatedAt := conv.UpdatedAt
@@ -296,7 +297,7 @@ func TestMemoryStore_UpdatedAtTracking(t *testing.T) {
Role: "assistant",
Content: []api.ContentBlock{{Type: "output_text", Text: "Response"}},
}
conv, err = store.Append("test-id", newMsg)
conv, err = store.Append(context.Background(),"test-id", newMsg)
require.NoError(t, err)
assert.Equal(t, createdAt, conv.CreatedAt, "created time should not change")
@@ -313,7 +314,7 @@ func TestMemoryStore_MultipleConversations(t *testing.T) {
messages := []api.Message{
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello " + id}}},
}
_, err := store.Create(id, model, messages)
_, err := store.Create(context.Background(),id, model, messages)
require.NoError(t, err)
}
@@ -322,7 +323,7 @@ func TestMemoryStore_MultipleConversations(t *testing.T) {
// Verify each conversation is independent
for i := 0; i < 10; i++ {
id := "conv-" + string(rune('0'+i))
conv, err := store.Get(id)
conv, err := store.Get(context.Background(),id)
require.NoError(t, err)
require.NotNil(t, conv)
assert.Equal(t, id, conv.ID)

View File

@@ -13,7 +13,6 @@ import (
type RedisStore struct {
client *redis.Client
ttl time.Duration
ctx context.Context
}
// NewRedisStore creates a Redis-backed conversation store.
@@ -21,7 +20,6 @@ func NewRedisStore(client *redis.Client, ttl time.Duration) *RedisStore {
return &RedisStore{
client: client,
ttl: ttl,
ctx: context.Background(),
}
}
@@ -31,8 +29,8 @@ func (s *RedisStore) key(id string) string {
}
// Get retrieves a conversation by ID from Redis.
func (s *RedisStore) Get(id string) (*Conversation, error) {
data, err := s.client.Get(s.ctx, s.key(id)).Bytes()
func (s *RedisStore) Get(ctx context.Context, id string) (*Conversation, error) {
data, err := s.client.Get(ctx, s.key(id)).Bytes()
if err == redis.Nil {
return nil, nil
}
@@ -49,7 +47,7 @@ func (s *RedisStore) Get(id string) (*Conversation, error) {
}
// Create creates a new conversation with the given messages.
func (s *RedisStore) Create(id string, model string, messages []api.Message) (*Conversation, error) {
func (s *RedisStore) Create(ctx context.Context, id string, model string, messages []api.Message) (*Conversation, error) {
now := time.Now()
conv := &Conversation{
ID: id,
@@ -64,7 +62,7 @@ func (s *RedisStore) Create(id string, model string, messages []api.Message) (*C
return nil, err
}
if err := s.client.Set(s.ctx, s.key(id), data, s.ttl).Err(); err != nil {
if err := s.client.Set(ctx, s.key(id), data, s.ttl).Err(); err != nil {
return nil, err
}
@@ -72,8 +70,8 @@ func (s *RedisStore) Create(id string, model string, messages []api.Message) (*C
}
// Append adds new messages to an existing conversation.
func (s *RedisStore) Append(id string, messages ...api.Message) (*Conversation, error) {
conv, err := s.Get(id)
func (s *RedisStore) Append(ctx context.Context, id string, messages ...api.Message) (*Conversation, error) {
conv, err := s.Get(ctx, id)
if err != nil {
return nil, err
}
@@ -89,7 +87,7 @@ func (s *RedisStore) Append(id string, messages ...api.Message) (*Conversation,
return nil, err
}
if err := s.client.Set(s.ctx, s.key(id), data, s.ttl).Err(); err != nil {
if err := s.client.Set(ctx, s.key(id), data, s.ttl).Err(); err != nil {
return nil, err
}
@@ -97,17 +95,18 @@ func (s *RedisStore) Append(id string, messages ...api.Message) (*Conversation,
}
// Delete removes a conversation from Redis.
func (s *RedisStore) Delete(id string) error {
return s.client.Del(s.ctx, s.key(id)).Err()
func (s *RedisStore) Delete(ctx context.Context, id string) error {
return s.client.Del(ctx, s.key(id)).Err()
}
// Size returns the number of active conversations in Redis.
func (s *RedisStore) Size() int {
var count int
var cursor uint64
ctx := context.Background()
for {
keys, nextCursor, err := s.client.Scan(s.ctx, cursor, "conv:*", 100).Result()
keys, nextCursor, err := s.client.Scan(ctx, cursor, "conv:*", 100).Result()
if err != nil {
return 0
}

View File

@@ -1,6 +1,7 @@
package conversation
import (
"context"
"database/sql"
"encoding/json"
"time"
@@ -71,8 +72,8 @@ func NewSQLStore(db *sql.DB, driver string, ttl time.Duration) (*SQLStore, error
return s, nil
}
func (s *SQLStore) Get(id string) (*Conversation, error) {
row := s.db.QueryRow(s.dialect.getByID, id)
func (s *SQLStore) Get(ctx context.Context, id string) (*Conversation, error) {
row := s.db.QueryRowContext(ctx, s.dialect.getByID, id)
var conv Conversation
var msgJSON string
@@ -91,14 +92,14 @@ func (s *SQLStore) Get(id string) (*Conversation, error) {
return &conv, nil
}
func (s *SQLStore) Create(id string, model string, messages []api.Message) (*Conversation, error) {
func (s *SQLStore) Create(ctx context.Context, id string, model string, messages []api.Message) (*Conversation, error) {
now := time.Now()
msgJSON, err := json.Marshal(messages)
if err != nil {
return nil, err
}
if _, err := s.db.Exec(s.dialect.upsert, id, model, string(msgJSON), now, now); err != nil {
if _, err := s.db.ExecContext(ctx, s.dialect.upsert, id, model, string(msgJSON), now, now); err != nil {
return nil, err
}
@@ -111,8 +112,8 @@ func (s *SQLStore) Create(id string, model string, messages []api.Message) (*Con
}, nil
}
func (s *SQLStore) Append(id string, messages ...api.Message) (*Conversation, error) {
conv, err := s.Get(id)
func (s *SQLStore) Append(ctx context.Context, id string, messages ...api.Message) (*Conversation, error) {
conv, err := s.Get(ctx, id)
if err != nil {
return nil, err
}
@@ -128,15 +129,15 @@ func (s *SQLStore) Append(id string, messages ...api.Message) (*Conversation, er
return nil, err
}
if _, err := s.db.Exec(s.dialect.update, string(msgJSON), conv.UpdatedAt, id); err != nil {
if _, err := s.db.ExecContext(ctx, s.dialect.update, string(msgJSON), conv.UpdatedAt, id); err != nil {
return nil, err
}
return conv, nil
}
func (s *SQLStore) Delete(id string) error {
_, err := s.db.Exec(s.dialect.deleteByID, id)
func (s *SQLStore) Delete(ctx context.Context, id string) error {
_, err := s.db.ExecContext(ctx, s.dialect.deleteByID, id)
return err
}