diff --git a/internal/providers/google/convert.go b/internal/providers/google/convert.go new file mode 100644 index 0000000..1aa0dce --- /dev/null +++ b/internal/providers/google/convert.go @@ -0,0 +1,212 @@ +package google + +import ( + "encoding/json" + "fmt" + "math/rand" + "time" + + "google.golang.org/genai" + + "github.com/ajac-zero/latticelm/internal/api" +) + +// parseTools converts generic tool definitions from req.Tools (JSON) to Google's []*genai.Tool format. +func parseTools(req *api.ResponseRequest) ([]*genai.Tool, error) { + if req.Tools == nil || len(req.Tools) == 0 { + return nil, nil + } + + // Unmarshal to slice of tool definitions + var toolDefs []map[string]interface{} + if err := json.Unmarshal(req.Tools, &toolDefs); err != nil { + return nil, fmt.Errorf("unmarshal tools: %w", err) + } + + var functionDeclarations []*genai.FunctionDeclaration + + for _, toolDef := range toolDefs { + // Extract function details + // Support both flat format (name/description/parameters at top level) + // and nested format (under "function" key) + var name, description string + var parameters interface{} + + if functionData, ok := toolDef["function"].(map[string]interface{}); ok { + // Nested format: {"type": "function", "function": {...}} + name, _ = functionData["name"].(string) + description, _ = functionData["description"].(string) + parameters = functionData["parameters"] + } else { + // Flat format: {"type": "function", "name": "...", ...} + name, _ = toolDef["name"].(string) + description, _ = toolDef["description"].(string) + parameters = toolDef["parameters"] + } + + if name == "" { + continue + } + + // Create function declaration + funcDecl := &genai.FunctionDeclaration{ + Name: name, + Description: description, + } + + // Google accepts parameters as raw JSON schema + if parameters != nil { + funcDecl.ParametersJsonSchema = parameters + } + + functionDeclarations = append(functionDeclarations, funcDecl) + } + + // Return single Tool with all function declarations + if len(functionDeclarations) > 0 { + return []*genai.Tool{{FunctionDeclarations: functionDeclarations}}, nil + } + + return nil, nil +} + +// parseToolChoice converts req.ToolChoice to Google's ToolConfig with FunctionCallingConfig. +func parseToolChoice(req *api.ResponseRequest) (*genai.ToolConfig, error) { + if req.ToolChoice == nil || len(req.ToolChoice) == 0 { + return nil, nil + } + + var choice interface{} + if err := json.Unmarshal(req.ToolChoice, &choice); err != nil { + return nil, fmt.Errorf("unmarshal tool_choice: %w", err) + } + + config := &genai.ToolConfig{ + FunctionCallingConfig: &genai.FunctionCallingConfig{}, + } + + // Handle string values: "auto", "none", "required"/"any" + if str, ok := choice.(string); ok { + switch str { + case "auto": + config.FunctionCallingConfig.Mode = genai.FunctionCallingConfigModeAuto + case "none": + config.FunctionCallingConfig.Mode = genai.FunctionCallingConfigModeNone + case "required", "any": + config.FunctionCallingConfig.Mode = genai.FunctionCallingConfigModeAny + default: + return nil, fmt.Errorf("unknown tool_choice string: %s", str) + } + return config, nil + } + + // Handle object format: {"type": "function", "function": {"name": "..."}} + if obj, ok := choice.(map[string]interface{}); ok { + if typeVal, ok := obj["type"].(string); ok && typeVal == "function" { + config.FunctionCallingConfig.Mode = genai.FunctionCallingConfigModeAny + if funcObj, ok := obj["function"].(map[string]interface{}); ok { + if name, ok := funcObj["name"].(string); ok { + config.FunctionCallingConfig.AllowedFunctionNames = []string{name} + } + } + return config, nil + } + } + + return nil, fmt.Errorf("unsupported tool_choice format") +} + +// extractToolCalls extracts tool calls from Google's response format to generic api.ToolCall slice. +func extractToolCalls(resp *genai.GenerateContentResponse) []api.ToolCall { + var toolCalls []api.ToolCall + + for _, candidate := range resp.Candidates { + if candidate.Content == nil { + continue + } + + for _, part := range candidate.Content.Parts { + if part == nil || part.FunctionCall == nil { + continue + } + + // Extract function call details + fc := part.FunctionCall + + // Marshal arguments to JSON string + var argsJSON string + if fc.Args != nil { + argsBytes, err := json.Marshal(fc.Args) + if err == nil { + argsJSON = string(argsBytes) + } else { + // Fallback to empty object + argsJSON = "{}" + } + } else { + argsJSON = "{}" + } + + // Generate ID if Google doesn't provide one + callID := fc.ID + if callID == "" { + callID = fmt.Sprintf("call_%s", generateRandomID()) + } + + toolCalls = append(toolCalls, api.ToolCall{ + ID: callID, + Name: fc.Name, + Arguments: argsJSON, + }) + } + } + + return toolCalls +} + +// extractToolCallDelta extracts streaming tool call information from response parts. +func extractToolCallDelta(part *genai.Part, index int) *api.ToolCallDelta { + if part == nil || part.FunctionCall == nil { + return nil + } + + fc := part.FunctionCall + + // Marshal arguments to JSON string + var argsJSON string + if fc.Args != nil { + argsBytes, err := json.Marshal(fc.Args) + if err == nil { + argsJSON = string(argsBytes) + } else { + argsJSON = "{}" + } + } else { + argsJSON = "{}" + } + + // Generate ID if Google doesn't provide one + callID := fc.ID + if callID == "" { + callID = fmt.Sprintf("call_%s", generateRandomID()) + } + + return &api.ToolCallDelta{ + Index: index, + ID: callID, + Name: fc.Name, + Arguments: argsJSON, + } +} + +// generateRandomID generates a random alphanumeric ID +func generateRandomID() string { + const charset = "abcdefghijklmnopqrstuvwxyz0123456789" + const length = 24 + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + b := make([]byte, length) + for i := range b { + b[i] = charset[rng.Intn(len(charset))] + } + return string(b) +} diff --git a/internal/providers/google/google.go b/internal/providers/google/google.go index 5be93f1..76423e3 100644 --- a/internal/providers/google/google.go +++ b/internal/providers/google/google.go @@ -2,6 +2,7 @@ package google import ( "context" + "encoding/json" "fmt" "github.com/google/uuid" @@ -76,7 +77,27 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap contents, systemText := convertMessages(messages) - config := buildConfig(systemText, req) + // Parse tools if present + var tools []*genai.Tool + if req.Tools != nil && len(req.Tools) > 0 { + var err error + tools, err = parseTools(req) + if err != nil { + return nil, fmt.Errorf("parse tools: %w", err) + } + } + + // Parse tool_choice if present + var toolConfig *genai.ToolConfig + if req.ToolChoice != nil && len(req.ToolChoice) > 0 { + var err error + toolConfig, err = parseToolChoice(req) + if err != nil { + return nil, fmt.Errorf("parse tool_choice: %w", err) + } + } + + config := buildConfig(systemText, req, tools, toolConfig) resp, err := p.client.Models.GenerateContent(ctx, model, contents, config) if err != nil { @@ -92,6 +113,11 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap } } + var toolCalls []api.ToolCall + if len(resp.Candidates) > 0 { + toolCalls = extractToolCalls(resp) + } + var inputTokens, outputTokens int if resp.UsageMetadata != nil { inputTokens = int(resp.UsageMetadata.PromptTokenCount) @@ -99,9 +125,10 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap } return &api.ProviderResult{ - ID: uuid.NewString(), - Model: model, - Text: text, + ID: uuid.NewString(), + Model: model, + Text: text, + ToolCalls: toolCalls, Usage: api.Usage{ InputTokens: inputTokens, OutputTokens: outputTokens, @@ -128,7 +155,29 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r contents, systemText := convertMessages(messages) - config := buildConfig(systemText, req) + // Parse tools if present + var tools []*genai.Tool + if req.Tools != nil && len(req.Tools) > 0 { + var err error + tools, err = parseTools(req) + if err != nil { + errChan <- fmt.Errorf("parse tools: %w", err) + return + } + } + + // Parse tool_choice if present + var toolConfig *genai.ToolConfig + if req.ToolChoice != nil && len(req.ToolChoice) > 0 { + var err error + toolConfig, err = parseToolChoice(req) + if err != nil { + errChan <- fmt.Errorf("parse tool_choice: %w", err) + return + } + } + + config := buildConfig(systemText, req, tools, toolConfig) stream := p.client.Models.GenerateContentStream(ctx, model, contents, config) @@ -138,21 +187,32 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r return } - var text string if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil { - for _, part := range resp.Candidates[0].Content.Parts { + for partIndex, part := range resp.Candidates[0].Content.Parts { if part != nil { - text += part.Text - } - } - } + // Handle text content + if part.Text != "" { + select { + case deltaChan <- &api.ProviderStreamDelta{Text: part.Text}: + case <-ctx.Done(): + errChan <- ctx.Err() + return + } + } - if text != "" { - select { - case deltaChan <- &api.ProviderStreamDelta{Text: text}: - case <-ctx.Done(): - errChan <- ctx.Err() - return + // Handle tool call content + if part.FunctionCall != nil { + delta := extractToolCallDelta(part, partIndex) + if delta != nil { + select { + case deltaChan <- &api.ProviderStreamDelta{ToolCallDelta: delta}: + case <-ctx.Done(): + errChan <- ctx.Err() + return + } + } + } + } } } } @@ -182,6 +242,39 @@ func convertMessages(messages []api.Message) ([]*genai.Content, string) { continue } + if msg.Role == "tool" { + // Tool results are sent as FunctionResponse in user role message + var output string + for _, block := range msg.Content { + if block.Type == "input_text" || block.Type == "output_text" { + output += block.Text + } + } + + // Parse output as JSON map, or wrap in {"output": "..."} if not JSON + var responseMap map[string]any + if err := json.Unmarshal([]byte(output), &responseMap); err != nil { + // Not JSON, wrap it + responseMap = map[string]any{"output": output} + } + + // Create FunctionResponse part with CallID from message + part := &genai.Part{ + FunctionResponse: &genai.FunctionResponse{ + ID: msg.CallID, + Name: "", // Name is optional for responses + Response: responseMap, + }, + } + + // Add to user role message + contents = append(contents, &genai.Content{ + Role: "user", + Parts: []*genai.Part{part}, + }) + continue + } + var parts []*genai.Part for _, block := range msg.Content { if block.Type == "input_text" || block.Type == "output_text" { @@ -204,10 +297,10 @@ func convertMessages(messages []api.Message) ([]*genai.Content, string) { } // buildConfig constructs a GenerateContentConfig from system text and request params. -func buildConfig(systemText string, req *api.ResponseRequest) *genai.GenerateContentConfig { +func buildConfig(systemText string, req *api.ResponseRequest, tools []*genai.Tool, toolConfig *genai.ToolConfig) *genai.GenerateContentConfig { var cfg *genai.GenerateContentConfig - needsCfg := systemText != "" || req.MaxOutputTokens != nil || req.Temperature != nil || req.TopP != nil + needsCfg := systemText != "" || req.MaxOutputTokens != nil || req.Temperature != nil || req.TopP != nil || tools != nil || toolConfig != nil if !needsCfg { return nil } @@ -234,6 +327,14 @@ func buildConfig(systemText string, req *api.ResponseRequest) *genai.GenerateCon cfg.TopP = &tp } + if tools != nil { + cfg.Tools = tools + } + + if toolConfig != nil { + cfg.ToolConfig = toolConfig + } + return cfg }