Compare commits
10 Commits
3e645a3525
...
6adf7eae54
| Author | SHA1 | Date | |
|---|---|---|---|
| 6adf7eae54 | |||
| 38d44f104a | |||
| 2188e3cba8 | |||
| 830a87afa1 | |||
| 259d02d140 | |||
| 09d687b45b | |||
| 157680bb13 | |||
| 8ceb831e84 | |||
| f79af84afb | |||
| cf47ad444a |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -53,3 +53,6 @@ logs/
|
||||
|
||||
# Python scripts
|
||||
__pycache__/*
|
||||
|
||||
# Node.js (compliance tests)
|
||||
tests/node_modules/
|
||||
|
||||
17
README.md
17
README.md
@@ -1,4 +1,4 @@
|
||||
# Go LLM Gateway
|
||||
# latticelm
|
||||
|
||||
## Overview
|
||||
|
||||
@@ -11,6 +11,7 @@ Simplify LLM integration by exposing a single, consistent API that routes reques
|
||||
- **Azure OpenAI** (Azure-deployed models)
|
||||
- **Anthropic** (Claude)
|
||||
- **Google Generative AI** (Gemini)
|
||||
- **Vertex AI** (Google Cloud-hosted Gemini models)
|
||||
|
||||
Instead of managing multiple SDK integrations in your application, call one endpoint and let the gateway handle provider-specific implementations.
|
||||
|
||||
@@ -19,12 +20,13 @@ Instead of managing multiple SDK integrations in your application, call one endp
|
||||
```
|
||||
Client Request
|
||||
↓
|
||||
Go LLM Gateway (unified API)
|
||||
latticelm (unified API)
|
||||
↓
|
||||
├─→ OpenAI SDK
|
||||
├─→ Azure OpenAI (OpenAI SDK + Azure auth)
|
||||
├─→ Anthropic SDK
|
||||
└─→ Google Gen AI SDK
|
||||
├─→ Google Gen AI SDK
|
||||
└─→ Vertex AI (Google Gen AI SDK + GCP auth)
|
||||
```
|
||||
|
||||
## Key Features
|
||||
@@ -45,11 +47,12 @@ Go LLM Gateway (unified API)
|
||||
|
||||
## 🎉 Status: **WORKING!**
|
||||
|
||||
✅ **All four providers integrated with official Go SDKs:**
|
||||
- OpenAI → `github.com/openai/openai-go`
|
||||
- Azure OpenAI → `github.com/openai/openai-go` (with Azure auth)
|
||||
✅ **All providers integrated with official Go SDKs:**
|
||||
- OpenAI → `github.com/openai/openai-go/v3`
|
||||
- Azure OpenAI → `github.com/openai/openai-go/v3` (with Azure auth)
|
||||
- Anthropic → `github.com/anthropics/anthropic-sdk-go`
|
||||
- Google → `google.golang.org/genai`
|
||||
- Vertex AI → `google.golang.org/genai` (with GCP auth)
|
||||
|
||||
✅ **Compiles successfully** (36MB binary)
|
||||
✅ **Provider auto-selection** (gpt→Azure/OpenAI, claude→Anthropic, gemini→Google)
|
||||
@@ -68,7 +71,7 @@ export ANTHROPIC_API_KEY="your-key"
|
||||
export GOOGLE_API_KEY="your-key"
|
||||
|
||||
# 2. Build
|
||||
cd go-llm-gateway
|
||||
cd latticelm
|
||||
go build -o gateway ./cmd/gateway
|
||||
|
||||
# 3. Run
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"flag"
|
||||
"fmt"
|
||||
@@ -12,12 +13,13 @@ import (
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
_ "github.com/jackc/pgx/v5/stdlib"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/yourusername/go-llm-gateway/internal/auth"
|
||||
"github.com/yourusername/go-llm-gateway/internal/config"
|
||||
"github.com/yourusername/go-llm-gateway/internal/conversation"
|
||||
"github.com/yourusername/go-llm-gateway/internal/providers"
|
||||
"github.com/yourusername/go-llm-gateway/internal/server"
|
||||
"github.com/ajac-zero/latticelm/internal/auth"
|
||||
"github.com/ajac-zero/latticelm/internal/config"
|
||||
"github.com/ajac-zero/latticelm/internal/conversation"
|
||||
"github.com/ajac-zero/latticelm/internal/providers"
|
||||
"github.com/ajac-zero/latticelm/internal/server"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -112,6 +114,22 @@ func initConversationStore(cfg config.ConversationConfig, logger *log.Logger) (c
|
||||
}
|
||||
logger.Printf("Conversation store initialized (sql/%s, TTL: %s)", driver, ttl)
|
||||
return store, nil
|
||||
case "redis":
|
||||
opts, err := redis.ParseURL(cfg.DSN)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse redis dsn: %w", err)
|
||||
}
|
||||
client := redis.NewClient(opts)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := client.Ping(ctx).Err(); err != nil {
|
||||
return nil, fmt.Errorf("connect to redis: %w", err)
|
||||
}
|
||||
|
||||
logger.Printf("Conversation store initialized (redis, TTL: %s)", ttl)
|
||||
return conversation.NewRedisStore(client, ttl), nil
|
||||
default:
|
||||
logger.Printf("Conversation store initialized (memory, TTL: %s)", ttl)
|
||||
return conversation.NewMemoryStore(ttl), nil
|
||||
|
||||
@@ -14,6 +14,12 @@ providers:
|
||||
type: "openai"
|
||||
api_key: "YOUR_OPENAI_API_KEY"
|
||||
endpoint: "https://api.openai.com"
|
||||
# Vertex AI (Google Cloud) - optional
|
||||
# Uses Application Default Credentials (ADC) or service account
|
||||
# vertexai:
|
||||
# type: "vertexai"
|
||||
# project: "your-gcp-project-id"
|
||||
# location: "us-central1" # or other GCP region
|
||||
# Azure OpenAI - optional
|
||||
# azureopenai:
|
||||
# type: "azureopenai"
|
||||
@@ -27,16 +33,19 @@ providers:
|
||||
# endpoint: "https://your-resource.services.ai.azure.com/anthropic"
|
||||
|
||||
# conversations:
|
||||
# store: "sql" # "memory" (default) or "sql"
|
||||
# store: "sql" # "memory" (default), "sql", or "redis"
|
||||
# ttl: "1h" # conversation expiration (default: 1h)
|
||||
# driver: "sqlite3" # SQL driver: "sqlite3", "mysql", "pgx"
|
||||
# dsn: "conversations.db" # connection string
|
||||
# driver: "sqlite3" # SQL driver: "sqlite3", "mysql", "pgx" (required for sql store)
|
||||
# dsn: "conversations.db" # connection string (required for sql/redis store)
|
||||
# # MySQL example:
|
||||
# # driver: "mysql"
|
||||
# # dsn: "user:password@tcp(localhost:3306)/dbname?parseTime=true"
|
||||
# # PostgreSQL example:
|
||||
# # driver: "pgx"
|
||||
# # dsn: "postgres://user:password@localhost:5432/dbname?sslmode=disable"
|
||||
# # Redis example:
|
||||
# # store: "redis"
|
||||
# # dsn: "redis://:password@localhost:6379/0"
|
||||
|
||||
models:
|
||||
- name: "gemini-1.5-flash"
|
||||
@@ -45,6 +54,8 @@ models:
|
||||
provider: "anthropic"
|
||||
- name: "gpt-4o-mini"
|
||||
provider: "openai"
|
||||
# - name: "gemini-2.0-flash-exp"
|
||||
# provider: "vertexai" # Use Vertex AI instead of Google AI API
|
||||
# - name: "gpt-4o"
|
||||
# provider: "azureopenai"
|
||||
# provider_model_id: "my-gpt4o-deployment" # optional: defaults to name
|
||||
|
||||
17
go.mod
17
go.mod
@@ -1,11 +1,17 @@
|
||||
module github.com/yourusername/go-llm-gateway
|
||||
module github.com/ajac-zero/latticelm
|
||||
|
||||
go 1.25.7
|
||||
|
||||
require (
|
||||
github.com/anthropics/anthropic-sdk-go v1.26.0
|
||||
github.com/go-sql-driver/mysql v1.9.3
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jackc/pgx/v5 v5.8.0
|
||||
github.com/mattn/go-sqlite3 v1.14.34
|
||||
github.com/openai/openai-go v1.12.0
|
||||
github.com/openai/openai-go/v3 v3.2.0
|
||||
github.com/redis/go-redis/v9 v9.18.0
|
||||
google.golang.org/genai v1.48.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
@@ -15,11 +21,10 @@ require (
|
||||
cloud.google.com/go/auth v0.9.3 // indirect
|
||||
cloud.google.com/go/compute/metadata v0.5.0 // indirect
|
||||
filippo.io/edwards25519 v1.1.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.9.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect
|
||||
github.com/go-sql-driver/mysql v1.9.3 // indirect
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
||||
github.com/google/go-cmp v0.6.0 // indirect
|
||||
github.com/google/s2a-go v0.1.8 // indirect
|
||||
@@ -27,15 +32,13 @@ require (
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/pgx/v5 v5.8.0 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.34 // indirect
|
||||
github.com/openai/openai-go/v3 v3.2.0 // indirect
|
||||
github.com/tidwall/gjson v1.18.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
go.opencensus.io v0.24.0 // indirect
|
||||
go.uber.org/atomic v1.11.0 // indirect
|
||||
golang.org/x/crypto v0.47.0 // indirect
|
||||
golang.org/x/net v0.49.0 // indirect
|
||||
golang.org/x/sync v0.19.0 // indirect
|
||||
|
||||
48
go.sum
48
go.sum
@@ -7,21 +7,31 @@ cloud.google.com/go/compute/metadata v0.5.0 h1:Zr0eK8JbFv6+Wi4ilXAR8FJ3wyNdpxHKJ
|
||||
cloud.google.com/go/compute/metadata v0.5.0/go.mod h1:aHnloV2TPI38yx4s9+wAZhHykWvVCfu7hQbF+9CWoiY=
|
||||
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.9.0 h1:t/DLMixbb8ygU11RAHJ8quXwJD7FwlC7+u6XodmSi1w=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.9.0/go.mod h1:Bb4vy1c7tXIqFrypNxCO7I5xlDSbpQiOWu/XvF5htP8=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 h1:fou+2+WFTib47nS+nz/ozhEBnvU96bKHy6LjRsY4E28=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0/go.mod h1:t76Ruy8AHvUAC8GfMWJMa0ElSbuIcO03NLpynfbgsPA=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1 h1:B+blDbyVIG3WaikNxPnhPiJ1MThR03b3vKGtER95TP4=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1/go.mod h1:JdM5psgjfBf5fo2uWOZhflPWyDBZ/O/CNAH9CtsuZE4=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 h1:9iefClla7iYpfYWdzPCRDozdmndjTm8DXdpCzPajMgA=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2/go.mod h1:XtLgD3ZD34DAaVIIAyG3objl5DynM3CQ/vMcbBNJZGI=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 h1:oygO0locgZJe7PpYPXT5A29ZkwJaPqcva7BVeemZOZs=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI=
|
||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||
github.com/anthropics/anthropic-sdk-go v1.26.0 h1:oUTzFaUpAevfuELAP1sjL6CQJ9HHAfT7CoSYSac11PY=
|
||||
github.com/anthropics/anthropic-sdk-go v1.26.0/go.mod h1:qUKmaW+uuPB64iy1l+4kOSvaLqPXnHTTBKH6RVZ7q5Q=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
||||
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
|
||||
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
|
||||
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI=
|
||||
github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ=
|
||||
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
@@ -71,15 +81,29 @@ github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo=
|
||||
github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/mattn/go-sqlite3 v1.14.34 h1:3NtcvcUnFBPsuRcno8pUtupspG/GM+9nZ88zgJcp6Zk=
|
||||
github.com/mattn/go-sqlite3 v1.14.34/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/openai/openai-go v1.12.0 h1:NBQCnXzqOTv5wsgNC36PrFEiskGfO5wccfCWDo9S1U0=
|
||||
github.com/openai/openai-go v1.12.0/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y=
|
||||
github.com/openai/openai-go/v3 v3.2.0 h1:2AbqFUCsoW2pm/2pUtPRuwK89dnoGHaQokzWsfoQO/U=
|
||||
github.com/openai/openai-go/v3 v3.2.0/go.mod h1:UOpNxkqC9OdNXNUfpNByKOtB4jAL0EssQXq5p8gO0Xs=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/redis/go-redis/v9 v9.18.0 h1:pMkxYPkEbMPwRdenAzUNyFNrDgHx9U+DrBabWNfSRQs=
|
||||
github.com/redis/go-redis/v9 v9.18.0/go.mod h1:k3ufPphLU5YXwNTUcCRXGxUoF1fqxnhFQmscfkCoDA0=
|
||||
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
|
||||
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
@@ -88,9 +112,8 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
@@ -101,12 +124,14 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0=
|
||||
github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA=
|
||||
go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0=
|
||||
go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo=
|
||||
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
|
||||
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM=
|
||||
golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY=
|
||||
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
|
||||
golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
@@ -119,30 +144,22 @@ golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73r
|
||||
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
|
||||
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
|
||||
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
|
||||
golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
||||
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
|
||||
golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
|
||||
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4=
|
||||
golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
|
||||
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
|
||||
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
@@ -178,8 +195,9 @@ google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpAD
|
||||
google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
|
||||
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
|
||||
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
|
||||
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
@@ -96,6 +96,7 @@ type InputItem struct {
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content []ContentBlock `json:"content"`
|
||||
CallID string `json:"call_id,omitempty"` // for tool messages
|
||||
}
|
||||
|
||||
// ContentBlock is a typed content element.
|
||||
@@ -138,6 +139,7 @@ func (r *ResponseRequest) NormalizeInput() []Message {
|
||||
msgs = append(msgs, Message{
|
||||
Role: "tool",
|
||||
Content: []ContentBlock{{Type: "input_text", Text: item.Output}},
|
||||
CallID: item.CallID,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -188,11 +190,14 @@ type Response struct {
|
||||
|
||||
// OutputItem represents a typed item in the response output.
|
||||
type OutputItem struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Status string `json:"status"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Content []ContentPart `json:"content,omitempty"`
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Status string `json:"status"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Content []ContentPart `json:"content,omitempty"`
|
||||
CallID string `json:"call_id,omitempty"` // for function_call
|
||||
Name string `json:"name,omitempty"` // for function_call
|
||||
Arguments string `json:"arguments,omitempty"` // for function_call
|
||||
}
|
||||
|
||||
// ContentPart is a content block within an output item.
|
||||
@@ -259,6 +264,7 @@ type StreamEvent struct {
|
||||
Part *ContentPart `json:"part,omitempty"`
|
||||
Delta string `json:"delta,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Arguments string `json:"arguments,omitempty"` // for function_call_arguments.done
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
@@ -267,19 +273,36 @@ type StreamEvent struct {
|
||||
|
||||
// ProviderResult is returned by Provider.Generate.
|
||||
type ProviderResult struct {
|
||||
ID string
|
||||
Model string
|
||||
Text string
|
||||
Usage Usage
|
||||
ID string
|
||||
Model string
|
||||
Text string
|
||||
Usage Usage
|
||||
ToolCalls []ToolCall
|
||||
}
|
||||
|
||||
// ProviderStreamDelta is sent through the stream channel.
|
||||
type ProviderStreamDelta struct {
|
||||
ID string
|
||||
Model string
|
||||
Text string
|
||||
Done bool
|
||||
Usage *Usage
|
||||
ID string
|
||||
Model string
|
||||
Text string
|
||||
Done bool
|
||||
Usage *Usage
|
||||
ToolCallDelta *ToolCallDelta
|
||||
}
|
||||
|
||||
// ToolCall represents a function call from the model.
|
||||
type ToolCall struct {
|
||||
ID string
|
||||
Name string
|
||||
Arguments string // JSON string
|
||||
}
|
||||
|
||||
// ToolCallDelta represents a streaming chunk of a tool call.
|
||||
type ToolCallDelta struct {
|
||||
Index int
|
||||
ID string
|
||||
Name string
|
||||
Arguments string
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
|
||||
@@ -18,12 +18,12 @@ type Config struct {
|
||||
|
||||
// ConversationConfig controls conversation storage.
|
||||
type ConversationConfig struct {
|
||||
// Store is the storage backend: "memory" (default) or "sql".
|
||||
// Store is the storage backend: "memory" (default), "sql", or "redis".
|
||||
Store string `yaml:"store"`
|
||||
// TTL is the conversation expiration duration (e.g. "1h", "30m"). Defaults to "1h".
|
||||
TTL string `yaml:"ttl"`
|
||||
// DSN is the database connection string, required when store is "sql".
|
||||
// Examples: "conversations.db" (SQLite), "postgres://user:pass@host/db".
|
||||
// DSN is the database/Redis connection string, required when store is "sql" or "redis".
|
||||
// Examples: "conversations.db" (SQLite), "postgres://user:pass@host/db", "redis://:password@localhost:6379/0".
|
||||
DSN string `yaml:"dsn"`
|
||||
// Driver is the SQL driver name, required when store is "sql".
|
||||
// Examples: "sqlite3", "postgres", "mysql".
|
||||
@@ -48,6 +48,8 @@ type ProviderEntry struct {
|
||||
APIKey string `yaml:"api_key"`
|
||||
Endpoint string `yaml:"endpoint"`
|
||||
APIVersion string `yaml:"api_version"`
|
||||
Project string `yaml:"project"` // For Vertex AI
|
||||
Location string `yaml:"location"` // For Vertex AI
|
||||
}
|
||||
|
||||
// ModelEntry maps a model name to a provider entry.
|
||||
@@ -78,6 +80,12 @@ type AzureAnthropicConfig struct {
|
||||
Model string `yaml:"model"`
|
||||
}
|
||||
|
||||
// VertexAIConfig contains Vertex AI-specific settings used internally by the Google provider.
|
||||
type VertexAIConfig struct {
|
||||
Project string `yaml:"project"`
|
||||
Location string `yaml:"location"`
|
||||
}
|
||||
|
||||
// Load reads and parses a YAML configuration file, expanding ${VAR} env references.
|
||||
func Load(path string) (*Config, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
|
||||
@@ -4,15 +4,15 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/yourusername/go-llm-gateway/internal/api"
|
||||
"github.com/ajac-zero/latticelm/internal/api"
|
||||
)
|
||||
|
||||
// Store defines the interface for conversation storage backends.
|
||||
type Store interface {
|
||||
Get(id string) (*Conversation, bool)
|
||||
Create(id string, model string, messages []api.Message) *Conversation
|
||||
Append(id string, messages ...api.Message) (*Conversation, bool)
|
||||
Delete(id string)
|
||||
Get(id string) (*Conversation, error)
|
||||
Create(id string, model string, messages []api.Message) (*Conversation, error)
|
||||
Append(id string, messages ...api.Message) (*Conversation, error)
|
||||
Delete(id string) error
|
||||
Size() int
|
||||
}
|
||||
|
||||
@@ -47,55 +47,93 @@ func NewMemoryStore(ttl time.Duration) *MemoryStore {
|
||||
return s
|
||||
}
|
||||
|
||||
// Get retrieves a conversation by ID.
|
||||
func (s *MemoryStore) Get(id string) (*Conversation, bool) {
|
||||
// Get retrieves a conversation by ID. Returns a deep copy to prevent data races.
|
||||
func (s *MemoryStore) Get(id string) (*Conversation, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
conv, ok := s.conversations[id]
|
||||
return conv, ok
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Return a deep copy to prevent data races
|
||||
msgsCopy := make([]api.Message, len(conv.Messages))
|
||||
copy(msgsCopy, conv.Messages)
|
||||
|
||||
return &Conversation{
|
||||
ID: conv.ID,
|
||||
Messages: msgsCopy,
|
||||
Model: conv.Model,
|
||||
CreatedAt: conv.CreatedAt,
|
||||
UpdatedAt: conv.UpdatedAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Create creates a new conversation with the given messages.
|
||||
func (s *MemoryStore) Create(id string, model string, messages []api.Message) *Conversation {
|
||||
func (s *MemoryStore) Create(id string, model string, messages []api.Message) (*Conversation, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// Store a copy to prevent external modifications
|
||||
msgsCopy := make([]api.Message, len(messages))
|
||||
copy(msgsCopy, messages)
|
||||
|
||||
conv := &Conversation{
|
||||
ID: id,
|
||||
Messages: messages,
|
||||
Messages: msgsCopy,
|
||||
Model: model,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
s.conversations[id] = conv
|
||||
return conv
|
||||
|
||||
// Return a copy
|
||||
return &Conversation{
|
||||
ID: id,
|
||||
Messages: messages,
|
||||
Model: model,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Append adds new messages to an existing conversation.
|
||||
func (s *MemoryStore) Append(id string, messages ...api.Message) (*Conversation, bool) {
|
||||
func (s *MemoryStore) Append(id string, messages ...api.Message) (*Conversation, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
conv, ok := s.conversations[id]
|
||||
if !ok {
|
||||
return nil, false
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
conv.Messages = append(conv.Messages, messages...)
|
||||
conv.UpdatedAt = time.Now()
|
||||
|
||||
return conv, true
|
||||
// Return a deep copy
|
||||
msgsCopy := make([]api.Message, len(conv.Messages))
|
||||
copy(msgsCopy, conv.Messages)
|
||||
|
||||
return &Conversation{
|
||||
ID: conv.ID,
|
||||
Messages: msgsCopy,
|
||||
Model: conv.Model,
|
||||
CreatedAt: conv.CreatedAt,
|
||||
UpdatedAt: conv.UpdatedAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Delete removes a conversation from the store.
|
||||
func (s *MemoryStore) Delete(id string) {
|
||||
func (s *MemoryStore) Delete(id string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
delete(s.conversations, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanup periodically removes expired conversations.
|
||||
|
||||
124
internal/conversation/redis_store.go
Normal file
124
internal/conversation/redis_store.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/ajac-zero/latticelm/internal/api"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// RedisStore manages conversation history in Redis with automatic expiration.
|
||||
type RedisStore struct {
|
||||
client *redis.Client
|
||||
ttl time.Duration
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewRedisStore creates a Redis-backed conversation store.
|
||||
func NewRedisStore(client *redis.Client, ttl time.Duration) *RedisStore {
|
||||
return &RedisStore{
|
||||
client: client,
|
||||
ttl: ttl,
|
||||
ctx: context.Background(),
|
||||
}
|
||||
}
|
||||
|
||||
// key returns the Redis key for a conversation ID.
|
||||
func (s *RedisStore) key(id string) string {
|
||||
return "conv:" + id
|
||||
}
|
||||
|
||||
// Get retrieves a conversation by ID from Redis.
|
||||
func (s *RedisStore) Get(id string) (*Conversation, error) {
|
||||
data, err := s.client.Get(s.ctx, s.key(id)).Bytes()
|
||||
if err == redis.Nil {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var conv Conversation
|
||||
if err := json.Unmarshal(data, &conv); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &conv, nil
|
||||
}
|
||||
|
||||
// Create creates a new conversation with the given messages.
|
||||
func (s *RedisStore) Create(id string, model string, messages []api.Message) (*Conversation, error) {
|
||||
now := time.Now()
|
||||
conv := &Conversation{
|
||||
ID: id,
|
||||
Messages: messages,
|
||||
Model: model,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(conv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.client.Set(s.ctx, s.key(id), data, s.ttl).Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conv, nil
|
||||
}
|
||||
|
||||
// Append adds new messages to an existing conversation.
|
||||
func (s *RedisStore) Append(id string, messages ...api.Message) (*Conversation, error) {
|
||||
conv, err := s.Get(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if conv == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
conv.Messages = append(conv.Messages, messages...)
|
||||
conv.UpdatedAt = time.Now()
|
||||
|
||||
data, err := json.Marshal(conv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.client.Set(s.ctx, s.key(id), data, s.ttl).Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conv, nil
|
||||
}
|
||||
|
||||
// Delete removes a conversation from Redis.
|
||||
func (s *RedisStore) Delete(id string) error {
|
||||
return s.client.Del(s.ctx, s.key(id)).Err()
|
||||
}
|
||||
|
||||
// Size returns the number of active conversations in Redis.
|
||||
func (s *RedisStore) Size() int {
|
||||
var count int
|
||||
var cursor uint64
|
||||
|
||||
for {
|
||||
keys, nextCursor, err := s.client.Scan(s.ctx, cursor, "conv:*", 100).Result()
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
count += len(keys)
|
||||
cursor = nextCursor
|
||||
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/yourusername/go-llm-gateway/internal/api"
|
||||
"github.com/ajac-zero/latticelm/internal/api"
|
||||
)
|
||||
|
||||
// sqlDialect holds driver-specific SQL statements.
|
||||
@@ -65,28 +65,36 @@ func NewSQLStore(db *sql.DB, driver string, ttl time.Duration) (*SQLStore, error
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *SQLStore) Get(id string) (*Conversation, bool) {
|
||||
func (s *SQLStore) Get(id string) (*Conversation, error) {
|
||||
row := s.db.QueryRow(s.dialect.getByID, id)
|
||||
|
||||
var conv Conversation
|
||||
var msgJSON string
|
||||
err := row.Scan(&conv.ID, &conv.Model, &msgJSON, &conv.CreatedAt, &conv.UpdatedAt)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, false
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(msgJSON), &conv.Messages); err != nil {
|
||||
return nil, false
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &conv, true
|
||||
return &conv, nil
|
||||
}
|
||||
|
||||
func (s *SQLStore) Create(id string, model string, messages []api.Message) *Conversation {
|
||||
func (s *SQLStore) Create(id string, model string, messages []api.Message) (*Conversation, error) {
|
||||
now := time.Now()
|
||||
msgJSON, _ := json.Marshal(messages)
|
||||
msgJSON, err := json.Marshal(messages)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, _ = s.db.Exec(s.dialect.upsert, id, model, string(msgJSON), now, now)
|
||||
if _, err := s.db.Exec(s.dialect.upsert, id, model, string(msgJSON), now, now); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Conversation{
|
||||
ID: id,
|
||||
@@ -94,26 +102,36 @@ func (s *SQLStore) Create(id string, model string, messages []api.Message) *Conv
|
||||
Model: model,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SQLStore) Append(id string, messages ...api.Message) (*Conversation, bool) {
|
||||
conv, ok := s.Get(id)
|
||||
if !ok {
|
||||
return nil, false
|
||||
func (s *SQLStore) Append(id string, messages ...api.Message) (*Conversation, error) {
|
||||
conv, err := s.Get(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if conv == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
conv.Messages = append(conv.Messages, messages...)
|
||||
conv.UpdatedAt = time.Now()
|
||||
|
||||
msgJSON, _ := json.Marshal(conv.Messages)
|
||||
_, _ = s.db.Exec(s.dialect.update, string(msgJSON), conv.UpdatedAt, id)
|
||||
msgJSON, err := json.Marshal(conv.Messages)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conv, true
|
||||
if _, err := s.db.Exec(s.dialect.update, string(msgJSON), conv.UpdatedAt, id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conv, nil
|
||||
}
|
||||
|
||||
func (s *SQLStore) Delete(id string) {
|
||||
_, _ = s.db.Exec(s.dialect.deleteByID, id)
|
||||
func (s *SQLStore) Delete(id string) error {
|
||||
_, err := s.db.Exec(s.dialect.deleteByID, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SQLStore) Size() int {
|
||||
|
||||
@@ -2,13 +2,14 @@ package anthropic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
"github.com/anthropics/anthropic-sdk-go/option"
|
||||
|
||||
"github.com/yourusername/go-llm-gateway/internal/api"
|
||||
"github.com/yourusername/go-llm-gateway/internal/config"
|
||||
"github.com/ajac-zero/latticelm/internal/api"
|
||||
"github.com/ajac-zero/latticelm/internal/config"
|
||||
)
|
||||
|
||||
const Name = "anthropic"
|
||||
@@ -85,6 +86,11 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
|
||||
anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage(anthropic.NewTextBlock(content)))
|
||||
case "assistant":
|
||||
anthropicMsgs = append(anthropicMsgs, anthropic.NewAssistantMessage(anthropic.NewTextBlock(content)))
|
||||
case "tool":
|
||||
// Tool results must be in user message with tool_result blocks
|
||||
anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage(
|
||||
anthropic.NewToolResultBlock(msg.CallID, content, false),
|
||||
))
|
||||
case "system", "developer":
|
||||
system = content
|
||||
}
|
||||
@@ -116,24 +122,55 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
|
||||
params.TopP = anthropic.Float(*req.TopP)
|
||||
}
|
||||
|
||||
// Add tools if present
|
||||
if req.Tools != nil && len(req.Tools) > 0 {
|
||||
tools, err := parseTools(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse tools: %w", err)
|
||||
}
|
||||
params.Tools = tools
|
||||
}
|
||||
|
||||
// Add tool_choice if present
|
||||
if req.ToolChoice != nil && len(req.ToolChoice) > 0 {
|
||||
toolChoice, err := parseToolChoice(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse tool_choice: %w", err)
|
||||
}
|
||||
params.ToolChoice = toolChoice
|
||||
}
|
||||
|
||||
// Call Anthropic API
|
||||
resp, err := p.client.Messages.New(ctx, params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("anthropic api error: %w", err)
|
||||
}
|
||||
|
||||
// Extract text from response
|
||||
// Extract text and tool calls from response
|
||||
var text string
|
||||
var toolCalls []api.ToolCall
|
||||
|
||||
for _, block := range resp.Content {
|
||||
if block.Type == "text" {
|
||||
text += block.Text
|
||||
switch block.Type {
|
||||
case "text":
|
||||
text += block.AsText().Text
|
||||
case "tool_use":
|
||||
// Extract tool calls
|
||||
toolUse := block.AsToolUse()
|
||||
argsJSON, _ := json.Marshal(toolUse.Input)
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
ID: toolUse.ID,
|
||||
Name: toolUse.Name,
|
||||
Arguments: string(argsJSON),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return &api.ProviderResult{
|
||||
ID: resp.ID,
|
||||
Model: string(resp.Model),
|
||||
Text: text,
|
||||
ID: resp.ID,
|
||||
Model: string(resp.Model),
|
||||
Text: text,
|
||||
ToolCalls: toolCalls,
|
||||
Usage: api.Usage{
|
||||
InputTokens: int(resp.Usage.InputTokens),
|
||||
OutputTokens: int(resp.Usage.OutputTokens),
|
||||
@@ -177,6 +214,11 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
|
||||
anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage(anthropic.NewTextBlock(content)))
|
||||
case "assistant":
|
||||
anthropicMsgs = append(anthropicMsgs, anthropic.NewAssistantMessage(anthropic.NewTextBlock(content)))
|
||||
case "tool":
|
||||
// Tool results must be in user message with tool_result blocks
|
||||
anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage(
|
||||
anthropic.NewToolResultBlock(msg.CallID, content, false),
|
||||
))
|
||||
case "system", "developer":
|
||||
system = content
|
||||
}
|
||||
@@ -208,19 +250,77 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
|
||||
params.TopP = anthropic.Float(*req.TopP)
|
||||
}
|
||||
|
||||
// Add tools if present
|
||||
if req.Tools != nil && len(req.Tools) > 0 {
|
||||
tools, err := parseTools(req)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("parse tools: %w", err)
|
||||
return
|
||||
}
|
||||
params.Tools = tools
|
||||
}
|
||||
|
||||
// Add tool_choice if present
|
||||
if req.ToolChoice != nil && len(req.ToolChoice) > 0 {
|
||||
toolChoice, err := parseToolChoice(req)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("parse tool_choice: %w", err)
|
||||
return
|
||||
}
|
||||
params.ToolChoice = toolChoice
|
||||
}
|
||||
|
||||
// Create stream
|
||||
stream := p.client.Messages.NewStreaming(ctx, params)
|
||||
|
||||
// Track content block index and tool call state
|
||||
var contentBlockIndex int
|
||||
|
||||
// Process stream
|
||||
for stream.Next() {
|
||||
event := stream.Current()
|
||||
|
||||
if event.Type == "content_block_delta" && event.Delta.Type == "text_delta" {
|
||||
select {
|
||||
case deltaChan <- &api.ProviderStreamDelta{Text: event.Delta.Text}:
|
||||
case <-ctx.Done():
|
||||
errChan <- ctx.Err()
|
||||
return
|
||||
switch event.Type {
|
||||
case "content_block_start":
|
||||
// New content block (text or tool_use)
|
||||
contentBlockIndex = int(event.Index)
|
||||
if event.ContentBlock.Type == "tool_use" {
|
||||
// Send tool call delta with ID and name
|
||||
toolUse := event.ContentBlock.AsToolUse()
|
||||
delta := &api.ToolCallDelta{
|
||||
Index: contentBlockIndex,
|
||||
ID: toolUse.ID,
|
||||
Name: toolUse.Name,
|
||||
}
|
||||
select {
|
||||
case deltaChan <- &api.ProviderStreamDelta{ToolCallDelta: delta}:
|
||||
case <-ctx.Done():
|
||||
errChan <- ctx.Err()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
case "content_block_delta":
|
||||
if event.Delta.Type == "text_delta" {
|
||||
// Text streaming
|
||||
select {
|
||||
case deltaChan <- &api.ProviderStreamDelta{Text: event.Delta.Text}:
|
||||
case <-ctx.Done():
|
||||
errChan <- ctx.Err()
|
||||
return
|
||||
}
|
||||
} else if event.Delta.Type == "input_json_delta" {
|
||||
// Tool arguments streaming
|
||||
delta := &api.ToolCallDelta{
|
||||
Index: int(event.Index),
|
||||
Arguments: event.Delta.PartialJSON,
|
||||
}
|
||||
select {
|
||||
case deltaChan <- &api.ProviderStreamDelta{ToolCallDelta: delta}:
|
||||
case <-ctx.Done():
|
||||
errChan <- ctx.Err()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
154
internal/providers/anthropic/convert.go
Normal file
154
internal/providers/anthropic/convert.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/ajac-zero/latticelm/internal/api"
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
)
|
||||
|
||||
// parseTools converts Open Responses tools to Anthropic format
|
||||
func parseTools(req *api.ResponseRequest) ([]anthropic.ToolUnionParam, error) {
|
||||
if req.Tools == nil || len(req.Tools) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var toolDefs []map[string]interface{}
|
||||
if err := json.Unmarshal(req.Tools, &toolDefs); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal tools: %w", err)
|
||||
}
|
||||
|
||||
var tools []anthropic.ToolUnionParam
|
||||
for _, td := range toolDefs {
|
||||
// Extract: name, description, parameters
|
||||
// Note: Anthropic uses "input_schema" instead of "parameters"
|
||||
name, _ := td["name"].(string)
|
||||
desc, _ := td["description"].(string)
|
||||
params, _ := td["parameters"].(map[string]interface{})
|
||||
|
||||
inputSchema := anthropic.ToolInputSchemaParam{
|
||||
Type: "object",
|
||||
Properties: params["properties"],
|
||||
}
|
||||
|
||||
// Add required fields if present
|
||||
if required, ok := params["required"].([]interface{}); ok {
|
||||
requiredStrs := make([]string, 0, len(required))
|
||||
for _, r := range required {
|
||||
if str, ok := r.(string); ok {
|
||||
requiredStrs = append(requiredStrs, str)
|
||||
}
|
||||
}
|
||||
inputSchema.Required = requiredStrs
|
||||
}
|
||||
|
||||
// Create the tool using ToolUnionParamOfTool
|
||||
tool := anthropic.ToolUnionParamOfTool(inputSchema, name)
|
||||
|
||||
if desc != "" {
|
||||
tool.OfTool.Description = anthropic.String(desc)
|
||||
}
|
||||
|
||||
tools = append(tools, tool)
|
||||
}
|
||||
|
||||
return tools, nil
|
||||
}
|
||||
|
||||
// parseToolChoice converts Open Responses tool_choice to Anthropic format
|
||||
func parseToolChoice(req *api.ResponseRequest) (anthropic.ToolChoiceUnionParam, error) {
|
||||
var result anthropic.ToolChoiceUnionParam
|
||||
|
||||
if req.ToolChoice == nil || len(req.ToolChoice) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
var choice interface{}
|
||||
if err := json.Unmarshal(req.ToolChoice, &choice); err != nil {
|
||||
return result, fmt.Errorf("unmarshal tool_choice: %w", err)
|
||||
}
|
||||
|
||||
// Handle string values: "auto", "any", "required"
|
||||
if str, ok := choice.(string); ok {
|
||||
switch str {
|
||||
case "auto":
|
||||
result.OfAuto = &anthropic.ToolChoiceAutoParam{
|
||||
Type: "auto",
|
||||
}
|
||||
case "any", "required":
|
||||
result.OfAny = &anthropic.ToolChoiceAnyParam{
|
||||
Type: "any",
|
||||
}
|
||||
case "none":
|
||||
result.OfNone = &anthropic.ToolChoiceNoneParam{
|
||||
Type: "none",
|
||||
}
|
||||
default:
|
||||
return result, fmt.Errorf("unknown tool_choice string: %s", str)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Handle specific tool selection: {"type": "tool", "function": {"name": "..."}}
|
||||
if obj, ok := choice.(map[string]interface{}); ok {
|
||||
// Check for OpenAI format: {"type": "function", "function": {"name": "..."}}
|
||||
if funcObj, ok := obj["function"].(map[string]interface{}); ok {
|
||||
if name, ok := funcObj["name"].(string); ok {
|
||||
result.OfTool = &anthropic.ToolChoiceToolParam{
|
||||
Type: "tool",
|
||||
Name: name,
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Check for direct name field
|
||||
if name, ok := obj["name"].(string); ok {
|
||||
result.OfTool = &anthropic.ToolChoiceToolParam{
|
||||
Type: "tool",
|
||||
Name: name,
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
return result, fmt.Errorf("invalid tool_choice format")
|
||||
}
|
||||
|
||||
// extractToolCalls converts Anthropic content blocks to api.ToolCall
|
||||
func extractToolCalls(content []anthropic.ContentBlockUnion) []api.ToolCall {
|
||||
var toolCalls []api.ToolCall
|
||||
|
||||
for _, block := range content {
|
||||
// Check if this is a tool_use block
|
||||
if block.Type == "tool_use" {
|
||||
// Cast to ToolUseBlock to access the fields
|
||||
toolUse := block.AsToolUse()
|
||||
|
||||
// Marshal the input to JSON string for Arguments
|
||||
argsJSON, _ := json.Marshal(toolUse.Input)
|
||||
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
ID: toolUse.ID,
|
||||
Name: toolUse.Name,
|
||||
Arguments: string(argsJSON),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
// extractToolCallDelta extracts tool call delta from streaming content block delta
|
||||
func extractToolCallDelta(delta anthropic.RawContentBlockDeltaUnion, index int) *api.ToolCallDelta {
|
||||
// Check if this is an input_json_delta (streaming tool arguments)
|
||||
if delta.Type == "input_json_delta" {
|
||||
return &api.ToolCallDelta{
|
||||
Index: index,
|
||||
Arguments: delta.PartialJSON,
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
119
internal/providers/anthropic/convert_test.go
Normal file
119
internal/providers/anthropic/convert_test.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/ajac-zero/latticelm/internal/api"
|
||||
)
|
||||
|
||||
func TestParseTools(t *testing.T) {
|
||||
// Create a sample tool definition
|
||||
toolsJSON := `[{
|
||||
"type": "function",
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather for a location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state"
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}]`
|
||||
|
||||
req := &api.ResponseRequest{
|
||||
Tools: json.RawMessage(toolsJSON),
|
||||
}
|
||||
|
||||
tools, err := parseTools(req)
|
||||
if err != nil {
|
||||
t.Fatalf("parseTools failed: %v", err)
|
||||
}
|
||||
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("expected 1 tool, got %d", len(tools))
|
||||
}
|
||||
|
||||
tool := tools[0]
|
||||
if tool.OfTool == nil {
|
||||
t.Fatal("expected OfTool to be set")
|
||||
}
|
||||
|
||||
if tool.OfTool.Name != "get_weather" {
|
||||
t.Errorf("expected name 'get_weather', got '%s'", tool.OfTool.Name)
|
||||
}
|
||||
|
||||
desc := tool.GetDescription()
|
||||
if desc == nil || *desc != "Get the weather for a location" {
|
||||
t.Errorf("expected description 'Get the weather for a location', got '%v'", desc)
|
||||
}
|
||||
|
||||
if len(tool.OfTool.InputSchema.Required) != 1 || tool.OfTool.InputSchema.Required[0] != "location" {
|
||||
t.Errorf("expected required=['location'], got %v", tool.OfTool.InputSchema.Required)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolChoice(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
choiceJSON string
|
||||
expectAuto bool
|
||||
expectAny bool
|
||||
expectTool bool
|
||||
expectedName string
|
||||
}{
|
||||
{
|
||||
name: "auto",
|
||||
choiceJSON: `"auto"`,
|
||||
expectAuto: true,
|
||||
},
|
||||
{
|
||||
name: "any",
|
||||
choiceJSON: `"any"`,
|
||||
expectAny: true,
|
||||
},
|
||||
{
|
||||
name: "required",
|
||||
choiceJSON: `"required"`,
|
||||
expectAny: true,
|
||||
},
|
||||
{
|
||||
name: "specific tool",
|
||||
choiceJSON: `{"type": "function", "function": {"name": "get_weather"}}`,
|
||||
expectTool: true,
|
||||
expectedName: "get_weather",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := &api.ResponseRequest{
|
||||
ToolChoice: json.RawMessage(tt.choiceJSON),
|
||||
}
|
||||
|
||||
choice, err := parseToolChoice(req)
|
||||
if err != nil {
|
||||
t.Fatalf("parseToolChoice failed: %v", err)
|
||||
}
|
||||
|
||||
if tt.expectAuto && choice.OfAuto == nil {
|
||||
t.Error("expected OfAuto to be set")
|
||||
}
|
||||
if tt.expectAny && choice.OfAny == nil {
|
||||
t.Error("expected OfAny to be set")
|
||||
}
|
||||
if tt.expectTool {
|
||||
if choice.OfTool == nil {
|
||||
t.Fatal("expected OfTool to be set")
|
||||
}
|
||||
if choice.OfTool.Name != tt.expectedName {
|
||||
t.Errorf("expected name '%s', got '%s'", tt.expectedName, choice.OfTool.Name)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
212
internal/providers/google/convert.go
Normal file
212
internal/providers/google/convert.go
Normal file
@@ -0,0 +1,212 @@
|
||||
package google
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"google.golang.org/genai"
|
||||
|
||||
"github.com/ajac-zero/latticelm/internal/api"
|
||||
)
|
||||
|
||||
// parseTools converts generic tool definitions from req.Tools (JSON) to Google's []*genai.Tool format.
|
||||
func parseTools(req *api.ResponseRequest) ([]*genai.Tool, error) {
|
||||
if req.Tools == nil || len(req.Tools) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Unmarshal to slice of tool definitions
|
||||
var toolDefs []map[string]interface{}
|
||||
if err := json.Unmarshal(req.Tools, &toolDefs); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal tools: %w", err)
|
||||
}
|
||||
|
||||
var functionDeclarations []*genai.FunctionDeclaration
|
||||
|
||||
for _, toolDef := range toolDefs {
|
||||
// Extract function details
|
||||
// Support both flat format (name/description/parameters at top level)
|
||||
// and nested format (under "function" key)
|
||||
var name, description string
|
||||
var parameters interface{}
|
||||
|
||||
if functionData, ok := toolDef["function"].(map[string]interface{}); ok {
|
||||
// Nested format: {"type": "function", "function": {...}}
|
||||
name, _ = functionData["name"].(string)
|
||||
description, _ = functionData["description"].(string)
|
||||
parameters = functionData["parameters"]
|
||||
} else {
|
||||
// Flat format: {"type": "function", "name": "...", ...}
|
||||
name, _ = toolDef["name"].(string)
|
||||
description, _ = toolDef["description"].(string)
|
||||
parameters = toolDef["parameters"]
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Create function declaration
|
||||
funcDecl := &genai.FunctionDeclaration{
|
||||
Name: name,
|
||||
Description: description,
|
||||
}
|
||||
|
||||
// Google accepts parameters as raw JSON schema
|
||||
if parameters != nil {
|
||||
funcDecl.ParametersJsonSchema = parameters
|
||||
}
|
||||
|
||||
functionDeclarations = append(functionDeclarations, funcDecl)
|
||||
}
|
||||
|
||||
// Return single Tool with all function declarations
|
||||
if len(functionDeclarations) > 0 {
|
||||
return []*genai.Tool{{FunctionDeclarations: functionDeclarations}}, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// parseToolChoice converts req.ToolChoice to Google's ToolConfig with FunctionCallingConfig.
|
||||
func parseToolChoice(req *api.ResponseRequest) (*genai.ToolConfig, error) {
|
||||
if req.ToolChoice == nil || len(req.ToolChoice) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var choice interface{}
|
||||
if err := json.Unmarshal(req.ToolChoice, &choice); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal tool_choice: %w", err)
|
||||
}
|
||||
|
||||
config := &genai.ToolConfig{
|
||||
FunctionCallingConfig: &genai.FunctionCallingConfig{},
|
||||
}
|
||||
|
||||
// Handle string values: "auto", "none", "required"/"any"
|
||||
if str, ok := choice.(string); ok {
|
||||
switch str {
|
||||
case "auto":
|
||||
config.FunctionCallingConfig.Mode = genai.FunctionCallingConfigModeAuto
|
||||
case "none":
|
||||
config.FunctionCallingConfig.Mode = genai.FunctionCallingConfigModeNone
|
||||
case "required", "any":
|
||||
config.FunctionCallingConfig.Mode = genai.FunctionCallingConfigModeAny
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown tool_choice string: %s", str)
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// Handle object format: {"type": "function", "function": {"name": "..."}}
|
||||
if obj, ok := choice.(map[string]interface{}); ok {
|
||||
if typeVal, ok := obj["type"].(string); ok && typeVal == "function" {
|
||||
config.FunctionCallingConfig.Mode = genai.FunctionCallingConfigModeAny
|
||||
if funcObj, ok := obj["function"].(map[string]interface{}); ok {
|
||||
if name, ok := funcObj["name"].(string); ok {
|
||||
config.FunctionCallingConfig.AllowedFunctionNames = []string{name}
|
||||
}
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported tool_choice format")
|
||||
}
|
||||
|
||||
// extractToolCalls extracts tool calls from Google's response format to generic api.ToolCall slice.
|
||||
func extractToolCalls(resp *genai.GenerateContentResponse) []api.ToolCall {
|
||||
var toolCalls []api.ToolCall
|
||||
|
||||
for _, candidate := range resp.Candidates {
|
||||
if candidate.Content == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part == nil || part.FunctionCall == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract function call details
|
||||
fc := part.FunctionCall
|
||||
|
||||
// Marshal arguments to JSON string
|
||||
var argsJSON string
|
||||
if fc.Args != nil {
|
||||
argsBytes, err := json.Marshal(fc.Args)
|
||||
if err == nil {
|
||||
argsJSON = string(argsBytes)
|
||||
} else {
|
||||
// Fallback to empty object
|
||||
argsJSON = "{}"
|
||||
}
|
||||
} else {
|
||||
argsJSON = "{}"
|
||||
}
|
||||
|
||||
// Generate ID if Google doesn't provide one
|
||||
callID := fc.ID
|
||||
if callID == "" {
|
||||
callID = fmt.Sprintf("call_%s", generateRandomID())
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
ID: callID,
|
||||
Name: fc.Name,
|
||||
Arguments: argsJSON,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
// extractToolCallDelta extracts streaming tool call information from response parts.
|
||||
func extractToolCallDelta(part *genai.Part, index int) *api.ToolCallDelta {
|
||||
if part == nil || part.FunctionCall == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
fc := part.FunctionCall
|
||||
|
||||
// Marshal arguments to JSON string
|
||||
var argsJSON string
|
||||
if fc.Args != nil {
|
||||
argsBytes, err := json.Marshal(fc.Args)
|
||||
if err == nil {
|
||||
argsJSON = string(argsBytes)
|
||||
} else {
|
||||
argsJSON = "{}"
|
||||
}
|
||||
} else {
|
||||
argsJSON = "{}"
|
||||
}
|
||||
|
||||
// Generate ID if Google doesn't provide one
|
||||
callID := fc.ID
|
||||
if callID == "" {
|
||||
callID = fmt.Sprintf("call_%s", generateRandomID())
|
||||
}
|
||||
|
||||
return &api.ToolCallDelta{
|
||||
Index: index,
|
||||
ID: callID,
|
||||
Name: fc.Name,
|
||||
Arguments: argsJSON,
|
||||
}
|
||||
}
|
||||
|
||||
// generateRandomID generates a random alphanumeric ID
|
||||
func generateRandomID() string {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
const length = 24
|
||||
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
b[i] = charset[rng.Intn(len(charset))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
@@ -2,13 +2,14 @@ package google
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"google.golang.org/genai"
|
||||
|
||||
"github.com/yourusername/go-llm-gateway/internal/api"
|
||||
"github.com/yourusername/go-llm-gateway/internal/config"
|
||||
"github.com/ajac-zero/latticelm/internal/api"
|
||||
"github.com/ajac-zero/latticelm/internal/config"
|
||||
)
|
||||
|
||||
const Name = "google"
|
||||
@@ -19,7 +20,7 @@ type Provider struct {
|
||||
client *genai.Client
|
||||
}
|
||||
|
||||
// New constructs a Provider using the provided configuration.
|
||||
// New constructs a Provider using the Google AI API with API key authentication.
|
||||
func New(cfg config.ProviderConfig) *Provider {
|
||||
var client *genai.Client
|
||||
if cfg.APIKey != "" {
|
||||
@@ -38,13 +39,36 @@ func New(cfg config.ProviderConfig) *Provider {
|
||||
}
|
||||
}
|
||||
|
||||
// NewVertexAI constructs a Provider targeting Vertex AI.
|
||||
// Vertex AI uses the same genai SDK but with GCP project/location configuration
|
||||
// and Application Default Credentials (ADC) or service account authentication.
|
||||
func NewVertexAI(vertexCfg config.VertexAIConfig) *Provider {
|
||||
var client *genai.Client
|
||||
if vertexCfg.Project != "" && vertexCfg.Location != "" {
|
||||
var err error
|
||||
client, err = genai.NewClient(context.Background(), &genai.ClientConfig{
|
||||
Project: vertexCfg.Project,
|
||||
Location: vertexCfg.Location,
|
||||
Backend: genai.BackendVertexAI,
|
||||
})
|
||||
if err != nil {
|
||||
// Log error but don't fail construction - will fail on Generate
|
||||
fmt.Printf("warning: failed to create vertex ai client: %v\n", err)
|
||||
}
|
||||
}
|
||||
return &Provider{
|
||||
cfg: config.ProviderConfig{
|
||||
// Vertex AI doesn't use API key, but set empty for consistency
|
||||
APIKey: "",
|
||||
},
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Provider) Name() string { return Name }
|
||||
|
||||
// Generate routes the request to Gemini and returns a ProviderResult.
|
||||
func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (*api.ProviderResult, error) {
|
||||
if p.cfg.APIKey == "" {
|
||||
return nil, fmt.Errorf("google api key missing")
|
||||
}
|
||||
if p.client == nil {
|
||||
return nil, fmt.Errorf("google client not initialized")
|
||||
}
|
||||
@@ -53,7 +77,27 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
|
||||
|
||||
contents, systemText := convertMessages(messages)
|
||||
|
||||
config := buildConfig(systemText, req)
|
||||
// Parse tools if present
|
||||
var tools []*genai.Tool
|
||||
if req.Tools != nil && len(req.Tools) > 0 {
|
||||
var err error
|
||||
tools, err = parseTools(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse tools: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse tool_choice if present
|
||||
var toolConfig *genai.ToolConfig
|
||||
if req.ToolChoice != nil && len(req.ToolChoice) > 0 {
|
||||
var err error
|
||||
toolConfig, err = parseToolChoice(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse tool_choice: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
config := buildConfig(systemText, req, tools, toolConfig)
|
||||
|
||||
resp, err := p.client.Models.GenerateContent(ctx, model, contents, config)
|
||||
if err != nil {
|
||||
@@ -69,6 +113,11 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
|
||||
}
|
||||
}
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
if len(resp.Candidates) > 0 {
|
||||
toolCalls = extractToolCalls(resp)
|
||||
}
|
||||
|
||||
var inputTokens, outputTokens int
|
||||
if resp.UsageMetadata != nil {
|
||||
inputTokens = int(resp.UsageMetadata.PromptTokenCount)
|
||||
@@ -76,9 +125,10 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
|
||||
}
|
||||
|
||||
return &api.ProviderResult{
|
||||
ID: uuid.NewString(),
|
||||
Model: model,
|
||||
Text: text,
|
||||
ID: uuid.NewString(),
|
||||
Model: model,
|
||||
Text: text,
|
||||
ToolCalls: toolCalls,
|
||||
Usage: api.Usage{
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: outputTokens,
|
||||
@@ -96,10 +146,6 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
|
||||
defer close(deltaChan)
|
||||
defer close(errChan)
|
||||
|
||||
if p.cfg.APIKey == "" {
|
||||
errChan <- fmt.Errorf("google api key missing")
|
||||
return
|
||||
}
|
||||
if p.client == nil {
|
||||
errChan <- fmt.Errorf("google client not initialized")
|
||||
return
|
||||
@@ -109,7 +155,29 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
|
||||
|
||||
contents, systemText := convertMessages(messages)
|
||||
|
||||
config := buildConfig(systemText, req)
|
||||
// Parse tools if present
|
||||
var tools []*genai.Tool
|
||||
if req.Tools != nil && len(req.Tools) > 0 {
|
||||
var err error
|
||||
tools, err = parseTools(req)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("parse tools: %w", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Parse tool_choice if present
|
||||
var toolConfig *genai.ToolConfig
|
||||
if req.ToolChoice != nil && len(req.ToolChoice) > 0 {
|
||||
var err error
|
||||
toolConfig, err = parseToolChoice(req)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("parse tool_choice: %w", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
config := buildConfig(systemText, req, tools, toolConfig)
|
||||
|
||||
stream := p.client.Models.GenerateContentStream(ctx, model, contents, config)
|
||||
|
||||
@@ -119,21 +187,32 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
|
||||
return
|
||||
}
|
||||
|
||||
var text string
|
||||
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
|
||||
for _, part := range resp.Candidates[0].Content.Parts {
|
||||
for partIndex, part := range resp.Candidates[0].Content.Parts {
|
||||
if part != nil {
|
||||
text += part.Text
|
||||
}
|
||||
}
|
||||
}
|
||||
// Handle text content
|
||||
if part.Text != "" {
|
||||
select {
|
||||
case deltaChan <- &api.ProviderStreamDelta{Text: part.Text}:
|
||||
case <-ctx.Done():
|
||||
errChan <- ctx.Err()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if text != "" {
|
||||
select {
|
||||
case deltaChan <- &api.ProviderStreamDelta{Text: text}:
|
||||
case <-ctx.Done():
|
||||
errChan <- ctx.Err()
|
||||
return
|
||||
// Handle tool call content
|
||||
if part.FunctionCall != nil {
|
||||
delta := extractToolCallDelta(part, partIndex)
|
||||
if delta != nil {
|
||||
select {
|
||||
case deltaChan <- &api.ProviderStreamDelta{ToolCallDelta: delta}:
|
||||
case <-ctx.Done():
|
||||
errChan <- ctx.Err()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -163,6 +242,39 @@ func convertMessages(messages []api.Message) ([]*genai.Content, string) {
|
||||
continue
|
||||
}
|
||||
|
||||
if msg.Role == "tool" {
|
||||
// Tool results are sent as FunctionResponse in user role message
|
||||
var output string
|
||||
for _, block := range msg.Content {
|
||||
if block.Type == "input_text" || block.Type == "output_text" {
|
||||
output += block.Text
|
||||
}
|
||||
}
|
||||
|
||||
// Parse output as JSON map, or wrap in {"output": "..."} if not JSON
|
||||
var responseMap map[string]any
|
||||
if err := json.Unmarshal([]byte(output), &responseMap); err != nil {
|
||||
// Not JSON, wrap it
|
||||
responseMap = map[string]any{"output": output}
|
||||
}
|
||||
|
||||
// Create FunctionResponse part with CallID from message
|
||||
part := &genai.Part{
|
||||
FunctionResponse: &genai.FunctionResponse{
|
||||
ID: msg.CallID,
|
||||
Name: "", // Name is optional for responses
|
||||
Response: responseMap,
|
||||
},
|
||||
}
|
||||
|
||||
// Add to user role message
|
||||
contents = append(contents, &genai.Content{
|
||||
Role: "user",
|
||||
Parts: []*genai.Part{part},
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
var parts []*genai.Part
|
||||
for _, block := range msg.Content {
|
||||
if block.Type == "input_text" || block.Type == "output_text" {
|
||||
@@ -185,10 +297,10 @@ func convertMessages(messages []api.Message) ([]*genai.Content, string) {
|
||||
}
|
||||
|
||||
// buildConfig constructs a GenerateContentConfig from system text and request params.
|
||||
func buildConfig(systemText string, req *api.ResponseRequest) *genai.GenerateContentConfig {
|
||||
func buildConfig(systemText string, req *api.ResponseRequest, tools []*genai.Tool, toolConfig *genai.ToolConfig) *genai.GenerateContentConfig {
|
||||
var cfg *genai.GenerateContentConfig
|
||||
|
||||
needsCfg := systemText != "" || req.MaxOutputTokens != nil || req.Temperature != nil || req.TopP != nil
|
||||
needsCfg := systemText != "" || req.MaxOutputTokens != nil || req.Temperature != nil || req.TopP != nil || tools != nil || toolConfig != nil
|
||||
if !needsCfg {
|
||||
return nil
|
||||
}
|
||||
@@ -215,6 +327,14 @@ func buildConfig(systemText string, req *api.ResponseRequest) *genai.GenerateCon
|
||||
cfg.TopP = &tp
|
||||
}
|
||||
|
||||
if tools != nil {
|
||||
cfg.Tools = tools
|
||||
}
|
||||
|
||||
if toolConfig != nil {
|
||||
cfg.ToolConfig = toolConfig
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
|
||||
117
internal/providers/openai/convert.go
Normal file
117
internal/providers/openai/convert.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/ajac-zero/latticelm/internal/api"
|
||||
"github.com/openai/openai-go/v3"
|
||||
"github.com/openai/openai-go/v3/shared"
|
||||
)
|
||||
|
||||
// parseTools converts Open Responses tools to OpenAI format
|
||||
func parseTools(req *api.ResponseRequest) ([]openai.ChatCompletionToolUnionParam, error) {
|
||||
if req.Tools == nil || len(req.Tools) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var toolDefs []map[string]interface{}
|
||||
if err := json.Unmarshal(req.Tools, &toolDefs); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal tools: %w", err)
|
||||
}
|
||||
|
||||
var tools []openai.ChatCompletionToolUnionParam
|
||||
for _, td := range toolDefs {
|
||||
// Convert Open Responses tool to OpenAI ChatCompletionFunctionToolParam
|
||||
// Extract: name, description, parameters
|
||||
name, _ := td["name"].(string)
|
||||
desc, _ := td["description"].(string)
|
||||
params, _ := td["parameters"].(map[string]interface{})
|
||||
|
||||
funcDef := shared.FunctionDefinitionParam{
|
||||
Name: name,
|
||||
}
|
||||
|
||||
if desc != "" {
|
||||
funcDef.Description = openai.String(desc)
|
||||
}
|
||||
|
||||
if params != nil {
|
||||
funcDef.Parameters = shared.FunctionParameters(params)
|
||||
}
|
||||
|
||||
tools = append(tools, openai.ChatCompletionFunctionTool(funcDef))
|
||||
}
|
||||
|
||||
return tools, nil
|
||||
}
|
||||
|
||||
// parseToolChoice converts Open Responses tool_choice to OpenAI format
|
||||
func parseToolChoice(req *api.ResponseRequest) (openai.ChatCompletionToolChoiceOptionUnionParam, error) {
|
||||
var result openai.ChatCompletionToolChoiceOptionUnionParam
|
||||
|
||||
if req.ToolChoice == nil || len(req.ToolChoice) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
var choice interface{}
|
||||
if err := json.Unmarshal(req.ToolChoice, &choice); err != nil {
|
||||
return result, fmt.Errorf("unmarshal tool_choice: %w", err)
|
||||
}
|
||||
|
||||
// Handle string values: "auto", "none", "required"
|
||||
if str, ok := choice.(string); ok {
|
||||
result.OfAuto = openai.String(str)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Handle specific function selection: {"type": "function", "function": {"name": "..."}}
|
||||
if obj, ok := choice.(map[string]interface{}); ok {
|
||||
funcObj, _ := obj["function"].(map[string]interface{})
|
||||
name, _ := funcObj["name"].(string)
|
||||
|
||||
return openai.ToolChoiceOptionFunctionToolChoice(
|
||||
openai.ChatCompletionNamedToolChoiceFunctionParam{
|
||||
Name: name,
|
||||
},
|
||||
), nil
|
||||
}
|
||||
|
||||
return result, fmt.Errorf("invalid tool_choice format")
|
||||
}
|
||||
|
||||
// extractToolCalls converts OpenAI tool calls to api.ToolCall
|
||||
func extractToolCalls(message openai.ChatCompletionMessage) []api.ToolCall {
|
||||
if len(message.ToolCalls) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
for _, tc := range message.ToolCalls {
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
})
|
||||
}
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
// extractToolCallDelta extracts tool call delta from streaming chunk choice
|
||||
func extractToolCallDelta(choice openai.ChatCompletionChunkChoice) *api.ToolCallDelta {
|
||||
if len(choice.Delta.ToolCalls) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// OpenAI sends tool calls with index in the delta
|
||||
for _, tc := range choice.Delta.ToolCalls {
|
||||
return &api.ToolCallDelta{
|
||||
Index: int(tc.Index),
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -4,12 +4,12 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/openai/openai-go"
|
||||
"github.com/openai/openai-go/azure"
|
||||
"github.com/openai/openai-go/option"
|
||||
"github.com/openai/openai-go/v3"
|
||||
"github.com/openai/openai-go/v3/azure"
|
||||
"github.com/openai/openai-go/v3/option"
|
||||
|
||||
"github.com/yourusername/go-llm-gateway/internal/api"
|
||||
"github.com/yourusername/go-llm-gateway/internal/config"
|
||||
"github.com/ajac-zero/latticelm/internal/api"
|
||||
"github.com/ajac-zero/latticelm/internal/config"
|
||||
)
|
||||
|
||||
const Name = "openai"
|
||||
@@ -91,6 +91,8 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
|
||||
oaiMessages = append(oaiMessages, openai.SystemMessage(content))
|
||||
case "developer":
|
||||
oaiMessages = append(oaiMessages, openai.SystemMessage(content))
|
||||
case "tool":
|
||||
oaiMessages = append(oaiMessages, openai.ToolMessage(content, msg.CallID))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -108,6 +110,29 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
|
||||
params.TopP = openai.Float(*req.TopP)
|
||||
}
|
||||
|
||||
// Add tools if present
|
||||
if req.Tools != nil && len(req.Tools) > 0 {
|
||||
tools, err := parseTools(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse tools: %w", err)
|
||||
}
|
||||
params.Tools = tools
|
||||
}
|
||||
|
||||
// Add tool_choice if present
|
||||
if req.ToolChoice != nil && len(req.ToolChoice) > 0 {
|
||||
toolChoice, err := parseToolChoice(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse tool_choice: %w", err)
|
||||
}
|
||||
params.ToolChoice = toolChoice
|
||||
}
|
||||
|
||||
// Add parallel_tool_calls if specified
|
||||
if req.ParallelToolCalls != nil {
|
||||
params.ParallelToolCalls = openai.Bool(*req.ParallelToolCalls)
|
||||
}
|
||||
|
||||
// Call OpenAI API
|
||||
resp, err := p.client.Chat.Completions.New(ctx, params)
|
||||
if err != nil {
|
||||
@@ -115,14 +140,20 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
|
||||
}
|
||||
|
||||
var combinedText string
|
||||
var toolCalls []api.ToolCall
|
||||
|
||||
for _, choice := range resp.Choices {
|
||||
combinedText += choice.Message.Content
|
||||
if len(choice.Message.ToolCalls) > 0 {
|
||||
toolCalls = append(toolCalls, extractToolCalls(choice.Message)...)
|
||||
}
|
||||
}
|
||||
|
||||
return &api.ProviderResult{
|
||||
ID: resp.ID,
|
||||
Model: resp.Model,
|
||||
Text: combinedText,
|
||||
ID: resp.ID,
|
||||
Model: resp.Model,
|
||||
Text: combinedText,
|
||||
ToolCalls: toolCalls,
|
||||
Usage: api.Usage{
|
||||
InputTokens: int(resp.Usage.PromptTokens),
|
||||
OutputTokens: int(resp.Usage.CompletionTokens),
|
||||
@@ -168,6 +199,8 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
|
||||
oaiMessages = append(oaiMessages, openai.SystemMessage(content))
|
||||
case "developer":
|
||||
oaiMessages = append(oaiMessages, openai.SystemMessage(content))
|
||||
case "tool":
|
||||
oaiMessages = append(oaiMessages, openai.ToolMessage(content, msg.CallID))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -185,6 +218,31 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
|
||||
params.TopP = openai.Float(*req.TopP)
|
||||
}
|
||||
|
||||
// Add tools if present
|
||||
if req.Tools != nil && len(req.Tools) > 0 {
|
||||
tools, err := parseTools(req)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("parse tools: %w", err)
|
||||
return
|
||||
}
|
||||
params.Tools = tools
|
||||
}
|
||||
|
||||
// Add tool_choice if present
|
||||
if req.ToolChoice != nil && len(req.ToolChoice) > 0 {
|
||||
toolChoice, err := parseToolChoice(req)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("parse tool_choice: %w", err)
|
||||
return
|
||||
}
|
||||
params.ToolChoice = toolChoice
|
||||
}
|
||||
|
||||
// Add parallel_tool_calls if specified
|
||||
if req.ParallelToolCalls != nil {
|
||||
params.ParallelToolCalls = openai.Bool(*req.ParallelToolCalls)
|
||||
}
|
||||
|
||||
// Create streaming request
|
||||
stream := p.client.Chat.Completions.NewStreaming(ctx, params)
|
||||
|
||||
@@ -193,19 +251,35 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
|
||||
chunk := stream.Current()
|
||||
|
||||
for _, choice := range chunk.Choices {
|
||||
if choice.Delta.Content == "" {
|
||||
continue
|
||||
// Handle text content
|
||||
if choice.Delta.Content != "" {
|
||||
select {
|
||||
case deltaChan <- &api.ProviderStreamDelta{
|
||||
ID: chunk.ID,
|
||||
Model: chunk.Model,
|
||||
Text: choice.Delta.Content,
|
||||
}:
|
||||
case <-ctx.Done():
|
||||
errChan <- ctx.Err()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case deltaChan <- &api.ProviderStreamDelta{
|
||||
ID: chunk.ID,
|
||||
Model: chunk.Model,
|
||||
Text: choice.Delta.Content,
|
||||
}:
|
||||
case <-ctx.Done():
|
||||
errChan <- ctx.Err()
|
||||
return
|
||||
// Handle tool call deltas
|
||||
if len(choice.Delta.ToolCalls) > 0 {
|
||||
delta := extractToolCallDelta(choice)
|
||||
if delta != nil {
|
||||
select {
|
||||
case deltaChan <- &api.ProviderStreamDelta{
|
||||
ID: chunk.ID,
|
||||
Model: chunk.Model,
|
||||
ToolCallDelta: delta,
|
||||
}:
|
||||
case <-ctx.Done():
|
||||
errChan <- ctx.Err()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,11 +4,11 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/yourusername/go-llm-gateway/internal/api"
|
||||
"github.com/yourusername/go-llm-gateway/internal/config"
|
||||
anthropicprovider "github.com/yourusername/go-llm-gateway/internal/providers/anthropic"
|
||||
googleprovider "github.com/yourusername/go-llm-gateway/internal/providers/google"
|
||||
openaiprovider "github.com/yourusername/go-llm-gateway/internal/providers/openai"
|
||||
"github.com/ajac-zero/latticelm/internal/api"
|
||||
"github.com/ajac-zero/latticelm/internal/config"
|
||||
anthropicprovider "github.com/ajac-zero/latticelm/internal/providers/anthropic"
|
||||
googleprovider "github.com/ajac-zero/latticelm/internal/providers/google"
|
||||
openaiprovider "github.com/ajac-zero/latticelm/internal/providers/openai"
|
||||
)
|
||||
|
||||
// Provider represents a unified interface that each LLM provider must implement.
|
||||
@@ -60,7 +60,8 @@ func NewRegistry(entries map[string]config.ProviderEntry, models []config.ModelE
|
||||
}
|
||||
|
||||
func buildProvider(entry config.ProviderEntry) (Provider, error) {
|
||||
if entry.APIKey == "" {
|
||||
// Vertex AI doesn't require APIKey, so check for it separately
|
||||
if entry.Type != "vertexai" && entry.APIKey == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -97,6 +98,14 @@ func buildProvider(entry config.ProviderEntry) (Provider, error) {
|
||||
APIKey: entry.APIKey,
|
||||
Endpoint: entry.Endpoint,
|
||||
}), nil
|
||||
case "vertexai":
|
||||
if entry.Project == "" || entry.Location == "" {
|
||||
return nil, fmt.Errorf("project and location are required for vertexai")
|
||||
}
|
||||
return googleprovider.NewVertexAI(config.VertexAIConfig{
|
||||
Project: entry.Project,
|
||||
Location: entry.Location,
|
||||
}), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown provider type %q", entry.Type)
|
||||
}
|
||||
|
||||
@@ -10,9 +10,9 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/yourusername/go-llm-gateway/internal/api"
|
||||
"github.com/yourusername/go-llm-gateway/internal/conversation"
|
||||
"github.com/yourusername/go-llm-gateway/internal/providers"
|
||||
"github.com/ajac-zero/latticelm/internal/api"
|
||||
"github.com/ajac-zero/latticelm/internal/conversation"
|
||||
"github.com/ajac-zero/latticelm/internal/providers"
|
||||
)
|
||||
|
||||
// GatewayServer hosts the Open Responses API for the gateway.
|
||||
@@ -84,8 +84,13 @@ func (s *GatewayServer) handleResponses(w http.ResponseWriter, r *http.Request)
|
||||
// Build full message history from previous conversation
|
||||
var historyMsgs []api.Message
|
||||
if req.PreviousResponseID != nil && *req.PreviousResponseID != "" {
|
||||
conv, ok := s.convs.Get(*req.PreviousResponseID)
|
||||
if !ok {
|
||||
conv, err := s.convs.Get(*req.PreviousResponseID)
|
||||
if err != nil {
|
||||
s.logger.Printf("error retrieving conversation: %v", err)
|
||||
http.Error(w, "error retrieving conversation", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if conv == nil {
|
||||
http.Error(w, "conversation not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
@@ -140,7 +145,10 @@ func (s *GatewayServer) handleSyncResponse(w http.ResponseWriter, r *http.Reques
|
||||
Content: []api.ContentBlock{{Type: "output_text", Text: result.Text}},
|
||||
}
|
||||
allMsgs := append(storeMsgs, assistantMsg)
|
||||
s.convs.Create(responseID, result.Model, allMsgs)
|
||||
if _, err := s.convs.Create(responseID, result.Model, allMsgs); err != nil {
|
||||
s.logger.Printf("error storing conversation: %v", err)
|
||||
// Don't fail the response if storage fails
|
||||
}
|
||||
|
||||
// Build spec-compliant response
|
||||
resp := s.buildResponse(origReq, result, provider.Name(), responseID)
|
||||
@@ -224,6 +232,17 @@ func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.R
|
||||
var streamErr error
|
||||
var providerModel string
|
||||
|
||||
// Track tool calls being built
|
||||
type toolCallBuilder struct {
|
||||
itemID string
|
||||
id string
|
||||
name string
|
||||
arguments string
|
||||
}
|
||||
toolCallsInProgress := make(map[int]*toolCallBuilder)
|
||||
nextOutputIdx := 0
|
||||
textItemAdded := false
|
||||
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
@@ -234,7 +253,14 @@ loop:
|
||||
if delta.Model != "" && providerModel == "" {
|
||||
providerModel = delta.Model
|
||||
}
|
||||
|
||||
// Handle text content
|
||||
if delta.Text != "" {
|
||||
// Add text item on first text delta
|
||||
if !textItemAdded {
|
||||
textItemAdded = true
|
||||
nextOutputIdx++
|
||||
}
|
||||
fullText += delta.Text
|
||||
s.sendSSE(w, flusher, &seq, "response.output_text.delta", &api.StreamEvent{
|
||||
Type: "response.output_text.delta",
|
||||
@@ -244,6 +270,53 @@ loop:
|
||||
Delta: delta.Text,
|
||||
})
|
||||
}
|
||||
|
||||
// Handle tool call delta
|
||||
if delta.ToolCallDelta != nil {
|
||||
tc := delta.ToolCallDelta
|
||||
|
||||
// First chunk for this tool call index
|
||||
if _, exists := toolCallsInProgress[tc.Index]; !exists {
|
||||
toolItemID := generateID("item_")
|
||||
toolOutputIdx := nextOutputIdx
|
||||
nextOutputIdx++
|
||||
|
||||
// Send response.output_item.added
|
||||
s.sendSSE(w, flusher, &seq, "response.output_item.added", &api.StreamEvent{
|
||||
Type: "response.output_item.added",
|
||||
OutputIndex: &toolOutputIdx,
|
||||
Item: &api.OutputItem{
|
||||
ID: toolItemID,
|
||||
Type: "function_call",
|
||||
Status: "in_progress",
|
||||
CallID: tc.ID,
|
||||
Name: tc.Name,
|
||||
},
|
||||
})
|
||||
|
||||
toolCallsInProgress[tc.Index] = &toolCallBuilder{
|
||||
itemID: toolItemID,
|
||||
id: tc.ID,
|
||||
name: tc.Name,
|
||||
arguments: "",
|
||||
}
|
||||
}
|
||||
|
||||
// Send function_call_arguments.delta
|
||||
if tc.Arguments != "" {
|
||||
builder := toolCallsInProgress[tc.Index]
|
||||
builder.arguments += tc.Arguments
|
||||
toolOutputIdx := outputIdx + 1 + tc.Index
|
||||
|
||||
s.sendSSE(w, flusher, &seq, "response.function_call_arguments.delta", &api.StreamEvent{
|
||||
Type: "response.function_call_arguments.delta",
|
||||
ItemID: builder.itemID,
|
||||
OutputIndex: &toolOutputIdx,
|
||||
Delta: tc.Arguments,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if delta.Done {
|
||||
break loop
|
||||
}
|
||||
@@ -277,54 +350,108 @@ loop:
|
||||
return
|
||||
}
|
||||
|
||||
// response.output_text.done
|
||||
s.sendSSE(w, flusher, &seq, "response.output_text.done", &api.StreamEvent{
|
||||
Type: "response.output_text.done",
|
||||
ItemID: itemID,
|
||||
OutputIndex: &outputIdx,
|
||||
ContentIndex: &contentIdx,
|
||||
Text: fullText,
|
||||
})
|
||||
// Send done events for text output if text was added
|
||||
if textItemAdded && fullText != "" {
|
||||
// response.output_text.done
|
||||
s.sendSSE(w, flusher, &seq, "response.output_text.done", &api.StreamEvent{
|
||||
Type: "response.output_text.done",
|
||||
ItemID: itemID,
|
||||
OutputIndex: &outputIdx,
|
||||
ContentIndex: &contentIdx,
|
||||
Text: fullText,
|
||||
})
|
||||
|
||||
// response.content_part.done
|
||||
completedPart := &api.ContentPart{
|
||||
Type: "output_text",
|
||||
Text: fullText,
|
||||
Annotations: []api.Annotation{},
|
||||
}
|
||||
s.sendSSE(w, flusher, &seq, "response.content_part.done", &api.StreamEvent{
|
||||
Type: "response.content_part.done",
|
||||
ItemID: itemID,
|
||||
OutputIndex: &outputIdx,
|
||||
ContentIndex: &contentIdx,
|
||||
Part: completedPart,
|
||||
})
|
||||
// response.content_part.done
|
||||
completedPart := &api.ContentPart{
|
||||
Type: "output_text",
|
||||
Text: fullText,
|
||||
Annotations: []api.Annotation{},
|
||||
}
|
||||
s.sendSSE(w, flusher, &seq, "response.content_part.done", &api.StreamEvent{
|
||||
Type: "response.content_part.done",
|
||||
ItemID: itemID,
|
||||
OutputIndex: &outputIdx,
|
||||
ContentIndex: &contentIdx,
|
||||
Part: completedPart,
|
||||
})
|
||||
|
||||
// response.output_item.done
|
||||
completedItem := &api.OutputItem{
|
||||
ID: itemID,
|
||||
Type: "message",
|
||||
Status: "completed",
|
||||
Role: "assistant",
|
||||
Content: []api.ContentPart{*completedPart},
|
||||
// response.output_item.done
|
||||
completedItem := &api.OutputItem{
|
||||
ID: itemID,
|
||||
Type: "message",
|
||||
Status: "completed",
|
||||
Role: "assistant",
|
||||
Content: []api.ContentPart{*completedPart},
|
||||
}
|
||||
s.sendSSE(w, flusher, &seq, "response.output_item.done", &api.StreamEvent{
|
||||
Type: "response.output_item.done",
|
||||
OutputIndex: &outputIdx,
|
||||
Item: completedItem,
|
||||
})
|
||||
}
|
||||
|
||||
// Send done events for each tool call
|
||||
for idx, builder := range toolCallsInProgress {
|
||||
toolOutputIdx := outputIdx + 1 + idx
|
||||
|
||||
s.sendSSE(w, flusher, &seq, "response.function_call_arguments.done", &api.StreamEvent{
|
||||
Type: "response.function_call_arguments.done",
|
||||
ItemID: builder.itemID,
|
||||
OutputIndex: &toolOutputIdx,
|
||||
Arguments: builder.arguments,
|
||||
})
|
||||
|
||||
s.sendSSE(w, flusher, &seq, "response.output_item.done", &api.StreamEvent{
|
||||
Type: "response.output_item.done",
|
||||
OutputIndex: &toolOutputIdx,
|
||||
Item: &api.OutputItem{
|
||||
ID: builder.itemID,
|
||||
Type: "function_call",
|
||||
Status: "completed",
|
||||
CallID: builder.id,
|
||||
Name: builder.name,
|
||||
Arguments: builder.arguments,
|
||||
},
|
||||
})
|
||||
}
|
||||
s.sendSSE(w, flusher, &seq, "response.output_item.done", &api.StreamEvent{
|
||||
Type: "response.output_item.done",
|
||||
OutputIndex: &outputIdx,
|
||||
Item: completedItem,
|
||||
})
|
||||
|
||||
// Build final completed response
|
||||
model := origReq.Model
|
||||
if providerModel != "" {
|
||||
model = providerModel
|
||||
}
|
||||
|
||||
// Collect tool calls for result
|
||||
var toolCalls []api.ToolCall
|
||||
for _, builder := range toolCallsInProgress {
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
ID: builder.id,
|
||||
Name: builder.name,
|
||||
Arguments: builder.arguments,
|
||||
})
|
||||
}
|
||||
|
||||
finalResult := &api.ProviderResult{
|
||||
Model: model,
|
||||
Text: fullText,
|
||||
Model: model,
|
||||
Text: fullText,
|
||||
ToolCalls: toolCalls,
|
||||
}
|
||||
completedResp := s.buildResponse(origReq, finalResult, provider.Name(), responseID)
|
||||
completedResp.Output[0].ID = itemID
|
||||
|
||||
// Update item IDs to match what we sent during streaming
|
||||
if textItemAdded && len(completedResp.Output) > 0 {
|
||||
completedResp.Output[0].ID = itemID
|
||||
}
|
||||
for idx, builder := range toolCallsInProgress {
|
||||
// Find the corresponding output item
|
||||
for i := range completedResp.Output {
|
||||
if completedResp.Output[i].Type == "function_call" && completedResp.Output[i].CallID == builder.id {
|
||||
completedResp.Output[i].ID = builder.itemID
|
||||
break
|
||||
}
|
||||
}
|
||||
_ = idx // unused
|
||||
}
|
||||
|
||||
// response.completed
|
||||
s.sendSSE(w, flusher, &seq, "response.completed", &api.StreamEvent{
|
||||
@@ -339,7 +466,10 @@ loop:
|
||||
Content: []api.ContentBlock{{Type: "output_text", Text: fullText}},
|
||||
}
|
||||
allMsgs := append(storeMsgs, assistantMsg)
|
||||
s.convs.Create(responseID, model, allMsgs)
|
||||
if _, err := s.convs.Create(responseID, model, allMsgs); err != nil {
|
||||
s.logger.Printf("error storing conversation: %v", err)
|
||||
// Don't fail the response if storage fails
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -363,18 +493,34 @@ func (s *GatewayServer) buildResponse(req *api.ResponseRequest, result *api.Prov
|
||||
model = req.Model
|
||||
}
|
||||
|
||||
// Build output item
|
||||
itemID := generateID("msg_")
|
||||
outputItem := api.OutputItem{
|
||||
ID: itemID,
|
||||
Type: "message",
|
||||
Status: "completed",
|
||||
Role: "assistant",
|
||||
Content: []api.ContentPart{{
|
||||
Type: "output_text",
|
||||
Text: result.Text,
|
||||
Annotations: []api.Annotation{},
|
||||
}},
|
||||
// Build output items array
|
||||
outputItems := []api.OutputItem{}
|
||||
|
||||
// Add message item if there's text
|
||||
if result.Text != "" {
|
||||
outputItems = append(outputItems, api.OutputItem{
|
||||
ID: generateID("msg_"),
|
||||
Type: "message",
|
||||
Status: "completed",
|
||||
Role: "assistant",
|
||||
Content: []api.ContentPart{{
|
||||
Type: "output_text",
|
||||
Text: result.Text,
|
||||
Annotations: []api.Annotation{},
|
||||
}},
|
||||
})
|
||||
}
|
||||
|
||||
// Add function_call items
|
||||
for _, tc := range result.ToolCalls {
|
||||
outputItems = append(outputItems, api.OutputItem{
|
||||
ID: generateID("item_"),
|
||||
Type: "function_call",
|
||||
Status: "completed",
|
||||
CallID: tc.ID,
|
||||
Name: tc.Name,
|
||||
Arguments: tc.Arguments,
|
||||
})
|
||||
}
|
||||
|
||||
// Echo back request params with defaults
|
||||
@@ -454,7 +600,7 @@ func (s *GatewayServer) buildResponse(req *api.ResponseRequest, result *api.Prov
|
||||
Model: model,
|
||||
PreviousResponseID: req.PreviousResponseID,
|
||||
Instructions: req.Instructions,
|
||||
Output: []api.OutputItem{outputItem},
|
||||
Output: outputItems,
|
||||
Error: nil,
|
||||
Tools: tools,
|
||||
ToolChoice: toolChoice,
|
||||
|
||||
129
scripts/chat.py
129
scripts/chat.py
@@ -3,12 +3,12 @@
|
||||
# requires-python = ">=3.11"
|
||||
# dependencies = [
|
||||
# "rich>=13.7.0",
|
||||
# "httpx>=0.27.0",
|
||||
# "openai>=1.0.0",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
"""
|
||||
Terminal chat interface for go-llm-gateway.
|
||||
Terminal chat interface for latticelm.
|
||||
|
||||
Usage:
|
||||
python chat.py
|
||||
@@ -18,11 +18,10 @@ Usage:
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from openai import OpenAI, APIStatusError
|
||||
from rich.console import Console
|
||||
from rich.live import Live
|
||||
from rich.markdown import Markdown
|
||||
@@ -34,16 +33,13 @@ from rich.table import Table
|
||||
class ChatClient:
|
||||
def __init__(self, base_url: str, token: Optional[str] = None):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.token = token
|
||||
self.client = OpenAI(
|
||||
base_url=f"{self.base_url}/v1",
|
||||
api_key=token or "no-key",
|
||||
)
|
||||
self.messages = []
|
||||
self.console = Console()
|
||||
|
||||
def _headers(self) -> dict:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.token:
|
||||
headers["Authorization"] = f"Bearer {self.token}"
|
||||
return headers
|
||||
|
||||
def chat(self, user_message: str, model: str, stream: bool = True):
|
||||
"""Send a chat message and get response."""
|
||||
# Add user message to history
|
||||
@@ -52,35 +48,20 @@ class ChatClient:
|
||||
"content": [{"type": "input_text", "text": user_message}]
|
||||
})
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"input": self.messages,
|
||||
"stream": stream
|
||||
}
|
||||
|
||||
if stream:
|
||||
return self._stream_response(payload, model)
|
||||
return self._stream_response(model)
|
||||
else:
|
||||
return self._sync_response(payload, model)
|
||||
return self._sync_response(model)
|
||||
|
||||
def _sync_response(self, payload: dict, model: str) -> str:
|
||||
def _sync_response(self, model: str) -> str:
|
||||
"""Non-streaming response."""
|
||||
with self.console.status(f"[bold blue]Thinking ({model})..."):
|
||||
resp = httpx.post(
|
||||
f"{self.base_url}/v1/responses",
|
||||
json=payload,
|
||||
headers=self._headers(),
|
||||
timeout=60.0
|
||||
response = self.client.responses.create(
|
||||
model=model,
|
||||
input=self.messages,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
|
||||
data = resp.json()
|
||||
assistant_text = ""
|
||||
|
||||
for msg in data.get("output", []):
|
||||
for block in msg.get("content", []):
|
||||
if block.get("type") == "output_text":
|
||||
assistant_text += block.get("text", "")
|
||||
assistant_text = response.output_text
|
||||
|
||||
# Add to history
|
||||
self.messages.append({
|
||||
@@ -90,40 +71,19 @@ class ChatClient:
|
||||
|
||||
return assistant_text
|
||||
|
||||
def _stream_response(self, payload: dict, model: str) -> str:
|
||||
def _stream_response(self, model: str) -> str:
|
||||
"""Streaming response with live rendering."""
|
||||
assistant_text = ""
|
||||
|
||||
with httpx.stream(
|
||||
"POST",
|
||||
f"{self.base_url}/v1/responses",
|
||||
json=payload,
|
||||
headers=self._headers(),
|
||||
timeout=60.0
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
|
||||
with Live(console=self.console, refresh_per_second=10) as live:
|
||||
for line in resp.iter_lines():
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
|
||||
data_str = line[6:] # Remove "data: " prefix
|
||||
|
||||
try:
|
||||
chunk = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if chunk.get("done"):
|
||||
break
|
||||
|
||||
delta = chunk.get("delta", {})
|
||||
for block in delta.get("content", []):
|
||||
if block.get("type") == "output_text":
|
||||
assistant_text += block.get("text", "")
|
||||
|
||||
# Render markdown in real-time
|
||||
with Live(console=self.console, refresh_per_second=10) as live:
|
||||
stream = self.client.responses.create(
|
||||
model=model,
|
||||
input=self.messages,
|
||||
stream=True,
|
||||
)
|
||||
for event in stream:
|
||||
if event.type == "response.output_text.delta":
|
||||
assistant_text += event.delta
|
||||
live.update(Markdown(assistant_text))
|
||||
|
||||
# Add to history
|
||||
@@ -139,43 +99,56 @@ class ChatClient:
|
||||
self.messages = []
|
||||
|
||||
|
||||
def print_models_table(base_url: str, headers: dict):
|
||||
def print_models_table(client: OpenAI):
|
||||
"""Fetch and print available models from the gateway."""
|
||||
console = Console()
|
||||
try:
|
||||
resp = httpx.get(f"{base_url}/v1/models", headers=headers, timeout=10)
|
||||
resp.raise_for_status()
|
||||
data = resp.json().get("data", [])
|
||||
models = client.models.list()
|
||||
except Exception as e:
|
||||
console.print(f"[red]Failed to fetch models: {e}[/red]")
|
||||
return
|
||||
|
||||
table = Table(title="Available Models", show_header=True, header_style="bold magenta")
|
||||
table.add_column("Provider", style="cyan")
|
||||
table.add_column("Owner", style="cyan")
|
||||
table.add_column("Model ID", style="green")
|
||||
|
||||
for model in data:
|
||||
table.add_row(model.get("provider", ""), model.get("id", ""))
|
||||
for model in models:
|
||||
table.add_row(model.owned_by, model.id)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Chat with go-llm-gateway")
|
||||
parser = argparse.ArgumentParser(description="Chat with latticelm")
|
||||
parser.add_argument("--url", default="http://localhost:8080", help="Gateway URL")
|
||||
parser.add_argument("--model", default="gemini-2.0-flash-exp", help="Model to use")
|
||||
parser.add_argument("--model", default=None, help="Model to use (defaults to first available)")
|
||||
parser.add_argument("--token", help="Auth token (Bearer)")
|
||||
parser.add_argument("--no-stream", action="store_true", help="Disable streaming")
|
||||
args = parser.parse_args()
|
||||
|
||||
console = Console()
|
||||
client = ChatClient(args.url, args.token)
|
||||
current_model = args.model
|
||||
|
||||
# Fetch available models and select default
|
||||
try:
|
||||
available_models = list(client.client.models.list())
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]Failed to connect to gateway:[/bold red] {e}")
|
||||
sys.exit(1)
|
||||
|
||||
if not available_models:
|
||||
console.print("[bold red]Error:[/bold red] No models are configured on the gateway.")
|
||||
sys.exit(1)
|
||||
|
||||
if args.model:
|
||||
current_model = args.model
|
||||
else:
|
||||
current_model = available_models[0].id
|
||||
stream_enabled = not args.no_stream
|
||||
|
||||
# Welcome banner
|
||||
console.print(Panel.fit(
|
||||
"[bold cyan]go-llm-gateway Chat Interface[/bold cyan]\n"
|
||||
"[bold cyan]latticelm Chat Interface[/bold cyan]\n"
|
||||
f"Connected to: [green]{args.url}[/green]\n"
|
||||
f"Model: [yellow]{current_model}[/yellow]\n"
|
||||
f"Streaming: [{'green' if stream_enabled else 'red'}]{stream_enabled}[/]\n\n"
|
||||
@@ -230,7 +203,7 @@ def main():
|
||||
))
|
||||
|
||||
elif cmd == "/models":
|
||||
print_models_table(args.url, client._headers())
|
||||
print_models_table(client.client)
|
||||
|
||||
elif cmd == "/model":
|
||||
if len(cmd_parts) < 2:
|
||||
@@ -265,8 +238,8 @@ def main():
|
||||
# For non-streaming, render markdown
|
||||
console.print(Markdown(response))
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
console.print(f"[bold red]Error {e.response.status_code}:[/bold red] {e.response.text}")
|
||||
except APIStatusError as e:
|
||||
console.print(f"[bold red]Error {e.status_code}:[/bold red] {e.message}")
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]Error:[/bold red] {e}")
|
||||
|
||||
|
||||
270
tests/bin/compliance-test.ts
Normal file
270
tests/bin/compliance-test.ts
Normal file
@@ -0,0 +1,270 @@
|
||||
import {
|
||||
testTemplates,
|
||||
runAllTests,
|
||||
type TestConfig,
|
||||
type TestResult,
|
||||
} from "../src/compliance-tests.ts";
|
||||
|
||||
const colors = {
|
||||
green: (s: string) => `\x1b[32m${s}\x1b[0m`,
|
||||
red: (s: string) => `\x1b[31m${s}\x1b[0m`,
|
||||
yellow: (s: string) => `\x1b[33m${s}\x1b[0m`,
|
||||
gray: (s: string) => `\x1b[90m${s}\x1b[0m`,
|
||||
};
|
||||
|
||||
interface CliArgs {
|
||||
baseUrl?: string;
|
||||
apiKey?: string;
|
||||
model?: string;
|
||||
authHeader?: string;
|
||||
noBearer?: boolean;
|
||||
noAuth?: boolean;
|
||||
filter?: string[];
|
||||
verbose?: boolean;
|
||||
json?: boolean;
|
||||
help?: boolean;
|
||||
}
|
||||
|
||||
function parseArgs(argv: string[]): CliArgs {
|
||||
const args: CliArgs = {};
|
||||
let i = 0;
|
||||
|
||||
while (i < argv.length) {
|
||||
const arg = argv[i];
|
||||
const nextArg = argv[i + 1];
|
||||
|
||||
switch (arg) {
|
||||
case "--base-url":
|
||||
case "-u":
|
||||
args.baseUrl = nextArg;
|
||||
i += 2;
|
||||
break;
|
||||
case "--api-key":
|
||||
case "-k":
|
||||
args.apiKey = nextArg;
|
||||
i += 2;
|
||||
break;
|
||||
case "--model":
|
||||
case "-m":
|
||||
args.model = nextArg;
|
||||
i += 2;
|
||||
break;
|
||||
case "--auth-header":
|
||||
args.authHeader = nextArg;
|
||||
i += 2;
|
||||
break;
|
||||
case "--no-bearer":
|
||||
args.noBearer = true;
|
||||
i += 1;
|
||||
break;
|
||||
case "--no-auth":
|
||||
args.noAuth = true;
|
||||
i += 1;
|
||||
break;
|
||||
case "--filter":
|
||||
case "-f":
|
||||
args.filter = nextArg.split(",").map((s) => s.trim());
|
||||
i += 2;
|
||||
break;
|
||||
case "--verbose":
|
||||
case "-v":
|
||||
args.verbose = true;
|
||||
i += 1;
|
||||
break;
|
||||
case "--json":
|
||||
args.json = true;
|
||||
i += 1;
|
||||
break;
|
||||
case "--help":
|
||||
case "-h":
|
||||
args.help = true;
|
||||
i += 1;
|
||||
break;
|
||||
default:
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
return args;
|
||||
}
|
||||
|
||||
function printHelp() {
|
||||
console.log(`
|
||||
Usage: npm run test:compliance -- [options]
|
||||
|
||||
Options:
|
||||
-u, --base-url <url> Gateway base URL (default: http://localhost:8080)
|
||||
-k, --api-key <key> API key (or set OPENRESPONSES_API_KEY env var)
|
||||
--no-auth Skip authentication header entirely
|
||||
-m, --model <model> Model name (default: gpt-4o-mini)
|
||||
--auth-header <name> Auth header name (default: Authorization)
|
||||
--no-bearer Disable Bearer prefix in auth header
|
||||
-f, --filter <ids> Filter tests by ID (comma-separated)
|
||||
-v, --verbose Verbose output with request/response details
|
||||
--json Output results as JSON
|
||||
-h, --help Show this help message
|
||||
|
||||
Test IDs:
|
||||
${testTemplates.map((t) => t.id).join(", ")}
|
||||
|
||||
Examples:
|
||||
npm run test:compliance
|
||||
npm run test:compliance -- --model claude-3-5-sonnet-20241022
|
||||
npm run test:compliance -- --filter basic-response,streaming-response
|
||||
npm run test:compliance -- --verbose --filter basic-response
|
||||
npm run test:compliance -- --json > results.json
|
||||
`);
|
||||
}
|
||||
|
||||
function getStatusIcon(status: TestResult["status"]): string {
|
||||
switch (status) {
|
||||
case "passed":
|
||||
return colors.green("✓");
|
||||
case "failed":
|
||||
return colors.red("✗");
|
||||
case "running":
|
||||
return colors.yellow("◉");
|
||||
case "pending":
|
||||
return colors.gray("○");
|
||||
}
|
||||
}
|
||||
|
||||
function printResult(result: TestResult, verbose: boolean) {
|
||||
const icon = getStatusIcon(result.status);
|
||||
const duration = result.duration ? ` (${result.duration}ms)` : "";
|
||||
const events =
|
||||
result.streamEvents !== undefined ? ` [${result.streamEvents} events]` : "";
|
||||
const name =
|
||||
result.status === "failed" ? colors.red(result.name) : result.name;
|
||||
|
||||
console.log(`${icon} ${name}${duration}${events}`);
|
||||
|
||||
if (result.status === "failed" && result.errors?.length) {
|
||||
for (const error of result.errors) {
|
||||
console.log(` ${colors.red("✗")} ${error}`);
|
||||
}
|
||||
|
||||
if (verbose) {
|
||||
if (result.request) {
|
||||
console.log(`\n Request:`);
|
||||
console.log(
|
||||
` ${JSON.stringify(result.request, null, 2).split("\n").join("\n ")}`,
|
||||
);
|
||||
}
|
||||
if (result.response) {
|
||||
console.log(`\n Response:`);
|
||||
const responseStr =
|
||||
typeof result.response === "string"
|
||||
? result.response
|
||||
: JSON.stringify(result.response, null, 2);
|
||||
console.log(` ${responseStr.split("\n").join("\n ")}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async function main() {
|
||||
const args = parseArgs(process.argv.slice(2));
|
||||
|
||||
if (args.help) {
|
||||
printHelp();
|
||||
process.exit(0);
|
||||
}
|
||||
|
||||
const baseUrl = args.baseUrl || "http://localhost:8080";
|
||||
const apiKey = args.apiKey || process.env.OPENRESPONSES_API_KEY || "";
|
||||
|
||||
if (!apiKey && !args.noAuth) {
|
||||
// No auth is fine for local gateway without auth enabled
|
||||
}
|
||||
|
||||
const config: TestConfig = {
|
||||
baseUrl,
|
||||
apiKey,
|
||||
model: args.model || "gpt-4o-mini",
|
||||
authHeaderName: args.authHeader || "Authorization",
|
||||
useBearerPrefix: !args.noBearer,
|
||||
};
|
||||
|
||||
if (args.filter?.length) {
|
||||
const availableIds = testTemplates.map((t) => t.id);
|
||||
const invalidFilters = args.filter.filter(
|
||||
(id) => !availableIds.includes(id),
|
||||
);
|
||||
if (invalidFilters.length) {
|
||||
console.error(
|
||||
`${colors.red("Error:")} Invalid test IDs: ${invalidFilters.join(", ")}`,
|
||||
);
|
||||
console.error(`Available test IDs: ${availableIds.join(", ")}`);
|
||||
process.exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
const allUpdates: TestResult[] = [];
|
||||
|
||||
const onProgress = (result: TestResult) => {
|
||||
if (args.filter && !args.filter.includes(result.id)) {
|
||||
return;
|
||||
}
|
||||
allUpdates.push(result);
|
||||
if (!args.json && result.status !== "running") {
|
||||
printResult(result, args.verbose || false);
|
||||
}
|
||||
};
|
||||
|
||||
if (!args.json) {
|
||||
console.log(`Running compliance tests against: ${baseUrl}`);
|
||||
console.log(`Model: ${config.model}`);
|
||||
if (args.filter) {
|
||||
console.log(`Filter: ${args.filter.join(", ")}`);
|
||||
}
|
||||
console.log();
|
||||
}
|
||||
|
||||
await runAllTests(config, onProgress);
|
||||
|
||||
const finalResults = allUpdates.filter(
|
||||
(r) => r.status === "passed" || r.status === "failed",
|
||||
);
|
||||
const passed = finalResults.filter((r) => r.status === "passed").length;
|
||||
const failed = finalResults.filter((r) => r.status === "failed").length;
|
||||
|
||||
if (args.json) {
|
||||
console.log(
|
||||
JSON.stringify(
|
||||
{
|
||||
summary: { passed, failed, total: finalResults.length },
|
||||
results: finalResults,
|
||||
},
|
||||
null,
|
||||
2,
|
||||
),
|
||||
);
|
||||
} else {
|
||||
console.log(`\n${"=".repeat(50)}`);
|
||||
console.log(
|
||||
`Results: ${colors.green(`${passed} passed`)}, ${colors.red(`${failed} failed`)}, ${finalResults.length} total`,
|
||||
);
|
||||
|
||||
if (failed > 0) {
|
||||
console.log(`\nFailed tests:`);
|
||||
for (const r of finalResults) {
|
||||
if (r.status === "failed") {
|
||||
console.log(`\n${r.name}:`);
|
||||
for (const e of r.errors || []) {
|
||||
console.log(` - ${e}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
console.log(`\n${colors.green("✓ All tests passed!")}`);
|
||||
}
|
||||
}
|
||||
|
||||
process.exit(failed > 0 ? 1 : 0);
|
||||
}
|
||||
|
||||
main().catch((error) => {
|
||||
console.error(colors.red("Fatal error:"), error);
|
||||
process.exit(1);
|
||||
});
|
||||
58
tests/package-lock.json
generated
Normal file
58
tests/package-lock.json
generated
Normal file
@@ -0,0 +1,58 @@
|
||||
{
|
||||
"name": "latticelm-compliance-tests",
|
||||
"version": "1.0.0",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "latticelm-compliance-tests",
|
||||
"version": "1.0.0",
|
||||
"devDependencies": {
|
||||
"@types/node": "^22.0.0",
|
||||
"typescript": "^5.7.0",
|
||||
"zod": "^3.24.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@types/node": {
|
||||
"version": "22.19.13",
|
||||
"resolved": "https://registry.npmjs.org/@types/node/-/node-22.19.13.tgz",
|
||||
"integrity": "sha512-akNQMv0wW5uyRpD2v2IEyRSZiR+BeGuoB6L310EgGObO44HSMNT8z1xzio28V8qOrgYaopIDNA18YgdXd+qTiw==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"undici-types": "~6.21.0"
|
||||
}
|
||||
},
|
||||
"node_modules/typescript": {
|
||||
"version": "5.9.3",
|
||||
"resolved": "https://registry.npmjs.org/typescript/-/typescript-5.9.3.tgz",
|
||||
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"bin": {
|
||||
"tsc": "bin/tsc",
|
||||
"tsserver": "bin/tsserver"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=14.17"
|
||||
}
|
||||
},
|
||||
"node_modules/undici-types": {
|
||||
"version": "6.21.0",
|
||||
"resolved": "https://registry.npmjs.org/undici-types/-/undici-types-6.21.0.tgz",
|
||||
"integrity": "sha512-iwDZqg0QAGrg9Rav5H4n0M64c3mkR59cJ6wQp+7C4nI0gsmExaedaYLNO44eT4AtBBwjbTiGPMlt2Md0T9H9JQ==",
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/zod": {
|
||||
"version": "3.25.76",
|
||||
"resolved": "https://registry.npmjs.org/zod/-/zod-3.25.76.tgz",
|
||||
"integrity": "sha512-gzUt/qt81nXsFGKIFcC3YnfEAx5NkunCfnDlvuBSSFS02bcXu4Lmea0AFIUwbLWxWPx3d9p8S5QoaujKcNQxcQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/colinhacks"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
17
tests/package.json
Normal file
17
tests/package.json
Normal file
@@ -0,0 +1,17 @@
|
||||
{
|
||||
"name": "latticelm-compliance-tests",
|
||||
"version": "1.0.0",
|
||||
"private": true,
|
||||
"description": "Open Responses compliance tests for latticelm",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"test:compliance": "node --experimental-strip-types bin/compliance-test.ts",
|
||||
"test:compliance:verbose": "node --experimental-strip-types bin/compliance-test.ts --verbose",
|
||||
"test:compliance:json": "node --experimental-strip-types bin/compliance-test.ts --json"
|
||||
},
|
||||
"devDependencies": {
|
||||
"zod": "^3.24.0",
|
||||
"typescript": "^5.7.0",
|
||||
"@types/node": "^22.0.0"
|
||||
}
|
||||
}
|
||||
370
tests/src/compliance-tests.ts
Normal file
370
tests/src/compliance-tests.ts
Normal file
@@ -0,0 +1,370 @@
|
||||
import { responseResourceSchema, type ResponseResource } from "./schemas.ts";
|
||||
import { parseSSEStream, type SSEParseResult } from "./sse-parser.ts";
|
||||
|
||||
export interface TestConfig {
|
||||
baseUrl: string;
|
||||
apiKey: string;
|
||||
authHeaderName: string;
|
||||
useBearerPrefix: boolean;
|
||||
model: string;
|
||||
}
|
||||
|
||||
export interface TestResult {
|
||||
id: string;
|
||||
name: string;
|
||||
description: string;
|
||||
status: "pending" | "running" | "passed" | "failed";
|
||||
duration?: number;
|
||||
request?: unknown;
|
||||
response?: unknown;
|
||||
errors?: string[];
|
||||
streamEvents?: number;
|
||||
}
|
||||
|
||||
interface ValidatorContext {
|
||||
streaming: boolean;
|
||||
sseResult?: SSEParseResult;
|
||||
}
|
||||
|
||||
type ResponseValidator = (
|
||||
response: ResponseResource,
|
||||
context: ValidatorContext,
|
||||
) => string[];
|
||||
|
||||
export interface TestTemplate {
|
||||
id: string;
|
||||
name: string;
|
||||
description: string;
|
||||
getRequest: (config: TestConfig) => Record<string, unknown>;
|
||||
streaming?: boolean;
|
||||
validators: ResponseValidator[];
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Validators
|
||||
// ============================================================
|
||||
|
||||
const hasOutput: ResponseValidator = (response) => {
|
||||
if (!response.output || response.output.length === 0) {
|
||||
return ["Response has no output items"];
|
||||
}
|
||||
return [];
|
||||
};
|
||||
|
||||
const hasOutputType =
|
||||
(type: string): ResponseValidator =>
|
||||
(response) => {
|
||||
const hasType = response.output?.some((item) => item.type === type);
|
||||
if (!hasType) {
|
||||
return [`Expected output item of type "${type}" but none found`];
|
||||
}
|
||||
return [];
|
||||
};
|
||||
|
||||
const completedStatus: ResponseValidator = (response) => {
|
||||
if (response.status !== "completed") {
|
||||
return [`Expected status "completed" but got "${response.status}"`];
|
||||
}
|
||||
return [];
|
||||
};
|
||||
|
||||
const streamingEvents: ResponseValidator = (_, context) => {
|
||||
if (!context.streaming) return [];
|
||||
if (!context.sseResult || context.sseResult.events.length === 0) {
|
||||
return ["No streaming events received"];
|
||||
}
|
||||
return [];
|
||||
};
|
||||
|
||||
const streamingSchema: ResponseValidator = (_, context) => {
|
||||
if (!context.streaming || !context.sseResult) return [];
|
||||
return context.sseResult.errors;
|
||||
};
|
||||
|
||||
// ============================================================
|
||||
// Test Templates
|
||||
// ============================================================
|
||||
|
||||
export const testTemplates: TestTemplate[] = [
|
||||
{
|
||||
id: "basic-response",
|
||||
name: "Basic Text Response",
|
||||
description: "Simple user message, validates ResponseResource schema",
|
||||
getRequest: (config) => ({
|
||||
model: config.model,
|
||||
input: [
|
||||
{
|
||||
type: "message",
|
||||
role: "user",
|
||||
content: [{ type: "input_text", text: "Say hello in exactly 3 words." }],
|
||||
},
|
||||
],
|
||||
}),
|
||||
validators: [hasOutput, completedStatus],
|
||||
},
|
||||
|
||||
{
|
||||
id: "streaming-response",
|
||||
name: "Streaming Response",
|
||||
description: "Validates SSE streaming events and final response",
|
||||
streaming: true,
|
||||
getRequest: (config) => ({
|
||||
model: config.model,
|
||||
input: [
|
||||
{
|
||||
type: "message",
|
||||
role: "user",
|
||||
content: [{ type: "input_text", text: "Count from 1 to 5." }],
|
||||
},
|
||||
],
|
||||
}),
|
||||
validators: [streamingEvents, streamingSchema, completedStatus],
|
||||
},
|
||||
|
||||
{
|
||||
id: "system-prompt",
|
||||
name: "System Prompt",
|
||||
description: "Include system instructions via the instructions field",
|
||||
getRequest: (config) => ({
|
||||
model: config.model,
|
||||
instructions: "You are a pirate. Always respond in pirate speak.",
|
||||
input: [
|
||||
{
|
||||
type: "message",
|
||||
role: "user",
|
||||
content: [{ type: "input_text", text: "Say hello." }],
|
||||
},
|
||||
],
|
||||
}),
|
||||
validators: [hasOutput, completedStatus],
|
||||
},
|
||||
|
||||
{
|
||||
id: "tool-calling",
|
||||
name: "Tool Calling",
|
||||
description: "Define a function tool and verify function_call output",
|
||||
getRequest: (config) => ({
|
||||
model: config.model,
|
||||
input: [
|
||||
{
|
||||
type: "message",
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "input_text",
|
||||
text: "What's the weather like in San Francisco?",
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
tools: [
|
||||
{
|
||||
type: "function",
|
||||
name: "get_weather",
|
||||
description: "Get the current weather for a location",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
location: {
|
||||
type: "string",
|
||||
description: "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
required: ["location"],
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
validators: [hasOutput, hasOutputType("function_call")],
|
||||
},
|
||||
|
||||
{
|
||||
id: "image-input",
|
||||
name: "Image Input",
|
||||
description: "Send image URL in user content",
|
||||
getRequest: (config) => ({
|
||||
model: config.model,
|
||||
input: [
|
||||
{
|
||||
type: "message",
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "input_text",
|
||||
text: "What do you see in this image? Answer in one sentence.",
|
||||
},
|
||||
{
|
||||
type: "input_image",
|
||||
image_url:
|
||||
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAIAAAD8GO2jAAABmklEQVR42tyWAaTyUBzFew/eG4AHz+MBSAHKBiJRGFKwIgQQJKLUIioBIhCAiCAAEizAQIAECaASqFFJq84nudjnaqvuPnxzgP9xfrq5938csPn7PwHTKSoViCIEAYEAMhmoKsU2mUCWEQqB5xEMIp/HaGQG2G6RSuH9HQ7H34rFrtPbdz4jl6PbwmEsl3QA1mt4vcRKk8dz9eg6IpF7tt9fzGY0gCgafFRFo5Blc5vLhf3eCOj1yNhM5GRMVK0aATxPZoz09YXjkQDmczJgquGQAPp9WwCNBgG027YACgUC6HRsAZRKBDAY2AJoNv/ZnwzA6WScznG3p4UAymXGAEkyXrTFAh8fLAGqagQAyGaZpYsi7bHTNPz8MEj//LxuFPo+UBS8vb0KaLXubrRa7aX0RMLCykwmn0z3+XA4WACcTpCkh9MFAZpmuVXo+mO/w+/HZvNgbblcUCxaSo/Hyck80Yu6XXDcvfVZr79cvMZjuN2U9O9vKAqjZrfbIZ0mV4TUi9Xqz6jddNy//7+e3n8Fhf/Llo2kxi8AQyGRoDkmAhAAAAAASUVORK5CYII=",
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
}),
|
||||
validators: [hasOutput, completedStatus],
|
||||
},
|
||||
|
||||
{
|
||||
id: "multi-turn",
|
||||
name: "Multi-turn Conversation",
|
||||
description: "Send assistant + user messages as conversation history",
|
||||
getRequest: (config) => ({
|
||||
model: config.model,
|
||||
input: [
|
||||
{
|
||||
type: "message",
|
||||
role: "user",
|
||||
content: [{ type: "input_text", text: "My name is Alice." }],
|
||||
},
|
||||
{
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: [
|
||||
{
|
||||
type: "output_text",
|
||||
text: "Hello Alice! Nice to meet you. How can I help you today?",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
type: "message",
|
||||
role: "user",
|
||||
content: [{ type: "input_text", text: "What is my name?" }],
|
||||
},
|
||||
],
|
||||
}),
|
||||
validators: [hasOutput, completedStatus],
|
||||
},
|
||||
];
|
||||
|
||||
// ============================================================
|
||||
// Test Runner
|
||||
// ============================================================
|
||||
|
||||
async function makeRequest(
|
||||
config: TestConfig,
|
||||
body: Record<string, unknown>,
|
||||
streaming = false,
|
||||
): Promise<Response> {
|
||||
const headers: Record<string, string> = {
|
||||
"Content-Type": "application/json",
|
||||
};
|
||||
|
||||
if (config.apiKey) {
|
||||
const authValue = config.useBearerPrefix
|
||||
? `Bearer ${config.apiKey}`
|
||||
: config.apiKey;
|
||||
headers[config.authHeaderName] = authValue;
|
||||
}
|
||||
|
||||
return fetch(`${config.baseUrl}/v1/responses`, {
|
||||
method: "POST",
|
||||
headers,
|
||||
body: JSON.stringify({ ...body, stream: streaming }),
|
||||
});
|
||||
}
|
||||
|
||||
async function runTest(
|
||||
template: TestTemplate,
|
||||
config: TestConfig,
|
||||
): Promise<TestResult> {
|
||||
const startTime = Date.now();
|
||||
const requestBody = template.getRequest(config);
|
||||
const streaming = template.streaming ?? false;
|
||||
|
||||
try {
|
||||
const response = await makeRequest(config, requestBody, streaming);
|
||||
const duration = Date.now() - startTime;
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
return {
|
||||
id: template.id,
|
||||
name: template.name,
|
||||
description: template.description,
|
||||
status: "failed",
|
||||
duration,
|
||||
request: requestBody,
|
||||
response: errorText,
|
||||
errors: [`HTTP ${response.status}: ${errorText}`],
|
||||
};
|
||||
}
|
||||
|
||||
let rawData: unknown;
|
||||
let sseResult: SSEParseResult | undefined;
|
||||
|
||||
if (streaming) {
|
||||
sseResult = await parseSSEStream(response);
|
||||
rawData = sseResult.finalResponse;
|
||||
} else {
|
||||
rawData = await response.json();
|
||||
}
|
||||
|
||||
// Schema validation with Zod
|
||||
const parseResult = responseResourceSchema.safeParse(rawData);
|
||||
if (!parseResult.success) {
|
||||
return {
|
||||
id: template.id,
|
||||
name: template.name,
|
||||
description: template.description,
|
||||
status: "failed",
|
||||
duration,
|
||||
request: streaming ? { ...requestBody, stream: true } : requestBody,
|
||||
response: rawData,
|
||||
errors: parseResult.error.issues.map(
|
||||
(issue) => `${issue.path.join(".")}: ${issue.message}`,
|
||||
),
|
||||
streamEvents: sseResult?.events.length,
|
||||
};
|
||||
}
|
||||
|
||||
// Semantic validators
|
||||
const context: ValidatorContext = { streaming, sseResult };
|
||||
const errors = template.validators.flatMap((v) =>
|
||||
v(parseResult.data, context),
|
||||
);
|
||||
|
||||
return {
|
||||
id: template.id,
|
||||
name: template.name,
|
||||
description: template.description,
|
||||
status: errors.length === 0 ? "passed" : "failed",
|
||||
duration,
|
||||
request: streaming ? { ...requestBody, stream: true } : requestBody,
|
||||
response: parseResult.data,
|
||||
errors,
|
||||
streamEvents: sseResult?.events.length,
|
||||
};
|
||||
} catch (error) {
|
||||
return {
|
||||
id: template.id,
|
||||
name: template.name,
|
||||
description: template.description,
|
||||
status: "failed",
|
||||
duration: Date.now() - startTime,
|
||||
request: requestBody,
|
||||
errors: [error instanceof Error ? error.message : String(error)],
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
export async function runAllTests(
|
||||
config: TestConfig,
|
||||
onProgress: (result: TestResult) => void,
|
||||
): Promise<TestResult[]> {
|
||||
const promises = testTemplates.map(async (template) => {
|
||||
onProgress({
|
||||
id: template.id,
|
||||
name: template.name,
|
||||
description: template.description,
|
||||
status: "running",
|
||||
});
|
||||
|
||||
const result = await runTest(template, config);
|
||||
onProgress(result);
|
||||
return result;
|
||||
});
|
||||
|
||||
return Promise.all(promises);
|
||||
}
|
||||
253
tests/src/schemas.ts
Normal file
253
tests/src/schemas.ts
Normal file
@@ -0,0 +1,253 @@
|
||||
import { z } from "zod";
|
||||
|
||||
// ============================================================
|
||||
// Content Parts
|
||||
// ============================================================
|
||||
|
||||
const outputTextContentSchema = z.object({
|
||||
type: z.literal("output_text"),
|
||||
text: z.string(),
|
||||
annotations: z.array(z.object({
|
||||
type: z.string(),
|
||||
})),
|
||||
});
|
||||
|
||||
const inputTextContentSchema = z.object({
|
||||
type: z.literal("input_text"),
|
||||
text: z.string(),
|
||||
});
|
||||
|
||||
const refusalContentSchema = z.object({
|
||||
type: z.literal("refusal"),
|
||||
refusal: z.string(),
|
||||
});
|
||||
|
||||
const contentPartSchema = z.discriminatedUnion("type", [
|
||||
outputTextContentSchema,
|
||||
inputTextContentSchema,
|
||||
refusalContentSchema,
|
||||
]);
|
||||
|
||||
// ============================================================
|
||||
// Output Items
|
||||
// ============================================================
|
||||
|
||||
const messageOutputItemSchema = z.object({
|
||||
type: z.literal("message"),
|
||||
id: z.string(),
|
||||
status: z.enum(["in_progress", "completed", "incomplete"]),
|
||||
role: z.enum(["user", "assistant", "system", "developer"]),
|
||||
content: z.array(contentPartSchema),
|
||||
});
|
||||
|
||||
const functionCallOutputItemSchema = z.object({
|
||||
type: z.literal("function_call"),
|
||||
id: z.string(),
|
||||
call_id: z.string(),
|
||||
name: z.string(),
|
||||
arguments: z.string(),
|
||||
status: z.enum(["in_progress", "completed", "incomplete"]),
|
||||
});
|
||||
|
||||
const outputItemSchema = z.discriminatedUnion("type", [
|
||||
messageOutputItemSchema,
|
||||
functionCallOutputItemSchema,
|
||||
]);
|
||||
|
||||
// ============================================================
|
||||
// Usage
|
||||
// ============================================================
|
||||
|
||||
const usageSchema = z.object({
|
||||
input_tokens: z.number().int(),
|
||||
output_tokens: z.number().int(),
|
||||
total_tokens: z.number().int(),
|
||||
input_tokens_details: z.object({
|
||||
cached_tokens: z.number().int(),
|
||||
}),
|
||||
output_tokens_details: z.object({
|
||||
reasoning_tokens: z.number().int(),
|
||||
}),
|
||||
});
|
||||
|
||||
// ============================================================
|
||||
// ResponseResource
|
||||
// ============================================================
|
||||
|
||||
export const responseResourceSchema = z.object({
|
||||
id: z.string(),
|
||||
object: z.literal("response"),
|
||||
created_at: z.number().int(),
|
||||
completed_at: z.number().int().nullable(),
|
||||
status: z.string(),
|
||||
incomplete_details: z.object({ reason: z.string() }).nullable(),
|
||||
model: z.string(),
|
||||
previous_response_id: z.string().nullable(),
|
||||
instructions: z.string().nullable(),
|
||||
output: z.array(outputItemSchema),
|
||||
error: z.object({ type: z.string(), message: z.string() }).nullable(),
|
||||
tools: z.any(),
|
||||
tool_choice: z.any(),
|
||||
truncation: z.string(),
|
||||
parallel_tool_calls: z.boolean(),
|
||||
text: z.any(),
|
||||
top_p: z.number(),
|
||||
presence_penalty: z.number(),
|
||||
frequency_penalty: z.number(),
|
||||
top_logprobs: z.number().int(),
|
||||
temperature: z.number(),
|
||||
reasoning: z.any().nullable(),
|
||||
usage: usageSchema.nullable(),
|
||||
max_output_tokens: z.number().int().nullable(),
|
||||
max_tool_calls: z.number().int().nullable(),
|
||||
store: z.boolean(),
|
||||
background: z.boolean(),
|
||||
service_tier: z.string(),
|
||||
metadata: z.any(),
|
||||
safety_identifier: z.string().nullable(),
|
||||
prompt_cache_key: z.string().nullable(),
|
||||
});
|
||||
|
||||
export type ResponseResource = z.infer<typeof responseResourceSchema>;
|
||||
|
||||
// ============================================================
|
||||
// Streaming Event Schemas
|
||||
// ============================================================
|
||||
|
||||
const responseCreatedEventSchema = z.object({
|
||||
type: z.literal("response.created"),
|
||||
sequence_number: z.number().int(),
|
||||
response: responseResourceSchema,
|
||||
});
|
||||
|
||||
const responseInProgressEventSchema = z.object({
|
||||
type: z.literal("response.in_progress"),
|
||||
sequence_number: z.number().int(),
|
||||
response: responseResourceSchema,
|
||||
});
|
||||
|
||||
const responseCompletedEventSchema = z.object({
|
||||
type: z.literal("response.completed"),
|
||||
sequence_number: z.number().int(),
|
||||
response: responseResourceSchema,
|
||||
});
|
||||
|
||||
const responseFailedEventSchema = z.object({
|
||||
type: z.literal("response.failed"),
|
||||
sequence_number: z.number().int(),
|
||||
response: responseResourceSchema,
|
||||
});
|
||||
|
||||
const outputItemAddedEventSchema = z.object({
|
||||
type: z.literal("response.output_item.added"),
|
||||
sequence_number: z.number().int(),
|
||||
output_index: z.number().int(),
|
||||
item: z.object({
|
||||
id: z.string(),
|
||||
type: z.string(),
|
||||
status: z.string(),
|
||||
role: z.string().optional(),
|
||||
content: z.array(z.any()).optional(),
|
||||
}),
|
||||
});
|
||||
|
||||
const outputItemDoneEventSchema = z.object({
|
||||
type: z.literal("response.output_item.done"),
|
||||
sequence_number: z.number().int(),
|
||||
output_index: z.number().int(),
|
||||
item: z.object({
|
||||
id: z.string(),
|
||||
type: z.string(),
|
||||
status: z.string(),
|
||||
role: z.string().optional(),
|
||||
content: z.array(z.any()).optional(),
|
||||
}),
|
||||
});
|
||||
|
||||
const contentPartAddedEventSchema = z.object({
|
||||
type: z.literal("response.content_part.added"),
|
||||
sequence_number: z.number().int(),
|
||||
item_id: z.string(),
|
||||
output_index: z.number().int(),
|
||||
content_index: z.number().int(),
|
||||
part: z.object({
|
||||
type: z.string(),
|
||||
text: z.string().optional(),
|
||||
annotations: z.array(z.any()).optional(),
|
||||
}),
|
||||
});
|
||||
|
||||
const contentPartDoneEventSchema = z.object({
|
||||
type: z.literal("response.content_part.done"),
|
||||
sequence_number: z.number().int(),
|
||||
item_id: z.string(),
|
||||
output_index: z.number().int(),
|
||||
content_index: z.number().int(),
|
||||
part: z.object({
|
||||
type: z.string(),
|
||||
text: z.string().optional(),
|
||||
annotations: z.array(z.any()).optional(),
|
||||
}),
|
||||
});
|
||||
|
||||
const outputTextDeltaEventSchema = z.object({
|
||||
type: z.literal("response.output_text.delta"),
|
||||
sequence_number: z.number().int(),
|
||||
item_id: z.string(),
|
||||
output_index: z.number().int(),
|
||||
content_index: z.number().int(),
|
||||
delta: z.string(),
|
||||
});
|
||||
|
||||
const outputTextDoneEventSchema = z.object({
|
||||
type: z.literal("response.output_text.done"),
|
||||
sequence_number: z.number().int(),
|
||||
item_id: z.string(),
|
||||
output_index: z.number().int(),
|
||||
content_index: z.number().int(),
|
||||
text: z.string(),
|
||||
});
|
||||
|
||||
const functionCallArgsDeltaEventSchema = z.object({
|
||||
type: z.literal("response.function_call_arguments.delta"),
|
||||
sequence_number: z.number().int(),
|
||||
item_id: z.string(),
|
||||
output_index: z.number().int(),
|
||||
delta: z.string(),
|
||||
});
|
||||
|
||||
const functionCallArgsDoneEventSchema = z.object({
|
||||
type: z.literal("response.function_call_arguments.done"),
|
||||
sequence_number: z.number().int(),
|
||||
item_id: z.string(),
|
||||
output_index: z.number().int(),
|
||||
arguments: z.string(),
|
||||
});
|
||||
|
||||
const errorEventSchema = z.object({
|
||||
type: z.literal("error"),
|
||||
sequence_number: z.number().int(),
|
||||
error: z.object({
|
||||
type: z.string(),
|
||||
message: z.string(),
|
||||
code: z.string().nullable().optional(),
|
||||
}),
|
||||
});
|
||||
|
||||
export const streamingEventSchema = z.discriminatedUnion("type", [
|
||||
responseCreatedEventSchema,
|
||||
responseInProgressEventSchema,
|
||||
responseCompletedEventSchema,
|
||||
responseFailedEventSchema,
|
||||
outputItemAddedEventSchema,
|
||||
outputItemDoneEventSchema,
|
||||
contentPartAddedEventSchema,
|
||||
contentPartDoneEventSchema,
|
||||
outputTextDeltaEventSchema,
|
||||
outputTextDoneEventSchema,
|
||||
functionCallArgsDeltaEventSchema,
|
||||
functionCallArgsDoneEventSchema,
|
||||
errorEventSchema,
|
||||
]);
|
||||
|
||||
export type StreamingEvent = z.infer<typeof streamingEventSchema>;
|
||||
92
tests/src/sse-parser.ts
Normal file
92
tests/src/sse-parser.ts
Normal file
@@ -0,0 +1,92 @@
|
||||
import type { z } from "zod";
|
||||
import {
|
||||
streamingEventSchema,
|
||||
type StreamingEvent,
|
||||
type ResponseResource,
|
||||
} from "./schemas.ts";
|
||||
|
||||
export interface ParsedEvent {
|
||||
event: string;
|
||||
data: unknown;
|
||||
validationResult: z.SafeParseReturnType<unknown, StreamingEvent>;
|
||||
}
|
||||
|
||||
export interface SSEParseResult {
|
||||
events: ParsedEvent[];
|
||||
errors: string[];
|
||||
finalResponse: ResponseResource | null;
|
||||
}
|
||||
|
||||
export async function parseSSEStream(
|
||||
response: Response,
|
||||
): Promise<SSEParseResult> {
|
||||
const events: ParsedEvent[] = [];
|
||||
const errors: string[] = [];
|
||||
let finalResponse: ResponseResource | null = null;
|
||||
|
||||
const reader = response.body?.getReader();
|
||||
if (!reader) {
|
||||
return { events, errors: ["No response body"], finalResponse };
|
||||
}
|
||||
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
const lines = buffer.split("\n");
|
||||
buffer = lines.pop() || "";
|
||||
|
||||
let currentEvent = "";
|
||||
let currentData = "";
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith("event:")) {
|
||||
currentEvent = line.slice(6).trim();
|
||||
} else if (line.startsWith("data:")) {
|
||||
currentData = line.slice(5).trim();
|
||||
} else if (line === "" && currentData) {
|
||||
if (currentData === "[DONE]") {
|
||||
// Skip sentinel
|
||||
} else {
|
||||
try {
|
||||
const parsed = JSON.parse(currentData);
|
||||
const validationResult = streamingEventSchema.safeParse(parsed);
|
||||
|
||||
events.push({
|
||||
event: currentEvent || parsed.type || "unknown",
|
||||
data: parsed,
|
||||
validationResult,
|
||||
});
|
||||
|
||||
if (!validationResult.success) {
|
||||
errors.push(
|
||||
`Event validation failed for ${parsed.type || "unknown"}: ${JSON.stringify(validationResult.error.issues)}`,
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
parsed.type === "response.completed" ||
|
||||
parsed.type === "response.failed"
|
||||
) {
|
||||
finalResponse = parsed.response;
|
||||
}
|
||||
} catch {
|
||||
errors.push(`Failed to parse event data: ${currentData}`);
|
||||
}
|
||||
}
|
||||
currentEvent = "";
|
||||
currentData = "";
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock();
|
||||
}
|
||||
|
||||
return { events, errors, finalResponse };
|
||||
}
|
||||
14
tests/tsconfig.json
Normal file
14
tests/tsconfig.json
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2022",
|
||||
"module": "NodeNext",
|
||||
"moduleResolution": "NodeNext",
|
||||
"strict": true,
|
||||
"esModuleInterop": true,
|
||||
"skipLibCheck": true,
|
||||
"outDir": "dist",
|
||||
"rootDir": ".",
|
||||
"declaration": true
|
||||
},
|
||||
"include": ["src/**/*.ts", "bin/**/*.ts"]
|
||||
}
|
||||
Reference in New Issue
Block a user