Add Google tool calling
This commit is contained in:
212
internal/providers/google/convert.go
Normal file
212
internal/providers/google/convert.go
Normal file
@@ -0,0 +1,212 @@
|
||||
package google
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"google.golang.org/genai"
|
||||
|
||||
"github.com/ajac-zero/latticelm/internal/api"
|
||||
)
|
||||
|
||||
// parseTools converts generic tool definitions from req.Tools (JSON) to Google's []*genai.Tool format.
|
||||
func parseTools(req *api.ResponseRequest) ([]*genai.Tool, error) {
|
||||
if req.Tools == nil || len(req.Tools) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Unmarshal to slice of tool definitions
|
||||
var toolDefs []map[string]interface{}
|
||||
if err := json.Unmarshal(req.Tools, &toolDefs); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal tools: %w", err)
|
||||
}
|
||||
|
||||
var functionDeclarations []*genai.FunctionDeclaration
|
||||
|
||||
for _, toolDef := range toolDefs {
|
||||
// Extract function details
|
||||
// Support both flat format (name/description/parameters at top level)
|
||||
// and nested format (under "function" key)
|
||||
var name, description string
|
||||
var parameters interface{}
|
||||
|
||||
if functionData, ok := toolDef["function"].(map[string]interface{}); ok {
|
||||
// Nested format: {"type": "function", "function": {...}}
|
||||
name, _ = functionData["name"].(string)
|
||||
description, _ = functionData["description"].(string)
|
||||
parameters = functionData["parameters"]
|
||||
} else {
|
||||
// Flat format: {"type": "function", "name": "...", ...}
|
||||
name, _ = toolDef["name"].(string)
|
||||
description, _ = toolDef["description"].(string)
|
||||
parameters = toolDef["parameters"]
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Create function declaration
|
||||
funcDecl := &genai.FunctionDeclaration{
|
||||
Name: name,
|
||||
Description: description,
|
||||
}
|
||||
|
||||
// Google accepts parameters as raw JSON schema
|
||||
if parameters != nil {
|
||||
funcDecl.ParametersJsonSchema = parameters
|
||||
}
|
||||
|
||||
functionDeclarations = append(functionDeclarations, funcDecl)
|
||||
}
|
||||
|
||||
// Return single Tool with all function declarations
|
||||
if len(functionDeclarations) > 0 {
|
||||
return []*genai.Tool{{FunctionDeclarations: functionDeclarations}}, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// parseToolChoice converts req.ToolChoice to Google's ToolConfig with FunctionCallingConfig.
|
||||
func parseToolChoice(req *api.ResponseRequest) (*genai.ToolConfig, error) {
|
||||
if req.ToolChoice == nil || len(req.ToolChoice) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var choice interface{}
|
||||
if err := json.Unmarshal(req.ToolChoice, &choice); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal tool_choice: %w", err)
|
||||
}
|
||||
|
||||
config := &genai.ToolConfig{
|
||||
FunctionCallingConfig: &genai.FunctionCallingConfig{},
|
||||
}
|
||||
|
||||
// Handle string values: "auto", "none", "required"/"any"
|
||||
if str, ok := choice.(string); ok {
|
||||
switch str {
|
||||
case "auto":
|
||||
config.FunctionCallingConfig.Mode = genai.FunctionCallingConfigModeAuto
|
||||
case "none":
|
||||
config.FunctionCallingConfig.Mode = genai.FunctionCallingConfigModeNone
|
||||
case "required", "any":
|
||||
config.FunctionCallingConfig.Mode = genai.FunctionCallingConfigModeAny
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown tool_choice string: %s", str)
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// Handle object format: {"type": "function", "function": {"name": "..."}}
|
||||
if obj, ok := choice.(map[string]interface{}); ok {
|
||||
if typeVal, ok := obj["type"].(string); ok && typeVal == "function" {
|
||||
config.FunctionCallingConfig.Mode = genai.FunctionCallingConfigModeAny
|
||||
if funcObj, ok := obj["function"].(map[string]interface{}); ok {
|
||||
if name, ok := funcObj["name"].(string); ok {
|
||||
config.FunctionCallingConfig.AllowedFunctionNames = []string{name}
|
||||
}
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported tool_choice format")
|
||||
}
|
||||
|
||||
// extractToolCalls extracts tool calls from Google's response format to generic api.ToolCall slice.
|
||||
func extractToolCalls(resp *genai.GenerateContentResponse) []api.ToolCall {
|
||||
var toolCalls []api.ToolCall
|
||||
|
||||
for _, candidate := range resp.Candidates {
|
||||
if candidate.Content == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part == nil || part.FunctionCall == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract function call details
|
||||
fc := part.FunctionCall
|
||||
|
||||
// Marshal arguments to JSON string
|
||||
var argsJSON string
|
||||
if fc.Args != nil {
|
||||
argsBytes, err := json.Marshal(fc.Args)
|
||||
if err == nil {
|
||||
argsJSON = string(argsBytes)
|
||||
} else {
|
||||
// Fallback to empty object
|
||||
argsJSON = "{}"
|
||||
}
|
||||
} else {
|
||||
argsJSON = "{}"
|
||||
}
|
||||
|
||||
// Generate ID if Google doesn't provide one
|
||||
callID := fc.ID
|
||||
if callID == "" {
|
||||
callID = fmt.Sprintf("call_%s", generateRandomID())
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
ID: callID,
|
||||
Name: fc.Name,
|
||||
Arguments: argsJSON,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
// extractToolCallDelta extracts streaming tool call information from response parts.
|
||||
func extractToolCallDelta(part *genai.Part, index int) *api.ToolCallDelta {
|
||||
if part == nil || part.FunctionCall == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
fc := part.FunctionCall
|
||||
|
||||
// Marshal arguments to JSON string
|
||||
var argsJSON string
|
||||
if fc.Args != nil {
|
||||
argsBytes, err := json.Marshal(fc.Args)
|
||||
if err == nil {
|
||||
argsJSON = string(argsBytes)
|
||||
} else {
|
||||
argsJSON = "{}"
|
||||
}
|
||||
} else {
|
||||
argsJSON = "{}"
|
||||
}
|
||||
|
||||
// Generate ID if Google doesn't provide one
|
||||
callID := fc.ID
|
||||
if callID == "" {
|
||||
callID = fmt.Sprintf("call_%s", generateRandomID())
|
||||
}
|
||||
|
||||
return &api.ToolCallDelta{
|
||||
Index: index,
|
||||
ID: callID,
|
||||
Name: fc.Name,
|
||||
Arguments: argsJSON,
|
||||
}
|
||||
}
|
||||
|
||||
// generateRandomID generates a random alphanumeric ID
|
||||
func generateRandomID() string {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
const length = 24
|
||||
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
b[i] = charset[rng.Intn(len(charset))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package google
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -76,7 +77,27 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
|
||||
|
||||
contents, systemText := convertMessages(messages)
|
||||
|
||||
config := buildConfig(systemText, req)
|
||||
// Parse tools if present
|
||||
var tools []*genai.Tool
|
||||
if req.Tools != nil && len(req.Tools) > 0 {
|
||||
var err error
|
||||
tools, err = parseTools(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse tools: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse tool_choice if present
|
||||
var toolConfig *genai.ToolConfig
|
||||
if req.ToolChoice != nil && len(req.ToolChoice) > 0 {
|
||||
var err error
|
||||
toolConfig, err = parseToolChoice(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse tool_choice: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
config := buildConfig(systemText, req, tools, toolConfig)
|
||||
|
||||
resp, err := p.client.Models.GenerateContent(ctx, model, contents, config)
|
||||
if err != nil {
|
||||
@@ -92,6 +113,11 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
|
||||
}
|
||||
}
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
if len(resp.Candidates) > 0 {
|
||||
toolCalls = extractToolCalls(resp)
|
||||
}
|
||||
|
||||
var inputTokens, outputTokens int
|
||||
if resp.UsageMetadata != nil {
|
||||
inputTokens = int(resp.UsageMetadata.PromptTokenCount)
|
||||
@@ -102,6 +128,7 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
|
||||
ID: uuid.NewString(),
|
||||
Model: model,
|
||||
Text: text,
|
||||
ToolCalls: toolCalls,
|
||||
Usage: api.Usage{
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: outputTokens,
|
||||
@@ -128,7 +155,29 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
|
||||
|
||||
contents, systemText := convertMessages(messages)
|
||||
|
||||
config := buildConfig(systemText, req)
|
||||
// Parse tools if present
|
||||
var tools []*genai.Tool
|
||||
if req.Tools != nil && len(req.Tools) > 0 {
|
||||
var err error
|
||||
tools, err = parseTools(req)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("parse tools: %w", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Parse tool_choice if present
|
||||
var toolConfig *genai.ToolConfig
|
||||
if req.ToolChoice != nil && len(req.ToolChoice) > 0 {
|
||||
var err error
|
||||
toolConfig, err = parseToolChoice(req)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("parse tool_choice: %w", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
config := buildConfig(systemText, req, tools, toolConfig)
|
||||
|
||||
stream := p.client.Models.GenerateContentStream(ctx, model, contents, config)
|
||||
|
||||
@@ -138,23 +187,34 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
|
||||
return
|
||||
}
|
||||
|
||||
var text string
|
||||
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
|
||||
for _, part := range resp.Candidates[0].Content.Parts {
|
||||
for partIndex, part := range resp.Candidates[0].Content.Parts {
|
||||
if part != nil {
|
||||
text += part.Text
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if text != "" {
|
||||
// Handle text content
|
||||
if part.Text != "" {
|
||||
select {
|
||||
case deltaChan <- &api.ProviderStreamDelta{Text: text}:
|
||||
case deltaChan <- &api.ProviderStreamDelta{Text: part.Text}:
|
||||
case <-ctx.Done():
|
||||
errChan <- ctx.Err()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Handle tool call content
|
||||
if part.FunctionCall != nil {
|
||||
delta := extractToolCallDelta(part, partIndex)
|
||||
if delta != nil {
|
||||
select {
|
||||
case deltaChan <- &api.ProviderStreamDelta{ToolCallDelta: delta}:
|
||||
case <-ctx.Done():
|
||||
errChan <- ctx.Err()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
@@ -182,6 +242,39 @@ func convertMessages(messages []api.Message) ([]*genai.Content, string) {
|
||||
continue
|
||||
}
|
||||
|
||||
if msg.Role == "tool" {
|
||||
// Tool results are sent as FunctionResponse in user role message
|
||||
var output string
|
||||
for _, block := range msg.Content {
|
||||
if block.Type == "input_text" || block.Type == "output_text" {
|
||||
output += block.Text
|
||||
}
|
||||
}
|
||||
|
||||
// Parse output as JSON map, or wrap in {"output": "..."} if not JSON
|
||||
var responseMap map[string]any
|
||||
if err := json.Unmarshal([]byte(output), &responseMap); err != nil {
|
||||
// Not JSON, wrap it
|
||||
responseMap = map[string]any{"output": output}
|
||||
}
|
||||
|
||||
// Create FunctionResponse part with CallID from message
|
||||
part := &genai.Part{
|
||||
FunctionResponse: &genai.FunctionResponse{
|
||||
ID: msg.CallID,
|
||||
Name: "", // Name is optional for responses
|
||||
Response: responseMap,
|
||||
},
|
||||
}
|
||||
|
||||
// Add to user role message
|
||||
contents = append(contents, &genai.Content{
|
||||
Role: "user",
|
||||
Parts: []*genai.Part{part},
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
var parts []*genai.Part
|
||||
for _, block := range msg.Content {
|
||||
if block.Type == "input_text" || block.Type == "output_text" {
|
||||
@@ -204,10 +297,10 @@ func convertMessages(messages []api.Message) ([]*genai.Content, string) {
|
||||
}
|
||||
|
||||
// buildConfig constructs a GenerateContentConfig from system text and request params.
|
||||
func buildConfig(systemText string, req *api.ResponseRequest) *genai.GenerateContentConfig {
|
||||
func buildConfig(systemText string, req *api.ResponseRequest, tools []*genai.Tool, toolConfig *genai.ToolConfig) *genai.GenerateContentConfig {
|
||||
var cfg *genai.GenerateContentConfig
|
||||
|
||||
needsCfg := systemText != "" || req.MaxOutputTokens != nil || req.Temperature != nil || req.TopP != nil
|
||||
needsCfg := systemText != "" || req.MaxOutputTokens != nil || req.Temperature != nil || req.TopP != nil || tools != nil || toolConfig != nil
|
||||
if !needsCfg {
|
||||
return nil
|
||||
}
|
||||
@@ -234,6 +327,14 @@ func buildConfig(systemText string, req *api.ResponseRequest) *genai.GenerateCon
|
||||
cfg.TopP = &tp
|
||||
}
|
||||
|
||||
if tools != nil {
|
||||
cfg.Tools = tools
|
||||
}
|
||||
|
||||
if toolConfig != nil {
|
||||
cfg.ToolConfig = toolConfig
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user