Add tests
This commit is contained in:
330
internal/server/mocks_test.go
Normal file
330
internal/server/mocks_test.go
Normal file
@@ -0,0 +1,330 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"reflect"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ajac-zero/latticelm/internal/api"
|
||||
"github.com/ajac-zero/latticelm/internal/config"
|
||||
"github.com/ajac-zero/latticelm/internal/conversation"
|
||||
"github.com/ajac-zero/latticelm/internal/providers"
|
||||
)
|
||||
|
||||
// mockProvider implements providers.Provider for testing
|
||||
type mockProvider struct {
|
||||
name string
|
||||
generateFunc func(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (*api.ProviderResult, error)
|
||||
streamFunc func(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (<-chan *api.ProviderStreamDelta, <-chan error)
|
||||
generateCalled int
|
||||
streamCalled int
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newMockProvider(name string) *mockProvider {
|
||||
return &mockProvider{
|
||||
name: name,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockProvider) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *mockProvider) Generate(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (*api.ProviderResult, error) {
|
||||
m.mu.Lock()
|
||||
m.generateCalled++
|
||||
m.mu.Unlock()
|
||||
|
||||
if m.generateFunc != nil {
|
||||
return m.generateFunc(ctx, messages, req)
|
||||
}
|
||||
return &api.ProviderResult{
|
||||
ID: "mock-id",
|
||||
Model: req.Model,
|
||||
Text: "mock response",
|
||||
Usage: api.Usage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalTokens: 30,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockProvider) GenerateStream(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (<-chan *api.ProviderStreamDelta, <-chan error) {
|
||||
m.mu.Lock()
|
||||
m.streamCalled++
|
||||
m.mu.Unlock()
|
||||
|
||||
if m.streamFunc != nil {
|
||||
return m.streamFunc(ctx, messages, req)
|
||||
}
|
||||
|
||||
// Default behavior: send a simple text stream
|
||||
deltaChan := make(chan *api.ProviderStreamDelta, 3)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
defer close(deltaChan)
|
||||
defer close(errChan)
|
||||
|
||||
deltaChan <- &api.ProviderStreamDelta{
|
||||
Model: req.Model,
|
||||
Text: "Hello",
|
||||
}
|
||||
deltaChan <- &api.ProviderStreamDelta{
|
||||
Text: " world",
|
||||
}
|
||||
deltaChan <- &api.ProviderStreamDelta{
|
||||
Done: true,
|
||||
}
|
||||
}()
|
||||
|
||||
return deltaChan, errChan
|
||||
}
|
||||
|
||||
func (m *mockProvider) getGenerateCalled() int {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.generateCalled
|
||||
}
|
||||
|
||||
func (m *mockProvider) getStreamCalled() int {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.streamCalled
|
||||
}
|
||||
|
||||
// buildTestRegistry creates a providers.Registry for testing with mock providers
|
||||
// Uses reflection to inject mock providers into the registry
|
||||
func buildTestRegistry(mockProviders map[string]providers.Provider, modelConfigs []config.ModelEntry) *providers.Registry {
|
||||
// Create empty registry
|
||||
reg := &providers.Registry{}
|
||||
|
||||
// Use reflection to set private fields
|
||||
regValue := reflect.ValueOf(reg).Elem()
|
||||
|
||||
// Set providers field
|
||||
providersField := regValue.FieldByName("providers")
|
||||
providersPtr := unsafe.Pointer(providersField.UnsafeAddr())
|
||||
*(*map[string]providers.Provider)(providersPtr) = mockProviders
|
||||
|
||||
// Set modelList field
|
||||
modelListField := regValue.FieldByName("modelList")
|
||||
modelListPtr := unsafe.Pointer(modelListField.UnsafeAddr())
|
||||
*(*[]config.ModelEntry)(modelListPtr) = modelConfigs
|
||||
|
||||
// Set models map (model name -> provider name)
|
||||
modelsField := regValue.FieldByName("models")
|
||||
modelsPtr := unsafe.Pointer(modelsField.UnsafeAddr())
|
||||
modelsMap := make(map[string]string)
|
||||
for _, m := range modelConfigs {
|
||||
modelsMap[m.Name] = m.Provider
|
||||
}
|
||||
*(*map[string]string)(modelsPtr) = modelsMap
|
||||
|
||||
// Set providerModelIDs map
|
||||
providerModelIDsField := regValue.FieldByName("providerModelIDs")
|
||||
providerModelIDsPtr := unsafe.Pointer(providerModelIDsField.UnsafeAddr())
|
||||
providerModelIDsMap := make(map[string]string)
|
||||
for _, m := range modelConfigs {
|
||||
if m.ProviderModelID != "" {
|
||||
providerModelIDsMap[m.Name] = m.ProviderModelID
|
||||
}
|
||||
}
|
||||
*(*map[string]string)(providerModelIDsPtr) = providerModelIDsMap
|
||||
|
||||
return reg
|
||||
}
|
||||
|
||||
// mockConversationStore implements conversation.Store for testing
|
||||
type mockConversationStore struct {
|
||||
conversations map[string]*conversation.Conversation
|
||||
createErr error
|
||||
getErr error
|
||||
appendErr error
|
||||
deleteErr error
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newMockConversationStore() *mockConversationStore {
|
||||
return &mockConversationStore{
|
||||
conversations: make(map[string]*conversation.Conversation),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockConversationStore) Get(id string) (*conversation.Conversation, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.getErr != nil {
|
||||
return nil, m.getErr
|
||||
}
|
||||
conv, ok := m.conversations[id]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
return conv, nil
|
||||
}
|
||||
|
||||
func (m *mockConversationStore) Create(id string, model string, messages []api.Message) (*conversation.Conversation, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.createErr != nil {
|
||||
return nil, m.createErr
|
||||
}
|
||||
|
||||
conv := &conversation.Conversation{
|
||||
ID: id,
|
||||
Model: model,
|
||||
Messages: messages,
|
||||
}
|
||||
m.conversations[id] = conv
|
||||
return conv, nil
|
||||
}
|
||||
|
||||
func (m *mockConversationStore) Append(id string, messages ...api.Message) (*conversation.Conversation, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.appendErr != nil {
|
||||
return nil, m.appendErr
|
||||
}
|
||||
|
||||
conv, ok := m.conversations[id]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
conv.Messages = append(conv.Messages, messages...)
|
||||
return conv, nil
|
||||
}
|
||||
|
||||
func (m *mockConversationStore) Delete(id string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.deleteErr != nil {
|
||||
return m.deleteErr
|
||||
}
|
||||
delete(m.conversations, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockConversationStore) Size() int {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return len(m.conversations)
|
||||
}
|
||||
|
||||
func (m *mockConversationStore) setConversation(id string, conv *conversation.Conversation) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.conversations[id] = conv
|
||||
}
|
||||
|
||||
// mockLogger captures log output for testing
|
||||
type mockLogger struct {
|
||||
logs []string
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newMockLogger() *mockLogger {
|
||||
return &mockLogger{
|
||||
logs: []string{},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockLogger) Printf(format string, args ...interface{}) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.logs = append(m.logs, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (m *mockLogger) getLogs() []string {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return append([]string{}, m.logs...)
|
||||
}
|
||||
|
||||
func (m *mockLogger) asLogger() *log.Logger {
|
||||
return log.New(m, "", 0)
|
||||
}
|
||||
|
||||
func (m *mockLogger) Write(p []byte) (n int, err error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.logs = append(m.logs, string(p))
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// mockRegistry is a simple mock for providers.Registry
|
||||
type mockRegistry struct {
|
||||
providers map[string]providers.Provider
|
||||
models map[string]string // model name -> provider name
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func newMockRegistry() *mockRegistry {
|
||||
return &mockRegistry{
|
||||
providers: make(map[string]providers.Provider),
|
||||
models: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockRegistry) Get(name string) (providers.Provider, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
p, ok := m.providers[name]
|
||||
return p, ok
|
||||
}
|
||||
|
||||
func (m *mockRegistry) Default(model string) (providers.Provider, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
providerName, ok := m.models[model]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no provider configured for model %s", model)
|
||||
}
|
||||
|
||||
p, ok := m.providers[providerName]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("provider %s not found", providerName)
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (m *mockRegistry) Models() []struct{ Provider, Model string } {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
var models []struct{ Provider, Model string }
|
||||
for modelName, providerName := range m.models {
|
||||
models = append(models, struct{ Provider, Model string }{
|
||||
Model: modelName,
|
||||
Provider: providerName,
|
||||
})
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
||||
func (m *mockRegistry) ResolveModelID(model string) string {
|
||||
// Simple implementation - just return the model name as-is
|
||||
return model
|
||||
}
|
||||
|
||||
func (m *mockRegistry) addProvider(name string, provider providers.Provider) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.providers[name] = provider
|
||||
}
|
||||
|
||||
func (m *mockRegistry) addModel(model, provider string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.models[model] = provider
|
||||
}
|
||||
Reference in New Issue
Block a user