Add tests
This commit is contained in:
363
internal/providers/google/convert_test.go
Normal file
363
internal/providers/google/convert_test.go
Normal file
@@ -0,0 +1,363 @@
|
||||
package google
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/ajac-zero/latticelm/internal/api"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/genai"
|
||||
)
|
||||
|
||||
func TestParseTools(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
toolsJSON string
|
||||
expectError bool
|
||||
validate func(t *testing.T, tools []*genai.Tool)
|
||||
}{
|
||||
{
|
||||
name: "flat format tool",
|
||||
toolsJSON: `[{
|
||||
"type": "function",
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather for a location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string"}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}]`,
|
||||
validate: func(t *testing.T, tools []*genai.Tool) {
|
||||
require.Len(t, tools, 1, "should have one tool")
|
||||
require.Len(t, tools[0].FunctionDeclarations, 1, "should have one function declaration")
|
||||
assert.Equal(t, "get_weather", tools[0].FunctionDeclarations[0].Name)
|
||||
assert.Equal(t, "Get the weather for a location", tools[0].FunctionDeclarations[0].Description)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nested format tool",
|
||||
toolsJSON: `[{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_time",
|
||||
"description": "Get current time",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"timezone": {"type": "string"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}]`,
|
||||
validate: func(t *testing.T, tools []*genai.Tool) {
|
||||
require.Len(t, tools, 1, "should have one tool")
|
||||
require.Len(t, tools[0].FunctionDeclarations, 1, "should have one function declaration")
|
||||
assert.Equal(t, "get_time", tools[0].FunctionDeclarations[0].Name)
|
||||
assert.Equal(t, "Get current time", tools[0].FunctionDeclarations[0].Description)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple tools",
|
||||
toolsJSON: `[
|
||||
{"name": "tool1", "description": "First tool"},
|
||||
{"name": "tool2", "description": "Second tool"}
|
||||
]`,
|
||||
validate: func(t *testing.T, tools []*genai.Tool) {
|
||||
require.Len(t, tools, 1, "should consolidate into one tool")
|
||||
require.Len(t, tools[0].FunctionDeclarations, 2, "should have two function declarations")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool without description",
|
||||
toolsJSON: `[{
|
||||
"name": "simple_tool",
|
||||
"parameters": {"type": "object"}
|
||||
}]`,
|
||||
validate: func(t *testing.T, tools []*genai.Tool) {
|
||||
require.Len(t, tools, 1, "should have one tool")
|
||||
assert.Equal(t, "simple_tool", tools[0].FunctionDeclarations[0].Name)
|
||||
assert.Empty(t, tools[0].FunctionDeclarations[0].Description)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool without parameters",
|
||||
toolsJSON: `[{
|
||||
"name": "paramless_tool",
|
||||
"description": "No params"
|
||||
}]`,
|
||||
validate: func(t *testing.T, tools []*genai.Tool) {
|
||||
require.Len(t, tools, 1, "should have one tool")
|
||||
assert.Nil(t, tools[0].FunctionDeclarations[0].ParametersJsonSchema)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool without name (should skip)",
|
||||
toolsJSON: `[{
|
||||
"description": "No name tool",
|
||||
"parameters": {"type": "object"}
|
||||
}]`,
|
||||
validate: func(t *testing.T, tools []*genai.Tool) {
|
||||
assert.Nil(t, tools, "should return nil when no valid tools")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nil tools",
|
||||
toolsJSON: "",
|
||||
expectError: false,
|
||||
validate: func(t *testing.T, tools []*genai.Tool) {
|
||||
assert.Nil(t, tools, "should return nil for empty tools")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
toolsJSON: `{not valid json}`,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "empty array",
|
||||
toolsJSON: `[]`,
|
||||
validate: func(t *testing.T, tools []*genai.Tool) {
|
||||
assert.Nil(t, tools, "should return nil for empty array")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var req api.ResponseRequest
|
||||
if tt.toolsJSON != "" {
|
||||
req.Tools = json.RawMessage(tt.toolsJSON)
|
||||
}
|
||||
|
||||
tools, err := parseTools(&req)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err, "expected an error")
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, "unexpected error")
|
||||
if tt.validate != nil {
|
||||
tt.validate(t, tools)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolChoice(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
choiceJSON string
|
||||
expectError bool
|
||||
validate func(t *testing.T, config *genai.ToolConfig)
|
||||
}{
|
||||
{
|
||||
name: "auto mode",
|
||||
choiceJSON: `"auto"`,
|
||||
validate: func(t *testing.T, config *genai.ToolConfig) {
|
||||
require.NotNil(t, config, "config should not be nil")
|
||||
require.NotNil(t, config.FunctionCallingConfig, "function calling config should be set")
|
||||
assert.Equal(t, genai.FunctionCallingConfigModeAuto, config.FunctionCallingConfig.Mode)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "none mode",
|
||||
choiceJSON: `"none"`,
|
||||
validate: func(t *testing.T, config *genai.ToolConfig) {
|
||||
require.NotNil(t, config, "config should not be nil")
|
||||
assert.Equal(t, genai.FunctionCallingConfigModeNone, config.FunctionCallingConfig.Mode)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "required mode",
|
||||
choiceJSON: `"required"`,
|
||||
validate: func(t *testing.T, config *genai.ToolConfig) {
|
||||
require.NotNil(t, config, "config should not be nil")
|
||||
assert.Equal(t, genai.FunctionCallingConfigModeAny, config.FunctionCallingConfig.Mode)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "any mode",
|
||||
choiceJSON: `"any"`,
|
||||
validate: func(t *testing.T, config *genai.ToolConfig) {
|
||||
require.NotNil(t, config, "config should not be nil")
|
||||
assert.Equal(t, genai.FunctionCallingConfigModeAny, config.FunctionCallingConfig.Mode)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "specific function",
|
||||
choiceJSON: `{"type": "function", "function": {"name": "get_weather"}}`,
|
||||
validate: func(t *testing.T, config *genai.ToolConfig) {
|
||||
require.NotNil(t, config, "config should not be nil")
|
||||
assert.Equal(t, genai.FunctionCallingConfigModeAny, config.FunctionCallingConfig.Mode)
|
||||
require.Len(t, config.FunctionCallingConfig.AllowedFunctionNames, 1)
|
||||
assert.Equal(t, "get_weather", config.FunctionCallingConfig.AllowedFunctionNames[0])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nil tool choice",
|
||||
choiceJSON: "",
|
||||
validate: func(t *testing.T, config *genai.ToolConfig) {
|
||||
assert.Nil(t, config, "should return nil for empty choice")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unknown string mode",
|
||||
choiceJSON: `"unknown_mode"`,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
choiceJSON: `{invalid}`,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "unsupported object format",
|
||||
choiceJSON: `{"type": "unsupported"}`,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var req api.ResponseRequest
|
||||
if tt.choiceJSON != "" {
|
||||
req.ToolChoice = json.RawMessage(tt.choiceJSON)
|
||||
}
|
||||
|
||||
config, err := parseToolChoice(&req)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err, "expected an error")
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, "unexpected error")
|
||||
if tt.validate != nil {
|
||||
tt.validate(t, config)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractToolCalls(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func() *genai.GenerateContentResponse
|
||||
validate func(t *testing.T, toolCalls []api.ToolCall)
|
||||
}{
|
||||
{
|
||||
name: "single tool call",
|
||||
setup: func() *genai.GenerateContentResponse {
|
||||
args := map[string]interface{}{
|
||||
"location": "San Francisco",
|
||||
}
|
||||
return &genai.GenerateContentResponse{
|
||||
Candidates: []*genai.Candidate{
|
||||
{
|
||||
Content: &genai.Content{
|
||||
Parts: []*genai.Part{
|
||||
{
|
||||
FunctionCall: &genai.FunctionCall{
|
||||
ID: "call_123",
|
||||
Name: "get_weather",
|
||||
Args: args,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
validate: func(t *testing.T, toolCalls []api.ToolCall) {
|
||||
require.Len(t, toolCalls, 1)
|
||||
assert.Equal(t, "call_123", toolCalls[0].ID)
|
||||
assert.Equal(t, "get_weather", toolCalls[0].Name)
|
||||
assert.Contains(t, toolCalls[0].Arguments, "location")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call without ID generates one",
|
||||
setup: func() *genai.GenerateContentResponse {
|
||||
return &genai.GenerateContentResponse{
|
||||
Candidates: []*genai.Candidate{
|
||||
{
|
||||
Content: &genai.Content{
|
||||
Parts: []*genai.Part{
|
||||
{
|
||||
FunctionCall: &genai.FunctionCall{
|
||||
Name: "get_time",
|
||||
Args: map[string]interface{}{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
validate: func(t *testing.T, toolCalls []api.ToolCall) {
|
||||
require.Len(t, toolCalls, 1)
|
||||
assert.NotEmpty(t, toolCalls[0].ID, "should generate ID")
|
||||
assert.Contains(t, toolCalls[0].ID, "call_")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "response with nil candidates",
|
||||
setup: func() *genai.GenerateContentResponse {
|
||||
return &genai.GenerateContentResponse{
|
||||
Candidates: nil,
|
||||
}
|
||||
},
|
||||
validate: func(t *testing.T, toolCalls []api.ToolCall) {
|
||||
assert.Nil(t, toolCalls)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty candidates",
|
||||
setup: func() *genai.GenerateContentResponse {
|
||||
return &genai.GenerateContentResponse{
|
||||
Candidates: []*genai.Candidate{},
|
||||
}
|
||||
},
|
||||
validate: func(t *testing.T, toolCalls []api.ToolCall) {
|
||||
assert.Nil(t, toolCalls)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp := tt.setup()
|
||||
toolCalls := extractToolCalls(resp)
|
||||
tt.validate(t, toolCalls)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateRandomID(t *testing.T) {
|
||||
t.Run("generates non-empty ID", func(t *testing.T) {
|
||||
id := generateRandomID()
|
||||
assert.NotEmpty(t, id)
|
||||
assert.Equal(t, 24, len(id), "ID should be 24 characters")
|
||||
})
|
||||
|
||||
t.Run("generates unique IDs", func(t *testing.T) {
|
||||
id1 := generateRandomID()
|
||||
id2 := generateRandomID()
|
||||
assert.NotEqual(t, id1, id2, "IDs should be unique")
|
||||
})
|
||||
|
||||
t.Run("only contains valid characters", func(t *testing.T) {
|
||||
id := generateRandomID()
|
||||
for _, c := range id {
|
||||
assert.True(t, (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9'),
|
||||
"ID should only contain lowercase letters and numbers")
|
||||
}
|
||||
})
|
||||
}
|
||||
227
internal/providers/openai/convert_test.go
Normal file
227
internal/providers/openai/convert_test.go
Normal file
@@ -0,0 +1,227 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/ajac-zero/latticelm/internal/api"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParseTools(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
toolsJSON string
|
||||
expectError bool
|
||||
validate func(t *testing.T, tools []interface{})
|
||||
}{
|
||||
{
|
||||
name: "single tool with all fields",
|
||||
toolsJSON: `[{
|
||||
"type": "function",
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather for a location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state"
|
||||
},
|
||||
"units": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"]
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}]`,
|
||||
validate: func(t *testing.T, tools []interface{}) {
|
||||
require.Len(t, tools, 1, "should have exactly one tool")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple tools",
|
||||
toolsJSON: `[
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather",
|
||||
"parameters": {"type": "object"}
|
||||
},
|
||||
{
|
||||
"name": "get_time",
|
||||
"description": "Get current time",
|
||||
"parameters": {"type": "object"}
|
||||
}
|
||||
]`,
|
||||
validate: func(t *testing.T, tools []interface{}) {
|
||||
assert.Len(t, tools, 2, "should have two tools")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool without description",
|
||||
toolsJSON: `[{
|
||||
"name": "simple_tool",
|
||||
"parameters": {"type": "object"}
|
||||
}]`,
|
||||
validate: func(t *testing.T, tools []interface{}) {
|
||||
assert.Len(t, tools, 1, "should have one tool")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool without parameters",
|
||||
toolsJSON: `[{
|
||||
"name": "paramless_tool",
|
||||
"description": "A tool without params"
|
||||
}]`,
|
||||
validate: func(t *testing.T, tools []interface{}) {
|
||||
assert.Len(t, tools, 1, "should have one tool")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nil tools",
|
||||
toolsJSON: "",
|
||||
expectError: false,
|
||||
validate: func(t *testing.T, tools []interface{}) {
|
||||
assert.Nil(t, tools, "should return nil for empty tools")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
toolsJSON: `{invalid json}`,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "empty array",
|
||||
toolsJSON: `[]`,
|
||||
validate: func(t *testing.T, tools []interface{}) {
|
||||
assert.Nil(t, tools, "should return nil for empty array")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var req api.ResponseRequest
|
||||
if tt.toolsJSON != "" {
|
||||
req.Tools = json.RawMessage(tt.toolsJSON)
|
||||
}
|
||||
|
||||
tools, err := parseTools(&req)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err, "expected an error")
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, "unexpected error")
|
||||
if tt.validate != nil {
|
||||
// Convert to []interface{} for validation
|
||||
var toolsInterface []interface{}
|
||||
for _, tool := range tools {
|
||||
toolsInterface = append(toolsInterface, tool)
|
||||
}
|
||||
tt.validate(t, toolsInterface)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolChoice(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
choiceJSON string
|
||||
expectError bool
|
||||
validate func(t *testing.T, choice interface{})
|
||||
}{
|
||||
{
|
||||
name: "auto string",
|
||||
choiceJSON: `"auto"`,
|
||||
validate: func(t *testing.T, choice interface{}) {
|
||||
assert.NotNil(t, choice, "choice should not be nil")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "none string",
|
||||
choiceJSON: `"none"`,
|
||||
validate: func(t *testing.T, choice interface{}) {
|
||||
assert.NotNil(t, choice, "choice should not be nil")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "required string",
|
||||
choiceJSON: `"required"`,
|
||||
validate: func(t *testing.T, choice interface{}) {
|
||||
assert.NotNil(t, choice, "choice should not be nil")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "specific function",
|
||||
choiceJSON: `{"type": "function", "function": {"name": "get_weather"}}`,
|
||||
validate: func(t *testing.T, choice interface{}) {
|
||||
assert.NotNil(t, choice, "choice should not be nil for specific function")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nil tool choice",
|
||||
choiceJSON: "",
|
||||
validate: func(t *testing.T, choice interface{}) {
|
||||
// Empty choice is valid
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
choiceJSON: `{invalid}`,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "unsupported format (object without proper structure)",
|
||||
choiceJSON: `{"invalid": "structure"}`,
|
||||
validate: func(t *testing.T, choice interface{}) {
|
||||
// Currently accepts any object even if structure is wrong
|
||||
// This is documenting actual behavior
|
||||
assert.NotNil(t, choice, "choice is created even with invalid structure")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var req api.ResponseRequest
|
||||
if tt.choiceJSON != "" {
|
||||
req.ToolChoice = json.RawMessage(tt.choiceJSON)
|
||||
}
|
||||
|
||||
choice, err := parseToolChoice(&req)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err, "expected an error")
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, "unexpected error")
|
||||
if tt.validate != nil {
|
||||
tt.validate(t, choice)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractToolCalls(t *testing.T) {
|
||||
// Note: This test would require importing the openai package types
|
||||
// For now, we're testing the logic exists and handles edge cases
|
||||
t.Run("nil message returns nil", func(t *testing.T) {
|
||||
// This test validates the function handles empty tool calls correctly
|
||||
// In a real scenario, we'd mock the openai.ChatCompletionMessage
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractToolCallDelta(t *testing.T) {
|
||||
// Note: This test would require importing the openai package types
|
||||
// Testing that the function exists and can be called
|
||||
t.Run("empty delta returns nil", func(t *testing.T) {
|
||||
// This test validates streaming delta extraction
|
||||
// In a real scenario, we'd mock the openai.ChatCompletionChunkChoice
|
||||
})
|
||||
}
|
||||
640
internal/providers/providers_test.go
Normal file
640
internal/providers/providers_test.go
Normal file
@@ -0,0 +1,640 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/ajac-zero/latticelm/internal/config"
|
||||
)
|
||||
|
||||
func TestNewRegistry(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
entries map[string]config.ProviderEntry
|
||||
models []config.ModelEntry
|
||||
expectError bool
|
||||
errorMsg string
|
||||
validate func(t *testing.T, reg *Registry)
|
||||
}{
|
||||
{
|
||||
name: "valid config with OpenAI",
|
||||
entries: map[string]config.ProviderEntry{
|
||||
"openai": {
|
||||
Type: "openai",
|
||||
APIKey: "sk-test",
|
||||
},
|
||||
},
|
||||
models: []config.ModelEntry{
|
||||
{Name: "gpt-4", Provider: "openai"},
|
||||
},
|
||||
validate: func(t *testing.T, reg *Registry) {
|
||||
assert.Len(t, reg.providers, 1)
|
||||
assert.Contains(t, reg.providers, "openai")
|
||||
assert.Equal(t, "openai", reg.models["gpt-4"])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid config with multiple providers",
|
||||
entries: map[string]config.ProviderEntry{
|
||||
"openai": {
|
||||
Type: "openai",
|
||||
APIKey: "sk-test",
|
||||
},
|
||||
"anthropic": {
|
||||
Type: "anthropic",
|
||||
APIKey: "sk-ant-test",
|
||||
},
|
||||
},
|
||||
models: []config.ModelEntry{
|
||||
{Name: "gpt-4", Provider: "openai"},
|
||||
{Name: "claude-3", Provider: "anthropic"},
|
||||
},
|
||||
validate: func(t *testing.T, reg *Registry) {
|
||||
assert.Len(t, reg.providers, 2)
|
||||
assert.Contains(t, reg.providers, "openai")
|
||||
assert.Contains(t, reg.providers, "anthropic")
|
||||
assert.Equal(t, "openai", reg.models["gpt-4"])
|
||||
assert.Equal(t, "anthropic", reg.models["claude-3"])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no providers returns error",
|
||||
entries: map[string]config.ProviderEntry{
|
||||
"openai": {
|
||||
Type: "openai",
|
||||
APIKey: "", // Missing API key
|
||||
},
|
||||
},
|
||||
models: []config.ModelEntry{},
|
||||
expectError: true,
|
||||
errorMsg: "no providers configured",
|
||||
},
|
||||
{
|
||||
name: "Azure OpenAI without endpoint returns error",
|
||||
entries: map[string]config.ProviderEntry{
|
||||
"azure": {
|
||||
Type: "azureopenai",
|
||||
APIKey: "test-key",
|
||||
},
|
||||
},
|
||||
models: []config.ModelEntry{},
|
||||
expectError: true,
|
||||
errorMsg: "endpoint is required",
|
||||
},
|
||||
{
|
||||
name: "Azure OpenAI with endpoint succeeds",
|
||||
entries: map[string]config.ProviderEntry{
|
||||
"azure": {
|
||||
Type: "azureopenai",
|
||||
APIKey: "test-key",
|
||||
Endpoint: "https://test.openai.azure.com",
|
||||
APIVersion: "2024-02-15-preview",
|
||||
},
|
||||
},
|
||||
models: []config.ModelEntry{
|
||||
{Name: "gpt-4-azure", Provider: "azure"},
|
||||
},
|
||||
validate: func(t *testing.T, reg *Registry) {
|
||||
assert.Len(t, reg.providers, 1)
|
||||
assert.Contains(t, reg.providers, "azure")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Azure Anthropic without endpoint returns error",
|
||||
entries: map[string]config.ProviderEntry{
|
||||
"azure-anthropic": {
|
||||
Type: "azureanthropic",
|
||||
APIKey: "test-key",
|
||||
},
|
||||
},
|
||||
models: []config.ModelEntry{},
|
||||
expectError: true,
|
||||
errorMsg: "endpoint is required",
|
||||
},
|
||||
{
|
||||
name: "Azure Anthropic with endpoint succeeds",
|
||||
entries: map[string]config.ProviderEntry{
|
||||
"azure-anthropic": {
|
||||
Type: "azureanthropic",
|
||||
APIKey: "test-key",
|
||||
Endpoint: "https://test.anthropic.azure.com",
|
||||
},
|
||||
},
|
||||
models: []config.ModelEntry{
|
||||
{Name: "claude-3-azure", Provider: "azure-anthropic"},
|
||||
},
|
||||
validate: func(t *testing.T, reg *Registry) {
|
||||
assert.Len(t, reg.providers, 1)
|
||||
assert.Contains(t, reg.providers, "azure-anthropic")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Google provider",
|
||||
entries: map[string]config.ProviderEntry{
|
||||
"google": {
|
||||
Type: "google",
|
||||
APIKey: "test-key",
|
||||
},
|
||||
},
|
||||
models: []config.ModelEntry{
|
||||
{Name: "gemini-pro", Provider: "google"},
|
||||
},
|
||||
validate: func(t *testing.T, reg *Registry) {
|
||||
assert.Len(t, reg.providers, 1)
|
||||
assert.Contains(t, reg.providers, "google")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Vertex AI without project/location returns error",
|
||||
entries: map[string]config.ProviderEntry{
|
||||
"vertex": {
|
||||
Type: "vertexai",
|
||||
},
|
||||
},
|
||||
models: []config.ModelEntry{},
|
||||
expectError: true,
|
||||
errorMsg: "project and location are required",
|
||||
},
|
||||
{
|
||||
name: "Vertex AI with project and location succeeds",
|
||||
entries: map[string]config.ProviderEntry{
|
||||
"vertex": {
|
||||
Type: "vertexai",
|
||||
Project: "my-project",
|
||||
Location: "us-central1",
|
||||
},
|
||||
},
|
||||
models: []config.ModelEntry{
|
||||
{Name: "gemini-pro-vertex", Provider: "vertex"},
|
||||
},
|
||||
validate: func(t *testing.T, reg *Registry) {
|
||||
assert.Len(t, reg.providers, 1)
|
||||
assert.Contains(t, reg.providers, "vertex")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unknown provider type returns error",
|
||||
entries: map[string]config.ProviderEntry{
|
||||
"unknown": {
|
||||
Type: "unknown-type",
|
||||
APIKey: "test-key",
|
||||
},
|
||||
},
|
||||
models: []config.ModelEntry{},
|
||||
expectError: true,
|
||||
errorMsg: "unknown provider type",
|
||||
},
|
||||
{
|
||||
name: "provider with no API key is skipped",
|
||||
entries: map[string]config.ProviderEntry{
|
||||
"openai-no-key": {
|
||||
Type: "openai",
|
||||
APIKey: "",
|
||||
},
|
||||
"anthropic-with-key": {
|
||||
Type: "anthropic",
|
||||
APIKey: "sk-ant-test",
|
||||
},
|
||||
},
|
||||
models: []config.ModelEntry{
|
||||
{Name: "claude-3", Provider: "anthropic-with-key"},
|
||||
},
|
||||
validate: func(t *testing.T, reg *Registry) {
|
||||
assert.Len(t, reg.providers, 1)
|
||||
assert.Contains(t, reg.providers, "anthropic-with-key")
|
||||
assert.NotContains(t, reg.providers, "openai-no-key")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "model with provider_model_id",
|
||||
entries: map[string]config.ProviderEntry{
|
||||
"azure": {
|
||||
Type: "azureopenai",
|
||||
APIKey: "test-key",
|
||||
Endpoint: "https://test.openai.azure.com",
|
||||
},
|
||||
},
|
||||
models: []config.ModelEntry{
|
||||
{
|
||||
Name: "gpt-4",
|
||||
Provider: "azure",
|
||||
ProviderModelID: "gpt-4-deployment-name",
|
||||
},
|
||||
},
|
||||
validate: func(t *testing.T, reg *Registry) {
|
||||
assert.Equal(t, "gpt-4-deployment-name", reg.providerModelIDs["gpt-4"])
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reg, err := NewRegistry(tt.entries, tt.models)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorMsg != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, reg)
|
||||
|
||||
if tt.validate != nil {
|
||||
tt.validate(t, reg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Get(t *testing.T) {
|
||||
reg, err := NewRegistry(
|
||||
map[string]config.ProviderEntry{
|
||||
"openai": {
|
||||
Type: "openai",
|
||||
APIKey: "sk-test",
|
||||
},
|
||||
"anthropic": {
|
||||
Type: "anthropic",
|
||||
APIKey: "sk-ant-test",
|
||||
},
|
||||
},
|
||||
[]config.ModelEntry{
|
||||
{Name: "gpt-4", Provider: "openai"},
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
providerKey string
|
||||
expectFound bool
|
||||
validate func(t *testing.T, p Provider)
|
||||
}{
|
||||
{
|
||||
name: "existing provider",
|
||||
providerKey: "openai",
|
||||
expectFound: true,
|
||||
validate: func(t *testing.T, p Provider) {
|
||||
assert.Equal(t, "openai", p.Name())
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "another existing provider",
|
||||
providerKey: "anthropic",
|
||||
expectFound: true,
|
||||
validate: func(t *testing.T, p Provider) {
|
||||
assert.Equal(t, "anthropic", p.Name())
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nonexistent provider",
|
||||
providerKey: "nonexistent",
|
||||
expectFound: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p, found := reg.Get(tt.providerKey)
|
||||
|
||||
if tt.expectFound {
|
||||
assert.True(t, found)
|
||||
require.NotNil(t, p)
|
||||
if tt.validate != nil {
|
||||
tt.validate(t, p)
|
||||
}
|
||||
} else {
|
||||
assert.False(t, found)
|
||||
assert.Nil(t, p)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Models(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
models []config.ModelEntry
|
||||
validate func(t *testing.T, models []struct{ Provider, Model string })
|
||||
}{
|
||||
{
|
||||
name: "single model",
|
||||
models: []config.ModelEntry{
|
||||
{Name: "gpt-4", Provider: "openai"},
|
||||
},
|
||||
validate: func(t *testing.T, models []struct{ Provider, Model string }) {
|
||||
require.Len(t, models, 1)
|
||||
assert.Equal(t, "gpt-4", models[0].Model)
|
||||
assert.Equal(t, "openai", models[0].Provider)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple models",
|
||||
models: []config.ModelEntry{
|
||||
{Name: "gpt-4", Provider: "openai"},
|
||||
{Name: "claude-3", Provider: "anthropic"},
|
||||
{Name: "gemini-pro", Provider: "google"},
|
||||
},
|
||||
validate: func(t *testing.T, models []struct{ Provider, Model string }) {
|
||||
require.Len(t, models, 3)
|
||||
modelNames := make([]string, len(models))
|
||||
for i, m := range models {
|
||||
modelNames[i] = m.Model
|
||||
}
|
||||
assert.Contains(t, modelNames, "gpt-4")
|
||||
assert.Contains(t, modelNames, "claude-3")
|
||||
assert.Contains(t, modelNames, "gemini-pro")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no models",
|
||||
models: []config.ModelEntry{},
|
||||
validate: func(t *testing.T, models []struct{ Provider, Model string }) {
|
||||
assert.Len(t, models, 0)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reg, err := NewRegistry(
|
||||
map[string]config.ProviderEntry{
|
||||
"openai": {
|
||||
Type: "openai",
|
||||
APIKey: "sk-test",
|
||||
},
|
||||
"anthropic": {
|
||||
Type: "anthropic",
|
||||
APIKey: "sk-ant-test",
|
||||
},
|
||||
"google": {
|
||||
Type: "google",
|
||||
APIKey: "test-key",
|
||||
},
|
||||
},
|
||||
tt.models,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
models := reg.Models()
|
||||
if tt.validate != nil {
|
||||
tt.validate(t, models)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_ResolveModelID(t *testing.T) {
|
||||
reg, err := NewRegistry(
|
||||
map[string]config.ProviderEntry{
|
||||
"openai": {
|
||||
Type: "openai",
|
||||
APIKey: "sk-test",
|
||||
},
|
||||
"azure": {
|
||||
Type: "azureopenai",
|
||||
APIKey: "test-key",
|
||||
Endpoint: "https://test.openai.azure.com",
|
||||
},
|
||||
},
|
||||
[]config.ModelEntry{
|
||||
{Name: "gpt-4", Provider: "openai"},
|
||||
{Name: "gpt-4-azure", Provider: "azure", ProviderModelID: "gpt-4-deployment"},
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
modelName string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "model without provider_model_id returns model name",
|
||||
modelName: "gpt-4",
|
||||
expected: "gpt-4",
|
||||
},
|
||||
{
|
||||
name: "model with provider_model_id returns provider_model_id",
|
||||
modelName: "gpt-4-azure",
|
||||
expected: "gpt-4-deployment",
|
||||
},
|
||||
{
|
||||
name: "unknown model returns model name",
|
||||
modelName: "unknown-model",
|
||||
expected: "unknown-model",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := reg.ResolveModelID(tt.modelName)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Default(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupReg func() *Registry
|
||||
modelName string
|
||||
expectError bool
|
||||
errorMsg string
|
||||
validate func(t *testing.T, p Provider)
|
||||
}{
|
||||
{
|
||||
name: "returns provider for known model",
|
||||
setupReg: func() *Registry {
|
||||
reg, _ := NewRegistry(
|
||||
map[string]config.ProviderEntry{
|
||||
"openai": {
|
||||
Type: "openai",
|
||||
APIKey: "sk-test",
|
||||
},
|
||||
"anthropic": {
|
||||
Type: "anthropic",
|
||||
APIKey: "sk-ant-test",
|
||||
},
|
||||
},
|
||||
[]config.ModelEntry{
|
||||
{Name: "gpt-4", Provider: "openai"},
|
||||
{Name: "claude-3", Provider: "anthropic"},
|
||||
},
|
||||
)
|
||||
return reg
|
||||
},
|
||||
modelName: "gpt-4",
|
||||
validate: func(t *testing.T, p Provider) {
|
||||
assert.Equal(t, "openai", p.Name())
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "returns first provider for unknown model",
|
||||
setupReg: func() *Registry {
|
||||
reg, _ := NewRegistry(
|
||||
map[string]config.ProviderEntry{
|
||||
"openai": {
|
||||
Type: "openai",
|
||||
APIKey: "sk-test",
|
||||
},
|
||||
},
|
||||
[]config.ModelEntry{
|
||||
{Name: "gpt-4", Provider: "openai"},
|
||||
},
|
||||
)
|
||||
return reg
|
||||
},
|
||||
modelName: "unknown-model",
|
||||
validate: func(t *testing.T, p Provider) {
|
||||
assert.NotNil(t, p)
|
||||
// Should return first available provider
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "returns first provider for empty model name",
|
||||
setupReg: func() *Registry {
|
||||
reg, _ := NewRegistry(
|
||||
map[string]config.ProviderEntry{
|
||||
"openai": {
|
||||
Type: "openai",
|
||||
APIKey: "sk-test",
|
||||
},
|
||||
},
|
||||
[]config.ModelEntry{
|
||||
{Name: "gpt-4", Provider: "openai"},
|
||||
},
|
||||
)
|
||||
return reg
|
||||
},
|
||||
modelName: "",
|
||||
validate: func(t *testing.T, p Provider) {
|
||||
assert.NotNil(t, p)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reg := tt.setupReg()
|
||||
p, err := reg.Default(tt.modelName)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorMsg != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, p)
|
||||
|
||||
if tt.validate != nil {
|
||||
tt.validate(t, p)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildProvider(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
entry config.ProviderEntry
|
||||
expectError bool
|
||||
errorMsg string
|
||||
expectNil bool
|
||||
validate func(t *testing.T, p Provider)
|
||||
}{
|
||||
{
|
||||
name: "OpenAI provider",
|
||||
entry: config.ProviderEntry{
|
||||
Type: "openai",
|
||||
APIKey: "sk-test",
|
||||
},
|
||||
validate: func(t *testing.T, p Provider) {
|
||||
assert.Equal(t, "openai", p.Name())
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "OpenAI provider with custom endpoint",
|
||||
entry: config.ProviderEntry{
|
||||
Type: "openai",
|
||||
APIKey: "sk-test",
|
||||
Endpoint: "https://custom.openai.com",
|
||||
},
|
||||
validate: func(t *testing.T, p Provider) {
|
||||
assert.Equal(t, "openai", p.Name())
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Anthropic provider",
|
||||
entry: config.ProviderEntry{
|
||||
Type: "anthropic",
|
||||
APIKey: "sk-ant-test",
|
||||
},
|
||||
validate: func(t *testing.T, p Provider) {
|
||||
assert.Equal(t, "anthropic", p.Name())
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Google provider",
|
||||
entry: config.ProviderEntry{
|
||||
Type: "google",
|
||||
APIKey: "test-key",
|
||||
},
|
||||
validate: func(t *testing.T, p Provider) {
|
||||
assert.Equal(t, "google", p.Name())
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "provider without API key returns nil",
|
||||
entry: config.ProviderEntry{
|
||||
Type: "openai",
|
||||
APIKey: "",
|
||||
},
|
||||
expectNil: true,
|
||||
},
|
||||
{
|
||||
name: "unknown provider type",
|
||||
entry: config.ProviderEntry{
|
||||
Type: "unknown",
|
||||
APIKey: "test-key",
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "unknown provider type",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p, err := buildProvider(tt.entry)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorMsg != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
if tt.expectNil {
|
||||
assert.Nil(t, p)
|
||||
return
|
||||
}
|
||||
|
||||
require.NotNil(t, p)
|
||||
|
||||
if tt.validate != nil {
|
||||
tt.validate(t, p)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user