6272 lines
281 KiB
HTML
6272 lines
281 KiB
HTML
|
|
<!DOCTYPE html>
|
|
<html>
|
|
<head>
|
|
<meta http-equiv="Content-Type" content="text/html; charset=utf-8">
|
|
<title>gateway: Go Coverage Report</title>
|
|
<style>
|
|
body {
|
|
background: black;
|
|
color: rgb(80, 80, 80);
|
|
}
|
|
body, pre, #legend span {
|
|
font-family: Menlo, monospace;
|
|
font-weight: bold;
|
|
}
|
|
#topbar {
|
|
background: black;
|
|
position: fixed;
|
|
top: 0; left: 0; right: 0;
|
|
height: 42px;
|
|
border-bottom: 1px solid rgb(80, 80, 80);
|
|
}
|
|
#content {
|
|
margin-top: 50px;
|
|
}
|
|
#nav, #legend {
|
|
float: left;
|
|
margin-left: 10px;
|
|
}
|
|
#legend {
|
|
margin-top: 12px;
|
|
}
|
|
#nav {
|
|
margin-top: 10px;
|
|
}
|
|
#legend span {
|
|
margin: 0 5px;
|
|
}
|
|
.cov0 { color: rgb(192, 0, 0) }
|
|
.cov1 { color: rgb(128, 128, 128) }
|
|
.cov2 { color: rgb(116, 140, 131) }
|
|
.cov3 { color: rgb(104, 152, 134) }
|
|
.cov4 { color: rgb(92, 164, 137) }
|
|
.cov5 { color: rgb(80, 176, 140) }
|
|
.cov6 { color: rgb(68, 188, 143) }
|
|
.cov7 { color: rgb(56, 200, 146) }
|
|
.cov8 { color: rgb(44, 212, 149) }
|
|
.cov9 { color: rgb(32, 224, 152) }
|
|
.cov10 { color: rgb(20, 236, 155) }
|
|
|
|
</style>
|
|
</head>
|
|
<body>
|
|
<div id="topbar">
|
|
<div id="nav">
|
|
<select id="files">
|
|
|
|
<option value="file0">github.com/ajac-zero/latticelm/cmd/gateway/main.go (0.0%)</option>
|
|
|
|
<option value="file1">github.com/ajac-zero/latticelm/internal/api/types.go (100.0%)</option>
|
|
|
|
<option value="file2">github.com/ajac-zero/latticelm/internal/auth/auth.go (91.7%)</option>
|
|
|
|
<option value="file3">github.com/ajac-zero/latticelm/internal/config/config.go (100.0%)</option>
|
|
|
|
<option value="file4">github.com/ajac-zero/latticelm/internal/conversation/conversation.go (82.0%)</option>
|
|
|
|
<option value="file5">github.com/ajac-zero/latticelm/internal/conversation/redis_store.go (82.6%)</option>
|
|
|
|
<option value="file6">github.com/ajac-zero/latticelm/internal/conversation/sql_store.go (80.0%)</option>
|
|
|
|
<option value="file7">github.com/ajac-zero/latticelm/internal/conversation/testing.go (23.6%)</option>
|
|
|
|
<option value="file8">github.com/ajac-zero/latticelm/internal/logger/logger.go (0.0%)</option>
|
|
|
|
<option value="file9">github.com/ajac-zero/latticelm/internal/observability/init.go (0.0%)</option>
|
|
|
|
<option value="file10">github.com/ajac-zero/latticelm/internal/observability/metrics.go (100.0%)</option>
|
|
|
|
<option value="file11">github.com/ajac-zero/latticelm/internal/observability/metrics_middleware.go (0.0%)</option>
|
|
|
|
<option value="file12">github.com/ajac-zero/latticelm/internal/observability/provider_wrapper.go (88.2%)</option>
|
|
|
|
<option value="file13">github.com/ajac-zero/latticelm/internal/observability/store_wrapper.go (0.0%)</option>
|
|
|
|
<option value="file14">github.com/ajac-zero/latticelm/internal/observability/testing.go (22.2%)</option>
|
|
|
|
<option value="file15">github.com/ajac-zero/latticelm/internal/observability/tracing.go (36.7%)</option>
|
|
|
|
<option value="file16">github.com/ajac-zero/latticelm/internal/observability/tracing_middleware.go (0.0%)</option>
|
|
|
|
<option value="file17">github.com/ajac-zero/latticelm/internal/providers/anthropic/anthropic.go (0.0%)</option>
|
|
|
|
<option value="file18">github.com/ajac-zero/latticelm/internal/providers/anthropic/convert.go (63.0%)</option>
|
|
|
|
<option value="file19">github.com/ajac-zero/latticelm/internal/providers/circuitbreaker.go (14.3%)</option>
|
|
|
|
<option value="file20">github.com/ajac-zero/latticelm/internal/providers/google/convert.go (79.3%)</option>
|
|
|
|
<option value="file21">github.com/ajac-zero/latticelm/internal/providers/google/google.go (0.0%)</option>
|
|
|
|
<option value="file22">github.com/ajac-zero/latticelm/internal/providers/openai/convert.go (71.4%)</option>
|
|
|
|
<option value="file23">github.com/ajac-zero/latticelm/internal/providers/openai/openai.go (0.0%)</option>
|
|
|
|
<option value="file24">github.com/ajac-zero/latticelm/internal/providers/providers.go (98.0%)</option>
|
|
|
|
<option value="file25">github.com/ajac-zero/latticelm/internal/ratelimit/ratelimit.go (87.2%)</option>
|
|
|
|
<option value="file26">github.com/ajac-zero/latticelm/internal/server/health.go (89.2%)</option>
|
|
|
|
<option value="file27">github.com/ajac-zero/latticelm/internal/server/middleware.go (83.3%)</option>
|
|
|
|
<option value="file28">github.com/ajac-zero/latticelm/internal/server/server.go (91.6%)</option>
|
|
|
|
</select>
|
|
</div>
|
|
<div id="legend">
|
|
<span>not tracked</span>
|
|
|
|
<span class="cov0">not covered</span>
|
|
<span class="cov8">covered</span>
|
|
|
|
</div>
|
|
</div>
|
|
<div id="content">
|
|
|
|
<pre class="file" id="file0" style="display: none">package main
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"flag"
|
|
"fmt"
|
|
"log"
|
|
"log/slog"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"syscall"
|
|
"time"
|
|
|
|
_ "github.com/go-sql-driver/mysql"
|
|
"github.com/google/uuid"
|
|
_ "github.com/jackc/pgx/v5/stdlib"
|
|
_ "github.com/mattn/go-sqlite3"
|
|
"github.com/redis/go-redis/v9"
|
|
|
|
"github.com/ajac-zero/latticelm/internal/auth"
|
|
"github.com/ajac-zero/latticelm/internal/config"
|
|
"github.com/ajac-zero/latticelm/internal/conversation"
|
|
slogger "github.com/ajac-zero/latticelm/internal/logger"
|
|
"github.com/ajac-zero/latticelm/internal/observability"
|
|
"github.com/ajac-zero/latticelm/internal/providers"
|
|
"github.com/ajac-zero/latticelm/internal/ratelimit"
|
|
"github.com/ajac-zero/latticelm/internal/server"
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
|
"go.opentelemetry.io/otel"
|
|
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
|
)
|
|
|
|
func main() <span class="cov0" title="0">{
|
|
var configPath string
|
|
flag.StringVar(&configPath, "config", "config.yaml", "path to config file")
|
|
flag.Parse()
|
|
|
|
cfg, err := config.Load(configPath)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
log.Fatalf("load config: %v", err)
|
|
}</span>
|
|
|
|
// Initialize logger from config
|
|
<span class="cov0" title="0">logFormat := cfg.Logging.Format
|
|
if logFormat == "" </span><span class="cov0" title="0">{
|
|
logFormat = "json"
|
|
}</span>
|
|
<span class="cov0" title="0">logLevel := cfg.Logging.Level
|
|
if logLevel == "" </span><span class="cov0" title="0">{
|
|
logLevel = "info"
|
|
}</span>
|
|
<span class="cov0" title="0">logger := slogger.New(logFormat, logLevel)
|
|
|
|
// Initialize tracing
|
|
var tracerProvider *sdktrace.TracerProvider
|
|
if cfg.Observability.Enabled && cfg.Observability.Tracing.Enabled </span><span class="cov0" title="0">{
|
|
// Set defaults
|
|
tracingCfg := cfg.Observability.Tracing
|
|
if tracingCfg.ServiceName == "" </span><span class="cov0" title="0">{
|
|
tracingCfg.ServiceName = "llm-gateway"
|
|
}</span>
|
|
<span class="cov0" title="0">if tracingCfg.Sampler.Type == "" </span><span class="cov0" title="0">{
|
|
tracingCfg.Sampler.Type = "probability"
|
|
tracingCfg.Sampler.Rate = 0.1
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">tp, err := observability.InitTracer(tracingCfg)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
logger.Error("failed to initialize tracing", slog.String("error", err.Error()))
|
|
}</span> else<span class="cov0" title="0"> {
|
|
tracerProvider = tp
|
|
otel.SetTracerProvider(tracerProvider)
|
|
logger.Info("tracing initialized",
|
|
slog.String("exporter", tracingCfg.Exporter.Type),
|
|
slog.String("sampler", tracingCfg.Sampler.Type),
|
|
)
|
|
}</span>
|
|
}
|
|
|
|
// Initialize metrics
|
|
<span class="cov0" title="0">var metricsRegistry *prometheus.Registry
|
|
if cfg.Observability.Enabled && cfg.Observability.Metrics.Enabled </span><span class="cov0" title="0">{
|
|
metricsRegistry = observability.InitMetrics()
|
|
metricsPath := cfg.Observability.Metrics.Path
|
|
if metricsPath == "" </span><span class="cov0" title="0">{
|
|
metricsPath = "/metrics"
|
|
}</span>
|
|
<span class="cov0" title="0">logger.Info("metrics initialized", slog.String("path", metricsPath))</span>
|
|
}
|
|
|
|
// Create provider registry with circuit breaker support
|
|
<span class="cov0" title="0">var baseRegistry *providers.Registry
|
|
if cfg.Observability.Enabled && cfg.Observability.Metrics.Enabled </span><span class="cov0" title="0">{
|
|
// Pass observability callback for circuit breaker state changes
|
|
baseRegistry, err = providers.NewRegistryWithCircuitBreaker(
|
|
cfg.Providers,
|
|
cfg.Models,
|
|
observability.RecordCircuitBreakerStateChange,
|
|
)
|
|
}</span> else<span class="cov0" title="0"> {
|
|
// No observability, use default registry
|
|
baseRegistry, err = providers.NewRegistry(cfg.Providers, cfg.Models)
|
|
}</span>
|
|
<span class="cov0" title="0">if err != nil </span><span class="cov0" title="0">{
|
|
logger.Error("failed to initialize providers", slog.String("error", err.Error()))
|
|
os.Exit(1)
|
|
}</span>
|
|
|
|
// Wrap providers with observability
|
|
<span class="cov0" title="0">var registry server.ProviderRegistry = baseRegistry
|
|
if cfg.Observability.Enabled </span><span class="cov0" title="0">{
|
|
registry = observability.WrapProviderRegistry(registry, metricsRegistry, tracerProvider)
|
|
logger.Info("providers instrumented")
|
|
}</span>
|
|
|
|
// Initialize authentication middleware
|
|
<span class="cov0" title="0">authConfig := auth.Config{
|
|
Enabled: cfg.Auth.Enabled,
|
|
Issuer: cfg.Auth.Issuer,
|
|
Audience: cfg.Auth.Audience,
|
|
}
|
|
authMiddleware, err := auth.New(authConfig, logger)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
logger.Error("failed to initialize auth", slog.String("error", err.Error()))
|
|
os.Exit(1)
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">if cfg.Auth.Enabled </span><span class="cov0" title="0">{
|
|
logger.Info("authentication enabled", slog.String("issuer", cfg.Auth.Issuer))
|
|
}</span> else<span class="cov0" title="0"> {
|
|
logger.Warn("authentication disabled - API is publicly accessible")
|
|
}</span>
|
|
|
|
// Initialize conversation store
|
|
<span class="cov0" title="0">convStore, storeBackend, err := initConversationStore(cfg.Conversations, logger)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
logger.Error("failed to initialize conversation store", slog.String("error", err.Error()))
|
|
os.Exit(1)
|
|
}</span>
|
|
|
|
// Wrap conversation store with observability
|
|
<span class="cov0" title="0">if cfg.Observability.Enabled && convStore != nil </span><span class="cov0" title="0">{
|
|
convStore = observability.WrapConversationStore(convStore, storeBackend, metricsRegistry, tracerProvider)
|
|
logger.Info("conversation store instrumented")
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">gatewayServer := server.New(registry, convStore, logger)
|
|
mux := http.NewServeMux()
|
|
gatewayServer.RegisterRoutes(mux)
|
|
|
|
// Register metrics endpoint if enabled
|
|
if cfg.Observability.Enabled && cfg.Observability.Metrics.Enabled </span><span class="cov0" title="0">{
|
|
metricsPath := cfg.Observability.Metrics.Path
|
|
if metricsPath == "" </span><span class="cov0" title="0">{
|
|
metricsPath = "/metrics"
|
|
}</span>
|
|
<span class="cov0" title="0">mux.Handle(metricsPath, promhttp.HandlerFor(metricsRegistry, promhttp.HandlerOpts{}))
|
|
logger.Info("metrics endpoint registered", slog.String("path", metricsPath))</span>
|
|
}
|
|
|
|
<span class="cov0" title="0">addr := cfg.Server.Address
|
|
if addr == "" </span><span class="cov0" title="0">{
|
|
addr = ":8080"
|
|
}</span>
|
|
|
|
// Initialize rate limiting
|
|
<span class="cov0" title="0">rateLimitConfig := ratelimit.Config{
|
|
Enabled: cfg.RateLimit.Enabled,
|
|
RequestsPerSecond: cfg.RateLimit.RequestsPerSecond,
|
|
Burst: cfg.RateLimit.Burst,
|
|
}
|
|
// Set defaults if not configured
|
|
if rateLimitConfig.Enabled && rateLimitConfig.RequestsPerSecond == 0 </span><span class="cov0" title="0">{
|
|
rateLimitConfig.RequestsPerSecond = 10 // default 10 req/s
|
|
}</span>
|
|
<span class="cov0" title="0">if rateLimitConfig.Enabled && rateLimitConfig.Burst == 0 </span><span class="cov0" title="0">{
|
|
rateLimitConfig.Burst = 20 // default burst of 20
|
|
}</span>
|
|
<span class="cov0" title="0">rateLimitMiddleware := ratelimit.New(rateLimitConfig, logger)
|
|
|
|
if cfg.RateLimit.Enabled </span><span class="cov0" title="0">{
|
|
logger.Info("rate limiting enabled",
|
|
slog.Float64("requests_per_second", rateLimitConfig.RequestsPerSecond),
|
|
slog.Int("burst", rateLimitConfig.Burst),
|
|
)
|
|
}</span>
|
|
|
|
// Determine max request body size
|
|
<span class="cov0" title="0">maxRequestBodySize := cfg.Server.MaxRequestBodySize
|
|
if maxRequestBodySize == 0 </span><span class="cov0" title="0">{
|
|
maxRequestBodySize = server.MaxRequestBodyBytes // default: 10MB
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">logger.Info("server configuration",
|
|
slog.Int64("max_request_body_bytes", maxRequestBodySize),
|
|
)
|
|
|
|
// Build handler chain: panic recovery -> request size limit -> logging -> tracing -> metrics -> rate limiting -> auth -> routes
|
|
handler := server.PanicRecoveryMiddleware(
|
|
server.RequestSizeLimitMiddleware(
|
|
loggingMiddleware(
|
|
observability.TracingMiddleware(
|
|
observability.MetricsMiddleware(
|
|
rateLimitMiddleware.Handler(authMiddleware.Handler(mux)),
|
|
metricsRegistry,
|
|
tracerProvider,
|
|
),
|
|
tracerProvider,
|
|
),
|
|
logger,
|
|
),
|
|
maxRequestBodySize,
|
|
),
|
|
logger,
|
|
)
|
|
|
|
srv := &http.Server{
|
|
Addr: addr,
|
|
Handler: handler,
|
|
ReadTimeout: 15 * time.Second,
|
|
WriteTimeout: 60 * time.Second,
|
|
IdleTimeout: 120 * time.Second,
|
|
}
|
|
|
|
// Set up signal handling for graceful shutdown
|
|
sigChan := make(chan os.Signal, 1)
|
|
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
|
|
|
|
// Run server in a goroutine
|
|
serverErrors := make(chan error, 1)
|
|
go func() </span><span class="cov0" title="0">{
|
|
logger.Info("open responses gateway listening", slog.String("address", addr))
|
|
serverErrors <- srv.ListenAndServe()
|
|
}</span>()
|
|
|
|
// Wait for shutdown signal or server error
|
|
<span class="cov0" title="0">select </span>{
|
|
case err := <-serverErrors:<span class="cov0" title="0">
|
|
if err != nil && err != http.ErrServerClosed </span><span class="cov0" title="0">{
|
|
logger.Error("server error", slog.String("error", err.Error()))
|
|
os.Exit(1)
|
|
}</span>
|
|
case sig := <-sigChan:<span class="cov0" title="0">
|
|
logger.Info("received shutdown signal", slog.String("signal", sig.String()))
|
|
|
|
// Create shutdown context with timeout
|
|
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer shutdownCancel()
|
|
|
|
// Shutdown the HTTP server gracefully
|
|
logger.Info("shutting down server gracefully")
|
|
if err := srv.Shutdown(shutdownCtx); err != nil </span><span class="cov0" title="0">{
|
|
logger.Error("server shutdown error", slog.String("error", err.Error()))
|
|
}</span>
|
|
|
|
// Shutdown tracer provider
|
|
<span class="cov0" title="0">if tracerProvider != nil </span><span class="cov0" title="0">{
|
|
logger.Info("shutting down tracer")
|
|
shutdownTracerCtx, shutdownTracerCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer shutdownTracerCancel()
|
|
if err := observability.Shutdown(shutdownTracerCtx, tracerProvider); err != nil </span><span class="cov0" title="0">{
|
|
logger.Error("error shutting down tracer", slog.String("error", err.Error()))
|
|
}</span>
|
|
}
|
|
|
|
// Close conversation store
|
|
<span class="cov0" title="0">logger.Info("closing conversation store")
|
|
if err := convStore.Close(); err != nil </span><span class="cov0" title="0">{
|
|
logger.Error("error closing conversation store", slog.String("error", err.Error()))
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">logger.Info("shutdown complete")</span>
|
|
}
|
|
}
|
|
|
|
func initConversationStore(cfg config.ConversationConfig, logger *slog.Logger) (conversation.Store, string, error) <span class="cov0" title="0">{
|
|
var ttl time.Duration
|
|
if cfg.TTL != "" </span><span class="cov0" title="0">{
|
|
parsed, err := time.ParseDuration(cfg.TTL)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
return nil, "", fmt.Errorf("invalid conversation ttl %q: %w", cfg.TTL, err)
|
|
}</span>
|
|
<span class="cov0" title="0">ttl = parsed</span>
|
|
}
|
|
|
|
<span class="cov0" title="0">switch cfg.Store </span>{
|
|
case "sql":<span class="cov0" title="0">
|
|
driver := cfg.Driver
|
|
if driver == "" </span><span class="cov0" title="0">{
|
|
driver = "sqlite3"
|
|
}</span>
|
|
<span class="cov0" title="0">db, err := sql.Open(driver, cfg.DSN)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
return nil, "", fmt.Errorf("open database: %w", err)
|
|
}</span>
|
|
<span class="cov0" title="0">store, err := conversation.NewSQLStore(db, driver, ttl)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
return nil, "", fmt.Errorf("init sql store: %w", err)
|
|
}</span>
|
|
<span class="cov0" title="0">logger.Info("conversation store initialized",
|
|
slog.String("backend", "sql"),
|
|
slog.String("driver", driver),
|
|
slog.Duration("ttl", ttl),
|
|
)
|
|
return store, "sql", nil</span>
|
|
case "redis":<span class="cov0" title="0">
|
|
opts, err := redis.ParseURL(cfg.DSN)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
return nil, "", fmt.Errorf("parse redis dsn: %w", err)
|
|
}</span>
|
|
<span class="cov0" title="0">client := redis.NewClient(opts)
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
if err := client.Ping(ctx).Err(); err != nil </span><span class="cov0" title="0">{
|
|
return nil, "", fmt.Errorf("connect to redis: %w", err)
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">logger.Info("conversation store initialized",
|
|
slog.String("backend", "redis"),
|
|
slog.Duration("ttl", ttl),
|
|
)
|
|
return conversation.NewRedisStore(client, ttl), "redis", nil</span>
|
|
default:<span class="cov0" title="0">
|
|
logger.Info("conversation store initialized",
|
|
slog.String("backend", "memory"),
|
|
slog.Duration("ttl", ttl),
|
|
)
|
|
return conversation.NewMemoryStore(ttl), "memory", nil</span>
|
|
}
|
|
}
|
|
type responseWriter struct {
|
|
http.ResponseWriter
|
|
statusCode int
|
|
bytesWritten int
|
|
}
|
|
|
|
func (rw *responseWriter) WriteHeader(code int) <span class="cov0" title="0">{
|
|
rw.statusCode = code
|
|
rw.ResponseWriter.WriteHeader(code)
|
|
}</span>
|
|
|
|
func (rw *responseWriter) Write(b []byte) (int, error) <span class="cov0" title="0">{
|
|
n, err := rw.ResponseWriter.Write(b)
|
|
rw.bytesWritten += n
|
|
return n, err
|
|
}</span>
|
|
|
|
func loggingMiddleware(next http.Handler, logger *slog.Logger) http.Handler <span class="cov0" title="0">{
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) </span><span class="cov0" title="0">{
|
|
start := time.Now()
|
|
|
|
// Generate request ID
|
|
requestID := uuid.NewString()
|
|
ctx := slogger.WithRequestID(r.Context(), requestID)
|
|
r = r.WithContext(ctx)
|
|
|
|
// Wrap response writer to capture status code
|
|
rw := &responseWriter{
|
|
ResponseWriter: w,
|
|
statusCode: http.StatusOK,
|
|
}
|
|
|
|
// Add request ID header
|
|
w.Header().Set("X-Request-ID", requestID)
|
|
|
|
// Log request start
|
|
logger.InfoContext(ctx, "request started",
|
|
slog.String("request_id", requestID),
|
|
slog.String("method", r.Method),
|
|
slog.String("path", r.URL.Path),
|
|
slog.String("remote_addr", r.RemoteAddr),
|
|
slog.String("user_agent", r.UserAgent()),
|
|
)
|
|
|
|
next.ServeHTTP(rw, r)
|
|
|
|
duration := time.Since(start)
|
|
|
|
// Log request completion with appropriate level
|
|
logLevel := slog.LevelInfo
|
|
if rw.statusCode >= 500 </span><span class="cov0" title="0">{
|
|
logLevel = slog.LevelError
|
|
}</span> else<span class="cov0" title="0"> if rw.statusCode >= 400 </span><span class="cov0" title="0">{
|
|
logLevel = slog.LevelWarn
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">logger.Log(ctx, logLevel, "request completed",
|
|
slog.String("request_id", requestID),
|
|
slog.String("method", r.Method),
|
|
slog.String("path", r.URL.Path),
|
|
slog.Int("status_code", rw.statusCode),
|
|
slog.Int("response_bytes", rw.bytesWritten),
|
|
slog.Duration("duration", duration),
|
|
slog.Float64("duration_ms", float64(duration.Milliseconds())),
|
|
)</span>
|
|
})
|
|
}
|
|
</pre>
|
|
|
|
<pre class="file" id="file1" style="display: none">package api
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
)
|
|
|
|
// ============================================================
|
|
// Request Types (CreateResponseBody)
|
|
// ============================================================
|
|
|
|
// ResponseRequest models the OpenResponses CreateResponseBody.
|
|
type ResponseRequest struct {
|
|
Model string `json:"model"`
|
|
Input InputUnion `json:"input"`
|
|
Instructions *string `json:"instructions,omitempty"`
|
|
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
|
|
Metadata map[string]string `json:"metadata,omitempty"`
|
|
Stream bool `json:"stream,omitempty"`
|
|
PreviousResponseID *string `json:"previous_response_id,omitempty"`
|
|
Temperature *float64 `json:"temperature,omitempty"`
|
|
TopP *float64 `json:"top_p,omitempty"`
|
|
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
|
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
|
TopLogprobs *int `json:"top_logprobs,omitempty"`
|
|
Truncation *string `json:"truncation,omitempty"`
|
|
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
|
Tools json.RawMessage `json:"tools,omitempty"`
|
|
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
|
|
Store *bool `json:"store,omitempty"`
|
|
Text json.RawMessage `json:"text,omitempty"`
|
|
Reasoning json.RawMessage `json:"reasoning,omitempty"`
|
|
Include []string `json:"include,omitempty"`
|
|
ServiceTier *string `json:"service_tier,omitempty"`
|
|
Background *bool `json:"background,omitempty"`
|
|
StreamOptions json.RawMessage `json:"stream_options,omitempty"`
|
|
MaxToolCalls *int `json:"max_tool_calls,omitempty"`
|
|
|
|
// Non-spec extension: allows client to select a specific provider.
|
|
Provider string `json:"provider,omitempty"`
|
|
}
|
|
|
|
// InputUnion handles the polymorphic "input" field: string or []InputItem.
|
|
type InputUnion struct {
|
|
String *string
|
|
Items []InputItem
|
|
}
|
|
|
|
func (u *InputUnion) UnmarshalJSON(data []byte) error <span class="cov8" title="1">{
|
|
if string(data) == "null" </span><span class="cov8" title="1">{
|
|
return nil
|
|
}</span>
|
|
<span class="cov8" title="1">var s string
|
|
if err := json.Unmarshal(data, &s); err == nil </span><span class="cov8" title="1">{
|
|
u.String = &s
|
|
return nil
|
|
}</span>
|
|
<span class="cov8" title="1">var items []InputItem
|
|
if err := json.Unmarshal(data, &items); err == nil </span><span class="cov8" title="1">{
|
|
u.Items = items
|
|
return nil
|
|
}</span>
|
|
<span class="cov8" title="1">return fmt.Errorf("input must be a string or array of items")</span>
|
|
}
|
|
|
|
func (u InputUnion) MarshalJSON() ([]byte, error) <span class="cov8" title="1">{
|
|
if u.String != nil </span><span class="cov8" title="1">{
|
|
return json.Marshal(*u.String)
|
|
}</span>
|
|
<span class="cov8" title="1">if u.Items != nil </span><span class="cov8" title="1">{
|
|
return json.Marshal(u.Items)
|
|
}</span>
|
|
<span class="cov8" title="1">return []byte("null"), nil</span>
|
|
}
|
|
|
|
// InputItem is a discriminated union on "type".
|
|
// Valid types: message, item_reference, function_call, function_call_output, reasoning.
|
|
type InputItem struct {
|
|
Type string `json:"type"`
|
|
Role string `json:"role,omitempty"`
|
|
Content json.RawMessage `json:"content,omitempty"`
|
|
ID string `json:"id,omitempty"`
|
|
CallID string `json:"call_id,omitempty"`
|
|
Name string `json:"name,omitempty"`
|
|
Arguments string `json:"arguments,omitempty"`
|
|
Output string `json:"output,omitempty"`
|
|
Status string `json:"status,omitempty"`
|
|
}
|
|
|
|
// ============================================================
|
|
// Internal Types (providers + conversation store)
|
|
// ============================================================
|
|
|
|
// Message is the normalized internal message representation.
|
|
type Message struct {
|
|
Role string `json:"role"`
|
|
Content []ContentBlock `json:"content"`
|
|
CallID string `json:"call_id,omitempty"` // for tool messages
|
|
Name string `json:"name,omitempty"` // for tool messages
|
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"` // for assistant messages
|
|
}
|
|
|
|
// ContentBlock is a typed content element.
|
|
type ContentBlock struct {
|
|
Type string `json:"type"`
|
|
Text string `json:"text,omitempty"`
|
|
}
|
|
|
|
// NormalizeInput converts the request Input into messages for providers.
|
|
// Does NOT include instructions (the server prepends those separately).
|
|
func (r *ResponseRequest) NormalizeInput() []Message <span class="cov8" title="1">{
|
|
if r.Input.String != nil </span><span class="cov8" title="1">{
|
|
return []Message{{
|
|
Role: "user",
|
|
Content: []ContentBlock{{Type: "input_text", Text: *r.Input.String}},
|
|
}}
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">var msgs []Message
|
|
for _, item := range r.Input.Items </span><span class="cov8" title="1">{
|
|
switch item.Type </span>{
|
|
case "message", "":<span class="cov8" title="1">
|
|
msg := Message{Role: item.Role}
|
|
if item.Content != nil </span><span class="cov8" title="1">{
|
|
var s string
|
|
if err := json.Unmarshal(item.Content, &s); err == nil </span><span class="cov8" title="1">{
|
|
contentType := "input_text"
|
|
if item.Role == "assistant" </span><span class="cov8" title="1">{
|
|
contentType = "output_text"
|
|
}</span>
|
|
<span class="cov8" title="1">msg.Content = []ContentBlock{{Type: contentType, Text: s}}</span>
|
|
} else<span class="cov8" title="1"> {
|
|
// Content is an array of blocks - parse them
|
|
var rawBlocks []map[string]interface{}
|
|
if err := json.Unmarshal(item.Content, &rawBlocks); err == nil </span><span class="cov8" title="1">{
|
|
// Extract content blocks and tool calls
|
|
for _, block := range rawBlocks </span><span class="cov8" title="1">{
|
|
blockType, _ := block["type"].(string)
|
|
|
|
if blockType == "tool_use" </span><span class="cov8" title="1">{
|
|
// Extract tool call information
|
|
toolCall := ToolCall{
|
|
ID: getStringField(block, "id"),
|
|
Name: getStringField(block, "name"),
|
|
}
|
|
// input field contains the arguments as a map
|
|
if input, ok := block["input"].(map[string]interface{}); ok </span><span class="cov8" title="1">{
|
|
if inputJSON, err := json.Marshal(input); err == nil </span><span class="cov8" title="1">{
|
|
toolCall.Arguments = string(inputJSON)
|
|
}</span>
|
|
}
|
|
<span class="cov8" title="1">msg.ToolCalls = append(msg.ToolCalls, toolCall)</span>
|
|
} else<span class="cov8" title="1"> if blockType == "output_text" || blockType == "input_text" </span><span class="cov8" title="1">{
|
|
// Regular text content block
|
|
msg.Content = append(msg.Content, ContentBlock{
|
|
Type: blockType,
|
|
Text: getStringField(block, "text"),
|
|
})
|
|
}</span>
|
|
}
|
|
}
|
|
}
|
|
}
|
|
<span class="cov8" title="1">msgs = append(msgs, msg)</span>
|
|
case "function_call_output":<span class="cov8" title="1">
|
|
msgs = append(msgs, Message{
|
|
Role: "tool",
|
|
Content: []ContentBlock{{Type: "input_text", Text: item.Output}},
|
|
CallID: item.CallID,
|
|
Name: item.Name,
|
|
})</span>
|
|
}
|
|
}
|
|
<span class="cov8" title="1">return msgs</span>
|
|
}
|
|
|
|
// ============================================================
|
|
// Response Types (ResponseResource)
|
|
// ============================================================
|
|
|
|
// Response is the spec-compliant ResponseResource.
|
|
type Response struct {
|
|
ID string `json:"id"`
|
|
Object string `json:"object"`
|
|
CreatedAt int64 `json:"created_at"`
|
|
CompletedAt *int64 `json:"completed_at"`
|
|
Status string `json:"status"`
|
|
IncompleteDetails *IncompleteDetails `json:"incomplete_details"`
|
|
Model string `json:"model"`
|
|
PreviousResponseID *string `json:"previous_response_id"`
|
|
Instructions *string `json:"instructions"`
|
|
Output []OutputItem `json:"output"`
|
|
Error *ResponseError `json:"error"`
|
|
Tools json.RawMessage `json:"tools"`
|
|
ToolChoice json.RawMessage `json:"tool_choice"`
|
|
Truncation string `json:"truncation"`
|
|
ParallelToolCalls bool `json:"parallel_tool_calls"`
|
|
Text json.RawMessage `json:"text"`
|
|
TopP float64 `json:"top_p"`
|
|
PresencePenalty float64 `json:"presence_penalty"`
|
|
FrequencyPenalty float64 `json:"frequency_penalty"`
|
|
TopLogprobs int `json:"top_logprobs"`
|
|
Temperature float64 `json:"temperature"`
|
|
Reasoning json.RawMessage `json:"reasoning"`
|
|
Usage *Usage `json:"usage"`
|
|
MaxOutputTokens *int `json:"max_output_tokens"`
|
|
MaxToolCalls *int `json:"max_tool_calls"`
|
|
Store bool `json:"store"`
|
|
Background bool `json:"background"`
|
|
ServiceTier string `json:"service_tier"`
|
|
Metadata map[string]string `json:"metadata"`
|
|
SafetyIdentifier *string `json:"safety_identifier"`
|
|
PromptCacheKey *string `json:"prompt_cache_key"`
|
|
|
|
// Non-spec extension
|
|
Provider string `json:"provider,omitempty"`
|
|
}
|
|
|
|
// OutputItem represents a typed item in the response output.
|
|
type OutputItem struct {
|
|
ID string `json:"id"`
|
|
Type string `json:"type"`
|
|
Status string `json:"status"`
|
|
Role string `json:"role,omitempty"`
|
|
Content []ContentPart `json:"content,omitempty"`
|
|
CallID string `json:"call_id,omitempty"` // for function_call
|
|
Name string `json:"name,omitempty"` // for function_call
|
|
Arguments string `json:"arguments,omitempty"` // for function_call
|
|
}
|
|
|
|
// ContentPart is a content block within an output item.
|
|
type ContentPart struct {
|
|
Type string `json:"type"`
|
|
Text string `json:"text"`
|
|
Annotations []Annotation `json:"annotations"`
|
|
}
|
|
|
|
// Annotation on output text content.
|
|
type Annotation struct {
|
|
Type string `json:"type"`
|
|
}
|
|
|
|
// IncompleteDetails explains why a response is incomplete.
|
|
type IncompleteDetails struct {
|
|
Reason string `json:"reason"`
|
|
}
|
|
|
|
// ResponseError describes an error in the response.
|
|
type ResponseError struct {
|
|
Type string `json:"type"`
|
|
Message string `json:"message"`
|
|
Code *string `json:"code"`
|
|
}
|
|
|
|
// ============================================================
|
|
// Usage Types
|
|
// ============================================================
|
|
|
|
// Usage captures token accounting with sub-details.
|
|
type Usage struct {
|
|
InputTokens int `json:"input_tokens"`
|
|
OutputTokens int `json:"output_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
InputTokensDetails InputTokensDetails `json:"input_tokens_details"`
|
|
OutputTokensDetails OutputTokensDetails `json:"output_tokens_details"`
|
|
}
|
|
|
|
// InputTokensDetails breaks down input token usage.
|
|
type InputTokensDetails struct {
|
|
CachedTokens int `json:"cached_tokens"`
|
|
}
|
|
|
|
// OutputTokensDetails breaks down output token usage.
|
|
type OutputTokensDetails struct {
|
|
ReasoningTokens int `json:"reasoning_tokens"`
|
|
}
|
|
|
|
// ============================================================
|
|
// Streaming Types
|
|
// ============================================================
|
|
|
|
// StreamEvent represents a single SSE event in the streaming response.
|
|
// Fields are selectively populated based on the event Type.
|
|
type StreamEvent struct {
|
|
Type string `json:"type"`
|
|
SequenceNumber int `json:"sequence_number"`
|
|
Response *Response `json:"response,omitempty"`
|
|
OutputIndex *int `json:"output_index,omitempty"`
|
|
Item *OutputItem `json:"item,omitempty"`
|
|
ItemID string `json:"item_id,omitempty"`
|
|
ContentIndex *int `json:"content_index,omitempty"`
|
|
Part *ContentPart `json:"part,omitempty"`
|
|
Delta string `json:"delta,omitempty"`
|
|
Text string `json:"text,omitempty"`
|
|
Arguments string `json:"arguments,omitempty"` // for function_call_arguments.done
|
|
}
|
|
|
|
// ============================================================
|
|
// Provider Result Types (internal, not exposed via HTTP)
|
|
// ============================================================
|
|
|
|
// ProviderResult is returned by Provider.Generate.
|
|
type ProviderResult struct {
|
|
ID string
|
|
Model string
|
|
Text string
|
|
Usage Usage
|
|
ToolCalls []ToolCall
|
|
}
|
|
|
|
// ProviderStreamDelta is sent through the stream channel.
|
|
type ProviderStreamDelta struct {
|
|
ID string
|
|
Model string
|
|
Text string
|
|
Done bool
|
|
Usage *Usage
|
|
ToolCallDelta *ToolCallDelta
|
|
}
|
|
|
|
// ToolCall represents a function call from the model.
|
|
type ToolCall struct {
|
|
ID string
|
|
Name string
|
|
Arguments string // JSON string
|
|
}
|
|
|
|
// ToolCallDelta represents a streaming chunk of a tool call.
|
|
type ToolCallDelta struct {
|
|
Index int
|
|
ID string
|
|
Name string
|
|
Arguments string
|
|
}
|
|
|
|
// ============================================================
|
|
// Models Endpoint Types
|
|
// ============================================================
|
|
|
|
// ModelInfo describes a single model available through the gateway.
|
|
type ModelInfo struct {
|
|
ID string `json:"id"`
|
|
Provider string `json:"provider"`
|
|
}
|
|
|
|
// ModelsResponse is returned by GET /v1/models.
|
|
type ModelsResponse struct {
|
|
Object string `json:"object"`
|
|
Data []ModelInfo `json:"data"`
|
|
}
|
|
|
|
// ============================================================
|
|
// Validation
|
|
// ============================================================
|
|
|
|
// Validate performs basic structural validation.
|
|
func (r *ResponseRequest) Validate() error <span class="cov8" title="1">{
|
|
if r == nil </span><span class="cov8" title="1">{
|
|
return errors.New("request is nil")
|
|
}</span>
|
|
<span class="cov8" title="1">if r.Model == "" </span><span class="cov8" title="1">{
|
|
return errors.New("model is required")
|
|
}</span>
|
|
<span class="cov8" title="1">if r.Input.String == nil && len(r.Input.Items) == 0 </span><span class="cov8" title="1">{
|
|
return errors.New("input is required")
|
|
}</span>
|
|
<span class="cov8" title="1">return nil</span>
|
|
}
|
|
|
|
// getStringField is a helper to safely extract string fields from a map
|
|
func getStringField(m map[string]interface{}, key string) string <span class="cov8" title="1">{
|
|
if val, ok := m[key].(string); ok </span><span class="cov8" title="1">{
|
|
return val
|
|
}</span>
|
|
<span class="cov8" title="1">return ""</span>
|
|
}
|
|
</pre>
|
|
|
|
<pre class="file" id="file2" style="display: none">package auth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rsa"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log/slog"
|
|
"math/big"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
)
|
|
|
|
// Config holds OIDC authentication configuration.
|
|
type Config struct {
|
|
Enabled bool `yaml:"enabled"`
|
|
Issuer string `yaml:"issuer"` // e.g., "https://accounts.google.com"
|
|
Audience string `yaml:"audience"` // e.g., your client ID
|
|
}
|
|
|
|
// Middleware provides JWT validation middleware.
|
|
type Middleware struct {
|
|
cfg Config
|
|
keys map[string]*rsa.PublicKey
|
|
mu sync.RWMutex
|
|
client *http.Client
|
|
logger *slog.Logger
|
|
}
|
|
|
|
// New creates an authentication middleware.
|
|
func New(cfg Config, logger *slog.Logger) (*Middleware, error) <span class="cov8" title="1">{
|
|
if !cfg.Enabled </span><span class="cov8" title="1">{
|
|
return &Middleware{cfg: cfg, logger: logger}, nil
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">if cfg.Issuer == "" </span><span class="cov8" title="1">{
|
|
return nil, fmt.Errorf("auth enabled but issuer not configured")
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">m := &Middleware{
|
|
cfg: cfg,
|
|
keys: make(map[string]*rsa.PublicKey),
|
|
client: &http.Client{Timeout: 10 * time.Second},
|
|
logger: logger,
|
|
}
|
|
|
|
// Fetch JWKS on startup
|
|
if err := m.refreshJWKS(); err != nil </span><span class="cov8" title="1">{
|
|
return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
|
|
}</span>
|
|
|
|
// Refresh JWKS periodically
|
|
<span class="cov8" title="1">go m.periodicRefresh()
|
|
|
|
return m, nil</span>
|
|
}
|
|
|
|
// Handler wraps an HTTP handler with authentication.
|
|
func (m *Middleware) Handler(next http.Handler) http.Handler <span class="cov8" title="1">{
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) </span><span class="cov8" title="1">{
|
|
if !m.cfg.Enabled </span><span class="cov8" title="1">{
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}</span>
|
|
|
|
// Extract token from Authorization header
|
|
<span class="cov8" title="1">authHeader := r.Header.Get("Authorization")
|
|
if authHeader == "" </span><span class="cov8" title="1">{
|
|
http.Error(w, "missing authorization header", http.StatusUnauthorized)
|
|
return
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">parts := strings.SplitN(authHeader, " ", 2)
|
|
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" </span><span class="cov8" title="1">{
|
|
http.Error(w, "invalid authorization header format", http.StatusUnauthorized)
|
|
return
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">tokenString := parts[1]
|
|
|
|
// Validate token
|
|
claims, err := m.validateToken(tokenString)
|
|
if err != nil </span><span class="cov8" title="1">{
|
|
http.Error(w, fmt.Sprintf("invalid token: %v", err), http.StatusUnauthorized)
|
|
return
|
|
}</span>
|
|
|
|
// Add claims to context
|
|
<span class="cov8" title="1">ctx := context.WithValue(r.Context(), claimsKey, claims)
|
|
next.ServeHTTP(w, r.WithContext(ctx))</span>
|
|
})
|
|
}
|
|
|
|
type contextKey string
|
|
|
|
const claimsKey contextKey = "jwt_claims"
|
|
|
|
// GetClaims extracts JWT claims from request context.
|
|
func GetClaims(ctx context.Context) (jwt.MapClaims, bool) <span class="cov8" title="1">{
|
|
claims, ok := ctx.Value(claimsKey).(jwt.MapClaims)
|
|
return claims, ok
|
|
}</span>
|
|
|
|
func (m *Middleware) validateToken(tokenString string) (jwt.MapClaims, error) <span class="cov8" title="1">{
|
|
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) </span><span class="cov8" title="1">{
|
|
// Verify signing method
|
|
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok </span><span class="cov8" title="1">{
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
}</span>
|
|
|
|
// Get key ID from token header
|
|
<span class="cov8" title="1">kid, ok := token.Header["kid"].(string)
|
|
if !ok </span><span class="cov8" title="1">{
|
|
return nil, fmt.Errorf("missing kid in token header")
|
|
}</span>
|
|
|
|
// Get public key
|
|
<span class="cov8" title="1">m.mu.RLock()
|
|
key, exists := m.keys[kid]
|
|
m.mu.RUnlock()
|
|
|
|
if !exists </span><span class="cov8" title="1">{
|
|
// Try refreshing JWKS
|
|
if err := m.refreshJWKS(); err != nil </span><span class="cov0" title="0">{
|
|
return nil, fmt.Errorf("failed to refresh JWKS: %w", err)
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">m.mu.RLock()
|
|
key, exists = m.keys[kid]
|
|
m.mu.RUnlock()
|
|
|
|
if !exists </span><span class="cov8" title="1">{
|
|
return nil, fmt.Errorf("unknown key ID: %s", kid)
|
|
}</span>
|
|
}
|
|
|
|
<span class="cov8" title="1">return key, nil</span>
|
|
})
|
|
|
|
<span class="cov8" title="1">if err != nil </span><span class="cov8" title="1">{
|
|
return nil, err
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">claims, ok := token.Claims.(jwt.MapClaims)
|
|
if !ok || !token.Valid </span><span class="cov0" title="0">{
|
|
return nil, fmt.Errorf("invalid token claims")
|
|
}</span>
|
|
|
|
// Validate issuer
|
|
<span class="cov8" title="1">if iss, ok := claims["iss"].(string); !ok || iss != m.cfg.Issuer </span><span class="cov8" title="1">{
|
|
return nil, fmt.Errorf("invalid issuer: %s", iss)
|
|
}</span>
|
|
|
|
// Validate audience if configured
|
|
<span class="cov8" title="1">if m.cfg.Audience != "" </span><span class="cov8" title="1">{
|
|
aud, ok := claims["aud"].(string)
|
|
if !ok </span><span class="cov8" title="1">{
|
|
// aud might be an array
|
|
audArray, ok := claims["aud"].([]interface{})
|
|
if !ok </span><span class="cov8" title="1">{
|
|
return nil, fmt.Errorf("invalid audience format")
|
|
}</span>
|
|
<span class="cov8" title="1">found := false
|
|
for _, a := range audArray </span><span class="cov8" title="1">{
|
|
if audStr, ok := a.(string); ok && audStr == m.cfg.Audience </span><span class="cov8" title="1">{
|
|
found = true
|
|
break</span>
|
|
}
|
|
}
|
|
<span class="cov8" title="1">if !found </span><span class="cov8" title="1">{
|
|
return nil, fmt.Errorf("audience not matched")
|
|
}</span>
|
|
} else<span class="cov8" title="1"> if aud != m.cfg.Audience </span><span class="cov8" title="1">{
|
|
return nil, fmt.Errorf("invalid audience: %s", aud)
|
|
}</span>
|
|
}
|
|
|
|
<span class="cov8" title="1">return claims, nil</span>
|
|
}
|
|
|
|
func (m *Middleware) refreshJWKS() error <span class="cov8" title="1">{
|
|
jwksURL := strings.TrimSuffix(m.cfg.Issuer, "/") + "/.well-known/openid-configuration"
|
|
|
|
// Fetch OIDC discovery document
|
|
resp, err := m.client.Get(jwksURL)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
return err
|
|
}</span>
|
|
<span class="cov8" title="1">defer resp.Body.Close()
|
|
|
|
var oidcConfig struct {
|
|
JwksURI string `json:"jwks_uri"`
|
|
}
|
|
if err := json.NewDecoder(resp.Body).Decode(&oidcConfig); err != nil </span><span class="cov8" title="1">{
|
|
return err
|
|
}</span>
|
|
|
|
// Fetch JWKS
|
|
<span class="cov8" title="1">resp, err = m.client.Get(oidcConfig.JwksURI)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
return err
|
|
}</span>
|
|
<span class="cov8" title="1">defer resp.Body.Close()
|
|
|
|
var jwks struct {
|
|
Keys []struct {
|
|
Kid string `json:"kid"`
|
|
Kty string `json:"kty"`
|
|
Use string `json:"use"`
|
|
N string `json:"n"`
|
|
E string `json:"e"`
|
|
} `json:"keys"`
|
|
}
|
|
|
|
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil </span><span class="cov0" title="0">{
|
|
return err
|
|
}</span>
|
|
|
|
// Parse keys
|
|
<span class="cov8" title="1">newKeys := make(map[string]*rsa.PublicKey)
|
|
for _, key := range jwks.Keys </span><span class="cov8" title="1">{
|
|
if key.Kty != "RSA" || key.Use != "sig" </span><span class="cov8" title="1">{
|
|
continue</span>
|
|
}
|
|
|
|
<span class="cov8" title="1">nBytes, err := base64.RawURLEncoding.DecodeString(key.N)
|
|
if err != nil </span><span class="cov8" title="1">{
|
|
continue</span>
|
|
}
|
|
|
|
<span class="cov8" title="1">eBytes, err := base64.RawURLEncoding.DecodeString(key.E)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
continue</span>
|
|
}
|
|
|
|
<span class="cov8" title="1">pubKey := &rsa.PublicKey{
|
|
N: new(big.Int).SetBytes(nBytes),
|
|
E: int(new(big.Int).SetBytes(eBytes).Int64()),
|
|
}
|
|
|
|
newKeys[key.Kid] = pubKey</span>
|
|
}
|
|
|
|
<span class="cov8" title="1">m.mu.Lock()
|
|
m.keys = newKeys
|
|
m.mu.Unlock()
|
|
|
|
return nil</span>
|
|
}
|
|
|
|
func (m *Middleware) periodicRefresh() <span class="cov8" title="1">{
|
|
ticker := time.NewTicker(1 * time.Hour)
|
|
defer ticker.Stop()
|
|
|
|
for range ticker.C </span><span class="cov0" title="0">{
|
|
if err := m.refreshJWKS(); err != nil </span><span class="cov0" title="0">{
|
|
m.logger.Error("failed to refresh JWKS",
|
|
slog.String("issuer", m.cfg.Issuer),
|
|
slog.String("error", err.Error()),
|
|
)
|
|
}</span> else<span class="cov0" title="0"> {
|
|
m.logger.Debug("successfully refreshed JWKS",
|
|
slog.String("issuer", m.cfg.Issuer),
|
|
)
|
|
}</span>
|
|
}
|
|
}
|
|
</pre>
|
|
|
|
<pre class="file" id="file3" style="display: none">package config
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
|
|
"gopkg.in/yaml.v3"
|
|
)
|
|
|
|
// Config describes the full gateway configuration file.
|
|
type Config struct {
|
|
Server ServerConfig `yaml:"server"`
|
|
Providers map[string]ProviderEntry `yaml:"providers"`
|
|
Models []ModelEntry `yaml:"models"`
|
|
Auth AuthConfig `yaml:"auth"`
|
|
Conversations ConversationConfig `yaml:"conversations"`
|
|
Logging LoggingConfig `yaml:"logging"`
|
|
RateLimit RateLimitConfig `yaml:"rate_limit"`
|
|
Observability ObservabilityConfig `yaml:"observability"`
|
|
}
|
|
|
|
// ConversationConfig controls conversation storage.
|
|
type ConversationConfig struct {
|
|
// Store is the storage backend: "memory" (default), "sql", or "redis".
|
|
Store string `yaml:"store"`
|
|
// TTL is the conversation expiration duration (e.g. "1h", "30m"). Defaults to "1h".
|
|
TTL string `yaml:"ttl"`
|
|
// DSN is the database/Redis connection string, required when store is "sql" or "redis".
|
|
// Examples: "conversations.db" (SQLite), "postgres://user:pass@host/db", "redis://:password@localhost:6379/0".
|
|
DSN string `yaml:"dsn"`
|
|
// Driver is the SQL driver name, required when store is "sql".
|
|
// Examples: "sqlite3", "postgres", "mysql".
|
|
Driver string `yaml:"driver"`
|
|
}
|
|
|
|
// LoggingConfig controls logging format and level.
|
|
type LoggingConfig struct {
|
|
// Format is the log output format: "json" (default) or "text".
|
|
Format string `yaml:"format"`
|
|
// Level is the minimum log level: "debug", "info" (default), "warn", or "error".
|
|
Level string `yaml:"level"`
|
|
}
|
|
|
|
// RateLimitConfig controls rate limiting behavior.
|
|
type RateLimitConfig struct {
|
|
// Enabled controls whether rate limiting is active.
|
|
Enabled bool `yaml:"enabled"`
|
|
// RequestsPerSecond is the number of requests allowed per second per IP.
|
|
RequestsPerSecond float64 `yaml:"requests_per_second"`
|
|
// Burst is the maximum burst size allowed.
|
|
Burst int `yaml:"burst"`
|
|
}
|
|
|
|
// ObservabilityConfig controls observability features.
|
|
type ObservabilityConfig struct {
|
|
Enabled bool `yaml:"enabled"`
|
|
Metrics MetricsConfig `yaml:"metrics"`
|
|
Tracing TracingConfig `yaml:"tracing"`
|
|
}
|
|
|
|
// MetricsConfig controls Prometheus metrics.
|
|
type MetricsConfig struct {
|
|
Enabled bool `yaml:"enabled"`
|
|
Path string `yaml:"path"` // default: "/metrics"
|
|
}
|
|
|
|
// TracingConfig controls OpenTelemetry tracing.
|
|
type TracingConfig struct {
|
|
Enabled bool `yaml:"enabled"`
|
|
ServiceName string `yaml:"service_name"` // default: "llm-gateway"
|
|
Sampler SamplerConfig `yaml:"sampler"`
|
|
Exporter ExporterConfig `yaml:"exporter"`
|
|
}
|
|
|
|
// SamplerConfig controls trace sampling.
|
|
type SamplerConfig struct {
|
|
Type string `yaml:"type"` // "always", "never", "probability"
|
|
Rate float64 `yaml:"rate"` // 0.0 to 1.0
|
|
}
|
|
|
|
// ExporterConfig controls trace exporters.
|
|
type ExporterConfig struct {
|
|
Type string `yaml:"type"` // "otlp", "stdout"
|
|
Endpoint string `yaml:"endpoint"`
|
|
Insecure bool `yaml:"insecure"`
|
|
Headers map[string]string `yaml:"headers"`
|
|
}
|
|
|
|
// AuthConfig holds OIDC authentication settings.
|
|
type AuthConfig struct {
|
|
Enabled bool `yaml:"enabled"`
|
|
Issuer string `yaml:"issuer"`
|
|
Audience string `yaml:"audience"`
|
|
}
|
|
|
|
// ServerConfig controls HTTP server values.
|
|
type ServerConfig struct {
|
|
Address string `yaml:"address"`
|
|
MaxRequestBodySize int64 `yaml:"max_request_body_size"` // Maximum request body size in bytes (default: 10MB)
|
|
}
|
|
|
|
// ProviderEntry defines a named provider instance in the config file.
|
|
type ProviderEntry struct {
|
|
Type string `yaml:"type"`
|
|
APIKey string `yaml:"api_key"`
|
|
Endpoint string `yaml:"endpoint"`
|
|
APIVersion string `yaml:"api_version"`
|
|
Project string `yaml:"project"` // For Vertex AI
|
|
Location string `yaml:"location"` // For Vertex AI
|
|
}
|
|
|
|
// ModelEntry maps a model name to a provider entry.
|
|
type ModelEntry struct {
|
|
Name string `yaml:"name"`
|
|
Provider string `yaml:"provider"`
|
|
ProviderModelID string `yaml:"provider_model_id"`
|
|
}
|
|
|
|
// ProviderConfig contains shared provider configuration fields used internally by providers.
|
|
type ProviderConfig struct {
|
|
APIKey string `yaml:"api_key"`
|
|
Model string `yaml:"model"`
|
|
Endpoint string `yaml:"endpoint"`
|
|
}
|
|
|
|
// AzureOpenAIConfig contains Azure-specific settings used internally by the OpenAI provider.
|
|
type AzureOpenAIConfig struct {
|
|
APIKey string `yaml:"api_key"`
|
|
Endpoint string `yaml:"endpoint"`
|
|
APIVersion string `yaml:"api_version"`
|
|
}
|
|
|
|
// AzureAnthropicConfig contains Azure-specific settings for Anthropic used internally.
|
|
type AzureAnthropicConfig struct {
|
|
APIKey string `yaml:"api_key"`
|
|
Endpoint string `yaml:"endpoint"`
|
|
Model string `yaml:"model"`
|
|
}
|
|
|
|
// VertexAIConfig contains Vertex AI-specific settings used internally by the Google provider.
|
|
type VertexAIConfig struct {
|
|
Project string `yaml:"project"`
|
|
Location string `yaml:"location"`
|
|
}
|
|
|
|
// Load reads and parses a YAML configuration file, expanding ${VAR} env references.
|
|
func Load(path string) (*Config, error) <span class="cov8" title="1">{
|
|
data, err := os.ReadFile(path)
|
|
if err != nil </span><span class="cov8" title="1">{
|
|
return nil, fmt.Errorf("read config: %w", err)
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">expanded := os.Expand(string(data), os.Getenv)
|
|
|
|
var cfg Config
|
|
if err := yaml.Unmarshal([]byte(expanded), &cfg); err != nil </span><span class="cov8" title="1">{
|
|
return nil, fmt.Errorf("parse config: %w", err)
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">if err := cfg.validate(); err != nil </span><span class="cov8" title="1">{
|
|
return nil, err
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">return &cfg, nil</span>
|
|
}
|
|
|
|
func (cfg *Config) validate() error <span class="cov8" title="1">{
|
|
for _, m := range cfg.Models </span><span class="cov8" title="1">{
|
|
if _, ok := cfg.Providers[m.Provider]; !ok </span><span class="cov8" title="1">{
|
|
return fmt.Errorf("model %q references unknown provider %q", m.Name, m.Provider)
|
|
}</span>
|
|
}
|
|
<span class="cov8" title="1">return nil</span>
|
|
}
|
|
</pre>
|
|
|
|
<pre class="file" id="file4" style="display: none">package conversation
|
|
|
|
import (
|
|
"context"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/ajac-zero/latticelm/internal/api"
|
|
)
|
|
|
|
// Store defines the interface for conversation storage backends.
|
|
type Store interface {
|
|
Get(ctx context.Context, id string) (*Conversation, error)
|
|
Create(ctx context.Context, id string, model string, messages []api.Message) (*Conversation, error)
|
|
Append(ctx context.Context, id string, messages ...api.Message) (*Conversation, error)
|
|
Delete(ctx context.Context, id string) error
|
|
Size() int
|
|
Close() error
|
|
}
|
|
|
|
// MemoryStore manages conversation history in-memory with automatic expiration.
|
|
type MemoryStore struct {
|
|
conversations map[string]*Conversation
|
|
mu sync.RWMutex
|
|
ttl time.Duration
|
|
done chan struct{}
|
|
}
|
|
|
|
// Conversation holds the message history for a single conversation thread.
|
|
type Conversation struct {
|
|
ID string
|
|
Messages []api.Message
|
|
Model string
|
|
CreatedAt time.Time
|
|
UpdatedAt time.Time
|
|
}
|
|
|
|
// NewMemoryStore creates an in-memory conversation store with the given TTL.
|
|
func NewMemoryStore(ttl time.Duration) *MemoryStore <span class="cov8" title="1">{
|
|
s := &MemoryStore{
|
|
conversations: make(map[string]*Conversation),
|
|
ttl: ttl,
|
|
done: make(chan struct{}),
|
|
}
|
|
|
|
// Start cleanup goroutine if TTL is set
|
|
if ttl > 0 </span><span class="cov8" title="1">{
|
|
go s.cleanup()
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">return s</span>
|
|
}
|
|
|
|
// Get retrieves a conversation by ID. Returns a deep copy to prevent data races.
|
|
func (s *MemoryStore) Get(ctx context.Context, id string) (*Conversation, error) <span class="cov8" title="1">{
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
conv, ok := s.conversations[id]
|
|
if !ok </span><span class="cov8" title="1">{
|
|
return nil, nil
|
|
}</span>
|
|
|
|
// Return a deep copy to prevent data races
|
|
<span class="cov8" title="1">msgsCopy := make([]api.Message, len(conv.Messages))
|
|
copy(msgsCopy, conv.Messages)
|
|
|
|
return &Conversation{
|
|
ID: conv.ID,
|
|
Messages: msgsCopy,
|
|
Model: conv.Model,
|
|
CreatedAt: conv.CreatedAt,
|
|
UpdatedAt: conv.UpdatedAt,
|
|
}, nil</span>
|
|
}
|
|
|
|
// Create creates a new conversation with the given messages.
|
|
func (s *MemoryStore) Create(ctx context.Context, id string, model string, messages []api.Message) (*Conversation, error) <span class="cov8" title="1">{
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
|
|
// Store a copy to prevent external modifications
|
|
msgsCopy := make([]api.Message, len(messages))
|
|
copy(msgsCopy, messages)
|
|
|
|
conv := &Conversation{
|
|
ID: id,
|
|
Messages: msgsCopy,
|
|
Model: model,
|
|
CreatedAt: now,
|
|
UpdatedAt: now,
|
|
}
|
|
|
|
s.conversations[id] = conv
|
|
|
|
// Return a copy
|
|
return &Conversation{
|
|
ID: id,
|
|
Messages: messages,
|
|
Model: model,
|
|
CreatedAt: now,
|
|
UpdatedAt: now,
|
|
}, nil
|
|
}</span>
|
|
|
|
// Append adds new messages to an existing conversation.
|
|
func (s *MemoryStore) Append(ctx context.Context, id string, messages ...api.Message) (*Conversation, error) <span class="cov8" title="1">{
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
conv, ok := s.conversations[id]
|
|
if !ok </span><span class="cov8" title="1">{
|
|
return nil, nil
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">conv.Messages = append(conv.Messages, messages...)
|
|
conv.UpdatedAt = time.Now()
|
|
|
|
// Return a deep copy
|
|
msgsCopy := make([]api.Message, len(conv.Messages))
|
|
copy(msgsCopy, conv.Messages)
|
|
|
|
return &Conversation{
|
|
ID: conv.ID,
|
|
Messages: msgsCopy,
|
|
Model: conv.Model,
|
|
CreatedAt: conv.CreatedAt,
|
|
UpdatedAt: conv.UpdatedAt,
|
|
}, nil</span>
|
|
}
|
|
|
|
// Delete removes a conversation from the store.
|
|
func (s *MemoryStore) Delete(ctx context.Context, id string) error <span class="cov8" title="1">{
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
delete(s.conversations, id)
|
|
return nil
|
|
}</span>
|
|
|
|
// cleanup periodically removes expired conversations.
|
|
func (s *MemoryStore) cleanup() <span class="cov8" title="1">{
|
|
ticker := time.NewTicker(1 * time.Minute)
|
|
defer ticker.Stop()
|
|
|
|
for </span><span class="cov8" title="1">{
|
|
select </span>{
|
|
case <-ticker.C:<span class="cov0" title="0">
|
|
s.mu.Lock()
|
|
now := time.Now()
|
|
for id, conv := range s.conversations </span><span class="cov0" title="0">{
|
|
if now.Sub(conv.UpdatedAt) > s.ttl </span><span class="cov0" title="0">{
|
|
delete(s.conversations, id)
|
|
}</span>
|
|
}
|
|
<span class="cov0" title="0">s.mu.Unlock()</span>
|
|
case <-s.done:<span class="cov0" title="0">
|
|
return</span>
|
|
}
|
|
}
|
|
}
|
|
|
|
// Size returns the number of active conversations.
|
|
func (s *MemoryStore) Size() int <span class="cov8" title="1">{
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
return len(s.conversations)
|
|
}</span>
|
|
|
|
// Close stops the cleanup goroutine and releases resources.
|
|
func (s *MemoryStore) Close() error <span class="cov0" title="0">{
|
|
close(s.done)
|
|
return nil
|
|
}</span>
|
|
</pre>
|
|
|
|
<pre class="file" id="file5" style="display: none">package conversation
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"time"
|
|
|
|
"github.com/ajac-zero/latticelm/internal/api"
|
|
"github.com/redis/go-redis/v9"
|
|
)
|
|
|
|
// RedisStore manages conversation history in Redis with automatic expiration.
|
|
type RedisStore struct {
|
|
client *redis.Client
|
|
ttl time.Duration
|
|
}
|
|
|
|
// NewRedisStore creates a Redis-backed conversation store.
|
|
func NewRedisStore(client *redis.Client, ttl time.Duration) *RedisStore <span class="cov8" title="1">{
|
|
return &RedisStore{
|
|
client: client,
|
|
ttl: ttl,
|
|
}
|
|
}</span>
|
|
|
|
// key returns the Redis key for a conversation ID.
|
|
func (s *RedisStore) key(id string) string <span class="cov8" title="1">{
|
|
return "conv:" + id
|
|
}</span>
|
|
|
|
// Get retrieves a conversation by ID from Redis.
|
|
func (s *RedisStore) Get(ctx context.Context, id string) (*Conversation, error) <span class="cov8" title="1">{
|
|
data, err := s.client.Get(ctx, s.key(id)).Bytes()
|
|
if err == redis.Nil </span><span class="cov8" title="1">{
|
|
return nil, nil
|
|
}</span>
|
|
<span class="cov8" title="1">if err != nil </span><span class="cov0" title="0">{
|
|
return nil, err
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">var conv Conversation
|
|
if err := json.Unmarshal(data, &conv); err != nil </span><span class="cov0" title="0">{
|
|
return nil, err
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">return &conv, nil</span>
|
|
}
|
|
|
|
// Create creates a new conversation with the given messages.
|
|
func (s *RedisStore) Create(ctx context.Context, id string, model string, messages []api.Message) (*Conversation, error) <span class="cov8" title="1">{
|
|
now := time.Now()
|
|
conv := &Conversation{
|
|
ID: id,
|
|
Messages: messages,
|
|
Model: model,
|
|
CreatedAt: now,
|
|
UpdatedAt: now,
|
|
}
|
|
|
|
data, err := json.Marshal(conv)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
return nil, err
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">if err := s.client.Set(ctx, s.key(id), data, s.ttl).Err(); err != nil </span><span class="cov8" title="1">{
|
|
return nil, err
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">return conv, nil</span>
|
|
}
|
|
|
|
// Append adds new messages to an existing conversation.
|
|
func (s *RedisStore) Append(ctx context.Context, id string, messages ...api.Message) (*Conversation, error) <span class="cov8" title="1">{
|
|
conv, err := s.Get(ctx, id)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
return nil, err
|
|
}</span>
|
|
<span class="cov8" title="1">if conv == nil </span><span class="cov0" title="0">{
|
|
return nil, nil
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">conv.Messages = append(conv.Messages, messages...)
|
|
conv.UpdatedAt = time.Now()
|
|
|
|
data, err := json.Marshal(conv)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
return nil, err
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">if err := s.client.Set(ctx, s.key(id), data, s.ttl).Err(); err != nil </span><span class="cov0" title="0">{
|
|
return nil, err
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">return conv, nil</span>
|
|
}
|
|
|
|
// Delete removes a conversation from Redis.
|
|
func (s *RedisStore) Delete(ctx context.Context, id string) error <span class="cov8" title="1">{
|
|
return s.client.Del(ctx, s.key(id)).Err()
|
|
}</span>
|
|
|
|
// Size returns the number of active conversations in Redis.
|
|
func (s *RedisStore) Size() int <span class="cov8" title="1">{
|
|
var count int
|
|
var cursor uint64
|
|
ctx := context.Background()
|
|
|
|
for </span><span class="cov8" title="1">{
|
|
keys, nextCursor, err := s.client.Scan(ctx, cursor, "conv:*", 100).Result()
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
return 0
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">count += len(keys)
|
|
cursor = nextCursor
|
|
|
|
if cursor == 0 </span><span class="cov8" title="1">{
|
|
break</span>
|
|
}
|
|
}
|
|
|
|
<span class="cov8" title="1">return count</span>
|
|
}
|
|
|
|
// Close closes the Redis client connection.
|
|
func (s *RedisStore) Close() error <span class="cov8" title="1">{
|
|
return s.client.Close()
|
|
}</span>
|
|
</pre>
|
|
|
|
<pre class="file" id="file6" style="display: none">package conversation
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"time"
|
|
|
|
"github.com/ajac-zero/latticelm/internal/api"
|
|
)
|
|
|
|
// sqlDialect holds driver-specific SQL statements.
|
|
type sqlDialect struct {
|
|
getByID string
|
|
upsert string
|
|
update string
|
|
deleteByID string
|
|
cleanup string
|
|
}
|
|
|
|
func newDialect(driver string) sqlDialect <span class="cov8" title="1">{
|
|
if driver == "pgx" || driver == "postgres" </span><span class="cov0" title="0">{
|
|
return sqlDialect{
|
|
getByID: `SELECT id, model, messages, created_at, updated_at FROM conversations WHERE id = $1`,
|
|
upsert: `INSERT INTO conversations (id, model, messages, created_at, updated_at) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (id) DO UPDATE SET model = EXCLUDED.model, messages = EXCLUDED.messages, updated_at = EXCLUDED.updated_at`,
|
|
update: `UPDATE conversations SET messages = $1, updated_at = $2 WHERE id = $3`,
|
|
deleteByID: `DELETE FROM conversations WHERE id = $1`,
|
|
cleanup: `DELETE FROM conversations WHERE updated_at < $1`,
|
|
}
|
|
}</span>
|
|
<span class="cov8" title="1">return sqlDialect{
|
|
getByID: `SELECT id, model, messages, created_at, updated_at FROM conversations WHERE id = ?`,
|
|
upsert: `REPLACE INTO conversations (id, model, messages, created_at, updated_at) VALUES (?, ?, ?, ?, ?)`,
|
|
update: `UPDATE conversations SET messages = ?, updated_at = ? WHERE id = ?`,
|
|
deleteByID: `DELETE FROM conversations WHERE id = ?`,
|
|
cleanup: `DELETE FROM conversations WHERE updated_at < ?`,
|
|
}</span>
|
|
}
|
|
|
|
// SQLStore manages conversation history in a SQL database with automatic expiration.
|
|
type SQLStore struct {
|
|
db *sql.DB
|
|
ttl time.Duration
|
|
dialect sqlDialect
|
|
done chan struct{}
|
|
}
|
|
|
|
// NewSQLStore creates a SQL-backed conversation store. It creates the
|
|
// conversations table if it does not already exist and starts a background
|
|
// goroutine to remove expired rows.
|
|
func NewSQLStore(db *sql.DB, driver string, ttl time.Duration) (*SQLStore, error) <span class="cov8" title="1">{
|
|
_, err := db.Exec(`CREATE TABLE IF NOT EXISTS conversations (
|
|
id TEXT PRIMARY KEY,
|
|
model TEXT NOT NULL,
|
|
messages TEXT NOT NULL,
|
|
created_at TIMESTAMP NOT NULL,
|
|
updated_at TIMESTAMP NOT NULL
|
|
)`)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
return nil, err
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">s := &SQLStore{
|
|
db: db,
|
|
ttl: ttl,
|
|
dialect: newDialect(driver),
|
|
done: make(chan struct{}),
|
|
}
|
|
if ttl > 0 </span><span class="cov8" title="1">{
|
|
go s.cleanup()
|
|
}</span>
|
|
<span class="cov8" title="1">return s, nil</span>
|
|
}
|
|
|
|
func (s *SQLStore) Get(ctx context.Context, id string) (*Conversation, error) <span class="cov8" title="1">{
|
|
row := s.db.QueryRowContext(ctx, s.dialect.getByID, id)
|
|
|
|
var conv Conversation
|
|
var msgJSON string
|
|
err := row.Scan(&conv.ID, &conv.Model, &msgJSON, &conv.CreatedAt, &conv.UpdatedAt)
|
|
if err == sql.ErrNoRows </span><span class="cov8" title="1">{
|
|
return nil, nil
|
|
}</span>
|
|
<span class="cov8" title="1">if err != nil </span><span class="cov0" title="0">{
|
|
return nil, err
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">if err := json.Unmarshal([]byte(msgJSON), &conv.Messages); err != nil </span><span class="cov0" title="0">{
|
|
return nil, err
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">return &conv, nil</span>
|
|
}
|
|
|
|
func (s *SQLStore) Create(ctx context.Context, id string, model string, messages []api.Message) (*Conversation, error) <span class="cov8" title="1">{
|
|
now := time.Now()
|
|
msgJSON, err := json.Marshal(messages)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
return nil, err
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">if _, err := s.db.ExecContext(ctx, s.dialect.upsert, id, model, string(msgJSON), now, now); err != nil </span><span class="cov8" title="1">{
|
|
return nil, err
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">return &Conversation{
|
|
ID: id,
|
|
Messages: messages,
|
|
Model: model,
|
|
CreatedAt: now,
|
|
UpdatedAt: now,
|
|
}, nil</span>
|
|
}
|
|
|
|
func (s *SQLStore) Append(ctx context.Context, id string, messages ...api.Message) (*Conversation, error) <span class="cov8" title="1">{
|
|
conv, err := s.Get(ctx, id)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
return nil, err
|
|
}</span>
|
|
<span class="cov8" title="1">if conv == nil </span><span class="cov0" title="0">{
|
|
return nil, nil
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">conv.Messages = append(conv.Messages, messages...)
|
|
conv.UpdatedAt = time.Now()
|
|
|
|
msgJSON, err := json.Marshal(conv.Messages)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
return nil, err
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">if _, err := s.db.ExecContext(ctx, s.dialect.update, string(msgJSON), conv.UpdatedAt, id); err != nil </span><span class="cov0" title="0">{
|
|
return nil, err
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">return conv, nil</span>
|
|
}
|
|
|
|
func (s *SQLStore) Delete(ctx context.Context, id string) error <span class="cov8" title="1">{
|
|
_, err := s.db.ExecContext(ctx, s.dialect.deleteByID, id)
|
|
return err
|
|
}</span>
|
|
|
|
func (s *SQLStore) Size() int <span class="cov8" title="1">{
|
|
var count int
|
|
_ = s.db.QueryRow(`SELECT COUNT(*) FROM conversations`).Scan(&count)
|
|
return count
|
|
}</span>
|
|
|
|
func (s *SQLStore) cleanup() <span class="cov8" title="1">{
|
|
ticker := time.NewTicker(1 * time.Minute)
|
|
defer ticker.Stop()
|
|
|
|
for </span><span class="cov8" title="1">{
|
|
select </span>{
|
|
case <-ticker.C:<span class="cov0" title="0">
|
|
cutoff := time.Now().Add(-s.ttl)
|
|
_, _ = s.db.Exec(s.dialect.cleanup, cutoff)</span>
|
|
case <-s.done:<span class="cov8" title="1">
|
|
return</span>
|
|
}
|
|
}
|
|
}
|
|
|
|
// Close stops the cleanup goroutine and closes the database connection.
|
|
func (s *SQLStore) Close() error <span class="cov8" title="1">{
|
|
close(s.done)
|
|
return s.db.Close()
|
|
}</span>
|
|
</pre>
|
|
|
|
<pre class="file" id="file7" style="display: none">package conversation
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/alicebob/miniredis/v2"
|
|
_ "github.com/mattn/go-sqlite3"
|
|
"github.com/redis/go-redis/v9"
|
|
|
|
"github.com/ajac-zero/latticelm/internal/api"
|
|
)
|
|
|
|
// SetupTestDB creates an in-memory SQLite database for testing
|
|
func SetupTestDB(t *testing.T, driver string) *sql.DB <span class="cov0" title="0">{
|
|
t.Helper()
|
|
|
|
var dsn string
|
|
switch driver </span>{
|
|
case "sqlite3":<span class="cov0" title="0">
|
|
// Use in-memory SQLite database
|
|
dsn = ":memory:"</span>
|
|
case "postgres":<span class="cov0" title="0">
|
|
// For postgres tests, use a mock or skip
|
|
t.Skip("PostgreSQL tests require external database")
|
|
return nil</span>
|
|
case "mysql":<span class="cov0" title="0">
|
|
// For mysql tests, use a mock or skip
|
|
t.Skip("MySQL tests require external database")
|
|
return nil</span>
|
|
default:<span class="cov0" title="0">
|
|
t.Fatalf("unsupported driver: %s", driver)
|
|
return nil</span>
|
|
}
|
|
|
|
<span class="cov0" title="0">db, err := sql.Open(driver, dsn)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
t.Fatalf("failed to open database: %v", err)
|
|
}</span>
|
|
|
|
// Create the conversations table
|
|
<span class="cov0" title="0">schema := `
|
|
CREATE TABLE IF NOT EXISTS conversations (
|
|
conversation_id TEXT PRIMARY KEY,
|
|
messages TEXT NOT NULL,
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
`
|
|
if _, err := db.Exec(schema); err != nil </span><span class="cov0" title="0">{
|
|
db.Close()
|
|
t.Fatalf("failed to create schema: %v", err)
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">return db</span>
|
|
}
|
|
|
|
// SetupTestRedis creates a miniredis instance for testing
|
|
func SetupTestRedis(t *testing.T) (*redis.Client, *miniredis.Miniredis) <span class="cov8" title="1">{
|
|
t.Helper()
|
|
|
|
mr := miniredis.RunT(t)
|
|
|
|
client := redis.NewClient(&redis.Options{
|
|
Addr: mr.Addr(),
|
|
})
|
|
|
|
// Test connection
|
|
ctx := context.Background()
|
|
if err := client.Ping(ctx).Err(); err != nil </span><span class="cov0" title="0">{
|
|
t.Fatalf("failed to connect to miniredis: %v", err)
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">return client, mr</span>
|
|
}
|
|
|
|
// CreateTestMessages generates test message fixtures
|
|
func CreateTestMessages(count int) []api.Message <span class="cov8" title="1">{
|
|
messages := make([]api.Message, count)
|
|
for i := 0; i < count; i++ </span><span class="cov8" title="1">{
|
|
role := "user"
|
|
if i%2 == 1 </span><span class="cov8" title="1">{
|
|
role = "assistant"
|
|
}</span>
|
|
<span class="cov8" title="1">messages[i] = api.Message{
|
|
Role: role,
|
|
Content: []api.ContentBlock{
|
|
{
|
|
Type: "text",
|
|
Text: fmt.Sprintf("Test message %d", i+1),
|
|
},
|
|
},
|
|
}</span>
|
|
}
|
|
<span class="cov8" title="1">return messages</span>
|
|
}
|
|
|
|
// CreateTestConversation creates a test conversation with the given ID and messages
|
|
func CreateTestConversation(conversationID string, messageCount int) *Conversation <span class="cov0" title="0">{
|
|
return &Conversation{
|
|
ID: conversationID,
|
|
Messages: CreateTestMessages(messageCount),
|
|
Model: "test-model",
|
|
CreatedAt: time.Now(),
|
|
UpdatedAt: time.Now(),
|
|
}
|
|
}</span>
|
|
|
|
// MockStore is a simple in-memory store for testing
|
|
type MockStore struct {
|
|
conversations map[string]*Conversation
|
|
getCalled bool
|
|
createCalled bool
|
|
appendCalled bool
|
|
deleteCalled bool
|
|
sizeCalled bool
|
|
}
|
|
|
|
func NewMockStore() *MockStore <span class="cov0" title="0">{
|
|
return &MockStore{
|
|
conversations: make(map[string]*Conversation),
|
|
}
|
|
}</span>
|
|
|
|
func (m *MockStore) Get(ctx context.Context, conversationID string) (*Conversation, error) <span class="cov0" title="0">{
|
|
m.getCalled = true
|
|
conv, ok := m.conversations[conversationID]
|
|
if !ok </span><span class="cov0" title="0">{
|
|
return nil, fmt.Errorf("conversation not found")
|
|
}</span>
|
|
<span class="cov0" title="0">return conv, nil</span>
|
|
}
|
|
|
|
func (m *MockStore) Create(ctx context.Context, conversationID string, model string, messages []api.Message) (*Conversation, error) <span class="cov0" title="0">{
|
|
m.createCalled = true
|
|
m.conversations[conversationID] = &Conversation{
|
|
ID: conversationID,
|
|
Model: model,
|
|
Messages: messages,
|
|
CreatedAt: time.Now(),
|
|
UpdatedAt: time.Now(),
|
|
}
|
|
return m.conversations[conversationID], nil
|
|
}</span>
|
|
|
|
func (m *MockStore) Append(ctx context.Context, conversationID string, messages ...api.Message) (*Conversation, error) <span class="cov0" title="0">{
|
|
m.appendCalled = true
|
|
conv, ok := m.conversations[conversationID]
|
|
if !ok </span><span class="cov0" title="0">{
|
|
return nil, fmt.Errorf("conversation not found")
|
|
}</span>
|
|
<span class="cov0" title="0">conv.Messages = append(conv.Messages, messages...)
|
|
conv.UpdatedAt = time.Now()
|
|
return conv, nil</span>
|
|
}
|
|
|
|
func (m *MockStore) Delete(ctx context.Context, conversationID string) error <span class="cov0" title="0">{
|
|
m.deleteCalled = true
|
|
delete(m.conversations, conversationID)
|
|
return nil
|
|
}</span>
|
|
|
|
func (m *MockStore) Size() int <span class="cov0" title="0">{
|
|
m.sizeCalled = true
|
|
return len(m.conversations)
|
|
}</span>
|
|
|
|
func (m *MockStore) Close() error <span class="cov0" title="0">{
|
|
return nil
|
|
}</span>
|
|
</pre>
|
|
|
|
<pre class="file" id="file8" style="display: none">package logger
|
|
|
|
import (
|
|
"context"
|
|
"log/slog"
|
|
"os"
|
|
|
|
"go.opentelemetry.io/otel/trace"
|
|
)
|
|
|
|
type contextKey string
|
|
|
|
const requestIDKey contextKey = "request_id"
|
|
|
|
// New creates a logger with the specified format (json or text) and level.
|
|
func New(format string, level string) *slog.Logger <span class="cov0" title="0">{
|
|
var handler slog.Handler
|
|
|
|
logLevel := parseLevel(level)
|
|
opts := &slog.HandlerOptions{
|
|
Level: logLevel,
|
|
AddSource: true, // Add file:line info for debugging
|
|
}
|
|
|
|
if format == "json" </span><span class="cov0" title="0">{
|
|
handler = slog.NewJSONHandler(os.Stdout, opts)
|
|
}</span> else<span class="cov0" title="0"> {
|
|
handler = slog.NewTextHandler(os.Stdout, opts)
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">return slog.New(handler)</span>
|
|
}
|
|
|
|
// parseLevel converts a string level to slog.Level.
|
|
func parseLevel(level string) slog.Level <span class="cov0" title="0">{
|
|
switch level </span>{
|
|
case "debug":<span class="cov0" title="0">
|
|
return slog.LevelDebug</span>
|
|
case "info":<span class="cov0" title="0">
|
|
return slog.LevelInfo</span>
|
|
case "warn":<span class="cov0" title="0">
|
|
return slog.LevelWarn</span>
|
|
case "error":<span class="cov0" title="0">
|
|
return slog.LevelError</span>
|
|
default:<span class="cov0" title="0">
|
|
return slog.LevelInfo</span>
|
|
}
|
|
}
|
|
|
|
// WithRequestID adds a request ID to the context for tracing.
|
|
func WithRequestID(ctx context.Context, requestID string) context.Context <span class="cov0" title="0">{
|
|
return context.WithValue(ctx, requestIDKey, requestID)
|
|
}</span>
|
|
|
|
// FromContext extracts the request ID from context, or returns empty string.
|
|
func FromContext(ctx context.Context) string <span class="cov0" title="0">{
|
|
if id, ok := ctx.Value(requestIDKey).(string); ok </span><span class="cov0" title="0">{
|
|
return id
|
|
}</span>
|
|
<span class="cov0" title="0">return ""</span>
|
|
}
|
|
|
|
// LogAttrsWithTrace adds trace context to log attributes for correlation.
|
|
func LogAttrsWithTrace(ctx context.Context, attrs ...any) []any <span class="cov0" title="0">{
|
|
spanCtx := trace.SpanFromContext(ctx).SpanContext()
|
|
if spanCtx.IsValid() </span><span class="cov0" title="0">{
|
|
attrs = append(attrs,
|
|
slog.String("trace_id", spanCtx.TraceID().String()),
|
|
slog.String("span_id", spanCtx.SpanID().String()),
|
|
)
|
|
}</span>
|
|
<span class="cov0" title="0">return attrs</span>
|
|
}
|
|
</pre>
|
|
|
|
<pre class="file" id="file9" style="display: none">package observability
|
|
|
|
import (
|
|
"github.com/ajac-zero/latticelm/internal/conversation"
|
|
"github.com/ajac-zero/latticelm/internal/providers"
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
|
)
|
|
|
|
// ProviderRegistry defines the interface for provider registries.
|
|
// This matches the interface expected by the server.
|
|
type ProviderRegistry interface {
|
|
Get(name string) (providers.Provider, bool)
|
|
Models() []struct{ Provider, Model string }
|
|
ResolveModelID(model string) string
|
|
Default(model string) (providers.Provider, error)
|
|
}
|
|
|
|
// WrapProviderRegistry wraps all providers in a registry with observability.
|
|
func WrapProviderRegistry(registry ProviderRegistry, metricsRegistry *prometheus.Registry, tp *sdktrace.TracerProvider) ProviderRegistry <span class="cov0" title="0">{
|
|
if registry == nil </span><span class="cov0" title="0">{
|
|
return nil
|
|
}</span>
|
|
|
|
// We can't directly modify the registry's internal map, so we'll need to
|
|
// wrap providers as they're retrieved. Instead, create a new instrumented registry.
|
|
<span class="cov0" title="0">return &InstrumentedRegistry{
|
|
base: registry,
|
|
metrics: metricsRegistry,
|
|
tracer: tp,
|
|
wrappedProviders: make(map[string]providers.Provider),
|
|
}</span>
|
|
}
|
|
|
|
// InstrumentedRegistry wraps a provider registry to return instrumented providers.
|
|
type InstrumentedRegistry struct {
|
|
base ProviderRegistry
|
|
metrics *prometheus.Registry
|
|
tracer *sdktrace.TracerProvider
|
|
wrappedProviders map[string]providers.Provider
|
|
}
|
|
|
|
// Get returns an instrumented provider by entry name.
|
|
func (r *InstrumentedRegistry) Get(name string) (providers.Provider, bool) <span class="cov0" title="0">{
|
|
// Check if we've already wrapped this provider
|
|
if wrapped, ok := r.wrappedProviders[name]; ok </span><span class="cov0" title="0">{
|
|
return wrapped, true
|
|
}</span>
|
|
|
|
// Get the base provider
|
|
<span class="cov0" title="0">p, ok := r.base.Get(name)
|
|
if !ok </span><span class="cov0" title="0">{
|
|
return nil, false
|
|
}</span>
|
|
|
|
// Wrap it
|
|
<span class="cov0" title="0">wrapped := NewInstrumentedProvider(p, r.metrics, r.tracer)
|
|
r.wrappedProviders[name] = wrapped
|
|
return wrapped, true</span>
|
|
}
|
|
|
|
// Default returns the instrumented provider for the given model name.
|
|
func (r *InstrumentedRegistry) Default(model string) (providers.Provider, error) <span class="cov0" title="0">{
|
|
p, err := r.base.Default(model)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
return nil, err
|
|
}</span>
|
|
|
|
// Check if we've already wrapped this provider
|
|
<span class="cov0" title="0">name := p.Name()
|
|
if wrapped, ok := r.wrappedProviders[name]; ok </span><span class="cov0" title="0">{
|
|
return wrapped, nil
|
|
}</span>
|
|
|
|
// Wrap it
|
|
<span class="cov0" title="0">wrapped := NewInstrumentedProvider(p, r.metrics, r.tracer)
|
|
r.wrappedProviders[name] = wrapped
|
|
return wrapped, nil</span>
|
|
}
|
|
|
|
// Models returns the list of configured models and their provider entry names.
|
|
func (r *InstrumentedRegistry) Models() []struct{ Provider, Model string } <span class="cov0" title="0">{
|
|
return r.base.Models()
|
|
}</span>
|
|
|
|
// ResolveModelID returns the provider_model_id for a model.
|
|
func (r *InstrumentedRegistry) ResolveModelID(model string) string <span class="cov0" title="0">{
|
|
return r.base.ResolveModelID(model)
|
|
}</span>
|
|
|
|
// WrapConversationStore wraps a conversation store with observability.
|
|
func WrapConversationStore(store conversation.Store, backend string, metricsRegistry *prometheus.Registry, tp *sdktrace.TracerProvider) conversation.Store <span class="cov0" title="0">{
|
|
if store == nil </span><span class="cov0" title="0">{
|
|
return nil
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">return NewInstrumentedStore(store, backend, metricsRegistry, tp)</span>
|
|
}
|
|
</pre>
|
|
|
|
<pre class="file" id="file10" style="display: none">package observability
|
|
|
|
import (
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
)
|
|
|
|
var (
|
|
// HTTP Metrics
|
|
httpRequestsTotal = prometheus.NewCounterVec(
|
|
prometheus.CounterOpts{
|
|
Name: "http_requests_total",
|
|
Help: "Total number of HTTP requests",
|
|
},
|
|
[]string{"method", "path", "status"},
|
|
)
|
|
|
|
httpRequestDuration = prometheus.NewHistogramVec(
|
|
prometheus.HistogramOpts{
|
|
Name: "http_request_duration_seconds",
|
|
Help: "HTTP request latency in seconds",
|
|
Buckets: []float64{0.01, 0.05, 0.1, 0.5, 1, 2.5, 5, 10, 30},
|
|
},
|
|
[]string{"method", "path", "status"},
|
|
)
|
|
|
|
httpRequestSize = prometheus.NewHistogramVec(
|
|
prometheus.HistogramOpts{
|
|
Name: "http_request_size_bytes",
|
|
Help: "HTTP request size in bytes",
|
|
Buckets: prometheus.ExponentialBuckets(100, 10, 7), // 100B to 100MB
|
|
},
|
|
[]string{"method", "path"},
|
|
)
|
|
|
|
httpResponseSize = prometheus.NewHistogramVec(
|
|
prometheus.HistogramOpts{
|
|
Name: "http_response_size_bytes",
|
|
Help: "HTTP response size in bytes",
|
|
Buckets: prometheus.ExponentialBuckets(100, 10, 7), // 100B to 100MB
|
|
},
|
|
[]string{"method", "path"},
|
|
)
|
|
|
|
// Provider Metrics
|
|
providerRequestsTotal = prometheus.NewCounterVec(
|
|
prometheus.CounterOpts{
|
|
Name: "provider_requests_total",
|
|
Help: "Total number of provider requests",
|
|
},
|
|
[]string{"provider", "model", "operation", "status"},
|
|
)
|
|
|
|
providerRequestDuration = prometheus.NewHistogramVec(
|
|
prometheus.HistogramOpts{
|
|
Name: "provider_request_duration_seconds",
|
|
Help: "Provider request latency in seconds",
|
|
Buckets: []float64{0.1, 0.5, 1, 2, 5, 10, 20, 30, 60},
|
|
},
|
|
[]string{"provider", "model", "operation"},
|
|
)
|
|
|
|
providerTokensTotal = prometheus.NewCounterVec(
|
|
prometheus.CounterOpts{
|
|
Name: "provider_tokens_total",
|
|
Help: "Total number of tokens processed",
|
|
},
|
|
[]string{"provider", "model", "type"}, // type: input, output
|
|
)
|
|
|
|
providerStreamTTFB = prometheus.NewHistogramVec(
|
|
prometheus.HistogramOpts{
|
|
Name: "provider_stream_ttfb_seconds",
|
|
Help: "Time to first byte for streaming requests in seconds",
|
|
Buckets: []float64{0.05, 0.1, 0.5, 1, 2, 5, 10},
|
|
},
|
|
[]string{"provider", "model"},
|
|
)
|
|
|
|
providerStreamChunks = prometheus.NewCounterVec(
|
|
prometheus.CounterOpts{
|
|
Name: "provider_stream_chunks_total",
|
|
Help: "Total number of stream chunks received",
|
|
},
|
|
[]string{"provider", "model"},
|
|
)
|
|
|
|
providerStreamDuration = prometheus.NewHistogramVec(
|
|
prometheus.HistogramOpts{
|
|
Name: "provider_stream_duration_seconds",
|
|
Help: "Total duration of streaming requests in seconds",
|
|
Buckets: []float64{0.1, 0.5, 1, 2, 5, 10, 20, 30, 60},
|
|
},
|
|
[]string{"provider", "model"},
|
|
)
|
|
|
|
// Conversation Store Metrics
|
|
conversationOperationsTotal = prometheus.NewCounterVec(
|
|
prometheus.CounterOpts{
|
|
Name: "conversation_operations_total",
|
|
Help: "Total number of conversation store operations",
|
|
},
|
|
[]string{"operation", "backend", "status"},
|
|
)
|
|
|
|
conversationOperationDuration = prometheus.NewHistogramVec(
|
|
prometheus.HistogramOpts{
|
|
Name: "conversation_operation_duration_seconds",
|
|
Help: "Conversation store operation latency in seconds",
|
|
Buckets: []float64{0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1},
|
|
},
|
|
[]string{"operation", "backend"},
|
|
)
|
|
|
|
conversationActiveCount = prometheus.NewGaugeVec(
|
|
prometheus.GaugeOpts{
|
|
Name: "conversation_active_count",
|
|
Help: "Number of active conversations",
|
|
},
|
|
[]string{"backend"},
|
|
)
|
|
|
|
// Circuit Breaker Metrics
|
|
circuitBreakerState = prometheus.NewGaugeVec(
|
|
prometheus.GaugeOpts{
|
|
Name: "circuit_breaker_state",
|
|
Help: "Circuit breaker state (0=closed, 1=open, 2=half-open)",
|
|
},
|
|
[]string{"provider"},
|
|
)
|
|
|
|
circuitBreakerStateTransitions = prometheus.NewCounterVec(
|
|
prometheus.CounterOpts{
|
|
Name: "circuit_breaker_state_transitions_total",
|
|
Help: "Total number of circuit breaker state transitions",
|
|
},
|
|
[]string{"provider", "from", "to"},
|
|
)
|
|
)
|
|
|
|
// InitMetrics registers all metrics with a new Prometheus registry.
|
|
func InitMetrics() *prometheus.Registry <span class="cov8" title="1">{
|
|
registry := prometheus.NewRegistry()
|
|
|
|
// Register HTTP metrics
|
|
registry.MustRegister(httpRequestsTotal)
|
|
registry.MustRegister(httpRequestDuration)
|
|
registry.MustRegister(httpRequestSize)
|
|
registry.MustRegister(httpResponseSize)
|
|
|
|
// Register provider metrics
|
|
registry.MustRegister(providerRequestsTotal)
|
|
registry.MustRegister(providerRequestDuration)
|
|
registry.MustRegister(providerTokensTotal)
|
|
registry.MustRegister(providerStreamTTFB)
|
|
registry.MustRegister(providerStreamChunks)
|
|
registry.MustRegister(providerStreamDuration)
|
|
|
|
// Register conversation store metrics
|
|
registry.MustRegister(conversationOperationsTotal)
|
|
registry.MustRegister(conversationOperationDuration)
|
|
registry.MustRegister(conversationActiveCount)
|
|
|
|
// Register circuit breaker metrics
|
|
registry.MustRegister(circuitBreakerState)
|
|
registry.MustRegister(circuitBreakerStateTransitions)
|
|
|
|
return registry
|
|
}</span>
|
|
|
|
// RecordCircuitBreakerStateChange records a circuit breaker state transition.
|
|
func RecordCircuitBreakerStateChange(provider, from, to string) <span class="cov8" title="1">{
|
|
// Record the transition
|
|
circuitBreakerStateTransitions.WithLabelValues(provider, from, to).Inc()
|
|
|
|
// Update the current state gauge
|
|
var stateValue float64
|
|
switch to </span>{
|
|
case "closed":<span class="cov8" title="1">
|
|
stateValue = 0</span>
|
|
case "open":<span class="cov8" title="1">
|
|
stateValue = 1</span>
|
|
case "half-open":<span class="cov8" title="1">
|
|
stateValue = 2</span>
|
|
}
|
|
<span class="cov8" title="1">circuitBreakerState.WithLabelValues(provider).Set(stateValue)</span>
|
|
}
|
|
</pre>
|
|
|
|
<pre class="file" id="file11" style="display: none">package observability
|
|
|
|
import (
|
|
"net/http"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
)
|
|
|
|
// MetricsMiddleware creates a middleware that records HTTP metrics.
|
|
func MetricsMiddleware(next http.Handler, registry *prometheus.Registry, _ interface{}) http.Handler <span class="cov0" title="0">{
|
|
if registry == nil </span><span class="cov0" title="0">{
|
|
// If metrics are not enabled, pass through without modification
|
|
return next
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) </span><span class="cov0" title="0">{
|
|
start := time.Now()
|
|
|
|
// Record request size
|
|
if r.ContentLength > 0 </span><span class="cov0" title="0">{
|
|
httpRequestSize.WithLabelValues(r.Method, r.URL.Path).Observe(float64(r.ContentLength))
|
|
}</span>
|
|
|
|
// Wrap response writer to capture status code and response size
|
|
<span class="cov0" title="0">wrapped := &metricsResponseWriter{
|
|
ResponseWriter: w,
|
|
statusCode: http.StatusOK,
|
|
bytesWritten: 0,
|
|
}
|
|
|
|
// Call the next handler
|
|
next.ServeHTTP(wrapped, r)
|
|
|
|
// Record metrics after request completes
|
|
duration := time.Since(start).Seconds()
|
|
status := strconv.Itoa(wrapped.statusCode)
|
|
|
|
httpRequestsTotal.WithLabelValues(r.Method, r.URL.Path, status).Inc()
|
|
httpRequestDuration.WithLabelValues(r.Method, r.URL.Path, status).Observe(duration)
|
|
httpResponseSize.WithLabelValues(r.Method, r.URL.Path).Observe(float64(wrapped.bytesWritten))</span>
|
|
})
|
|
}
|
|
|
|
// metricsResponseWriter wraps http.ResponseWriter to capture status code and bytes written.
|
|
type metricsResponseWriter struct {
|
|
http.ResponseWriter
|
|
statusCode int
|
|
bytesWritten int
|
|
}
|
|
|
|
func (w *metricsResponseWriter) WriteHeader(statusCode int) <span class="cov0" title="0">{
|
|
w.statusCode = statusCode
|
|
w.ResponseWriter.WriteHeader(statusCode)
|
|
}</span>
|
|
|
|
func (w *metricsResponseWriter) Write(b []byte) (int, error) <span class="cov0" title="0">{
|
|
n, err := w.ResponseWriter.Write(b)
|
|
w.bytesWritten += n
|
|
return n, err
|
|
}</span>
|
|
</pre>
|
|
|
|
<pre class="file" id="file12" style="display: none">package observability
|
|
|
|
import (
|
|
"context"
|
|
"time"
|
|
|
|
"github.com/ajac-zero/latticelm/internal/api"
|
|
"github.com/ajac-zero/latticelm/internal/providers"
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
"go.opentelemetry.io/otel/attribute"
|
|
"go.opentelemetry.io/otel/codes"
|
|
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
|
"go.opentelemetry.io/otel/trace"
|
|
)
|
|
|
|
// InstrumentedProvider wraps a provider with metrics and tracing.
|
|
type InstrumentedProvider struct {
|
|
base providers.Provider
|
|
registry *prometheus.Registry
|
|
tracer trace.Tracer
|
|
}
|
|
|
|
// NewInstrumentedProvider wraps a provider with observability.
|
|
func NewInstrumentedProvider(p providers.Provider, registry *prometheus.Registry, tp *sdktrace.TracerProvider) providers.Provider <span class="cov8" title="1">{
|
|
var tracer trace.Tracer
|
|
if tp != nil </span><span class="cov8" title="1">{
|
|
tracer = tp.Tracer("llm-gateway")
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">return &InstrumentedProvider{
|
|
base: p,
|
|
registry: registry,
|
|
tracer: tracer,
|
|
}</span>
|
|
}
|
|
|
|
// Name returns the name of the underlying provider.
|
|
func (p *InstrumentedProvider) Name() string <span class="cov8" title="1">{
|
|
return p.base.Name()
|
|
}</span>
|
|
|
|
// Generate wraps the provider's Generate method with metrics and tracing.
|
|
func (p *InstrumentedProvider) Generate(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (*api.ProviderResult, error) <span class="cov8" title="1">{
|
|
// Start span if tracing is enabled
|
|
if p.tracer != nil </span><span class="cov8" title="1">{
|
|
var span trace.Span
|
|
ctx, span = p.tracer.Start(ctx, "provider.generate",
|
|
trace.WithSpanKind(trace.SpanKindClient),
|
|
trace.WithAttributes(
|
|
attribute.String("provider.name", p.base.Name()),
|
|
attribute.String("provider.model", req.Model),
|
|
),
|
|
)
|
|
defer span.End()
|
|
}</span>
|
|
|
|
// Record start time
|
|
<span class="cov8" title="1">start := time.Now()
|
|
|
|
// Call underlying provider
|
|
result, err := p.base.Generate(ctx, messages, req)
|
|
|
|
// Record metrics
|
|
duration := time.Since(start).Seconds()
|
|
status := "success"
|
|
if err != nil </span><span class="cov8" title="1">{
|
|
status = "error"
|
|
if p.tracer != nil </span><span class="cov8" title="1">{
|
|
span := trace.SpanFromContext(ctx)
|
|
span.RecordError(err)
|
|
span.SetStatus(codes.Error, err.Error())
|
|
}</span>
|
|
} else<span class="cov8" title="1"> if result != nil </span><span class="cov8" title="1">{
|
|
// Add token attributes to span
|
|
if p.tracer != nil </span><span class="cov8" title="1">{
|
|
span := trace.SpanFromContext(ctx)
|
|
span.SetAttributes(
|
|
attribute.Int64("provider.input_tokens", int64(result.Usage.InputTokens)),
|
|
attribute.Int64("provider.output_tokens", int64(result.Usage.OutputTokens)),
|
|
attribute.Int64("provider.total_tokens", int64(result.Usage.TotalTokens)),
|
|
)
|
|
span.SetStatus(codes.Ok, "")
|
|
}</span>
|
|
|
|
// Record token metrics
|
|
<span class="cov8" title="1">if p.registry != nil </span><span class="cov8" title="1">{
|
|
providerTokensTotal.WithLabelValues(p.base.Name(), req.Model, "input").Add(float64(result.Usage.InputTokens))
|
|
providerTokensTotal.WithLabelValues(p.base.Name(), req.Model, "output").Add(float64(result.Usage.OutputTokens))
|
|
}</span>
|
|
}
|
|
|
|
// Record request metrics
|
|
<span class="cov8" title="1">if p.registry != nil </span><span class="cov8" title="1">{
|
|
providerRequestsTotal.WithLabelValues(p.base.Name(), req.Model, "generate", status).Inc()
|
|
providerRequestDuration.WithLabelValues(p.base.Name(), req.Model, "generate").Observe(duration)
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">return result, err</span>
|
|
}
|
|
|
|
// GenerateStream wraps the provider's GenerateStream method with metrics and tracing.
|
|
func (p *InstrumentedProvider) GenerateStream(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (<-chan *api.ProviderStreamDelta, <-chan error) <span class="cov8" title="1">{
|
|
// Start span if tracing is enabled
|
|
if p.tracer != nil </span><span class="cov8" title="1">{
|
|
var span trace.Span
|
|
ctx, span = p.tracer.Start(ctx, "provider.generate_stream",
|
|
trace.WithSpanKind(trace.SpanKindClient),
|
|
trace.WithAttributes(
|
|
attribute.String("provider.name", p.base.Name()),
|
|
attribute.String("provider.model", req.Model),
|
|
),
|
|
)
|
|
defer span.End()
|
|
}</span>
|
|
|
|
// Record start time
|
|
<span class="cov8" title="1">start := time.Now()
|
|
var ttfb time.Duration
|
|
firstChunk := true
|
|
|
|
// Create instrumented channels
|
|
baseChan, baseErrChan := p.base.GenerateStream(ctx, messages, req)
|
|
outChan := make(chan *api.ProviderStreamDelta)
|
|
outErrChan := make(chan error, 1)
|
|
|
|
// Metrics tracking
|
|
var chunkCount int64
|
|
var totalInputTokens, totalOutputTokens int64
|
|
var streamErr error
|
|
|
|
go func() </span><span class="cov8" title="1">{
|
|
defer close(outChan)
|
|
defer close(outErrChan)
|
|
|
|
for </span><span class="cov8" title="1">{
|
|
select </span>{
|
|
case delta, ok := <-baseChan:<span class="cov8" title="1">
|
|
if !ok </span><span class="cov8" title="1">{
|
|
// Stream finished - record final metrics
|
|
duration := time.Since(start).Seconds()
|
|
status := "success"
|
|
if streamErr != nil </span><span class="cov0" title="0">{
|
|
status = "error"
|
|
if p.tracer != nil </span><span class="cov0" title="0">{
|
|
span := trace.SpanFromContext(ctx)
|
|
span.RecordError(streamErr)
|
|
span.SetStatus(codes.Error, streamErr.Error())
|
|
}</span>
|
|
} else<span class="cov8" title="1"> {
|
|
if p.tracer != nil </span><span class="cov0" title="0">{
|
|
span := trace.SpanFromContext(ctx)
|
|
span.SetAttributes(
|
|
attribute.Int64("provider.input_tokens", totalInputTokens),
|
|
attribute.Int64("provider.output_tokens", totalOutputTokens),
|
|
attribute.Int64("provider.chunk_count", chunkCount),
|
|
attribute.Float64("provider.ttfb_seconds", ttfb.Seconds()),
|
|
)
|
|
span.SetStatus(codes.Ok, "")
|
|
}</span>
|
|
|
|
// Record token metrics
|
|
<span class="cov8" title="1">if p.registry != nil && (totalInputTokens > 0 || totalOutputTokens > 0) </span><span class="cov0" title="0">{
|
|
providerTokensTotal.WithLabelValues(p.base.Name(), req.Model, "input").Add(float64(totalInputTokens))
|
|
providerTokensTotal.WithLabelValues(p.base.Name(), req.Model, "output").Add(float64(totalOutputTokens))
|
|
}</span>
|
|
}
|
|
|
|
// Record stream metrics
|
|
<span class="cov8" title="1">if p.registry != nil </span><span class="cov8" title="1">{
|
|
providerRequestsTotal.WithLabelValues(p.base.Name(), req.Model, "generate_stream", status).Inc()
|
|
providerStreamDuration.WithLabelValues(p.base.Name(), req.Model).Observe(duration)
|
|
providerStreamChunks.WithLabelValues(p.base.Name(), req.Model).Add(float64(chunkCount))
|
|
if ttfb > 0 </span><span class="cov8" title="1">{
|
|
providerStreamTTFB.WithLabelValues(p.base.Name(), req.Model).Observe(ttfb.Seconds())
|
|
}</span>
|
|
}
|
|
<span class="cov8" title="1">return</span>
|
|
}
|
|
|
|
// Record TTFB on first chunk
|
|
<span class="cov8" title="1">if firstChunk </span><span class="cov8" title="1">{
|
|
ttfb = time.Since(start)
|
|
firstChunk = false
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">chunkCount++
|
|
|
|
// Track token usage
|
|
if delta.Usage != nil </span><span class="cov8" title="1">{
|
|
totalInputTokens = int64(delta.Usage.InputTokens)
|
|
totalOutputTokens = int64(delta.Usage.OutputTokens)
|
|
}</span>
|
|
|
|
// Forward the delta
|
|
<span class="cov8" title="1">outChan <- delta</span>
|
|
|
|
case err, ok := <-baseErrChan:<span class="cov8" title="1">
|
|
if ok && err != nil </span><span class="cov8" title="1">{
|
|
streamErr = err
|
|
outErrChan <- err
|
|
}</span>
|
|
<span class="cov8" title="1">return</span>
|
|
}
|
|
}
|
|
}()
|
|
|
|
<span class="cov8" title="1">return outChan, outErrChan</span>
|
|
}
|
|
</pre>
|
|
|
|
<pre class="file" id="file13" style="display: none">package observability
|
|
|
|
import (
|
|
"context"
|
|
"time"
|
|
|
|
"github.com/ajac-zero/latticelm/internal/api"
|
|
"github.com/ajac-zero/latticelm/internal/conversation"
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
"go.opentelemetry.io/otel/attribute"
|
|
"go.opentelemetry.io/otel/codes"
|
|
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
|
"go.opentelemetry.io/otel/trace"
|
|
)
|
|
|
|
// InstrumentedStore wraps a conversation store with metrics and tracing.
|
|
type InstrumentedStore struct {
|
|
base conversation.Store
|
|
registry *prometheus.Registry
|
|
tracer trace.Tracer
|
|
backend string
|
|
}
|
|
|
|
// NewInstrumentedStore wraps a conversation store with observability.
|
|
func NewInstrumentedStore(s conversation.Store, backend string, registry *prometheus.Registry, tp *sdktrace.TracerProvider) conversation.Store <span class="cov0" title="0">{
|
|
var tracer trace.Tracer
|
|
if tp != nil </span><span class="cov0" title="0">{
|
|
tracer = tp.Tracer("llm-gateway")
|
|
}</span>
|
|
|
|
// Initialize gauge with current size
|
|
<span class="cov0" title="0">if registry != nil </span><span class="cov0" title="0">{
|
|
conversationActiveCount.WithLabelValues(backend).Set(float64(s.Size()))
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">return &InstrumentedStore{
|
|
base: s,
|
|
registry: registry,
|
|
tracer: tracer,
|
|
backend: backend,
|
|
}</span>
|
|
}
|
|
|
|
// Get wraps the store's Get method with metrics and tracing.
|
|
func (s *InstrumentedStore) Get(ctx context.Context, id string) (*conversation.Conversation, error) <span class="cov0" title="0">{
|
|
// Start span if tracing is enabled
|
|
if s.tracer != nil </span><span class="cov0" title="0">{
|
|
var span trace.Span
|
|
ctx, span = s.tracer.Start(ctx, "conversation.get",
|
|
trace.WithAttributes(
|
|
attribute.String("conversation.id", id),
|
|
attribute.String("conversation.backend", s.backend),
|
|
),
|
|
)
|
|
defer span.End()
|
|
}</span>
|
|
|
|
// Record start time
|
|
<span class="cov0" title="0">start := time.Now()
|
|
|
|
// Call underlying store
|
|
conv, err := s.base.Get(ctx, id)
|
|
|
|
// Record metrics
|
|
duration := time.Since(start).Seconds()
|
|
status := "success"
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
status = "error"
|
|
if s.tracer != nil </span><span class="cov0" title="0">{
|
|
span := trace.SpanFromContext(ctx)
|
|
span.RecordError(err)
|
|
span.SetStatus(codes.Error, err.Error())
|
|
}</span>
|
|
} else<span class="cov0" title="0"> {
|
|
if s.tracer != nil </span><span class="cov0" title="0">{
|
|
span := trace.SpanFromContext(ctx)
|
|
if conv != nil </span><span class="cov0" title="0">{
|
|
span.SetAttributes(
|
|
attribute.Int("conversation.message_count", len(conv.Messages)),
|
|
attribute.String("conversation.model", conv.Model),
|
|
)
|
|
}</span>
|
|
<span class="cov0" title="0">span.SetStatus(codes.Ok, "")</span>
|
|
}
|
|
}
|
|
|
|
<span class="cov0" title="0">if s.registry != nil </span><span class="cov0" title="0">{
|
|
conversationOperationsTotal.WithLabelValues("get", s.backend, status).Inc()
|
|
conversationOperationDuration.WithLabelValues("get", s.backend).Observe(duration)
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">return conv, err</span>
|
|
}
|
|
|
|
// Create wraps the store's Create method with metrics and tracing.
|
|
func (s *InstrumentedStore) Create(ctx context.Context, id string, model string, messages []api.Message) (*conversation.Conversation, error) <span class="cov0" title="0">{
|
|
// Start span if tracing is enabled
|
|
if s.tracer != nil </span><span class="cov0" title="0">{
|
|
var span trace.Span
|
|
ctx, span = s.tracer.Start(ctx, "conversation.create",
|
|
trace.WithAttributes(
|
|
attribute.String("conversation.id", id),
|
|
attribute.String("conversation.backend", s.backend),
|
|
attribute.String("conversation.model", model),
|
|
attribute.Int("conversation.initial_messages", len(messages)),
|
|
),
|
|
)
|
|
defer span.End()
|
|
}</span>
|
|
|
|
// Record start time
|
|
<span class="cov0" title="0">start := time.Now()
|
|
|
|
// Call underlying store
|
|
conv, err := s.base.Create(ctx, id, model, messages)
|
|
|
|
// Record metrics
|
|
duration := time.Since(start).Seconds()
|
|
status := "success"
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
status = "error"
|
|
if s.tracer != nil </span><span class="cov0" title="0">{
|
|
span := trace.SpanFromContext(ctx)
|
|
span.RecordError(err)
|
|
span.SetStatus(codes.Error, err.Error())
|
|
}</span>
|
|
} else<span class="cov0" title="0"> {
|
|
if s.tracer != nil </span><span class="cov0" title="0">{
|
|
span := trace.SpanFromContext(ctx)
|
|
span.SetStatus(codes.Ok, "")
|
|
}</span>
|
|
}
|
|
|
|
<span class="cov0" title="0">if s.registry != nil </span><span class="cov0" title="0">{
|
|
conversationOperationsTotal.WithLabelValues("create", s.backend, status).Inc()
|
|
conversationOperationDuration.WithLabelValues("create", s.backend).Observe(duration)
|
|
// Update active count
|
|
conversationActiveCount.WithLabelValues(s.backend).Set(float64(s.base.Size()))
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">return conv, err</span>
|
|
}
|
|
|
|
// Append wraps the store's Append method with metrics and tracing.
|
|
func (s *InstrumentedStore) Append(ctx context.Context, id string, messages ...api.Message) (*conversation.Conversation, error) <span class="cov0" title="0">{
|
|
// Start span if tracing is enabled
|
|
if s.tracer != nil </span><span class="cov0" title="0">{
|
|
var span trace.Span
|
|
ctx, span = s.tracer.Start(ctx, "conversation.append",
|
|
trace.WithAttributes(
|
|
attribute.String("conversation.id", id),
|
|
attribute.String("conversation.backend", s.backend),
|
|
attribute.Int("conversation.appended_messages", len(messages)),
|
|
),
|
|
)
|
|
defer span.End()
|
|
}</span>
|
|
|
|
// Record start time
|
|
<span class="cov0" title="0">start := time.Now()
|
|
|
|
// Call underlying store
|
|
conv, err := s.base.Append(ctx, id, messages...)
|
|
|
|
// Record metrics
|
|
duration := time.Since(start).Seconds()
|
|
status := "success"
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
status = "error"
|
|
if s.tracer != nil </span><span class="cov0" title="0">{
|
|
span := trace.SpanFromContext(ctx)
|
|
span.RecordError(err)
|
|
span.SetStatus(codes.Error, err.Error())
|
|
}</span>
|
|
} else<span class="cov0" title="0"> {
|
|
if s.tracer != nil </span><span class="cov0" title="0">{
|
|
span := trace.SpanFromContext(ctx)
|
|
if conv != nil </span><span class="cov0" title="0">{
|
|
span.SetAttributes(
|
|
attribute.Int("conversation.total_messages", len(conv.Messages)),
|
|
)
|
|
}</span>
|
|
<span class="cov0" title="0">span.SetStatus(codes.Ok, "")</span>
|
|
}
|
|
}
|
|
|
|
<span class="cov0" title="0">if s.registry != nil </span><span class="cov0" title="0">{
|
|
conversationOperationsTotal.WithLabelValues("append", s.backend, status).Inc()
|
|
conversationOperationDuration.WithLabelValues("append", s.backend).Observe(duration)
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">return conv, err</span>
|
|
}
|
|
|
|
// Delete wraps the store's Delete method with metrics and tracing.
|
|
func (s *InstrumentedStore) Delete(ctx context.Context, id string) error <span class="cov0" title="0">{
|
|
// Start span if tracing is enabled
|
|
if s.tracer != nil </span><span class="cov0" title="0">{
|
|
var span trace.Span
|
|
ctx, span = s.tracer.Start(ctx, "conversation.delete",
|
|
trace.WithAttributes(
|
|
attribute.String("conversation.id", id),
|
|
attribute.String("conversation.backend", s.backend),
|
|
),
|
|
)
|
|
defer span.End()
|
|
}</span>
|
|
|
|
// Record start time
|
|
<span class="cov0" title="0">start := time.Now()
|
|
|
|
// Call underlying store
|
|
err := s.base.Delete(ctx, id)
|
|
|
|
// Record metrics
|
|
duration := time.Since(start).Seconds()
|
|
status := "success"
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
status = "error"
|
|
if s.tracer != nil </span><span class="cov0" title="0">{
|
|
span := trace.SpanFromContext(ctx)
|
|
span.RecordError(err)
|
|
span.SetStatus(codes.Error, err.Error())
|
|
}</span>
|
|
} else<span class="cov0" title="0"> {
|
|
if s.tracer != nil </span><span class="cov0" title="0">{
|
|
span := trace.SpanFromContext(ctx)
|
|
span.SetStatus(codes.Ok, "")
|
|
}</span>
|
|
}
|
|
|
|
<span class="cov0" title="0">if s.registry != nil </span><span class="cov0" title="0">{
|
|
conversationOperationsTotal.WithLabelValues("delete", s.backend, status).Inc()
|
|
conversationOperationDuration.WithLabelValues("delete", s.backend).Observe(duration)
|
|
// Update active count
|
|
conversationActiveCount.WithLabelValues(s.backend).Set(float64(s.base.Size()))
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">return err</span>
|
|
}
|
|
|
|
// Size returns the size of the underlying store.
|
|
func (s *InstrumentedStore) Size() int <span class="cov0" title="0">{
|
|
return s.base.Size()
|
|
}</span>
|
|
|
|
// Close wraps the store's Close method.
|
|
func (s *InstrumentedStore) Close() error <span class="cov0" title="0">{
|
|
return s.base.Close()
|
|
}</span>
|
|
</pre>
|
|
|
|
<pre class="file" id="file14" style="display: none">package observability
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
"github.com/prometheus/client_golang/prometheus/testutil"
|
|
"go.opentelemetry.io/otel"
|
|
"go.opentelemetry.io/otel/sdk/resource"
|
|
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
|
"go.opentelemetry.io/otel/sdk/trace/tracetest"
|
|
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
|
|
)
|
|
|
|
// NewTestRegistry creates a new isolated Prometheus registry for testing
|
|
func NewTestRegistry() *prometheus.Registry <span class="cov8" title="1">{
|
|
return prometheus.NewRegistry()
|
|
}</span>
|
|
|
|
// NewTestTracer creates a no-op tracer for testing
|
|
func NewTestTracer() (*sdktrace.TracerProvider, *tracetest.InMemoryExporter) <span class="cov8" title="1">{
|
|
exporter := tracetest.NewInMemoryExporter()
|
|
res := resource.NewSchemaless(
|
|
semconv.ServiceNameKey.String("test-service"),
|
|
)
|
|
tp := sdktrace.NewTracerProvider(
|
|
sdktrace.WithSyncer(exporter),
|
|
sdktrace.WithResource(res),
|
|
)
|
|
otel.SetTracerProvider(tp)
|
|
return tp, exporter
|
|
}</span>
|
|
|
|
// GetMetricValue extracts a metric value from a registry
|
|
func GetMetricValue(registry *prometheus.Registry, metricName string) (float64, error) <span class="cov0" title="0">{
|
|
metrics, err := registry.Gather()
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
return 0, err
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">for _, mf := range metrics </span><span class="cov0" title="0">{
|
|
if mf.GetName() == metricName </span><span class="cov0" title="0">{
|
|
if len(mf.GetMetric()) > 0 </span><span class="cov0" title="0">{
|
|
m := mf.GetMetric()[0]
|
|
if m.GetCounter() != nil </span><span class="cov0" title="0">{
|
|
return m.GetCounter().GetValue(), nil
|
|
}</span>
|
|
<span class="cov0" title="0">if m.GetGauge() != nil </span><span class="cov0" title="0">{
|
|
return m.GetGauge().GetValue(), nil
|
|
}</span>
|
|
<span class="cov0" title="0">if m.GetHistogram() != nil </span><span class="cov0" title="0">{
|
|
return float64(m.GetHistogram().GetSampleCount()), nil
|
|
}</span>
|
|
}
|
|
}
|
|
}
|
|
|
|
<span class="cov0" title="0">return 0, nil</span>
|
|
}
|
|
|
|
// CountMetricsWithName counts how many metrics match the given name
|
|
func CountMetricsWithName(registry *prometheus.Registry, metricName string) (int, error) <span class="cov0" title="0">{
|
|
metrics, err := registry.Gather()
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
return 0, err
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">for _, mf := range metrics </span><span class="cov0" title="0">{
|
|
if mf.GetName() == metricName </span><span class="cov0" title="0">{
|
|
return len(mf.GetMetric()), nil
|
|
}</span>
|
|
}
|
|
|
|
<span class="cov0" title="0">return 0, nil</span>
|
|
}
|
|
|
|
// GetCounterValue is a helper to get counter values using testutil
|
|
func GetCounterValue(counter prometheus.Counter) float64 <span class="cov0" title="0">{
|
|
return testutil.ToFloat64(counter)
|
|
}</span>
|
|
|
|
// NewNoOpTracerProvider creates a tracer provider that discards all spans
|
|
func NewNoOpTracerProvider() *sdktrace.TracerProvider <span class="cov0" title="0">{
|
|
return sdktrace.NewTracerProvider(
|
|
sdktrace.WithSpanProcessor(sdktrace.NewSimpleSpanProcessor(&noOpExporter{})),
|
|
)
|
|
}</span>
|
|
|
|
// noOpExporter is an exporter that discards all spans
|
|
type noOpExporter struct{}
|
|
|
|
func (e *noOpExporter) ExportSpans(context.Context, []sdktrace.ReadOnlySpan) error <span class="cov0" title="0">{
|
|
return nil
|
|
}</span>
|
|
|
|
func (e *noOpExporter) Shutdown(context.Context) error <span class="cov0" title="0">{
|
|
return nil
|
|
}</span>
|
|
|
|
// ShutdownTracer is a helper to safely shutdown a tracer provider
|
|
func ShutdownTracer(tp *sdktrace.TracerProvider) error <span class="cov8" title="1">{
|
|
if tp != nil </span><span class="cov8" title="1">{
|
|
return tp.Shutdown(context.Background())
|
|
}</span>
|
|
<span class="cov0" title="0">return nil</span>
|
|
}
|
|
|
|
// NewTestExporter creates a test exporter that writes to the provided writer
|
|
type TestExporter struct {
|
|
writer io.Writer
|
|
}
|
|
|
|
func (e *TestExporter) ExportSpans(ctx context.Context, spans []sdktrace.ReadOnlySpan) error <span class="cov0" title="0">{
|
|
return nil
|
|
}</span>
|
|
|
|
func (e *TestExporter) Shutdown(ctx context.Context) error <span class="cov0" title="0">{
|
|
return nil
|
|
}</span>
|
|
</pre>
|
|
|
|
<pre class="file" id="file15" style="display: none">package observability
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"github.com/ajac-zero/latticelm/internal/config"
|
|
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
|
|
"go.opentelemetry.io/otel/exporters/stdout/stdouttrace"
|
|
"go.opentelemetry.io/otel/sdk/resource"
|
|
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
|
semconv "go.opentelemetry.io/otel/semconv/v1.24.0"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/credentials/insecure"
|
|
)
|
|
|
|
// InitTracer initializes the OpenTelemetry tracer provider.
|
|
func InitTracer(cfg config.TracingConfig) (*sdktrace.TracerProvider, error) <span class="cov8" title="1">{
|
|
// Create resource with service information
|
|
res, err := resource.Merge(
|
|
resource.Default(),
|
|
resource.NewWithAttributes(
|
|
semconv.SchemaURL,
|
|
semconv.ServiceName(cfg.ServiceName),
|
|
),
|
|
)
|
|
if err != nil </span><span class="cov8" title="1">{
|
|
return nil, fmt.Errorf("failed to create resource: %w", err)
|
|
}</span>
|
|
|
|
// Create exporter
|
|
<span class="cov0" title="0">var exporter sdktrace.SpanExporter
|
|
switch cfg.Exporter.Type </span>{
|
|
case "otlp":<span class="cov0" title="0">
|
|
exporter, err = createOTLPExporter(cfg.Exporter)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
return nil, fmt.Errorf("failed to create OTLP exporter: %w", err)
|
|
}</span>
|
|
case "stdout":<span class="cov0" title="0">
|
|
exporter, err = stdouttrace.New(
|
|
stdouttrace.WithPrettyPrint(),
|
|
)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
return nil, fmt.Errorf("failed to create stdout exporter: %w", err)
|
|
}</span>
|
|
default:<span class="cov0" title="0">
|
|
return nil, fmt.Errorf("unsupported exporter type: %s", cfg.Exporter.Type)</span>
|
|
}
|
|
|
|
// Create sampler
|
|
<span class="cov0" title="0">sampler := createSampler(cfg.Sampler)
|
|
|
|
// Create tracer provider
|
|
tp := sdktrace.NewTracerProvider(
|
|
sdktrace.WithBatcher(exporter),
|
|
sdktrace.WithResource(res),
|
|
sdktrace.WithSampler(sampler),
|
|
)
|
|
|
|
return tp, nil</span>
|
|
}
|
|
|
|
// createOTLPExporter creates an OTLP gRPC exporter.
|
|
func createOTLPExporter(cfg config.ExporterConfig) (sdktrace.SpanExporter, error) <span class="cov0" title="0">{
|
|
opts := []otlptracegrpc.Option{
|
|
otlptracegrpc.WithEndpoint(cfg.Endpoint),
|
|
}
|
|
|
|
if cfg.Insecure </span><span class="cov0" title="0">{
|
|
opts = append(opts, otlptracegrpc.WithTLSCredentials(insecure.NewCredentials()))
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">if len(cfg.Headers) > 0 </span><span class="cov0" title="0">{
|
|
opts = append(opts, otlptracegrpc.WithHeaders(cfg.Headers))
|
|
}</span>
|
|
|
|
// Add dial options to ensure connection
|
|
<span class="cov0" title="0">opts = append(opts, otlptracegrpc.WithDialOption(grpc.WithBlock()))
|
|
|
|
return otlptracegrpc.New(context.Background(), opts...)</span>
|
|
}
|
|
|
|
// createSampler creates a sampler based on the configuration.
|
|
func createSampler(cfg config.SamplerConfig) sdktrace.Sampler <span class="cov8" title="1">{
|
|
switch cfg.Type </span>{
|
|
case "always":<span class="cov8" title="1">
|
|
return sdktrace.AlwaysSample()</span>
|
|
case "never":<span class="cov8" title="1">
|
|
return sdktrace.NeverSample()</span>
|
|
case "probability":<span class="cov8" title="1">
|
|
return sdktrace.TraceIDRatioBased(cfg.Rate)</span>
|
|
default:<span class="cov8" title="1">
|
|
// Default to 10% sampling
|
|
return sdktrace.TraceIDRatioBased(0.1)</span>
|
|
}
|
|
}
|
|
|
|
// Shutdown gracefully shuts down the tracer provider.
|
|
func Shutdown(ctx context.Context, tp *sdktrace.TracerProvider) error <span class="cov8" title="1">{
|
|
if tp == nil </span><span class="cov8" title="1">{
|
|
return nil
|
|
}</span>
|
|
<span class="cov8" title="1">return tp.Shutdown(ctx)</span>
|
|
}
|
|
</pre>
|
|
|
|
<pre class="file" id="file16" style="display: none">package observability
|
|
|
|
import (
|
|
"net/http"
|
|
|
|
"go.opentelemetry.io/otel"
|
|
"go.opentelemetry.io/otel/attribute"
|
|
"go.opentelemetry.io/otel/codes"
|
|
"go.opentelemetry.io/otel/propagation"
|
|
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
|
"go.opentelemetry.io/otel/trace"
|
|
)
|
|
|
|
// TracingMiddleware creates a middleware that adds OpenTelemetry tracing to HTTP requests.
|
|
func TracingMiddleware(next http.Handler, tp *sdktrace.TracerProvider) http.Handler <span class="cov0" title="0">{
|
|
if tp == nil </span><span class="cov0" title="0">{
|
|
// If tracing is not enabled, pass through without modification
|
|
return next
|
|
}</span>
|
|
|
|
// Set up W3C Trace Context propagation
|
|
<span class="cov0" title="0">otel.SetTextMapPropagator(propagation.TraceContext{})
|
|
|
|
tracer := tp.Tracer("llm-gateway")
|
|
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) </span><span class="cov0" title="0">{
|
|
// Extract trace context from incoming request headers
|
|
ctx := otel.GetTextMapPropagator().Extract(r.Context(), propagation.HeaderCarrier(r.Header))
|
|
|
|
// Start a new span
|
|
ctx, span := tracer.Start(ctx, "HTTP "+r.Method+" "+r.URL.Path,
|
|
trace.WithSpanKind(trace.SpanKindServer),
|
|
trace.WithAttributes(
|
|
attribute.String("http.method", r.Method),
|
|
attribute.String("http.route", r.URL.Path),
|
|
attribute.String("http.scheme", r.URL.Scheme),
|
|
attribute.String("http.host", r.Host),
|
|
attribute.String("http.user_agent", r.Header.Get("User-Agent")),
|
|
),
|
|
)
|
|
defer span.End()
|
|
|
|
// Add request ID to span if present
|
|
if requestID := r.Header.Get("X-Request-ID"); requestID != "" </span><span class="cov0" title="0">{
|
|
span.SetAttributes(attribute.String("http.request_id", requestID))
|
|
}</span>
|
|
|
|
// Create a response writer wrapper to capture status code
|
|
<span class="cov0" title="0">wrapped := &statusResponseWriter{
|
|
ResponseWriter: w,
|
|
statusCode: http.StatusOK,
|
|
}
|
|
|
|
// Inject trace context into request for downstream services
|
|
r = r.WithContext(ctx)
|
|
|
|
// Call the next handler
|
|
next.ServeHTTP(wrapped, r)
|
|
|
|
// Record the status code in the span
|
|
span.SetAttributes(attribute.Int("http.status_code", wrapped.statusCode))
|
|
|
|
// Set span status based on HTTP status code
|
|
if wrapped.statusCode >= 400 </span><span class="cov0" title="0">{
|
|
span.SetStatus(codes.Error, http.StatusText(wrapped.statusCode))
|
|
}</span> else<span class="cov0" title="0"> {
|
|
span.SetStatus(codes.Ok, "")
|
|
}</span>
|
|
})
|
|
}
|
|
|
|
// statusResponseWriter wraps http.ResponseWriter to capture the status code.
|
|
type statusResponseWriter struct {
|
|
http.ResponseWriter
|
|
statusCode int
|
|
}
|
|
|
|
func (w *statusResponseWriter) WriteHeader(statusCode int) <span class="cov0" title="0">{
|
|
w.statusCode = statusCode
|
|
w.ResponseWriter.WriteHeader(statusCode)
|
|
}</span>
|
|
|
|
func (w *statusResponseWriter) Write(b []byte) (int, error) <span class="cov0" title="0">{
|
|
return w.ResponseWriter.Write(b)
|
|
}</span>
|
|
</pre>
|
|
|
|
<pre class="file" id="file17" style="display: none">package anthropic
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
|
|
"github.com/anthropics/anthropic-sdk-go"
|
|
"github.com/anthropics/anthropic-sdk-go/option"
|
|
|
|
"github.com/ajac-zero/latticelm/internal/api"
|
|
"github.com/ajac-zero/latticelm/internal/config"
|
|
)
|
|
|
|
const Name = "anthropic"
|
|
|
|
// Provider implements the Anthropic SDK integration.
|
|
// It supports both direct Anthropic API and Azure-hosted (Microsoft Foundry) endpoints.
|
|
type Provider struct {
|
|
cfg config.ProviderConfig
|
|
client *anthropic.Client
|
|
azure bool
|
|
}
|
|
|
|
// New constructs a Provider for the direct Anthropic API.
|
|
func New(cfg config.ProviderConfig) *Provider <span class="cov0" title="0">{
|
|
var client *anthropic.Client
|
|
if cfg.APIKey != "" </span><span class="cov0" title="0">{
|
|
c := anthropic.NewClient(option.WithAPIKey(cfg.APIKey))
|
|
client = &c
|
|
}</span>
|
|
<span class="cov0" title="0">return &Provider{
|
|
cfg: cfg,
|
|
client: client,
|
|
}</span>
|
|
}
|
|
|
|
// NewAzure constructs a Provider targeting Azure-hosted Anthropic (Microsoft Foundry).
|
|
// The Azure endpoint uses api-key header auth and a base URL like
|
|
// https://<resource>.services.ai.azure.com/anthropic.
|
|
func NewAzure(azureCfg config.AzureAnthropicConfig) *Provider <span class="cov0" title="0">{
|
|
var client *anthropic.Client
|
|
if azureCfg.APIKey != "" && azureCfg.Endpoint != "" </span><span class="cov0" title="0">{
|
|
c := anthropic.NewClient(
|
|
option.WithBaseURL(azureCfg.Endpoint),
|
|
option.WithAPIKey("unused"),
|
|
option.WithAuthToken(azureCfg.APIKey),
|
|
)
|
|
client = &c
|
|
}</span>
|
|
<span class="cov0" title="0">return &Provider{
|
|
cfg: config.ProviderConfig{
|
|
APIKey: azureCfg.APIKey,
|
|
Model: azureCfg.Model,
|
|
},
|
|
client: client,
|
|
azure: true,
|
|
}</span>
|
|
}
|
|
|
|
func (p *Provider) Name() string <span class="cov0" title="0">{ return Name }</span>
|
|
|
|
// Generate routes the request to Anthropic's API.
|
|
func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (*api.ProviderResult, error) <span class="cov0" title="0">{
|
|
if p.cfg.APIKey == "" </span><span class="cov0" title="0">{
|
|
return nil, fmt.Errorf("anthropic api key missing")
|
|
}</span>
|
|
<span class="cov0" title="0">if p.client == nil </span><span class="cov0" title="0">{
|
|
return nil, fmt.Errorf("anthropic client not initialized")
|
|
}</span>
|
|
|
|
// Convert messages to Anthropic format
|
|
<span class="cov0" title="0">anthropicMsgs := make([]anthropic.MessageParam, 0, len(messages))
|
|
var system string
|
|
|
|
for _, msg := range messages </span><span class="cov0" title="0">{
|
|
var content string
|
|
for _, block := range msg.Content </span><span class="cov0" title="0">{
|
|
if block.Type == "input_text" || block.Type == "output_text" </span><span class="cov0" title="0">{
|
|
content += block.Text
|
|
}</span>
|
|
}
|
|
|
|
<span class="cov0" title="0">switch msg.Role </span>{
|
|
case "user":<span class="cov0" title="0">
|
|
anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage(anthropic.NewTextBlock(content)))</span>
|
|
case "assistant":<span class="cov0" title="0">
|
|
// Build content blocks including text and tool calls
|
|
var contentBlocks []anthropic.ContentBlockParamUnion
|
|
if content != "" </span><span class="cov0" title="0">{
|
|
contentBlocks = append(contentBlocks, anthropic.NewTextBlock(content))
|
|
}</span>
|
|
// Add tool use blocks
|
|
<span class="cov0" title="0">for _, tc := range msg.ToolCalls </span><span class="cov0" title="0">{
|
|
var input map[string]interface{}
|
|
if err := json.Unmarshal([]byte(tc.Arguments), &input); err != nil </span><span class="cov0" title="0">{
|
|
// If unmarshal fails, skip this tool call
|
|
continue</span>
|
|
}
|
|
<span class="cov0" title="0">contentBlocks = append(contentBlocks, anthropic.NewToolUseBlock(tc.ID, input, tc.Name))</span>
|
|
}
|
|
<span class="cov0" title="0">if len(contentBlocks) > 0 </span><span class="cov0" title="0">{
|
|
anthropicMsgs = append(anthropicMsgs, anthropic.NewAssistantMessage(contentBlocks...))
|
|
}</span>
|
|
case "tool":<span class="cov0" title="0">
|
|
// Tool results must be in user message with tool_result blocks
|
|
anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage(
|
|
anthropic.NewToolResultBlock(msg.CallID, content, false),
|
|
))</span>
|
|
case "system", "developer":<span class="cov0" title="0">
|
|
system = content</span>
|
|
}
|
|
}
|
|
|
|
// Build request params
|
|
<span class="cov0" title="0">maxTokens := int64(4096)
|
|
if req.MaxOutputTokens != nil </span><span class="cov0" title="0">{
|
|
maxTokens = int64(*req.MaxOutputTokens)
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">params := anthropic.MessageNewParams{
|
|
Model: anthropic.Model(req.Model),
|
|
Messages: anthropicMsgs,
|
|
MaxTokens: maxTokens,
|
|
}
|
|
|
|
if system != "" </span><span class="cov0" title="0">{
|
|
systemBlocks := []anthropic.TextBlockParam{
|
|
{Text: system, Type: "text"},
|
|
}
|
|
params.System = systemBlocks
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">if req.Temperature != nil </span><span class="cov0" title="0">{
|
|
params.Temperature = anthropic.Float(*req.Temperature)
|
|
}</span>
|
|
<span class="cov0" title="0">if req.TopP != nil </span><span class="cov0" title="0">{
|
|
params.TopP = anthropic.Float(*req.TopP)
|
|
}</span>
|
|
|
|
// Add tools if present
|
|
<span class="cov0" title="0">if req.Tools != nil && len(req.Tools) > 0 </span><span class="cov0" title="0">{
|
|
tools, err := parseTools(req)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
return nil, fmt.Errorf("parse tools: %w", err)
|
|
}</span>
|
|
<span class="cov0" title="0">params.Tools = tools</span>
|
|
}
|
|
|
|
// Add tool_choice if present
|
|
<span class="cov0" title="0">if req.ToolChoice != nil && len(req.ToolChoice) > 0 </span><span class="cov0" title="0">{
|
|
toolChoice, err := parseToolChoice(req)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
return nil, fmt.Errorf("parse tool_choice: %w", err)
|
|
}</span>
|
|
<span class="cov0" title="0">params.ToolChoice = toolChoice</span>
|
|
}
|
|
|
|
// Call Anthropic API
|
|
<span class="cov0" title="0">resp, err := p.client.Messages.New(ctx, params)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
return nil, fmt.Errorf("anthropic api error: %w", err)
|
|
}</span>
|
|
|
|
// Extract text and tool calls from response
|
|
<span class="cov0" title="0">var text string
|
|
var toolCalls []api.ToolCall
|
|
|
|
for _, block := range resp.Content </span><span class="cov0" title="0">{
|
|
switch block.Type </span>{
|
|
case "text":<span class="cov0" title="0">
|
|
text += block.AsText().Text</span>
|
|
case "tool_use":<span class="cov0" title="0">
|
|
// Extract tool calls
|
|
toolUse := block.AsToolUse()
|
|
argsJSON, _ := json.Marshal(toolUse.Input)
|
|
toolCalls = append(toolCalls, api.ToolCall{
|
|
ID: toolUse.ID,
|
|
Name: toolUse.Name,
|
|
Arguments: string(argsJSON),
|
|
})</span>
|
|
}
|
|
}
|
|
|
|
<span class="cov0" title="0">return &api.ProviderResult{
|
|
ID: resp.ID,
|
|
Model: string(resp.Model),
|
|
Text: text,
|
|
ToolCalls: toolCalls,
|
|
Usage: api.Usage{
|
|
InputTokens: int(resp.Usage.InputTokens),
|
|
OutputTokens: int(resp.Usage.OutputTokens),
|
|
TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens),
|
|
},
|
|
}, nil</span>
|
|
}
|
|
|
|
// GenerateStream handles streaming requests to Anthropic.
|
|
func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (<-chan *api.ProviderStreamDelta, <-chan error) <span class="cov0" title="0">{
|
|
deltaChan := make(chan *api.ProviderStreamDelta)
|
|
errChan := make(chan error, 1)
|
|
|
|
go func() </span><span class="cov0" title="0">{
|
|
defer close(deltaChan)
|
|
defer close(errChan)
|
|
|
|
if p.cfg.APIKey == "" </span><span class="cov0" title="0">{
|
|
errChan <- fmt.Errorf("anthropic api key missing")
|
|
return
|
|
}</span>
|
|
<span class="cov0" title="0">if p.client == nil </span><span class="cov0" title="0">{
|
|
errChan <- fmt.Errorf("anthropic client not initialized")
|
|
return
|
|
}</span>
|
|
|
|
// Convert messages to Anthropic format
|
|
<span class="cov0" title="0">anthropicMsgs := make([]anthropic.MessageParam, 0, len(messages))
|
|
var system string
|
|
|
|
for _, msg := range messages </span><span class="cov0" title="0">{
|
|
var content string
|
|
for _, block := range msg.Content </span><span class="cov0" title="0">{
|
|
if block.Type == "input_text" || block.Type == "output_text" </span><span class="cov0" title="0">{
|
|
content += block.Text
|
|
}</span>
|
|
}
|
|
|
|
<span class="cov0" title="0">switch msg.Role </span>{
|
|
case "user":<span class="cov0" title="0">
|
|
anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage(anthropic.NewTextBlock(content)))</span>
|
|
case "assistant":<span class="cov0" title="0">
|
|
// Build content blocks including text and tool calls
|
|
var contentBlocks []anthropic.ContentBlockParamUnion
|
|
if content != "" </span><span class="cov0" title="0">{
|
|
contentBlocks = append(contentBlocks, anthropic.NewTextBlock(content))
|
|
}</span>
|
|
// Add tool use blocks
|
|
<span class="cov0" title="0">for _, tc := range msg.ToolCalls </span><span class="cov0" title="0">{
|
|
var input map[string]interface{}
|
|
if err := json.Unmarshal([]byte(tc.Arguments), &input); err != nil </span><span class="cov0" title="0">{
|
|
// If unmarshal fails, skip this tool call
|
|
continue</span>
|
|
}
|
|
<span class="cov0" title="0">contentBlocks = append(contentBlocks, anthropic.NewToolUseBlock(tc.ID, input, tc.Name))</span>
|
|
}
|
|
<span class="cov0" title="0">if len(contentBlocks) > 0 </span><span class="cov0" title="0">{
|
|
anthropicMsgs = append(anthropicMsgs, anthropic.NewAssistantMessage(contentBlocks...))
|
|
}</span>
|
|
case "tool":<span class="cov0" title="0">
|
|
// Tool results must be in user message with tool_result blocks
|
|
anthropicMsgs = append(anthropicMsgs, anthropic.NewUserMessage(
|
|
anthropic.NewToolResultBlock(msg.CallID, content, false),
|
|
))</span>
|
|
case "system", "developer":<span class="cov0" title="0">
|
|
system = content</span>
|
|
}
|
|
}
|
|
|
|
// Build params
|
|
<span class="cov0" title="0">maxTokens := int64(4096)
|
|
if req.MaxOutputTokens != nil </span><span class="cov0" title="0">{
|
|
maxTokens = int64(*req.MaxOutputTokens)
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">params := anthropic.MessageNewParams{
|
|
Model: anthropic.Model(req.Model),
|
|
Messages: anthropicMsgs,
|
|
MaxTokens: maxTokens,
|
|
}
|
|
|
|
if system != "" </span><span class="cov0" title="0">{
|
|
systemBlocks := []anthropic.TextBlockParam{
|
|
{Text: system, Type: "text"},
|
|
}
|
|
params.System = systemBlocks
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">if req.Temperature != nil </span><span class="cov0" title="0">{
|
|
params.Temperature = anthropic.Float(*req.Temperature)
|
|
}</span>
|
|
<span class="cov0" title="0">if req.TopP != nil </span><span class="cov0" title="0">{
|
|
params.TopP = anthropic.Float(*req.TopP)
|
|
}</span>
|
|
|
|
// Add tools if present
|
|
<span class="cov0" title="0">if req.Tools != nil && len(req.Tools) > 0 </span><span class="cov0" title="0">{
|
|
tools, err := parseTools(req)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
errChan <- fmt.Errorf("parse tools: %w", err)
|
|
return
|
|
}</span>
|
|
<span class="cov0" title="0">params.Tools = tools</span>
|
|
}
|
|
|
|
// Add tool_choice if present
|
|
<span class="cov0" title="0">if req.ToolChoice != nil && len(req.ToolChoice) > 0 </span><span class="cov0" title="0">{
|
|
toolChoice, err := parseToolChoice(req)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
errChan <- fmt.Errorf("parse tool_choice: %w", err)
|
|
return
|
|
}</span>
|
|
<span class="cov0" title="0">params.ToolChoice = toolChoice</span>
|
|
}
|
|
|
|
// Create stream
|
|
<span class="cov0" title="0">stream := p.client.Messages.NewStreaming(ctx, params)
|
|
|
|
// Track content block index and tool call state
|
|
var contentBlockIndex int
|
|
|
|
// Process stream
|
|
for stream.Next() </span><span class="cov0" title="0">{
|
|
event := stream.Current()
|
|
|
|
switch event.Type </span>{
|
|
case "content_block_start":<span class="cov0" title="0">
|
|
// New content block (text or tool_use)
|
|
contentBlockIndex = int(event.Index)
|
|
if event.ContentBlock.Type == "tool_use" </span><span class="cov0" title="0">{
|
|
// Send tool call delta with ID and name
|
|
toolUse := event.ContentBlock.AsToolUse()
|
|
delta := &api.ToolCallDelta{
|
|
Index: contentBlockIndex,
|
|
ID: toolUse.ID,
|
|
Name: toolUse.Name,
|
|
}
|
|
select </span>{
|
|
case deltaChan <- &api.ProviderStreamDelta{ToolCallDelta: delta}:<span class="cov0" title="0"></span>
|
|
case <-ctx.Done():<span class="cov0" title="0">
|
|
errChan <- ctx.Err()
|
|
return</span>
|
|
}
|
|
}
|
|
|
|
case "content_block_delta":<span class="cov0" title="0">
|
|
if event.Delta.Type == "text_delta" </span><span class="cov0" title="0">{
|
|
// Text streaming
|
|
select </span>{
|
|
case deltaChan <- &api.ProviderStreamDelta{Text: event.Delta.Text}:<span class="cov0" title="0"></span>
|
|
case <-ctx.Done():<span class="cov0" title="0">
|
|
errChan <- ctx.Err()
|
|
return</span>
|
|
}
|
|
} else<span class="cov0" title="0"> if event.Delta.Type == "input_json_delta" </span><span class="cov0" title="0">{
|
|
// Tool arguments streaming
|
|
delta := &api.ToolCallDelta{
|
|
Index: int(event.Index),
|
|
Arguments: event.Delta.PartialJSON,
|
|
}
|
|
select </span>{
|
|
case deltaChan <- &api.ProviderStreamDelta{ToolCallDelta: delta}:<span class="cov0" title="0"></span>
|
|
case <-ctx.Done():<span class="cov0" title="0">
|
|
errChan <- ctx.Err()
|
|
return</span>
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
<span class="cov0" title="0">if err := stream.Err(); err != nil </span><span class="cov0" title="0">{
|
|
errChan <- fmt.Errorf("anthropic stream error: %w", err)
|
|
return
|
|
}</span>
|
|
|
|
// Send final delta
|
|
<span class="cov0" title="0">select </span>{
|
|
case deltaChan <- &api.ProviderStreamDelta{Done: true}:<span class="cov0" title="0"></span>
|
|
case <-ctx.Done():<span class="cov0" title="0">
|
|
errChan <- ctx.Err()</span>
|
|
}
|
|
}()
|
|
|
|
<span class="cov0" title="0">return deltaChan, errChan</span>
|
|
}
|
|
|
|
func chooseModel(requested, defaultModel string) string <span class="cov0" title="0">{
|
|
if requested != "" </span><span class="cov0" title="0">{
|
|
return requested
|
|
}</span>
|
|
<span class="cov0" title="0">if defaultModel != "" </span><span class="cov0" title="0">{
|
|
return defaultModel
|
|
}</span>
|
|
<span class="cov0" title="0">return "claude-3-5-sonnet"</span>
|
|
}
|
|
</pre>
|
|
|
|
<pre class="file" id="file18" style="display: none">package anthropic
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
|
|
"github.com/ajac-zero/latticelm/internal/api"
|
|
"github.com/anthropics/anthropic-sdk-go"
|
|
)
|
|
|
|
// parseTools converts Open Responses tools to Anthropic format
|
|
func parseTools(req *api.ResponseRequest) ([]anthropic.ToolUnionParam, error) <span class="cov8" title="1">{
|
|
if req.Tools == nil || len(req.Tools) == 0 </span><span class="cov0" title="0">{
|
|
return nil, nil
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">var toolDefs []map[string]interface{}
|
|
if err := json.Unmarshal(req.Tools, &toolDefs); err != nil </span><span class="cov0" title="0">{
|
|
return nil, fmt.Errorf("unmarshal tools: %w", err)
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">var tools []anthropic.ToolUnionParam
|
|
for _, td := range toolDefs </span><span class="cov8" title="1">{
|
|
// Extract: name, description, parameters
|
|
// Note: Anthropic uses "input_schema" instead of "parameters"
|
|
name, _ := td["name"].(string)
|
|
desc, _ := td["description"].(string)
|
|
params, _ := td["parameters"].(map[string]interface{})
|
|
|
|
inputSchema := anthropic.ToolInputSchemaParam{
|
|
Type: "object",
|
|
Properties: params["properties"],
|
|
}
|
|
|
|
// Add required fields if present
|
|
if required, ok := params["required"].([]interface{}); ok </span><span class="cov8" title="1">{
|
|
requiredStrs := make([]string, 0, len(required))
|
|
for _, r := range required </span><span class="cov8" title="1">{
|
|
if str, ok := r.(string); ok </span><span class="cov8" title="1">{
|
|
requiredStrs = append(requiredStrs, str)
|
|
}</span>
|
|
}
|
|
<span class="cov8" title="1">inputSchema.Required = requiredStrs</span>
|
|
}
|
|
|
|
// Create the tool using ToolUnionParamOfTool
|
|
<span class="cov8" title="1">tool := anthropic.ToolUnionParamOfTool(inputSchema, name)
|
|
|
|
if desc != "" </span><span class="cov8" title="1">{
|
|
tool.OfTool.Description = anthropic.String(desc)
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">tools = append(tools, tool)</span>
|
|
}
|
|
|
|
<span class="cov8" title="1">return tools, nil</span>
|
|
}
|
|
|
|
// parseToolChoice converts Open Responses tool_choice to Anthropic format
|
|
func parseToolChoice(req *api.ResponseRequest) (anthropic.ToolChoiceUnionParam, error) <span class="cov8" title="1">{
|
|
var result anthropic.ToolChoiceUnionParam
|
|
|
|
if req.ToolChoice == nil || len(req.ToolChoice) == 0 </span><span class="cov0" title="0">{
|
|
return result, nil
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">var choice interface{}
|
|
if err := json.Unmarshal(req.ToolChoice, &choice); err != nil </span><span class="cov0" title="0">{
|
|
return result, fmt.Errorf("unmarshal tool_choice: %w", err)
|
|
}</span>
|
|
|
|
// Handle string values: "auto", "any", "required"
|
|
<span class="cov8" title="1">if str, ok := choice.(string); ok </span><span class="cov8" title="1">{
|
|
switch str </span>{
|
|
case "auto":<span class="cov8" title="1">
|
|
result.OfAuto = &anthropic.ToolChoiceAutoParam{
|
|
Type: "auto",
|
|
}</span>
|
|
case "any", "required":<span class="cov8" title="1">
|
|
result.OfAny = &anthropic.ToolChoiceAnyParam{
|
|
Type: "any",
|
|
}</span>
|
|
case "none":<span class="cov0" title="0">
|
|
result.OfNone = &anthropic.ToolChoiceNoneParam{
|
|
Type: "none",
|
|
}</span>
|
|
default:<span class="cov0" title="0">
|
|
return result, fmt.Errorf("unknown tool_choice string: %s", str)</span>
|
|
}
|
|
<span class="cov8" title="1">return result, nil</span>
|
|
}
|
|
|
|
// Handle specific tool selection: {"type": "tool", "function": {"name": "..."}}
|
|
<span class="cov8" title="1">if obj, ok := choice.(map[string]interface{}); ok </span><span class="cov8" title="1">{
|
|
// Check for OpenAI format: {"type": "function", "function": {"name": "..."}}
|
|
if funcObj, ok := obj["function"].(map[string]interface{}); ok </span><span class="cov8" title="1">{
|
|
if name, ok := funcObj["name"].(string); ok </span><span class="cov8" title="1">{
|
|
result.OfTool = &anthropic.ToolChoiceToolParam{
|
|
Type: "tool",
|
|
Name: name,
|
|
}
|
|
return result, nil
|
|
}</span>
|
|
}
|
|
|
|
// Check for direct name field
|
|
<span class="cov0" title="0">if name, ok := obj["name"].(string); ok </span><span class="cov0" title="0">{
|
|
result.OfTool = &anthropic.ToolChoiceToolParam{
|
|
Type: "tool",
|
|
Name: name,
|
|
}
|
|
return result, nil
|
|
}</span>
|
|
}
|
|
|
|
<span class="cov0" title="0">return result, fmt.Errorf("invalid tool_choice format")</span>
|
|
}
|
|
|
|
// extractToolCalls converts Anthropic content blocks to api.ToolCall
|
|
func extractToolCalls(content []anthropic.ContentBlockUnion) []api.ToolCall <span class="cov0" title="0">{
|
|
var toolCalls []api.ToolCall
|
|
|
|
for _, block := range content </span><span class="cov0" title="0">{
|
|
// Check if this is a tool_use block
|
|
if block.Type == "tool_use" </span><span class="cov0" title="0">{
|
|
// Cast to ToolUseBlock to access the fields
|
|
toolUse := block.AsToolUse()
|
|
|
|
// Marshal the input to JSON string for Arguments
|
|
argsJSON, _ := json.Marshal(toolUse.Input)
|
|
|
|
toolCalls = append(toolCalls, api.ToolCall{
|
|
ID: toolUse.ID,
|
|
Name: toolUse.Name,
|
|
Arguments: string(argsJSON),
|
|
})
|
|
}</span>
|
|
}
|
|
|
|
<span class="cov0" title="0">return toolCalls</span>
|
|
}
|
|
|
|
// extractToolCallDelta extracts tool call delta from streaming content block delta
|
|
func extractToolCallDelta(delta anthropic.RawContentBlockDeltaUnion, index int) *api.ToolCallDelta <span class="cov0" title="0">{
|
|
// Check if this is an input_json_delta (streaming tool arguments)
|
|
if delta.Type == "input_json_delta" </span><span class="cov0" title="0">{
|
|
return &api.ToolCallDelta{
|
|
Index: index,
|
|
Arguments: delta.PartialJSON,
|
|
}
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">return nil</span>
|
|
}
|
|
</pre>
|
|
|
|
<pre class="file" id="file19" style="display: none">package providers
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/sony/gobreaker"
|
|
|
|
"github.com/ajac-zero/latticelm/internal/api"
|
|
)
|
|
|
|
// CircuitBreakerProvider wraps a Provider with circuit breaker functionality.
|
|
type CircuitBreakerProvider struct {
|
|
provider Provider
|
|
cb *gobreaker.CircuitBreaker
|
|
}
|
|
|
|
// CircuitBreakerConfig holds configuration for the circuit breaker.
|
|
type CircuitBreakerConfig struct {
|
|
// MaxRequests is the maximum number of requests allowed to pass through
|
|
// when the circuit breaker is half-open. Default: 3
|
|
MaxRequests uint32
|
|
|
|
// Interval is the cyclic period of the closed state for the circuit breaker
|
|
// to clear the internal Counts. Default: 30s
|
|
Interval time.Duration
|
|
|
|
// Timeout is the period of the open state, after which the state becomes half-open.
|
|
// Default: 60s
|
|
Timeout time.Duration
|
|
|
|
// MinRequests is the minimum number of requests needed before evaluating failure ratio.
|
|
// Default: 5
|
|
MinRequests uint32
|
|
|
|
// FailureRatio is the ratio of failures that will trip the circuit breaker.
|
|
// Default: 0.5 (50%)
|
|
FailureRatio float64
|
|
|
|
// OnStateChange is an optional callback invoked when circuit breaker state changes.
|
|
// Parameters: provider name, from state, to state
|
|
OnStateChange func(provider, from, to string)
|
|
}
|
|
|
|
// DefaultCircuitBreakerConfig returns a sensible default configuration.
|
|
func DefaultCircuitBreakerConfig() CircuitBreakerConfig <span class="cov8" title="1">{
|
|
return CircuitBreakerConfig{
|
|
MaxRequests: 3,
|
|
Interval: 30 * time.Second,
|
|
Timeout: 60 * time.Second,
|
|
MinRequests: 5,
|
|
FailureRatio: 0.5,
|
|
}
|
|
}</span>
|
|
|
|
// NewCircuitBreakerProvider wraps a provider with circuit breaker functionality.
|
|
func NewCircuitBreakerProvider(provider Provider, cfg CircuitBreakerConfig) *CircuitBreakerProvider <span class="cov8" title="1">{
|
|
providerName := provider.Name()
|
|
|
|
settings := gobreaker.Settings{
|
|
Name: fmt.Sprintf("%s-circuit-breaker", providerName),
|
|
MaxRequests: cfg.MaxRequests,
|
|
Interval: cfg.Interval,
|
|
Timeout: cfg.Timeout,
|
|
ReadyToTrip: func(counts gobreaker.Counts) bool </span><span class="cov0" title="0">{
|
|
// Only trip if we have enough requests to be statistically meaningful
|
|
if counts.Requests < cfg.MinRequests </span><span class="cov0" title="0">{
|
|
return false
|
|
}</span>
|
|
<span class="cov0" title="0">failureRatio := float64(counts.TotalFailures) / float64(counts.Requests)
|
|
return failureRatio >= cfg.FailureRatio</span>
|
|
},
|
|
OnStateChange: func(name string, from gobreaker.State, to gobreaker.State) <span class="cov0" title="0">{
|
|
// Call the callback if provided
|
|
if cfg.OnStateChange != nil </span><span class="cov0" title="0">{
|
|
cfg.OnStateChange(providerName, from.String(), to.String())
|
|
}</span>
|
|
},
|
|
}
|
|
|
|
<span class="cov8" title="1">return &CircuitBreakerProvider{
|
|
provider: provider,
|
|
cb: gobreaker.NewCircuitBreaker(settings),
|
|
}</span>
|
|
}
|
|
|
|
// Name returns the underlying provider name.
|
|
func (p *CircuitBreakerProvider) Name() string <span class="cov8" title="1">{
|
|
return p.provider.Name()
|
|
}</span>
|
|
|
|
// Generate wraps the provider's Generate method with circuit breaker protection.
|
|
func (p *CircuitBreakerProvider) Generate(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (*api.ProviderResult, error) <span class="cov0" title="0">{
|
|
result, err := p.cb.Execute(func() (interface{}, error) </span><span class="cov0" title="0">{
|
|
return p.provider.Generate(ctx, messages, req)
|
|
}</span>)
|
|
|
|
<span class="cov0" title="0">if err != nil </span><span class="cov0" title="0">{
|
|
return nil, err
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">return result.(*api.ProviderResult), nil</span>
|
|
}
|
|
|
|
// GenerateStream wraps the provider's GenerateStream method with circuit breaker protection.
|
|
func (p *CircuitBreakerProvider) GenerateStream(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (<-chan *api.ProviderStreamDelta, <-chan error) <span class="cov0" title="0">{
|
|
// For streaming, we check the circuit breaker state before initiating the stream
|
|
// If the circuit is open, we return an error immediately
|
|
state := p.cb.State()
|
|
if state == gobreaker.StateOpen </span><span class="cov0" title="0">{
|
|
errChan := make(chan error, 1)
|
|
deltaChan := make(chan *api.ProviderStreamDelta)
|
|
errChan <- gobreaker.ErrOpenState
|
|
close(deltaChan)
|
|
close(errChan)
|
|
return deltaChan, errChan
|
|
}</span>
|
|
|
|
// If circuit is closed or half-open, attempt the stream
|
|
<span class="cov0" title="0">deltaChan, errChan := p.provider.GenerateStream(ctx, messages, req)
|
|
|
|
// Wrap the error channel to report successes/failures to circuit breaker
|
|
wrappedErrChan := make(chan error, 1)
|
|
|
|
go func() </span><span class="cov0" title="0">{
|
|
defer close(wrappedErrChan)
|
|
|
|
// Wait for the error channel to signal completion
|
|
if err := <-errChan; err != nil </span><span class="cov0" title="0">{
|
|
// Record failure in circuit breaker
|
|
p.cb.Execute(func() (interface{}, error) </span><span class="cov0" title="0">{
|
|
return nil, err
|
|
}</span>)
|
|
<span class="cov0" title="0">wrappedErrChan <- err</span>
|
|
} else<span class="cov0" title="0"> {
|
|
// Record success in circuit breaker
|
|
p.cb.Execute(func() (interface{}, error) </span><span class="cov0" title="0">{
|
|
return nil, nil
|
|
}</span>)
|
|
}
|
|
}()
|
|
|
|
<span class="cov0" title="0">return deltaChan, wrappedErrChan</span>
|
|
}
|
|
</pre>
|
|
|
|
<pre class="file" id="file20" style="display: none">package google
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"math/rand"
|
|
"time"
|
|
|
|
"google.golang.org/genai"
|
|
|
|
"github.com/ajac-zero/latticelm/internal/api"
|
|
)
|
|
|
|
// parseTools converts generic tool definitions from req.Tools (JSON) to Google's []*genai.Tool format.
|
|
func parseTools(req *api.ResponseRequest) ([]*genai.Tool, error) <span class="cov8" title="1">{
|
|
if req.Tools == nil || len(req.Tools) == 0 </span><span class="cov8" title="1">{
|
|
return nil, nil
|
|
}</span>
|
|
|
|
// Unmarshal to slice of tool definitions
|
|
<span class="cov8" title="1">var toolDefs []map[string]interface{}
|
|
if err := json.Unmarshal(req.Tools, &toolDefs); err != nil </span><span class="cov8" title="1">{
|
|
return nil, fmt.Errorf("unmarshal tools: %w", err)
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">var functionDeclarations []*genai.FunctionDeclaration
|
|
|
|
for _, toolDef := range toolDefs </span><span class="cov8" title="1">{
|
|
// Extract function details
|
|
// Support both flat format (name/description/parameters at top level)
|
|
// and nested format (under "function" key)
|
|
var name, description string
|
|
var parameters interface{}
|
|
|
|
if functionData, ok := toolDef["function"].(map[string]interface{}); ok </span><span class="cov8" title="1">{
|
|
// Nested format: {"type": "function", "function": {...}}
|
|
name, _ = functionData["name"].(string)
|
|
description, _ = functionData["description"].(string)
|
|
parameters = functionData["parameters"]
|
|
}</span> else<span class="cov8" title="1"> {
|
|
// Flat format: {"type": "function", "name": "...", ...}
|
|
name, _ = toolDef["name"].(string)
|
|
description, _ = toolDef["description"].(string)
|
|
parameters = toolDef["parameters"]
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">if name == "" </span><span class="cov8" title="1">{
|
|
continue</span>
|
|
}
|
|
|
|
// Create function declaration
|
|
<span class="cov8" title="1">funcDecl := &genai.FunctionDeclaration{
|
|
Name: name,
|
|
Description: description,
|
|
}
|
|
|
|
// Google accepts parameters as raw JSON schema
|
|
if parameters != nil </span><span class="cov8" title="1">{
|
|
funcDecl.ParametersJsonSchema = parameters
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">functionDeclarations = append(functionDeclarations, funcDecl)</span>
|
|
}
|
|
|
|
// Return single Tool with all function declarations
|
|
<span class="cov8" title="1">if len(functionDeclarations) > 0 </span><span class="cov8" title="1">{
|
|
return []*genai.Tool{{FunctionDeclarations: functionDeclarations}}, nil
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">return nil, nil</span>
|
|
}
|
|
|
|
// parseToolChoice converts req.ToolChoice to Google's ToolConfig with FunctionCallingConfig.
|
|
func parseToolChoice(req *api.ResponseRequest) (*genai.ToolConfig, error) <span class="cov8" title="1">{
|
|
if req.ToolChoice == nil || len(req.ToolChoice) == 0 </span><span class="cov8" title="1">{
|
|
return nil, nil
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">var choice interface{}
|
|
if err := json.Unmarshal(req.ToolChoice, &choice); err != nil </span><span class="cov8" title="1">{
|
|
return nil, fmt.Errorf("unmarshal tool_choice: %w", err)
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">config := &genai.ToolConfig{
|
|
FunctionCallingConfig: &genai.FunctionCallingConfig{},
|
|
}
|
|
|
|
// Handle string values: "auto", "none", "required"/"any"
|
|
if str, ok := choice.(string); ok </span><span class="cov8" title="1">{
|
|
switch str </span>{
|
|
case "auto":<span class="cov8" title="1">
|
|
config.FunctionCallingConfig.Mode = genai.FunctionCallingConfigModeAuto</span>
|
|
case "none":<span class="cov8" title="1">
|
|
config.FunctionCallingConfig.Mode = genai.FunctionCallingConfigModeNone</span>
|
|
case "required", "any":<span class="cov8" title="1">
|
|
config.FunctionCallingConfig.Mode = genai.FunctionCallingConfigModeAny</span>
|
|
default:<span class="cov8" title="1">
|
|
return nil, fmt.Errorf("unknown tool_choice string: %s", str)</span>
|
|
}
|
|
<span class="cov8" title="1">return config, nil</span>
|
|
}
|
|
|
|
// Handle object format: {"type": "function", "function": {"name": "..."}}
|
|
<span class="cov8" title="1">if obj, ok := choice.(map[string]interface{}); ok </span><span class="cov8" title="1">{
|
|
if typeVal, ok := obj["type"].(string); ok && typeVal == "function" </span><span class="cov8" title="1">{
|
|
config.FunctionCallingConfig.Mode = genai.FunctionCallingConfigModeAny
|
|
if funcObj, ok := obj["function"].(map[string]interface{}); ok </span><span class="cov8" title="1">{
|
|
if name, ok := funcObj["name"].(string); ok </span><span class="cov8" title="1">{
|
|
config.FunctionCallingConfig.AllowedFunctionNames = []string{name}
|
|
}</span>
|
|
}
|
|
<span class="cov8" title="1">return config, nil</span>
|
|
}
|
|
}
|
|
|
|
<span class="cov8" title="1">return nil, fmt.Errorf("unsupported tool_choice format")</span>
|
|
}
|
|
|
|
// extractToolCalls extracts tool calls from Google's response format to generic api.ToolCall slice.
|
|
func extractToolCalls(resp *genai.GenerateContentResponse) []api.ToolCall <span class="cov8" title="1">{
|
|
var toolCalls []api.ToolCall
|
|
|
|
for _, candidate := range resp.Candidates </span><span class="cov8" title="1">{
|
|
if candidate.Content == nil </span><span class="cov0" title="0">{
|
|
continue</span>
|
|
}
|
|
|
|
<span class="cov8" title="1">for _, part := range candidate.Content.Parts </span><span class="cov8" title="1">{
|
|
if part == nil || part.FunctionCall == nil </span><span class="cov0" title="0">{
|
|
continue</span>
|
|
}
|
|
|
|
// Extract function call details
|
|
<span class="cov8" title="1">fc := part.FunctionCall
|
|
|
|
// Marshal arguments to JSON string
|
|
var argsJSON string
|
|
if fc.Args != nil </span><span class="cov8" title="1">{
|
|
argsBytes, err := json.Marshal(fc.Args)
|
|
if err == nil </span><span class="cov8" title="1">{
|
|
argsJSON = string(argsBytes)
|
|
}</span> else<span class="cov0" title="0"> {
|
|
// Fallback to empty object
|
|
argsJSON = "{}"
|
|
}</span>
|
|
} else<span class="cov0" title="0"> {
|
|
argsJSON = "{}"
|
|
}</span>
|
|
|
|
// Generate ID if Google doesn't provide one
|
|
<span class="cov8" title="1">callID := fc.ID
|
|
if callID == "" </span><span class="cov8" title="1">{
|
|
callID = fmt.Sprintf("call_%s", generateRandomID())
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">toolCalls = append(toolCalls, api.ToolCall{
|
|
ID: callID,
|
|
Name: fc.Name,
|
|
Arguments: argsJSON,
|
|
})</span>
|
|
}
|
|
}
|
|
|
|
<span class="cov8" title="1">return toolCalls</span>
|
|
}
|
|
|
|
// extractToolCallDelta extracts streaming tool call information from response parts.
|
|
func extractToolCallDelta(part *genai.Part, index int) *api.ToolCallDelta <span class="cov0" title="0">{
|
|
if part == nil || part.FunctionCall == nil </span><span class="cov0" title="0">{
|
|
return nil
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">fc := part.FunctionCall
|
|
|
|
// Marshal arguments to JSON string
|
|
var argsJSON string
|
|
if fc.Args != nil </span><span class="cov0" title="0">{
|
|
argsBytes, err := json.Marshal(fc.Args)
|
|
if err == nil </span><span class="cov0" title="0">{
|
|
argsJSON = string(argsBytes)
|
|
}</span> else<span class="cov0" title="0"> {
|
|
argsJSON = "{}"
|
|
}</span>
|
|
} else<span class="cov0" title="0"> {
|
|
argsJSON = "{}"
|
|
}</span>
|
|
|
|
// Generate ID if Google doesn't provide one
|
|
<span class="cov0" title="0">callID := fc.ID
|
|
if callID == "" </span><span class="cov0" title="0">{
|
|
callID = fmt.Sprintf("call_%s", generateRandomID())
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">return &api.ToolCallDelta{
|
|
Index: index,
|
|
ID: callID,
|
|
Name: fc.Name,
|
|
Arguments: argsJSON,
|
|
}</span>
|
|
}
|
|
|
|
// generateRandomID generates a random alphanumeric ID
|
|
func generateRandomID() string <span class="cov8" title="1">{
|
|
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
|
|
const length = 24
|
|
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
|
|
b := make([]byte, length)
|
|
for i := range b </span><span class="cov8" title="1">{
|
|
b[i] = charset[rng.Intn(len(charset))]
|
|
}</span>
|
|
<span class="cov8" title="1">return string(b)</span>
|
|
}
|
|
</pre>
|
|
|
|
<pre class="file" id="file21" style="display: none">package google
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
|
|
"github.com/google/uuid"
|
|
"google.golang.org/genai"
|
|
|
|
"github.com/ajac-zero/latticelm/internal/api"
|
|
"github.com/ajac-zero/latticelm/internal/config"
|
|
)
|
|
|
|
const Name = "google"
|
|
|
|
// Provider implements the Google Generative AI integration.
|
|
type Provider struct {
|
|
cfg config.ProviderConfig
|
|
client *genai.Client
|
|
}
|
|
|
|
// New constructs a Provider using the Google AI API with API key authentication.
|
|
func New(cfg config.ProviderConfig) (*Provider, error) <span class="cov0" title="0">{
|
|
var client *genai.Client
|
|
if cfg.APIKey != "" </span><span class="cov0" title="0">{
|
|
var err error
|
|
client, err = genai.NewClient(context.Background(), &genai.ClientConfig{
|
|
APIKey: cfg.APIKey,
|
|
})
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
return nil, fmt.Errorf("failed to create google client: %w", err)
|
|
}</span>
|
|
}
|
|
<span class="cov0" title="0">return &Provider{
|
|
cfg: cfg,
|
|
client: client,
|
|
}, nil</span>
|
|
}
|
|
|
|
// NewVertexAI constructs a Provider targeting Vertex AI.
|
|
// Vertex AI uses the same genai SDK but with GCP project/location configuration
|
|
// and Application Default Credentials (ADC) or service account authentication.
|
|
func NewVertexAI(vertexCfg config.VertexAIConfig) (*Provider, error) <span class="cov0" title="0">{
|
|
var client *genai.Client
|
|
if vertexCfg.Project != "" && vertexCfg.Location != "" </span><span class="cov0" title="0">{
|
|
var err error
|
|
client, err = genai.NewClient(context.Background(), &genai.ClientConfig{
|
|
Project: vertexCfg.Project,
|
|
Location: vertexCfg.Location,
|
|
Backend: genai.BackendVertexAI,
|
|
})
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
return nil, fmt.Errorf("failed to create vertex ai client: %w", err)
|
|
}</span>
|
|
}
|
|
<span class="cov0" title="0">return &Provider{
|
|
cfg: config.ProviderConfig{
|
|
// Vertex AI doesn't use API key, but set empty for consistency
|
|
APIKey: "",
|
|
},
|
|
client: client,
|
|
}, nil</span>
|
|
}
|
|
|
|
func (p *Provider) Name() string <span class="cov0" title="0">{ return Name }</span>
|
|
|
|
// Generate routes the request to Gemini and returns a ProviderResult.
|
|
func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (*api.ProviderResult, error) <span class="cov0" title="0">{
|
|
if p.client == nil </span><span class="cov0" title="0">{
|
|
return nil, fmt.Errorf("google client not initialized")
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">model := req.Model
|
|
|
|
contents, systemText := convertMessages(messages)
|
|
|
|
// Parse tools if present
|
|
var tools []*genai.Tool
|
|
if req.Tools != nil && len(req.Tools) > 0 </span><span class="cov0" title="0">{
|
|
var err error
|
|
tools, err = parseTools(req)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
return nil, fmt.Errorf("parse tools: %w", err)
|
|
}</span>
|
|
}
|
|
|
|
// Parse tool_choice if present
|
|
<span class="cov0" title="0">var toolConfig *genai.ToolConfig
|
|
if req.ToolChoice != nil && len(req.ToolChoice) > 0 </span><span class="cov0" title="0">{
|
|
var err error
|
|
toolConfig, err = parseToolChoice(req)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
return nil, fmt.Errorf("parse tool_choice: %w", err)
|
|
}</span>
|
|
}
|
|
|
|
<span class="cov0" title="0">config := buildConfig(systemText, req, tools, toolConfig)
|
|
|
|
resp, err := p.client.Models.GenerateContent(ctx, model, contents, config)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
return nil, fmt.Errorf("google api error: %w", err)
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">var text string
|
|
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil </span><span class="cov0" title="0">{
|
|
for _, part := range resp.Candidates[0].Content.Parts </span><span class="cov0" title="0">{
|
|
if part != nil </span><span class="cov0" title="0">{
|
|
text += part.Text
|
|
}</span>
|
|
}
|
|
}
|
|
|
|
<span class="cov0" title="0">var toolCalls []api.ToolCall
|
|
if len(resp.Candidates) > 0 </span><span class="cov0" title="0">{
|
|
toolCalls = extractToolCalls(resp)
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">var inputTokens, outputTokens int
|
|
if resp.UsageMetadata != nil </span><span class="cov0" title="0">{
|
|
inputTokens = int(resp.UsageMetadata.PromptTokenCount)
|
|
outputTokens = int(resp.UsageMetadata.CandidatesTokenCount)
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">return &api.ProviderResult{
|
|
ID: uuid.NewString(),
|
|
Model: model,
|
|
Text: text,
|
|
ToolCalls: toolCalls,
|
|
Usage: api.Usage{
|
|
InputTokens: inputTokens,
|
|
OutputTokens: outputTokens,
|
|
TotalTokens: inputTokens + outputTokens,
|
|
},
|
|
}, nil</span>
|
|
}
|
|
|
|
// GenerateStream handles streaming requests to Google.
|
|
func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (<-chan *api.ProviderStreamDelta, <-chan error) <span class="cov0" title="0">{
|
|
deltaChan := make(chan *api.ProviderStreamDelta)
|
|
errChan := make(chan error, 1)
|
|
|
|
go func() </span><span class="cov0" title="0">{
|
|
defer close(deltaChan)
|
|
defer close(errChan)
|
|
|
|
if p.client == nil </span><span class="cov0" title="0">{
|
|
errChan <- fmt.Errorf("google client not initialized")
|
|
return
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">model := req.Model
|
|
|
|
contents, systemText := convertMessages(messages)
|
|
|
|
// Parse tools if present
|
|
var tools []*genai.Tool
|
|
if req.Tools != nil && len(req.Tools) > 0 </span><span class="cov0" title="0">{
|
|
var err error
|
|
tools, err = parseTools(req)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
errChan <- fmt.Errorf("parse tools: %w", err)
|
|
return
|
|
}</span>
|
|
}
|
|
|
|
// Parse tool_choice if present
|
|
<span class="cov0" title="0">var toolConfig *genai.ToolConfig
|
|
if req.ToolChoice != nil && len(req.ToolChoice) > 0 </span><span class="cov0" title="0">{
|
|
var err error
|
|
toolConfig, err = parseToolChoice(req)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
errChan <- fmt.Errorf("parse tool_choice: %w", err)
|
|
return
|
|
}</span>
|
|
}
|
|
|
|
<span class="cov0" title="0">config := buildConfig(systemText, req, tools, toolConfig)
|
|
|
|
stream := p.client.Models.GenerateContentStream(ctx, model, contents, config)
|
|
|
|
for resp, err := range stream </span><span class="cov0" title="0">{
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
errChan <- fmt.Errorf("google stream error: %w", err)
|
|
return
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil </span><span class="cov0" title="0">{
|
|
for partIndex, part := range resp.Candidates[0].Content.Parts </span><span class="cov0" title="0">{
|
|
if part != nil </span><span class="cov0" title="0">{
|
|
// Handle text content
|
|
if part.Text != "" </span><span class="cov0" title="0">{
|
|
select </span>{
|
|
case deltaChan <- &api.ProviderStreamDelta{Text: part.Text}:<span class="cov0" title="0"></span>
|
|
case <-ctx.Done():<span class="cov0" title="0">
|
|
errChan <- ctx.Err()
|
|
return</span>
|
|
}
|
|
}
|
|
|
|
// Handle tool call content
|
|
<span class="cov0" title="0">if part.FunctionCall != nil </span><span class="cov0" title="0">{
|
|
delta := extractToolCallDelta(part, partIndex)
|
|
if delta != nil </span><span class="cov0" title="0">{
|
|
select </span>{
|
|
case deltaChan <- &api.ProviderStreamDelta{ToolCallDelta: delta}:<span class="cov0" title="0"></span>
|
|
case <-ctx.Done():<span class="cov0" title="0">
|
|
errChan <- ctx.Err()
|
|
return</span>
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
<span class="cov0" title="0">select </span>{
|
|
case deltaChan <- &api.ProviderStreamDelta{Done: true}:<span class="cov0" title="0"></span>
|
|
case <-ctx.Done():<span class="cov0" title="0">
|
|
errChan <- ctx.Err()</span>
|
|
}
|
|
}()
|
|
|
|
<span class="cov0" title="0">return deltaChan, errChan</span>
|
|
}
|
|
|
|
// convertMessages splits messages into Gemini contents and system text.
|
|
func convertMessages(messages []api.Message) ([]*genai.Content, string) <span class="cov0" title="0">{
|
|
var contents []*genai.Content
|
|
var systemText string
|
|
|
|
// Build a map of CallID -> Name from assistant tool calls
|
|
// This allows us to look up function names when processing tool results
|
|
callIDToName := make(map[string]string)
|
|
for _, msg := range messages </span><span class="cov0" title="0">{
|
|
if msg.Role == "assistant" || msg.Role == "model" </span><span class="cov0" title="0">{
|
|
for _, tc := range msg.ToolCalls </span><span class="cov0" title="0">{
|
|
if tc.ID != "" && tc.Name != "" </span><span class="cov0" title="0">{
|
|
callIDToName[tc.ID] = tc.Name
|
|
}</span>
|
|
}
|
|
}
|
|
}
|
|
|
|
<span class="cov0" title="0">for _, msg := range messages </span><span class="cov0" title="0">{
|
|
if msg.Role == "system" || msg.Role == "developer" </span><span class="cov0" title="0">{
|
|
for _, block := range msg.Content </span><span class="cov0" title="0">{
|
|
if block.Type == "input_text" || block.Type == "output_text" </span><span class="cov0" title="0">{
|
|
systemText += block.Text
|
|
}</span>
|
|
}
|
|
<span class="cov0" title="0">continue</span>
|
|
}
|
|
|
|
<span class="cov0" title="0">if msg.Role == "tool" </span><span class="cov0" title="0">{
|
|
// Tool results are sent as FunctionResponse in user role message
|
|
var output string
|
|
for _, block := range msg.Content </span><span class="cov0" title="0">{
|
|
if block.Type == "input_text" || block.Type == "output_text" </span><span class="cov0" title="0">{
|
|
output += block.Text
|
|
}</span>
|
|
}
|
|
|
|
// Parse output as JSON map, or wrap in {"output": "..."} if not JSON
|
|
<span class="cov0" title="0">var responseMap map[string]any
|
|
if err := json.Unmarshal([]byte(output), &responseMap); err != nil </span><span class="cov0" title="0">{
|
|
// Not JSON, wrap it
|
|
responseMap = map[string]any{"output": output}
|
|
}</span>
|
|
|
|
// Get function name from message or look it up from CallID
|
|
<span class="cov0" title="0">name := msg.Name
|
|
if name == "" && msg.CallID != "" </span><span class="cov0" title="0">{
|
|
name = callIDToName[msg.CallID]
|
|
}</span>
|
|
|
|
// Create FunctionResponse part with CallID and Name from message
|
|
<span class="cov0" title="0">part := &genai.Part{
|
|
FunctionResponse: &genai.FunctionResponse{
|
|
ID: msg.CallID,
|
|
Name: name, // Name is required by Google
|
|
Response: responseMap,
|
|
},
|
|
}
|
|
|
|
// Add to user role message
|
|
contents = append(contents, &genai.Content{
|
|
Role: "user",
|
|
Parts: []*genai.Part{part},
|
|
})
|
|
continue</span>
|
|
}
|
|
|
|
<span class="cov0" title="0">var parts []*genai.Part
|
|
for _, block := range msg.Content </span><span class="cov0" title="0">{
|
|
if block.Type == "input_text" || block.Type == "output_text" </span><span class="cov0" title="0">{
|
|
parts = append(parts, genai.NewPartFromText(block.Text))
|
|
}</span>
|
|
}
|
|
|
|
// Add tool calls for assistant messages
|
|
<span class="cov0" title="0">if msg.Role == "assistant" || msg.Role == "model" </span><span class="cov0" title="0">{
|
|
for _, tc := range msg.ToolCalls </span><span class="cov0" title="0">{
|
|
// Parse arguments JSON into map
|
|
var args map[string]any
|
|
if err := json.Unmarshal([]byte(tc.Arguments), &args); err != nil </span><span class="cov0" title="0">{
|
|
// If unmarshal fails, skip this tool call
|
|
continue</span>
|
|
}
|
|
|
|
// Create FunctionCall part
|
|
<span class="cov0" title="0">parts = append(parts, &genai.Part{
|
|
FunctionCall: &genai.FunctionCall{
|
|
ID: tc.ID,
|
|
Name: tc.Name,
|
|
Args: args,
|
|
},
|
|
})</span>
|
|
}
|
|
}
|
|
|
|
<span class="cov0" title="0">role := "user"
|
|
if msg.Role == "assistant" || msg.Role == "model" </span><span class="cov0" title="0">{
|
|
role = "model"
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">contents = append(contents, &genai.Content{
|
|
Role: role,
|
|
Parts: parts,
|
|
})</span>
|
|
}
|
|
|
|
<span class="cov0" title="0">return contents, systemText</span>
|
|
}
|
|
|
|
// buildConfig constructs a GenerateContentConfig from system text and request params.
|
|
func buildConfig(systemText string, req *api.ResponseRequest, tools []*genai.Tool, toolConfig *genai.ToolConfig) *genai.GenerateContentConfig <span class="cov0" title="0">{
|
|
var cfg *genai.GenerateContentConfig
|
|
|
|
needsCfg := systemText != "" || req.MaxOutputTokens != nil || req.Temperature != nil || req.TopP != nil || tools != nil || toolConfig != nil
|
|
if !needsCfg </span><span class="cov0" title="0">{
|
|
return nil
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">cfg = &genai.GenerateContentConfig{}
|
|
|
|
if systemText != "" </span><span class="cov0" title="0">{
|
|
cfg.SystemInstruction = &genai.Content{
|
|
Parts: []*genai.Part{genai.NewPartFromText(systemText)},
|
|
}
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">if req.MaxOutputTokens != nil </span><span class="cov0" title="0">{
|
|
cfg.MaxOutputTokens = int32(*req.MaxOutputTokens)
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">if req.Temperature != nil </span><span class="cov0" title="0">{
|
|
t := float32(*req.Temperature)
|
|
cfg.Temperature = &t
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">if req.TopP != nil </span><span class="cov0" title="0">{
|
|
tp := float32(*req.TopP)
|
|
cfg.TopP = &tp
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">if tools != nil </span><span class="cov0" title="0">{
|
|
cfg.Tools = tools
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">if toolConfig != nil </span><span class="cov0" title="0">{
|
|
cfg.ToolConfig = toolConfig
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">return cfg</span>
|
|
}
|
|
|
|
func chooseModel(requested, defaultModel string) string <span class="cov0" title="0">{
|
|
if requested != "" </span><span class="cov0" title="0">{
|
|
return requested
|
|
}</span>
|
|
<span class="cov0" title="0">if defaultModel != "" </span><span class="cov0" title="0">{
|
|
return defaultModel
|
|
}</span>
|
|
<span class="cov0" title="0">return "gemini-2.0-flash-exp"</span>
|
|
}
|
|
</pre>
|
|
|
|
<pre class="file" id="file22" style="display: none">package openai
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
|
|
"github.com/ajac-zero/latticelm/internal/api"
|
|
"github.com/openai/openai-go/v3"
|
|
"github.com/openai/openai-go/v3/shared"
|
|
)
|
|
|
|
// parseTools converts Open Responses tools to OpenAI format
|
|
func parseTools(req *api.ResponseRequest) ([]openai.ChatCompletionToolUnionParam, error) <span class="cov8" title="1">{
|
|
if req.Tools == nil || len(req.Tools) == 0 </span><span class="cov8" title="1">{
|
|
return nil, nil
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">var toolDefs []map[string]interface{}
|
|
if err := json.Unmarshal(req.Tools, &toolDefs); err != nil </span><span class="cov8" title="1">{
|
|
return nil, fmt.Errorf("unmarshal tools: %w", err)
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">var tools []openai.ChatCompletionToolUnionParam
|
|
for _, td := range toolDefs </span><span class="cov8" title="1">{
|
|
// Convert Open Responses tool to OpenAI ChatCompletionFunctionToolParam
|
|
// Extract: name, description, parameters
|
|
name, _ := td["name"].(string)
|
|
desc, _ := td["description"].(string)
|
|
params, _ := td["parameters"].(map[string]interface{})
|
|
|
|
funcDef := shared.FunctionDefinitionParam{
|
|
Name: name,
|
|
}
|
|
|
|
if desc != "" </span><span class="cov8" title="1">{
|
|
funcDef.Description = openai.String(desc)
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">if params != nil </span><span class="cov8" title="1">{
|
|
funcDef.Parameters = shared.FunctionParameters(params)
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">tools = append(tools, openai.ChatCompletionFunctionTool(funcDef))</span>
|
|
}
|
|
|
|
<span class="cov8" title="1">return tools, nil</span>
|
|
}
|
|
|
|
// parseToolChoice converts Open Responses tool_choice to OpenAI format
|
|
func parseToolChoice(req *api.ResponseRequest) (openai.ChatCompletionToolChoiceOptionUnionParam, error) <span class="cov8" title="1">{
|
|
var result openai.ChatCompletionToolChoiceOptionUnionParam
|
|
|
|
if req.ToolChoice == nil || len(req.ToolChoice) == 0 </span><span class="cov8" title="1">{
|
|
return result, nil
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">var choice interface{}
|
|
if err := json.Unmarshal(req.ToolChoice, &choice); err != nil </span><span class="cov8" title="1">{
|
|
return result, fmt.Errorf("unmarshal tool_choice: %w", err)
|
|
}</span>
|
|
|
|
// Handle string values: "auto", "none", "required"
|
|
<span class="cov8" title="1">if str, ok := choice.(string); ok </span><span class="cov8" title="1">{
|
|
result.OfAuto = openai.String(str)
|
|
return result, nil
|
|
}</span>
|
|
|
|
// Handle specific function selection: {"type": "function", "function": {"name": "..."}}
|
|
<span class="cov8" title="1">if obj, ok := choice.(map[string]interface{}); ok </span><span class="cov8" title="1">{
|
|
funcObj, _ := obj["function"].(map[string]interface{})
|
|
name, _ := funcObj["name"].(string)
|
|
|
|
return openai.ToolChoiceOptionFunctionToolChoice(
|
|
openai.ChatCompletionNamedToolChoiceFunctionParam{
|
|
Name: name,
|
|
},
|
|
), nil
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">return result, fmt.Errorf("invalid tool_choice format")</span>
|
|
}
|
|
|
|
// extractToolCalls converts OpenAI tool calls to api.ToolCall
|
|
func extractToolCalls(message openai.ChatCompletionMessage) []api.ToolCall <span class="cov0" title="0">{
|
|
if len(message.ToolCalls) == 0 </span><span class="cov0" title="0">{
|
|
return nil
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">var toolCalls []api.ToolCall
|
|
for _, tc := range message.ToolCalls </span><span class="cov0" title="0">{
|
|
toolCalls = append(toolCalls, api.ToolCall{
|
|
ID: tc.ID,
|
|
Name: tc.Function.Name,
|
|
Arguments: tc.Function.Arguments,
|
|
})
|
|
}</span>
|
|
<span class="cov0" title="0">return toolCalls</span>
|
|
}
|
|
|
|
// extractToolCallDelta extracts tool call delta from streaming chunk choice
|
|
func extractToolCallDelta(choice openai.ChatCompletionChunkChoice) *api.ToolCallDelta <span class="cov0" title="0">{
|
|
if len(choice.Delta.ToolCalls) == 0 </span><span class="cov0" title="0">{
|
|
return nil
|
|
}</span>
|
|
|
|
// OpenAI sends tool calls with index in the delta
|
|
<span class="cov0" title="0">for _, tc := range choice.Delta.ToolCalls </span><span class="cov0" title="0">{
|
|
return &api.ToolCallDelta{
|
|
Index: int(tc.Index),
|
|
ID: tc.ID,
|
|
Name: tc.Function.Name,
|
|
Arguments: tc.Function.Arguments,
|
|
}
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">return nil</span>
|
|
}
|
|
</pre>
|
|
|
|
<pre class="file" id="file23" style="display: none">package openai
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"github.com/openai/openai-go/v3"
|
|
"github.com/openai/openai-go/v3/azure"
|
|
"github.com/openai/openai-go/v3/option"
|
|
|
|
"github.com/ajac-zero/latticelm/internal/api"
|
|
"github.com/ajac-zero/latticelm/internal/config"
|
|
)
|
|
|
|
const Name = "openai"
|
|
|
|
// Provider implements the OpenAI SDK integration.
|
|
// It supports both direct OpenAI API and Azure-hosted endpoints.
|
|
type Provider struct {
|
|
cfg config.ProviderConfig
|
|
client *openai.Client
|
|
azure bool
|
|
}
|
|
|
|
// New constructs a Provider for the direct OpenAI API.
|
|
func New(cfg config.ProviderConfig) *Provider <span class="cov0" title="0">{
|
|
var client *openai.Client
|
|
if cfg.APIKey != "" </span><span class="cov0" title="0">{
|
|
c := openai.NewClient(option.WithAPIKey(cfg.APIKey))
|
|
client = &c
|
|
}</span>
|
|
<span class="cov0" title="0">return &Provider{
|
|
cfg: cfg,
|
|
client: client,
|
|
}</span>
|
|
}
|
|
|
|
// NewAzure constructs a Provider targeting Azure OpenAI.
|
|
// Azure OpenAI uses the OpenAI SDK with the azure subpackage for proper
|
|
// endpoint routing, api-version query parameter, and API key header.
|
|
func NewAzure(azureCfg config.AzureOpenAIConfig) *Provider <span class="cov0" title="0">{
|
|
var client *openai.Client
|
|
if azureCfg.APIKey != "" && azureCfg.Endpoint != "" </span><span class="cov0" title="0">{
|
|
apiVersion := azureCfg.APIVersion
|
|
if apiVersion == "" </span><span class="cov0" title="0">{
|
|
apiVersion = "2024-12-01-preview"
|
|
}</span>
|
|
<span class="cov0" title="0">c := openai.NewClient(
|
|
azure.WithEndpoint(azureCfg.Endpoint, apiVersion),
|
|
azure.WithAPIKey(azureCfg.APIKey),
|
|
)
|
|
client = &c</span>
|
|
}
|
|
<span class="cov0" title="0">return &Provider{
|
|
cfg: config.ProviderConfig{
|
|
APIKey: azureCfg.APIKey,
|
|
},
|
|
client: client,
|
|
azure: true,
|
|
}</span>
|
|
}
|
|
|
|
// Name returns the provider identifier.
|
|
func (p *Provider) Name() string <span class="cov0" title="0">{ return Name }</span>
|
|
|
|
// Generate routes the request to OpenAI and returns a ProviderResult.
|
|
func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (*api.ProviderResult, error) <span class="cov0" title="0">{
|
|
if p.cfg.APIKey == "" </span><span class="cov0" title="0">{
|
|
return nil, fmt.Errorf("openai api key missing")
|
|
}</span>
|
|
<span class="cov0" title="0">if p.client == nil </span><span class="cov0" title="0">{
|
|
return nil, fmt.Errorf("openai client not initialized")
|
|
}</span>
|
|
|
|
// Convert messages to OpenAI format
|
|
<span class="cov0" title="0">oaiMessages := make([]openai.ChatCompletionMessageParamUnion, 0, len(messages))
|
|
for _, msg := range messages </span><span class="cov0" title="0">{
|
|
var content string
|
|
for _, block := range msg.Content </span><span class="cov0" title="0">{
|
|
if block.Type == "input_text" || block.Type == "output_text" </span><span class="cov0" title="0">{
|
|
content += block.Text
|
|
}</span>
|
|
}
|
|
|
|
<span class="cov0" title="0">switch msg.Role </span>{
|
|
case "user":<span class="cov0" title="0">
|
|
oaiMessages = append(oaiMessages, openai.UserMessage(content))</span>
|
|
case "assistant":<span class="cov0" title="0">
|
|
// If assistant message has tool calls, include them
|
|
if len(msg.ToolCalls) > 0 </span><span class="cov0" title="0">{
|
|
toolCalls := make([]openai.ChatCompletionMessageToolCallUnionParam, len(msg.ToolCalls))
|
|
for i, tc := range msg.ToolCalls </span><span class="cov0" title="0">{
|
|
toolCalls[i] = openai.ChatCompletionMessageToolCallUnionParam{
|
|
OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
|
|
ID: tc.ID,
|
|
Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
|
|
Name: tc.Name,
|
|
Arguments: tc.Arguments,
|
|
},
|
|
},
|
|
}
|
|
}</span>
|
|
<span class="cov0" title="0">msgParam := openai.ChatCompletionAssistantMessageParam{
|
|
ToolCalls: toolCalls,
|
|
}
|
|
if content != "" </span><span class="cov0" title="0">{
|
|
msgParam.Content.OfString = openai.String(content)
|
|
}</span>
|
|
<span class="cov0" title="0">oaiMessages = append(oaiMessages, openai.ChatCompletionMessageParamUnion{
|
|
OfAssistant: &msgParam,
|
|
})</span>
|
|
} else<span class="cov0" title="0"> {
|
|
oaiMessages = append(oaiMessages, openai.AssistantMessage(content))
|
|
}</span>
|
|
case "system":<span class="cov0" title="0">
|
|
oaiMessages = append(oaiMessages, openai.SystemMessage(content))</span>
|
|
case "developer":<span class="cov0" title="0">
|
|
oaiMessages = append(oaiMessages, openai.SystemMessage(content))</span>
|
|
case "tool":<span class="cov0" title="0">
|
|
oaiMessages = append(oaiMessages, openai.ToolMessage(content, msg.CallID))</span>
|
|
}
|
|
}
|
|
|
|
<span class="cov0" title="0">params := openai.ChatCompletionNewParams{
|
|
Model: openai.ChatModel(req.Model),
|
|
Messages: oaiMessages,
|
|
}
|
|
if req.MaxOutputTokens != nil </span><span class="cov0" title="0">{
|
|
params.MaxTokens = openai.Int(int64(*req.MaxOutputTokens))
|
|
}</span>
|
|
<span class="cov0" title="0">if req.Temperature != nil </span><span class="cov0" title="0">{
|
|
params.Temperature = openai.Float(*req.Temperature)
|
|
}</span>
|
|
<span class="cov0" title="0">if req.TopP != nil </span><span class="cov0" title="0">{
|
|
params.TopP = openai.Float(*req.TopP)
|
|
}</span>
|
|
|
|
// Add tools if present
|
|
<span class="cov0" title="0">if req.Tools != nil && len(req.Tools) > 0 </span><span class="cov0" title="0">{
|
|
tools, err := parseTools(req)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
return nil, fmt.Errorf("parse tools: %w", err)
|
|
}</span>
|
|
<span class="cov0" title="0">params.Tools = tools</span>
|
|
}
|
|
|
|
// Add tool_choice if present
|
|
<span class="cov0" title="0">if req.ToolChoice != nil && len(req.ToolChoice) > 0 </span><span class="cov0" title="0">{
|
|
toolChoice, err := parseToolChoice(req)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
return nil, fmt.Errorf("parse tool_choice: %w", err)
|
|
}</span>
|
|
<span class="cov0" title="0">params.ToolChoice = toolChoice</span>
|
|
}
|
|
|
|
// Add parallel_tool_calls if specified
|
|
<span class="cov0" title="0">if req.ParallelToolCalls != nil </span><span class="cov0" title="0">{
|
|
params.ParallelToolCalls = openai.Bool(*req.ParallelToolCalls)
|
|
}</span>
|
|
|
|
// Call OpenAI API
|
|
<span class="cov0" title="0">resp, err := p.client.Chat.Completions.New(ctx, params)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
return nil, fmt.Errorf("openai api error: %w", err)
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">var combinedText string
|
|
var toolCalls []api.ToolCall
|
|
|
|
for _, choice := range resp.Choices </span><span class="cov0" title="0">{
|
|
combinedText += choice.Message.Content
|
|
if len(choice.Message.ToolCalls) > 0 </span><span class="cov0" title="0">{
|
|
toolCalls = append(toolCalls, extractToolCalls(choice.Message)...)
|
|
}</span>
|
|
}
|
|
|
|
<span class="cov0" title="0">return &api.ProviderResult{
|
|
ID: resp.ID,
|
|
Model: resp.Model,
|
|
Text: combinedText,
|
|
ToolCalls: toolCalls,
|
|
Usage: api.Usage{
|
|
InputTokens: int(resp.Usage.PromptTokens),
|
|
OutputTokens: int(resp.Usage.CompletionTokens),
|
|
TotalTokens: int(resp.Usage.TotalTokens),
|
|
},
|
|
}, nil</span>
|
|
}
|
|
|
|
// GenerateStream handles streaming requests to OpenAI.
|
|
func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (<-chan *api.ProviderStreamDelta, <-chan error) <span class="cov0" title="0">{
|
|
deltaChan := make(chan *api.ProviderStreamDelta)
|
|
errChan := make(chan error, 1)
|
|
|
|
go func() </span><span class="cov0" title="0">{
|
|
defer close(deltaChan)
|
|
defer close(errChan)
|
|
|
|
if p.cfg.APIKey == "" </span><span class="cov0" title="0">{
|
|
errChan <- fmt.Errorf("openai api key missing")
|
|
return
|
|
}</span>
|
|
<span class="cov0" title="0">if p.client == nil </span><span class="cov0" title="0">{
|
|
errChan <- fmt.Errorf("openai client not initialized")
|
|
return
|
|
}</span>
|
|
|
|
// Convert messages to OpenAI format
|
|
<span class="cov0" title="0">oaiMessages := make([]openai.ChatCompletionMessageParamUnion, 0, len(messages))
|
|
for _, msg := range messages </span><span class="cov0" title="0">{
|
|
var content string
|
|
for _, block := range msg.Content </span><span class="cov0" title="0">{
|
|
if block.Type == "input_text" || block.Type == "output_text" </span><span class="cov0" title="0">{
|
|
content += block.Text
|
|
}</span>
|
|
}
|
|
|
|
<span class="cov0" title="0">switch msg.Role </span>{
|
|
case "user":<span class="cov0" title="0">
|
|
oaiMessages = append(oaiMessages, openai.UserMessage(content))</span>
|
|
case "assistant":<span class="cov0" title="0">
|
|
// If assistant message has tool calls, include them
|
|
if len(msg.ToolCalls) > 0 </span><span class="cov0" title="0">{
|
|
toolCalls := make([]openai.ChatCompletionMessageToolCallUnionParam, len(msg.ToolCalls))
|
|
for i, tc := range msg.ToolCalls </span><span class="cov0" title="0">{
|
|
toolCalls[i] = openai.ChatCompletionMessageToolCallUnionParam{
|
|
OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
|
|
ID: tc.ID,
|
|
Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
|
|
Name: tc.Name,
|
|
Arguments: tc.Arguments,
|
|
},
|
|
},
|
|
}
|
|
}</span>
|
|
<span class="cov0" title="0">msgParam := openai.ChatCompletionAssistantMessageParam{
|
|
ToolCalls: toolCalls,
|
|
}
|
|
if content != "" </span><span class="cov0" title="0">{
|
|
msgParam.Content.OfString = openai.String(content)
|
|
}</span>
|
|
<span class="cov0" title="0">oaiMessages = append(oaiMessages, openai.ChatCompletionMessageParamUnion{
|
|
OfAssistant: &msgParam,
|
|
})</span>
|
|
} else<span class="cov0" title="0"> {
|
|
oaiMessages = append(oaiMessages, openai.AssistantMessage(content))
|
|
}</span>
|
|
case "system":<span class="cov0" title="0">
|
|
oaiMessages = append(oaiMessages, openai.SystemMessage(content))</span>
|
|
case "developer":<span class="cov0" title="0">
|
|
oaiMessages = append(oaiMessages, openai.SystemMessage(content))</span>
|
|
case "tool":<span class="cov0" title="0">
|
|
oaiMessages = append(oaiMessages, openai.ToolMessage(content, msg.CallID))</span>
|
|
}
|
|
}
|
|
|
|
<span class="cov0" title="0">params := openai.ChatCompletionNewParams{
|
|
Model: openai.ChatModel(req.Model),
|
|
Messages: oaiMessages,
|
|
}
|
|
if req.MaxOutputTokens != nil </span><span class="cov0" title="0">{
|
|
params.MaxTokens = openai.Int(int64(*req.MaxOutputTokens))
|
|
}</span>
|
|
<span class="cov0" title="0">if req.Temperature != nil </span><span class="cov0" title="0">{
|
|
params.Temperature = openai.Float(*req.Temperature)
|
|
}</span>
|
|
<span class="cov0" title="0">if req.TopP != nil </span><span class="cov0" title="0">{
|
|
params.TopP = openai.Float(*req.TopP)
|
|
}</span>
|
|
|
|
// Add tools if present
|
|
<span class="cov0" title="0">if req.Tools != nil && len(req.Tools) > 0 </span><span class="cov0" title="0">{
|
|
tools, err := parseTools(req)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
errChan <- fmt.Errorf("parse tools: %w", err)
|
|
return
|
|
}</span>
|
|
<span class="cov0" title="0">params.Tools = tools</span>
|
|
}
|
|
|
|
// Add tool_choice if present
|
|
<span class="cov0" title="0">if req.ToolChoice != nil && len(req.ToolChoice) > 0 </span><span class="cov0" title="0">{
|
|
toolChoice, err := parseToolChoice(req)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
errChan <- fmt.Errorf("parse tool_choice: %w", err)
|
|
return
|
|
}</span>
|
|
<span class="cov0" title="0">params.ToolChoice = toolChoice</span>
|
|
}
|
|
|
|
// Add parallel_tool_calls if specified
|
|
<span class="cov0" title="0">if req.ParallelToolCalls != nil </span><span class="cov0" title="0">{
|
|
params.ParallelToolCalls = openai.Bool(*req.ParallelToolCalls)
|
|
}</span>
|
|
|
|
// Create streaming request
|
|
<span class="cov0" title="0">stream := p.client.Chat.Completions.NewStreaming(ctx, params)
|
|
|
|
// Process stream
|
|
for stream.Next() </span><span class="cov0" title="0">{
|
|
chunk := stream.Current()
|
|
|
|
for _, choice := range chunk.Choices </span><span class="cov0" title="0">{
|
|
// Handle text content
|
|
if choice.Delta.Content != "" </span><span class="cov0" title="0">{
|
|
select </span>{
|
|
case deltaChan <- &api.ProviderStreamDelta{
|
|
ID: chunk.ID,
|
|
Model: chunk.Model,
|
|
Text: choice.Delta.Content,
|
|
}:<span class="cov0" title="0"></span>
|
|
case <-ctx.Done():<span class="cov0" title="0">
|
|
errChan <- ctx.Err()
|
|
return</span>
|
|
}
|
|
}
|
|
|
|
// Handle tool call deltas
|
|
<span class="cov0" title="0">if len(choice.Delta.ToolCalls) > 0 </span><span class="cov0" title="0">{
|
|
delta := extractToolCallDelta(choice)
|
|
if delta != nil </span><span class="cov0" title="0">{
|
|
select </span>{
|
|
case deltaChan <- &api.ProviderStreamDelta{
|
|
ID: chunk.ID,
|
|
Model: chunk.Model,
|
|
ToolCallDelta: delta,
|
|
}:<span class="cov0" title="0"></span>
|
|
case <-ctx.Done():<span class="cov0" title="0">
|
|
errChan <- ctx.Err()
|
|
return</span>
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
<span class="cov0" title="0">if err := stream.Err(); err != nil </span><span class="cov0" title="0">{
|
|
errChan <- fmt.Errorf("openai stream error: %w", err)
|
|
return
|
|
}</span>
|
|
|
|
// Send final delta
|
|
<span class="cov0" title="0">select </span>{
|
|
case deltaChan <- &api.ProviderStreamDelta{Done: true}:<span class="cov0" title="0"></span>
|
|
case <-ctx.Done():<span class="cov0" title="0">
|
|
errChan <- ctx.Err()</span>
|
|
}
|
|
}()
|
|
|
|
<span class="cov0" title="0">return deltaChan, errChan</span>
|
|
}
|
|
|
|
func chooseModel(requested, defaultModel string) string <span class="cov0" title="0">{
|
|
if requested != "" </span><span class="cov0" title="0">{
|
|
return requested
|
|
}</span>
|
|
<span class="cov0" title="0">if defaultModel != "" </span><span class="cov0" title="0">{
|
|
return defaultModel
|
|
}</span>
|
|
<span class="cov0" title="0">return "gpt-4o-mini"</span>
|
|
}
|
|
</pre>
|
|
|
|
<pre class="file" id="file24" style="display: none">package providers
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"github.com/ajac-zero/latticelm/internal/api"
|
|
"github.com/ajac-zero/latticelm/internal/config"
|
|
anthropicprovider "github.com/ajac-zero/latticelm/internal/providers/anthropic"
|
|
googleprovider "github.com/ajac-zero/latticelm/internal/providers/google"
|
|
openaiprovider "github.com/ajac-zero/latticelm/internal/providers/openai"
|
|
)
|
|
|
|
// Provider represents a unified interface that each LLM provider must implement.
|
|
type Provider interface {
|
|
Name() string
|
|
Generate(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (*api.ProviderResult, error)
|
|
GenerateStream(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (<-chan *api.ProviderStreamDelta, <-chan error)
|
|
}
|
|
|
|
// Registry keeps track of registered providers and model-to-provider mappings.
|
|
type Registry struct {
|
|
providers map[string]Provider
|
|
models map[string]string // model name -> provider entry name
|
|
providerModelIDs map[string]string // model name -> provider model ID
|
|
modelList []config.ModelEntry
|
|
}
|
|
|
|
// NewRegistry constructs provider implementations from configuration.
|
|
func NewRegistry(entries map[string]config.ProviderEntry, models []config.ModelEntry) (*Registry, error) <span class="cov8" title="1">{
|
|
return NewRegistryWithCircuitBreaker(entries, models, nil)
|
|
}</span>
|
|
|
|
// NewRegistryWithCircuitBreaker constructs provider implementations with circuit breaker support.
|
|
// The onStateChange callback is invoked when circuit breaker state changes.
|
|
func NewRegistryWithCircuitBreaker(
|
|
entries map[string]config.ProviderEntry,
|
|
models []config.ModelEntry,
|
|
onStateChange func(provider, from, to string),
|
|
) (*Registry, error) <span class="cov8" title="1">{
|
|
reg := &Registry{
|
|
providers: make(map[string]Provider),
|
|
models: make(map[string]string),
|
|
providerModelIDs: make(map[string]string),
|
|
modelList: models,
|
|
}
|
|
|
|
// Use default circuit breaker configuration
|
|
cbConfig := DefaultCircuitBreakerConfig()
|
|
cbConfig.OnStateChange = onStateChange
|
|
|
|
for name, entry := range entries </span><span class="cov8" title="1">{
|
|
p, err := buildProvider(entry)
|
|
if err != nil </span><span class="cov8" title="1">{
|
|
return nil, fmt.Errorf("provider %q: %w", name, err)
|
|
}</span>
|
|
<span class="cov8" title="1">if p != nil </span><span class="cov8" title="1">{
|
|
// Wrap provider with circuit breaker
|
|
reg.providers[name] = NewCircuitBreakerProvider(p, cbConfig)
|
|
}</span>
|
|
}
|
|
|
|
<span class="cov8" title="1">for _, m := range models </span><span class="cov8" title="1">{
|
|
reg.models[m.Name] = m.Provider
|
|
if m.ProviderModelID != "" </span><span class="cov8" title="1">{
|
|
reg.providerModelIDs[m.Name] = m.ProviderModelID
|
|
}</span>
|
|
}
|
|
|
|
<span class="cov8" title="1">if len(reg.providers) == 0 </span><span class="cov8" title="1">{
|
|
return nil, fmt.Errorf("no providers configured")
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">return reg, nil</span>
|
|
}
|
|
|
|
func buildProvider(entry config.ProviderEntry) (Provider, error) <span class="cov8" title="1">{
|
|
// Vertex AI doesn't require APIKey, so check for it separately
|
|
if entry.Type != "vertexai" && entry.APIKey == "" </span><span class="cov8" title="1">{
|
|
return nil, nil
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">switch entry.Type </span>{
|
|
case "openai":<span class="cov8" title="1">
|
|
return openaiprovider.New(config.ProviderConfig{
|
|
APIKey: entry.APIKey,
|
|
Endpoint: entry.Endpoint,
|
|
}), nil</span>
|
|
case "azureopenai":<span class="cov8" title="1">
|
|
if entry.Endpoint == "" </span><span class="cov8" title="1">{
|
|
return nil, fmt.Errorf("endpoint is required for azureopenai")
|
|
}</span>
|
|
<span class="cov8" title="1">return openaiprovider.NewAzure(config.AzureOpenAIConfig{
|
|
APIKey: entry.APIKey,
|
|
Endpoint: entry.Endpoint,
|
|
APIVersion: entry.APIVersion,
|
|
}), nil</span>
|
|
case "anthropic":<span class="cov8" title="1">
|
|
return anthropicprovider.New(config.ProviderConfig{
|
|
APIKey: entry.APIKey,
|
|
Endpoint: entry.Endpoint,
|
|
}), nil</span>
|
|
case "azureanthropic":<span class="cov8" title="1">
|
|
if entry.Endpoint == "" </span><span class="cov8" title="1">{
|
|
return nil, fmt.Errorf("endpoint is required for azureanthropic")
|
|
}</span>
|
|
<span class="cov8" title="1">return anthropicprovider.NewAzure(config.AzureAnthropicConfig{
|
|
APIKey: entry.APIKey,
|
|
Endpoint: entry.Endpoint,
|
|
}), nil</span>
|
|
case "google":<span class="cov8" title="1">
|
|
return googleprovider.New(config.ProviderConfig{
|
|
APIKey: entry.APIKey,
|
|
Endpoint: entry.Endpoint,
|
|
})</span>
|
|
case "vertexai":<span class="cov8" title="1">
|
|
if entry.Project == "" || entry.Location == "" </span><span class="cov8" title="1">{
|
|
return nil, fmt.Errorf("project and location are required for vertexai")
|
|
}</span>
|
|
<span class="cov8" title="1">return googleprovider.NewVertexAI(config.VertexAIConfig{
|
|
Project: entry.Project,
|
|
Location: entry.Location,
|
|
})</span>
|
|
default:<span class="cov8" title="1">
|
|
return nil, fmt.Errorf("unknown provider type %q", entry.Type)</span>
|
|
}
|
|
}
|
|
|
|
// Get returns provider by entry name.
|
|
func (r *Registry) Get(name string) (Provider, bool) <span class="cov8" title="1">{
|
|
p, ok := r.providers[name]
|
|
return p, ok
|
|
}</span>
|
|
|
|
// Models returns the list of configured models and their provider entry names.
|
|
func (r *Registry) Models() []struct{ Provider, Model string } <span class="cov8" title="1">{
|
|
var out []struct{ Provider, Model string }
|
|
for _, m := range r.modelList </span><span class="cov8" title="1">{
|
|
out = append(out, struct{ Provider, Model string }{Provider: m.Provider, Model: m.Name})
|
|
}</span>
|
|
<span class="cov8" title="1">return out</span>
|
|
}
|
|
|
|
// ResolveModelID returns the provider_model_id for a model, falling back to the model name itself.
|
|
func (r *Registry) ResolveModelID(model string) string <span class="cov8" title="1">{
|
|
if id, ok := r.providerModelIDs[model]; ok </span><span class="cov8" title="1">{
|
|
return id
|
|
}</span>
|
|
<span class="cov8" title="1">return model</span>
|
|
}
|
|
|
|
// Default returns the provider for the given model name.
|
|
func (r *Registry) Default(model string) (Provider, error) <span class="cov8" title="1">{
|
|
if model != "" </span><span class="cov8" title="1">{
|
|
if providerName, ok := r.models[model]; ok </span><span class="cov8" title="1">{
|
|
if p, ok := r.providers[providerName]; ok </span><span class="cov8" title="1">{
|
|
return p, nil
|
|
}</span>
|
|
}
|
|
}
|
|
|
|
<span class="cov8" title="1">for _, p := range r.providers </span><span class="cov8" title="1">{
|
|
return p, nil
|
|
}</span>
|
|
|
|
<span class="cov0" title="0">return nil, fmt.Errorf("no providers available")</span>
|
|
}
|
|
</pre>
|
|
|
|
<pre class="file" id="file25" style="display: none">package ratelimit
|
|
|
|
import (
|
|
"log/slog"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
// Config defines rate limiting configuration.
|
|
type Config struct {
|
|
// RequestsPerSecond is the number of requests allowed per second per IP.
|
|
RequestsPerSecond float64
|
|
// Burst is the maximum burst size allowed.
|
|
Burst int
|
|
// Enabled controls whether rate limiting is active.
|
|
Enabled bool
|
|
}
|
|
|
|
// Middleware provides per-IP rate limiting using token bucket algorithm.
|
|
type Middleware struct {
|
|
limiters map[string]*rate.Limiter
|
|
mu sync.RWMutex
|
|
config Config
|
|
logger *slog.Logger
|
|
}
|
|
|
|
// New creates a new rate limiting middleware.
|
|
func New(config Config, logger *slog.Logger) *Middleware <span class="cov8" title="1">{
|
|
m := &Middleware{
|
|
limiters: make(map[string]*rate.Limiter),
|
|
config: config,
|
|
logger: logger,
|
|
}
|
|
|
|
// Start cleanup goroutine to remove old limiters
|
|
if config.Enabled </span><span class="cov8" title="1">{
|
|
go m.cleanupLimiters()
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">return m</span>
|
|
}
|
|
|
|
// Handler wraps an http.Handler with rate limiting.
|
|
func (m *Middleware) Handler(next http.Handler) http.Handler <span class="cov8" title="1">{
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) </span><span class="cov8" title="1">{
|
|
if !m.config.Enabled </span><span class="cov8" title="1">{
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}</span>
|
|
|
|
// Extract client IP (handle X-Forwarded-For for proxies)
|
|
<span class="cov8" title="1">ip := m.getClientIP(r)
|
|
|
|
limiter := m.getLimiter(ip)
|
|
if !limiter.Allow() </span><span class="cov8" title="1">{
|
|
m.logger.Warn("rate limit exceeded",
|
|
slog.String("ip", ip),
|
|
slog.String("path", r.URL.Path),
|
|
)
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Header().Set("Retry-After", "1")
|
|
w.WriteHeader(http.StatusTooManyRequests)
|
|
w.Write([]byte(`{"error":"rate limit exceeded","message":"too many requests"}`))
|
|
return
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">next.ServeHTTP(w, r)</span>
|
|
})
|
|
}
|
|
|
|
// getLimiter returns the rate limiter for a given IP, creating one if needed.
|
|
func (m *Middleware) getLimiter(ip string) *rate.Limiter <span class="cov8" title="1">{
|
|
m.mu.RLock()
|
|
limiter, exists := m.limiters[ip]
|
|
m.mu.RUnlock()
|
|
|
|
if exists </span><span class="cov8" title="1">{
|
|
return limiter
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
// Double-check after acquiring write lock
|
|
limiter, exists = m.limiters[ip]
|
|
if exists </span><span class="cov0" title="0">{
|
|
return limiter
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">limiter = rate.NewLimiter(rate.Limit(m.config.RequestsPerSecond), m.config.Burst)
|
|
m.limiters[ip] = limiter
|
|
return limiter</span>
|
|
}
|
|
|
|
// getClientIP extracts the client IP from the request.
|
|
func (m *Middleware) getClientIP(r *http.Request) string <span class="cov8" title="1">{
|
|
// Check X-Forwarded-For header (for proxies/load balancers)
|
|
xff := r.Header.Get("X-Forwarded-For")
|
|
if xff != "" </span><span class="cov8" title="1">{
|
|
// X-Forwarded-For can be a comma-separated list, use the first IP
|
|
for idx := 0; idx < len(xff); idx++ </span><span class="cov8" title="1">{
|
|
if xff[idx] == ',' </span><span class="cov8" title="1">{
|
|
return xff[:idx]
|
|
}</span>
|
|
}
|
|
<span class="cov0" title="0">return xff</span>
|
|
}
|
|
|
|
// Check X-Real-IP header
|
|
<span class="cov8" title="1">if xri := r.Header.Get("X-Real-IP"); xri != "" </span><span class="cov8" title="1">{
|
|
return xri
|
|
}</span>
|
|
|
|
// Fall back to RemoteAddr
|
|
<span class="cov8" title="1">return r.RemoteAddr</span>
|
|
}
|
|
|
|
// cleanupLimiters periodically removes unused limiters to prevent memory leaks.
|
|
func (m *Middleware) cleanupLimiters() <span class="cov8" title="1">{
|
|
ticker := time.NewTicker(5 * time.Minute)
|
|
defer ticker.Stop()
|
|
|
|
for range ticker.C </span><span class="cov0" title="0">{
|
|
m.mu.Lock()
|
|
// Clear all limiters periodically
|
|
// In production, you might want a more sophisticated LRU cache
|
|
m.limiters = make(map[string]*rate.Limiter)
|
|
m.mu.Unlock()
|
|
|
|
m.logger.Debug("cleaned up rate limiters")
|
|
}</span>
|
|
}
|
|
</pre>
|
|
|
|
<pre class="file" id="file26" style="display: none">package server
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"time"
|
|
)
|
|
|
|
// HealthStatus represents the health check response.
|
|
type HealthStatus struct {
|
|
Status string `json:"status"`
|
|
Timestamp int64 `json:"timestamp"`
|
|
Checks map[string]string `json:"checks,omitempty"`
|
|
}
|
|
|
|
// handleHealth returns a basic health check endpoint.
|
|
// This is suitable for Kubernetes liveness probes.
|
|
func (s *GatewayServer) handleHealth(w http.ResponseWriter, r *http.Request) <span class="cov8" title="1">{
|
|
if r.Method != http.MethodGet </span><span class="cov8" title="1">{
|
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">status := HealthStatus{
|
|
Status: "healthy",
|
|
Timestamp: time.Now().Unix(),
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
if err := json.NewEncoder(w).Encode(status); err != nil </span><span class="cov0" title="0">{
|
|
s.logger.ErrorContext(r.Context(), "failed to encode health response", "error", err.Error())
|
|
}</span>
|
|
}
|
|
|
|
// handleReady returns a readiness check that verifies dependencies.
|
|
// This is suitable for Kubernetes readiness probes and load balancer health checks.
|
|
func (s *GatewayServer) handleReady(w http.ResponseWriter, r *http.Request) <span class="cov8" title="1">{
|
|
if r.Method != http.MethodGet </span><span class="cov8" title="1">{
|
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">checks := make(map[string]string)
|
|
allHealthy := true
|
|
|
|
// Check conversation store connectivity
|
|
ctx, cancel := context.WithTimeout(r.Context(), 2*time.Second)
|
|
defer cancel()
|
|
|
|
// Test conversation store by attempting a simple operation
|
|
testID := "health_check_test"
|
|
_, err := s.convs.Get(ctx, testID)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
checks["conversation_store"] = "unhealthy: " + err.Error()
|
|
allHealthy = false
|
|
}</span> else<span class="cov8" title="1"> {
|
|
checks["conversation_store"] = "healthy"
|
|
}</span>
|
|
|
|
// Check if at least one provider is configured
|
|
<span class="cov8" title="1">models := s.registry.Models()
|
|
if len(models) == 0 </span><span class="cov8" title="1">{
|
|
checks["providers"] = "unhealthy: no providers configured"
|
|
allHealthy = false
|
|
}</span> else<span class="cov8" title="1"> {
|
|
checks["providers"] = "healthy"
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">_ = ctx // Use context if needed
|
|
|
|
status := HealthStatus{
|
|
Timestamp: time.Now().Unix(),
|
|
Checks: checks,
|
|
}
|
|
|
|
if allHealthy </span><span class="cov8" title="1">{
|
|
status.Status = "ready"
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
}</span> else<span class="cov8" title="1"> {
|
|
status.Status = "not_ready"
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusServiceUnavailable)
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">if err := json.NewEncoder(w).Encode(status); err != nil </span><span class="cov0" title="0">{
|
|
s.logger.ErrorContext(r.Context(), "failed to encode ready response", "error", err.Error())
|
|
}</span>
|
|
}
|
|
</pre>
|
|
|
|
<pre class="file" id="file27" style="display: none">package server
|
|
|
|
import (
|
|
"fmt"
|
|
"log/slog"
|
|
"net/http"
|
|
"runtime/debug"
|
|
|
|
"github.com/ajac-zero/latticelm/internal/logger"
|
|
)
|
|
|
|
// MaxRequestBodyBytes is the maximum size allowed for request bodies (10MB)
|
|
const MaxRequestBodyBytes = 10 * 1024 * 1024
|
|
|
|
// PanicRecoveryMiddleware recovers from panics in HTTP handlers and logs them
|
|
// instead of crashing the server. Returns 500 Internal Server Error to the client.
|
|
func PanicRecoveryMiddleware(next http.Handler, log *slog.Logger) http.Handler <span class="cov8" title="1">{
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) </span><span class="cov8" title="1">{
|
|
defer func() </span><span class="cov8" title="1">{
|
|
if err := recover(); err != nil </span><span class="cov8" title="1">{
|
|
// Capture stack trace
|
|
stack := debug.Stack()
|
|
|
|
// Log the panic with full context
|
|
log.ErrorContext(r.Context(), "panic recovered in HTTP handler",
|
|
logger.LogAttrsWithTrace(r.Context(),
|
|
slog.String("request_id", logger.FromContext(r.Context())),
|
|
slog.String("method", r.Method),
|
|
slog.String("path", r.URL.Path),
|
|
slog.String("remote_addr", r.RemoteAddr),
|
|
slog.Any("panic", err),
|
|
slog.String("stack", string(stack)),
|
|
)...,
|
|
)
|
|
|
|
// Return 500 to client
|
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
}</span>
|
|
}()
|
|
|
|
<span class="cov8" title="1">next.ServeHTTP(w, r)</span>
|
|
})
|
|
}
|
|
|
|
// RequestSizeLimitMiddleware enforces a maximum request body size to prevent
|
|
// DoS attacks via oversized payloads. Requests exceeding the limit receive 413.
|
|
func RequestSizeLimitMiddleware(next http.Handler, maxBytes int64) http.Handler <span class="cov8" title="1">{
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) </span><span class="cov8" title="1">{
|
|
// Only limit body size for requests that have a body
|
|
if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodPatch </span><span class="cov8" title="1">{
|
|
// Wrap the request body with a size limiter
|
|
r.Body = http.MaxBytesReader(w, r.Body, maxBytes)
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">next.ServeHTTP(w, r)</span>
|
|
})
|
|
}
|
|
|
|
// ErrorRecoveryMiddleware catches errors from MaxBytesReader and converts them
|
|
// to proper HTTP error responses. This should be placed after RequestSizeLimitMiddleware
|
|
// in the middleware chain.
|
|
func ErrorRecoveryMiddleware(next http.Handler, log *slog.Logger) http.Handler <span class="cov0" title="0">{
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) </span><span class="cov0" title="0">{
|
|
next.ServeHTTP(w, r)
|
|
|
|
// Check if the request body exceeded the size limit
|
|
// MaxBytesReader sets an error that we can detect on the next read attempt
|
|
// But we need to handle the error when it actually occurs during JSON decoding
|
|
// The JSON decoder will return the error, so we don't need special handling here
|
|
// This middleware is more for future extensibility
|
|
}</span>)
|
|
}
|
|
|
|
// WriteJSONError is a helper function to safely write JSON error responses,
|
|
// handling any encoding errors that might occur.
|
|
func WriteJSONError(w http.ResponseWriter, log *slog.Logger, message string, statusCode int) <span class="cov8" title="1">{
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(statusCode)
|
|
|
|
// Use fmt.Fprintf to write the error response
|
|
// This is safer than json.Encoder as we control the format
|
|
_, err := fmt.Fprintf(w, `{"error":{"message":"%s"}}`, message)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
// If we can't even write the error response, log it
|
|
log.Error("failed to write error response",
|
|
slog.String("original_message", message),
|
|
slog.Int("status_code", statusCode),
|
|
slog.String("write_error", err.Error()),
|
|
)
|
|
}</span>
|
|
}
|
|
</pre>
|
|
|
|
<pre class="file" id="file28" style="display: none">package server
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/sony/gobreaker"
|
|
|
|
"github.com/ajac-zero/latticelm/internal/api"
|
|
"github.com/ajac-zero/latticelm/internal/conversation"
|
|
"github.com/ajac-zero/latticelm/internal/logger"
|
|
"github.com/ajac-zero/latticelm/internal/providers"
|
|
)
|
|
|
|
// ProviderRegistry is an interface for provider registries.
|
|
type ProviderRegistry interface {
|
|
Get(name string) (providers.Provider, bool)
|
|
Models() []struct{ Provider, Model string }
|
|
ResolveModelID(model string) string
|
|
Default(model string) (providers.Provider, error)
|
|
}
|
|
|
|
// GatewayServer hosts the Open Responses API for the gateway.
|
|
type GatewayServer struct {
|
|
registry ProviderRegistry
|
|
convs conversation.Store
|
|
logger *slog.Logger
|
|
}
|
|
|
|
// New creates a GatewayServer bound to the provider registry.
|
|
func New(registry ProviderRegistry, convs conversation.Store, logger *slog.Logger) *GatewayServer <span class="cov8" title="1">{
|
|
return &GatewayServer{
|
|
registry: registry,
|
|
convs: convs,
|
|
logger: logger,
|
|
}
|
|
}</span>
|
|
|
|
// isCircuitBreakerError checks if the error is from a circuit breaker.
|
|
func isCircuitBreakerError(err error) bool <span class="cov8" title="1">{
|
|
return errors.Is(err, gobreaker.ErrOpenState) || errors.Is(err, gobreaker.ErrTooManyRequests)
|
|
}</span>
|
|
|
|
// RegisterRoutes wires the HTTP handlers onto the provided mux.
|
|
func (s *GatewayServer) RegisterRoutes(mux *http.ServeMux) <span class="cov0" title="0">{
|
|
mux.HandleFunc("/v1/responses", s.handleResponses)
|
|
mux.HandleFunc("/v1/models", s.handleModels)
|
|
mux.HandleFunc("/health", s.handleHealth)
|
|
mux.HandleFunc("/ready", s.handleReady)
|
|
}</span>
|
|
|
|
func (s *GatewayServer) handleModels(w http.ResponseWriter, r *http.Request) <span class="cov8" title="1">{
|
|
if r.Method != http.MethodGet </span><span class="cov8" title="1">{
|
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">models := s.registry.Models()
|
|
var data []api.ModelInfo
|
|
for _, m := range models </span><span class="cov8" title="1">{
|
|
data = append(data, api.ModelInfo{
|
|
ID: m.Model,
|
|
Provider: m.Provider,
|
|
})
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">resp := api.ModelsResponse{
|
|
Object: "list",
|
|
Data: data,
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
if err := json.NewEncoder(w).Encode(resp); err != nil </span><span class="cov0" title="0">{
|
|
s.logger.ErrorContext(r.Context(), "failed to encode models response",
|
|
logger.LogAttrsWithTrace(r.Context(),
|
|
slog.String("request_id", logger.FromContext(r.Context())),
|
|
slog.String("error", err.Error()),
|
|
)...,
|
|
)
|
|
}</span>
|
|
}
|
|
|
|
func (s *GatewayServer) handleResponses(w http.ResponseWriter, r *http.Request) <span class="cov8" title="1">{
|
|
if r.Method != http.MethodPost </span><span class="cov8" title="1">{
|
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">var req api.ResponseRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil </span><span class="cov8" title="1">{
|
|
// Check if error is due to request size limit
|
|
if err.Error() == "http: request body too large" </span><span class="cov0" title="0">{
|
|
http.Error(w, "request body too large", http.StatusRequestEntityTooLarge)
|
|
return
|
|
}</span>
|
|
<span class="cov8" title="1">http.Error(w, "invalid JSON payload", http.StatusBadRequest)
|
|
return</span>
|
|
}
|
|
|
|
<span class="cov8" title="1">if err := req.Validate(); err != nil </span><span class="cov8" title="1">{
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
return
|
|
}</span>
|
|
|
|
// Normalize input to internal messages
|
|
<span class="cov8" title="1">inputMsgs := req.NormalizeInput()
|
|
|
|
// Build full message history from previous conversation
|
|
var historyMsgs []api.Message
|
|
if req.PreviousResponseID != nil && *req.PreviousResponseID != "" </span><span class="cov8" title="1">{
|
|
conv, err := s.convs.Get(r.Context(), *req.PreviousResponseID)
|
|
if err != nil </span><span class="cov8" title="1">{
|
|
s.logger.ErrorContext(r.Context(), "failed to retrieve conversation",
|
|
logger.LogAttrsWithTrace(r.Context(),
|
|
slog.String("request_id", logger.FromContext(r.Context())),
|
|
slog.String("conversation_id", *req.PreviousResponseID),
|
|
slog.String("error", err.Error()),
|
|
)...,
|
|
)
|
|
http.Error(w, "error retrieving conversation", http.StatusInternalServerError)
|
|
return
|
|
}</span>
|
|
<span class="cov8" title="1">if conv == nil </span><span class="cov8" title="1">{
|
|
s.logger.WarnContext(r.Context(), "conversation not found",
|
|
slog.String("request_id", logger.FromContext(r.Context())),
|
|
slog.String("conversation_id", *req.PreviousResponseID),
|
|
)
|
|
http.Error(w, "conversation not found", http.StatusNotFound)
|
|
return
|
|
}</span>
|
|
<span class="cov8" title="1">historyMsgs = conv.Messages</span>
|
|
}
|
|
|
|
// Combined messages for conversation storage (history + new input, no instructions)
|
|
<span class="cov8" title="1">storeMsgs := make([]api.Message, 0, len(historyMsgs)+len(inputMsgs))
|
|
storeMsgs = append(storeMsgs, historyMsgs...)
|
|
storeMsgs = append(storeMsgs, inputMsgs...)
|
|
|
|
// Build provider messages: instructions + history + input
|
|
var providerMsgs []api.Message
|
|
if req.Instructions != nil && *req.Instructions != "" </span><span class="cov8" title="1">{
|
|
providerMsgs = append(providerMsgs, api.Message{
|
|
Role: "developer",
|
|
Content: []api.ContentBlock{{Type: "input_text", Text: *req.Instructions}},
|
|
})
|
|
}</span>
|
|
<span class="cov8" title="1">providerMsgs = append(providerMsgs, storeMsgs...)
|
|
|
|
provider, err := s.resolveProvider(&req)
|
|
if err != nil </span><span class="cov8" title="1">{
|
|
http.Error(w, err.Error(), http.StatusBadGateway)
|
|
return
|
|
}</span>
|
|
|
|
// Resolve provider_model_id (e.g., Azure deployment name)
|
|
<span class="cov8" title="1">resolvedReq := req
|
|
resolvedReq.Model = s.registry.ResolveModelID(req.Model)
|
|
|
|
if req.Stream </span><span class="cov8" title="1">{
|
|
s.handleStreamingResponse(w, r, provider, providerMsgs, &resolvedReq, &req, storeMsgs)
|
|
}</span> else<span class="cov8" title="1"> {
|
|
s.handleSyncResponse(w, r, provider, providerMsgs, &resolvedReq, &req, storeMsgs)
|
|
}</span>
|
|
}
|
|
|
|
func (s *GatewayServer) handleSyncResponse(w http.ResponseWriter, r *http.Request, provider providers.Provider, providerMsgs []api.Message, resolvedReq *api.ResponseRequest, origReq *api.ResponseRequest, storeMsgs []api.Message) <span class="cov8" title="1">{
|
|
result, err := provider.Generate(r.Context(), providerMsgs, resolvedReq)
|
|
if err != nil </span><span class="cov8" title="1">{
|
|
s.logger.ErrorContext(r.Context(), "provider generation failed",
|
|
logger.LogAttrsWithTrace(r.Context(),
|
|
slog.String("request_id", logger.FromContext(r.Context())),
|
|
slog.String("provider", provider.Name()),
|
|
slog.String("model", resolvedReq.Model),
|
|
slog.String("error", err.Error()),
|
|
)...,
|
|
)
|
|
|
|
// Check if error is from circuit breaker
|
|
if isCircuitBreakerError(err) </span><span class="cov0" title="0">{
|
|
http.Error(w, "service temporarily unavailable - circuit breaker open", http.StatusServiceUnavailable)
|
|
}</span> else<span class="cov8" title="1"> {
|
|
http.Error(w, "provider error", http.StatusBadGateway)
|
|
}</span>
|
|
<span class="cov8" title="1">return</span>
|
|
}
|
|
|
|
<span class="cov8" title="1">responseID := generateID("resp_")
|
|
|
|
// Build assistant message for conversation store
|
|
assistantMsg := api.Message{
|
|
Role: "assistant",
|
|
Content: []api.ContentBlock{{Type: "output_text", Text: result.Text}},
|
|
ToolCalls: result.ToolCalls,
|
|
}
|
|
allMsgs := append(storeMsgs, assistantMsg)
|
|
if _, err := s.convs.Create(r.Context(), responseID, result.Model, allMsgs); err != nil </span><span class="cov0" title="0">{
|
|
s.logger.ErrorContext(r.Context(), "failed to store conversation",
|
|
logger.LogAttrsWithTrace(r.Context(),
|
|
slog.String("request_id", logger.FromContext(r.Context())),
|
|
slog.String("response_id", responseID),
|
|
slog.String("error", err.Error()),
|
|
)...,
|
|
)
|
|
// Don't fail the response if storage fails
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">s.logger.InfoContext(r.Context(), "response generated",
|
|
logger.LogAttrsWithTrace(r.Context(),
|
|
slog.String("request_id", logger.FromContext(r.Context())),
|
|
slog.String("provider", provider.Name()),
|
|
slog.String("model", result.Model),
|
|
slog.String("response_id", responseID),
|
|
slog.Int("input_tokens", result.Usage.InputTokens),
|
|
slog.Int("output_tokens", result.Usage.OutputTokens),
|
|
slog.Bool("has_tool_calls", len(result.ToolCalls) > 0),
|
|
)...,
|
|
)
|
|
|
|
// Build spec-compliant response
|
|
resp := s.buildResponse(origReq, result, provider.Name(), responseID)
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
if err := json.NewEncoder(w).Encode(resp); err != nil </span><span class="cov0" title="0">{
|
|
s.logger.ErrorContext(r.Context(), "failed to encode response",
|
|
logger.LogAttrsWithTrace(r.Context(),
|
|
slog.String("request_id", logger.FromContext(r.Context())),
|
|
slog.String("response_id", responseID),
|
|
slog.String("error", err.Error()),
|
|
)...,
|
|
)
|
|
}</span>
|
|
}
|
|
|
|
func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.Request, provider providers.Provider, providerMsgs []api.Message, resolvedReq *api.ResponseRequest, origReq *api.ResponseRequest, storeMsgs []api.Message) <span class="cov8" title="1">{
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.Header().Set("Cache-Control", "no-cache")
|
|
w.Header().Set("Connection", "keep-alive")
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
flusher, ok := w.(http.Flusher)
|
|
if !ok </span><span class="cov0" title="0">{
|
|
http.Error(w, "streaming not supported", http.StatusInternalServerError)
|
|
return
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">responseID := generateID("resp_")
|
|
itemID := generateID("msg_")
|
|
seq := 0
|
|
outputIdx := 0
|
|
contentIdx := 0
|
|
|
|
// Build initial response snapshot (in_progress, no output yet)
|
|
initialResp := s.buildResponse(origReq, &api.ProviderResult{
|
|
Model: origReq.Model,
|
|
}, provider.Name(), responseID)
|
|
initialResp.Status = "in_progress"
|
|
initialResp.CompletedAt = nil
|
|
initialResp.Output = []api.OutputItem{}
|
|
initialResp.Usage = nil
|
|
|
|
// response.created
|
|
s.sendSSE(w, flusher, &seq, "response.created", &api.StreamEvent{
|
|
Type: "response.created",
|
|
Response: initialResp,
|
|
})
|
|
|
|
// response.in_progress
|
|
s.sendSSE(w, flusher, &seq, "response.in_progress", &api.StreamEvent{
|
|
Type: "response.in_progress",
|
|
Response: initialResp,
|
|
})
|
|
|
|
// response.output_item.added
|
|
inProgressItem := &api.OutputItem{
|
|
ID: itemID,
|
|
Type: "message",
|
|
Status: "in_progress",
|
|
Role: "assistant",
|
|
Content: []api.ContentPart{},
|
|
}
|
|
s.sendSSE(w, flusher, &seq, "response.output_item.added", &api.StreamEvent{
|
|
Type: "response.output_item.added",
|
|
OutputIndex: &outputIdx,
|
|
Item: inProgressItem,
|
|
})
|
|
|
|
// response.content_part.added
|
|
emptyPart := &api.ContentPart{
|
|
Type: "output_text",
|
|
Text: "",
|
|
Annotations: []api.Annotation{},
|
|
}
|
|
s.sendSSE(w, flusher, &seq, "response.content_part.added", &api.StreamEvent{
|
|
Type: "response.content_part.added",
|
|
ItemID: itemID,
|
|
OutputIndex: &outputIdx,
|
|
ContentIndex: &contentIdx,
|
|
Part: emptyPart,
|
|
})
|
|
|
|
// Start provider stream
|
|
deltaChan, errChan := provider.GenerateStream(r.Context(), providerMsgs, resolvedReq)
|
|
|
|
var fullText string
|
|
var streamErr error
|
|
var providerModel string
|
|
|
|
// Track tool calls being built
|
|
type toolCallBuilder struct {
|
|
itemID string
|
|
id string
|
|
name string
|
|
arguments string
|
|
}
|
|
toolCallsInProgress := make(map[int]*toolCallBuilder)
|
|
nextOutputIdx := 0
|
|
textItemAdded := false
|
|
|
|
loop:
|
|
for </span><span class="cov8" title="1">{
|
|
select </span>{
|
|
case delta, ok := <-deltaChan:<span class="cov8" title="1">
|
|
if !ok </span><span class="cov0" title="0">{
|
|
break loop</span>
|
|
}
|
|
<span class="cov8" title="1">if delta.Model != "" && providerModel == "" </span><span class="cov8" title="1">{
|
|
providerModel = delta.Model
|
|
}</span>
|
|
|
|
// Handle text content
|
|
<span class="cov8" title="1">if delta.Text != "" </span><span class="cov8" title="1">{
|
|
// Add text item on first text delta
|
|
if !textItemAdded </span><span class="cov8" title="1">{
|
|
textItemAdded = true
|
|
nextOutputIdx++
|
|
}</span>
|
|
<span class="cov8" title="1">fullText += delta.Text
|
|
s.sendSSE(w, flusher, &seq, "response.output_text.delta", &api.StreamEvent{
|
|
Type: "response.output_text.delta",
|
|
ItemID: itemID,
|
|
OutputIndex: &outputIdx,
|
|
ContentIndex: &contentIdx,
|
|
Delta: delta.Text,
|
|
})</span>
|
|
}
|
|
|
|
// Handle tool call delta
|
|
<span class="cov8" title="1">if delta.ToolCallDelta != nil </span><span class="cov8" title="1">{
|
|
tc := delta.ToolCallDelta
|
|
|
|
// First chunk for this tool call index
|
|
if _, exists := toolCallsInProgress[tc.Index]; !exists </span><span class="cov8" title="1">{
|
|
toolItemID := generateID("item_")
|
|
toolOutputIdx := nextOutputIdx
|
|
nextOutputIdx++
|
|
|
|
// Send response.output_item.added
|
|
s.sendSSE(w, flusher, &seq, "response.output_item.added", &api.StreamEvent{
|
|
Type: "response.output_item.added",
|
|
OutputIndex: &toolOutputIdx,
|
|
Item: &api.OutputItem{
|
|
ID: toolItemID,
|
|
Type: "function_call",
|
|
Status: "in_progress",
|
|
CallID: tc.ID,
|
|
Name: tc.Name,
|
|
},
|
|
})
|
|
|
|
toolCallsInProgress[tc.Index] = &toolCallBuilder{
|
|
itemID: toolItemID,
|
|
id: tc.ID,
|
|
name: tc.Name,
|
|
arguments: "",
|
|
}
|
|
}</span>
|
|
|
|
// Send function_call_arguments.delta
|
|
<span class="cov8" title="1">if tc.Arguments != "" </span><span class="cov8" title="1">{
|
|
builder := toolCallsInProgress[tc.Index]
|
|
builder.arguments += tc.Arguments
|
|
toolOutputIdx := outputIdx + 1 + tc.Index
|
|
|
|
s.sendSSE(w, flusher, &seq, "response.function_call_arguments.delta", &api.StreamEvent{
|
|
Type: "response.function_call_arguments.delta",
|
|
ItemID: builder.itemID,
|
|
OutputIndex: &toolOutputIdx,
|
|
Delta: tc.Arguments,
|
|
})
|
|
}</span>
|
|
}
|
|
|
|
<span class="cov8" title="1">if delta.Done </span><span class="cov8" title="1">{
|
|
break loop</span>
|
|
}
|
|
case err := <-errChan:<span class="cov8" title="1">
|
|
if err != nil </span><span class="cov8" title="1">{
|
|
streamErr = err
|
|
}</span>
|
|
<span class="cov8" title="1">break loop</span>
|
|
case <-r.Context().Done():<span class="cov0" title="0">
|
|
s.logger.InfoContext(r.Context(), "client disconnected",
|
|
slog.String("request_id", logger.FromContext(r.Context())),
|
|
)
|
|
return</span>
|
|
}
|
|
}
|
|
|
|
<span class="cov8" title="1">if streamErr != nil </span><span class="cov8" title="1">{
|
|
s.logger.ErrorContext(r.Context(), "stream error",
|
|
logger.LogAttrsWithTrace(r.Context(),
|
|
slog.String("request_id", logger.FromContext(r.Context())),
|
|
slog.String("provider", provider.Name()),
|
|
slog.String("model", origReq.Model),
|
|
slog.String("error", streamErr.Error()),
|
|
)...,
|
|
)
|
|
|
|
// Determine error type based on circuit breaker state
|
|
errorType := "server_error"
|
|
errorMessage := streamErr.Error()
|
|
if isCircuitBreakerError(streamErr) </span><span class="cov0" title="0">{
|
|
errorType = "circuit_breaker_open"
|
|
errorMessage = "service temporarily unavailable - circuit breaker open"
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">failedResp := s.buildResponse(origReq, &api.ProviderResult{
|
|
Model: origReq.Model,
|
|
}, provider.Name(), responseID)
|
|
failedResp.Status = "failed"
|
|
failedResp.CompletedAt = nil
|
|
failedResp.Output = []api.OutputItem{}
|
|
failedResp.Error = &api.ResponseError{
|
|
Type: errorType,
|
|
Message: errorMessage,
|
|
}
|
|
s.sendSSE(w, flusher, &seq, "response.failed", &api.StreamEvent{
|
|
Type: "response.failed",
|
|
Response: failedResp,
|
|
})
|
|
return</span>
|
|
}
|
|
|
|
// Send done events for text output if text was added
|
|
<span class="cov8" title="1">if textItemAdded && fullText != "" </span><span class="cov8" title="1">{
|
|
// response.output_text.done
|
|
s.sendSSE(w, flusher, &seq, "response.output_text.done", &api.StreamEvent{
|
|
Type: "response.output_text.done",
|
|
ItemID: itemID,
|
|
OutputIndex: &outputIdx,
|
|
ContentIndex: &contentIdx,
|
|
Text: fullText,
|
|
})
|
|
|
|
// response.content_part.done
|
|
completedPart := &api.ContentPart{
|
|
Type: "output_text",
|
|
Text: fullText,
|
|
Annotations: []api.Annotation{},
|
|
}
|
|
s.sendSSE(w, flusher, &seq, "response.content_part.done", &api.StreamEvent{
|
|
Type: "response.content_part.done",
|
|
ItemID: itemID,
|
|
OutputIndex: &outputIdx,
|
|
ContentIndex: &contentIdx,
|
|
Part: completedPart,
|
|
})
|
|
|
|
// response.output_item.done
|
|
completedItem := &api.OutputItem{
|
|
ID: itemID,
|
|
Type: "message",
|
|
Status: "completed",
|
|
Role: "assistant",
|
|
Content: []api.ContentPart{*completedPart},
|
|
}
|
|
s.sendSSE(w, flusher, &seq, "response.output_item.done", &api.StreamEvent{
|
|
Type: "response.output_item.done",
|
|
OutputIndex: &outputIdx,
|
|
Item: completedItem,
|
|
})
|
|
}</span>
|
|
|
|
// Send done events for each tool call
|
|
<span class="cov8" title="1">for idx, builder := range toolCallsInProgress </span><span class="cov8" title="1">{
|
|
toolOutputIdx := outputIdx + 1 + idx
|
|
|
|
s.sendSSE(w, flusher, &seq, "response.function_call_arguments.done", &api.StreamEvent{
|
|
Type: "response.function_call_arguments.done",
|
|
ItemID: builder.itemID,
|
|
OutputIndex: &toolOutputIdx,
|
|
Arguments: builder.arguments,
|
|
})
|
|
|
|
s.sendSSE(w, flusher, &seq, "response.output_item.done", &api.StreamEvent{
|
|
Type: "response.output_item.done",
|
|
OutputIndex: &toolOutputIdx,
|
|
Item: &api.OutputItem{
|
|
ID: builder.itemID,
|
|
Type: "function_call",
|
|
Status: "completed",
|
|
CallID: builder.id,
|
|
Name: builder.name,
|
|
Arguments: builder.arguments,
|
|
},
|
|
})
|
|
}</span>
|
|
|
|
// Build final completed response
|
|
<span class="cov8" title="1">model := origReq.Model
|
|
if providerModel != "" </span><span class="cov8" title="1">{
|
|
model = providerModel
|
|
}</span>
|
|
|
|
// Collect tool calls for result
|
|
<span class="cov8" title="1">var toolCalls []api.ToolCall
|
|
for _, builder := range toolCallsInProgress </span><span class="cov8" title="1">{
|
|
toolCalls = append(toolCalls, api.ToolCall{
|
|
ID: builder.id,
|
|
Name: builder.name,
|
|
Arguments: builder.arguments,
|
|
})
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">finalResult := &api.ProviderResult{
|
|
Model: model,
|
|
Text: fullText,
|
|
ToolCalls: toolCalls,
|
|
}
|
|
completedResp := s.buildResponse(origReq, finalResult, provider.Name(), responseID)
|
|
|
|
// Update item IDs to match what we sent during streaming
|
|
if textItemAdded && len(completedResp.Output) > 0 </span><span class="cov8" title="1">{
|
|
completedResp.Output[0].ID = itemID
|
|
}</span>
|
|
<span class="cov8" title="1">for idx, builder := range toolCallsInProgress </span><span class="cov8" title="1">{
|
|
// Find the corresponding output item
|
|
for i := range completedResp.Output </span><span class="cov8" title="1">{
|
|
if completedResp.Output[i].Type == "function_call" && completedResp.Output[i].CallID == builder.id </span><span class="cov8" title="1">{
|
|
completedResp.Output[i].ID = builder.itemID
|
|
break</span>
|
|
}
|
|
}
|
|
<span class="cov8" title="1">_ = idx</span> // unused
|
|
}
|
|
|
|
// response.completed
|
|
<span class="cov8" title="1">s.sendSSE(w, flusher, &seq, "response.completed", &api.StreamEvent{
|
|
Type: "response.completed",
|
|
Response: completedResp,
|
|
})
|
|
|
|
// Store conversation
|
|
if fullText != "" || len(toolCalls) > 0 </span><span class="cov8" title="1">{
|
|
assistantMsg := api.Message{
|
|
Role: "assistant",
|
|
Content: []api.ContentBlock{{Type: "output_text", Text: fullText}},
|
|
ToolCalls: toolCalls,
|
|
}
|
|
allMsgs := append(storeMsgs, assistantMsg)
|
|
if _, err := s.convs.Create(r.Context(), responseID, model, allMsgs); err != nil </span><span class="cov0" title="0">{
|
|
s.logger.ErrorContext(r.Context(), "failed to store conversation",
|
|
slog.String("request_id", logger.FromContext(r.Context())),
|
|
slog.String("response_id", responseID),
|
|
slog.String("error", err.Error()),
|
|
)
|
|
// Don't fail the response if storage fails
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">s.logger.InfoContext(r.Context(), "streaming response completed",
|
|
slog.String("request_id", logger.FromContext(r.Context())),
|
|
slog.String("provider", provider.Name()),
|
|
slog.String("model", model),
|
|
slog.String("response_id", responseID),
|
|
slog.Bool("has_tool_calls", len(toolCalls) > 0),
|
|
)</span>
|
|
}
|
|
}
|
|
|
|
func (s *GatewayServer) sendSSE(w http.ResponseWriter, flusher http.Flusher, seq *int, eventType string, event *api.StreamEvent) <span class="cov8" title="1">{
|
|
event.SequenceNumber = *seq
|
|
*seq++
|
|
data, err := json.Marshal(event)
|
|
if err != nil </span><span class="cov0" title="0">{
|
|
s.logger.Error("failed to marshal SSE event",
|
|
slog.String("event_type", eventType),
|
|
slog.String("error", err.Error()),
|
|
)
|
|
return
|
|
}</span>
|
|
<span class="cov8" title="1">fmt.Fprintf(w, "event: %s\ndata: %s\n\n", eventType, data)
|
|
flusher.Flush()</span>
|
|
}
|
|
|
|
func (s *GatewayServer) buildResponse(req *api.ResponseRequest, result *api.ProviderResult, providerName string, responseID string) *api.Response <span class="cov8" title="1">{
|
|
now := time.Now().Unix()
|
|
|
|
model := result.Model
|
|
if model == "" </span><span class="cov0" title="0">{
|
|
model = req.Model
|
|
}</span>
|
|
|
|
// Build output items array
|
|
<span class="cov8" title="1">outputItems := []api.OutputItem{}
|
|
|
|
// Add message item if there's text
|
|
if result.Text != "" </span><span class="cov8" title="1">{
|
|
outputItems = append(outputItems, api.OutputItem{
|
|
ID: generateID("msg_"),
|
|
Type: "message",
|
|
Status: "completed",
|
|
Role: "assistant",
|
|
Content: []api.ContentPart{{
|
|
Type: "output_text",
|
|
Text: result.Text,
|
|
Annotations: []api.Annotation{},
|
|
}},
|
|
})
|
|
}</span>
|
|
|
|
// Add function_call items
|
|
<span class="cov8" title="1">for _, tc := range result.ToolCalls </span><span class="cov8" title="1">{
|
|
outputItems = append(outputItems, api.OutputItem{
|
|
ID: generateID("item_"),
|
|
Type: "function_call",
|
|
Status: "completed",
|
|
CallID: tc.ID,
|
|
Name: tc.Name,
|
|
Arguments: tc.Arguments,
|
|
})
|
|
}</span>
|
|
|
|
// Echo back request params with defaults
|
|
<span class="cov8" title="1">tools := req.Tools
|
|
if tools == nil </span><span class="cov8" title="1">{
|
|
tools = json.RawMessage(`[]`)
|
|
}</span>
|
|
<span class="cov8" title="1">toolChoice := req.ToolChoice
|
|
if toolChoice == nil </span><span class="cov8" title="1">{
|
|
toolChoice = json.RawMessage(`"auto"`)
|
|
}</span>
|
|
<span class="cov8" title="1">text := req.Text
|
|
if text == nil </span><span class="cov8" title="1">{
|
|
text = json.RawMessage(`{"format":{"type":"text"}}`)
|
|
}</span>
|
|
<span class="cov8" title="1">truncation := "disabled"
|
|
if req.Truncation != nil </span><span class="cov8" title="1">{
|
|
truncation = *req.Truncation
|
|
}</span>
|
|
<span class="cov8" title="1">temperature := 1.0
|
|
if req.Temperature != nil </span><span class="cov8" title="1">{
|
|
temperature = *req.Temperature
|
|
}</span>
|
|
<span class="cov8" title="1">topP := 1.0
|
|
if req.TopP != nil </span><span class="cov8" title="1">{
|
|
topP = *req.TopP
|
|
}</span>
|
|
<span class="cov8" title="1">presencePenalty := 0.0
|
|
if req.PresencePenalty != nil </span><span class="cov8" title="1">{
|
|
presencePenalty = *req.PresencePenalty
|
|
}</span>
|
|
<span class="cov8" title="1">frequencyPenalty := 0.0
|
|
if req.FrequencyPenalty != nil </span><span class="cov8" title="1">{
|
|
frequencyPenalty = *req.FrequencyPenalty
|
|
}</span>
|
|
<span class="cov8" title="1">topLogprobs := 0
|
|
if req.TopLogprobs != nil </span><span class="cov8" title="1">{
|
|
topLogprobs = *req.TopLogprobs
|
|
}</span>
|
|
<span class="cov8" title="1">parallelToolCalls := true
|
|
if req.ParallelToolCalls != nil </span><span class="cov8" title="1">{
|
|
parallelToolCalls = *req.ParallelToolCalls
|
|
}</span>
|
|
<span class="cov8" title="1">store := true
|
|
if req.Store != nil </span><span class="cov8" title="1">{
|
|
store = *req.Store
|
|
}</span>
|
|
<span class="cov8" title="1">background := false
|
|
if req.Background != nil </span><span class="cov8" title="1">{
|
|
background = *req.Background
|
|
}</span>
|
|
<span class="cov8" title="1">serviceTier := "default"
|
|
if req.ServiceTier != nil </span><span class="cov8" title="1">{
|
|
serviceTier = *req.ServiceTier
|
|
}</span>
|
|
<span class="cov8" title="1">var reasoning json.RawMessage
|
|
if req.Reasoning != nil </span><span class="cov0" title="0">{
|
|
reasoning = req.Reasoning
|
|
}</span>
|
|
<span class="cov8" title="1">metadata := req.Metadata
|
|
if metadata == nil </span><span class="cov8" title="1">{
|
|
metadata = map[string]string{}
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">var usage *api.Usage
|
|
if result.Text != "" </span><span class="cov8" title="1">{
|
|
usage = &result.Usage
|
|
}</span>
|
|
|
|
<span class="cov8" title="1">return &api.Response{
|
|
ID: responseID,
|
|
Object: "response",
|
|
CreatedAt: now,
|
|
CompletedAt: &now,
|
|
Status: "completed",
|
|
IncompleteDetails: nil,
|
|
Model: model,
|
|
PreviousResponseID: req.PreviousResponseID,
|
|
Instructions: req.Instructions,
|
|
Output: outputItems,
|
|
Error: nil,
|
|
Tools: tools,
|
|
ToolChoice: toolChoice,
|
|
Truncation: truncation,
|
|
ParallelToolCalls: parallelToolCalls,
|
|
Text: text,
|
|
TopP: topP,
|
|
PresencePenalty: presencePenalty,
|
|
FrequencyPenalty: frequencyPenalty,
|
|
TopLogprobs: topLogprobs,
|
|
Temperature: temperature,
|
|
Reasoning: reasoning,
|
|
Usage: usage,
|
|
MaxOutputTokens: req.MaxOutputTokens,
|
|
MaxToolCalls: req.MaxToolCalls,
|
|
Store: store,
|
|
Background: background,
|
|
ServiceTier: serviceTier,
|
|
Metadata: metadata,
|
|
SafetyIdentifier: nil,
|
|
PromptCacheKey: nil,
|
|
Provider: providerName,
|
|
}</span>
|
|
}
|
|
|
|
func (s *GatewayServer) resolveProvider(req *api.ResponseRequest) (providers.Provider, error) <span class="cov8" title="1">{
|
|
if req.Provider != "" </span><span class="cov8" title="1">{
|
|
if provider, ok := s.registry.Get(req.Provider); ok </span><span class="cov8" title="1">{
|
|
return provider, nil
|
|
}</span>
|
|
<span class="cov8" title="1">return nil, fmt.Errorf("provider %s not configured", req.Provider)</span>
|
|
}
|
|
<span class="cov8" title="1">return s.registry.Default(req.Model)</span>
|
|
}
|
|
|
|
func generateID(prefix string) string <span class="cov8" title="1">{
|
|
id := strings.ReplaceAll(uuid.NewString(), "-", "")
|
|
return prefix + id[:24]
|
|
}</span>
|
|
</pre>
|
|
|
|
</div>
|
|
</body>
|
|
<script>
|
|
(function() {
|
|
var files = document.getElementById('files');
|
|
var visible;
|
|
files.addEventListener('change', onChange, false);
|
|
function select(part) {
|
|
if (visible)
|
|
visible.style.display = 'none';
|
|
visible = document.getElementById(part);
|
|
if (!visible)
|
|
return;
|
|
files.value = part;
|
|
visible.style.display = 'block';
|
|
location.hash = part;
|
|
}
|
|
function onChange() {
|
|
select(files.value);
|
|
window.scrollTo(0, 0);
|
|
}
|
|
if (location.hash != "") {
|
|
select(location.hash.substr(1));
|
|
}
|
|
if (!visible) {
|
|
select("file0");
|
|
}
|
|
})();
|
|
</script>
|
|
</html>
|