Fix tool calling

This commit is contained in:
2026-03-02 17:14:20 +00:00
parent 6adf7eae54
commit cf263a2a8d
5 changed files with 423 additions and 58 deletions

View File

@@ -94,9 +94,10 @@ type InputItem struct {
// Message is the normalized internal message representation.
type Message struct {
Role string `json:"role"`
Content []ContentBlock `json:"content"`
CallID string `json:"call_id,omitempty"` // for tool messages
Role string `json:"role"`
Content []ContentBlock `json:"content"`
CallID string `json:"call_id,omitempty"` // for tool messages
ToolCalls []ToolCall `json:"tool_calls,omitempty"` // for assistant messages
}
// ContentBlock is a typed content element.
@@ -129,9 +130,35 @@ func (r *ResponseRequest) NormalizeInput() []Message {
}
msg.Content = []ContentBlock{{Type: contentType, Text: s}}
} else {
var blocks []ContentBlock
_ = json.Unmarshal(item.Content, &blocks)
msg.Content = blocks
// Content is an array of blocks - parse them
var rawBlocks []map[string]interface{}
if err := json.Unmarshal(item.Content, &rawBlocks); err == nil {
// Extract content blocks and tool calls
for _, block := range rawBlocks {
blockType, _ := block["type"].(string)
if blockType == "tool_use" {
// Extract tool call information
toolCall := ToolCall{
ID: getStringField(block, "id"),
Name: getStringField(block, "name"),
}
// input field contains the arguments as a map
if input, ok := block["input"].(map[string]interface{}); ok {
if inputJSON, err := json.Marshal(input); err == nil {
toolCall.Arguments = string(inputJSON)
}
}
msg.ToolCalls = append(msg.ToolCalls, toolCall)
} else if blockType == "output_text" || blockType == "input_text" {
// Regular text content block
msg.Content = append(msg.Content, ContentBlock{
Type: blockType,
Text: getStringField(block, "text"),
})
}
}
}
}
}
msgs = append(msgs, msg)
@@ -338,3 +365,11 @@ func (r *ResponseRequest) Validate() error {
}
return nil
}
// getStringField is a helper to safely extract string fields from a map
func getStringField(m map[string]interface{}, key string) string {
if val, ok := m[key].(string); ok {
return val
}
return ""
}

View File

@@ -85,7 +85,23 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
case "user":
anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage(anthropic.NewTextBlock(content)))
case "assistant":
anthropicMsgs = append(anthropicMsgs, anthropic.NewAssistantMessage(anthropic.NewTextBlock(content)))
// Build content blocks including text and tool calls
var contentBlocks []anthropic.ContentBlockParamUnion
if content != "" {
contentBlocks = append(contentBlocks, anthropic.NewTextBlock(content))
}
// Add tool use blocks
for _, tc := range msg.ToolCalls {
var input map[string]interface{}
if err := json.Unmarshal([]byte(tc.Arguments), &input); err != nil {
// If unmarshal fails, skip this tool call
continue
}
contentBlocks = append(contentBlocks, anthropic.NewToolUseBlock(tc.ID, input, tc.Name))
}
if len(contentBlocks) > 0 {
anthropicMsgs = append(anthropicMsgs, anthropic.NewAssistantMessage(contentBlocks...))
}
case "tool":
// Tool results must be in user message with tool_result blocks
anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage(
@@ -213,7 +229,23 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
case "user":
anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage(anthropic.NewTextBlock(content)))
case "assistant":
anthropicMsgs = append(anthropicMsgs, anthropic.NewAssistantMessage(anthropic.NewTextBlock(content)))
// Build content blocks including text and tool calls
var contentBlocks []anthropic.ContentBlockParamUnion
if content != "" {
contentBlocks = append(contentBlocks, anthropic.NewTextBlock(content))
}
// Add tool use blocks
for _, tc := range msg.ToolCalls {
var input map[string]interface{}
if err := json.Unmarshal([]byte(tc.Arguments), &input); err != nil {
// If unmarshal fails, skip this tool call
continue
}
contentBlocks = append(contentBlocks, anthropic.NewToolUseBlock(tc.ID, input, tc.Name))
}
if len(contentBlocks) > 0 {
anthropicMsgs = append(anthropicMsgs, anthropic.NewAssistantMessage(contentBlocks...))
}
case "tool":
// Tool results must be in user message with tool_result blocks
anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage(

View File

@@ -86,7 +86,32 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
case "user":
oaiMessages = append(oaiMessages, openai.UserMessage(content))
case "assistant":
oaiMessages = append(oaiMessages, openai.AssistantMessage(content))
// If assistant message has tool calls, include them
if len(msg.ToolCalls) > 0 {
toolCalls := make([]openai.ChatCompletionMessageToolCallUnionParam, len(msg.ToolCalls))
for i, tc := range msg.ToolCalls {
toolCalls[i] = openai.ChatCompletionMessageToolCallUnionParam{
OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
ID: tc.ID,
Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
Name: tc.Name,
Arguments: tc.Arguments,
},
},
}
}
msgParam := openai.ChatCompletionAssistantMessageParam{
ToolCalls: toolCalls,
}
if content != "" {
msgParam.Content.OfString = openai.String(content)
}
oaiMessages = append(oaiMessages, openai.ChatCompletionMessageParamUnion{
OfAssistant: &msgParam,
})
} else {
oaiMessages = append(oaiMessages, openai.AssistantMessage(content))
}
case "system":
oaiMessages = append(oaiMessages, openai.SystemMessage(content))
case "developer":
@@ -194,7 +219,32 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
case "user":
oaiMessages = append(oaiMessages, openai.UserMessage(content))
case "assistant":
oaiMessages = append(oaiMessages, openai.AssistantMessage(content))
// If assistant message has tool calls, include them
if len(msg.ToolCalls) > 0 {
toolCalls := make([]openai.ChatCompletionMessageToolCallUnionParam, len(msg.ToolCalls))
for i, tc := range msg.ToolCalls {
toolCalls[i] = openai.ChatCompletionMessageToolCallUnionParam{
OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
ID: tc.ID,
Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
Name: tc.Name,
Arguments: tc.Arguments,
},
},
}
}
msgParam := openai.ChatCompletionAssistantMessageParam{
ToolCalls: toolCalls,
}
if content != "" {
msgParam.Content.OfString = openai.String(content)
}
oaiMessages = append(oaiMessages, openai.ChatCompletionMessageParamUnion{
OfAssistant: &msgParam,
})
} else {
oaiMessages = append(oaiMessages, openai.AssistantMessage(content))
}
case "system":
oaiMessages = append(oaiMessages, openai.SystemMessage(content))
case "developer":

View File

@@ -141,8 +141,9 @@ func (s *GatewayServer) handleSyncResponse(w http.ResponseWriter, r *http.Reques
// Build assistant message for conversation store
assistantMsg := api.Message{
Role: "assistant",
Content: []api.ContentBlock{{Type: "output_text", Text: result.Text}},
Role: "assistant",
Content: []api.ContentBlock{{Type: "output_text", Text: result.Text}},
ToolCalls: result.ToolCalls,
}
allMsgs := append(storeMsgs, assistantMsg)
if _, err := s.convs.Create(responseID, result.Model, allMsgs); err != nil {
@@ -460,10 +461,11 @@ loop:
})
// Store conversation
if fullText != "" {
if fullText != "" || len(toolCalls) > 0 {
assistantMsg := api.Message{
Role: "assistant",
Content: []api.ContentBlock{{Type: "output_text", Text: fullText}},
Role: "assistant",
Content: []api.ContentBlock{{Type: "output_text", Text: fullText}},
ToolCalls: toolCalls,
}
allMsgs := append(storeMsgs, assistantMsg)
if _, err := s.convs.Create(responseID, model, allMsgs); err != nil {