292 lines
6.7 KiB
Go
292 lines
6.7 KiB
Go
package anthropic
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
|
|
"github.com/ajac-zero/latticelm/internal/api"
|
|
"github.com/ajac-zero/latticelm/internal/config"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestNew(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
cfg config.ProviderConfig
|
|
validate func(t *testing.T, p *Provider)
|
|
}{
|
|
{
|
|
name: "creates provider with API key",
|
|
cfg: config.ProviderConfig{
|
|
APIKey: "sk-ant-test-key",
|
|
Model: "claude-3-opus",
|
|
},
|
|
validate: func(t *testing.T, p *Provider) {
|
|
assert.NotNil(t, p)
|
|
assert.NotNil(t, p.client)
|
|
assert.Equal(t, "sk-ant-test-key", p.cfg.APIKey)
|
|
assert.Equal(t, "claude-3-opus", p.cfg.Model)
|
|
assert.False(t, p.azure)
|
|
},
|
|
},
|
|
{
|
|
name: "creates provider without API key",
|
|
cfg: config.ProviderConfig{
|
|
APIKey: "",
|
|
},
|
|
validate: func(t *testing.T, p *Provider) {
|
|
assert.NotNil(t, p)
|
|
assert.Nil(t, p.client)
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
p := New(tt.cfg)
|
|
tt.validate(t, p)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestNewAzure(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
cfg config.AzureAnthropicConfig
|
|
validate func(t *testing.T, p *Provider)
|
|
}{
|
|
{
|
|
name: "creates Azure provider with endpoint and API key",
|
|
cfg: config.AzureAnthropicConfig{
|
|
APIKey: "azure-key",
|
|
Endpoint: "https://test.services.ai.azure.com/anthropic",
|
|
Model: "claude-3-sonnet",
|
|
},
|
|
validate: func(t *testing.T, p *Provider) {
|
|
assert.NotNil(t, p)
|
|
assert.NotNil(t, p.client)
|
|
assert.Equal(t, "azure-key", p.cfg.APIKey)
|
|
assert.Equal(t, "claude-3-sonnet", p.cfg.Model)
|
|
assert.True(t, p.azure)
|
|
},
|
|
},
|
|
{
|
|
name: "creates Azure provider without API key",
|
|
cfg: config.AzureAnthropicConfig{
|
|
APIKey: "",
|
|
Endpoint: "https://test.services.ai.azure.com/anthropic",
|
|
},
|
|
validate: func(t *testing.T, p *Provider) {
|
|
assert.NotNil(t, p)
|
|
assert.Nil(t, p.client)
|
|
assert.True(t, p.azure)
|
|
},
|
|
},
|
|
{
|
|
name: "creates Azure provider without endpoint",
|
|
cfg: config.AzureAnthropicConfig{
|
|
APIKey: "azure-key",
|
|
Endpoint: "",
|
|
},
|
|
validate: func(t *testing.T, p *Provider) {
|
|
assert.NotNil(t, p)
|
|
assert.Nil(t, p.client)
|
|
assert.True(t, p.azure)
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
p := NewAzure(tt.cfg)
|
|
tt.validate(t, p)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestProvider_Name(t *testing.T) {
|
|
p := New(config.ProviderConfig{})
|
|
assert.Equal(t, "anthropic", p.Name())
|
|
}
|
|
|
|
func TestProvider_Generate_Validation(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
provider *Provider
|
|
messages []api.Message
|
|
req *api.ResponseRequest
|
|
expectError bool
|
|
errorMsg string
|
|
}{
|
|
{
|
|
name: "returns error when API key missing",
|
|
provider: &Provider{
|
|
cfg: config.ProviderConfig{APIKey: ""},
|
|
client: nil,
|
|
},
|
|
messages: []api.Message{
|
|
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
|
|
},
|
|
req: &api.ResponseRequest{
|
|
Model: "claude-3-opus",
|
|
},
|
|
expectError: true,
|
|
errorMsg: "api key missing",
|
|
},
|
|
{
|
|
name: "returns error when client not initialized",
|
|
provider: &Provider{
|
|
cfg: config.ProviderConfig{APIKey: "sk-ant-test"},
|
|
client: nil,
|
|
},
|
|
messages: []api.Message{
|
|
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
|
|
},
|
|
req: &api.ResponseRequest{
|
|
Model: "claude-3-opus",
|
|
},
|
|
expectError: true,
|
|
errorMsg: "client not initialized",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result, err := tt.provider.Generate(context.Background(), tt.messages, tt.req)
|
|
|
|
if tt.expectError {
|
|
assert.Error(t, err)
|
|
assert.Nil(t, result)
|
|
if tt.errorMsg != "" {
|
|
assert.Contains(t, err.Error(), tt.errorMsg)
|
|
}
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, result)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestProvider_GenerateStream_Validation(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
provider *Provider
|
|
messages []api.Message
|
|
req *api.ResponseRequest
|
|
expectError bool
|
|
errorMsg string
|
|
}{
|
|
{
|
|
name: "returns error when API key missing",
|
|
provider: &Provider{
|
|
cfg: config.ProviderConfig{APIKey: ""},
|
|
client: nil,
|
|
},
|
|
messages: []api.Message{
|
|
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
|
|
},
|
|
req: &api.ResponseRequest{
|
|
Model: "claude-3-opus",
|
|
},
|
|
expectError: true,
|
|
errorMsg: "api key missing",
|
|
},
|
|
{
|
|
name: "returns error when client not initialized",
|
|
provider: &Provider{
|
|
cfg: config.ProviderConfig{APIKey: "sk-ant-test"},
|
|
client: nil,
|
|
},
|
|
messages: []api.Message{
|
|
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
|
|
},
|
|
req: &api.ResponseRequest{
|
|
Model: "claude-3-opus",
|
|
},
|
|
expectError: true,
|
|
errorMsg: "client not initialized",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
deltaChan, errChan := tt.provider.GenerateStream(context.Background(), tt.messages, tt.req)
|
|
|
|
// Read from channels
|
|
var receivedError error
|
|
for {
|
|
select {
|
|
case _, ok := <-deltaChan:
|
|
if !ok {
|
|
deltaChan = nil
|
|
}
|
|
case err, ok := <-errChan:
|
|
if ok && err != nil {
|
|
receivedError = err
|
|
}
|
|
errChan = nil
|
|
}
|
|
|
|
if deltaChan == nil && errChan == nil {
|
|
break
|
|
}
|
|
}
|
|
|
|
if tt.expectError {
|
|
require.Error(t, receivedError)
|
|
if tt.errorMsg != "" {
|
|
assert.Contains(t, receivedError.Error(), tt.errorMsg)
|
|
}
|
|
} else {
|
|
assert.NoError(t, receivedError)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestChooseModel(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
requested string
|
|
defaultModel string
|
|
expected string
|
|
}{
|
|
{
|
|
name: "returns requested model when provided",
|
|
requested: "claude-3-opus",
|
|
defaultModel: "claude-3-sonnet",
|
|
expected: "claude-3-opus",
|
|
},
|
|
{
|
|
name: "returns default model when requested is empty",
|
|
requested: "",
|
|
defaultModel: "claude-3-sonnet",
|
|
expected: "claude-3-sonnet",
|
|
},
|
|
{
|
|
name: "returns fallback when both empty",
|
|
requested: "",
|
|
defaultModel: "",
|
|
expected: "claude-3-5-sonnet",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := chooseModel(tt.requested, tt.defaultModel)
|
|
assert.Equal(t, tt.expected, result)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestExtractToolCalls(t *testing.T) {
|
|
// Note: This function is already tested in convert_test.go
|
|
// This is a placeholder for additional integration tests if needed
|
|
t.Run("returns nil for empty content", func(t *testing.T) {
|
|
result := extractToolCalls(nil)
|
|
assert.Nil(t, result)
|
|
})
|
|
}
|