diff --git a/cmd/gateway/main.go b/cmd/gateway/main.go index 3f37124..4dfb56e 100644 --- a/cmd/gateway/main.go +++ b/cmd/gateway/main.go @@ -1,7 +1,9 @@ package main import ( + "database/sql" "flag" + "fmt" "log" "net/http" "os" @@ -49,7 +51,7 @@ func main() { } // Initialize conversation store (1 hour TTL) - convStore := conversation.NewStore(1 * time.Hour) + convStore := conversation.NewMemoryStore(1 * time.Hour) logger.Printf("Conversation store initialized (TTL: 1h)") gatewayServer := server.New(registry, convStore, logger) @@ -78,6 +80,38 @@ func main() { } } +func initConversationStore(cfg config.ConversationConfig, logger *log.Logger) (conversation.Store, error) { + var ttl time.Duration + if cfg.TTL != "" { + parsed, err := time.ParseDuration(cfg.TTL) + if err != nil { + return nil, fmt.Errorf("invalid conversation ttl %q: %w", cfg.TTL, err) + } + ttl = parsed + } + + switch cfg.Store { + case "sql": + driver := cfg.Driver + if driver == "" { + driver = "sqlite3" + } + db, err := sql.Open(driver, cfg.DSN) + if err != nil { + return nil, fmt.Errorf("open database: %w", err) + } + store, err := conversation.NewSQLStore(db, driver, ttl) + if err != nil { + return nil, fmt.Errorf("init sql store: %w", err) + } + logger.Printf("Conversation store initialized (sql/%s, TTL: %s)", driver, ttl) + return store, nil + default: + logger.Printf("Conversation store initialized (memory, TTL: %s)", ttl) + return conversation.NewMemoryStore(ttl), nil + } +} + func loggingMiddleware(next http.Handler, logger *log.Logger) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() diff --git a/internal/conversation/conversation.go b/internal/conversation/conversation.go index d65be22..2013517 100644 --- a/internal/conversation/conversation.go +++ b/internal/conversation/conversation.go @@ -7,8 +7,17 @@ import ( "github.com/yourusername/go-llm-gateway/internal/api" ) -// Store manages conversation history with automatic expiration. -type Store struct { +// Store defines the interface for conversation storage backends. +type Store interface { + Get(id string) (*Conversation, bool) + Create(id string, model string, messages []api.Message) *Conversation + Append(id string, messages ...api.Message) (*Conversation, bool) + Delete(id string) + Size() int +} + +// MemoryStore manages conversation history in-memory with automatic expiration. +type MemoryStore struct { conversations map[string]*Conversation mu sync.RWMutex ttl time.Duration @@ -23,21 +32,23 @@ type Conversation struct { UpdatedAt time.Time } -// NewStore creates a conversation store with the given TTL. -func NewStore(ttl time.Duration) *Store { - s := &Store{ +// NewMemoryStore creates an in-memory conversation store with the given TTL. +func NewMemoryStore(ttl time.Duration) *MemoryStore { + s := &MemoryStore{ conversations: make(map[string]*Conversation), ttl: ttl, } - // Start cleanup goroutine - go s.cleanup() + // Start cleanup goroutine if TTL is set + if ttl > 0 { + go s.cleanup() + } return s } // Get retrieves a conversation by ID. -func (s *Store) Get(id string) (*Conversation, bool) { +func (s *MemoryStore) Get(id string) (*Conversation, bool) { s.mu.RLock() defer s.mu.RUnlock() @@ -46,7 +57,7 @@ func (s *Store) Get(id string) (*Conversation, bool) { } // Create creates a new conversation with the given messages. -func (s *Store) Create(id string, model string, messages []api.Message) *Conversation { +func (s *MemoryStore) Create(id string, model string, messages []api.Message) *Conversation { s.mu.Lock() defer s.mu.Unlock() @@ -64,7 +75,7 @@ func (s *Store) Create(id string, model string, messages []api.Message) *Convers } // Append adds new messages to an existing conversation. -func (s *Store) Append(id string, messages ...api.Message) (*Conversation, bool) { +func (s *MemoryStore) Append(id string, messages ...api.Message) (*Conversation, bool) { s.mu.Lock() defer s.mu.Unlock() @@ -80,7 +91,7 @@ func (s *Store) Append(id string, messages ...api.Message) (*Conversation, bool) } // Delete removes a conversation from the store. -func (s *Store) Delete(id string) { +func (s *MemoryStore) Delete(id string) { s.mu.Lock() defer s.mu.Unlock() @@ -88,7 +99,7 @@ func (s *Store) Delete(id string) { } // cleanup periodically removes expired conversations. -func (s *Store) cleanup() { +func (s *MemoryStore) cleanup() { ticker := time.NewTicker(1 * time.Minute) defer ticker.Stop() @@ -105,7 +116,7 @@ func (s *Store) cleanup() { } // Size returns the number of active conversations. -func (s *Store) Size() int { +func (s *MemoryStore) Size() int { s.mu.RLock() defer s.mu.RUnlock() return len(s.conversations) diff --git a/internal/conversation/sql_store.go b/internal/conversation/sql_store.go new file mode 100644 index 0000000..3080bc7 --- /dev/null +++ b/internal/conversation/sql_store.go @@ -0,0 +1,133 @@ +package conversation + +import ( + "database/sql" + "encoding/json" + "time" + + "github.com/yourusername/go-llm-gateway/internal/api" +) + +// sqlDialect holds driver-specific SQL statements. +type sqlDialect struct { + getByID string + upsert string + update string + deleteByID string + cleanup string +} + +func newDialect(driver string) sqlDialect { + if driver == "pgx" || driver == "postgres" { + return sqlDialect{ + getByID: `SELECT id, model, messages, created_at, updated_at FROM conversations WHERE id = $1`, + upsert: `INSERT INTO conversations (id, model, messages, created_at, updated_at) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (id) DO UPDATE SET model = EXCLUDED.model, messages = EXCLUDED.messages, updated_at = EXCLUDED.updated_at`, + update: `UPDATE conversations SET messages = $1, updated_at = $2 WHERE id = $3`, + deleteByID: `DELETE FROM conversations WHERE id = $1`, + cleanup: `DELETE FROM conversations WHERE updated_at < $1`, + } + } + return sqlDialect{ + getByID: `SELECT id, model, messages, created_at, updated_at FROM conversations WHERE id = ?`, + upsert: `REPLACE INTO conversations (id, model, messages, created_at, updated_at) VALUES (?, ?, ?, ?, ?)`, + update: `UPDATE conversations SET messages = ?, updated_at = ? WHERE id = ?`, + deleteByID: `DELETE FROM conversations WHERE id = ?`, + cleanup: `DELETE FROM conversations WHERE updated_at < ?`, + } +} + +// SQLStore manages conversation history in a SQL database with automatic expiration. +type SQLStore struct { + db *sql.DB + ttl time.Duration + dialect sqlDialect +} + +// NewSQLStore creates a SQL-backed conversation store. It creates the +// conversations table if it does not already exist and starts a background +// goroutine to remove expired rows. +func NewSQLStore(db *sql.DB, driver string, ttl time.Duration) (*SQLStore, error) { + _, err := db.Exec(`CREATE TABLE IF NOT EXISTS conversations ( + id TEXT PRIMARY KEY, + model TEXT NOT NULL, + messages TEXT NOT NULL, + created_at TIMESTAMP NOT NULL, + updated_at TIMESTAMP NOT NULL + )`) + if err != nil { + return nil, err + } + + s := &SQLStore{db: db, ttl: ttl, dialect: newDialect(driver)} + if ttl > 0 { + go s.cleanup() + } + return s, nil +} + +func (s *SQLStore) Get(id string) (*Conversation, bool) { + row := s.db.QueryRow(s.dialect.getByID, id) + + var conv Conversation + var msgJSON string + err := row.Scan(&conv.ID, &conv.Model, &msgJSON, &conv.CreatedAt, &conv.UpdatedAt) + if err != nil { + return nil, false + } + + if err := json.Unmarshal([]byte(msgJSON), &conv.Messages); err != nil { + return nil, false + } + + return &conv, true +} + +func (s *SQLStore) Create(id string, model string, messages []api.Message) *Conversation { + now := time.Now() + msgJSON, _ := json.Marshal(messages) + + _, _ = s.db.Exec(s.dialect.upsert, id, model, string(msgJSON), now, now) + + return &Conversation{ + ID: id, + Messages: messages, + Model: model, + CreatedAt: now, + UpdatedAt: now, + } +} + +func (s *SQLStore) Append(id string, messages ...api.Message) (*Conversation, bool) { + conv, ok := s.Get(id) + if !ok { + return nil, false + } + + conv.Messages = append(conv.Messages, messages...) + conv.UpdatedAt = time.Now() + + msgJSON, _ := json.Marshal(conv.Messages) + _, _ = s.db.Exec(s.dialect.update, string(msgJSON), conv.UpdatedAt, id) + + return conv, true +} + +func (s *SQLStore) Delete(id string) { + _, _ = s.db.Exec(s.dialect.deleteByID, id) +} + +func (s *SQLStore) Size() int { + var count int + _ = s.db.QueryRow(`SELECT COUNT(*) FROM conversations`).Scan(&count) + return count +} + +func (s *SQLStore) cleanup() { + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + + for range ticker.C { + cutoff := time.Now().Add(-s.ttl) + _, _ = s.db.Exec(s.dialect.cleanup, cutoff) + } +} diff --git a/internal/server/server.go b/internal/server/server.go index aada7d0..b846415 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -14,12 +14,12 @@ import ( // GatewayServer hosts the Open Responses API for the gateway. type GatewayServer struct { registry *providers.Registry - convs *conversation.Store + convs conversation.Store logger *log.Logger } // New creates a GatewayServer bound to the provider registry. -func New(registry *providers.Registry, convs *conversation.Store, logger *log.Logger) *GatewayServer { +func New(registry *providers.Registry, convs conversation.Store, logger *log.Logger) *GatewayServer { return &GatewayServer{ registry: registry, convs: convs,