Add Redis Store

This commit is contained in:
2026-03-02 15:28:03 +00:00
parent 09d687b45b
commit 259d02d140
7 changed files with 167 additions and 23 deletions

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"database/sql" "database/sql"
"flag" "flag"
"fmt" "fmt"
@@ -12,6 +13,7 @@ import (
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
_ "github.com/jackc/pgx/v5/stdlib" _ "github.com/jackc/pgx/v5/stdlib"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"github.com/redis/go-redis/v9"
"github.com/ajac-zero/latticelm/internal/auth" "github.com/ajac-zero/latticelm/internal/auth"
"github.com/ajac-zero/latticelm/internal/config" "github.com/ajac-zero/latticelm/internal/config"
@@ -112,6 +114,22 @@ func initConversationStore(cfg config.ConversationConfig, logger *log.Logger) (c
} }
logger.Printf("Conversation store initialized (sql/%s, TTL: %s)", driver, ttl) logger.Printf("Conversation store initialized (sql/%s, TTL: %s)", driver, ttl)
return store, nil return store, nil
case "redis":
opts, err := redis.ParseURL(cfg.DSN)
if err != nil {
return nil, fmt.Errorf("parse redis dsn: %w", err)
}
client := redis.NewClient(opts)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := client.Ping(ctx).Err(); err != nil {
return nil, fmt.Errorf("connect to redis: %w", err)
}
logger.Printf("Conversation store initialized (redis, TTL: %s)", ttl)
return conversation.NewRedisStore(client, ttl), nil
default: default:
logger.Printf("Conversation store initialized (memory, TTL: %s)", ttl) logger.Printf("Conversation store initialized (memory, TTL: %s)", ttl)
return conversation.NewMemoryStore(ttl), nil return conversation.NewMemoryStore(ttl), nil

View File

@@ -27,16 +27,19 @@ providers:
# endpoint: "https://your-resource.services.ai.azure.com/anthropic" # endpoint: "https://your-resource.services.ai.azure.com/anthropic"
# conversations: # conversations:
# store: "sql" # "memory" (default) or "sql" # store: "sql" # "memory" (default), "sql", or "redis"
# ttl: "1h" # conversation expiration (default: 1h) # ttl: "1h" # conversation expiration (default: 1h)
# driver: "sqlite3" # SQL driver: "sqlite3", "mysql", "pgx" # driver: "sqlite3" # SQL driver: "sqlite3", "mysql", "pgx" (required for sql store)
# dsn: "conversations.db" # connection string # dsn: "conversations.db" # connection string (required for sql/redis store)
# # MySQL example: # # MySQL example:
# # driver: "mysql" # # driver: "mysql"
# # dsn: "user:password@tcp(localhost:3306)/dbname?parseTime=true" # # dsn: "user:password@tcp(localhost:3306)/dbname?parseTime=true"
# # PostgreSQL example: # # PostgreSQL example:
# # driver: "pgx" # # driver: "pgx"
# # dsn: "postgres://user:password@localhost:5432/dbname?sslmode=disable" # # dsn: "postgres://user:password@localhost:5432/dbname?sslmode=disable"
# # Redis example:
# # store: "redis"
# # dsn: "redis://:password@localhost:6379/0"
models: models:
- name: "gemini-1.5-flash" - name: "gemini-1.5-flash"

4
go.mod
View File

@@ -10,6 +10,7 @@ require (
github.com/jackc/pgx/v5 v5.8.0 github.com/jackc/pgx/v5 v5.8.0
github.com/mattn/go-sqlite3 v1.14.34 github.com/mattn/go-sqlite3 v1.14.34
github.com/openai/openai-go/v3 v3.2.0 github.com/openai/openai-go/v3 v3.2.0
github.com/redis/go-redis/v9 v9.18.0
google.golang.org/genai v1.48.0 google.golang.org/genai v1.48.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
) )
@@ -21,6 +22,8 @@ require (
filippo.io/edwards25519 v1.1.0 // indirect filippo.io/edwards25519 v1.1.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/google/go-cmp v0.6.0 // indirect github.com/google/go-cmp v0.6.0 // indirect
github.com/google/s2a-go v0.1.8 // indirect github.com/google/s2a-go v0.1.8 // indirect
@@ -34,6 +37,7 @@ require (
github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect github.com/tidwall/sjson v1.2.5 // indirect
go.opencensus.io v0.24.0 // indirect go.opencensus.io v0.24.0 // indirect
go.uber.org/atomic v1.11.0 // indirect
golang.org/x/crypto v0.47.0 // indirect golang.org/x/crypto v0.47.0 // indirect
golang.org/x/net v0.49.0 // indirect golang.org/x/net v0.49.0 // indirect
golang.org/x/sync v0.19.0 // indirect golang.org/x/sync v0.19.0 // indirect

16
go.sum
View File

@@ -18,12 +18,20 @@ github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2/go.mod h1:wP83
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/anthropics/anthropic-sdk-go v1.26.0 h1:oUTzFaUpAevfuELAP1sjL6CQJ9HHAfT7CoSYSac11PY= github.com/anthropics/anthropic-sdk-go v1.26.0 h1:oUTzFaUpAevfuELAP1sjL6CQJ9HHAfT7CoSYSac11PY=
github.com/anthropics/anthropic-sdk-go v1.26.0/go.mod h1:qUKmaW+uuPB64iy1l+4kOSvaLqPXnHTTBKH6RVZ7q5Q= github.com/anthropics/anthropic-sdk-go v1.26.0/go.mod h1:qUKmaW+uuPB64iy1l+4kOSvaLqPXnHTTBKH6RVZ7q5Q=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI= github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI=
github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ= github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
@@ -73,6 +81,8 @@ github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo=
github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw= github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
@@ -88,6 +98,8 @@ github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjL
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/redis/go-redis/v9 v9.18.0 h1:pMkxYPkEbMPwRdenAzUNyFNrDgHx9U+DrBabWNfSRQs=
github.com/redis/go-redis/v9 v9.18.0/go.mod h1:k3ufPphLU5YXwNTUcCRXGxUoF1fqxnhFQmscfkCoDA0=
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
@@ -110,8 +122,12 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0=
github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA=
go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0=
go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo=
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=

View File

@@ -18,12 +18,12 @@ type Config struct {
// ConversationConfig controls conversation storage. // ConversationConfig controls conversation storage.
type ConversationConfig struct { type ConversationConfig struct {
// Store is the storage backend: "memory" (default) or "sql". // Store is the storage backend: "memory" (default), "sql", or "redis".
Store string `yaml:"store"` Store string `yaml:"store"`
// TTL is the conversation expiration duration (e.g. "1h", "30m"). Defaults to "1h". // TTL is the conversation expiration duration (e.g. "1h", "30m"). Defaults to "1h".
TTL string `yaml:"ttl"` TTL string `yaml:"ttl"`
// DSN is the database connection string, required when store is "sql". // DSN is the database/Redis connection string, required when store is "sql" or "redis".
// Examples: "conversations.db" (SQLite), "postgres://user:pass@host/db". // Examples: "conversations.db" (SQLite), "postgres://user:pass@host/db", "redis://:password@localhost:6379/0".
DSN string `yaml:"dsn"` DSN string `yaml:"dsn"`
// Driver is the SQL driver name, required when store is "sql". // Driver is the SQL driver name, required when store is "sql".
// Examples: "sqlite3", "postgres", "mysql". // Examples: "sqlite3", "postgres", "mysql".

View File

@@ -0,0 +1,106 @@
package conversation
import (
"context"
"encoding/json"
"time"
"github.com/ajac-zero/latticelm/internal/api"
"github.com/redis/go-redis/v9"
)
// RedisStore manages conversation history in Redis with automatic expiration.
type RedisStore struct {
client *redis.Client
ttl time.Duration
ctx context.Context
}
// NewRedisStore creates a Redis-backed conversation store.
func NewRedisStore(client *redis.Client, ttl time.Duration) *RedisStore {
return &RedisStore{
client: client,
ttl: ttl,
ctx: context.Background(),
}
}
// key returns the Redis key for a conversation ID.
func (s *RedisStore) key(id string) string {
return "conv:" + id
}
// Get retrieves a conversation by ID from Redis.
func (s *RedisStore) Get(id string) (*Conversation, bool) {
data, err := s.client.Get(s.ctx, s.key(id)).Bytes()
if err != nil {
return nil, false
}
var conv Conversation
if err := json.Unmarshal(data, &conv); err != nil {
return nil, false
}
return &conv, true
}
// Create creates a new conversation with the given messages.
func (s *RedisStore) Create(id string, model string, messages []api.Message) *Conversation {
now := time.Now()
conv := &Conversation{
ID: id,
Messages: messages,
Model: model,
CreatedAt: now,
UpdatedAt: now,
}
data, _ := json.Marshal(conv)
_ = s.client.Set(s.ctx, s.key(id), data, s.ttl).Err()
return conv
}
// Append adds new messages to an existing conversation.
func (s *RedisStore) Append(id string, messages ...api.Message) (*Conversation, bool) {
conv, ok := s.Get(id)
if !ok {
return nil, false
}
conv.Messages = append(conv.Messages, messages...)
conv.UpdatedAt = time.Now()
data, _ := json.Marshal(conv)
_ = s.client.Set(s.ctx, s.key(id), data, s.ttl).Err()
return conv, true
}
// Delete removes a conversation from Redis.
func (s *RedisStore) Delete(id string) {
_ = s.client.Del(s.ctx, s.key(id)).Err()
}
// Size returns the number of active conversations in Redis.
func (s *RedisStore) Size() int {
var count int
var cursor uint64
for {
keys, nextCursor, err := s.client.Scan(s.ctx, cursor, "conv:*", 100).Result()
if err != nil {
return 0
}
count += len(keys)
cursor = nextCursor
if cursor == 0 {
break
}
}
return count
}

View File

@@ -5,12 +5,12 @@ import (
"fmt" "fmt"
"github.com/ajac-zero/latticelm/internal/api" "github.com/ajac-zero/latticelm/internal/api"
"github.com/openai/openai-go" "github.com/openai/openai-go/v3"
"github.com/openai/openai-go/shared" "github.com/openai/openai-go/v3/shared"
) )
// parseTools converts Open Responses tools to OpenAI format // parseTools converts Open Responses tools to OpenAI format
func parseTools(req *api.ResponseRequest) ([]openai.ChatCompletionToolParam, error) { func parseTools(req *api.ResponseRequest) ([]openai.ChatCompletionToolUnionParam, error) {
if req.Tools == nil || len(req.Tools) == 0 { if req.Tools == nil || len(req.Tools) == 0 {
return nil, nil return nil, nil
} }
@@ -20,29 +20,27 @@ func parseTools(req *api.ResponseRequest) ([]openai.ChatCompletionToolParam, err
return nil, fmt.Errorf("unmarshal tools: %w", err) return nil, fmt.Errorf("unmarshal tools: %w", err)
} }
var tools []openai.ChatCompletionToolParam var tools []openai.ChatCompletionToolUnionParam
for _, td := range toolDefs { for _, td := range toolDefs {
// Convert Open Responses tool to OpenAI ChatCompletionToolParam // Convert Open Responses tool to OpenAI function tool
// Extract: name, description, parameters // Extract: name, description, parameters
name, _ := td["name"].(string) name, _ := td["name"].(string)
desc, _ := td["description"].(string) desc, _ := td["description"].(string)
params, _ := td["parameters"].(map[string]interface{}) params, _ := td["parameters"].(map[string]interface{})
tool := openai.ChatCompletionToolParam{ funcDef := shared.FunctionDefinitionParam{
Function: shared.FunctionDefinitionParam{ Name: name,
Name: name,
},
} }
if desc != "" { if desc != "" {
tool.Function.Description = openai.String(desc) funcDef.Description = openai.String(desc)
} }
if params != nil { if params != nil {
tool.Function.Parameters = shared.FunctionParameters(params) funcDef.Parameters = shared.FunctionParameters(params)
} }
tools = append(tools, tool) tools = append(tools, openai.ChatCompletionFunctionTool(funcDef))
} }
return tools, nil return tools, nil
@@ -67,17 +65,16 @@ func parseToolChoice(req *api.ResponseRequest) (openai.ChatCompletionToolChoiceO
return result, nil return result, nil
} }
// Handle specific function selection: {"type": "function", "name": "..."} // Handle specific function selection: {"type": "function", "function": {"name": "..."}}
if obj, ok := choice.(map[string]interface{}); ok { if obj, ok := choice.(map[string]interface{}); ok {
funcObj, _ := obj["function"].(map[string]interface{}) funcObj, _ := obj["function"].(map[string]interface{})
name, _ := funcObj["name"].(string) name, _ := funcObj["name"].(string)
result.OfChatCompletionNamedToolChoice = &openai.ChatCompletionNamedToolChoiceParam{ return openai.ToolChoiceOptionFunctionToolChoice(
Function: openai.ChatCompletionNamedToolChoiceFunctionParam{ openai.ChatCompletionNamedToolChoiceFunctionParam{
Name: name, Name: name,
}, },
} ), nil
return result, nil
} }
return result, fmt.Errorf("invalid tool_choice format") return result, fmt.Errorf("invalid tool_choice format")