Fix tool calling
This commit is contained in:
@@ -94,9 +94,10 @@ type InputItem struct {
|
||||
|
||||
// Message is the normalized internal message representation.
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content []ContentBlock `json:"content"`
|
||||
CallID string `json:"call_id,omitempty"` // for tool messages
|
||||
Role string `json:"role"`
|
||||
Content []ContentBlock `json:"content"`
|
||||
CallID string `json:"call_id,omitempty"` // for tool messages
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"` // for assistant messages
|
||||
}
|
||||
|
||||
// ContentBlock is a typed content element.
|
||||
@@ -129,9 +130,35 @@ func (r *ResponseRequest) NormalizeInput() []Message {
|
||||
}
|
||||
msg.Content = []ContentBlock{{Type: contentType, Text: s}}
|
||||
} else {
|
||||
var blocks []ContentBlock
|
||||
_ = json.Unmarshal(item.Content, &blocks)
|
||||
msg.Content = blocks
|
||||
// Content is an array of blocks - parse them
|
||||
var rawBlocks []map[string]interface{}
|
||||
if err := json.Unmarshal(item.Content, &rawBlocks); err == nil {
|
||||
// Extract content blocks and tool calls
|
||||
for _, block := range rawBlocks {
|
||||
blockType, _ := block["type"].(string)
|
||||
|
||||
if blockType == "tool_use" {
|
||||
// Extract tool call information
|
||||
toolCall := ToolCall{
|
||||
ID: getStringField(block, "id"),
|
||||
Name: getStringField(block, "name"),
|
||||
}
|
||||
// input field contains the arguments as a map
|
||||
if input, ok := block["input"].(map[string]interface{}); ok {
|
||||
if inputJSON, err := json.Marshal(input); err == nil {
|
||||
toolCall.Arguments = string(inputJSON)
|
||||
}
|
||||
}
|
||||
msg.ToolCalls = append(msg.ToolCalls, toolCall)
|
||||
} else if blockType == "output_text" || blockType == "input_text" {
|
||||
// Regular text content block
|
||||
msg.Content = append(msg.Content, ContentBlock{
|
||||
Type: blockType,
|
||||
Text: getStringField(block, "text"),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
msgs = append(msgs, msg)
|
||||
@@ -338,3 +365,11 @@ func (r *ResponseRequest) Validate() error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getStringField is a helper to safely extract string fields from a map
|
||||
func getStringField(m map[string]interface{}, key string) string {
|
||||
if val, ok := m[key].(string); ok {
|
||||
return val
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user