Add Google tool calling
This commit is contained in:
@@ -2,6 +2,7 @@ package google
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -76,7 +77,27 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
|
||||
|
||||
contents, systemText := convertMessages(messages)
|
||||
|
||||
config := buildConfig(systemText, req)
|
||||
// Parse tools if present
|
||||
var tools []*genai.Tool
|
||||
if req.Tools != nil && len(req.Tools) > 0 {
|
||||
var err error
|
||||
tools, err = parseTools(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse tools: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse tool_choice if present
|
||||
var toolConfig *genai.ToolConfig
|
||||
if req.ToolChoice != nil && len(req.ToolChoice) > 0 {
|
||||
var err error
|
||||
toolConfig, err = parseToolChoice(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse tool_choice: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
config := buildConfig(systemText, req, tools, toolConfig)
|
||||
|
||||
resp, err := p.client.Models.GenerateContent(ctx, model, contents, config)
|
||||
if err != nil {
|
||||
@@ -92,6 +113,11 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
|
||||
}
|
||||
}
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
if len(resp.Candidates) > 0 {
|
||||
toolCalls = extractToolCalls(resp)
|
||||
}
|
||||
|
||||
var inputTokens, outputTokens int
|
||||
if resp.UsageMetadata != nil {
|
||||
inputTokens = int(resp.UsageMetadata.PromptTokenCount)
|
||||
@@ -99,9 +125,10 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
|
||||
}
|
||||
|
||||
return &api.ProviderResult{
|
||||
ID: uuid.NewString(),
|
||||
Model: model,
|
||||
Text: text,
|
||||
ID: uuid.NewString(),
|
||||
Model: model,
|
||||
Text: text,
|
||||
ToolCalls: toolCalls,
|
||||
Usage: api.Usage{
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: outputTokens,
|
||||
@@ -128,7 +155,29 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
|
||||
|
||||
contents, systemText := convertMessages(messages)
|
||||
|
||||
config := buildConfig(systemText, req)
|
||||
// Parse tools if present
|
||||
var tools []*genai.Tool
|
||||
if req.Tools != nil && len(req.Tools) > 0 {
|
||||
var err error
|
||||
tools, err = parseTools(req)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("parse tools: %w", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Parse tool_choice if present
|
||||
var toolConfig *genai.ToolConfig
|
||||
if req.ToolChoice != nil && len(req.ToolChoice) > 0 {
|
||||
var err error
|
||||
toolConfig, err = parseToolChoice(req)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("parse tool_choice: %w", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
config := buildConfig(systemText, req, tools, toolConfig)
|
||||
|
||||
stream := p.client.Models.GenerateContentStream(ctx, model, contents, config)
|
||||
|
||||
@@ -138,21 +187,32 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
|
||||
return
|
||||
}
|
||||
|
||||
var text string
|
||||
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
|
||||
for _, part := range resp.Candidates[0].Content.Parts {
|
||||
for partIndex, part := range resp.Candidates[0].Content.Parts {
|
||||
if part != nil {
|
||||
text += part.Text
|
||||
}
|
||||
}
|
||||
}
|
||||
// Handle text content
|
||||
if part.Text != "" {
|
||||
select {
|
||||
case deltaChan <- &api.ProviderStreamDelta{Text: part.Text}:
|
||||
case <-ctx.Done():
|
||||
errChan <- ctx.Err()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if text != "" {
|
||||
select {
|
||||
case deltaChan <- &api.ProviderStreamDelta{Text: text}:
|
||||
case <-ctx.Done():
|
||||
errChan <- ctx.Err()
|
||||
return
|
||||
// Handle tool call content
|
||||
if part.FunctionCall != nil {
|
||||
delta := extractToolCallDelta(part, partIndex)
|
||||
if delta != nil {
|
||||
select {
|
||||
case deltaChan <- &api.ProviderStreamDelta{ToolCallDelta: delta}:
|
||||
case <-ctx.Done():
|
||||
errChan <- ctx.Err()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -182,6 +242,39 @@ func convertMessages(messages []api.Message) ([]*genai.Content, string) {
|
||||
continue
|
||||
}
|
||||
|
||||
if msg.Role == "tool" {
|
||||
// Tool results are sent as FunctionResponse in user role message
|
||||
var output string
|
||||
for _, block := range msg.Content {
|
||||
if block.Type == "input_text" || block.Type == "output_text" {
|
||||
output += block.Text
|
||||
}
|
||||
}
|
||||
|
||||
// Parse output as JSON map, or wrap in {"output": "..."} if not JSON
|
||||
var responseMap map[string]any
|
||||
if err := json.Unmarshal([]byte(output), &responseMap); err != nil {
|
||||
// Not JSON, wrap it
|
||||
responseMap = map[string]any{"output": output}
|
||||
}
|
||||
|
||||
// Create FunctionResponse part with CallID from message
|
||||
part := &genai.Part{
|
||||
FunctionResponse: &genai.FunctionResponse{
|
||||
ID: msg.CallID,
|
||||
Name: "", // Name is optional for responses
|
||||
Response: responseMap,
|
||||
},
|
||||
}
|
||||
|
||||
// Add to user role message
|
||||
contents = append(contents, &genai.Content{
|
||||
Role: "user",
|
||||
Parts: []*genai.Part{part},
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
var parts []*genai.Part
|
||||
for _, block := range msg.Content {
|
||||
if block.Type == "input_text" || block.Type == "output_text" {
|
||||
@@ -204,10 +297,10 @@ func convertMessages(messages []api.Message) ([]*genai.Content, string) {
|
||||
}
|
||||
|
||||
// buildConfig constructs a GenerateContentConfig from system text and request params.
|
||||
func buildConfig(systemText string, req *api.ResponseRequest) *genai.GenerateContentConfig {
|
||||
func buildConfig(systemText string, req *api.ResponseRequest, tools []*genai.Tool, toolConfig *genai.ToolConfig) *genai.GenerateContentConfig {
|
||||
var cfg *genai.GenerateContentConfig
|
||||
|
||||
needsCfg := systemText != "" || req.MaxOutputTokens != nil || req.Temperature != nil || req.TopP != nil
|
||||
needsCfg := systemText != "" || req.MaxOutputTokens != nil || req.Temperature != nil || req.TopP != nil || tools != nil || toolConfig != nil
|
||||
if !needsCfg {
|
||||
return nil
|
||||
}
|
||||
@@ -234,6 +327,14 @@ func buildConfig(systemText string, req *api.ResponseRequest) *genai.GenerateCon
|
||||
cfg.TopP = &tp
|
||||
}
|
||||
|
||||
if tools != nil {
|
||||
cfg.Tools = tools
|
||||
}
|
||||
|
||||
if toolConfig != nil {
|
||||
cfg.ToolConfig = toolConfig
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user