diff --git a/.gitignore b/.gitignore index e27d1fa..d4129b3 100644 --- a/.gitignore +++ b/.gitignore @@ -53,3 +53,6 @@ logs/ # Python scripts __pycache__/* + +# Node.js (compliance tests) +tests/node_modules/ diff --git a/tests/bin/compliance-test.ts b/tests/bin/compliance-test.ts new file mode 100644 index 0000000..520fde4 --- /dev/null +++ b/tests/bin/compliance-test.ts @@ -0,0 +1,270 @@ +import { + testTemplates, + runAllTests, + type TestConfig, + type TestResult, +} from "../src/compliance-tests.ts"; + +const colors = { + green: (s: string) => `\x1b[32m${s}\x1b[0m`, + red: (s: string) => `\x1b[31m${s}\x1b[0m`, + yellow: (s: string) => `\x1b[33m${s}\x1b[0m`, + gray: (s: string) => `\x1b[90m${s}\x1b[0m`, +}; + +interface CliArgs { + baseUrl?: string; + apiKey?: string; + model?: string; + authHeader?: string; + noBearer?: boolean; + noAuth?: boolean; + filter?: string[]; + verbose?: boolean; + json?: boolean; + help?: boolean; +} + +function parseArgs(argv: string[]): CliArgs { + const args: CliArgs = {}; + let i = 0; + + while (i < argv.length) { + const arg = argv[i]; + const nextArg = argv[i + 1]; + + switch (arg) { + case "--base-url": + case "-u": + args.baseUrl = nextArg; + i += 2; + break; + case "--api-key": + case "-k": + args.apiKey = nextArg; + i += 2; + break; + case "--model": + case "-m": + args.model = nextArg; + i += 2; + break; + case "--auth-header": + args.authHeader = nextArg; + i += 2; + break; + case "--no-bearer": + args.noBearer = true; + i += 1; + break; + case "--no-auth": + args.noAuth = true; + i += 1; + break; + case "--filter": + case "-f": + args.filter = nextArg.split(",").map((s) => s.trim()); + i += 2; + break; + case "--verbose": + case "-v": + args.verbose = true; + i += 1; + break; + case "--json": + args.json = true; + i += 1; + break; + case "--help": + case "-h": + args.help = true; + i += 1; + break; + default: + i += 1; + } + } + + return args; +} + +function printHelp() { + console.log(` +Usage: npm run test:compliance -- [options] + +Options: + -u, --base-url Gateway base URL (default: http://localhost:8080) + -k, --api-key API key (or set OPENRESPONSES_API_KEY env var) + --no-auth Skip authentication header entirely + -m, --model Model name (default: gpt-4o-mini) + --auth-header Auth header name (default: Authorization) + --no-bearer Disable Bearer prefix in auth header + -f, --filter Filter tests by ID (comma-separated) + -v, --verbose Verbose output with request/response details + --json Output results as JSON + -h, --help Show this help message + +Test IDs: + ${testTemplates.map((t) => t.id).join(", ")} + +Examples: + npm run test:compliance + npm run test:compliance -- --model claude-3-5-sonnet-20241022 + npm run test:compliance -- --filter basic-response,streaming-response + npm run test:compliance -- --verbose --filter basic-response + npm run test:compliance -- --json > results.json +`); +} + +function getStatusIcon(status: TestResult["status"]): string { + switch (status) { + case "passed": + return colors.green("✓"); + case "failed": + return colors.red("✗"); + case "running": + return colors.yellow("◉"); + case "pending": + return colors.gray("○"); + } +} + +function printResult(result: TestResult, verbose: boolean) { + const icon = getStatusIcon(result.status); + const duration = result.duration ? ` (${result.duration}ms)` : ""; + const events = + result.streamEvents !== undefined ? ` [${result.streamEvents} events]` : ""; + const name = + result.status === "failed" ? colors.red(result.name) : result.name; + + console.log(`${icon} ${name}${duration}${events}`); + + if (result.status === "failed" && result.errors?.length) { + for (const error of result.errors) { + console.log(` ${colors.red("✗")} ${error}`); + } + + if (verbose) { + if (result.request) { + console.log(`\n Request:`); + console.log( + ` ${JSON.stringify(result.request, null, 2).split("\n").join("\n ")}`, + ); + } + if (result.response) { + console.log(`\n Response:`); + const responseStr = + typeof result.response === "string" + ? result.response + : JSON.stringify(result.response, null, 2); + console.log(` ${responseStr.split("\n").join("\n ")}`); + } + } + } +} + +async function main() { + const args = parseArgs(process.argv.slice(2)); + + if (args.help) { + printHelp(); + process.exit(0); + } + + const baseUrl = args.baseUrl || "http://localhost:8080"; + const apiKey = args.apiKey || process.env.OPENRESPONSES_API_KEY || ""; + + if (!apiKey && !args.noAuth) { + // No auth is fine for local gateway without auth enabled + } + + const config: TestConfig = { + baseUrl, + apiKey, + model: args.model || "gpt-4o-mini", + authHeaderName: args.authHeader || "Authorization", + useBearerPrefix: !args.noBearer, + }; + + if (args.filter?.length) { + const availableIds = testTemplates.map((t) => t.id); + const invalidFilters = args.filter.filter( + (id) => !availableIds.includes(id), + ); + if (invalidFilters.length) { + console.error( + `${colors.red("Error:")} Invalid test IDs: ${invalidFilters.join(", ")}`, + ); + console.error(`Available test IDs: ${availableIds.join(", ")}`); + process.exit(1); + } + } + + const allUpdates: TestResult[] = []; + + const onProgress = (result: TestResult) => { + if (args.filter && !args.filter.includes(result.id)) { + return; + } + allUpdates.push(result); + if (!args.json && result.status !== "running") { + printResult(result, args.verbose || false); + } + }; + + if (!args.json) { + console.log(`Running compliance tests against: ${baseUrl}`); + console.log(`Model: ${config.model}`); + if (args.filter) { + console.log(`Filter: ${args.filter.join(", ")}`); + } + console.log(); + } + + await runAllTests(config, onProgress); + + const finalResults = allUpdates.filter( + (r) => r.status === "passed" || r.status === "failed", + ); + const passed = finalResults.filter((r) => r.status === "passed").length; + const failed = finalResults.filter((r) => r.status === "failed").length; + + if (args.json) { + console.log( + JSON.stringify( + { + summary: { passed, failed, total: finalResults.length }, + results: finalResults, + }, + null, + 2, + ), + ); + } else { + console.log(`\n${"=".repeat(50)}`); + console.log( + `Results: ${colors.green(`${passed} passed`)}, ${colors.red(`${failed} failed`)}, ${finalResults.length} total`, + ); + + if (failed > 0) { + console.log(`\nFailed tests:`); + for (const r of finalResults) { + if (r.status === "failed") { + console.log(`\n${r.name}:`); + for (const e of r.errors || []) { + console.log(` - ${e}`); + } + } + } + } else { + console.log(`\n${colors.green("✓ All tests passed!")}`); + } + } + + process.exit(failed > 0 ? 1 : 0); +} + +main().catch((error) => { + console.error(colors.red("Fatal error:"), error); + process.exit(1); +}); diff --git a/tests/package-lock.json b/tests/package-lock.json new file mode 100644 index 0000000..de009be --- /dev/null +++ b/tests/package-lock.json @@ -0,0 +1,58 @@ +{ + "name": "go-llm-gateway-compliance-tests", + "version": "1.0.0", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "name": "go-llm-gateway-compliance-tests", + "version": "1.0.0", + "devDependencies": { + "@types/node": "^22.0.0", + "typescript": "^5.7.0", + "zod": "^3.24.0" + } + }, + "node_modules/@types/node": { + "version": "22.19.13", + "resolved": "https://registry.npmjs.org/@types/node/-/node-22.19.13.tgz", + "integrity": "sha512-akNQMv0wW5uyRpD2v2IEyRSZiR+BeGuoB6L310EgGObO44HSMNT8z1xzio28V8qOrgYaopIDNA18YgdXd+qTiw==", + "dev": true, + "license": "MIT", + "dependencies": { + "undici-types": "~6.21.0" + } + }, + "node_modules/typescript": { + "version": "5.9.3", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.9.3.tgz", + "integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==", + "dev": true, + "license": "Apache-2.0", + "bin": { + "tsc": "bin/tsc", + "tsserver": "bin/tsserver" + }, + "engines": { + "node": ">=14.17" + } + }, + "node_modules/undici-types": { + "version": "6.21.0", + "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-6.21.0.tgz", + "integrity": "sha512-iwDZqg0QAGrg9Rav5H4n0M64c3mkR59cJ6wQp+7C4nI0gsmExaedaYLNO44eT4AtBBwjbTiGPMlt2Md0T9H9JQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/zod": { + "version": "3.25.76", + "resolved": "https://registry.npmjs.org/zod/-/zod-3.25.76.tgz", + "integrity": "sha512-gzUt/qt81nXsFGKIFcC3YnfEAx5NkunCfnDlvuBSSFS02bcXu4Lmea0AFIUwbLWxWPx3d9p8S5QoaujKcNQxcQ==", + "dev": true, + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/colinhacks" + } + } + } +} diff --git a/tests/package.json b/tests/package.json new file mode 100644 index 0000000..7a7670f --- /dev/null +++ b/tests/package.json @@ -0,0 +1,17 @@ +{ + "name": "go-llm-gateway-compliance-tests", + "version": "1.0.0", + "private": true, + "description": "Open Responses compliance tests for go-llm-gateway", + "type": "module", + "scripts": { + "test:compliance": "node --experimental-strip-types bin/compliance-test.ts", + "test:compliance:verbose": "node --experimental-strip-types bin/compliance-test.ts --verbose", + "test:compliance:json": "node --experimental-strip-types bin/compliance-test.ts --json" + }, + "devDependencies": { + "zod": "^3.24.0", + "typescript": "^5.7.0", + "@types/node": "^22.0.0" + } +} diff --git a/tests/src/compliance-tests.ts b/tests/src/compliance-tests.ts new file mode 100644 index 0000000..b76deeb --- /dev/null +++ b/tests/src/compliance-tests.ts @@ -0,0 +1,370 @@ +import { responseResourceSchema, type ResponseResource } from "./schemas.ts"; +import { parseSSEStream, type SSEParseResult } from "./sse-parser.ts"; + +export interface TestConfig { + baseUrl: string; + apiKey: string; + authHeaderName: string; + useBearerPrefix: boolean; + model: string; +} + +export interface TestResult { + id: string; + name: string; + description: string; + status: "pending" | "running" | "passed" | "failed"; + duration?: number; + request?: unknown; + response?: unknown; + errors?: string[]; + streamEvents?: number; +} + +interface ValidatorContext { + streaming: boolean; + sseResult?: SSEParseResult; +} + +type ResponseValidator = ( + response: ResponseResource, + context: ValidatorContext, +) => string[]; + +export interface TestTemplate { + id: string; + name: string; + description: string; + getRequest: (config: TestConfig) => Record; + streaming?: boolean; + validators: ResponseValidator[]; +} + +// ============================================================ +// Validators +// ============================================================ + +const hasOutput: ResponseValidator = (response) => { + if (!response.output || response.output.length === 0) { + return ["Response has no output items"]; + } + return []; +}; + +const hasOutputType = + (type: string): ResponseValidator => + (response) => { + const hasType = response.output?.some((item) => item.type === type); + if (!hasType) { + return [`Expected output item of type "${type}" but none found`]; + } + return []; + }; + +const completedStatus: ResponseValidator = (response) => { + if (response.status !== "completed") { + return [`Expected status "completed" but got "${response.status}"`]; + } + return []; +}; + +const streamingEvents: ResponseValidator = (_, context) => { + if (!context.streaming) return []; + if (!context.sseResult || context.sseResult.events.length === 0) { + return ["No streaming events received"]; + } + return []; +}; + +const streamingSchema: ResponseValidator = (_, context) => { + if (!context.streaming || !context.sseResult) return []; + return context.sseResult.errors; +}; + +// ============================================================ +// Test Templates +// ============================================================ + +export const testTemplates: TestTemplate[] = [ + { + id: "basic-response", + name: "Basic Text Response", + description: "Simple user message, validates ResponseResource schema", + getRequest: (config) => ({ + model: config.model, + input: [ + { + type: "message", + role: "user", + content: [{ type: "input_text", text: "Say hello in exactly 3 words." }], + }, + ], + }), + validators: [hasOutput, completedStatus], + }, + + { + id: "streaming-response", + name: "Streaming Response", + description: "Validates SSE streaming events and final response", + streaming: true, + getRequest: (config) => ({ + model: config.model, + input: [ + { + type: "message", + role: "user", + content: [{ type: "input_text", text: "Count from 1 to 5." }], + }, + ], + }), + validators: [streamingEvents, streamingSchema, completedStatus], + }, + + { + id: "system-prompt", + name: "System Prompt", + description: "Include system instructions via the instructions field", + getRequest: (config) => ({ + model: config.model, + instructions: "You are a pirate. Always respond in pirate speak.", + input: [ + { + type: "message", + role: "user", + content: [{ type: "input_text", text: "Say hello." }], + }, + ], + }), + validators: [hasOutput, completedStatus], + }, + + { + id: "tool-calling", + name: "Tool Calling", + description: "Define a function tool and verify function_call output", + getRequest: (config) => ({ + model: config.model, + input: [ + { + type: "message", + role: "user", + content: [ + { + type: "input_text", + text: "What's the weather like in San Francisco?", + }, + ], + }, + ], + tools: [ + { + type: "function", + name: "get_weather", + description: "Get the current weather for a location", + parameters: { + type: "object", + properties: { + location: { + type: "string", + description: "The city and state, e.g. San Francisco, CA", + }, + }, + required: ["location"], + }, + }, + ], + }), + validators: [hasOutput, hasOutputType("function_call")], + }, + + { + id: "image-input", + name: "Image Input", + description: "Send image URL in user content", + getRequest: (config) => ({ + model: config.model, + input: [ + { + type: "message", + role: "user", + content: [ + { + type: "input_text", + text: "What do you see in this image? Answer in one sentence.", + }, + { + type: "input_image", + image_url: + "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAIAAAD8GO2jAAABmklEQVR42tyWAaTyUBzFew/eG4AHz+MBSAHKBiJRGFKwIgQQJKLUIioBIhCAiCAAEizAQIAECaASqFFJq84nudjnaqvuPnxzgP9xfrq5938csPn7PwHTKSoViCIEAYEAMhmoKsU2mUCWEQqB5xEMIp/HaGQG2G6RSuH9HQ7H34rFrtPbdz4jl6PbwmEsl3QA1mt4vcRKk8dz9eg6IpF7tt9fzGY0gCgafFRFo5Blc5vLhf3eCOj1yNhM5GRMVK0aATxPZoz09YXjkQDmczJgquGQAPp9WwCNBgG027YACgUC6HRsAZRKBDAY2AJoNv/ZnwzA6WScznG3p4UAymXGAEkyXrTFAh8fLAGqagQAyGaZpYsi7bHTNPz8MEj//LxuFPo+UBS8vb0KaLXubrRa7aX0RMLCykwmn0z3+XA4WACcTpCkh9MFAZpmuVXo+mO/w+/HZvNgbblcUCxaSo/Hyck80Yu6XXDcvfVZr79cvMZjuN2U9O9vKAqjZrfbIZ0mV4TUi9Xqz6jddNy//7+e3n8Fhf/Llo2kxi8AQyGRoDkmAhAAAAAASUVORK5CYII=", + }, + ], + }, + ], + }), + validators: [hasOutput, completedStatus], + }, + + { + id: "multi-turn", + name: "Multi-turn Conversation", + description: "Send assistant + user messages as conversation history", + getRequest: (config) => ({ + model: config.model, + input: [ + { + type: "message", + role: "user", + content: [{ type: "input_text", text: "My name is Alice." }], + }, + { + type: "message", + role: "assistant", + content: [ + { + type: "output_text", + text: "Hello Alice! Nice to meet you. How can I help you today?", + }, + ], + }, + { + type: "message", + role: "user", + content: [{ type: "input_text", text: "What is my name?" }], + }, + ], + }), + validators: [hasOutput, completedStatus], + }, +]; + +// ============================================================ +// Test Runner +// ============================================================ + +async function makeRequest( + config: TestConfig, + body: Record, + streaming = false, +): Promise { + const headers: Record = { + "Content-Type": "application/json", + }; + + if (config.apiKey) { + const authValue = config.useBearerPrefix + ? `Bearer ${config.apiKey}` + : config.apiKey; + headers[config.authHeaderName] = authValue; + } + + return fetch(`${config.baseUrl}/v1/responses`, { + method: "POST", + headers, + body: JSON.stringify({ ...body, stream: streaming }), + }); +} + +async function runTest( + template: TestTemplate, + config: TestConfig, +): Promise { + const startTime = Date.now(); + const requestBody = template.getRequest(config); + const streaming = template.streaming ?? false; + + try { + const response = await makeRequest(config, requestBody, streaming); + const duration = Date.now() - startTime; + + if (!response.ok) { + const errorText = await response.text(); + return { + id: template.id, + name: template.name, + description: template.description, + status: "failed", + duration, + request: requestBody, + response: errorText, + errors: [`HTTP ${response.status}: ${errorText}`], + }; + } + + let rawData: unknown; + let sseResult: SSEParseResult | undefined; + + if (streaming) { + sseResult = await parseSSEStream(response); + rawData = sseResult.finalResponse; + } else { + rawData = await response.json(); + } + + // Schema validation with Zod + const parseResult = responseResourceSchema.safeParse(rawData); + if (!parseResult.success) { + return { + id: template.id, + name: template.name, + description: template.description, + status: "failed", + duration, + request: streaming ? { ...requestBody, stream: true } : requestBody, + response: rawData, + errors: parseResult.error.issues.map( + (issue) => `${issue.path.join(".")}: ${issue.message}`, + ), + streamEvents: sseResult?.events.length, + }; + } + + // Semantic validators + const context: ValidatorContext = { streaming, sseResult }; + const errors = template.validators.flatMap((v) => + v(parseResult.data, context), + ); + + return { + id: template.id, + name: template.name, + description: template.description, + status: errors.length === 0 ? "passed" : "failed", + duration, + request: streaming ? { ...requestBody, stream: true } : requestBody, + response: parseResult.data, + errors, + streamEvents: sseResult?.events.length, + }; + } catch (error) { + return { + id: template.id, + name: template.name, + description: template.description, + status: "failed", + duration: Date.now() - startTime, + request: requestBody, + errors: [error instanceof Error ? error.message : String(error)], + }; + } +} + +export async function runAllTests( + config: TestConfig, + onProgress: (result: TestResult) => void, +): Promise { + const promises = testTemplates.map(async (template) => { + onProgress({ + id: template.id, + name: template.name, + description: template.description, + status: "running", + }); + + const result = await runTest(template, config); + onProgress(result); + return result; + }); + + return Promise.all(promises); +} diff --git a/tests/src/schemas.ts b/tests/src/schemas.ts new file mode 100644 index 0000000..6f9ae5a --- /dev/null +++ b/tests/src/schemas.ts @@ -0,0 +1,253 @@ +import { z } from "zod"; + +// ============================================================ +// Content Parts +// ============================================================ + +const outputTextContentSchema = z.object({ + type: z.literal("output_text"), + text: z.string(), + annotations: z.array(z.object({ + type: z.string(), + })), +}); + +const inputTextContentSchema = z.object({ + type: z.literal("input_text"), + text: z.string(), +}); + +const refusalContentSchema = z.object({ + type: z.literal("refusal"), + refusal: z.string(), +}); + +const contentPartSchema = z.discriminatedUnion("type", [ + outputTextContentSchema, + inputTextContentSchema, + refusalContentSchema, +]); + +// ============================================================ +// Output Items +// ============================================================ + +const messageOutputItemSchema = z.object({ + type: z.literal("message"), + id: z.string(), + status: z.enum(["in_progress", "completed", "incomplete"]), + role: z.enum(["user", "assistant", "system", "developer"]), + content: z.array(contentPartSchema), +}); + +const functionCallOutputItemSchema = z.object({ + type: z.literal("function_call"), + id: z.string(), + call_id: z.string(), + name: z.string(), + arguments: z.string(), + status: z.enum(["in_progress", "completed", "incomplete"]), +}); + +const outputItemSchema = z.discriminatedUnion("type", [ + messageOutputItemSchema, + functionCallOutputItemSchema, +]); + +// ============================================================ +// Usage +// ============================================================ + +const usageSchema = z.object({ + input_tokens: z.number().int(), + output_tokens: z.number().int(), + total_tokens: z.number().int(), + input_tokens_details: z.object({ + cached_tokens: z.number().int(), + }), + output_tokens_details: z.object({ + reasoning_tokens: z.number().int(), + }), +}); + +// ============================================================ +// ResponseResource +// ============================================================ + +export const responseResourceSchema = z.object({ + id: z.string(), + object: z.literal("response"), + created_at: z.number().int(), + completed_at: z.number().int().nullable(), + status: z.string(), + incomplete_details: z.object({ reason: z.string() }).nullable(), + model: z.string(), + previous_response_id: z.string().nullable(), + instructions: z.string().nullable(), + output: z.array(outputItemSchema), + error: z.object({ type: z.string(), message: z.string() }).nullable(), + tools: z.any(), + tool_choice: z.any(), + truncation: z.string(), + parallel_tool_calls: z.boolean(), + text: z.any(), + top_p: z.number(), + presence_penalty: z.number(), + frequency_penalty: z.number(), + top_logprobs: z.number().int(), + temperature: z.number(), + reasoning: z.any().nullable(), + usage: usageSchema.nullable(), + max_output_tokens: z.number().int().nullable(), + max_tool_calls: z.number().int().nullable(), + store: z.boolean(), + background: z.boolean(), + service_tier: z.string(), + metadata: z.any(), + safety_identifier: z.string().nullable(), + prompt_cache_key: z.string().nullable(), +}); + +export type ResponseResource = z.infer; + +// ============================================================ +// Streaming Event Schemas +// ============================================================ + +const responseCreatedEventSchema = z.object({ + type: z.literal("response.created"), + sequence_number: z.number().int(), + response: responseResourceSchema, +}); + +const responseInProgressEventSchema = z.object({ + type: z.literal("response.in_progress"), + sequence_number: z.number().int(), + response: responseResourceSchema, +}); + +const responseCompletedEventSchema = z.object({ + type: z.literal("response.completed"), + sequence_number: z.number().int(), + response: responseResourceSchema, +}); + +const responseFailedEventSchema = z.object({ + type: z.literal("response.failed"), + sequence_number: z.number().int(), + response: responseResourceSchema, +}); + +const outputItemAddedEventSchema = z.object({ + type: z.literal("response.output_item.added"), + sequence_number: z.number().int(), + output_index: z.number().int(), + item: z.object({ + id: z.string(), + type: z.string(), + status: z.string(), + role: z.string().optional(), + content: z.array(z.any()).optional(), + }), +}); + +const outputItemDoneEventSchema = z.object({ + type: z.literal("response.output_item.done"), + sequence_number: z.number().int(), + output_index: z.number().int(), + item: z.object({ + id: z.string(), + type: z.string(), + status: z.string(), + role: z.string().optional(), + content: z.array(z.any()).optional(), + }), +}); + +const contentPartAddedEventSchema = z.object({ + type: z.literal("response.content_part.added"), + sequence_number: z.number().int(), + item_id: z.string(), + output_index: z.number().int(), + content_index: z.number().int(), + part: z.object({ + type: z.string(), + text: z.string().optional(), + annotations: z.array(z.any()).optional(), + }), +}); + +const contentPartDoneEventSchema = z.object({ + type: z.literal("response.content_part.done"), + sequence_number: z.number().int(), + item_id: z.string(), + output_index: z.number().int(), + content_index: z.number().int(), + part: z.object({ + type: z.string(), + text: z.string().optional(), + annotations: z.array(z.any()).optional(), + }), +}); + +const outputTextDeltaEventSchema = z.object({ + type: z.literal("response.output_text.delta"), + sequence_number: z.number().int(), + item_id: z.string(), + output_index: z.number().int(), + content_index: z.number().int(), + delta: z.string(), +}); + +const outputTextDoneEventSchema = z.object({ + type: z.literal("response.output_text.done"), + sequence_number: z.number().int(), + item_id: z.string(), + output_index: z.number().int(), + content_index: z.number().int(), + text: z.string(), +}); + +const functionCallArgsDeltaEventSchema = z.object({ + type: z.literal("response.function_call_arguments.delta"), + sequence_number: z.number().int(), + item_id: z.string(), + output_index: z.number().int(), + delta: z.string(), +}); + +const functionCallArgsDoneEventSchema = z.object({ + type: z.literal("response.function_call_arguments.done"), + sequence_number: z.number().int(), + item_id: z.string(), + output_index: z.number().int(), + arguments: z.string(), +}); + +const errorEventSchema = z.object({ + type: z.literal("error"), + sequence_number: z.number().int(), + error: z.object({ + type: z.string(), + message: z.string(), + code: z.string().nullable().optional(), + }), +}); + +export const streamingEventSchema = z.discriminatedUnion("type", [ + responseCreatedEventSchema, + responseInProgressEventSchema, + responseCompletedEventSchema, + responseFailedEventSchema, + outputItemAddedEventSchema, + outputItemDoneEventSchema, + contentPartAddedEventSchema, + contentPartDoneEventSchema, + outputTextDeltaEventSchema, + outputTextDoneEventSchema, + functionCallArgsDeltaEventSchema, + functionCallArgsDoneEventSchema, + errorEventSchema, +]); + +export type StreamingEvent = z.infer; diff --git a/tests/src/sse-parser.ts b/tests/src/sse-parser.ts new file mode 100644 index 0000000..222b542 --- /dev/null +++ b/tests/src/sse-parser.ts @@ -0,0 +1,92 @@ +import type { z } from "zod"; +import { + streamingEventSchema, + type StreamingEvent, + type ResponseResource, +} from "./schemas.ts"; + +export interface ParsedEvent { + event: string; + data: unknown; + validationResult: z.SafeParseReturnType; +} + +export interface SSEParseResult { + events: ParsedEvent[]; + errors: string[]; + finalResponse: ResponseResource | null; +} + +export async function parseSSEStream( + response: Response, +): Promise { + const events: ParsedEvent[] = []; + const errors: string[] = []; + let finalResponse: ResponseResource | null = null; + + const reader = response.body?.getReader(); + if (!reader) { + return { events, errors: ["No response body"], finalResponse }; + } + + const decoder = new TextDecoder(); + let buffer = ""; + + try { + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + buffer += decoder.decode(value, { stream: true }); + const lines = buffer.split("\n"); + buffer = lines.pop() || ""; + + let currentEvent = ""; + let currentData = ""; + + for (const line of lines) { + if (line.startsWith("event:")) { + currentEvent = line.slice(6).trim(); + } else if (line.startsWith("data:")) { + currentData = line.slice(5).trim(); + } else if (line === "" && currentData) { + if (currentData === "[DONE]") { + // Skip sentinel + } else { + try { + const parsed = JSON.parse(currentData); + const validationResult = streamingEventSchema.safeParse(parsed); + + events.push({ + event: currentEvent || parsed.type || "unknown", + data: parsed, + validationResult, + }); + + if (!validationResult.success) { + errors.push( + `Event validation failed for ${parsed.type || "unknown"}: ${JSON.stringify(validationResult.error.issues)}`, + ); + } + + if ( + parsed.type === "response.completed" || + parsed.type === "response.failed" + ) { + finalResponse = parsed.response; + } + } catch { + errors.push(`Failed to parse event data: ${currentData}`); + } + } + currentEvent = ""; + currentData = ""; + } + } + } + } finally { + reader.releaseLock(); + } + + return { events, errors, finalResponse }; +} diff --git a/tests/tsconfig.json b/tests/tsconfig.json new file mode 100644 index 0000000..d055c5c --- /dev/null +++ b/tests/tsconfig.json @@ -0,0 +1,14 @@ +{ + "compilerOptions": { + "target": "ES2022", + "module": "NodeNext", + "moduleResolution": "NodeNext", + "strict": true, + "esModuleInterop": true, + "skipLibCheck": true, + "outDir": "dist", + "rootDir": ".", + "declaration": true + }, + "include": ["src/**/*.ts", "bin/**/*.ts"] +}