Add conversation management
This commit is contained in:
16
README.md
16
README.md
@@ -54,6 +54,7 @@ Go LLM Gateway (unified API)
|
|||||||
✅ **Streaming support** (Server-Sent Events for all providers)
|
✅ **Streaming support** (Server-Sent Events for all providers)
|
||||||
✅ **OAuth2/OIDC authentication** (Google, Auth0, any OIDC provider)
|
✅ **OAuth2/OIDC authentication** (Google, Auth0, any OIDC provider)
|
||||||
✅ **Terminal chat client** (Python with Rich UI, PEP 723)
|
✅ **Terminal chat client** (Python with Rich UI, PEP 723)
|
||||||
|
✅ **Conversation tracking** (previous_response_id for efficient context)
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
@@ -186,8 +187,21 @@ You> /model claude
|
|||||||
You> /models # List all available models
|
You> /models # List all available models
|
||||||
```
|
```
|
||||||
|
|
||||||
|
The chat client automatically uses `previous_response_id` to reduce token usage by only sending new messages instead of the full conversation history.
|
||||||
|
|
||||||
See **[CHAT_CLIENT.md](./CHAT_CLIENT.md)** for full documentation.
|
See **[CHAT_CLIENT.md](./CHAT_CLIENT.md)** for full documentation.
|
||||||
|
|
||||||
|
## Conversation Management
|
||||||
|
|
||||||
|
The gateway implements conversation tracking using `previous_response_id` from the Open Responses spec:
|
||||||
|
|
||||||
|
- 📉 **Reduced token usage** - Only send new messages
|
||||||
|
- ⚡ **Smaller requests** - Less bandwidth
|
||||||
|
- 🧠 **Server-side context** - Gateway maintains history
|
||||||
|
- ⏰ **Auto-expire** - Conversations expire after 1 hour
|
||||||
|
|
||||||
|
See **[CONVERSATIONS.md](./CONVERSATIONS.md)** for details.
|
||||||
|
|
||||||
## Authentication
|
## Authentication
|
||||||
|
|
||||||
The gateway supports OAuth2/OIDC authentication. See **[AUTH.md](./AUTH.md)** for setup instructions.
|
The gateway supports OAuth2/OIDC authentication. See **[AUTH.md](./AUTH.md)** for setup instructions.
|
||||||
@@ -216,6 +230,8 @@ curl -X POST http://localhost:8080/v1/responses \
|
|||||||
|
|
||||||
- ✅ ~~Implement streaming responses~~
|
- ✅ ~~Implement streaming responses~~
|
||||||
- ✅ ~~Add OAuth2/OIDC authentication~~
|
- ✅ ~~Add OAuth2/OIDC authentication~~
|
||||||
|
- ✅ ~~Implement conversation tracking with previous_response_id~~
|
||||||
- ⬜ Add structured logging, tracing, and request-level metrics
|
- ⬜ Add structured logging, tracing, and request-level metrics
|
||||||
- ⬜ Support tool/function calling
|
- ⬜ Support tool/function calling
|
||||||
|
- ⬜ Persistent conversation storage (Redis/database)
|
||||||
- ⬜ Expand configuration to support routing policies (cost, latency, failover)
|
- ⬜ Expand configuration to support routing policies (cost, latency, failover)
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
"github.com/yourusername/go-llm-gateway/internal/auth"
|
"github.com/yourusername/go-llm-gateway/internal/auth"
|
||||||
"github.com/yourusername/go-llm-gateway/internal/config"
|
"github.com/yourusername/go-llm-gateway/internal/config"
|
||||||
|
"github.com/yourusername/go-llm-gateway/internal/conversation"
|
||||||
"github.com/yourusername/go-llm-gateway/internal/providers"
|
"github.com/yourusername/go-llm-gateway/internal/providers"
|
||||||
"github.com/yourusername/go-llm-gateway/internal/server"
|
"github.com/yourusername/go-llm-gateway/internal/server"
|
||||||
)
|
)
|
||||||
@@ -47,7 +48,11 @@ func main() {
|
|||||||
logger.Printf("Authentication disabled - WARNING: API is publicly accessible")
|
logger.Printf("Authentication disabled - WARNING: API is publicly accessible")
|
||||||
}
|
}
|
||||||
|
|
||||||
gatewayServer := server.New(registry, logger)
|
// Initialize conversation store (1 hour TTL)
|
||||||
|
convStore := conversation.NewStore(1 * time.Hour)
|
||||||
|
logger.Printf("Conversation store initialized (TTL: 1h)")
|
||||||
|
|
||||||
|
gatewayServer := server.New(registry, convStore, logger)
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
gatewayServer.RegisterRoutes(mux)
|
gatewayServer.RegisterRoutes(mux)
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ type ResponseRequest struct {
|
|||||||
Metadata map[string]string `json:"metadata,omitempty"`
|
Metadata map[string]string `json:"metadata,omitempty"`
|
||||||
Input []Message `json:"input"`
|
Input []Message `json:"input"`
|
||||||
Stream bool `json:"stream,omitempty"`
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
PreviousResponseID string `json:"previous_response_id,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Message captures user, assistant, or system roles.
|
// Message captures user, assistant, or system roles.
|
||||||
|
|||||||
112
internal/conversation/conversation.go
Normal file
112
internal/conversation/conversation.go
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
package conversation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/yourusername/go-llm-gateway/internal/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Store manages conversation history with automatic expiration.
|
||||||
|
type Store struct {
|
||||||
|
conversations map[string]*Conversation
|
||||||
|
mu sync.RWMutex
|
||||||
|
ttl time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// Conversation holds the message history for a single conversation thread.
|
||||||
|
type Conversation struct {
|
||||||
|
ID string
|
||||||
|
Messages []api.Message
|
||||||
|
Model string
|
||||||
|
CreatedAt time.Time
|
||||||
|
UpdatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewStore creates a conversation store with the given TTL.
|
||||||
|
func NewStore(ttl time.Duration) *Store {
|
||||||
|
s := &Store{
|
||||||
|
conversations: make(map[string]*Conversation),
|
||||||
|
ttl: ttl,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start cleanup goroutine
|
||||||
|
go s.cleanup()
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a conversation by ID.
|
||||||
|
func (s *Store) Get(id string) (*Conversation, bool) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
conv, ok := s.conversations[id]
|
||||||
|
return conv, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create creates a new conversation with the given messages.
|
||||||
|
func (s *Store) Create(id string, model string, messages []api.Message) *Conversation {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
conv := &Conversation{
|
||||||
|
ID: id,
|
||||||
|
Messages: messages,
|
||||||
|
Model: model,
|
||||||
|
CreatedAt: now,
|
||||||
|
UpdatedAt: now,
|
||||||
|
}
|
||||||
|
|
||||||
|
s.conversations[id] = conv
|
||||||
|
return conv
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append adds new messages to an existing conversation.
|
||||||
|
func (s *Store) Append(id string, messages ...api.Message) (*Conversation, bool) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
conv, ok := s.conversations[id]
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
conv.Messages = append(conv.Messages, messages...)
|
||||||
|
conv.UpdatedAt = time.Now()
|
||||||
|
|
||||||
|
return conv, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete removes a conversation from the store.
|
||||||
|
func (s *Store) Delete(id string) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
delete(s.conversations, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanup periodically removes expired conversations.
|
||||||
|
func (s *Store) cleanup() {
|
||||||
|
ticker := time.NewTicker(1 * time.Minute)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for range ticker.C {
|
||||||
|
s.mu.Lock()
|
||||||
|
now := time.Now()
|
||||||
|
for id, conv := range s.conversations {
|
||||||
|
if now.Sub(conv.UpdatedAt) > s.ttl {
|
||||||
|
delete(s.conversations, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Size returns the number of active conversations.
|
||||||
|
func (s *Store) Size() int {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
return len(s.conversations)
|
||||||
|
}
|
||||||
@@ -53,7 +53,7 @@ func (p *Provider) Generate(ctx context.Context, req *api.ResponseRequest) (*api
|
|||||||
for _, msg := range req.Input {
|
for _, msg := range req.Input {
|
||||||
var content string
|
var content string
|
||||||
for _, block := range msg.Content {
|
for _, block := range msg.Content {
|
||||||
if block.Type == "input_text" {
|
if block.Type == "input_text" || block.Type == "output_text" {
|
||||||
content += block.Text
|
content += block.Text
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -147,7 +147,7 @@ func (p *Provider) GenerateStream(ctx context.Context, req *api.ResponseRequest)
|
|||||||
for _, msg := range req.Input {
|
for _, msg := range req.Input {
|
||||||
var content string
|
var content string
|
||||||
for _, block := range msg.Content {
|
for _, block := range msg.Content {
|
||||||
if block.Type == "input_text" {
|
if block.Type == "input_text" || block.Type == "output_text" {
|
||||||
content += block.Text
|
content += block.Text
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -54,11 +54,21 @@ func (p *Provider) Generate(ctx context.Context, req *api.ResponseRequest) (*api
|
|||||||
|
|
||||||
// Convert Open Responses messages to Gemini format
|
// Convert Open Responses messages to Gemini format
|
||||||
var contents []*genai.Content
|
var contents []*genai.Content
|
||||||
|
var systemText string
|
||||||
|
|
||||||
for _, msg := range req.Input {
|
for _, msg := range req.Input {
|
||||||
|
if msg.Role == "system" {
|
||||||
|
for _, block := range msg.Content {
|
||||||
|
if block.Type == "input_text" || block.Type == "output_text" {
|
||||||
|
systemText += block.Text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
var parts []*genai.Part
|
var parts []*genai.Part
|
||||||
for _, block := range msg.Content {
|
for _, block := range msg.Content {
|
||||||
if block.Type == "input_text" {
|
if block.Type == "input_text" || block.Type == "output_text" {
|
||||||
parts = append(parts, genai.NewPartFromText(block.Text))
|
parts = append(parts, genai.NewPartFromText(block.Text))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -74,8 +84,18 @@ func (p *Provider) Generate(ctx context.Context, req *api.ResponseRequest) (*api
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Build config with system instruction if present
|
||||||
|
var config *genai.GenerateContentConfig
|
||||||
|
if systemText != "" {
|
||||||
|
config = &genai.GenerateContentConfig{
|
||||||
|
SystemInstruction: &genai.Content{
|
||||||
|
Parts: []*genai.Part{genai.NewPartFromText(systemText)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Generate content
|
// Generate content
|
||||||
resp, err := p.client.Models.GenerateContent(ctx, model, contents, nil)
|
resp, err := p.client.Models.GenerateContent(ctx, model, contents, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("google api error: %w", err)
|
return nil, fmt.Errorf("google api error: %w", err)
|
||||||
}
|
}
|
||||||
@@ -143,11 +163,21 @@ func (p *Provider) GenerateStream(ctx context.Context, req *api.ResponseRequest)
|
|||||||
|
|
||||||
// Convert messages
|
// Convert messages
|
||||||
var contents []*genai.Content
|
var contents []*genai.Content
|
||||||
|
var systemText string
|
||||||
|
|
||||||
for _, msg := range req.Input {
|
for _, msg := range req.Input {
|
||||||
|
if msg.Role == "system" {
|
||||||
|
for _, block := range msg.Content {
|
||||||
|
if block.Type == "input_text" || block.Type == "output_text" {
|
||||||
|
systemText += block.Text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
var parts []*genai.Part
|
var parts []*genai.Part
|
||||||
for _, block := range msg.Content {
|
for _, block := range msg.Content {
|
||||||
if block.Type == "input_text" {
|
if block.Type == "input_text" || block.Type == "output_text" {
|
||||||
parts = append(parts, genai.NewPartFromText(block.Text))
|
parts = append(parts, genai.NewPartFromText(block.Text))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -163,8 +193,18 @@ func (p *Provider) GenerateStream(ctx context.Context, req *api.ResponseRequest)
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Build config with system instruction if present
|
||||||
|
var config *genai.GenerateContentConfig
|
||||||
|
if systemText != "" {
|
||||||
|
config = &genai.GenerateContentConfig{
|
||||||
|
SystemInstruction: &genai.Content{
|
||||||
|
Parts: []*genai.Part{genai.NewPartFromText(systemText)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Create stream
|
// Create stream
|
||||||
stream := p.client.Models.GenerateContentStream(ctx, model, contents, nil)
|
stream := p.client.Models.GenerateContentStream(ctx, model, contents, config)
|
||||||
|
|
||||||
// Process stream
|
// Process stream
|
||||||
for resp, err := range stream {
|
for resp, err := range stream {
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ func (p *Provider) Generate(ctx context.Context, req *api.ResponseRequest) (*api
|
|||||||
for _, msg := range req.Input {
|
for _, msg := range req.Input {
|
||||||
var content string
|
var content string
|
||||||
for _, block := range msg.Content {
|
for _, block := range msg.Content {
|
||||||
if block.Type == "input_text" {
|
if block.Type == "input_text" || block.Type == "output_text" {
|
||||||
content += block.Text
|
content += block.Text
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -127,7 +127,7 @@ func (p *Provider) GenerateStream(ctx context.Context, req *api.ResponseRequest)
|
|||||||
for _, msg := range req.Input {
|
for _, msg := range req.Input {
|
||||||
var content string
|
var content string
|
||||||
for _, block := range msg.Content {
|
for _, block := range msg.Content {
|
||||||
if block.Type == "input_text" {
|
if block.Type == "input_text" || block.Type == "output_text" {
|
||||||
content += block.Text
|
content += block.Text
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,18 +7,24 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/yourusername/go-llm-gateway/internal/api"
|
"github.com/yourusername/go-llm-gateway/internal/api"
|
||||||
|
"github.com/yourusername/go-llm-gateway/internal/conversation"
|
||||||
"github.com/yourusername/go-llm-gateway/internal/providers"
|
"github.com/yourusername/go-llm-gateway/internal/providers"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GatewayServer hosts the Open Responses API for the gateway.
|
// GatewayServer hosts the Open Responses API for the gateway.
|
||||||
type GatewayServer struct {
|
type GatewayServer struct {
|
||||||
registry *providers.Registry
|
registry *providers.Registry
|
||||||
|
convs *conversation.Store
|
||||||
logger *log.Logger
|
logger *log.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a GatewayServer bound to the provider registry.
|
// New creates a GatewayServer bound to the provider registry.
|
||||||
func New(registry *providers.Registry, logger *log.Logger) *GatewayServer {
|
func New(registry *providers.Registry, convs *conversation.Store, logger *log.Logger) *GatewayServer {
|
||||||
return &GatewayServer{registry: registry, logger: logger}
|
return &GatewayServer{
|
||||||
|
registry: registry,
|
||||||
|
convs: convs,
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterRoutes wires the HTTP handlers onto the provided mux.
|
// RegisterRoutes wires the HTTP handlers onto the provided mux.
|
||||||
@@ -43,6 +49,17 @@ func (s *GatewayServer) handleResponses(w http.ResponseWriter, r *http.Request)
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Build full message history
|
||||||
|
messages := s.buildMessageHistory(&req)
|
||||||
|
if messages == nil {
|
||||||
|
http.Error(w, "conversation not found", http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update request with full history for provider
|
||||||
|
fullReq := req
|
||||||
|
fullReq.Input = messages
|
||||||
|
|
||||||
provider, err := s.resolveProvider(&req)
|
provider, err := s.resolveProvider(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||||
@@ -51,26 +68,58 @@ func (s *GatewayServer) handleResponses(w http.ResponseWriter, r *http.Request)
|
|||||||
|
|
||||||
// Handle streaming vs non-streaming
|
// Handle streaming vs non-streaming
|
||||||
if req.Stream {
|
if req.Stream {
|
||||||
s.handleStreamingResponse(w, r, provider, &req)
|
s.handleStreamingResponse(w, r, provider, &fullReq, &req)
|
||||||
} else {
|
} else {
|
||||||
s.handleSyncResponse(w, r, provider, &req)
|
s.handleSyncResponse(w, r, provider, &fullReq, &req)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayServer) handleSyncResponse(w http.ResponseWriter, r *http.Request, provider providers.Provider, req *api.ResponseRequest) {
|
func (s *GatewayServer) buildMessageHistory(req *api.ResponseRequest) []api.Message {
|
||||||
resp, err := provider.Generate(r.Context(), req)
|
// If no previous_response_id, use input as-is
|
||||||
|
if req.PreviousResponseID == "" {
|
||||||
|
return req.Input
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load previous conversation
|
||||||
|
conv, ok := s.convs.Get(req.PreviousResponseID)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append new input to conversation history
|
||||||
|
messages := make([]api.Message, len(conv.Messages))
|
||||||
|
copy(messages, conv.Messages)
|
||||||
|
messages = append(messages, req.Input...)
|
||||||
|
|
||||||
|
return messages
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayServer) handleSyncResponse(w http.ResponseWriter, r *http.Request, provider providers.Provider, fullReq *api.ResponseRequest, origReq *api.ResponseRequest) {
|
||||||
|
resp, err := provider.Generate(r.Context(), fullReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Printf("provider %s error: %v", provider.Name(), err)
|
s.logger.Printf("provider %s error: %v", provider.Name(), err)
|
||||||
http.Error(w, "provider error", http.StatusBadGateway)
|
http.Error(w, "provider error", http.StatusBadGateway)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Store conversation - use previous_response_id if continuing, otherwise use new ID
|
||||||
|
conversationID := origReq.PreviousResponseID
|
||||||
|
if conversationID == "" {
|
||||||
|
conversationID = resp.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
messages := append(fullReq.Input, resp.Output...)
|
||||||
|
s.convs.Create(conversationID, resp.Model, messages)
|
||||||
|
|
||||||
|
// Return the conversation ID (not the provider's response ID)
|
||||||
|
resp.ID = conversationID
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
_ = json.NewEncoder(w).Encode(resp)
|
_ = json.NewEncoder(w).Encode(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.Request, provider providers.Provider, req *api.ResponseRequest) {
|
func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.Request, provider providers.Provider, fullReq *api.ResponseRequest, origReq *api.ResponseRequest) {
|
||||||
// Set headers for SSE
|
// Set headers for SSE
|
||||||
w.Header().Set("Content-Type", "text/event-stream")
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
w.Header().Set("Cache-Control", "no-cache")
|
w.Header().Set("Cache-Control", "no-cache")
|
||||||
@@ -83,7 +132,10 @@ func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.R
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
chunkChan, errChan := provider.GenerateStream(r.Context(), req)
|
chunkChan, errChan := provider.GenerateStream(r.Context(), fullReq)
|
||||||
|
|
||||||
|
var responseID string
|
||||||
|
var fullText string
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@@ -92,6 +144,25 @@ func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.R
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Capture response ID
|
||||||
|
if chunk.ID != "" && responseID == "" {
|
||||||
|
responseID = chunk.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override chunk ID with conversation ID
|
||||||
|
if origReq.PreviousResponseID != "" {
|
||||||
|
chunk.ID = origReq.PreviousResponseID
|
||||||
|
} else if responseID != "" {
|
||||||
|
chunk.ID = responseID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accumulate text from deltas
|
||||||
|
if chunk.Delta != nil && len(chunk.Delta.Content) > 0 {
|
||||||
|
for _, block := range chunk.Delta.Content {
|
||||||
|
fullText += block.Text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
data, err := json.Marshal(chunk)
|
data, err := json.Marshal(chunk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Printf("failed to marshal chunk: %v", err)
|
s.logger.Printf("failed to marshal chunk: %v", err)
|
||||||
@@ -102,6 +173,8 @@ func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.R
|
|||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
|
|
||||||
if chunk.Done {
|
if chunk.Done {
|
||||||
|
// Store conversation with a single consolidated assistant message
|
||||||
|
s.storeStreamConversation(fullReq, origReq, responseID, fullText)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -112,6 +185,8 @@ func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.R
|
|||||||
fmt.Fprintf(w, "data: %s\n\n", errData)
|
fmt.Fprintf(w, "data: %s\n\n", errData)
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
|
// Store whatever we accumulated before the error
|
||||||
|
s.storeStreamConversation(fullReq, origReq, responseID, fullText)
|
||||||
return
|
return
|
||||||
|
|
||||||
case <-r.Context().Done():
|
case <-r.Context().Done():
|
||||||
@@ -121,6 +196,27 @@ func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.R
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *GatewayServer) storeStreamConversation(fullReq *api.ResponseRequest, origReq *api.ResponseRequest, responseID string, fullText string) {
|
||||||
|
if responseID == "" || fullText == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
assistantMsg := api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: []api.ContentBlock{
|
||||||
|
{Type: "output_text", Text: fullText},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
messages := append(fullReq.Input, assistantMsg)
|
||||||
|
|
||||||
|
conversationID := origReq.PreviousResponseID
|
||||||
|
if conversationID == "" {
|
||||||
|
conversationID = responseID
|
||||||
|
}
|
||||||
|
|
||||||
|
s.convs.Create(conversationID, fullReq.Model, messages)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *GatewayServer) resolveProvider(req *api.ResponseRequest) (providers.Provider, error) {
|
func (s *GatewayServer) resolveProvider(req *api.ResponseRequest) (providers.Provider, error) {
|
||||||
if req.Provider != "" {
|
if req.Provider != "" {
|
||||||
if provider, ok := s.registry.Get(req.Provider); ok {
|
if provider, ok := s.registry.Get(req.Provider); ok {
|
||||||
|
|||||||
Reference in New Issue
Block a user