Compare commits

..

10 Commits

Author SHA1 Message Date
6adf7eae54 Add Google tool calling 2026-03-02 17:12:15 +00:00
38d44f104a Add Vertex AI support 2026-03-02 16:52:57 +00:00
2188e3cba8 Add Anthropic tool calling support 2026-03-02 16:08:39 +00:00
830a87afa1 Improve Stores 2026-03-02 16:06:38 +00:00
259d02d140 Add Redis Store 2026-03-02 15:55:03 +00:00
09d687b45b Migrate to OpenAI v3 2026-03-02 15:36:56 +00:00
157680bb13 Add OpenAI tool calling support 2026-03-02 15:36:56 +00:00
8ceb831e84 Rebrand project 2026-03-02 14:32:10 +00:00
f79af84afb Add Open Responses compliance tests 2026-03-02 13:58:25 +00:00
cf47ad444a Update chat script to use openai lib 2026-03-02 13:40:27 +00:00
28 changed files with 2662 additions and 297 deletions

3
.gitignore vendored
View File

@@ -53,3 +53,6 @@ logs/
# Python scripts
__pycache__/*
# Node.js (compliance tests)
tests/node_modules/

View File

@@ -1,4 +1,4 @@
# Go LLM Gateway
# latticelm
## Overview
@@ -11,6 +11,7 @@ Simplify LLM integration by exposing a single, consistent API that routes reques
- **Azure OpenAI** (Azure-deployed models)
- **Anthropic** (Claude)
- **Google Generative AI** (Gemini)
- **Vertex AI** (Google Cloud-hosted Gemini models)
Instead of managing multiple SDK integrations in your application, call one endpoint and let the gateway handle provider-specific implementations.
@@ -19,12 +20,13 @@ Instead of managing multiple SDK integrations in your application, call one endp
```
Client Request
Go LLM Gateway (unified API)
latticelm (unified API)
├─→ OpenAI SDK
├─→ Azure OpenAI (OpenAI SDK + Azure auth)
├─→ Anthropic SDK
─→ Google Gen AI SDK
─→ Google Gen AI SDK
└─→ Vertex AI (Google Gen AI SDK + GCP auth)
```
## Key Features
@@ -45,11 +47,12 @@ Go LLM Gateway (unified API)
## 🎉 Status: **WORKING!**
**All four providers integrated with official Go SDKs:**
- OpenAI → `github.com/openai/openai-go`
- Azure OpenAI → `github.com/openai/openai-go` (with Azure auth)
**All providers integrated with official Go SDKs:**
- OpenAI → `github.com/openai/openai-go/v3`
- Azure OpenAI → `github.com/openai/openai-go/v3` (with Azure auth)
- Anthropic → `github.com/anthropics/anthropic-sdk-go`
- Google → `google.golang.org/genai`
- Vertex AI → `google.golang.org/genai` (with GCP auth)
**Compiles successfully** (36MB binary)
**Provider auto-selection** (gpt→Azure/OpenAI, claude→Anthropic, gemini→Google)
@@ -68,7 +71,7 @@ export ANTHROPIC_API_KEY="your-key"
export GOOGLE_API_KEY="your-key"
# 2. Build
cd go-llm-gateway
cd latticelm
go build -o gateway ./cmd/gateway
# 3. Run

View File

@@ -1,6 +1,7 @@
package main
import (
"context"
"database/sql"
"flag"
"fmt"
@@ -12,12 +13,13 @@ import (
_ "github.com/go-sql-driver/mysql"
_ "github.com/jackc/pgx/v5/stdlib"
_ "github.com/mattn/go-sqlite3"
"github.com/redis/go-redis/v9"
"github.com/yourusername/go-llm-gateway/internal/auth"
"github.com/yourusername/go-llm-gateway/internal/config"
"github.com/yourusername/go-llm-gateway/internal/conversation"
"github.com/yourusername/go-llm-gateway/internal/providers"
"github.com/yourusername/go-llm-gateway/internal/server"
"github.com/ajac-zero/latticelm/internal/auth"
"github.com/ajac-zero/latticelm/internal/config"
"github.com/ajac-zero/latticelm/internal/conversation"
"github.com/ajac-zero/latticelm/internal/providers"
"github.com/ajac-zero/latticelm/internal/server"
)
func main() {
@@ -112,6 +114,22 @@ func initConversationStore(cfg config.ConversationConfig, logger *log.Logger) (c
}
logger.Printf("Conversation store initialized (sql/%s, TTL: %s)", driver, ttl)
return store, nil
case "redis":
opts, err := redis.ParseURL(cfg.DSN)
if err != nil {
return nil, fmt.Errorf("parse redis dsn: %w", err)
}
client := redis.NewClient(opts)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := client.Ping(ctx).Err(); err != nil {
return nil, fmt.Errorf("connect to redis: %w", err)
}
logger.Printf("Conversation store initialized (redis, TTL: %s)", ttl)
return conversation.NewRedisStore(client, ttl), nil
default:
logger.Printf("Conversation store initialized (memory, TTL: %s)", ttl)
return conversation.NewMemoryStore(ttl), nil

View File

@@ -14,6 +14,12 @@ providers:
type: "openai"
api_key: "YOUR_OPENAI_API_KEY"
endpoint: "https://api.openai.com"
# Vertex AI (Google Cloud) - optional
# Uses Application Default Credentials (ADC) or service account
# vertexai:
# type: "vertexai"
# project: "your-gcp-project-id"
# location: "us-central1" # or other GCP region
# Azure OpenAI - optional
# azureopenai:
# type: "azureopenai"
@@ -27,16 +33,19 @@ providers:
# endpoint: "https://your-resource.services.ai.azure.com/anthropic"
# conversations:
# store: "sql" # "memory" (default) or "sql"
# store: "sql" # "memory" (default), "sql", or "redis"
# ttl: "1h" # conversation expiration (default: 1h)
# driver: "sqlite3" # SQL driver: "sqlite3", "mysql", "pgx"
# dsn: "conversations.db" # connection string
# driver: "sqlite3" # SQL driver: "sqlite3", "mysql", "pgx" (required for sql store)
# dsn: "conversations.db" # connection string (required for sql/redis store)
# # MySQL example:
# # driver: "mysql"
# # dsn: "user:password@tcp(localhost:3306)/dbname?parseTime=true"
# # PostgreSQL example:
# # driver: "pgx"
# # dsn: "postgres://user:password@localhost:5432/dbname?sslmode=disable"
# # Redis example:
# # store: "redis"
# # dsn: "redis://:password@localhost:6379/0"
models:
- name: "gemini-1.5-flash"
@@ -45,6 +54,8 @@ models:
provider: "anthropic"
- name: "gpt-4o-mini"
provider: "openai"
# - name: "gemini-2.0-flash-exp"
# provider: "vertexai" # Use Vertex AI instead of Google AI API
# - name: "gpt-4o"
# provider: "azureopenai"
# provider_model_id: "my-gpt4o-deployment" # optional: defaults to name

17
go.mod
View File

@@ -1,11 +1,17 @@
module github.com/yourusername/go-llm-gateway
module github.com/ajac-zero/latticelm
go 1.25.7
require (
github.com/anthropics/anthropic-sdk-go v1.26.0
github.com/go-sql-driver/mysql v1.9.3
github.com/golang-jwt/jwt/v5 v5.3.1
github.com/google/uuid v1.6.0
github.com/jackc/pgx/v5 v5.8.0
github.com/mattn/go-sqlite3 v1.14.34
github.com/openai/openai-go v1.12.0
github.com/openai/openai-go/v3 v3.2.0
github.com/redis/go-redis/v9 v9.18.0
google.golang.org/genai v1.48.0
gopkg.in/yaml.v3 v3.0.1
)
@@ -15,11 +21,10 @@ require (
cloud.google.com/go/auth v0.9.3 // indirect
cloud.google.com/go/compute/metadata v0.5.0 // indirect
filippo.io/edwards25519 v1.1.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.9.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect
github.com/go-sql-driver/mysql v1.9.3 // indirect
github.com/golang-jwt/jwt/v5 v5.3.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/google/s2a-go v0.1.8 // indirect
@@ -27,15 +32,13 @@ require (
github.com/gorilla/websocket v1.5.3 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/pgx/v5 v5.8.0 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/mattn/go-sqlite3 v1.14.34 // indirect
github.com/openai/openai-go/v3 v3.2.0 // indirect
github.com/tidwall/gjson v1.18.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
go.opencensus.io v0.24.0 // indirect
go.uber.org/atomic v1.11.0 // indirect
golang.org/x/crypto v0.47.0 // indirect
golang.org/x/net v0.49.0 // indirect
golang.org/x/sync v0.19.0 // indirect

48
go.sum
View File

@@ -7,21 +7,31 @@ cloud.google.com/go/compute/metadata v0.5.0 h1:Zr0eK8JbFv6+Wi4ilXAR8FJ3wyNdpxHKJ
cloud.google.com/go/compute/metadata v0.5.0/go.mod h1:aHnloV2TPI38yx4s9+wAZhHykWvVCfu7hQbF+9CWoiY=
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.9.0 h1:t/DLMixbb8ygU11RAHJ8quXwJD7FwlC7+u6XodmSi1w=
github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.9.0/go.mod h1:Bb4vy1c7tXIqFrypNxCO7I5xlDSbpQiOWu/XvF5htP8=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 h1:fou+2+WFTib47nS+nz/ozhEBnvU96bKHy6LjRsY4E28=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0/go.mod h1:t76Ruy8AHvUAC8GfMWJMa0ElSbuIcO03NLpynfbgsPA=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1 h1:B+blDbyVIG3WaikNxPnhPiJ1MThR03b3vKGtER95TP4=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1/go.mod h1:JdM5psgjfBf5fo2uWOZhflPWyDBZ/O/CNAH9CtsuZE4=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 h1:9iefClla7iYpfYWdzPCRDozdmndjTm8DXdpCzPajMgA=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2/go.mod h1:XtLgD3ZD34DAaVIIAyG3objl5DynM3CQ/vMcbBNJZGI=
github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 h1:oygO0locgZJe7PpYPXT5A29ZkwJaPqcva7BVeemZOZs=
github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/anthropics/anthropic-sdk-go v1.26.0 h1:oUTzFaUpAevfuELAP1sjL6CQJ9HHAfT7CoSYSac11PY=
github.com/anthropics/anthropic-sdk-go v1.26.0/go.mod h1:qUKmaW+uuPB64iy1l+4kOSvaLqPXnHTTBKH6RVZ7q5Q=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI=
github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
@@ -71,15 +81,29 @@ github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo=
github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/mattn/go-sqlite3 v1.14.34 h1:3NtcvcUnFBPsuRcno8pUtupspG/GM+9nZ88zgJcp6Zk=
github.com/mattn/go-sqlite3 v1.14.34/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/openai/openai-go v1.12.0 h1:NBQCnXzqOTv5wsgNC36PrFEiskGfO5wccfCWDo9S1U0=
github.com/openai/openai-go v1.12.0/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y=
github.com/openai/openai-go/v3 v3.2.0 h1:2AbqFUCsoW2pm/2pUtPRuwK89dnoGHaQokzWsfoQO/U=
github.com/openai/openai-go/v3 v3.2.0/go.mod h1:UOpNxkqC9OdNXNUfpNByKOtB4jAL0EssQXq5p8gO0Xs=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/redis/go-redis/v9 v9.18.0 h1:pMkxYPkEbMPwRdenAzUNyFNrDgHx9U+DrBabWNfSRQs=
github.com/redis/go-redis/v9 v9.18.0/go.mod h1:k3ufPphLU5YXwNTUcCRXGxUoF1fqxnhFQmscfkCoDA0=
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
@@ -88,9 +112,8 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
@@ -101,12 +124,14 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0=
github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA=
go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0=
go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo=
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM=
golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY=
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@@ -119,30 +144,22 @@ golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73r
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4=
golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
@@ -178,8 +195,9 @@ google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpAD
google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -96,6 +96,7 @@ type InputItem struct {
type Message struct {
Role string `json:"role"`
Content []ContentBlock `json:"content"`
CallID string `json:"call_id,omitempty"` // for tool messages
}
// ContentBlock is a typed content element.
@@ -138,6 +139,7 @@ func (r *ResponseRequest) NormalizeInput() []Message {
msgs = append(msgs, Message{
Role: "tool",
Content: []ContentBlock{{Type: "input_text", Text: item.Output}},
CallID: item.CallID,
})
}
}
@@ -193,6 +195,9 @@ type OutputItem struct {
Status string `json:"status"`
Role string `json:"role,omitempty"`
Content []ContentPart `json:"content,omitempty"`
CallID string `json:"call_id,omitempty"` // for function_call
Name string `json:"name,omitempty"` // for function_call
Arguments string `json:"arguments,omitempty"` // for function_call
}
// ContentPart is a content block within an output item.
@@ -259,6 +264,7 @@ type StreamEvent struct {
Part *ContentPart `json:"part,omitempty"`
Delta string `json:"delta,omitempty"`
Text string `json:"text,omitempty"`
Arguments string `json:"arguments,omitempty"` // for function_call_arguments.done
}
// ============================================================
@@ -271,6 +277,7 @@ type ProviderResult struct {
Model string
Text string
Usage Usage
ToolCalls []ToolCall
}
// ProviderStreamDelta is sent through the stream channel.
@@ -280,6 +287,22 @@ type ProviderStreamDelta struct {
Text string
Done bool
Usage *Usage
ToolCallDelta *ToolCallDelta
}
// ToolCall represents a function call from the model.
type ToolCall struct {
ID string
Name string
Arguments string // JSON string
}
// ToolCallDelta represents a streaming chunk of a tool call.
type ToolCallDelta struct {
Index int
ID string
Name string
Arguments string
}
// ============================================================

View File

@@ -18,12 +18,12 @@ type Config struct {
// ConversationConfig controls conversation storage.
type ConversationConfig struct {
// Store is the storage backend: "memory" (default) or "sql".
// Store is the storage backend: "memory" (default), "sql", or "redis".
Store string `yaml:"store"`
// TTL is the conversation expiration duration (e.g. "1h", "30m"). Defaults to "1h".
TTL string `yaml:"ttl"`
// DSN is the database connection string, required when store is "sql".
// Examples: "conversations.db" (SQLite), "postgres://user:pass@host/db".
// DSN is the database/Redis connection string, required when store is "sql" or "redis".
// Examples: "conversations.db" (SQLite), "postgres://user:pass@host/db", "redis://:password@localhost:6379/0".
DSN string `yaml:"dsn"`
// Driver is the SQL driver name, required when store is "sql".
// Examples: "sqlite3", "postgres", "mysql".
@@ -48,6 +48,8 @@ type ProviderEntry struct {
APIKey string `yaml:"api_key"`
Endpoint string `yaml:"endpoint"`
APIVersion string `yaml:"api_version"`
Project string `yaml:"project"` // For Vertex AI
Location string `yaml:"location"` // For Vertex AI
}
// ModelEntry maps a model name to a provider entry.
@@ -78,6 +80,12 @@ type AzureAnthropicConfig struct {
Model string `yaml:"model"`
}
// VertexAIConfig contains Vertex AI-specific settings used internally by the Google provider.
type VertexAIConfig struct {
Project string `yaml:"project"`
Location string `yaml:"location"`
}
// Load reads and parses a YAML configuration file, expanding ${VAR} env references.
func Load(path string) (*Config, error) {
data, err := os.ReadFile(path)

View File

@@ -4,15 +4,15 @@ import (
"sync"
"time"
"github.com/yourusername/go-llm-gateway/internal/api"
"github.com/ajac-zero/latticelm/internal/api"
)
// Store defines the interface for conversation storage backends.
type Store interface {
Get(id string) (*Conversation, bool)
Create(id string, model string, messages []api.Message) *Conversation
Append(id string, messages ...api.Message) (*Conversation, bool)
Delete(id string)
Get(id string) (*Conversation, error)
Create(id string, model string, messages []api.Message) (*Conversation, error)
Append(id string, messages ...api.Message) (*Conversation, error)
Delete(id string) error
Size() int
}
@@ -47,55 +47,93 @@ func NewMemoryStore(ttl time.Duration) *MemoryStore {
return s
}
// Get retrieves a conversation by ID.
func (s *MemoryStore) Get(id string) (*Conversation, bool) {
// Get retrieves a conversation by ID. Returns a deep copy to prevent data races.
func (s *MemoryStore) Get(id string) (*Conversation, error) {
s.mu.RLock()
defer s.mu.RUnlock()
conv, ok := s.conversations[id]
return conv, ok
if !ok {
return nil, nil
}
// Return a deep copy to prevent data races
msgsCopy := make([]api.Message, len(conv.Messages))
copy(msgsCopy, conv.Messages)
return &Conversation{
ID: conv.ID,
Messages: msgsCopy,
Model: conv.Model,
CreatedAt: conv.CreatedAt,
UpdatedAt: conv.UpdatedAt,
}, nil
}
// Create creates a new conversation with the given messages.
func (s *MemoryStore) Create(id string, model string, messages []api.Message) *Conversation {
func (s *MemoryStore) Create(id string, model string, messages []api.Message) (*Conversation, error) {
s.mu.Lock()
defer s.mu.Unlock()
now := time.Now()
// Store a copy to prevent external modifications
msgsCopy := make([]api.Message, len(messages))
copy(msgsCopy, messages)
conv := &Conversation{
ID: id,
Messages: messages,
Messages: msgsCopy,
Model: model,
CreatedAt: now,
UpdatedAt: now,
}
s.conversations[id] = conv
return conv
// Return a copy
return &Conversation{
ID: id,
Messages: messages,
Model: model,
CreatedAt: now,
UpdatedAt: now,
}, nil
}
// Append adds new messages to an existing conversation.
func (s *MemoryStore) Append(id string, messages ...api.Message) (*Conversation, bool) {
func (s *MemoryStore) Append(id string, messages ...api.Message) (*Conversation, error) {
s.mu.Lock()
defer s.mu.Unlock()
conv, ok := s.conversations[id]
if !ok {
return nil, false
return nil, nil
}
conv.Messages = append(conv.Messages, messages...)
conv.UpdatedAt = time.Now()
return conv, true
// Return a deep copy
msgsCopy := make([]api.Message, len(conv.Messages))
copy(msgsCopy, conv.Messages)
return &Conversation{
ID: conv.ID,
Messages: msgsCopy,
Model: conv.Model,
CreatedAt: conv.CreatedAt,
UpdatedAt: conv.UpdatedAt,
}, nil
}
// Delete removes a conversation from the store.
func (s *MemoryStore) Delete(id string) {
func (s *MemoryStore) Delete(id string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.conversations, id)
return nil
}
// cleanup periodically removes expired conversations.

View File

@@ -0,0 +1,124 @@
package conversation
import (
"context"
"encoding/json"
"time"
"github.com/ajac-zero/latticelm/internal/api"
"github.com/redis/go-redis/v9"
)
// RedisStore manages conversation history in Redis with automatic expiration.
type RedisStore struct {
client *redis.Client
ttl time.Duration
ctx context.Context
}
// NewRedisStore creates a Redis-backed conversation store.
func NewRedisStore(client *redis.Client, ttl time.Duration) *RedisStore {
return &RedisStore{
client: client,
ttl: ttl,
ctx: context.Background(),
}
}
// key returns the Redis key for a conversation ID.
func (s *RedisStore) key(id string) string {
return "conv:" + id
}
// Get retrieves a conversation by ID from Redis.
func (s *RedisStore) Get(id string) (*Conversation, error) {
data, err := s.client.Get(s.ctx, s.key(id)).Bytes()
if err == redis.Nil {
return nil, nil
}
if err != nil {
return nil, err
}
var conv Conversation
if err := json.Unmarshal(data, &conv); err != nil {
return nil, err
}
return &conv, nil
}
// Create creates a new conversation with the given messages.
func (s *RedisStore) Create(id string, model string, messages []api.Message) (*Conversation, error) {
now := time.Now()
conv := &Conversation{
ID: id,
Messages: messages,
Model: model,
CreatedAt: now,
UpdatedAt: now,
}
data, err := json.Marshal(conv)
if err != nil {
return nil, err
}
if err := s.client.Set(s.ctx, s.key(id), data, s.ttl).Err(); err != nil {
return nil, err
}
return conv, nil
}
// Append adds new messages to an existing conversation.
func (s *RedisStore) Append(id string, messages ...api.Message) (*Conversation, error) {
conv, err := s.Get(id)
if err != nil {
return nil, err
}
if conv == nil {
return nil, nil
}
conv.Messages = append(conv.Messages, messages...)
conv.UpdatedAt = time.Now()
data, err := json.Marshal(conv)
if err != nil {
return nil, err
}
if err := s.client.Set(s.ctx, s.key(id), data, s.ttl).Err(); err != nil {
return nil, err
}
return conv, nil
}
// Delete removes a conversation from Redis.
func (s *RedisStore) Delete(id string) error {
return s.client.Del(s.ctx, s.key(id)).Err()
}
// Size returns the number of active conversations in Redis.
func (s *RedisStore) Size() int {
var count int
var cursor uint64
for {
keys, nextCursor, err := s.client.Scan(s.ctx, cursor, "conv:*", 100).Result()
if err != nil {
return 0
}
count += len(keys)
cursor = nextCursor
if cursor == 0 {
break
}
}
return count
}

View File

@@ -5,7 +5,7 @@ import (
"encoding/json"
"time"
"github.com/yourusername/go-llm-gateway/internal/api"
"github.com/ajac-zero/latticelm/internal/api"
)
// sqlDialect holds driver-specific SQL statements.
@@ -65,28 +65,36 @@ func NewSQLStore(db *sql.DB, driver string, ttl time.Duration) (*SQLStore, error
return s, nil
}
func (s *SQLStore) Get(id string) (*Conversation, bool) {
func (s *SQLStore) Get(id string) (*Conversation, error) {
row := s.db.QueryRow(s.dialect.getByID, id)
var conv Conversation
var msgJSON string
err := row.Scan(&conv.ID, &conv.Model, &msgJSON, &conv.CreatedAt, &conv.UpdatedAt)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, false
return nil, err
}
if err := json.Unmarshal([]byte(msgJSON), &conv.Messages); err != nil {
return nil, false
return nil, err
}
return &conv, true
return &conv, nil
}
func (s *SQLStore) Create(id string, model string, messages []api.Message) *Conversation {
func (s *SQLStore) Create(id string, model string, messages []api.Message) (*Conversation, error) {
now := time.Now()
msgJSON, _ := json.Marshal(messages)
msgJSON, err := json.Marshal(messages)
if err != nil {
return nil, err
}
_, _ = s.db.Exec(s.dialect.upsert, id, model, string(msgJSON), now, now)
if _, err := s.db.Exec(s.dialect.upsert, id, model, string(msgJSON), now, now); err != nil {
return nil, err
}
return &Conversation{
ID: id,
@@ -94,26 +102,36 @@ func (s *SQLStore) Create(id string, model string, messages []api.Message) *Conv
Model: model,
CreatedAt: now,
UpdatedAt: now,
}
}, nil
}
func (s *SQLStore) Append(id string, messages ...api.Message) (*Conversation, bool) {
conv, ok := s.Get(id)
if !ok {
return nil, false
func (s *SQLStore) Append(id string, messages ...api.Message) (*Conversation, error) {
conv, err := s.Get(id)
if err != nil {
return nil, err
}
if conv == nil {
return nil, nil
}
conv.Messages = append(conv.Messages, messages...)
conv.UpdatedAt = time.Now()
msgJSON, _ := json.Marshal(conv.Messages)
_, _ = s.db.Exec(s.dialect.update, string(msgJSON), conv.UpdatedAt, id)
msgJSON, err := json.Marshal(conv.Messages)
if err != nil {
return nil, err
}
return conv, true
if _, err := s.db.Exec(s.dialect.update, string(msgJSON), conv.UpdatedAt, id); err != nil {
return nil, err
}
return conv, nil
}
func (s *SQLStore) Delete(id string) {
_, _ = s.db.Exec(s.dialect.deleteByID, id)
func (s *SQLStore) Delete(id string) error {
_, err := s.db.Exec(s.dialect.deleteByID, id)
return err
}
func (s *SQLStore) Size() int {

View File

@@ -2,13 +2,14 @@ package anthropic
import (
"context"
"encoding/json"
"fmt"
"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/yourusername/go-llm-gateway/internal/api"
"github.com/yourusername/go-llm-gateway/internal/config"
"github.com/ajac-zero/latticelm/internal/api"
"github.com/ajac-zero/latticelm/internal/config"
)
const Name = "anthropic"
@@ -85,6 +86,11 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage(anthropic.NewTextBlock(content)))
case "assistant":
anthropicMsgs = append(anthropicMsgs, anthropic.NewAssistantMessage(anthropic.NewTextBlock(content)))
case "tool":
// Tool results must be in user message with tool_result blocks
anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage(
anthropic.NewToolResultBlock(msg.CallID, content, false),
))
case "system", "developer":
system = content
}
@@ -116,17 +122,47 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
params.TopP = anthropic.Float(*req.TopP)
}
// Add tools if present
if req.Tools != nil && len(req.Tools) > 0 {
tools, err := parseTools(req)
if err != nil {
return nil, fmt.Errorf("parse tools: %w", err)
}
params.Tools = tools
}
// Add tool_choice if present
if req.ToolChoice != nil && len(req.ToolChoice) > 0 {
toolChoice, err := parseToolChoice(req)
if err != nil {
return nil, fmt.Errorf("parse tool_choice: %w", err)
}
params.ToolChoice = toolChoice
}
// Call Anthropic API
resp, err := p.client.Messages.New(ctx, params)
if err != nil {
return nil, fmt.Errorf("anthropic api error: %w", err)
}
// Extract text from response
// Extract text and tool calls from response
var text string
var toolCalls []api.ToolCall
for _, block := range resp.Content {
if block.Type == "text" {
text += block.Text
switch block.Type {
case "text":
text += block.AsText().Text
case "tool_use":
// Extract tool calls
toolUse := block.AsToolUse()
argsJSON, _ := json.Marshal(toolUse.Input)
toolCalls = append(toolCalls, api.ToolCall{
ID: toolUse.ID,
Name: toolUse.Name,
Arguments: string(argsJSON),
})
}
}
@@ -134,6 +170,7 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
ID: resp.ID,
Model: string(resp.Model),
Text: text,
ToolCalls: toolCalls,
Usage: api.Usage{
InputTokens: int(resp.Usage.InputTokens),
OutputTokens: int(resp.Usage.OutputTokens),
@@ -177,6 +214,11 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage(anthropic.NewTextBlock(content)))
case "assistant":
anthropicMsgs = append(anthropicMsgs, anthropic.NewAssistantMessage(anthropic.NewTextBlock(content)))
case "tool":
// Tool results must be in user message with tool_result blocks
anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage(
anthropic.NewToolResultBlock(msg.CallID, content, false),
))
case "system", "developer":
system = content
}
@@ -208,20 +250,78 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
params.TopP = anthropic.Float(*req.TopP)
}
// Add tools if present
if req.Tools != nil && len(req.Tools) > 0 {
tools, err := parseTools(req)
if err != nil {
errChan <- fmt.Errorf("parse tools: %w", err)
return
}
params.Tools = tools
}
// Add tool_choice if present
if req.ToolChoice != nil && len(req.ToolChoice) > 0 {
toolChoice, err := parseToolChoice(req)
if err != nil {
errChan <- fmt.Errorf("parse tool_choice: %w", err)
return
}
params.ToolChoice = toolChoice
}
// Create stream
stream := p.client.Messages.NewStreaming(ctx, params)
// Track content block index and tool call state
var contentBlockIndex int
// Process stream
for stream.Next() {
event := stream.Current()
if event.Type == "content_block_delta" && event.Delta.Type == "text_delta" {
switch event.Type {
case "content_block_start":
// New content block (text or tool_use)
contentBlockIndex = int(event.Index)
if event.ContentBlock.Type == "tool_use" {
// Send tool call delta with ID and name
toolUse := event.ContentBlock.AsToolUse()
delta := &api.ToolCallDelta{
Index: contentBlockIndex,
ID: toolUse.ID,
Name: toolUse.Name,
}
select {
case deltaChan <- &api.ProviderStreamDelta{ToolCallDelta: delta}:
case <-ctx.Done():
errChan <- ctx.Err()
return
}
}
case "content_block_delta":
if event.Delta.Type == "text_delta" {
// Text streaming
select {
case deltaChan <- &api.ProviderStreamDelta{Text: event.Delta.Text}:
case <-ctx.Done():
errChan <- ctx.Err()
return
}
} else if event.Delta.Type == "input_json_delta" {
// Tool arguments streaming
delta := &api.ToolCallDelta{
Index: int(event.Index),
Arguments: event.Delta.PartialJSON,
}
select {
case deltaChan <- &api.ProviderStreamDelta{ToolCallDelta: delta}:
case <-ctx.Done():
errChan <- ctx.Err()
return
}
}
}
}

View File

@@ -0,0 +1,154 @@
package anthropic
import (
"encoding/json"
"fmt"
"github.com/ajac-zero/latticelm/internal/api"
"github.com/anthropics/anthropic-sdk-go"
)
// parseTools converts Open Responses tools to Anthropic format
func parseTools(req *api.ResponseRequest) ([]anthropic.ToolUnionParam, error) {
if req.Tools == nil || len(req.Tools) == 0 {
return nil, nil
}
var toolDefs []map[string]interface{}
if err := json.Unmarshal(req.Tools, &toolDefs); err != nil {
return nil, fmt.Errorf("unmarshal tools: %w", err)
}
var tools []anthropic.ToolUnionParam
for _, td := range toolDefs {
// Extract: name, description, parameters
// Note: Anthropic uses "input_schema" instead of "parameters"
name, _ := td["name"].(string)
desc, _ := td["description"].(string)
params, _ := td["parameters"].(map[string]interface{})
inputSchema := anthropic.ToolInputSchemaParam{
Type: "object",
Properties: params["properties"],
}
// Add required fields if present
if required, ok := params["required"].([]interface{}); ok {
requiredStrs := make([]string, 0, len(required))
for _, r := range required {
if str, ok := r.(string); ok {
requiredStrs = append(requiredStrs, str)
}
}
inputSchema.Required = requiredStrs
}
// Create the tool using ToolUnionParamOfTool
tool := anthropic.ToolUnionParamOfTool(inputSchema, name)
if desc != "" {
tool.OfTool.Description = anthropic.String(desc)
}
tools = append(tools, tool)
}
return tools, nil
}
// parseToolChoice converts Open Responses tool_choice to Anthropic format
func parseToolChoice(req *api.ResponseRequest) (anthropic.ToolChoiceUnionParam, error) {
var result anthropic.ToolChoiceUnionParam
if req.ToolChoice == nil || len(req.ToolChoice) == 0 {
return result, nil
}
var choice interface{}
if err := json.Unmarshal(req.ToolChoice, &choice); err != nil {
return result, fmt.Errorf("unmarshal tool_choice: %w", err)
}
// Handle string values: "auto", "any", "required"
if str, ok := choice.(string); ok {
switch str {
case "auto":
result.OfAuto = &anthropic.ToolChoiceAutoParam{
Type: "auto",
}
case "any", "required":
result.OfAny = &anthropic.ToolChoiceAnyParam{
Type: "any",
}
case "none":
result.OfNone = &anthropic.ToolChoiceNoneParam{
Type: "none",
}
default:
return result, fmt.Errorf("unknown tool_choice string: %s", str)
}
return result, nil
}
// Handle specific tool selection: {"type": "tool", "function": {"name": "..."}}
if obj, ok := choice.(map[string]interface{}); ok {
// Check for OpenAI format: {"type": "function", "function": {"name": "..."}}
if funcObj, ok := obj["function"].(map[string]interface{}); ok {
if name, ok := funcObj["name"].(string); ok {
result.OfTool = &anthropic.ToolChoiceToolParam{
Type: "tool",
Name: name,
}
return result, nil
}
}
// Check for direct name field
if name, ok := obj["name"].(string); ok {
result.OfTool = &anthropic.ToolChoiceToolParam{
Type: "tool",
Name: name,
}
return result, nil
}
}
return result, fmt.Errorf("invalid tool_choice format")
}
// extractToolCalls converts Anthropic content blocks to api.ToolCall
func extractToolCalls(content []anthropic.ContentBlockUnion) []api.ToolCall {
var toolCalls []api.ToolCall
for _, block := range content {
// Check if this is a tool_use block
if block.Type == "tool_use" {
// Cast to ToolUseBlock to access the fields
toolUse := block.AsToolUse()
// Marshal the input to JSON string for Arguments
argsJSON, _ := json.Marshal(toolUse.Input)
toolCalls = append(toolCalls, api.ToolCall{
ID: toolUse.ID,
Name: toolUse.Name,
Arguments: string(argsJSON),
})
}
}
return toolCalls
}
// extractToolCallDelta extracts tool call delta from streaming content block delta
func extractToolCallDelta(delta anthropic.RawContentBlockDeltaUnion, index int) *api.ToolCallDelta {
// Check if this is an input_json_delta (streaming tool arguments)
if delta.Type == "input_json_delta" {
return &api.ToolCallDelta{
Index: index,
Arguments: delta.PartialJSON,
}
}
return nil
}

View File

@@ -0,0 +1,119 @@
package anthropic
import (
"encoding/json"
"testing"
"github.com/ajac-zero/latticelm/internal/api"
)
func TestParseTools(t *testing.T) {
// Create a sample tool definition
toolsJSON := `[{
"type": "function",
"name": "get_weather",
"description": "Get the weather for a location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state"
}
},
"required": ["location"]
}
}]`
req := &api.ResponseRequest{
Tools: json.RawMessage(toolsJSON),
}
tools, err := parseTools(req)
if err != nil {
t.Fatalf("parseTools failed: %v", err)
}
if len(tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(tools))
}
tool := tools[0]
if tool.OfTool == nil {
t.Fatal("expected OfTool to be set")
}
if tool.OfTool.Name != "get_weather" {
t.Errorf("expected name 'get_weather', got '%s'", tool.OfTool.Name)
}
desc := tool.GetDescription()
if desc == nil || *desc != "Get the weather for a location" {
t.Errorf("expected description 'Get the weather for a location', got '%v'", desc)
}
if len(tool.OfTool.InputSchema.Required) != 1 || tool.OfTool.InputSchema.Required[0] != "location" {
t.Errorf("expected required=['location'], got %v", tool.OfTool.InputSchema.Required)
}
}
func TestParseToolChoice(t *testing.T) {
tests := []struct {
name string
choiceJSON string
expectAuto bool
expectAny bool
expectTool bool
expectedName string
}{
{
name: "auto",
choiceJSON: `"auto"`,
expectAuto: true,
},
{
name: "any",
choiceJSON: `"any"`,
expectAny: true,
},
{
name: "required",
choiceJSON: `"required"`,
expectAny: true,
},
{
name: "specific tool",
choiceJSON: `{"type": "function", "function": {"name": "get_weather"}}`,
expectTool: true,
expectedName: "get_weather",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := &api.ResponseRequest{
ToolChoice: json.RawMessage(tt.choiceJSON),
}
choice, err := parseToolChoice(req)
if err != nil {
t.Fatalf("parseToolChoice failed: %v", err)
}
if tt.expectAuto && choice.OfAuto == nil {
t.Error("expected OfAuto to be set")
}
if tt.expectAny && choice.OfAny == nil {
t.Error("expected OfAny to be set")
}
if tt.expectTool {
if choice.OfTool == nil {
t.Fatal("expected OfTool to be set")
}
if choice.OfTool.Name != tt.expectedName {
t.Errorf("expected name '%s', got '%s'", tt.expectedName, choice.OfTool.Name)
}
}
})
}
}

View File

@@ -0,0 +1,212 @@
package google
import (
"encoding/json"
"fmt"
"math/rand"
"time"
"google.golang.org/genai"
"github.com/ajac-zero/latticelm/internal/api"
)
// parseTools converts generic tool definitions from req.Tools (JSON) to Google's []*genai.Tool format.
func parseTools(req *api.ResponseRequest) ([]*genai.Tool, error) {
if req.Tools == nil || len(req.Tools) == 0 {
return nil, nil
}
// Unmarshal to slice of tool definitions
var toolDefs []map[string]interface{}
if err := json.Unmarshal(req.Tools, &toolDefs); err != nil {
return nil, fmt.Errorf("unmarshal tools: %w", err)
}
var functionDeclarations []*genai.FunctionDeclaration
for _, toolDef := range toolDefs {
// Extract function details
// Support both flat format (name/description/parameters at top level)
// and nested format (under "function" key)
var name, description string
var parameters interface{}
if functionData, ok := toolDef["function"].(map[string]interface{}); ok {
// Nested format: {"type": "function", "function": {...}}
name, _ = functionData["name"].(string)
description, _ = functionData["description"].(string)
parameters = functionData["parameters"]
} else {
// Flat format: {"type": "function", "name": "...", ...}
name, _ = toolDef["name"].(string)
description, _ = toolDef["description"].(string)
parameters = toolDef["parameters"]
}
if name == "" {
continue
}
// Create function declaration
funcDecl := &genai.FunctionDeclaration{
Name: name,
Description: description,
}
// Google accepts parameters as raw JSON schema
if parameters != nil {
funcDecl.ParametersJsonSchema = parameters
}
functionDeclarations = append(functionDeclarations, funcDecl)
}
// Return single Tool with all function declarations
if len(functionDeclarations) > 0 {
return []*genai.Tool{{FunctionDeclarations: functionDeclarations}}, nil
}
return nil, nil
}
// parseToolChoice converts req.ToolChoice to Google's ToolConfig with FunctionCallingConfig.
func parseToolChoice(req *api.ResponseRequest) (*genai.ToolConfig, error) {
if req.ToolChoice == nil || len(req.ToolChoice) == 0 {
return nil, nil
}
var choice interface{}
if err := json.Unmarshal(req.ToolChoice, &choice); err != nil {
return nil, fmt.Errorf("unmarshal tool_choice: %w", err)
}
config := &genai.ToolConfig{
FunctionCallingConfig: &genai.FunctionCallingConfig{},
}
// Handle string values: "auto", "none", "required"/"any"
if str, ok := choice.(string); ok {
switch str {
case "auto":
config.FunctionCallingConfig.Mode = genai.FunctionCallingConfigModeAuto
case "none":
config.FunctionCallingConfig.Mode = genai.FunctionCallingConfigModeNone
case "required", "any":
config.FunctionCallingConfig.Mode = genai.FunctionCallingConfigModeAny
default:
return nil, fmt.Errorf("unknown tool_choice string: %s", str)
}
return config, nil
}
// Handle object format: {"type": "function", "function": {"name": "..."}}
if obj, ok := choice.(map[string]interface{}); ok {
if typeVal, ok := obj["type"].(string); ok && typeVal == "function" {
config.FunctionCallingConfig.Mode = genai.FunctionCallingConfigModeAny
if funcObj, ok := obj["function"].(map[string]interface{}); ok {
if name, ok := funcObj["name"].(string); ok {
config.FunctionCallingConfig.AllowedFunctionNames = []string{name}
}
}
return config, nil
}
}
return nil, fmt.Errorf("unsupported tool_choice format")
}
// extractToolCalls extracts tool calls from Google's response format to generic api.ToolCall slice.
func extractToolCalls(resp *genai.GenerateContentResponse) []api.ToolCall {
var toolCalls []api.ToolCall
for _, candidate := range resp.Candidates {
if candidate.Content == nil {
continue
}
for _, part := range candidate.Content.Parts {
if part == nil || part.FunctionCall == nil {
continue
}
// Extract function call details
fc := part.FunctionCall
// Marshal arguments to JSON string
var argsJSON string
if fc.Args != nil {
argsBytes, err := json.Marshal(fc.Args)
if err == nil {
argsJSON = string(argsBytes)
} else {
// Fallback to empty object
argsJSON = "{}"
}
} else {
argsJSON = "{}"
}
// Generate ID if Google doesn't provide one
callID := fc.ID
if callID == "" {
callID = fmt.Sprintf("call_%s", generateRandomID())
}
toolCalls = append(toolCalls, api.ToolCall{
ID: callID,
Name: fc.Name,
Arguments: argsJSON,
})
}
}
return toolCalls
}
// extractToolCallDelta extracts streaming tool call information from response parts.
func extractToolCallDelta(part *genai.Part, index int) *api.ToolCallDelta {
if part == nil || part.FunctionCall == nil {
return nil
}
fc := part.FunctionCall
// Marshal arguments to JSON string
var argsJSON string
if fc.Args != nil {
argsBytes, err := json.Marshal(fc.Args)
if err == nil {
argsJSON = string(argsBytes)
} else {
argsJSON = "{}"
}
} else {
argsJSON = "{}"
}
// Generate ID if Google doesn't provide one
callID := fc.ID
if callID == "" {
callID = fmt.Sprintf("call_%s", generateRandomID())
}
return &api.ToolCallDelta{
Index: index,
ID: callID,
Name: fc.Name,
Arguments: argsJSON,
}
}
// generateRandomID generates a random alphanumeric ID
func generateRandomID() string {
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
const length = 24
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
b := make([]byte, length)
for i := range b {
b[i] = charset[rng.Intn(len(charset))]
}
return string(b)
}

View File

@@ -2,13 +2,14 @@ package google
import (
"context"
"encoding/json"
"fmt"
"github.com/google/uuid"
"google.golang.org/genai"
"github.com/yourusername/go-llm-gateway/internal/api"
"github.com/yourusername/go-llm-gateway/internal/config"
"github.com/ajac-zero/latticelm/internal/api"
"github.com/ajac-zero/latticelm/internal/config"
)
const Name = "google"
@@ -19,7 +20,7 @@ type Provider struct {
client *genai.Client
}
// New constructs a Provider using the provided configuration.
// New constructs a Provider using the Google AI API with API key authentication.
func New(cfg config.ProviderConfig) *Provider {
var client *genai.Client
if cfg.APIKey != "" {
@@ -38,13 +39,36 @@ func New(cfg config.ProviderConfig) *Provider {
}
}
// NewVertexAI constructs a Provider targeting Vertex AI.
// Vertex AI uses the same genai SDK but with GCP project/location configuration
// and Application Default Credentials (ADC) or service account authentication.
func NewVertexAI(vertexCfg config.VertexAIConfig) *Provider {
var client *genai.Client
if vertexCfg.Project != "" && vertexCfg.Location != "" {
var err error
client, err = genai.NewClient(context.Background(), &genai.ClientConfig{
Project: vertexCfg.Project,
Location: vertexCfg.Location,
Backend: genai.BackendVertexAI,
})
if err != nil {
// Log error but don't fail construction - will fail on Generate
fmt.Printf("warning: failed to create vertex ai client: %v\n", err)
}
}
return &Provider{
cfg: config.ProviderConfig{
// Vertex AI doesn't use API key, but set empty for consistency
APIKey: "",
},
client: client,
}
}
func (p *Provider) Name() string { return Name }
// Generate routes the request to Gemini and returns a ProviderResult.
func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (*api.ProviderResult, error) {
if p.cfg.APIKey == "" {
return nil, fmt.Errorf("google api key missing")
}
if p.client == nil {
return nil, fmt.Errorf("google client not initialized")
}
@@ -53,7 +77,27 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
contents, systemText := convertMessages(messages)
config := buildConfig(systemText, req)
// Parse tools if present
var tools []*genai.Tool
if req.Tools != nil && len(req.Tools) > 0 {
var err error
tools, err = parseTools(req)
if err != nil {
return nil, fmt.Errorf("parse tools: %w", err)
}
}
// Parse tool_choice if present
var toolConfig *genai.ToolConfig
if req.ToolChoice != nil && len(req.ToolChoice) > 0 {
var err error
toolConfig, err = parseToolChoice(req)
if err != nil {
return nil, fmt.Errorf("parse tool_choice: %w", err)
}
}
config := buildConfig(systemText, req, tools, toolConfig)
resp, err := p.client.Models.GenerateContent(ctx, model, contents, config)
if err != nil {
@@ -69,6 +113,11 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
}
}
var toolCalls []api.ToolCall
if len(resp.Candidates) > 0 {
toolCalls = extractToolCalls(resp)
}
var inputTokens, outputTokens int
if resp.UsageMetadata != nil {
inputTokens = int(resp.UsageMetadata.PromptTokenCount)
@@ -79,6 +128,7 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
ID: uuid.NewString(),
Model: model,
Text: text,
ToolCalls: toolCalls,
Usage: api.Usage{
InputTokens: inputTokens,
OutputTokens: outputTokens,
@@ -96,10 +146,6 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
defer close(deltaChan)
defer close(errChan)
if p.cfg.APIKey == "" {
errChan <- fmt.Errorf("google api key missing")
return
}
if p.client == nil {
errChan <- fmt.Errorf("google client not initialized")
return
@@ -109,7 +155,29 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
contents, systemText := convertMessages(messages)
config := buildConfig(systemText, req)
// Parse tools if present
var tools []*genai.Tool
if req.Tools != nil && len(req.Tools) > 0 {
var err error
tools, err = parseTools(req)
if err != nil {
errChan <- fmt.Errorf("parse tools: %w", err)
return
}
}
// Parse tool_choice if present
var toolConfig *genai.ToolConfig
if req.ToolChoice != nil && len(req.ToolChoice) > 0 {
var err error
toolConfig, err = parseToolChoice(req)
if err != nil {
errChan <- fmt.Errorf("parse tool_choice: %w", err)
return
}
}
config := buildConfig(systemText, req, tools, toolConfig)
stream := p.client.Models.GenerateContentStream(ctx, model, contents, config)
@@ -119,23 +187,34 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
return
}
var text string
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
for _, part := range resp.Candidates[0].Content.Parts {
for partIndex, part := range resp.Candidates[0].Content.Parts {
if part != nil {
text += part.Text
}
}
}
if text != "" {
// Handle text content
if part.Text != "" {
select {
case deltaChan <- &api.ProviderStreamDelta{Text: text}:
case deltaChan <- &api.ProviderStreamDelta{Text: part.Text}:
case <-ctx.Done():
errChan <- ctx.Err()
return
}
}
// Handle tool call content
if part.FunctionCall != nil {
delta := extractToolCallDelta(part, partIndex)
if delta != nil {
select {
case deltaChan <- &api.ProviderStreamDelta{ToolCallDelta: delta}:
case <-ctx.Done():
errChan <- ctx.Err()
return
}
}
}
}
}
}
}
select {
@@ -163,6 +242,39 @@ func convertMessages(messages []api.Message) ([]*genai.Content, string) {
continue
}
if msg.Role == "tool" {
// Tool results are sent as FunctionResponse in user role message
var output string
for _, block := range msg.Content {
if block.Type == "input_text" || block.Type == "output_text" {
output += block.Text
}
}
// Parse output as JSON map, or wrap in {"output": "..."} if not JSON
var responseMap map[string]any
if err := json.Unmarshal([]byte(output), &responseMap); err != nil {
// Not JSON, wrap it
responseMap = map[string]any{"output": output}
}
// Create FunctionResponse part with CallID from message
part := &genai.Part{
FunctionResponse: &genai.FunctionResponse{
ID: msg.CallID,
Name: "", // Name is optional for responses
Response: responseMap,
},
}
// Add to user role message
contents = append(contents, &genai.Content{
Role: "user",
Parts: []*genai.Part{part},
})
continue
}
var parts []*genai.Part
for _, block := range msg.Content {
if block.Type == "input_text" || block.Type == "output_text" {
@@ -185,10 +297,10 @@ func convertMessages(messages []api.Message) ([]*genai.Content, string) {
}
// buildConfig constructs a GenerateContentConfig from system text and request params.
func buildConfig(systemText string, req *api.ResponseRequest) *genai.GenerateContentConfig {
func buildConfig(systemText string, req *api.ResponseRequest, tools []*genai.Tool, toolConfig *genai.ToolConfig) *genai.GenerateContentConfig {
var cfg *genai.GenerateContentConfig
needsCfg := systemText != "" || req.MaxOutputTokens != nil || req.Temperature != nil || req.TopP != nil
needsCfg := systemText != "" || req.MaxOutputTokens != nil || req.Temperature != nil || req.TopP != nil || tools != nil || toolConfig != nil
if !needsCfg {
return nil
}
@@ -215,6 +327,14 @@ func buildConfig(systemText string, req *api.ResponseRequest) *genai.GenerateCon
cfg.TopP = &tp
}
if tools != nil {
cfg.Tools = tools
}
if toolConfig != nil {
cfg.ToolConfig = toolConfig
}
return cfg
}

View File

@@ -0,0 +1,117 @@
package openai
import (
"encoding/json"
"fmt"
"github.com/ajac-zero/latticelm/internal/api"
"github.com/openai/openai-go/v3"
"github.com/openai/openai-go/v3/shared"
)
// parseTools converts Open Responses tools to OpenAI format
func parseTools(req *api.ResponseRequest) ([]openai.ChatCompletionToolUnionParam, error) {
if req.Tools == nil || len(req.Tools) == 0 {
return nil, nil
}
var toolDefs []map[string]interface{}
if err := json.Unmarshal(req.Tools, &toolDefs); err != nil {
return nil, fmt.Errorf("unmarshal tools: %w", err)
}
var tools []openai.ChatCompletionToolUnionParam
for _, td := range toolDefs {
// Convert Open Responses tool to OpenAI ChatCompletionFunctionToolParam
// Extract: name, description, parameters
name, _ := td["name"].(string)
desc, _ := td["description"].(string)
params, _ := td["parameters"].(map[string]interface{})
funcDef := shared.FunctionDefinitionParam{
Name: name,
}
if desc != "" {
funcDef.Description = openai.String(desc)
}
if params != nil {
funcDef.Parameters = shared.FunctionParameters(params)
}
tools = append(tools, openai.ChatCompletionFunctionTool(funcDef))
}
return tools, nil
}
// parseToolChoice converts Open Responses tool_choice to OpenAI format
func parseToolChoice(req *api.ResponseRequest) (openai.ChatCompletionToolChoiceOptionUnionParam, error) {
var result openai.ChatCompletionToolChoiceOptionUnionParam
if req.ToolChoice == nil || len(req.ToolChoice) == 0 {
return result, nil
}
var choice interface{}
if err := json.Unmarshal(req.ToolChoice, &choice); err != nil {
return result, fmt.Errorf("unmarshal tool_choice: %w", err)
}
// Handle string values: "auto", "none", "required"
if str, ok := choice.(string); ok {
result.OfAuto = openai.String(str)
return result, nil
}
// Handle specific function selection: {"type": "function", "function": {"name": "..."}}
if obj, ok := choice.(map[string]interface{}); ok {
funcObj, _ := obj["function"].(map[string]interface{})
name, _ := funcObj["name"].(string)
return openai.ToolChoiceOptionFunctionToolChoice(
openai.ChatCompletionNamedToolChoiceFunctionParam{
Name: name,
},
), nil
}
return result, fmt.Errorf("invalid tool_choice format")
}
// extractToolCalls converts OpenAI tool calls to api.ToolCall
func extractToolCalls(message openai.ChatCompletionMessage) []api.ToolCall {
if len(message.ToolCalls) == 0 {
return nil
}
var toolCalls []api.ToolCall
for _, tc := range message.ToolCalls {
toolCalls = append(toolCalls, api.ToolCall{
ID: tc.ID,
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
})
}
return toolCalls
}
// extractToolCallDelta extracts tool call delta from streaming chunk choice
func extractToolCallDelta(choice openai.ChatCompletionChunkChoice) *api.ToolCallDelta {
if len(choice.Delta.ToolCalls) == 0 {
return nil
}
// OpenAI sends tool calls with index in the delta
for _, tc := range choice.Delta.ToolCalls {
return &api.ToolCallDelta{
Index: int(tc.Index),
ID: tc.ID,
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
}
}
return nil
}

View File

@@ -4,12 +4,12 @@ import (
"context"
"fmt"
"github.com/openai/openai-go"
"github.com/openai/openai-go/azure"
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/v3"
"github.com/openai/openai-go/v3/azure"
"github.com/openai/openai-go/v3/option"
"github.com/yourusername/go-llm-gateway/internal/api"
"github.com/yourusername/go-llm-gateway/internal/config"
"github.com/ajac-zero/latticelm/internal/api"
"github.com/ajac-zero/latticelm/internal/config"
)
const Name = "openai"
@@ -91,6 +91,8 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
oaiMessages = append(oaiMessages, openai.SystemMessage(content))
case "developer":
oaiMessages = append(oaiMessages, openai.SystemMessage(content))
case "tool":
oaiMessages = append(oaiMessages, openai.ToolMessage(content, msg.CallID))
}
}
@@ -108,6 +110,29 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
params.TopP = openai.Float(*req.TopP)
}
// Add tools if present
if req.Tools != nil && len(req.Tools) > 0 {
tools, err := parseTools(req)
if err != nil {
return nil, fmt.Errorf("parse tools: %w", err)
}
params.Tools = tools
}
// Add tool_choice if present
if req.ToolChoice != nil && len(req.ToolChoice) > 0 {
toolChoice, err := parseToolChoice(req)
if err != nil {
return nil, fmt.Errorf("parse tool_choice: %w", err)
}
params.ToolChoice = toolChoice
}
// Add parallel_tool_calls if specified
if req.ParallelToolCalls != nil {
params.ParallelToolCalls = openai.Bool(*req.ParallelToolCalls)
}
// Call OpenAI API
resp, err := p.client.Chat.Completions.New(ctx, params)
if err != nil {
@@ -115,14 +140,20 @@ func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *ap
}
var combinedText string
var toolCalls []api.ToolCall
for _, choice := range resp.Choices {
combinedText += choice.Message.Content
if len(choice.Message.ToolCalls) > 0 {
toolCalls = append(toolCalls, extractToolCalls(choice.Message)...)
}
}
return &api.ProviderResult{
ID: resp.ID,
Model: resp.Model,
Text: combinedText,
ToolCalls: toolCalls,
Usage: api.Usage{
InputTokens: int(resp.Usage.PromptTokens),
OutputTokens: int(resp.Usage.CompletionTokens),
@@ -168,6 +199,8 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
oaiMessages = append(oaiMessages, openai.SystemMessage(content))
case "developer":
oaiMessages = append(oaiMessages, openai.SystemMessage(content))
case "tool":
oaiMessages = append(oaiMessages, openai.ToolMessage(content, msg.CallID))
}
}
@@ -185,6 +218,31 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
params.TopP = openai.Float(*req.TopP)
}
// Add tools if present
if req.Tools != nil && len(req.Tools) > 0 {
tools, err := parseTools(req)
if err != nil {
errChan <- fmt.Errorf("parse tools: %w", err)
return
}
params.Tools = tools
}
// Add tool_choice if present
if req.ToolChoice != nil && len(req.ToolChoice) > 0 {
toolChoice, err := parseToolChoice(req)
if err != nil {
errChan <- fmt.Errorf("parse tool_choice: %w", err)
return
}
params.ToolChoice = toolChoice
}
// Add parallel_tool_calls if specified
if req.ParallelToolCalls != nil {
params.ParallelToolCalls = openai.Bool(*req.ParallelToolCalls)
}
// Create streaming request
stream := p.client.Chat.Completions.NewStreaming(ctx, params)
@@ -193,10 +251,8 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
chunk := stream.Current()
for _, choice := range chunk.Choices {
if choice.Delta.Content == "" {
continue
}
// Handle text content
if choice.Delta.Content != "" {
select {
case deltaChan <- &api.ProviderStreamDelta{
ID: chunk.ID,
@@ -208,6 +264,24 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
return
}
}
// Handle tool call deltas
if len(choice.Delta.ToolCalls) > 0 {
delta := extractToolCallDelta(choice)
if delta != nil {
select {
case deltaChan <- &api.ProviderStreamDelta{
ID: chunk.ID,
Model: chunk.Model,
ToolCallDelta: delta,
}:
case <-ctx.Done():
errChan <- ctx.Err()
return
}
}
}
}
}
if err := stream.Err(); err != nil {

View File

@@ -4,11 +4,11 @@ import (
"context"
"fmt"
"github.com/yourusername/go-llm-gateway/internal/api"
"github.com/yourusername/go-llm-gateway/internal/config"
anthropicprovider "github.com/yourusername/go-llm-gateway/internal/providers/anthropic"
googleprovider "github.com/yourusername/go-llm-gateway/internal/providers/google"
openaiprovider "github.com/yourusername/go-llm-gateway/internal/providers/openai"
"github.com/ajac-zero/latticelm/internal/api"
"github.com/ajac-zero/latticelm/internal/config"
anthropicprovider "github.com/ajac-zero/latticelm/internal/providers/anthropic"
googleprovider "github.com/ajac-zero/latticelm/internal/providers/google"
openaiprovider "github.com/ajac-zero/latticelm/internal/providers/openai"
)
// Provider represents a unified interface that each LLM provider must implement.
@@ -60,7 +60,8 @@ func NewRegistry(entries map[string]config.ProviderEntry, models []config.ModelE
}
func buildProvider(entry config.ProviderEntry) (Provider, error) {
if entry.APIKey == "" {
// Vertex AI doesn't require APIKey, so check for it separately
if entry.Type != "vertexai" && entry.APIKey == "" {
return nil, nil
}
@@ -97,6 +98,14 @@ func buildProvider(entry config.ProviderEntry) (Provider, error) {
APIKey: entry.APIKey,
Endpoint: entry.Endpoint,
}), nil
case "vertexai":
if entry.Project == "" || entry.Location == "" {
return nil, fmt.Errorf("project and location are required for vertexai")
}
return googleprovider.NewVertexAI(config.VertexAIConfig{
Project: entry.Project,
Location: entry.Location,
}), nil
default:
return nil, fmt.Errorf("unknown provider type %q", entry.Type)
}

View File

@@ -10,9 +10,9 @@ import (
"github.com/google/uuid"
"github.com/yourusername/go-llm-gateway/internal/api"
"github.com/yourusername/go-llm-gateway/internal/conversation"
"github.com/yourusername/go-llm-gateway/internal/providers"
"github.com/ajac-zero/latticelm/internal/api"
"github.com/ajac-zero/latticelm/internal/conversation"
"github.com/ajac-zero/latticelm/internal/providers"
)
// GatewayServer hosts the Open Responses API for the gateway.
@@ -84,8 +84,13 @@ func (s *GatewayServer) handleResponses(w http.ResponseWriter, r *http.Request)
// Build full message history from previous conversation
var historyMsgs []api.Message
if req.PreviousResponseID != nil && *req.PreviousResponseID != "" {
conv, ok := s.convs.Get(*req.PreviousResponseID)
if !ok {
conv, err := s.convs.Get(*req.PreviousResponseID)
if err != nil {
s.logger.Printf("error retrieving conversation: %v", err)
http.Error(w, "error retrieving conversation", http.StatusInternalServerError)
return
}
if conv == nil {
http.Error(w, "conversation not found", http.StatusNotFound)
return
}
@@ -140,7 +145,10 @@ func (s *GatewayServer) handleSyncResponse(w http.ResponseWriter, r *http.Reques
Content: []api.ContentBlock{{Type: "output_text", Text: result.Text}},
}
allMsgs := append(storeMsgs, assistantMsg)
s.convs.Create(responseID, result.Model, allMsgs)
if _, err := s.convs.Create(responseID, result.Model, allMsgs); err != nil {
s.logger.Printf("error storing conversation: %v", err)
// Don't fail the response if storage fails
}
// Build spec-compliant response
resp := s.buildResponse(origReq, result, provider.Name(), responseID)
@@ -224,6 +232,17 @@ func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.R
var streamErr error
var providerModel string
// Track tool calls being built
type toolCallBuilder struct {
itemID string
id string
name string
arguments string
}
toolCallsInProgress := make(map[int]*toolCallBuilder)
nextOutputIdx := 0
textItemAdded := false
loop:
for {
select {
@@ -234,7 +253,14 @@ loop:
if delta.Model != "" && providerModel == "" {
providerModel = delta.Model
}
// Handle text content
if delta.Text != "" {
// Add text item on first text delta
if !textItemAdded {
textItemAdded = true
nextOutputIdx++
}
fullText += delta.Text
s.sendSSE(w, flusher, &seq, "response.output_text.delta", &api.StreamEvent{
Type: "response.output_text.delta",
@@ -244,6 +270,53 @@ loop:
Delta: delta.Text,
})
}
// Handle tool call delta
if delta.ToolCallDelta != nil {
tc := delta.ToolCallDelta
// First chunk for this tool call index
if _, exists := toolCallsInProgress[tc.Index]; !exists {
toolItemID := generateID("item_")
toolOutputIdx := nextOutputIdx
nextOutputIdx++
// Send response.output_item.added
s.sendSSE(w, flusher, &seq, "response.output_item.added", &api.StreamEvent{
Type: "response.output_item.added",
OutputIndex: &toolOutputIdx,
Item: &api.OutputItem{
ID: toolItemID,
Type: "function_call",
Status: "in_progress",
CallID: tc.ID,
Name: tc.Name,
},
})
toolCallsInProgress[tc.Index] = &toolCallBuilder{
itemID: toolItemID,
id: tc.ID,
name: tc.Name,
arguments: "",
}
}
// Send function_call_arguments.delta
if tc.Arguments != "" {
builder := toolCallsInProgress[tc.Index]
builder.arguments += tc.Arguments
toolOutputIdx := outputIdx + 1 + tc.Index
s.sendSSE(w, flusher, &seq, "response.function_call_arguments.delta", &api.StreamEvent{
Type: "response.function_call_arguments.delta",
ItemID: builder.itemID,
OutputIndex: &toolOutputIdx,
Delta: tc.Arguments,
})
}
}
if delta.Done {
break loop
}
@@ -277,6 +350,8 @@ loop:
return
}
// Send done events for text output if text was added
if textItemAdded && fullText != "" {
// response.output_text.done
s.sendSSE(w, flusher, &seq, "response.output_text.done", &api.StreamEvent{
Type: "response.output_text.done",
@@ -313,18 +388,70 @@ loop:
OutputIndex: &outputIdx,
Item: completedItem,
})
}
// Send done events for each tool call
for idx, builder := range toolCallsInProgress {
toolOutputIdx := outputIdx + 1 + idx
s.sendSSE(w, flusher, &seq, "response.function_call_arguments.done", &api.StreamEvent{
Type: "response.function_call_arguments.done",
ItemID: builder.itemID,
OutputIndex: &toolOutputIdx,
Arguments: builder.arguments,
})
s.sendSSE(w, flusher, &seq, "response.output_item.done", &api.StreamEvent{
Type: "response.output_item.done",
OutputIndex: &toolOutputIdx,
Item: &api.OutputItem{
ID: builder.itemID,
Type: "function_call",
Status: "completed",
CallID: builder.id,
Name: builder.name,
Arguments: builder.arguments,
},
})
}
// Build final completed response
model := origReq.Model
if providerModel != "" {
model = providerModel
}
// Collect tool calls for result
var toolCalls []api.ToolCall
for _, builder := range toolCallsInProgress {
toolCalls = append(toolCalls, api.ToolCall{
ID: builder.id,
Name: builder.name,
Arguments: builder.arguments,
})
}
finalResult := &api.ProviderResult{
Model: model,
Text: fullText,
ToolCalls: toolCalls,
}
completedResp := s.buildResponse(origReq, finalResult, provider.Name(), responseID)
// Update item IDs to match what we sent during streaming
if textItemAdded && len(completedResp.Output) > 0 {
completedResp.Output[0].ID = itemID
}
for idx, builder := range toolCallsInProgress {
// Find the corresponding output item
for i := range completedResp.Output {
if completedResp.Output[i].Type == "function_call" && completedResp.Output[i].CallID == builder.id {
completedResp.Output[i].ID = builder.itemID
break
}
}
_ = idx // unused
}
// response.completed
s.sendSSE(w, flusher, &seq, "response.completed", &api.StreamEvent{
@@ -339,7 +466,10 @@ loop:
Content: []api.ContentBlock{{Type: "output_text", Text: fullText}},
}
allMsgs := append(storeMsgs, assistantMsg)
s.convs.Create(responseID, model, allMsgs)
if _, err := s.convs.Create(responseID, model, allMsgs); err != nil {
s.logger.Printf("error storing conversation: %v", err)
// Don't fail the response if storage fails
}
}
}
@@ -363,10 +493,13 @@ func (s *GatewayServer) buildResponse(req *api.ResponseRequest, result *api.Prov
model = req.Model
}
// Build output item
itemID := generateID("msg_")
outputItem := api.OutputItem{
ID: itemID,
// Build output items array
outputItems := []api.OutputItem{}
// Add message item if there's text
if result.Text != "" {
outputItems = append(outputItems, api.OutputItem{
ID: generateID("msg_"),
Type: "message",
Status: "completed",
Role: "assistant",
@@ -375,6 +508,19 @@ func (s *GatewayServer) buildResponse(req *api.ResponseRequest, result *api.Prov
Text: result.Text,
Annotations: []api.Annotation{},
}},
})
}
// Add function_call items
for _, tc := range result.ToolCalls {
outputItems = append(outputItems, api.OutputItem{
ID: generateID("item_"),
Type: "function_call",
Status: "completed",
CallID: tc.ID,
Name: tc.Name,
Arguments: tc.Arguments,
})
}
// Echo back request params with defaults
@@ -454,7 +600,7 @@ func (s *GatewayServer) buildResponse(req *api.ResponseRequest, result *api.Prov
Model: model,
PreviousResponseID: req.PreviousResponseID,
Instructions: req.Instructions,
Output: []api.OutputItem{outputItem},
Output: outputItems,
Error: nil,
Tools: tools,
ToolChoice: toolChoice,

View File

@@ -3,12 +3,12 @@
# requires-python = ">=3.11"
# dependencies = [
# "rich>=13.7.0",
# "httpx>=0.27.0",
# "openai>=1.0.0",
# ]
# ///
"""
Terminal chat interface for go-llm-gateway.
Terminal chat interface for latticelm.
Usage:
python chat.py
@@ -18,11 +18,10 @@ Usage:
"""
import argparse
import json
import sys
from typing import Optional
import httpx
from openai import OpenAI, APIStatusError
from rich.console import Console
from rich.live import Live
from rich.markdown import Markdown
@@ -34,16 +33,13 @@ from rich.table import Table
class ChatClient:
def __init__(self, base_url: str, token: Optional[str] = None):
self.base_url = base_url.rstrip("/")
self.token = token
self.client = OpenAI(
base_url=f"{self.base_url}/v1",
api_key=token or "no-key",
)
self.messages = []
self.console = Console()
def _headers(self) -> dict:
headers = {"Content-Type": "application/json"}
if self.token:
headers["Authorization"] = f"Bearer {self.token}"
return headers
def chat(self, user_message: str, model: str, stream: bool = True):
"""Send a chat message and get response."""
# Add user message to history
@@ -52,35 +48,20 @@ class ChatClient:
"content": [{"type": "input_text", "text": user_message}]
})
payload = {
"model": model,
"input": self.messages,
"stream": stream
}
if stream:
return self._stream_response(payload, model)
return self._stream_response(model)
else:
return self._sync_response(payload, model)
return self._sync_response(model)
def _sync_response(self, payload: dict, model: str) -> str:
def _sync_response(self, model: str) -> str:
"""Non-streaming response."""
with self.console.status(f"[bold blue]Thinking ({model})..."):
resp = httpx.post(
f"{self.base_url}/v1/responses",
json=payload,
headers=self._headers(),
timeout=60.0
response = self.client.responses.create(
model=model,
input=self.messages,
)
resp.raise_for_status()
data = resp.json()
assistant_text = ""
for msg in data.get("output", []):
for block in msg.get("content", []):
if block.get("type") == "output_text":
assistant_text += block.get("text", "")
assistant_text = response.output_text
# Add to history
self.messages.append({
@@ -90,40 +71,19 @@ class ChatClient:
return assistant_text
def _stream_response(self, payload: dict, model: str) -> str:
def _stream_response(self, model: str) -> str:
"""Streaming response with live rendering."""
assistant_text = ""
with httpx.stream(
"POST",
f"{self.base_url}/v1/responses",
json=payload,
headers=self._headers(),
timeout=60.0
) as resp:
resp.raise_for_status()
with Live(console=self.console, refresh_per_second=10) as live:
for line in resp.iter_lines():
if not line.startswith("data: "):
continue
data_str = line[6:] # Remove "data: " prefix
try:
chunk = json.loads(data_str)
except json.JSONDecodeError:
continue
if chunk.get("done"):
break
delta = chunk.get("delta", {})
for block in delta.get("content", []):
if block.get("type") == "output_text":
assistant_text += block.get("text", "")
# Render markdown in real-time
stream = self.client.responses.create(
model=model,
input=self.messages,
stream=True,
)
for event in stream:
if event.type == "response.output_text.delta":
assistant_text += event.delta
live.update(Markdown(assistant_text))
# Add to history
@@ -139,43 +99,56 @@ class ChatClient:
self.messages = []
def print_models_table(base_url: str, headers: dict):
def print_models_table(client: OpenAI):
"""Fetch and print available models from the gateway."""
console = Console()
try:
resp = httpx.get(f"{base_url}/v1/models", headers=headers, timeout=10)
resp.raise_for_status()
data = resp.json().get("data", [])
models = client.models.list()
except Exception as e:
console.print(f"[red]Failed to fetch models: {e}[/red]")
return
table = Table(title="Available Models", show_header=True, header_style="bold magenta")
table.add_column("Provider", style="cyan")
table.add_column("Owner", style="cyan")
table.add_column("Model ID", style="green")
for model in data:
table.add_row(model.get("provider", ""), model.get("id", ""))
for model in models:
table.add_row(model.owned_by, model.id)
console.print(table)
def main():
parser = argparse.ArgumentParser(description="Chat with go-llm-gateway")
parser = argparse.ArgumentParser(description="Chat with latticelm")
parser.add_argument("--url", default="http://localhost:8080", help="Gateway URL")
parser.add_argument("--model", default="gemini-2.0-flash-exp", help="Model to use")
parser.add_argument("--model", default=None, help="Model to use (defaults to first available)")
parser.add_argument("--token", help="Auth token (Bearer)")
parser.add_argument("--no-stream", action="store_true", help="Disable streaming")
args = parser.parse_args()
console = Console()
client = ChatClient(args.url, args.token)
# Fetch available models and select default
try:
available_models = list(client.client.models.list())
except Exception as e:
console.print(f"[bold red]Failed to connect to gateway:[/bold red] {e}")
sys.exit(1)
if not available_models:
console.print("[bold red]Error:[/bold red] No models are configured on the gateway.")
sys.exit(1)
if args.model:
current_model = args.model
else:
current_model = available_models[0].id
stream_enabled = not args.no_stream
# Welcome banner
console.print(Panel.fit(
"[bold cyan]go-llm-gateway Chat Interface[/bold cyan]\n"
"[bold cyan]latticelm Chat Interface[/bold cyan]\n"
f"Connected to: [green]{args.url}[/green]\n"
f"Model: [yellow]{current_model}[/yellow]\n"
f"Streaming: [{'green' if stream_enabled else 'red'}]{stream_enabled}[/]\n\n"
@@ -230,7 +203,7 @@ def main():
))
elif cmd == "/models":
print_models_table(args.url, client._headers())
print_models_table(client.client)
elif cmd == "/model":
if len(cmd_parts) < 2:
@@ -265,8 +238,8 @@ def main():
# For non-streaming, render markdown
console.print(Markdown(response))
except httpx.HTTPStatusError as e:
console.print(f"[bold red]Error {e.response.status_code}:[/bold red] {e.response.text}")
except APIStatusError as e:
console.print(f"[bold red]Error {e.status_code}:[/bold red] {e.message}")
except Exception as e:
console.print(f"[bold red]Error:[/bold red] {e}")

View File

@@ -0,0 +1,270 @@
import {
testTemplates,
runAllTests,
type TestConfig,
type TestResult,
} from "../src/compliance-tests.ts";
const colors = {
green: (s: string) => `\x1b[32m${s}\x1b[0m`,
red: (s: string) => `\x1b[31m${s}\x1b[0m`,
yellow: (s: string) => `\x1b[33m${s}\x1b[0m`,
gray: (s: string) => `\x1b[90m${s}\x1b[0m`,
};
interface CliArgs {
baseUrl?: string;
apiKey?: string;
model?: string;
authHeader?: string;
noBearer?: boolean;
noAuth?: boolean;
filter?: string[];
verbose?: boolean;
json?: boolean;
help?: boolean;
}
function parseArgs(argv: string[]): CliArgs {
const args: CliArgs = {};
let i = 0;
while (i < argv.length) {
const arg = argv[i];
const nextArg = argv[i + 1];
switch (arg) {
case "--base-url":
case "-u":
args.baseUrl = nextArg;
i += 2;
break;
case "--api-key":
case "-k":
args.apiKey = nextArg;
i += 2;
break;
case "--model":
case "-m":
args.model = nextArg;
i += 2;
break;
case "--auth-header":
args.authHeader = nextArg;
i += 2;
break;
case "--no-bearer":
args.noBearer = true;
i += 1;
break;
case "--no-auth":
args.noAuth = true;
i += 1;
break;
case "--filter":
case "-f":
args.filter = nextArg.split(",").map((s) => s.trim());
i += 2;
break;
case "--verbose":
case "-v":
args.verbose = true;
i += 1;
break;
case "--json":
args.json = true;
i += 1;
break;
case "--help":
case "-h":
args.help = true;
i += 1;
break;
default:
i += 1;
}
}
return args;
}
function printHelp() {
console.log(`
Usage: npm run test:compliance -- [options]
Options:
-u, --base-url <url> Gateway base URL (default: http://localhost:8080)
-k, --api-key <key> API key (or set OPENRESPONSES_API_KEY env var)
--no-auth Skip authentication header entirely
-m, --model <model> Model name (default: gpt-4o-mini)
--auth-header <name> Auth header name (default: Authorization)
--no-bearer Disable Bearer prefix in auth header
-f, --filter <ids> Filter tests by ID (comma-separated)
-v, --verbose Verbose output with request/response details
--json Output results as JSON
-h, --help Show this help message
Test IDs:
${testTemplates.map((t) => t.id).join(", ")}
Examples:
npm run test:compliance
npm run test:compliance -- --model claude-3-5-sonnet-20241022
npm run test:compliance -- --filter basic-response,streaming-response
npm run test:compliance -- --verbose --filter basic-response
npm run test:compliance -- --json > results.json
`);
}
function getStatusIcon(status: TestResult["status"]): string {
switch (status) {
case "passed":
return colors.green("✓");
case "failed":
return colors.red("✗");
case "running":
return colors.yellow("◉");
case "pending":
return colors.gray("○");
}
}
function printResult(result: TestResult, verbose: boolean) {
const icon = getStatusIcon(result.status);
const duration = result.duration ? ` (${result.duration}ms)` : "";
const events =
result.streamEvents !== undefined ? ` [${result.streamEvents} events]` : "";
const name =
result.status === "failed" ? colors.red(result.name) : result.name;
console.log(`${icon} ${name}${duration}${events}`);
if (result.status === "failed" && result.errors?.length) {
for (const error of result.errors) {
console.log(` ${colors.red("✗")} ${error}`);
}
if (verbose) {
if (result.request) {
console.log(`\n Request:`);
console.log(
` ${JSON.stringify(result.request, null, 2).split("\n").join("\n ")}`,
);
}
if (result.response) {
console.log(`\n Response:`);
const responseStr =
typeof result.response === "string"
? result.response
: JSON.stringify(result.response, null, 2);
console.log(` ${responseStr.split("\n").join("\n ")}`);
}
}
}
}
async function main() {
const args = parseArgs(process.argv.slice(2));
if (args.help) {
printHelp();
process.exit(0);
}
const baseUrl = args.baseUrl || "http://localhost:8080";
const apiKey = args.apiKey || process.env.OPENRESPONSES_API_KEY || "";
if (!apiKey && !args.noAuth) {
// No auth is fine for local gateway without auth enabled
}
const config: TestConfig = {
baseUrl,
apiKey,
model: args.model || "gpt-4o-mini",
authHeaderName: args.authHeader || "Authorization",
useBearerPrefix: !args.noBearer,
};
if (args.filter?.length) {
const availableIds = testTemplates.map((t) => t.id);
const invalidFilters = args.filter.filter(
(id) => !availableIds.includes(id),
);
if (invalidFilters.length) {
console.error(
`${colors.red("Error:")} Invalid test IDs: ${invalidFilters.join(", ")}`,
);
console.error(`Available test IDs: ${availableIds.join(", ")}`);
process.exit(1);
}
}
const allUpdates: TestResult[] = [];
const onProgress = (result: TestResult) => {
if (args.filter && !args.filter.includes(result.id)) {
return;
}
allUpdates.push(result);
if (!args.json && result.status !== "running") {
printResult(result, args.verbose || false);
}
};
if (!args.json) {
console.log(`Running compliance tests against: ${baseUrl}`);
console.log(`Model: ${config.model}`);
if (args.filter) {
console.log(`Filter: ${args.filter.join(", ")}`);
}
console.log();
}
await runAllTests(config, onProgress);
const finalResults = allUpdates.filter(
(r) => r.status === "passed" || r.status === "failed",
);
const passed = finalResults.filter((r) => r.status === "passed").length;
const failed = finalResults.filter((r) => r.status === "failed").length;
if (args.json) {
console.log(
JSON.stringify(
{
summary: { passed, failed, total: finalResults.length },
results: finalResults,
},
null,
2,
),
);
} else {
console.log(`\n${"=".repeat(50)}`);
console.log(
`Results: ${colors.green(`${passed} passed`)}, ${colors.red(`${failed} failed`)}, ${finalResults.length} total`,
);
if (failed > 0) {
console.log(`\nFailed tests:`);
for (const r of finalResults) {
if (r.status === "failed") {
console.log(`\n${r.name}:`);
for (const e of r.errors || []) {
console.log(` - ${e}`);
}
}
}
} else {
console.log(`\n${colors.green("✓ All tests passed!")}`);
}
}
process.exit(failed > 0 ? 1 : 0);
}
main().catch((error) => {
console.error(colors.red("Fatal error:"), error);
process.exit(1);
});

58
tests/package-lock.json generated Normal file
View File

@@ -0,0 +1,58 @@
{
"name": "latticelm-compliance-tests",
"version": "1.0.0",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "latticelm-compliance-tests",
"version": "1.0.0",
"devDependencies": {
"@types/node": "^22.0.0",
"typescript": "^5.7.0",
"zod": "^3.24.0"
}
},
"node_modules/@types/node": {
"version": "22.19.13",
"resolved": "https://registry.npmjs.org/@types/node/-/node-22.19.13.tgz",
"integrity": "sha512-akNQMv0wW5uyRpD2v2IEyRSZiR+BeGuoB6L310EgGObO44HSMNT8z1xzio28V8qOrgYaopIDNA18YgdXd+qTiw==",
"dev": true,
"license": "MIT",
"dependencies": {
"undici-types": "~6.21.0"
}
},
"node_modules/typescript": {
"version": "5.9.3",
"resolved": "https://registry.npmjs.org/typescript/-/typescript-5.9.3.tgz",
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
"dev": true,
"license": "Apache-2.0",
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"
},
"engines": {
"node": ">=14.17"
}
},
"node_modules/undici-types": {
"version": "6.21.0",
"resolved": "https://registry.npmjs.org/undici-types/-/undici-types-6.21.0.tgz",
"integrity": "sha512-iwDZqg0QAGrg9Rav5H4n0M64c3mkR59cJ6wQp+7C4nI0gsmExaedaYLNO44eT4AtBBwjbTiGPMlt2Md0T9H9JQ==",
"dev": true,
"license": "MIT"
},
"node_modules/zod": {
"version": "3.25.76",
"resolved": "https://registry.npmjs.org/zod/-/zod-3.25.76.tgz",
"integrity": "sha512-gzUt/qt81nXsFGKIFcC3YnfEAx5NkunCfnDlvuBSSFS02bcXu4Lmea0AFIUwbLWxWPx3d9p8S5QoaujKcNQxcQ==",
"dev": true,
"license": "MIT",
"funding": {
"url": "https://github.com/sponsors/colinhacks"
}
}
}
}

17
tests/package.json Normal file
View File

@@ -0,0 +1,17 @@
{
"name": "latticelm-compliance-tests",
"version": "1.0.0",
"private": true,
"description": "Open Responses compliance tests for latticelm",
"type": "module",
"scripts": {
"test:compliance": "node --experimental-strip-types bin/compliance-test.ts",
"test:compliance:verbose": "node --experimental-strip-types bin/compliance-test.ts --verbose",
"test:compliance:json": "node --experimental-strip-types bin/compliance-test.ts --json"
},
"devDependencies": {
"zod": "^3.24.0",
"typescript": "^5.7.0",
"@types/node": "^22.0.0"
}
}

View File

@@ -0,0 +1,370 @@
import { responseResourceSchema, type ResponseResource } from "./schemas.ts";
import { parseSSEStream, type SSEParseResult } from "./sse-parser.ts";
export interface TestConfig {
baseUrl: string;
apiKey: string;
authHeaderName: string;
useBearerPrefix: boolean;
model: string;
}
export interface TestResult {
id: string;
name: string;
description: string;
status: "pending" | "running" | "passed" | "failed";
duration?: number;
request?: unknown;
response?: unknown;
errors?: string[];
streamEvents?: number;
}
interface ValidatorContext {
streaming: boolean;
sseResult?: SSEParseResult;
}
type ResponseValidator = (
response: ResponseResource,
context: ValidatorContext,
) => string[];
export interface TestTemplate {
id: string;
name: string;
description: string;
getRequest: (config: TestConfig) => Record<string, unknown>;
streaming?: boolean;
validators: ResponseValidator[];
}
// ============================================================
// Validators
// ============================================================
const hasOutput: ResponseValidator = (response) => {
if (!response.output || response.output.length === 0) {
return ["Response has no output items"];
}
return [];
};
const hasOutputType =
(type: string): ResponseValidator =>
(response) => {
const hasType = response.output?.some((item) => item.type === type);
if (!hasType) {
return [`Expected output item of type "${type}" but none found`];
}
return [];
};
const completedStatus: ResponseValidator = (response) => {
if (response.status !== "completed") {
return [`Expected status "completed" but got "${response.status}"`];
}
return [];
};
const streamingEvents: ResponseValidator = (_, context) => {
if (!context.streaming) return [];
if (!context.sseResult || context.sseResult.events.length === 0) {
return ["No streaming events received"];
}
return [];
};
const streamingSchema: ResponseValidator = (_, context) => {
if (!context.streaming || !context.sseResult) return [];
return context.sseResult.errors;
};
// ============================================================
// Test Templates
// ============================================================
export const testTemplates: TestTemplate[] = [
{
id: "basic-response",
name: "Basic Text Response",
description: "Simple user message, validates ResponseResource schema",
getRequest: (config) => ({
model: config.model,
input: [
{
type: "message",
role: "user",
content: [{ type: "input_text", text: "Say hello in exactly 3 words." }],
},
],
}),
validators: [hasOutput, completedStatus],
},
{
id: "streaming-response",
name: "Streaming Response",
description: "Validates SSE streaming events and final response",
streaming: true,
getRequest: (config) => ({
model: config.model,
input: [
{
type: "message",
role: "user",
content: [{ type: "input_text", text: "Count from 1 to 5." }],
},
],
}),
validators: [streamingEvents, streamingSchema, completedStatus],
},
{
id: "system-prompt",
name: "System Prompt",
description: "Include system instructions via the instructions field",
getRequest: (config) => ({
model: config.model,
instructions: "You are a pirate. Always respond in pirate speak.",
input: [
{
type: "message",
role: "user",
content: [{ type: "input_text", text: "Say hello." }],
},
],
}),
validators: [hasOutput, completedStatus],
},
{
id: "tool-calling",
name: "Tool Calling",
description: "Define a function tool and verify function_call output",
getRequest: (config) => ({
model: config.model,
input: [
{
type: "message",
role: "user",
content: [
{
type: "input_text",
text: "What's the weather like in San Francisco?",
},
],
},
],
tools: [
{
type: "function",
name: "get_weather",
description: "Get the current weather for a location",
parameters: {
type: "object",
properties: {
location: {
type: "string",
description: "The city and state, e.g. San Francisco, CA",
},
},
required: ["location"],
},
},
],
}),
validators: [hasOutput, hasOutputType("function_call")],
},
{
id: "image-input",
name: "Image Input",
description: "Send image URL in user content",
getRequest: (config) => ({
model: config.model,
input: [
{
type: "message",
role: "user",
content: [
{
type: "input_text",
text: "What do you see in this image? Answer in one sentence.",
},
{
type: "input_image",
image_url:
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAIAAAD8GO2jAAABmklEQVR42tyWAaTyUBzFew/eG4AHz+MBSAHKBiJRGFKwIgQQJKLUIioBIhCAiCAAEizAQIAECaASqFFJq84nudjnaqvuPnxzgP9xfrq5938csPn7PwHTKSoViCIEAYEAMhmoKsU2mUCWEQqB5xEMIp/HaGQG2G6RSuH9HQ7H34rFrtPbdz4jl6PbwmEsl3QA1mt4vcRKk8dz9eg6IpF7tt9fzGY0gCgafFRFo5Blc5vLhf3eCOj1yNhM5GRMVK0aATxPZoz09YXjkQDmczJgquGQAPp9WwCNBgG027YACgUC6HRsAZRKBDAY2AJoNv/ZnwzA6WScznG3p4UAymXGAEkyXrTFAh8fLAGqagQAyGaZpYsi7bHTNPz8MEj//LxuFPo+UBS8vb0KaLXubrRa7aX0RMLCykwmn0z3+XA4WACcTpCkh9MFAZpmuVXo+mO/w+/HZvNgbblcUCxaSo/Hyck80Yu6XXDcvfVZr79cvMZjuN2U9O9vKAqjZrfbIZ0mV4TUi9Xqz6jddNy//7+e3n8Fhf/Llo2kxi8AQyGRoDkmAhAAAAAASUVORK5CYII=",
},
],
},
],
}),
validators: [hasOutput, completedStatus],
},
{
id: "multi-turn",
name: "Multi-turn Conversation",
description: "Send assistant + user messages as conversation history",
getRequest: (config) => ({
model: config.model,
input: [
{
type: "message",
role: "user",
content: [{ type: "input_text", text: "My name is Alice." }],
},
{
type: "message",
role: "assistant",
content: [
{
type: "output_text",
text: "Hello Alice! Nice to meet you. How can I help you today?",
},
],
},
{
type: "message",
role: "user",
content: [{ type: "input_text", text: "What is my name?" }],
},
],
}),
validators: [hasOutput, completedStatus],
},
];
// ============================================================
// Test Runner
// ============================================================
async function makeRequest(
config: TestConfig,
body: Record<string, unknown>,
streaming = false,
): Promise<Response> {
const headers: Record<string, string> = {
"Content-Type": "application/json",
};
if (config.apiKey) {
const authValue = config.useBearerPrefix
? `Bearer ${config.apiKey}`
: config.apiKey;
headers[config.authHeaderName] = authValue;
}
return fetch(`${config.baseUrl}/v1/responses`, {
method: "POST",
headers,
body: JSON.stringify({ ...body, stream: streaming }),
});
}
async function runTest(
template: TestTemplate,
config: TestConfig,
): Promise<TestResult> {
const startTime = Date.now();
const requestBody = template.getRequest(config);
const streaming = template.streaming ?? false;
try {
const response = await makeRequest(config, requestBody, streaming);
const duration = Date.now() - startTime;
if (!response.ok) {
const errorText = await response.text();
return {
id: template.id,
name: template.name,
description: template.description,
status: "failed",
duration,
request: requestBody,
response: errorText,
errors: [`HTTP ${response.status}: ${errorText}`],
};
}
let rawData: unknown;
let sseResult: SSEParseResult | undefined;
if (streaming) {
sseResult = await parseSSEStream(response);
rawData = sseResult.finalResponse;
} else {
rawData = await response.json();
}
// Schema validation with Zod
const parseResult = responseResourceSchema.safeParse(rawData);
if (!parseResult.success) {
return {
id: template.id,
name: template.name,
description: template.description,
status: "failed",
duration,
request: streaming ? { ...requestBody, stream: true } : requestBody,
response: rawData,
errors: parseResult.error.issues.map(
(issue) => `${issue.path.join(".")}: ${issue.message}`,
),
streamEvents: sseResult?.events.length,
};
}
// Semantic validators
const context: ValidatorContext = { streaming, sseResult };
const errors = template.validators.flatMap((v) =>
v(parseResult.data, context),
);
return {
id: template.id,
name: template.name,
description: template.description,
status: errors.length === 0 ? "passed" : "failed",
duration,
request: streaming ? { ...requestBody, stream: true } : requestBody,
response: parseResult.data,
errors,
streamEvents: sseResult?.events.length,
};
} catch (error) {
return {
id: template.id,
name: template.name,
description: template.description,
status: "failed",
duration: Date.now() - startTime,
request: requestBody,
errors: [error instanceof Error ? error.message : String(error)],
};
}
}
export async function runAllTests(
config: TestConfig,
onProgress: (result: TestResult) => void,
): Promise<TestResult[]> {
const promises = testTemplates.map(async (template) => {
onProgress({
id: template.id,
name: template.name,
description: template.description,
status: "running",
});
const result = await runTest(template, config);
onProgress(result);
return result;
});
return Promise.all(promises);
}

253
tests/src/schemas.ts Normal file
View File

@@ -0,0 +1,253 @@
import { z } from "zod";
// ============================================================
// Content Parts
// ============================================================
const outputTextContentSchema = z.object({
type: z.literal("output_text"),
text: z.string(),
annotations: z.array(z.object({
type: z.string(),
})),
});
const inputTextContentSchema = z.object({
type: z.literal("input_text"),
text: z.string(),
});
const refusalContentSchema = z.object({
type: z.literal("refusal"),
refusal: z.string(),
});
const contentPartSchema = z.discriminatedUnion("type", [
outputTextContentSchema,
inputTextContentSchema,
refusalContentSchema,
]);
// ============================================================
// Output Items
// ============================================================
const messageOutputItemSchema = z.object({
type: z.literal("message"),
id: z.string(),
status: z.enum(["in_progress", "completed", "incomplete"]),
role: z.enum(["user", "assistant", "system", "developer"]),
content: z.array(contentPartSchema),
});
const functionCallOutputItemSchema = z.object({
type: z.literal("function_call"),
id: z.string(),
call_id: z.string(),
name: z.string(),
arguments: z.string(),
status: z.enum(["in_progress", "completed", "incomplete"]),
});
const outputItemSchema = z.discriminatedUnion("type", [
messageOutputItemSchema,
functionCallOutputItemSchema,
]);
// ============================================================
// Usage
// ============================================================
const usageSchema = z.object({
input_tokens: z.number().int(),
output_tokens: z.number().int(),
total_tokens: z.number().int(),
input_tokens_details: z.object({
cached_tokens: z.number().int(),
}),
output_tokens_details: z.object({
reasoning_tokens: z.number().int(),
}),
});
// ============================================================
// ResponseResource
// ============================================================
export const responseResourceSchema = z.object({
id: z.string(),
object: z.literal("response"),
created_at: z.number().int(),
completed_at: z.number().int().nullable(),
status: z.string(),
incomplete_details: z.object({ reason: z.string() }).nullable(),
model: z.string(),
previous_response_id: z.string().nullable(),
instructions: z.string().nullable(),
output: z.array(outputItemSchema),
error: z.object({ type: z.string(), message: z.string() }).nullable(),
tools: z.any(),
tool_choice: z.any(),
truncation: z.string(),
parallel_tool_calls: z.boolean(),
text: z.any(),
top_p: z.number(),
presence_penalty: z.number(),
frequency_penalty: z.number(),
top_logprobs: z.number().int(),
temperature: z.number(),
reasoning: z.any().nullable(),
usage: usageSchema.nullable(),
max_output_tokens: z.number().int().nullable(),
max_tool_calls: z.number().int().nullable(),
store: z.boolean(),
background: z.boolean(),
service_tier: z.string(),
metadata: z.any(),
safety_identifier: z.string().nullable(),
prompt_cache_key: z.string().nullable(),
});
export type ResponseResource = z.infer<typeof responseResourceSchema>;
// ============================================================
// Streaming Event Schemas
// ============================================================
const responseCreatedEventSchema = z.object({
type: z.literal("response.created"),
sequence_number: z.number().int(),
response: responseResourceSchema,
});
const responseInProgressEventSchema = z.object({
type: z.literal("response.in_progress"),
sequence_number: z.number().int(),
response: responseResourceSchema,
});
const responseCompletedEventSchema = z.object({
type: z.literal("response.completed"),
sequence_number: z.number().int(),
response: responseResourceSchema,
});
const responseFailedEventSchema = z.object({
type: z.literal("response.failed"),
sequence_number: z.number().int(),
response: responseResourceSchema,
});
const outputItemAddedEventSchema = z.object({
type: z.literal("response.output_item.added"),
sequence_number: z.number().int(),
output_index: z.number().int(),
item: z.object({
id: z.string(),
type: z.string(),
status: z.string(),
role: z.string().optional(),
content: z.array(z.any()).optional(),
}),
});
const outputItemDoneEventSchema = z.object({
type: z.literal("response.output_item.done"),
sequence_number: z.number().int(),
output_index: z.number().int(),
item: z.object({
id: z.string(),
type: z.string(),
status: z.string(),
role: z.string().optional(),
content: z.array(z.any()).optional(),
}),
});
const contentPartAddedEventSchema = z.object({
type: z.literal("response.content_part.added"),
sequence_number: z.number().int(),
item_id: z.string(),
output_index: z.number().int(),
content_index: z.number().int(),
part: z.object({
type: z.string(),
text: z.string().optional(),
annotations: z.array(z.any()).optional(),
}),
});
const contentPartDoneEventSchema = z.object({
type: z.literal("response.content_part.done"),
sequence_number: z.number().int(),
item_id: z.string(),
output_index: z.number().int(),
content_index: z.number().int(),
part: z.object({
type: z.string(),
text: z.string().optional(),
annotations: z.array(z.any()).optional(),
}),
});
const outputTextDeltaEventSchema = z.object({
type: z.literal("response.output_text.delta"),
sequence_number: z.number().int(),
item_id: z.string(),
output_index: z.number().int(),
content_index: z.number().int(),
delta: z.string(),
});
const outputTextDoneEventSchema = z.object({
type: z.literal("response.output_text.done"),
sequence_number: z.number().int(),
item_id: z.string(),
output_index: z.number().int(),
content_index: z.number().int(),
text: z.string(),
});
const functionCallArgsDeltaEventSchema = z.object({
type: z.literal("response.function_call_arguments.delta"),
sequence_number: z.number().int(),
item_id: z.string(),
output_index: z.number().int(),
delta: z.string(),
});
const functionCallArgsDoneEventSchema = z.object({
type: z.literal("response.function_call_arguments.done"),
sequence_number: z.number().int(),
item_id: z.string(),
output_index: z.number().int(),
arguments: z.string(),
});
const errorEventSchema = z.object({
type: z.literal("error"),
sequence_number: z.number().int(),
error: z.object({
type: z.string(),
message: z.string(),
code: z.string().nullable().optional(),
}),
});
export const streamingEventSchema = z.discriminatedUnion("type", [
responseCreatedEventSchema,
responseInProgressEventSchema,
responseCompletedEventSchema,
responseFailedEventSchema,
outputItemAddedEventSchema,
outputItemDoneEventSchema,
contentPartAddedEventSchema,
contentPartDoneEventSchema,
outputTextDeltaEventSchema,
outputTextDoneEventSchema,
functionCallArgsDeltaEventSchema,
functionCallArgsDoneEventSchema,
errorEventSchema,
]);
export type StreamingEvent = z.infer<typeof streamingEventSchema>;

92
tests/src/sse-parser.ts Normal file
View File

@@ -0,0 +1,92 @@
import type { z } from "zod";
import {
streamingEventSchema,
type StreamingEvent,
type ResponseResource,
} from "./schemas.ts";
export interface ParsedEvent {
event: string;
data: unknown;
validationResult: z.SafeParseReturnType<unknown, StreamingEvent>;
}
export interface SSEParseResult {
events: ParsedEvent[];
errors: string[];
finalResponse: ResponseResource | null;
}
export async function parseSSEStream(
response: Response,
): Promise<SSEParseResult> {
const events: ParsedEvent[] = [];
const errors: string[] = [];
let finalResponse: ResponseResource | null = null;
const reader = response.body?.getReader();
if (!reader) {
return { events, errors: ["No response body"], finalResponse };
}
const decoder = new TextDecoder();
let buffer = "";
try {
while (true) {
const { done, value } = await reader.read();
if (done) break;
buffer += decoder.decode(value, { stream: true });
const lines = buffer.split("\n");
buffer = lines.pop() || "";
let currentEvent = "";
let currentData = "";
for (const line of lines) {
if (line.startsWith("event:")) {
currentEvent = line.slice(6).trim();
} else if (line.startsWith("data:")) {
currentData = line.slice(5).trim();
} else if (line === "" && currentData) {
if (currentData === "[DONE]") {
// Skip sentinel
} else {
try {
const parsed = JSON.parse(currentData);
const validationResult = streamingEventSchema.safeParse(parsed);
events.push({
event: currentEvent || parsed.type || "unknown",
data: parsed,
validationResult,
});
if (!validationResult.success) {
errors.push(
`Event validation failed for ${parsed.type || "unknown"}: ${JSON.stringify(validationResult.error.issues)}`,
);
}
if (
parsed.type === "response.completed" ||
parsed.type === "response.failed"
) {
finalResponse = parsed.response;
}
} catch {
errors.push(`Failed to parse event data: ${currentData}`);
}
}
currentEvent = "";
currentData = "";
}
}
}
} finally {
reader.releaseLock();
}
return { events, errors, finalResponse };
}

14
tests/tsconfig.json Normal file
View File

@@ -0,0 +1,14 @@
{
"compilerOptions": {
"target": "ES2022",
"module": "NodeNext",
"moduleResolution": "NodeNext",
"strict": true,
"esModuleInterop": true,
"skipLibCheck": true,
"outDir": "dist",
"rootDir": ".",
"declaration": true
},
"include": ["src/**/*.ts", "bin/**/*.ts"]
}