Improve Stores

This commit is contained in:
2026-03-02 15:47:36 +00:00
parent 259d02d140
commit 830a87afa1
6 changed files with 148 additions and 60 deletions

View File

@@ -9,10 +9,10 @@ import (
// Store defines the interface for conversation storage backends.
type Store interface {
Get(id string) (*Conversation, bool)
Create(id string, model string, messages []api.Message) *Conversation
Append(id string, messages ...api.Message) (*Conversation, bool)
Delete(id string)
Get(id string) (*Conversation, error)
Create(id string, model string, messages []api.Message) (*Conversation, error)
Append(id string, messages ...api.Message) (*Conversation, error)
Delete(id string) error
Size() int
}
@@ -47,55 +47,93 @@ func NewMemoryStore(ttl time.Duration) *MemoryStore {
return s
}
// Get retrieves a conversation by ID.
func (s *MemoryStore) Get(id string) (*Conversation, bool) {
// Get retrieves a conversation by ID. Returns a deep copy to prevent data races.
func (s *MemoryStore) Get(id string) (*Conversation, error) {
s.mu.RLock()
defer s.mu.RUnlock()
conv, ok := s.conversations[id]
return conv, ok
if !ok {
return nil, nil
}
// Return a deep copy to prevent data races
msgsCopy := make([]api.Message, len(conv.Messages))
copy(msgsCopy, conv.Messages)
return &Conversation{
ID: conv.ID,
Messages: msgsCopy,
Model: conv.Model,
CreatedAt: conv.CreatedAt,
UpdatedAt: conv.UpdatedAt,
}, nil
}
// Create creates a new conversation with the given messages.
func (s *MemoryStore) Create(id string, model string, messages []api.Message) *Conversation {
func (s *MemoryStore) Create(id string, model string, messages []api.Message) (*Conversation, error) {
s.mu.Lock()
defer s.mu.Unlock()
now := time.Now()
// Store a copy to prevent external modifications
msgsCopy := make([]api.Message, len(messages))
copy(msgsCopy, messages)
conv := &Conversation{
ID: id,
Messages: msgsCopy,
Model: model,
CreatedAt: now,
UpdatedAt: now,
}
s.conversations[id] = conv
// Return a copy
return &Conversation{
ID: id,
Messages: messages,
Model: model,
CreatedAt: now,
UpdatedAt: now,
}
s.conversations[id] = conv
return conv
}, nil
}
// Append adds new messages to an existing conversation.
func (s *MemoryStore) Append(id string, messages ...api.Message) (*Conversation, bool) {
func (s *MemoryStore) Append(id string, messages ...api.Message) (*Conversation, error) {
s.mu.Lock()
defer s.mu.Unlock()
conv, ok := s.conversations[id]
if !ok {
return nil, false
return nil, nil
}
conv.Messages = append(conv.Messages, messages...)
conv.UpdatedAt = time.Now()
return conv, true
// Return a deep copy
msgsCopy := make([]api.Message, len(conv.Messages))
copy(msgsCopy, conv.Messages)
return &Conversation{
ID: conv.ID,
Messages: msgsCopy,
Model: conv.Model,
CreatedAt: conv.CreatedAt,
UpdatedAt: conv.UpdatedAt,
}, nil
}
// Delete removes a conversation from the store.
func (s *MemoryStore) Delete(id string) {
func (s *MemoryStore) Delete(id string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.conversations, id)
return nil
}
// cleanup periodically removes expired conversations.

View File

@@ -31,22 +31,25 @@ func (s *RedisStore) key(id string) string {
}
// Get retrieves a conversation by ID from Redis.
func (s *RedisStore) Get(id string) (*Conversation, bool) {
func (s *RedisStore) Get(id string) (*Conversation, error) {
data, err := s.client.Get(s.ctx, s.key(id)).Bytes()
if err == redis.Nil {
return nil, nil
}
if err != nil {
return nil, false
return nil, err
}
var conv Conversation
if err := json.Unmarshal(data, &conv); err != nil {
return nil, false
return nil, err
}
return &conv, true
return &conv, nil
}
// Create creates a new conversation with the given messages.
func (s *RedisStore) Create(id string, model string, messages []api.Message) *Conversation {
func (s *RedisStore) Create(id string, model string, messages []api.Message) (*Conversation, error) {
now := time.Now()
conv := &Conversation{
ID: id,
@@ -56,31 +59,46 @@ func (s *RedisStore) Create(id string, model string, messages []api.Message) *Co
UpdatedAt: now,
}
data, _ := json.Marshal(conv)
_ = s.client.Set(s.ctx, s.key(id), data, s.ttl).Err()
data, err := json.Marshal(conv)
if err != nil {
return nil, err
}
return conv
if err := s.client.Set(s.ctx, s.key(id), data, s.ttl).Err(); err != nil {
return nil, err
}
return conv, nil
}
// Append adds new messages to an existing conversation.
func (s *RedisStore) Append(id string, messages ...api.Message) (*Conversation, bool) {
conv, ok := s.Get(id)
if !ok {
return nil, false
func (s *RedisStore) Append(id string, messages ...api.Message) (*Conversation, error) {
conv, err := s.Get(id)
if err != nil {
return nil, err
}
if conv == nil {
return nil, nil
}
conv.Messages = append(conv.Messages, messages...)
conv.UpdatedAt = time.Now()
data, _ := json.Marshal(conv)
_ = s.client.Set(s.ctx, s.key(id), data, s.ttl).Err()
data, err := json.Marshal(conv)
if err != nil {
return nil, err
}
return conv, true
if err := s.client.Set(s.ctx, s.key(id), data, s.ttl).Err(); err != nil {
return nil, err
}
return conv, nil
}
// Delete removes a conversation from Redis.
func (s *RedisStore) Delete(id string) {
_ = s.client.Del(s.ctx, s.key(id)).Err()
func (s *RedisStore) Delete(id string) error {
return s.client.Del(s.ctx, s.key(id)).Err()
}
// Size returns the number of active conversations in Redis.

View File

@@ -65,28 +65,36 @@ func NewSQLStore(db *sql.DB, driver string, ttl time.Duration) (*SQLStore, error
return s, nil
}
func (s *SQLStore) Get(id string) (*Conversation, bool) {
func (s *SQLStore) Get(id string) (*Conversation, error) {
row := s.db.QueryRow(s.dialect.getByID, id)
var conv Conversation
var msgJSON string
err := row.Scan(&conv.ID, &conv.Model, &msgJSON, &conv.CreatedAt, &conv.UpdatedAt)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, false
return nil, err
}
if err := json.Unmarshal([]byte(msgJSON), &conv.Messages); err != nil {
return nil, false
return nil, err
}
return &conv, true
return &conv, nil
}
func (s *SQLStore) Create(id string, model string, messages []api.Message) *Conversation {
func (s *SQLStore) Create(id string, model string, messages []api.Message) (*Conversation, error) {
now := time.Now()
msgJSON, _ := json.Marshal(messages)
msgJSON, err := json.Marshal(messages)
if err != nil {
return nil, err
}
_, _ = s.db.Exec(s.dialect.upsert, id, model, string(msgJSON), now, now)
if _, err := s.db.Exec(s.dialect.upsert, id, model, string(msgJSON), now, now); err != nil {
return nil, err
}
return &Conversation{
ID: id,
@@ -94,26 +102,36 @@ func (s *SQLStore) Create(id string, model string, messages []api.Message) *Conv
Model: model,
CreatedAt: now,
UpdatedAt: now,
}
}, nil
}
func (s *SQLStore) Append(id string, messages ...api.Message) (*Conversation, bool) {
conv, ok := s.Get(id)
if !ok {
return nil, false
func (s *SQLStore) Append(id string, messages ...api.Message) (*Conversation, error) {
conv, err := s.Get(id)
if err != nil {
return nil, err
}
if conv == nil {
return nil, nil
}
conv.Messages = append(conv.Messages, messages...)
conv.UpdatedAt = time.Now()
msgJSON, _ := json.Marshal(conv.Messages)
_, _ = s.db.Exec(s.dialect.update, string(msgJSON), conv.UpdatedAt, id)
msgJSON, err := json.Marshal(conv.Messages)
if err != nil {
return nil, err
}
return conv, true
if _, err := s.db.Exec(s.dialect.update, string(msgJSON), conv.UpdatedAt, id); err != nil {
return nil, err
}
return conv, nil
}
func (s *SQLStore) Delete(id string) {
_, _ = s.db.Exec(s.dialect.deleteByID, id)
func (s *SQLStore) Delete(id string) error {
_, err := s.db.Exec(s.dialect.deleteByID, id)
return err
}
func (s *SQLStore) Size() int {