From 841bcd0e8bcb369e55b8e453eb39b738b2291200 Mon Sep 17 00:00:00 2001 From: Anibal Angulo Date: Mon, 2 Mar 2026 17:14:20 +0000 Subject: [PATCH] Fix tool calling --- internal/api/types.go | 49 +++- internal/providers/anthropic/anthropic.go | 36 ++- internal/providers/google/google.go | 44 ++- internal/providers/openai/openai.go | 54 +++- internal/server/server.go | 12 +- scripts/chat.py | 332 +++++++++++++++++++--- 6 files changed, 467 insertions(+), 60 deletions(-) diff --git a/internal/api/types.go b/internal/api/types.go index 0c39d34..73688e4 100644 --- a/internal/api/types.go +++ b/internal/api/types.go @@ -94,9 +94,11 @@ type InputItem struct { // Message is the normalized internal message representation. type Message struct { - Role string `json:"role"` - Content []ContentBlock `json:"content"` - CallID string `json:"call_id,omitempty"` // for tool messages + Role string `json:"role"` + Content []ContentBlock `json:"content"` + CallID string `json:"call_id,omitempty"` // for tool messages + Name string `json:"name,omitempty"` // for tool messages + ToolCalls []ToolCall `json:"tool_calls,omitempty"` // for assistant messages } // ContentBlock is a typed content element. @@ -129,9 +131,35 @@ func (r *ResponseRequest) NormalizeInput() []Message { } msg.Content = []ContentBlock{{Type: contentType, Text: s}} } else { - var blocks []ContentBlock - _ = json.Unmarshal(item.Content, &blocks) - msg.Content = blocks + // Content is an array of blocks - parse them + var rawBlocks []map[string]interface{} + if err := json.Unmarshal(item.Content, &rawBlocks); err == nil { + // Extract content blocks and tool calls + for _, block := range rawBlocks { + blockType, _ := block["type"].(string) + + if blockType == "tool_use" { + // Extract tool call information + toolCall := ToolCall{ + ID: getStringField(block, "id"), + Name: getStringField(block, "name"), + } + // input field contains the arguments as a map + if input, ok := block["input"].(map[string]interface{}); ok { + if inputJSON, err := json.Marshal(input); err == nil { + toolCall.Arguments = string(inputJSON) + } + } + msg.ToolCalls = append(msg.ToolCalls, toolCall) + } else if blockType == "output_text" || blockType == "input_text" { + // Regular text content block + msg.Content = append(msg.Content, ContentBlock{ + Type: blockType, + Text: getStringField(block, "text"), + }) + } + } + } } } msgs = append(msgs, msg) @@ -140,6 +168,7 @@ func (r *ResponseRequest) NormalizeInput() []Message { Role: "tool", Content: []ContentBlock{{Type: "input_text", Text: item.Output}}, CallID: item.CallID, + Name: item.Name, }) } } @@ -338,3 +367,11 @@ func (r *ResponseRequest) Validate() error { } return nil } + +// getStringField is a helper to safely extract string fields from a map +func getStringField(m map[string]interface{}, key string) string { + if val, ok := m[key].(string); ok { + return val + } + return "" +} diff --git a/internal/providers/anthropic/anthropic.go b/internal/providers/anthropic/anthropic.go index 3d84103..7e1a406 100644 --- a/internal/providers/anthropic/anthropic.go +++ b/internal/providers/anthropic/anthropic.go @@ -85,7 +85,23 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap case "user": anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage(anthropic.NewTextBlock(content))) case "assistant": - anthropicMsgs = append(anthropicMsgs, anthropic.NewAssistantMessage(anthropic.NewTextBlock(content))) + // Build content blocks including text and tool calls + var contentBlocks []anthropic.ContentBlockParamUnion + if content != "" { + contentBlocks = append(contentBlocks, anthropic.NewTextBlock(content)) + } + // Add tool use blocks + for _, tc := range msg.ToolCalls { + var input map[string]interface{} + if err := json.Unmarshal([]byte(tc.Arguments), &input); err != nil { + // If unmarshal fails, skip this tool call + continue + } + contentBlocks = append(contentBlocks, anthropic.NewToolUseBlock(tc.ID, input, tc.Name)) + } + if len(contentBlocks) > 0 { + anthropicMsgs = append(anthropicMsgs, anthropic.NewAssistantMessage(contentBlocks...)) + } case "tool": // Tool results must be in user message with tool_result blocks anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage( @@ -213,7 +229,23 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r case "user": anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage(anthropic.NewTextBlock(content))) case "assistant": - anthropicMsgs = append(anthropicMsgs, anthropic.NewAssistantMessage(anthropic.NewTextBlock(content))) + // Build content blocks including text and tool calls + var contentBlocks []anthropic.ContentBlockParamUnion + if content != "" { + contentBlocks = append(contentBlocks, anthropic.NewTextBlock(content)) + } + // Add tool use blocks + for _, tc := range msg.ToolCalls { + var input map[string]interface{} + if err := json.Unmarshal([]byte(tc.Arguments), &input); err != nil { + // If unmarshal fails, skip this tool call + continue + } + contentBlocks = append(contentBlocks, anthropic.NewToolUseBlock(tc.ID, input, tc.Name)) + } + if len(contentBlocks) > 0 { + anthropicMsgs = append(anthropicMsgs, anthropic.NewAssistantMessage(contentBlocks...)) + } case "tool": // Tool results must be in user message with tool_result blocks anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage( diff --git a/internal/providers/google/google.go b/internal/providers/google/google.go index 76423e3..7b43b76 100644 --- a/internal/providers/google/google.go +++ b/internal/providers/google/google.go @@ -232,6 +232,19 @@ func convertMessages(messages []api.Message) ([]*genai.Content, string) { var contents []*genai.Content var systemText string + // Build a map of CallID -> Name from assistant tool calls + // This allows us to look up function names when processing tool results + callIDToName := make(map[string]string) + for _, msg := range messages { + if msg.Role == "assistant" || msg.Role == "model" { + for _, tc := range msg.ToolCalls { + if tc.ID != "" && tc.Name != "" { + callIDToName[tc.ID] = tc.Name + } + } + } + } + for _, msg := range messages { if msg.Role == "system" || msg.Role == "developer" { for _, block := range msg.Content { @@ -258,11 +271,17 @@ func convertMessages(messages []api.Message) ([]*genai.Content, string) { responseMap = map[string]any{"output": output} } - // Create FunctionResponse part with CallID from message + // Get function name from message or look it up from CallID + name := msg.Name + if name == "" && msg.CallID != "" { + name = callIDToName[msg.CallID] + } + + // Create FunctionResponse part with CallID and Name from message part := &genai.Part{ FunctionResponse: &genai.FunctionResponse{ ID: msg.CallID, - Name: "", // Name is optional for responses + Name: name, // Name is required by Google Response: responseMap, }, } @@ -282,6 +301,27 @@ func convertMessages(messages []api.Message) ([]*genai.Content, string) { } } + // Add tool calls for assistant messages + if msg.Role == "assistant" || msg.Role == "model" { + for _, tc := range msg.ToolCalls { + // Parse arguments JSON into map + var args map[string]any + if err := json.Unmarshal([]byte(tc.Arguments), &args); err != nil { + // If unmarshal fails, skip this tool call + continue + } + + // Create FunctionCall part + parts = append(parts, &genai.Part{ + FunctionCall: &genai.FunctionCall{ + ID: tc.ID, + Name: tc.Name, + Args: args, + }, + }) + } + } + role := "user" if msg.Role == "assistant" || msg.Role == "model" { role = "model" diff --git a/internal/providers/openai/openai.go b/internal/providers/openai/openai.go index d687037..3f408c6 100644 --- a/internal/providers/openai/openai.go +++ b/internal/providers/openai/openai.go @@ -86,7 +86,32 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap case "user": oaiMessages = append(oaiMessages, openai.UserMessage(content)) case "assistant": - oaiMessages = append(oaiMessages, openai.AssistantMessage(content)) + // If assistant message has tool calls, include them + if len(msg.ToolCalls) > 0 { + toolCalls := make([]openai.ChatCompletionMessageToolCallUnionParam, len(msg.ToolCalls)) + for i, tc := range msg.ToolCalls { + toolCalls[i] = openai.ChatCompletionMessageToolCallUnionParam{ + OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{ + ID: tc.ID, + Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{ + Name: tc.Name, + Arguments: tc.Arguments, + }, + }, + } + } + msgParam := openai.ChatCompletionAssistantMessageParam{ + ToolCalls: toolCalls, + } + if content != "" { + msgParam.Content.OfString = openai.String(content) + } + oaiMessages = append(oaiMessages, openai.ChatCompletionMessageParamUnion{ + OfAssistant: &msgParam, + }) + } else { + oaiMessages = append(oaiMessages, openai.AssistantMessage(content)) + } case "system": oaiMessages = append(oaiMessages, openai.SystemMessage(content)) case "developer": @@ -194,7 +219,32 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r case "user": oaiMessages = append(oaiMessages, openai.UserMessage(content)) case "assistant": - oaiMessages = append(oaiMessages, openai.AssistantMessage(content)) + // If assistant message has tool calls, include them + if len(msg.ToolCalls) > 0 { + toolCalls := make([]openai.ChatCompletionMessageToolCallUnionParam, len(msg.ToolCalls)) + for i, tc := range msg.ToolCalls { + toolCalls[i] = openai.ChatCompletionMessageToolCallUnionParam{ + OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{ + ID: tc.ID, + Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{ + Name: tc.Name, + Arguments: tc.Arguments, + }, + }, + } + } + msgParam := openai.ChatCompletionAssistantMessageParam{ + ToolCalls: toolCalls, + } + if content != "" { + msgParam.Content.OfString = openai.String(content) + } + oaiMessages = append(oaiMessages, openai.ChatCompletionMessageParamUnion{ + OfAssistant: &msgParam, + }) + } else { + oaiMessages = append(oaiMessages, openai.AssistantMessage(content)) + } case "system": oaiMessages = append(oaiMessages, openai.SystemMessage(content)) case "developer": diff --git a/internal/server/server.go b/internal/server/server.go index 1eff7c8..88e3cbd 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -141,8 +141,9 @@ func (s *GatewayServer) handleSyncResponse(w http.ResponseWriter, r *http.Reques // Build assistant message for conversation store assistantMsg := api.Message{ - Role: "assistant", - Content: []api.ContentBlock{{Type: "output_text", Text: result.Text}}, + Role: "assistant", + Content: []api.ContentBlock{{Type: "output_text", Text: result.Text}}, + ToolCalls: result.ToolCalls, } allMsgs := append(storeMsgs, assistantMsg) if _, err := s.convs.Create(responseID, result.Model, allMsgs); err != nil { @@ -460,10 +461,11 @@ loop: }) // Store conversation - if fullText != "" { + if fullText != "" || len(toolCalls) > 0 { assistantMsg := api.Message{ - Role: "assistant", - Content: []api.ContentBlock{{Type: "output_text", Text: fullText}}, + Role: "assistant", + Content: []api.ContentBlock{{Type: "output_text", Text: fullText}}, + ToolCalls: toolCalls, } allMsgs := append(storeMsgs, assistantMsg) if _, err := s.convs.Create(responseID, model, allMsgs); err != nil { diff --git a/scripts/chat.py b/scripts/chat.py index 092e5d9..545faeb 100755 --- a/scripts/chat.py +++ b/scripts/chat.py @@ -18,8 +18,10 @@ Usage: """ import argparse +import json import sys -from typing import Optional +from datetime import datetime +from typing import Optional, Any from openai import OpenAI, APIStatusError from rich.console import Console @@ -30,6 +32,85 @@ from rich.prompt import Prompt from rich.table import Table +# Define available tools in OpenResponses format +TOOLS = [ + { + "type": "function", + "name": "calculator", + "description": "Perform basic arithmetic operations. Supports addition, subtraction, multiplication, and division.", + "parameters": { + "type": "object", + "properties": { + "operation": { + "type": "string", + "enum": ["add", "subtract", "multiply", "divide"], + "description": "The arithmetic operation to perform" + }, + "a": { + "type": "number", + "description": "The first number" + }, + "b": { + "type": "number", + "description": "The second number" + } + }, + "required": ["operation", "a", "b"] + } + }, + { + "type": "function", + "name": "get_current_time", + "description": "Get the current time in a specified timezone or UTC", + "parameters": { + "type": "object", + "properties": { + "timezone": { + "type": "string", + "description": "Timezone name (e.g., 'UTC', 'America/New_York', 'Europe/London'). Defaults to UTC.", + } + } + } + } +] + + +def execute_tool(tool_name: str, arguments: dict[str, Any]) -> str: + """Execute a tool and return the result as a string.""" + if tool_name == "calculator": + operation = arguments["operation"] + a = arguments["a"] + b = arguments["b"] + + if operation == "add": + result = a + b + elif operation == "subtract": + result = a - b + elif operation == "multiply": + result = a * b + elif operation == "divide": + if b == 0: + return json.dumps({"error": "Division by zero"}) + result = a / b + else: + return json.dumps({"error": f"Unknown operation: {operation}"}) + + return json.dumps({"result": result, "operation": operation, "a": a, "b": b}) + + elif tool_name == "get_current_time": + # Simple implementation without pytz + timezone = arguments.get("timezone", "UTC") + now = datetime.now() + return json.dumps({ + "current_time": now.isoformat(), + "timezone": timezone, + "note": "Showing local system time (timezone parameter not fully implemented)" + }) + + else: + return json.dumps({"error": f"Unknown tool: {tool_name}"}) + + class ChatClient: def __init__(self, base_url: str, token: Optional[str] = None): self.base_url = base_url.rstrip("/") @@ -39,60 +120,199 @@ class ChatClient: ) self.messages = [] self.console = Console() + self.tools_enabled = True def chat(self, user_message: str, model: str, stream: bool = True): """Send a chat message and get response.""" - # Add user message to history + # Add user message to history as a message-type input item self.messages.append({ + "type": "message", "role": "user", "content": [{"type": "input_text", "text": user_message}] }) - + if stream: return self._stream_response(model) else: return self._sync_response(model) def _sync_response(self, model: str) -> str: - """Non-streaming response.""" - with self.console.status(f"[bold blue]Thinking ({model})..."): - response = self.client.responses.create( - model=model, - input=self.messages, - ) - - assistant_text = response.output_text - - # Add to history - self.messages.append({ - "role": "assistant", - "content": [{"type": "output_text", "text": assistant_text}] - }) - - return assistant_text + """Non-streaming response with tool support.""" + max_iterations = 10 # Prevent infinite loops + iteration = 0 + + while iteration < max_iterations: + iteration += 1 + + with self.console.status(f"[bold blue]Thinking ({model})..."): + kwargs = { + "model": model, + "input": self.messages, + } + if self.tools_enabled: + kwargs["tools"] = TOOLS + + response = self.client.responses.create(**kwargs) + + # Check if there are tool calls + tool_calls = [] + assistant_content = [] + text_parts = [] + + for item in response.output: + if item.type == "message": + # Extract text from message content + for content_block in item.content: + if content_block.type == "output_text": + text_parts.append(content_block.text) + assistant_content.append({"type": "output_text", "text": content_block.text}) + elif item.type == "function_call": + # Parse arguments JSON string + try: + arguments = json.loads(item.arguments) + except json.JSONDecodeError: + arguments = {} + + tool_calls.append({ + "id": item.call_id, + "name": item.name, + "arguments": arguments + }) + assistant_content.append({ + "type": "tool_use", + "id": item.call_id, + "name": item.name, + "input": arguments + }) + + # Add assistant message to history as a message-type input item + if assistant_content: + self.messages.append({ + "type": "message", + "role": "assistant", + "content": assistant_content + }) + + # If no tool calls, we're done + if not tool_calls: + return "\n".join(text_parts) if text_parts else "" + + # Execute tools and add results + self.console.print(f"[dim]Executing {len(tool_calls)} tool(s)...[/dim]") + tool_results = [] + + for tool_call in tool_calls: + self.console.print(f"[dim] → {tool_call['name']}({json.dumps(tool_call['arguments'])})[/dim]") + result = execute_tool(tool_call["name"], tool_call["arguments"]) + tool_results.append({ + "type": "function_call_output", + "call_id": tool_call["id"], + "output": result + }) + + # Add tool results to input + self.messages.extend(tool_results) + + # Continue the loop to get the next response + + return "[Error: Max iterations reached in tool calling loop]" def _stream_response(self, model: str) -> str: - """Streaming response with live rendering.""" - assistant_text = "" - - with Live(console=self.console, refresh_per_second=10) as live: - stream = self.client.responses.create( - model=model, - input=self.messages, - stream=True, - ) - for event in stream: - if event.type == "response.output_text.delta": - assistant_text += event.delta - live.update(Markdown(assistant_text)) - - # Add to history - self.messages.append({ - "role": "assistant", - "content": [{"type": "output_text", "text": assistant_text}] - }) - - return assistant_text + """Streaming response with live rendering and tool support.""" + max_iterations = 10 + iteration = 0 + + while iteration < max_iterations: + iteration += 1 + assistant_text = "" + tool_calls = {} # Dict to track tool calls by item_id + tool_calls_list = [] # Final list of completed tool calls + assistant_content = [] + + with Live(console=self.console, refresh_per_second=10) as live: + kwargs = { + "model": model, + "input": self.messages, + "stream": True, + } + if self.tools_enabled: + kwargs["tools"] = TOOLS + + stream = self.client.responses.create(**kwargs) + + for event in stream: + if event.type == "response.output_text.delta": + assistant_text += event.delta + live.update(Markdown(assistant_text)) + elif event.type == "response.output_item.added": + if hasattr(event, 'item') and event.item.type == "function_call": + # Start tracking a new tool call + tool_calls[event.item.id] = { + "id": event.item.call_id, + "name": event.item.name, + "arguments": "", + "item_id": event.item.id + } + elif event.type == "response.function_call_arguments.delta": + # Accumulate arguments for the current function call + # Find which tool call this belongs to by item_id + if hasattr(event, 'item_id') and event.item_id in tool_calls: + tool_calls[event.item_id]["arguments"] += event.delta + elif event.type == "response.output_item.done": + if hasattr(event, 'item') and event.item.type == "function_call": + # Function call is complete + if event.item.id in tool_calls: + tool_call = tool_calls[event.item.id] + try: + # Parse the complete arguments JSON + tool_call["arguments"] = json.loads(tool_call["arguments"]) + tool_calls_list.append(tool_call) + except json.JSONDecodeError: + self.console.print(f"[red]Error parsing tool arguments JSON[/red]") + + # Build assistant content + if assistant_text: + assistant_content.append({"type": "output_text", "text": assistant_text}) + + for tool_call in tool_calls_list: + assistant_content.append({ + "type": "tool_use", + "id": tool_call["id"], + "name": tool_call["name"], + "input": tool_call["arguments"] + }) + + # Add to history as a message-type input item + if assistant_content: + self.messages.append({ + "type": "message", + "role": "assistant", + "content": assistant_content + }) + + # If no tool calls, we're done + if not tool_calls_list: + return assistant_text + + # Execute tools + self.console.print(f"\n[dim]Executing {len(tool_calls_list)} tool(s)...[/dim]") + tool_results = [] + + for tool_call in tool_calls_list: + self.console.print(f"[dim] → {tool_call['name']}({json.dumps(tool_call['arguments'])})[/dim]") + result = execute_tool(tool_call["name"], tool_call["arguments"]) + tool_results.append({ + "type": "function_call_output", + "call_id": tool_call["id"], + "output": result + }) + + # Add tool results to input + self.messages.extend(tool_results) + + # Continue loop for next response + + return "[Error: Max iterations reached in tool calling loop]" def clear_history(self): """Clear conversation history.""" @@ -118,6 +338,20 @@ def print_models_table(client: OpenAI): console.print(table) +def print_tools_table(): + """Print available tools.""" + console = Console() + table = Table(title="Available Tools", show_header=True, header_style="bold magenta") + table.add_column("Tool Name", style="cyan") + table.add_column("Description", style="green") + + for tool in TOOLS: + if tool.get("type") == "function": + table.add_row(tool["name"], tool["description"]) + + console.print(table) + + def main(): parser = argparse.ArgumentParser(description="Chat with latticelm") parser.add_argument("--url", default="http://localhost:8080", help="Gateway URL") @@ -151,11 +385,14 @@ def main(): "[bold cyan]latticelm Chat Interface[/bold cyan]\n" f"Connected to: [green]{args.url}[/green]\n" f"Model: [yellow]{current_model}[/yellow]\n" - f"Streaming: [{'green' if stream_enabled else 'red'}]{stream_enabled}[/]\n\n" + f"Streaming: [{'green' if stream_enabled else 'red'}]{stream_enabled}[/]\n" + f"Tools: [{'green' if client.tools_enabled else 'red'}]{client.tools_enabled}[/]\n\n" "Commands:\n" " [bold]/model [/bold] - Switch model\n" " [bold]/models[/bold] - List available models\n" " [bold]/stream[/bold] - Toggle streaming\n" + " [bold]/tools[/bold] - Toggle tool calling\n" + " [bold]/listtools[/bold] - List available tools\n" " [bold]/clear[/bold] - Clear conversation\n" " [bold]/quit[/bold] or [bold]/exit[/bold] - Exit\n" " [bold]/help[/bold] - Show this help", @@ -196,6 +433,8 @@ def main(): " /model - Switch model\n" " /models - List available models\n" " /stream - Toggle streaming\n" + " /tools - Toggle tool calling\n" + " /listtools - List available tools\n" " /clear - Clear conversation\n" " /quit - Exit", title="Help", @@ -204,7 +443,10 @@ def main(): elif cmd == "/models": print_models_table(client.client) - + + elif cmd == "/listtools": + print_tools_table() + elif cmd == "/model": if len(cmd_parts) < 2: console.print("[red]Usage: /model [/red]") @@ -219,7 +461,11 @@ def main(): elif cmd == "/stream": stream_enabled = not stream_enabled console.print(f"[green]Streaming {'enabled' if stream_enabled else 'disabled'}[/green]") - + + elif cmd == "/tools": + client.tools_enabled = not client.tools_enabled + console.print(f"[green]Tools {'enabled' if client.tools_enabled else 'disabled'}[/green]") + elif cmd == "/clear": client.clear_history() console.print("[green]Conversation history cleared[/green]") -- 2.49.1