Add conversation management

This commit is contained in:
2026-02-28 22:10:55 +00:00
parent 4439567ccd
commit ae4c7ab489
8 changed files with 311 additions and 41 deletions

View File

@@ -54,6 +54,7 @@ Go LLM Gateway (unified API)
**Streaming support** (Server-Sent Events for all providers) **Streaming support** (Server-Sent Events for all providers)
**OAuth2/OIDC authentication** (Google, Auth0, any OIDC provider) **OAuth2/OIDC authentication** (Google, Auth0, any OIDC provider)
**Terminal chat client** (Python with Rich UI, PEP 723) **Terminal chat client** (Python with Rich UI, PEP 723)
**Conversation tracking** (previous_response_id for efficient context)
## Quick Start ## Quick Start
@@ -186,8 +187,21 @@ You> /model claude
You> /models # List all available models You> /models # List all available models
``` ```
The chat client automatically uses `previous_response_id` to reduce token usage by only sending new messages instead of the full conversation history.
See **[CHAT_CLIENT.md](./CHAT_CLIENT.md)** for full documentation. See **[CHAT_CLIENT.md](./CHAT_CLIENT.md)** for full documentation.
## Conversation Management
The gateway implements conversation tracking using `previous_response_id` from the Open Responses spec:
- 📉 **Reduced token usage** - Only send new messages
- ⚡ **Smaller requests** - Less bandwidth
- 🧠 **Server-side context** - Gateway maintains history
- ⏰ **Auto-expire** - Conversations expire after 1 hour
See **[CONVERSATIONS.md](./CONVERSATIONS.md)** for details.
## Authentication ## Authentication
The gateway supports OAuth2/OIDC authentication. See **[AUTH.md](./AUTH.md)** for setup instructions. The gateway supports OAuth2/OIDC authentication. See **[AUTH.md](./AUTH.md)** for setup instructions.
@@ -216,6 +230,8 @@ curl -X POST http://localhost:8080/v1/responses \
-~~Implement streaming responses~~ -~~Implement streaming responses~~
-~~Add OAuth2/OIDC authentication~~ -~~Add OAuth2/OIDC authentication~~
-~~Implement conversation tracking with previous_response_id~~
- ⬜ Add structured logging, tracing, and request-level metrics - ⬜ Add structured logging, tracing, and request-level metrics
- ⬜ Support tool/function calling - ⬜ Support tool/function calling
- ⬜ Persistent conversation storage (Redis/database)
- ⬜ Expand configuration to support routing policies (cost, latency, failover) - ⬜ Expand configuration to support routing policies (cost, latency, failover)

View File

@@ -9,6 +9,7 @@ import (
"github.com/yourusername/go-llm-gateway/internal/auth" "github.com/yourusername/go-llm-gateway/internal/auth"
"github.com/yourusername/go-llm-gateway/internal/config" "github.com/yourusername/go-llm-gateway/internal/config"
"github.com/yourusername/go-llm-gateway/internal/conversation"
"github.com/yourusername/go-llm-gateway/internal/providers" "github.com/yourusername/go-llm-gateway/internal/providers"
"github.com/yourusername/go-llm-gateway/internal/server" "github.com/yourusername/go-llm-gateway/internal/server"
) )
@@ -47,7 +48,11 @@ func main() {
logger.Printf("Authentication disabled - WARNING: API is publicly accessible") logger.Printf("Authentication disabled - WARNING: API is publicly accessible")
} }
gatewayServer := server.New(registry, logger) // Initialize conversation store (1 hour TTL)
convStore := conversation.NewStore(1 * time.Hour)
logger.Printf("Conversation store initialized (TTL: 1h)")
gatewayServer := server.New(registry, convStore, logger)
mux := http.NewServeMux() mux := http.NewServeMux()
gatewayServer.RegisterRoutes(mux) gatewayServer.RegisterRoutes(mux)

View File

@@ -7,12 +7,13 @@ import (
// ResponseRequest models the Open Responses create request payload. // ResponseRequest models the Open Responses create request payload.
type ResponseRequest struct { type ResponseRequest struct {
Model string `json:"model"` Model string `json:"model"`
Provider string `json:"provider,omitempty"` Provider string `json:"provider,omitempty"`
MaxOutputTokens int `json:"max_output_tokens,omitempty"` MaxOutputTokens int `json:"max_output_tokens,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"` Metadata map[string]string `json:"metadata,omitempty"`
Input []Message `json:"input"` Input []Message `json:"input"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
PreviousResponseID string `json:"previous_response_id,omitempty"`
} }
// Message captures user, assistant, or system roles. // Message captures user, assistant, or system roles.

View File

@@ -0,0 +1,112 @@
package conversation
import (
"sync"
"time"
"github.com/yourusername/go-llm-gateway/internal/api"
)
// Store manages conversation history with automatic expiration.
type Store struct {
conversations map[string]*Conversation
mu sync.RWMutex
ttl time.Duration
}
// Conversation holds the message history for a single conversation thread.
type Conversation struct {
ID string
Messages []api.Message
Model string
CreatedAt time.Time
UpdatedAt time.Time
}
// NewStore creates a conversation store with the given TTL.
func NewStore(ttl time.Duration) *Store {
s := &Store{
conversations: make(map[string]*Conversation),
ttl: ttl,
}
// Start cleanup goroutine
go s.cleanup()
return s
}
// Get retrieves a conversation by ID.
func (s *Store) Get(id string) (*Conversation, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
conv, ok := s.conversations[id]
return conv, ok
}
// Create creates a new conversation with the given messages.
func (s *Store) Create(id string, model string, messages []api.Message) *Conversation {
s.mu.Lock()
defer s.mu.Unlock()
now := time.Now()
conv := &Conversation{
ID: id,
Messages: messages,
Model: model,
CreatedAt: now,
UpdatedAt: now,
}
s.conversations[id] = conv
return conv
}
// Append adds new messages to an existing conversation.
func (s *Store) Append(id string, messages ...api.Message) (*Conversation, bool) {
s.mu.Lock()
defer s.mu.Unlock()
conv, ok := s.conversations[id]
if !ok {
return nil, false
}
conv.Messages = append(conv.Messages, messages...)
conv.UpdatedAt = time.Now()
return conv, true
}
// Delete removes a conversation from the store.
func (s *Store) Delete(id string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.conversations, id)
}
// cleanup periodically removes expired conversations.
func (s *Store) cleanup() {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for range ticker.C {
s.mu.Lock()
now := time.Now()
for id, conv := range s.conversations {
if now.Sub(conv.UpdatedAt) > s.ttl {
delete(s.conversations, id)
}
}
s.mu.Unlock()
}
}
// Size returns the number of active conversations.
func (s *Store) Size() int {
s.mu.RLock()
defer s.mu.RUnlock()
return len(s.conversations)
}

View File

@@ -53,7 +53,7 @@ func (p *Provider) Generate(ctx context.Context, req *api.ResponseRequest) (*api
for _, msg := range req.Input { for _, msg := range req.Input {
var content string var content string
for _, block := range msg.Content { for _, block := range msg.Content {
if block.Type == "input_text" { if block.Type == "input_text" || block.Type == "output_text" {
content += block.Text content += block.Text
} }
} }
@@ -147,22 +147,22 @@ func (p *Provider) GenerateStream(ctx context.Context, req *api.ResponseRequest)
for _, msg := range req.Input { for _, msg := range req.Input {
var content string var content string
for _, block := range msg.Content { for _, block := range msg.Content {
if block.Type == "input_text" { if block.Type == "input_text" || block.Type == "output_text" {
content += block.Text content += block.Text
} }
} }
switch msg.Role { switch msg.Role {
case "user": case "user":
messages = append(messages, anthropic.NewUserMessage(anthropic.NewTextBlock(content))) messages = append(messages, anthropic.NewUserMessage(anthropic.NewTextBlock(content)))
case "assistant": case "assistant":
messages = append(messages, anthropic.NewAssistantMessage(anthropic.NewTextBlock(content))) messages = append(messages, anthropic.NewAssistantMessage(anthropic.NewTextBlock(content)))
case "system": case "system":
system = content system = content
} }
} }
// Build params // Build params
params := anthropic.MessageNewParams{ params := anthropic.MessageNewParams{
Model: anthropic.Model(model), Model: anthropic.Model(model),
Messages: messages, Messages: messages,

View File

@@ -54,11 +54,21 @@ func (p *Provider) Generate(ctx context.Context, req *api.ResponseRequest) (*api
// Convert Open Responses messages to Gemini format // Convert Open Responses messages to Gemini format
var contents []*genai.Content var contents []*genai.Content
var systemText string
for _, msg := range req.Input { for _, msg := range req.Input {
if msg.Role == "system" {
for _, block := range msg.Content {
if block.Type == "input_text" || block.Type == "output_text" {
systemText += block.Text
}
}
continue
}
var parts []*genai.Part var parts []*genai.Part
for _, block := range msg.Content { for _, block := range msg.Content {
if block.Type == "input_text" { if block.Type == "input_text" || block.Type == "output_text" {
parts = append(parts, genai.NewPartFromText(block.Text)) parts = append(parts, genai.NewPartFromText(block.Text))
} }
} }
@@ -74,8 +84,18 @@ func (p *Provider) Generate(ctx context.Context, req *api.ResponseRequest) (*api
}) })
} }
// Build config with system instruction if present
var config *genai.GenerateContentConfig
if systemText != "" {
config = &genai.GenerateContentConfig{
SystemInstruction: &genai.Content{
Parts: []*genai.Part{genai.NewPartFromText(systemText)},
},
}
}
// Generate content // Generate content
resp, err := p.client.Models.GenerateContent(ctx, model, contents, nil) resp, err := p.client.Models.GenerateContent(ctx, model, contents, config)
if err != nil { if err != nil {
return nil, fmt.Errorf("google api error: %w", err) return nil, fmt.Errorf("google api error: %w", err)
} }
@@ -143,11 +163,21 @@ func (p *Provider) GenerateStream(ctx context.Context, req *api.ResponseRequest)
// Convert messages // Convert messages
var contents []*genai.Content var contents []*genai.Content
var systemText string
for _, msg := range req.Input { for _, msg := range req.Input {
if msg.Role == "system" {
for _, block := range msg.Content {
if block.Type == "input_text" || block.Type == "output_text" {
systemText += block.Text
}
}
continue
}
var parts []*genai.Part var parts []*genai.Part
for _, block := range msg.Content { for _, block := range msg.Content {
if block.Type == "input_text" { if block.Type == "input_text" || block.Type == "output_text" {
parts = append(parts, genai.NewPartFromText(block.Text)) parts = append(parts, genai.NewPartFromText(block.Text))
} }
} }
@@ -163,8 +193,18 @@ func (p *Provider) GenerateStream(ctx context.Context, req *api.ResponseRequest)
}) })
} }
// Build config with system instruction if present
var config *genai.GenerateContentConfig
if systemText != "" {
config = &genai.GenerateContentConfig{
SystemInstruction: &genai.Content{
Parts: []*genai.Part{genai.NewPartFromText(systemText)},
},
}
}
// Create stream // Create stream
stream := p.client.Models.GenerateContentStream(ctx, model, contents, nil) stream := p.client.Models.GenerateContentStream(ctx, model, contents, config)
// Process stream // Process stream
for resp, err := range stream { for resp, err := range stream {

View File

@@ -52,7 +52,7 @@ func (p *Provider) Generate(ctx context.Context, req *api.ResponseRequest) (*api
for _, msg := range req.Input { for _, msg := range req.Input {
var content string var content string
for _, block := range msg.Content { for _, block := range msg.Content {
if block.Type == "input_text" { if block.Type == "input_text" || block.Type == "output_text" {
content += block.Text content += block.Text
} }
} }
@@ -127,22 +127,22 @@ func (p *Provider) GenerateStream(ctx context.Context, req *api.ResponseRequest)
for _, msg := range req.Input { for _, msg := range req.Input {
var content string var content string
for _, block := range msg.Content { for _, block := range msg.Content {
if block.Type == "input_text" { if block.Type == "input_text" || block.Type == "output_text" {
content += block.Text content += block.Text
} }
} }
switch msg.Role { switch msg.Role {
case "user": case "user":
messages = append(messages, openai.UserMessage(content)) messages = append(messages, openai.UserMessage(content))
case "assistant": case "assistant":
messages = append(messages, openai.AssistantMessage(content)) messages = append(messages, openai.AssistantMessage(content))
case "system": case "system":
messages = append(messages, openai.SystemMessage(content)) messages = append(messages, openai.SystemMessage(content))
} }
} }
// Create streaming request // Create streaming request
stream := p.client.Chat.Completions.NewStreaming(ctx, openai.ChatCompletionNewParams{ stream := p.client.Chat.Completions.NewStreaming(ctx, openai.ChatCompletionNewParams{
Model: openai.ChatModel(model), Model: openai.ChatModel(model),
Messages: messages, Messages: messages,

View File

@@ -7,18 +7,24 @@ import (
"net/http" "net/http"
"github.com/yourusername/go-llm-gateway/internal/api" "github.com/yourusername/go-llm-gateway/internal/api"
"github.com/yourusername/go-llm-gateway/internal/conversation"
"github.com/yourusername/go-llm-gateway/internal/providers" "github.com/yourusername/go-llm-gateway/internal/providers"
) )
// GatewayServer hosts the Open Responses API for the gateway. // GatewayServer hosts the Open Responses API for the gateway.
type GatewayServer struct { type GatewayServer struct {
registry *providers.Registry registry *providers.Registry
convs *conversation.Store
logger *log.Logger logger *log.Logger
} }
// New creates a GatewayServer bound to the provider registry. // New creates a GatewayServer bound to the provider registry.
func New(registry *providers.Registry, logger *log.Logger) *GatewayServer { func New(registry *providers.Registry, convs *conversation.Store, logger *log.Logger) *GatewayServer {
return &GatewayServer{registry: registry, logger: logger} return &GatewayServer{
registry: registry,
convs: convs,
logger: logger,
}
} }
// RegisterRoutes wires the HTTP handlers onto the provided mux. // RegisterRoutes wires the HTTP handlers onto the provided mux.
@@ -43,6 +49,17 @@ func (s *GatewayServer) handleResponses(w http.ResponseWriter, r *http.Request)
return return
} }
// Build full message history
messages := s.buildMessageHistory(&req)
if messages == nil {
http.Error(w, "conversation not found", http.StatusNotFound)
return
}
// Update request with full history for provider
fullReq := req
fullReq.Input = messages
provider, err := s.resolveProvider(&req) provider, err := s.resolveProvider(&req)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway) http.Error(w, err.Error(), http.StatusBadGateway)
@@ -51,26 +68,58 @@ func (s *GatewayServer) handleResponses(w http.ResponseWriter, r *http.Request)
// Handle streaming vs non-streaming // Handle streaming vs non-streaming
if req.Stream { if req.Stream {
s.handleStreamingResponse(w, r, provider, &req) s.handleStreamingResponse(w, r, provider, &fullReq, &req)
} else { } else {
s.handleSyncResponse(w, r, provider, &req) s.handleSyncResponse(w, r, provider, &fullReq, &req)
} }
} }
func (s *GatewayServer) handleSyncResponse(w http.ResponseWriter, r *http.Request, provider providers.Provider, req *api.ResponseRequest) { func (s *GatewayServer) buildMessageHistory(req *api.ResponseRequest) []api.Message {
resp, err := provider.Generate(r.Context(), req) // If no previous_response_id, use input as-is
if req.PreviousResponseID == "" {
return req.Input
}
// Load previous conversation
conv, ok := s.convs.Get(req.PreviousResponseID)
if !ok {
return nil
}
// Append new input to conversation history
messages := make([]api.Message, len(conv.Messages))
copy(messages, conv.Messages)
messages = append(messages, req.Input...)
return messages
}
func (s *GatewayServer) handleSyncResponse(w http.ResponseWriter, r *http.Request, provider providers.Provider, fullReq *api.ResponseRequest, origReq *api.ResponseRequest) {
resp, err := provider.Generate(r.Context(), fullReq)
if err != nil { if err != nil {
s.logger.Printf("provider %s error: %v", provider.Name(), err) s.logger.Printf("provider %s error: %v", provider.Name(), err)
http.Error(w, "provider error", http.StatusBadGateway) http.Error(w, "provider error", http.StatusBadGateway)
return return
} }
// Store conversation - use previous_response_id if continuing, otherwise use new ID
conversationID := origReq.PreviousResponseID
if conversationID == "" {
conversationID = resp.ID
}
messages := append(fullReq.Input, resp.Output...)
s.convs.Create(conversationID, resp.Model, messages)
// Return the conversation ID (not the provider's response ID)
resp.ID = conversationID
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp) _ = json.NewEncoder(w).Encode(resp)
} }
func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.Request, provider providers.Provider, req *api.ResponseRequest) { func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.Request, provider providers.Provider, fullReq *api.ResponseRequest, origReq *api.ResponseRequest) {
// Set headers for SSE // Set headers for SSE
w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Cache-Control", "no-cache")
@@ -83,7 +132,10 @@ func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.R
return return
} }
chunkChan, errChan := provider.GenerateStream(r.Context(), req) chunkChan, errChan := provider.GenerateStream(r.Context(), fullReq)
var responseID string
var fullText string
for { for {
select { select {
@@ -92,6 +144,25 @@ func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.R
return return
} }
// Capture response ID
if chunk.ID != "" && responseID == "" {
responseID = chunk.ID
}
// Override chunk ID with conversation ID
if origReq.PreviousResponseID != "" {
chunk.ID = origReq.PreviousResponseID
} else if responseID != "" {
chunk.ID = responseID
}
// Accumulate text from deltas
if chunk.Delta != nil && len(chunk.Delta.Content) > 0 {
for _, block := range chunk.Delta.Content {
fullText += block.Text
}
}
data, err := json.Marshal(chunk) data, err := json.Marshal(chunk)
if err != nil { if err != nil {
s.logger.Printf("failed to marshal chunk: %v", err) s.logger.Printf("failed to marshal chunk: %v", err)
@@ -102,6 +173,8 @@ func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.R
flusher.Flush() flusher.Flush()
if chunk.Done { if chunk.Done {
// Store conversation with a single consolidated assistant message
s.storeStreamConversation(fullReq, origReq, responseID, fullText)
return return
} }
@@ -112,6 +185,8 @@ func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.R
fmt.Fprintf(w, "data: %s\n\n", errData) fmt.Fprintf(w, "data: %s\n\n", errData)
flusher.Flush() flusher.Flush()
} }
// Store whatever we accumulated before the error
s.storeStreamConversation(fullReq, origReq, responseID, fullText)
return return
case <-r.Context().Done(): case <-r.Context().Done():
@@ -121,6 +196,27 @@ func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.R
} }
} }
func (s *GatewayServer) storeStreamConversation(fullReq *api.ResponseRequest, origReq *api.ResponseRequest, responseID string, fullText string) {
if responseID == "" || fullText == "" {
return
}
assistantMsg := api.Message{
Role: "assistant",
Content: []api.ContentBlock{
{Type: "output_text", Text: fullText},
},
}
messages := append(fullReq.Input, assistantMsg)
conversationID := origReq.PreviousResponseID
if conversationID == "" {
conversationID = responseID
}
s.convs.Create(conversationID, fullReq.Model, messages)
}
func (s *GatewayServer) resolveProvider(req *api.ResponseRequest) (providers.Provider, error) { func (s *GatewayServer) resolveProvider(req *api.ResponseRequest) (providers.Provider, error) {
if req.Provider != "" { if req.Provider != "" {
if provider, ok := s.registry.Get(req.Provider); ok { if provider, ok := s.registry.Get(req.Provider); ok {