Add OpenAI tool calling support

This commit is contained in:
2026-03-02 15:27:28 +00:00
parent 8ceb831e84
commit 157680bb13
4 changed files with 429 additions and 77 deletions

View File

@@ -91,6 +91,8 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
oaiMessages = append(oaiMessages, openai.SystemMessage(content))
case "developer":
oaiMessages = append(oaiMessages, openai.SystemMessage(content))
case "tool":
oaiMessages = append(oaiMessages, openai.ToolMessage(content, msg.CallID))
}
}
@@ -108,6 +110,29 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
params.TopP = openai.Float(*req.TopP)
}
// Add tools if present
if req.Tools != nil && len(req.Tools) > 0 {
tools, err := parseTools(req)
if err != nil {
return nil, fmt.Errorf("parse tools: %w", err)
}
params.Tools = tools
}
// Add tool_choice if present
if req.ToolChoice != nil && len(req.ToolChoice) > 0 {
toolChoice, err := parseToolChoice(req)
if err != nil {
return nil, fmt.Errorf("parse tool_choice: %w", err)
}
params.ToolChoice = toolChoice
}
// Add parallel_tool_calls if specified
if req.ParallelToolCalls != nil {
params.ParallelToolCalls = openai.Bool(*req.ParallelToolCalls)
}
// Call OpenAI API
resp, err := p.client.Chat.Completions.New(ctx, params)
if err != nil {
@@ -115,14 +140,20 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
}
var combinedText string
var toolCalls []api.ToolCall
for _, choice := range resp.Choices {
combinedText += choice.Message.Content
if len(choice.Message.ToolCalls) > 0 {
toolCalls = append(toolCalls, extractToolCalls(choice.Message)...)
}
}
return &api.ProviderResult{
ID: resp.ID,
Model: resp.Model,
Text: combinedText,
ID: resp.ID,
Model: resp.Model,
Text: combinedText,
ToolCalls: toolCalls,
Usage: api.Usage{
InputTokens: int(resp.Usage.PromptTokens),
OutputTokens: int(resp.Usage.CompletionTokens),
@@ -168,6 +199,8 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
oaiMessages = append(oaiMessages, openai.SystemMessage(content))
case "developer":
oaiMessages = append(oaiMessages, openai.SystemMessage(content))
case "tool":
oaiMessages = append(oaiMessages, openai.ToolMessage(content, msg.CallID))
}
}
@@ -185,6 +218,31 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
params.TopP = openai.Float(*req.TopP)
}
// Add tools if present
if req.Tools != nil && len(req.Tools) > 0 {
tools, err := parseTools(req)
if err != nil {
errChan <- fmt.Errorf("parse tools: %w", err)
return
}
params.Tools = tools
}
// Add tool_choice if present
if req.ToolChoice != nil && len(req.ToolChoice) > 0 {
toolChoice, err := parseToolChoice(req)
if err != nil {
errChan <- fmt.Errorf("parse tool_choice: %w", err)
return
}
params.ToolChoice = toolChoice
}
// Add parallel_tool_calls if specified
if req.ParallelToolCalls != nil {
params.ParallelToolCalls = openai.Bool(*req.ParallelToolCalls)
}
// Create streaming request
stream := p.client.Chat.Completions.NewStreaming(ctx, params)
@@ -193,19 +251,35 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
chunk := stream.Current()
for _, choice := range chunk.Choices {
if choice.Delta.Content == "" {
continue
// Handle text content
if choice.Delta.Content != "" {
select {
case deltaChan <- &api.ProviderStreamDelta{
ID: chunk.ID,
Model: chunk.Model,
Text: choice.Delta.Content,
}:
case <-ctx.Done():
errChan <- ctx.Err()
return
}
}
select {
case deltaChan <- &api.ProviderStreamDelta{
ID: chunk.ID,
Model: chunk.Model,
Text: choice.Delta.Content,
}:
case <-ctx.Done():
errChan <- ctx.Err()
return
// Handle tool call deltas
if len(choice.Delta.ToolCalls) > 0 {
delta := extractToolCallDelta(choice)
if delta != nil {
select {
case deltaChan <- &api.ProviderStreamDelta{
ID: chunk.ID,
Model: chunk.Model,
ToolCallDelta: delta,
}:
case <-ctx.Done():
errChan <- ctx.Err()
return
}
}
}
}
}