Improve Stores
This commit is contained in:
1
go.mod
1
go.mod
@@ -9,6 +9,7 @@ require (
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jackc/pgx/v5 v5.8.0
|
||||
github.com/mattn/go-sqlite3 v1.14.34
|
||||
github.com/openai/openai-go v1.12.0
|
||||
github.com/openai/openai-go/v3 v3.2.0
|
||||
github.com/redis/go-redis/v9 v9.18.0
|
||||
google.golang.org/genai v1.48.0
|
||||
|
||||
2
go.sum
2
go.sum
@@ -91,6 +91,8 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/mattn/go-sqlite3 v1.14.34 h1:3NtcvcUnFBPsuRcno8pUtupspG/GM+9nZ88zgJcp6Zk=
|
||||
github.com/mattn/go-sqlite3 v1.14.34/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/openai/openai-go v1.12.0 h1:NBQCnXzqOTv5wsgNC36PrFEiskGfO5wccfCWDo9S1U0=
|
||||
github.com/openai/openai-go v1.12.0/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y=
|
||||
github.com/openai/openai-go/v3 v3.2.0 h1:2AbqFUCsoW2pm/2pUtPRuwK89dnoGHaQokzWsfoQO/U=
|
||||
github.com/openai/openai-go/v3 v3.2.0/go.mod h1:UOpNxkqC9OdNXNUfpNByKOtB4jAL0EssQXq5p8gO0Xs=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||
|
||||
@@ -9,10 +9,10 @@ import (
|
||||
|
||||
// Store defines the interface for conversation storage backends.
|
||||
type Store interface {
|
||||
Get(id string) (*Conversation, bool)
|
||||
Create(id string, model string, messages []api.Message) *Conversation
|
||||
Append(id string, messages ...api.Message) (*Conversation, bool)
|
||||
Delete(id string)
|
||||
Get(id string) (*Conversation, error)
|
||||
Create(id string, model string, messages []api.Message) (*Conversation, error)
|
||||
Append(id string, messages ...api.Message) (*Conversation, error)
|
||||
Delete(id string) error
|
||||
Size() int
|
||||
}
|
||||
|
||||
@@ -47,55 +47,93 @@ func NewMemoryStore(ttl time.Duration) *MemoryStore {
|
||||
return s
|
||||
}
|
||||
|
||||
// Get retrieves a conversation by ID.
|
||||
func (s *MemoryStore) Get(id string) (*Conversation, bool) {
|
||||
// Get retrieves a conversation by ID. Returns a deep copy to prevent data races.
|
||||
func (s *MemoryStore) Get(id string) (*Conversation, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
|
||||
conv, ok := s.conversations[id]
|
||||
return conv, ok
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Return a deep copy to prevent data races
|
||||
msgsCopy := make([]api.Message, len(conv.Messages))
|
||||
copy(msgsCopy, conv.Messages)
|
||||
|
||||
return &Conversation{
|
||||
ID: conv.ID,
|
||||
Messages: msgsCopy,
|
||||
Model: conv.Model,
|
||||
CreatedAt: conv.CreatedAt,
|
||||
UpdatedAt: conv.UpdatedAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Create creates a new conversation with the given messages.
|
||||
func (s *MemoryStore) Create(id string, model string, messages []api.Message) *Conversation {
|
||||
func (s *MemoryStore) Create(id string, model string, messages []api.Message) (*Conversation, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// Store a copy to prevent external modifications
|
||||
msgsCopy := make([]api.Message, len(messages))
|
||||
copy(msgsCopy, messages)
|
||||
|
||||
conv := &Conversation{
|
||||
ID: id,
|
||||
Messages: msgsCopy,
|
||||
Model: model,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
s.conversations[id] = conv
|
||||
|
||||
// Return a copy
|
||||
return &Conversation{
|
||||
ID: id,
|
||||
Messages: messages,
|
||||
Model: model,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
s.conversations[id] = conv
|
||||
return conv
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Append adds new messages to an existing conversation.
|
||||
func (s *MemoryStore) Append(id string, messages ...api.Message) (*Conversation, bool) {
|
||||
func (s *MemoryStore) Append(id string, messages ...api.Message) (*Conversation, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
|
||||
conv, ok := s.conversations[id]
|
||||
if !ok {
|
||||
return nil, false
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
||||
conv.Messages = append(conv.Messages, messages...)
|
||||
conv.UpdatedAt = time.Now()
|
||||
|
||||
return conv, true
|
||||
|
||||
// Return a deep copy
|
||||
msgsCopy := make([]api.Message, len(conv.Messages))
|
||||
copy(msgsCopy, conv.Messages)
|
||||
|
||||
return &Conversation{
|
||||
ID: conv.ID,
|
||||
Messages: msgsCopy,
|
||||
Model: conv.Model,
|
||||
CreatedAt: conv.CreatedAt,
|
||||
UpdatedAt: conv.UpdatedAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Delete removes a conversation from the store.
|
||||
func (s *MemoryStore) Delete(id string) {
|
||||
func (s *MemoryStore) Delete(id string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
|
||||
delete(s.conversations, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanup periodically removes expired conversations.
|
||||
|
||||
@@ -31,22 +31,25 @@ func (s *RedisStore) key(id string) string {
|
||||
}
|
||||
|
||||
// Get retrieves a conversation by ID from Redis.
|
||||
func (s *RedisStore) Get(id string) (*Conversation, bool) {
|
||||
func (s *RedisStore) Get(id string) (*Conversation, error) {
|
||||
data, err := s.client.Get(s.ctx, s.key(id)).Bytes()
|
||||
if err == redis.Nil {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, false
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var conv Conversation
|
||||
if err := json.Unmarshal(data, &conv); err != nil {
|
||||
return nil, false
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &conv, true
|
||||
return &conv, nil
|
||||
}
|
||||
|
||||
// Create creates a new conversation with the given messages.
|
||||
func (s *RedisStore) Create(id string, model string, messages []api.Message) *Conversation {
|
||||
func (s *RedisStore) Create(id string, model string, messages []api.Message) (*Conversation, error) {
|
||||
now := time.Now()
|
||||
conv := &Conversation{
|
||||
ID: id,
|
||||
@@ -56,31 +59,46 @@ func (s *RedisStore) Create(id string, model string, messages []api.Message) *Co
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(conv)
|
||||
_ = s.client.Set(s.ctx, s.key(id), data, s.ttl).Err()
|
||||
data, err := json.Marshal(conv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conv
|
||||
if err := s.client.Set(s.ctx, s.key(id), data, s.ttl).Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conv, nil
|
||||
}
|
||||
|
||||
// Append adds new messages to an existing conversation.
|
||||
func (s *RedisStore) Append(id string, messages ...api.Message) (*Conversation, bool) {
|
||||
conv, ok := s.Get(id)
|
||||
if !ok {
|
||||
return nil, false
|
||||
func (s *RedisStore) Append(id string, messages ...api.Message) (*Conversation, error) {
|
||||
conv, err := s.Get(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if conv == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
conv.Messages = append(conv.Messages, messages...)
|
||||
conv.UpdatedAt = time.Now()
|
||||
|
||||
data, _ := json.Marshal(conv)
|
||||
_ = s.client.Set(s.ctx, s.key(id), data, s.ttl).Err()
|
||||
data, err := json.Marshal(conv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conv, true
|
||||
if err := s.client.Set(s.ctx, s.key(id), data, s.ttl).Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conv, nil
|
||||
}
|
||||
|
||||
// Delete removes a conversation from Redis.
|
||||
func (s *RedisStore) Delete(id string) {
|
||||
_ = s.client.Del(s.ctx, s.key(id)).Err()
|
||||
func (s *RedisStore) Delete(id string) error {
|
||||
return s.client.Del(s.ctx, s.key(id)).Err()
|
||||
}
|
||||
|
||||
// Size returns the number of active conversations in Redis.
|
||||
|
||||
@@ -65,28 +65,36 @@ func NewSQLStore(db *sql.DB, driver string, ttl time.Duration) (*SQLStore, error
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *SQLStore) Get(id string) (*Conversation, bool) {
|
||||
func (s *SQLStore) Get(id string) (*Conversation, error) {
|
||||
row := s.db.QueryRow(s.dialect.getByID, id)
|
||||
|
||||
var conv Conversation
|
||||
var msgJSON string
|
||||
err := row.Scan(&conv.ID, &conv.Model, &msgJSON, &conv.CreatedAt, &conv.UpdatedAt)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, false
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(msgJSON), &conv.Messages); err != nil {
|
||||
return nil, false
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &conv, true
|
||||
return &conv, nil
|
||||
}
|
||||
|
||||
func (s *SQLStore) Create(id string, model string, messages []api.Message) *Conversation {
|
||||
func (s *SQLStore) Create(id string, model string, messages []api.Message) (*Conversation, error) {
|
||||
now := time.Now()
|
||||
msgJSON, _ := json.Marshal(messages)
|
||||
msgJSON, err := json.Marshal(messages)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, _ = s.db.Exec(s.dialect.upsert, id, model, string(msgJSON), now, now)
|
||||
if _, err := s.db.Exec(s.dialect.upsert, id, model, string(msgJSON), now, now); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Conversation{
|
||||
ID: id,
|
||||
@@ -94,26 +102,36 @@ func (s *SQLStore) Create(id string, model string, messages []api.Message) *Conv
|
||||
Model: model,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SQLStore) Append(id string, messages ...api.Message) (*Conversation, bool) {
|
||||
conv, ok := s.Get(id)
|
||||
if !ok {
|
||||
return nil, false
|
||||
func (s *SQLStore) Append(id string, messages ...api.Message) (*Conversation, error) {
|
||||
conv, err := s.Get(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if conv == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
conv.Messages = append(conv.Messages, messages...)
|
||||
conv.UpdatedAt = time.Now()
|
||||
|
||||
msgJSON, _ := json.Marshal(conv.Messages)
|
||||
_, _ = s.db.Exec(s.dialect.update, string(msgJSON), conv.UpdatedAt, id)
|
||||
msgJSON, err := json.Marshal(conv.Messages)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conv, true
|
||||
if _, err := s.db.Exec(s.dialect.update, string(msgJSON), conv.UpdatedAt, id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conv, nil
|
||||
}
|
||||
|
||||
func (s *SQLStore) Delete(id string) {
|
||||
_, _ = s.db.Exec(s.dialect.deleteByID, id)
|
||||
func (s *SQLStore) Delete(id string) error {
|
||||
_, err := s.db.Exec(s.dialect.deleteByID, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SQLStore) Size() int {
|
||||
|
||||
@@ -84,8 +84,13 @@ func (s *GatewayServer) handleResponses(w http.ResponseWriter, r *http.Request)
|
||||
// Build full message history from previous conversation
|
||||
var historyMsgs []api.Message
|
||||
if req.PreviousResponseID != nil && *req.PreviousResponseID != "" {
|
||||
conv, ok := s.convs.Get(*req.PreviousResponseID)
|
||||
if !ok {
|
||||
conv, err := s.convs.Get(*req.PreviousResponseID)
|
||||
if err != nil {
|
||||
s.logger.Printf("error retrieving conversation: %v", err)
|
||||
http.Error(w, "error retrieving conversation", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if conv == nil {
|
||||
http.Error(w, "conversation not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
@@ -140,7 +145,10 @@ func (s *GatewayServer) handleSyncResponse(w http.ResponseWriter, r *http.Reques
|
||||
Content: []api.ContentBlock{{Type: "output_text", Text: result.Text}},
|
||||
}
|
||||
allMsgs := append(storeMsgs, assistantMsg)
|
||||
s.convs.Create(responseID, result.Model, allMsgs)
|
||||
if _, err := s.convs.Create(responseID, result.Model, allMsgs); err != nil {
|
||||
s.logger.Printf("error storing conversation: %v", err)
|
||||
// Don't fail the response if storage fails
|
||||
}
|
||||
|
||||
// Build spec-compliant response
|
||||
resp := s.buildResponse(origReq, result, provider.Name(), responseID)
|
||||
@@ -458,7 +466,10 @@ loop:
|
||||
Content: []api.ContentBlock{{Type: "output_text", Text: fullText}},
|
||||
}
|
||||
allMsgs := append(storeMsgs, assistantMsg)
|
||||
s.convs.Create(responseID, model, allMsgs)
|
||||
if _, err := s.convs.Create(responseID, model, allMsgs); err != nil {
|
||||
s.logger.Printf("error storing conversation: %v", err)
|
||||
// Don't fail the response if storage fails
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user