Add OpenAI tool calling support

This commit is contained in:
2026-03-02 15:27:28 +00:00
parent 8ceb831e84
commit 157680bb13
4 changed files with 429 additions and 77 deletions

View File

@@ -224,6 +224,17 @@ func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.R
var streamErr error
var providerModel string
// Track tool calls being built
type toolCallBuilder struct {
itemID string
id string
name string
arguments string
}
toolCallsInProgress := make(map[int]*toolCallBuilder)
nextOutputIdx := 0
textItemAdded := false
loop:
for {
select {
@@ -234,7 +245,14 @@ loop:
if delta.Model != "" && providerModel == "" {
providerModel = delta.Model
}
// Handle text content
if delta.Text != "" {
// Add text item on first text delta
if !textItemAdded {
textItemAdded = true
nextOutputIdx++
}
fullText += delta.Text
s.sendSSE(w, flusher, &seq, "response.output_text.delta", &api.StreamEvent{
Type: "response.output_text.delta",
@@ -244,6 +262,53 @@ loop:
Delta: delta.Text,
})
}
// Handle tool call delta
if delta.ToolCallDelta != nil {
tc := delta.ToolCallDelta
// First chunk for this tool call index
if _, exists := toolCallsInProgress[tc.Index]; !exists {
toolItemID := generateID("item_")
toolOutputIdx := nextOutputIdx
nextOutputIdx++
// Send response.output_item.added
s.sendSSE(w, flusher, &seq, "response.output_item.added", &api.StreamEvent{
Type: "response.output_item.added",
OutputIndex: &toolOutputIdx,
Item: &api.OutputItem{
ID: toolItemID,
Type: "function_call",
Status: "in_progress",
CallID: tc.ID,
Name: tc.Name,
},
})
toolCallsInProgress[tc.Index] = &toolCallBuilder{
itemID: toolItemID,
id: tc.ID,
name: tc.Name,
arguments: "",
}
}
// Send function_call_arguments.delta
if tc.Arguments != "" {
builder := toolCallsInProgress[tc.Index]
builder.arguments += tc.Arguments
toolOutputIdx := outputIdx + 1 + tc.Index
s.sendSSE(w, flusher, &seq, "response.function_call_arguments.delta", &api.StreamEvent{
Type: "response.function_call_arguments.delta",
ItemID: builder.itemID,
OutputIndex: &toolOutputIdx,
Delta: tc.Arguments,
})
}
}
if delta.Done {
break loop
}
@@ -277,54 +342,108 @@ loop:
return
}
// response.output_text.done
s.sendSSE(w, flusher, &seq, "response.output_text.done", &api.StreamEvent{
Type: "response.output_text.done",
ItemID: itemID,
OutputIndex: &outputIdx,
ContentIndex: &contentIdx,
Text: fullText,
})
// Send done events for text output if text was added
if textItemAdded && fullText != "" {
// response.output_text.done
s.sendSSE(w, flusher, &seq, "response.output_text.done", &api.StreamEvent{
Type: "response.output_text.done",
ItemID: itemID,
OutputIndex: &outputIdx,
ContentIndex: &contentIdx,
Text: fullText,
})
// response.content_part.done
completedPart := &api.ContentPart{
Type: "output_text",
Text: fullText,
Annotations: []api.Annotation{},
}
s.sendSSE(w, flusher, &seq, "response.content_part.done", &api.StreamEvent{
Type: "response.content_part.done",
ItemID: itemID,
OutputIndex: &outputIdx,
ContentIndex: &contentIdx,
Part: completedPart,
})
// response.content_part.done
completedPart := &api.ContentPart{
Type: "output_text",
Text: fullText,
Annotations: []api.Annotation{},
}
s.sendSSE(w, flusher, &seq, "response.content_part.done", &api.StreamEvent{
Type: "response.content_part.done",
ItemID: itemID,
OutputIndex: &outputIdx,
ContentIndex: &contentIdx,
Part: completedPart,
})
// response.output_item.done
completedItem := &api.OutputItem{
ID: itemID,
Type: "message",
Status: "completed",
Role: "assistant",
Content: []api.ContentPart{*completedPart},
// response.output_item.done
completedItem := &api.OutputItem{
ID: itemID,
Type: "message",
Status: "completed",
Role: "assistant",
Content: []api.ContentPart{*completedPart},
}
s.sendSSE(w, flusher, &seq, "response.output_item.done", &api.StreamEvent{
Type: "response.output_item.done",
OutputIndex: &outputIdx,
Item: completedItem,
})
}
// Send done events for each tool call
for idx, builder := range toolCallsInProgress {
toolOutputIdx := outputIdx + 1 + idx
s.sendSSE(w, flusher, &seq, "response.function_call_arguments.done", &api.StreamEvent{
Type: "response.function_call_arguments.done",
ItemID: builder.itemID,
OutputIndex: &toolOutputIdx,
Arguments: builder.arguments,
})
s.sendSSE(w, flusher, &seq, "response.output_item.done", &api.StreamEvent{
Type: "response.output_item.done",
OutputIndex: &toolOutputIdx,
Item: &api.OutputItem{
ID: builder.itemID,
Type: "function_call",
Status: "completed",
CallID: builder.id,
Name: builder.name,
Arguments: builder.arguments,
},
})
}
s.sendSSE(w, flusher, &seq, "response.output_item.done", &api.StreamEvent{
Type: "response.output_item.done",
OutputIndex: &outputIdx,
Item: completedItem,
})
// Build final completed response
model := origReq.Model
if providerModel != "" {
model = providerModel
}
// Collect tool calls for result
var toolCalls []api.ToolCall
for _, builder := range toolCallsInProgress {
toolCalls = append(toolCalls, api.ToolCall{
ID: builder.id,
Name: builder.name,
Arguments: builder.arguments,
})
}
finalResult := &api.ProviderResult{
Model: model,
Text: fullText,
Model: model,
Text: fullText,
ToolCalls: toolCalls,
}
completedResp := s.buildResponse(origReq, finalResult, provider.Name(), responseID)
completedResp.Output[0].ID = itemID
// Update item IDs to match what we sent during streaming
if textItemAdded && len(completedResp.Output) > 0 {
completedResp.Output[0].ID = itemID
}
for idx, builder := range toolCallsInProgress {
// Find the corresponding output item
for i := range completedResp.Output {
if completedResp.Output[i].Type == "function_call" && completedResp.Output[i].CallID == builder.id {
completedResp.Output[i].ID = builder.itemID
break
}
}
_ = idx // unused
}
// response.completed
s.sendSSE(w, flusher, &seq, "response.completed", &api.StreamEvent{
@@ -363,18 +482,34 @@ func (s *GatewayServer) buildResponse(req *api.ResponseRequest, result *api.Prov
model = req.Model
}
// Build output item
itemID := generateID("msg_")
outputItem := api.OutputItem{
ID: itemID,
Type: "message",
Status: "completed",
Role: "assistant",
Content: []api.ContentPart{{
Type: "output_text",
Text: result.Text,
Annotations: []api.Annotation{},
}},
// Build output items array
outputItems := []api.OutputItem{}
// Add message item if there's text
if result.Text != "" {
outputItems = append(outputItems, api.OutputItem{
ID: generateID("msg_"),
Type: "message",
Status: "completed",
Role: "assistant",
Content: []api.ContentPart{{
Type: "output_text",
Text: result.Text,
Annotations: []api.Annotation{},
}},
})
}
// Add function_call items
for _, tc := range result.ToolCalls {
outputItems = append(outputItems, api.OutputItem{
ID: generateID("item_"),
Type: "function_call",
Status: "completed",
CallID: tc.ID,
Name: tc.Name,
Arguments: tc.Arguments,
})
}
// Echo back request params with defaults
@@ -454,7 +589,7 @@ func (s *GatewayServer) buildResponse(req *api.ResponseRequest, result *api.Prov
Model: model,
PreviousResponseID: req.PreviousResponseID,
Instructions: req.Instructions,
Output: []api.OutputItem{outputItem},
Output: outputItems,
Error: nil,
Tools: tools,
ToolChoice: toolChoice,