Add Store interface

This commit is contained in:
2026-03-02 00:01:39 +00:00
parent 27e68f8e4c
commit c45d6cc89b
4 changed files with 194 additions and 16 deletions

View File

@@ -1,7 +1,9 @@
package main
import (
"database/sql"
"flag"
"fmt"
"log"
"net/http"
"os"
@@ -49,7 +51,7 @@ func main() {
}
// Initialize conversation store (1 hour TTL)
convStore := conversation.NewStore(1 * time.Hour)
convStore := conversation.NewMemoryStore(1 * time.Hour)
logger.Printf("Conversation store initialized (TTL: 1h)")
gatewayServer := server.New(registry, convStore, logger)
@@ -78,6 +80,38 @@ func main() {
}
}
func initConversationStore(cfg config.ConversationConfig, logger *log.Logger) (conversation.Store, error) {
var ttl time.Duration
if cfg.TTL != "" {
parsed, err := time.ParseDuration(cfg.TTL)
if err != nil {
return nil, fmt.Errorf("invalid conversation ttl %q: %w", cfg.TTL, err)
}
ttl = parsed
}
switch cfg.Store {
case "sql":
driver := cfg.Driver
if driver == "" {
driver = "sqlite3"
}
db, err := sql.Open(driver, cfg.DSN)
if err != nil {
return nil, fmt.Errorf("open database: %w", err)
}
store, err := conversation.NewSQLStore(db, driver, ttl)
if err != nil {
return nil, fmt.Errorf("init sql store: %w", err)
}
logger.Printf("Conversation store initialized (sql/%s, TTL: %s)", driver, ttl)
return store, nil
default:
logger.Printf("Conversation store initialized (memory, TTL: %s)", ttl)
return conversation.NewMemoryStore(ttl), nil
}
}
func loggingMiddleware(next http.Handler, logger *log.Logger) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()

View File

@@ -7,8 +7,17 @@ import (
"github.com/yourusername/go-llm-gateway/internal/api"
)
// Store manages conversation history with automatic expiration.
type Store struct {
// Store defines the interface for conversation storage backends.
type Store interface {
Get(id string) (*Conversation, bool)
Create(id string, model string, messages []api.Message) *Conversation
Append(id string, messages ...api.Message) (*Conversation, bool)
Delete(id string)
Size() int
}
// MemoryStore manages conversation history in-memory with automatic expiration.
type MemoryStore struct {
conversations map[string]*Conversation
mu sync.RWMutex
ttl time.Duration
@@ -23,21 +32,23 @@ type Conversation struct {
UpdatedAt time.Time
}
// NewStore creates a conversation store with the given TTL.
func NewStore(ttl time.Duration) *Store {
s := &Store{
// NewMemoryStore creates an in-memory conversation store with the given TTL.
func NewMemoryStore(ttl time.Duration) *MemoryStore {
s := &MemoryStore{
conversations: make(map[string]*Conversation),
ttl: ttl,
}
// Start cleanup goroutine
go s.cleanup()
// Start cleanup goroutine if TTL is set
if ttl > 0 {
go s.cleanup()
}
return s
}
// Get retrieves a conversation by ID.
func (s *Store) Get(id string) (*Conversation, bool) {
func (s *MemoryStore) Get(id string) (*Conversation, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -46,7 +57,7 @@ func (s *Store) Get(id string) (*Conversation, bool) {
}
// Create creates a new conversation with the given messages.
func (s *Store) Create(id string, model string, messages []api.Message) *Conversation {
func (s *MemoryStore) Create(id string, model string, messages []api.Message) *Conversation {
s.mu.Lock()
defer s.mu.Unlock()
@@ -64,7 +75,7 @@ func (s *Store) Create(id string, model string, messages []api.Message) *Convers
}
// Append adds new messages to an existing conversation.
func (s *Store) Append(id string, messages ...api.Message) (*Conversation, bool) {
func (s *MemoryStore) Append(id string, messages ...api.Message) (*Conversation, bool) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -80,7 +91,7 @@ func (s *Store) Append(id string, messages ...api.Message) (*Conversation, bool)
}
// Delete removes a conversation from the store.
func (s *Store) Delete(id string) {
func (s *MemoryStore) Delete(id string) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -88,7 +99,7 @@ func (s *Store) Delete(id string) {
}
// cleanup periodically removes expired conversations.
func (s *Store) cleanup() {
func (s *MemoryStore) cleanup() {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
@@ -105,7 +116,7 @@ func (s *Store) cleanup() {
}
// Size returns the number of active conversations.
func (s *Store) Size() int {
func (s *MemoryStore) Size() int {
s.mu.RLock()
defer s.mu.RUnlock()
return len(s.conversations)

View File

@@ -0,0 +1,133 @@
package conversation
import (
"database/sql"
"encoding/json"
"time"
"github.com/yourusername/go-llm-gateway/internal/api"
)
// sqlDialect holds driver-specific SQL statements.
type sqlDialect struct {
getByID string
upsert string
update string
deleteByID string
cleanup string
}
func newDialect(driver string) sqlDialect {
if driver == "pgx" || driver == "postgres" {
return sqlDialect{
getByID: `SELECT id, model, messages, created_at, updated_at FROM conversations WHERE id = $1`,
upsert: `INSERT INTO conversations (id, model, messages, created_at, updated_at) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (id) DO UPDATE SET model = EXCLUDED.model, messages = EXCLUDED.messages, updated_at = EXCLUDED.updated_at`,
update: `UPDATE conversations SET messages = $1, updated_at = $2 WHERE id = $3`,
deleteByID: `DELETE FROM conversations WHERE id = $1`,
cleanup: `DELETE FROM conversations WHERE updated_at < $1`,
}
}
return sqlDialect{
getByID: `SELECT id, model, messages, created_at, updated_at FROM conversations WHERE id = ?`,
upsert: `REPLACE INTO conversations (id, model, messages, created_at, updated_at) VALUES (?, ?, ?, ?, ?)`,
update: `UPDATE conversations SET messages = ?, updated_at = ? WHERE id = ?`,
deleteByID: `DELETE FROM conversations WHERE id = ?`,
cleanup: `DELETE FROM conversations WHERE updated_at < ?`,
}
}
// SQLStore manages conversation history in a SQL database with automatic expiration.
type SQLStore struct {
db *sql.DB
ttl time.Duration
dialect sqlDialect
}
// NewSQLStore creates a SQL-backed conversation store. It creates the
// conversations table if it does not already exist and starts a background
// goroutine to remove expired rows.
func NewSQLStore(db *sql.DB, driver string, ttl time.Duration) (*SQLStore, error) {
_, err := db.Exec(`CREATE TABLE IF NOT EXISTS conversations (
id TEXT PRIMARY KEY,
model TEXT NOT NULL,
messages TEXT NOT NULL,
created_at TIMESTAMP NOT NULL,
updated_at TIMESTAMP NOT NULL
)`)
if err != nil {
return nil, err
}
s := &SQLStore{db: db, ttl: ttl, dialect: newDialect(driver)}
if ttl > 0 {
go s.cleanup()
}
return s, nil
}
func (s *SQLStore) Get(id string) (*Conversation, bool) {
row := s.db.QueryRow(s.dialect.getByID, id)
var conv Conversation
var msgJSON string
err := row.Scan(&conv.ID, &conv.Model, &msgJSON, &conv.CreatedAt, &conv.UpdatedAt)
if err != nil {
return nil, false
}
if err := json.Unmarshal([]byte(msgJSON), &conv.Messages); err != nil {
return nil, false
}
return &conv, true
}
func (s *SQLStore) Create(id string, model string, messages []api.Message) *Conversation {
now := time.Now()
msgJSON, _ := json.Marshal(messages)
_, _ = s.db.Exec(s.dialect.upsert, id, model, string(msgJSON), now, now)
return &Conversation{
ID: id,
Messages: messages,
Model: model,
CreatedAt: now,
UpdatedAt: now,
}
}
func (s *SQLStore) Append(id string, messages ...api.Message) (*Conversation, bool) {
conv, ok := s.Get(id)
if !ok {
return nil, false
}
conv.Messages = append(conv.Messages, messages...)
conv.UpdatedAt = time.Now()
msgJSON, _ := json.Marshal(conv.Messages)
_, _ = s.db.Exec(s.dialect.update, string(msgJSON), conv.UpdatedAt, id)
return conv, true
}
func (s *SQLStore) Delete(id string) {
_, _ = s.db.Exec(s.dialect.deleteByID, id)
}
func (s *SQLStore) Size() int {
var count int
_ = s.db.QueryRow(`SELECT COUNT(*) FROM conversations`).Scan(&count)
return count
}
func (s *SQLStore) cleanup() {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for range ticker.C {
cutoff := time.Now().Add(-s.ttl)
_, _ = s.db.Exec(s.dialect.cleanup, cutoff)
}
}

View File

@@ -14,12 +14,12 @@ import (
// GatewayServer hosts the Open Responses API for the gateway.
type GatewayServer struct {
registry *providers.Registry
convs *conversation.Store
convs conversation.Store
logger *log.Logger
}
// New creates a GatewayServer bound to the provider registry.
func New(registry *providers.Registry, convs *conversation.Store, logger *log.Logger) *GatewayServer {
func New(registry *providers.Registry, convs conversation.Store, logger *log.Logger) *GatewayServer {
return &GatewayServer{
registry: registry,
convs: convs,