Add Store interface
This commit is contained in:
@@ -1,7 +1,9 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"flag"
|
"flag"
|
||||||
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
@@ -49,7 +51,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Initialize conversation store (1 hour TTL)
|
// Initialize conversation store (1 hour TTL)
|
||||||
convStore := conversation.NewStore(1 * time.Hour)
|
convStore := conversation.NewMemoryStore(1 * time.Hour)
|
||||||
logger.Printf("Conversation store initialized (TTL: 1h)")
|
logger.Printf("Conversation store initialized (TTL: 1h)")
|
||||||
|
|
||||||
gatewayServer := server.New(registry, convStore, logger)
|
gatewayServer := server.New(registry, convStore, logger)
|
||||||
@@ -78,6 +80,38 @@ func main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func initConversationStore(cfg config.ConversationConfig, logger *log.Logger) (conversation.Store, error) {
|
||||||
|
var ttl time.Duration
|
||||||
|
if cfg.TTL != "" {
|
||||||
|
parsed, err := time.ParseDuration(cfg.TTL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid conversation ttl %q: %w", cfg.TTL, err)
|
||||||
|
}
|
||||||
|
ttl = parsed
|
||||||
|
}
|
||||||
|
|
||||||
|
switch cfg.Store {
|
||||||
|
case "sql":
|
||||||
|
driver := cfg.Driver
|
||||||
|
if driver == "" {
|
||||||
|
driver = "sqlite3"
|
||||||
|
}
|
||||||
|
db, err := sql.Open(driver, cfg.DSN)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("open database: %w", err)
|
||||||
|
}
|
||||||
|
store, err := conversation.NewSQLStore(db, driver, ttl)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("init sql store: %w", err)
|
||||||
|
}
|
||||||
|
logger.Printf("Conversation store initialized (sql/%s, TTL: %s)", driver, ttl)
|
||||||
|
return store, nil
|
||||||
|
default:
|
||||||
|
logger.Printf("Conversation store initialized (memory, TTL: %s)", ttl)
|
||||||
|
return conversation.NewMemoryStore(ttl), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func loggingMiddleware(next http.Handler, logger *log.Logger) http.Handler {
|
func loggingMiddleware(next http.Handler, logger *log.Logger) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
|||||||
@@ -7,8 +7,17 @@ import (
|
|||||||
"github.com/yourusername/go-llm-gateway/internal/api"
|
"github.com/yourusername/go-llm-gateway/internal/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Store manages conversation history with automatic expiration.
|
// Store defines the interface for conversation storage backends.
|
||||||
type Store struct {
|
type Store interface {
|
||||||
|
Get(id string) (*Conversation, bool)
|
||||||
|
Create(id string, model string, messages []api.Message) *Conversation
|
||||||
|
Append(id string, messages ...api.Message) (*Conversation, bool)
|
||||||
|
Delete(id string)
|
||||||
|
Size() int
|
||||||
|
}
|
||||||
|
|
||||||
|
// MemoryStore manages conversation history in-memory with automatic expiration.
|
||||||
|
type MemoryStore struct {
|
||||||
conversations map[string]*Conversation
|
conversations map[string]*Conversation
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
ttl time.Duration
|
ttl time.Duration
|
||||||
@@ -23,21 +32,23 @@ type Conversation struct {
|
|||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewStore creates a conversation store with the given TTL.
|
// NewMemoryStore creates an in-memory conversation store with the given TTL.
|
||||||
func NewStore(ttl time.Duration) *Store {
|
func NewMemoryStore(ttl time.Duration) *MemoryStore {
|
||||||
s := &Store{
|
s := &MemoryStore{
|
||||||
conversations: make(map[string]*Conversation),
|
conversations: make(map[string]*Conversation),
|
||||||
ttl: ttl,
|
ttl: ttl,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start cleanup goroutine
|
// Start cleanup goroutine if TTL is set
|
||||||
go s.cleanup()
|
if ttl > 0 {
|
||||||
|
go s.cleanup()
|
||||||
|
}
|
||||||
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get retrieves a conversation by ID.
|
// Get retrieves a conversation by ID.
|
||||||
func (s *Store) Get(id string) (*Conversation, bool) {
|
func (s *MemoryStore) Get(id string) (*Conversation, bool) {
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
defer s.mu.RUnlock()
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
@@ -46,7 +57,7 @@ func (s *Store) Get(id string) (*Conversation, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create creates a new conversation with the given messages.
|
// Create creates a new conversation with the given messages.
|
||||||
func (s *Store) Create(id string, model string, messages []api.Message) *Conversation {
|
func (s *MemoryStore) Create(id string, model string, messages []api.Message) *Conversation {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
@@ -64,7 +75,7 @@ func (s *Store) Create(id string, model string, messages []api.Message) *Convers
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Append adds new messages to an existing conversation.
|
// Append adds new messages to an existing conversation.
|
||||||
func (s *Store) Append(id string, messages ...api.Message) (*Conversation, bool) {
|
func (s *MemoryStore) Append(id string, messages ...api.Message) (*Conversation, bool) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
@@ -80,7 +91,7 @@ func (s *Store) Append(id string, messages ...api.Message) (*Conversation, bool)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Delete removes a conversation from the store.
|
// Delete removes a conversation from the store.
|
||||||
func (s *Store) Delete(id string) {
|
func (s *MemoryStore) Delete(id string) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
@@ -88,7 +99,7 @@ func (s *Store) Delete(id string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// cleanup periodically removes expired conversations.
|
// cleanup periodically removes expired conversations.
|
||||||
func (s *Store) cleanup() {
|
func (s *MemoryStore) cleanup() {
|
||||||
ticker := time.NewTicker(1 * time.Minute)
|
ticker := time.NewTicker(1 * time.Minute)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
@@ -105,7 +116,7 @@ func (s *Store) cleanup() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Size returns the number of active conversations.
|
// Size returns the number of active conversations.
|
||||||
func (s *Store) Size() int {
|
func (s *MemoryStore) Size() int {
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
defer s.mu.RUnlock()
|
defer s.mu.RUnlock()
|
||||||
return len(s.conversations)
|
return len(s.conversations)
|
||||||
|
|||||||
133
internal/conversation/sql_store.go
Normal file
133
internal/conversation/sql_store.go
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
package conversation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/yourusername/go-llm-gateway/internal/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
// sqlDialect holds driver-specific SQL statements.
|
||||||
|
type sqlDialect struct {
|
||||||
|
getByID string
|
||||||
|
upsert string
|
||||||
|
update string
|
||||||
|
deleteByID string
|
||||||
|
cleanup string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDialect(driver string) sqlDialect {
|
||||||
|
if driver == "pgx" || driver == "postgres" {
|
||||||
|
return sqlDialect{
|
||||||
|
getByID: `SELECT id, model, messages, created_at, updated_at FROM conversations WHERE id = $1`,
|
||||||
|
upsert: `INSERT INTO conversations (id, model, messages, created_at, updated_at) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (id) DO UPDATE SET model = EXCLUDED.model, messages = EXCLUDED.messages, updated_at = EXCLUDED.updated_at`,
|
||||||
|
update: `UPDATE conversations SET messages = $1, updated_at = $2 WHERE id = $3`,
|
||||||
|
deleteByID: `DELETE FROM conversations WHERE id = $1`,
|
||||||
|
cleanup: `DELETE FROM conversations WHERE updated_at < $1`,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sqlDialect{
|
||||||
|
getByID: `SELECT id, model, messages, created_at, updated_at FROM conversations WHERE id = ?`,
|
||||||
|
upsert: `REPLACE INTO conversations (id, model, messages, created_at, updated_at) VALUES (?, ?, ?, ?, ?)`,
|
||||||
|
update: `UPDATE conversations SET messages = ?, updated_at = ? WHERE id = ?`,
|
||||||
|
deleteByID: `DELETE FROM conversations WHERE id = ?`,
|
||||||
|
cleanup: `DELETE FROM conversations WHERE updated_at < ?`,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SQLStore manages conversation history in a SQL database with automatic expiration.
|
||||||
|
type SQLStore struct {
|
||||||
|
db *sql.DB
|
||||||
|
ttl time.Duration
|
||||||
|
dialect sqlDialect
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSQLStore creates a SQL-backed conversation store. It creates the
|
||||||
|
// conversations table if it does not already exist and starts a background
|
||||||
|
// goroutine to remove expired rows.
|
||||||
|
func NewSQLStore(db *sql.DB, driver string, ttl time.Duration) (*SQLStore, error) {
|
||||||
|
_, err := db.Exec(`CREATE TABLE IF NOT EXISTS conversations (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
model TEXT NOT NULL,
|
||||||
|
messages TEXT NOT NULL,
|
||||||
|
created_at TIMESTAMP NOT NULL,
|
||||||
|
updated_at TIMESTAMP NOT NULL
|
||||||
|
)`)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &SQLStore{db: db, ttl: ttl, dialect: newDialect(driver)}
|
||||||
|
if ttl > 0 {
|
||||||
|
go s.cleanup()
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLStore) Get(id string) (*Conversation, bool) {
|
||||||
|
row := s.db.QueryRow(s.dialect.getByID, id)
|
||||||
|
|
||||||
|
var conv Conversation
|
||||||
|
var msgJSON string
|
||||||
|
err := row.Scan(&conv.ID, &conv.Model, &msgJSON, &conv.CreatedAt, &conv.UpdatedAt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal([]byte(msgJSON), &conv.Messages); err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return &conv, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLStore) Create(id string, model string, messages []api.Message) *Conversation {
|
||||||
|
now := time.Now()
|
||||||
|
msgJSON, _ := json.Marshal(messages)
|
||||||
|
|
||||||
|
_, _ = s.db.Exec(s.dialect.upsert, id, model, string(msgJSON), now, now)
|
||||||
|
|
||||||
|
return &Conversation{
|
||||||
|
ID: id,
|
||||||
|
Messages: messages,
|
||||||
|
Model: model,
|
||||||
|
CreatedAt: now,
|
||||||
|
UpdatedAt: now,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLStore) Append(id string, messages ...api.Message) (*Conversation, bool) {
|
||||||
|
conv, ok := s.Get(id)
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
conv.Messages = append(conv.Messages, messages...)
|
||||||
|
conv.UpdatedAt = time.Now()
|
||||||
|
|
||||||
|
msgJSON, _ := json.Marshal(conv.Messages)
|
||||||
|
_, _ = s.db.Exec(s.dialect.update, string(msgJSON), conv.UpdatedAt, id)
|
||||||
|
|
||||||
|
return conv, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLStore) Delete(id string) {
|
||||||
|
_, _ = s.db.Exec(s.dialect.deleteByID, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLStore) Size() int {
|
||||||
|
var count int
|
||||||
|
_ = s.db.QueryRow(`SELECT COUNT(*) FROM conversations`).Scan(&count)
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLStore) cleanup() {
|
||||||
|
ticker := time.NewTicker(1 * time.Minute)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for range ticker.C {
|
||||||
|
cutoff := time.Now().Add(-s.ttl)
|
||||||
|
_, _ = s.db.Exec(s.dialect.cleanup, cutoff)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -14,12 +14,12 @@ import (
|
|||||||
// GatewayServer hosts the Open Responses API for the gateway.
|
// GatewayServer hosts the Open Responses API for the gateway.
|
||||||
type GatewayServer struct {
|
type GatewayServer struct {
|
||||||
registry *providers.Registry
|
registry *providers.Registry
|
||||||
convs *conversation.Store
|
convs conversation.Store
|
||||||
logger *log.Logger
|
logger *log.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a GatewayServer bound to the provider registry.
|
// New creates a GatewayServer bound to the provider registry.
|
||||||
func New(registry *providers.Registry, convs *conversation.Store, logger *log.Logger) *GatewayServer {
|
func New(registry *providers.Registry, convs conversation.Store, logger *log.Logger) *GatewayServer {
|
||||||
return &GatewayServer{
|
return &GatewayServer{
|
||||||
registry: registry,
|
registry: registry,
|
||||||
convs: convs,
|
convs: convs,
|
||||||
|
|||||||
Reference in New Issue
Block a user