Add OpenAI tool calling support
This commit is contained in:
120
internal/providers/openai/convert.go
Normal file
120
internal/providers/openai/convert.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/ajac-zero/latticelm/internal/api"
|
||||
"github.com/openai/openai-go"
|
||||
"github.com/openai/openai-go/shared"
|
||||
)
|
||||
|
||||
// parseTools converts Open Responses tools to OpenAI format
|
||||
func parseTools(req *api.ResponseRequest) ([]openai.ChatCompletionToolParam, error) {
|
||||
if req.Tools == nil || len(req.Tools) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var toolDefs []map[string]interface{}
|
||||
if err := json.Unmarshal(req.Tools, &toolDefs); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal tools: %w", err)
|
||||
}
|
||||
|
||||
var tools []openai.ChatCompletionToolParam
|
||||
for _, td := range toolDefs {
|
||||
// Convert Open Responses tool to OpenAI ChatCompletionToolParam
|
||||
// Extract: name, description, parameters
|
||||
name, _ := td["name"].(string)
|
||||
desc, _ := td["description"].(string)
|
||||
params, _ := td["parameters"].(map[string]interface{})
|
||||
|
||||
tool := openai.ChatCompletionToolParam{
|
||||
Function: shared.FunctionDefinitionParam{
|
||||
Name: name,
|
||||
},
|
||||
}
|
||||
|
||||
if desc != "" {
|
||||
tool.Function.Description = openai.String(desc)
|
||||
}
|
||||
|
||||
if params != nil {
|
||||
tool.Function.Parameters = shared.FunctionParameters(params)
|
||||
}
|
||||
|
||||
tools = append(tools, tool)
|
||||
}
|
||||
|
||||
return tools, nil
|
||||
}
|
||||
|
||||
// parseToolChoice converts Open Responses tool_choice to OpenAI format
|
||||
func parseToolChoice(req *api.ResponseRequest) (openai.ChatCompletionToolChoiceOptionUnionParam, error) {
|
||||
var result openai.ChatCompletionToolChoiceOptionUnionParam
|
||||
|
||||
if req.ToolChoice == nil || len(req.ToolChoice) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
var choice interface{}
|
||||
if err := json.Unmarshal(req.ToolChoice, &choice); err != nil {
|
||||
return result, fmt.Errorf("unmarshal tool_choice: %w", err)
|
||||
}
|
||||
|
||||
// Handle string values: "auto", "none", "required"
|
||||
if str, ok := choice.(string); ok {
|
||||
result.OfAuto = openai.String(str)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Handle specific function selection: {"type": "function", "name": "..."}
|
||||
if obj, ok := choice.(map[string]interface{}); ok {
|
||||
funcObj, _ := obj["function"].(map[string]interface{})
|
||||
name, _ := funcObj["name"].(string)
|
||||
|
||||
result.OfChatCompletionNamedToolChoice = &openai.ChatCompletionNamedToolChoiceParam{
|
||||
Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
|
||||
Name: name,
|
||||
},
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
return result, fmt.Errorf("invalid tool_choice format")
|
||||
}
|
||||
|
||||
// extractToolCalls converts OpenAI tool calls to api.ToolCall
|
||||
func extractToolCalls(message openai.ChatCompletionMessage) []api.ToolCall {
|
||||
if len(message.ToolCalls) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
for _, tc := range message.ToolCalls {
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
})
|
||||
}
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
// extractToolCallDelta extracts tool call delta from streaming chunk choice
|
||||
func extractToolCallDelta(choice openai.ChatCompletionChunkChoice) *api.ToolCallDelta {
|
||||
if len(choice.Delta.ToolCalls) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// OpenAI sends tool calls with index in the delta
|
||||
for _, tc := range choice.Delta.ToolCalls {
|
||||
return &api.ToolCallDelta{
|
||||
Index: int(tc.Index),
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -91,6 +91,8 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
|
||||
oaiMessages = append(oaiMessages, openai.SystemMessage(content))
|
||||
case "developer":
|
||||
oaiMessages = append(oaiMessages, openai.SystemMessage(content))
|
||||
case "tool":
|
||||
oaiMessages = append(oaiMessages, openai.ToolMessage(content, msg.CallID))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -108,6 +110,29 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
|
||||
params.TopP = openai.Float(*req.TopP)
|
||||
}
|
||||
|
||||
// Add tools if present
|
||||
if req.Tools != nil && len(req.Tools) > 0 {
|
||||
tools, err := parseTools(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse tools: %w", err)
|
||||
}
|
||||
params.Tools = tools
|
||||
}
|
||||
|
||||
// Add tool_choice if present
|
||||
if req.ToolChoice != nil && len(req.ToolChoice) > 0 {
|
||||
toolChoice, err := parseToolChoice(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse tool_choice: %w", err)
|
||||
}
|
||||
params.ToolChoice = toolChoice
|
||||
}
|
||||
|
||||
// Add parallel_tool_calls if specified
|
||||
if req.ParallelToolCalls != nil {
|
||||
params.ParallelToolCalls = openai.Bool(*req.ParallelToolCalls)
|
||||
}
|
||||
|
||||
// Call OpenAI API
|
||||
resp, err := p.client.Chat.Completions.New(ctx, params)
|
||||
if err != nil {
|
||||
@@ -115,14 +140,20 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
|
||||
}
|
||||
|
||||
var combinedText string
|
||||
var toolCalls []api.ToolCall
|
||||
|
||||
for _, choice := range resp.Choices {
|
||||
combinedText += choice.Message.Content
|
||||
if len(choice.Message.ToolCalls) > 0 {
|
||||
toolCalls = append(toolCalls, extractToolCalls(choice.Message)...)
|
||||
}
|
||||
}
|
||||
|
||||
return &api.ProviderResult{
|
||||
ID: resp.ID,
|
||||
Model: resp.Model,
|
||||
Text: combinedText,
|
||||
ID: resp.ID,
|
||||
Model: resp.Model,
|
||||
Text: combinedText,
|
||||
ToolCalls: toolCalls,
|
||||
Usage: api.Usage{
|
||||
InputTokens: int(resp.Usage.PromptTokens),
|
||||
OutputTokens: int(resp.Usage.CompletionTokens),
|
||||
@@ -168,6 +199,8 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
|
||||
oaiMessages = append(oaiMessages, openai.SystemMessage(content))
|
||||
case "developer":
|
||||
oaiMessages = append(oaiMessages, openai.SystemMessage(content))
|
||||
case "tool":
|
||||
oaiMessages = append(oaiMessages, openai.ToolMessage(content, msg.CallID))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -185,6 +218,31 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
|
||||
params.TopP = openai.Float(*req.TopP)
|
||||
}
|
||||
|
||||
// Add tools if present
|
||||
if req.Tools != nil && len(req.Tools) > 0 {
|
||||
tools, err := parseTools(req)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("parse tools: %w", err)
|
||||
return
|
||||
}
|
||||
params.Tools = tools
|
||||
}
|
||||
|
||||
// Add tool_choice if present
|
||||
if req.ToolChoice != nil && len(req.ToolChoice) > 0 {
|
||||
toolChoice, err := parseToolChoice(req)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("parse tool_choice: %w", err)
|
||||
return
|
||||
}
|
||||
params.ToolChoice = toolChoice
|
||||
}
|
||||
|
||||
// Add parallel_tool_calls if specified
|
||||
if req.ParallelToolCalls != nil {
|
||||
params.ParallelToolCalls = openai.Bool(*req.ParallelToolCalls)
|
||||
}
|
||||
|
||||
// Create streaming request
|
||||
stream := p.client.Chat.Completions.NewStreaming(ctx, params)
|
||||
|
||||
@@ -193,19 +251,35 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
|
||||
chunk := stream.Current()
|
||||
|
||||
for _, choice := range chunk.Choices {
|
||||
if choice.Delta.Content == "" {
|
||||
continue
|
||||
// Handle text content
|
||||
if choice.Delta.Content != "" {
|
||||
select {
|
||||
case deltaChan <- &api.ProviderStreamDelta{
|
||||
ID: chunk.ID,
|
||||
Model: chunk.Model,
|
||||
Text: choice.Delta.Content,
|
||||
}:
|
||||
case <-ctx.Done():
|
||||
errChan <- ctx.Err()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case deltaChan <- &api.ProviderStreamDelta{
|
||||
ID: chunk.ID,
|
||||
Model: chunk.Model,
|
||||
Text: choice.Delta.Content,
|
||||
}:
|
||||
case <-ctx.Done():
|
||||
errChan <- ctx.Err()
|
||||
return
|
||||
// Handle tool call deltas
|
||||
if len(choice.Delta.ToolCalls) > 0 {
|
||||
delta := extractToolCallDelta(choice)
|
||||
if delta != nil {
|
||||
select {
|
||||
case deltaChan <- &api.ProviderStreamDelta{
|
||||
ID: chunk.ID,
|
||||
Model: chunk.Model,
|
||||
ToolCallDelta: delta,
|
||||
}:
|
||||
case <-ctx.Done():
|
||||
errChan <- ctx.Err()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user