Add OpenAI tool calling support

This commit is contained in:
2026-03-02 15:27:28 +00:00
parent 8ceb831e84
commit 157680bb13
4 changed files with 429 additions and 77 deletions

View File

@@ -96,6 +96,7 @@ type InputItem struct {
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"role"`
Content []ContentBlock `json:"content"` Content []ContentBlock `json:"content"`
CallID string `json:"call_id,omitempty"` // for tool messages
} }
// ContentBlock is a typed content element. // ContentBlock is a typed content element.
@@ -138,6 +139,7 @@ func (r *ResponseRequest) NormalizeInput() []Message {
msgs = append(msgs, Message{ msgs = append(msgs, Message{
Role: "tool", Role: "tool",
Content: []ContentBlock{{Type: "input_text", Text: item.Output}}, Content: []ContentBlock{{Type: "input_text", Text: item.Output}},
CallID: item.CallID,
}) })
} }
} }
@@ -188,11 +190,14 @@ type Response struct {
// OutputItem represents a typed item in the response output. // OutputItem represents a typed item in the response output.
type OutputItem struct { type OutputItem struct {
ID string `json:"id"` ID string `json:"id"`
Type string `json:"type"` Type string `json:"type"`
Status string `json:"status"` Status string `json:"status"`
Role string `json:"role,omitempty"` Role string `json:"role,omitempty"`
Content []ContentPart `json:"content,omitempty"` Content []ContentPart `json:"content,omitempty"`
CallID string `json:"call_id,omitempty"` // for function_call
Name string `json:"name,omitempty"` // for function_call
Arguments string `json:"arguments,omitempty"` // for function_call
} }
// ContentPart is a content block within an output item. // ContentPart is a content block within an output item.
@@ -259,6 +264,7 @@ type StreamEvent struct {
Part *ContentPart `json:"part,omitempty"` Part *ContentPart `json:"part,omitempty"`
Delta string `json:"delta,omitempty"` Delta string `json:"delta,omitempty"`
Text string `json:"text,omitempty"` Text string `json:"text,omitempty"`
Arguments string `json:"arguments,omitempty"` // for function_call_arguments.done
} }
// ============================================================ // ============================================================
@@ -267,19 +273,36 @@ type StreamEvent struct {
// ProviderResult is returned by Provider.Generate. // ProviderResult is returned by Provider.Generate.
type ProviderResult struct { type ProviderResult struct {
ID string ID string
Model string Model string
Text string Text string
Usage Usage Usage Usage
ToolCalls []ToolCall
} }
// ProviderStreamDelta is sent through the stream channel. // ProviderStreamDelta is sent through the stream channel.
type ProviderStreamDelta struct { type ProviderStreamDelta struct {
ID string ID string
Model string Model string
Text string Text string
Done bool Done bool
Usage *Usage Usage *Usage
ToolCallDelta *ToolCallDelta
}
// ToolCall represents a function call from the model.
type ToolCall struct {
ID string
Name string
Arguments string // JSON string
}
// ToolCallDelta represents a streaming chunk of a tool call.
type ToolCallDelta struct {
Index int
ID string
Name string
Arguments string
} }
// ============================================================ // ============================================================

View File

@@ -0,0 +1,120 @@
package openai
import (
"encoding/json"
"fmt"
"github.com/ajac-zero/latticelm/internal/api"
"github.com/openai/openai-go"
"github.com/openai/openai-go/shared"
)
// parseTools converts Open Responses tools to OpenAI format
func parseTools(req *api.ResponseRequest) ([]openai.ChatCompletionToolParam, error) {
if req.Tools == nil || len(req.Tools) == 0 {
return nil, nil
}
var toolDefs []map[string]interface{}
if err := json.Unmarshal(req.Tools, &toolDefs); err != nil {
return nil, fmt.Errorf("unmarshal tools: %w", err)
}
var tools []openai.ChatCompletionToolParam
for _, td := range toolDefs {
// Convert Open Responses tool to OpenAI ChatCompletionToolParam
// Extract: name, description, parameters
name, _ := td["name"].(string)
desc, _ := td["description"].(string)
params, _ := td["parameters"].(map[string]interface{})
tool := openai.ChatCompletionToolParam{
Function: shared.FunctionDefinitionParam{
Name: name,
},
}
if desc != "" {
tool.Function.Description = openai.String(desc)
}
if params != nil {
tool.Function.Parameters = shared.FunctionParameters(params)
}
tools = append(tools, tool)
}
return tools, nil
}
// parseToolChoice converts Open Responses tool_choice to OpenAI format
func parseToolChoice(req *api.ResponseRequest) (openai.ChatCompletionToolChoiceOptionUnionParam, error) {
var result openai.ChatCompletionToolChoiceOptionUnionParam
if req.ToolChoice == nil || len(req.ToolChoice) == 0 {
return result, nil
}
var choice interface{}
if err := json.Unmarshal(req.ToolChoice, &choice); err != nil {
return result, fmt.Errorf("unmarshal tool_choice: %w", err)
}
// Handle string values: "auto", "none", "required"
if str, ok := choice.(string); ok {
result.OfAuto = openai.String(str)
return result, nil
}
// Handle specific function selection: {"type": "function", "name": "..."}
if obj, ok := choice.(map[string]interface{}); ok {
funcObj, _ := obj["function"].(map[string]interface{})
name, _ := funcObj["name"].(string)
result.OfChatCompletionNamedToolChoice = &openai.ChatCompletionNamedToolChoiceParam{
Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
Name: name,
},
}
return result, nil
}
return result, fmt.Errorf("invalid tool_choice format")
}
// extractToolCalls converts OpenAI tool calls to api.ToolCall
func extractToolCalls(message openai.ChatCompletionMessage) []api.ToolCall {
if len(message.ToolCalls) == 0 {
return nil
}
var toolCalls []api.ToolCall
for _, tc := range message.ToolCalls {
toolCalls = append(toolCalls, api.ToolCall{
ID: tc.ID,
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
})
}
return toolCalls
}
// extractToolCallDelta extracts tool call delta from streaming chunk choice
func extractToolCallDelta(choice openai.ChatCompletionChunkChoice) *api.ToolCallDelta {
if len(choice.Delta.ToolCalls) == 0 {
return nil
}
// OpenAI sends tool calls with index in the delta
for _, tc := range choice.Delta.ToolCalls {
return &api.ToolCallDelta{
Index: int(tc.Index),
ID: tc.ID,
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
}
}
return nil
}

View File

@@ -91,6 +91,8 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
oaiMessages = append(oaiMessages, openai.SystemMessage(content)) oaiMessages = append(oaiMessages, openai.SystemMessage(content))
case "developer": case "developer":
oaiMessages = append(oaiMessages, openai.SystemMessage(content)) oaiMessages = append(oaiMessages, openai.SystemMessage(content))
case "tool":
oaiMessages = append(oaiMessages, openai.ToolMessage(content, msg.CallID))
} }
} }
@@ -108,6 +110,29 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
params.TopP = openai.Float(*req.TopP) params.TopP = openai.Float(*req.TopP)
} }
// Add tools if present
if req.Tools != nil && len(req.Tools) > 0 {
tools, err := parseTools(req)
if err != nil {
return nil, fmt.Errorf("parse tools: %w", err)
}
params.Tools = tools
}
// Add tool_choice if present
if req.ToolChoice != nil && len(req.ToolChoice) > 0 {
toolChoice, err := parseToolChoice(req)
if err != nil {
return nil, fmt.Errorf("parse tool_choice: %w", err)
}
params.ToolChoice = toolChoice
}
// Add parallel_tool_calls if specified
if req.ParallelToolCalls != nil {
params.ParallelToolCalls = openai.Bool(*req.ParallelToolCalls)
}
// Call OpenAI API // Call OpenAI API
resp, err := p.client.Chat.Completions.New(ctx, params) resp, err := p.client.Chat.Completions.New(ctx, params)
if err != nil { if err != nil {
@@ -115,14 +140,20 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
} }
var combinedText string var combinedText string
var toolCalls []api.ToolCall
for _, choice := range resp.Choices { for _, choice := range resp.Choices {
combinedText += choice.Message.Content combinedText += choice.Message.Content
if len(choice.Message.ToolCalls) > 0 {
toolCalls = append(toolCalls, extractToolCalls(choice.Message)...)
}
} }
return &api.ProviderResult{ return &api.ProviderResult{
ID: resp.ID, ID: resp.ID,
Model: resp.Model, Model: resp.Model,
Text: combinedText, Text: combinedText,
ToolCalls: toolCalls,
Usage: api.Usage{ Usage: api.Usage{
InputTokens: int(resp.Usage.PromptTokens), InputTokens: int(resp.Usage.PromptTokens),
OutputTokens: int(resp.Usage.CompletionTokens), OutputTokens: int(resp.Usage.CompletionTokens),
@@ -168,6 +199,8 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
oaiMessages = append(oaiMessages, openai.SystemMessage(content)) oaiMessages = append(oaiMessages, openai.SystemMessage(content))
case "developer": case "developer":
oaiMessages = append(oaiMessages, openai.SystemMessage(content)) oaiMessages = append(oaiMessages, openai.SystemMessage(content))
case "tool":
oaiMessages = append(oaiMessages, openai.ToolMessage(content, msg.CallID))
} }
} }
@@ -185,6 +218,31 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
params.TopP = openai.Float(*req.TopP) params.TopP = openai.Float(*req.TopP)
} }
// Add tools if present
if req.Tools != nil && len(req.Tools) > 0 {
tools, err := parseTools(req)
if err != nil {
errChan <- fmt.Errorf("parse tools: %w", err)
return
}
params.Tools = tools
}
// Add tool_choice if present
if req.ToolChoice != nil && len(req.ToolChoice) > 0 {
toolChoice, err := parseToolChoice(req)
if err != nil {
errChan <- fmt.Errorf("parse tool_choice: %w", err)
return
}
params.ToolChoice = toolChoice
}
// Add parallel_tool_calls if specified
if req.ParallelToolCalls != nil {
params.ParallelToolCalls = openai.Bool(*req.ParallelToolCalls)
}
// Create streaming request // Create streaming request
stream := p.client.Chat.Completions.NewStreaming(ctx, params) stream := p.client.Chat.Completions.NewStreaming(ctx, params)
@@ -193,19 +251,35 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
chunk := stream.Current() chunk := stream.Current()
for _, choice := range chunk.Choices { for _, choice := range chunk.Choices {
if choice.Delta.Content == "" { // Handle text content
continue if choice.Delta.Content != "" {
select {
case deltaChan <- &api.ProviderStreamDelta{
ID: chunk.ID,
Model: chunk.Model,
Text: choice.Delta.Content,
}:
case <-ctx.Done():
errChan <- ctx.Err()
return
}
} }
select { // Handle tool call deltas
case deltaChan <- &api.ProviderStreamDelta{ if len(choice.Delta.ToolCalls) > 0 {
ID: chunk.ID, delta := extractToolCallDelta(choice)
Model: chunk.Model, if delta != nil {
Text: choice.Delta.Content, select {
}: case deltaChan <- &api.ProviderStreamDelta{
case <-ctx.Done(): ID: chunk.ID,
errChan <- ctx.Err() Model: chunk.Model,
return ToolCallDelta: delta,
}:
case <-ctx.Done():
errChan <- ctx.Err()
return
}
}
} }
} }
} }

View File

@@ -224,6 +224,17 @@ func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.R
var streamErr error var streamErr error
var providerModel string var providerModel string
// Track tool calls being built
type toolCallBuilder struct {
itemID string
id string
name string
arguments string
}
toolCallsInProgress := make(map[int]*toolCallBuilder)
nextOutputIdx := 0
textItemAdded := false
loop: loop:
for { for {
select { select {
@@ -234,7 +245,14 @@ loop:
if delta.Model != "" && providerModel == "" { if delta.Model != "" && providerModel == "" {
providerModel = delta.Model providerModel = delta.Model
} }
// Handle text content
if delta.Text != "" { if delta.Text != "" {
// Add text item on first text delta
if !textItemAdded {
textItemAdded = true
nextOutputIdx++
}
fullText += delta.Text fullText += delta.Text
s.sendSSE(w, flusher, &seq, "response.output_text.delta", &api.StreamEvent{ s.sendSSE(w, flusher, &seq, "response.output_text.delta", &api.StreamEvent{
Type: "response.output_text.delta", Type: "response.output_text.delta",
@@ -244,6 +262,53 @@ loop:
Delta: delta.Text, Delta: delta.Text,
}) })
} }
// Handle tool call delta
if delta.ToolCallDelta != nil {
tc := delta.ToolCallDelta
// First chunk for this tool call index
if _, exists := toolCallsInProgress[tc.Index]; !exists {
toolItemID := generateID("item_")
toolOutputIdx := nextOutputIdx
nextOutputIdx++
// Send response.output_item.added
s.sendSSE(w, flusher, &seq, "response.output_item.added", &api.StreamEvent{
Type: "response.output_item.added",
OutputIndex: &toolOutputIdx,
Item: &api.OutputItem{
ID: toolItemID,
Type: "function_call",
Status: "in_progress",
CallID: tc.ID,
Name: tc.Name,
},
})
toolCallsInProgress[tc.Index] = &toolCallBuilder{
itemID: toolItemID,
id: tc.ID,
name: tc.Name,
arguments: "",
}
}
// Send function_call_arguments.delta
if tc.Arguments != "" {
builder := toolCallsInProgress[tc.Index]
builder.arguments += tc.Arguments
toolOutputIdx := outputIdx + 1 + tc.Index
s.sendSSE(w, flusher, &seq, "response.function_call_arguments.delta", &api.StreamEvent{
Type: "response.function_call_arguments.delta",
ItemID: builder.itemID,
OutputIndex: &toolOutputIdx,
Delta: tc.Arguments,
})
}
}
if delta.Done { if delta.Done {
break loop break loop
} }
@@ -277,54 +342,108 @@ loop:
return return
} }
// response.output_text.done // Send done events for text output if text was added
s.sendSSE(w, flusher, &seq, "response.output_text.done", &api.StreamEvent{ if textItemAdded && fullText != "" {
Type: "response.output_text.done", // response.output_text.done
ItemID: itemID, s.sendSSE(w, flusher, &seq, "response.output_text.done", &api.StreamEvent{
OutputIndex: &outputIdx, Type: "response.output_text.done",
ContentIndex: &contentIdx, ItemID: itemID,
Text: fullText, OutputIndex: &outputIdx,
}) ContentIndex: &contentIdx,
Text: fullText,
})
// response.content_part.done // response.content_part.done
completedPart := &api.ContentPart{ completedPart := &api.ContentPart{
Type: "output_text", Type: "output_text",
Text: fullText, Text: fullText,
Annotations: []api.Annotation{}, Annotations: []api.Annotation{},
} }
s.sendSSE(w, flusher, &seq, "response.content_part.done", &api.StreamEvent{ s.sendSSE(w, flusher, &seq, "response.content_part.done", &api.StreamEvent{
Type: "response.content_part.done", Type: "response.content_part.done",
ItemID: itemID, ItemID: itemID,
OutputIndex: &outputIdx, OutputIndex: &outputIdx,
ContentIndex: &contentIdx, ContentIndex: &contentIdx,
Part: completedPart, Part: completedPart,
}) })
// response.output_item.done // response.output_item.done
completedItem := &api.OutputItem{ completedItem := &api.OutputItem{
ID: itemID, ID: itemID,
Type: "message", Type: "message",
Status: "completed", Status: "completed",
Role: "assistant", Role: "assistant",
Content: []api.ContentPart{*completedPart}, Content: []api.ContentPart{*completedPart},
}
s.sendSSE(w, flusher, &seq, "response.output_item.done", &api.StreamEvent{
Type: "response.output_item.done",
OutputIndex: &outputIdx,
Item: completedItem,
})
}
// Send done events for each tool call
for idx, builder := range toolCallsInProgress {
toolOutputIdx := outputIdx + 1 + idx
s.sendSSE(w, flusher, &seq, "response.function_call_arguments.done", &api.StreamEvent{
Type: "response.function_call_arguments.done",
ItemID: builder.itemID,
OutputIndex: &toolOutputIdx,
Arguments: builder.arguments,
})
s.sendSSE(w, flusher, &seq, "response.output_item.done", &api.StreamEvent{
Type: "response.output_item.done",
OutputIndex: &toolOutputIdx,
Item: &api.OutputItem{
ID: builder.itemID,
Type: "function_call",
Status: "completed",
CallID: builder.id,
Name: builder.name,
Arguments: builder.arguments,
},
})
} }
s.sendSSE(w, flusher, &seq, "response.output_item.done", &api.StreamEvent{
Type: "response.output_item.done",
OutputIndex: &outputIdx,
Item: completedItem,
})
// Build final completed response // Build final completed response
model := origReq.Model model := origReq.Model
if providerModel != "" { if providerModel != "" {
model = providerModel model = providerModel
} }
// Collect tool calls for result
var toolCalls []api.ToolCall
for _, builder := range toolCallsInProgress {
toolCalls = append(toolCalls, api.ToolCall{
ID: builder.id,
Name: builder.name,
Arguments: builder.arguments,
})
}
finalResult := &api.ProviderResult{ finalResult := &api.ProviderResult{
Model: model, Model: model,
Text: fullText, Text: fullText,
ToolCalls: toolCalls,
} }
completedResp := s.buildResponse(origReq, finalResult, provider.Name(), responseID) completedResp := s.buildResponse(origReq, finalResult, provider.Name(), responseID)
completedResp.Output[0].ID = itemID
// Update item IDs to match what we sent during streaming
if textItemAdded && len(completedResp.Output) > 0 {
completedResp.Output[0].ID = itemID
}
for idx, builder := range toolCallsInProgress {
// Find the corresponding output item
for i := range completedResp.Output {
if completedResp.Output[i].Type == "function_call" && completedResp.Output[i].CallID == builder.id {
completedResp.Output[i].ID = builder.itemID
break
}
}
_ = idx // unused
}
// response.completed // response.completed
s.sendSSE(w, flusher, &seq, "response.completed", &api.StreamEvent{ s.sendSSE(w, flusher, &seq, "response.completed", &api.StreamEvent{
@@ -363,18 +482,34 @@ func (s *GatewayServer) buildResponse(req *api.ResponseRequest, result *api.Prov
model = req.Model model = req.Model
} }
// Build output item // Build output items array
itemID := generateID("msg_") outputItems := []api.OutputItem{}
outputItem := api.OutputItem{
ID: itemID, // Add message item if there's text
Type: "message", if result.Text != "" {
Status: "completed", outputItems = append(outputItems, api.OutputItem{
Role: "assistant", ID: generateID("msg_"),
Content: []api.ContentPart{{ Type: "message",
Type: "output_text", Status: "completed",
Text: result.Text, Role: "assistant",
Annotations: []api.Annotation{}, Content: []api.ContentPart{{
}}, Type: "output_text",
Text: result.Text,
Annotations: []api.Annotation{},
}},
})
}
// Add function_call items
for _, tc := range result.ToolCalls {
outputItems = append(outputItems, api.OutputItem{
ID: generateID("item_"),
Type: "function_call",
Status: "completed",
CallID: tc.ID,
Name: tc.Name,
Arguments: tc.Arguments,
})
} }
// Echo back request params with defaults // Echo back request params with defaults
@@ -454,7 +589,7 @@ func (s *GatewayServer) buildResponse(req *api.ResponseRequest, result *api.Prov
Model: model, Model: model,
PreviousResponseID: req.PreviousResponseID, PreviousResponseID: req.PreviousResponseID,
Instructions: req.Instructions, Instructions: req.Instructions,
Output: []api.OutputItem{outputItem}, Output: outputItems,
Error: nil, Error: nil,
Tools: tools, Tools: tools,
ToolChoice: toolChoice, ToolChoice: toolChoice,