689 lines
15 KiB
Go
689 lines
15 KiB
Go
package providers
|
|
|
|
import (
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/ajac-zero/latticelm/internal/config"
|
|
)
|
|
|
|
func TestNewRegistry(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
entries map[string]config.ProviderEntry
|
|
models []config.ModelEntry
|
|
expectError bool
|
|
errorMsg string
|
|
validate func(t *testing.T, reg *Registry)
|
|
}{
|
|
{
|
|
name: "valid config with OpenAI",
|
|
entries: map[string]config.ProviderEntry{
|
|
"openai": {
|
|
Type: "openai",
|
|
APIKey: "sk-test",
|
|
},
|
|
},
|
|
models: []config.ModelEntry{
|
|
{Name: "gpt-4", Provider: "openai"},
|
|
},
|
|
validate: func(t *testing.T, reg *Registry) {
|
|
assert.Len(t, reg.providers, 1)
|
|
assert.Contains(t, reg.providers, "openai")
|
|
assert.Equal(t, "openai", reg.models["gpt-4"])
|
|
},
|
|
},
|
|
{
|
|
name: "valid config with multiple providers",
|
|
entries: map[string]config.ProviderEntry{
|
|
"openai": {
|
|
Type: "openai",
|
|
APIKey: "sk-test",
|
|
},
|
|
"anthropic": {
|
|
Type: "anthropic",
|
|
APIKey: "sk-ant-test",
|
|
},
|
|
},
|
|
models: []config.ModelEntry{
|
|
{Name: "gpt-4", Provider: "openai"},
|
|
{Name: "claude-3", Provider: "anthropic"},
|
|
},
|
|
validate: func(t *testing.T, reg *Registry) {
|
|
assert.Len(t, reg.providers, 2)
|
|
assert.Contains(t, reg.providers, "openai")
|
|
assert.Contains(t, reg.providers, "anthropic")
|
|
assert.Equal(t, "openai", reg.models["gpt-4"])
|
|
assert.Equal(t, "anthropic", reg.models["claude-3"])
|
|
},
|
|
},
|
|
{
|
|
name: "no providers returns error",
|
|
entries: map[string]config.ProviderEntry{
|
|
"openai": {
|
|
Type: "openai",
|
|
APIKey: "", // Missing API key
|
|
},
|
|
},
|
|
models: []config.ModelEntry{},
|
|
expectError: true,
|
|
errorMsg: "no providers configured",
|
|
},
|
|
{
|
|
name: "Azure OpenAI without endpoint returns error",
|
|
entries: map[string]config.ProviderEntry{
|
|
"azure": {
|
|
Type: "azureopenai",
|
|
APIKey: "test-key",
|
|
},
|
|
},
|
|
models: []config.ModelEntry{},
|
|
expectError: true,
|
|
errorMsg: "endpoint is required",
|
|
},
|
|
{
|
|
name: "Azure OpenAI with endpoint succeeds",
|
|
entries: map[string]config.ProviderEntry{
|
|
"azure": {
|
|
Type: "azureopenai",
|
|
APIKey: "test-key",
|
|
Endpoint: "https://test.openai.azure.com",
|
|
APIVersion: "2024-02-15-preview",
|
|
},
|
|
},
|
|
models: []config.ModelEntry{
|
|
{Name: "gpt-4-azure", Provider: "azure"},
|
|
},
|
|
validate: func(t *testing.T, reg *Registry) {
|
|
assert.Len(t, reg.providers, 1)
|
|
assert.Contains(t, reg.providers, "azure")
|
|
},
|
|
},
|
|
{
|
|
name: "Azure Anthropic without endpoint returns error",
|
|
entries: map[string]config.ProviderEntry{
|
|
"azure-anthropic": {
|
|
Type: "azureanthropic",
|
|
APIKey: "test-key",
|
|
},
|
|
},
|
|
models: []config.ModelEntry{},
|
|
expectError: true,
|
|
errorMsg: "endpoint is required",
|
|
},
|
|
{
|
|
name: "Azure Anthropic with endpoint succeeds",
|
|
entries: map[string]config.ProviderEntry{
|
|
"azure-anthropic": {
|
|
Type: "azureanthropic",
|
|
APIKey: "test-key",
|
|
Endpoint: "https://test.anthropic.azure.com",
|
|
},
|
|
},
|
|
models: []config.ModelEntry{
|
|
{Name: "claude-3-azure", Provider: "azure-anthropic"},
|
|
},
|
|
validate: func(t *testing.T, reg *Registry) {
|
|
assert.Len(t, reg.providers, 1)
|
|
assert.Contains(t, reg.providers, "azure-anthropic")
|
|
},
|
|
},
|
|
{
|
|
name: "Google provider",
|
|
entries: map[string]config.ProviderEntry{
|
|
"google": {
|
|
Type: "google",
|
|
APIKey: "test-key",
|
|
},
|
|
},
|
|
models: []config.ModelEntry{
|
|
{Name: "gemini-pro", Provider: "google"},
|
|
},
|
|
validate: func(t *testing.T, reg *Registry) {
|
|
assert.Len(t, reg.providers, 1)
|
|
assert.Contains(t, reg.providers, "google")
|
|
},
|
|
},
|
|
{
|
|
name: "Vertex AI without project/location returns error",
|
|
entries: map[string]config.ProviderEntry{
|
|
"vertex": {
|
|
Type: "vertexai",
|
|
},
|
|
},
|
|
models: []config.ModelEntry{},
|
|
expectError: true,
|
|
errorMsg: "project and location are required",
|
|
},
|
|
{
|
|
name: "Vertex AI with project and location succeeds",
|
|
entries: map[string]config.ProviderEntry{
|
|
"vertex": {
|
|
Type: "vertexai",
|
|
Project: "my-project",
|
|
Location: "us-central1",
|
|
},
|
|
},
|
|
models: []config.ModelEntry{
|
|
{Name: "gemini-pro-vertex", Provider: "vertex"},
|
|
},
|
|
validate: func(t *testing.T, reg *Registry) {
|
|
assert.Len(t, reg.providers, 1)
|
|
assert.Contains(t, reg.providers, "vertex")
|
|
},
|
|
},
|
|
{
|
|
name: "unknown provider type returns error",
|
|
entries: map[string]config.ProviderEntry{
|
|
"unknown": {
|
|
Type: "unknown-type",
|
|
APIKey: "test-key",
|
|
},
|
|
},
|
|
models: []config.ModelEntry{},
|
|
expectError: true,
|
|
errorMsg: "unknown provider type",
|
|
},
|
|
{
|
|
name: "provider with no API key is skipped",
|
|
entries: map[string]config.ProviderEntry{
|
|
"openai-no-key": {
|
|
Type: "openai",
|
|
APIKey: "",
|
|
},
|
|
"anthropic-with-key": {
|
|
Type: "anthropic",
|
|
APIKey: "sk-ant-test",
|
|
},
|
|
},
|
|
models: []config.ModelEntry{
|
|
{Name: "claude-3", Provider: "anthropic-with-key"},
|
|
},
|
|
validate: func(t *testing.T, reg *Registry) {
|
|
assert.Len(t, reg.providers, 1)
|
|
assert.Contains(t, reg.providers, "anthropic-with-key")
|
|
assert.NotContains(t, reg.providers, "openai-no-key")
|
|
},
|
|
},
|
|
{
|
|
name: "model with provider_model_id",
|
|
entries: map[string]config.ProviderEntry{
|
|
"azure": {
|
|
Type: "azureopenai",
|
|
APIKey: "test-key",
|
|
Endpoint: "https://test.openai.azure.com",
|
|
},
|
|
},
|
|
models: []config.ModelEntry{
|
|
{
|
|
Name: "gpt-4",
|
|
Provider: "azure",
|
|
ProviderModelID: "gpt-4-deployment-name",
|
|
},
|
|
},
|
|
validate: func(t *testing.T, reg *Registry) {
|
|
assert.Equal(t, "gpt-4-deployment-name", reg.providerModelIDs["gpt-4"])
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
reg, err := NewRegistry(tt.entries, tt.models)
|
|
|
|
if tt.expectError {
|
|
assert.Error(t, err)
|
|
if tt.errorMsg != "" {
|
|
assert.Contains(t, err.Error(), tt.errorMsg)
|
|
}
|
|
return
|
|
}
|
|
|
|
require.NoError(t, err)
|
|
require.NotNil(t, reg)
|
|
|
|
if tt.validate != nil {
|
|
tt.validate(t, reg)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRegistry_Get(t *testing.T) {
|
|
reg, err := NewRegistry(
|
|
map[string]config.ProviderEntry{
|
|
"openai": {
|
|
Type: "openai",
|
|
APIKey: "sk-test",
|
|
},
|
|
"anthropic": {
|
|
Type: "anthropic",
|
|
APIKey: "sk-ant-test",
|
|
},
|
|
},
|
|
[]config.ModelEntry{
|
|
{Name: "gpt-4", Provider: "openai"},
|
|
},
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
tests := []struct {
|
|
name string
|
|
providerKey string
|
|
expectFound bool
|
|
validate func(t *testing.T, p Provider)
|
|
}{
|
|
{
|
|
name: "existing provider",
|
|
providerKey: "openai",
|
|
expectFound: true,
|
|
validate: func(t *testing.T, p Provider) {
|
|
assert.Equal(t, "openai", p.Name())
|
|
},
|
|
},
|
|
{
|
|
name: "another existing provider",
|
|
providerKey: "anthropic",
|
|
expectFound: true,
|
|
validate: func(t *testing.T, p Provider) {
|
|
assert.Equal(t, "anthropic", p.Name())
|
|
},
|
|
},
|
|
{
|
|
name: "nonexistent provider",
|
|
providerKey: "nonexistent",
|
|
expectFound: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
p, found := reg.Get(tt.providerKey)
|
|
|
|
if tt.expectFound {
|
|
assert.True(t, found)
|
|
require.NotNil(t, p)
|
|
if tt.validate != nil {
|
|
tt.validate(t, p)
|
|
}
|
|
} else {
|
|
assert.False(t, found)
|
|
assert.Nil(t, p)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRegistry_Models(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
models []config.ModelEntry
|
|
validate func(t *testing.T, models []struct{ Provider, Model string })
|
|
}{
|
|
{
|
|
name: "single model",
|
|
models: []config.ModelEntry{
|
|
{Name: "gpt-4", Provider: "openai"},
|
|
},
|
|
validate: func(t *testing.T, models []struct{ Provider, Model string }) {
|
|
require.Len(t, models, 1)
|
|
assert.Equal(t, "gpt-4", models[0].Model)
|
|
assert.Equal(t, "openai", models[0].Provider)
|
|
},
|
|
},
|
|
{
|
|
name: "multiple models",
|
|
models: []config.ModelEntry{
|
|
{Name: "gpt-4", Provider: "openai"},
|
|
{Name: "claude-3", Provider: "anthropic"},
|
|
{Name: "gemini-pro", Provider: "google"},
|
|
},
|
|
validate: func(t *testing.T, models []struct{ Provider, Model string }) {
|
|
require.Len(t, models, 3)
|
|
modelNames := make([]string, len(models))
|
|
for i, m := range models {
|
|
modelNames[i] = m.Model
|
|
}
|
|
assert.Contains(t, modelNames, "gpt-4")
|
|
assert.Contains(t, modelNames, "claude-3")
|
|
assert.Contains(t, modelNames, "gemini-pro")
|
|
},
|
|
},
|
|
{
|
|
name: "no models",
|
|
models: []config.ModelEntry{},
|
|
validate: func(t *testing.T, models []struct{ Provider, Model string }) {
|
|
assert.Len(t, models, 0)
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
reg, err := NewRegistry(
|
|
map[string]config.ProviderEntry{
|
|
"openai": {
|
|
Type: "openai",
|
|
APIKey: "sk-test",
|
|
},
|
|
"anthropic": {
|
|
Type: "anthropic",
|
|
APIKey: "sk-ant-test",
|
|
},
|
|
"google": {
|
|
Type: "google",
|
|
APIKey: "test-key",
|
|
},
|
|
},
|
|
tt.models,
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
models := reg.Models()
|
|
if tt.validate != nil {
|
|
tt.validate(t, models)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRegistry_ResolveModelID(t *testing.T) {
|
|
reg, err := NewRegistry(
|
|
map[string]config.ProviderEntry{
|
|
"openai": {
|
|
Type: "openai",
|
|
APIKey: "sk-test",
|
|
},
|
|
"azure": {
|
|
Type: "azureopenai",
|
|
APIKey: "test-key",
|
|
Endpoint: "https://test.openai.azure.com",
|
|
},
|
|
},
|
|
[]config.ModelEntry{
|
|
{Name: "gpt-4", Provider: "openai"},
|
|
{Name: "gpt-4-azure", Provider: "azure", ProviderModelID: "gpt-4-deployment"},
|
|
},
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
tests := []struct {
|
|
name string
|
|
modelName string
|
|
expected string
|
|
}{
|
|
{
|
|
name: "model without provider_model_id returns model name",
|
|
modelName: "gpt-4",
|
|
expected: "gpt-4",
|
|
},
|
|
{
|
|
name: "model with provider_model_id returns provider_model_id",
|
|
modelName: "gpt-4-azure",
|
|
expected: "gpt-4-deployment",
|
|
},
|
|
{
|
|
name: "unknown model returns model name",
|
|
modelName: "unknown-model",
|
|
expected: "unknown-model",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := reg.ResolveModelID(tt.modelName)
|
|
assert.Equal(t, tt.expected, result)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRegistry_Default(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
setupReg func() *Registry
|
|
modelName string
|
|
expectError bool
|
|
errorMsg string
|
|
validate func(t *testing.T, p Provider)
|
|
}{
|
|
{
|
|
name: "returns provider for known model",
|
|
setupReg: func() *Registry {
|
|
reg, _ := NewRegistry(
|
|
map[string]config.ProviderEntry{
|
|
"openai": {
|
|
Type: "openai",
|
|
APIKey: "sk-test",
|
|
},
|
|
"anthropic": {
|
|
Type: "anthropic",
|
|
APIKey: "sk-ant-test",
|
|
},
|
|
},
|
|
[]config.ModelEntry{
|
|
{Name: "gpt-4", Provider: "openai"},
|
|
{Name: "claude-3", Provider: "anthropic"},
|
|
},
|
|
)
|
|
return reg
|
|
},
|
|
modelName: "gpt-4",
|
|
validate: func(t *testing.T, p Provider) {
|
|
assert.Equal(t, "openai", p.Name())
|
|
},
|
|
},
|
|
{
|
|
name: "returns error for unknown model",
|
|
setupReg: func() *Registry {
|
|
reg, _ := NewRegistry(
|
|
map[string]config.ProviderEntry{
|
|
"openai": {
|
|
Type: "openai",
|
|
APIKey: "sk-test",
|
|
},
|
|
},
|
|
[]config.ModelEntry{
|
|
{Name: "gpt-4", Provider: "openai"},
|
|
},
|
|
)
|
|
return reg
|
|
},
|
|
modelName: "unknown-model",
|
|
expectError: true,
|
|
errorMsg: "not configured",
|
|
},
|
|
{
|
|
name: "returns error for model whose provider is unavailable",
|
|
setupReg: func() *Registry {
|
|
reg, _ := NewRegistry(
|
|
map[string]config.ProviderEntry{
|
|
"openai": {
|
|
Type: "openai",
|
|
APIKey: "", // unavailable provider
|
|
},
|
|
"google": {
|
|
Type: "google",
|
|
APIKey: "test-key",
|
|
},
|
|
},
|
|
[]config.ModelEntry{
|
|
{Name: "gpt-4", Provider: "openai"},
|
|
{Name: "gemini-pro", Provider: "google"},
|
|
},
|
|
)
|
|
return reg
|
|
},
|
|
modelName: "gpt-4",
|
|
expectError: true,
|
|
errorMsg: "not available",
|
|
},
|
|
{
|
|
name: "returns first provider for empty model name",
|
|
setupReg: func() *Registry {
|
|
reg, _ := NewRegistry(
|
|
map[string]config.ProviderEntry{
|
|
"openai": {
|
|
Type: "openai",
|
|
APIKey: "sk-test",
|
|
},
|
|
},
|
|
[]config.ModelEntry{
|
|
{Name: "gpt-4", Provider: "openai"},
|
|
},
|
|
)
|
|
return reg
|
|
},
|
|
modelName: "",
|
|
validate: func(t *testing.T, p Provider) {
|
|
assert.NotNil(t, p)
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
reg := tt.setupReg()
|
|
p, err := reg.Default(tt.modelName)
|
|
|
|
if tt.expectError {
|
|
assert.Error(t, err)
|
|
if tt.errorMsg != "" {
|
|
assert.Contains(t, err.Error(), tt.errorMsg)
|
|
}
|
|
return
|
|
}
|
|
|
|
require.NoError(t, err)
|
|
require.NotNil(t, p)
|
|
|
|
if tt.validate != nil {
|
|
tt.validate(t, p)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRegistry_Models_FiltersUnavailableProviders(t *testing.T) {
|
|
reg, err := NewRegistry(
|
|
map[string]config.ProviderEntry{
|
|
"openai": {
|
|
Type: "openai",
|
|
APIKey: "", // unavailable provider
|
|
},
|
|
"google": {
|
|
Type: "google",
|
|
APIKey: "test-key",
|
|
},
|
|
},
|
|
[]config.ModelEntry{
|
|
{Name: "gpt-4", Provider: "openai"},
|
|
{Name: "gemini-pro", Provider: "google"},
|
|
},
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
models := reg.Models()
|
|
require.Len(t, models, 1)
|
|
assert.Equal(t, "gemini-pro", models[0].Model)
|
|
assert.Equal(t, "google", models[0].Provider)
|
|
}
|
|
|
|
func TestBuildProvider(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
entry config.ProviderEntry
|
|
expectError bool
|
|
errorMsg string
|
|
expectNil bool
|
|
validate func(t *testing.T, p Provider)
|
|
}{
|
|
{
|
|
name: "OpenAI provider",
|
|
entry: config.ProviderEntry{
|
|
Type: "openai",
|
|
APIKey: "sk-test",
|
|
},
|
|
validate: func(t *testing.T, p Provider) {
|
|
assert.Equal(t, "openai", p.Name())
|
|
},
|
|
},
|
|
{
|
|
name: "OpenAI provider with custom endpoint",
|
|
entry: config.ProviderEntry{
|
|
Type: "openai",
|
|
APIKey: "sk-test",
|
|
Endpoint: "https://custom.openai.com",
|
|
},
|
|
validate: func(t *testing.T, p Provider) {
|
|
assert.Equal(t, "openai", p.Name())
|
|
},
|
|
},
|
|
{
|
|
name: "Anthropic provider",
|
|
entry: config.ProviderEntry{
|
|
Type: "anthropic",
|
|
APIKey: "sk-ant-test",
|
|
},
|
|
validate: func(t *testing.T, p Provider) {
|
|
assert.Equal(t, "anthropic", p.Name())
|
|
},
|
|
},
|
|
{
|
|
name: "Google provider",
|
|
entry: config.ProviderEntry{
|
|
Type: "google",
|
|
APIKey: "test-key",
|
|
},
|
|
validate: func(t *testing.T, p Provider) {
|
|
assert.Equal(t, "google", p.Name())
|
|
},
|
|
},
|
|
{
|
|
name: "provider without API key returns nil",
|
|
entry: config.ProviderEntry{
|
|
Type: "openai",
|
|
APIKey: "",
|
|
},
|
|
expectNil: true,
|
|
},
|
|
{
|
|
name: "unknown provider type",
|
|
entry: config.ProviderEntry{
|
|
Type: "unknown",
|
|
APIKey: "test-key",
|
|
},
|
|
expectError: true,
|
|
errorMsg: "unknown provider type",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
p, err := buildProvider(tt.entry)
|
|
|
|
if tt.expectError {
|
|
assert.Error(t, err)
|
|
if tt.errorMsg != "" {
|
|
assert.Contains(t, err.Error(), tt.errorMsg)
|
|
}
|
|
return
|
|
}
|
|
|
|
require.NoError(t, err)
|
|
|
|
if tt.expectNil {
|
|
assert.Nil(t, p)
|
|
return
|
|
}
|
|
|
|
require.NotNil(t, p)
|
|
|
|
if tt.validate != nil {
|
|
tt.validate(t, p)
|
|
}
|
|
})
|
|
}
|
|
}
|