1161 lines
34 KiB
Go
1161 lines
34 KiB
Go
package server
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/ajac-zero/latticelm/internal/api"
|
|
"github.com/ajac-zero/latticelm/internal/conversation"
|
|
)
|
|
|
|
func TestHandleModels(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
method string
|
|
setupServer func() *GatewayServer
|
|
expectStatus int
|
|
validate func(t *testing.T, body string)
|
|
}{
|
|
{
|
|
name: "GET returns model list",
|
|
method: http.MethodGet,
|
|
setupServer: func() *GatewayServer {
|
|
registry := newMockRegistry()
|
|
registry.addModel("gpt-4", "openai")
|
|
registry.addModel("claude-3", "anthropic")
|
|
registry.addProvider("openai", newMockProvider("openai"))
|
|
registry.addProvider("anthropic", newMockProvider("anthropic"))
|
|
return New(registry, newMockConversationStore(), newMockLogger().asLogger())
|
|
},
|
|
expectStatus: http.StatusOK,
|
|
validate: func(t *testing.T, body string) {
|
|
var resp api.ModelsResponse
|
|
err := json.Unmarshal([]byte(body), &resp)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "list", resp.Object)
|
|
assert.Len(t, resp.Data, 2)
|
|
},
|
|
},
|
|
{
|
|
name: "POST returns 405",
|
|
method: http.MethodPost,
|
|
setupServer: func() *GatewayServer {
|
|
registry := newMockRegistry()
|
|
return New(registry, newMockConversationStore(), newMockLogger().asLogger())
|
|
},
|
|
expectStatus: http.StatusMethodNotAllowed,
|
|
},
|
|
{
|
|
name: "empty registry returns empty list",
|
|
method: http.MethodGet,
|
|
setupServer: func() *GatewayServer {
|
|
registry := newMockRegistry()
|
|
return New(registry, newMockConversationStore(), newMockLogger().asLogger())
|
|
},
|
|
expectStatus: http.StatusOK,
|
|
validate: func(t *testing.T, body string) {
|
|
var resp api.ModelsResponse
|
|
err := json.Unmarshal([]byte(body), &resp)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "list", resp.Object)
|
|
assert.Len(t, resp.Data, 0)
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
server := tt.setupServer()
|
|
req := httptest.NewRequest(tt.method, "/v1/models", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
server.handleModels(rec, req)
|
|
|
|
assert.Equal(t, tt.expectStatus, rec.Code)
|
|
if tt.validate != nil {
|
|
tt.validate(t, rec.Body.String())
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestHandleResponses_Validation(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
method string
|
|
body string
|
|
expectStatus int
|
|
expectBody string
|
|
}{
|
|
{
|
|
name: "GET returns 405",
|
|
method: http.MethodGet,
|
|
body: "",
|
|
expectStatus: http.StatusMethodNotAllowed,
|
|
},
|
|
{
|
|
name: "invalid JSON returns 400",
|
|
method: http.MethodPost,
|
|
body: `{invalid json}`,
|
|
expectStatus: http.StatusBadRequest,
|
|
expectBody: "invalid JSON payload",
|
|
},
|
|
{
|
|
name: "missing model returns 400",
|
|
method: http.MethodPost,
|
|
body: `{"input": "hello"}`,
|
|
expectStatus: http.StatusBadRequest,
|
|
expectBody: "model is required",
|
|
},
|
|
{
|
|
name: "missing input returns 400",
|
|
method: http.MethodPost,
|
|
body: `{"model": "gpt-4"}`,
|
|
expectStatus: http.StatusBadRequest,
|
|
expectBody: "input is required",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
registry := newMockRegistry()
|
|
server := New(registry, newMockConversationStore(), newMockLogger().asLogger())
|
|
|
|
req := httptest.NewRequest(tt.method, "/v1/responses", strings.NewReader(tt.body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
rec := httptest.NewRecorder()
|
|
|
|
server.handleResponses(rec, req)
|
|
|
|
assert.Equal(t, tt.expectStatus, rec.Code)
|
|
if tt.expectBody != "" {
|
|
assert.Contains(t, rec.Body.String(), tt.expectBody)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestHandleResponses_Sync_Success(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
requestBody string
|
|
setupMock func(p *mockProvider)
|
|
validate func(t *testing.T, resp *api.Response, store *mockConversationStore)
|
|
}{
|
|
{
|
|
name: "simple text response",
|
|
requestBody: `{"model": "gpt-4", "input": "hello"}`,
|
|
setupMock: func(p *mockProvider) {
|
|
p.generateFunc = func(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (*api.ProviderResult, error) {
|
|
return &api.ProviderResult{
|
|
Model: "gpt-4-turbo",
|
|
Text: "Hello! How can I help you?",
|
|
Usage: api.Usage{
|
|
InputTokens: 5,
|
|
OutputTokens: 10,
|
|
TotalTokens: 15,
|
|
},
|
|
}, nil
|
|
}
|
|
},
|
|
validate: func(t *testing.T, resp *api.Response, store *mockConversationStore) {
|
|
assert.Equal(t, "response", resp.Object)
|
|
assert.Equal(t, "completed", resp.Status)
|
|
assert.Equal(t, "gpt-4-turbo", resp.Model)
|
|
assert.Equal(t, "openai", resp.Provider)
|
|
require.Len(t, resp.Output, 1)
|
|
assert.Equal(t, "message", resp.Output[0].Type)
|
|
assert.Equal(t, "completed", resp.Output[0].Status)
|
|
assert.Equal(t, "assistant", resp.Output[0].Role)
|
|
require.Len(t, resp.Output[0].Content, 1)
|
|
assert.Equal(t, "output_text", resp.Output[0].Content[0].Type)
|
|
assert.Equal(t, "Hello! How can I help you?", resp.Output[0].Content[0].Text)
|
|
require.NotNil(t, resp.Usage)
|
|
assert.Equal(t, 5, resp.Usage.InputTokens)
|
|
assert.Equal(t, 10, resp.Usage.OutputTokens)
|
|
assert.Equal(t, 15, resp.Usage.TotalTokens)
|
|
assert.Equal(t, 1, store.Size())
|
|
},
|
|
},
|
|
{
|
|
name: "response with tool calls",
|
|
requestBody: `{"model": "gpt-4", "input": "what's the weather?"}`,
|
|
setupMock: func(p *mockProvider) {
|
|
p.generateFunc = func(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (*api.ProviderResult, error) {
|
|
return &api.ProviderResult{
|
|
Model: "gpt-4",
|
|
Text: "Let me check that for you.",
|
|
ToolCalls: []api.ToolCall{
|
|
{
|
|
ID: "call_123",
|
|
Name: "get_weather",
|
|
Arguments: `{"location":"San Francisco"}`,
|
|
},
|
|
},
|
|
Usage: api.Usage{
|
|
InputTokens: 10,
|
|
OutputTokens: 20,
|
|
TotalTokens: 30,
|
|
},
|
|
}, nil
|
|
}
|
|
},
|
|
validate: func(t *testing.T, resp *api.Response, store *mockConversationStore) {
|
|
assert.Equal(t, "completed", resp.Status)
|
|
require.Len(t, resp.Output, 2)
|
|
assert.Equal(t, "message", resp.Output[0].Type)
|
|
assert.Equal(t, "Let me check that for you.", resp.Output[0].Content[0].Text)
|
|
assert.Equal(t, "function_call", resp.Output[1].Type)
|
|
assert.Equal(t, "completed", resp.Output[1].Status)
|
|
assert.Equal(t, "call_123", resp.Output[1].CallID)
|
|
assert.Equal(t, "get_weather", resp.Output[1].Name)
|
|
assert.JSONEq(t, `{"location":"San Francisco"}`, resp.Output[1].Arguments)
|
|
},
|
|
},
|
|
{
|
|
name: "response with multiple tool calls",
|
|
requestBody: `{"model": "gpt-4", "input": "check NYC and LA weather"}`,
|
|
setupMock: func(p *mockProvider) {
|
|
p.generateFunc = func(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (*api.ProviderResult, error) {
|
|
return &api.ProviderResult{
|
|
Model: "gpt-4",
|
|
Text: "Checking both cities.",
|
|
ToolCalls: []api.ToolCall{
|
|
{ID: "call_1", Name: "get_weather", Arguments: `{"location":"NYC"}`},
|
|
{ID: "call_2", Name: "get_weather", Arguments: `{"location":"LA"}`},
|
|
},
|
|
}, nil
|
|
}
|
|
},
|
|
validate: func(t *testing.T, resp *api.Response, store *mockConversationStore) {
|
|
require.Len(t, resp.Output, 3)
|
|
assert.Equal(t, "message", resp.Output[0].Type)
|
|
assert.Equal(t, "function_call", resp.Output[1].Type)
|
|
assert.Equal(t, "function_call", resp.Output[2].Type)
|
|
assert.Equal(t, "call_1", resp.Output[1].CallID)
|
|
assert.Equal(t, "call_2", resp.Output[2].CallID)
|
|
},
|
|
},
|
|
{
|
|
name: "response with only tool calls (no text)",
|
|
requestBody: `{"model": "gpt-4", "input": "search"}`,
|
|
setupMock: func(p *mockProvider) {
|
|
p.generateFunc = func(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (*api.ProviderResult, error) {
|
|
return &api.ProviderResult{
|
|
Model: "gpt-4",
|
|
ToolCalls: []api.ToolCall{
|
|
{ID: "call_xyz", Name: "search", Arguments: `{}`},
|
|
},
|
|
}, nil
|
|
}
|
|
},
|
|
validate: func(t *testing.T, resp *api.Response, store *mockConversationStore) {
|
|
require.Len(t, resp.Output, 1)
|
|
assert.Equal(t, "function_call", resp.Output[0].Type)
|
|
assert.Nil(t, resp.Usage)
|
|
},
|
|
},
|
|
{
|
|
name: "response echoes request parameters",
|
|
requestBody: `{"model": "gpt-4", "input": "hi", "temperature": 0.7, "top_p": 0.9, "parallel_tool_calls": false}`,
|
|
setupMock: nil,
|
|
validate: func(t *testing.T, resp *api.Response, store *mockConversationStore) {
|
|
assert.Equal(t, 0.7, resp.Temperature)
|
|
assert.Equal(t, 0.9, resp.TopP)
|
|
assert.False(t, resp.ParallelToolCalls)
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
registry := newMockRegistry()
|
|
provider := newMockProvider("openai")
|
|
if tt.setupMock != nil {
|
|
tt.setupMock(provider)
|
|
}
|
|
registry.addProvider("openai", provider)
|
|
registry.addModel("gpt-4", "openai")
|
|
|
|
store := newMockConversationStore()
|
|
server := New(registry, store, newMockLogger().asLogger())
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(tt.requestBody))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
rec := httptest.NewRecorder()
|
|
|
|
server.handleResponses(rec, req)
|
|
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
|
|
var resp api.Response
|
|
err := json.Unmarshal(rec.Body.Bytes(), &resp)
|
|
require.NoError(t, err)
|
|
|
|
if tt.validate != nil {
|
|
tt.validate(t, &resp, store)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestHandleResponses_Sync_ConversationHistory(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
setupServer func() *GatewayServer
|
|
requestBody string
|
|
expectStatus int
|
|
expectBody string
|
|
validate func(t *testing.T, provider *mockProvider)
|
|
}{
|
|
{
|
|
name: "without previous_response_id",
|
|
setupServer: func() *GatewayServer {
|
|
registry := newMockRegistry()
|
|
provider := newMockProvider("openai")
|
|
registry.addProvider("openai", provider)
|
|
registry.addModel("gpt-4", "openai")
|
|
return New(registry, newMockConversationStore(), newMockLogger().asLogger())
|
|
},
|
|
requestBody: `{"model": "gpt-4", "input": "hello"}`,
|
|
expectStatus: http.StatusOK,
|
|
validate: func(t *testing.T, provider *mockProvider) {
|
|
assert.Equal(t, 1, provider.getGenerateCalled())
|
|
},
|
|
},
|
|
{
|
|
name: "with valid previous_response_id",
|
|
setupServer: func() *GatewayServer {
|
|
registry := newMockRegistry()
|
|
provider := newMockProvider("openai")
|
|
provider.generateFunc = func(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (*api.ProviderResult, error) {
|
|
// Should receive history + new message
|
|
if len(messages) != 2 {
|
|
return nil, fmt.Errorf("expected 2 messages, got %d", len(messages))
|
|
}
|
|
return &api.ProviderResult{
|
|
Model: req.Model,
|
|
Text: "response",
|
|
}, nil
|
|
}
|
|
registry.addProvider("openai", provider)
|
|
registry.addModel("gpt-4", "openai")
|
|
|
|
store := newMockConversationStore()
|
|
store.setConversation("prev-123", &conversation.Conversation{
|
|
ID: "prev-123",
|
|
Model: "gpt-4",
|
|
Messages: []api.Message{
|
|
{
|
|
Role: "user",
|
|
Content: []api.ContentBlock{{Type: "input_text", Text: "previous message"}},
|
|
},
|
|
},
|
|
})
|
|
return New(registry, store, newMockLogger().asLogger())
|
|
},
|
|
requestBody: `{"model": "gpt-4", "input": "new message", "previous_response_id": "prev-123"}`,
|
|
expectStatus: http.StatusOK,
|
|
},
|
|
{
|
|
name: "with instructions prepends developer message",
|
|
setupServer: func() *GatewayServer {
|
|
registry := newMockRegistry()
|
|
provider := newMockProvider("openai")
|
|
provider.generateFunc = func(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (*api.ProviderResult, error) {
|
|
// Should have developer message first
|
|
if len(messages) < 1 || messages[0].Role != "developer" {
|
|
return nil, fmt.Errorf("expected developer message first")
|
|
}
|
|
if messages[0].Content[0].Text != "Be helpful" {
|
|
return nil, fmt.Errorf("unexpected instructions: %s", messages[0].Content[0].Text)
|
|
}
|
|
return &api.ProviderResult{
|
|
Model: req.Model,
|
|
Text: "response",
|
|
}, nil
|
|
}
|
|
registry.addProvider("openai", provider)
|
|
registry.addModel("gpt-4", "openai")
|
|
return New(registry, newMockConversationStore(), newMockLogger().asLogger())
|
|
},
|
|
requestBody: `{"model": "gpt-4", "input": "hello", "instructions": "Be helpful"}`,
|
|
expectStatus: http.StatusOK,
|
|
},
|
|
{
|
|
name: "nonexistent conversation returns 404",
|
|
setupServer: func() *GatewayServer {
|
|
registry := newMockRegistry()
|
|
provider := newMockProvider("openai")
|
|
registry.addProvider("openai", provider)
|
|
registry.addModel("gpt-4", "openai")
|
|
return New(registry, newMockConversationStore(), newMockLogger().asLogger())
|
|
},
|
|
requestBody: `{"model": "gpt-4", "input": "hello", "previous_response_id": "nonexistent"}`,
|
|
expectStatus: http.StatusNotFound,
|
|
expectBody: "conversation not found",
|
|
},
|
|
{
|
|
name: "conversation store error returns 500",
|
|
setupServer: func() *GatewayServer {
|
|
registry := newMockRegistry()
|
|
provider := newMockProvider("openai")
|
|
registry.addProvider("openai", provider)
|
|
registry.addModel("gpt-4", "openai")
|
|
|
|
store := newMockConversationStore()
|
|
store.getErr = fmt.Errorf("database error")
|
|
return New(registry, store, newMockLogger().asLogger())
|
|
},
|
|
requestBody: `{"model": "gpt-4", "input": "hello", "previous_response_id": "any"}`,
|
|
expectStatus: http.StatusInternalServerError,
|
|
expectBody: "error retrieving conversation",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
server := tt.setupServer()
|
|
|
|
// Get the provider for validation if needed
|
|
var provider *mockProvider
|
|
if registry, ok := server.registry.(*mockRegistry); ok {
|
|
if p, exists := registry.Get("openai"); exists {
|
|
provider = p.(*mockProvider)
|
|
}
|
|
}
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(tt.requestBody))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
rec := httptest.NewRecorder()
|
|
|
|
server.handleResponses(rec, req)
|
|
|
|
assert.Equal(t, tt.expectStatus, rec.Code)
|
|
if tt.expectBody != "" {
|
|
assert.Contains(t, rec.Body.String(), tt.expectBody)
|
|
}
|
|
if tt.validate != nil && provider != nil {
|
|
tt.validate(t, provider)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestHandleResponses_Sync_ProviderErrors(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
setupMock func(p *mockProvider)
|
|
expectStatus int
|
|
expectBody string
|
|
}{
|
|
{
|
|
name: "provider returns error",
|
|
setupMock: func(p *mockProvider) {
|
|
p.generateFunc = func(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (*api.ProviderResult, error) {
|
|
return nil, fmt.Errorf("rate limit exceeded")
|
|
}
|
|
},
|
|
expectStatus: http.StatusBadGateway,
|
|
expectBody: "provider error",
|
|
},
|
|
{
|
|
name: "provider not configured",
|
|
setupMock: func(p *mockProvider) {
|
|
// Don't set up this provider, request will use explicit provider
|
|
},
|
|
expectStatus: http.StatusBadGateway,
|
|
expectBody: "provider nonexistent not configured",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
registry := newMockRegistry()
|
|
provider := newMockProvider("openai")
|
|
if tt.setupMock != nil {
|
|
tt.setupMock(provider)
|
|
}
|
|
registry.addProvider("openai", provider)
|
|
registry.addModel("gpt-4", "openai")
|
|
|
|
server := New(registry, newMockConversationStore(), newMockLogger().asLogger())
|
|
|
|
body := `{"model": "gpt-4", "input": "hello"}`
|
|
if tt.name == "provider not configured" {
|
|
body = `{"model": "gpt-4", "input": "hello", "provider": "nonexistent"}`
|
|
}
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
rec := httptest.NewRecorder()
|
|
|
|
server.handleResponses(rec, req)
|
|
|
|
assert.Equal(t, tt.expectStatus, rec.Code)
|
|
if tt.expectBody != "" {
|
|
assert.Contains(t, rec.Body.String(), tt.expectBody)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestHandleResponses_Stream_Success(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
requestBody string
|
|
setupMock func(p *mockProvider)
|
|
validate func(t *testing.T, events []api.StreamEvent)
|
|
}{
|
|
{
|
|
name: "simple text streaming",
|
|
requestBody: `{"model": "gpt-4", "input": "hello", "stream": true}`,
|
|
setupMock: func(p *mockProvider) {
|
|
p.streamFunc = func(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (<-chan *api.ProviderStreamDelta, <-chan error) {
|
|
deltaChan := make(chan *api.ProviderStreamDelta)
|
|
errChan := make(chan error, 1)
|
|
go func() {
|
|
defer close(deltaChan)
|
|
defer close(errChan)
|
|
deltaChan <- &api.ProviderStreamDelta{Model: "gpt-4-turbo", Text: "Hello"}
|
|
deltaChan <- &api.ProviderStreamDelta{Text: " there"}
|
|
deltaChan <- &api.ProviderStreamDelta{Done: true}
|
|
}()
|
|
return deltaChan, errChan
|
|
}
|
|
},
|
|
validate: func(t *testing.T, events []api.StreamEvent) {
|
|
require.GreaterOrEqual(t, len(events), 5)
|
|
assert.Equal(t, "response.created", events[0].Type)
|
|
assert.Equal(t, "response.in_progress", events[1].Type)
|
|
assert.Equal(t, "response.output_item.added", events[2].Type)
|
|
|
|
// Find text deltas
|
|
var textDeltas []string
|
|
for _, e := range events {
|
|
if e.Type == "response.output_text.delta" {
|
|
textDeltas = append(textDeltas, e.Delta)
|
|
}
|
|
}
|
|
assert.Equal(t, []string{"Hello", " there"}, textDeltas)
|
|
|
|
// Last event should be response.completed
|
|
lastEvent := events[len(events)-1]
|
|
assert.Equal(t, "response.completed", lastEvent.Type)
|
|
require.NotNil(t, lastEvent.Response)
|
|
assert.Equal(t, "completed", lastEvent.Response.Status)
|
|
assert.Equal(t, "gpt-4-turbo", lastEvent.Response.Model)
|
|
},
|
|
},
|
|
{
|
|
name: "streaming with tool calls",
|
|
requestBody: `{"model": "gpt-4", "input": "weather?", "stream": true}`,
|
|
setupMock: func(p *mockProvider) {
|
|
p.streamFunc = func(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (<-chan *api.ProviderStreamDelta, <-chan error) {
|
|
deltaChan := make(chan *api.ProviderStreamDelta)
|
|
errChan := make(chan error, 1)
|
|
go func() {
|
|
defer close(deltaChan)
|
|
defer close(errChan)
|
|
deltaChan <- &api.ProviderStreamDelta{Model: "gpt-4", Text: "Let me check"}
|
|
deltaChan <- &api.ProviderStreamDelta{
|
|
ToolCallDelta: &api.ToolCallDelta{
|
|
Index: 0,
|
|
ID: "call_abc",
|
|
Name: "get_weather",
|
|
},
|
|
}
|
|
deltaChan <- &api.ProviderStreamDelta{
|
|
ToolCallDelta: &api.ToolCallDelta{
|
|
Index: 0,
|
|
Arguments: `{"location":"NYC"}`,
|
|
},
|
|
}
|
|
deltaChan <- &api.ProviderStreamDelta{Done: true}
|
|
}()
|
|
return deltaChan, errChan
|
|
}
|
|
},
|
|
validate: func(t *testing.T, events []api.StreamEvent) {
|
|
// Find tool call events
|
|
var toolCallAdded bool
|
|
var argsDeltas []string
|
|
for _, e := range events {
|
|
if e.Type == "response.output_item.added" && e.Item != nil && e.Item.Type == "function_call" {
|
|
toolCallAdded = true
|
|
assert.Equal(t, "call_abc", e.Item.CallID)
|
|
assert.Equal(t, "get_weather", e.Item.Name)
|
|
}
|
|
if e.Type == "response.function_call_arguments.delta" {
|
|
argsDeltas = append(argsDeltas, e.Delta)
|
|
}
|
|
}
|
|
assert.True(t, toolCallAdded, "should have tool call added event")
|
|
assert.Equal(t, []string{`{"location":"NYC"}`}, argsDeltas)
|
|
},
|
|
},
|
|
{
|
|
name: "streaming with multiple tool calls",
|
|
requestBody: `{"model": "gpt-4", "input": "check multiple", "stream": true}`,
|
|
setupMock: func(p *mockProvider) {
|
|
p.streamFunc = func(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (<-chan *api.ProviderStreamDelta, <-chan error) {
|
|
deltaChan := make(chan *api.ProviderStreamDelta)
|
|
errChan := make(chan error, 1)
|
|
go func() {
|
|
defer close(deltaChan)
|
|
defer close(errChan)
|
|
// First tool call
|
|
deltaChan <- &api.ProviderStreamDelta{
|
|
ToolCallDelta: &api.ToolCallDelta{
|
|
Index: 0,
|
|
ID: "call_1",
|
|
Name: "tool_a",
|
|
},
|
|
}
|
|
deltaChan <- &api.ProviderStreamDelta{
|
|
ToolCallDelta: &api.ToolCallDelta{
|
|
Index: 0,
|
|
Arguments: `{"a":1}`,
|
|
},
|
|
}
|
|
// Second tool call
|
|
deltaChan <- &api.ProviderStreamDelta{
|
|
ToolCallDelta: &api.ToolCallDelta{
|
|
Index: 1,
|
|
ID: "call_2",
|
|
Name: "tool_b",
|
|
},
|
|
}
|
|
deltaChan <- &api.ProviderStreamDelta{
|
|
ToolCallDelta: &api.ToolCallDelta{
|
|
Index: 1,
|
|
Arguments: `{"b":2}`,
|
|
},
|
|
}
|
|
deltaChan <- &api.ProviderStreamDelta{Done: true}
|
|
}()
|
|
return deltaChan, errChan
|
|
}
|
|
},
|
|
validate: func(t *testing.T, events []api.StreamEvent) {
|
|
var toolCallCount int
|
|
for _, e := range events {
|
|
if e.Type == "response.output_item.added" && e.Item != nil && e.Item.Type == "function_call" {
|
|
toolCallCount++
|
|
}
|
|
}
|
|
assert.Equal(t, 2, toolCallCount)
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
registry := newMockRegistry()
|
|
provider := newMockProvider("openai")
|
|
if tt.setupMock != nil {
|
|
tt.setupMock(provider)
|
|
}
|
|
registry.addProvider("openai", provider)
|
|
registry.addModel("gpt-4", "openai")
|
|
|
|
server := New(registry, newMockConversationStore(), newMockLogger().asLogger())
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(tt.requestBody))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
rec := newFlushableRecorder()
|
|
|
|
server.handleResponses(rec, req)
|
|
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
assert.Equal(t, "text/event-stream", rec.Header().Get("Content-Type"))
|
|
|
|
events, err := parseSSEEvents(rec.Body)
|
|
require.NoError(t, err)
|
|
|
|
if tt.validate != nil {
|
|
tt.validate(t, events)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestHandleResponses_Stream_Errors(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
setupMock func(p *mockProvider)
|
|
validate func(t *testing.T, events []api.StreamEvent)
|
|
}{
|
|
{
|
|
name: "stream error returns failed event",
|
|
setupMock: func(p *mockProvider) {
|
|
p.streamFunc = func(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (<-chan *api.ProviderStreamDelta, <-chan error) {
|
|
deltaChan := make(chan *api.ProviderStreamDelta)
|
|
errChan := make(chan error, 1)
|
|
go func() {
|
|
defer close(deltaChan)
|
|
defer close(errChan)
|
|
errChan <- fmt.Errorf("stream error occurred")
|
|
}()
|
|
return deltaChan, errChan
|
|
}
|
|
},
|
|
validate: func(t *testing.T, events []api.StreamEvent) {
|
|
// Should have initial events and then failed event
|
|
var foundFailed bool
|
|
for _, e := range events {
|
|
if e.Type == "response.failed" {
|
|
foundFailed = true
|
|
require.NotNil(t, e.Response)
|
|
assert.Equal(t, "failed", e.Response.Status)
|
|
require.NotNil(t, e.Response.Error)
|
|
assert.Contains(t, e.Response.Error.Message, "stream error")
|
|
}
|
|
}
|
|
assert.True(t, foundFailed)
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
registry := newMockRegistry()
|
|
provider := newMockProvider("openai")
|
|
if tt.setupMock != nil {
|
|
tt.setupMock(provider)
|
|
}
|
|
registry.addProvider("openai", provider)
|
|
registry.addModel("gpt-4", "openai")
|
|
|
|
server := New(registry, newMockConversationStore(), newMockLogger().asLogger())
|
|
|
|
body := `{"model": "gpt-4", "input": "hello", "stream": true}`
|
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
rec := newFlushableRecorder()
|
|
|
|
server.handleResponses(rec, req)
|
|
|
|
events, err := parseSSEEvents(rec.Body)
|
|
require.NoError(t, err)
|
|
|
|
if tt.validate != nil {
|
|
tt.validate(t, events)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestResolveProvider(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
setupServer func() *GatewayServer
|
|
request api.ResponseRequest
|
|
expectError bool
|
|
errorMsg string
|
|
validate func(t *testing.T, provider any)
|
|
}{
|
|
{
|
|
name: "explicit provider selection",
|
|
setupServer: func() *GatewayServer {
|
|
registry := newMockRegistry()
|
|
registry.addProvider("openai", newMockProvider("openai"))
|
|
registry.addProvider("anthropic", newMockProvider("anthropic"))
|
|
return New(registry, newMockConversationStore(), newMockLogger().asLogger())
|
|
},
|
|
request: api.ResponseRequest{
|
|
Model: "gpt-4",
|
|
Provider: "anthropic",
|
|
},
|
|
validate: func(t *testing.T, provider any) {
|
|
assert.Equal(t, "anthropic", provider.(*mockProvider).Name())
|
|
},
|
|
},
|
|
{
|
|
name: "default by model name",
|
|
setupServer: func() *GatewayServer {
|
|
registry := newMockRegistry()
|
|
registry.addProvider("openai", newMockProvider("openai"))
|
|
registry.addModel("gpt-4", "openai")
|
|
return New(registry, newMockConversationStore(), newMockLogger().asLogger())
|
|
},
|
|
request: api.ResponseRequest{
|
|
Model: "gpt-4",
|
|
},
|
|
validate: func(t *testing.T, provider any) {
|
|
assert.Equal(t, "openai", provider.(*mockProvider).Name())
|
|
},
|
|
},
|
|
{
|
|
name: "provider not found returns error",
|
|
setupServer: func() *GatewayServer {
|
|
registry := newMockRegistry()
|
|
registry.addProvider("openai", newMockProvider("openai"))
|
|
return New(registry, newMockConversationStore(), newMockLogger().asLogger())
|
|
},
|
|
request: api.ResponseRequest{
|
|
Model: "gpt-4",
|
|
Provider: "nonexistent",
|
|
},
|
|
expectError: true,
|
|
errorMsg: "provider nonexistent not configured",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
server := tt.setupServer()
|
|
provider, err := server.resolveProvider(&tt.request)
|
|
|
|
if tt.expectError {
|
|
assert.Error(t, err)
|
|
if tt.errorMsg != "" {
|
|
assert.Contains(t, err.Error(), tt.errorMsg)
|
|
}
|
|
return
|
|
}
|
|
|
|
require.NoError(t, err)
|
|
require.NotNil(t, provider)
|
|
if tt.validate != nil {
|
|
tt.validate(t, provider)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGenerateID(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
prefix string
|
|
}{
|
|
{
|
|
name: "resp_ prefix",
|
|
prefix: "resp_",
|
|
},
|
|
{
|
|
name: "msg_ prefix",
|
|
prefix: "msg_",
|
|
},
|
|
{
|
|
name: "item_ prefix",
|
|
prefix: "item_",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
id := generateID(tt.prefix)
|
|
assert.True(t, strings.HasPrefix(id, tt.prefix))
|
|
assert.Len(t, id, len(tt.prefix)+24)
|
|
|
|
// Generate another to ensure uniqueness
|
|
id2 := generateID(tt.prefix)
|
|
assert.NotEqual(t, id, id2)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestBuildResponse(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
request *api.ResponseRequest
|
|
result *api.ProviderResult
|
|
provider string
|
|
id string
|
|
validate func(t *testing.T, resp *api.Response)
|
|
}{
|
|
{
|
|
name: "minimal response structure",
|
|
request: &api.ResponseRequest{
|
|
Model: "gpt-4",
|
|
},
|
|
result: &api.ProviderResult{
|
|
Model: "gpt-4-turbo",
|
|
Text: "Hello",
|
|
},
|
|
provider: "openai",
|
|
id: "resp_123",
|
|
validate: func(t *testing.T, resp *api.Response) {
|
|
assert.Equal(t, "resp_123", resp.ID)
|
|
assert.Equal(t, "response", resp.Object)
|
|
assert.Equal(t, "completed", resp.Status)
|
|
assert.Equal(t, "gpt-4-turbo", resp.Model)
|
|
assert.Equal(t, "openai", resp.Provider)
|
|
assert.NotNil(t, resp.CompletedAt)
|
|
assert.Len(t, resp.Output, 1)
|
|
assert.Equal(t, "message", resp.Output[0].Type)
|
|
},
|
|
},
|
|
{
|
|
name: "response with tool calls",
|
|
request: &api.ResponseRequest{
|
|
Model: "gpt-4",
|
|
},
|
|
result: &api.ProviderResult{
|
|
Model: "gpt-4",
|
|
Text: "Let me check",
|
|
ToolCalls: []api.ToolCall{
|
|
{ID: "call_1", Name: "get_weather", Arguments: `{"location":"NYC"}`},
|
|
},
|
|
},
|
|
provider: "openai",
|
|
id: "resp_456",
|
|
validate: func(t *testing.T, resp *api.Response) {
|
|
assert.Len(t, resp.Output, 2)
|
|
assert.Equal(t, "message", resp.Output[0].Type)
|
|
assert.Equal(t, "function_call", resp.Output[1].Type)
|
|
assert.Equal(t, "call_1", resp.Output[1].CallID)
|
|
assert.Equal(t, "get_weather", resp.Output[1].Name)
|
|
},
|
|
},
|
|
{
|
|
name: "parameter echoing with defaults",
|
|
request: &api.ResponseRequest{
|
|
Model: "gpt-4",
|
|
},
|
|
result: &api.ProviderResult{
|
|
Model: "gpt-4",
|
|
Text: "response",
|
|
},
|
|
provider: "openai",
|
|
id: "resp_789",
|
|
validate: func(t *testing.T, resp *api.Response) {
|
|
assert.Equal(t, 1.0, resp.Temperature)
|
|
assert.Equal(t, 1.0, resp.TopP)
|
|
assert.Equal(t, 0.0, resp.PresencePenalty)
|
|
assert.Equal(t, 0.0, resp.FrequencyPenalty)
|
|
assert.Equal(t, 0, resp.TopLogprobs)
|
|
assert.True(t, resp.ParallelToolCalls)
|
|
assert.True(t, resp.Store)
|
|
assert.False(t, resp.Background)
|
|
assert.Equal(t, "disabled", resp.Truncation)
|
|
assert.Equal(t, "default", resp.ServiceTier)
|
|
},
|
|
},
|
|
{
|
|
name: "parameter echoing with custom values",
|
|
request: &api.ResponseRequest{
|
|
Model: "gpt-4",
|
|
Temperature: floatPtr(0.7),
|
|
TopP: floatPtr(0.9),
|
|
PresencePenalty: floatPtr(0.5),
|
|
FrequencyPenalty: floatPtr(0.3),
|
|
TopLogprobs: intPtr(5),
|
|
ParallelToolCalls: boolPtr(false),
|
|
Store: boolPtr(false),
|
|
Background: boolPtr(true),
|
|
Truncation: stringPtr("auto"),
|
|
ServiceTier: stringPtr("premium"),
|
|
},
|
|
result: &api.ProviderResult{
|
|
Model: "gpt-4",
|
|
Text: "response",
|
|
},
|
|
provider: "openai",
|
|
id: "resp_custom",
|
|
validate: func(t *testing.T, resp *api.Response) {
|
|
assert.Equal(t, 0.7, resp.Temperature)
|
|
assert.Equal(t, 0.9, resp.TopP)
|
|
assert.Equal(t, 0.5, resp.PresencePenalty)
|
|
assert.Equal(t, 0.3, resp.FrequencyPenalty)
|
|
assert.Equal(t, 5, resp.TopLogprobs)
|
|
assert.False(t, resp.ParallelToolCalls)
|
|
assert.False(t, resp.Store)
|
|
assert.True(t, resp.Background)
|
|
assert.Equal(t, "auto", resp.Truncation)
|
|
assert.Equal(t, "premium", resp.ServiceTier)
|
|
},
|
|
},
|
|
{
|
|
name: "usage included when text present",
|
|
request: &api.ResponseRequest{
|
|
Model: "gpt-4",
|
|
},
|
|
result: &api.ProviderResult{
|
|
Model: "gpt-4",
|
|
Text: "response",
|
|
Usage: api.Usage{
|
|
InputTokens: 10,
|
|
OutputTokens: 20,
|
|
TotalTokens: 30,
|
|
},
|
|
},
|
|
provider: "openai",
|
|
id: "resp_usage",
|
|
validate: func(t *testing.T, resp *api.Response) {
|
|
require.NotNil(t, resp.Usage)
|
|
assert.Equal(t, 10, resp.Usage.InputTokens)
|
|
assert.Equal(t, 20, resp.Usage.OutputTokens)
|
|
assert.Equal(t, 30, resp.Usage.TotalTokens)
|
|
},
|
|
},
|
|
{
|
|
name: "no usage when no text",
|
|
request: &api.ResponseRequest{
|
|
Model: "gpt-4",
|
|
},
|
|
result: &api.ProviderResult{
|
|
Model: "gpt-4",
|
|
ToolCalls: []api.ToolCall{
|
|
{ID: "call_1", Name: "func", Arguments: "{}"},
|
|
},
|
|
},
|
|
provider: "openai",
|
|
id: "resp_no_usage",
|
|
validate: func(t *testing.T, resp *api.Response) {
|
|
assert.Nil(t, resp.Usage)
|
|
},
|
|
},
|
|
{
|
|
name: "instructions prepended",
|
|
request: &api.ResponseRequest{
|
|
Model: "gpt-4",
|
|
Instructions: stringPtr("Be helpful"),
|
|
},
|
|
result: &api.ProviderResult{
|
|
Model: "gpt-4",
|
|
Text: "response",
|
|
},
|
|
provider: "openai",
|
|
id: "resp_instr",
|
|
validate: func(t *testing.T, resp *api.Response) {
|
|
require.NotNil(t, resp.Instructions)
|
|
assert.Equal(t, "Be helpful", *resp.Instructions)
|
|
},
|
|
},
|
|
{
|
|
name: "previous_response_id included",
|
|
request: &api.ResponseRequest{
|
|
Model: "gpt-4",
|
|
PreviousResponseID: stringPtr("prev_123"),
|
|
},
|
|
result: &api.ProviderResult{
|
|
Model: "gpt-4",
|
|
Text: "response",
|
|
},
|
|
provider: "openai",
|
|
id: "resp_prev",
|
|
validate: func(t *testing.T, resp *api.Response) {
|
|
require.NotNil(t, resp.PreviousResponseID)
|
|
assert.Equal(t, "prev_123", *resp.PreviousResponseID)
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
server := New(newMockRegistry(), newMockConversationStore(), newMockLogger().asLogger())
|
|
resp := server.buildResponse(tt.request, tt.result, tt.provider, tt.id)
|
|
|
|
require.NotNil(t, resp)
|
|
if tt.validate != nil {
|
|
tt.validate(t, resp)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSendSSE(t *testing.T) {
|
|
server := New(newMockRegistry(), newMockConversationStore(), newMockLogger().asLogger())
|
|
rec := newFlushableRecorder()
|
|
seq := 0
|
|
|
|
event := &api.StreamEvent{
|
|
Type: "test.event",
|
|
}
|
|
|
|
server.sendSSE(rec, rec, &seq, "test.event", event)
|
|
|
|
assert.Equal(t, 1, seq)
|
|
assert.Equal(t, 0, event.SequenceNumber)
|
|
body := rec.Body.String()
|
|
assert.Contains(t, body, "event: test.event")
|
|
assert.Contains(t, body, "data:")
|
|
assert.Contains(t, body, `"type":"test.event"`)
|
|
}
|
|
|
|
// Helper functions
|
|
func stringPtr(s string) *string {
|
|
return &s
|
|
}
|
|
|
|
func intPtr(i int) *int {
|
|
return &i
|
|
}
|
|
|
|
func floatPtr(f float64) *float64 {
|
|
return &f
|
|
}
|
|
|
|
func boolPtr(b bool) *bool {
|
|
return &b
|
|
}
|
|
|
|
// flushableRecorder wraps httptest.ResponseRecorder to support Flusher interface
|
|
type flushableRecorder struct {
|
|
*httptest.ResponseRecorder
|
|
flushed int
|
|
}
|
|
|
|
func newFlushableRecorder() *flushableRecorder {
|
|
return &flushableRecorder{
|
|
ResponseRecorder: httptest.NewRecorder(),
|
|
}
|
|
}
|
|
|
|
func (f *flushableRecorder) Flush() {
|
|
f.flushed++
|
|
}
|
|
|
|
// parseSSEEvents parses Server-Sent Events from a reader
|
|
func parseSSEEvents(body io.Reader) ([]api.StreamEvent, error) {
|
|
var events []api.StreamEvent
|
|
scanner := bufio.NewScanner(body)
|
|
|
|
var currentEvent string
|
|
var currentData bytes.Buffer
|
|
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
|
|
if line == "" {
|
|
// Empty line marks end of event
|
|
if currentEvent != "" && currentData.Len() > 0 {
|
|
var event api.StreamEvent
|
|
if err := json.Unmarshal(currentData.Bytes(), &event); err != nil {
|
|
return nil, fmt.Errorf("failed to parse event data: %w", err)
|
|
}
|
|
events = append(events, event)
|
|
currentEvent = ""
|
|
currentData.Reset()
|
|
}
|
|
continue
|
|
}
|
|
|
|
if strings.HasPrefix(line, "event: ") {
|
|
currentEvent = strings.TrimPrefix(line, "event: ")
|
|
} else if strings.HasPrefix(line, "data: ") {
|
|
data := strings.TrimPrefix(line, "data: ")
|
|
currentData.WriteString(data)
|
|
}
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return events, nil
|
|
}
|