Fix context background and silent JWT
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -9,10 +10,10 @@ import (
|
||||
|
||||
// Store defines the interface for conversation storage backends.
|
||||
type Store interface {
|
||||
Get(id string) (*Conversation, error)
|
||||
Create(id string, model string, messages []api.Message) (*Conversation, error)
|
||||
Append(id string, messages ...api.Message) (*Conversation, error)
|
||||
Delete(id string) error
|
||||
Get(ctx context.Context, id string) (*Conversation, error)
|
||||
Create(ctx context.Context, id string, model string, messages []api.Message) (*Conversation, error)
|
||||
Append(ctx context.Context, id string, messages ...api.Message) (*Conversation, error)
|
||||
Delete(ctx context.Context, id string) error
|
||||
Size() int
|
||||
Close() error
|
||||
}
|
||||
@@ -51,7 +52,7 @@ func NewMemoryStore(ttl time.Duration) *MemoryStore {
|
||||
}
|
||||
|
||||
// Get retrieves a conversation by ID. Returns a deep copy to prevent data races.
|
||||
func (s *MemoryStore) Get(id string) (*Conversation, error) {
|
||||
func (s *MemoryStore) Get(ctx context.Context, id string) (*Conversation, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
@@ -74,7 +75,7 @@ func (s *MemoryStore) Get(id string) (*Conversation, error) {
|
||||
}
|
||||
|
||||
// Create creates a new conversation with the given messages.
|
||||
func (s *MemoryStore) Create(id string, model string, messages []api.Message) (*Conversation, error) {
|
||||
func (s *MemoryStore) Create(ctx context.Context, id string, model string, messages []api.Message) (*Conversation, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
@@ -105,7 +106,7 @@ func (s *MemoryStore) Create(id string, model string, messages []api.Message) (*
|
||||
}
|
||||
|
||||
// Append adds new messages to an existing conversation.
|
||||
func (s *MemoryStore) Append(id string, messages ...api.Message) (*Conversation, error) {
|
||||
func (s *MemoryStore) Append(ctx context.Context, id string, messages ...api.Message) (*Conversation, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
@@ -131,7 +132,7 @@ func (s *MemoryStore) Append(id string, messages ...api.Message) (*Conversation,
|
||||
}
|
||||
|
||||
// Delete removes a conversation from the store.
|
||||
func (s *MemoryStore) Delete(id string) error {
|
||||
func (s *MemoryStore) Delete(ctx context.Context, id string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -21,7 +22,7 @@ func TestMemoryStore_CreateAndGet(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
conv, err := store.Create("test-id", "gpt-4", messages)
|
||||
conv, err := store.Create(context.Background(),"test-id", "gpt-4", messages)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conv)
|
||||
assert.Equal(t, "test-id", conv.ID)
|
||||
@@ -29,7 +30,7 @@ func TestMemoryStore_CreateAndGet(t *testing.T) {
|
||||
assert.Len(t, conv.Messages, 1)
|
||||
assert.Equal(t, "Hello", conv.Messages[0].Content[0].Text)
|
||||
|
||||
retrieved, err := store.Get("test-id")
|
||||
retrieved, err := store.Get(context.Background(),"test-id")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, retrieved)
|
||||
assert.Equal(t, conv.ID, retrieved.ID)
|
||||
@@ -40,7 +41,7 @@ func TestMemoryStore_CreateAndGet(t *testing.T) {
|
||||
func TestMemoryStore_GetNonExistent(t *testing.T) {
|
||||
store := NewMemoryStore(1 * time.Hour)
|
||||
|
||||
conv, err := store.Get("nonexistent")
|
||||
conv, err := store.Get(context.Background(),"nonexistent")
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, conv, "should return nil for nonexistent conversation")
|
||||
}
|
||||
@@ -57,7 +58,7 @@ func TestMemoryStore_Append(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
_, err := store.Create("test-id", "gpt-4", initialMessages)
|
||||
_, err := store.Create(context.Background(),"test-id", "gpt-4", initialMessages)
|
||||
require.NoError(t, err)
|
||||
|
||||
newMessages := []api.Message{
|
||||
@@ -75,7 +76,7 @@ func TestMemoryStore_Append(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
conv, err := store.Append("test-id", newMessages...)
|
||||
conv, err := store.Append(context.Background(),"test-id", newMessages...)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conv)
|
||||
assert.Len(t, conv.Messages, 3, "should have all messages")
|
||||
@@ -94,7 +95,7 @@ func TestMemoryStore_AppendNonExistent(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
conv, err := store.Append("nonexistent", newMessage)
|
||||
conv, err := store.Append(context.Background(),"nonexistent", newMessage)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, conv, "should return nil when appending to nonexistent conversation")
|
||||
}
|
||||
@@ -111,20 +112,20 @@ func TestMemoryStore_Delete(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
_, err := store.Create("test-id", "gpt-4", messages)
|
||||
_, err := store.Create(context.Background(),"test-id", "gpt-4", messages)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify it exists
|
||||
conv, err := store.Get("test-id")
|
||||
conv, err := store.Get(context.Background(),"test-id")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, conv)
|
||||
|
||||
// Delete it
|
||||
err = store.Delete("test-id")
|
||||
err = store.Delete(context.Background(),"test-id")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify it's gone
|
||||
conv, err = store.Get("test-id")
|
||||
conv, err = store.Get(context.Background(),"test-id")
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, conv, "conversation should be deleted")
|
||||
}
|
||||
@@ -138,15 +139,15 @@ func TestMemoryStore_Size(t *testing.T) {
|
||||
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
|
||||
}
|
||||
|
||||
_, err := store.Create("conv-1", "gpt-4", messages)
|
||||
_, err := store.Create(context.Background(),"conv-1", "gpt-4", messages)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, store.Size())
|
||||
|
||||
_, err = store.Create("conv-2", "gpt-4", messages)
|
||||
_, err = store.Create(context.Background(),"conv-2", "gpt-4", messages)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, store.Size())
|
||||
|
||||
err = store.Delete("conv-1")
|
||||
err = store.Delete(context.Background(),"conv-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, store.Size())
|
||||
}
|
||||
@@ -159,14 +160,14 @@ func TestMemoryStore_ConcurrentAccess(t *testing.T) {
|
||||
}
|
||||
|
||||
// Create initial conversation
|
||||
_, err := store.Create("test-id", "gpt-4", messages)
|
||||
_, err := store.Create(context.Background(),"test-id", "gpt-4", messages)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate concurrent reads and writes
|
||||
done := make(chan bool, 10)
|
||||
for i := 0; i < 5; i++ {
|
||||
go func() {
|
||||
_, _ = store.Get("test-id")
|
||||
_, _ = store.Get(context.Background(),"test-id")
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
@@ -176,7 +177,7 @@ func TestMemoryStore_ConcurrentAccess(t *testing.T) {
|
||||
Role: "assistant",
|
||||
Content: []api.ContentBlock{{Type: "output_text", Text: "Response"}},
|
||||
}
|
||||
_, _ = store.Append("test-id", newMsg)
|
||||
_, _ = store.Append(context.Background(),"test-id", newMsg)
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
@@ -187,7 +188,7 @@ func TestMemoryStore_ConcurrentAccess(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify final state
|
||||
conv, err := store.Get("test-id")
|
||||
conv, err := store.Get(context.Background(),"test-id")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, conv)
|
||||
assert.GreaterOrEqual(t, len(conv.Messages), 1)
|
||||
@@ -205,11 +206,11 @@ func TestMemoryStore_DeepCopy(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
_, err := store.Create("test-id", "gpt-4", messages)
|
||||
_, err := store.Create(context.Background(),"test-id", "gpt-4", messages)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get conversation
|
||||
conv1, err := store.Get("test-id")
|
||||
conv1, err := store.Get(context.Background(),"test-id")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Note: Current implementation copies the Messages slice but not the Content blocks
|
||||
@@ -225,7 +226,7 @@ func TestMemoryStore_DeepCopy(t *testing.T) {
|
||||
assert.Equal(t, originalLen+1, len(conv1.Messages), "can modify returned message slice")
|
||||
|
||||
// Verify original is unchanged
|
||||
conv2, err := store.Get("test-id")
|
||||
conv2, err := store.Get(context.Background(),"test-id")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, originalLen, len(conv2.Messages), "original conversation unaffected by slice modification")
|
||||
}
|
||||
@@ -238,11 +239,11 @@ func TestMemoryStore_TTLCleanup(t *testing.T) {
|
||||
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
|
||||
}
|
||||
|
||||
_, err := store.Create("test-id", "gpt-4", messages)
|
||||
_, err := store.Create(context.Background(),"test-id", "gpt-4", messages)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify it exists
|
||||
conv, err := store.Get("test-id")
|
||||
conv, err := store.Get(context.Background(),"test-id")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, conv)
|
||||
assert.Equal(t, 1, store.Size())
|
||||
@@ -265,12 +266,12 @@ func TestMemoryStore_NoTTL(t *testing.T) {
|
||||
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
|
||||
}
|
||||
|
||||
_, err := store.Create("test-id", "gpt-4", messages)
|
||||
_, err := store.Create(context.Background(),"test-id", "gpt-4", messages)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, store.Size())
|
||||
|
||||
// Without TTL, conversation should persist indefinitely
|
||||
conv, err := store.Get("test-id")
|
||||
conv, err := store.Get(context.Background(),"test-id")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, conv)
|
||||
}
|
||||
@@ -282,7 +283,7 @@ func TestMemoryStore_UpdatedAtTracking(t *testing.T) {
|
||||
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
|
||||
}
|
||||
|
||||
conv, err := store.Create("test-id", "gpt-4", messages)
|
||||
conv, err := store.Create(context.Background(),"test-id", "gpt-4", messages)
|
||||
require.NoError(t, err)
|
||||
createdAt := conv.CreatedAt
|
||||
updatedAt := conv.UpdatedAt
|
||||
@@ -296,7 +297,7 @@ func TestMemoryStore_UpdatedAtTracking(t *testing.T) {
|
||||
Role: "assistant",
|
||||
Content: []api.ContentBlock{{Type: "output_text", Text: "Response"}},
|
||||
}
|
||||
conv, err = store.Append("test-id", newMsg)
|
||||
conv, err = store.Append(context.Background(),"test-id", newMsg)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, createdAt, conv.CreatedAt, "created time should not change")
|
||||
@@ -313,7 +314,7 @@ func TestMemoryStore_MultipleConversations(t *testing.T) {
|
||||
messages := []api.Message{
|
||||
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello " + id}}},
|
||||
}
|
||||
_, err := store.Create(id, model, messages)
|
||||
_, err := store.Create(context.Background(),id, model, messages)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -322,7 +323,7 @@ func TestMemoryStore_MultipleConversations(t *testing.T) {
|
||||
// Verify each conversation is independent
|
||||
for i := 0; i < 10; i++ {
|
||||
id := "conv-" + string(rune('0'+i))
|
||||
conv, err := store.Get(id)
|
||||
conv, err := store.Get(context.Background(),id)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conv)
|
||||
assert.Equal(t, id, conv.ID)
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
type RedisStore struct {
|
||||
client *redis.Client
|
||||
ttl time.Duration
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewRedisStore creates a Redis-backed conversation store.
|
||||
@@ -21,7 +20,6 @@ func NewRedisStore(client *redis.Client, ttl time.Duration) *RedisStore {
|
||||
return &RedisStore{
|
||||
client: client,
|
||||
ttl: ttl,
|
||||
ctx: context.Background(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,8 +29,8 @@ func (s *RedisStore) key(id string) string {
|
||||
}
|
||||
|
||||
// Get retrieves a conversation by ID from Redis.
|
||||
func (s *RedisStore) Get(id string) (*Conversation, error) {
|
||||
data, err := s.client.Get(s.ctx, s.key(id)).Bytes()
|
||||
func (s *RedisStore) Get(ctx context.Context, id string) (*Conversation, error) {
|
||||
data, err := s.client.Get(ctx, s.key(id)).Bytes()
|
||||
if err == redis.Nil {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -49,7 +47,7 @@ func (s *RedisStore) Get(id string) (*Conversation, error) {
|
||||
}
|
||||
|
||||
// Create creates a new conversation with the given messages.
|
||||
func (s *RedisStore) Create(id string, model string, messages []api.Message) (*Conversation, error) {
|
||||
func (s *RedisStore) Create(ctx context.Context, id string, model string, messages []api.Message) (*Conversation, error) {
|
||||
now := time.Now()
|
||||
conv := &Conversation{
|
||||
ID: id,
|
||||
@@ -64,7 +62,7 @@ func (s *RedisStore) Create(id string, model string, messages []api.Message) (*C
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.client.Set(s.ctx, s.key(id), data, s.ttl).Err(); err != nil {
|
||||
if err := s.client.Set(ctx, s.key(id), data, s.ttl).Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -72,8 +70,8 @@ func (s *RedisStore) Create(id string, model string, messages []api.Message) (*C
|
||||
}
|
||||
|
||||
// Append adds new messages to an existing conversation.
|
||||
func (s *RedisStore) Append(id string, messages ...api.Message) (*Conversation, error) {
|
||||
conv, err := s.Get(id)
|
||||
func (s *RedisStore) Append(ctx context.Context, id string, messages ...api.Message) (*Conversation, error) {
|
||||
conv, err := s.Get(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -89,7 +87,7 @@ func (s *RedisStore) Append(id string, messages ...api.Message) (*Conversation,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.client.Set(s.ctx, s.key(id), data, s.ttl).Err(); err != nil {
|
||||
if err := s.client.Set(ctx, s.key(id), data, s.ttl).Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -97,17 +95,18 @@ func (s *RedisStore) Append(id string, messages ...api.Message) (*Conversation,
|
||||
}
|
||||
|
||||
// Delete removes a conversation from Redis.
|
||||
func (s *RedisStore) Delete(id string) error {
|
||||
return s.client.Del(s.ctx, s.key(id)).Err()
|
||||
func (s *RedisStore) Delete(ctx context.Context, id string) error {
|
||||
return s.client.Del(ctx, s.key(id)).Err()
|
||||
}
|
||||
|
||||
// Size returns the number of active conversations in Redis.
|
||||
func (s *RedisStore) Size() int {
|
||||
var count int
|
||||
var cursor uint64
|
||||
ctx := context.Background()
|
||||
|
||||
for {
|
||||
keys, nextCursor, err := s.client.Scan(s.ctx, cursor, "conv:*", 100).Result()
|
||||
keys, nextCursor, err := s.client.Scan(ctx, cursor, "conv:*", 100).Result()
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
@@ -71,8 +72,8 @@ func NewSQLStore(db *sql.DB, driver string, ttl time.Duration) (*SQLStore, error
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *SQLStore) Get(id string) (*Conversation, error) {
|
||||
row := s.db.QueryRow(s.dialect.getByID, id)
|
||||
func (s *SQLStore) Get(ctx context.Context, id string) (*Conversation, error) {
|
||||
row := s.db.QueryRowContext(ctx, s.dialect.getByID, id)
|
||||
|
||||
var conv Conversation
|
||||
var msgJSON string
|
||||
@@ -91,14 +92,14 @@ func (s *SQLStore) Get(id string) (*Conversation, error) {
|
||||
return &conv, nil
|
||||
}
|
||||
|
||||
func (s *SQLStore) Create(id string, model string, messages []api.Message) (*Conversation, error) {
|
||||
func (s *SQLStore) Create(ctx context.Context, id string, model string, messages []api.Message) (*Conversation, error) {
|
||||
now := time.Now()
|
||||
msgJSON, err := json.Marshal(messages)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := s.db.Exec(s.dialect.upsert, id, model, string(msgJSON), now, now); err != nil {
|
||||
if _, err := s.db.ExecContext(ctx, s.dialect.upsert, id, model, string(msgJSON), now, now); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -111,8 +112,8 @@ func (s *SQLStore) Create(id string, model string, messages []api.Message) (*Con
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SQLStore) Append(id string, messages ...api.Message) (*Conversation, error) {
|
||||
conv, err := s.Get(id)
|
||||
func (s *SQLStore) Append(ctx context.Context, id string, messages ...api.Message) (*Conversation, error) {
|
||||
conv, err := s.Get(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -128,15 +129,15 @@ func (s *SQLStore) Append(id string, messages ...api.Message) (*Conversation, er
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := s.db.Exec(s.dialect.update, string(msgJSON), conv.UpdatedAt, id); err != nil {
|
||||
if _, err := s.db.ExecContext(ctx, s.dialect.update, string(msgJSON), conv.UpdatedAt, id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conv, nil
|
||||
}
|
||||
|
||||
func (s *SQLStore) Delete(id string) error {
|
||||
_, err := s.db.Exec(s.dialect.deleteByID, id)
|
||||
func (s *SQLStore) Delete(ctx context.Context, id string) error {
|
||||
_, err := s.db.ExecContext(ctx, s.dialect.deleteByID, id)
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user