From 157680bb13e161944cf5a381fffb187a88240baa Mon Sep 17 00:00:00 2001 From: Anibal Angulo Date: Mon, 2 Mar 2026 15:27:28 +0000 Subject: [PATCH] Add OpenAI tool calling support --- internal/api/types.go | 51 ++++-- internal/providers/openai/convert.go | 120 ++++++++++++++ internal/providers/openai/openai.go | 102 ++++++++++-- internal/server/server.go | 233 +++++++++++++++++++++------ 4 files changed, 429 insertions(+), 77 deletions(-) create mode 100644 internal/providers/openai/convert.go diff --git a/internal/api/types.go b/internal/api/types.go index b0b3565..0c39d34 100644 --- a/internal/api/types.go +++ b/internal/api/types.go @@ -96,6 +96,7 @@ type InputItem struct { type Message struct { Role string `json:"role"` Content []ContentBlock `json:"content"` + CallID string `json:"call_id,omitempty"` // for tool messages } // ContentBlock is a typed content element. @@ -138,6 +139,7 @@ func (r *ResponseRequest) NormalizeInput() []Message { msgs = append(msgs, Message{ Role: "tool", Content: []ContentBlock{{Type: "input_text", Text: item.Output}}, + CallID: item.CallID, }) } } @@ -188,11 +190,14 @@ type Response struct { // OutputItem represents a typed item in the response output. type OutputItem struct { - ID string `json:"id"` - Type string `json:"type"` - Status string `json:"status"` - Role string `json:"role,omitempty"` - Content []ContentPart `json:"content,omitempty"` + ID string `json:"id"` + Type string `json:"type"` + Status string `json:"status"` + Role string `json:"role,omitempty"` + Content []ContentPart `json:"content,omitempty"` + CallID string `json:"call_id,omitempty"` // for function_call + Name string `json:"name,omitempty"` // for function_call + Arguments string `json:"arguments,omitempty"` // for function_call } // ContentPart is a content block within an output item. @@ -259,6 +264,7 @@ type StreamEvent struct { Part *ContentPart `json:"part,omitempty"` Delta string `json:"delta,omitempty"` Text string `json:"text,omitempty"` + Arguments string `json:"arguments,omitempty"` // for function_call_arguments.done } // ============================================================ @@ -267,19 +273,36 @@ type StreamEvent struct { // ProviderResult is returned by Provider.Generate. type ProviderResult struct { - ID string - Model string - Text string - Usage Usage + ID string + Model string + Text string + Usage Usage + ToolCalls []ToolCall } // ProviderStreamDelta is sent through the stream channel. type ProviderStreamDelta struct { - ID string - Model string - Text string - Done bool - Usage *Usage + ID string + Model string + Text string + Done bool + Usage *Usage + ToolCallDelta *ToolCallDelta +} + +// ToolCall represents a function call from the model. +type ToolCall struct { + ID string + Name string + Arguments string // JSON string +} + +// ToolCallDelta represents a streaming chunk of a tool call. +type ToolCallDelta struct { + Index int + ID string + Name string + Arguments string } // ============================================================ diff --git a/internal/providers/openai/convert.go b/internal/providers/openai/convert.go new file mode 100644 index 0000000..e805425 --- /dev/null +++ b/internal/providers/openai/convert.go @@ -0,0 +1,120 @@ +package openai + +import ( + "encoding/json" + "fmt" + + "github.com/ajac-zero/latticelm/internal/api" + "github.com/openai/openai-go" + "github.com/openai/openai-go/shared" +) + +// parseTools converts Open Responses tools to OpenAI format +func parseTools(req *api.ResponseRequest) ([]openai.ChatCompletionToolParam, error) { + if req.Tools == nil || len(req.Tools) == 0 { + return nil, nil + } + + var toolDefs []map[string]interface{} + if err := json.Unmarshal(req.Tools, &toolDefs); err != nil { + return nil, fmt.Errorf("unmarshal tools: %w", err) + } + + var tools []openai.ChatCompletionToolParam + for _, td := range toolDefs { + // Convert Open Responses tool to OpenAI ChatCompletionToolParam + // Extract: name, description, parameters + name, _ := td["name"].(string) + desc, _ := td["description"].(string) + params, _ := td["parameters"].(map[string]interface{}) + + tool := openai.ChatCompletionToolParam{ + Function: shared.FunctionDefinitionParam{ + Name: name, + }, + } + + if desc != "" { + tool.Function.Description = openai.String(desc) + } + + if params != nil { + tool.Function.Parameters = shared.FunctionParameters(params) + } + + tools = append(tools, tool) + } + + return tools, nil +} + +// parseToolChoice converts Open Responses tool_choice to OpenAI format +func parseToolChoice(req *api.ResponseRequest) (openai.ChatCompletionToolChoiceOptionUnionParam, error) { + var result openai.ChatCompletionToolChoiceOptionUnionParam + + if req.ToolChoice == nil || len(req.ToolChoice) == 0 { + return result, nil + } + + var choice interface{} + if err := json.Unmarshal(req.ToolChoice, &choice); err != nil { + return result, fmt.Errorf("unmarshal tool_choice: %w", err) + } + + // Handle string values: "auto", "none", "required" + if str, ok := choice.(string); ok { + result.OfAuto = openai.String(str) + return result, nil + } + + // Handle specific function selection: {"type": "function", "name": "..."} + if obj, ok := choice.(map[string]interface{}); ok { + funcObj, _ := obj["function"].(map[string]interface{}) + name, _ := funcObj["name"].(string) + + result.OfChatCompletionNamedToolChoice = &openai.ChatCompletionNamedToolChoiceParam{ + Function: openai.ChatCompletionNamedToolChoiceFunctionParam{ + Name: name, + }, + } + return result, nil + } + + return result, fmt.Errorf("invalid tool_choice format") +} + +// extractToolCalls converts OpenAI tool calls to api.ToolCall +func extractToolCalls(message openai.ChatCompletionMessage) []api.ToolCall { + if len(message.ToolCalls) == 0 { + return nil + } + + var toolCalls []api.ToolCall + for _, tc := range message.ToolCalls { + toolCalls = append(toolCalls, api.ToolCall{ + ID: tc.ID, + Name: tc.Function.Name, + Arguments: tc.Function.Arguments, + }) + } + return toolCalls +} + +// extractToolCallDelta extracts tool call delta from streaming chunk choice +func extractToolCallDelta(choice openai.ChatCompletionChunkChoice) *api.ToolCallDelta { + if len(choice.Delta.ToolCalls) == 0 { + return nil + } + + // OpenAI sends tool calls with index in the delta + for _, tc := range choice.Delta.ToolCalls { + return &api.ToolCallDelta{ + Index: int(tc.Index), + ID: tc.ID, + Name: tc.Function.Name, + Arguments: tc.Function.Arguments, + } + } + + return nil +} diff --git a/internal/providers/openai/openai.go b/internal/providers/openai/openai.go index 61e3deb..98fcb69 100644 --- a/internal/providers/openai/openai.go +++ b/internal/providers/openai/openai.go @@ -91,6 +91,8 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap oaiMessages = append(oaiMessages, openai.SystemMessage(content)) case "developer": oaiMessages = append(oaiMessages, openai.SystemMessage(content)) + case "tool": + oaiMessages = append(oaiMessages, openai.ToolMessage(content, msg.CallID)) } } @@ -108,6 +110,29 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap params.TopP = openai.Float(*req.TopP) } + // Add tools if present + if req.Tools != nil && len(req.Tools) > 0 { + tools, err := parseTools(req) + if err != nil { + return nil, fmt.Errorf("parse tools: %w", err) + } + params.Tools = tools + } + + // Add tool_choice if present + if req.ToolChoice != nil && len(req.ToolChoice) > 0 { + toolChoice, err := parseToolChoice(req) + if err != nil { + return nil, fmt.Errorf("parse tool_choice: %w", err) + } + params.ToolChoice = toolChoice + } + + // Add parallel_tool_calls if specified + if req.ParallelToolCalls != nil { + params.ParallelToolCalls = openai.Bool(*req.ParallelToolCalls) + } + // Call OpenAI API resp, err := p.client.Chat.Completions.New(ctx, params) if err != nil { @@ -115,14 +140,20 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap } var combinedText string + var toolCalls []api.ToolCall + for _, choice := range resp.Choices { combinedText += choice.Message.Content + if len(choice.Message.ToolCalls) > 0 { + toolCalls = append(toolCalls, extractToolCalls(choice.Message)...) + } } return &api.ProviderResult{ - ID: resp.ID, - Model: resp.Model, - Text: combinedText, + ID: resp.ID, + Model: resp.Model, + Text: combinedText, + ToolCalls: toolCalls, Usage: api.Usage{ InputTokens: int(resp.Usage.PromptTokens), OutputTokens: int(resp.Usage.CompletionTokens), @@ -168,6 +199,8 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r oaiMessages = append(oaiMessages, openai.SystemMessage(content)) case "developer": oaiMessages = append(oaiMessages, openai.SystemMessage(content)) + case "tool": + oaiMessages = append(oaiMessages, openai.ToolMessage(content, msg.CallID)) } } @@ -185,6 +218,31 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r params.TopP = openai.Float(*req.TopP) } + // Add tools if present + if req.Tools != nil && len(req.Tools) > 0 { + tools, err := parseTools(req) + if err != nil { + errChan <- fmt.Errorf("parse tools: %w", err) + return + } + params.Tools = tools + } + + // Add tool_choice if present + if req.ToolChoice != nil && len(req.ToolChoice) > 0 { + toolChoice, err := parseToolChoice(req) + if err != nil { + errChan <- fmt.Errorf("parse tool_choice: %w", err) + return + } + params.ToolChoice = toolChoice + } + + // Add parallel_tool_calls if specified + if req.ParallelToolCalls != nil { + params.ParallelToolCalls = openai.Bool(*req.ParallelToolCalls) + } + // Create streaming request stream := p.client.Chat.Completions.NewStreaming(ctx, params) @@ -193,19 +251,35 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r chunk := stream.Current() for _, choice := range chunk.Choices { - if choice.Delta.Content == "" { - continue + // Handle text content + if choice.Delta.Content != "" { + select { + case deltaChan <- &api.ProviderStreamDelta{ + ID: chunk.ID, + Model: chunk.Model, + Text: choice.Delta.Content, + }: + case <-ctx.Done(): + errChan <- ctx.Err() + return + } } - select { - case deltaChan <- &api.ProviderStreamDelta{ - ID: chunk.ID, - Model: chunk.Model, - Text: choice.Delta.Content, - }: - case <-ctx.Done(): - errChan <- ctx.Err() - return + // Handle tool call deltas + if len(choice.Delta.ToolCalls) > 0 { + delta := extractToolCallDelta(choice) + if delta != nil { + select { + case deltaChan <- &api.ProviderStreamDelta{ + ID: chunk.ID, + Model: chunk.Model, + ToolCallDelta: delta, + }: + case <-ctx.Done(): + errChan <- ctx.Err() + return + } + } } } } diff --git a/internal/server/server.go b/internal/server/server.go index e3d79e7..b581201 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -224,6 +224,17 @@ func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.R var streamErr error var providerModel string + // Track tool calls being built + type toolCallBuilder struct { + itemID string + id string + name string + arguments string + } + toolCallsInProgress := make(map[int]*toolCallBuilder) + nextOutputIdx := 0 + textItemAdded := false + loop: for { select { @@ -234,7 +245,14 @@ loop: if delta.Model != "" && providerModel == "" { providerModel = delta.Model } + + // Handle text content if delta.Text != "" { + // Add text item on first text delta + if !textItemAdded { + textItemAdded = true + nextOutputIdx++ + } fullText += delta.Text s.sendSSE(w, flusher, &seq, "response.output_text.delta", &api.StreamEvent{ Type: "response.output_text.delta", @@ -244,6 +262,53 @@ loop: Delta: delta.Text, }) } + + // Handle tool call delta + if delta.ToolCallDelta != nil { + tc := delta.ToolCallDelta + + // First chunk for this tool call index + if _, exists := toolCallsInProgress[tc.Index]; !exists { + toolItemID := generateID("item_") + toolOutputIdx := nextOutputIdx + nextOutputIdx++ + + // Send response.output_item.added + s.sendSSE(w, flusher, &seq, "response.output_item.added", &api.StreamEvent{ + Type: "response.output_item.added", + OutputIndex: &toolOutputIdx, + Item: &api.OutputItem{ + ID: toolItemID, + Type: "function_call", + Status: "in_progress", + CallID: tc.ID, + Name: tc.Name, + }, + }) + + toolCallsInProgress[tc.Index] = &toolCallBuilder{ + itemID: toolItemID, + id: tc.ID, + name: tc.Name, + arguments: "", + } + } + + // Send function_call_arguments.delta + if tc.Arguments != "" { + builder := toolCallsInProgress[tc.Index] + builder.arguments += tc.Arguments + toolOutputIdx := outputIdx + 1 + tc.Index + + s.sendSSE(w, flusher, &seq, "response.function_call_arguments.delta", &api.StreamEvent{ + Type: "response.function_call_arguments.delta", + ItemID: builder.itemID, + OutputIndex: &toolOutputIdx, + Delta: tc.Arguments, + }) + } + } + if delta.Done { break loop } @@ -277,54 +342,108 @@ loop: return } - // response.output_text.done - s.sendSSE(w, flusher, &seq, "response.output_text.done", &api.StreamEvent{ - Type: "response.output_text.done", - ItemID: itemID, - OutputIndex: &outputIdx, - ContentIndex: &contentIdx, - Text: fullText, - }) + // Send done events for text output if text was added + if textItemAdded && fullText != "" { + // response.output_text.done + s.sendSSE(w, flusher, &seq, "response.output_text.done", &api.StreamEvent{ + Type: "response.output_text.done", + ItemID: itemID, + OutputIndex: &outputIdx, + ContentIndex: &contentIdx, + Text: fullText, + }) - // response.content_part.done - completedPart := &api.ContentPart{ - Type: "output_text", - Text: fullText, - Annotations: []api.Annotation{}, - } - s.sendSSE(w, flusher, &seq, "response.content_part.done", &api.StreamEvent{ - Type: "response.content_part.done", - ItemID: itemID, - OutputIndex: &outputIdx, - ContentIndex: &contentIdx, - Part: completedPart, - }) + // response.content_part.done + completedPart := &api.ContentPart{ + Type: "output_text", + Text: fullText, + Annotations: []api.Annotation{}, + } + s.sendSSE(w, flusher, &seq, "response.content_part.done", &api.StreamEvent{ + Type: "response.content_part.done", + ItemID: itemID, + OutputIndex: &outputIdx, + ContentIndex: &contentIdx, + Part: completedPart, + }) - // response.output_item.done - completedItem := &api.OutputItem{ - ID: itemID, - Type: "message", - Status: "completed", - Role: "assistant", - Content: []api.ContentPart{*completedPart}, + // response.output_item.done + completedItem := &api.OutputItem{ + ID: itemID, + Type: "message", + Status: "completed", + Role: "assistant", + Content: []api.ContentPart{*completedPart}, + } + s.sendSSE(w, flusher, &seq, "response.output_item.done", &api.StreamEvent{ + Type: "response.output_item.done", + OutputIndex: &outputIdx, + Item: completedItem, + }) + } + + // Send done events for each tool call + for idx, builder := range toolCallsInProgress { + toolOutputIdx := outputIdx + 1 + idx + + s.sendSSE(w, flusher, &seq, "response.function_call_arguments.done", &api.StreamEvent{ + Type: "response.function_call_arguments.done", + ItemID: builder.itemID, + OutputIndex: &toolOutputIdx, + Arguments: builder.arguments, + }) + + s.sendSSE(w, flusher, &seq, "response.output_item.done", &api.StreamEvent{ + Type: "response.output_item.done", + OutputIndex: &toolOutputIdx, + Item: &api.OutputItem{ + ID: builder.itemID, + Type: "function_call", + Status: "completed", + CallID: builder.id, + Name: builder.name, + Arguments: builder.arguments, + }, + }) } - s.sendSSE(w, flusher, &seq, "response.output_item.done", &api.StreamEvent{ - Type: "response.output_item.done", - OutputIndex: &outputIdx, - Item: completedItem, - }) // Build final completed response model := origReq.Model if providerModel != "" { model = providerModel } + + // Collect tool calls for result + var toolCalls []api.ToolCall + for _, builder := range toolCallsInProgress { + toolCalls = append(toolCalls, api.ToolCall{ + ID: builder.id, + Name: builder.name, + Arguments: builder.arguments, + }) + } + finalResult := &api.ProviderResult{ - Model: model, - Text: fullText, + Model: model, + Text: fullText, + ToolCalls: toolCalls, } completedResp := s.buildResponse(origReq, finalResult, provider.Name(), responseID) - completedResp.Output[0].ID = itemID + + // Update item IDs to match what we sent during streaming + if textItemAdded && len(completedResp.Output) > 0 { + completedResp.Output[0].ID = itemID + } + for idx, builder := range toolCallsInProgress { + // Find the corresponding output item + for i := range completedResp.Output { + if completedResp.Output[i].Type == "function_call" && completedResp.Output[i].CallID == builder.id { + completedResp.Output[i].ID = builder.itemID + break + } + } + _ = idx // unused + } // response.completed s.sendSSE(w, flusher, &seq, "response.completed", &api.StreamEvent{ @@ -363,18 +482,34 @@ func (s *GatewayServer) buildResponse(req *api.ResponseRequest, result *api.Prov model = req.Model } - // Build output item - itemID := generateID("msg_") - outputItem := api.OutputItem{ - ID: itemID, - Type: "message", - Status: "completed", - Role: "assistant", - Content: []api.ContentPart{{ - Type: "output_text", - Text: result.Text, - Annotations: []api.Annotation{}, - }}, + // Build output items array + outputItems := []api.OutputItem{} + + // Add message item if there's text + if result.Text != "" { + outputItems = append(outputItems, api.OutputItem{ + ID: generateID("msg_"), + Type: "message", + Status: "completed", + Role: "assistant", + Content: []api.ContentPart{{ + Type: "output_text", + Text: result.Text, + Annotations: []api.Annotation{}, + }}, + }) + } + + // Add function_call items + for _, tc := range result.ToolCalls { + outputItems = append(outputItems, api.OutputItem{ + ID: generateID("item_"), + Type: "function_call", + Status: "completed", + CallID: tc.ID, + Name: tc.Name, + Arguments: tc.Arguments, + }) } // Echo back request params with defaults @@ -454,7 +589,7 @@ func (s *GatewayServer) buildResponse(req *api.ResponseRequest, result *api.Prov Model: model, PreviousResponseID: req.PreviousResponseID, Instructions: req.Instructions, - Output: []api.OutputItem{outputItem}, + Output: outputItems, Error: nil, Tools: tools, ToolChoice: toolChoice,