Fix tool calling
This commit is contained in:
@@ -94,9 +94,11 @@ type InputItem struct {
|
||||
|
||||
// Message is the normalized internal message representation.
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content []ContentBlock `json:"content"`
|
||||
CallID string `json:"call_id,omitempty"` // for tool messages
|
||||
Role string `json:"role"`
|
||||
Content []ContentBlock `json:"content"`
|
||||
CallID string `json:"call_id,omitempty"` // for tool messages
|
||||
Name string `json:"name,omitempty"` // for tool messages
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"` // for assistant messages
|
||||
}
|
||||
|
||||
// ContentBlock is a typed content element.
|
||||
@@ -129,9 +131,35 @@ func (r *ResponseRequest) NormalizeInput() []Message {
|
||||
}
|
||||
msg.Content = []ContentBlock{{Type: contentType, Text: s}}
|
||||
} else {
|
||||
var blocks []ContentBlock
|
||||
_ = json.Unmarshal(item.Content, &blocks)
|
||||
msg.Content = blocks
|
||||
// Content is an array of blocks - parse them
|
||||
var rawBlocks []map[string]interface{}
|
||||
if err := json.Unmarshal(item.Content, &rawBlocks); err == nil {
|
||||
// Extract content blocks and tool calls
|
||||
for _, block := range rawBlocks {
|
||||
blockType, _ := block["type"].(string)
|
||||
|
||||
if blockType == "tool_use" {
|
||||
// Extract tool call information
|
||||
toolCall := ToolCall{
|
||||
ID: getStringField(block, "id"),
|
||||
Name: getStringField(block, "name"),
|
||||
}
|
||||
// input field contains the arguments as a map
|
||||
if input, ok := block["input"].(map[string]interface{}); ok {
|
||||
if inputJSON, err := json.Marshal(input); err == nil {
|
||||
toolCall.Arguments = string(inputJSON)
|
||||
}
|
||||
}
|
||||
msg.ToolCalls = append(msg.ToolCalls, toolCall)
|
||||
} else if blockType == "output_text" || blockType == "input_text" {
|
||||
// Regular text content block
|
||||
msg.Content = append(msg.Content, ContentBlock{
|
||||
Type: blockType,
|
||||
Text: getStringField(block, "text"),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
msgs = append(msgs, msg)
|
||||
@@ -140,6 +168,7 @@ func (r *ResponseRequest) NormalizeInput() []Message {
|
||||
Role: "tool",
|
||||
Content: []ContentBlock{{Type: "input_text", Text: item.Output}},
|
||||
CallID: item.CallID,
|
||||
Name: item.Name,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -338,3 +367,11 @@ func (r *ResponseRequest) Validate() error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getStringField is a helper to safely extract string fields from a map
|
||||
func getStringField(m map[string]interface{}, key string) string {
|
||||
if val, ok := m[key].(string); ok {
|
||||
return val
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -85,7 +85,23 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
|
||||
case "user":
|
||||
anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage(anthropic.NewTextBlock(content)))
|
||||
case "assistant":
|
||||
anthropicMsgs = append(anthropicMsgs, anthropic.NewAssistantMessage(anthropic.NewTextBlock(content)))
|
||||
// Build content blocks including text and tool calls
|
||||
var contentBlocks []anthropic.ContentBlockParamUnion
|
||||
if content != "" {
|
||||
contentBlocks = append(contentBlocks, anthropic.NewTextBlock(content))
|
||||
}
|
||||
// Add tool use blocks
|
||||
for _, tc := range msg.ToolCalls {
|
||||
var input map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(tc.Arguments), &input); err != nil {
|
||||
// If unmarshal fails, skip this tool call
|
||||
continue
|
||||
}
|
||||
contentBlocks = append(contentBlocks, anthropic.NewToolUseBlock(tc.ID, input, tc.Name))
|
||||
}
|
||||
if len(contentBlocks) > 0 {
|
||||
anthropicMsgs = append(anthropicMsgs, anthropic.NewAssistantMessage(contentBlocks...))
|
||||
}
|
||||
case "tool":
|
||||
// Tool results must be in user message with tool_result blocks
|
||||
anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage(
|
||||
@@ -213,7 +229,23 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
|
||||
case "user":
|
||||
anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage(anthropic.NewTextBlock(content)))
|
||||
case "assistant":
|
||||
anthropicMsgs = append(anthropicMsgs, anthropic.NewAssistantMessage(anthropic.NewTextBlock(content)))
|
||||
// Build content blocks including text and tool calls
|
||||
var contentBlocks []anthropic.ContentBlockParamUnion
|
||||
if content != "" {
|
||||
contentBlocks = append(contentBlocks, anthropic.NewTextBlock(content))
|
||||
}
|
||||
// Add tool use blocks
|
||||
for _, tc := range msg.ToolCalls {
|
||||
var input map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(tc.Arguments), &input); err != nil {
|
||||
// If unmarshal fails, skip this tool call
|
||||
continue
|
||||
}
|
||||
contentBlocks = append(contentBlocks, anthropic.NewToolUseBlock(tc.ID, input, tc.Name))
|
||||
}
|
||||
if len(contentBlocks) > 0 {
|
||||
anthropicMsgs = append(anthropicMsgs, anthropic.NewAssistantMessage(contentBlocks...))
|
||||
}
|
||||
case "tool":
|
||||
// Tool results must be in user message with tool_result blocks
|
||||
anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage(
|
||||
|
||||
@@ -232,6 +232,19 @@ func convertMessages(messages []api.Message) ([]*genai.Content, string) {
|
||||
var contents []*genai.Content
|
||||
var systemText string
|
||||
|
||||
// Build a map of CallID -> Name from assistant tool calls
|
||||
// This allows us to look up function names when processing tool results
|
||||
callIDToName := make(map[string]string)
|
||||
for _, msg := range messages {
|
||||
if msg.Role == "assistant" || msg.Role == "model" {
|
||||
for _, tc := range msg.ToolCalls {
|
||||
if tc.ID != "" && tc.Name != "" {
|
||||
callIDToName[tc.ID] = tc.Name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, msg := range messages {
|
||||
if msg.Role == "system" || msg.Role == "developer" {
|
||||
for _, block := range msg.Content {
|
||||
@@ -258,11 +271,17 @@ func convertMessages(messages []api.Message) ([]*genai.Content, string) {
|
||||
responseMap = map[string]any{"output": output}
|
||||
}
|
||||
|
||||
// Create FunctionResponse part with CallID from message
|
||||
// Get function name from message or look it up from CallID
|
||||
name := msg.Name
|
||||
if name == "" && msg.CallID != "" {
|
||||
name = callIDToName[msg.CallID]
|
||||
}
|
||||
|
||||
// Create FunctionResponse part with CallID and Name from message
|
||||
part := &genai.Part{
|
||||
FunctionResponse: &genai.FunctionResponse{
|
||||
ID: msg.CallID,
|
||||
Name: "", // Name is optional for responses
|
||||
Name: name, // Name is required by Google
|
||||
Response: responseMap,
|
||||
},
|
||||
}
|
||||
@@ -282,6 +301,27 @@ func convertMessages(messages []api.Message) ([]*genai.Content, string) {
|
||||
}
|
||||
}
|
||||
|
||||
// Add tool calls for assistant messages
|
||||
if msg.Role == "assistant" || msg.Role == "model" {
|
||||
for _, tc := range msg.ToolCalls {
|
||||
// Parse arguments JSON into map
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(tc.Arguments), &args); err != nil {
|
||||
// If unmarshal fails, skip this tool call
|
||||
continue
|
||||
}
|
||||
|
||||
// Create FunctionCall part
|
||||
parts = append(parts, &genai.Part{
|
||||
FunctionCall: &genai.FunctionCall{
|
||||
ID: tc.ID,
|
||||
Name: tc.Name,
|
||||
Args: args,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
role := "user"
|
||||
if msg.Role == "assistant" || msg.Role == "model" {
|
||||
role = "model"
|
||||
|
||||
@@ -86,7 +86,32 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
|
||||
case "user":
|
||||
oaiMessages = append(oaiMessages, openai.UserMessage(content))
|
||||
case "assistant":
|
||||
oaiMessages = append(oaiMessages, openai.AssistantMessage(content))
|
||||
// If assistant message has tool calls, include them
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
toolCalls := make([]openai.ChatCompletionMessageToolCallUnionParam, len(msg.ToolCalls))
|
||||
for i, tc := range msg.ToolCalls {
|
||||
toolCalls[i] = openai.ChatCompletionMessageToolCallUnionParam{
|
||||
OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
|
||||
ID: tc.ID,
|
||||
Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
|
||||
Name: tc.Name,
|
||||
Arguments: tc.Arguments,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
msgParam := openai.ChatCompletionAssistantMessageParam{
|
||||
ToolCalls: toolCalls,
|
||||
}
|
||||
if content != "" {
|
||||
msgParam.Content.OfString = openai.String(content)
|
||||
}
|
||||
oaiMessages = append(oaiMessages, openai.ChatCompletionMessageParamUnion{
|
||||
OfAssistant: &msgParam,
|
||||
})
|
||||
} else {
|
||||
oaiMessages = append(oaiMessages, openai.AssistantMessage(content))
|
||||
}
|
||||
case "system":
|
||||
oaiMessages = append(oaiMessages, openai.SystemMessage(content))
|
||||
case "developer":
|
||||
@@ -194,7 +219,32 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
|
||||
case "user":
|
||||
oaiMessages = append(oaiMessages, openai.UserMessage(content))
|
||||
case "assistant":
|
||||
oaiMessages = append(oaiMessages, openai.AssistantMessage(content))
|
||||
// If assistant message has tool calls, include them
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
toolCalls := make([]openai.ChatCompletionMessageToolCallUnionParam, len(msg.ToolCalls))
|
||||
for i, tc := range msg.ToolCalls {
|
||||
toolCalls[i] = openai.ChatCompletionMessageToolCallUnionParam{
|
||||
OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
|
||||
ID: tc.ID,
|
||||
Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
|
||||
Name: tc.Name,
|
||||
Arguments: tc.Arguments,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
msgParam := openai.ChatCompletionAssistantMessageParam{
|
||||
ToolCalls: toolCalls,
|
||||
}
|
||||
if content != "" {
|
||||
msgParam.Content.OfString = openai.String(content)
|
||||
}
|
||||
oaiMessages = append(oaiMessages, openai.ChatCompletionMessageParamUnion{
|
||||
OfAssistant: &msgParam,
|
||||
})
|
||||
} else {
|
||||
oaiMessages = append(oaiMessages, openai.AssistantMessage(content))
|
||||
}
|
||||
case "system":
|
||||
oaiMessages = append(oaiMessages, openai.SystemMessage(content))
|
||||
case "developer":
|
||||
|
||||
@@ -141,8 +141,9 @@ func (s *GatewayServer) handleSyncResponse(w http.ResponseWriter, r *http.Reques
|
||||
|
||||
// Build assistant message for conversation store
|
||||
assistantMsg := api.Message{
|
||||
Role: "assistant",
|
||||
Content: []api.ContentBlock{{Type: "output_text", Text: result.Text}},
|
||||
Role: "assistant",
|
||||
Content: []api.ContentBlock{{Type: "output_text", Text: result.Text}},
|
||||
ToolCalls: result.ToolCalls,
|
||||
}
|
||||
allMsgs := append(storeMsgs, assistantMsg)
|
||||
if _, err := s.convs.Create(responseID, result.Model, allMsgs); err != nil {
|
||||
@@ -460,10 +461,11 @@ loop:
|
||||
})
|
||||
|
||||
// Store conversation
|
||||
if fullText != "" {
|
||||
if fullText != "" || len(toolCalls) > 0 {
|
||||
assistantMsg := api.Message{
|
||||
Role: "assistant",
|
||||
Content: []api.ContentBlock{{Type: "output_text", Text: fullText}},
|
||||
Role: "assistant",
|
||||
Content: []api.ContentBlock{{Type: "output_text", Text: fullText}},
|
||||
ToolCalls: toolCalls,
|
||||
}
|
||||
allMsgs := append(storeMsgs, assistantMsg)
|
||||
if _, err := s.convs.Create(responseID, model, allMsgs); err != nil {
|
||||
|
||||
Reference in New Issue
Block a user