Add OpenAI tool calling support
This commit is contained in:
@@ -224,6 +224,17 @@ func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.R
|
||||
var streamErr error
|
||||
var providerModel string
|
||||
|
||||
// Track tool calls being built
|
||||
type toolCallBuilder struct {
|
||||
itemID string
|
||||
id string
|
||||
name string
|
||||
arguments string
|
||||
}
|
||||
toolCallsInProgress := make(map[int]*toolCallBuilder)
|
||||
nextOutputIdx := 0
|
||||
textItemAdded := false
|
||||
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
@@ -234,7 +245,14 @@ loop:
|
||||
if delta.Model != "" && providerModel == "" {
|
||||
providerModel = delta.Model
|
||||
}
|
||||
|
||||
// Handle text content
|
||||
if delta.Text != "" {
|
||||
// Add text item on first text delta
|
||||
if !textItemAdded {
|
||||
textItemAdded = true
|
||||
nextOutputIdx++
|
||||
}
|
||||
fullText += delta.Text
|
||||
s.sendSSE(w, flusher, &seq, "response.output_text.delta", &api.StreamEvent{
|
||||
Type: "response.output_text.delta",
|
||||
@@ -244,6 +262,53 @@ loop:
|
||||
Delta: delta.Text,
|
||||
})
|
||||
}
|
||||
|
||||
// Handle tool call delta
|
||||
if delta.ToolCallDelta != nil {
|
||||
tc := delta.ToolCallDelta
|
||||
|
||||
// First chunk for this tool call index
|
||||
if _, exists := toolCallsInProgress[tc.Index]; !exists {
|
||||
toolItemID := generateID("item_")
|
||||
toolOutputIdx := nextOutputIdx
|
||||
nextOutputIdx++
|
||||
|
||||
// Send response.output_item.added
|
||||
s.sendSSE(w, flusher, &seq, "response.output_item.added", &api.StreamEvent{
|
||||
Type: "response.output_item.added",
|
||||
OutputIndex: &toolOutputIdx,
|
||||
Item: &api.OutputItem{
|
||||
ID: toolItemID,
|
||||
Type: "function_call",
|
||||
Status: "in_progress",
|
||||
CallID: tc.ID,
|
||||
Name: tc.Name,
|
||||
},
|
||||
})
|
||||
|
||||
toolCallsInProgress[tc.Index] = &toolCallBuilder{
|
||||
itemID: toolItemID,
|
||||
id: tc.ID,
|
||||
name: tc.Name,
|
||||
arguments: "",
|
||||
}
|
||||
}
|
||||
|
||||
// Send function_call_arguments.delta
|
||||
if tc.Arguments != "" {
|
||||
builder := toolCallsInProgress[tc.Index]
|
||||
builder.arguments += tc.Arguments
|
||||
toolOutputIdx := outputIdx + 1 + tc.Index
|
||||
|
||||
s.sendSSE(w, flusher, &seq, "response.function_call_arguments.delta", &api.StreamEvent{
|
||||
Type: "response.function_call_arguments.delta",
|
||||
ItemID: builder.itemID,
|
||||
OutputIndex: &toolOutputIdx,
|
||||
Delta: tc.Arguments,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if delta.Done {
|
||||
break loop
|
||||
}
|
||||
@@ -277,54 +342,108 @@ loop:
|
||||
return
|
||||
}
|
||||
|
||||
// response.output_text.done
|
||||
s.sendSSE(w, flusher, &seq, "response.output_text.done", &api.StreamEvent{
|
||||
Type: "response.output_text.done",
|
||||
ItemID: itemID,
|
||||
OutputIndex: &outputIdx,
|
||||
ContentIndex: &contentIdx,
|
||||
Text: fullText,
|
||||
})
|
||||
// Send done events for text output if text was added
|
||||
if textItemAdded && fullText != "" {
|
||||
// response.output_text.done
|
||||
s.sendSSE(w, flusher, &seq, "response.output_text.done", &api.StreamEvent{
|
||||
Type: "response.output_text.done",
|
||||
ItemID: itemID,
|
||||
OutputIndex: &outputIdx,
|
||||
ContentIndex: &contentIdx,
|
||||
Text: fullText,
|
||||
})
|
||||
|
||||
// response.content_part.done
|
||||
completedPart := &api.ContentPart{
|
||||
Type: "output_text",
|
||||
Text: fullText,
|
||||
Annotations: []api.Annotation{},
|
||||
}
|
||||
s.sendSSE(w, flusher, &seq, "response.content_part.done", &api.StreamEvent{
|
||||
Type: "response.content_part.done",
|
||||
ItemID: itemID,
|
||||
OutputIndex: &outputIdx,
|
||||
ContentIndex: &contentIdx,
|
||||
Part: completedPart,
|
||||
})
|
||||
// response.content_part.done
|
||||
completedPart := &api.ContentPart{
|
||||
Type: "output_text",
|
||||
Text: fullText,
|
||||
Annotations: []api.Annotation{},
|
||||
}
|
||||
s.sendSSE(w, flusher, &seq, "response.content_part.done", &api.StreamEvent{
|
||||
Type: "response.content_part.done",
|
||||
ItemID: itemID,
|
||||
OutputIndex: &outputIdx,
|
||||
ContentIndex: &contentIdx,
|
||||
Part: completedPart,
|
||||
})
|
||||
|
||||
// response.output_item.done
|
||||
completedItem := &api.OutputItem{
|
||||
ID: itemID,
|
||||
Type: "message",
|
||||
Status: "completed",
|
||||
Role: "assistant",
|
||||
Content: []api.ContentPart{*completedPart},
|
||||
// response.output_item.done
|
||||
completedItem := &api.OutputItem{
|
||||
ID: itemID,
|
||||
Type: "message",
|
||||
Status: "completed",
|
||||
Role: "assistant",
|
||||
Content: []api.ContentPart{*completedPart},
|
||||
}
|
||||
s.sendSSE(w, flusher, &seq, "response.output_item.done", &api.StreamEvent{
|
||||
Type: "response.output_item.done",
|
||||
OutputIndex: &outputIdx,
|
||||
Item: completedItem,
|
||||
})
|
||||
}
|
||||
|
||||
// Send done events for each tool call
|
||||
for idx, builder := range toolCallsInProgress {
|
||||
toolOutputIdx := outputIdx + 1 + idx
|
||||
|
||||
s.sendSSE(w, flusher, &seq, "response.function_call_arguments.done", &api.StreamEvent{
|
||||
Type: "response.function_call_arguments.done",
|
||||
ItemID: builder.itemID,
|
||||
OutputIndex: &toolOutputIdx,
|
||||
Arguments: builder.arguments,
|
||||
})
|
||||
|
||||
s.sendSSE(w, flusher, &seq, "response.output_item.done", &api.StreamEvent{
|
||||
Type: "response.output_item.done",
|
||||
OutputIndex: &toolOutputIdx,
|
||||
Item: &api.OutputItem{
|
||||
ID: builder.itemID,
|
||||
Type: "function_call",
|
||||
Status: "completed",
|
||||
CallID: builder.id,
|
||||
Name: builder.name,
|
||||
Arguments: builder.arguments,
|
||||
},
|
||||
})
|
||||
}
|
||||
s.sendSSE(w, flusher, &seq, "response.output_item.done", &api.StreamEvent{
|
||||
Type: "response.output_item.done",
|
||||
OutputIndex: &outputIdx,
|
||||
Item: completedItem,
|
||||
})
|
||||
|
||||
// Build final completed response
|
||||
model := origReq.Model
|
||||
if providerModel != "" {
|
||||
model = providerModel
|
||||
}
|
||||
|
||||
// Collect tool calls for result
|
||||
var toolCalls []api.ToolCall
|
||||
for _, builder := range toolCallsInProgress {
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
ID: builder.id,
|
||||
Name: builder.name,
|
||||
Arguments: builder.arguments,
|
||||
})
|
||||
}
|
||||
|
||||
finalResult := &api.ProviderResult{
|
||||
Model: model,
|
||||
Text: fullText,
|
||||
Model: model,
|
||||
Text: fullText,
|
||||
ToolCalls: toolCalls,
|
||||
}
|
||||
completedResp := s.buildResponse(origReq, finalResult, provider.Name(), responseID)
|
||||
completedResp.Output[0].ID = itemID
|
||||
|
||||
// Update item IDs to match what we sent during streaming
|
||||
if textItemAdded && len(completedResp.Output) > 0 {
|
||||
completedResp.Output[0].ID = itemID
|
||||
}
|
||||
for idx, builder := range toolCallsInProgress {
|
||||
// Find the corresponding output item
|
||||
for i := range completedResp.Output {
|
||||
if completedResp.Output[i].Type == "function_call" && completedResp.Output[i].CallID == builder.id {
|
||||
completedResp.Output[i].ID = builder.itemID
|
||||
break
|
||||
}
|
||||
}
|
||||
_ = idx // unused
|
||||
}
|
||||
|
||||
// response.completed
|
||||
s.sendSSE(w, flusher, &seq, "response.completed", &api.StreamEvent{
|
||||
@@ -363,18 +482,34 @@ func (s *GatewayServer) buildResponse(req *api.ResponseRequest, result *api.Prov
|
||||
model = req.Model
|
||||
}
|
||||
|
||||
// Build output item
|
||||
itemID := generateID("msg_")
|
||||
outputItem := api.OutputItem{
|
||||
ID: itemID,
|
||||
Type: "message",
|
||||
Status: "completed",
|
||||
Role: "assistant",
|
||||
Content: []api.ContentPart{{
|
||||
Type: "output_text",
|
||||
Text: result.Text,
|
||||
Annotations: []api.Annotation{},
|
||||
}},
|
||||
// Build output items array
|
||||
outputItems := []api.OutputItem{}
|
||||
|
||||
// Add message item if there's text
|
||||
if result.Text != "" {
|
||||
outputItems = append(outputItems, api.OutputItem{
|
||||
ID: generateID("msg_"),
|
||||
Type: "message",
|
||||
Status: "completed",
|
||||
Role: "assistant",
|
||||
Content: []api.ContentPart{{
|
||||
Type: "output_text",
|
||||
Text: result.Text,
|
||||
Annotations: []api.Annotation{},
|
||||
}},
|
||||
})
|
||||
}
|
||||
|
||||
// Add function_call items
|
||||
for _, tc := range result.ToolCalls {
|
||||
outputItems = append(outputItems, api.OutputItem{
|
||||
ID: generateID("item_"),
|
||||
Type: "function_call",
|
||||
Status: "completed",
|
||||
CallID: tc.ID,
|
||||
Name: tc.Name,
|
||||
Arguments: tc.Arguments,
|
||||
})
|
||||
}
|
||||
|
||||
// Echo back request params with defaults
|
||||
@@ -454,7 +589,7 @@ func (s *GatewayServer) buildResponse(req *api.ResponseRequest, result *api.Prov
|
||||
Model: model,
|
||||
PreviousResponseID: req.PreviousResponseID,
|
||||
Instructions: req.Instructions,
|
||||
Output: []api.OutputItem{outputItem},
|
||||
Output: outputItems,
|
||||
Error: nil,
|
||||
Tools: tools,
|
||||
ToolChoice: toolChoice,
|
||||
|
||||
Reference in New Issue
Block a user