From 2188e3cba80f14995c35187110c60074fce07392 Mon Sep 17 00:00:00 2001 From: Anibal Angulo Date: Mon, 2 Mar 2026 15:46:50 +0000 Subject: [PATCH] Add Anthropic tool calling support --- internal/providers/anthropic/anthropic.go | 124 +++++++++++++-- internal/providers/anthropic/convert.go | 154 +++++++++++++++++++ internal/providers/anthropic/convert_test.go | 119 ++++++++++++++ internal/providers/openai/convert.go | 2 +- 4 files changed, 386 insertions(+), 13 deletions(-) create mode 100644 internal/providers/anthropic/convert.go create mode 100644 internal/providers/anthropic/convert_test.go diff --git a/internal/providers/anthropic/anthropic.go b/internal/providers/anthropic/anthropic.go index afa1f9e..3d84103 100644 --- a/internal/providers/anthropic/anthropic.go +++ b/internal/providers/anthropic/anthropic.go @@ -2,6 +2,7 @@ package anthropic import ( "context" + "encoding/json" "fmt" "github.com/anthropics/anthropic-sdk-go" @@ -85,6 +86,11 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage(anthropic.NewTextBlock(content))) case "assistant": anthropicMsgs = append(anthropicMsgs, anthropic.NewAssistantMessage(anthropic.NewTextBlock(content))) + case "tool": + // Tool results must be in user message with tool_result blocks + anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage( + anthropic.NewToolResultBlock(msg.CallID, content, false), + )) case "system", "developer": system = content } @@ -116,24 +122,55 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap params.TopP = anthropic.Float(*req.TopP) } + // Add tools if present + if req.Tools != nil && len(req.Tools) > 0 { + tools, err := parseTools(req) + if err != nil { + return nil, fmt.Errorf("parse tools: %w", err) + } + params.Tools = tools + } + + // Add tool_choice if present + if req.ToolChoice != nil && len(req.ToolChoice) > 0 { + toolChoice, err := parseToolChoice(req) + if err != nil { + return nil, fmt.Errorf("parse tool_choice: %w", err) + } + params.ToolChoice = toolChoice + } + // Call Anthropic API resp, err := p.client.Messages.New(ctx, params) if err != nil { return nil, fmt.Errorf("anthropic api error: %w", err) } - // Extract text from response + // Extract text and tool calls from response var text string + var toolCalls []api.ToolCall + for _, block := range resp.Content { - if block.Type == "text" { - text += block.Text + switch block.Type { + case "text": + text += block.AsText().Text + case "tool_use": + // Extract tool calls + toolUse := block.AsToolUse() + argsJSON, _ := json.Marshal(toolUse.Input) + toolCalls = append(toolCalls, api.ToolCall{ + ID: toolUse.ID, + Name: toolUse.Name, + Arguments: string(argsJSON), + }) } } return &api.ProviderResult{ - ID: resp.ID, - Model: string(resp.Model), - Text: text, + ID: resp.ID, + Model: string(resp.Model), + Text: text, + ToolCalls: toolCalls, Usage: api.Usage{ InputTokens: int(resp.Usage.InputTokens), OutputTokens: int(resp.Usage.OutputTokens), @@ -177,6 +214,11 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage(anthropic.NewTextBlock(content))) case "assistant": anthropicMsgs = append(anthropicMsgs, anthropic.NewAssistantMessage(anthropic.NewTextBlock(content))) + case "tool": + // Tool results must be in user message with tool_result blocks + anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage( + anthropic.NewToolResultBlock(msg.CallID, content, false), + )) case "system", "developer": system = content } @@ -208,19 +250,77 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r params.TopP = anthropic.Float(*req.TopP) } + // Add tools if present + if req.Tools != nil && len(req.Tools) > 0 { + tools, err := parseTools(req) + if err != nil { + errChan <- fmt.Errorf("parse tools: %w", err) + return + } + params.Tools = tools + } + + // Add tool_choice if present + if req.ToolChoice != nil && len(req.ToolChoice) > 0 { + toolChoice, err := parseToolChoice(req) + if err != nil { + errChan <- fmt.Errorf("parse tool_choice: %w", err) + return + } + params.ToolChoice = toolChoice + } + // Create stream stream := p.client.Messages.NewStreaming(ctx, params) + // Track content block index and tool call state + var contentBlockIndex int + // Process stream for stream.Next() { event := stream.Current() - if event.Type == "content_block_delta" && event.Delta.Type == "text_delta" { - select { - case deltaChan <- &api.ProviderStreamDelta{Text: event.Delta.Text}: - case <-ctx.Done(): - errChan <- ctx.Err() - return + switch event.Type { + case "content_block_start": + // New content block (text or tool_use) + contentBlockIndex = int(event.Index) + if event.ContentBlock.Type == "tool_use" { + // Send tool call delta with ID and name + toolUse := event.ContentBlock.AsToolUse() + delta := &api.ToolCallDelta{ + Index: contentBlockIndex, + ID: toolUse.ID, + Name: toolUse.Name, + } + select { + case deltaChan <- &api.ProviderStreamDelta{ToolCallDelta: delta}: + case <-ctx.Done(): + errChan <- ctx.Err() + return + } + } + + case "content_block_delta": + if event.Delta.Type == "text_delta" { + // Text streaming + select { + case deltaChan <- &api.ProviderStreamDelta{Text: event.Delta.Text}: + case <-ctx.Done(): + errChan <- ctx.Err() + return + } + } else if event.Delta.Type == "input_json_delta" { + // Tool arguments streaming + delta := &api.ToolCallDelta{ + Index: int(event.Index), + Arguments: event.Delta.PartialJSON, + } + select { + case deltaChan <- &api.ProviderStreamDelta{ToolCallDelta: delta}: + case <-ctx.Done(): + errChan <- ctx.Err() + return + } } } } diff --git a/internal/providers/anthropic/convert.go b/internal/providers/anthropic/convert.go new file mode 100644 index 0000000..18154d4 --- /dev/null +++ b/internal/providers/anthropic/convert.go @@ -0,0 +1,154 @@ +package anthropic + +import ( + "encoding/json" + "fmt" + + "github.com/ajac-zero/latticelm/internal/api" + "github.com/anthropics/anthropic-sdk-go" +) + +// parseTools converts Open Responses tools to Anthropic format +func parseTools(req *api.ResponseRequest) ([]anthropic.ToolUnionParam, error) { + if req.Tools == nil || len(req.Tools) == 0 { + return nil, nil + } + + var toolDefs []map[string]interface{} + if err := json.Unmarshal(req.Tools, &toolDefs); err != nil { + return nil, fmt.Errorf("unmarshal tools: %w", err) + } + + var tools []anthropic.ToolUnionParam + for _, td := range toolDefs { + // Extract: name, description, parameters + // Note: Anthropic uses "input_schema" instead of "parameters" + name, _ := td["name"].(string) + desc, _ := td["description"].(string) + params, _ := td["parameters"].(map[string]interface{}) + + inputSchema := anthropic.ToolInputSchemaParam{ + Type: "object", + Properties: params["properties"], + } + + // Add required fields if present + if required, ok := params["required"].([]interface{}); ok { + requiredStrs := make([]string, 0, len(required)) + for _, r := range required { + if str, ok := r.(string); ok { + requiredStrs = append(requiredStrs, str) + } + } + inputSchema.Required = requiredStrs + } + + // Create the tool using ToolUnionParamOfTool + tool := anthropic.ToolUnionParamOfTool(inputSchema, name) + + if desc != "" { + tool.OfTool.Description = anthropic.String(desc) + } + + tools = append(tools, tool) + } + + return tools, nil +} + +// parseToolChoice converts Open Responses tool_choice to Anthropic format +func parseToolChoice(req *api.ResponseRequest) (anthropic.ToolChoiceUnionParam, error) { + var result anthropic.ToolChoiceUnionParam + + if req.ToolChoice == nil || len(req.ToolChoice) == 0 { + return result, nil + } + + var choice interface{} + if err := json.Unmarshal(req.ToolChoice, &choice); err != nil { + return result, fmt.Errorf("unmarshal tool_choice: %w", err) + } + + // Handle string values: "auto", "any", "required" + if str, ok := choice.(string); ok { + switch str { + case "auto": + result.OfAuto = &anthropic.ToolChoiceAutoParam{ + Type: "auto", + } + case "any", "required": + result.OfAny = &anthropic.ToolChoiceAnyParam{ + Type: "any", + } + case "none": + result.OfNone = &anthropic.ToolChoiceNoneParam{ + Type: "none", + } + default: + return result, fmt.Errorf("unknown tool_choice string: %s", str) + } + return result, nil + } + + // Handle specific tool selection: {"type": "tool", "function": {"name": "..."}} + if obj, ok := choice.(map[string]interface{}); ok { + // Check for OpenAI format: {"type": "function", "function": {"name": "..."}} + if funcObj, ok := obj["function"].(map[string]interface{}); ok { + if name, ok := funcObj["name"].(string); ok { + result.OfTool = &anthropic.ToolChoiceToolParam{ + Type: "tool", + Name: name, + } + return result, nil + } + } + + // Check for direct name field + if name, ok := obj["name"].(string); ok { + result.OfTool = &anthropic.ToolChoiceToolParam{ + Type: "tool", + Name: name, + } + return result, nil + } + } + + return result, fmt.Errorf("invalid tool_choice format") +} + +// extractToolCalls converts Anthropic content blocks to api.ToolCall +func extractToolCalls(content []anthropic.ContentBlockUnion) []api.ToolCall { + var toolCalls []api.ToolCall + + for _, block := range content { + // Check if this is a tool_use block + if block.Type == "tool_use" { + // Cast to ToolUseBlock to access the fields + toolUse := block.AsToolUse() + + // Marshal the input to JSON string for Arguments + argsJSON, _ := json.Marshal(toolUse.Input) + + toolCalls = append(toolCalls, api.ToolCall{ + ID: toolUse.ID, + Name: toolUse.Name, + Arguments: string(argsJSON), + }) + } + } + + return toolCalls +} + +// extractToolCallDelta extracts tool call delta from streaming content block delta +func extractToolCallDelta(delta anthropic.RawContentBlockDeltaUnion, index int) *api.ToolCallDelta { + // Check if this is an input_json_delta (streaming tool arguments) + if delta.Type == "input_json_delta" { + return &api.ToolCallDelta{ + Index: index, + Arguments: delta.PartialJSON, + } + } + + return nil +} diff --git a/internal/providers/anthropic/convert_test.go b/internal/providers/anthropic/convert_test.go new file mode 100644 index 0000000..4298f46 --- /dev/null +++ b/internal/providers/anthropic/convert_test.go @@ -0,0 +1,119 @@ +package anthropic + +import ( + "encoding/json" + "testing" + + "github.com/ajac-zero/latticelm/internal/api" +) + +func TestParseTools(t *testing.T) { + // Create a sample tool definition + toolsJSON := `[{ + "type": "function", + "name": "get_weather", + "description": "Get the weather for a location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state" + } + }, + "required": ["location"] + } + }]` + + req := &api.ResponseRequest{ + Tools: json.RawMessage(toolsJSON), + } + + tools, err := parseTools(req) + if err != nil { + t.Fatalf("parseTools failed: %v", err) + } + + if len(tools) != 1 { + t.Fatalf("expected 1 tool, got %d", len(tools)) + } + + tool := tools[0] + if tool.OfTool == nil { + t.Fatal("expected OfTool to be set") + } + + if tool.OfTool.Name != "get_weather" { + t.Errorf("expected name 'get_weather', got '%s'", tool.OfTool.Name) + } + + desc := tool.GetDescription() + if desc == nil || *desc != "Get the weather for a location" { + t.Errorf("expected description 'Get the weather for a location', got '%v'", desc) + } + + if len(tool.OfTool.InputSchema.Required) != 1 || tool.OfTool.InputSchema.Required[0] != "location" { + t.Errorf("expected required=['location'], got %v", tool.OfTool.InputSchema.Required) + } +} + +func TestParseToolChoice(t *testing.T) { + tests := []struct { + name string + choiceJSON string + expectAuto bool + expectAny bool + expectTool bool + expectedName string + }{ + { + name: "auto", + choiceJSON: `"auto"`, + expectAuto: true, + }, + { + name: "any", + choiceJSON: `"any"`, + expectAny: true, + }, + { + name: "required", + choiceJSON: `"required"`, + expectAny: true, + }, + { + name: "specific tool", + choiceJSON: `{"type": "function", "function": {"name": "get_weather"}}`, + expectTool: true, + expectedName: "get_weather", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := &api.ResponseRequest{ + ToolChoice: json.RawMessage(tt.choiceJSON), + } + + choice, err := parseToolChoice(req) + if err != nil { + t.Fatalf("parseToolChoice failed: %v", err) + } + + if tt.expectAuto && choice.OfAuto == nil { + t.Error("expected OfAuto to be set") + } + if tt.expectAny && choice.OfAny == nil { + t.Error("expected OfAny to be set") + } + if tt.expectTool { + if choice.OfTool == nil { + t.Fatal("expected OfTool to be set") + } + if choice.OfTool.Name != tt.expectedName { + t.Errorf("expected name '%s', got '%s'", tt.expectedName, choice.OfTool.Name) + } + } + }) + } +} diff --git a/internal/providers/openai/convert.go b/internal/providers/openai/convert.go index 308a992..9b8d60b 100644 --- a/internal/providers/openai/convert.go +++ b/internal/providers/openai/convert.go @@ -22,7 +22,7 @@ func parseTools(req *api.ResponseRequest) ([]openai.ChatCompletionToolUnionParam var tools []openai.ChatCompletionToolUnionParam for _, td := range toolDefs { - // Convert Open Responses tool to OpenAI function tool + // Convert Open Responses tool to OpenAI ChatCompletionFunctionToolParam // Extract: name, description, parameters name, _ := td["name"].(string) desc, _ := td["description"].(string)