Compare commits
2 Commits
6adf7eae54
...
cb631479a1
| Author | SHA1 | Date | |
|---|---|---|---|
| cb631479a1 | |||
| 841bcd0e8b |
@@ -94,9 +94,11 @@ type InputItem struct {
|
|||||||
|
|
||||||
// Message is the normalized internal message representation.
|
// Message is the normalized internal message representation.
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content []ContentBlock `json:"content"`
|
Content []ContentBlock `json:"content"`
|
||||||
CallID string `json:"call_id,omitempty"` // for tool messages
|
CallID string `json:"call_id,omitempty"` // for tool messages
|
||||||
|
Name string `json:"name,omitempty"` // for tool messages
|
||||||
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"` // for assistant messages
|
||||||
}
|
}
|
||||||
|
|
||||||
// ContentBlock is a typed content element.
|
// ContentBlock is a typed content element.
|
||||||
@@ -129,9 +131,35 @@ func (r *ResponseRequest) NormalizeInput() []Message {
|
|||||||
}
|
}
|
||||||
msg.Content = []ContentBlock{{Type: contentType, Text: s}}
|
msg.Content = []ContentBlock{{Type: contentType, Text: s}}
|
||||||
} else {
|
} else {
|
||||||
var blocks []ContentBlock
|
// Content is an array of blocks - parse them
|
||||||
_ = json.Unmarshal(item.Content, &blocks)
|
var rawBlocks []map[string]interface{}
|
||||||
msg.Content = blocks
|
if err := json.Unmarshal(item.Content, &rawBlocks); err == nil {
|
||||||
|
// Extract content blocks and tool calls
|
||||||
|
for _, block := range rawBlocks {
|
||||||
|
blockType, _ := block["type"].(string)
|
||||||
|
|
||||||
|
if blockType == "tool_use" {
|
||||||
|
// Extract tool call information
|
||||||
|
toolCall := ToolCall{
|
||||||
|
ID: getStringField(block, "id"),
|
||||||
|
Name: getStringField(block, "name"),
|
||||||
|
}
|
||||||
|
// input field contains the arguments as a map
|
||||||
|
if input, ok := block["input"].(map[string]interface{}); ok {
|
||||||
|
if inputJSON, err := json.Marshal(input); err == nil {
|
||||||
|
toolCall.Arguments = string(inputJSON)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
msg.ToolCalls = append(msg.ToolCalls, toolCall)
|
||||||
|
} else if blockType == "output_text" || blockType == "input_text" {
|
||||||
|
// Regular text content block
|
||||||
|
msg.Content = append(msg.Content, ContentBlock{
|
||||||
|
Type: blockType,
|
||||||
|
Text: getStringField(block, "text"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
msgs = append(msgs, msg)
|
msgs = append(msgs, msg)
|
||||||
@@ -140,6 +168,7 @@ func (r *ResponseRequest) NormalizeInput() []Message {
|
|||||||
Role: "tool",
|
Role: "tool",
|
||||||
Content: []ContentBlock{{Type: "input_text", Text: item.Output}},
|
Content: []ContentBlock{{Type: "input_text", Text: item.Output}},
|
||||||
CallID: item.CallID,
|
CallID: item.CallID,
|
||||||
|
Name: item.Name,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -338,3 +367,11 @@ func (r *ResponseRequest) Validate() error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getStringField is a helper to safely extract string fields from a map
|
||||||
|
func getStringField(m map[string]interface{}, key string) string {
|
||||||
|
if val, ok := m[key].(string); ok {
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|||||||
@@ -85,7 +85,23 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
|
|||||||
case "user":
|
case "user":
|
||||||
anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage(anthropic.NewTextBlock(content)))
|
anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage(anthropic.NewTextBlock(content)))
|
||||||
case "assistant":
|
case "assistant":
|
||||||
anthropicMsgs = append(anthropicMsgs, anthropic.NewAssistantMessage(anthropic.NewTextBlock(content)))
|
// Build content blocks including text and tool calls
|
||||||
|
var contentBlocks []anthropic.ContentBlockParamUnion
|
||||||
|
if content != "" {
|
||||||
|
contentBlocks = append(contentBlocks, anthropic.NewTextBlock(content))
|
||||||
|
}
|
||||||
|
// Add tool use blocks
|
||||||
|
for _, tc := range msg.ToolCalls {
|
||||||
|
var input map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(tc.Arguments), &input); err != nil {
|
||||||
|
// If unmarshal fails, skip this tool call
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
contentBlocks = append(contentBlocks, anthropic.NewToolUseBlock(tc.ID, input, tc.Name))
|
||||||
|
}
|
||||||
|
if len(contentBlocks) > 0 {
|
||||||
|
anthropicMsgs = append(anthropicMsgs, anthropic.NewAssistantMessage(contentBlocks...))
|
||||||
|
}
|
||||||
case "tool":
|
case "tool":
|
||||||
// Tool results must be in user message with tool_result blocks
|
// Tool results must be in user message with tool_result blocks
|
||||||
anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage(
|
anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage(
|
||||||
@@ -213,7 +229,23 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
|
|||||||
case "user":
|
case "user":
|
||||||
anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage(anthropic.NewTextBlock(content)))
|
anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage(anthropic.NewTextBlock(content)))
|
||||||
case "assistant":
|
case "assistant":
|
||||||
anthropicMsgs = append(anthropicMsgs, anthropic.NewAssistantMessage(anthropic.NewTextBlock(content)))
|
// Build content blocks including text and tool calls
|
||||||
|
var contentBlocks []anthropic.ContentBlockParamUnion
|
||||||
|
if content != "" {
|
||||||
|
contentBlocks = append(contentBlocks, anthropic.NewTextBlock(content))
|
||||||
|
}
|
||||||
|
// Add tool use blocks
|
||||||
|
for _, tc := range msg.ToolCalls {
|
||||||
|
var input map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(tc.Arguments), &input); err != nil {
|
||||||
|
// If unmarshal fails, skip this tool call
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
contentBlocks = append(contentBlocks, anthropic.NewToolUseBlock(tc.ID, input, tc.Name))
|
||||||
|
}
|
||||||
|
if len(contentBlocks) > 0 {
|
||||||
|
anthropicMsgs = append(anthropicMsgs, anthropic.NewAssistantMessage(contentBlocks...))
|
||||||
|
}
|
||||||
case "tool":
|
case "tool":
|
||||||
// Tool results must be in user message with tool_result blocks
|
// Tool results must be in user message with tool_result blocks
|
||||||
anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage(
|
anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage(
|
||||||
|
|||||||
@@ -232,6 +232,19 @@ func convertMessages(messages []api.Message) ([]*genai.Content, string) {
|
|||||||
var contents []*genai.Content
|
var contents []*genai.Content
|
||||||
var systemText string
|
var systemText string
|
||||||
|
|
||||||
|
// Build a map of CallID -> Name from assistant tool calls
|
||||||
|
// This allows us to look up function names when processing tool results
|
||||||
|
callIDToName := make(map[string]string)
|
||||||
|
for _, msg := range messages {
|
||||||
|
if msg.Role == "assistant" || msg.Role == "model" {
|
||||||
|
for _, tc := range msg.ToolCalls {
|
||||||
|
if tc.ID != "" && tc.Name != "" {
|
||||||
|
callIDToName[tc.ID] = tc.Name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for _, msg := range messages {
|
for _, msg := range messages {
|
||||||
if msg.Role == "system" || msg.Role == "developer" {
|
if msg.Role == "system" || msg.Role == "developer" {
|
||||||
for _, block := range msg.Content {
|
for _, block := range msg.Content {
|
||||||
@@ -258,11 +271,17 @@ func convertMessages(messages []api.Message) ([]*genai.Content, string) {
|
|||||||
responseMap = map[string]any{"output": output}
|
responseMap = map[string]any{"output": output}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create FunctionResponse part with CallID from message
|
// Get function name from message or look it up from CallID
|
||||||
|
name := msg.Name
|
||||||
|
if name == "" && msg.CallID != "" {
|
||||||
|
name = callIDToName[msg.CallID]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create FunctionResponse part with CallID and Name from message
|
||||||
part := &genai.Part{
|
part := &genai.Part{
|
||||||
FunctionResponse: &genai.FunctionResponse{
|
FunctionResponse: &genai.FunctionResponse{
|
||||||
ID: msg.CallID,
|
ID: msg.CallID,
|
||||||
Name: "", // Name is optional for responses
|
Name: name, // Name is required by Google
|
||||||
Response: responseMap,
|
Response: responseMap,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -282,6 +301,27 @@ func convertMessages(messages []api.Message) ([]*genai.Content, string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add tool calls for assistant messages
|
||||||
|
if msg.Role == "assistant" || msg.Role == "model" {
|
||||||
|
for _, tc := range msg.ToolCalls {
|
||||||
|
// Parse arguments JSON into map
|
||||||
|
var args map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(tc.Arguments), &args); err != nil {
|
||||||
|
// If unmarshal fails, skip this tool call
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create FunctionCall part
|
||||||
|
parts = append(parts, &genai.Part{
|
||||||
|
FunctionCall: &genai.FunctionCall{
|
||||||
|
ID: tc.ID,
|
||||||
|
Name: tc.Name,
|
||||||
|
Args: args,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
role := "user"
|
role := "user"
|
||||||
if msg.Role == "assistant" || msg.Role == "model" {
|
if msg.Role == "assistant" || msg.Role == "model" {
|
||||||
role = "model"
|
role = "model"
|
||||||
|
|||||||
@@ -86,7 +86,32 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
|
|||||||
case "user":
|
case "user":
|
||||||
oaiMessages = append(oaiMessages, openai.UserMessage(content))
|
oaiMessages = append(oaiMessages, openai.UserMessage(content))
|
||||||
case "assistant":
|
case "assistant":
|
||||||
oaiMessages = append(oaiMessages, openai.AssistantMessage(content))
|
// If assistant message has tool calls, include them
|
||||||
|
if len(msg.ToolCalls) > 0 {
|
||||||
|
toolCalls := make([]openai.ChatCompletionMessageToolCallUnionParam, len(msg.ToolCalls))
|
||||||
|
for i, tc := range msg.ToolCalls {
|
||||||
|
toolCalls[i] = openai.ChatCompletionMessageToolCallUnionParam{
|
||||||
|
OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
|
||||||
|
ID: tc.ID,
|
||||||
|
Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
|
||||||
|
Name: tc.Name,
|
||||||
|
Arguments: tc.Arguments,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
msgParam := openai.ChatCompletionAssistantMessageParam{
|
||||||
|
ToolCalls: toolCalls,
|
||||||
|
}
|
||||||
|
if content != "" {
|
||||||
|
msgParam.Content.OfString = openai.String(content)
|
||||||
|
}
|
||||||
|
oaiMessages = append(oaiMessages, openai.ChatCompletionMessageParamUnion{
|
||||||
|
OfAssistant: &msgParam,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
oaiMessages = append(oaiMessages, openai.AssistantMessage(content))
|
||||||
|
}
|
||||||
case "system":
|
case "system":
|
||||||
oaiMessages = append(oaiMessages, openai.SystemMessage(content))
|
oaiMessages = append(oaiMessages, openai.SystemMessage(content))
|
||||||
case "developer":
|
case "developer":
|
||||||
@@ -194,7 +219,32 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
|
|||||||
case "user":
|
case "user":
|
||||||
oaiMessages = append(oaiMessages, openai.UserMessage(content))
|
oaiMessages = append(oaiMessages, openai.UserMessage(content))
|
||||||
case "assistant":
|
case "assistant":
|
||||||
oaiMessages = append(oaiMessages, openai.AssistantMessage(content))
|
// If assistant message has tool calls, include them
|
||||||
|
if len(msg.ToolCalls) > 0 {
|
||||||
|
toolCalls := make([]openai.ChatCompletionMessageToolCallUnionParam, len(msg.ToolCalls))
|
||||||
|
for i, tc := range msg.ToolCalls {
|
||||||
|
toolCalls[i] = openai.ChatCompletionMessageToolCallUnionParam{
|
||||||
|
OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
|
||||||
|
ID: tc.ID,
|
||||||
|
Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
|
||||||
|
Name: tc.Name,
|
||||||
|
Arguments: tc.Arguments,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
msgParam := openai.ChatCompletionAssistantMessageParam{
|
||||||
|
ToolCalls: toolCalls,
|
||||||
|
}
|
||||||
|
if content != "" {
|
||||||
|
msgParam.Content.OfString = openai.String(content)
|
||||||
|
}
|
||||||
|
oaiMessages = append(oaiMessages, openai.ChatCompletionMessageParamUnion{
|
||||||
|
OfAssistant: &msgParam,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
oaiMessages = append(oaiMessages, openai.AssistantMessage(content))
|
||||||
|
}
|
||||||
case "system":
|
case "system":
|
||||||
oaiMessages = append(oaiMessages, openai.SystemMessage(content))
|
oaiMessages = append(oaiMessages, openai.SystemMessage(content))
|
||||||
case "developer":
|
case "developer":
|
||||||
|
|||||||
@@ -141,8 +141,9 @@ func (s *GatewayServer) handleSyncResponse(w http.ResponseWriter, r *http.Reques
|
|||||||
|
|
||||||
// Build assistant message for conversation store
|
// Build assistant message for conversation store
|
||||||
assistantMsg := api.Message{
|
assistantMsg := api.Message{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: []api.ContentBlock{{Type: "output_text", Text: result.Text}},
|
Content: []api.ContentBlock{{Type: "output_text", Text: result.Text}},
|
||||||
|
ToolCalls: result.ToolCalls,
|
||||||
}
|
}
|
||||||
allMsgs := append(storeMsgs, assistantMsg)
|
allMsgs := append(storeMsgs, assistantMsg)
|
||||||
if _, err := s.convs.Create(responseID, result.Model, allMsgs); err != nil {
|
if _, err := s.convs.Create(responseID, result.Model, allMsgs); err != nil {
|
||||||
@@ -460,10 +461,11 @@ loop:
|
|||||||
})
|
})
|
||||||
|
|
||||||
// Store conversation
|
// Store conversation
|
||||||
if fullText != "" {
|
if fullText != "" || len(toolCalls) > 0 {
|
||||||
assistantMsg := api.Message{
|
assistantMsg := api.Message{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: []api.ContentBlock{{Type: "output_text", Text: fullText}},
|
Content: []api.ContentBlock{{Type: "output_text", Text: fullText}},
|
||||||
|
ToolCalls: toolCalls,
|
||||||
}
|
}
|
||||||
allMsgs := append(storeMsgs, assistantMsg)
|
allMsgs := append(storeMsgs, assistantMsg)
|
||||||
if _, err := s.convs.Create(responseID, model, allMsgs); err != nil {
|
if _, err := s.convs.Create(responseID, model, allMsgs); err != nil {
|
||||||
|
|||||||
314
scripts/chat.py
314
scripts/chat.py
@@ -18,8 +18,10 @@ Usage:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
import sys
|
import sys
|
||||||
from typing import Optional
|
from datetime import datetime
|
||||||
|
from typing import Optional, Any
|
||||||
|
|
||||||
from openai import OpenAI, APIStatusError
|
from openai import OpenAI, APIStatusError
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
@@ -30,6 +32,85 @@ from rich.prompt import Prompt
|
|||||||
from rich.table import Table
|
from rich.table import Table
|
||||||
|
|
||||||
|
|
||||||
|
# Define available tools in OpenResponses format
|
||||||
|
TOOLS = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"name": "calculator",
|
||||||
|
"description": "Perform basic arithmetic operations. Supports addition, subtraction, multiplication, and division.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"operation": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["add", "subtract", "multiply", "divide"],
|
||||||
|
"description": "The arithmetic operation to perform"
|
||||||
|
},
|
||||||
|
"a": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The first number"
|
||||||
|
},
|
||||||
|
"b": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The second number"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["operation", "a", "b"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"name": "get_current_time",
|
||||||
|
"description": "Get the current time in a specified timezone or UTC",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"timezone": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Timezone name (e.g., 'UTC', 'America/New_York', 'Europe/London'). Defaults to UTC.",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def execute_tool(tool_name: str, arguments: dict[str, Any]) -> str:
|
||||||
|
"""Execute a tool and return the result as a string."""
|
||||||
|
if tool_name == "calculator":
|
||||||
|
operation = arguments["operation"]
|
||||||
|
a = arguments["a"]
|
||||||
|
b = arguments["b"]
|
||||||
|
|
||||||
|
if operation == "add":
|
||||||
|
result = a + b
|
||||||
|
elif operation == "subtract":
|
||||||
|
result = a - b
|
||||||
|
elif operation == "multiply":
|
||||||
|
result = a * b
|
||||||
|
elif operation == "divide":
|
||||||
|
if b == 0:
|
||||||
|
return json.dumps({"error": "Division by zero"})
|
||||||
|
result = a / b
|
||||||
|
else:
|
||||||
|
return json.dumps({"error": f"Unknown operation: {operation}"})
|
||||||
|
|
||||||
|
return json.dumps({"result": result, "operation": operation, "a": a, "b": b})
|
||||||
|
|
||||||
|
elif tool_name == "get_current_time":
|
||||||
|
# Simple implementation without pytz
|
||||||
|
timezone = arguments.get("timezone", "UTC")
|
||||||
|
now = datetime.now()
|
||||||
|
return json.dumps({
|
||||||
|
"current_time": now.isoformat(),
|
||||||
|
"timezone": timezone,
|
||||||
|
"note": "Showing local system time (timezone parameter not fully implemented)"
|
||||||
|
})
|
||||||
|
|
||||||
|
else:
|
||||||
|
return json.dumps({"error": f"Unknown tool: {tool_name}"})
|
||||||
|
|
||||||
|
|
||||||
class ChatClient:
|
class ChatClient:
|
||||||
def __init__(self, base_url: str, token: Optional[str] = None):
|
def __init__(self, base_url: str, token: Optional[str] = None):
|
||||||
self.base_url = base_url.rstrip("/")
|
self.base_url = base_url.rstrip("/")
|
||||||
@@ -39,11 +120,13 @@ class ChatClient:
|
|||||||
)
|
)
|
||||||
self.messages = []
|
self.messages = []
|
||||||
self.console = Console()
|
self.console = Console()
|
||||||
|
self.tools_enabled = True
|
||||||
|
|
||||||
def chat(self, user_message: str, model: str, stream: bool = True):
|
def chat(self, user_message: str, model: str, stream: bool = True):
|
||||||
"""Send a chat message and get response."""
|
"""Send a chat message and get response."""
|
||||||
# Add user message to history
|
# Add user message to history as a message-type input item
|
||||||
self.messages.append({
|
self.messages.append({
|
||||||
|
"type": "message",
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [{"type": "input_text", "text": user_message}]
|
"content": [{"type": "input_text", "text": user_message}]
|
||||||
})
|
})
|
||||||
@@ -54,45 +137,182 @@ class ChatClient:
|
|||||||
return self._sync_response(model)
|
return self._sync_response(model)
|
||||||
|
|
||||||
def _sync_response(self, model: str) -> str:
|
def _sync_response(self, model: str) -> str:
|
||||||
"""Non-streaming response."""
|
"""Non-streaming response with tool support."""
|
||||||
with self.console.status(f"[bold blue]Thinking ({model})..."):
|
max_iterations = 10 # Prevent infinite loops
|
||||||
response = self.client.responses.create(
|
iteration = 0
|
||||||
model=model,
|
|
||||||
input=self.messages,
|
|
||||||
)
|
|
||||||
|
|
||||||
assistant_text = response.output_text
|
while iteration < max_iterations:
|
||||||
|
iteration += 1
|
||||||
|
|
||||||
# Add to history
|
with self.console.status(f"[bold blue]Thinking ({model})..."):
|
||||||
self.messages.append({
|
kwargs = {
|
||||||
"role": "assistant",
|
"model": model,
|
||||||
"content": [{"type": "output_text", "text": assistant_text}]
|
"input": self.messages,
|
||||||
})
|
}
|
||||||
|
if self.tools_enabled:
|
||||||
|
kwargs["tools"] = TOOLS
|
||||||
|
|
||||||
return assistant_text
|
response = self.client.responses.create(**kwargs)
|
||||||
|
|
||||||
|
# Check if there are tool calls
|
||||||
|
tool_calls = []
|
||||||
|
assistant_content = []
|
||||||
|
text_parts = []
|
||||||
|
|
||||||
|
for item in response.output:
|
||||||
|
if item.type == "message":
|
||||||
|
# Extract text from message content
|
||||||
|
for content_block in item.content:
|
||||||
|
if content_block.type == "output_text":
|
||||||
|
text_parts.append(content_block.text)
|
||||||
|
assistant_content.append({"type": "output_text", "text": content_block.text})
|
||||||
|
elif item.type == "function_call":
|
||||||
|
# Parse arguments JSON string
|
||||||
|
try:
|
||||||
|
arguments = json.loads(item.arguments)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
arguments = {}
|
||||||
|
|
||||||
|
tool_calls.append({
|
||||||
|
"id": item.call_id,
|
||||||
|
"name": item.name,
|
||||||
|
"arguments": arguments
|
||||||
|
})
|
||||||
|
assistant_content.append({
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": item.call_id,
|
||||||
|
"name": item.name,
|
||||||
|
"input": arguments
|
||||||
|
})
|
||||||
|
|
||||||
|
# Add assistant message to history as a message-type input item
|
||||||
|
if assistant_content:
|
||||||
|
self.messages.append({
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"content": assistant_content
|
||||||
|
})
|
||||||
|
|
||||||
|
# If no tool calls, we're done
|
||||||
|
if not tool_calls:
|
||||||
|
return "\n".join(text_parts) if text_parts else ""
|
||||||
|
|
||||||
|
# Execute tools and add results
|
||||||
|
self.console.print(f"[dim]Executing {len(tool_calls)} tool(s)...[/dim]")
|
||||||
|
tool_results = []
|
||||||
|
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
self.console.print(f"[dim] → {tool_call['name']}({json.dumps(tool_call['arguments'])})[/dim]")
|
||||||
|
result = execute_tool(tool_call["name"], tool_call["arguments"])
|
||||||
|
tool_results.append({
|
||||||
|
"type": "function_call_output",
|
||||||
|
"call_id": tool_call["id"],
|
||||||
|
"output": result
|
||||||
|
})
|
||||||
|
|
||||||
|
# Add tool results to input
|
||||||
|
self.messages.extend(tool_results)
|
||||||
|
|
||||||
|
# Continue the loop to get the next response
|
||||||
|
|
||||||
|
return "[Error: Max iterations reached in tool calling loop]"
|
||||||
|
|
||||||
def _stream_response(self, model: str) -> str:
|
def _stream_response(self, model: str) -> str:
|
||||||
"""Streaming response with live rendering."""
|
"""Streaming response with live rendering and tool support."""
|
||||||
assistant_text = ""
|
max_iterations = 10
|
||||||
|
iteration = 0
|
||||||
|
|
||||||
with Live(console=self.console, refresh_per_second=10) as live:
|
while iteration < max_iterations:
|
||||||
stream = self.client.responses.create(
|
iteration += 1
|
||||||
model=model,
|
assistant_text = ""
|
||||||
input=self.messages,
|
tool_calls = {} # Dict to track tool calls by item_id
|
||||||
stream=True,
|
tool_calls_list = [] # Final list of completed tool calls
|
||||||
)
|
assistant_content = []
|
||||||
for event in stream:
|
|
||||||
if event.type == "response.output_text.delta":
|
|
||||||
assistant_text += event.delta
|
|
||||||
live.update(Markdown(assistant_text))
|
|
||||||
|
|
||||||
# Add to history
|
with Live(console=self.console, refresh_per_second=10) as live:
|
||||||
self.messages.append({
|
kwargs = {
|
||||||
"role": "assistant",
|
"model": model,
|
||||||
"content": [{"type": "output_text", "text": assistant_text}]
|
"input": self.messages,
|
||||||
})
|
"stream": True,
|
||||||
|
}
|
||||||
|
if self.tools_enabled:
|
||||||
|
kwargs["tools"] = TOOLS
|
||||||
|
|
||||||
return assistant_text
|
stream = self.client.responses.create(**kwargs)
|
||||||
|
|
||||||
|
for event in stream:
|
||||||
|
if event.type == "response.output_text.delta":
|
||||||
|
assistant_text += event.delta
|
||||||
|
live.update(Markdown(assistant_text))
|
||||||
|
elif event.type == "response.output_item.added":
|
||||||
|
if hasattr(event, 'item') and event.item.type == "function_call":
|
||||||
|
# Start tracking a new tool call
|
||||||
|
tool_calls[event.item.id] = {
|
||||||
|
"id": event.item.call_id,
|
||||||
|
"name": event.item.name,
|
||||||
|
"arguments": "",
|
||||||
|
"item_id": event.item.id
|
||||||
|
}
|
||||||
|
elif event.type == "response.function_call_arguments.delta":
|
||||||
|
# Accumulate arguments for the current function call
|
||||||
|
# Find which tool call this belongs to by item_id
|
||||||
|
if hasattr(event, 'item_id') and event.item_id in tool_calls:
|
||||||
|
tool_calls[event.item_id]["arguments"] += event.delta
|
||||||
|
elif event.type == "response.output_item.done":
|
||||||
|
if hasattr(event, 'item') and event.item.type == "function_call":
|
||||||
|
# Function call is complete
|
||||||
|
if event.item.id in tool_calls:
|
||||||
|
tool_call = tool_calls[event.item.id]
|
||||||
|
try:
|
||||||
|
# Parse the complete arguments JSON
|
||||||
|
tool_call["arguments"] = json.loads(tool_call["arguments"])
|
||||||
|
tool_calls_list.append(tool_call)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
self.console.print(f"[red]Error parsing tool arguments JSON[/red]")
|
||||||
|
|
||||||
|
# Build assistant content
|
||||||
|
if assistant_text:
|
||||||
|
assistant_content.append({"type": "output_text", "text": assistant_text})
|
||||||
|
|
||||||
|
for tool_call in tool_calls_list:
|
||||||
|
assistant_content.append({
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": tool_call["id"],
|
||||||
|
"name": tool_call["name"],
|
||||||
|
"input": tool_call["arguments"]
|
||||||
|
})
|
||||||
|
|
||||||
|
# Add to history as a message-type input item
|
||||||
|
if assistant_content:
|
||||||
|
self.messages.append({
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"content": assistant_content
|
||||||
|
})
|
||||||
|
|
||||||
|
# If no tool calls, we're done
|
||||||
|
if not tool_calls_list:
|
||||||
|
return assistant_text
|
||||||
|
|
||||||
|
# Execute tools
|
||||||
|
self.console.print(f"\n[dim]Executing {len(tool_calls_list)} tool(s)...[/dim]")
|
||||||
|
tool_results = []
|
||||||
|
|
||||||
|
for tool_call in tool_calls_list:
|
||||||
|
self.console.print(f"[dim] → {tool_call['name']}({json.dumps(tool_call['arguments'])})[/dim]")
|
||||||
|
result = execute_tool(tool_call["name"], tool_call["arguments"])
|
||||||
|
tool_results.append({
|
||||||
|
"type": "function_call_output",
|
||||||
|
"call_id": tool_call["id"],
|
||||||
|
"output": result
|
||||||
|
})
|
||||||
|
|
||||||
|
# Add tool results to input
|
||||||
|
self.messages.extend(tool_results)
|
||||||
|
|
||||||
|
# Continue loop for next response
|
||||||
|
|
||||||
|
return "[Error: Max iterations reached in tool calling loop]"
|
||||||
|
|
||||||
def clear_history(self):
|
def clear_history(self):
|
||||||
"""Clear conversation history."""
|
"""Clear conversation history."""
|
||||||
@@ -118,6 +338,20 @@ def print_models_table(client: OpenAI):
|
|||||||
console.print(table)
|
console.print(table)
|
||||||
|
|
||||||
|
|
||||||
|
def print_tools_table():
|
||||||
|
"""Print available tools."""
|
||||||
|
console = Console()
|
||||||
|
table = Table(title="Available Tools", show_header=True, header_style="bold magenta")
|
||||||
|
table.add_column("Tool Name", style="cyan")
|
||||||
|
table.add_column("Description", style="green")
|
||||||
|
|
||||||
|
for tool in TOOLS:
|
||||||
|
if tool.get("type") == "function":
|
||||||
|
table.add_row(tool["name"], tool["description"])
|
||||||
|
|
||||||
|
console.print(table)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="Chat with latticelm")
|
parser = argparse.ArgumentParser(description="Chat with latticelm")
|
||||||
parser.add_argument("--url", default="http://localhost:8080", help="Gateway URL")
|
parser.add_argument("--url", default="http://localhost:8080", help="Gateway URL")
|
||||||
@@ -151,11 +385,14 @@ def main():
|
|||||||
"[bold cyan]latticelm Chat Interface[/bold cyan]\n"
|
"[bold cyan]latticelm Chat Interface[/bold cyan]\n"
|
||||||
f"Connected to: [green]{args.url}[/green]\n"
|
f"Connected to: [green]{args.url}[/green]\n"
|
||||||
f"Model: [yellow]{current_model}[/yellow]\n"
|
f"Model: [yellow]{current_model}[/yellow]\n"
|
||||||
f"Streaming: [{'green' if stream_enabled else 'red'}]{stream_enabled}[/]\n\n"
|
f"Streaming: [{'green' if stream_enabled else 'red'}]{stream_enabled}[/]\n"
|
||||||
|
f"Tools: [{'green' if client.tools_enabled else 'red'}]{client.tools_enabled}[/]\n\n"
|
||||||
"Commands:\n"
|
"Commands:\n"
|
||||||
" [bold]/model <name>[/bold] - Switch model\n"
|
" [bold]/model <name>[/bold] - Switch model\n"
|
||||||
" [bold]/models[/bold] - List available models\n"
|
" [bold]/models[/bold] - List available models\n"
|
||||||
" [bold]/stream[/bold] - Toggle streaming\n"
|
" [bold]/stream[/bold] - Toggle streaming\n"
|
||||||
|
" [bold]/tools[/bold] - Toggle tool calling\n"
|
||||||
|
" [bold]/listtools[/bold] - List available tools\n"
|
||||||
" [bold]/clear[/bold] - Clear conversation\n"
|
" [bold]/clear[/bold] - Clear conversation\n"
|
||||||
" [bold]/quit[/bold] or [bold]/exit[/bold] - Exit\n"
|
" [bold]/quit[/bold] or [bold]/exit[/bold] - Exit\n"
|
||||||
" [bold]/help[/bold] - Show this help",
|
" [bold]/help[/bold] - Show this help",
|
||||||
@@ -196,6 +433,8 @@ def main():
|
|||||||
" /model <name> - Switch model\n"
|
" /model <name> - Switch model\n"
|
||||||
" /models - List available models\n"
|
" /models - List available models\n"
|
||||||
" /stream - Toggle streaming\n"
|
" /stream - Toggle streaming\n"
|
||||||
|
" /tools - Toggle tool calling\n"
|
||||||
|
" /listtools - List available tools\n"
|
||||||
" /clear - Clear conversation\n"
|
" /clear - Clear conversation\n"
|
||||||
" /quit - Exit",
|
" /quit - Exit",
|
||||||
title="Help",
|
title="Help",
|
||||||
@@ -205,6 +444,9 @@ def main():
|
|||||||
elif cmd == "/models":
|
elif cmd == "/models":
|
||||||
print_models_table(client.client)
|
print_models_table(client.client)
|
||||||
|
|
||||||
|
elif cmd == "/listtools":
|
||||||
|
print_tools_table()
|
||||||
|
|
||||||
elif cmd == "/model":
|
elif cmd == "/model":
|
||||||
if len(cmd_parts) < 2:
|
if len(cmd_parts) < 2:
|
||||||
console.print("[red]Usage: /model <model-name>[/red]")
|
console.print("[red]Usage: /model <model-name>[/red]")
|
||||||
@@ -220,6 +462,10 @@ def main():
|
|||||||
stream_enabled = not stream_enabled
|
stream_enabled = not stream_enabled
|
||||||
console.print(f"[green]Streaming {'enabled' if stream_enabled else 'disabled'}[/green]")
|
console.print(f"[green]Streaming {'enabled' if stream_enabled else 'disabled'}[/green]")
|
||||||
|
|
||||||
|
elif cmd == "/tools":
|
||||||
|
client.tools_enabled = not client.tools_enabled
|
||||||
|
console.print(f"[green]Tools {'enabled' if client.tools_enabled else 'disabled'}[/green]")
|
||||||
|
|
||||||
elif cmd == "/clear":
|
elif cmd == "/clear":
|
||||||
client.clear_history()
|
client.clear_history()
|
||||||
console.print("[green]Conversation history cleared[/green]")
|
console.print("[green]Conversation history cleared[/green]")
|
||||||
|
|||||||
Reference in New Issue
Block a user