Make gateway Open Responses compliant
This commit is contained in:
@@ -5,6 +5,10 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/yourusername/go-llm-gateway/internal/api"
|
||||
"github.com/yourusername/go-llm-gateway/internal/conversation"
|
||||
@@ -74,16 +78,34 @@ func (s *GatewayServer) handleResponses(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
// Build full message history
|
||||
messages := s.buildMessageHistory(&req)
|
||||
if messages == nil {
|
||||
http.Error(w, "conversation not found", http.StatusNotFound)
|
||||
return
|
||||
// Normalize input to internal messages
|
||||
inputMsgs := req.NormalizeInput()
|
||||
|
||||
// Build full message history from previous conversation
|
||||
var historyMsgs []api.Message
|
||||
if req.PreviousResponseID != nil && *req.PreviousResponseID != "" {
|
||||
conv, ok := s.convs.Get(*req.PreviousResponseID)
|
||||
if !ok {
|
||||
http.Error(w, "conversation not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
historyMsgs = conv.Messages
|
||||
}
|
||||
|
||||
// Update request with full history for provider
|
||||
fullReq := req
|
||||
fullReq.Input = messages
|
||||
// Combined messages for conversation storage (history + new input, no instructions)
|
||||
storeMsgs := make([]api.Message, 0, len(historyMsgs)+len(inputMsgs))
|
||||
storeMsgs = append(storeMsgs, historyMsgs...)
|
||||
storeMsgs = append(storeMsgs, inputMsgs...)
|
||||
|
||||
// Build provider messages: instructions + history + input
|
||||
var providerMsgs []api.Message
|
||||
if req.Instructions != nil && *req.Instructions != "" {
|
||||
providerMsgs = append(providerMsgs, api.Message{
|
||||
Role: "developer",
|
||||
Content: []api.ContentBlock{{Type: "input_text", Text: *req.Instructions}},
|
||||
})
|
||||
}
|
||||
providerMsgs = append(providerMsgs, storeMsgs...)
|
||||
|
||||
provider, err := s.resolveProvider(&req)
|
||||
if err != nil {
|
||||
@@ -91,64 +113,44 @@ func (s *GatewayServer) handleResponses(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
// Resolve provider_model_id (e.g., Azure deployment name) before sending to provider
|
||||
fullReq.Model = s.registry.ResolveModelID(req.Model)
|
||||
// Resolve provider_model_id (e.g., Azure deployment name)
|
||||
resolvedReq := req
|
||||
resolvedReq.Model = s.registry.ResolveModelID(req.Model)
|
||||
|
||||
// Handle streaming vs non-streaming
|
||||
if req.Stream {
|
||||
s.handleStreamingResponse(w, r, provider, &fullReq, &req)
|
||||
s.handleStreamingResponse(w, r, provider, providerMsgs, &resolvedReq, &req, storeMsgs)
|
||||
} else {
|
||||
s.handleSyncResponse(w, r, provider, &fullReq, &req)
|
||||
s.handleSyncResponse(w, r, provider, providerMsgs, &resolvedReq, &req, storeMsgs)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayServer) buildMessageHistory(req *api.ResponseRequest) []api.Message {
|
||||
// If no previous_response_id, use input as-is
|
||||
if req.PreviousResponseID == "" {
|
||||
return req.Input
|
||||
}
|
||||
|
||||
// Load previous conversation
|
||||
conv, ok := s.convs.Get(req.PreviousResponseID)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Append new input to conversation history
|
||||
messages := make([]api.Message, len(conv.Messages))
|
||||
copy(messages, conv.Messages)
|
||||
messages = append(messages, req.Input...)
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
func (s *GatewayServer) handleSyncResponse(w http.ResponseWriter, r *http.Request, provider providers.Provider, fullReq *api.ResponseRequest, origReq *api.ResponseRequest) {
|
||||
resp, err := provider.Generate(r.Context(), fullReq)
|
||||
func (s *GatewayServer) handleSyncResponse(w http.ResponseWriter, r *http.Request, provider providers.Provider, providerMsgs []api.Message, resolvedReq *api.ResponseRequest, origReq *api.ResponseRequest, storeMsgs []api.Message) {
|
||||
result, err := provider.Generate(r.Context(), providerMsgs, resolvedReq)
|
||||
if err != nil {
|
||||
s.logger.Printf("provider %s error: %v", provider.Name(), err)
|
||||
http.Error(w, "provider error", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
// Store conversation - use previous_response_id if continuing, otherwise use new ID
|
||||
conversationID := origReq.PreviousResponseID
|
||||
if conversationID == "" {
|
||||
conversationID = resp.ID
|
||||
responseID := generateID("resp_")
|
||||
|
||||
// Build assistant message for conversation store
|
||||
assistantMsg := api.Message{
|
||||
Role: "assistant",
|
||||
Content: []api.ContentBlock{{Type: "output_text", Text: result.Text}},
|
||||
}
|
||||
|
||||
messages := append(fullReq.Input, resp.Output...)
|
||||
s.convs.Create(conversationID, resp.Model, messages)
|
||||
|
||||
// Return the conversation ID (not the provider's response ID)
|
||||
resp.ID = conversationID
|
||||
allMsgs := append(storeMsgs, assistantMsg)
|
||||
s.convs.Create(responseID, result.Model, allMsgs)
|
||||
|
||||
// Build spec-compliant response
|
||||
resp := s.buildResponse(origReq, result, provider.Name(), responseID)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.Request, provider providers.Provider, fullReq *api.ResponseRequest, origReq *api.ResponseRequest) {
|
||||
// Set headers for SSE
|
||||
func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.Request, provider providers.Provider, providerMsgs []api.Message, resolvedReq *api.ResponseRequest, origReq *api.ResponseRequest, storeMsgs []api.Message) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
@@ -160,89 +162,322 @@ func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.R
|
||||
return
|
||||
}
|
||||
|
||||
chunkChan, errChan := provider.GenerateStream(r.Context(), fullReq)
|
||||
|
||||
var responseID string
|
||||
var fullText string
|
||||
responseID := generateID("resp_")
|
||||
itemID := generateID("msg_")
|
||||
seq := 0
|
||||
outputIdx := 0
|
||||
contentIdx := 0
|
||||
|
||||
// Build initial response snapshot (in_progress, no output yet)
|
||||
initialResp := s.buildResponse(origReq, &api.ProviderResult{
|
||||
Model: origReq.Model,
|
||||
}, provider.Name(), responseID)
|
||||
initialResp.Status = "in_progress"
|
||||
initialResp.CompletedAt = nil
|
||||
initialResp.Output = []api.OutputItem{}
|
||||
initialResp.Usage = nil
|
||||
|
||||
// response.created
|
||||
s.sendSSE(w, flusher, &seq, "response.created", &api.StreamEvent{
|
||||
Type: "response.created",
|
||||
Response: initialResp,
|
||||
})
|
||||
|
||||
// response.in_progress
|
||||
s.sendSSE(w, flusher, &seq, "response.in_progress", &api.StreamEvent{
|
||||
Type: "response.in_progress",
|
||||
Response: initialResp,
|
||||
})
|
||||
|
||||
// response.output_item.added
|
||||
inProgressItem := &api.OutputItem{
|
||||
ID: itemID,
|
||||
Type: "message",
|
||||
Status: "in_progress",
|
||||
Role: "assistant",
|
||||
Content: []api.ContentPart{},
|
||||
}
|
||||
s.sendSSE(w, flusher, &seq, "response.output_item.added", &api.StreamEvent{
|
||||
Type: "response.output_item.added",
|
||||
OutputIndex: &outputIdx,
|
||||
Item: inProgressItem,
|
||||
})
|
||||
|
||||
// response.content_part.added
|
||||
emptyPart := &api.ContentPart{
|
||||
Type: "output_text",
|
||||
Text: "",
|
||||
Annotations: []api.Annotation{},
|
||||
}
|
||||
s.sendSSE(w, flusher, &seq, "response.content_part.added", &api.StreamEvent{
|
||||
Type: "response.content_part.added",
|
||||
ItemID: itemID,
|
||||
OutputIndex: &outputIdx,
|
||||
ContentIndex: &contentIdx,
|
||||
Part: emptyPart,
|
||||
})
|
||||
|
||||
// Start provider stream
|
||||
deltaChan, errChan := provider.GenerateStream(r.Context(), providerMsgs, resolvedReq)
|
||||
|
||||
var fullText string
|
||||
var streamErr error
|
||||
var providerModel string
|
||||
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case chunk, ok := <-chunkChan:
|
||||
case delta, ok := <-deltaChan:
|
||||
if !ok {
|
||||
return
|
||||
break loop
|
||||
}
|
||||
|
||||
// Capture response ID
|
||||
if chunk.ID != "" && responseID == "" {
|
||||
responseID = chunk.ID
|
||||
if delta.Model != "" && providerModel == "" {
|
||||
providerModel = delta.Model
|
||||
}
|
||||
|
||||
// Override chunk ID with conversation ID
|
||||
if origReq.PreviousResponseID != "" {
|
||||
chunk.ID = origReq.PreviousResponseID
|
||||
} else if responseID != "" {
|
||||
chunk.ID = responseID
|
||||
if delta.Text != "" {
|
||||
fullText += delta.Text
|
||||
s.sendSSE(w, flusher, &seq, "response.output_text.delta", &api.StreamEvent{
|
||||
Type: "response.output_text.delta",
|
||||
ItemID: itemID,
|
||||
OutputIndex: &outputIdx,
|
||||
ContentIndex: &contentIdx,
|
||||
Delta: delta.Text,
|
||||
})
|
||||
}
|
||||
|
||||
// Accumulate text from deltas
|
||||
if chunk.Delta != nil && len(chunk.Delta.Content) > 0 {
|
||||
for _, block := range chunk.Delta.Content {
|
||||
fullText += block.Text
|
||||
}
|
||||
if delta.Done {
|
||||
break loop
|
||||
}
|
||||
|
||||
data, err := json.Marshal(chunk)
|
||||
if err != nil {
|
||||
s.logger.Printf("failed to marshal chunk: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
|
||||
if chunk.Done {
|
||||
// Store conversation with a single consolidated assistant message
|
||||
s.storeStreamConversation(fullReq, origReq, responseID, fullText)
|
||||
return
|
||||
}
|
||||
|
||||
case err := <-errChan:
|
||||
if err != nil {
|
||||
s.logger.Printf("stream error: %v", err)
|
||||
errData, _ := json.Marshal(map[string]string{"error": err.Error()})
|
||||
fmt.Fprintf(w, "data: %s\n\n", errData)
|
||||
flusher.Flush()
|
||||
streamErr = err
|
||||
}
|
||||
// Store whatever we accumulated before the error
|
||||
s.storeStreamConversation(fullReq, origReq, responseID, fullText)
|
||||
return
|
||||
|
||||
break loop
|
||||
case <-r.Context().Done():
|
||||
s.logger.Printf("client disconnected")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayServer) storeStreamConversation(fullReq *api.ResponseRequest, origReq *api.ResponseRequest, responseID string, fullText string) {
|
||||
if responseID == "" || fullText == "" {
|
||||
if streamErr != nil {
|
||||
s.logger.Printf("stream error: %v", streamErr)
|
||||
failedResp := s.buildResponse(origReq, &api.ProviderResult{
|
||||
Model: origReq.Model,
|
||||
}, provider.Name(), responseID)
|
||||
failedResp.Status = "failed"
|
||||
failedResp.CompletedAt = nil
|
||||
failedResp.Output = []api.OutputItem{}
|
||||
failedResp.Error = &api.ResponseError{
|
||||
Type: "server_error",
|
||||
Message: streamErr.Error(),
|
||||
}
|
||||
s.sendSSE(w, flusher, &seq, "response.failed", &api.StreamEvent{
|
||||
Type: "response.failed",
|
||||
Response: failedResp,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
assistantMsg := api.Message{
|
||||
Role: "assistant",
|
||||
Content: []api.ContentBlock{
|
||||
{Type: "output_text", Text: fullText},
|
||||
},
|
||||
}
|
||||
messages := append(fullReq.Input, assistantMsg)
|
||||
// response.output_text.done
|
||||
s.sendSSE(w, flusher, &seq, "response.output_text.done", &api.StreamEvent{
|
||||
Type: "response.output_text.done",
|
||||
ItemID: itemID,
|
||||
OutputIndex: &outputIdx,
|
||||
ContentIndex: &contentIdx,
|
||||
Text: fullText,
|
||||
})
|
||||
|
||||
conversationID := origReq.PreviousResponseID
|
||||
if conversationID == "" {
|
||||
conversationID = responseID
|
||||
// response.content_part.done
|
||||
completedPart := &api.ContentPart{
|
||||
Type: "output_text",
|
||||
Text: fullText,
|
||||
Annotations: []api.Annotation{},
|
||||
}
|
||||
s.sendSSE(w, flusher, &seq, "response.content_part.done", &api.StreamEvent{
|
||||
Type: "response.content_part.done",
|
||||
ItemID: itemID,
|
||||
OutputIndex: &outputIdx,
|
||||
ContentIndex: &contentIdx,
|
||||
Part: completedPart,
|
||||
})
|
||||
|
||||
// response.output_item.done
|
||||
completedItem := &api.OutputItem{
|
||||
ID: itemID,
|
||||
Type: "message",
|
||||
Status: "completed",
|
||||
Role: "assistant",
|
||||
Content: []api.ContentPart{*completedPart},
|
||||
}
|
||||
s.sendSSE(w, flusher, &seq, "response.output_item.done", &api.StreamEvent{
|
||||
Type: "response.output_item.done",
|
||||
OutputIndex: &outputIdx,
|
||||
Item: completedItem,
|
||||
})
|
||||
|
||||
// Build final completed response
|
||||
model := origReq.Model
|
||||
if providerModel != "" {
|
||||
model = providerModel
|
||||
}
|
||||
finalResult := &api.ProviderResult{
|
||||
Model: model,
|
||||
Text: fullText,
|
||||
}
|
||||
completedResp := s.buildResponse(origReq, finalResult, provider.Name(), responseID)
|
||||
completedResp.Output[0].ID = itemID
|
||||
|
||||
// response.completed
|
||||
s.sendSSE(w, flusher, &seq, "response.completed", &api.StreamEvent{
|
||||
Type: "response.completed",
|
||||
Response: completedResp,
|
||||
})
|
||||
|
||||
// Store conversation
|
||||
if fullText != "" {
|
||||
assistantMsg := api.Message{
|
||||
Role: "assistant",
|
||||
Content: []api.ContentBlock{{Type: "output_text", Text: fullText}},
|
||||
}
|
||||
allMsgs := append(storeMsgs, assistantMsg)
|
||||
s.convs.Create(responseID, model, allMsgs)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayServer) sendSSE(w http.ResponseWriter, flusher http.Flusher, seq *int, eventType string, event *api.StreamEvent) {
|
||||
event.SequenceNumber = *seq
|
||||
*seq++
|
||||
data, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
s.logger.Printf("failed to marshal SSE event: %v", err)
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, "event: %s\ndata: %s\n\n", eventType, data)
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
func (s *GatewayServer) buildResponse(req *api.ResponseRequest, result *api.ProviderResult, providerName string, responseID string) *api.Response {
|
||||
now := time.Now().Unix()
|
||||
|
||||
model := result.Model
|
||||
if model == "" {
|
||||
model = req.Model
|
||||
}
|
||||
|
||||
s.convs.Create(conversationID, fullReq.Model, messages)
|
||||
// Build output item
|
||||
itemID := generateID("msg_")
|
||||
outputItem := api.OutputItem{
|
||||
ID: itemID,
|
||||
Type: "message",
|
||||
Status: "completed",
|
||||
Role: "assistant",
|
||||
Content: []api.ContentPart{{
|
||||
Type: "output_text",
|
||||
Text: result.Text,
|
||||
Annotations: []api.Annotation{},
|
||||
}},
|
||||
}
|
||||
|
||||
// Echo back request params with defaults
|
||||
tools := req.Tools
|
||||
if tools == nil {
|
||||
tools = json.RawMessage(`[]`)
|
||||
}
|
||||
toolChoice := req.ToolChoice
|
||||
if toolChoice == nil {
|
||||
toolChoice = json.RawMessage(`"auto"`)
|
||||
}
|
||||
text := req.Text
|
||||
if text == nil {
|
||||
text = json.RawMessage(`{"format":{"type":"text"}}`)
|
||||
}
|
||||
truncation := "disabled"
|
||||
if req.Truncation != nil {
|
||||
truncation = *req.Truncation
|
||||
}
|
||||
temperature := 1.0
|
||||
if req.Temperature != nil {
|
||||
temperature = *req.Temperature
|
||||
}
|
||||
topP := 1.0
|
||||
if req.TopP != nil {
|
||||
topP = *req.TopP
|
||||
}
|
||||
presencePenalty := 0.0
|
||||
if req.PresencePenalty != nil {
|
||||
presencePenalty = *req.PresencePenalty
|
||||
}
|
||||
frequencyPenalty := 0.0
|
||||
if req.FrequencyPenalty != nil {
|
||||
frequencyPenalty = *req.FrequencyPenalty
|
||||
}
|
||||
topLogprobs := 0
|
||||
if req.TopLogprobs != nil {
|
||||
topLogprobs = *req.TopLogprobs
|
||||
}
|
||||
parallelToolCalls := true
|
||||
if req.ParallelToolCalls != nil {
|
||||
parallelToolCalls = *req.ParallelToolCalls
|
||||
}
|
||||
store := true
|
||||
if req.Store != nil {
|
||||
store = *req.Store
|
||||
}
|
||||
background := false
|
||||
if req.Background != nil {
|
||||
background = *req.Background
|
||||
}
|
||||
serviceTier := "default"
|
||||
if req.ServiceTier != nil {
|
||||
serviceTier = *req.ServiceTier
|
||||
}
|
||||
var reasoning json.RawMessage
|
||||
if req.Reasoning != nil {
|
||||
reasoning = req.Reasoning
|
||||
}
|
||||
metadata := req.Metadata
|
||||
if metadata == nil {
|
||||
metadata = map[string]string{}
|
||||
}
|
||||
|
||||
var usage *api.Usage
|
||||
if result.Text != "" {
|
||||
usage = &result.Usage
|
||||
}
|
||||
|
||||
return &api.Response{
|
||||
ID: responseID,
|
||||
Object: "response",
|
||||
CreatedAt: now,
|
||||
CompletedAt: &now,
|
||||
Status: "completed",
|
||||
IncompleteDetails: nil,
|
||||
Model: model,
|
||||
PreviousResponseID: req.PreviousResponseID,
|
||||
Instructions: req.Instructions,
|
||||
Output: []api.OutputItem{outputItem},
|
||||
Error: nil,
|
||||
Tools: tools,
|
||||
ToolChoice: toolChoice,
|
||||
Truncation: truncation,
|
||||
ParallelToolCalls: parallelToolCalls,
|
||||
Text: text,
|
||||
TopP: topP,
|
||||
PresencePenalty: presencePenalty,
|
||||
FrequencyPenalty: frequencyPenalty,
|
||||
TopLogprobs: topLogprobs,
|
||||
Temperature: temperature,
|
||||
Reasoning: reasoning,
|
||||
Usage: usage,
|
||||
MaxOutputTokens: req.MaxOutputTokens,
|
||||
MaxToolCalls: req.MaxToolCalls,
|
||||
Store: store,
|
||||
Background: background,
|
||||
ServiceTier: serviceTier,
|
||||
Metadata: metadata,
|
||||
SafetyIdentifier: nil,
|
||||
PromptCacheKey: nil,
|
||||
Provider: providerName,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayServer) resolveProvider(req *api.ResponseRequest) (providers.Provider, error) {
|
||||
@@ -254,3 +489,8 @@ func (s *GatewayServer) resolveProvider(req *api.ResponseRequest) (providers.Pro
|
||||
}
|
||||
return s.registry.Default(req.Model)
|
||||
}
|
||||
|
||||
func generateID(prefix string) string {
|
||||
id := strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
return prefix + id[:24]
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user