diff --git a/go.mod b/go.mod index b423b8b..5cbad9b 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/google/uuid v1.6.0 github.com/jackc/pgx/v5 v5.8.0 github.com/mattn/go-sqlite3 v1.14.34 + github.com/openai/openai-go v1.12.0 github.com/openai/openai-go/v3 v3.2.0 github.com/redis/go-redis/v9 v9.18.0 google.golang.org/genai v1.48.0 diff --git a/go.sum b/go.sum index f71fd69..0d5eb0f 100644 --- a/go.sum +++ b/go.sum @@ -91,6 +91,8 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0 github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/mattn/go-sqlite3 v1.14.34 h1:3NtcvcUnFBPsuRcno8pUtupspG/GM+9nZ88zgJcp6Zk= github.com/mattn/go-sqlite3 v1.14.34/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/openai/openai-go v1.12.0 h1:NBQCnXzqOTv5wsgNC36PrFEiskGfO5wccfCWDo9S1U0= +github.com/openai/openai-go v1.12.0/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= github.com/openai/openai-go/v3 v3.2.0 h1:2AbqFUCsoW2pm/2pUtPRuwK89dnoGHaQokzWsfoQO/U= github.com/openai/openai-go/v3 v3.2.0/go.mod h1:UOpNxkqC9OdNXNUfpNByKOtB4jAL0EssQXq5p8gO0Xs= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= diff --git a/internal/conversation/conversation.go b/internal/conversation/conversation.go index eec1e2b..ff757c8 100644 --- a/internal/conversation/conversation.go +++ b/internal/conversation/conversation.go @@ -9,10 +9,10 @@ import ( // Store defines the interface for conversation storage backends. type Store interface { - Get(id string) (*Conversation, bool) - Create(id string, model string, messages []api.Message) *Conversation - Append(id string, messages ...api.Message) (*Conversation, bool) - Delete(id string) + Get(id string) (*Conversation, error) + Create(id string, model string, messages []api.Message) (*Conversation, error) + Append(id string, messages ...api.Message) (*Conversation, error) + Delete(id string) error Size() int } @@ -47,55 +47,93 @@ func NewMemoryStore(ttl time.Duration) *MemoryStore { return s } -// Get retrieves a conversation by ID. -func (s *MemoryStore) Get(id string) (*Conversation, bool) { +// Get retrieves a conversation by ID. Returns a deep copy to prevent data races. +func (s *MemoryStore) Get(id string) (*Conversation, error) { s.mu.RLock() defer s.mu.RUnlock() - + conv, ok := s.conversations[id] - return conv, ok + if !ok { + return nil, nil + } + + // Return a deep copy to prevent data races + msgsCopy := make([]api.Message, len(conv.Messages)) + copy(msgsCopy, conv.Messages) + + return &Conversation{ + ID: conv.ID, + Messages: msgsCopy, + Model: conv.Model, + CreatedAt: conv.CreatedAt, + UpdatedAt: conv.UpdatedAt, + }, nil } // Create creates a new conversation with the given messages. -func (s *MemoryStore) Create(id string, model string, messages []api.Message) *Conversation { +func (s *MemoryStore) Create(id string, model string, messages []api.Message) (*Conversation, error) { s.mu.Lock() defer s.mu.Unlock() - + now := time.Now() + + // Store a copy to prevent external modifications + msgsCopy := make([]api.Message, len(messages)) + copy(msgsCopy, messages) + conv := &Conversation{ + ID: id, + Messages: msgsCopy, + Model: model, + CreatedAt: now, + UpdatedAt: now, + } + + s.conversations[id] = conv + + // Return a copy + return &Conversation{ ID: id, Messages: messages, Model: model, CreatedAt: now, UpdatedAt: now, - } - - s.conversations[id] = conv - return conv + }, nil } // Append adds new messages to an existing conversation. -func (s *MemoryStore) Append(id string, messages ...api.Message) (*Conversation, bool) { +func (s *MemoryStore) Append(id string, messages ...api.Message) (*Conversation, error) { s.mu.Lock() defer s.mu.Unlock() - + conv, ok := s.conversations[id] if !ok { - return nil, false + return nil, nil } - + conv.Messages = append(conv.Messages, messages...) conv.UpdatedAt = time.Now() - - return conv, true + + // Return a deep copy + msgsCopy := make([]api.Message, len(conv.Messages)) + copy(msgsCopy, conv.Messages) + + return &Conversation{ + ID: conv.ID, + Messages: msgsCopy, + Model: conv.Model, + CreatedAt: conv.CreatedAt, + UpdatedAt: conv.UpdatedAt, + }, nil } // Delete removes a conversation from the store. -func (s *MemoryStore) Delete(id string) { +func (s *MemoryStore) Delete(id string) error { s.mu.Lock() defer s.mu.Unlock() - + delete(s.conversations, id) + return nil } // cleanup periodically removes expired conversations. diff --git a/internal/conversation/redis_store.go b/internal/conversation/redis_store.go index 73b0cc1..5c96ba2 100644 --- a/internal/conversation/redis_store.go +++ b/internal/conversation/redis_store.go @@ -31,22 +31,25 @@ func (s *RedisStore) key(id string) string { } // Get retrieves a conversation by ID from Redis. -func (s *RedisStore) Get(id string) (*Conversation, bool) { +func (s *RedisStore) Get(id string) (*Conversation, error) { data, err := s.client.Get(s.ctx, s.key(id)).Bytes() + if err == redis.Nil { + return nil, nil + } if err != nil { - return nil, false + return nil, err } var conv Conversation if err := json.Unmarshal(data, &conv); err != nil { - return nil, false + return nil, err } - return &conv, true + return &conv, nil } // Create creates a new conversation with the given messages. -func (s *RedisStore) Create(id string, model string, messages []api.Message) *Conversation { +func (s *RedisStore) Create(id string, model string, messages []api.Message) (*Conversation, error) { now := time.Now() conv := &Conversation{ ID: id, @@ -56,31 +59,46 @@ func (s *RedisStore) Create(id string, model string, messages []api.Message) *Co UpdatedAt: now, } - data, _ := json.Marshal(conv) - _ = s.client.Set(s.ctx, s.key(id), data, s.ttl).Err() + data, err := json.Marshal(conv) + if err != nil { + return nil, err + } - return conv + if err := s.client.Set(s.ctx, s.key(id), data, s.ttl).Err(); err != nil { + return nil, err + } + + return conv, nil } // Append adds new messages to an existing conversation. -func (s *RedisStore) Append(id string, messages ...api.Message) (*Conversation, bool) { - conv, ok := s.Get(id) - if !ok { - return nil, false +func (s *RedisStore) Append(id string, messages ...api.Message) (*Conversation, error) { + conv, err := s.Get(id) + if err != nil { + return nil, err + } + if conv == nil { + return nil, nil } conv.Messages = append(conv.Messages, messages...) conv.UpdatedAt = time.Now() - data, _ := json.Marshal(conv) - _ = s.client.Set(s.ctx, s.key(id), data, s.ttl).Err() + data, err := json.Marshal(conv) + if err != nil { + return nil, err + } - return conv, true + if err := s.client.Set(s.ctx, s.key(id), data, s.ttl).Err(); err != nil { + return nil, err + } + + return conv, nil } // Delete removes a conversation from Redis. -func (s *RedisStore) Delete(id string) { - _ = s.client.Del(s.ctx, s.key(id)).Err() +func (s *RedisStore) Delete(id string) error { + return s.client.Del(s.ctx, s.key(id)).Err() } // Size returns the number of active conversations in Redis. diff --git a/internal/conversation/sql_store.go b/internal/conversation/sql_store.go index 4862b79..d1a7e84 100644 --- a/internal/conversation/sql_store.go +++ b/internal/conversation/sql_store.go @@ -65,28 +65,36 @@ func NewSQLStore(db *sql.DB, driver string, ttl time.Duration) (*SQLStore, error return s, nil } -func (s *SQLStore) Get(id string) (*Conversation, bool) { +func (s *SQLStore) Get(id string) (*Conversation, error) { row := s.db.QueryRow(s.dialect.getByID, id) var conv Conversation var msgJSON string err := row.Scan(&conv.ID, &conv.Model, &msgJSON, &conv.CreatedAt, &conv.UpdatedAt) + if err == sql.ErrNoRows { + return nil, nil + } if err != nil { - return nil, false + return nil, err } if err := json.Unmarshal([]byte(msgJSON), &conv.Messages); err != nil { - return nil, false + return nil, err } - return &conv, true + return &conv, nil } -func (s *SQLStore) Create(id string, model string, messages []api.Message) *Conversation { +func (s *SQLStore) Create(id string, model string, messages []api.Message) (*Conversation, error) { now := time.Now() - msgJSON, _ := json.Marshal(messages) + msgJSON, err := json.Marshal(messages) + if err != nil { + return nil, err + } - _, _ = s.db.Exec(s.dialect.upsert, id, model, string(msgJSON), now, now) + if _, err := s.db.Exec(s.dialect.upsert, id, model, string(msgJSON), now, now); err != nil { + return nil, err + } return &Conversation{ ID: id, @@ -94,26 +102,36 @@ func (s *SQLStore) Create(id string, model string, messages []api.Message) *Conv Model: model, CreatedAt: now, UpdatedAt: now, - } + }, nil } -func (s *SQLStore) Append(id string, messages ...api.Message) (*Conversation, bool) { - conv, ok := s.Get(id) - if !ok { - return nil, false +func (s *SQLStore) Append(id string, messages ...api.Message) (*Conversation, error) { + conv, err := s.Get(id) + if err != nil { + return nil, err + } + if conv == nil { + return nil, nil } conv.Messages = append(conv.Messages, messages...) conv.UpdatedAt = time.Now() - msgJSON, _ := json.Marshal(conv.Messages) - _, _ = s.db.Exec(s.dialect.update, string(msgJSON), conv.UpdatedAt, id) + msgJSON, err := json.Marshal(conv.Messages) + if err != nil { + return nil, err + } - return conv, true + if _, err := s.db.Exec(s.dialect.update, string(msgJSON), conv.UpdatedAt, id); err != nil { + return nil, err + } + + return conv, nil } -func (s *SQLStore) Delete(id string) { - _, _ = s.db.Exec(s.dialect.deleteByID, id) +func (s *SQLStore) Delete(id string) error { + _, err := s.db.Exec(s.dialect.deleteByID, id) + return err } func (s *SQLStore) Size() int { diff --git a/internal/server/server.go b/internal/server/server.go index b581201..1eff7c8 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -84,8 +84,13 @@ func (s *GatewayServer) handleResponses(w http.ResponseWriter, r *http.Request) // Build full message history from previous conversation var historyMsgs []api.Message if req.PreviousResponseID != nil && *req.PreviousResponseID != "" { - conv, ok := s.convs.Get(*req.PreviousResponseID) - if !ok { + conv, err := s.convs.Get(*req.PreviousResponseID) + if err != nil { + s.logger.Printf("error retrieving conversation: %v", err) + http.Error(w, "error retrieving conversation", http.StatusInternalServerError) + return + } + if conv == nil { http.Error(w, "conversation not found", http.StatusNotFound) return } @@ -140,7 +145,10 @@ func (s *GatewayServer) handleSyncResponse(w http.ResponseWriter, r *http.Reques Content: []api.ContentBlock{{Type: "output_text", Text: result.Text}}, } allMsgs := append(storeMsgs, assistantMsg) - s.convs.Create(responseID, result.Model, allMsgs) + if _, err := s.convs.Create(responseID, result.Model, allMsgs); err != nil { + s.logger.Printf("error storing conversation: %v", err) + // Don't fail the response if storage fails + } // Build spec-compliant response resp := s.buildResponse(origReq, result, provider.Name(), responseID) @@ -458,7 +466,10 @@ loop: Content: []api.ContentBlock{{Type: "output_text", Text: fullText}}, } allMsgs := append(storeMsgs, assistantMsg) - s.convs.Create(responseID, model, allMsgs) + if _, err := s.convs.Create(responseID, model, allMsgs); err != nil { + s.logger.Printf("error storing conversation: %v", err) + // Don't fail the response if storage fails + } } }