From 3e645a35257d82d825456e3e33b92d5fdf482a93 Mon Sep 17 00:00:00 2001 From: Anibal Angulo Date: Mon, 2 Mar 2026 04:21:29 +0000 Subject: [PATCH] Make gateway Open Responses compliant --- internal/api/types.go | 306 ++++++++++++--- internal/providers/anthropic/anthropic.go | 168 ++++---- internal/providers/google/google.go | 222 +++++------ internal/providers/openai/openai.go | 157 ++++---- internal/providers/providers.go | 4 +- internal/server/server.go | 454 +++++++++++++++++----- 6 files changed, 858 insertions(+), 453 deletions(-) diff --git a/internal/api/types.go b/internal/api/types.go index a96ddd6..b0b3565 100644 --- a/internal/api/types.go +++ b/internal/api/types.go @@ -1,81 +1,307 @@ package api import ( + "encoding/json" "errors" "fmt" ) -// ResponseRequest models the Open Responses create request payload. +// ============================================================ +// Request Types (CreateResponseBody) +// ============================================================ + +// ResponseRequest models the OpenResponses CreateResponseBody. type ResponseRequest struct { Model string `json:"model"` - Provider string `json:"provider,omitempty"` - MaxOutputTokens int `json:"max_output_tokens,omitempty"` + Input InputUnion `json:"input"` + Instructions *string `json:"instructions,omitempty"` + MaxOutputTokens *int `json:"max_output_tokens,omitempty"` Metadata map[string]string `json:"metadata,omitempty"` - Input []Message `json:"input"` Stream bool `json:"stream,omitempty"` - PreviousResponseID string `json:"previous_response_id,omitempty"` + PreviousResponseID *string `json:"previous_response_id,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + TopLogprobs *int `json:"top_logprobs,omitempty"` + Truncation *string `json:"truncation,omitempty"` + ToolChoice json.RawMessage `json:"tool_choice,omitempty"` + Tools json.RawMessage `json:"tools,omitempty"` + ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` + Store *bool `json:"store,omitempty"` + Text json.RawMessage `json:"text,omitempty"` + Reasoning json.RawMessage `json:"reasoning,omitempty"` + Include []string `json:"include,omitempty"` + ServiceTier *string `json:"service_tier,omitempty"` + Background *bool `json:"background,omitempty"` + StreamOptions json.RawMessage `json:"stream_options,omitempty"` + MaxToolCalls *int `json:"max_tool_calls,omitempty"` + + // Non-spec extension: allows client to select a specific provider. + Provider string `json:"provider,omitempty"` } -// Message captures user, assistant, or system roles. +// InputUnion handles the polymorphic "input" field: string or []InputItem. +type InputUnion struct { + String *string + Items []InputItem +} + +func (u *InputUnion) UnmarshalJSON(data []byte) error { + if string(data) == "null" { + return nil + } + var s string + if err := json.Unmarshal(data, &s); err == nil { + u.String = &s + return nil + } + var items []InputItem + if err := json.Unmarshal(data, &items); err == nil { + u.Items = items + return nil + } + return fmt.Errorf("input must be a string or array of items") +} + +func (u InputUnion) MarshalJSON() ([]byte, error) { + if u.String != nil { + return json.Marshal(*u.String) + } + if u.Items != nil { + return json.Marshal(u.Items) + } + return []byte("null"), nil +} + +// InputItem is a discriminated union on "type". +// Valid types: message, item_reference, function_call, function_call_output, reasoning. +type InputItem struct { + Type string `json:"type"` + Role string `json:"role,omitempty"` + Content json.RawMessage `json:"content,omitempty"` + ID string `json:"id,omitempty"` + CallID string `json:"call_id,omitempty"` + Name string `json:"name,omitempty"` + Arguments string `json:"arguments,omitempty"` + Output string `json:"output,omitempty"` + Status string `json:"status,omitempty"` +} + +// ============================================================ +// Internal Types (providers + conversation store) +// ============================================================ + +// Message is the normalized internal message representation. type Message struct { Role string `json:"role"` Content []ContentBlock `json:"content"` } -// ContentBlock represents a typed content element (text, data, tool call, etc.). +// ContentBlock is a typed content element. type ContentBlock struct { Type string `json:"type"` Text string `json:"text,omitempty"` } -// Response is a simplified Open Responses response payload. +// NormalizeInput converts the request Input into messages for providers. +// Does NOT include instructions (the server prepends those separately). +func (r *ResponseRequest) NormalizeInput() []Message { + if r.Input.String != nil { + return []Message{{ + Role: "user", + Content: []ContentBlock{{Type: "input_text", Text: *r.Input.String}}, + }} + } + + var msgs []Message + for _, item := range r.Input.Items { + switch item.Type { + case "message", "": + msg := Message{Role: item.Role} + if item.Content != nil { + var s string + if err := json.Unmarshal(item.Content, &s); err == nil { + contentType := "input_text" + if item.Role == "assistant" { + contentType = "output_text" + } + msg.Content = []ContentBlock{{Type: contentType, Text: s}} + } else { + var blocks []ContentBlock + _ = json.Unmarshal(item.Content, &blocks) + msg.Content = blocks + } + } + msgs = append(msgs, msg) + case "function_call_output": + msgs = append(msgs, Message{ + Role: "tool", + Content: []ContentBlock{{Type: "input_text", Text: item.Output}}, + }) + } + } + return msgs +} + +// ============================================================ +// Response Types (ResponseResource) +// ============================================================ + +// Response is the spec-compliant ResponseResource. type Response struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Provider string `json:"provider"` - Output []Message `json:"output"` - Usage Usage `json:"usage"` + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + CompletedAt *int64 `json:"completed_at"` + Status string `json:"status"` + IncompleteDetails *IncompleteDetails `json:"incomplete_details"` + Model string `json:"model"` + PreviousResponseID *string `json:"previous_response_id"` + Instructions *string `json:"instructions"` + Output []OutputItem `json:"output"` + Error *ResponseError `json:"error"` + Tools json.RawMessage `json:"tools"` + ToolChoice json.RawMessage `json:"tool_choice"` + Truncation string `json:"truncation"` + ParallelToolCalls bool `json:"parallel_tool_calls"` + Text json.RawMessage `json:"text"` + TopP float64 `json:"top_p"` + PresencePenalty float64 `json:"presence_penalty"` + FrequencyPenalty float64 `json:"frequency_penalty"` + TopLogprobs int `json:"top_logprobs"` + Temperature float64 `json:"temperature"` + Reasoning json.RawMessage `json:"reasoning"` + Usage *Usage `json:"usage"` + MaxOutputTokens *int `json:"max_output_tokens"` + MaxToolCalls *int `json:"max_tool_calls"` + Store bool `json:"store"` + Background bool `json:"background"` + ServiceTier string `json:"service_tier"` + Metadata map[string]string `json:"metadata"` + SafetyIdentifier *string `json:"safety_identifier"` + PromptCacheKey *string `json:"prompt_cache_key"` + + // Non-spec extension + Provider string `json:"provider,omitempty"` } -// Usage captures token accounting. +// OutputItem represents a typed item in the response output. +type OutputItem struct { + ID string `json:"id"` + Type string `json:"type"` + Status string `json:"status"` + Role string `json:"role,omitempty"` + Content []ContentPart `json:"content,omitempty"` +} + +// ContentPart is a content block within an output item. +type ContentPart struct { + Type string `json:"type"` + Text string `json:"text"` + Annotations []Annotation `json:"annotations"` +} + +// Annotation on output text content. +type Annotation struct { + Type string `json:"type"` +} + +// IncompleteDetails explains why a response is incomplete. +type IncompleteDetails struct { + Reason string `json:"reason"` +} + +// ResponseError describes an error in the response. +type ResponseError struct { + Type string `json:"type"` + Message string `json:"message"` + Code *string `json:"code"` +} + +// ============================================================ +// Usage Types +// ============================================================ + +// Usage captures token accounting with sub-details. type Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - TotalTokens int `json:"total_tokens"` + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` + InputTokensDetails InputTokensDetails `json:"input_tokens_details"` + OutputTokensDetails OutputTokensDetails `json:"output_tokens_details"` } -// StreamChunk represents a single Server-Sent Event in a streaming response. -type StreamChunk struct { - ID string `json:"id,omitempty"` - Object string `json:"object"` - Created int64 `json:"created,omitempty"` - Model string `json:"model,omitempty"` - Provider string `json:"provider,omitempty"` - Delta *StreamDelta `json:"delta,omitempty"` - Usage *Usage `json:"usage,omitempty"` - Done bool `json:"done,omitempty"` +// InputTokensDetails breaks down input token usage. +type InputTokensDetails struct { + CachedTokens int `json:"cached_tokens"` } -// StreamDelta represents incremental content in a stream chunk. -type StreamDelta struct { - Role string `json:"role,omitempty"` - Content []ContentBlock `json:"content,omitempty"` +// OutputTokensDetails breaks down output token usage. +type OutputTokensDetails struct { + ReasoningTokens int `json:"reasoning_tokens"` } +// ============================================================ +// Streaming Types +// ============================================================ + +// StreamEvent represents a single SSE event in the streaming response. +// Fields are selectively populated based on the event Type. +type StreamEvent struct { + Type string `json:"type"` + SequenceNumber int `json:"sequence_number"` + Response *Response `json:"response,omitempty"` + OutputIndex *int `json:"output_index,omitempty"` + Item *OutputItem `json:"item,omitempty"` + ItemID string `json:"item_id,omitempty"` + ContentIndex *int `json:"content_index,omitempty"` + Part *ContentPart `json:"part,omitempty"` + Delta string `json:"delta,omitempty"` + Text string `json:"text,omitempty"` +} + +// ============================================================ +// Provider Result Types (internal, not exposed via HTTP) +// ============================================================ + +// ProviderResult is returned by Provider.Generate. +type ProviderResult struct { + ID string + Model string + Text string + Usage Usage +} + +// ProviderStreamDelta is sent through the stream channel. +type ProviderStreamDelta struct { + ID string + Model string + Text string + Done bool + Usage *Usage +} + +// ============================================================ +// Models Endpoint Types +// ============================================================ + // ModelInfo describes a single model available through the gateway. type ModelInfo struct { ID string `json:"id"` Provider string `json:"provider"` } -// ModelsResponse is returned by the GET /v1/models endpoint. +// ModelsResponse is returned by GET /v1/models. type ModelsResponse struct { Object string `json:"object"` Data []ModelInfo `json:"data"` } +// ============================================================ +// Validation +// ============================================================ + // Validate performs basic structural validation. func (r *ResponseRequest) Validate() error { if r == nil { @@ -84,16 +310,8 @@ func (r *ResponseRequest) Validate() error { if r.Model == "" { return errors.New("model is required") } - if len(r.Input) == 0 { - return errors.New("input messages are required") - } - for i, msg := range r.Input { - if msg.Role == "" { - return fmt.Errorf("input[%d] role is required", i) - } - if len(msg.Content) == 0 { - return fmt.Errorf("input[%d] content is required", i) - } + if r.Input.String == nil && len(r.Input.Items) == 0 { + return errors.New("input is required") } return nil } diff --git a/internal/providers/anthropic/anthropic.go b/internal/providers/anthropic/anthropic.go index 44e6ce4..97b41ed 100644 --- a/internal/providers/anthropic/anthropic.go +++ b/internal/providers/anthropic/anthropic.go @@ -3,7 +3,6 @@ package anthropic import ( "context" "fmt" - "time" "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/option" @@ -60,8 +59,8 @@ func NewAzure(azureCfg config.AzureAnthropicConfig) *Provider { func (p *Provider) Name() string { return Name } -// Generate routes the Open Responses request to Anthropic's API. -func (p *Provider) Generate(ctx context.Context, req *api.ResponseRequest) (*api.Response, error) { +// Generate routes the request to Anthropic's API. +func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (*api.ProviderResult, error) { if p.cfg.APIKey == "" { return nil, fmt.Errorf("anthropic api key missing") } @@ -69,37 +68,40 @@ func (p *Provider) Generate(ctx context.Context, req *api.ResponseRequest) (*api return nil, fmt.Errorf("anthropic client not initialized") } - model := chooseModel(req.Model, p.cfg.Model) - - // Convert Open Responses messages to Anthropic format - messages := make([]anthropic.MessageParam, 0, len(req.Input)) + // Convert messages to Anthropic format + anthropicMsgs := make([]anthropic.MessageParam, 0, len(messages)) var system string - - for _, msg := range req.Input { + + for _, msg := range messages { var content string for _, block := range msg.Content { if block.Type == "input_text" || block.Type == "output_text" { content += block.Text } } - + switch msg.Role { case "user": - messages = append(messages, anthropic.NewUserMessage(anthropic.NewTextBlock(content))) + anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage(anthropic.NewTextBlock(content))) case "assistant": - messages = append(messages, anthropic.NewAssistantMessage(anthropic.NewTextBlock(content))) - case "system": + anthropicMsgs = append(anthropicMsgs, anthropic.NewAssistantMessage(anthropic.NewTextBlock(content))) + case "system", "developer": system = content } } // Build request params - params := anthropic.MessageNewParams{ - Model: anthropic.Model(model), - Messages: messages, - MaxTokens: int64(4096), + maxTokens := int64(4096) + if req.MaxOutputTokens != nil { + maxTokens = int64(*req.MaxOutputTokens) } - + + params := anthropic.MessageNewParams{ + Model: anthropic.Model(req.Model), + Messages: anthropicMsgs, + MaxTokens: maxTokens, + } + if system != "" { systemBlocks := []anthropic.TextBlockParam{ {Text: system, Type: "text"}, @@ -107,36 +109,31 @@ func (p *Provider) Generate(ctx context.Context, req *api.ResponseRequest) (*api params.System = systemBlocks } + if req.Temperature != nil { + params.Temperature = anthropic.Float(*req.Temperature) + } + if req.TopP != nil { + params.TopP = anthropic.Float(*req.TopP) + } + // Call Anthropic API resp, err := p.client.Messages.New(ctx, params) if err != nil { return nil, fmt.Errorf("anthropic api error: %w", err) } - // Convert Anthropic response to Open Responses format - output := make([]api.Message, 0, 1) + // Extract text from response var text string - for _, block := range resp.Content { if block.Type == "text" { text += block.Text } } - - output = append(output, api.Message{ - Role: "assistant", - Content: []api.ContentBlock{ - {Type: "output_text", Text: text}, - }, - }) - return &api.Response{ - ID: resp.ID, - Object: "response", - Created: time.Now().Unix(), - Model: string(resp.Model), - Provider: Name, - Output: output, + return &api.ProviderResult{ + ID: resp.ID, + Model: string(resp.Model), + Text: text, Usage: api.Usage{ InputTokens: int(resp.Usage.InputTokens), OutputTokens: int(resp.Usage.OutputTokens), @@ -146,12 +143,12 @@ func (p *Provider) Generate(ctx context.Context, req *api.ResponseRequest) (*api } // GenerateStream handles streaming requests to Anthropic. -func (p *Provider) GenerateStream(ctx context.Context, req *api.ResponseRequest) (<-chan *api.StreamChunk, <-chan error) { - chunkChan := make(chan *api.StreamChunk) +func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (<-chan *api.ProviderStreamDelta, <-chan error) { + deltaChan := make(chan *api.ProviderStreamDelta) errChan := make(chan error, 1) go func() { - defer close(chunkChan) + defer close(deltaChan) defer close(errChan) if p.cfg.APIKey == "" { @@ -163,37 +160,40 @@ func (p *Provider) GenerateStream(ctx context.Context, req *api.ResponseRequest) return } - model := chooseModel(req.Model, p.cfg.Model) - - // Convert messages - messages := make([]anthropic.MessageParam, 0, len(req.Input)) + // Convert messages to Anthropic format + anthropicMsgs := make([]anthropic.MessageParam, 0, len(messages)) var system string - - for _, msg := range req.Input { + + for _, msg := range messages { var content string for _, block := range msg.Content { if block.Type == "input_text" || block.Type == "output_text" { content += block.Text } - } - - switch msg.Role { - case "user": - messages = append(messages, anthropic.NewUserMessage(anthropic.NewTextBlock(content))) - case "assistant": - messages = append(messages, anthropic.NewAssistantMessage(anthropic.NewTextBlock(content))) - case "system": - system = content - } - } + } - // Build params - params := anthropic.MessageNewParams{ - Model: anthropic.Model(model), - Messages: messages, - MaxTokens: int64(4096), + switch msg.Role { + case "user": + anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage(anthropic.NewTextBlock(content))) + case "assistant": + anthropicMsgs = append(anthropicMsgs, anthropic.NewAssistantMessage(anthropic.NewTextBlock(content))) + case "system", "developer": + system = content + } } - + + // Build params + maxTokens := int64(4096) + if req.MaxOutputTokens != nil { + maxTokens = int64(*req.MaxOutputTokens) + } + + params := anthropic.MessageNewParams{ + Model: anthropic.Model(req.Model), + Messages: anthropicMsgs, + MaxTokens: maxTokens, + } + if system != "" { systemBlocks := []anthropic.TextBlockParam{ {Text: system, Type: "text"}, @@ -201,42 +201,28 @@ func (p *Provider) GenerateStream(ctx context.Context, req *api.ResponseRequest) params.System = systemBlocks } + if req.Temperature != nil { + params.Temperature = anthropic.Float(*req.Temperature) + } + if req.TopP != nil { + params.TopP = anthropic.Float(*req.TopP) + } + // Create stream stream := p.client.Messages.NewStreaming(ctx, params) // Process stream for stream.Next() { event := stream.Current() - - delta := &api.StreamDelta{} - var text string - - // Handle different event types + if event.Type == "content_block_delta" && event.Delta.Type == "text_delta" { - text = event.Delta.Text - delta.Content = []api.ContentBlock{ - {Type: "output_text", Text: text}, + select { + case deltaChan <- &api.ProviderStreamDelta{Text: event.Delta.Text}: + case <-ctx.Done(): + errChan <- ctx.Err() + return } } - - if event.Type == "message_start" { - delta.Role = "assistant" - } - - streamChunk := &api.StreamChunk{ - Object: "response.chunk", - Created: time.Now().Unix(), - Model: string(model), - Provider: Name, - Delta: delta, - } - - select { - case chunkChan <- streamChunk: - case <-ctx.Done(): - errChan <- ctx.Err() - return - } } if err := stream.Err(); err != nil { @@ -244,15 +230,15 @@ func (p *Provider) GenerateStream(ctx context.Context, req *api.ResponseRequest) return } - // Send final chunk + // Send final delta select { - case chunkChan <- &api.StreamChunk{Object: "response.chunk", Done: true}: + case deltaChan <- &api.ProviderStreamDelta{Done: true}: case <-ctx.Done(): errChan <- ctx.Err() } }() - return chunkChan, errChan + return deltaChan, errChan } func chooseModel(requested, defaultModel string) string { diff --git a/internal/providers/google/google.go b/internal/providers/google/google.go index 5d4ef82..39b0807 100644 --- a/internal/providers/google/google.go +++ b/internal/providers/google/google.go @@ -3,7 +3,6 @@ package google import ( "context" "fmt" - "time" "github.com/google/uuid" "google.golang.org/genai" @@ -41,8 +40,8 @@ func New(cfg config.ProviderConfig) *Provider { func (p *Provider) Name() string { return Name } -// Generate routes the Open Responses request to Gemini. -func (p *Provider) Generate(ctx context.Context, req *api.ResponseRequest) (*api.Response, error) { +// Generate routes the request to Gemini and returns a ProviderResult. +func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (*api.ProviderResult, error) { if p.cfg.APIKey == "" { return nil, fmt.Errorf("google api key missing") } @@ -50,60 +49,18 @@ func (p *Provider) Generate(ctx context.Context, req *api.ResponseRequest) (*api return nil, fmt.Errorf("google client not initialized") } - model := chooseModel(req.Model, p.cfg.Model) + model := req.Model - // Convert Open Responses messages to Gemini format - var contents []*genai.Content - var systemText string - - for _, msg := range req.Input { - if msg.Role == "system" { - for _, block := range msg.Content { - if block.Type == "input_text" || block.Type == "output_text" { - systemText += block.Text - } - } - continue - } + contents, systemText := convertMessages(messages) - var parts []*genai.Part - for _, block := range msg.Content { - if block.Type == "input_text" || block.Type == "output_text" { - parts = append(parts, genai.NewPartFromText(block.Text)) - } - } - - role := "user" - if msg.Role == "assistant" || msg.Role == "model" { - role = "model" - } - - contents = append(contents, &genai.Content{ - Role: role, - Parts: parts, - }) - } + config := buildConfig(systemText, req) - // Build config with system instruction if present - var config *genai.GenerateContentConfig - if systemText != "" { - config = &genai.GenerateContentConfig{ - SystemInstruction: &genai.Content{ - Parts: []*genai.Part{genai.NewPartFromText(systemText)}, - }, - } - } - - // Generate content resp, err := p.client.Models.GenerateContent(ctx, model, contents, config) if err != nil { return nil, fmt.Errorf("google api error: %w", err) } - // Convert Gemini response to Open Responses format - output := make([]api.Message, 0, 1) var text string - if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil { for _, part := range resp.Candidates[0].Content.Parts { if part != nil { @@ -111,28 +68,17 @@ func (p *Provider) Generate(ctx context.Context, req *api.ResponseRequest) (*api } } } - - output = append(output, api.Message{ - Role: "assistant", - Content: []api.ContentBlock{ - {Type: "output_text", Text: text}, - }, - }) - // Extract usage info if available var inputTokens, outputTokens int if resp.UsageMetadata != nil { inputTokens = int(resp.UsageMetadata.PromptTokenCount) outputTokens = int(resp.UsageMetadata.CandidatesTokenCount) } - return &api.Response{ - ID: uuid.NewString(), - Object: "response", - Created: time.Now().Unix(), - Model: model, - Provider: Name, - Output: output, + return &api.ProviderResult{ + ID: uuid.NewString(), + Model: model, + Text: text, Usage: api.Usage{ InputTokens: inputTokens, OutputTokens: outputTokens, @@ -142,12 +88,12 @@ func (p *Provider) Generate(ctx context.Context, req *api.ResponseRequest) (*api } // GenerateStream handles streaming requests to Google. -func (p *Provider) GenerateStream(ctx context.Context, req *api.ResponseRequest) (<-chan *api.StreamChunk, <-chan error) { - chunkChan := make(chan *api.StreamChunk) +func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (<-chan *api.ProviderStreamDelta, <-chan error) { + deltaChan := make(chan *api.ProviderStreamDelta) errChan := make(chan error, 1) go func() { - defer close(chunkChan) + defer close(deltaChan) defer close(errChan) if p.cfg.APIKey == "" { @@ -159,54 +105,14 @@ func (p *Provider) GenerateStream(ctx context.Context, req *api.ResponseRequest) return } - model := chooseModel(req.Model, p.cfg.Model) + model := req.Model - // Convert messages - var contents []*genai.Content - var systemText string - - for _, msg := range req.Input { - if msg.Role == "system" { - for _, block := range msg.Content { - if block.Type == "input_text" || block.Type == "output_text" { - systemText += block.Text - } - } - continue - } + contents, systemText := convertMessages(messages) - var parts []*genai.Part - for _, block := range msg.Content { - if block.Type == "input_text" || block.Type == "output_text" { - parts = append(parts, genai.NewPartFromText(block.Text)) - } - } - - role := "user" - if msg.Role == "assistant" || msg.Role == "model" { - role = "model" - } - - contents = append(contents, &genai.Content{ - Role: role, - Parts: parts, - }) - } + config := buildConfig(systemText, req) - // Build config with system instruction if present - var config *genai.GenerateContentConfig - if systemText != "" { - config = &genai.GenerateContentConfig{ - SystemInstruction: &genai.Content{ - Parts: []*genai.Part{genai.NewPartFromText(systemText)}, - }, - } - } - - // Create stream stream := p.client.Models.GenerateContentStream(ctx, model, contents, config) - // Process stream for resp, err := range stream { if err != nil { errChan <- fmt.Errorf("google stream error: %w", err) @@ -222,38 +128,94 @@ func (p *Provider) GenerateStream(ctx context.Context, req *api.ResponseRequest) } } - delta := &api.StreamDelta{} if text != "" { - delta.Content = []api.ContentBlock{ - {Type: "output_text", Text: text}, + select { + case deltaChan <- &api.ProviderStreamDelta{Text: text}: + case <-ctx.Done(): + errChan <- ctx.Err() + return } } - - streamChunk := &api.StreamChunk{ - Object: "response.chunk", - Created: time.Now().Unix(), - Model: model, - Provider: Name, - Delta: delta, - } - - select { - case chunkChan <- streamChunk: - case <-ctx.Done(): - errChan <- ctx.Err() - return - } } - // Send final chunk select { - case chunkChan <- &api.StreamChunk{Object: "response.chunk", Done: true}: + case deltaChan <- &api.ProviderStreamDelta{Done: true}: case <-ctx.Done(): errChan <- ctx.Err() } }() - return chunkChan, errChan + return deltaChan, errChan +} + +// convertMessages splits messages into Gemini contents and system text. +func convertMessages(messages []api.Message) ([]*genai.Content, string) { + var contents []*genai.Content + var systemText string + + for _, msg := range messages { + if msg.Role == "system" || msg.Role == "developer" { + for _, block := range msg.Content { + if block.Type == "input_text" || block.Type == "output_text" { + systemText += block.Text + } + } + continue + } + + var parts []*genai.Part + for _, block := range msg.Content { + if block.Type == "input_text" || block.Type == "output_text" { + parts = append(parts, genai.NewPartFromText(block.Text)) + } + } + + role := "user" + if msg.Role == "assistant" || msg.Role == "model" { + role = "model" + } + + contents = append(contents, &genai.Content{ + Role: role, + Parts: parts, + }) + } + + return contents, systemText +} + +// buildConfig constructs a GenerateContentConfig from system text and request params. +func buildConfig(systemText string, req *api.ResponseRequest) *genai.GenerateContentConfig { + var cfg *genai.GenerateContentConfig + + needsCfg := systemText != "" || req.MaxOutputTokens != nil || req.Temperature != nil || req.TopP != nil + if !needsCfg { + return nil + } + + cfg = &genai.GenerateContentConfig{} + + if systemText != "" { + cfg.SystemInstruction = &genai.Content{ + Parts: []*genai.Part{genai.NewPartFromText(systemText)}, + } + } + + if req.MaxOutputTokens != nil { + cfg.MaxOutputTokens = int32(*req.MaxOutputTokens) + } + + if req.Temperature != nil { + t := float32(*req.Temperature) + cfg.Temperature = &t + } + + if req.TopP != nil { + tp := float32(*req.TopP) + cfg.TopP = &tp + } + + return cfg } func chooseModel(requested, defaultModel string) string { diff --git a/internal/providers/openai/openai.go b/internal/providers/openai/openai.go index ac9bc03..45e5c33 100644 --- a/internal/providers/openai/openai.go +++ b/internal/providers/openai/openai.go @@ -3,7 +3,6 @@ package openai import ( "context" "fmt" - "time" "github.com/openai/openai-go" "github.com/openai/openai-go/azure" @@ -64,8 +63,8 @@ func NewAzure(azureCfg config.AzureOpenAIConfig) *Provider { // Name returns the provider identifier. func (p *Provider) Name() string { return Name } -// Generate routes the Open Responses request to OpenAI. -func (p *Provider) Generate(ctx context.Context, req *api.ResponseRequest) (*api.Response, error) { +// Generate routes the request to OpenAI and returns a ProviderResult. +func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (*api.ProviderResult, error) { if p.cfg.APIKey == "" { return nil, fmt.Errorf("openai api key missing") } @@ -73,55 +72,57 @@ func (p *Provider) Generate(ctx context.Context, req *api.ResponseRequest) (*api return nil, fmt.Errorf("openai client not initialized") } - model := chooseModel(req.Model, p.cfg.Model) - - // Convert Open Responses messages to OpenAI format - messages := make([]openai.ChatCompletionMessageParamUnion, 0, len(req.Input)) - for _, msg := range req.Input { + // Convert messages to OpenAI format + oaiMessages := make([]openai.ChatCompletionMessageParamUnion, 0, len(messages)) + for _, msg := range messages { var content string for _, block := range msg.Content { if block.Type == "input_text" || block.Type == "output_text" { content += block.Text } } - + switch msg.Role { case "user": - messages = append(messages, openai.UserMessage(content)) + oaiMessages = append(oaiMessages, openai.UserMessage(content)) case "assistant": - messages = append(messages, openai.AssistantMessage(content)) + oaiMessages = append(oaiMessages, openai.AssistantMessage(content)) case "system": - messages = append(messages, openai.SystemMessage(content)) + oaiMessages = append(oaiMessages, openai.SystemMessage(content)) + case "developer": + oaiMessages = append(oaiMessages, openai.SystemMessage(content)) } } + params := openai.ChatCompletionNewParams{ + Model: openai.ChatModel(req.Model), + Messages: oaiMessages, + } + if req.MaxOutputTokens != nil { + params.MaxTokens = openai.Int(int64(*req.MaxOutputTokens)) + } + if req.Temperature != nil { + params.Temperature = openai.Float(*req.Temperature) + } + if req.TopP != nil { + params.TopP = openai.Float(*req.TopP) + } + // Call OpenAI API - resp, err := p.client.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ - Model: openai.ChatModel(model), - Messages: messages, - }) + resp, err := p.client.Chat.Completions.New(ctx, params) if err != nil { return nil, fmt.Errorf("openai api error: %w", err) } - // Convert OpenAI response to Open Responses format - output := make([]api.Message, 0, len(resp.Choices)) + var combinedText string for _, choice := range resp.Choices { - output = append(output, api.Message{ - Role: "assistant", - Content: []api.ContentBlock{ - {Type: "output_text", Text: choice.Message.Content}, - }, - }) + combinedText += choice.Message.Content } - return &api.Response{ - ID: resp.ID, - Object: "response", - Created: time.Now().Unix(), - Model: resp.Model, - Provider: Name, - Output: output, + return &api.ProviderResult{ + ID: resp.ID, + Model: resp.Model, + Text: combinedText, Usage: api.Usage{ InputTokens: int(resp.Usage.PromptTokens), OutputTokens: int(resp.Usage.CompletionTokens), @@ -131,12 +132,12 @@ func (p *Provider) Generate(ctx context.Context, req *api.ResponseRequest) (*api } // GenerateStream handles streaming requests to OpenAI. -func (p *Provider) GenerateStream(ctx context.Context, req *api.ResponseRequest) (<-chan *api.StreamChunk, <-chan error) { - chunkChan := make(chan *api.StreamChunk) +func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (<-chan *api.ProviderStreamDelta, <-chan error) { + deltaChan := make(chan *api.ProviderStreamDelta) errChan := make(chan error, 1) go func() { - defer close(chunkChan) + defer close(deltaChan) defer close(errChan) if p.cfg.APIKey == "" { @@ -148,62 +149,60 @@ func (p *Provider) GenerateStream(ctx context.Context, req *api.ResponseRequest) return } - model := chooseModel(req.Model, p.cfg.Model) - - // Convert messages - messages := make([]openai.ChatCompletionMessageParamUnion, 0, len(req.Input)) - for _, msg := range req.Input { + // Convert messages to OpenAI format + oaiMessages := make([]openai.ChatCompletionMessageParamUnion, 0, len(messages)) + for _, msg := range messages { var content string for _, block := range msg.Content { if block.Type == "input_text" || block.Type == "output_text" { content += block.Text } - } - - switch msg.Role { - case "user": - messages = append(messages, openai.UserMessage(content)) - case "assistant": - messages = append(messages, openai.AssistantMessage(content)) - case "system": - messages = append(messages, openai.SystemMessage(content)) - } - } + } - // Create streaming request - stream := p.client.Chat.Completions.NewStreaming(ctx, openai.ChatCompletionNewParams{ - Model: openai.ChatModel(model), - Messages: messages, - }) + switch msg.Role { + case "user": + oaiMessages = append(oaiMessages, openai.UserMessage(content)) + case "assistant": + oaiMessages = append(oaiMessages, openai.AssistantMessage(content)) + case "system": + oaiMessages = append(oaiMessages, openai.SystemMessage(content)) + case "developer": + oaiMessages = append(oaiMessages, openai.SystemMessage(content)) + } + } + + params := openai.ChatCompletionNewParams{ + Model: openai.ChatModel(req.Model), + Messages: oaiMessages, + } + if req.MaxOutputTokens != nil { + params.MaxTokens = openai.Int(int64(*req.MaxOutputTokens)) + } + if req.Temperature != nil { + params.Temperature = openai.Float(*req.Temperature) + } + if req.TopP != nil { + params.TopP = openai.Float(*req.TopP) + } + + // Create streaming request + stream := p.client.Chat.Completions.NewStreaming(ctx, params) // Process stream for stream.Next() { chunk := stream.Current() - - for _, choice := range chunk.Choices { - delta := &api.StreamDelta{} - - if choice.Delta.Role != "" { - delta.Role = string(choice.Delta.Role) - } - - if choice.Delta.Content != "" { - delta.Content = []api.ContentBlock{ - {Type: "output_text", Text: choice.Delta.Content}, - } - } - streamChunk := &api.StreamChunk{ - ID: chunk.ID, - Object: "response.chunk", - Created: time.Now().Unix(), - Model: chunk.Model, - Provider: Name, - Delta: delta, + for _, choice := range chunk.Choices { + if choice.Delta.Content == "" { + continue } select { - case chunkChan <- streamChunk: + case deltaChan <- &api.ProviderStreamDelta{ + ID: chunk.ID, + Model: chunk.Model, + Text: choice.Delta.Content, + }: case <-ctx.Done(): errChan <- ctx.Err() return @@ -216,15 +215,15 @@ func (p *Provider) GenerateStream(ctx context.Context, req *api.ResponseRequest) return } - // Send final chunk + // Send final delta select { - case chunkChan <- &api.StreamChunk{Object: "response.chunk", Done: true}: + case deltaChan <- &api.ProviderStreamDelta{Done: true}: case <-ctx.Done(): errChan <- ctx.Err() } }() - return chunkChan, errChan + return deltaChan, errChan } func chooseModel(requested, defaultModel string) string { diff --git a/internal/providers/providers.go b/internal/providers/providers.go index f32c79f..1affc26 100644 --- a/internal/providers/providers.go +++ b/internal/providers/providers.go @@ -14,8 +14,8 @@ import ( // Provider represents a unified interface that each LLM provider must implement. type Provider interface { Name() string - Generate(ctx context.Context, req *api.ResponseRequest) (*api.Response, error) - GenerateStream(ctx context.Context, req *api.ResponseRequest) (<-chan *api.StreamChunk, <-chan error) + Generate(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (*api.ProviderResult, error) + GenerateStream(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (<-chan *api.ProviderStreamDelta, <-chan error) } // Registry keeps track of registered providers and model-to-provider mappings. diff --git a/internal/server/server.go b/internal/server/server.go index b846415..f61435f 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -5,6 +5,10 @@ import ( "fmt" "log" "net/http" + "strings" + "time" + + "github.com/google/uuid" "github.com/yourusername/go-llm-gateway/internal/api" "github.com/yourusername/go-llm-gateway/internal/conversation" @@ -74,16 +78,34 @@ func (s *GatewayServer) handleResponses(w http.ResponseWriter, r *http.Request) return } - // Build full message history - messages := s.buildMessageHistory(&req) - if messages == nil { - http.Error(w, "conversation not found", http.StatusNotFound) - return + // Normalize input to internal messages + inputMsgs := req.NormalizeInput() + + // Build full message history from previous conversation + var historyMsgs []api.Message + if req.PreviousResponseID != nil && *req.PreviousResponseID != "" { + conv, ok := s.convs.Get(*req.PreviousResponseID) + if !ok { + http.Error(w, "conversation not found", http.StatusNotFound) + return + } + historyMsgs = conv.Messages } - // Update request with full history for provider - fullReq := req - fullReq.Input = messages + // Combined messages for conversation storage (history + new input, no instructions) + storeMsgs := make([]api.Message, 0, len(historyMsgs)+len(inputMsgs)) + storeMsgs = append(storeMsgs, historyMsgs...) + storeMsgs = append(storeMsgs, inputMsgs...) + + // Build provider messages: instructions + history + input + var providerMsgs []api.Message + if req.Instructions != nil && *req.Instructions != "" { + providerMsgs = append(providerMsgs, api.Message{ + Role: "developer", + Content: []api.ContentBlock{{Type: "input_text", Text: *req.Instructions}}, + }) + } + providerMsgs = append(providerMsgs, storeMsgs...) provider, err := s.resolveProvider(&req) if err != nil { @@ -91,64 +113,44 @@ func (s *GatewayServer) handleResponses(w http.ResponseWriter, r *http.Request) return } - // Resolve provider_model_id (e.g., Azure deployment name) before sending to provider - fullReq.Model = s.registry.ResolveModelID(req.Model) + // Resolve provider_model_id (e.g., Azure deployment name) + resolvedReq := req + resolvedReq.Model = s.registry.ResolveModelID(req.Model) - // Handle streaming vs non-streaming if req.Stream { - s.handleStreamingResponse(w, r, provider, &fullReq, &req) + s.handleStreamingResponse(w, r, provider, providerMsgs, &resolvedReq, &req, storeMsgs) } else { - s.handleSyncResponse(w, r, provider, &fullReq, &req) + s.handleSyncResponse(w, r, provider, providerMsgs, &resolvedReq, &req, storeMsgs) } } -func (s *GatewayServer) buildMessageHistory(req *api.ResponseRequest) []api.Message { - // If no previous_response_id, use input as-is - if req.PreviousResponseID == "" { - return req.Input - } - - // Load previous conversation - conv, ok := s.convs.Get(req.PreviousResponseID) - if !ok { - return nil - } - - // Append new input to conversation history - messages := make([]api.Message, len(conv.Messages)) - copy(messages, conv.Messages) - messages = append(messages, req.Input...) - - return messages -} - -func (s *GatewayServer) handleSyncResponse(w http.ResponseWriter, r *http.Request, provider providers.Provider, fullReq *api.ResponseRequest, origReq *api.ResponseRequest) { - resp, err := provider.Generate(r.Context(), fullReq) +func (s *GatewayServer) handleSyncResponse(w http.ResponseWriter, r *http.Request, provider providers.Provider, providerMsgs []api.Message, resolvedReq *api.ResponseRequest, origReq *api.ResponseRequest, storeMsgs []api.Message) { + result, err := provider.Generate(r.Context(), providerMsgs, resolvedReq) if err != nil { s.logger.Printf("provider %s error: %v", provider.Name(), err) http.Error(w, "provider error", http.StatusBadGateway) return } - // Store conversation - use previous_response_id if continuing, otherwise use new ID - conversationID := origReq.PreviousResponseID - if conversationID == "" { - conversationID = resp.ID + responseID := generateID("resp_") + + // Build assistant message for conversation store + assistantMsg := api.Message{ + Role: "assistant", + Content: []api.ContentBlock{{Type: "output_text", Text: result.Text}}, } - - messages := append(fullReq.Input, resp.Output...) - s.convs.Create(conversationID, resp.Model, messages) - - // Return the conversation ID (not the provider's response ID) - resp.ID = conversationID + allMsgs := append(storeMsgs, assistantMsg) + s.convs.Create(responseID, result.Model, allMsgs) + + // Build spec-compliant response + resp := s.buildResponse(origReq, result, provider.Name(), responseID) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) _ = json.NewEncoder(w).Encode(resp) } -func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.Request, provider providers.Provider, fullReq *api.ResponseRequest, origReq *api.ResponseRequest) { - // Set headers for SSE +func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.Request, provider providers.Provider, providerMsgs []api.Message, resolvedReq *api.ResponseRequest, origReq *api.ResponseRequest, storeMsgs []api.Message) { w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") @@ -160,89 +162,322 @@ func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.R return } - chunkChan, errChan := provider.GenerateStream(r.Context(), fullReq) - - var responseID string - var fullText string + responseID := generateID("resp_") + itemID := generateID("msg_") + seq := 0 + outputIdx := 0 + contentIdx := 0 + // Build initial response snapshot (in_progress, no output yet) + initialResp := s.buildResponse(origReq, &api.ProviderResult{ + Model: origReq.Model, + }, provider.Name(), responseID) + initialResp.Status = "in_progress" + initialResp.CompletedAt = nil + initialResp.Output = []api.OutputItem{} + initialResp.Usage = nil + + // response.created + s.sendSSE(w, flusher, &seq, "response.created", &api.StreamEvent{ + Type: "response.created", + Response: initialResp, + }) + + // response.in_progress + s.sendSSE(w, flusher, &seq, "response.in_progress", &api.StreamEvent{ + Type: "response.in_progress", + Response: initialResp, + }) + + // response.output_item.added + inProgressItem := &api.OutputItem{ + ID: itemID, + Type: "message", + Status: "in_progress", + Role: "assistant", + Content: []api.ContentPart{}, + } + s.sendSSE(w, flusher, &seq, "response.output_item.added", &api.StreamEvent{ + Type: "response.output_item.added", + OutputIndex: &outputIdx, + Item: inProgressItem, + }) + + // response.content_part.added + emptyPart := &api.ContentPart{ + Type: "output_text", + Text: "", + Annotations: []api.Annotation{}, + } + s.sendSSE(w, flusher, &seq, "response.content_part.added", &api.StreamEvent{ + Type: "response.content_part.added", + ItemID: itemID, + OutputIndex: &outputIdx, + ContentIndex: &contentIdx, + Part: emptyPart, + }) + + // Start provider stream + deltaChan, errChan := provider.GenerateStream(r.Context(), providerMsgs, resolvedReq) + + var fullText string + var streamErr error + var providerModel string + +loop: for { select { - case chunk, ok := <-chunkChan: + case delta, ok := <-deltaChan: if !ok { - return + break loop } - - // Capture response ID - if chunk.ID != "" && responseID == "" { - responseID = chunk.ID + if delta.Model != "" && providerModel == "" { + providerModel = delta.Model } - - // Override chunk ID with conversation ID - if origReq.PreviousResponseID != "" { - chunk.ID = origReq.PreviousResponseID - } else if responseID != "" { - chunk.ID = responseID + if delta.Text != "" { + fullText += delta.Text + s.sendSSE(w, flusher, &seq, "response.output_text.delta", &api.StreamEvent{ + Type: "response.output_text.delta", + ItemID: itemID, + OutputIndex: &outputIdx, + ContentIndex: &contentIdx, + Delta: delta.Text, + }) } - - // Accumulate text from deltas - if chunk.Delta != nil && len(chunk.Delta.Content) > 0 { - for _, block := range chunk.Delta.Content { - fullText += block.Text - } + if delta.Done { + break loop } - - data, err := json.Marshal(chunk) - if err != nil { - s.logger.Printf("failed to marshal chunk: %v", err) - continue - } - - fmt.Fprintf(w, "data: %s\n\n", data) - flusher.Flush() - - if chunk.Done { - // Store conversation with a single consolidated assistant message - s.storeStreamConversation(fullReq, origReq, responseID, fullText) - return - } - case err := <-errChan: if err != nil { - s.logger.Printf("stream error: %v", err) - errData, _ := json.Marshal(map[string]string{"error": err.Error()}) - fmt.Fprintf(w, "data: %s\n\n", errData) - flusher.Flush() + streamErr = err } - // Store whatever we accumulated before the error - s.storeStreamConversation(fullReq, origReq, responseID, fullText) - return - + break loop case <-r.Context().Done(): s.logger.Printf("client disconnected") return } } -} -func (s *GatewayServer) storeStreamConversation(fullReq *api.ResponseRequest, origReq *api.ResponseRequest, responseID string, fullText string) { - if responseID == "" || fullText == "" { + if streamErr != nil { + s.logger.Printf("stream error: %v", streamErr) + failedResp := s.buildResponse(origReq, &api.ProviderResult{ + Model: origReq.Model, + }, provider.Name(), responseID) + failedResp.Status = "failed" + failedResp.CompletedAt = nil + failedResp.Output = []api.OutputItem{} + failedResp.Error = &api.ResponseError{ + Type: "server_error", + Message: streamErr.Error(), + } + s.sendSSE(w, flusher, &seq, "response.failed", &api.StreamEvent{ + Type: "response.failed", + Response: failedResp, + }) return } - assistantMsg := api.Message{ - Role: "assistant", - Content: []api.ContentBlock{ - {Type: "output_text", Text: fullText}, - }, - } - messages := append(fullReq.Input, assistantMsg) + // response.output_text.done + s.sendSSE(w, flusher, &seq, "response.output_text.done", &api.StreamEvent{ + Type: "response.output_text.done", + ItemID: itemID, + OutputIndex: &outputIdx, + ContentIndex: &contentIdx, + Text: fullText, + }) - conversationID := origReq.PreviousResponseID - if conversationID == "" { - conversationID = responseID + // response.content_part.done + completedPart := &api.ContentPart{ + Type: "output_text", + Text: fullText, + Annotations: []api.Annotation{}, + } + s.sendSSE(w, flusher, &seq, "response.content_part.done", &api.StreamEvent{ + Type: "response.content_part.done", + ItemID: itemID, + OutputIndex: &outputIdx, + ContentIndex: &contentIdx, + Part: completedPart, + }) + + // response.output_item.done + completedItem := &api.OutputItem{ + ID: itemID, + Type: "message", + Status: "completed", + Role: "assistant", + Content: []api.ContentPart{*completedPart}, + } + s.sendSSE(w, flusher, &seq, "response.output_item.done", &api.StreamEvent{ + Type: "response.output_item.done", + OutputIndex: &outputIdx, + Item: completedItem, + }) + + // Build final completed response + model := origReq.Model + if providerModel != "" { + model = providerModel + } + finalResult := &api.ProviderResult{ + Model: model, + Text: fullText, + } + completedResp := s.buildResponse(origReq, finalResult, provider.Name(), responseID) + completedResp.Output[0].ID = itemID + + // response.completed + s.sendSSE(w, flusher, &seq, "response.completed", &api.StreamEvent{ + Type: "response.completed", + Response: completedResp, + }) + + // Store conversation + if fullText != "" { + assistantMsg := api.Message{ + Role: "assistant", + Content: []api.ContentBlock{{Type: "output_text", Text: fullText}}, + } + allMsgs := append(storeMsgs, assistantMsg) + s.convs.Create(responseID, model, allMsgs) + } +} + +func (s *GatewayServer) sendSSE(w http.ResponseWriter, flusher http.Flusher, seq *int, eventType string, event *api.StreamEvent) { + event.SequenceNumber = *seq + *seq++ + data, err := json.Marshal(event) + if err != nil { + s.logger.Printf("failed to marshal SSE event: %v", err) + return + } + fmt.Fprintf(w, "event: %s\ndata: %s\n\n", eventType, data) + flusher.Flush() +} + +func (s *GatewayServer) buildResponse(req *api.ResponseRequest, result *api.ProviderResult, providerName string, responseID string) *api.Response { + now := time.Now().Unix() + + model := result.Model + if model == "" { + model = req.Model } - s.convs.Create(conversationID, fullReq.Model, messages) + // Build output item + itemID := generateID("msg_") + outputItem := api.OutputItem{ + ID: itemID, + Type: "message", + Status: "completed", + Role: "assistant", + Content: []api.ContentPart{{ + Type: "output_text", + Text: result.Text, + Annotations: []api.Annotation{}, + }}, + } + + // Echo back request params with defaults + tools := req.Tools + if tools == nil { + tools = json.RawMessage(`[]`) + } + toolChoice := req.ToolChoice + if toolChoice == nil { + toolChoice = json.RawMessage(`"auto"`) + } + text := req.Text + if text == nil { + text = json.RawMessage(`{"format":{"type":"text"}}`) + } + truncation := "disabled" + if req.Truncation != nil { + truncation = *req.Truncation + } + temperature := 1.0 + if req.Temperature != nil { + temperature = *req.Temperature + } + topP := 1.0 + if req.TopP != nil { + topP = *req.TopP + } + presencePenalty := 0.0 + if req.PresencePenalty != nil { + presencePenalty = *req.PresencePenalty + } + frequencyPenalty := 0.0 + if req.FrequencyPenalty != nil { + frequencyPenalty = *req.FrequencyPenalty + } + topLogprobs := 0 + if req.TopLogprobs != nil { + topLogprobs = *req.TopLogprobs + } + parallelToolCalls := true + if req.ParallelToolCalls != nil { + parallelToolCalls = *req.ParallelToolCalls + } + store := true + if req.Store != nil { + store = *req.Store + } + background := false + if req.Background != nil { + background = *req.Background + } + serviceTier := "default" + if req.ServiceTier != nil { + serviceTier = *req.ServiceTier + } + var reasoning json.RawMessage + if req.Reasoning != nil { + reasoning = req.Reasoning + } + metadata := req.Metadata + if metadata == nil { + metadata = map[string]string{} + } + + var usage *api.Usage + if result.Text != "" { + usage = &result.Usage + } + + return &api.Response{ + ID: responseID, + Object: "response", + CreatedAt: now, + CompletedAt: &now, + Status: "completed", + IncompleteDetails: nil, + Model: model, + PreviousResponseID: req.PreviousResponseID, + Instructions: req.Instructions, + Output: []api.OutputItem{outputItem}, + Error: nil, + Tools: tools, + ToolChoice: toolChoice, + Truncation: truncation, + ParallelToolCalls: parallelToolCalls, + Text: text, + TopP: topP, + PresencePenalty: presencePenalty, + FrequencyPenalty: frequencyPenalty, + TopLogprobs: topLogprobs, + Temperature: temperature, + Reasoning: reasoning, + Usage: usage, + MaxOutputTokens: req.MaxOutputTokens, + MaxToolCalls: req.MaxToolCalls, + Store: store, + Background: background, + ServiceTier: serviceTier, + Metadata: metadata, + SafetyIdentifier: nil, + PromptCacheKey: nil, + Provider: providerName, + } } func (s *GatewayServer) resolveProvider(req *api.ResponseRequest) (providers.Provider, error) { @@ -254,3 +489,8 @@ func (s *GatewayServer) resolveProvider(req *api.ResponseRequest) (providers.Pro } return s.registry.Default(req.Model) } + +func generateID(prefix string) string { + id := strings.ReplaceAll(uuid.NewString(), "-", "") + return prefix + id[:24] +}