package observability import ( "net/http" "strconv" "time" "github.com/prometheus/client_golang/prometheus" ) // MetricsMiddleware creates a middleware that records HTTP metrics. func MetricsMiddleware(next http.Handler, registry *prometheus.Registry, _ interface{}) http.Handler { if registry == nil { // If metrics are not enabled, pass through without modification return next } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() // Record request size if r.ContentLength > 0 { httpRequestSize.WithLabelValues(r.Method, r.URL.Path).Observe(float64(r.ContentLength)) } // Wrap response writer to capture status code and response size wrapped := &metricsResponseWriter{ ResponseWriter: w, statusCode: http.StatusOK, bytesWritten: 0, } // Call the next handler next.ServeHTTP(wrapped, r) // Record metrics after request completes duration := time.Since(start).Seconds() status := strconv.Itoa(wrapped.statusCode) httpRequestsTotal.WithLabelValues(r.Method, r.URL.Path, status).Inc() httpRequestDuration.WithLabelValues(r.Method, r.URL.Path, status).Observe(duration) httpResponseSize.WithLabelValues(r.Method, r.URL.Path).Observe(float64(wrapped.bytesWritten)) }) } // metricsResponseWriter wraps http.ResponseWriter to capture status code and bytes written. type metricsResponseWriter struct { http.ResponseWriter statusCode int bytesWritten int wroteHeader bool } func (w *metricsResponseWriter) WriteHeader(statusCode int) { if w.wroteHeader { return } w.wroteHeader = true w.statusCode = statusCode w.ResponseWriter.WriteHeader(statusCode) } func (w *metricsResponseWriter) Write(b []byte) (int, error) { if !w.wroteHeader { w.wroteHeader = true w.statusCode = http.StatusOK } n, err := w.ResponseWriter.Write(b) w.bytesWritten += n return n, err } func (w *metricsResponseWriter) Flush() { if flusher, ok := w.ResponseWriter.(http.Flusher); ok { flusher.Flush() } }