Improve test coverage
Some checks failed
Some checks failed
This commit is contained in:
291
internal/providers/anthropic/anthropic_test.go
Normal file
291
internal/providers/anthropic/anthropic_test.go
Normal file
@@ -0,0 +1,291 @@
|
|||||||
|
package anthropic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ajac-zero/latticelm/internal/api"
|
||||||
|
"github.com/ajac-zero/latticelm/internal/config"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNew(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
cfg config.ProviderConfig
|
||||||
|
validate func(t *testing.T, p *Provider)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "creates provider with API key",
|
||||||
|
cfg: config.ProviderConfig{
|
||||||
|
APIKey: "sk-ant-test-key",
|
||||||
|
Model: "claude-3-opus",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, p *Provider) {
|
||||||
|
assert.NotNil(t, p)
|
||||||
|
assert.NotNil(t, p.client)
|
||||||
|
assert.Equal(t, "sk-ant-test-key", p.cfg.APIKey)
|
||||||
|
assert.Equal(t, "claude-3-opus", p.cfg.Model)
|
||||||
|
assert.False(t, p.azure)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "creates provider without API key",
|
||||||
|
cfg: config.ProviderConfig{
|
||||||
|
APIKey: "",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, p *Provider) {
|
||||||
|
assert.NotNil(t, p)
|
||||||
|
assert.Nil(t, p.client)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
p := New(tt.cfg)
|
||||||
|
tt.validate(t, p)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewAzure(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
cfg config.AzureAnthropicConfig
|
||||||
|
validate func(t *testing.T, p *Provider)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "creates Azure provider with endpoint and API key",
|
||||||
|
cfg: config.AzureAnthropicConfig{
|
||||||
|
APIKey: "azure-key",
|
||||||
|
Endpoint: "https://test.services.ai.azure.com/anthropic",
|
||||||
|
Model: "claude-3-sonnet",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, p *Provider) {
|
||||||
|
assert.NotNil(t, p)
|
||||||
|
assert.NotNil(t, p.client)
|
||||||
|
assert.Equal(t, "azure-key", p.cfg.APIKey)
|
||||||
|
assert.Equal(t, "claude-3-sonnet", p.cfg.Model)
|
||||||
|
assert.True(t, p.azure)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "creates Azure provider without API key",
|
||||||
|
cfg: config.AzureAnthropicConfig{
|
||||||
|
APIKey: "",
|
||||||
|
Endpoint: "https://test.services.ai.azure.com/anthropic",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, p *Provider) {
|
||||||
|
assert.NotNil(t, p)
|
||||||
|
assert.Nil(t, p.client)
|
||||||
|
assert.True(t, p.azure)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "creates Azure provider without endpoint",
|
||||||
|
cfg: config.AzureAnthropicConfig{
|
||||||
|
APIKey: "azure-key",
|
||||||
|
Endpoint: "",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, p *Provider) {
|
||||||
|
assert.NotNil(t, p)
|
||||||
|
assert.Nil(t, p.client)
|
||||||
|
assert.True(t, p.azure)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
p := NewAzure(tt.cfg)
|
||||||
|
tt.validate(t, p)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProvider_Name(t *testing.T) {
|
||||||
|
p := New(config.ProviderConfig{})
|
||||||
|
assert.Equal(t, "anthropic", p.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProvider_Generate_Validation(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
provider *Provider
|
||||||
|
messages []api.Message
|
||||||
|
req *api.ResponseRequest
|
||||||
|
expectError bool
|
||||||
|
errorMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "returns error when API key missing",
|
||||||
|
provider: &Provider{
|
||||||
|
cfg: config.ProviderConfig{APIKey: ""},
|
||||||
|
client: nil,
|
||||||
|
},
|
||||||
|
messages: []api.Message{
|
||||||
|
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
|
||||||
|
},
|
||||||
|
req: &api.ResponseRequest{
|
||||||
|
Model: "claude-3-opus",
|
||||||
|
},
|
||||||
|
expectError: true,
|
||||||
|
errorMsg: "api key missing",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "returns error when client not initialized",
|
||||||
|
provider: &Provider{
|
||||||
|
cfg: config.ProviderConfig{APIKey: "sk-ant-test"},
|
||||||
|
client: nil,
|
||||||
|
},
|
||||||
|
messages: []api.Message{
|
||||||
|
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
|
||||||
|
},
|
||||||
|
req: &api.ResponseRequest{
|
||||||
|
Model: "claude-3-opus",
|
||||||
|
},
|
||||||
|
expectError: true,
|
||||||
|
errorMsg: "client not initialized",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, err := tt.provider.Generate(context.Background(), tt.messages, tt.req)
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
if tt.errorMsg != "" {
|
||||||
|
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProvider_GenerateStream_Validation(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
provider *Provider
|
||||||
|
messages []api.Message
|
||||||
|
req *api.ResponseRequest
|
||||||
|
expectError bool
|
||||||
|
errorMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "returns error when API key missing",
|
||||||
|
provider: &Provider{
|
||||||
|
cfg: config.ProviderConfig{APIKey: ""},
|
||||||
|
client: nil,
|
||||||
|
},
|
||||||
|
messages: []api.Message{
|
||||||
|
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
|
||||||
|
},
|
||||||
|
req: &api.ResponseRequest{
|
||||||
|
Model: "claude-3-opus",
|
||||||
|
},
|
||||||
|
expectError: true,
|
||||||
|
errorMsg: "api key missing",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "returns error when client not initialized",
|
||||||
|
provider: &Provider{
|
||||||
|
cfg: config.ProviderConfig{APIKey: "sk-ant-test"},
|
||||||
|
client: nil,
|
||||||
|
},
|
||||||
|
messages: []api.Message{
|
||||||
|
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
|
||||||
|
},
|
||||||
|
req: &api.ResponseRequest{
|
||||||
|
Model: "claude-3-opus",
|
||||||
|
},
|
||||||
|
expectError: true,
|
||||||
|
errorMsg: "client not initialized",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
deltaChan, errChan := tt.provider.GenerateStream(context.Background(), tt.messages, tt.req)
|
||||||
|
|
||||||
|
// Read from channels
|
||||||
|
var receivedError error
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case _, ok := <-deltaChan:
|
||||||
|
if !ok {
|
||||||
|
deltaChan = nil
|
||||||
|
}
|
||||||
|
case err, ok := <-errChan:
|
||||||
|
if ok && err != nil {
|
||||||
|
receivedError = err
|
||||||
|
}
|
||||||
|
errChan = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if deltaChan == nil && errChan == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
require.Error(t, receivedError)
|
||||||
|
if tt.errorMsg != "" {
|
||||||
|
assert.Contains(t, receivedError.Error(), tt.errorMsg)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, receivedError)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChooseModel(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
requested string
|
||||||
|
defaultModel string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "returns requested model when provided",
|
||||||
|
requested: "claude-3-opus",
|
||||||
|
defaultModel: "claude-3-sonnet",
|
||||||
|
expected: "claude-3-opus",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "returns default model when requested is empty",
|
||||||
|
requested: "",
|
||||||
|
defaultModel: "claude-3-sonnet",
|
||||||
|
expected: "claude-3-sonnet",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "returns fallback when both empty",
|
||||||
|
requested: "",
|
||||||
|
defaultModel: "",
|
||||||
|
expected: "claude-3-5-sonnet",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := chooseModel(tt.requested, tt.defaultModel)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractToolCalls(t *testing.T) {
|
||||||
|
// Note: This function is already tested in convert_test.go
|
||||||
|
// This is a placeholder for additional integration tests if needed
|
||||||
|
t.Run("returns nil for empty content", func(t *testing.T) {
|
||||||
|
result := extractToolCalls(nil)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
574
internal/providers/google/google_test.go
Normal file
574
internal/providers/google/google_test.go
Normal file
@@ -0,0 +1,574 @@
|
|||||||
|
package google
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ajac-zero/latticelm/internal/api"
|
||||||
|
"github.com/ajac-zero/latticelm/internal/config"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"google.golang.org/genai"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNew(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
cfg config.ProviderConfig
|
||||||
|
expectError bool
|
||||||
|
validate func(t *testing.T, p *Provider, err error)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "creates provider with API key",
|
||||||
|
cfg: config.ProviderConfig{
|
||||||
|
APIKey: "test-api-key",
|
||||||
|
Model: "gemini-2.0-flash",
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
validate: func(t *testing.T, p *Provider, err error) {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, p)
|
||||||
|
assert.NotNil(t, p.client)
|
||||||
|
assert.Equal(t, "test-api-key", p.cfg.APIKey)
|
||||||
|
assert.Equal(t, "gemini-2.0-flash", p.cfg.Model)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "creates provider without API key",
|
||||||
|
cfg: config.ProviderConfig{
|
||||||
|
APIKey: "",
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
validate: func(t *testing.T, p *Provider, err error) {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, p)
|
||||||
|
assert.Nil(t, p.client)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
p, err := New(tt.cfg)
|
||||||
|
tt.validate(t, p, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewVertexAI(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
cfg config.VertexAIConfig
|
||||||
|
expectError bool
|
||||||
|
validate func(t *testing.T, p *Provider, err error)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "creates Vertex AI provider with project and location",
|
||||||
|
cfg: config.VertexAIConfig{
|
||||||
|
Project: "my-gcp-project",
|
||||||
|
Location: "us-central1",
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
validate: func(t *testing.T, p *Provider, err error) {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, p)
|
||||||
|
// Client creation may fail without proper GCP credentials in test env
|
||||||
|
// but provider should be created
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "creates Vertex AI provider without project",
|
||||||
|
cfg: config.VertexAIConfig{
|
||||||
|
Project: "",
|
||||||
|
Location: "us-central1",
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
validate: func(t *testing.T, p *Provider, err error) {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, p)
|
||||||
|
assert.Nil(t, p.client)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "creates Vertex AI provider without location",
|
||||||
|
cfg: config.VertexAIConfig{
|
||||||
|
Project: "my-gcp-project",
|
||||||
|
Location: "",
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
validate: func(t *testing.T, p *Provider, err error) {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, p)
|
||||||
|
assert.Nil(t, p.client)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
p, err := NewVertexAI(tt.cfg)
|
||||||
|
tt.validate(t, p, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProvider_Name(t *testing.T) {
|
||||||
|
p := &Provider{}
|
||||||
|
assert.Equal(t, "google", p.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProvider_Generate_Validation(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
provider *Provider
|
||||||
|
messages []api.Message
|
||||||
|
req *api.ResponseRequest
|
||||||
|
expectError bool
|
||||||
|
errorMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "returns error when client not initialized",
|
||||||
|
provider: &Provider{
|
||||||
|
client: nil,
|
||||||
|
},
|
||||||
|
messages: []api.Message{
|
||||||
|
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
|
||||||
|
},
|
||||||
|
req: &api.ResponseRequest{
|
||||||
|
Model: "gemini-2.0-flash",
|
||||||
|
},
|
||||||
|
expectError: true,
|
||||||
|
errorMsg: "client not initialized",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, err := tt.provider.Generate(context.Background(), tt.messages, tt.req)
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
if tt.errorMsg != "" {
|
||||||
|
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProvider_GenerateStream_Validation(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
provider *Provider
|
||||||
|
messages []api.Message
|
||||||
|
req *api.ResponseRequest
|
||||||
|
expectError bool
|
||||||
|
errorMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "returns error when client not initialized",
|
||||||
|
provider: &Provider{
|
||||||
|
client: nil,
|
||||||
|
},
|
||||||
|
messages: []api.Message{
|
||||||
|
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
|
||||||
|
},
|
||||||
|
req: &api.ResponseRequest{
|
||||||
|
Model: "gemini-2.0-flash",
|
||||||
|
},
|
||||||
|
expectError: true,
|
||||||
|
errorMsg: "client not initialized",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
deltaChan, errChan := tt.provider.GenerateStream(context.Background(), tt.messages, tt.req)
|
||||||
|
|
||||||
|
// Read from channels
|
||||||
|
var receivedError error
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case _, ok := <-deltaChan:
|
||||||
|
if !ok {
|
||||||
|
deltaChan = nil
|
||||||
|
}
|
||||||
|
case err, ok := <-errChan:
|
||||||
|
if ok && err != nil {
|
||||||
|
receivedError = err
|
||||||
|
}
|
||||||
|
errChan = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if deltaChan == nil && errChan == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
require.Error(t, receivedError)
|
||||||
|
if tt.errorMsg != "" {
|
||||||
|
assert.Contains(t, receivedError.Error(), tt.errorMsg)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, receivedError)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertMessages(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
messages []api.Message
|
||||||
|
expectedContents int
|
||||||
|
expectedSystem string
|
||||||
|
validate func(t *testing.T, contents []*genai.Content, systemText string)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "converts user message",
|
||||||
|
messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: []api.ContentBlock{
|
||||||
|
{Type: "input_text", Text: "Hello"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedContents: 1,
|
||||||
|
expectedSystem: "",
|
||||||
|
validate: func(t *testing.T, contents []*genai.Content, systemText string) {
|
||||||
|
require.Len(t, contents, 1)
|
||||||
|
assert.Equal(t, "user", contents[0].Role)
|
||||||
|
assert.Equal(t, "", systemText)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extracts system message",
|
||||||
|
messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "system",
|
||||||
|
Content: []api.ContentBlock{
|
||||||
|
{Type: "input_text", Text: "You are a helpful assistant"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: []api.ContentBlock{
|
||||||
|
{Type: "input_text", Text: "Hello"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedContents: 1,
|
||||||
|
expectedSystem: "You are a helpful assistant",
|
||||||
|
validate: func(t *testing.T, contents []*genai.Content, systemText string) {
|
||||||
|
require.Len(t, contents, 1)
|
||||||
|
assert.Equal(t, "You are a helpful assistant", systemText)
|
||||||
|
assert.Equal(t, "user", contents[0].Role)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "converts assistant message with tool calls",
|
||||||
|
messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: []api.ContentBlock{
|
||||||
|
{Type: "output_text", Text: "Let me check the weather"},
|
||||||
|
},
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_123",
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: `{"location": "SF"}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedContents: 1,
|
||||||
|
validate: func(t *testing.T, contents []*genai.Content, systemText string) {
|
||||||
|
require.Len(t, contents, 1)
|
||||||
|
assert.Equal(t, "model", contents[0].Role)
|
||||||
|
// Should have text part and function call part
|
||||||
|
assert.GreaterOrEqual(t, len(contents[0].Parts), 1)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "converts tool result message",
|
||||||
|
messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{ID: "call_123", Name: "get_weather", Arguments: "{}"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "tool",
|
||||||
|
CallID: "call_123",
|
||||||
|
Name: "get_weather",
|
||||||
|
Content: []api.ContentBlock{
|
||||||
|
{Type: "output_text", Text: `{"temp": 72}`},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedContents: 2,
|
||||||
|
validate: func(t *testing.T, contents []*genai.Content, systemText string) {
|
||||||
|
require.Len(t, contents, 2)
|
||||||
|
// Tool result should be in user role
|
||||||
|
assert.Equal(t, "user", contents[1].Role)
|
||||||
|
require.Len(t, contents[1].Parts, 1)
|
||||||
|
assert.NotNil(t, contents[1].Parts[0].FunctionResponse)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "handles developer message as system",
|
||||||
|
messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "developer",
|
||||||
|
Content: []api.ContentBlock{
|
||||||
|
{Type: "input_text", Text: "Developer instruction"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedContents: 0,
|
||||||
|
expectedSystem: "Developer instruction",
|
||||||
|
validate: func(t *testing.T, contents []*genai.Content, systemText string) {
|
||||||
|
assert.Len(t, contents, 0)
|
||||||
|
assert.Equal(t, "Developer instruction", systemText)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
contents, systemText := convertMessages(tt.messages)
|
||||||
|
assert.Len(t, contents, tt.expectedContents)
|
||||||
|
assert.Equal(t, tt.expectedSystem, systemText)
|
||||||
|
if tt.validate != nil {
|
||||||
|
tt.validate(t, contents, systemText)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildConfig(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
systemText string
|
||||||
|
req *api.ResponseRequest
|
||||||
|
tools []*genai.Tool
|
||||||
|
toolConfig *genai.ToolConfig
|
||||||
|
expectNil bool
|
||||||
|
validate func(t *testing.T, cfg *genai.GenerateContentConfig)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "returns nil when no config needed",
|
||||||
|
systemText: "",
|
||||||
|
req: &api.ResponseRequest{},
|
||||||
|
tools: nil,
|
||||||
|
toolConfig: nil,
|
||||||
|
expectNil: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "creates config with system text",
|
||||||
|
systemText: "You are helpful",
|
||||||
|
req: &api.ResponseRequest{},
|
||||||
|
expectNil: false,
|
||||||
|
validate: func(t *testing.T, cfg *genai.GenerateContentConfig) {
|
||||||
|
require.NotNil(t, cfg)
|
||||||
|
require.NotNil(t, cfg.SystemInstruction)
|
||||||
|
assert.Len(t, cfg.SystemInstruction.Parts, 1)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "creates config with max tokens",
|
||||||
|
systemText: "",
|
||||||
|
req: &api.ResponseRequest{
|
||||||
|
MaxOutputTokens: intPtr(1000),
|
||||||
|
},
|
||||||
|
expectNil: false,
|
||||||
|
validate: func(t *testing.T, cfg *genai.GenerateContentConfig) {
|
||||||
|
require.NotNil(t, cfg)
|
||||||
|
assert.Equal(t, int32(1000), cfg.MaxOutputTokens)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "creates config with temperature",
|
||||||
|
systemText: "",
|
||||||
|
req: &api.ResponseRequest{
|
||||||
|
Temperature: float64Ptr(0.7),
|
||||||
|
},
|
||||||
|
expectNil: false,
|
||||||
|
validate: func(t *testing.T, cfg *genai.GenerateContentConfig) {
|
||||||
|
require.NotNil(t, cfg)
|
||||||
|
require.NotNil(t, cfg.Temperature)
|
||||||
|
assert.Equal(t, float32(0.7), *cfg.Temperature)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "creates config with top_p",
|
||||||
|
systemText: "",
|
||||||
|
req: &api.ResponseRequest{
|
||||||
|
TopP: float64Ptr(0.9),
|
||||||
|
},
|
||||||
|
expectNil: false,
|
||||||
|
validate: func(t *testing.T, cfg *genai.GenerateContentConfig) {
|
||||||
|
require.NotNil(t, cfg)
|
||||||
|
require.NotNil(t, cfg.TopP)
|
||||||
|
assert.Equal(t, float32(0.9), *cfg.TopP)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "creates config with tools",
|
||||||
|
systemText: "",
|
||||||
|
req: &api.ResponseRequest{},
|
||||||
|
tools: []*genai.Tool{
|
||||||
|
{
|
||||||
|
FunctionDeclarations: []*genai.FunctionDeclaration{
|
||||||
|
{Name: "get_weather"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectNil: false,
|
||||||
|
validate: func(t *testing.T, cfg *genai.GenerateContentConfig) {
|
||||||
|
require.NotNil(t, cfg)
|
||||||
|
require.Len(t, cfg.Tools, 1)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cfg := buildConfig(tt.systemText, tt.req, tt.tools, tt.toolConfig)
|
||||||
|
if tt.expectNil {
|
||||||
|
assert.Nil(t, cfg)
|
||||||
|
} else {
|
||||||
|
require.NotNil(t, cfg)
|
||||||
|
if tt.validate != nil {
|
||||||
|
tt.validate(t, cfg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChooseModel(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
requested string
|
||||||
|
defaultModel string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "returns requested model when provided",
|
||||||
|
requested: "gemini-1.5-pro",
|
||||||
|
defaultModel: "gemini-2.0-flash",
|
||||||
|
expected: "gemini-1.5-pro",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "returns default model when requested is empty",
|
||||||
|
requested: "",
|
||||||
|
defaultModel: "gemini-2.0-flash",
|
||||||
|
expected: "gemini-2.0-flash",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "returns fallback when both empty",
|
||||||
|
requested: "",
|
||||||
|
defaultModel: "",
|
||||||
|
expected: "gemini-2.0-flash-exp",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := chooseModel(tt.requested, tt.defaultModel)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractToolCallDelta(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
part *genai.Part
|
||||||
|
index int
|
||||||
|
expected *api.ToolCallDelta
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "extracts tool call delta",
|
||||||
|
part: &genai.Part{
|
||||||
|
FunctionCall: &genai.FunctionCall{
|
||||||
|
ID: "call_123",
|
||||||
|
Name: "get_weather",
|
||||||
|
Args: map[string]any{"location": "SF"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
index: 0,
|
||||||
|
expected: &api.ToolCallDelta{
|
||||||
|
Index: 0,
|
||||||
|
ID: "call_123",
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: `{"location":"SF"}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "returns nil for nil part",
|
||||||
|
part: nil,
|
||||||
|
index: 0,
|
||||||
|
expected: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "returns nil for part without function call",
|
||||||
|
part: &genai.Part{Text: "Hello"},
|
||||||
|
index: 0,
|
||||||
|
expected: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "generates ID when not provided",
|
||||||
|
part: &genai.Part{
|
||||||
|
FunctionCall: &genai.FunctionCall{
|
||||||
|
ID: "",
|
||||||
|
Name: "get_time",
|
||||||
|
Args: map[string]any{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
index: 1,
|
||||||
|
expected: &api.ToolCallDelta{
|
||||||
|
Index: 1,
|
||||||
|
Name: "get_time",
|
||||||
|
Arguments: `{}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := extractToolCallDelta(tt.part, tt.index)
|
||||||
|
if tt.expected == nil {
|
||||||
|
assert.Nil(t, result)
|
||||||
|
} else {
|
||||||
|
require.NotNil(t, result)
|
||||||
|
assert.Equal(t, tt.expected.Index, result.Index)
|
||||||
|
assert.Equal(t, tt.expected.Name, result.Name)
|
||||||
|
if tt.part != nil && tt.part.FunctionCall != nil && tt.part.FunctionCall.ID != "" {
|
||||||
|
assert.Equal(t, tt.expected.ID, result.ID)
|
||||||
|
} else if tt.expected.ID == "" {
|
||||||
|
// Generated ID should start with "call_"
|
||||||
|
assert.Contains(t, result.ID, "call_")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper functions
|
||||||
|
func intPtr(i int) *int {
|
||||||
|
return &i
|
||||||
|
}
|
||||||
|
|
||||||
|
func float64Ptr(f float64) *float64 {
|
||||||
|
return &f
|
||||||
|
}
|
||||||
304
internal/providers/openai/openai_test.go
Normal file
304
internal/providers/openai/openai_test.go
Normal file
@@ -0,0 +1,304 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ajac-zero/latticelm/internal/api"
|
||||||
|
"github.com/ajac-zero/latticelm/internal/config"
|
||||||
|
"github.com/openai/openai-go/v3"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNew(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
cfg config.ProviderConfig
|
||||||
|
validate func(t *testing.T, p *Provider)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "creates provider with API key",
|
||||||
|
cfg: config.ProviderConfig{
|
||||||
|
APIKey: "sk-test-key",
|
||||||
|
Model: "gpt-4o",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, p *Provider) {
|
||||||
|
assert.NotNil(t, p)
|
||||||
|
assert.NotNil(t, p.client)
|
||||||
|
assert.Equal(t, "sk-test-key", p.cfg.APIKey)
|
||||||
|
assert.Equal(t, "gpt-4o", p.cfg.Model)
|
||||||
|
assert.False(t, p.azure)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "creates provider without API key",
|
||||||
|
cfg: config.ProviderConfig{
|
||||||
|
APIKey: "",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, p *Provider) {
|
||||||
|
assert.NotNil(t, p)
|
||||||
|
assert.Nil(t, p.client)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
p := New(tt.cfg)
|
||||||
|
tt.validate(t, p)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewAzure(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
cfg config.AzureOpenAIConfig
|
||||||
|
validate func(t *testing.T, p *Provider)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "creates Azure provider with endpoint and API key",
|
||||||
|
cfg: config.AzureOpenAIConfig{
|
||||||
|
APIKey: "azure-key",
|
||||||
|
Endpoint: "https://test.openai.azure.com",
|
||||||
|
APIVersion: "2024-02-15-preview",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, p *Provider) {
|
||||||
|
assert.NotNil(t, p)
|
||||||
|
assert.NotNil(t, p.client)
|
||||||
|
assert.Equal(t, "azure-key", p.cfg.APIKey)
|
||||||
|
assert.True(t, p.azure)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "creates Azure provider with default API version",
|
||||||
|
cfg: config.AzureOpenAIConfig{
|
||||||
|
APIKey: "azure-key",
|
||||||
|
Endpoint: "https://test.openai.azure.com",
|
||||||
|
APIVersion: "",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, p *Provider) {
|
||||||
|
assert.NotNil(t, p)
|
||||||
|
assert.NotNil(t, p.client)
|
||||||
|
assert.True(t, p.azure)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "creates Azure provider without API key",
|
||||||
|
cfg: config.AzureOpenAIConfig{
|
||||||
|
APIKey: "",
|
||||||
|
Endpoint: "https://test.openai.azure.com",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, p *Provider) {
|
||||||
|
assert.NotNil(t, p)
|
||||||
|
assert.Nil(t, p.client)
|
||||||
|
assert.True(t, p.azure)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "creates Azure provider without endpoint",
|
||||||
|
cfg: config.AzureOpenAIConfig{
|
||||||
|
APIKey: "azure-key",
|
||||||
|
Endpoint: "",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, p *Provider) {
|
||||||
|
assert.NotNil(t, p)
|
||||||
|
assert.Nil(t, p.client)
|
||||||
|
assert.True(t, p.azure)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
p := NewAzure(tt.cfg)
|
||||||
|
tt.validate(t, p)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProvider_Name(t *testing.T) {
|
||||||
|
p := New(config.ProviderConfig{})
|
||||||
|
assert.Equal(t, "openai", p.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProvider_Generate_Validation(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
provider *Provider
|
||||||
|
messages []api.Message
|
||||||
|
req *api.ResponseRequest
|
||||||
|
expectError bool
|
||||||
|
errorMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "returns error when API key missing",
|
||||||
|
provider: &Provider{
|
||||||
|
cfg: config.ProviderConfig{APIKey: ""},
|
||||||
|
client: nil,
|
||||||
|
},
|
||||||
|
messages: []api.Message{
|
||||||
|
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
|
||||||
|
},
|
||||||
|
req: &api.ResponseRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
},
|
||||||
|
expectError: true,
|
||||||
|
errorMsg: "api key missing",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "returns error when client not initialized",
|
||||||
|
provider: &Provider{
|
||||||
|
cfg: config.ProviderConfig{APIKey: "sk-test"},
|
||||||
|
client: nil,
|
||||||
|
},
|
||||||
|
messages: []api.Message{
|
||||||
|
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
|
||||||
|
},
|
||||||
|
req: &api.ResponseRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
},
|
||||||
|
expectError: true,
|
||||||
|
errorMsg: "client not initialized",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, err := tt.provider.Generate(context.Background(), tt.messages, tt.req)
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
if tt.errorMsg != "" {
|
||||||
|
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProvider_GenerateStream_Validation(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
provider *Provider
|
||||||
|
messages []api.Message
|
||||||
|
req *api.ResponseRequest
|
||||||
|
expectError bool
|
||||||
|
errorMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "returns error when API key missing",
|
||||||
|
provider: &Provider{
|
||||||
|
cfg: config.ProviderConfig{APIKey: ""},
|
||||||
|
client: nil,
|
||||||
|
},
|
||||||
|
messages: []api.Message{
|
||||||
|
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
|
||||||
|
},
|
||||||
|
req: &api.ResponseRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
},
|
||||||
|
expectError: true,
|
||||||
|
errorMsg: "api key missing",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "returns error when client not initialized",
|
||||||
|
provider: &Provider{
|
||||||
|
cfg: config.ProviderConfig{APIKey: "sk-test"},
|
||||||
|
client: nil,
|
||||||
|
},
|
||||||
|
messages: []api.Message{
|
||||||
|
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
|
||||||
|
},
|
||||||
|
req: &api.ResponseRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
},
|
||||||
|
expectError: true,
|
||||||
|
errorMsg: "client not initialized",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
deltaChan, errChan := tt.provider.GenerateStream(context.Background(), tt.messages, tt.req)
|
||||||
|
|
||||||
|
// Read from channels
|
||||||
|
var receivedError error
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case _, ok := <-deltaChan:
|
||||||
|
if !ok {
|
||||||
|
deltaChan = nil
|
||||||
|
}
|
||||||
|
case err, ok := <-errChan:
|
||||||
|
if ok && err != nil {
|
||||||
|
receivedError = err
|
||||||
|
}
|
||||||
|
errChan = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if deltaChan == nil && errChan == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
require.Error(t, receivedError)
|
||||||
|
if tt.errorMsg != "" {
|
||||||
|
assert.Contains(t, receivedError.Error(), tt.errorMsg)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, receivedError)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChooseModel(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
requested string
|
||||||
|
defaultModel string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "returns requested model when provided",
|
||||||
|
requested: "gpt-4o",
|
||||||
|
defaultModel: "gpt-4o-mini",
|
||||||
|
expected: "gpt-4o",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "returns default model when requested is empty",
|
||||||
|
requested: "",
|
||||||
|
defaultModel: "gpt-4o-mini",
|
||||||
|
expected: "gpt-4o-mini",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "returns fallback when both empty",
|
||||||
|
requested: "",
|
||||||
|
defaultModel: "",
|
||||||
|
expected: "gpt-4o-mini",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := chooseModel(tt.requested, tt.defaultModel)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractToolCalls_Integration(t *testing.T) {
|
||||||
|
// Additional integration tests for extractToolCalls beyond convert_test.go
|
||||||
|
t.Run("handles empty message", func(t *testing.T) {
|
||||||
|
msg := openai.ChatCompletionMessage{}
|
||||||
|
result := extractToolCalls(msg)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user