337 lines
7.9 KiB
Go
337 lines
7.9 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log/slog"
|
|
"reflect"
|
|
"sync"
|
|
"unsafe"
|
|
|
|
"github.com/ajac-zero/latticelm/internal/api"
|
|
"github.com/ajac-zero/latticelm/internal/config"
|
|
"github.com/ajac-zero/latticelm/internal/conversation"
|
|
"github.com/ajac-zero/latticelm/internal/providers"
|
|
)
|
|
|
|
// mockProvider implements providers.Provider for testing
|
|
type mockProvider struct {
|
|
name string
|
|
generateFunc func(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (*api.ProviderResult, error)
|
|
streamFunc func(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (<-chan *api.ProviderStreamDelta, <-chan error)
|
|
generateCalled int
|
|
streamCalled int
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func newMockProvider(name string) *mockProvider {
|
|
return &mockProvider{
|
|
name: name,
|
|
}
|
|
}
|
|
|
|
func (m *mockProvider) Name() string {
|
|
return m.name
|
|
}
|
|
|
|
func (m *mockProvider) Generate(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (*api.ProviderResult, error) {
|
|
m.mu.Lock()
|
|
m.generateCalled++
|
|
m.mu.Unlock()
|
|
|
|
if m.generateFunc != nil {
|
|
return m.generateFunc(ctx, messages, req)
|
|
}
|
|
return &api.ProviderResult{
|
|
ID: "mock-id",
|
|
Model: req.Model,
|
|
Text: "mock response",
|
|
Usage: api.Usage{
|
|
InputTokens: 10,
|
|
OutputTokens: 20,
|
|
TotalTokens: 30,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func (m *mockProvider) GenerateStream(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (<-chan *api.ProviderStreamDelta, <-chan error) {
|
|
m.mu.Lock()
|
|
m.streamCalled++
|
|
m.mu.Unlock()
|
|
|
|
if m.streamFunc != nil {
|
|
return m.streamFunc(ctx, messages, req)
|
|
}
|
|
|
|
// Default behavior: send a simple text stream
|
|
deltaChan := make(chan *api.ProviderStreamDelta, 3)
|
|
errChan := make(chan error, 1)
|
|
|
|
go func() {
|
|
defer close(deltaChan)
|
|
defer close(errChan)
|
|
|
|
deltaChan <- &api.ProviderStreamDelta{
|
|
Model: req.Model,
|
|
Text: "Hello",
|
|
}
|
|
deltaChan <- &api.ProviderStreamDelta{
|
|
Text: " world",
|
|
}
|
|
deltaChan <- &api.ProviderStreamDelta{
|
|
Done: true,
|
|
}
|
|
}()
|
|
|
|
return deltaChan, errChan
|
|
}
|
|
|
|
func (m *mockProvider) getGenerateCalled() int {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
return m.generateCalled
|
|
}
|
|
|
|
func (m *mockProvider) getStreamCalled() int {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
return m.streamCalled
|
|
}
|
|
|
|
// buildTestRegistry creates a providers.Registry for testing with mock providers
|
|
// Uses reflection to inject mock providers into the registry
|
|
func buildTestRegistry(mockProviders map[string]providers.Provider, modelConfigs []config.ModelEntry) *providers.Registry {
|
|
// Create empty registry
|
|
reg := &providers.Registry{}
|
|
|
|
// Use reflection to set private fields
|
|
regValue := reflect.ValueOf(reg).Elem()
|
|
|
|
// Set providers field
|
|
providersField := regValue.FieldByName("providers")
|
|
providersPtr := unsafe.Pointer(providersField.UnsafeAddr())
|
|
*(*map[string]providers.Provider)(providersPtr) = mockProviders
|
|
|
|
// Set modelList field
|
|
modelListField := regValue.FieldByName("modelList")
|
|
modelListPtr := unsafe.Pointer(modelListField.UnsafeAddr())
|
|
*(*[]config.ModelEntry)(modelListPtr) = modelConfigs
|
|
|
|
// Set models map (model name -> provider name)
|
|
modelsField := regValue.FieldByName("models")
|
|
modelsPtr := unsafe.Pointer(modelsField.UnsafeAddr())
|
|
modelsMap := make(map[string]string)
|
|
for _, m := range modelConfigs {
|
|
modelsMap[m.Name] = m.Provider
|
|
}
|
|
*(*map[string]string)(modelsPtr) = modelsMap
|
|
|
|
// Set providerModelIDs map
|
|
providerModelIDsField := regValue.FieldByName("providerModelIDs")
|
|
providerModelIDsPtr := unsafe.Pointer(providerModelIDsField.UnsafeAddr())
|
|
providerModelIDsMap := make(map[string]string)
|
|
for _, m := range modelConfigs {
|
|
if m.ProviderModelID != "" {
|
|
providerModelIDsMap[m.Name] = m.ProviderModelID
|
|
}
|
|
}
|
|
*(*map[string]string)(providerModelIDsPtr) = providerModelIDsMap
|
|
|
|
return reg
|
|
}
|
|
|
|
// mockConversationStore implements conversation.Store for testing
|
|
type mockConversationStore struct {
|
|
conversations map[string]*conversation.Conversation
|
|
createErr error
|
|
getErr error
|
|
appendErr error
|
|
deleteErr error
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func newMockConversationStore() *mockConversationStore {
|
|
return &mockConversationStore{
|
|
conversations: make(map[string]*conversation.Conversation),
|
|
}
|
|
}
|
|
|
|
func (m *mockConversationStore) Get(id string) (*conversation.Conversation, error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
if m.getErr != nil {
|
|
return nil, m.getErr
|
|
}
|
|
conv, ok := m.conversations[id]
|
|
if !ok {
|
|
return nil, nil
|
|
}
|
|
return conv, nil
|
|
}
|
|
|
|
func (m *mockConversationStore) Create(id string, model string, messages []api.Message) (*conversation.Conversation, error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
if m.createErr != nil {
|
|
return nil, m.createErr
|
|
}
|
|
|
|
conv := &conversation.Conversation{
|
|
ID: id,
|
|
Model: model,
|
|
Messages: messages,
|
|
}
|
|
m.conversations[id] = conv
|
|
return conv, nil
|
|
}
|
|
|
|
func (m *mockConversationStore) Append(id string, messages ...api.Message) (*conversation.Conversation, error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
if m.appendErr != nil {
|
|
return nil, m.appendErr
|
|
}
|
|
|
|
conv, ok := m.conversations[id]
|
|
if !ok {
|
|
return nil, nil
|
|
}
|
|
conv.Messages = append(conv.Messages, messages...)
|
|
return conv, nil
|
|
}
|
|
|
|
func (m *mockConversationStore) Delete(id string) error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
if m.deleteErr != nil {
|
|
return m.deleteErr
|
|
}
|
|
delete(m.conversations, id)
|
|
return nil
|
|
}
|
|
|
|
func (m *mockConversationStore) Size() int {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
return len(m.conversations)
|
|
}
|
|
|
|
func (m *mockConversationStore) Close() error {
|
|
return nil
|
|
}
|
|
|
|
func (m *mockConversationStore) setConversation(id string, conv *conversation.Conversation) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
m.conversations[id] = conv
|
|
}
|
|
|
|
// mockLogger captures log output for testing
|
|
type mockLogger struct {
|
|
logs []string
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func newMockLogger() *mockLogger {
|
|
return &mockLogger{
|
|
logs: []string{},
|
|
}
|
|
}
|
|
|
|
func (m *mockLogger) Printf(format string, args ...interface{}) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
m.logs = append(m.logs, fmt.Sprintf(format, args...))
|
|
}
|
|
|
|
func (m *mockLogger) getLogs() []string {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
return append([]string{}, m.logs...)
|
|
}
|
|
|
|
func (m *mockLogger) asLogger() *slog.Logger {
|
|
return slog.New(slog.NewTextHandler(m, &slog.HandlerOptions{
|
|
Level: slog.LevelDebug,
|
|
}))
|
|
}
|
|
|
|
func (m *mockLogger) Write(p []byte) (n int, err error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
m.logs = append(m.logs, string(p))
|
|
return len(p), nil
|
|
}
|
|
|
|
// mockRegistry is a simple mock for providers.Registry
|
|
type mockRegistry struct {
|
|
providers map[string]providers.Provider
|
|
models map[string]string // model name -> provider name
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
func newMockRegistry() *mockRegistry {
|
|
return &mockRegistry{
|
|
providers: make(map[string]providers.Provider),
|
|
models: make(map[string]string),
|
|
}
|
|
}
|
|
|
|
func (m *mockRegistry) Get(name string) (providers.Provider, bool) {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
p, ok := m.providers[name]
|
|
return p, ok
|
|
}
|
|
|
|
func (m *mockRegistry) Default(model string) (providers.Provider, error) {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
|
|
providerName, ok := m.models[model]
|
|
if !ok {
|
|
return nil, fmt.Errorf("no provider configured for model %s", model)
|
|
}
|
|
|
|
p, ok := m.providers[providerName]
|
|
if !ok {
|
|
return nil, fmt.Errorf("provider %s not found", providerName)
|
|
}
|
|
return p, nil
|
|
}
|
|
|
|
func (m *mockRegistry) Models() []struct{ Provider, Model string } {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
|
|
var models []struct{ Provider, Model string }
|
|
for modelName, providerName := range m.models {
|
|
models = append(models, struct{ Provider, Model string }{
|
|
Model: modelName,
|
|
Provider: providerName,
|
|
})
|
|
}
|
|
return models
|
|
}
|
|
|
|
func (m *mockRegistry) ResolveModelID(model string) string {
|
|
// Simple implementation - just return the model name as-is
|
|
return model
|
|
}
|
|
|
|
func (m *mockRegistry) addProvider(name string, provider providers.Provider) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
m.providers[name] = provider
|
|
}
|
|
|
|
func (m *mockRegistry) addModel(model, provider string) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
m.models[model] = provider
|
|
}
|