Add Google tool calling

This commit is contained in:
2026-03-02 16:12:46 +00:00
parent 38d44f104a
commit 6adf7eae54
2 changed files with 332 additions and 19 deletions

View File

@@ -0,0 +1,212 @@
package google
import (
"encoding/json"
"fmt"
"math/rand"
"time"
"google.golang.org/genai"
"github.com/ajac-zero/latticelm/internal/api"
)
// parseTools converts generic tool definitions from req.Tools (JSON) to Google's []*genai.Tool format.
func parseTools(req *api.ResponseRequest) ([]*genai.Tool, error) {
if req.Tools == nil || len(req.Tools) == 0 {
return nil, nil
}
// Unmarshal to slice of tool definitions
var toolDefs []map[string]interface{}
if err := json.Unmarshal(req.Tools, &toolDefs); err != nil {
return nil, fmt.Errorf("unmarshal tools: %w", err)
}
var functionDeclarations []*genai.FunctionDeclaration
for _, toolDef := range toolDefs {
// Extract function details
// Support both flat format (name/description/parameters at top level)
// and nested format (under "function" key)
var name, description string
var parameters interface{}
if functionData, ok := toolDef["function"].(map[string]interface{}); ok {
// Nested format: {"type": "function", "function": {...}}
name, _ = functionData["name"].(string)
description, _ = functionData["description"].(string)
parameters = functionData["parameters"]
} else {
// Flat format: {"type": "function", "name": "...", ...}
name, _ = toolDef["name"].(string)
description, _ = toolDef["description"].(string)
parameters = toolDef["parameters"]
}
if name == "" {
continue
}
// Create function declaration
funcDecl := &genai.FunctionDeclaration{
Name: name,
Description: description,
}
// Google accepts parameters as raw JSON schema
if parameters != nil {
funcDecl.ParametersJsonSchema = parameters
}
functionDeclarations = append(functionDeclarations, funcDecl)
}
// Return single Tool with all function declarations
if len(functionDeclarations) > 0 {
return []*genai.Tool{{FunctionDeclarations: functionDeclarations}}, nil
}
return nil, nil
}
// parseToolChoice converts req.ToolChoice to Google's ToolConfig with FunctionCallingConfig.
func parseToolChoice(req *api.ResponseRequest) (*genai.ToolConfig, error) {
if req.ToolChoice == nil || len(req.ToolChoice) == 0 {
return nil, nil
}
var choice interface{}
if err := json.Unmarshal(req.ToolChoice, &choice); err != nil {
return nil, fmt.Errorf("unmarshal tool_choice: %w", err)
}
config := &genai.ToolConfig{
FunctionCallingConfig: &genai.FunctionCallingConfig{},
}
// Handle string values: "auto", "none", "required"/"any"
if str, ok := choice.(string); ok {
switch str {
case "auto":
config.FunctionCallingConfig.Mode = genai.FunctionCallingConfigModeAuto
case "none":
config.FunctionCallingConfig.Mode = genai.FunctionCallingConfigModeNone
case "required", "any":
config.FunctionCallingConfig.Mode = genai.FunctionCallingConfigModeAny
default:
return nil, fmt.Errorf("unknown tool_choice string: %s", str)
}
return config, nil
}
// Handle object format: {"type": "function", "function": {"name": "..."}}
if obj, ok := choice.(map[string]interface{}); ok {
if typeVal, ok := obj["type"].(string); ok && typeVal == "function" {
config.FunctionCallingConfig.Mode = genai.FunctionCallingConfigModeAny
if funcObj, ok := obj["function"].(map[string]interface{}); ok {
if name, ok := funcObj["name"].(string); ok {
config.FunctionCallingConfig.AllowedFunctionNames = []string{name}
}
}
return config, nil
}
}
return nil, fmt.Errorf("unsupported tool_choice format")
}
// extractToolCalls extracts tool calls from Google's response format to generic api.ToolCall slice.
func extractToolCalls(resp *genai.GenerateContentResponse) []api.ToolCall {
var toolCalls []api.ToolCall
for _, candidate := range resp.Candidates {
if candidate.Content == nil {
continue
}
for _, part := range candidate.Content.Parts {
if part == nil || part.FunctionCall == nil {
continue
}
// Extract function call details
fc := part.FunctionCall
// Marshal arguments to JSON string
var argsJSON string
if fc.Args != nil {
argsBytes, err := json.Marshal(fc.Args)
if err == nil {
argsJSON = string(argsBytes)
} else {
// Fallback to empty object
argsJSON = "{}"
}
} else {
argsJSON = "{}"
}
// Generate ID if Google doesn't provide one
callID := fc.ID
if callID == "" {
callID = fmt.Sprintf("call_%s", generateRandomID())
}
toolCalls = append(toolCalls, api.ToolCall{
ID: callID,
Name: fc.Name,
Arguments: argsJSON,
})
}
}
return toolCalls
}
// extractToolCallDelta extracts streaming tool call information from response parts.
func extractToolCallDelta(part *genai.Part, index int) *api.ToolCallDelta {
if part == nil || part.FunctionCall == nil {
return nil
}
fc := part.FunctionCall
// Marshal arguments to JSON string
var argsJSON string
if fc.Args != nil {
argsBytes, err := json.Marshal(fc.Args)
if err == nil {
argsJSON = string(argsBytes)
} else {
argsJSON = "{}"
}
} else {
argsJSON = "{}"
}
// Generate ID if Google doesn't provide one
callID := fc.ID
if callID == "" {
callID = fmt.Sprintf("call_%s", generateRandomID())
}
return &api.ToolCallDelta{
Index: index,
ID: callID,
Name: fc.Name,
Arguments: argsJSON,
}
}
// generateRandomID generates a random alphanumeric ID
func generateRandomID() string {
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
const length = 24
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
b := make([]byte, length)
for i := range b {
b[i] = charset[rng.Intn(len(charset))]
}
return string(b)
}

View File

@@ -2,6 +2,7 @@ package google
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"github.com/google/uuid" "github.com/google/uuid"
@@ -76,7 +77,27 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
contents, systemText := convertMessages(messages) contents, systemText := convertMessages(messages)
config := buildConfig(systemText, req) // Parse tools if present
var tools []*genai.Tool
if req.Tools != nil && len(req.Tools) > 0 {
var err error
tools, err = parseTools(req)
if err != nil {
return nil, fmt.Errorf("parse tools: %w", err)
}
}
// Parse tool_choice if present
var toolConfig *genai.ToolConfig
if req.ToolChoice != nil && len(req.ToolChoice) > 0 {
var err error
toolConfig, err = parseToolChoice(req)
if err != nil {
return nil, fmt.Errorf("parse tool_choice: %w", err)
}
}
config := buildConfig(systemText, req, tools, toolConfig)
resp, err := p.client.Models.GenerateContent(ctx, model, contents, config) resp, err := p.client.Models.GenerateContent(ctx, model, contents, config)
if err != nil { if err != nil {
@@ -92,6 +113,11 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
} }
} }
var toolCalls []api.ToolCall
if len(resp.Candidates) > 0 {
toolCalls = extractToolCalls(resp)
}
var inputTokens, outputTokens int var inputTokens, outputTokens int
if resp.UsageMetadata != nil { if resp.UsageMetadata != nil {
inputTokens = int(resp.UsageMetadata.PromptTokenCount) inputTokens = int(resp.UsageMetadata.PromptTokenCount)
@@ -102,6 +128,7 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
ID: uuid.NewString(), ID: uuid.NewString(),
Model: model, Model: model,
Text: text, Text: text,
ToolCalls: toolCalls,
Usage: api.Usage{ Usage: api.Usage{
InputTokens: inputTokens, InputTokens: inputTokens,
OutputTokens: outputTokens, OutputTokens: outputTokens,
@@ -128,7 +155,29 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
contents, systemText := convertMessages(messages) contents, systemText := convertMessages(messages)
config := buildConfig(systemText, req) // Parse tools if present
var tools []*genai.Tool
if req.Tools != nil && len(req.Tools) > 0 {
var err error
tools, err = parseTools(req)
if err != nil {
errChan <- fmt.Errorf("parse tools: %w", err)
return
}
}
// Parse tool_choice if present
var toolConfig *genai.ToolConfig
if req.ToolChoice != nil && len(req.ToolChoice) > 0 {
var err error
toolConfig, err = parseToolChoice(req)
if err != nil {
errChan <- fmt.Errorf("parse tool_choice: %w", err)
return
}
}
config := buildConfig(systemText, req, tools, toolConfig)
stream := p.client.Models.GenerateContentStream(ctx, model, contents, config) stream := p.client.Models.GenerateContentStream(ctx, model, contents, config)
@@ -138,23 +187,34 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
return return
} }
var text string
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil { if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
for _, part := range resp.Candidates[0].Content.Parts { for partIndex, part := range resp.Candidates[0].Content.Parts {
if part != nil { if part != nil {
text += part.Text // Handle text content
} if part.Text != "" {
}
}
if text != "" {
select { select {
case deltaChan <- &api.ProviderStreamDelta{Text: text}: case deltaChan <- &api.ProviderStreamDelta{Text: part.Text}:
case <-ctx.Done(): case <-ctx.Done():
errChan <- ctx.Err() errChan <- ctx.Err()
return return
} }
} }
// Handle tool call content
if part.FunctionCall != nil {
delta := extractToolCallDelta(part, partIndex)
if delta != nil {
select {
case deltaChan <- &api.ProviderStreamDelta{ToolCallDelta: delta}:
case <-ctx.Done():
errChan <- ctx.Err()
return
}
}
}
}
}
}
} }
select { select {
@@ -182,6 +242,39 @@ func convertMessages(messages []api.Message) ([]*genai.Content, string) {
continue continue
} }
if msg.Role == "tool" {
// Tool results are sent as FunctionResponse in user role message
var output string
for _, block := range msg.Content {
if block.Type == "input_text" || block.Type == "output_text" {
output += block.Text
}
}
// Parse output as JSON map, or wrap in {"output": "..."} if not JSON
var responseMap map[string]any
if err := json.Unmarshal([]byte(output), &responseMap); err != nil {
// Not JSON, wrap it
responseMap = map[string]any{"output": output}
}
// Create FunctionResponse part with CallID from message
part := &genai.Part{
FunctionResponse: &genai.FunctionResponse{
ID: msg.CallID,
Name: "", // Name is optional for responses
Response: responseMap,
},
}
// Add to user role message
contents = append(contents, &genai.Content{
Role: "user",
Parts: []*genai.Part{part},
})
continue
}
var parts []*genai.Part var parts []*genai.Part
for _, block := range msg.Content { for _, block := range msg.Content {
if block.Type == "input_text" || block.Type == "output_text" { if block.Type == "input_text" || block.Type == "output_text" {
@@ -204,10 +297,10 @@ func convertMessages(messages []api.Message) ([]*genai.Content, string) {
} }
// buildConfig constructs a GenerateContentConfig from system text and request params. // buildConfig constructs a GenerateContentConfig from system text and request params.
func buildConfig(systemText string, req *api.ResponseRequest) *genai.GenerateContentConfig { func buildConfig(systemText string, req *api.ResponseRequest, tools []*genai.Tool, toolConfig *genai.ToolConfig) *genai.GenerateContentConfig {
var cfg *genai.GenerateContentConfig var cfg *genai.GenerateContentConfig
needsCfg := systemText != "" || req.MaxOutputTokens != nil || req.Temperature != nil || req.TopP != nil needsCfg := systemText != "" || req.MaxOutputTokens != nil || req.Temperature != nil || req.TopP != nil || tools != nil || toolConfig != nil
if !needsCfg { if !needsCfg {
return nil return nil
} }
@@ -234,6 +327,14 @@ func buildConfig(systemText string, req *api.ResponseRequest) *genai.GenerateCon
cfg.TopP = &tp cfg.TopP = &tp
} }
if tools != nil {
cfg.Tools = tools
}
if toolConfig != nil {
cfg.ToolConfig = toolConfig
}
return cfg return cfg
} }