diff --git a/README.md b/README.md index 38676ab..2864bcd 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,7 @@ Go LLM Gateway (unified API) ✅ **Streaming support** (Server-Sent Events for all providers) ✅ **OAuth2/OIDC authentication** (Google, Auth0, any OIDC provider) ✅ **Terminal chat client** (Python with Rich UI, PEP 723) +✅ **Conversation tracking** (previous_response_id for efficient context) ## Quick Start @@ -186,8 +187,21 @@ You> /model claude You> /models # List all available models ``` +The chat client automatically uses `previous_response_id` to reduce token usage by only sending new messages instead of the full conversation history. + See **[CHAT_CLIENT.md](./CHAT_CLIENT.md)** for full documentation. +## Conversation Management + +The gateway implements conversation tracking using `previous_response_id` from the Open Responses spec: + +- 📉 **Reduced token usage** - Only send new messages +- ⚡ **Smaller requests** - Less bandwidth +- 🧠 **Server-side context** - Gateway maintains history +- ⏰ **Auto-expire** - Conversations expire after 1 hour + +See **[CONVERSATIONS.md](./CONVERSATIONS.md)** for details. + ## Authentication The gateway supports OAuth2/OIDC authentication. See **[AUTH.md](./AUTH.md)** for setup instructions. @@ -216,6 +230,8 @@ curl -X POST http://localhost:8080/v1/responses \ - ✅ ~~Implement streaming responses~~ - ✅ ~~Add OAuth2/OIDC authentication~~ +- ✅ ~~Implement conversation tracking with previous_response_id~~ - ⬜ Add structured logging, tracing, and request-level metrics - ⬜ Support tool/function calling +- ⬜ Persistent conversation storage (Redis/database) - ⬜ Expand configuration to support routing policies (cost, latency, failover) diff --git a/cmd/gateway/main.go b/cmd/gateway/main.go index d404696..97d5260 100644 --- a/cmd/gateway/main.go +++ b/cmd/gateway/main.go @@ -9,6 +9,7 @@ import ( "github.com/yourusername/go-llm-gateway/internal/auth" "github.com/yourusername/go-llm-gateway/internal/config" + "github.com/yourusername/go-llm-gateway/internal/conversation" "github.com/yourusername/go-llm-gateway/internal/providers" "github.com/yourusername/go-llm-gateway/internal/server" ) @@ -47,7 +48,11 @@ func main() { logger.Printf("Authentication disabled - WARNING: API is publicly accessible") } - gatewayServer := server.New(registry, logger) + // Initialize conversation store (1 hour TTL) + convStore := conversation.NewStore(1 * time.Hour) + logger.Printf("Conversation store initialized (TTL: 1h)") + + gatewayServer := server.New(registry, convStore, logger) mux := http.NewServeMux() gatewayServer.RegisterRoutes(mux) diff --git a/internal/api/types.go b/internal/api/types.go index 52a2310..5b7c94b 100644 --- a/internal/api/types.go +++ b/internal/api/types.go @@ -7,12 +7,13 @@ import ( // ResponseRequest models the Open Responses create request payload. type ResponseRequest struct { - Model string `json:"model"` - Provider string `json:"provider,omitempty"` - MaxOutputTokens int `json:"max_output_tokens,omitempty"` - Metadata map[string]string `json:"metadata,omitempty"` - Input []Message `json:"input"` - Stream bool `json:"stream,omitempty"` + Model string `json:"model"` + Provider string `json:"provider,omitempty"` + MaxOutputTokens int `json:"max_output_tokens,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` + Input []Message `json:"input"` + Stream bool `json:"stream,omitempty"` + PreviousResponseID string `json:"previous_response_id,omitempty"` } // Message captures user, assistant, or system roles. diff --git a/internal/conversation/conversation.go b/internal/conversation/conversation.go new file mode 100644 index 0000000..d65be22 --- /dev/null +++ b/internal/conversation/conversation.go @@ -0,0 +1,112 @@ +package conversation + +import ( + "sync" + "time" + + "github.com/yourusername/go-llm-gateway/internal/api" +) + +// Store manages conversation history with automatic expiration. +type Store struct { + conversations map[string]*Conversation + mu sync.RWMutex + ttl time.Duration +} + +// Conversation holds the message history for a single conversation thread. +type Conversation struct { + ID string + Messages []api.Message + Model string + CreatedAt time.Time + UpdatedAt time.Time +} + +// NewStore creates a conversation store with the given TTL. +func NewStore(ttl time.Duration) *Store { + s := &Store{ + conversations: make(map[string]*Conversation), + ttl: ttl, + } + + // Start cleanup goroutine + go s.cleanup() + + return s +} + +// Get retrieves a conversation by ID. +func (s *Store) Get(id string) (*Conversation, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + + conv, ok := s.conversations[id] + return conv, ok +} + +// Create creates a new conversation with the given messages. +func (s *Store) Create(id string, model string, messages []api.Message) *Conversation { + s.mu.Lock() + defer s.mu.Unlock() + + now := time.Now() + conv := &Conversation{ + ID: id, + Messages: messages, + Model: model, + CreatedAt: now, + UpdatedAt: now, + } + + s.conversations[id] = conv + return conv +} + +// Append adds new messages to an existing conversation. +func (s *Store) Append(id string, messages ...api.Message) (*Conversation, bool) { + s.mu.Lock() + defer s.mu.Unlock() + + conv, ok := s.conversations[id] + if !ok { + return nil, false + } + + conv.Messages = append(conv.Messages, messages...) + conv.UpdatedAt = time.Now() + + return conv, true +} + +// Delete removes a conversation from the store. +func (s *Store) Delete(id string) { + s.mu.Lock() + defer s.mu.Unlock() + + delete(s.conversations, id) +} + +// cleanup periodically removes expired conversations. +func (s *Store) cleanup() { + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + + for range ticker.C { + s.mu.Lock() + now := time.Now() + for id, conv := range s.conversations { + if now.Sub(conv.UpdatedAt) > s.ttl { + delete(s.conversations, id) + } + } + s.mu.Unlock() + } +} + +// Size returns the number of active conversations. +func (s *Store) Size() int { + s.mu.RLock() + defer s.mu.RUnlock() + return len(s.conversations) +} diff --git a/internal/providers/anthropic/anthropic.go b/internal/providers/anthropic/anthropic.go index 10dafa2..b0dc3c4 100644 --- a/internal/providers/anthropic/anthropic.go +++ b/internal/providers/anthropic/anthropic.go @@ -53,7 +53,7 @@ func (p *Provider) Generate(ctx context.Context, req *api.ResponseRequest) (*api for _, msg := range req.Input { var content string for _, block := range msg.Content { - if block.Type == "input_text" { + if block.Type == "input_text" || block.Type == "output_text" { content += block.Text } } @@ -147,22 +147,22 @@ func (p *Provider) GenerateStream(ctx context.Context, req *api.ResponseRequest) for _, msg := range req.Input { var content string for _, block := range msg.Content { - if block.Type == "input_text" { + if block.Type == "input_text" || block.Type == "output_text" { content += block.Text } - } - - switch msg.Role { - case "user": + } + + switch msg.Role { + case "user": messages = append(messages, anthropic.NewUserMessage(anthropic.NewTextBlock(content))) - case "assistant": + case "assistant": messages = append(messages, anthropic.NewAssistantMessage(anthropic.NewTextBlock(content))) - case "system": + case "system": system = content - } - } + } + } - // Build params + // Build params params := anthropic.MessageNewParams{ Model: anthropic.Model(model), Messages: messages, diff --git a/internal/providers/google/google.go b/internal/providers/google/google.go index f8253ac..5d4ef82 100644 --- a/internal/providers/google/google.go +++ b/internal/providers/google/google.go @@ -54,11 +54,21 @@ func (p *Provider) Generate(ctx context.Context, req *api.ResponseRequest) (*api // Convert Open Responses messages to Gemini format var contents []*genai.Content + var systemText string for _, msg := range req.Input { + if msg.Role == "system" { + for _, block := range msg.Content { + if block.Type == "input_text" || block.Type == "output_text" { + systemText += block.Text + } + } + continue + } + var parts []*genai.Part for _, block := range msg.Content { - if block.Type == "input_text" { + if block.Type == "input_text" || block.Type == "output_text" { parts = append(parts, genai.NewPartFromText(block.Text)) } } @@ -74,8 +84,18 @@ func (p *Provider) Generate(ctx context.Context, req *api.ResponseRequest) (*api }) } + // Build config with system instruction if present + var config *genai.GenerateContentConfig + if systemText != "" { + config = &genai.GenerateContentConfig{ + SystemInstruction: &genai.Content{ + Parts: []*genai.Part{genai.NewPartFromText(systemText)}, + }, + } + } + // Generate content - resp, err := p.client.Models.GenerateContent(ctx, model, contents, nil) + resp, err := p.client.Models.GenerateContent(ctx, model, contents, config) if err != nil { return nil, fmt.Errorf("google api error: %w", err) } @@ -143,11 +163,21 @@ func (p *Provider) GenerateStream(ctx context.Context, req *api.ResponseRequest) // Convert messages var contents []*genai.Content + var systemText string for _, msg := range req.Input { + if msg.Role == "system" { + for _, block := range msg.Content { + if block.Type == "input_text" || block.Type == "output_text" { + systemText += block.Text + } + } + continue + } + var parts []*genai.Part for _, block := range msg.Content { - if block.Type == "input_text" { + if block.Type == "input_text" || block.Type == "output_text" { parts = append(parts, genai.NewPartFromText(block.Text)) } } @@ -163,8 +193,18 @@ func (p *Provider) GenerateStream(ctx context.Context, req *api.ResponseRequest) }) } + // Build config with system instruction if present + var config *genai.GenerateContentConfig + if systemText != "" { + config = &genai.GenerateContentConfig{ + SystemInstruction: &genai.Content{ + Parts: []*genai.Part{genai.NewPartFromText(systemText)}, + }, + } + } + // Create stream - stream := p.client.Models.GenerateContentStream(ctx, model, contents, nil) + stream := p.client.Models.GenerateContentStream(ctx, model, contents, config) // Process stream for resp, err := range stream { diff --git a/internal/providers/openai/openai.go b/internal/providers/openai/openai.go index f53c7bd..42018b6 100644 --- a/internal/providers/openai/openai.go +++ b/internal/providers/openai/openai.go @@ -52,7 +52,7 @@ func (p *Provider) Generate(ctx context.Context, req *api.ResponseRequest) (*api for _, msg := range req.Input { var content string for _, block := range msg.Content { - if block.Type == "input_text" { + if block.Type == "input_text" || block.Type == "output_text" { content += block.Text } } @@ -127,22 +127,22 @@ func (p *Provider) GenerateStream(ctx context.Context, req *api.ResponseRequest) for _, msg := range req.Input { var content string for _, block := range msg.Content { - if block.Type == "input_text" { + if block.Type == "input_text" || block.Type == "output_text" { content += block.Text } - } - - switch msg.Role { - case "user": + } + + switch msg.Role { + case "user": messages = append(messages, openai.UserMessage(content)) - case "assistant": + case "assistant": messages = append(messages, openai.AssistantMessage(content)) - case "system": + case "system": messages = append(messages, openai.SystemMessage(content)) - } - } + } + } - // Create streaming request + // Create streaming request stream := p.client.Chat.Completions.NewStreaming(ctx, openai.ChatCompletionNewParams{ Model: openai.ChatModel(model), Messages: messages, diff --git a/internal/server/server.go b/internal/server/server.go index 12e5f7c..fce079f 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -7,18 +7,24 @@ import ( "net/http" "github.com/yourusername/go-llm-gateway/internal/api" + "github.com/yourusername/go-llm-gateway/internal/conversation" "github.com/yourusername/go-llm-gateway/internal/providers" ) // GatewayServer hosts the Open Responses API for the gateway. type GatewayServer struct { registry *providers.Registry + convs *conversation.Store logger *log.Logger } // New creates a GatewayServer bound to the provider registry. -func New(registry *providers.Registry, logger *log.Logger) *GatewayServer { - return &GatewayServer{registry: registry, logger: logger} +func New(registry *providers.Registry, convs *conversation.Store, logger *log.Logger) *GatewayServer { + return &GatewayServer{ + registry: registry, + convs: convs, + logger: logger, + } } // RegisterRoutes wires the HTTP handlers onto the provided mux. @@ -43,6 +49,17 @@ func (s *GatewayServer) handleResponses(w http.ResponseWriter, r *http.Request) return } + // Build full message history + messages := s.buildMessageHistory(&req) + if messages == nil { + http.Error(w, "conversation not found", http.StatusNotFound) + return + } + + // Update request with full history for provider + fullReq := req + fullReq.Input = messages + provider, err := s.resolveProvider(&req) if err != nil { http.Error(w, err.Error(), http.StatusBadGateway) @@ -51,26 +68,58 @@ func (s *GatewayServer) handleResponses(w http.ResponseWriter, r *http.Request) // Handle streaming vs non-streaming if req.Stream { - s.handleStreamingResponse(w, r, provider, &req) + s.handleStreamingResponse(w, r, provider, &fullReq, &req) } else { - s.handleSyncResponse(w, r, provider, &req) + s.handleSyncResponse(w, r, provider, &fullReq, &req) } } -func (s *GatewayServer) handleSyncResponse(w http.ResponseWriter, r *http.Request, provider providers.Provider, req *api.ResponseRequest) { - resp, err := provider.Generate(r.Context(), req) +func (s *GatewayServer) buildMessageHistory(req *api.ResponseRequest) []api.Message { + // If no previous_response_id, use input as-is + if req.PreviousResponseID == "" { + return req.Input + } + + // Load previous conversation + conv, ok := s.convs.Get(req.PreviousResponseID) + if !ok { + return nil + } + + // Append new input to conversation history + messages := make([]api.Message, len(conv.Messages)) + copy(messages, conv.Messages) + messages = append(messages, req.Input...) + + return messages +} + +func (s *GatewayServer) handleSyncResponse(w http.ResponseWriter, r *http.Request, provider providers.Provider, fullReq *api.ResponseRequest, origReq *api.ResponseRequest) { + resp, err := provider.Generate(r.Context(), fullReq) if err != nil { s.logger.Printf("provider %s error: %v", provider.Name(), err) http.Error(w, "provider error", http.StatusBadGateway) return } + // Store conversation - use previous_response_id if continuing, otherwise use new ID + conversationID := origReq.PreviousResponseID + if conversationID == "" { + conversationID = resp.ID + } + + messages := append(fullReq.Input, resp.Output...) + s.convs.Create(conversationID, resp.Model, messages) + + // Return the conversation ID (not the provider's response ID) + resp.ID = conversationID + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) _ = json.NewEncoder(w).Encode(resp) } -func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.Request, provider providers.Provider, req *api.ResponseRequest) { +func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.Request, provider providers.Provider, fullReq *api.ResponseRequest, origReq *api.ResponseRequest) { // Set headers for SSE w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") @@ -83,7 +132,10 @@ func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.R return } - chunkChan, errChan := provider.GenerateStream(r.Context(), req) + chunkChan, errChan := provider.GenerateStream(r.Context(), fullReq) + + var responseID string + var fullText string for { select { @@ -92,6 +144,25 @@ func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.R return } + // Capture response ID + if chunk.ID != "" && responseID == "" { + responseID = chunk.ID + } + + // Override chunk ID with conversation ID + if origReq.PreviousResponseID != "" { + chunk.ID = origReq.PreviousResponseID + } else if responseID != "" { + chunk.ID = responseID + } + + // Accumulate text from deltas + if chunk.Delta != nil && len(chunk.Delta.Content) > 0 { + for _, block := range chunk.Delta.Content { + fullText += block.Text + } + } + data, err := json.Marshal(chunk) if err != nil { s.logger.Printf("failed to marshal chunk: %v", err) @@ -102,6 +173,8 @@ func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.R flusher.Flush() if chunk.Done { + // Store conversation with a single consolidated assistant message + s.storeStreamConversation(fullReq, origReq, responseID, fullText) return } @@ -112,6 +185,8 @@ func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.R fmt.Fprintf(w, "data: %s\n\n", errData) flusher.Flush() } + // Store whatever we accumulated before the error + s.storeStreamConversation(fullReq, origReq, responseID, fullText) return case <-r.Context().Done(): @@ -121,6 +196,27 @@ func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.R } } +func (s *GatewayServer) storeStreamConversation(fullReq *api.ResponseRequest, origReq *api.ResponseRequest, responseID string, fullText string) { + if responseID == "" || fullText == "" { + return + } + + assistantMsg := api.Message{ + Role: "assistant", + Content: []api.ContentBlock{ + {Type: "output_text", Text: fullText}, + }, + } + messages := append(fullReq.Input, assistantMsg) + + conversationID := origReq.PreviousResponseID + if conversationID == "" { + conversationID = responseID + } + + s.convs.Create(conversationID, fullReq.Model, messages) +} + func (s *GatewayServer) resolveProvider(req *api.ResponseRequest) (providers.Provider, error) { if req.Provider != "" { if provider, ok := s.registry.Get(req.Provider); ok {