Add observabilitty and monitoring

This commit is contained in:
2026-03-03 06:39:42 +00:00
parent 2edb290563
commit b56c78fa07
15 changed files with 1549 additions and 38 deletions

View File

@@ -0,0 +1,98 @@
package observability
import (
"github.com/ajac-zero/latticelm/internal/conversation"
"github.com/ajac-zero/latticelm/internal/providers"
"github.com/prometheus/client_golang/prometheus"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
)
// ProviderRegistry defines the interface for provider registries.
// This matches the interface expected by the server.
type ProviderRegistry interface {
Get(name string) (providers.Provider, bool)
Models() []struct{ Provider, Model string }
ResolveModelID(model string) string
Default(model string) (providers.Provider, error)
}
// WrapProviderRegistry wraps all providers in a registry with observability.
func WrapProviderRegistry(registry ProviderRegistry, metricsRegistry *prometheus.Registry, tp *sdktrace.TracerProvider) ProviderRegistry {
if registry == nil {
return nil
}
// We can't directly modify the registry's internal map, so we'll need to
// wrap providers as they're retrieved. Instead, create a new instrumented registry.
return &InstrumentedRegistry{
base: registry,
metrics: metricsRegistry,
tracer: tp,
wrappedProviders: make(map[string]providers.Provider),
}
}
// InstrumentedRegistry wraps a provider registry to return instrumented providers.
type InstrumentedRegistry struct {
base ProviderRegistry
metrics *prometheus.Registry
tracer *sdktrace.TracerProvider
wrappedProviders map[string]providers.Provider
}
// Get returns an instrumented provider by entry name.
func (r *InstrumentedRegistry) Get(name string) (providers.Provider, bool) {
// Check if we've already wrapped this provider
if wrapped, ok := r.wrappedProviders[name]; ok {
return wrapped, true
}
// Get the base provider
p, ok := r.base.Get(name)
if !ok {
return nil, false
}
// Wrap it
wrapped := NewInstrumentedProvider(p, r.metrics, r.tracer)
r.wrappedProviders[name] = wrapped
return wrapped, true
}
// Default returns the instrumented provider for the given model name.
func (r *InstrumentedRegistry) Default(model string) (providers.Provider, error) {
p, err := r.base.Default(model)
if err != nil {
return nil, err
}
// Check if we've already wrapped this provider
name := p.Name()
if wrapped, ok := r.wrappedProviders[name]; ok {
return wrapped, nil
}
// Wrap it
wrapped := NewInstrumentedProvider(p, r.metrics, r.tracer)
r.wrappedProviders[name] = wrapped
return wrapped, nil
}
// Models returns the list of configured models and their provider entry names.
func (r *InstrumentedRegistry) Models() []struct{ Provider, Model string } {
return r.base.Models()
}
// ResolveModelID returns the provider_model_id for a model.
func (r *InstrumentedRegistry) ResolveModelID(model string) string {
return r.base.ResolveModelID(model)
}
// WrapConversationStore wraps a conversation store with observability.
func WrapConversationStore(store conversation.Store, backend string, metricsRegistry *prometheus.Registry, tp *sdktrace.TracerProvider) conversation.Store {
if store == nil {
return nil
}
return NewInstrumentedStore(store, backend, metricsRegistry, tp)
}

View File

@@ -0,0 +1,147 @@
package observability
import (
"github.com/prometheus/client_golang/prometheus"
)
var (
// HTTP Metrics
httpRequestsTotal = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "http_requests_total",
Help: "Total number of HTTP requests",
},
[]string{"method", "path", "status"},
)
httpRequestDuration = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Name: "http_request_duration_seconds",
Help: "HTTP request latency in seconds",
Buckets: []float64{0.01, 0.05, 0.1, 0.5, 1, 2.5, 5, 10, 30},
},
[]string{"method", "path", "status"},
)
httpRequestSize = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Name: "http_request_size_bytes",
Help: "HTTP request size in bytes",
Buckets: prometheus.ExponentialBuckets(100, 10, 7), // 100B to 100MB
},
[]string{"method", "path"},
)
httpResponseSize = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Name: "http_response_size_bytes",
Help: "HTTP response size in bytes",
Buckets: prometheus.ExponentialBuckets(100, 10, 7), // 100B to 100MB
},
[]string{"method", "path"},
)
// Provider Metrics
providerRequestsTotal = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "provider_requests_total",
Help: "Total number of provider requests",
},
[]string{"provider", "model", "operation", "status"},
)
providerRequestDuration = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Name: "provider_request_duration_seconds",
Help: "Provider request latency in seconds",
Buckets: []float64{0.1, 0.5, 1, 2, 5, 10, 20, 30, 60},
},
[]string{"provider", "model", "operation"},
)
providerTokensTotal = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "provider_tokens_total",
Help: "Total number of tokens processed",
},
[]string{"provider", "model", "type"}, // type: input, output
)
providerStreamTTFB = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Name: "provider_stream_ttfb_seconds",
Help: "Time to first byte for streaming requests in seconds",
Buckets: []float64{0.05, 0.1, 0.5, 1, 2, 5, 10},
},
[]string{"provider", "model"},
)
providerStreamChunks = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "provider_stream_chunks_total",
Help: "Total number of stream chunks received",
},
[]string{"provider", "model"},
)
providerStreamDuration = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Name: "provider_stream_duration_seconds",
Help: "Total duration of streaming requests in seconds",
Buckets: []float64{0.1, 0.5, 1, 2, 5, 10, 20, 30, 60},
},
[]string{"provider", "model"},
)
// Conversation Store Metrics
conversationOperationsTotal = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "conversation_operations_total",
Help: "Total number of conversation store operations",
},
[]string{"operation", "backend", "status"},
)
conversationOperationDuration = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Name: "conversation_operation_duration_seconds",
Help: "Conversation store operation latency in seconds",
Buckets: []float64{0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1},
},
[]string{"operation", "backend"},
)
conversationActiveCount = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "conversation_active_count",
Help: "Number of active conversations",
},
[]string{"backend"},
)
)
// InitMetrics registers all metrics with a new Prometheus registry.
func InitMetrics() *prometheus.Registry {
registry := prometheus.NewRegistry()
// Register HTTP metrics
registry.MustRegister(httpRequestsTotal)
registry.MustRegister(httpRequestDuration)
registry.MustRegister(httpRequestSize)
registry.MustRegister(httpResponseSize)
// Register provider metrics
registry.MustRegister(providerRequestsTotal)
registry.MustRegister(providerRequestDuration)
registry.MustRegister(providerTokensTotal)
registry.MustRegister(providerStreamTTFB)
registry.MustRegister(providerStreamChunks)
registry.MustRegister(providerStreamDuration)
// Register conversation store metrics
registry.MustRegister(conversationOperationsTotal)
registry.MustRegister(conversationOperationDuration)
registry.MustRegister(conversationActiveCount)
return registry
}

View File

@@ -0,0 +1,62 @@
package observability
import (
"net/http"
"strconv"
"time"
"github.com/prometheus/client_golang/prometheus"
)
// MetricsMiddleware creates a middleware that records HTTP metrics.
func MetricsMiddleware(next http.Handler, registry *prometheus.Registry, _ interface{}) http.Handler {
if registry == nil {
// If metrics are not enabled, pass through without modification
return next
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// Record request size
if r.ContentLength > 0 {
httpRequestSize.WithLabelValues(r.Method, r.URL.Path).Observe(float64(r.ContentLength))
}
// Wrap response writer to capture status code and response size
wrapped := &metricsResponseWriter{
ResponseWriter: w,
statusCode: http.StatusOK,
bytesWritten: 0,
}
// Call the next handler
next.ServeHTTP(wrapped, r)
// Record metrics after request completes
duration := time.Since(start).Seconds()
status := strconv.Itoa(wrapped.statusCode)
httpRequestsTotal.WithLabelValues(r.Method, r.URL.Path, status).Inc()
httpRequestDuration.WithLabelValues(r.Method, r.URL.Path, status).Observe(duration)
httpResponseSize.WithLabelValues(r.Method, r.URL.Path).Observe(float64(wrapped.bytesWritten))
})
}
// metricsResponseWriter wraps http.ResponseWriter to capture status code and bytes written.
type metricsResponseWriter struct {
http.ResponseWriter
statusCode int
bytesWritten int
}
func (w *metricsResponseWriter) WriteHeader(statusCode int) {
w.statusCode = statusCode
w.ResponseWriter.WriteHeader(statusCode)
}
func (w *metricsResponseWriter) Write(b []byte) (int, error) {
n, err := w.ResponseWriter.Write(b)
w.bytesWritten += n
return n, err
}

View File

@@ -0,0 +1,208 @@
package observability
import (
"context"
"time"
"github.com/ajac-zero/latticelm/internal/api"
"github.com/ajac-zero/latticelm/internal/providers"
"github.com/prometheus/client_golang/prometheus"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
"go.opentelemetry.io/otel/trace"
)
// InstrumentedProvider wraps a provider with metrics and tracing.
type InstrumentedProvider struct {
base providers.Provider
registry *prometheus.Registry
tracer trace.Tracer
}
// NewInstrumentedProvider wraps a provider with observability.
func NewInstrumentedProvider(p providers.Provider, registry *prometheus.Registry, tp *sdktrace.TracerProvider) providers.Provider {
var tracer trace.Tracer
if tp != nil {
tracer = tp.Tracer("llm-gateway")
}
return &InstrumentedProvider{
base: p,
registry: registry,
tracer: tracer,
}
}
// Name returns the name of the underlying provider.
func (p *InstrumentedProvider) Name() string {
return p.base.Name()
}
// Generate wraps the provider's Generate method with metrics and tracing.
func (p *InstrumentedProvider) Generate(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (*api.ProviderResult, error) {
// Start span if tracing is enabled
if p.tracer != nil {
var span trace.Span
ctx, span = p.tracer.Start(ctx, "provider.generate",
trace.WithSpanKind(trace.SpanKindClient),
trace.WithAttributes(
attribute.String("provider.name", p.base.Name()),
attribute.String("provider.model", req.Model),
),
)
defer span.End()
}
// Record start time
start := time.Now()
// Call underlying provider
result, err := p.base.Generate(ctx, messages, req)
// Record metrics
duration := time.Since(start).Seconds()
status := "success"
if err != nil {
status = "error"
if p.tracer != nil {
span := trace.SpanFromContext(ctx)
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
}
} else if result != nil {
// Add token attributes to span
if p.tracer != nil {
span := trace.SpanFromContext(ctx)
span.SetAttributes(
attribute.Int64("provider.input_tokens", int64(result.Usage.InputTokens)),
attribute.Int64("provider.output_tokens", int64(result.Usage.OutputTokens)),
attribute.Int64("provider.total_tokens", int64(result.Usage.TotalTokens)),
)
span.SetStatus(codes.Ok, "")
}
// Record token metrics
if p.registry != nil {
providerTokensTotal.WithLabelValues(p.base.Name(), req.Model, "input").Add(float64(result.Usage.InputTokens))
providerTokensTotal.WithLabelValues(p.base.Name(), req.Model, "output").Add(float64(result.Usage.OutputTokens))
}
}
// Record request metrics
if p.registry != nil {
providerRequestsTotal.WithLabelValues(p.base.Name(), req.Model, "generate", status).Inc()
providerRequestDuration.WithLabelValues(p.base.Name(), req.Model, "generate").Observe(duration)
}
return result, err
}
// GenerateStream wraps the provider's GenerateStream method with metrics and tracing.
func (p *InstrumentedProvider) GenerateStream(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (<-chan *api.ProviderStreamDelta, <-chan error) {
// Start span if tracing is enabled
if p.tracer != nil {
var span trace.Span
ctx, span = p.tracer.Start(ctx, "provider.generate_stream",
trace.WithSpanKind(trace.SpanKindClient),
trace.WithAttributes(
attribute.String("provider.name", p.base.Name()),
attribute.String("provider.model", req.Model),
),
)
defer span.End()
}
// Record start time
start := time.Now()
var ttfb time.Duration
firstChunk := true
// Create instrumented channels
baseChan, baseErrChan := p.base.GenerateStream(ctx, messages, req)
outChan := make(chan *api.ProviderStreamDelta)
outErrChan := make(chan error, 1)
// Metrics tracking
var chunkCount int64
var totalInputTokens, totalOutputTokens int64
var streamErr error
go func() {
defer close(outChan)
defer close(outErrChan)
for {
select {
case delta, ok := <-baseChan:
if !ok {
// Stream finished - record final metrics
duration := time.Since(start).Seconds()
status := "success"
if streamErr != nil {
status = "error"
if p.tracer != nil {
span := trace.SpanFromContext(ctx)
span.RecordError(streamErr)
span.SetStatus(codes.Error, streamErr.Error())
}
} else {
if p.tracer != nil {
span := trace.SpanFromContext(ctx)
span.SetAttributes(
attribute.Int64("provider.input_tokens", totalInputTokens),
attribute.Int64("provider.output_tokens", totalOutputTokens),
attribute.Int64("provider.chunk_count", chunkCount),
attribute.Float64("provider.ttfb_seconds", ttfb.Seconds()),
)
span.SetStatus(codes.Ok, "")
}
// Record token metrics
if p.registry != nil && (totalInputTokens > 0 || totalOutputTokens > 0) {
providerTokensTotal.WithLabelValues(p.base.Name(), req.Model, "input").Add(float64(totalInputTokens))
providerTokensTotal.WithLabelValues(p.base.Name(), req.Model, "output").Add(float64(totalOutputTokens))
}
}
// Record stream metrics
if p.registry != nil {
providerRequestsTotal.WithLabelValues(p.base.Name(), req.Model, "generate_stream", status).Inc()
providerStreamDuration.WithLabelValues(p.base.Name(), req.Model).Observe(duration)
providerStreamChunks.WithLabelValues(p.base.Name(), req.Model).Add(float64(chunkCount))
if ttfb > 0 {
providerStreamTTFB.WithLabelValues(p.base.Name(), req.Model).Observe(ttfb.Seconds())
}
}
return
}
// Record TTFB on first chunk
if firstChunk {
ttfb = time.Since(start)
firstChunk = false
}
chunkCount++
// Track token usage
if delta.Usage != nil {
totalInputTokens = int64(delta.Usage.InputTokens)
totalOutputTokens = int64(delta.Usage.OutputTokens)
}
// Forward the delta
outChan <- delta
case err, ok := <-baseErrChan:
if ok && err != nil {
streamErr = err
outErrChan <- err
}
return
}
}
}()
return outChan, outErrChan
}

View File

@@ -0,0 +1,258 @@
package observability
import (
"context"
"time"
"github.com/ajac-zero/latticelm/internal/api"
"github.com/ajac-zero/latticelm/internal/conversation"
"github.com/prometheus/client_golang/prometheus"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
"go.opentelemetry.io/otel/trace"
)
// InstrumentedStore wraps a conversation store with metrics and tracing.
type InstrumentedStore struct {
base conversation.Store
registry *prometheus.Registry
tracer trace.Tracer
backend string
}
// NewInstrumentedStore wraps a conversation store with observability.
func NewInstrumentedStore(s conversation.Store, backend string, registry *prometheus.Registry, tp *sdktrace.TracerProvider) conversation.Store {
var tracer trace.Tracer
if tp != nil {
tracer = tp.Tracer("llm-gateway")
}
// Initialize gauge with current size
if registry != nil {
conversationActiveCount.WithLabelValues(backend).Set(float64(s.Size()))
}
return &InstrumentedStore{
base: s,
registry: registry,
tracer: tracer,
backend: backend,
}
}
// Get wraps the store's Get method with metrics and tracing.
func (s *InstrumentedStore) Get(id string) (*conversation.Conversation, error) {
ctx := context.Background()
// Start span if tracing is enabled
if s.tracer != nil {
var span trace.Span
ctx, span = s.tracer.Start(ctx, "conversation.get",
trace.WithAttributes(
attribute.String("conversation.id", id),
attribute.String("conversation.backend", s.backend),
),
)
defer span.End()
}
// Record start time
start := time.Now()
// Call underlying store
conv, err := s.base.Get(id)
// Record metrics
duration := time.Since(start).Seconds()
status := "success"
if err != nil {
status = "error"
if s.tracer != nil {
span := trace.SpanFromContext(ctx)
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
}
} else {
if s.tracer != nil {
span := trace.SpanFromContext(ctx)
if conv != nil {
span.SetAttributes(
attribute.Int("conversation.message_count", len(conv.Messages)),
attribute.String("conversation.model", conv.Model),
)
}
span.SetStatus(codes.Ok, "")
}
}
if s.registry != nil {
conversationOperationsTotal.WithLabelValues("get", s.backend, status).Inc()
conversationOperationDuration.WithLabelValues("get", s.backend).Observe(duration)
}
return conv, err
}
// Create wraps the store's Create method with metrics and tracing.
func (s *InstrumentedStore) Create(id string, model string, messages []api.Message) (*conversation.Conversation, error) {
ctx := context.Background()
// Start span if tracing is enabled
if s.tracer != nil {
var span trace.Span
ctx, span = s.tracer.Start(ctx, "conversation.create",
trace.WithAttributes(
attribute.String("conversation.id", id),
attribute.String("conversation.backend", s.backend),
attribute.String("conversation.model", model),
attribute.Int("conversation.initial_messages", len(messages)),
),
)
defer span.End()
}
// Record start time
start := time.Now()
// Call underlying store
conv, err := s.base.Create(id, model, messages)
// Record metrics
duration := time.Since(start).Seconds()
status := "success"
if err != nil {
status = "error"
if s.tracer != nil {
span := trace.SpanFromContext(ctx)
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
}
} else {
if s.tracer != nil {
span := trace.SpanFromContext(ctx)
span.SetStatus(codes.Ok, "")
}
}
if s.registry != nil {
conversationOperationsTotal.WithLabelValues("create", s.backend, status).Inc()
conversationOperationDuration.WithLabelValues("create", s.backend).Observe(duration)
// Update active count
conversationActiveCount.WithLabelValues(s.backend).Set(float64(s.base.Size()))
}
return conv, err
}
// Append wraps the store's Append method with metrics and tracing.
func (s *InstrumentedStore) Append(id string, messages ...api.Message) (*conversation.Conversation, error) {
ctx := context.Background()
// Start span if tracing is enabled
if s.tracer != nil {
var span trace.Span
ctx, span = s.tracer.Start(ctx, "conversation.append",
trace.WithAttributes(
attribute.String("conversation.id", id),
attribute.String("conversation.backend", s.backend),
attribute.Int("conversation.appended_messages", len(messages)),
),
)
defer span.End()
}
// Record start time
start := time.Now()
// Call underlying store
conv, err := s.base.Append(id, messages...)
// Record metrics
duration := time.Since(start).Seconds()
status := "success"
if err != nil {
status = "error"
if s.tracer != nil {
span := trace.SpanFromContext(ctx)
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
}
} else {
if s.tracer != nil {
span := trace.SpanFromContext(ctx)
if conv != nil {
span.SetAttributes(
attribute.Int("conversation.total_messages", len(conv.Messages)),
)
}
span.SetStatus(codes.Ok, "")
}
}
if s.registry != nil {
conversationOperationsTotal.WithLabelValues("append", s.backend, status).Inc()
conversationOperationDuration.WithLabelValues("append", s.backend).Observe(duration)
}
return conv, err
}
// Delete wraps the store's Delete method with metrics and tracing.
func (s *InstrumentedStore) Delete(id string) error {
ctx := context.Background()
// Start span if tracing is enabled
if s.tracer != nil {
var span trace.Span
ctx, span = s.tracer.Start(ctx, "conversation.delete",
trace.WithAttributes(
attribute.String("conversation.id", id),
attribute.String("conversation.backend", s.backend),
),
)
defer span.End()
}
// Record start time
start := time.Now()
// Call underlying store
err := s.base.Delete(id)
// Record metrics
duration := time.Since(start).Seconds()
status := "success"
if err != nil {
status = "error"
if s.tracer != nil {
span := trace.SpanFromContext(ctx)
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
}
} else {
if s.tracer != nil {
span := trace.SpanFromContext(ctx)
span.SetStatus(codes.Ok, "")
}
}
if s.registry != nil {
conversationOperationsTotal.WithLabelValues("delete", s.backend, status).Inc()
conversationOperationDuration.WithLabelValues("delete", s.backend).Observe(duration)
// Update active count
conversationActiveCount.WithLabelValues(s.backend).Set(float64(s.base.Size()))
}
return err
}
// Size returns the size of the underlying store.
func (s *InstrumentedStore) Size() int {
return s.base.Size()
}
// Close wraps the store's Close method.
func (s *InstrumentedStore) Close() error {
return s.base.Close()
}

View File

@@ -0,0 +1,104 @@
package observability
import (
"context"
"fmt"
"github.com/ajac-zero/latticelm/internal/config"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
"go.opentelemetry.io/otel/exporters/stdout/stdouttrace"
"go.opentelemetry.io/otel/sdk/resource"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
semconv "go.opentelemetry.io/otel/semconv/v1.24.0"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
// InitTracer initializes the OpenTelemetry tracer provider.
func InitTracer(cfg config.TracingConfig) (*sdktrace.TracerProvider, error) {
// Create resource with service information
res, err := resource.Merge(
resource.Default(),
resource.NewWithAttributes(
semconv.SchemaURL,
semconv.ServiceName(cfg.ServiceName),
),
)
if err != nil {
return nil, fmt.Errorf("failed to create resource: %w", err)
}
// Create exporter
var exporter sdktrace.SpanExporter
switch cfg.Exporter.Type {
case "otlp":
exporter, err = createOTLPExporter(cfg.Exporter)
if err != nil {
return nil, fmt.Errorf("failed to create OTLP exporter: %w", err)
}
case "stdout":
exporter, err = stdouttrace.New(
stdouttrace.WithPrettyPrint(),
)
if err != nil {
return nil, fmt.Errorf("failed to create stdout exporter: %w", err)
}
default:
return nil, fmt.Errorf("unsupported exporter type: %s", cfg.Exporter.Type)
}
// Create sampler
sampler := createSampler(cfg.Sampler)
// Create tracer provider
tp := sdktrace.NewTracerProvider(
sdktrace.WithBatcher(exporter),
sdktrace.WithResource(res),
sdktrace.WithSampler(sampler),
)
return tp, nil
}
// createOTLPExporter creates an OTLP gRPC exporter.
func createOTLPExporter(cfg config.ExporterConfig) (sdktrace.SpanExporter, error) {
opts := []otlptracegrpc.Option{
otlptracegrpc.WithEndpoint(cfg.Endpoint),
}
if cfg.Insecure {
opts = append(opts, otlptracegrpc.WithTLSCredentials(insecure.NewCredentials()))
}
if len(cfg.Headers) > 0 {
opts = append(opts, otlptracegrpc.WithHeaders(cfg.Headers))
}
// Add dial options to ensure connection
opts = append(opts, otlptracegrpc.WithDialOption(grpc.WithBlock()))
return otlptracegrpc.New(context.Background(), opts...)
}
// createSampler creates a sampler based on the configuration.
func createSampler(cfg config.SamplerConfig) sdktrace.Sampler {
switch cfg.Type {
case "always":
return sdktrace.AlwaysSample()
case "never":
return sdktrace.NeverSample()
case "probability":
return sdktrace.TraceIDRatioBased(cfg.Rate)
default:
// Default to 10% sampling
return sdktrace.TraceIDRatioBased(0.1)
}
}
// Shutdown gracefully shuts down the tracer provider.
func Shutdown(ctx context.Context, tp *sdktrace.TracerProvider) error {
if tp == nil {
return nil
}
return tp.Shutdown(ctx)
}

View File

@@ -0,0 +1,85 @@
package observability
import (
"net/http"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/propagation"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
"go.opentelemetry.io/otel/trace"
)
// TracingMiddleware creates a middleware that adds OpenTelemetry tracing to HTTP requests.
func TracingMiddleware(next http.Handler, tp *sdktrace.TracerProvider) http.Handler {
if tp == nil {
// If tracing is not enabled, pass through without modification
return next
}
// Set up W3C Trace Context propagation
otel.SetTextMapPropagator(propagation.TraceContext{})
tracer := tp.Tracer("llm-gateway")
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Extract trace context from incoming request headers
ctx := otel.GetTextMapPropagator().Extract(r.Context(), propagation.HeaderCarrier(r.Header))
// Start a new span
ctx, span := tracer.Start(ctx, "HTTP "+r.Method+" "+r.URL.Path,
trace.WithSpanKind(trace.SpanKindServer),
trace.WithAttributes(
attribute.String("http.method", r.Method),
attribute.String("http.route", r.URL.Path),
attribute.String("http.scheme", r.URL.Scheme),
attribute.String("http.host", r.Host),
attribute.String("http.user_agent", r.Header.Get("User-Agent")),
),
)
defer span.End()
// Add request ID to span if present
if requestID := r.Header.Get("X-Request-ID"); requestID != "" {
span.SetAttributes(attribute.String("http.request_id", requestID))
}
// Create a response writer wrapper to capture status code
wrapped := &statusResponseWriter{
ResponseWriter: w,
statusCode: http.StatusOK,
}
// Inject trace context into request for downstream services
r = r.WithContext(ctx)
// Call the next handler
next.ServeHTTP(wrapped, r)
// Record the status code in the span
span.SetAttributes(attribute.Int("http.status_code", wrapped.statusCode))
// Set span status based on HTTP status code
if wrapped.statusCode >= 400 {
span.SetStatus(codes.Error, http.StatusText(wrapped.statusCode))
} else {
span.SetStatus(codes.Ok, "")
}
})
}
// statusResponseWriter wraps http.ResponseWriter to capture the status code.
type statusResponseWriter struct {
http.ResponseWriter
statusCode int
}
func (w *statusResponseWriter) WriteHeader(statusCode int) {
w.statusCode = statusCode
w.ResponseWriter.WriteHeader(statusCode)
}
func (w *statusResponseWriter) Write(b []byte) (int, error) {
return w.ResponseWriter.Write(b)
}