Fix tool calling

This commit is contained in:
2026-03-02 17:14:20 +00:00
parent 6adf7eae54
commit 841bcd0e8b
6 changed files with 467 additions and 60 deletions

View File

@@ -18,8 +18,10 @@ Usage:
"""
import argparse
import json
import sys
from typing import Optional
from datetime import datetime
from typing import Optional, Any
from openai import OpenAI, APIStatusError
from rich.console import Console
@@ -30,6 +32,85 @@ from rich.prompt import Prompt
from rich.table import Table
# Define available tools in OpenResponses format
TOOLS = [
{
"type": "function",
"name": "calculator",
"description": "Perform basic arithmetic operations. Supports addition, subtraction, multiplication, and division.",
"parameters": {
"type": "object",
"properties": {
"operation": {
"type": "string",
"enum": ["add", "subtract", "multiply", "divide"],
"description": "The arithmetic operation to perform"
},
"a": {
"type": "number",
"description": "The first number"
},
"b": {
"type": "number",
"description": "The second number"
}
},
"required": ["operation", "a", "b"]
}
},
{
"type": "function",
"name": "get_current_time",
"description": "Get the current time in a specified timezone or UTC",
"parameters": {
"type": "object",
"properties": {
"timezone": {
"type": "string",
"description": "Timezone name (e.g., 'UTC', 'America/New_York', 'Europe/London'). Defaults to UTC.",
}
}
}
}
]
def execute_tool(tool_name: str, arguments: dict[str, Any]) -> str:
"""Execute a tool and return the result as a string."""
if tool_name == "calculator":
operation = arguments["operation"]
a = arguments["a"]
b = arguments["b"]
if operation == "add":
result = a + b
elif operation == "subtract":
result = a - b
elif operation == "multiply":
result = a * b
elif operation == "divide":
if b == 0:
return json.dumps({"error": "Division by zero"})
result = a / b
else:
return json.dumps({"error": f"Unknown operation: {operation}"})
return json.dumps({"result": result, "operation": operation, "a": a, "b": b})
elif tool_name == "get_current_time":
# Simple implementation without pytz
timezone = arguments.get("timezone", "UTC")
now = datetime.now()
return json.dumps({
"current_time": now.isoformat(),
"timezone": timezone,
"note": "Showing local system time (timezone parameter not fully implemented)"
})
else:
return json.dumps({"error": f"Unknown tool: {tool_name}"})
class ChatClient:
def __init__(self, base_url: str, token: Optional[str] = None):
self.base_url = base_url.rstrip("/")
@@ -39,60 +120,199 @@ class ChatClient:
)
self.messages = []
self.console = Console()
self.tools_enabled = True
def chat(self, user_message: str, model: str, stream: bool = True):
"""Send a chat message and get response."""
# Add user message to history
# Add user message to history as a message-type input item
self.messages.append({
"type": "message",
"role": "user",
"content": [{"type": "input_text", "text": user_message}]
})
if stream:
return self._stream_response(model)
else:
return self._sync_response(model)
def _sync_response(self, model: str) -> str:
"""Non-streaming response."""
with self.console.status(f"[bold blue]Thinking ({model})..."):
response = self.client.responses.create(
model=model,
input=self.messages,
)
assistant_text = response.output_text
# Add to history
self.messages.append({
"role": "assistant",
"content": [{"type": "output_text", "text": assistant_text}]
})
return assistant_text
"""Non-streaming response with tool support."""
max_iterations = 10 # Prevent infinite loops
iteration = 0
while iteration < max_iterations:
iteration += 1
with self.console.status(f"[bold blue]Thinking ({model})..."):
kwargs = {
"model": model,
"input": self.messages,
}
if self.tools_enabled:
kwargs["tools"] = TOOLS
response = self.client.responses.create(**kwargs)
# Check if there are tool calls
tool_calls = []
assistant_content = []
text_parts = []
for item in response.output:
if item.type == "message":
# Extract text from message content
for content_block in item.content:
if content_block.type == "output_text":
text_parts.append(content_block.text)
assistant_content.append({"type": "output_text", "text": content_block.text})
elif item.type == "function_call":
# Parse arguments JSON string
try:
arguments = json.loads(item.arguments)
except json.JSONDecodeError:
arguments = {}
tool_calls.append({
"id": item.call_id,
"name": item.name,
"arguments": arguments
})
assistant_content.append({
"type": "tool_use",
"id": item.call_id,
"name": item.name,
"input": arguments
})
# Add assistant message to history as a message-type input item
if assistant_content:
self.messages.append({
"type": "message",
"role": "assistant",
"content": assistant_content
})
# If no tool calls, we're done
if not tool_calls:
return "\n".join(text_parts) if text_parts else ""
# Execute tools and add results
self.console.print(f"[dim]Executing {len(tool_calls)} tool(s)...[/dim]")
tool_results = []
for tool_call in tool_calls:
self.console.print(f"[dim] → {tool_call['name']}({json.dumps(tool_call['arguments'])})[/dim]")
result = execute_tool(tool_call["name"], tool_call["arguments"])
tool_results.append({
"type": "function_call_output",
"call_id": tool_call["id"],
"output": result
})
# Add tool results to input
self.messages.extend(tool_results)
# Continue the loop to get the next response
return "[Error: Max iterations reached in tool calling loop]"
def _stream_response(self, model: str) -> str:
"""Streaming response with live rendering."""
assistant_text = ""
with Live(console=self.console, refresh_per_second=10) as live:
stream = self.client.responses.create(
model=model,
input=self.messages,
stream=True,
)
for event in stream:
if event.type == "response.output_text.delta":
assistant_text += event.delta
live.update(Markdown(assistant_text))
# Add to history
self.messages.append({
"role": "assistant",
"content": [{"type": "output_text", "text": assistant_text}]
})
return assistant_text
"""Streaming response with live rendering and tool support."""
max_iterations = 10
iteration = 0
while iteration < max_iterations:
iteration += 1
assistant_text = ""
tool_calls = {} # Dict to track tool calls by item_id
tool_calls_list = [] # Final list of completed tool calls
assistant_content = []
with Live(console=self.console, refresh_per_second=10) as live:
kwargs = {
"model": model,
"input": self.messages,
"stream": True,
}
if self.tools_enabled:
kwargs["tools"] = TOOLS
stream = self.client.responses.create(**kwargs)
for event in stream:
if event.type == "response.output_text.delta":
assistant_text += event.delta
live.update(Markdown(assistant_text))
elif event.type == "response.output_item.added":
if hasattr(event, 'item') and event.item.type == "function_call":
# Start tracking a new tool call
tool_calls[event.item.id] = {
"id": event.item.call_id,
"name": event.item.name,
"arguments": "",
"item_id": event.item.id
}
elif event.type == "response.function_call_arguments.delta":
# Accumulate arguments for the current function call
# Find which tool call this belongs to by item_id
if hasattr(event, 'item_id') and event.item_id in tool_calls:
tool_calls[event.item_id]["arguments"] += event.delta
elif event.type == "response.output_item.done":
if hasattr(event, 'item') and event.item.type == "function_call":
# Function call is complete
if event.item.id in tool_calls:
tool_call = tool_calls[event.item.id]
try:
# Parse the complete arguments JSON
tool_call["arguments"] = json.loads(tool_call["arguments"])
tool_calls_list.append(tool_call)
except json.JSONDecodeError:
self.console.print(f"[red]Error parsing tool arguments JSON[/red]")
# Build assistant content
if assistant_text:
assistant_content.append({"type": "output_text", "text": assistant_text})
for tool_call in tool_calls_list:
assistant_content.append({
"type": "tool_use",
"id": tool_call["id"],
"name": tool_call["name"],
"input": tool_call["arguments"]
})
# Add to history as a message-type input item
if assistant_content:
self.messages.append({
"type": "message",
"role": "assistant",
"content": assistant_content
})
# If no tool calls, we're done
if not tool_calls_list:
return assistant_text
# Execute tools
self.console.print(f"\n[dim]Executing {len(tool_calls_list)} tool(s)...[/dim]")
tool_results = []
for tool_call in tool_calls_list:
self.console.print(f"[dim] → {tool_call['name']}({json.dumps(tool_call['arguments'])})[/dim]")
result = execute_tool(tool_call["name"], tool_call["arguments"])
tool_results.append({
"type": "function_call_output",
"call_id": tool_call["id"],
"output": result
})
# Add tool results to input
self.messages.extend(tool_results)
# Continue loop for next response
return "[Error: Max iterations reached in tool calling loop]"
def clear_history(self):
"""Clear conversation history."""
@@ -118,6 +338,20 @@ def print_models_table(client: OpenAI):
console.print(table)
def print_tools_table():
"""Print available tools."""
console = Console()
table = Table(title="Available Tools", show_header=True, header_style="bold magenta")
table.add_column("Tool Name", style="cyan")
table.add_column("Description", style="green")
for tool in TOOLS:
if tool.get("type") == "function":
table.add_row(tool["name"], tool["description"])
console.print(table)
def main():
parser = argparse.ArgumentParser(description="Chat with latticelm")
parser.add_argument("--url", default="http://localhost:8080", help="Gateway URL")
@@ -151,11 +385,14 @@ def main():
"[bold cyan]latticelm Chat Interface[/bold cyan]\n"
f"Connected to: [green]{args.url}[/green]\n"
f"Model: [yellow]{current_model}[/yellow]\n"
f"Streaming: [{'green' if stream_enabled else 'red'}]{stream_enabled}[/]\n\n"
f"Streaming: [{'green' if stream_enabled else 'red'}]{stream_enabled}[/]\n"
f"Tools: [{'green' if client.tools_enabled else 'red'}]{client.tools_enabled}[/]\n\n"
"Commands:\n"
" [bold]/model <name>[/bold] - Switch model\n"
" [bold]/models[/bold] - List available models\n"
" [bold]/stream[/bold] - Toggle streaming\n"
" [bold]/tools[/bold] - Toggle tool calling\n"
" [bold]/listtools[/bold] - List available tools\n"
" [bold]/clear[/bold] - Clear conversation\n"
" [bold]/quit[/bold] or [bold]/exit[/bold] - Exit\n"
" [bold]/help[/bold] - Show this help",
@@ -196,6 +433,8 @@ def main():
" /model <name> - Switch model\n"
" /models - List available models\n"
" /stream - Toggle streaming\n"
" /tools - Toggle tool calling\n"
" /listtools - List available tools\n"
" /clear - Clear conversation\n"
" /quit - Exit",
title="Help",
@@ -204,7 +443,10 @@ def main():
elif cmd == "/models":
print_models_table(client.client)
elif cmd == "/listtools":
print_tools_table()
elif cmd == "/model":
if len(cmd_parts) < 2:
console.print("[red]Usage: /model <model-name>[/red]")
@@ -219,7 +461,11 @@ def main():
elif cmd == "/stream":
stream_enabled = not stream_enabled
console.print(f"[green]Streaming {'enabled' if stream_enabled else 'disabled'}[/green]")
elif cmd == "/tools":
client.tools_enabled = not client.tools_enabled
console.print(f"[green]Tools {'enabled' if client.tools_enabled else 'disabled'}[/green]")
elif cmd == "/clear":
client.clear_history()
console.print("[green]Conversation history cleared[/green]")